test / app.py
ylingag's picture
Update app.py
49394d0 verified
import gradio as gr
import json
import os
import torch
import nltk
import spacy
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline, AutoModelForSeq2SeqLM
# Download necessary NLTK data for sentence tokenizations
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
SUMMARY_FILE = "training_summary.json"
# Assume label meanings are consistent with previous files
LABEL_MAP = {0: "Negative", 1: "Neutral", 2: "Positive"}
# Color coding for sentiment
COLOR_MAP = {
"Negative": "red",
"Neutral": "blue",
"Positive": "green"
}
# Global loading of models and NLP components
loaded_model = None
loaded_tokenizer = None
best_model_summary = None
summarizer = None
nlp = None # For NER
def load_models_and_components():
global loaded_model, loaded_tokenizer, best_model_summary, summarizer, nlp
# Load sentiment analysis model from training
if not os.path.exists(SUMMARY_FILE):
raise FileNotFoundError(f"Error: Could not find training summary file {SUMMARY_FILE}. Please run the fine-tuning and testing scripts first.")
with open(SUMMARY_FILE, 'r') as f:
summary_data = json.load(f)
if "best_model_details" not in summary_data or not summary_data["best_model_details"]:
raise ValueError(f"Error: Best model information not found or incomplete in {SUMMARY_FILE}.")
best_model_summary = summary_data["best_model_details"]
best_model_path = best_model_summary.get("best_model_path")
if not best_model_path:
best_model_path = summary_data.get("best_model_path") # Compatible with older format
if not best_model_path or not os.path.exists(best_model_path):
raise FileNotFoundError(f"Error: Best model path {best_model_path} not found or invalid.")
print(f"Loading sentiment model {best_model_summary['model_name']} from {best_model_path}...")
try:
loaded_tokenizer = AutoTokenizer.from_pretrained(best_model_path)
loaded_model = AutoModelForSequenceClassification.from_pretrained(best_model_path)
loaded_model.eval() # Set to evaluation mode
print("Sentiment model loaded successfully.")
except Exception as e:
raise RuntimeError(f"Failed to load sentiment model: {e}")
# Load summarization model
print("Loading summarization model...")
try:
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
print("Summarization model loaded successfully.")
except Exception as e:
print(f"Warning: Failed to load summarization model: {e}")
print("Will continue without summarization capability.")
summarizer = None
# Load spaCy model for NER (Named Entity Recognition)
print("Loading NER model...")
try:
# Download the model if it's not already downloaded
if not spacy.util.is_package("en_core_web_sm"):
spacy.cli.download("en_core_web_sm")
nlp = spacy.load("en_core_web_sm")
print("NER model loaded successfully.")
except Exception as e:
print(f"Warning: Failed to load NER model: {e}")
print("Will continue without NER capability.")
nlp = None
def predict_sentiment(text):
"""Predict sentiment for a single piece of text"""
global loaded_model, loaded_tokenizer
if not loaded_model or not loaded_tokenizer:
return "Error: Model not loaded.", None
if not text or not text.strip():
return "Please enter text for analysis.", None
try:
inputs = loaded_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loaded_model.to(device)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = loaded_model(**inputs)
prediction_idx = torch.argmax(outputs.logits, dim=-1).item()
sentiment = LABEL_MAP.get(prediction_idx, f"Unknown ({prediction_idx})")
return sentiment, prediction_idx
except Exception as e:
print(f"Error during sentiment prediction: {e}")
return f"Error: {str(e)}", None
def generate_summary(text):
"""Generate a summary for longer text"""
global summarizer
if not summarizer:
return "Summarization model not available."
if not text or len(text.strip()) < 50:
return "Text too short for summarization."
try:
# BART has a max length, so we'll truncate if needed
max_length = min(1024, len(text.split()))
summary = summarizer(text, max_length=max_length//4, min_length=30, do_sample=False)
return summary[0]['summary_text']
except Exception as e:
print(f"Error during summarization: {e}")
return f"Summarization error: {str(e)}"
def identify_entities(text):
"""Identify locations and organizations in the text"""
global nlp
if not nlp:
return "NER model not available."
if not text or not text.strip():
return "Please enter text for entity analysis."
try:
doc = nlp(text)
locations = []
organizations = []
for ent in doc.ents:
if ent.label_ == "GPE" or ent.label_ == "LOC": # Geopolitical entity or Location
locations.append(ent.text)
elif ent.label_ == "ORG": # Organization
organizations.append(ent.text)
# Remove duplicates and sort
locations = sorted(list(set(locations)))
organizations = sorted(list(set(organizations)))
return {
"locations": locations,
"organizations": organizations
}
except Exception as e:
print(f"Error during entity identification: {e}")
return f"Entity identification error: {str(e)}"
def format_entities(entities):
"""Format identified entities for display"""
if isinstance(entities, str): # Error message
return entities
formatted = "<h3>Interested Parties</h3>"
# Add locations in red
if entities["locations"]:
formatted += "<p><b>Locations:</b> "
formatted += ", ".join([f"<span style='color: red'>{loc}</span>" for loc in entities["locations"]])
formatted += "</p>"
else:
formatted += "<p><b>Locations:</b> None identified</p>"
# Add organizations in green
if entities["organizations"]:
formatted += "<p><b>Organizations:</b> "
formatted += ", ".join([f"<span style='color: green'>{org}</span>" for org in entities["organizations"]])
formatted += "</p>"
else:
formatted += "<p><b>Organizations:</b> None identified</p>"
return formatted
def analyze_text_sentiment_by_sentence(text):
"""Analyze sentiment of each sentence in the text and format with colors"""
if not text or not text.strip():
return "Please enter text for analysis."
try:
# Split text into sentences
sentences = nltk.sent_tokenize(text)
formatted_result = ""
for sentence in sentences:
if len(sentence.strip()) < 3: # Skip very short sentences
continue
sentiment, _ = predict_sentiment(sentence)
color = COLOR_MAP.get(sentiment, "black")
formatted_result += f"<span style='color: {color}'>{sentence}</span> "
return formatted_result if formatted_result else "No valid sentences found for analysis."
except Exception as e:
print(f"Error during sentence-level sentiment analysis: {e}")
return f"Error: {str(e)}"
def analyze_financial_text(text):
"""Master function that performs all analysis tasks"""
if not text or not text.strip():
return "Please enter text for analysis.", "No summary available.", "No entities identified."
# Generate summary
summary = generate_summary(text)
# Perform sentence-level sentiment analysis
sentiment_analysis = analyze_text_sentiment_by_sentence(text)
# Identify entities
entities = identify_entities(text)
formatted_entities = format_entities(entities)
return sentiment_analysis, summary, formatted_entities
# Try to load models at app startup
try:
load_models_and_components()
except Exception as e:
print(f"Initial model loading failed: {e}")
# Gradio interface will still start, but functionality will be limited
# Build Gradio interface
model_info = "### Model Information\n"
if best_model_summary:
model_name = best_model_summary.get("model_name", "N/A")
accuracy = best_model_summary.get("accuracy_percent", "N/A")
run_time = best_model_summary.get("run_time_sec", "N/A")
hyperparams = best_model_summary.get("hyperparameters", {})
model_info += f"- **Model Name**: {model_name}\n"
model_info += f"- **Model Accuracy**: {accuracy}%\n"
model_info += f"- **Description**: The model is trained and fine-tuned using the financial news dataset to improve its sensitivity in recognizing financial sentiment.\n"
# Add hyperparameters
model_info += "\n### Hyperparameters\n"
model_info += f"- **Learning Rate**: {hyperparams.get('learning_rate', 'N/A')}\n"
model_info += f"- **Batch Size**: {hyperparams.get('batch_size', 'N/A')}\n"
model_info += f"- **Number of Epochs**: {hyperparams.get('num_epochs', 'N/A')}\n"
else:
model_info += "Model information loading failed. Please check the `training_summary.json` file and backend logs."
# Gradio interface definition
app_title = "ISOM5240_financial_tone"
app_description = (
"Analyze financial news text to extract summary, sentiment, and identify interested parties. "
"The sentiment analysis model is fine-tuned on financial news data."
)
with gr.Blocks(title=app_title) as iface:
gr.Markdown(f"# {app_title}")
gr.Markdown(app_description)
with gr.Row():
with gr.Column(scale=2):
input_text = gr.Textbox(
lines=10,
label="Financial News Text",
placeholder="Enter a longer financial news text here for analysis..."
)
analyze_btn = gr.Button("Start Analysis", variant="primary")
with gr.Column(scale=1):
gr.Markdown(model_info)
with gr.Row():
with gr.Column():
gr.Markdown("### Text Summary")
summary_output = gr.Textbox(label="Summary", lines=3)
with gr.Row():
with gr.Column():
gr.Markdown("### Sentiment Analysis (Sentence-level)")
gr.Markdown("- <span style='color: green'>Green</span>: Positive")
gr.Markdown("- <span style='color: blue'>Blue</span>: Neutral")
gr.Markdown("- <span style='color: red'>Red</span>: Negative")
sentiment_output = gr.HTML(label="Sentiment")
with gr.Row():
with gr.Column():
entities_output = gr.HTML(label="Interested Parties")
# Set up the click event for the analyze button
analyze_btn.click(
fn=analyze_financial_text,
inputs=[input_text],
outputs=[sentiment_output, summary_output, entities_output]
)
# Add examples
gr.Examples(
[
["The Federal Reserve announced today that interest rates will remain unchanged. Markets responded positively, with the S&P 500 gaining 1.2%. However, smaller tech companies in Silicon Valley expressed concerns about potential future rate hikes affecting their access to capital."],
["Apple Inc. reported record quarterly revenue of $91.8 billion, an increase of 9% from the year-ago quarter. The company's CEO Tim Cook attributed this success to strong international sales, particularly in European markets and China. However, supply chain disruptions in Taiwan may impact future quarters."]
],
inputs=input_text
)
if __name__ == "__main__":
print("Starting Gradio application...")
# share=True will generate a public link
iface.launch(share=True)