uas-nlp2 / app.py
ElizabethSrgh's picture
Update app.py
53eb003 verified
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
import gradio as gr
# Load tokenizer dan encoder untuk klasifikasi
tokenizer = AutoTokenizer.from_pretrained("indobenchmark/indobert-base-p1")
encoder = AutoModel.from_pretrained("indobenchmark/indobert-base-p1")
# Definisi model multi-task
class MultiTaskModel(nn.Module):
def __init__(self, encoder, hidden_size=768, num_topic_labels=5, num_sentiment_labels=3):
super(MultiTaskModel, self).__init__()
self.encoder = encoder
self.topik_classifier = nn.Linear(hidden_size, num_topic_labels)
self.sentiment_classifier = nn.Linear(hidden_size, num_sentiment_labels)
def forward(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0, :]
topik_logits = self.topik_classifier(cls_output)
sentiment_logits = self.sentiment_classifier(cls_output)
return topik_logits, sentiment_logits
# Inisialisasi dan load model
model = MultiTaskModel(encoder)
model.load_state_dict(torch.load("model.pt", map_location=torch.device("cpu")))
model.eval()
# Load tokenizer dan model summarization
sum_tokenizer = AutoTokenizer.from_pretrained("flax-community/bart-base-indonesian-summarization")
sum_model = AutoModelForSeq2SeqLM.from_pretrained("flax-community/bart-base-indonesian-summarization")
# Label mapping
topik_labels = ["Produk", "Layanan", "Pengiriman", "Pembatalan", "Lainnya"]
sentiment_labels = ["Negatif", "Netral", "Positif"]
# Fungsi analisis
def analyze_text(text):
# Klasifikasi
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
topik_logits, sentiment_logits = model(**inputs)
topik = torch.argmax(topik_logits, dim=1).item()
sentiment = torch.argmax(sentiment_logits, dim=1).item()
# Ringkasan
input_summary = f"Ringkas percakapan berikut: {text}"
inputs_sum = sum_tokenizer.encode(input_summary, return_tensors="pt", max_length=512, truncation=True)
summary_ids = sum_model.generate(inputs_sum, max_length=50, min_length=5, do_sample=False)
summary = sum_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# Output akhir
result = f"""**HASIL ANALISIS**
Topik: {topik_labels[topik]}
Sentimen: {sentiment_labels[sentiment]}
Ringkasan: {summary}"""
return result
# UI Gradio
with gr.Blocks() as demo:
gr.Markdown("## Analisis Topik, Sentimen, dan Ringkasan Pelanggan")
with gr.Row():
input_text = gr.Textbox(label="Masukkan Teks Percakapan")
output_text = gr.Textbox(label="Hasil Analisis")
with gr.Row():
clear_btn = gr.Button("Clear")
submit_btn = gr.Button("Analisa")
submit_btn.click(analyze_text, inputs=input_text, outputs=output_text)
clear_btn.click(lambda: ("", ""), inputs=[], outputs=[input_text, output_text])
demo.launch()