#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ open_ended_question_generator_secure.py End-to-end script to generate open-ended questions from context(s) with: - Robust list-formatted parsing - CLI with single or batch inputs (TXT/CSV) - Reproducibility (seed) - Device auto-select (CUDA / MPS / CPU) - Export to JSON / CSV / TXT - Optional AES-256-like authenticated encryption via Fernet (with PBKDF2 key derivation) - Optional decryption utility Dependencies: pip install torch transformers cryptography Example: python open_ended_question_generator_secure.py \ --context "AGI for cosmology" --n 5 --model gpt2-large \ --out questions.json --format json --encrypt --password "your-secret" """ import os import re import csv import json import argparse import getpass import base64 import sys from typing import List, Dict, Tuple, Optional import torch from transformers import AutoTokenizer, AutoModelForCausalLM # --- Optional encryption deps --- try: from cryptography.fernet import Fernet from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from cryptography.hazmat.primitives import hashes from cryptography.hazmat.backends import default_backend except Exception: Fernet = None # Will validate at runtime if encryption/decryption is used. # ---------------------------- # Device selection # ---------------------------- def select_device() -> torch.device: if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return torch.device("mps") if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") # ---------------------------- # Prompt and parsing # ---------------------------- PROMPT_TEMPLATE = """You are a master at generating deep, open-ended, and thought-provoking questions. Each question must be: - Self-contained and understandable without extra context. - Exploratory (not answerable with yes/no). - Written in clear, engaging language. Context: {context} Output exactly {n} questions as a numbered list, one per line, formatted like: 1. ... 2. ... 3. ... No extra commentary, no headings, no explanations — just the list. """ def build_prompt(context: str, n: int) -> str: return PROMPT_TEMPLATE.format(context=context.strip(), n=n) _Q_LINE_RE = re.compile(r"^\s*(\d+)\.\s+(.*\S)\s*$") def normalize_q(q: str) -> str: q = q.strip() # Ensure it ends with a question mark for consistency if not q.endswith("?"): q += "?" return q def parse_questions_from_text(text: str, n: int) -> List[str]: lines = text.splitlines() candidates = [] for line in lines: m = _Q_LINE_RE.match(line) if m: q_text = normalize_q(m.group(2)) candidates.append(q_text) # Deduplicate while preserving order seen = set() unique = [] for q in candidates: key = q.lower().strip() if key not in seen: seen.add(key) unique.append(q) return unique[:n] # ---------------------------- # Model loading and generation # ---------------------------- def load_model_and_tokenizer(model_name: str, device: torch.device): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) model.to(device) # For models like GPT-2 without a pad token if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: tokenizer.pad_token_id = tokenizer.eos_token_id return model, tokenizer def generate_questions_once( model, tokenizer, device: torch.device, context: str, n: int, max_new_tokens: int, temperature: float, top_p: float, top_k: int, ) -> List[str]: prompt = build_prompt(context, n) inputs = tokenizer(prompt, return_tensors="pt").to(device) output = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) decoded = tokenizer.decode(output[0], skip_special_tokens=True) # Extract only the continuation after the prompt # In many causal LMs, decoded contains prompt + completion; we slice from len(input_ids) # Simpler approach: parse all lines and trust the numbered format. questions = parse_questions_from_text(decoded, n) return questions def generate_questions( model, tokenizer, device: torch.device, context: str, n: int = 3, max_new_tokens: int = 200, temperature: float = 0.95, top_p: float = 0.95, top_k: int = 50, seed: Optional[int] = None, attempts: int = 3, ) -> List[str]: if seed is not None: torch.manual_seed(seed) if device.type == "cuda": torch.cuda.manual_seed_all(seed) collected: List[str] = [] tried = 0 while len(collected) < n and tried < attempts: tried += 1 # Slightly adjust temperature on retries to improve variety temp = min(1.2, max(0.7, temperature + 0.1 * (tried - 1))) qs = generate_questions_once( model, tokenizer, device, context, n, max_new_tokens, temp, top_p, top_k ) # Merge unique existing = set([q.lower().strip() for q in collected]) for q in qs: key = q.lower().strip() if key not in existing and len(collected) < n: collected.append(q) existing.add(key) # If still short, pad with simple variants (rare) while len(collected) < n: collected.append(collected[-1] + " (expand)") if collected else collected.append("What deeper questions arise from this context?") return collected[:n] # ---------------------------- # Batch input handling # ---------------------------- def load_contexts(source_text: Optional[str], source_file: Optional[str]) -> List[Tuple[str, str]]: """ Returns list of (context_id, context_text). - If source_text is provided, returns single-item list. - If CSV file: expects a 'context' column. - If TXT/MD: splits on lines containing only '---' or returns whole file as one context. """ out: List[Tuple[str, str]] = [] if source_text: out.append(("context_1", source_text.strip())) return out if not source_file: raise ValueError("Either --context or --context-file is required.") if not os.path.exists(source_file): raise FileNotFoundError(f"Context file not found: {source_file}") ext = os.path.splitext(source_file)[1].lower() if ext == ".csv": with open(source_file, "r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) if "context" not in reader.fieldnames: raise ValueError("CSV must have a 'context' column.") for i, row in enumerate(reader, start=1): ctx = (row.get("context") or "").strip() if ctx: out.append((f"context_{i}", ctx)) else: # Plain text / markdown: split on '---' delimiter lines if present with open(source_file, "r", encoding="utf-8") as f: content = f.read() parts = re.split(r"^\s*---\s*$", content, flags=re.MULTILINE) parts = [p.strip() for p in parts if p.strip()] if not parts: raise ValueError("No context found in file.") for i, ctx in enumerate(parts, start=1): out.append((f"context_{i}", ctx)) return out # ---------------------------- # Output writers # ---------------------------- def write_json(out_path: str, rows: List[Dict]): with open(out_path, "w", encoding="utf-8") as f: json.dump(rows, f, ensure_ascii=False, indent=2) def write_csv(out_path: str, rows: List[Dict], n: int): fieldnames = ["context_id", "context"] + [f"q{i}" for i in range(1, n + 1)] with open(out_path, "w", encoding="utf-8", newline="") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() for r in rows: writer.writerow(r) def write_txt(out_path: str, rows: List[Dict], n: int): with open(out_path, "w", encoding="utf-8") as f: for r in rows: f.write(f"[{r['context_id']}]\n") f.write(r["context"].strip() + "\n") for i in range(1, n + 1): f.write(f"{i}. {r[f'q{i}']}\n") f.write("\n") # ---------------------------- # Encryption / Decryption # ---------------------------- MAGIC = b"QSEC1" def require_crypto(): if Fernet is None: raise RuntimeError("Encryption requested but 'cryptography' is not installed. Run: pip install cryptography") def derive_key_from_password(password: str, salt: bytes) -> bytes: kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, salt=salt, iterations=200_000, backend=default_backend(), ) key = kdf.derive(password.encode("utf-8")) return base64.urlsafe_b64encode(key) def encrypt_file(in_path: str, out_path: str, password: str): require_crypto() with open(in_path, "rb") as f: plaintext = f.read() salt = os.urandom(16) key = derive_key_from_password(password, salt) fernet = Fernet(key) ciphertext = fernet.encrypt(plaintext) with open(out_path, "wb") as f: f.write(MAGIC + salt + ciphertext) def decrypt_file(in_path: str, out_path: str, password: str): require_crypto() with open(in_path, "rb") as f: blob = f.read() if not blob.startswith(MAGIC) or len(blob) < len(MAGIC) + 16 + 1: raise ValueError("Invalid or unsupported encrypted file.") salt = blob[len(MAGIC):len(MAGIC)+16] ciphertext = blob[len(MAGIC)+16:] key = derive_key_from_password(password, salt) fernet = Fernet(key) plaintext = fernet.decrypt(ciphertext) with open(out_path, "wb") as f: f.write(plaintext) # ---------------------------- # Main CLI # ---------------------------- def main(): parser = argparse.ArgumentParser(description="Generate deep open-ended questions with optional encryption/decryption.") mode = parser.add_mutuallyExclusiveGroup(required=True) mode.add_argument("--generate", action="store_true", help="Generate questions from context(s).") mode.add_argument("--decrypt", action="store_true", help="Decrypt an encrypted file (no generation).") # Generation inputs parser.add_argument("--context", type=str, help="Inline context text.") parser.add_argument("--context-file", type=str, help="Path to TXT/MD (split by ---) or CSV with 'context' column.") parser.add_argument("--n", type=int, default=3, help="Number of questions to generate per context.") parser.add_argument("--model", type=str, default="gpt2-large", help="HuggingFace model name.") parser.add_argument("--max-new-tokens", type=int, default=220, help="Max new tokens for generation.") parser.add_argument("--temperature", type=float, default=0.95, help="Sampling temperature.") parser.add_argument("--top-p", type=float, default=0.95, help="Top-p nucleus sampling.") parser.add_argument("--top-k", type=int, default=50, help="Top-k sampling.") parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.") parser.add_argument("--attempts", type=int, default=3, help="Max attempts to reach exactly n questions.") # Output parser.add_argument("--out", type=str, default=None, help="Output file path. If omitted, prints to stdout.") parser.add_argument("--format", type=str, choices=["json", "csv", "txt"], default="json", help="Output format when generating.") parser.add_argument("--encrypt", action="store_true", help="Encrypt the output file after generation.") parser.add_argument("--password", type=str, default=None, help="Password for encryption/decryption. If omitted, prompts securely.") # Decryption I/O parser.add_argument("--in", dest="in_path", type=str, help="Input file for decryption (encrypted).") parser.add_argument("--out-decrypted", type=str, help="Output file for decrypted plaintext.") args = parser.parse_args() device = select_device() if args.decrypt: # Decrypt mode if not args.in_path or not args.out_decrypted: parser.error("--decrypt requires --in and --out-decrypted.") password = args.password or getpass.getpass("Enter password: ") decrypt_file(args.in_path, args.out_decrypted, password) print(f"Decrypted to: {args.out_decrypted}") return # Generate mode contexts = load_contexts(args.context, args.context_file) model, tokenizer = load_model_and_tokenizer(args.model, device) rows: List[Dict] = [] for ctx_id, ctx in contexts: qs = generate_questions( model=model, tokenizer=tokenizer, device=device, context=ctx, n=args.n, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, seed=args.seed, attempts=args.attempts, ) row = {"context_id": ctx_id, "context": ctx} for i, q in enumerate(qs, start=1): row[f"q{i}"] = q rows.append(row) # Output if args.out: out_path = args.out os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) if args.format == "json": write_json(out_path, rows) elif args.format == "csv": write_csv(out_path, rows, args.n) else: write_txt(out_path, rows, args.n) if args.encrypt: password = args.password or getpass.getpass("Enter password: ") enc_path = out_path + ".enc" encrypt_file(out_path, enc_path, password) print(f"Saved: {out_path}") print(f"Encrypted copy: {enc_path}") else: print(f"Saved: {out_path}") else: # Print to stdout in selected format if args.format == "json": print(json.dumps(rows, ensure_ascii=False, indent=2)) elif args.format == "csv": # Minimal CSV to stdout fieldnames = ["context_id", "context"] + [f"q{i}" for i in range(1, args.n + 1)] writer = csv.DictWriter(sys.stdout, fieldnames=fieldnames) writer.writeheader() for r in rows: writer.writerow(r) else: for r in rows: print(f"[{r['context_id']}]") print(r["context"].strip()) for i in range(1, args.n + 1): print(f"{i}. {r[f'q{i}']}") print() if __name__ == "__main__": main()