File size: 4,827 Bytes
fbe73ed
e959a70
fbe73ed
78db3ec
fbe73ed
e959a70
 
 
fbe73ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e959a70
 
 
fbe73ed
 
 
e959a70
 
fbe73ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e959a70
 
fbe73ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e959a70
 
 
 
dbd06e6
a8452a4
fbe73ed
e959a70
fbe73ed
 
 
 
e959a70
fbe73ed
e959a70
fbe73ed
 
e959a70
 
 
fbe73ed
 
 
 
 
 
e959a70
 
fbe73ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e959a70
 
 
fbe73ed
 
e959a70
 
fbe73ed
e959a70
 
 
fbe73ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# app.py

import os
import requests
import pandas as pd
from datasets import load_dataset, Dataset, concatenate_datasets
from sentence_transformers import SentenceTransformer, util
import gradio as gr
from transformers import pipeline

# =========================
# CONFIG
# =========================
MAIN_DATASET = "princemaxp/guardian-ai-qna"
KEYWORDS_DATASET = "princemaxp/cybersecurity-keywords"

# Tokens from secrets
HF_MODEL_TOKEN = os.environ.get("HF_TOKEN")
MAIN_DATASET_TOKEN = os.environ.get("DATASET_HF_TOKEN")
TRENDYOL_TOKEN = os.environ.get("DATASET_TRENDYOL_TOKEN")
ROW_TOKEN = os.environ.get("DATASET_ROW_CYBERQA_TOKEN")
SHAREGPT_TOKEN = os.environ.get("DATASET_SHAREGPT_TOKEN")
RENDER_API_URL = os.environ.get("RENDER_API_URL")

# =========================
# LOAD DATASETS
# =========================

# Main dataset (writable)
try:
    main_ds = load_dataset(MAIN_DATASET, split="train", use_auth_token=MAIN_DATASET_TOKEN)
except Exception:
    main_ds = Dataset.from_dict({"question": [], "answer": []})

# External datasets (read-only)
external_datasets = [
    ("trendyol/cybersecurity-defense-v2", TRENDYOL_TOKEN),
    ("Rowden/CybersecurityQAA", ROW_TOKEN),
    ("Nitral-AI/Cybersecurity-ShareGPT", SHAREGPT_TOKEN),
]

ext_ds_list = []
for name, token in external_datasets:
    try:
        ds = load_dataset(name, split="train", use_auth_token=token)
        ext_ds_list.append(ds)
    except Exception as e:
        print(f"⚠ Could not load {name}: {e}")

# Keyword dataset (CSV)
try:
    kw_ds = load_dataset(KEYWORDS_DATASET, split="train", use_auth_token=MAIN_DATASET_TOKEN)
    keywords = set(kw_ds["keyword"])
except Exception as e:
    print(f"⚠ Could not load keywords dataset: {e}")
    keywords = set()

# =========================
# MODELS
# =========================
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
chat_model = pipeline("text-generation", model="gpt2", token=HF_MODEL_TOKEN)

# Precompute embeddings
def compute_embeddings(dataset):
    if len(dataset) == 0:
        return None
    return embedder.encode(dataset["question"], convert_to_tensor=True)

main_embs = compute_embeddings(main_ds)
ext_embs = [compute_embeddings(ds) for ds in ext_ds_list]

# =========================
# HELPERS
# =========================

def is_cybersecurity_question(text: str) -> bool:
    words = text.lower().split()
    return any(kw.lower() in words for kw in keywords)

def search_dataset(user_query, dataset, dataset_embeddings, threshold=0.7):
    if dataset_embeddings is None or len(dataset) == 0:
        return None
    query_emb = embedder.encode(user_query, convert_to_tensor=True)
    cos_sim = util.cos_sim(query_emb, dataset_embeddings)[0]
    best_idx = cos_sim.argmax().item()
    best_score = cos_sim[best_idx].item()
    if best_score >= threshold:
        return dataset[best_idx]["answer"]
    return None

def call_render(question):
    try:
        resp = requests.post(RENDER_API_URL, json={"question": question}, timeout=15)
        if resp.status_code == 200:
            return resp.json().get("answer", "No answer from Render.")
        return f"Render error: {resp.status_code}"
    except Exception as e:
        return f"Render request failed: {e}"

def save_to_main(question, answer):
    global main_ds, main_embs
    new_row = {"question": [question], "answer": [answer]}
    new_ds = Dataset.from_dict(new_row)
    main_ds = concatenate_datasets([main_ds, new_ds])
    main_embs = compute_embeddings(main_ds)
    main_ds.push_to_hub(MAIN_DATASET, token=MAIN_DATASET_TOKEN)

# =========================
# ANSWER FLOW
# =========================

def get_answer(user_query):
    if is_cybersecurity_question(user_query):
        # 1. Check in main dataset
        ans = search_dataset(user_query, main_ds, main_embs)
        if ans:
            return ans

        # 2. Check external datasets
        for ds, emb in zip(ext_ds_list, ext_embs):
            ans = search_dataset(user_query, ds, emb)
            if ans:
                save_to_main(user_query, ans)
                return ans

        # 3. Fallback: Render
        ans = call_render(user_query)
        save_to_main(user_query, ans)
        return ans
    else:
        # General Q → use chat model
        gen = chat_model(user_query, max_length=100, num_return_sequences=1)
        return gen[0]["generated_text"]

# =========================
# GRADIO UI
# =========================

def chatbot(user_input):
    return get_answer(user_input)

iface = gr.Interface(
    fn=chatbot,
    inputs=gr.Textbox(lines=2, placeholder="Ask me anything..."),
    outputs="text",
    title="Guardian AI Chatbot",
    description="Cybersecurity-focused chatbot with general fallback"
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860)