9d51bd91a8
data sources
39 lines
1002 B
Python
39 lines
1002 B
Python
import os
|
|
import sys
|
|
import pprint
|
|
import google.generativeai as genai
|
|
|
|
api_key = os.getenv("GEMINI_API_KEY")
|
|
assert api_key is not None
|
|
|
|
genai.configure(api_key=api_key)
|
|
|
|
max_500_config = genai.GenerationConfig(max_output_tokens=500)
|
|
|
|
model = genai.GenerativeModel("gemini-pro", generation_config=max_500_config)
|
|
|
|
SUMMARY_PROMPT = "Sumarize following content:\n"
|
|
SUMMARY_PROMPT_ZH = "总结以下内容:\n"
|
|
TRANSLATE_PROMPT = "Translate following context to Chinese:\n"
|
|
|
|
|
|
def get_response(text: str, prompt: str = SUMMARY_PROMPT) -> str:
|
|
"""
|
|
Get response from gemini-pro model
|
|
text: content
|
|
prompt: prompt will be added ahead the content
|
|
"""
|
|
response = model.generate_content(prompt + text)
|
|
return response.text
|
|
|
|
|
|
def get_response_list(articles: list, prompt: str = SUMMARY_PROMPT) -> str:
|
|
pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
f = open(sys.argv[1], "rb")
|
|
c = str(f.read(), "utf-8")
|
|
pp = pprint.PrettyPrinter()
|
|
pp.pprint(get_response(c, SUMMARY_PROMPT_ZH))
|