| |
|
|
| import argparse |
| import json |
| import os |
| import random |
| import sys |
|
|
|
|
| def set_seed(seed: int): |
| """Set random seeds for reproducibility.""" |
| random.seed(seed) |
| try: |
| import numpy as np |
| np.random.seed(seed) |
| except ImportError: |
| pass |
| try: |
| import torch |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
| except ImportError: |
| pass |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--model_path", "-m", default=".", help="Path to converted model") |
| parser.add_argument( |
| "--prompts", "-p", default="test_prompts.json", |
| help="Path to JSON file with list of prompt strings (default: hf_conversion/test_prompts.json)") |
| parser.add_argument( |
| "--seed", "-s", type=int, default=0, |
| help="Random seed for reproducible generation (default: None, non-deterministic)") |
| parser.add_argument( |
| "--max_new_tokens", type=int, default=None, |
| help="Max tokens to generate (default: 50)") |
| parser.add_argument( |
| "--max_new_sents", type=int, default=None, |
| help="Max sentences in decoded output (default: pipeline default)") |
| args = parser.parse_args() |
|
|
| if args.seed is not None: |
| set_seed(args.seed) |
| print(f"Random seed set to {args.seed} for reproducibility") |
|
|
| if not os.path.isdir(args.model_path): |
| print(f"Error: Model path {args.model_path} does not exist.") |
| sys.exit(1) |
|
|
| prompts_path = args.prompts |
| if prompts_path is None: |
| prompts_path = os.path.join(os.path.dirname( |
| os.path.abspath(__file__)), "test_prompts.json") |
| if not os.path.isfile(prompts_path): |
| print(f"Error: Prompts file {prompts_path} does not exist.") |
| sys.exit(1) |
|
|
| print("Loading model and tokenizer...") |
| from transformers import AutoModelForCausalLM |
|
|
| |
| model_path = os.path.abspath(args.model_path) |
| from rnnlm_model import ( |
| RNNLMConfig, |
| RNNLMForCausalLM, |
| RNNLMTokenizer, |
| RNNLMTextGenerationPipeline, |
| ) |
| from transformers import AutoConfig |
| AutoConfig.register("rnnlm", RNNLMConfig) |
| AutoModelForCausalLM.register(RNNLMConfig, RNNLMForCausalLM) |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, trust_remote_code=True) |
| tokenizer = RNNLMTokenizer.from_pretrained(model_path) |
|
|
| print("Creating RNNLMTextGenerationPipeline (with entity adaptation)...") |
| pipe = RNNLMTextGenerationPipeline( |
| model=model, |
| tokenizer=tokenizer, |
| ) |
|
|
| with open(prompts_path) as f: |
| test_prompts = json.load(f) |
|
|
| base_kwargs = dict( |
| max_new_tokens=args.max_new_tokens if args.max_new_tokens is not None else 50, |
| do_sample=True, |
| temperature=1.0, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
| if args.max_new_sents is not None: |
| base_kwargs["max_new_sents"] = args.max_new_sents |
|
|
| def run_tests(kwargs): |
| for i, prompt in enumerate(test_prompts): |
| print(f"\n [{i + 1}/{len(test_prompts)}]") |
| print(f" PROMPT: ``{prompt}``") |
| output = pipe(prompt, **kwargs) |
| print(f" GENERATED: ``{output[0]['generated_text']}``") |
|
|
| |
| print("\n--- Test 1: Basic generation (default params) ---") |
| run_tests(base_kwargs) |
|
|
| |
| print("\n--- Test 2: max_new_tokens=20 ---") |
| short_kwargs = {**base_kwargs, "max_new_tokens": 20} |
| run_tests(short_kwargs) |
|
|
| |
| print("\n--- Test 3: max_new_sents=2 ---") |
| sents_kwargs = {**base_kwargs, "max_new_sents": 2} |
| run_tests(sents_kwargs) |
|
|
| |
| print("\n--- Test 4: max_new_sents=1 ---") |
| sents1_kwargs = {**base_kwargs, "max_new_sents": 1} |
| run_tests(sents1_kwargs) |
|
|
| |
| print("\n--- Test 5: do_sample=False ---") |
| greedy_kwargs = {**base_kwargs, "do_sample": False} |
| run_tests(greedy_kwargs) |
|
|
| |
| print("\n--- Test 6: temperature=0.3 ---") |
| low_temp_kwargs = {**base_kwargs, "temperature": 0.3} |
| run_tests(low_temp_kwargs) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|