import os import gradio as gr from langchain.chat_models import ChatOpenAI from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate from langchain.chains import LLMChain # Set environment variables OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY os.environ['LANGCHAIN_VERBOSE'] = 'true' # Histogram-specific reference code hist_prompt_template = """```python import matplotlib.pyplot as plt import seaborn as sns import pandas as pd simulated_data = { 'Month': ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December'], 'Executed_Operations': [5, 8, 6, 7, 9, 10, 12, 11, 8, 7, 6, 5] } df = pd.DataFrame(simulated_data) plt.figure(figsize=(10, 6)) sns.barplot(x='Month', y='Executed_Operations', data=df, palette='viridis') plt.xticks(rotation=45) plt.title('Executed Science Operations per Month in 2006') plt.xlabel('Month') plt.ylabel('Number of Executed Operations') plt.tight_layout() plt.show() ```""" # Line graph-specific reference code graph_prompt_template = """```python import matplotlib.pyplot as plt import pandas as pd data = { 'Month': ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December'], 'Executed_Science_Operations': [5, 8, 7, 6, 9, 10, 8, 7, 6, 5, 9, 10], 'Calibration_Operations': [2, 3, 2, 4, 3, 2, 3, 4, 3, 2, 3, 4] } df = pd.DataFrame(data) plt.figure(figsize=(10, 6)) plt.plot(df['Month'], df['Executed_Science_Operations'], marker='o', label='Executed Science Operations') plt.plot(df['Month'], df['Calibration_Operations'], marker='s', label='Calibration Operations') plt.title('Spitzer Space Telescope Operations - 2006') plt.xlabel('Month') plt.ylabel('Number of Operations') plt.xticks(rotation=45) plt.legend() plt.tight_layout() plt.show() ```""" # Master prompt with code injection default_prompt_template = """You are a Python data visualization assistant. The following is the content of a text file uploaded by a user: ------------------- {file_content} ------------------- {user_query} ------------------- Use this code for reference to generate the code. {generated_code} ------------------- Your task: - Assume this content is tabular or semi-structured data. - Generate a valid Python script using matplotlib, seaborn, or plotly based on the user query to visualize the data. - Modify the code as required to match the user query. - Only output the code block (no extra explanation). - The code should be executable as-is. - Include data parsing if required. Python Code:""" prompt = PromptTemplate(input_variables=["file_content", "user_query", "generated_code"], template=default_prompt_template) human_prompt = HumanMessagePromptTemplate(prompt=prompt) chat_prompt_template = ChatPromptTemplate.from_messages([human_prompt]) chat_model = ChatOpenAI(temperature=0, model_name="gpt-4o") chain = LLMChain(prompt=chat_prompt_template, llm=chat_model) def generate_and_plot(file, query): try: # Read file content if hasattr(file, "read"): file_content = file.read().decode("utf-8")[:1000] elif isinstance(file, str) or hasattr(file, "name"): file_path = file.name if hasattr(file, "name") else file with open(file_path, "r", encoding="utf-8") as f: file_content = f.read()[:1000] else: return "Unsupported file type." # Choose reference code if "histogram" in query.lower(): generated_code_hint = hist_prompt_template elif "line graph" in query.lower() or "graph plot" in query.lower(): generated_code_hint = graph_prompt_template else: generated_code_hint = "" # Generate code generated_code = chain.run(file_content=file_content, user_query=query, generated_code=generated_code_hint) # Execute code global_env = {} cleaned_code = generated_code.replace("```python", "").replace("```", "") exec(cleaned_code, global_env) # Return plot if "fig" in global_env: fig = global_env["fig"] try: import plotly.graph_objects as go if isinstance(fig, go.Figure): return fig except ImportError: pass if "plt" in global_env: import matplotlib.pyplot as plt return plt.gcf() import matplotlib.pyplot as plt fig, ax = plt.subplots() ax.text(0.5, 0.5, "No figure was generated", ha='center', va='center') ax.axis("off") return fig except Exception as e: return f"Error: {e}" import os import gradio as gr from langchain.chat_models import ChatOpenAI from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import CharacterTextSplitter # from langchain.embeddings.openai import OpenAIEmbeddings # from langchain.vectorstores import Chroma from langchain.chains import RetrievalQA from tempfile import NamedTemporaryFile # Initialize LLM llm = ChatOpenAI(temperature=0, model_name="gpt-4o") # Keep global QA chain qa_chain = None def load_pdf_and_create_qa_chain(pdf_file): global qa_chain # Save uploaded file to temp if hasattr(pdf_file, 'read'): with NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file: tmp_file.write(pdf_file.read()) tmp_file_path = tmp_file.name else: tmp_file_path = pdf_file.name if hasattr(pdf_file, "name") else pdf_file # Load document loader = PyPDFLoader(tmp_file_path) documents = loader.load() # Split into chunks text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100) texts = text_splitter.split_documents(documents) from langchain.vectorstores import Chroma from langchain.embeddings.openai import OpenAIEmbeddings # Embed and store in vector DB embeddings = OpenAIEmbeddings() db = Chroma.from_documents(texts, embeddings) # Set up retriever retriever = db.as_retriever(search_type="similarity", search_kwargs={"k":2}) # Create RAG QA chain qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True ) return "PDF loaded and ready! You can now ask questions." def ask_question(query): global qa_chain if qa_chain is None: return "Please upload a PDF first." try: # result = qa_chain.run(query) response = qa_chain.invoke({"query": query}) result = response["result"] return result except Exception as e: return f"Error answering question: {e}" import gradio as gr import os # Assuming `generate_and_plot`, `load_pdf_and_create_qa_chain`, `ask_question` are already defined above def process_file(file, query): if file is None: return "Please upload a file.", None filename = file.name if hasattr(file, "name") else "" if filename.endswith(".pdf"): load_status = load_pdf_and_create_qa_chain(file) answer = ask_question(query) return answer, None elif filename.endswith(".txt") or filename.endswith(".csv"): plot = generate_and_plot(file, query) return "Here is your plot:", plot else: return "Unsupported file type. Upload a .pdf, .txt, or .csv file.", None # Gradio Interface with gr.Blocks() as demo: gr.Markdown("Astronomy ChatBot with Plotting and Summarizer") with gr.Row(): file_input = gr.File(label="Upload your file (.txt, .csv, .pdf)", file_types=[".txt", ".csv", ".pdf"]) query_input = gr.Textbox(label="Enter your question or plotting instruction") submit_btn = gr.Button("Submit") output_text = gr.Textbox(label="Response") output_plot = gr.Plot(label="Generated Plot") submit_btn.click(fn=process_file, inputs=[file_input, query_input], outputs=[output_text, output_plot]) demo.launch(debug = True, share = True)