princemaxp commited on
Commit
be72f4b
·
verified ·
1 Parent(s): 2a46434

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -42
app.py CHANGED
@@ -1,76 +1,96 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
- from datasets import load_dataset, Dataset
4
- from huggingface_hub import login
5
  import os
6
 
7
- # --- Hugging Face Dataset Setup ---
8
- HF_TOKEN = os.environ.get("dataset_HF_TOKEN") # Secret in your HF Space
9
- login(token=HF_TOKEN)
 
 
 
10
 
11
- dataset_name = "YOUR_USERNAME/guardian-ai-qna" # Replace YOUR_USERNAME
12
- try:
13
- dataset = load_dataset(dataset_name)
14
- except:
15
- # If dataset is empty or not yet created, create an empty one
16
- dataset = Dataset.from_dict({"question": [], "answer": []})
17
-
18
- # --- Load model & tokenizer ---
19
- model_id = "google/gemma-2b-it"
20
- tokenizer = AutoTokenizer.from_pretrained(model_id)
21
- model = AutoModelForCausalLM.from_pretrained(model_id)
22
-
23
- generator = pipeline(
24
- "text-generation",
25
- model=model,
26
- tokenizer=tokenizer,
27
- device=-1 # CPU, change to 0 if GPU available
28
- )
29
-
30
- # --- System instruction ---
31
  SYSTEM_PROMPT = """You are Guardian AI, a friendly cybersecurity educator.
32
  Your goal is to explain cybersecurity concepts in simple, engaging language with examples.
33
  Always keep answers clear, short, and focused on security awareness.
 
34
  """
35
 
36
- # --- Save Q&A to dataset ---
37
- def save_qna(question, answer):
38
- global dataset
39
- new_entry = Dataset.from_dict({"question": [question], "answer": [answer]})
40
- dataset = dataset.concat(new_entry)
41
- dataset.push_to_hub(dataset_name, private=False) # push updates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # --- Chat function ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def chat(history, user_input):
45
- prompt = SYSTEM_PROMPT + "\nUser: " + user_input + "\nGuardian AI:"
 
 
 
 
 
 
46
  result = generator(
47
  prompt,
48
  max_new_tokens=200,
49
  do_sample=True,
50
  temperature=0.7,
51
  top_p=0.9
52
- )[0]['generated_text']
53
-
54
  response = result.split("Guardian AI:")[-1].strip()
55
  history.append((user_input, response))
56
-
57
- # Save to dataset
58
  save_qna(user_input, response)
59
-
60
  return history, history
61
 
62
- # --- Gradio UI ---
 
 
63
  with gr.Blocks() as demo:
64
  gr.Markdown("## 🛡️ Guardian AI – Cybersecurity Educator")
65
- chatbot = gr.Chatbot()
66
  state = gr.State([])
67
-
68
  with gr.Row():
69
  with gr.Column(scale=8):
70
  user_input = gr.Textbox(show_label=False, placeholder="Ask me about cybersecurity...")
71
  with gr.Column(scale=2):
72
  send_btn = gr.Button("Send")
73
-
74
  send_btn.click(chat, [state, user_input], [chatbot, state])
75
  user_input.submit(chat, [state, user_input], [chatbot, state])
76
 
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
+ from datasets import load_dataset, Dataset, concatenate_datasets
 
4
  import os
5
 
6
+ # -------------------------------
7
+ # Config
8
+ # -------------------------------
9
+ HF_TOKEN = os.environ["dataset_HF_TOKEN"]
10
+ DATASET_ID = "your-username/guardian-ai-qna" # replace with your HF username
11
+ MODEL_ID = "google/gemma-2b-it"
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  SYSTEM_PROMPT = """You are Guardian AI, a friendly cybersecurity educator.
14
  Your goal is to explain cybersecurity concepts in simple, engaging language with examples.
15
  Always keep answers clear, short, and focused on security awareness.
16
+ Use the examples from the Q&A memory to improve your answers.
17
  """
18
 
19
+ # -------------------------------
20
+ # Load model & tokenizer
21
+ # -------------------------------
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
23
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
24
+ generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1)
25
+
26
+ # -------------------------------
27
+ # Dataset functions
28
+ # -------------------------------
29
+ def load_qna_dataset():
30
+ try:
31
+ dataset = load_dataset(DATASET_ID, use_auth_token=HF_TOKEN)["train"]
32
+ except:
33
+ dataset = Dataset.from_dict({"question": [], "answer": []})
34
+ return dataset
35
+
36
+ def save_qna(user_input, response):
37
+ dataset = load_qna_dataset()
38
+ new_entry = Dataset.from_dict({"question": [user_input], "answer": [response]})
39
+ dataset = concatenate_datasets([dataset, new_entry])
40
+ dataset.push_to_hub(DATASET_ID, token=HF_TOKEN)
41
 
42
+ def retrieve_similar_qna(user_input, top_k=3):
43
+ dataset = load_qna_dataset()
44
+ if len(dataset) == 0:
45
+ return ""
46
+ # Simple keyword-based retrieval
47
+ # You can upgrade to semantic search later
48
+ relevant = []
49
+ for q, a in zip(dataset["question"], dataset["answer"]):
50
+ if any(word in user_input.lower() for word in q.lower().split()):
51
+ relevant.append(f"Q: {q}\nA: {a}")
52
+ if len(relevant) >= top_k:
53
+ break
54
+ return "\n".join(relevant)
55
+
56
+ # -------------------------------
57
+ # Chat function
58
+ # -------------------------------
59
  def chat(history, user_input):
60
+ # Retrieve past Q&A for context
61
+ context = retrieve_similar_qna(user_input)
62
+ prompt = SYSTEM_PROMPT
63
+ if context:
64
+ prompt += f"\n\nMemory of past Q&A:\n{context}"
65
+ prompt += f"\n\nUser: {user_input}\nGuardian AI:"
66
+
67
  result = generator(
68
  prompt,
69
  max_new_tokens=200,
70
  do_sample=True,
71
  temperature=0.7,
72
  top_p=0.9
73
+ )[0]["generated_text"]
74
+
75
  response = result.split("Guardian AI:")[-1].strip()
76
  history.append((user_input, response))
 
 
77
  save_qna(user_input, response)
 
78
  return history, history
79
 
80
+ # -------------------------------
81
+ # Gradio UI
82
+ # -------------------------------
83
  with gr.Blocks() as demo:
84
  gr.Markdown("## 🛡️ Guardian AI – Cybersecurity Educator")
85
+ chatbot = gr.Chatbot(type="messages") # Updated type to avoid deprecation warning
86
  state = gr.State([])
87
+
88
  with gr.Row():
89
  with gr.Column(scale=8):
90
  user_input = gr.Textbox(show_label=False, placeholder="Ask me about cybersecurity...")
91
  with gr.Column(scale=2):
92
  send_btn = gr.Button("Send")
93
+
94
  send_btn.click(chat, [state, user_input], [chatbot, state])
95
  user_input.submit(chat, [state, user_input], [chatbot, state])
96