import gradio as gr import json import os import torch import nltk import spacy import re from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline, AutoModelForSeq2SeqLM # Download necessary NLTK data for sentence tokenizations try: nltk.data.find('tokenizers/punkt') except LookupError: nltk.download('punkt') SUMMARY_FILE = "training_summary.json" # Assume label meanings are consistent with previous files LABEL_MAP = {0: "Negative", 1: "Neutral", 2: "Positive"} # Color coding for sentiment COLOR_MAP = { "Negative": "red", "Neutral": "blue", "Positive": "green" } # Global loading of models and NLP components loaded_model = None loaded_tokenizer = None best_model_summary = None summarizer = None nlp = None # For NER def load_models_and_components(): global loaded_model, loaded_tokenizer, best_model_summary, summarizer, nlp # Load sentiment analysis model from training if not os.path.exists(SUMMARY_FILE): raise FileNotFoundError(f"Error: Could not find training summary file {SUMMARY_FILE}. Please run the fine-tuning and testing scripts first.") with open(SUMMARY_FILE, 'r') as f: summary_data = json.load(f) if "best_model_details" not in summary_data or not summary_data["best_model_details"]: raise ValueError(f"Error: Best model information not found or incomplete in {SUMMARY_FILE}.") best_model_summary = summary_data["best_model_details"] best_model_path = best_model_summary.get("best_model_path") if not best_model_path: best_model_path = summary_data.get("best_model_path") # Compatible with older format if not best_model_path or not os.path.exists(best_model_path): raise FileNotFoundError(f"Error: Best model path {best_model_path} not found or invalid.") print(f"Loading sentiment model {best_model_summary['model_name']} from {best_model_path}...") try: loaded_tokenizer = AutoTokenizer.from_pretrained(best_model_path) loaded_model = AutoModelForSequenceClassification.from_pretrained(best_model_path) loaded_model.eval() # Set to evaluation mode print("Sentiment model loaded successfully.") except Exception as e: raise RuntimeError(f"Failed to load sentiment model: {e}") # Load summarization model print("Loading summarization model...") try: summarizer = pipeline("summarization", model="facebook/bart-large-cnn") print("Summarization model loaded successfully.") except Exception as e: print(f"Warning: Failed to load summarization model: {e}") print("Will continue without summarization capability.") summarizer = None # Load spaCy model for NER (Named Entity Recognition) print("Loading NER model...") try: # Download the model if it's not already downloaded if not spacy.util.is_package("en_core_web_sm"): spacy.cli.download("en_core_web_sm") nlp = spacy.load("en_core_web_sm") print("NER model loaded successfully.") except Exception as e: print(f"Warning: Failed to load NER model: {e}") print("Will continue without NER capability.") nlp = None def predict_sentiment(text): """Predict sentiment for a single piece of text""" global loaded_model, loaded_tokenizer if not loaded_model or not loaded_tokenizer: return "Error: Model not loaded.", None if not text or not text.strip(): return "Please enter text for analysis.", None try: inputs = loaded_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") loaded_model.to(device) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = loaded_model(**inputs) prediction_idx = torch.argmax(outputs.logits, dim=-1).item() sentiment = LABEL_MAP.get(prediction_idx, f"Unknown ({prediction_idx})") return sentiment, prediction_idx except Exception as e: print(f"Error during sentiment prediction: {e}") return f"Error: {str(e)}", None def generate_summary(text): """Generate a summary for longer text""" global summarizer if not summarizer: return "Summarization model not available." if not text or len(text.strip()) < 50: return "Text too short for summarization." try: # BART has a max length, so we'll truncate if needed max_length = min(1024, len(text.split())) summary = summarizer(text, max_length=max_length//4, min_length=30, do_sample=False) return summary[0]['summary_text'] except Exception as e: print(f"Error during summarization: {e}") return f"Summarization error: {str(e)}" def identify_entities(text): """Identify locations and organizations in the text""" global nlp if not nlp: return "NER model not available." if not text or not text.strip(): return "Please enter text for entity analysis." try: doc = nlp(text) locations = [] organizations = [] for ent in doc.ents: if ent.label_ == "GPE" or ent.label_ == "LOC": # Geopolitical entity or Location locations.append(ent.text) elif ent.label_ == "ORG": # Organization organizations.append(ent.text) # Remove duplicates and sort locations = sorted(list(set(locations))) organizations = sorted(list(set(organizations))) return { "locations": locations, "organizations": organizations } except Exception as e: print(f"Error during entity identification: {e}") return f"Entity identification error: {str(e)}" def format_entities(entities): """Format identified entities for display""" if isinstance(entities, str): # Error message return entities formatted = "

Interested Parties

" # Add locations in red if entities["locations"]: formatted += "

