Poojashetty357's picture
Update app.py
6d90ced verified
import os
import requests
import base64
import matplotlib.pyplot as plt
from io import BytesIO
import fitz # PyMuPDF
from dotenv import load_dotenv
import gradio as gr
from langchain_openai import ChatOpenAI
from langchain.tools import Tool
from langgraph.graph import StateGraph, END
from typing import TypedDict, Optional
from PIL import Image
# Load environment variables
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
# Inject into environment explicitly (safety)
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
STOCK_API_KEY = os.getenv("STOCK_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
# Initialize LLM
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.3)
# Tool functions
def get_stock_symbol(company_name: str) -> str:
if not STOCK_API_KEY:
raise ValueError("Missing Alpha Vantage API key!")
# Fallback for well-known company names
known_symbols = {
"apple": "AAPL",
"tesla": "TSLA",
"microsoft": "MSFT",
"amazon": "AMZN",
"meta": "META",
"google": "GOOGL",
"alphabet": "GOOGL",
"netflix": "NFLX",
"nvidia": "NVDA",
"intel": "INTC",
"accenture": "ACN"
}
clean_name = company_name.lower().strip()
if clean_name in known_symbols:
print(f"[Fallback] Returning known symbol for {company_name}: {known_symbols[clean_name]}")
return known_symbols[clean_name]
url = f"https://www.alphavantage.co/query?function=SYMBOL_SEARCH&keywords={company_name}&apikey={STOCK_API_KEY}"
try:
response = requests.get(url)
data = response.json()
print("[DEBUG] Symbol search API response:", data)
# Check for rate limiting
if "Note" in data:
print("❌ Alpha Vantage API limit hit:", data["Note"])
return ""
# Check for error messages
if "Error Message" in data:
print("❌ Alpha Vantage Error:", data["Error Message"])
return ""
matches = data.get("bestMatches", [])
if not matches:
print(f"❌ No matches found for: {company_name}")
return ""
# Prefer US-based symbols
for match in matches:
region = match.get("4. region", "").lower()
if "united states" in region:
symbol = match.get("1. symbol", "")
print(f"[Match] Found US symbol for {company_name}: {symbol}")
return symbol
# Fallback: return top match
symbol = matches[0].get("1. symbol", "")
print(f"[Fallback] Using first match for {company_name}: {symbol}")
return symbol
except Exception as e:
print(f"❌ Exception during symbol lookup: {e}")
return ""
def get_financial_overview(symbol: str) -> str:
url = f"https://www.alphavantage.co/query?function=OVERVIEW&symbol={symbol}&apikey={STOCK_API_KEY}"
response = requests.get(url)
if response.status_code != 200:
return f"❌ Error fetching financial overview: {response.status_code}"
data = response.json()
if not data or "Symbol" not in data:
return f"❌ No financial data found for {symbol}. Try another company."
def format_value(key, unit=""):
val = data.get(key)
return f"{val}{unit}" if val and val != "None" else "N/A"
return (
f"πŸ“ˆ **Financial Overview for {data.get('Name', symbol)}**\n\n"
f"β€’ **P/E Ratio:** {format_value('PERatio')}\n"
f"β€’ **EPS:** {format_value('EPS')}\n"
f"β€’ **Profit Margin:** {format_value('ProfitMargin')}\n"
f"β€’ **Operating Margin:** {format_value('OperatingMarginTTM')}\n"
f"β€’ **Market Cap:** ${format_value('MarketCapitalization')}\n"
f"β€’ **Revenue (TTM):** ${format_value('RevenueTTM')}\n"
f"β€’ **Gross Profit:** ${format_value('GrossProfitTTM')}\n"
f"β€’ **Return on Equity:** {format_value('ReturnOnEquityTTM')}\n"
f"β€’ **Analyst Target Price:** ${format_value('AnalystTargetPrice')}\n\n"
f"πŸ“ **Description:** {data.get('Description', 'No description available.')[:400]}..."
)
def get_company_news(company_name: str) -> dict:
headers = {"Authorization": f"Bearer {TAVILY_API_KEY}"}
payload = {
"query": f"{company_name} latest news",
"num_results": 3,
"topic": "news",
"time_range": "week"
}
try:
response = requests.post("https://api.tavily.com/search", headers=headers, json=payload)
if response.status_code != 200:
return {"success": False, "error": f"❌ Tavily API error: {response.status_code}"}
data = response.json()
results = data.get("results", [])
if not results:
return {"success": False, "error": "❌ No news found."}
return {
"success": True,
"raw": results,
"news": "\n\n".join([f"πŸ“° {r['title']}\nπŸ”— {r['url']}" for r in results])
}
except Exception as e:
return {"success": False, "error": f"❌ Exception: {str(e)}"}
def get_stock_quote(symbol: str) -> str:
url = f"https://www.alphavantage.co/query?function=GLOBAL_QUOTE&symbol={symbol}&apikey={STOCK_API_KEY}"
data = requests.get(url).json()
# Optional: check for rate limit or error
if "Note" in data:
return "❌ API rate limit reached. Try again later."
if "Error Message" in data:
return f"❌ API Error: {data['Error Message']}"
try:
quote = data["Global Quote"]
return f"πŸ“ˆ Price: ${quote['05. price']}, Last Trade: {quote['07. latest trading day']}"
except:
return "❌ Price data unavailable or symbol invalid."
def generate_stock_chart(symbol: str) -> Optional[BytesIO]:
url = f"https://www.alphavantage.co/query?function=TIME_SERIES_DAILY&symbol={symbol}&apikey={STOCK_API_KEY}"
response = requests.get(url)
if response.status_code != 200:
print(f"❌ HTTP Error: {response.status_code}")
return None
data = response.json()
# Check for API limit or error
if "Time Series (Daily)" not in data:
print("❌ Error in response:", data.get("Note") or data.get("Error Message") or "Unknown issue")
return None
timeseries = data["Time Series (Daily)"]
dates = list(timeseries.keys())
if len(dates) < 2:
print("❌ Not enough data to plot chart")
return None
# Prepare the chart data
dates = dates[:30] # Take latest 30 dates
prices = [float(timeseries[date]['4. close']) for date in dates]
# Generate plot
plt.figure(figsize=(8, 3))
plt.plot(dates[::-1], prices[::-1], marker='o')
plt.title(f"{symbol.upper()} - Last 30 Days Price")
plt.xticks(rotation=45, fontsize=8)
plt.grid(True)
plt.tight_layout()
# Save to buffer
buf = BytesIO()
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
return buf
def extract_pdf_text(file) -> str:
text = ""
if isinstance(file, str):
# If it's a file path
doc = fitz.open(file)
else:
# If it's a file-like object (BytesIO)
doc = fitz.open(stream=file.read(), filetype="pdf")
for page in doc:
text += page.get_text()
return text[:1500]
# Tools
tools = [
Tool(name="Get Symbol", func=get_stock_symbol, description="Find stock symbol for a company"),
Tool(name="Get Quote", func=get_stock_quote, description="Get real-time stock price"),
Tool(name="Get Overview", func=get_financial_overview, description="Get company financials"),
Tool(name="Get News", func=get_company_news, description="Fetch company-related news"),
Tool(name="Get Chart", func=generate_stock_chart, description="Generate 30-day price chart")
]
tool_map = {tool.name: tool for tool in tools}
# LangGraph Nodes
def select_tool(state):
query = state["query"]
options = ", ".join(tool_map.keys())
prompt = f"""You are a smart financial agent. Based on the query: \"{query}\", pick the best tool from:
{options}
Respond only with the tool name (exact match)."""
tool_name = llm.invoke(prompt).content.strip().splitlines()[0].strip().replace(".", "")
print("[Tool Selection] Tool Chosen:", tool_name)
if tool_name not in tool_map:
tool_name = "Get Overview"
state["chosen_tool"] = tool_name
return state
def run_tool(state):
tool_name = state["chosen_tool"]
query = state["query"]
tool = tool_map.get(tool_name)
if not tool:
state["result"] = f"❌ Tool '{tool_name}' not found."
return state
company_name = extract_company_name(query)
symbol = get_stock_symbol(company_name)
print("[Run Tool] Tool:", tool_name, "Company:", company_name, "Symbol:", symbol)
# βœ… Allow Get News and Get Symbol to run even if symbol is not found
if not symbol and tool_name not in ["Get News", "Get Symbol"]:
state["result"] = f"❌ No stock symbol found for '{company_name}'."
return state
# βœ… Get News tool
if tool_name == "Get News":
news_response = get_company_news(company_name)
if not news_response.get("success"):
state["result"] = news_response.get("error", "❌ Failed to fetch news.")
else:
results = news_response.get("raw", [])
results.sort(key=lambda r: r.get("published_date", ""), reverse=True)
formatted = "\n\n".join(
[f"πŸ“° **{r.get('title')}**\nπŸ“… {r.get('published_date', 'Unknown')}\nπŸ”— {r.get('url', '#')}"
for r in results]
)
print("[News Results]\n", formatted)
state["result"] = formatted
# βœ… Get Symbol tool
elif tool_name == "Get Symbol":
result = tool.run(company_name)
state["result"] = result or f"❌ Could not find symbol for '{company_name}'."
# βœ… All other tools (Quote, Chart, Overview)
else:
result = tool.run(symbol)
if result:
state["result"] = result
else:
state["result"] = f"❌ No data returned for '{symbol}' using '{tool_name}'."
return state
def serve_pdf():
return "docs/Apple-Q2-Report.pdf"
def summarize_tool_result(state):
if state.get("chosen_tool") == "Get News":
# Don't summarize news links β€” just display them
state["summary"] = state["result"]
return state
summary_input = state.get("result", "")
doc_input = state.get("uploaded_content", "")
query = state.get("query", "")
prompt = (
f"Based on the following data and uploaded report, summarize investment insight for: {query}\n\n"
f"{summary_input}\n\nReport:\n{doc_input}"
)
state["summary"] = llm.invoke(prompt).content.strip()
return state
# LangGraph State
class AgentState(TypedDict):
query: str
chosen_tool: Optional[str]
result: Optional[str]
uploaded_content: Optional[str]
summary: Optional[str]
builder = StateGraph(AgentState)
builder.add_node("select_tool", select_tool)
builder.add_node("run_tool", run_tool)
builder.add_node("summarize", summarize_tool_result)
builder.set_entry_point("select_tool")
builder.add_edge("select_tool", "run_tool")
builder.add_edge("run_tool", "summarize")
builder.add_edge("summarize", END)
graph = builder.compile()
def extract_company_name(query: str) -> str:
"""
Extracts a known company name from the query string.
Falls back to using the full query if no match is found.
"""
known_names = [
"Apple", "Tesla", "Microsoft", "Amazon", "Accenture",
"Meta", "Google", "Alphabet", "Nvidia", "Netflix", "Intel"
]
for name in known_names:
if name.lower() in query.lower():
return name
return query.strip() # fallback to full query
def agent_response(query, uploaded_file):
state = {"query": query}
# βœ… Case 1: File-based summarization
if uploaded_file and "Summarize" in query:
state["uploaded_content"] = extract_pdf_text(uploaded_file)
prompt = (
f"You are a financial analyst. Based on the following uploaded financial report, "
f"generate an investment insight summary for {query}. "
f"Use specific details from the report and avoid general statements.\n\n"
f"### Uploaded Report:\n{state['uploaded_content']}"
)
print("[PDF Summary] Invoking LLM with report text")
summary = llm.invoke(prompt).content.strip()
return summary # Just text for Markdown
# βœ… Case 2: Tool-based flow
print("[Agent Response] Running LangGraph flow for query:", query)
result = graph.invoke(state)
final = result.get("summary") or result.get("result")
# πŸ“ˆ Return chart image if applicable
if isinstance(final, BytesIO):
chart_image = Image.open(final)
return "", chart_image
return str(final), None
dropdown_options = [
"Get Overview",
"Get News",
"Get Symbol",
"Get Quote",
"Get Chart"
]
with gr.Blocks() as demo:
gr.Markdown("# 🧠 AI Stock Advisor + Financial Report Summarizer")
with gr.Tab("πŸ“Š Stock Advisor"):
gr.Markdown("**Option 1:** Type your question (recommended) or use the dropdown below")
free_query = gr.Textbox(label="Ask your stock-related question", placeholder="e.g., What is Tesla's stock price?")
gr.Markdown("**Option 2:** Use dropdown + company name")
company_input = gr.Textbox(label="Company Name", placeholder="e.g., Apple, Tesla", lines=1)
dropdown = gr.Dropdown(choices=dropdown_options, label="What do you want to know?", value=dropdown_options[0])
text_output = gr.Markdown(label="πŸ“˜ Text Summary")
chart_output = gr.Image(label="πŸ“ˆ Stock Chart", type="pil")
run_btn = gr.Button("πŸ” Analyze")
clear_btn = gr.Button("πŸ—‘οΈ Clear")
def handle_query(company_name, query_choice):
if not company_name.strip():
return "⚠️ Enter a company", None
# πŸ”„ Create the full query string
combined_query = f"{query_choice} for {company_name.strip()}"
# πŸš€ Run the agent (LLM + tool + logic)
text_result, image_result = agent_response(combined_query, None)
# πŸ–ΌοΈ If result is an image/chart
# if isinstance(result, BytesIO):
# return "", result # first is text output (blank), second is image/chart
# ❌ If no result was returned
if not text_result and image_result is None:
return "❌ Could not generate response. Try another company.", None
# πŸ“ Normal case β€” return result as text
return text_result, image_result
def clear_all():
return "", dropdown_options[0], "", None # Clear all
run_btn.click(fn=handle_query, inputs=[company_input, dropdown], outputs=[text_output, chart_output])
clear_btn.click(fn=clear_all, inputs=[], outputs=[company_input, dropdown, text_output, chart_output])
with gr.Tab("πŸ“„ Financial Report Summarizer"):
gr.Markdown("### Upload a financial report and provide company name. The agent will analyze using tools + file.")
with gr.Row():
download_btn = gr.Button("πŸ“Ž Download Sample Report")
sample_file_output = gr.File(label="Click to Download")
file_input = gr.File(label="Upload PDF report", file_types=[".pdf"])
company_name_input = gr.Textbox(label="Company Name", placeholder="e.g., Apple", lines=1)
summary_output = gr.Markdown(label="πŸ“ AI Summary")
with gr.Row():
summarize_btn = gr.Button("πŸ“„ Analyze Report")
clear_btn_2 = gr.Button("πŸ—‘οΈ Clear")
def summarize_with_tool(company, file):
if not file or not company.strip():
return "⚠️ Please upload a file and enter a company name."
combined_query = f"Summarize financial report for {company.strip()}"
summary = agent_response(combined_query, file) # Only take the summary
return summary # Return just the text
def clear_report_fields():
return None, "", ""
download_btn.click(fn=serve_pdf, outputs=sample_file_output)
summarize_btn.click(fn=summarize_with_tool, inputs=[company_name_input, file_input], outputs=summary_output)
clear_btn_2.click(fn=clear_report_fields, inputs=[], outputs=[file_input, company_name_input, summary_output])
demo.launch()