import os import sys import pprint import google.generativeai as genai class aimodel: def __init__(self, api_key: str, model: str = "gemini-pro"): self.api_key = api_key self.model_name = model genai.configure(api_key=api_key) max_500_config = genai.GenerationConfig(max_output_tokens=500) self.model = genai.GenerativeModel( self.model_name, generation_config=max_500_config ) def get_response(self, text: str, prompt: str = "") -> str: """ Get response from gemini-pro model text: content prompt: prompt will be added ahead the content """ response = self.model.generate_content(prompt + text) return response.text SUMMARY_PROMPT = "Sumarize following content:\n" SUMMARY_PROMPT_ZH = "总结以下内容:\n" TRANSLATE_PROMPT = "Translate following context to Chinese:\n" CONTENT_TO_KEYWORD_PROMPT = { "EN": "Anlysis following content and generate 5 keywords:\n", "ZH": "分析以下内容并总结出五个关键词:\n", } class aimodel_wrapper: def __init__(self, aimodel): self.aimodel = aimodel def content_to_5_keywords(self, content: str, lang: str = "EN") -> list: raw = self.aimodel.get_response(content, CONTENT_TO_KEYWORD_PROMPT[lang]) splitted = raw.split("\n") ret = [x.split(" ")[-1] for x in splitted] return ret if __name__ == "__main__": api_key = os.getenv("GEMINI_API_KEY") assert api_key is not None model = aimodel(api_key) wrapper = aimodel_wrapper(model) f = open(sys.argv[1], "rb") c = str(f.read(), "utf-8") pp = pprint.PrettyPrinter() pp.pprint(wrapper.content_to_5_keywords(c, "EN"))