Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM | |
import gradio as gr | |
# Load tokenizer dan encoder untuk klasifikasi | |
tokenizer = AutoTokenizer.from_pretrained("indobenchmark/indobert-base-p1") | |
encoder = AutoModel.from_pretrained("indobenchmark/indobert-base-p1") | |
# Definisi model multi-task | |
class MultiTaskModel(nn.Module): | |
def __init__(self, encoder, hidden_size=768, num_topic_labels=5, num_sentiment_labels=3): | |
super(MultiTaskModel, self).__init__() | |
self.encoder = encoder | |
self.topik_classifier = nn.Linear(hidden_size, num_topic_labels) | |
self.sentiment_classifier = nn.Linear(hidden_size, num_sentiment_labels) | |
def forward(self, input_ids, attention_mask): | |
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
cls_output = outputs.last_hidden_state[:, 0, :] | |
topik_logits = self.topik_classifier(cls_output) | |
sentiment_logits = self.sentiment_classifier(cls_output) | |
return topik_logits, sentiment_logits | |
# Inisialisasi dan load model | |
model = MultiTaskModel(encoder) | |
model.load_state_dict(torch.load("model.pt", map_location=torch.device("cpu"))) | |
model.eval() | |
# Load tokenizer dan model summarization | |
sum_tokenizer = AutoTokenizer.from_pretrained("flax-community/bart-base-indonesian-summarization") | |
sum_model = AutoModelForSeq2SeqLM.from_pretrained("flax-community/bart-base-indonesian-summarization") | |
# Label mapping | |
topik_labels = ["Produk", "Layanan", "Pengiriman", "Pembatalan", "Lainnya"] | |
sentiment_labels = ["Negatif", "Netral", "Positif"] | |
# Fungsi analisis | |
def analyze_text(text): | |
# Klasifikasi | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
with torch.no_grad(): | |
topik_logits, sentiment_logits = model(**inputs) | |
topik = torch.argmax(topik_logits, dim=1).item() | |
sentiment = torch.argmax(sentiment_logits, dim=1).item() | |
# Ringkasan | |
input_summary = f"Ringkas percakapan berikut: {text}" | |
inputs_sum = sum_tokenizer.encode(input_summary, return_tensors="pt", max_length=512, truncation=True) | |
summary_ids = sum_model.generate(inputs_sum, max_length=50, min_length=5, do_sample=False) | |
summary = sum_tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
# Output akhir | |
result = f"""**HASIL ANALISIS** | |
Topik: {topik_labels[topik]} | |
Sentimen: {sentiment_labels[sentiment]} | |
Ringkasan: {summary}""" | |
return result | |
# UI Gradio | |
with gr.Blocks() as demo: | |
gr.Markdown("## Analisis Topik, Sentimen, dan Ringkasan Pelanggan") | |
with gr.Row(): | |
input_text = gr.Textbox(label="Masukkan Teks Percakapan") | |
output_text = gr.Textbox(label="Hasil Analisis") | |
with gr.Row(): | |
clear_btn = gr.Button("Clear") | |
submit_btn = gr.Button("Analisa") | |
submit_btn.click(analyze_text, inputs=input_text, outputs=output_text) | |
clear_btn.click(lambda: ("", ""), inputs=[], outputs=[input_text, output_text]) | |
demo.launch() | |