File size: 3,728 Bytes
d1c870a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fdff7e
d1c870a
 
1b55ac6
 
 
86b4e16
 
 
d1c870a
0ae91f3
d1c870a
0ae91f3
 
 
 
 
 
f693b94
d1c870a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643c698
d1c870a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
import os

# Theme configuration - MUST BE FIRST STREAMLIT COMMAND
st.set_page_config(
    page_title="QDF Classifier",
    page_icon="🔍",
    layout="wide",
    initial_sidebar_state="collapsed",
    menu_items=None
)

MODEL_ID = "dejanseo/QDF-large"
HF_TOKEN = os.getenv("HF_TOKEN")

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_ID,
    token=HF_TOKEN,
    low_cpu_mem_usage=True
).eval()

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)

def classify(prompt: str):
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=512
    )
    with torch.no_grad():
        logits = model(**inputs).logits
        probs = torch.softmax(logits, dim=-1).squeeze().cpu()
        pred = torch.argmax(probs).item()
        confidence = probs[pred].item()
        return pred, confidence

# Font and style overrides
st.markdown("""
    <style>
        @import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@400;600&display=swap');

        html, body, div, span, input, label, textarea, button, h1, h2, p, table {
            font-family: 'Montserrat', sans-serif !important;
        }

        [class^="css-"], [class*=" css-"] {
            font-family: 'Montserrat', sans-serif !important;
        }

        header {visibility: hidden;}
    </style>
""", unsafe_allow_html=True)

# UI
st.title("QDF Classifier")
st.write("Built by [**Dejan AI**](https://dejan.ai)")
st.write("This classifier determines whether query deserves freshness.")

# Placeholder example prompts
example_text = """how would a cat describe a dog
how to reset a Nest thermostat
write a poem about time
is there a train strike in London today
summarize the theory of relativity
who won the champions league last year
explain quantum computing to a child
weather in tokyo tomorrow
generate a social media post for Earth Day
what is the latest iPhone model"""

user_input = st.text_area(
    "Enter one search query per line:",
    placeholder=example_text
)

if st.button("Classify"):
    raw_input = user_input.strip()
    if raw_input:
        prompts = [line.strip() for line in raw_input.split("\n") if line.strip()]
    else:
        prompts = [line.strip() for line in example_text.split("\n")]

    if not prompts:
        st.warning("Please enter at least one prompt.")
    else:
        info_box = st.info("Processing... results will appear below one by one.")
        table_placeholder = st.empty()
        results = []

        for p in prompts:
            with st.spinner(f"Classifying: {p[:50]}..."):
                label, conf = classify(p)
                results.append({
                    "Prompt": p,
                    "QDF": "Yes" if label == 1 else "No",
                    "Confidence": round(conf, 4)
                })
                df = pd.DataFrame(results)
                table_placeholder.data_editor(
                    df,
                    column_config={
                        "Confidence": st.column_config.ProgressColumn(
                            label="Confidence",
                            min_value=0.0,
                            max_value=1.0,
                            format="%.4f"
                        )
                    },
                    hide_index=True,
                )
        info_box.empty()

        # Promo message shown only after results
        st.subheader("Working together.")
        st.write("[**Schedule a call**](https://dejan.ai/call/) to see how we can help you.")