|
import gradio as gr |
|
import json |
|
import os |
|
import torch |
|
import nltk |
|
import spacy |
|
import re |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline, AutoModelForSeq2SeqLM |
|
|
|
|
|
try: |
|
nltk.data.find('tokenizers/punkt') |
|
except LookupError: |
|
nltk.download('punkt') |
|
|
|
SUMMARY_FILE = "training_summary.json" |
|
|
|
LABEL_MAP = {0: "Negative", 1: "Neutral", 2: "Positive"} |
|
|
|
COLOR_MAP = { |
|
"Negative": "red", |
|
"Neutral": "blue", |
|
"Positive": "green" |
|
} |
|
|
|
|
|
loaded_model = None |
|
loaded_tokenizer = None |
|
best_model_summary = None |
|
summarizer = None |
|
nlp = None |
|
|
|
def load_models_and_components(): |
|
global loaded_model, loaded_tokenizer, best_model_summary, summarizer, nlp |
|
|
|
|
|
if not os.path.exists(SUMMARY_FILE): |
|
raise FileNotFoundError(f"Error: Could not find training summary file {SUMMARY_FILE}. Please run the fine-tuning and testing scripts first.") |
|
|
|
with open(SUMMARY_FILE, 'r') as f: |
|
summary_data = json.load(f) |
|
|
|
if "best_model_details" not in summary_data or not summary_data["best_model_details"]: |
|
raise ValueError(f"Error: Best model information not found or incomplete in {SUMMARY_FILE}.") |
|
|
|
best_model_summary = summary_data["best_model_details"] |
|
best_model_path = best_model_summary.get("best_model_path") |
|
|
|
if not best_model_path: |
|
best_model_path = summary_data.get("best_model_path") |
|
|
|
if not best_model_path or not os.path.exists(best_model_path): |
|
raise FileNotFoundError(f"Error: Best model path {best_model_path} not found or invalid.") |
|
|
|
print(f"Loading sentiment model {best_model_summary['model_name']} from {best_model_path}...") |
|
try: |
|
loaded_tokenizer = AutoTokenizer.from_pretrained(best_model_path) |
|
loaded_model = AutoModelForSequenceClassification.from_pretrained(best_model_path) |
|
loaded_model.eval() |
|
print("Sentiment model loaded successfully.") |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to load sentiment model: {e}") |
|
|
|
|
|
print("Loading summarization model...") |
|
try: |
|
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") |
|
print("Summarization model loaded successfully.") |
|
except Exception as e: |
|
print(f"Warning: Failed to load summarization model: {e}") |
|
print("Will continue without summarization capability.") |
|
summarizer = None |
|
|
|
|
|
print("Loading NER model...") |
|
try: |
|
|
|
if not spacy.util.is_package("en_core_web_sm"): |
|
spacy.cli.download("en_core_web_sm") |
|
nlp = spacy.load("en_core_web_sm") |
|
print("NER model loaded successfully.") |
|
except Exception as e: |
|
print(f"Warning: Failed to load NER model: {e}") |
|
print("Will continue without NER capability.") |
|
nlp = None |
|
|
|
def predict_sentiment(text): |
|
"""Predict sentiment for a single piece of text""" |
|
global loaded_model, loaded_tokenizer |
|
if not loaded_model or not loaded_tokenizer: |
|
return "Error: Model not loaded.", None |
|
|
|
if not text or not text.strip(): |
|
return "Please enter text for analysis.", None |
|
|
|
try: |
|
inputs = loaded_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
loaded_model.to(device) |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
outputs = loaded_model(**inputs) |
|
|
|
prediction_idx = torch.argmax(outputs.logits, dim=-1).item() |
|
sentiment = LABEL_MAP.get(prediction_idx, f"Unknown ({prediction_idx})") |
|
return sentiment, prediction_idx |
|
except Exception as e: |
|
print(f"Error during sentiment prediction: {e}") |
|
return f"Error: {str(e)}", None |
|
|
|
def generate_summary(text): |
|
"""Generate a summary for longer text""" |
|
global summarizer |
|
if not summarizer: |
|
return "Summarization model not available." |
|
|
|
if not text or len(text.strip()) < 50: |
|
return "Text too short for summarization." |
|
|
|
try: |
|
|
|
max_length = min(1024, len(text.split())) |
|
summary = summarizer(text, max_length=max_length//4, min_length=30, do_sample=False) |
|
return summary[0]['summary_text'] |
|
except Exception as e: |
|
print(f"Error during summarization: {e}") |
|
return f"Summarization error: {str(e)}" |
|
|
|
def identify_entities(text): |
|
"""Identify locations and organizations in the text""" |
|
global nlp |
|
if not nlp: |
|
return "NER model not available." |
|
|
|
if not text or not text.strip(): |
|
return "Please enter text for entity analysis." |
|
|
|
try: |
|
doc = nlp(text) |
|
locations = [] |
|
organizations = [] |
|
|
|
for ent in doc.ents: |
|
if ent.label_ == "GPE" or ent.label_ == "LOC": |
|
locations.append(ent.text) |
|
elif ent.label_ == "ORG": |
|
organizations.append(ent.text) |
|
|
|
|
|
locations = sorted(list(set(locations))) |
|
organizations = sorted(list(set(organizations))) |
|
|
|
return { |
|
"locations": locations, |
|
"organizations": organizations |
|
} |
|
except Exception as e: |
|
print(f"Error during entity identification: {e}") |
|
return f"Entity identification error: {str(e)}" |
|
|
|
def format_entities(entities): |
|
"""Format identified entities for display""" |
|
if isinstance(entities, str): |
|
return entities |
|
|
|
formatted = "<h3>Interested Parties</h3>" |
|
|
|
|
|
if entities["locations"]: |
|
formatted += "<p><b>Locations:</b> " |
|
formatted += ", ".join([f"<span style='color: red'>{loc}</span>" for loc in entities["locations"]]) |
|
formatted += "</p>" |
|
else: |
|
formatted += "<p><b>Locations:</b> None identified</p>" |
|
|
|
|
|
if entities["organizations"]: |
|
formatted += "<p><b>Organizations:</b> " |
|
formatted += ", ".join([f"<span style='color: green'>{org}</span>" for org in entities["organizations"]]) |
|
formatted += "</p>" |
|
else: |
|
formatted += "<p><b>Organizations:</b> None identified</p>" |
|
|
|
return formatted |
|
|
|
def analyze_text_sentiment_by_sentence(text): |
|
"""Analyze sentiment of each sentence in the text and format with colors""" |
|
if not text or not text.strip(): |
|
return "Please enter text for analysis." |
|
|
|
try: |
|
|
|
sentences = nltk.sent_tokenize(text) |
|
formatted_result = "" |
|
|
|
for sentence in sentences: |
|
if len(sentence.strip()) < 3: |
|
continue |
|
|
|
sentiment, _ = predict_sentiment(sentence) |
|
color = COLOR_MAP.get(sentiment, "black") |
|
|
|
formatted_result += f"<span style='color: {color}'>{sentence}</span> " |
|
|
|
return formatted_result if formatted_result else "No valid sentences found for analysis." |
|
except Exception as e: |
|
print(f"Error during sentence-level sentiment analysis: {e}") |
|
return f"Error: {str(e)}" |
|
|
|
def analyze_financial_text(text): |
|
"""Master function that performs all analysis tasks""" |
|
if not text or not text.strip(): |
|
return "Please enter text for analysis.", "No summary available.", "No entities identified." |
|
|
|
|
|
summary = generate_summary(text) |
|
|
|
|
|
sentiment_analysis = analyze_text_sentiment_by_sentence(text) |
|
|
|
|
|
entities = identify_entities(text) |
|
formatted_entities = format_entities(entities) |
|
|
|
return sentiment_analysis, summary, formatted_entities |
|
|
|
|
|
try: |
|
load_models_and_components() |
|
except Exception as e: |
|
print(f"Initial model loading failed: {e}") |
|
|
|
|
|
|
|
model_info = "### Model Information\n" |
|
if best_model_summary: |
|
model_name = best_model_summary.get("model_name", "N/A") |
|
accuracy = best_model_summary.get("accuracy_percent", "N/A") |
|
run_time = best_model_summary.get("run_time_sec", "N/A") |
|
hyperparams = best_model_summary.get("hyperparameters", {}) |
|
|
|
model_info += f"- **Model Name**: {model_name}\n" |
|
model_info += f"- **Model Accuracy**: {accuracy}%\n" |
|
model_info += f"- **Description**: The model is trained and fine-tuned using the financial news dataset to improve its sensitivity in recognizing financial sentiment.\n" |
|
|
|
|
|
model_info += "\n### Hyperparameters\n" |
|
model_info += f"- **Learning Rate**: {hyperparams.get('learning_rate', 'N/A')}\n" |
|
model_info += f"- **Batch Size**: {hyperparams.get('batch_size', 'N/A')}\n" |
|
model_info += f"- **Number of Epochs**: {hyperparams.get('num_epochs', 'N/A')}\n" |
|
else: |
|
model_info += "Model information loading failed. Please check the `training_summary.json` file and backend logs." |
|
|
|
|
|
app_title = "ISOM5240_financial_tone" |
|
app_description = ( |
|
"Analyze financial news text to extract summary, sentiment, and identify interested parties. " |
|
"The sentiment analysis model is fine-tuned on financial news data." |
|
) |
|
|
|
with gr.Blocks(title=app_title) as iface: |
|
gr.Markdown(f"# {app_title}") |
|
gr.Markdown(app_description) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
input_text = gr.Textbox( |
|
lines=10, |
|
label="Financial News Text", |
|
placeholder="Enter a longer financial news text here for analysis..." |
|
) |
|
analyze_btn = gr.Button("Start Analysis", variant="primary") |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown(model_info) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### Text Summary") |
|
summary_output = gr.Textbox(label="Summary", lines=3) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### Sentiment Analysis (Sentence-level)") |
|
gr.Markdown("- <span style='color: green'>Green</span>: Positive") |
|
gr.Markdown("- <span style='color: blue'>Blue</span>: Neutral") |
|
gr.Markdown("- <span style='color: red'>Red</span>: Negative") |
|
sentiment_output = gr.HTML(label="Sentiment") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
entities_output = gr.HTML(label="Interested Parties") |
|
|
|
|
|
analyze_btn.click( |
|
fn=analyze_financial_text, |
|
inputs=[input_text], |
|
outputs=[sentiment_output, summary_output, entities_output] |
|
) |
|
|
|
|
|
gr.Examples( |
|
[ |
|
["The Federal Reserve announced today that interest rates will remain unchanged. Markets responded positively, with the S&P 500 gaining 1.2%. However, smaller tech companies in Silicon Valley expressed concerns about potential future rate hikes affecting their access to capital."], |
|
["Apple Inc. reported record quarterly revenue of $91.8 billion, an increase of 9% from the year-ago quarter. The company's CEO Tim Cook attributed this success to strong international sales, particularly in European markets and China. However, supply chain disruptions in Taiwan may impact future quarters."] |
|
], |
|
inputs=input_text |
|
) |
|
|
|
if __name__ == "__main__": |
|
print("Starting Gradio application...") |
|
|
|
iface.launch(share=True) |