Spaces:
Configuration error
Configuration error
import streamlit as st | |
import torch | |
import torch.nn as nn | |
from transformers import AutoTokenizer | |
import yfinance as yf | |
import pandas as pd | |
import numpy as np | |
from datetime import datetime, timedelta | |
import plotly.graph_objects as go | |
import plotly.express as px | |
from plotly.subplots import make_subplots | |
import time | |
import logging | |
# Configurazione pagina | |
st.set_page_config( | |
page_title="Financial Transformer Analysis", | |
page_icon="📈", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Configurazione logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Import delle classi dal modulo principale | |
def load_model_components(): | |
"""Carica i componenti del modello con cache""" | |
class MultiLayerSemanticExtractor(nn.Module): | |
def __init__(self, input_dim: int, hidden_dims: list, output_dim: int): | |
super().__init__() | |
self.layers = nn.ModuleList() | |
prev_dim = input_dim | |
for hidden_dim in hidden_dims: | |
self.layers.append(nn.Sequential( | |
nn.Linear(prev_dim, hidden_dim), | |
nn.LayerNorm(hidden_dim), | |
nn.ReLU(), | |
nn.Dropout(0.1) | |
)) | |
prev_dim = hidden_dim | |
self.output_layer = nn.Linear(prev_dim, output_dim) | |
def forward(self, x): | |
layer_outputs = [] | |
for layer in self.layers: | |
x = layer(x) | |
layer_outputs.append(x) | |
final_output = self.output_layer(x) | |
return final_output, layer_outputs | |
class FinancialTransformer(nn.Module): | |
def __init__(self, vocab_size=10000, d_model=512, nhead=8, num_layers=6, | |
feature_dim=6, semantic_dims=[256, 128, 64]): | |
super().__init__() | |
self.d_model = d_model | |
self.feature_dim = feature_dim | |
self.embedding = nn.Embedding(vocab_size, d_model) | |
self.pos_encoding = nn.Parameter(torch.randn(1000, d_model)) | |
encoder_layer = nn.TransformerEncoderLayer( | |
d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4, | |
dropout=0.1, batch_first=True | |
) | |
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers) | |
self.semantic_extractor = MultiLayerSemanticExtractor( | |
input_dim=feature_dim, hidden_dims=semantic_dims, output_dim=d_model | |
) | |
self.feature_projection = nn.Linear(d_model, d_model) | |
self.price_predictor = nn.Linear(d_model, 1) | |
self.trend_classifier = nn.Linear(d_model, 3) | |
self.volatility_predictor = nn.Linear(d_model, 1) | |
def forward(self, text_tokens, financial_features, attention_mask=None): | |
batch_size, seq_len = text_tokens.shape | |
text_emb = self.embedding(text_tokens) | |
pos_emb = self.pos_encoding[:seq_len].unsqueeze(0).repeat(batch_size, 1, 1) | |
text_emb = text_emb + pos_emb | |
financial_emb, semantic_layers = self.semantic_extractor(financial_features) | |
financial_emb = self.feature_projection(financial_emb) | |
if len(financial_emb.shape) == 2: | |
financial_emb = financial_emb.unsqueeze(1).repeat(1, seq_len, 1) | |
combined_emb = text_emb + financial_emb | |
transformer_output = self.transformer(combined_emb, src_key_padding_mask=attention_mask) | |
if attention_mask is not None: | |
mask_expanded = attention_mask.unsqueeze(-1).expand_as(transformer_output) | |
transformer_output = transformer_output * mask_expanded | |
pooled_output = transformer_output.sum(1) / mask_expanded.sum(1) | |
else: | |
pooled_output = transformer_output.mean(1) | |
predictions = { | |
'price_change': self.price_predictor(pooled_output), | |
'trend': self.trend_classifier(pooled_output), | |
'volatility': self.volatility_predictor(pooled_output), | |
'semantic_layers': semantic_layers, | |
'transformer_output': transformer_output | |
} | |
return predictions | |
# Carica tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Inizializza modello | |
model = FinancialTransformer( | |
vocab_size=tokenizer.vocab_size, | |
d_model=512, | |
nhead=8, | |
num_layers=6, | |
feature_dim=6, | |
semantic_dims=[256, 128, 64] | |
) | |
return model, tokenizer | |
def calculate_technical_indicators(data): | |
"""Calcola indicatori tecnici""" | |
indicators = {} | |
# Media mobile semplice | |
indicators['sma_20'] = data['Close'].rolling(window=20).mean().fillna(data['Close'].mean()) | |
indicators['sma_50'] = data['Close'].rolling(window=50).mean().fillna(data['Close'].mean()) | |
# RSI | |
delta = data['Close'].diff() | |
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean() | |
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean() | |
rs = gain / loss | |
indicators['rsi'] = 100 - (100 / (1 + rs)).fillna(50) | |
# Volatilità | |
indicators['volatility'] = data['Close'].rolling(window=20).std().fillna(data['Close'].std()) | |
# Volume relativo | |
indicators['volume_ratio'] = (data['Volume'] / data['Volume'].rolling(window=20).mean()).fillna(1) | |
# Trend | |
indicators['price_change'] = data['Close'].pct_change().fillna(0) | |
return indicators | |
def extract_semantic_features(data, indicators): | |
"""Estrae features semantiche""" | |
features = [] | |
# Normalizza i prezzi | |
price_norm = (data['Close'] - data['Close'].mean()) / (data['Close'].std() + 1e-8) | |
features.append(price_norm.values) | |
# Aggiungi indicatori normalizzati | |
for key, values in indicators.items(): | |
if key == 'rsi': | |
normalized = (values - 50) / 50 | |
else: | |
mean_val = values.mean() | |
std_val = values.std() | |
normalized = (values - mean_val) / (std_val + 1e-8) | |
features.append(normalized.values) | |
feature_matrix = np.column_stack(features) | |
return feature_matrix | |
def create_market_context(symbol, data): | |
"""Crea contesto testuale""" | |
if len(data) < 2: | |
return f"Stock {symbol} trading data available." | |
latest = data.iloc[-1] | |
prev = data.iloc[-2] | |
change = ((latest['Close'] - prev['Close']) / prev['Close']) * 100 | |
direction = "increased" if change > 0 else "decreased" | |
context = f"Stock {symbol} has {direction} by {abs(change):.2f}% " \ | |
f"trading at ${latest['Close']:.2f} with volume {latest['Volume']:,}. " \ | |
f"High: ${latest['High']:.2f}, Low: ${latest['Low']:.2f}." | |
return context | |
def analyze_symbol(symbol, model, tokenizer): | |
"""Analizza un simbolo""" | |
try: | |
# Recupera dati | |
ticker = yf.Ticker(symbol) | |
data = ticker.history(period="5d", interval="1m") | |
if data.empty: | |
return None | |
# Calcola indicatori | |
indicators = calculate_technical_indicators(data) | |
# Estrai features | |
features = extract_semantic_features(data, indicators) | |
# Crea contesto | |
context = create_market_context(symbol, data) | |
# Tokenizza | |
tokens = tokenizer( | |
context, padding=True, truncation=True, | |
max_length=512, return_tensors="pt" | |
) | |
# Features finanziarie | |
financial_features = torch.FloatTensor(features[-1:]) | |
# Predizione | |
model.eval() | |
with torch.no_grad(): | |
predictions = model( | |
tokens['input_ids'], | |
financial_features, | |
attention_mask=tokens['attention_mask'] | |
) | |
# Interpreta risultati | |
price_change = predictions['price_change'].item() | |
trend_probs = torch.softmax(predictions['trend'], dim=1) | |
volatility = predictions['volatility'].item() | |
trend_labels = ['Down', 'Stable', 'Up'] | |
predicted_trend = trend_labels[trend_probs.argmax().item()] | |
return { | |
'symbol': symbol, | |
'current_price': data['Close'].iloc[-1], | |
'predicted_price_change': price_change, | |
'predicted_trend': predicted_trend, | |
'trend_confidence': trend_probs.max().item(), | |
'predicted_volatility': volatility, | |
'market_context': context, | |
'data': data, | |
'indicators': indicators | |
} | |
except Exception as e: | |
st.error(f"Errore nell'analisi di {symbol}: {str(e)}") | |
return None | |
def create_price_chart(data, symbol): | |
"""Crea grafico dei prezzi""" | |
fig = make_subplots( | |
rows=2, cols=1, | |
shared_xaxes=True, | |
vertical_spacing=0.1, | |
subplot_titles=(f'{symbol} Price', 'Volume'), | |
row_width=[0.7, 0.3] | |
) | |
# Candlestick | |
fig.add_trace( | |
go.Candlestick( | |
x=data.index, | |
open=data['Open'], | |
high=data['High'], | |
low=data['Low'], | |
close=data['Close'], | |
name=symbol | |
), | |
row=1, col=1 | |
) | |
# Volume | |
fig.add_trace( | |
go.Bar( | |
x=data.index, | |
y=data['Volume'], | |
name='Volume', | |
marker_color='rgba(0,100,80,0.6)' | |
), | |
row=2, col=1 | |
) | |
fig.update_layout( | |
title=f'{symbol} Real-Time Analysis', | |
xaxis_title='Time', | |
yaxis_title='Price ($)', | |
height=600, | |
showlegend=False | |
) | |
return fig | |
def create_indicators_chart(data, indicators): | |
"""Crea grafico degli indicatori""" | |
fig = make_subplots( | |
rows=2, cols=2, | |
subplot_titles=('RSI', 'Moving Averages', 'Volatility', 'Volume Ratio') | |
) | |
# RSI | |
fig.add_trace( | |
go.Scatter(x=data.index, y=indicators['rsi'], name='RSI'), | |
row=1, col=1 | |
) | |
fig.add_hline(y=70, line_dash="dash", line_color="red", row=1, col=1) | |
fig.add_hline(y=30, line_dash="dash", line_color="green", row=1, col=1) | |
# Moving Averages | |
fig.add_trace( | |
go.Scatter(x=data.index, y=data['Close'], name='Close', line=dict(color='blue')), | |
row=1, col=2 | |
) | |
fig.add_trace( | |
go.Scatter(x=data.index, y=indicators['sma_20'], name='SMA 20', line=dict(color='orange')), | |
row=1, col=2 | |
) | |
fig.add_trace( | |
go.Scatter(x=data.index, y=indicators['sma_50'], name='SMA 50', line=dict(color='red')), | |
row=1, col=2 | |
) | |
# Volatility | |
fig.add_trace( | |
go.Scatter(x=data.index, y=indicators['volatility'], name='Volatility'), | |
row=2, col=1 | |
) | |
# Volume Ratio | |
fig.add_trace( | |
go.Scatter(x=data.index, y=indicators['volume_ratio'], name='Volume Ratio'), | |
row=2, col=2 | |
) | |
fig.add_hline(y=1, line_dash="dash", line_color="gray", row=2, col=2) | |
fig.update_layout(height=600, showlegend=False) | |
return fig | |
# Interfaccia principale | |
def main(): | |
st.title("📈 Financial Transformer Real-Time Analysis") | |
st.markdown("---") | |
# Sidebar | |
st.sidebar.header("⚙️ Configuration") | |
# Selezione simboli | |
popular_symbols = ['AAPL', 'GOOGL', 'MSFT', 'TSLA', 'AMZN', 'META', 'NVDA'] | |
selected_symbols = st.sidebar.multiselect( | |
"Select Symbols", | |
popular_symbols, | |
default=['AAPL', 'GOOGL', 'MSFT'] | |
) | |
# Simbolo custom | |
custom_symbol = st.sidebar.text_input("Custom Symbol (optional)") | |
if custom_symbol: | |
selected_symbols.append(custom_symbol.upper()) | |
# Parametri | |
st.sidebar.subheader("Parameters") | |
update_interval = st.sidebar.slider("Update Interval (seconds)", 30, 300, 60) | |
show_charts = st.sidebar.checkbox("Show Charts", True) | |
show_indicators = st.sidebar.checkbox("Show Technical Indicators", True) | |
# Carica modello | |
with st.spinner("Loading model..."): | |
model, tokenizer = load_model_components() | |
# Pulsante di analisi | |
if st.sidebar.button("🚀 Start Analysis"): | |
if not selected_symbols: | |
st.error("Please select at least one symbol") | |
return | |
# Placeholder per risultati | |
results_placeholder = st.empty() | |
charts_placeholder = st.empty() | |
# Loop di analisi | |
for iteration in range(10): # Limitato per demo | |
st.subheader(f"Analysis Iteration {iteration + 1}") | |
results = [] | |
# Analizza ogni simbolo | |
for symbol in selected_symbols: | |
with st.spinner(f"Analyzing {symbol}..."): | |
result = analyze_symbol(symbol, model, tokenizer) | |
if result: | |
results.append(result) | |
if results: | |
# Mostra risultati in tabella | |
results_df = pd.DataFrame([{ | |
'Symbol': r['symbol'], | |
'Current Price': f"${r['current_price']:.2f}", | |
'Predicted Change': f"{r['predicted_price_change']:.4f}", | |
'Trend': r['predicted_trend'], | |
'Confidence': f"{r['trend_confidence']:.2f}", | |
'Volatility': f"{r['predicted_volatility']:.4f}" | |
} for r in results]) | |
st.table(results_df) | |
# Mostra grafici se richiesto | |
if show_charts: | |
cols = st.columns(len(results)) | |
for i, result in enumerate(results): | |
with cols[i]: | |
st.plotly_chart( | |
create_price_chart(result['data'], result['symbol']), | |
use_container_width=True | |
) | |
# Mostra indicatori tecnici | |
if show_indicators: | |
st.subheader("📊 Technical Indicators") | |
for result in results: | |
st.subheader(f"{result['symbol']} Indicators") | |
st.plotly_chart( | |
create_indicators_chart(result['data'], result['indicators']), | |
use_container_width=True | |
) | |
# Attendi prossimo update | |
if iteration < 9: # Non aspettare nell'ultima iterazione | |
time.sleep(update_interval) | |
# Informazioni | |
st.sidebar.markdown("---") | |
st.sidebar.markdown("### ℹ️ About") | |
st.sidebar.markdown(""" | |
Questo strumento utilizza un transformer multi-layer per analizzare | |
dati finanziari in tempo reale e generare predizioni. | |
**Features:** | |
- Analisi semantica multi-layer | |
- Indicatori tecnici avanzati | |
- Predizioni trend e volatilità | |
- Visualizzazioni interattive | |
""") | |
# Disclaimer | |
st.sidebar.markdown("---") | |
st.sidebar.markdown("### ⚠️ Disclaimer") | |
st.sidebar.markdown(""" | |
**ATTENZIONE**: Questo strumento è solo per scopi educativi. | |
Non costituisce consulenza finanziaria. Gli investimenti comportano rischi. | |
""") | |
if __name__ == "__main__": | |
main() |