princemaxp commited on
Commit
e959a70
·
verified ·
1 Parent(s): 520e6f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -131
app.py CHANGED
@@ -1,135 +1,120 @@
1
- import os
2
- import time
3
- from datetime import datetime, timedelta
4
- import gradio as gr
5
- from datasets import load_dataset, Dataset, DatasetDict, concatenate_datasets
6
- from huggingface_hub import HfFolder
7
- from sentence_transformers import SentenceTransformer, util
8
- import torch
9
  import requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- # ================================
12
- # CONFIG
13
- # ================================
14
- MODEL_TOKEN = os.environ.get("HF_TOKEN") # for model usage
15
- DATASET_TOKEN = os.environ.get("dataset_HF_TOKEN") # for dataset updates
16
- DATASET_NAME = "guardian-ai-qna"
17
-
18
- MAX_QUERIES = 5 # max queries per user per window
19
- WINDOW_HOURS = 1 # time window for rate limiting
20
-
21
- # Rate limiter store
22
- user_queries = {}
23
-
24
- # Save dataset token for pushes
25
- HfFolder.save_token(DATASET_TOKEN)
26
-
27
- # Load or create dataset
28
- try:
29
- dataset = load_dataset(DATASET_NAME, use_auth_token=DATASET_TOKEN)
30
- except:
31
- dataset = DatasetDict({"train": Dataset.from_dict({"question": [], "answer": [], "embedding": []})})
32
-
33
- # Load embedding model
34
- embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
35
-
36
- # ================================
37
- # HELPER FUNCTIONS
38
- # ================================
39
-
40
- def check_rate_limit(user_id):
41
- now = datetime.now()
42
- queries = user_queries.get(user_id, [])
43
- # Remove expired queries
44
- queries = [q for q in queries if q > now - timedelta(hours=WINDOW_HOURS)]
45
- user_queries[user_id] = queries
46
-
47
- if len(queries) >= MAX_QUERIES:
48
- next_allowed = min(queries) + timedelta(hours=WINDOW_HOURS)
49
- wait_seconds = int((next_allowed - now).total_seconds())
50
- return False, wait_seconds
51
- return True, 0
52
-
53
- def log_query(user_id):
54
- now = datetime.now()
55
- user_queries.setdefault(user_id, []).append(now)
56
-
57
- def find_in_dataset(question, threshold=0.75):
58
- if len(dataset["train"]) == 0:
59
- return None
60
- # Compute embedding for input
61
- question_emb = embed_model.encode(question, convert_to_tensor=True)
62
- # Load existing embeddings
63
- existing_embs = torch.tensor(dataset["train"]["embedding"]) if dataset["train"]["embedding"] else None
64
- if existing_embs is None or len(existing_embs) == 0:
65
- return None
66
- # Compute cosine similarities
67
- similarities = util.cos_sim(question_emb, existing_embs)[0]
68
- max_score, idx = torch.max(similarities, dim=0)
69
- if max_score >= threshold:
70
- return dataset["train"]["answer"][idx.item()]
71
  return None
72
 
73
- def save_qna(question, answer):
74
- # Load your dataset
75
- dataset = load_dataset("princemaxp/guardian-ai-qna", use_auth_token=os.environ["dataset_HF_TOKEN"])
76
-
77
- new_ds = Dataset.from_dict({"question": [question], "answer": [answer]})
78
-
79
- # Correct way to add new data
80
- dataset["train"] = concatenate_datasets([dataset["train"], new_ds])
81
-
82
- # Push updated dataset to HF
83
- dataset.push_to_hub("username/guardian-ai-qna", token=os.environ["dataset_HF_TOKEN"])
84
-
85
- def call_render(question):
86
- """
87
- Replace this with your actual Render API call logic
88
- that fetches the answer from the internet.
89
- """
90
- RENDER_API_URL = os.environ.get("RENDER_API_URL")
91
- if not RENDER_API_URL:
92
- return "Render API not configured."
93
- resp = requests.post(RENDER_API_URL, json={"question": question})
94
- if resp.status_code == 200:
95
- return resp.json().get("answer", "No answer found.")
96
- return "Error fetching answer from Render."
97
-
98
- # ================================
99
- # CHAT FUNCTION
100
- # ================================
101
-
102
- def chat(history, message, session_id):
103
- # Rate limit
104
- allowed, wait_seconds = check_rate_limit(session_id)
105
- if not allowed:
106
- return history + [(f"System", f"Rate limit reached. Try again in {wait_seconds//60} minutes.")], ""
107
-
108
- log_query(session_id)
109
-
110
- # Check dataset first (embedding-based)
111
- response = find_in_dataset(message)
112
- if response is None:
113
- # Call Render API fallback
114
- response = call_render(message)
115
- # Save in dataset
116
- save_qna(message, response)
117
-
118
- history.append(("User", message))
119
- history.append(("Guardian AI", response))
120
- return history, ""
121
-
122
- # ================================
123
- # GRADIO UI
124
- # ================================
125
- with gr.Blocks() as demo:
126
- gr.Markdown("## Guardian AI Chatbot")
127
- chatbot = gr.Chatbot()
128
- session_id = gr.Textbox(label="Session ID (unique per user)", value=str(time.time()), visible=False)
129
- msg = gr.Textbox(label="Enter your message")
130
- send_btn = gr.Button("Send")
131
-
132
- send_btn.click(fn=chat, inputs=[chatbot, msg, session_id], outputs=[chatbot, msg])
133
- msg.submit(fn=chat, inputs=[chatbot, msg, session_id], outputs=[chatbot, msg])
134
-
135
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ # app.py (Hugging Face Space)
2
+
 
 
 
 
 
 
3
  import requests
