|
import os |
|
import requests |
|
import base64 |
|
import matplotlib.pyplot as plt |
|
from io import BytesIO |
|
import fitz |
|
from dotenv import load_dotenv |
|
import gradio as gr |
|
from langchain_openai import ChatOpenAI |
|
from langchain.tools import Tool |
|
from langgraph.graph import StateGraph, END |
|
from typing import TypedDict, Optional |
|
from PIL import Image |
|
|
|
|
|
|
|
load_dotenv() |
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY |
|
|
|
STOCK_API_KEY = os.getenv("STOCK_API_KEY") |
|
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") |
|
|
|
|
|
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.3) |
|
|
|
|
|
|
|
def get_stock_symbol(company_name: str) -> str: |
|
if not STOCK_API_KEY: |
|
raise ValueError("Missing Alpha Vantage API key!") |
|
|
|
|
|
known_symbols = { |
|
"apple": "AAPL", |
|
"tesla": "TSLA", |
|
"microsoft": "MSFT", |
|
"amazon": "AMZN", |
|
"meta": "META", |
|
"google": "GOOGL", |
|
"alphabet": "GOOGL", |
|
"netflix": "NFLX", |
|
"nvidia": "NVDA", |
|
"intel": "INTC", |
|
"accenture": "ACN" |
|
} |
|
|
|
clean_name = company_name.lower().strip() |
|
if clean_name in known_symbols: |
|
print(f"[Fallback] Returning known symbol for {company_name}: {known_symbols[clean_name]}") |
|
return known_symbols[clean_name] |
|
|
|
url = f"https://www.alphavantage.co/query?function=SYMBOL_SEARCH&keywords={company_name}&apikey={STOCK_API_KEY}" |
|
try: |
|
response = requests.get(url) |
|
data = response.json() |
|
print("[DEBUG] Symbol search API response:", data) |
|
|
|
|
|
if "Note" in data: |
|
print("β Alpha Vantage API limit hit:", data["Note"]) |
|
return "" |
|
|
|
|
|
if "Error Message" in data: |
|
print("β Alpha Vantage Error:", data["Error Message"]) |
|
return "" |
|
|
|
matches = data.get("bestMatches", []) |
|
if not matches: |
|
print(f"β No matches found for: {company_name}") |
|
return "" |
|
|
|
|
|
for match in matches: |
|
region = match.get("4. region", "").lower() |
|
if "united states" in region: |
|
symbol = match.get("1. symbol", "") |
|
print(f"[Match] Found US symbol for {company_name}: {symbol}") |
|
return symbol |
|
|
|
|
|
symbol = matches[0].get("1. symbol", "") |
|
print(f"[Fallback] Using first match for {company_name}: {symbol}") |
|
return symbol |
|
|
|
except Exception as e: |
|
print(f"β Exception during symbol lookup: {e}") |
|
return "" |
|
|
|
|
|
def get_financial_overview(symbol: str) -> str: |
|
url = f"https://www.alphavantage.co/query?function=OVERVIEW&symbol={symbol}&apikey={STOCK_API_KEY}" |
|
response = requests.get(url) |
|
|
|
if response.status_code != 200: |
|
return f"β Error fetching financial overview: {response.status_code}" |
|
|
|
data = response.json() |
|
|
|
if not data or "Symbol" not in data: |
|
return f"β No financial data found for {symbol}. Try another company." |
|
|
|
def format_value(key, unit=""): |
|
val = data.get(key) |
|
return f"{val}{unit}" if val and val != "None" else "N/A" |
|
|
|
return ( |
|
f"π **Financial Overview for {data.get('Name', symbol)}**\n\n" |
|
f"β’ **P/E Ratio:** {format_value('PERatio')}\n" |
|
f"β’ **EPS:** {format_value('EPS')}\n" |
|
f"β’ **Profit Margin:** {format_value('ProfitMargin')}\n" |
|
f"β’ **Operating Margin:** {format_value('OperatingMarginTTM')}\n" |
|
f"β’ **Market Cap:** ${format_value('MarketCapitalization')}\n" |
|
f"β’ **Revenue (TTM):** ${format_value('RevenueTTM')}\n" |
|
f"β’ **Gross Profit:** ${format_value('GrossProfitTTM')}\n" |
|
f"β’ **Return on Equity:** {format_value('ReturnOnEquityTTM')}\n" |
|
f"β’ **Analyst Target Price:** ${format_value('AnalystTargetPrice')}\n\n" |
|
f"π **Description:** {data.get('Description', 'No description available.')[:400]}..." |
|
) |
|
|
|
|
|
def get_company_news(company_name: str) -> dict: |
|
headers = {"Authorization": f"Bearer {TAVILY_API_KEY}"} |
|
payload = { |
|
"query": f"{company_name} latest news", |
|
"num_results": 3, |
|
"topic": "news", |
|
"time_range": "week" |
|
} |
|
|
|
try: |
|
response = requests.post("https://api.tavily.com/search", headers=headers, json=payload) |
|
if response.status_code != 200: |
|
return {"success": False, "error": f"β Tavily API error: {response.status_code}"} |
|
|
|
data = response.json() |
|
results = data.get("results", []) |
|
|
|
if not results: |
|
return {"success": False, "error": "β No news found."} |
|
|
|
return { |
|
"success": True, |
|
"raw": results, |
|
"news": "\n\n".join([f"π° {r['title']}\nπ {r['url']}" for r in results]) |
|
} |
|
|
|
except Exception as e: |
|
return {"success": False, "error": f"β Exception: {str(e)}"} |
|
|
|
def get_stock_quote(symbol: str) -> str: |
|
url = f"https://www.alphavantage.co/query?function=GLOBAL_QUOTE&symbol={symbol}&apikey={STOCK_API_KEY}" |
|
data = requests.get(url).json() |
|
|
|
|
|
if "Note" in data: |
|
return "β API rate limit reached. Try again later." |
|
if "Error Message" in data: |
|
return f"β API Error: {data['Error Message']}" |
|
|
|
try: |
|
quote = data["Global Quote"] |
|
return f"π Price: ${quote['05. price']}, Last Trade: {quote['07. latest trading day']}" |
|
except: |
|
return "β Price data unavailable or symbol invalid." |
|
|
|
|
|
def generate_stock_chart(symbol: str) -> Optional[BytesIO]: |
|
url = f"https://www.alphavantage.co/query?function=TIME_SERIES_DAILY&symbol={symbol}&apikey={STOCK_API_KEY}" |
|
response = requests.get(url) |
|
|
|
if response.status_code != 200: |
|
print(f"β HTTP Error: {response.status_code}") |
|
return None |
|
|
|
data = response.json() |
|
|
|
|
|
if "Time Series (Daily)" not in data: |
|
print("β Error in response:", data.get("Note") or data.get("Error Message") or "Unknown issue") |
|
return None |
|
|
|
timeseries = data["Time Series (Daily)"] |
|
dates = list(timeseries.keys()) |
|
|
|
if len(dates) < 2: |
|
print("β Not enough data to plot chart") |
|
return None |
|
|
|
|
|
dates = dates[:30] |
|
prices = [float(timeseries[date]['4. close']) for date in dates] |
|
|
|
|
|
plt.figure(figsize=(8, 3)) |
|
plt.plot(dates[::-1], prices[::-1], marker='o') |
|
plt.title(f"{symbol.upper()} - Last 30 Days Price") |
|
plt.xticks(rotation=45, fontsize=8) |
|
plt.grid(True) |
|
plt.tight_layout() |
|
|
|
|
|
buf = BytesIO() |
|
plt.savefig(buf, format='png') |
|
plt.close() |
|
buf.seek(0) |
|
return buf |
|
|
|
|
|
def extract_pdf_text(file) -> str: |
|
text = "" |
|
if isinstance(file, str): |
|
|
|
doc = fitz.open(file) |
|
else: |
|
|
|
doc = fitz.open(stream=file.read(), filetype="pdf") |
|
for page in doc: |
|
text += page.get_text() |
|
return text[:1500] |
|
|
|
|
|
tools = [ |
|
Tool(name="Get Symbol", func=get_stock_symbol, description="Find stock symbol for a company"), |
|
Tool(name="Get Quote", func=get_stock_quote, description="Get real-time stock price"), |
|
Tool(name="Get Overview", func=get_financial_overview, description="Get company financials"), |
|
Tool(name="Get News", func=get_company_news, description="Fetch company-related news"), |
|
Tool(name="Get Chart", func=generate_stock_chart, description="Generate 30-day price chart") |
|
] |
|
tool_map = {tool.name: tool for tool in tools} |
|
|
|
|
|
def select_tool(state): |
|
query = state["query"] |
|
options = ", ".join(tool_map.keys()) |
|
prompt = f"""You are a smart financial agent. Based on the query: \"{query}\", pick the best tool from: |
|
{options} |
|
Respond only with the tool name (exact match).""" |
|
tool_name = llm.invoke(prompt).content.strip().splitlines()[0].strip().replace(".", "") |
|
print("[Tool Selection] Tool Chosen:", tool_name) |
|
if tool_name not in tool_map: |
|
tool_name = "Get Overview" |
|
state["chosen_tool"] = tool_name |
|
return state |
|
|
|
|
|
def run_tool(state): |
|
tool_name = state["chosen_tool"] |
|
query = state["query"] |
|
tool = tool_map.get(tool_name) |
|
|
|
if not tool: |
|
state["result"] = f"β Tool '{tool_name}' not found." |
|
return state |
|
|
|
company_name = extract_company_name(query) |
|
symbol = get_stock_symbol(company_name) |
|
print("[Run Tool] Tool:", tool_name, "Company:", company_name, "Symbol:", symbol) |
|
|
|
|
|
if not symbol and tool_name not in ["Get News", "Get Symbol"]: |
|
state["result"] = f"β No stock symbol found for '{company_name}'." |
|
return state |
|
|
|
|
|
if tool_name == "Get News": |
|
news_response = get_company_news(company_name) |
|
|
|
if not news_response.get("success"): |
|
state["result"] = news_response.get("error", "β Failed to fetch news.") |
|
else: |
|
results = news_response.get("raw", []) |
|
results.sort(key=lambda r: r.get("published_date", ""), reverse=True) |
|
|
|
formatted = "\n\n".join( |
|
[f"π° **{r.get('title')}**\nπ
{r.get('published_date', 'Unknown')}\nπ {r.get('url', '#')}" |
|
for r in results] |
|
) |
|
print("[News Results]\n", formatted) |
|
state["result"] = formatted |
|
|
|
|
|
elif tool_name == "Get Symbol": |
|
result = tool.run(company_name) |
|
state["result"] = result or f"β Could not find symbol for '{company_name}'." |
|
|
|
|
|
else: |
|
result = tool.run(symbol) |
|
if result: |
|
state["result"] = result |
|
else: |
|
state["result"] = f"β No data returned for '{symbol}' using '{tool_name}'." |
|
|
|
return state |
|
def serve_pdf(): |
|
return "docs/Apple-Q2-Report.pdf" |
|
|
|
def summarize_tool_result(state): |
|
if state.get("chosen_tool") == "Get News": |
|
|
|
state["summary"] = state["result"] |
|
return state |
|
|
|
summary_input = state.get("result", "") |
|
doc_input = state.get("uploaded_content", "") |
|
query = state.get("query", "") |
|
prompt = ( |
|
f"Based on the following data and uploaded report, summarize investment insight for: {query}\n\n" |
|
f"{summary_input}\n\nReport:\n{doc_input}" |
|
) |
|
state["summary"] = llm.invoke(prompt).content.strip() |
|
return state |
|
|
|
|
|
|
|
class AgentState(TypedDict): |
|
query: str |
|
chosen_tool: Optional[str] |
|
result: Optional[str] |
|
uploaded_content: Optional[str] |
|
summary: Optional[str] |
|
|
|
builder = StateGraph(AgentState) |
|
builder.add_node("select_tool", select_tool) |
|
builder.add_node("run_tool", run_tool) |
|
builder.add_node("summarize", summarize_tool_result) |
|
builder.set_entry_point("select_tool") |
|
builder.add_edge("select_tool", "run_tool") |
|
builder.add_edge("run_tool", "summarize") |
|
builder.add_edge("summarize", END) |
|
graph = builder.compile() |
|
|
|
def extract_company_name(query: str) -> str: |
|
""" |
|
Extracts a known company name from the query string. |
|
Falls back to using the full query if no match is found. |
|
""" |
|
known_names = [ |
|
"Apple", "Tesla", "Microsoft", "Amazon", "Accenture", |
|
"Meta", "Google", "Alphabet", "Nvidia", "Netflix", "Intel" |
|
] |
|
for name in known_names: |
|
if name.lower() in query.lower(): |
|
return name |
|
return query.strip() |
|
|
|
|
|
def agent_response(query, uploaded_file): |
|
state = {"query": query} |
|
|
|
|
|
if uploaded_file and "Summarize" in query: |
|
state["uploaded_content"] = extract_pdf_text(uploaded_file) |
|
|
|
prompt = ( |
|
f"You are a financial analyst. Based on the following uploaded financial report, " |
|
f"generate an investment insight summary for {query}. " |
|
f"Use specific details from the report and avoid general statements.\n\n" |
|
f"### Uploaded Report:\n{state['uploaded_content']}" |
|
) |
|
|
|
print("[PDF Summary] Invoking LLM with report text") |
|
summary = llm.invoke(prompt).content.strip() |
|
return summary |
|
|
|
|
|
print("[Agent Response] Running LangGraph flow for query:", query) |
|
result = graph.invoke(state) |
|
|
|
final = result.get("summary") or result.get("result") |
|
|
|
|
|
if isinstance(final, BytesIO): |
|
chart_image = Image.open(final) |
|
return "", chart_image |
|
|
|
return str(final), None |
|
|
|
|
|
|
|
|
|
dropdown_options = [ |
|
"Get Overview", |
|
"Get News", |
|
"Get Symbol", |
|
"Get Quote", |
|
"Get Chart" |
|
] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# π§ AI Stock Advisor + Financial Report Summarizer") |
|
|
|
with gr.Tab("π Stock Advisor"): |
|
gr.Markdown("**Option 1:** Type your question (recommended) or use the dropdown below") |
|
|
|
free_query = gr.Textbox(label="Ask your stock-related question", placeholder="e.g., What is Tesla's stock price?") |
|
|
|
gr.Markdown("**Option 2:** Use dropdown + company name") |
|
company_input = gr.Textbox(label="Company Name", placeholder="e.g., Apple, Tesla", lines=1) |
|
dropdown = gr.Dropdown(choices=dropdown_options, label="What do you want to know?", value=dropdown_options[0]) |
|
text_output = gr.Markdown(label="π Text Summary") |
|
chart_output = gr.Image(label="π Stock Chart", type="pil") |
|
run_btn = gr.Button("π Analyze") |
|
clear_btn = gr.Button("ποΈ Clear") |
|
|
|
def handle_query(company_name, query_choice): |
|
if not company_name.strip(): |
|
return "β οΈ Enter a company", None |
|
|
|
|
|
combined_query = f"{query_choice} for {company_name.strip()}" |
|
|
|
|
|
text_result, image_result = agent_response(combined_query, None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if not text_result and image_result is None: |
|
return "β Could not generate response. Try another company.", None |
|
|
|
|
|
return text_result, image_result |
|
|
|
|
|
def clear_all(): |
|
return "", dropdown_options[0], "", None |
|
|
|
run_btn.click(fn=handle_query, inputs=[company_input, dropdown], outputs=[text_output, chart_output]) |
|
clear_btn.click(fn=clear_all, inputs=[], outputs=[company_input, dropdown, text_output, chart_output]) |
|
|
|
|
|
|
|
with gr.Tab("π Financial Report Summarizer"): |
|
gr.Markdown("### Upload a financial report and provide company name. The agent will analyze using tools + file.") |
|
with gr.Row(): |
|
download_btn = gr.Button("π Download Sample Report") |
|
sample_file_output = gr.File(label="Click to Download") |
|
|
|
|
|
file_input = gr.File(label="Upload PDF report", file_types=[".pdf"]) |
|
company_name_input = gr.Textbox(label="Company Name", placeholder="e.g., Apple", lines=1) |
|
summary_output = gr.Markdown(label="π AI Summary") |
|
with gr.Row(): |
|
summarize_btn = gr.Button("π Analyze Report") |
|
clear_btn_2 = gr.Button("ποΈ Clear") |
|
|
|
|
|
|
|
|
|
def summarize_with_tool(company, file): |
|
if not file or not company.strip(): |
|
return "β οΈ Please upload a file and enter a company name." |
|
|
|
combined_query = f"Summarize financial report for {company.strip()}" |
|
summary = agent_response(combined_query, file) |
|
return summary |
|
|
|
def clear_report_fields(): |
|
return None, "", "" |
|
|
|
download_btn.click(fn=serve_pdf, outputs=sample_file_output) |
|
summarize_btn.click(fn=summarize_with_tool, inputs=[company_name_input, file_input], outputs=summary_output) |
|
clear_btn_2.click(fn=clear_report_fields, inputs=[], outputs=[file_input, company_name_input, summary_output]) |
|
demo.launch() |
|
|
|
|
|
|