|
|
|
|
|
|
|
""" |
|
open_ended_question_generator_secure.py |
|
|
|
End-to-end script to generate open-ended questions from context(s) with: |
|
- Robust list-formatted parsing |
|
- CLI with single or batch inputs (TXT/CSV) |
|
- Reproducibility (seed) |
|
- Device auto-select (CUDA / MPS / CPU) |
|
- Export to JSON / CSV / TXT |
|
- Optional AES-256-like authenticated encryption via Fernet (with PBKDF2 key derivation) |
|
- Optional decryption utility |
|
|
|
Dependencies: |
|
pip install torch transformers cryptography |
|
|
|
Example: |
|
python open_ended_question_generator_secure.py \ |
|
--context "AGI for cosmology" --n 5 --model gpt2-large \ |
|
--out questions.json --format json --encrypt --password "your-secret" |
|
""" |
|
|
|
import os |
|
import re |
|
import csv |
|
import json |
|
import argparse |
|
import getpass |
|
import base64 |
|
import sys |
|
from typing import List, Dict, Tuple, Optional |
|
|
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
try: |
|
from cryptography.fernet import Fernet |
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC |
|
from cryptography.hazmat.primitives import hashes |
|
from cryptography.hazmat.backends import default_backend |
|
except Exception: |
|
Fernet = None |
|
|
|
|
|
|
|
|
|
|
|
def select_device() -> torch.device: |
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
|
return torch.device("mps") |
|
if torch.cuda.is_available(): |
|
return torch.device("cuda") |
|
return torch.device("cpu") |
|
|
|
|
|
|
|
|
|
|
|
PROMPT_TEMPLATE = """You are a master at generating deep, open-ended, and thought-provoking questions. |
|
Each question must be: |
|
- Self-contained and understandable without extra context. |
|
- Exploratory (not answerable with yes/no). |
|
- Written in clear, engaging language. |
|
|
|
Context: |
|
{context} |
|
|
|
Output exactly {n} questions as a numbered list, one per line, formatted like: |
|
1. ... |
|
2. ... |
|
3. ... |
|
No extra commentary, no headings, no explanations โ just the list. |
|
""" |
|
|
|
def build_prompt(context: str, n: int) -> str: |
|
return PROMPT_TEMPLATE.format(context=context.strip(), n=n) |
|
|
|
_Q_LINE_RE = re.compile(r"^\s*(\d+)\.\s+(.*\S)\s*$") |
|
|
|
def normalize_q(q: str) -> str: |
|
q = q.strip() |
|
|
|
if not q.endswith("?"): |
|
q += "?" |
|
return q |
|
|
|
def parse_questions_from_text(text: str, n: int) -> List[str]: |
|
lines = text.splitlines() |
|
candidates = [] |
|
for line in lines: |
|
m = _Q_LINE_RE.match(line) |
|
if m: |
|
q_text = normalize_q(m.group(2)) |
|
candidates.append(q_text) |
|
|
|
seen = set() |
|
unique = [] |
|
for q in candidates: |
|
key = q.lower().strip() |
|
if key not in seen: |
|
seen.add(key) |
|
unique.append(q) |
|
return unique[:n] |
|
|
|
|
|
|
|
|
|
|
|
def load_model_and_tokenizer(model_name: str, device: torch.device): |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
model.to(device) |
|
|
|
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
return model, tokenizer |
|
|
|
def generate_questions_once( |
|
model, |
|
tokenizer, |
|
device: torch.device, |
|
context: str, |
|
n: int, |
|
max_new_tokens: int, |
|
temperature: float, |
|
top_p: float, |
|
top_k: int, |
|
) -> List[str]: |
|
prompt = build_prompt(context, n) |
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
output = model.generate( |
|
**inputs, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
do_sample=True, |
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
) |
|
decoded = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
questions = parse_questions_from_text(decoded, n) |
|
return questions |
|
|
|
def generate_questions( |
|
model, |
|
tokenizer, |
|
device: torch.device, |
|
context: str, |
|
n: int = 3, |
|
max_new_tokens: int = 200, |
|
temperature: float = 0.95, |
|
top_p: float = 0.95, |
|
top_k: int = 50, |
|
seed: Optional[int] = None, |
|
attempts: int = 3, |
|
) -> List[str]: |
|
if seed is not None: |
|
torch.manual_seed(seed) |
|
if device.type == "cuda": |
|
torch.cuda.manual_seed_all(seed) |
|
collected: List[str] = [] |
|
tried = 0 |
|
while len(collected) < n and tried < attempts: |
|
tried += 1 |
|
|
|
temp = min(1.2, max(0.7, temperature + 0.1 * (tried - 1))) |
|
qs = generate_questions_once( |
|
model, tokenizer, device, context, n, max_new_tokens, temp, top_p, top_k |
|
) |
|
|
|
existing = set([q.lower().strip() for q in collected]) |
|
for q in qs: |
|
key = q.lower().strip() |
|
if key not in existing and len(collected) < n: |
|
collected.append(q) |
|
existing.add(key) |
|
|
|
while len(collected) < n: |
|
collected.append(collected[-1] + " (expand)") if collected else collected.append("What deeper questions arise from this context?") |
|
return collected[:n] |
|
|
|
|
|
|
|
|
|
|
|
def load_contexts(source_text: Optional[str], source_file: Optional[str]) -> List[Tuple[str, str]]: |
|
""" |
|
Returns list of (context_id, context_text). |
|
- If source_text is provided, returns single-item list. |
|
- If CSV file: expects a 'context' column. |
|
- If TXT/MD: splits on lines containing only '---' or returns whole file as one context. |
|
""" |
|
out: List[Tuple[str, str]] = [] |
|
if source_text: |
|
out.append(("context_1", source_text.strip())) |
|
return out |
|
if not source_file: |
|
raise ValueError("Either --context or --context-file is required.") |
|
if not os.path.exists(source_file): |
|
raise FileNotFoundError(f"Context file not found: {source_file}") |
|
|
|
ext = os.path.splitext(source_file)[1].lower() |
|
if ext == ".csv": |
|
with open(source_file, "r", encoding="utf-8", newline="") as f: |
|
reader = csv.DictReader(f) |
|
if "context" not in reader.fieldnames: |
|
raise ValueError("CSV must have a 'context' column.") |
|
for i, row in enumerate(reader, start=1): |
|
ctx = (row.get("context") or "").strip() |
|
if ctx: |
|
out.append((f"context_{i}", ctx)) |
|
else: |
|
|
|
with open(source_file, "r", encoding="utf-8") as f: |
|
content = f.read() |
|
parts = re.split(r"^\s*---\s*$", content, flags=re.MULTILINE) |
|
parts = [p.strip() for p in parts if p.strip()] |
|
if not parts: |
|
raise ValueError("No context found in file.") |
|
for i, ctx in enumerate(parts, start=1): |
|
out.append((f"context_{i}", ctx)) |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
def write_json(out_path: str, rows: List[Dict]): |
|
with open(out_path, "w", encoding="utf-8") as f: |
|
json.dump(rows, f, ensure_ascii=False, indent=2) |
|
|
|
def write_csv(out_path: str, rows: List[Dict], n: int): |
|
fieldnames = ["context_id", "context"] + [f"q{i}" for i in range(1, n + 1)] |
|
with open(out_path, "w", encoding="utf-8", newline="") as f: |
|
writer = csv.DictWriter(f, fieldnames=fieldnames) |
|
writer.writeheader() |
|
for r in rows: |
|
writer.writerow(r) |
|
|
|
def write_txt(out_path: str, rows: List[Dict], n: int): |
|
with open(out_path, "w", encoding="utf-8") as f: |
|
for r in rows: |
|
f.write(f"[{r['context_id']}]\n") |
|
f.write(r["context"].strip() + "\n") |
|
for i in range(1, n + 1): |
|
f.write(f"{i}. {r[f'q{i}']}\n") |
|
f.write("\n") |
|
|
|
|
|
|
|
|
|
|
|
MAGIC = b"QSEC1" |
|
|
|
def require_crypto(): |
|
if Fernet is None: |
|
raise RuntimeError("Encryption requested but 'cryptography' is not installed. Run: pip install cryptography") |
|
|
|
def derive_key_from_password(password: str, salt: bytes) -> bytes: |
|
kdf = PBKDF2HMAC( |
|
algorithm=hashes.SHA256(), |
|
length=32, |
|
salt=salt, |
|
iterations=200_000, |
|
backend=default_backend(), |
|
) |
|
key = kdf.derive(password.encode("utf-8")) |
|
return base64.urlsafe_b64encode(key) |
|
|
|
def encrypt_file(in_path: str, out_path: str, password: str): |
|
require_crypto() |
|
with open(in_path, "rb") as f: |
|
plaintext = f.read() |
|
salt = os.urandom(16) |
|
key = derive_key_from_password(password, salt) |
|
fernet = Fernet(key) |
|
ciphertext = fernet.encrypt(plaintext) |
|
with open(out_path, "wb") as f: |
|
f.write(MAGIC + salt + ciphertext) |
|
|
|
def decrypt_file(in_path: str, out_path: str, password: str): |
|
require_crypto() |
|
with open(in_path, "rb") as f: |
|
blob = f.read() |
|
if not blob.startswith(MAGIC) or len(blob) < len(MAGIC) + 16 + 1: |
|
raise ValueError("Invalid or unsupported encrypted file.") |
|
salt = blob[len(MAGIC):len(MAGIC)+16] |
|
ciphertext = blob[len(MAGIC)+16:] |
|
key = derive_key_from_password(password, salt) |
|
fernet = Fernet(key) |
|
plaintext = fernet.decrypt(ciphertext) |
|
with open(out_path, "wb") as f: |
|
f.write(plaintext) |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Generate deep open-ended questions with optional encryption/decryption.") |
|
mode = parser.add_mutuallyExclusiveGroup(required=True) |
|
mode.add_argument("--generate", action="store_true", help="Generate questions from context(s).") |
|
mode.add_argument("--decrypt", action="store_true", help="Decrypt an encrypted file (no generation).") |
|
|
|
|
|
parser.add_argument("--context", type=str, help="Inline context text.") |
|
parser.add_argument("--context-file", type=str, help="Path to TXT/MD (split by ---) or CSV with 'context' column.") |
|
parser.add_argument("--n", type=int, default=3, help="Number of questions to generate per context.") |
|
parser.add_argument("--model", type=str, default="gpt2-large", help="HuggingFace model name.") |
|
parser.add_argument("--max-new-tokens", type=int, default=220, help="Max new tokens for generation.") |
|
parser.add_argument("--temperature", type=float, default=0.95, help="Sampling temperature.") |
|
parser.add_argument("--top-p", type=float, default=0.95, help="Top-p nucleus sampling.") |
|
parser.add_argument("--top-k", type=int, default=50, help="Top-k sampling.") |
|
parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.") |
|
parser.add_argument("--attempts", type=int, default=3, help="Max attempts to reach exactly n questions.") |
|
|
|
|
|
parser.add_argument("--out", type=str, default=None, help="Output file path. If omitted, prints to stdout.") |
|
parser.add_argument("--format", type=str, choices=["json", "csv", "txt"], default="json", help="Output format when generating.") |
|
parser.add_argument("--encrypt", action="store_true", help="Encrypt the output file after generation.") |
|
parser.add_argument("--password", type=str, default=None, help="Password for encryption/decryption. If omitted, prompts securely.") |
|
|
|
|
|
parser.add_argument("--in", dest="in_path", type=str, help="Input file for decryption (encrypted).") |
|
parser.add_argument("--out-decrypted", type=str, help="Output file for decrypted plaintext.") |
|
|
|
args = parser.parse_args() |
|
|
|
device = select_device() |
|
|
|
if args.decrypt: |
|
|
|
if not args.in_path or not args.out_decrypted: |
|
parser.error("--decrypt requires --in and --out-decrypted.") |
|
password = args.password or getpass.getpass("Enter password: ") |
|
decrypt_file(args.in_path, args.out_decrypted, password) |
|
print(f"Decrypted to: {args.out_decrypted}") |
|
return |
|
|
|
|
|
contexts = load_contexts(args.context, args.context_file) |
|
model, tokenizer = load_model_and_tokenizer(args.model, device) |
|
|
|
rows: List[Dict] = [] |
|
for ctx_id, ctx in contexts: |
|
qs = generate_questions( |
|
model=model, |
|
tokenizer=tokenizer, |
|
device=device, |
|
context=ctx, |
|
n=args.n, |
|
max_new_tokens=args.max_new_tokens, |
|
temperature=args.temperature, |
|
top_p=args.top_p, |
|
top_k=args.top_k, |
|
seed=args.seed, |
|
attempts=args.attempts, |
|
) |
|
row = {"context_id": ctx_id, "context": ctx} |
|
for i, q in enumerate(qs, start=1): |
|
row[f"q{i}"] = q |
|
rows.append(row) |
|
|
|
|
|
if args.out: |
|
out_path = args.out |
|
os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) |
|
if args.format == "json": |
|
write_json(out_path, rows) |
|
elif args.format == "csv": |
|
write_csv(out_path, rows, args.n) |
|
else: |
|
write_txt(out_path, rows, args.n) |
|
|
|
if args.encrypt: |
|
password = args.password or getpass.getpass("Enter password: ") |
|
enc_path = out_path + ".enc" |
|
encrypt_file(out_path, enc_path, password) |
|
print(f"Saved: {out_path}") |
|
print(f"Encrypted copy: {enc_path}") |
|
else: |
|
print(f"Saved: {out_path}") |
|
else: |
|
|
|
if args.format == "json": |
|
print(json.dumps(rows, ensure_ascii=False, indent=2)) |
|
elif args.format == "csv": |
|
|
|
fieldnames = ["context_id", "context"] + [f"q{i}" for i in range(1, args.n + 1)] |
|
writer = csv.DictWriter(sys.stdout, fieldnames=fieldnames) |
|
writer.writeheader() |
|
for r in rows: |
|
writer.writerow(r) |
|
else: |
|
for r in rows: |
|
print(f"[{r['context_id']}]") |
|
print(r["context"].strip()) |
|
for i in range(1, args.n + 1): |
|
print(f"{i}. {r[f'q{i}']}") |
|
print() |
|
|
|
if __name__ == "__main__": |
|
main() |