|
import gradio as gr |
|
import torch |
|
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification |
|
from datetime import datetime |
|
import csv |
|
import os |
|
|
|
|
|
model = DistilBertForSequenceClassification.from_pretrained("debojit01/course-review-sentiment") |
|
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased") |
|
|
|
labels = ['negative', 'neutral', 'positive'] |
|
|
|
|
|
log_path = "logs.csv" |
|
corrections_path = "corrections.csv" |
|
|
|
for path, headers in [(log_path, ["timestamp", "input_text", "predicted_label"]), |
|
(corrections_path, ["timestamp", "input_text", "predicted_label", "user_correction"])]: |
|
if not os.path.exists(path): |
|
with open(path, mode='w', newline='') as file: |
|
writer = csv.writer(file) |
|
writer.writerow(headers) |
|
|
|
def classify_review(text): |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
label_idx = torch.argmax(probs).item() |
|
predicted_label = labels[label_idx] |
|
|
|
|
|
with open(log_path, mode='a', newline='') as file: |
|
writer = csv.writer(file) |
|
writer.writerow([datetime.now().isoformat(), text, predicted_label]) |
|
|
|
return {label: float(prob) for label, prob in zip(labels, probs[0])}, text, predicted_label |
|
|
|
def save_correction(text, predicted_label, user_correction): |
|
if user_correction != predicted_label: |
|
with open(corrections_path, mode='a', newline='') as file: |
|
writer = csv.writer(file) |
|
writer.writerow([datetime.now().isoformat(), text, predicted_label, user_correction]) |
|
return f"Thanks! Correction recorded: {user_correction}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# π Course Review Sentiment Classifier") |
|
gr.Markdown("Enter a course review and get the sentiment prediction. You can correct the result if needed.") |
|
|
|
input_text = gr.Textbox(lines=4, placeholder="Enter course review here...") |
|
output_label = gr.Label(num_top_classes=3) |
|
predict_btn = gr.Button("Classify") |
|
|
|
with gr.Row(visible=False) as correction_row: |
|
gr.Markdown("### β Is the prediction wrong?") |
|
correction_dropdown = gr.Dropdown(choices=labels, label="Correct Sentiment") |
|
submit_btn = gr.Button("Submit Correction") |
|
correction_status = gr.Textbox(interactive=False) |
|
|
|
hidden_text = gr.Textbox(visible=False) |
|
hidden_pred = gr.Textbox(visible=False) |
|
|
|
def show_correction_ui(result, text, pred): |
|
return result, gr.update(visible=True), text, pred |
|
|
|
predict_btn.click(classify_review, inputs=input_text, outputs=[output_label, hidden_text, hidden_pred])\ |
|
.then(show_correction_ui, outputs=[output_label, correction_row, hidden_text, hidden_pred]) |
|
|
|
submit_btn.click(save_correction, inputs=[hidden_text, hidden_pred, correction_dropdown], outputs=correction_status) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |