Spaces:
Running
Running
# src/streamlit_app.py | |
import streamlit as st | |
import pandas as pd | |
from serpapi import Client | |
from prophet import Prophet | |
import plotly.express as px | |
import time | |
import google.generativeai as genai | |
import os | |
from matplotlib import pyplot as plt | |
from statsmodels.tsa.seasonal import seasonal_decompose | |
# --- Page Configuration --- | |
st.set_page_config(page_title="Trend-AI Marketing Dashboard", page_icon="🧠", layout="wide") | |
# --- Caching & Data Fetching Functions --- | |
def fetch_data_from_serpapi(api_key, keywords, timeframe, geo): | |
""" | |
Fetches Interest Over Time and By Region from SerpApi for a live analysis. | |
""" | |
def _parse_date_string(date_str): | |
if '–' in date_str: | |
try: | |
parts = date_str.split('–'); start_day_month = parts[0].strip(); year = parts[1].split(',')[-1].strip() | |
return pd.to_datetime(f"{start_day_month}, {year}") | |
except: return pd.to_datetime(date_str.split('–')[0].strip()) | |
else: return pd.to_datetime(date_str) | |
params = {"engine": "google_trends", "q": ", ".join(keywords), "date": timeframe, "geo": geo, "api_key": api_key} | |
try: | |
client = Client(); results = client.search(params) | |
all_data = {"interest_over_time": pd.DataFrame(), "interest_by_region": pd.DataFrame()} | |
if 'interest_over_time' in results: | |
timeline_data = results['interest_over_time']['timeline_data'] | |
dates = [_parse_date_string(item['date']) for item in timeline_data] | |
data = {} | |
for i, keyword in enumerate(keywords): data[keyword] = [item['values'][i].get('value', 0) for item in timeline_data] | |
all_data["interest_over_time"] = pd.DataFrame(data, index=dates) | |
if 'interest_by_region' in results: | |
df_region = pd.DataFrame(results['interest_by_region']).set_index('geoName') | |
if keywords and len(keywords) > 0: | |
renamed_col = df_region.columns[0] | |
df_region = df_region.rename(columns={renamed_col: 'Interest'}) | |
all_data["interest_by_region"] = df_region | |
return all_data | |
except Exception as e: | |
st.error(f"An error occurred with the SerpApi request: {e}"); return None | |
def load_all_offline_data(scenario_config): | |
""" | |
Loads Interest Over Time and By Region CSVs for a given offline scenario. | |
""" | |
prefix = scenario_config["prefix"] | |
keywords = scenario_config["keywords"] | |
all_data = {"interest_over_time": pd.DataFrame(), "interest_by_region": pd.DataFrame()} | |
try: | |
ot_path = f"data/{prefix}_over_time.csv" | |
df_ot = pd.read_csv(ot_path, skiprows=2); df_ot.rename(columns={df_ot.columns[0]: 'Date'}, inplace=True) | |
original_columns = df_ot.columns[1:] | |
column_map = dict(zip(original_columns, keywords)) | |
df_ot.rename(columns=column_map, inplace=True) | |
df_ot['Date'] = pd.to_datetime(df_ot['Date']); df_ot.set_index('Date', inplace=True) | |
for col in keywords: | |
if col in df_ot.columns: df_ot[col] = pd.to_numeric(df_ot[col], errors='coerce') | |
all_data['interest_over_time'] = df_ot | |
r_path = f"data/{prefix}_by_region.csv" | |
df_r = pd.read_csv(r_path, skiprows=1); df_r.rename(columns={df_r.columns[0]: 'Region', df_r.columns[1]: 'Interest'}, inplace=True); df_r['Interest'] = pd.to_numeric(df_r['Interest'], errors='coerce') | |
all_data['interest_by_region'] = df_r.set_index('Region') | |
return all_data | |
except FileNotFoundError as e: | |
st.error(f"Offline Data Error: Missing a primary CSV file: {e.filename}"); return None | |
except Exception as e: | |
st.error(f"Error loading offline data: {e}"); return None | |
# --- Main Application Logic --- | |
st.title("📈 Trend-AI: Marketing Intelligence Dashboard") | |
st.sidebar.header("Dashboard Controls") | |
# --- Final Sidebar UI using Session State and a Callback --- | |
scenarios = { | |
"Nike, Adidas, Puma, Asics, Under Armour": {"prefix": "athletic_brands", "keywords": ["Nike", "Adidas", "Puma", "Asics", "Under Armour"], "geo_code": "", "timeframe_key": "Past 5 Years"}, | |
"Apple iPhone, Samsung Galaxy, Google Pixel": {"prefix": "smartphones", "keywords": ["Apple iPhone", "Samsung Galaxy", "Google Pixel"], "geo_code": "US", "timeframe_key": "Past 5 Years"}, | |
"Mahindra Scorpio, Maruti Suzuki Brezza, Hyundai Creta": {"prefix": "indian_suvs", "keywords": ["Mahindra Scorpio", "Maruti Suzuki Brezza", "Hyundai Creta"], "geo_code": "IN", "timeframe_key": "Past 5 Years"} | |
} | |
scenario_options = [""] + list(scenarios.keys()) | |
# Initialize session state for widgets. Default to the "Athletic Brands" scenario. | |
if 'keywords_input' not in st.session_state: | |
st.session_state.keywords_input = "Nike, Adidas, Puma, Asics, Under Armour" | |
st.session_state.geo_selection = "" | |
st.session_state.timeframe_selection = "Past 5 Years" | |
st.session_state.scenario_selector = "Nike, Adidas, Puma, Asics, Under Armour" | |
# Callback function to update other widgets when a scenario is chosen from the dropdown | |
def update_state_from_scenario(): | |
scenario_key = st.session_state.scenario_selector | |
if scenario_key and scenario_key in scenarios: | |
config = scenarios[scenario_key] | |
st.session_state.keywords_input = scenario_key | |
st.session_state.geo_selection = config["geo_code"] | |
st.session_state.timeframe_selection = config["timeframe_key"] | |
# The UI widgets are linked via session state and the callback | |
keywords_input = st.sidebar.text_input("Enter Keywords", key="keywords_input") | |
st.sidebar.selectbox( | |
"Or, select a popular comparison", | |
options=scenario_options, | |
key='scenario_selector', | |
on_change=update_state_from_scenario, | |
help="Selecting an option will pre-fill the controls and use reliable offline data." | |
) | |
country_list = ['', 'US', 'GB', 'CA', 'AU', 'DE', 'FR', 'IN', 'JP', 'BR', 'ZA']; country_names = ['Worldwide', 'United States', 'United Kingdom', 'Canada', 'Australia', 'Germany', 'France', 'India', 'Japan', 'Brazil', 'South Africa']; country_dict = dict(zip(country_list, country_names)); geo_keys = list(country_dict.keys()) | |
try: geo_default_index = geo_keys.index(st.session_state.geo_selection) | |
except ValueError: geo_default_index = 0 | |
geo = st.sidebar.selectbox("Select Region", options=country_list, format_func=lambda x: country_dict[x], index=geo_default_index) | |
timeframe_options = {"Past Hour": "now 1-H", "Past 4 Hours": "now 4-H", "Past Day": "now 1-d", "Past 7 Days": "now 7-d", "Past 30 Days": "today 1-m", "Past 90 Days": "today 3-m", "Past 12 Months": "today 12-m", "Past 5 Years": "today 5-y", "All Time (Since 2004)": "all"} | |
timeframe_keys = list(timeframe_options.keys()) | |
try: timeframe_default_index = timeframe_keys.index(st.session_state.timeframe_selection) | |
except ValueError: timeframe_default_index = 7 | |
timeframe_key = st.sidebar.selectbox("Select Timeframe", options=timeframe_keys, index=timeframe_default_index) | |
timeframe = timeframe_options[timeframe_key] | |
# --- Main Dashboard --- | |
keywords_str = keywords_input | |
if keywords_str in scenarios: | |
keywords = scenarios[keywords_str]["keywords"] | |
else: | |
keywords = [k.strip() for k in keywords_str.split(',') if k.strip()][:5] | |
if not keywords: | |
st.info("⬅️ Please enter keywords or select a popular comparison from the sidebar.") | |
else: | |
st.header(f"Analysis for: {', '.join(keywords)}") | |
tab_names = ["📈 Trend Analysis", "🔮 Future Forecast", "🤖 AI Co-pilot"] | |
tab1, tab2, tab3 = st.tabs(tab_names) | |
trends_data = None | |
if keywords_str in scenarios: | |
with st.spinner(f"Loading pre-configured analysis for '{keywords_str}'..."): | |
trends_data = load_all_offline_data(scenarios[keywords_str]) | |
else: | |
serpapi_key = os.environ.get("SERPAPI_API_KEY") # Use os.environ.get for deployment | |
if not serpapi_key: | |
# Fallback for local dev using secrets.toml | |
try: | |
serpapi_key = st.secrets.get("SERPAPI_API_KEY") | |
except: | |
serpapi_key = None | |
if not serpapi_key: | |
st.error("❌ SerpApi API Key not found in secrets.") | |
else: | |
with st.spinner(f"Fetching live data for '{keywords_str}'..."): | |
trends_data = fetch_data_from_serpapi(serpapi_key, keywords, timeframe, geo) | |
if trends_data: | |
interest_df = trends_data.get("interest_over_time", pd.DataFrame()) | |
if not interest_df.empty: | |
for col in interest_df.columns: | |
if col != 'Date': interest_df[col] = pd.to_numeric(interest_df[col], errors='coerce') | |
with tab1: | |
st.subheader("Search Interest Over Time") | |
if not interest_df.empty: | |
fig = px.line(interest_df, x=interest_df.index, y=interest_df.columns) | |
st.plotly_chart(fig, use_container_width=True) | |
st.subheader("📊 Overall Interest Share") | |
interest_sum = interest_df.sum().reset_index() | |
interest_sum.columns = ['keyword', 'total_interest'] | |
fig_pie = px.pie(interest_sum, names='keyword', values='total_interest') | |
st.plotly_chart(fig_pie, use_container_width=True) | |
st.markdown("---") | |
st.subheader("🌍 Interest by Region") | |
region_df = trends_data.get("interest_by_region") | |
if region_df is not None and not region_df.empty: | |
fig_map = px.choropleth(region_df, locations=region_df.index, locationmode='country names', color='Interest', hover_name=region_df.index, color_continuous_scale=px.colors.sequential.Plasma) | |
st.plotly_chart(fig_map, use_container_width=True) | |
st.dataframe(region_df.sort_values(by="Interest", ascending=False), use_container_width=True) | |
else: st.warning("No regional data available.") | |
st.markdown("---") | |
st.subheader("📅 Keyword Seasonality Analysis") | |
monthly_df = interest_df.resample('M').mean() | |
monthly_df['month'] = monthly_df.index.month | |
seasonal_df = monthly_df.groupby('month')[interest_df.columns].mean().reset_index() | |
for col in interest_df.columns: | |
if (seasonal_df[col].max() - seasonal_df[col].min()) != 0: | |
seasonal_df[col] = (seasonal_df[col] - seasonal_df[col].min()) / (seasonal_df[col].max() - seasonal_df[col].min()) * 100 | |
else: seasonal_df[col] = 0 | |
seasonal_df['month'] = seasonal_df['month'].apply(lambda x: pd.to_datetime(str(x), format='%m').strftime('%B')) | |
seasonal_df.set_index('month', inplace=True) | |
fig_heatmap = px.imshow(seasonal_df.T, labels=dict(x="Month", y="Keyword", color="Normalized Interest"), aspect="auto", color_continuous_scale="Viridis") | |
st.plotly_chart(fig_heatmap, use_container_width=True) | |
st.markdown("---") | |
st.subheader("🔬 Year-over-Year Growth Analysis") | |
keyword_for_yoy = st.selectbox("Select keyword for YoY analysis", options=keywords, key='yoy_select') | |
if keyword_for_yoy: | |
monthly_yoy_df = interest_df[[keyword_for_yoy]].resample('M').mean() | |
monthly_yoy_df['YoY Growth (%)'] = monthly_yoy_df[keyword_for_yoy].pct_change(12) * 100 | |
st.dataframe(monthly_yoy_df.style.format({'YoY Growth (%)': "{:+.2f}%"}).applymap(lambda v: 'color: green;' if v > 0 else ('color: red;' if v < 0 else ''), subset=['YoY Growth (%)']), use_container_width=True) | |
st.markdown("---") | |
st.subheader("🔍 Trend Decomposition") | |
keyword_to_decompose = st.selectbox("Select keyword to decompose", options=keywords, key='decomp_select') | |
if keyword_to_decompose: | |
monthly_decomp_df = interest_df[keyword_to_decompose].resample('M').mean() | |
if len(monthly_decomp_df.dropna()) >= 24: | |
decomposition = seasonal_decompose(monthly_decomp_df.dropna(), model='additive', period=12) | |
fig_decomp = decomposition.plot() | |
fig_decomp.set_size_inches(10, 8) | |
st.pyplot(fig_decomp) | |
else: st.error("❌ Analysis Error: Decomposition requires at least 24 months of data.") | |
else: | |
st.warning("Could not fetch or load time-series data.") | |
with tab2: | |
st.subheader("Future Forecast with Prophet") | |
keyword_to_forecast = st.selectbox("Select a keyword to forecast", options=keywords) | |
if not interest_df.empty and keyword_to_forecast in interest_df.columns: | |
if st.button(f"Generate 12-Month Forecast for '{keyword_to_forecast}'"): | |
with st.spinner("Calculating future trends..."): | |
prophet_df = interest_df[[keyword_to_forecast]].reset_index() | |
prophet_df.columns = ['ds', 'y'] | |
model = Prophet() | |
model.fit(prophet_df) | |
future = model.make_future_dataframe(periods=365) | |
forecast = model.predict(future) | |
st.success("Forecast generated!") | |
fig_forecast = model.plot(forecast) | |
st.pyplot(fig_forecast) | |
st.subheader("Forecast Data") | |
st.dataframe(forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].tail(12)) | |
else: st.warning("Trend data must be loaded first.") | |
with tab3: | |
st.subheader("AI Marketing Co-pilot") | |
google_api_key = os.environ.get("GOOGLE_API_KEY") | |
if not google_api_key: | |
try: google_api_key = st.secrets.get("GOOGLE_API_KEY") | |
except: google_api_key = None | |
if not google_api_key: st.info("Please add your Google AI API key to your secrets.") | |
else: | |
genai.configure(api_key=google_api_key) | |
model = genai.GenerativeModel('gemini-1.5-flash-latest') | |
data_summary = "" | |
if trends_data: | |
if not trends_data.get("interest_over_time", pd.DataFrame()).empty: data_summary += "Time-series summary:\n" + trends_data["interest_over_time"].describe().to_string() + "\n\n" | |
if not trends_data.get("interest_by_region", pd.DataFrame()).empty: data_summary += "Top 5 regions:\n" + trends_data["interest_by_region"].head().to_string() + "\n\n" | |
if "messages" not in st.session_state: st.session_state.messages = [] | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): st.markdown(message["content"]) | |
if prompt := st.chat_input("Ask about trends or request a marketing campaign..."): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): st.markdown(prompt) | |
full_prompt = f"You are a marketing analyst AI. Based on this data summary:\n{data_summary}\n\nUser's Question: '{prompt}'" | |
with st.chat_message("assistant"): | |
message_placeholder = st.empty() | |
try: | |
response = model.generate_content(full_prompt, stream=True) | |
full_response_text = "" | |
for chunk in response: | |
full_response_text += chunk.text | |
message_placeholder.markdown(full_response_text + "▌") | |
message_placeholder.markdown(full_response_text) | |
st.session_state.messages.append({"role": "assistant", "content": full_response_text}) | |
except Exception as e: | |
st.error(f"An error occurred with the AI model: {e}") | |
else: | |
st.error("Could not load or fetch data.") |