|
|
|
|
|
|
|
import os |
|
os.environ.setdefault("OMP_NUM_THREADS", "1") |
|
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
|
import hashlib, hmac, secrets, sqlite3, time, csv |
|
from datetime import datetime |
|
from typing import Optional, Tuple |
|
|
|
import gradio as gr |
|
from First_Pass import ask |
|
|
|
TITLE = "Askstein β CT Rigidity / FE Q&A" |
|
DESC = ( |
|
"Login or register to use the chatbot. Choose your lab, then ask. After each answer, leave a π or π β feedback is saved for review." |
|
) |
|
|
|
|
|
|
|
LAB_LABELS = ["Nazarian Lab", "Freedman Lab", "Alboro Lab"] |
|
LAB_TO_CODE = { |
|
"Nazarian Lab": "nazarian", |
|
"Freedman Lab": "freedman", |
|
"Alboro Lab": "alboro", |
|
} |
|
CODE_TO_LAB = {v: k for k, v in LAB_TO_CODE.items()} |
|
|
|
def to_lab_code(label: str) -> str: |
|
return LAB_TO_CODE.get(label, "nazarian") |
|
|
|
def to_lab_label(code: str) -> str: |
|
return CODE_TO_LAB.get((code or "").lower(), "Nazarian Lab") |
|
|
|
|
|
|
|
DATA_DIR = os.path.abspath(os.getenv("DATA_DIR", "./data")) |
|
DB_PATH = os.path.join(DATA_DIR, "askstein.db") |
|
os.makedirs(DATA_DIR, exist_ok=True) |
|
|
|
def _db(): |
|
|
|
conn = sqlite3.connect(DB_PATH, check_same_thread=False) |
|
conn.execute("PRAGMA journal_mode=WAL;") |
|
conn.execute("PRAGMA synchronous=NORMAL;") |
|
return conn |
|
|
|
def _ensure_user_columns(conn: sqlite3.Connection): |
|
cur = conn.cursor() |
|
cur.execute("PRAGMA table_info(users)") |
|
cols = {row[1] for row in cur.fetchall()} |
|
|
|
if "first_name" not in cols: |
|
conn.execute("ALTER TABLE users ADD COLUMN first_name TEXT DEFAULT ''") |
|
if "last_name" not in cols: |
|
conn.execute("ALTER TABLE users ADD COLUMN last_name TEXT DEFAULT ''") |
|
if "lab_choice" not in cols: |
|
|
|
conn.execute("ALTER TABLE users ADD COLUMN lab_choice TEXT DEFAULT 'nazarian'") |
|
conn.commit() |
|
|
|
def init_db(): |
|
conn = _db() |
|
cur = conn.cursor() |
|
cur.execute(""" |
|
CREATE TABLE IF NOT EXISTS users ( |
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
email TEXT UNIQUE NOT NULL, |
|
password_hash TEXT NOT NULL, |
|
salt TEXT NOT NULL, |
|
created_at TEXT NOT NULL |
|
-- first_name, last_name, lab_choice added via migration |
|
) |
|
""") |
|
cur.execute(""" |
|
CREATE TABLE IF NOT EXISTS feedback ( |
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
user_id INTEGER, |
|
question TEXT NOT NULL, |
|
answer_preview TEXT NOT NULL, |
|
rating INTEGER NOT NULL, -- +1 = thumbs up, -1 = thumbs down |
|
created_at TEXT NOT NULL, |
|
FOREIGN KEY (user_id) REFERENCES users(id) |
|
) |
|
""") |
|
conn.commit() |
|
|
|
_ensure_user_columns(conn) |
|
conn.close() |
|
|
|
|
|
|
|
def _hash_password(password: str, salt_hex: Optional[str]=None) -> Tuple[str, str]: |
|
if not salt_hex: |
|
salt = secrets.token_bytes(16) |
|
else: |
|
salt = bytes.fromhex(salt_hex) |
|
dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, 100_000, dklen=32) |
|
return dk.hex(), salt.hex() |
|
|
|
def _verify_password(password: str, salt_hex: str, expected_hash_hex: str) -> bool: |
|
actual, _ = _hash_password(password, salt_hex) |
|
return hmac.compare_digest(actual, expected_hash_hex) |
|
|
|
|
|
|
|
def register_user(first: str, last: str, email: str, password: str, lab_label: str) -> Tuple[bool, str]: |
|
first = (first or "").strip() |
|
last = (last or "").strip() |
|
email = (email or "").strip().lower() |
|
lab_code = to_lab_code((lab_label or "").strip()) |
|
|
|
if not first or len(first) > 80: |
|
return False, "Please enter a valid first name." |
|
if not last or len(last) > 80: |
|
return False, "Please enter a valid last name." |
|
if not email or "@" not in email or len(email) > 200: |
|
return False, "Please enter a valid email." |
|
if not password or len(password) < 8: |
|
return False, "Password must be at least 8 characters." |
|
if lab_code not in CODE_TO_LAB: |
|
return False, "Please select a lab." |
|
|
|
pw_hash, salt = _hash_password(password) |
|
try: |
|
conn = _db() |
|
conn.execute( |
|
"INSERT INTO users (email, password_hash, salt, created_at, first_name, last_name, lab_choice) VALUES (?, ?, ?, ?, ?, ?, ?)", |
|
(email, pw_hash, salt, datetime.utcnow().isoformat(), first, last, lab_code) |
|
) |
|
conn.commit() |
|
return True, "Registration successful! You can log in now." |
|
except sqlite3.IntegrityError: |
|
return False, "This email is already registered." |
|
except Exception as e: |
|
return False, f"Registration failed: {e}" |
|
finally: |
|
conn.close() |
|
|
|
def login_user(email: str, password: str) -> Tuple[bool, str, Optional[int], Optional[str], Optional[str], Optional[str]]: |
|
email = (email or "").strip().lower() |
|
if not email or not password: |
|
return False, "Missing email or password.", None, None, None, None |
|
try: |
|
conn = _db() |
|
cur = conn.cursor() |
|
cur.execute("SELECT id, password_hash, salt, first_name, last_name, lab_choice FROM users WHERE email = ?", (email,)) |
|
row = cur.fetchone() |
|
if not row: |
|
return False, "Invalid email or password.", None, None, None, None |
|
uid, pw_hash, salt, first, last, lab_code = row |
|
if _verify_password(password, salt, pw_hash): |
|
return True, "Login successful.", uid, first or "", last or "", (lab_code or "nazarian") |
|
return False, "Invalid email or password.", None, None, None, None |
|
except Exception as e: |
|
return False, f"Login failed: {e}", None, None, None, None |
|
finally: |
|
conn.close() |
|
|
|
def save_feedback(user_id: Optional[int], question: str, answer: str, rating: int) -> Tuple[bool, str]: |
|
preview = (answer or "").replace("\n", " ").strip() |
|
if len(preview) > 350: |
|
preview = preview[:350] + "β¦" |
|
try: |
|
conn = _db() |
|
conn.execute( |
|
"INSERT INTO feedback (user_id, question, answer_preview, rating, created_at) VALUES (?, ?, ?, ?, ?)", |
|
(user_id, question, preview, int(rating), datetime.utcnow().isoformat()) |
|
) |
|
conn.commit() |
|
return True, "Thanks for your feedback!" |
|
except Exception as e: |
|
return False, f"Could not save feedback: {e}" |
|
finally: |
|
conn.close() |
|
|
|
def export_feedback_csv(all_users: bool, user_id: Optional[int]=None) -> str: |
|
|
|
ts = int(time.time()) |
|
out_path = os.path.join(DATA_DIR, f"feedback_export_{'all' if all_users else user_id}_{ts}.csv") |
|
conn = _db() |
|
cur = conn.cursor() |
|
if all_users: |
|
cur.execute(""" |
|
SELECT f.id, f.created_at, f.rating, f.question, f.answer_preview, |
|
u.email, u.first_name, u.last_name, u.lab_choice |
|
FROM feedback f LEFT JOIN users u ON f.user_id = u.id |
|
ORDER BY f.id DESC |
|
""") |
|
else: |
|
cur.execute(""" |
|
SELECT f.id, f.created_at, f.rating, f.question, f.answer_preview, |
|
u.email, u.first_name, u.last_name, u.lab_choice |
|
FROM feedback f LEFT JOIN users u ON f.user_id = u.id |
|
WHERE f.user_id = ? |
|
ORDER BY f.id DESC |
|
""", (user_id,)) |
|
rows = cur.fetchall() |
|
conn.close() |
|
with open(out_path, "w", newline="", encoding="utf-8") as f: |
|
w = csv.writer(f) |
|
w.writerow(["id","created_at","rating","question","answer_preview","user_email","first_name","last_name","lab"]) |
|
for r in rows: |
|
|
|
r = list(r) |
|
r[-1] = to_lab_label(r[-1] or "nazarian") |
|
w.writerow(r) |
|
return out_path |
|
|
|
def export_users_csv() -> str: |
|
ts = int(time.time()) |
|
out_path = os.path.join(DATA_DIR, f"users_export_{ts}.csv") |
|
conn = _db() |
|
cur = conn.cursor() |
|
cur.execute("SELECT id, email, first_name, last_name, lab_choice, created_at FROM users ORDER BY id DESC") |
|
rows = cur.fetchall() |
|
conn.close() |
|
with open(out_path, "w", newline="", encoding="utf-8") as f: |
|
w = csv.writer(f) |
|
w.writerow(["id","email","first_name","last_name","lab","created_at"]) |
|
for r in rows: |
|
r = list(r) |
|
r[4] = to_lab_label(r[4] or "nazarian") |
|
w.writerow(r) |
|
return out_path |
|
|
|
|
|
init_db() |
|
|
|
ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", "").strip() |
|
|
|
|
|
|
|
with gr.Blocks(title=TITLE) as demo: |
|
gr.Markdown(f"## {TITLE}\n{DESC}") |
|
|
|
|
|
st_user_id = gr.State(value=None) |
|
st_user_email = gr.State(value=None) |
|
st_user_fname = gr.State(value="") |
|
st_user_lname = gr.State(value="") |
|
st_user_lab = gr.State(value="nazarian") |
|
st_last_q = gr.State(value="") |
|
st_last_a = gr.State(value="") |
|
st_can_fb = gr.State(value=False) |
|
st_admin_ok = gr.State(value=False) |
|
|
|
|
|
auth_view = gr.Column(visible=True) |
|
with auth_view: |
|
with gr.Tabs(): |
|
with gr.Tab("Login"): |
|
login_email = gr.Textbox(label="Email", placeholder="you@example.com") |
|
login_pass = gr.Textbox(label="Password", type="password", placeholder="Your password") |
|
login_btn = gr.Button("Log in", variant="primary") |
|
login_msg = gr.Markdown("") |
|
with gr.Tab("Register"): |
|
reg_first = gr.Textbox(label="First name", placeholder="Ada") |
|
reg_last = gr.Textbox(label="Last name", placeholder="Lovelace") |
|
reg_email = gr.Textbox(label="Email", placeholder="you@example.com") |
|
reg_pass = gr.Textbox(label="Password (min 8 chars)", type="password") |
|
reg_lab = gr.Dropdown(choices=LAB_LABELS, value="Nazarian Lab", label="Lab", multiselect=False) |
|
reg_btn = gr.Button("Create account", variant="primary") |
|
reg_msg = gr.Markdown("") |
|
|
|
|
|
chat_view = gr.Column(visible=False) |
|
with chat_view: |
|
welcome_md = gr.Markdown("### Chat") |
|
with gr.Row(): |
|
inp = gr.Textbox( |
|
label="Your question", |
|
placeholder="e.g., How is bending rigidity (EI) computed from a cortical cross-section?", |
|
lines=3, |
|
) |
|
out = gr.Textbox(label="Askstein", lines=12) |
|
with gr.Row(): |
|
btn_submit = gr.Button("Ask", variant="primary") |
|
btn_logout = gr.Button("Log out") |
|
|
|
|
|
with gr.Row(): |
|
fb_up = gr.Button("π Helpful", visible=True) |
|
fb_down = gr.Button("π Not helpful", visible=True) |
|
fb_status = gr.Markdown("", visible=True) |
|
|
|
|
|
with gr.Row(): |
|
my_export_btn = gr.Button("Download my feedback (CSV)") |
|
my_export_file= gr.File(label="Your feedback CSV", visible=False) |
|
|
|
|
|
admin_view = gr.Column(visible=True) |
|
with admin_view: |
|
gr.Markdown("### Admin (enter token to unlock)") |
|
admin_token_in = gr.Textbox(label="Admin token", type="password", placeholder="Set ADMIN_TOKEN env to use") |
|
admin_unlock = gr.Button("Unlock Admin") |
|
admin_status = gr.Markdown("") |
|
with gr.Group(visible=False) as admin_controls: |
|
all_export_btn = gr.Button("Export ALL feedback (CSV)") |
|
users_export_btn= gr.Button("Export users (CSV)") |
|
all_export_file = gr.File(label="All feedback CSV", visible=False) |
|
users_export_file= gr.File(label="Users CSV", visible=False) |
|
|
|
|
|
examples = gr.Examples( |
|
examples=[ |
|
"Define axial rigidity (EA) and how it is estimated from CT-derived cortical masks.", |
|
"How does torsional rigidity (GJ) relate to polar moment of area in FE pre-processing?", |
|
"What are typical pitfalls when mapping Hounsfield units to elastic modulus?", |
|
"What boundary conditions are common in long-bone FE bending simulations?", |
|
], |
|
inputs=inp, |
|
) |
|
|
|
|
|
|
|
|
|
def on_register(first, last, email, password, lab_label): |
|
ok, msg = register_user(first, last, email, password, lab_label) |
|
return gr.Markdown.update(value=msg), None, None |
|
|
|
reg_btn.click(on_register, [reg_first, reg_last, reg_email, reg_pass, reg_lab], [reg_msg]) |
|
|
|
|
|
def on_login(email, password): |
|
ok, msg, uid, first, last, lab_code = login_user(email, password) |
|
if ok: |
|
label = to_lab_label(lab_code) |
|
welcome = f"Welcome, **{first} {last}** ({label}) β you are logged in as **{email.strip().lower()}**." |
|
return ( |
|
gr.Markdown.update(value=welcome), |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
uid, email.strip().lower(), first, last, lab_code |
|
) |
|
else: |
|
return gr.Markdown.update(value=msg), gr.update(), gr.update(), None, None, "", "", "nazarian" |
|
|
|
login_btn.click( |
|
on_login, |
|
[login_email, login_pass], |
|
[login_msg, auth_view, chat_view, st_user_id, st_user_email, st_user_fname, st_user_lname, st_user_lab] |
|
) |
|
|
|
|
|
def on_logout(): |
|
|
|
return ( |
|
None, None, "", "", "nazarian", "", "", False, |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.Markdown.update(value=""), |
|
) |
|
|
|
btn_logout.click( |
|
on_logout, |
|
[], |
|
[st_user_id, st_user_email, st_user_fname, st_user_lname, st_user_lab, st_last_q, st_last_a, st_can_fb, auth_view, chat_view, fb_status] |
|
) |
|
|
|
|
|
def on_ask(user_id, q, lab_code): |
|
q = (q or "").strip() |
|
if not q: |
|
return gr.Textbox.update(value="Please enter a question."), q, "", False |
|
try: |
|
|
|
lab_hint = to_lab_label(lab_code) |
|
q_eff = f"{q} ({lab_hint})" |
|
a = ask(q_eff) |
|
except Exception as e: |
|
a = f"[runtime error] {e}" |
|
|
|
return gr.Textbox.update(value=a), q, a, True |
|
|
|
btn_submit.click( |
|
on_ask, |
|
[st_user_id, inp, st_user_lab], |
|
[out, st_last_q, st_last_a, st_can_fb] |
|
) |
|
inp.submit( |
|
on_ask, |
|
[st_user_id, inp, st_user_lab], |
|
[out, st_last_q, st_last_a, st_can_fb] |
|
) |
|
|
|
|
|
def on_feedback(user_id, can_fb, last_q, last_a, rating): |
|
if not can_fb or not last_q or not last_a: |
|
return gr.Markdown.update(value="No recent answer to rate."), False |
|
ok, msg = save_feedback(user_id, last_q, last_a, rating) |
|
return gr.Markdown.update(value=msg), False |
|
|
|
fb_up.click( |
|
lambda uid, can, q, a: on_feedback(uid, can, q, a, +1), |
|
[st_user_id, st_can_fb, st_last_q, st_last_a], |
|
[fb_status, st_can_fb] |
|
) |
|
fb_down.click( |
|
lambda uid, can, q, a: on_feedback(uid, can, q, a, -1), |
|
[st_user_id, st_can_fb, st_last_q, st_last_a], |
|
[fb_status, st_can_fb] |
|
) |
|
|
|
|
|
def on_my_export(user_id): |
|
if not user_id: |
|
return gr.File.update(visible=False), gr.Markdown.update(value="Please log in first.") |
|
path = export_feedback_csv(all_users=False, user_id=user_id) |
|
return gr.File.update(value=path, visible=True), gr.Markdown.update(value="") |
|
|
|
my_export_btn.click( |
|
on_my_export, |
|
[st_user_id], |
|
[my_export_file, fb_status] |
|
) |
|
|
|
|
|
ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", "").strip() |
|
|
|
def on_admin_unlock(token): |
|
ok = bool(ADMIN_TOKEN) and (token.strip() == ADMIN_TOKEN) |
|
if ok: |
|
return True, gr.Markdown.update(value="Admin unlocked β
"), gr.update(visible=True) |
|
msg = "Invalid token or ADMIN_TOKEN not set." |
|
return False, gr.Markdown.update(value=msg), gr.update(visible=False) |
|
|
|
admin_unlock.click( |
|
on_admin_unlock, |
|
[admin_token_in], |
|
[st_admin_ok, admin_status, admin_controls] |
|
) |
|
|
|
|
|
def on_export_all(admin_ok): |
|
if not admin_ok: |
|
return gr.File.update(visible=False), gr.Markdown.update(value="Admin is locked.") |
|
path = export_feedback_csv(all_users=True) |
|
return gr.File.update(value=path, visible=True), gr.Markdown.update(value="") |
|
|
|
def on_export_users(admin_ok): |
|
if not admin_ok: |
|
return gr.File.update(visible=False), gr.Markdown.update(value="Admin is locked.") |
|
path = export_users_csv() |
|
return gr.File.update(value=path, visible=True), gr.Markdown.update(value="") |
|
|
|
all_export_btn.click( |
|
on_export_all, |
|
[st_admin_ok], |
|
[all_export_file, admin_status] |
|
) |
|
users_export_btn.click( |
|
on_export_users, |
|
[st_admin_ok], |
|
[users_export_file, admin_status] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860"))) |
|
|