aifinancegen / app.py
mset's picture
Update app.py
654fd3c verified
import streamlit as st
import torch
import torch.nn as nn
from transformers import AutoTokenizer
import yfinance as yf
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import time
import logging
# Configurazione pagina
st.set_page_config(
page_title="Financial Transformer Analysis",
page_icon="📈",
layout="wide",
initial_sidebar_state="expanded"
)
# Configurazione logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Import delle classi dal modulo principale
@st.cache_resource
def load_model_components():
"""Carica i componenti del modello con cache"""
class MultiLayerSemanticExtractor(nn.Module):
def __init__(self, input_dim: int, hidden_dims: list, output_dim: int):
super().__init__()
self.layers = nn.ModuleList()
prev_dim = input_dim
for hidden_dim in hidden_dims:
self.layers.append(nn.Sequential(
nn.Linear(prev_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(0.1)
))
prev_dim = hidden_dim
self.output_layer = nn.Linear(prev_dim, output_dim)
def forward(self, x):
layer_outputs = []
for layer in self.layers:
x = layer(x)
layer_outputs.append(x)
final_output = self.output_layer(x)
return final_output, layer_outputs
class FinancialTransformer(nn.Module):
def __init__(self, vocab_size=10000, d_model=512, nhead=8, num_layers=6,
feature_dim=6, semantic_dims=[256, 128, 64]):
super().__init__()
self.d_model = d_model
self.feature_dim = feature_dim
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = nn.Parameter(torch.randn(1000, d_model))
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
dropout=0.1, batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
self.semantic_extractor = MultiLayerSemanticExtractor(
input_dim=feature_dim, hidden_dims=semantic_dims, output_dim=d_model
)
self.feature_projection = nn.Linear(d_model, d_model)
self.price_predictor = nn.Linear(d_model, 1)
self.trend_classifier = nn.Linear(d_model, 3)
self.volatility_predictor = nn.Linear(d_model, 1)
def forward(self, text_tokens, financial_features, attention_mask=None):
batch_size, seq_len = text_tokens.shape
text_emb = self.embedding(text_tokens)
pos_emb = self.pos_encoding[:seq_len].unsqueeze(0).repeat(batch_size, 1, 1)
text_emb = text_emb + pos_emb
financial_emb, semantic_layers = self.semantic_extractor(financial_features)
financial_emb = self.feature_projection(financial_emb)
if len(financial_emb.shape) == 2:
financial_emb = financial_emb.unsqueeze(1).repeat(1, seq_len, 1)
combined_emb = text_emb + financial_emb
transformer_output = self.transformer(combined_emb, src_key_padding_mask=attention_mask)
if attention_mask is not None:
mask_expanded = attention_mask.unsqueeze(-1).expand_as(transformer_output)
transformer_output = transformer_output * mask_expanded
pooled_output = transformer_output.sum(1) / mask_expanded.sum(1)
else:
pooled_output = transformer_output.mean(1)
predictions = {
'price_change': self.price_predictor(pooled_output),
'trend': self.trend_classifier(pooled_output),
'volatility': self.volatility_predictor(pooled_output),
'semantic_layers': semantic_layers,
'transformer_output': transformer_output
}
return predictions
# Carica tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Inizializza modello
model = FinancialTransformer(
vocab_size=tokenizer.vocab_size,
d_model=512,
nhead=8,
num_layers=6,
feature_dim=6,
semantic_dims=[256, 128, 64]
)
return model, tokenizer
def calculate_technical_indicators(data):
"""Calcola indicatori tecnici"""
indicators = {}
# Media mobile semplice
indicators['sma_20'] = data['Close'].rolling(window=20).mean().fillna(data['Close'].mean())
indicators['sma_50'] = data['Close'].rolling(window=50).mean().fillna(data['Close'].mean())
# RSI
delta = data['Close'].diff()
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
rs = gain / loss
indicators['rsi'] = 100 - (100 / (1 + rs)).fillna(50)
# Volatilità
indicators['volatility'] = data['Close'].rolling(window=20).std().fillna(data['Close'].std())
# Volume relativo
indicators['volume_ratio'] = (data['Volume'] / data['Volume'].rolling(window=20).mean()).fillna(1)
# Trend
indicators['price_change'] = data['Close'].pct_change().fillna(0)
return indicators
def extract_semantic_features(data, indicators):
"""Estrae features semantiche"""
features = []
# Normalizza i prezzi
price_norm = (data['Close'] - data['Close'].mean()) / (data['Close'].std() + 1e-8)
features.append(price_norm.values)
# Aggiungi indicatori normalizzati
for key, values in indicators.items():
if key == 'rsi':
normalized = (values - 50) / 50
else:
mean_val = values.mean()
std_val = values.std()
normalized = (values - mean_val) / (std_val + 1e-8)
features.append(normalized.values)
feature_matrix = np.column_stack(features)
return feature_matrix
def create_market_context(symbol, data):
"""Crea contesto testuale"""
if len(data) < 2:
return f"Stock {symbol} trading data available."
latest = data.iloc[-1]
prev = data.iloc[-2]
change = ((latest['Close'] - prev['Close']) / prev['Close']) * 100
direction = "increased" if change > 0 else "decreased"
context = f"Stock {symbol} has {direction} by {abs(change):.2f}% " \
f"trading at ${latest['Close']:.2f} with volume {latest['Volume']:,}. " \
f"High: ${latest['High']:.2f}, Low: ${latest['Low']:.2f}."
return context
def analyze_symbol(symbol, model, tokenizer):
"""Analizza un simbolo"""
try:
# Recupera dati
ticker = yf.Ticker(symbol)
data = ticker.history(period="5d", interval="1m")
if data.empty:
return None
# Calcola indicatori
indicators = calculate_technical_indicators(data)
# Estrai features
features = extract_semantic_features(data, indicators)
# Crea contesto
context = create_market_context(symbol, data)
# Tokenizza
tokens = tokenizer(
context, padding=True, truncation=True,
max_length=512, return_tensors="pt"
)
# Features finanziarie
financial_features = torch.FloatTensor(features[-1:])
# Predizione
model.eval()
with torch.no_grad():
predictions = model(
tokens['input_ids'],
financial_features,
attention_mask=tokens['attention_mask']
)
# Interpreta risultati
price_change = predictions['price_change'].item()
trend_probs = torch.softmax(predictions['trend'], dim=1)
volatility = predictions['volatility'].item()
trend_labels = ['Down', 'Stable', 'Up']
predicted_trend = trend_labels[trend_probs.argmax().item()]
return {
'symbol': symbol,
'current_price': data['Close'].iloc[-1],
'predicted_price_change': price_change,
'predicted_trend': predicted_trend,
'trend_confidence': trend_probs.max().item(),
'predicted_volatility': volatility,
'market_context': context,
'data': data,
'indicators': indicators
}
except Exception as e:
st.error(f"Errore nell'analisi di {symbol}: {str(e)}")
return None
def create_price_chart(data, symbol):
"""Crea grafico dei prezzi"""
fig = make_subplots(
rows=2, cols=1,
shared_xaxes=True,
vertical_spacing=0.1,
subplot_titles=(f'{symbol} Price', 'Volume'),
row_width=[0.7, 0.3]
)
# Candlestick
fig.add_trace(
go.Candlestick(
x=data.index,
open=data['Open'],
high=data['High'],
low=data['Low'],
close=data['Close'],
name=symbol
),
row=1, col=1
)
# Volume
fig.add_trace(
go.Bar(
x=data.index,
y=data['Volume'],
name='Volume',
marker_color='rgba(0,100,80,0.6)'
),
row=2, col=1
)
fig.update_layout(
title=f'{symbol} Real-Time Analysis',
xaxis_title='Time',
yaxis_title='Price ($)',
height=600,
showlegend=False
)
return fig
def create_indicators_chart(data, indicators):
"""Crea grafico degli indicatori"""
fig = make_subplots(
rows=2, cols=2,
subplot_titles=('RSI', 'Moving Averages', 'Volatility', 'Volume Ratio')
)
# RSI
fig.add_trace(
go.Scatter(x=data.index, y=indicators['rsi'], name='RSI'),
row=1, col=1
)
fig.add_hline(y=70, line_dash="dash", line_color="red", row=1, col=1)
fig.add_hline(y=30, line_dash="dash", line_color="green", row=1, col=1)
# Moving Averages
fig.add_trace(
go.Scatter(x=data.index, y=data['Close'], name='Close', line=dict(color='blue')),
row=1, col=2
)
fig.add_trace(
go.Scatter(x=data.index, y=indicators['sma_20'], name='SMA 20', line=dict(color='orange')),
row=1, col=2
)
fig.add_trace(
go.Scatter(x=data.index, y=indicators['sma_50'], name='SMA 50', line=dict(color='red')),
row=1, col=2
)
# Volatility
fig.add_trace(
go.Scatter(x=data.index, y=indicators['volatility'], name='Volatility'),
row=2, col=1
)
# Volume Ratio
fig.add_trace(
go.Scatter(x=data.index, y=indicators['volume_ratio'], name='Volume Ratio'),
row=2, col=2
)
fig.add_hline(y=1, line_dash="dash", line_color="gray", row=2, col=2)
fig.update_layout(height=600, showlegend=False)
return fig
# Interfaccia principale
def main():
st.title("📈 Financial Transformer Real-Time Analysis")
st.markdown("---")
# Sidebar
st.sidebar.header("⚙️ Configuration")
# Selezione simboli
popular_symbols = ['AAPL', 'GOOGL', 'MSFT', 'TSLA', 'AMZN', 'META', 'NVDA']
selected_symbols = st.sidebar.multiselect(
"Select Symbols",
popular_symbols,
default=['AAPL', 'GOOGL', 'MSFT']
)
# Simbolo custom
custom_symbol = st.sidebar.text_input("Custom Symbol (optional)")
if custom_symbol:
selected_symbols.append(custom_symbol.upper())
# Parametri
st.sidebar.subheader("Parameters")
update_interval = st.sidebar.slider("Update Interval (seconds)", 30, 300, 60)
show_charts = st.sidebar.checkbox("Show Charts", True)
show_indicators = st.sidebar.checkbox("Show Technical Indicators", True)
# Carica modello
with st.spinner("Loading model..."):
model, tokenizer = load_model_components()
# Pulsante di analisi
if st.sidebar.button("🚀 Start Analysis"):
if not selected_symbols:
st.error("Please select at least one symbol")
return
# Placeholder per risultati
results_placeholder = st.empty()
charts_placeholder = st.empty()
# Loop di analisi
for iteration in range(10): # Limitato per demo
st.subheader(f"Analysis Iteration {iteration + 1}")
results = []
# Analizza ogni simbolo
for symbol in selected_symbols:
with st.spinner(f"Analyzing {symbol}..."):
result = analyze_symbol(symbol, model, tokenizer)
if result:
results.append(result)
if results:
# Mostra risultati in tabella
results_df = pd.DataFrame([{
'Symbol': r['symbol'],
'Current Price': f"${r['current_price']:.2f}",
'Predicted Change': f"{r['predicted_price_change']:.4f}",
'Trend': r['predicted_trend'],
'Confidence': f"{r['trend_confidence']:.2f}",
'Volatility': f"{r['predicted_volatility']:.4f}"
} for r in results])
st.table(results_df)
# Mostra grafici se richiesto
if show_charts:
cols = st.columns(len(results))
for i, result in enumerate(results):
with cols[i]:
st.plotly_chart(
create_price_chart(result['data'], result['symbol']),
use_container_width=True
)
# Mostra indicatori tecnici
if show_indicators:
st.subheader("📊 Technical Indicators")
for result in results:
st.subheader(f"{result['symbol']} Indicators")
st.plotly_chart(
create_indicators_chart(result['data'], result['indicators']),
use_container_width=True
)
# Attendi prossimo update
if iteration < 9: # Non aspettare nell'ultima iterazione
time.sleep(update_interval)
# Informazioni
st.sidebar.markdown("---")
st.sidebar.markdown("### ℹ️ About")
st.sidebar.markdown("""
Questo strumento utilizza un transformer multi-layer per analizzare
dati finanziari in tempo reale e generare predizioni.
**Features:**
- Analisi semantica multi-layer
- Indicatori tecnici avanzati
- Predizioni trend e volatilità
- Visualizzazioni interattive
""")
# Disclaimer
st.sidebar.markdown("---")
st.sidebar.markdown("### ⚠️ Disclaimer")
st.sidebar.markdown("""
**ATTENZIONE**: Questo strumento è solo per scopi educativi.
Non costituisce consulenza finanziaria. Gli investimenti comportano rischi.
""")
if __name__ == "__main__":
main()