Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import torch | |
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, AutoModel | |
import plotly.graph_objects as go | |
import logging | |
import io | |
from rapidfuzz import fuzz | |
import time | |
import os | |
from typing import List, Set, Tuple | |
import asyncio | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class GPUTaskManager: | |
def __init__(self, max_retries=3, retry_delay=30, cleanup_callback=None): | |
self.max_retries = max_retries | |
self.retry_delay = retry_delay | |
self.cleanup_callback = cleanup_callback | |
async def run_with_retry(self, task_func, *args, **kwargs): | |
for attempt in range(self.max_retries): | |
try: | |
return await task_func(*args, **kwargs) | |
except Exception as e: | |
if "CUDA out of memory" in str(e) or "GPU quota" in str(e): | |
if attempt < self.max_retries - 1: | |
if self.cleanup_callback: | |
self.cleanup_callback() | |
torch.cuda.empty_cache() | |
await asyncio.sleep(self.retry_delay) | |
continue | |
raise | |
def batch_process(items, batch_size=3): | |
return [items[i:i + batch_size] for i in range(0, len(items), batch_size)] | |
class ProcessControl: | |
def __init__(self): | |
self.stop_requested = False | |
self.error = None | |
def request_stop(self): | |
self.stop_requested = True | |
def should_stop(self): | |
return self.stop_requested | |
def reset(self): | |
self.stop_requested = False | |
self.error = None | |
def set_error(self, error): | |
self.error = error | |
self.stop_requested = True | |
class EventDetector: | |
def __init__(self): | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Initializing models on device: {device}") | |
self.device = device | |
self.initialize_models() | |
# Initialize transformer for declusterization | |
self.tokenizer_cluster = AutoTokenizer.from_pretrained('sentence-transformers/paraphrase-multilingual-mpnet-base-v2') | |
self.model_cluster = AutoModel.from_pretrained('sentence-transformers/paraphrase-multilingual-mpnet-base-v2').to(device) | |
self.initialized = True | |
logger.info("All models initialized successfully") | |
except Exception as e: | |
logger.error(f"Error in EventDetector initialization: {str(e)}") | |
raise | |
def initialize_models(self): | |
"""Initialize models with proper error handling""" | |
try: | |
# Initialize translation models | |
self.translator = pipeline( | |
"translation", | |
model="Helsinki-NLP/opus-mt-ru-en", | |
device=self.device | |
) | |
self.rutranslator = pipeline( | |
"translation", | |
model="Helsinki-NLP/opus-mt-en-ru", | |
device=self.device | |
) | |
# Initialize sentiment models | |
self.finbert = pipeline( | |
"sentiment-analysis", | |
model="ProsusAI/finbert", | |
device=self.device, | |
truncation=True, | |
max_length=512 | |
) | |
self.roberta = pipeline( | |
"sentiment-analysis", | |
model="cardiffnlp/twitter-roberta-base-sentiment", | |
device=self.device, | |
truncation=True, | |
max_length=512 | |
) | |
# Initialize MT5 model | |
self.model_name = "google/mt5-small" | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, | |
legacy=True | |
) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name).to(self.device) | |
except Exception as e: | |
logger.error(f"Model initialization error: {str(e)}") | |
raise | |
def process_text(self, text, entity): | |
"""Process text with simplified analysis""" | |
try: | |
translated_text = self._translate_text(text) | |
sentiment = self.analyze_sentiment(translated_text) | |
event_type, event_summary = self.detect_events(text, entity) | |
return { | |
'translated_text': translated_text, | |
'sentiment': sentiment, | |
'impact': 'Неопределенный эффект', | |
'reasoning': 'Автоматический анализ', | |
'event_type': event_type, | |
'event_summary': event_summary | |
} | |
except Exception as e: | |
logger.error(f"Text processing error: {str(e)}") | |
return { | |
'translated_text': '', | |
'sentiment': 'Neutral', | |
'impact': 'Неопределенный эффект', | |
'reasoning': f'Ошибка обработки: {str(e)}', | |
'event_type': 'Нет', | |
'event_summary': '' | |
} | |
def _translate_text(self, text): | |
"""Translate Russian text to English with proper error handling""" | |
try: | |
if not text or not isinstance(text, str): | |
return "" | |
text = text.strip() | |
if not text: | |
return "" | |
max_length = 450 | |
chunks = [text[i:i + max_length] for i in range(0, len(text), max_length)] | |
translated_chunks = [] | |
for chunk in chunks: | |
result = self.translator(chunk)[0]['translation_text'] | |
translated_chunks.append(result) | |
time.sleep(0.1) | |
return " ".join(translated_chunks) | |
except Exception as e: | |
logger.error(f"Translation error: {str(e)}") | |
return text | |
def analyze_sentiment(self, text): | |
"""Simplified sentiment analysis""" | |
try: | |
if not text or not isinstance(text, str) or not text.strip(): | |
return "Neutral" | |
finbert_result = self.finbert(text)[0] | |
roberta_result = self.roberta(text)[0] | |
# Simple majority voting | |
sentiments = [] | |
for result in [finbert_result, roberta_result]: | |
label = result['label'].lower() | |
if 'positive' in label or 'pos' in label: | |
sentiments.append("Positive") | |
elif 'negative' in label or 'neg' in label: | |
sentiments.append("Negative") | |
else: | |
sentiments.append("Neutral") | |
# Count occurrences | |
pos_count = sentiments.count("Positive") | |
neg_count = sentiments.count("Negative") | |
if neg_count > pos_count: | |
return "Negative" | |
elif pos_count > neg_count: | |
return "Positive" | |
return "Neutral" | |
except Exception as e: | |
logger.error(f"Sentiment analysis error: {str(e)}") | |
return "Neutral" | |
def detect_events(self, text, entity): | |
"""Simplified event detection""" | |
if not text or not entity: | |
return "Нет", "Invalid input" | |
try: | |
prompt = f"<s>Classify news about {entity}: {text}</s>" | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512 | |
).to(self.device) | |
outputs = self.model.generate( | |
**inputs, | |
max_length=100, | |
num_return_sequences=1, | |
do_sample=False | |
) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Simple classification based on keywords | |
if any(word in response.lower() for word in ['financial', 'revenue', 'profit']): | |
return "Отчетность", "Financial report detected" | |
elif any(word in response.lower() for word in ['court', 'lawsuit', 'legal']): | |
return "Суд", "Legal proceedings detected" | |
elif any(word in response.lower() for word in ['bonds', 'securities', 'debt']): | |
return "РЦБ", "Securities-related news detected" | |
return "Нет", "No specific event detected" | |
except Exception as e: | |
logger.error(f"Event detection error: {str(e)}") | |
return "Нет", f"Error in event detection: {str(e)}" | |
def cleanup(self): | |
"""Clean up GPU resources""" | |
try: | |
self.model = None | |
self.translator = None | |
self.finbert = None | |
self.roberta = None | |
torch.cuda.empty_cache() | |
self.initialized = False | |
logger.info("Cleaned up GPU resources") | |
except Exception as e: | |
logger.error(f"Error in cleanup: {str(e)}") | |
def create_visualizations(df): | |
"""Create visualization plots""" | |
if df is None or df.empty: | |
return None, None | |
try: | |
sentiments = df['Sentiment'].value_counts() | |
fig_sentiment = go.Figure(data=[go.Pie( | |
labels=sentiments.index, | |
values=sentiments.values, | |
marker_colors=['#FF6B6B', '#4ECDC4', '#95A5A6'] | |
)]) | |
fig_sentiment.update_layout(title="Распределение тональности") | |
events = df['Event_Type'].value_counts() | |
fig_events = go.Figure(data=[go.Bar( | |
x=events.index, | |
y=events.values, | |
marker_color='#2196F3' | |
)]) | |
fig_events.update_layout(title="Распределение событий") | |
return fig_sentiment, fig_events | |
except Exception as e: | |
logger.error(f"Visualization error: {e}") | |
return None, None | |
def create_interface(): | |
"""Create Gradio interface""" | |
control = ProcessControl() | |
with gr.Blocks() as app: | |
gr.Markdown("# AI-анализ мониторинга новостей v.2.0") | |
with gr.Row(): | |
file_input = gr.File( | |
label="Загрузите Excel файл", | |
file_types=[".xlsx"] | |
) | |
with gr.Row(): | |
analyze_btn = gr.Button("▶️ Начать анализ", variant="primary") | |
stop_btn = gr.Button("⏹️ Остановить", variant="stop") | |
progress = gr.Textbox( | |
label="Статус обработки", | |
value="Ожидание файла..." | |
) | |
stats = gr.DataFrame(label="Результаты анализа") | |
with gr.Row(): | |
sentiment_plot = gr.Plot(label="Распределение тональности") | |
events_plot = gr.Plot(label="Распределение событий") | |
def stop_processing(): | |
control.request_stop() | |
return "Остановка обработки..." | |
def process_file(file): | |
try: | |
if file is None: | |
return None, None, None, "Пожалуйста, загрузите файл" | |
df = pd.read_excel(file.name) | |
detector = EventDetector() | |
processed_rows = [] | |
total = len(df) | |
for idx, row in df.iterrows(): | |
if control.should_stop(): | |
break | |
text = str(row.get('Выдержки из текста', '')).strip() | |
entity = str(row.get('Объект', '')).strip() | |
if text and entity: | |
results = detector.process_text(text, entity) | |
processed_rows.append({ | |
'Объект': entity, | |
'Заголовок': str(row.get('Заголовок', '')), | |
'Sentiment': results['sentiment'], | |
'Event_Type': results['event_type'], | |
'Event_Summary': results['event_summary'], | |
'Текст': text[:1000] | |
}) | |
if len(processed_rows) % 10 == 0: | |
yield pd.DataFrame(processed_rows), None, None, f"Обработано {len(processed_rows)}/{total} строк" | |
final_df = pd.DataFrame(processed_rows) | |
fig_sentiment, fig_events = create_visualizations(final_df) | |
return final_df, fig_sentiment, fig_events, "Обработка завершена!" | |
except Exception as e: | |
error_msg = f"Ошибка анализа: {str(e)}" | |
logger.error(error_msg) | |
return None, None, None, error_msg | |
finally: | |
if 'detector' in locals(): | |
detector.cleanup() | |
stop_btn.click(fn=stop_processing, outputs=[progress]) | |
analyze_btn.click( | |
fn=process_file, | |
inputs=[file_input], | |
outputs=[stats, sentiment_plot, events_plot, progress] | |
) | |
return app | |
if __name__ == "__main__": | |
app = create_interface() | |
app.launch() |