rwitz commited on
Commit
4c8f98f
·
verified ·
1 Parent(s): 7ea791a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -11
app.py CHANGED
@@ -1,30 +1,46 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModel
3
  import torch
 
4
 
5
  # Load model and tokenizer
6
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B")
7
  model = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B")
8
 
9
  def get_embedding(text):
10
- if len(text) > 250:
11
- return "❌ Error: Input exceeds 250 character limit."
12
-
13
  inputs = tokenizer(text, return_tensors="pt", truncation=True)
14
  with torch.no_grad():
15
  outputs = model(**inputs)
16
- # Use [CLS] token embedding (or mean pooling)
17
- embedding = outputs.last_hidden_state[:, 0, :].squeeze().tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # Show only first 10 dimensions for readability
20
- return f"✅ Embedding (first 10 values): {embedding[:10]}..."
21
 
22
  demo = gr.Interface(
23
- fn=get_embedding,
24
- inputs=gr.Textbox(label="Enter a sentence (max 250 characters)", max_lines=3, placeholder="Type your sentence here...", lines=2),
 
 
 
25
  outputs="text",
26
- title="Qwen3 Embedding Demo",
27
- description="Generates sentence embeddings using Qwen/Qwen3-Embedding-0.6B. Input must be 250 characters or fewer."
28
  )
29
 
30
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModel
3
  import torch
4
+ import torch.nn.functional as F
5
 
6
  # Load model and tokenizer
7
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B")
8
  model = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B")
9
 
10
  def get_embedding(text):
 
 
 
11
  inputs = tokenizer(text, return_tensors="pt", truncation=True)
12
  with torch.no_grad():
13
  outputs = model(**inputs)
14
+ return outputs.last_hidden_state[:, 0, :] # [CLS] token
15
+
16
+ def compare_sentences(reference, comparisons):
17
+ if len(reference) > 250:
18
+ return "❌ Error: Reference exceeds 250 character limit."
19
+
20
+ comparison_list = [s.strip() for s in comparisons.strip().split('\n') if s.strip()]
21
+ if not comparison_list:
22
+ return "❌ Error: No comparison sentences provided."
23
+
24
+ if any(len(s) > 250 for s in comparison_list):
25
+ return "❌ Error: One or more comparison sentences exceed 250 characters."
26
+
27
+ ref_emb = get_embedding(reference)
28
+ comp_embs = torch.cat([get_embedding(s) for s in comparison_list], dim=0)
29
+
30
+ similarities = F.cosine_similarity(ref_emb, comp_embs).tolist()
31
+ results = "\n".join([f"Similarity with: \"{s}\"\n→ {round(score, 4)}" for s, score in zip(comparison_list, similarities)])
32
 
33
+ return results
 
34
 
35
  demo = gr.Interface(
36
+ fn=compare_sentences,
37
+ inputs=[
38
+ gr.Textbox(label="Reference Sentence (max 250 characters)", lines=2, placeholder="Type the reference sentence here..."),
39
+ gr.Textbox(label="Comparison Sentences (one per line, each max 250 characters)", lines=8, placeholder="Type comparison sentences here, one per line..."),
40
+ ],
41
  outputs="text",
42
+ title="Qwen3 Embedding Comparison Demo",
43
+ description="Enter a reference sentence and multiple comparison sentences (one per line). The model computes the cosine similarity between the reference and each comparison."
44
  )
45
 
46
  if __name__ == "__main__":