|
import gradio as gr |
|
from transformers import pipeline |
|
|
|
|
|
topic_pipeline = pipeline( |
|
"text-classification", |
|
model="AfroLogicInsect/topic-model-analysis-model", |
|
tokenizer="AfroLogicInsect/topic-model-analysis-model", |
|
return_all_scores=True |
|
) |
|
|
|
def predict_topics(text): |
|
if not text.strip(): |
|
return [["Please enter some text", 0.0]] |
|
|
|
results = topic_pipeline(text) |
|
sorted_results = sorted(results[0], key=lambda x: x['score'], reverse=True)[:5] |
|
|
|
|
|
return [[res['label'], round(res['score'], 3)] for res in sorted_results] |
|
|
|
iface = gr.Interface( |
|
fn=predict_topics, |
|
inputs=gr.Textbox(label="Enter text"), |
|
outputs=gr.Dataframe( |
|
headers=["Topic", "Confidence"], |
|
label="Top 5 Predicted Topics", |
|
type="array" |
|
) |
|
) |
|
|
|
iface.launch(share=True) |