princemaxp commited on
Commit
a2f0146
·
verified ·
1 Parent(s): a8452a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -65
app.py CHANGED
@@ -1,91 +1,130 @@
1
- import os
2
  import gradio as gr
3
- import torch
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
5
  from datasets import load_dataset, Dataset
 
 
 
6
 
7
- # ---------- CONFIG ----------
8
- MODEL_ID = "YOUR_MODEL_ID_HF" # Replace with your HF model ID
 
 
9
  DATASET_NAME = "guardian-ai-qna"
10
- SYSTEM_PROMPT = "You are Guardian AI, a cybersecurity expert. Answer concisely."
11
-
12
- # ---------- LOAD TOKENIZER & MODEL ----------
13
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
14
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16)
15
- device = 0 if torch.cuda.is_available() else -1
16
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
17
 
18
- # ---------- LOAD DATASET ----------
 
 
19
  try:
20
- dataset = load_dataset("huggingface", DATASET_NAME, split="train")
 
21
  except:
22
  dataset = Dataset.from_dict({"question": [], "answer": []})
23
 
24
- # ---------- EMBEDDING HELPER ----------
25
- from sentence_transformers import SentenceTransformer, util
26
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
 
27
 
28
- # Cache embeddings in memory
29
  if len(dataset) > 0:
30
  dataset_embeddings = embedder.encode(dataset["question"], convert_to_tensor=True)
31
  else:
32
- dataset_embeddings = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- # ---------- SAVE QNA FUNCTION ----------
35
  def save_qna(question, answer):
36
  global dataset, dataset_embeddings
37
  new_entry = Dataset.from_dict({"question": [question], "answer": [answer]})
38
  dataset = Dataset.from_dict({
39
- "question": dataset["question"] + [question],
40
- "answer": dataset["answer"] + [answer]
41
  })
42
- # update embeddings
43
- dataset_embeddings.append(embedder.encode(question, convert_to_tensor=True))
44
- # push to HF dataset
45
- dataset.push_to_hub(DATASET_NAME, token=os.environ.get("HF_TOKEN"))
46
-
47
- # ---------- RETRIEVE SIMILAR QNA ----------
48
- def retrieve_similar_qna(query, top_k=3):
49
- if len(dataset) == 0:
50
- return ""
51
- query_emb = embedder.encode(query, convert_to_tensor=True)
52
- similarities = util.cos_sim(query_emb, dataset_embeddings)[0]
53
- top_results = similarities.topk(k=min(top_k, len(similarities)))
54
- context = ""
55
- for idx in top_results.indices:
56
- context += f"Q: {dataset[idx]['question']}\nA: {dataset[idx]['answer']}\n"
57
- return context
58
-
59
- # ---------- CHAT FUNCTION ----------
60
- def chat(history, user_input):
61
- context = retrieve_similar_qna(user_input)
62
- prompt = SYSTEM_PROMPT
63
- if context:
64
- prompt += f"\n\nMemory of past Q&A:\n{context}"
65
- prompt += f"\n\nUser: {user_input}\nGuardian AI:"
66
 
67
- with torch.no_grad():
68
- result = generator(
69
- prompt,
70
- max_new_tokens=150,
71
- do_sample=True,
72
- temperature=0.6,
73
- top_p=0.85
74
- )[0]["generated_text"]
75
-
76
- response = result.split("Guardian AI:")[-1].strip()
77
- history.append((user_input, response))
 
 
 
 
 
 
 
78
  save_qna(user_input, response)
 
 
 
79
  return history, history
80
 
81
- # ---------- GRADIO APP ----------
 
 
82
  with gr.Blocks() as app:
83
  chatbot = gr.Chatbot()
84
- state = gr.State([])
85
- with gr.Row():
86
- user_msg = gr.Textbox(label="Type your message")
87
- send_btn = gr.Button("Send")
88
-
89
- send_btn.click(chat, [state, user_msg], [chatbot, state])
 
90
 
91
- app.launch(share=True)
 
 
1
+ import time
2
  import gradio as gr
 
 
3
  from datasets import load_dataset, Dataset
4
+ from huggingface_hub import hf_hub_download
5
+ from sentence_transformers import SentenceTransformer, util
6
+ import torch
7
 
8
+ # ---------------------------
9
+ # CONFIGURATION
10
+ # ---------------------------
11
+ HF_TOKEN = "<YOUR_HF_TOKEN>" # set your HF token
12
  DATASET_NAME = "guardian-ai-qna"
13
+ MAX_QUESTIONS = 5 # max questions per TIME_WINDOW
14
+ TIME_WINDOW = 3600 # 1 hour in seconds
15
+ EMBED_MODEL = "all-MiniLM-L6-v2" # small but effective embedding model
 
 
 
 
16
 
