Spaces:
Sleeping
Sleeping
File size: 4,827 Bytes
fbe73ed e959a70 fbe73ed 78db3ec fbe73ed e959a70 fbe73ed e959a70 fbe73ed e959a70 fbe73ed e959a70 fbe73ed e959a70 dbd06e6 a8452a4 fbe73ed e959a70 fbe73ed e959a70 fbe73ed e959a70 fbe73ed e959a70 fbe73ed e959a70 fbe73ed e959a70 fbe73ed e959a70 fbe73ed e959a70 fbe73ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# app.py
import os
import requests
import pandas as pd
from datasets import load_dataset, Dataset, concatenate_datasets
from sentence_transformers import SentenceTransformer, util
import gradio as gr
from transformers import pipeline
# =========================
# CONFIG
# =========================
MAIN_DATASET = "princemaxp/guardian-ai-qna"
KEYWORDS_DATASET = "princemaxp/cybersecurity-keywords"
# Tokens from secrets
HF_MODEL_TOKEN = os.environ.get("HF_TOKEN")
MAIN_DATASET_TOKEN = os.environ.get("DATASET_HF_TOKEN")
TRENDYOL_TOKEN = os.environ.get("DATASET_TRENDYOL_TOKEN")
ROW_TOKEN = os.environ.get("DATASET_ROW_CYBERQA_TOKEN")
SHAREGPT_TOKEN = os.environ.get("DATASET_SHAREGPT_TOKEN")
RENDER_API_URL = os.environ.get("RENDER_API_URL")
# =========================
# LOAD DATASETS
# =========================
# Main dataset (writable)
try:
main_ds = load_dataset(MAIN_DATASET, split="train", use_auth_token=MAIN_DATASET_TOKEN)
except Exception:
main_ds = Dataset.from_dict({"question": [], "answer": []})
# External datasets (read-only)
external_datasets = [
("trendyol/cybersecurity-defense-v2", TRENDYOL_TOKEN),
("Rowden/CybersecurityQAA", ROW_TOKEN),
("Nitral-AI/Cybersecurity-ShareGPT", SHAREGPT_TOKEN),
]
ext_ds_list = []
for name, token in external_datasets:
try:
ds = load_dataset(name, split="train", use_auth_token=token)
ext_ds_list.append(ds)
except Exception as e:
print(f"⚠ Could not load {name}: {e}")
# Keyword dataset (CSV)
try:
kw_ds = load_dataset(KEYWORDS_DATASET, split="train", use_auth_token=MAIN_DATASET_TOKEN)
keywords = set(kw_ds["keyword"])
except Exception as e:
print(f"⚠ Could not load keywords dataset: {e}")
keywords = set()
# =========================
# MODELS
# =========================
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
chat_model = pipeline("text-generation", model="gpt2", token=HF_MODEL_TOKEN)
# Precompute embeddings
def compute_embeddings(dataset):
if len(dataset) == 0:
return None
return embedder.encode(dataset["question"], convert_to_tensor=True)
main_embs = compute_embeddings(main_ds)
ext_embs = [compute_embeddings(ds) for ds in ext_ds_list]
# =========================
# HELPERS
# =========================
def is_cybersecurity_question(text: str) -> bool:
words = text.lower().split()
return any(kw.lower() in words for kw in keywords)
def search_dataset(user_query, dataset, dataset_embeddings, threshold=0.7):
if dataset_embeddings is None or len(dataset) == 0:
return None
query_emb = embedder.encode(user_query, convert_to_tensor=True)
cos_sim = util.cos_sim(query_emb, dataset_embeddings)[0]
best_idx = cos_sim.argmax().item()
best_score = cos_sim[best_idx].item()
if best_score >= threshold:
return dataset[best_idx]["answer"]
return None
def call_render(question):
try:
resp = requests.post(RENDER_API_URL, json={"question": question}, timeout=15)
if resp.status_code == 200:
return resp.json().get("answer", "No answer from Render.")
return f"Render error: {resp.status_code}"
except Exception as e:
return f"Render request failed: {e}"
def save_to_main(question, answer):
global main_ds, main_embs
new_row = {"question": [question], "answer": [answer]}
new_ds = Dataset.from_dict(new_row)
main_ds = concatenate_datasets([main_ds, new_ds])
main_embs = compute_embeddings(main_ds)
main_ds.push_to_hub(MAIN_DATASET, token=MAIN_DATASET_TOKEN)
# =========================
# ANSWER FLOW
# =========================
def get_answer(user_query):
if is_cybersecurity_question(user_query):
# 1. Check in main dataset
ans = search_dataset(user_query, main_ds, main_embs)
if ans:
return ans
# 2. Check external datasets
for ds, emb in zip(ext_ds_list, ext_embs):
ans = search_dataset(user_query, ds, emb)
if ans:
save_to_main(user_query, ans)
return ans
# 3. Fallback: Render
ans = call_render(user_query)
save_to_main(user_query, ans)
return ans
else:
# General Q → use chat model
gen = chat_model(user_query, max_length=100, num_return_sequences=1)
return gen[0]["generated_text"]
# =========================
# GRADIO UI
# =========================
def chatbot(user_input):
return get_answer(user_input)
iface = gr.Interface(
fn=chatbot,
inputs=gr.Textbox(lines=2, placeholder="Ask me anything..."),
outputs="text",
title="Guardian AI Chatbot",
description="Cybersecurity-focused chatbot with general fallback"
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)
|