dejanseo commited on
Commit
d1c870a
·
verified ·
1 Parent(s): 342f955

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import pandas as pd
6
+ import os
7
+
8
+ # Theme configuration - MUST BE FIRST STREAMLIT COMMAND
9
+ st.set_page_config(
10
+ page_title="QDF Classifier",
11
+ page_icon="🔍",
12
+ layout="wide",
13
+ initial_sidebar_state="collapsed",
14
+ menu_items=None
15
+ )
16
+
17
+ MODEL_ID = "dejanseo/QDF"
18
+ HF_TOKEN = os.getenv("HF_TOKEN")
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, token=HF_TOKEN)
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
23
+ model.to(device)
24
+ model.eval()
25
+
26
+ def classify(prompt: str):
27
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
28
+ with torch.no_grad():
29
+ logits = model(**inputs).logits
30
+ probs = torch.softmax(logits, dim=-1).squeeze().cpu()
31
+ pred = torch.argmax(probs).item()
32
+ confidence = probs[pred].item()
33
+ return pred, confidence
34
+
35
+ # Font and style overrides
36
+ st.markdown("""
37
+ <style>
38
+ @import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@400;600&display=swap');
39
+
40
+ html, body, div, span, input, label, textarea, button, h1, h2, p, table {
41
+ font-family: 'Montserrat', sans-serif !important;
42
+ }
43
+
44
+ [class^="css-"], [class*=" css-"] {
45
+ font-family: 'Montserrat', sans-serif !important;
46
+ }
47
+
48
+ header {visibility: hidden;}
49
+ </style>
50
+ """, unsafe_allow_html=True)
51
+
52
+ # UI
53
+ st.title("QDF Classifier")
54
+ st.write("Built by [**Dejan AI**](https://dejan.ai)")
55
+ st.write("This classifier determines whether query deserves freshness.")
56
+
57
+ # Placeholder example prompts
58
+ example_text = """how would a cat describe a dog
59
+ how to reset a Nest thermostat
60
+ write a poem about time
61
+ is there a train strike in London today
62
+ summarize the theory of relativity
63
+ who won the champions league last year
64
+ explain quantum computing to a child
65
+ weather in tokyo tomorrow
66
+ generate a social media post for Earth Day
67
+ what is the latest iPhone model"""
68
+
69
+ user_input = st.text_area(
70
+ "Enter one search query per line:",
71
+ placeholder=example_text
72
+ )
73
+
74
+ if st.button("Classify"):
75
+ raw_input = user_input.strip()
76
+ if raw_input:
77
+ prompts = [line.strip() for line in raw_input.split("\n") if line.strip()]
78
+ else:
79
+ prompts = [line.strip() for line in example_text.split("\n")]
80
+
81
+ if not prompts:
82
+ st.warning("Please enter at least one prompt.")
83
+ else:
84
+ info_box = st.info("Processing... results will appear below one by one.")
85
+ table_placeholder = st.empty()
86
+ results = []
87
+
88
+ for p in prompts:
89
+ with st.spinner(f"Classifying: {p[:50]}..."):
90
+ label, conf = classify(p)
91
+ results.append({
92
+ "Prompt": p,
93
+ "Grounding": "Yes" if label == 1 else "No",
94
+ "Confidence": round(conf, 4)
95
+ })
96
+ df = pd.DataFrame(results)
97
+ table_placeholder.data_editor(
98
+ df,
99
+ column_config={
100
+ "Confidence": st.column_config.ProgressColumn(
101
+ label="Confidence",
102
+ min_value=0.0,
103
+ max_value=1.0,
104
+ format="%.4f"
105
+ )
106
+ },
107
+ hide_index=True,
108
+ )
109
+ info_box.empty()
110
+
111
+ # Promo message shown only after results
112
+ st.subheader("Working together.")
113
+ st.write("[**Schedule a call**](https://dejan.ai/call/) to see how we can help you.")