Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| def pipeline_getter(): | |
| tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased') | |
| model = AutoModelForSequenceClassification.from_pretrained('KemmerEdition/my-distill-classifier') | |
| mapping = pd.read_csv('./categories.csv').values.squeeze() | |
| return tokenizer, model, mapping | |
| tokenizer, model, mapping = pipeline_getter() | |
| def predict_article_categories_with_confidence( | |
| text_data, | |
| abstract_text=None, | |
| confidence_level=0.95, | |
| max_categories=9 | |
| ): | |
| tokenized_input = tokenizer( | |
| text=text_data, | |
| text_pair=abstract_text, | |
| padding=True, | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| model_output = model(**tokenized_input) | |
| logits = model_output.logits | |
| probs = torch.sigmoid(logits).detach().numpy().flatten() | |
| sorted_indices = np.argsort(probs)[::-1] | |
| sorted_probs = probs[sorted_indices] | |
| cumulative_probs = np.cumsum(sorted_probs) | |
| selected_indices = [] | |
| for i, cum_prob in enumerate(cumulative_probs): | |
| if cum_prob >= confidence_level or i >= max_categories - 1: | |
| selected_indices = sorted_indices[:i+1] | |
| break | |
| result = { | |
| 'probabilities': probs, | |
| 'predicted_categories': [mapping[idx] for idx in selected_indices], | |
| 'confidence': cumulative_probs[len(selected_indices)-1], | |
| 'top_category': mapping[sorted_indices[0]], | |
| 'used_categories': len(selected_indices) | |
| } | |
| return result | |
| st.markdown(""" | |
| <style> | |
| .header { | |
| font-size: 36px !important; | |
| color: #1f77b4; | |
| margin-bottom: 20px; | |
| } | |
| .input-box { | |
| background-color: #f0f2f6; | |
| padding: 20px; | |
| border-radius: 10px; | |
| margin-bottom: 20px; | |
| } | |
| .result-box { | |
| background-color: #e6f3ff; | |
| padding: 20px; | |
| border-radius: 10px; | |
| margin-top: 20px; | |
| } | |
| .category-badge { | |
| display: inline-block; | |
| background-color: #1f77b4; | |
| color: white; | |
| padding: 5px 10px; | |
| margin: 5px; | |
| border-radius: 15px; | |
| font-size: 14px; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.markdown('<div class="header">Classificator of Paper from arxiv</div>', unsafe_allow_html=True) | |
| with st.container(): | |
| st.markdown('<div class="input-box">', unsafe_allow_html=True) | |
| title_input = st.text_input('**Here you can write title:**', placeholder="e.g. Quantum Machine Learning Approaches") | |
| abstract_input = st.text_area('**Here you can write summary from arxiv:**', | |
| placeholder="Paste the abstract here for more accurate categorization...", | |
| height=150) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| confidence_level = st.slider('**Confidence level (%)**', 80, 100, 95) | |
| with col2: | |
| max_categories = st.slider('**Maximum categories**', 1, 10, 3) | |
| if st.button('**Press F (just press)**', type="primary"): | |
| if len(title_input) > 0: | |
| with st.spinner('Analyzing paper content...'): | |
| result = predict_article_categories_with_confidence( | |
| title_input, | |
| abstract_input if abstract_input else None, | |
| confidence_level=confidence_level/100, | |
| max_categories=max_categories | |
| ) | |
| with st.container(): | |
| st.markdown('<div class="result-box">', unsafe_allow_html=True) | |
| st.subheader("Categorization Results") | |
| st.markdown(f"**Most likely category:**") | |
| st.markdown(f'<div class="category-badge">{result["top_category"]} (p={result["probabilities"][np.argmax(result["probabilities"])]:.3f})</div>', | |
| unsafe_allow_html=True) | |
| if len(result["predicted_categories"]) > 1: | |
| st.markdown(f"Additional categories:") | |
| for category in result["predicted_categories"][1:]: | |
| st.markdown(f'<div class="category-badge">{category}</div>', unsafe_allow_html=True) | |
| st.markdown("---") | |
| else: | |
| st.warning("Please enter at least the paper title") |