17
+ # ---------------------------
18
+ # LOAD OR CREATE DATASET
19
+ # ---------------------------
20
  try:
21
+ dataset = load_dataset(DATASET_NAME, use_auth_token=HF_TOKEN)
22
+ dataset = dataset["train"]
23
  except:
24
  dataset = Dataset.from_dict({"question": [], "answer": []})
25
 
26
+ # ---------------------------
27
+ # EMBEDDING MODEL
28
+ # ---------------------------
29
+ embedder = SentenceTransformer(EMBED_MODEL)
30
 
31
+ # Precompute embeddings for existing Q&A
32
  if len(dataset) > 0:
33
  dataset_embeddings = embedder.encode(dataset["question"], convert_to_tensor=True)
34
  else:
35
+ dataset_embeddings = torch.empty((0, embedder.get_sentence_embedding_dimension()))
36
+
37
+ # ---------------------------
38
+ # USER RATE LIMITING
39
+ # ---------------------------
40
+ user_limits = {}
41
+
42
+ def check_rate_limit(session_id):
43
+ current_time = time.time()
44
+ if session_id not in user_limits:
45
+ user_limits[session_id] = {"count": 0, "start_time": current_time}
46
+
47
+ user_data = user_limits[session_id]
48
+ if current_time - user_data["start_time"] > TIME_WINDOW:
49
+ user_data["count"] = 0
50
+ user_data["start_time"] = current_time
51
+
52
+ if user_data["count"] >= MAX_QUESTIONS:
53
+ return False, f"You have reached the max of {MAX_QUESTIONS} questions. Please wait before asking more."
54
+
55
+ user_data["count"] += 1
56
+ return True, None
57
+
58
+ # ---------------------------
59
+ # HELPER FUNCTIONS
60
+ # ---------------------------
61
+ def find_similar_answer(user_input):
62
+ if len(dataset) == 0:
63
+ return None
64
+
65
+ query_emb = embedder.encode(user_input, convert_to_tensor=True)
66
+ scores = util.cos_sim(query_emb, dataset_embeddings)
67
+ top_idx = torch.argmax(scores)
68
+ top_score = scores[0][top_idx].item()
69
+
70
+ if top_score > 0.6: # threshold for similarity
71
+ return dataset["answer"][top_idx]
72
+ return None
73
 
 
74
  def save_qna(question, answer):
75
  global dataset, dataset_embeddings
76
  new_entry = Dataset.from_dict({"question": [question], "answer": [answer]})
77
  dataset = Dataset.from_dict({
78
+ "question": dataset["question"] + new_entry["question"],
79
+ "answer": dataset["answer"] + new_entry["answer"]
80
  })
81
+
82
+ # update embeddings incrementally
83
+ new_emb = embedder.encode([question], convert_to_tensor=True)
84
+ if len(dataset_embeddings) == 0:
85
+ dataset_embeddings = new_emb
86
+ else:
87
+ dataset_embeddings = torch.vstack([dataset_embeddings, new_emb])
88
+
89
+ # save to HF dataset (push to hub)
90
+ dataset.push_to_hub(DATASET_NAME, token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ # ---------------------------
93
+ # MAIN CHAT FUNCTION
94
+ # ---------------------------
95
+ def chat(history, user_input, session_id="default"):
96
+ # Rate limit check
97
+ allowed, message = check_rate_limit(session_id)
98
+ if not allowed:
99
+ history.append(("System", message))
100
+ return history, history
101
+
102
+ # Check existing similar Q&A
103
+ response = find_similar_answer(user_input)
104
+
105
+ if not response:
106
+ # Fallback / simple generative response
107
+ response = f"Guardian AI: Sorry, I don’t know the answer yet. I’m learning!"
108
+
109
+ # Save new Q&A for incremental learning
110
  save_qna(user_input, response)
111
+
112
+ # Update chat history
113
+ history.append((user_input, response))
114
  return history, history
115
 
116
+ # ---------------------------
117
+ # GRADIO INTERFACE
118
+ # ---------------------------
119
  with gr.Blocks() as app:
120
  chatbot = gr.Chatbot()
121
+ msg = gr.Textbox(label="Your question")
122
+ session_state = gr.State("default") # default session
123
+
124
+ def user_submit(message, history, session_id):
125
+ return chat(history, message, session_id)
126
+
127
+ msg.submit(user_submit, inputs=[msg, chatbot, session_state], outputs=[chatbot, chatbot])
128
 
129
+ # Launch app
130
+ app.launch(server_name="0.0.0.0", server_port=7860, share=True)