Locations: " formatted += ", ".join([f"{loc}" for loc in entities["locations"]]) formatted += "

" else: formatted += "

Locations: None identified

" # Add organizations in green if entities["organizations"]: formatted += "

Organizations: " formatted += ", ".join([f"{org}" for org in entities["organizations"]]) formatted += "

" else: formatted += "

Organizations: None identified

" return formatted def analyze_text_sentiment_by_sentence(text): """Analyze sentiment of each sentence in the text and format with colors""" if not text or not text.strip(): return "Please enter text for analysis." try: # Split text into sentences sentences = nltk.sent_tokenize(text) formatted_result = "" for sentence in sentences: if len(sentence.strip()) < 3: # Skip very short sentences continue sentiment, _ = predict_sentiment(sentence) color = COLOR_MAP.get(sentiment, "black") formatted_result += f"{sentence} " return formatted_result if formatted_result else "No valid sentences found for analysis." except Exception as e: print(f"Error during sentence-level sentiment analysis: {e}") return f"Error: {str(e)}" def analyze_financial_text(text): """Master function that performs all analysis tasks""" if not text or not text.strip(): return "Please enter text for analysis.", "No summary available.", "No entities identified." # Generate summary summary = generate_summary(text) # Perform sentence-level sentiment analysis sentiment_analysis = analyze_text_sentiment_by_sentence(text) # Identify entities entities = identify_entities(text) formatted_entities = format_entities(entities) return sentiment_analysis, summary, formatted_entities # Try to load models at app startup try: load_models_and_components() except Exception as e: print(f"Initial model loading failed: {e}") # Gradio interface will still start, but functionality will be limited # Build Gradio interface model_info = "### Model Information\n" if best_model_summary: model_name = best_model_summary.get("model_name", "N/A") accuracy = best_model_summary.get("accuracy_percent", "N/A") run_time = best_model_summary.get("run_time_sec", "N/A") hyperparams = best_model_summary.get("hyperparameters", {}) model_info += f"- **Model Name**: {model_name}\n" model_info += f"- **Model Accuracy**: {accuracy}%\n" model_info += f"- **Description**: The model is trained and fine-tuned using the financial news dataset to improve its sensitivity in recognizing financial sentiment.\n" # Add hyperparameters model_info += "\n### Hyperparameters\n" model_info += f"- **Learning Rate**: {hyperparams.get('learning_rate', 'N/A')}\n" model_info += f"- **Batch Size**: {hyperparams.get('batch_size', 'N/A')}\n" model_info += f"- **Number of Epochs**: {hyperparams.get('num_epochs', 'N/A')}\n" else: model_info += "Model information loading failed. Please check the `training_summary.json` file and backend logs." # Gradio interface definition app_title = "ISOM5240_financial_tone" app_description = ( "Analyze financial news text to extract summary, sentiment, and identify interested parties. " "The sentiment analysis model is fine-tuned on financial news data." ) with gr.Blocks(title=app_title) as iface: gr.Markdown(f"# {app_title}") gr.Markdown(app_description) with gr.Row(): with gr.Column(scale=2): input_text = gr.Textbox( lines=10, label="Financial News Text", placeholder="Enter a longer financial news text here for analysis..." ) analyze_btn = gr.Button("Start Analysis", variant="primary") with gr.Column(scale=1): gr.Markdown(model_info) with gr.Row(): with gr.Column(): gr.Markdown("### Text Summary") summary_output = gr.Textbox(label="Summary", lines=3) with gr.Row(): with gr.Column(): gr.Markdown("### Sentiment Analysis (Sentence-level)") gr.Markdown("- Green: Positive") gr.Markdown("- Blue: Neutral") gr.Markdown("- Red: Negative") sentiment_output = gr.HTML(label="Sentiment") with gr.Row(): with gr.Column(): entities_output = gr.HTML(label="Interested Parties") # Set up the click event for the analyze button analyze_btn.click( fn=analyze_financial_text, inputs=[input_text], outputs=[sentiment_output, summary_output, entities_output] ) # Add examples gr.Examples( [ ["The Federal Reserve announced today that interest rates will remain unchanged. Markets responded positively, with the S&P 500 gaining 1.2%. However, smaller tech companies in Silicon Valley expressed concerns about potential future rate hikes affecting their access to capital."], ["Apple Inc. reported record quarterly revenue of $91.8 billion, an increase of 9% from the year-ago quarter. The company's CEO Tim Cook attributed this success to strong international sales, particularly in European markets and China. However, supply chain disruptions in Taiwan may impact future quarters."] ], inputs=input_text ) if __name__ == "__main__": print("Starting Gradio application...") # share=True will generate a public link iface.launch(share=True)