AfroLogicInsect's picture
Create app.py
c7fc1c9 verified
raw
history blame contribute delete
908 Bytes
import gradio as gr
from transformers import pipeline
# Load multi-class topic classification pipeline
topic_pipeline = pipeline(
"text-classification",
model="AfroLogicInsect/topic-model-analysis-model",
tokenizer="AfroLogicInsect/topic-model-analysis-model",
return_all_scores=True
)
def predict_topics(text):
if not text.strip():
return [["Please enter some text", 0.0]]
results = topic_pipeline(text)
sorted_results = sorted(results[0], key=lambda x: x['score'], reverse=True)[:5]
# Format for Gradio output: list of [label, score]
return [[res['label'], round(res['score'], 3)] for res in sorted_results]
iface = gr.Interface(
fn=predict_topics,
inputs=gr.Textbox(label="Enter text"),
outputs=gr.Dataframe(
headers=["Topic", "Confidence"],
label="Top 5 Predicted Topics",
type="array"
)
)
iface.launch(share=True)