import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM import gradio as gr # Load tokenizer dan encoder untuk klasifikasi tokenizer = AutoTokenizer.from_pretrained("indobenchmark/indobert-base-p1") encoder = AutoModel.from_pretrained("indobenchmark/indobert-base-p1") # Definisi model multi-task class MultiTaskModel(nn.Module): def __init__(self, encoder, hidden_size=768, num_topic_labels=5, num_sentiment_labels=3): super(MultiTaskModel, self).__init__() self.encoder = encoder self.topik_classifier = nn.Linear(hidden_size, num_topic_labels) self.sentiment_classifier = nn.Linear(hidden_size, num_sentiment_labels) def forward(self, input_ids, attention_mask): outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) cls_output = outputs.last_hidden_state[:, 0, :] topik_logits = self.topik_classifier(cls_output) sentiment_logits = self.sentiment_classifier(cls_output) return topik_logits, sentiment_logits # Inisialisasi dan load model model = MultiTaskModel(encoder) model.load_state_dict(torch.load("model.pt", map_location=torch.device("cpu"))) model.eval() # Load tokenizer dan model summarization sum_tokenizer = AutoTokenizer.from_pretrained("flax-community/bart-base-indonesian-summarization") sum_model = AutoModelForSeq2SeqLM.from_pretrained("flax-community/bart-base-indonesian-summarization") # Label mapping topik_labels = ["Produk", "Layanan", "Pengiriman", "Pembatalan", "Lainnya"] sentiment_labels = ["Negatif", "Netral", "Positif"] # Fungsi analisis def analyze_text(text): # Klasifikasi inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): topik_logits, sentiment_logits = model(**inputs) topik = torch.argmax(topik_logits, dim=1).item() sentiment = torch.argmax(sentiment_logits, dim=1).item() # Ringkasan input_summary = f"Ringkas percakapan berikut: {text}" inputs_sum = sum_tokenizer.encode(input_summary, return_tensors="pt", max_length=512, truncation=True) summary_ids = sum_model.generate(inputs_sum, max_length=50, min_length=5, do_sample=False) summary = sum_tokenizer.decode(summary_ids[0], skip_special_tokens=True) # Output akhir result = f"""**HASIL ANALISIS** Topik: {topik_labels[topik]} Sentimen: {sentiment_labels[sentiment]} Ringkasan: {summary}""" return result # UI Gradio with gr.Blocks() as demo: gr.Markdown("## Analisis Topik, Sentimen, dan Ringkasan Pelanggan") with gr.Row(): input_text = gr.Textbox(label="Masukkan Teks Percakapan") output_text = gr.Textbox(label="Hasil Analisis") with gr.Row(): clear_btn = gr.Button("Clear") submit_btn = gr.Button("Analisa") submit_btn.click(analyze_text, inputs=input_text, outputs=output_text) clear_btn.click(lambda: ("", ""), inputs=[], outputs=[input_text, output_text]) demo.launch()