4
+ from datasets import load_dataset, Dataset, concatenate_datasets
5
+ from sentence_transformers import SentenceTransformer, util
6
+ import gradio as gr
7
+ import os
8
+
9
+ # ----------------------------
10
+ # Load datasets
11
+ # ----------------------------
12
+
13
+ # Your main dataset (where new Q/A will be saved)
14
+ main_ds_name = "princemaxp/cybersecurity-main"
15
+ main_ds = load_dataset(main_ds_name, split="train")
16
+
17
+ # External datasets (read-only)
18
+ external_datasets = [
19
+ ("Trendyol-Security/cybersecurity-defense-instruction-tuning-v2", "train"),
20
+ ("Rowden/CybersecurityQAA", "train"),
21
+ ("Nitral-AI/Cybersecurity-ShareGPT", "train"),
22
+ ]
23
+
24
+ ext_ds_list = [load_dataset(name, split=split) for name, split in external_datasets]
25
+
26
+ # Keywords dataset (for classification)
27
+ keywords_ds = load_dataset("princemaxp/cybersecurity-keywords", split="train")
28
+
29
+ # ----------------------------
30
+ # Embedding model
31
+ # ----------------------------
32
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
33
+
34
+ # Precompute embeddings
35
+ main_q_embeddings = embedder.encode(main_ds["question"], convert_to_tensor=True)
36
+
37
+ ext_q_embeddings = []
38
+ for ds in ext_ds_list:
39
+ ext_q_embeddings.append(embedder.encode(ds["question"], convert_to_tensor=True))
40
+
41
+ keywords = keywords_ds["keyword"]
42
+ keyword_embeddings = embedder.encode(keywords, convert_to_tensor=True)
43
+
44
+ # ----------------------------
45
+ # Helper functions
46
+ # ----------------------------
47
+
48
+ def is_cybersecurity_question(user_query, threshold=0.65):
49
+ query_embedding = embedder.encode(user_query, convert_to_tensor=True)
50
+ cos_sim = util.cos_sim(query_embedding, keyword_embeddings)
51
+ max_score = cos_sim.max().item()
52
+ return max_score >= threshold
53
+
54
+ def search_dataset(user_query, dataset, dataset_embeddings, top_k=1, threshold=0.7):
55
+ query_embedding = embedder.encode(user_query, convert_to_tensor=True)
56
+ cos_sim = util.cos_sim(query_embedding, dataset_embeddings)[0]
57
+ best_idx = cos_sim.argmax().item()
58
+ best_score = cos_sim[best_idx].item()
59
 
60
+ if best_score >= threshold:
61
+ return dataset[best_idx]["answer"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  return None
63
 
64
+ def call_render_service(question):
65
+ url = os.getenv("RENDER_URL", "https://render-python-app-4ty3.onrender.com/answer")
66
+ try:
67
+ response = requests.post(url, json={"question": question}, timeout=15)
68
+ if response.status_code == 200:
69
+ return response.json().get("answer", "No answer received.")
70
+ return "Render service error."
71
+ except Exception as e:
72
+ return f"Render service failed: {str(e)}"
73
+
74
+ def save_to_main_dataset(question, answer):
75
+ global main_ds, main_q_embeddings
76
+ new_row = {"question": [question], "answer": [answer]}
77
+ new_ds = Dataset.from_dict(new_row)
78
+ main_ds = concatenate_datasets([main_ds, new_ds])
79
+ main_q_embeddings = embedder.encode(main_ds["question"], convert_to_tensor=True)
80
+ # Push back to HF Hub (requires HF token set as secret)
81
+ main_ds.push_to_hub(main_ds_name)
82
+
83
+ def get_answer(user_query):
84
+ # Step 1: Check if cybersecurity-related
85
+ if not is_cybersecurity_question(user_query):
86
+ return "This doesn’t seem like a cybersecurity-related question. Please refine your query."
87
+
88
+ # Step 2: Search in main dataset
89
+ answer = search_dataset(user_query, main_ds, main_q_embeddings)
90
+ if answer:
91
+ return answer
92
+
93
+ # Step 3: Search in external datasets
94
+ for ds, emb in zip(ext_ds_list, ext_q_embeddings):
95
+ answer = search_dataset(user_query, ds, emb)
96
+ if answer:
97
+ save_to_main_dataset(user_query, answer)
98
+ return answer
99
+
100
+ # Step 4: Fallback → Render service
101
+ answer = call_render_service(user_query)
102
+ save_to_main_dataset(user_query, answer)
103
+ return answer
104
+
105
+ # ----------------------------
106
+ # Gradio interface
107
+ # ----------------------------
108
+ def chatbot_interface(user_input):
109
+ return get_answer(user_input)
110
+
111
+ iface = gr.Interface(
112
+ fn=chatbot_interface,
113
+ inputs=gr.Textbox(lines=2, placeholder="Ask a cybersecurity question..."),
114
+ outputs="text",
115
+ title="Guardian AI Chatbot",
116
+ description="Ask me anything about cybersecurity!"
117
+ )
118
+
119
+ if __name__ == "__main__":
120
+ iface.launch()