Spaces:
Sleeping
Sleeping
| import os | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_openai import ChatOpenAI | |
| from dotenv import load_dotenv | |
| from typing import List, AsyncIterable, Annotated, Optional | |
| from langchain.callbacks import AsyncIteratorCallbackHandler | |
| from langchain_core.output_parsers import StrOutputParser | |
| import asyncio | |
| import datetime | |
| import csv | |
| load_dotenv() | |
| GROQ_API_KEY = os.environ.get("GROQ_API_KEY") | |
| GROQ_API_BASE = os.environ.get("GROQ_API_BASE") | |
| GROQ_MODEL_NAME = os.environ.get("OPENAI_MODEL_NAME") | |
| def read_pattern_files(pattern: str) -> (str, str): | |
| system_file = 'system.md' | |
| user_file = 'user.md' | |
| system_content = "" | |
| user_content = "" | |
| pattern_dir = "patterns" | |
| # Construct the full paths | |
| system_file_path = os.path.abspath(os.path.join(pattern_dir, pattern, system_file)) | |
| user_file_path = os.path.abspath(os.path.join(pattern_dir, pattern, user_file)) | |
| print(system_file_path) | |
| print(user_file_path) | |
| # Check if system.md exists | |
| if os.path.exists(system_file_path): | |
| with open(system_file_path, 'r') as file: | |
| system_content = file.read() | |
| # Check if user.md exists | |
| if os.path.exists(user_file_path): | |
| with open(user_file_path, 'r') as file: | |
| user_content = file.read() | |
| return system_content, user_content | |
| async def generate_pattern(pattern: str, query: str) -> AsyncIterable[str] : | |
| callback = AsyncIteratorCallbackHandler() | |
| chat = ChatOpenAI( | |
| openai_api_base=GROQ_API_BASE, | |
| api_key=GROQ_API_KEY, | |
| temperature=0.0, | |
| model_name= GROQ_MODEL_NAME, #"mixtral-8x7b-32768", #GROQ_MODEL_NAME, | |
| streaming=True, # ! important | |
| verbose=True, | |
| callbacks=[callback] | |
| ) | |
| system, usr_content = read_pattern_files(pattern=pattern) | |
| print('Sys Content -- > ') | |
| print(system) | |
| print('User Content --- > ') | |
| print(usr_content) | |
| human = usr_content + "{text}" | |
| prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)]) | |
| chain = prompt | chat | StrOutputParser() | |
| task = asyncio.create_task( | |
| chain.ainvoke({"text": query}) | |
| ) | |
| index = 0 | |
| try: | |
| async for token in callback.aiter(): | |
| print(index, ": ", token, ": ", datetime.datetime.now().time()) | |
| index = index + 1 | |
| yield token | |
| except Exception as e: | |
| print(f"Caught exception: {e}") | |
| finally: | |
| callback.done.set() | |
| await task | |