| | |
| | """ |
| | Helion-2.5-Rnd Inference Pipeline |
| | High-level pipeline for easy model usage |
| | """ |
| |
|
| | import logging |
| | import time |
| | from typing import Any, Dict, List, Optional, Union |
| |
|
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class StopOnTokens(StoppingCriteria): |
| | """Stop generation when specific tokens are generated""" |
| | |
| | def __init__(self, stop_token_ids: List[int]): |
| | self.stop_token_ids = stop_token_ids |
| | |
| | def __call__( |
| | self, |
| | input_ids: torch.LongTensor, |
| | scores: torch.FloatTensor, |
| | **kwargs |
| | ) -> bool: |
| | for stop_id in self.stop_token_ids: |
| | if input_ids[0][-1] == stop_id: |
| | return True |
| | return False |
| |
|
| |
|
| | class HelionPipeline: |
| | """High-level inference pipeline for Helion model""" |
| | |
| | def __init__( |
| | self, |
| | model_path: str, |
| | device: str = "cuda", |
| | torch_dtype=torch.bfloat16, |
| | load_in_8bit: bool = False, |
| | trust_remote_code: bool = True |
| | ): |
| | """ |
| | Initialize Helion pipeline |
| | |
| | Args: |
| | model_path: Path to model or HuggingFace ID |
| | device: Device to load model on |
| | torch_dtype: Torch data type |
| | load_in_8bit: Whether to load in 8-bit |
| | trust_remote_code: Trust remote code |
| | """ |
| | logger.info(f"Loading Helion model from {model_path}") |
| | |
| | self.device = device |
| | self.model_path = model_path |
| | |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | model_path, |
| | trust_remote_code=trust_remote_code |
| | ) |
| | |
| | |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | model_path, |
| | torch_dtype=torch_dtype, |
| | device_map="auto" if device == "cuda" else None, |
| | load_in_8bit=load_in_8bit, |
| | trust_remote_code=trust_remote_code |
| | ) |
| | |
| | if device != "cuda" and not load_in_8bit: |
| | self.model = self.model.to(device) |
| | |
| | self.model.eval() |
| | |
| | |
| | self.stop_token_ids = [ |
| | self.tokenizer.eos_token_id, |
| | self.tokenizer.convert_tokens_to_ids("<|im_end|>"), |
| | ] |
| | |
| | logger.info("Model loaded successfully") |
| | |
| | def generate( |
| | self, |
| | prompt: str, |
| | max_new_tokens: int = 512, |
| | temperature: float = 0.7, |
| | top_p: float = 0.9, |
| | top_k: int = 50, |
| | repetition_penalty: float = 1.1, |
| | do_sample: bool = True, |
| | num_return_sequences: int = 1, |
| | **kwargs |
| | ) -> Union[str, List[str]]: |
| | """ |
| | Generate text from prompt |
| | |
| | Args: |
| | prompt: Input prompt |
| | max_new_tokens: Maximum tokens to generate |
| | temperature: Sampling temperature |
| | top_p: Nucleus sampling parameter |
| | top_k: Top-k sampling parameter |
| | repetition_penalty: Repetition penalty |
| | do_sample: Whether to sample |
| | num_return_sequences: Number of sequences to return |
| | **kwargs: Additional generation parameters |
| | |
| | Returns: |
| | Generated text or list of texts |
| | """ |
| | |
| | inputs = self.tokenizer( |
| | prompt, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=self.model.config.max_position_embeddings |
| | ).to(self.device) |
| | |
| | |
| | stopping_criteria = StoppingCriteriaList([ |
| | StopOnTokens(self.stop_token_ids) |
| | ]) |
| | |
| | |
| | with torch.no_grad(): |
| | start_time = time.time() |
| | |
| | outputs = self.model.generate( |
| | **inputs, |
| | max_new_tokens=max_new_tokens, |
| | temperature=temperature, |
| | top_p=top_p, |
| | top_k=top_k, |
| | repetition_penalty=repetition_penalty, |
| | do_sample=do_sample, |
| | num_return_sequences=num_return_sequences, |
| | stopping_criteria=stopping_criteria, |
| | pad_token_id=self.tokenizer.pad_token_id, |
| | **kwargs |
| | ) |
| | |
| | generation_time = time.time() - start_time |
| | |
| | |
| | generated_texts = [] |
| | for output in outputs: |
| | text = self.tokenizer.decode( |
| | output[inputs['input_ids'].shape[1]:], |
| | skip_special_tokens=True |
| | ) |
| | generated_texts.append(text.strip()) |
| | |
| | logger.info(f"Generated {len(generated_texts)} sequences in {generation_time:.2f}s") |
| | |
| | if num_return_sequences == 1: |
| | return generated_texts[0] |
| | return generated_texts |
| | |
| | def chat( |
| | self, |
| | messages: List[Dict[str, str]], |
| | max_new_tokens: int = 512, |
| | temperature: float = 0.7, |
| | **kwargs |
| | ) -> str: |
| | """ |
| | Chat completion |
| | |
| | Args: |
| | messages: List of message dictionaries |
| | max_new_tokens: Maximum tokens to generate |
| | temperature: Sampling temperature |
| | **kwargs: Additional generation parameters |
| | |
| | Returns: |
| | Assistant response |
| | """ |
| | |
| | prompt = self._format_chat_prompt(messages) |
| | |
| | |
| | response = self.generate( |
| | prompt, |
| | max_new_tokens=max_new_tokens, |
| | temperature=temperature, |
| | **kwargs |
| | ) |
| | |
| | return response |
| | |
| | def _format_chat_prompt(self, messages: List[Dict[str, str]]) -> str: |
| | """Format messages into chat prompt""" |
| | formatted = "" |
| | |
| | for msg in messages: |
| | role = msg.get('role', 'user') |
| | content = msg.get('content', '') |
| | formatted += f"<|im_start|>{role}\n{content}<|im_end|>\n" |
| | |
| | formatted += "<|im_start|>assistant\n" |
| | return formatted |
| | |
| | def batch_generate( |
| | self, |
| | prompts: List[str], |
| | max_new_tokens: int = 512, |
| | temperature: float = 0.7, |
| | batch_size: int = 4, |
| | **kwargs |
| | ) -> List[str]: |
| | """ |
| | Generate for multiple prompts in batches |
| | |
| | Args: |
| | prompts: List of input prompts |
| | max_new_tokens: Maximum tokens to generate |
| | temperature: Sampling temperature |
| | batch_size: Batch size for processing |
| | **kwargs: Additional generation parameters |
| | |
| | Returns: |
| | List of generated texts |
| | """ |
| | all_outputs = [] |
| | |
| | for i in range(0, len(prompts), batch_size): |
| | batch = prompts[i:i + batch_size] |
| | |
| | |
| | inputs = self.tokenizer( |
| | batch, |
| | return_tensors="pt", |
| | padding=True, |
| | truncation=True, |
| | max_length=self.model.config.max_position_embeddings |
| | ).to(self.device) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = self.model.generate( |
| | **inputs, |
| | max_new_tokens=max_new_tokens, |
| | temperature=temperature, |
| | pad_token_id=self.tokenizer.pad_token_id, |
| | **kwargs |
| | ) |
| | |
| | |
| | for j, output in enumerate(outputs): |
| | text = self.tokenizer.decode( |
| | output[inputs['input_ids'][j].shape[0]:], |
| | skip_special_tokens=True |
| | ) |
| | all_outputs.append(text.strip()) |
| | |
| | logger.info(f"Generated {len(all_outputs)} outputs") |
| | return all_outputs |
| | |
| | def stream_generate( |
| | self, |
| | prompt: str, |
| | max_new_tokens: int = 512, |
| | temperature: float = 0.7, |
| | **kwargs |
| | ): |
| | """ |
| | Stream generation token by token |
| | |
| | Args: |
| | prompt: Input prompt |
| | max_new_tokens: Maximum tokens to generate |
| | temperature: Sampling temperature |
| | **kwargs: Additional generation parameters |
| | |
| | Yields: |
| | Generated tokens |
| | """ |
| | inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
| | input_length = inputs['input_ids'].shape[1] |
| | |
| | stopping_criteria = StoppingCriteriaList([ |
| | StopOnTokens(self.stop_token_ids) |
| | ]) |
| | |
| | with torch.no_grad(): |
| | for _ in range(max_new_tokens): |
| | outputs = self.model.generate( |
| | **inputs, |
| | max_new_tokens=1, |
| | temperature=temperature, |
| | stopping_criteria=stopping_criteria, |
| | pad_token_id=self.tokenizer.pad_token_id, |
| | **kwargs |
| | ) |
| | |
| | new_token_id = outputs[0, -1].item() |
| | |
| | |
| | if new_token_id in self.stop_token_ids: |
| | break |
| | |
| | |
| | new_token = self.tokenizer.decode([new_token_id]) |
| | yield new_token |
| | |
| | |
| | inputs = { |
| | 'input_ids': outputs, |
| | 'attention_mask': torch.ones_like(outputs) |
| | } |
| | |
| | def get_embeddings(self, text: str) -> torch.Tensor: |
| | """ |
| | Get embeddings for text |
| | |
| | Args: |
| | text: Input text |
| | |
| | Returns: |
| | Embedding tensor |
| | """ |
| | inputs = self.tokenizer(text, return_tensors="pt").to(self.device) |
| | |
| | with torch.no_grad(): |
| | outputs = self.model(**inputs, output_hidden_states=True) |
| | embeddings = outputs.hidden_states[-1].mean(dim=1) |
| | |
| | return embeddings |
| | |
| | def score_text(self, text: str) -> float: |
| | """ |
| | Calculate perplexity score for text |
| | |
| | Args: |
| | text: Input text |
| | |
| | Returns: |
| | Perplexity score |
| | """ |
| | inputs = self.tokenizer(text, return_tensors="pt").to(self.device) |
| | |
| | with torch.no_grad(): |
| | outputs = self.model(**inputs, labels=inputs['input_ids']) |
| | loss = outputs.loss |
| | perplexity = torch.exp(loss).item() |
| | |
| | return perplexity |
| | |
| | def cleanup(self): |
| | """Clean up resources""" |
| | del self.model |
| | del self.tokenizer |
| | torch.cuda.empty_cache() |
| | logger.info("Pipeline cleaned up") |
| |
|
| |
|
| | class ConversationPipeline(HelionPipeline): |
| | """Pipeline with conversation history management""" |
| | |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.conversation_history: List[Dict[str, str]] = [] |
| | self.system_prompt: Optional[str] = None |
| | |
| | def set_system_prompt(self, prompt: str): |
| | """Set system prompt for conversation""" |
| | self.system_prompt = prompt |
| | |
| | def add_message(self, role: str, content: str): |
| | """Add message to conversation history""" |
| | self.conversation_history.append({ |
| | 'role': role, |
| | 'content': content |
| | }) |
| | |
| | def generate_response( |
| | self, |
| | user_message: str, |
| | max_new_tokens: int = 512, |
| | temperature: float = 0.7, |
| | **kwargs |
| | ) -> str: |
| | """ |
| | Generate response in conversation context |
| | |
| | Args: |
| | user_message: User's message |
| | max_new_tokens: Maximum tokens to generate |
| | temperature: Sampling temperature |
| | **kwargs: Additional generation parameters |
| | |
| | Returns: |
| | Assistant response |
| | """ |
| | |
| | messages = [] |
| | |
| | if self.system_prompt: |
| | messages.append({ |
| | 'role': 'system', |
| | 'content': self.system_prompt |
| | }) |
| | |
| | messages.extend(self.conversation_history) |
| | messages.append({ |
| | 'role': 'user', |
| | 'content': user_message |
| | }) |
| | |
| | |
| | response = self.chat( |
| | messages, |
| | max_new_tokens=max_new_tokens, |
| | temperature=temperature, |
| | **kwargs |
| | ) |
| | |
| | |
| | self.add_message('user', user_message) |
| | self.add_message('assistant', response) |
| | |
| | return response |
| | |
| | def reset_conversation(self): |
| | """Reset conversation history""" |
| | self.conversation_history.clear() |
| | logger.info("Conversation history reset") |
| |
|
| |
|
| | def main(): |
| | """Example usage""" |
| | |
| | pipeline = HelionPipeline( |
| | model_path="DeepXR/Helion-2.5-Rnd", |
| | device="cuda" |
| | ) |
| | |
| | |
| | prompt = "Explain quantum computing in simple terms:" |
| | response = pipeline.generate(prompt, max_new_tokens=256) |
| | print(f"Response: {response}\n") |
| | |
| | |
| | messages = [ |
| | {"role": "system", "content": "You are a helpful assistant."}, |
| | {"role": "user", "content": "What is the capital of France?"} |
| | ] |
| | response = pipeline.chat(messages) |
| | print(f"Chat response: {response}\n") |
| | |
| | |
| | prompts = [ |
| | "Write a haiku about AI:", |
| | "Explain machine learning:", |
| | "What is Python?" |
| | ] |
| | responses = pipeline.batch_generate(prompts, batch_size=2) |
| | for i, resp in enumerate(responses): |
| | print(f"Batch {i+1}: {resp}\n") |
| | |
| | |
| | conv_pipeline = ConversationPipeline( |
| | model_path="DeepXR/Helion-2.5-Rnd", |
| | device="cuda" |
| | ) |
| | conv_pipeline.set_system_prompt("You are a helpful coding assistant.") |
| | |
| | response1 = conv_pipeline.generate_response("How do I sort a list in Python?") |
| | print(f"Conv 1: {response1}\n") |
| | |
| | response2 = conv_pipeline.generate_response("Can you show me an example?") |
| | print(f"Conv 2: {response2}\n") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |