Spaces:
Runtime error
Runtime error
| import copy | |
| import asyncio | |
| import google.generativeai as genai | |
| from pingpong import PingPong | |
| from pingpong.pingpong import PPManager | |
| from pingpong.pingpong import PromptFmt | |
| from pingpong.pingpong import UIFmt | |
| from pingpong.gradio import GradioChatUIFmt | |
| class GeminiChatPromptFmt(PromptFmt): | |
| def ctx(cls, context): | |
| if context is None or context == "": | |
| return None | |
| else: | |
| return { | |
| "role": "system", | |
| "parts": [context] | |
| } | |
| def prompt(cls, pingpong, truncate_size): | |
| ping = pingpong.ping[:truncate_size] | |
| pong = "" if pingpong.pong is None else pingpong.pong[:truncate_size] | |
| result = [ | |
| { | |
| "role": "user", | |
| "parts": [ping] | |
| } | |
| ] | |
| if pong != "": | |
| result = result + [ | |
| { | |
| "role": "model", | |
| "parts": [pong] | |
| } | |
| ] | |
| return result | |
| class GeminiChatPPManager(PPManager): | |
| def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=GeminiChatPromptFmt, truncate_size: int=None): | |
| if to_idx == -1 or to_idx >= len(self.pingpongs): | |
| to_idx = len(self.pingpongs) | |
| pingpongs = copy.deepcopy(self.pingpongs) | |
| ctx = fmt.ctx(self.ctx) | |
| ctx = ctx['parts'][0] if ctx is not None else "" | |
| results = [] | |
| for idx, pingpong in enumerate(pingpongs[from_idx:to_idx]): | |
| if idx == 0: | |
| pingpong.ping = f"SYSTEM: {ctx} ----------- \n" + pingpong.ping | |
| results += fmt.prompt(pingpong, truncate_size=truncate_size) | |
| return results | |
| class GradioGeminiChatPPManager(GeminiChatPPManager): | |
| def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt): | |
| if to_idx == -1 or to_idx >= len(self.pingpongs): | |
| to_idx = len(self.pingpongs) | |
| results = [] | |
| for pingpong in self.pingpongs[from_idx:to_idx]: | |
| results.append(fmt.ui(pingpong)) | |
| return results | |
| def init(api_key): | |
| genai.configure(api_key=api_key) | |
| def _default_gen_text(): | |
| return { | |
| "temperature": 0.9, | |
| "top_p": 0.8, | |
| "top_k": 32, | |
| "max_output_tokens": 2048, | |
| } | |
| def _default_safety_settings(): | |
| return [ | |
| { | |
| "category": "HARM_CATEGORY_HARASSMENT", | |
| "threshold": "BLOCK_NONE" | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_HATE_SPEECH", | |
| "threshold": "BLOCK_NONE" | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
| "threshold": "BLOCK_NONE" | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
| "threshold": "BLOCK_NONE" | |
| }, | |
| ] | |
| async def _word_generator(sentence): | |
| for word in sentence.split(" "): | |
| yield word | |
| delay = 0.03 + (len(word) * 0.005) | |
| await asyncio.sleep(delay) # Simulate a short delay | |
| async def gen_text( | |
| prompts, | |
| gen_config=_default_gen_text(), | |
| safety_settings=_default_safety_settings(), | |
| stream=True | |
| ): | |
| model = genai.GenerativeModel(model_name="gemini-1.0-pro", | |
| generation_config=gen_config, | |
| safety_settings=safety_settings) | |
| user_prompt = prompts[-1] | |
| prompts = prompts[:-1] | |
| convo = model.start_chat(history=prompts) | |
| resps = await convo.send_message_async( | |
| user_prompt["parts"][0], stream=stream | |
| ) | |
| async for resp in resps: | |
| async for word in _word_generator(resp.text): | |
| yield word + " " | |