| | |
| | """ |
| | Interactive chat script for any model with automatic chat template support. |
| | Usage: python chat_with_models.py <model_folder_name> [--assistant] |
| | """ |
| |
|
| | import os |
| | import sys |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer, StoppingCriteria, StoppingCriteriaList |
| | import warnings |
| | import argparse |
| |
|
| | |
| | warnings.filterwarnings("ignore") |
| |
|
| | class StopSequenceCriteria(StoppingCriteria): |
| | def __init__(self, tokenizer, stop_sequences, prompt_length): |
| | self.tokenizer = tokenizer |
| | self.stop_sequences = stop_sequences |
| | self.prompt_length = prompt_length |
| | self.triggered_stop_sequence = None |
| | |
| | def __call__(self, input_ids, scores, **kwargs): |
| | |
| | if input_ids.shape[1] <= self.prompt_length: |
| | return False |
| | |
| | |
| | new_tokens = input_ids[0][self.prompt_length:] |
| | new_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) |
| | |
| | |
| | for stop_seq in self.stop_sequences: |
| | if stop_seq in new_text: |
| | return True |
| | return False |
| |
|
| | class ModelChatter: |
| | def __init__(self, model_folder, force_assistant_template=False): |
| | self.model_folder = model_folder |
| | self.hf_path = os.path.join(model_folder, 'hf') |
| | self.model = None |
| | self.tokenizer = None |
| | self.pipeline = None |
| | self.conversation_history = [] |
| | self.force_assistant_template = force_assistant_template |
| | |
| | def load_model(self): |
| | """Load the model and tokenizer.""" |
| | try: |
| | print(f"π Loading {self.model_folder}...") |
| | |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained(self.hf_path) |
| | if self.tokenizer.pad_token is None: |
| | self.tokenizer.pad_token = self.tokenizer.eos_token |
| | |
| | |
| | if self.force_assistant_template: |
| | print(f"π Forcing User: Assistant: chat template...") |
| | custom_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for msg in messages %}{% if msg.role=='user' %}{% if loop.index > 1 %}{{ '\\n\\n' }}{% endif %}User: {{ msg.content }}{% elif msg.role=='assistant' %}{{ '\\n\\nAssistant: ' }}{{ msg.content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\\n\\nAssistant: ' }}{% endif %}""" |
| | self.tokenizer.chat_template = custom_template |
| | print(f"β
User: Assistant: chat template forced") |
| | elif not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None: |
| | print(f"π No chat template found, assigning custom template...") |
| | custom_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for msg in messages %}{% if msg.role=='user' %}{% if loop.index > 1 %}{{ '\\n\\n' }}{% endif %}Instruction: {{ msg.content }}{% elif msg.role=='assistant' %}{{ '\\n\\nAnswer:' }}{{ msg.content }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\\n\\nAnswer:' }}{% endif %}""" |
| | self.tokenizer.chat_template = custom_template |
| | print(f"β
Custom chat template assigned") |
| | else: |
| | print(f"β
Model has existing chat template") |
| | |
| | |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | self.hf_path, |
| | device_map=None, |
| | torch_dtype=torch.float16, |
| | trust_remote_code=True |
| | ) |
| | |
| | |
| | if torch.cuda.is_available(): |
| | self.model.to("cuda:0") |
| | device = "cuda:0" |
| | elif torch.backends.mps.is_available(): |
| | self.model.to("mps") |
| | device = "mps" |
| | else: |
| | self.model.to("cpu") |
| | device = "cpu" |
| | |
| | print(f" π± Using device: {device}") |
| | |
| | |
| | self.pipeline = pipeline( |
| | "text-generation", |
| | model=self.model, |
| | tokenizer=self.tokenizer, |
| | device_map="auto", |
| | torch_dtype=torch.float16 |
| | ) |
| | |
| | print(f" β
{self.model_folder} loaded successfully") |
| | return True |
| | |
| | except Exception as e: |
| | print(f" β Failed to load {self.model_folder}: {str(e)}") |
| | return False |
| | |
| | def format_chat_prompt(self, user_message): |
| | """Format the conversation history and new user message using the chat template.""" |
| | |
| | self.conversation_history.append({"role": "user", "content": user_message}) |
| | |
| | |
| | try: |
| | formatted_prompt = self.tokenizer.apply_chat_template( |
| | self.conversation_history, |
| | tokenize=False, |
| | add_generation_prompt=True |
| | ) |
| | return formatted_prompt |
| | except Exception as e: |
| | print(f"β Error formatting chat prompt: {str(e)}") |
| | return None |
| | |
| | def generate_response(self, user_message, max_length=512): |
| | """Generate a response to the user message.""" |
| | try: |
| | |
| | formatted_prompt = self.format_chat_prompt(user_message) |
| | if formatted_prompt is None: |
| | return "β Failed to format chat prompt" |
| | |
| | |
| | print("π€ Response: ", end="", flush=True) |
| | |
| | |
| | inputs = self.tokenizer(formatted_prompt, return_tensors="pt") |
| | if torch.cuda.is_available(): |
| | inputs = {k: v.to("cuda:0") for k, v in inputs.items()} |
| | elif torch.backends.mps.is_available(): |
| | inputs = {k: v.to("mps") for k, v in inputs.items()} |
| | |
| | |
| | streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) |
| | |
| | |
| | stop_sequences = ["Question:", "Instruction:", "Answer:", "User:"] |
| | |
| | |
| | prompt_length = inputs['input_ids'].shape[1] |
| | stopping_criteria = StopSequenceCriteria(self.tokenizer, stop_sequences, prompt_length) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = self.model.generate( |
| | **inputs, |
| | max_new_tokens=max_length, |
| | do_sample=True, |
| | temperature=0.7, |
| | top_p=0.9, |
| | repetition_penalty=1.1, |
| | pad_token_id=self.tokenizer.eos_token_id, |
| | streamer=streamer, |
| | eos_token_id=self.tokenizer.eos_token_id, |
| | stopping_criteria=StoppingCriteriaList([stopping_criteria]) |
| | ) |
| | |
| | |
| | generated_text = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) |
| | |
| | |
| | if stopping_criteria.triggered_stop_sequence: |
| | stop_seq = stopping_criteria.triggered_stop_sequence |
| | original_text = generated_text |
| | if generated_text.endswith(stop_seq): |
| | generated_text = generated_text[:-len(stop_seq)].rstrip() |
| | elif stop_seq in generated_text: |
| | |
| | last_pos = generated_text.rfind(stop_seq) |
| | if last_pos != -1: |
| | generated_text = generated_text[:last_pos].rstrip() |
| | |
| | |
| | if generated_text != original_text: |
| | print(f"\nπ Stripped stop sequence '{stop_seq}' from response") |
| | |
| | |
| | self.conversation_history.append({"role": "assistant", "content": generated_text}) |
| | |
| | |
| | return "" |
| | |
| | except Exception as e: |
| | return f"β Generation failed: {str(e)}" |
| | |
| | def reset_conversation(self): |
| | """Reset the conversation history.""" |
| | self.conversation_history = [] |
| | print("π Conversation history cleared!") |
| | |
| | def show_conversation_history(self): |
| | """Display the current conversation history.""" |
| | if not self.conversation_history: |
| | print("π No conversation history yet.") |
| | return |
| | |
| | print("\nπ Conversation History:") |
| | print("=" * 50) |
| | for i, message in enumerate(self.conversation_history): |
| | role = message["role"].capitalize() |
| | content = message["content"] |
| | print(f"{role}: {content}") |
| | if i < len(self.conversation_history) - 1: |
| | print("-" * 30) |
| | print("=" * 50) |
| | |
| | def interactive_chat(self): |
| | """Main interactive chat loop.""" |
| | print(f"\n㪠Chatting with {self.model_folder}") |
| | print("Commands:") |
| | print(" - Type your message to chat") |
| | print(" - Type 'quit' or 'exit' to end") |
| | print(" - Type 'help' for this message") |
| | print(" - Type 'reset' to clear conversation history") |
| | print(" - Type 'history' to show conversation history") |
| | print(" - Type 'clear' to clear screen") |
| | print("\nπ‘ Start chatting! (Works with any model)") |
| | |
| | while True: |
| | try: |
| | user_input = input("\nπ€ You: ").strip() |
| | |
| | if not user_input: |
| | continue |
| | |
| | if user_input.lower() in ['quit', 'exit', 'q']: |
| | print("π Goodbye!") |
| | break |
| | |
| | elif user_input.lower() == 'help': |
| | print(f"\n㪠Chatting with {self.model_folder}") |
| | print("Commands:") |
| | print(" - Type your message to chat") |
| | print(" - Type 'quit' or 'exit' to end") |
| | print(" - Type 'help' for this message") |
| | print(" - Type 'reset' to clear conversation history") |
| | print(" - Type 'history' to show conversation history") |
| | print(" - Type 'clear' to clear screen") |
| | print(" - Works with any model (auto-assigns chat template)") |
| | |
| | elif user_input.lower() == 'reset': |
| | self.reset_conversation() |
| | |
| | elif user_input.lower() == 'history': |
| | self.show_conversation_history() |
| | |
| | elif user_input.lower() == 'clear': |
| | os.system('clear' if os.name == 'posix' else 'cls') |
| | |
| | else: |
| | |
| | print(f"\nπ€ {self.model_folder}:") |
| | response = self.generate_response(user_input) |
| | |
| | |
| | except KeyboardInterrupt: |
| | print("\n\nπ Goodbye!") |
| | break |
| | except Exception as e: |
| | print(f"β Error: {str(e)}") |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Interactive chat script for any model") |
| | parser.add_argument("model_folder", help="Name of the model folder") |
| | parser.add_argument("--assistant", action="store_true", |
| | help="Force User: Assistant: chat template even if model has its own") |
| | |
| | args = parser.parse_args() |
| | |
| | model_folder = args.model_folder |
| | force_assistant_template = args.assistant |
| | |
| | |
| | if not os.path.exists(model_folder): |
| | print(f"β Model folder '{model_folder}' not found!") |
| | sys.exit(1) |
| | |
| | |
| | hf_path = os.path.join(model_folder, 'hf') |
| | if not os.path.exists(hf_path): |
| | print(f"β No 'hf' subdirectory found in '{model_folder}'!") |
| | sys.exit(1) |
| | |
| | print("π Model Chat Script") |
| | print("=" * 50) |
| | if force_assistant_template: |
| | print("π§ Forcing User: Assistant: chat template") |
| | print("=" * 50) |
| | |
| | chatter = ModelChatter(model_folder, force_assistant_template) |
| | |
| | |
| | if not chatter.load_model(): |
| | print("β Failed to load model. Exiting.") |
| | sys.exit(1) |
| | |
| | print(f"β
Model '{model_folder}' loaded successfully") |
| | |
| | |
| | chatter.interactive_chat() |
| |
|
| | if __name__ == "__main__": |
| | main() |