| |
| |
|
|
| """ |
| open_ended_question_generator_secure.py |
| |
| End-to-end script to generate open-ended questions from context(s) with: |
| - Robust list-formatted parsing |
| - CLI with single or batch inputs (TXT/CSV) |
| - Reproducibility (seed) |
| - Device auto-select (CUDA / MPS / CPU) |
| - Export to JSON / CSV / TXT |
| - Optional AES-256-like authenticated encryption via Fernet (with PBKDF2 key derivation) |
| - Optional decryption utility |
| |
| Dependencies: |
| pip install torch transformers cryptography |
| |
| Example: |
| python open_ended_question_generator_secure.py \ |
| --context "AGI for cosmology" --n 5 --model gpt2-large \ |
| --out questions.json --format json --encrypt --password "your-secret" |
| """ |
|
|
| import os |
| import re |
| import csv |
| import json |
| import argparse |
| import getpass |
| import base64 |
| import sys |
| from typing import List, Dict, Tuple, Optional |
|
|
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| |
| try: |
| from cryptography.fernet import Fernet |
| from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC |
| from cryptography.hazmat.primitives import hashes |
| from cryptography.hazmat.backends import default_backend |
| except Exception: |
| Fernet = None |
|
|
|
|
| |
| |
| |
| def select_device() -> torch.device: |
| if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| return torch.device("mps") |
| if torch.cuda.is_available(): |
| return torch.device("cuda") |
| return torch.device("cpu") |
|
|
|
|
| |
| |
| |
| PROMPT_TEMPLATE = """You are a master at generating deep, open-ended, and thought-provoking questions. |
| Each question must be: |
| - Self-contained and understandable without extra context. |
| - Exploratory (not answerable with yes/no). |
| - Written in clear, engaging language. |
| |
| Context: |
| {context} |
| |
| Output exactly {n} questions as a numbered list, one per line, formatted like: |
| 1. ... |
| 2. ... |
| 3. ... |
| No extra commentary, no headings, no explanations โ just the list. |
| """ |
|
|
| def build_prompt(context: str, n: int) -> str: |
| return PROMPT_TEMPLATE.format(context=context.strip(), n=n) |
|
|
| _Q_LINE_RE = re.compile(r"^\s*(\d+)\.\s+(.*\S)\s*$") |
|
|
| def normalize_q(q: str) -> str: |
| q = q.strip() |
| |
| if not q.endswith("?"): |
| q += "?" |
| return q |
|
|
| def parse_questions_from_text(text: str, n: int) -> List[str]: |
| lines = text.splitlines() |
| candidates = [] |
| for line in lines: |
| m = _Q_LINE_RE.match(line) |
| if m: |
| q_text = normalize_q(m.group(2)) |
| candidates.append(q_text) |
| |
| seen = set() |
| unique = [] |
| for q in candidates: |
| key = q.lower().strip() |
| if key not in seen: |
| seen.add(key) |
| unique.append(q) |
| return unique[:n] |
|
|
|
|
| |
| |
| |
| def load_model_and_tokenizer(model_name: str, device: torch.device): |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForCausalLM.from_pretrained(model_name) |
| model.to(device) |
| |
| if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: |
| tokenizer.pad_token_id = tokenizer.eos_token_id |
| return model, tokenizer |
|
|
| def generate_questions_once( |
| model, |
| tokenizer, |
| device: torch.device, |
| context: str, |
| n: int, |
| max_new_tokens: int, |
| temperature: float, |
| top_p: float, |
| top_k: int, |
| ) -> List[str]: |
| prompt = build_prompt(context, n) |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) |
| output = model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| do_sample=True, |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
| decoded = tokenizer.decode(output[0], skip_special_tokens=True) |
| |
| |
| |
| questions = parse_questions_from_text(decoded, n) |
| return questions |
|
|
| def generate_questions( |
| model, |
| tokenizer, |
| device: torch.device, |
| context: str, |
| n: int = 3, |
| max_new_tokens: int = 200, |
| temperature: float = 0.95, |
| top_p: float = 0.95, |
| top_k: int = 50, |
| seed: Optional[int] = None, |
| attempts: int = 3, |
| ) -> List[str]: |
| if seed is not None: |
| torch.manual_seed(seed) |
| if device.type == "cuda": |
| torch.cuda.manual_seed_all(seed) |
| collected: List[str] = [] |
| tried = 0 |
| while len(collected) < n and tried < attempts: |
| tried += 1 |
| |
| temp = min(1.2, max(0.7, temperature + 0.1 * (tried - 1))) |
| qs = generate_questions_once( |
| model, tokenizer, device, context, n, max_new_tokens, temp, top_p, top_k |
| ) |
| |
| existing = set([q.lower().strip() for q in collected]) |
| for q in qs: |
| key = q.lower().strip() |
| if key not in existing and len(collected) < n: |
| collected.append(q) |
| existing.add(key) |
| |
| while len(collected) < n: |
| collected.append(collected[-1] + " (expand)") if collected else collected.append("What deeper questions arise from this context?") |
| return collected[:n] |
|
|
|
|
| |
| |
| |
| def load_contexts(source_text: Optional[str], source_file: Optional[str]) -> List[Tuple[str, str]]: |
| """ |
| Returns list of (context_id, context_text). |
| - If source_text is provided, returns single-item list. |
| - If CSV file: expects a 'context' column. |
| - If TXT/MD: splits on lines containing only '---' or returns whole file as one context. |
| """ |
| out: List[Tuple[str, str]] = [] |
| if source_text: |
| out.append(("context_1", source_text.strip())) |
| return out |
| if not source_file: |
| raise ValueError("Either --context or --context-file is required.") |
| if not os.path.exists(source_file): |
| raise FileNotFoundError(f"Context file not found: {source_file}") |
|
|
| ext = os.path.splitext(source_file)[1].lower() |
| if ext == ".csv": |
| with open(source_file, "r", encoding="utf-8", newline="") as f: |
| reader = csv.DictReader(f) |
| if "context" not in reader.fieldnames: |
| raise ValueError("CSV must have a 'context' column.") |
| for i, row in enumerate(reader, start=1): |
| ctx = (row.get("context") or "").strip() |
| if ctx: |
| out.append((f"context_{i}", ctx)) |
| else: |
| |
| with open(source_file, "r", encoding="utf-8") as f: |
| content = f.read() |
| parts = re.split(r"^\s*---\s*$", content, flags=re.MULTILINE) |
| parts = [p.strip() for p in parts if p.strip()] |
| if not parts: |
| raise ValueError("No context found in file.") |
| for i, ctx in enumerate(parts, start=1): |
| out.append((f"context_{i}", ctx)) |
| return out |
|
|
|
|
| |
| |
| |
| def write_json(out_path: str, rows: List[Dict]): |
| with open(out_path, "w", encoding="utf-8") as f: |
| json.dump(rows, f, ensure_ascii=False, indent=2) |
|
|
| def write_csv(out_path: str, rows: List[Dict], n: int): |
| fieldnames = ["context_id", "context"] + [f"q{i}" for i in range(1, n + 1)] |
| with open(out_path, "w", encoding="utf-8", newline="") as f: |
| writer = csv.DictWriter(f, fieldnames=fieldnames) |
| writer.writeheader() |
| for r in rows: |
| writer.writerow(r) |
|
|
| def write_txt(out_path: str, rows: List[Dict], n: int): |
| with open(out_path, "w", encoding="utf-8") as f: |
| for r in rows: |
| f.write(f"[{r['context_id']}]\n") |
| f.write(r["context"].strip() + "\n") |
| for i in range(1, n + 1): |
| f.write(f"{i}. {r[f'q{i}']}\n") |
| f.write("\n") |
|
|
|
|
| |
| |
| |
| MAGIC = b"QSEC1" |
|
|
| def require_crypto(): |
| if Fernet is None: |
| raise RuntimeError("Encryption requested but 'cryptography' is not installed. Run: pip install cryptography") |
|
|
| def derive_key_from_password(password: str, salt: bytes) -> bytes: |
| kdf = PBKDF2HMAC( |
| algorithm=hashes.SHA256(), |
| length=32, |
| salt=salt, |
| iterations=200_000, |
| backend=default_backend(), |
| ) |
| key = kdf.derive(password.encode("utf-8")) |
| return base64.urlsafe_b64encode(key) |
|
|
| def encrypt_file(in_path: str, out_path: str, password: str): |
| require_crypto() |
| with open(in_path, "rb") as f: |
| plaintext = f.read() |
| salt = os.urandom(16) |
| key = derive_key_from_password(password, salt) |
| fernet = Fernet(key) |
| ciphertext = fernet.encrypt(plaintext) |
| with open(out_path, "wb") as f: |
| f.write(MAGIC + salt + ciphertext) |
|
|
| def decrypt_file(in_path: str, out_path: str, password: str): |
| require_crypto() |
| with open(in_path, "rb") as f: |
| blob = f.read() |
| if not blob.startswith(MAGIC) or len(blob) < len(MAGIC) + 16 + 1: |
| raise ValueError("Invalid or unsupported encrypted file.") |
| salt = blob[len(MAGIC):len(MAGIC)+16] |
| ciphertext = blob[len(MAGIC)+16:] |
| key = derive_key_from_password(password, salt) |
| fernet = Fernet(key) |
| plaintext = fernet.decrypt(ciphertext) |
| with open(out_path, "wb") as f: |
| f.write(plaintext) |
|
|
|
|
| |
| |
| |
| def main(): |
| parser = argparse.ArgumentParser(description="Generate deep open-ended questions with optional encryption/decryption.") |
| mode = parser.add_mutuallyExclusiveGroup(required=True) |
| mode.add_argument("--generate", action="store_true", help="Generate questions from context(s).") |
| mode.add_argument("--decrypt", action="store_true", help="Decrypt an encrypted file (no generation).") |
|
|
| |
| parser.add_argument("--context", type=str, help="Inline context text.") |
| parser.add_argument("--context-file", type=str, help="Path to TXT/MD (split by ---) or CSV with 'context' column.") |
| parser.add_argument("--n", type=int, default=3, help="Number of questions to generate per context.") |
| parser.add_argument("--model", type=str, default="gpt2-large", help="HuggingFace model name.") |
| parser.add_argument("--max-new-tokens", type=int, default=220, help="Max new tokens for generation.") |
| parser.add_argument("--temperature", type=float, default=0.95, help="Sampling temperature.") |
| parser.add_argument("--top-p", type=float, default=0.95, help="Top-p nucleus sampling.") |
| parser.add_argument("--top-k", type=int, default=50, help="Top-k sampling.") |
| parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility.") |
| parser.add_argument("--attempts", type=int, default=3, help="Max attempts to reach exactly n questions.") |
|
|
| |
| parser.add_argument("--out", type=str, default=None, help="Output file path. If omitted, prints to stdout.") |
| parser.add_argument("--format", type=str, choices=["json", "csv", "txt"], default="json", help="Output format when generating.") |
| parser.add_argument("--encrypt", action="store_true", help="Encrypt the output file after generation.") |
| parser.add_argument("--password", type=str, default=None, help="Password for encryption/decryption. If omitted, prompts securely.") |
|
|
| |
| parser.add_argument("--in", dest="in_path", type=str, help="Input file for decryption (encrypted).") |
| parser.add_argument("--out-decrypted", type=str, help="Output file for decrypted plaintext.") |
|
|
| args = parser.parse_args() |
|
|
| device = select_device() |
|
|
| if args.decrypt: |
| |
| if not args.in_path or not args.out_decrypted: |
| parser.error("--decrypt requires --in and --out-decrypted.") |
| password = args.password or getpass.getpass("Enter password: ") |
| decrypt_file(args.in_path, args.out_decrypted, password) |
| print(f"Decrypted to: {args.out_decrypted}") |
| return |
|
|
| |
| contexts = load_contexts(args.context, args.context_file) |
| model, tokenizer = load_model_and_tokenizer(args.model, device) |
|
|
| rows: List[Dict] = [] |
| for ctx_id, ctx in contexts: |
| qs = generate_questions( |
| model=model, |
| tokenizer=tokenizer, |
| device=device, |
| context=ctx, |
| n=args.n, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_p=args.top_p, |
| top_k=args.top_k, |
| seed=args.seed, |
| attempts=args.attempts, |
| ) |
| row = {"context_id": ctx_id, "context": ctx} |
| for i, q in enumerate(qs, start=1): |
| row[f"q{i}"] = q |
| rows.append(row) |
|
|
| |
| if args.out: |
| out_path = args.out |
| os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) |
| if args.format == "json": |
| write_json(out_path, rows) |
| elif args.format == "csv": |
| write_csv(out_path, rows, args.n) |
| else: |
| write_txt(out_path, rows, args.n) |
|
|
| if args.encrypt: |
| password = args.password or getpass.getpass("Enter password: ") |
| enc_path = out_path + ".enc" |
| encrypt_file(out_path, enc_path, password) |
| print(f"Saved: {out_path}") |
| print(f"Encrypted copy: {enc_path}") |
| else: |
| print(f"Saved: {out_path}") |
| else: |
| |
| if args.format == "json": |
| print(json.dumps(rows, ensure_ascii=False, indent=2)) |
| elif args.format == "csv": |
| |
| fieldnames = ["context_id", "context"] + [f"q{i}" for i in range(1, args.n + 1)] |
| writer = csv.DictWriter(sys.stdout, fieldnames=fieldnames) |
| writer.writeheader() |
| for r in rows: |
| writer.writerow(r) |
| else: |
| for r in rows: |
| print(f"[{r['context_id']}]") |
| print(r["context"].strip()) |
| for i in range(1, args.n + 1): |
| print(f"{i}. {r[f'q{i}']}") |
| print() |
|
|
| if __name__ == "__main__": |
| main() |