| """ |
| Chat interface for CosmicFish model downloaded from Hugging Face Hub. |
| Uses safetensors format only for secure model loading. |
| """ |
|
|
| import os |
| import sys |
| import time |
| import argparse |
| import torch |
| import numpy as np |
| from termcolor import colored |
| import logging |
| import readline |
| import re |
| import textwrap |
| import random |
| from collections import defaultdict |
| import json |
|
|
| |
| try: |
| from transformers import GPT2Tokenizer |
| from huggingface_hub import hf_hub_download, snapshot_download |
| HF_AVAILABLE = True |
| except ImportError: |
| HF_AVAILABLE = False |
| print("Required libraries not available.") |
| print("Install with: pip install transformers huggingface-hub") |
| sys.exit(1) |
|
|
| |
| try: |
| from safetensors.torch import load_file |
| SAFETENSORS_AVAILABLE = True |
| except ImportError: |
| SAFETENSORS_AVAILABLE = False |
| print("Safetensors not available. Install with: pip install safetensors") |
| sys.exit(1) |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s', |
| handlers=[logging.StreamHandler(sys.stdout)] |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| DEFAULT_MODEL_REPO = "MistyozAI/CosmicFish-90M" |
|
|
| |
| DEFAULT_PROMPT_TEMPLATE = "Below is a conversation between a helpful AI assistant and a human. The assistant is knowledgeable, friendly, and provides detailed and accurate responses.\n\n" |
|
|
|
|
| class CosmicConfig: |
| """Configuration class for CosmicFish.""" |
|
|
| def __init__(self, |
| vocab_size=50257, |
| block_size=512, |
| n_layer=10, |
| n_head=16, |
| n_embd=640, |
| bias=True, |
| dropout=0.0, |
| n_query_groups=4, |
| eps=1e-6, |
| use_rotary=True, |
| use_swiglu=True, |
| use_qk_norm=False, |
| use_gqa=True): |
| self.vocab_size = vocab_size |
| self.block_size = block_size |
| self.n_layer = n_layer |
| self.n_head = n_head |
| self.n_embd = n_embd |
| self.bias = bias |
| self.dropout = dropout |
| self.eps = eps |
| self.use_rotary = use_rotary |
| self.use_swiglu = use_swiglu |
| self.use_qk_norm = use_qk_norm |
| self.use_gqa = use_gqa |
| self.n_query_groups = n_query_groups if use_gqa else n_head |
| |
| assert n_head % self.n_query_groups == 0, "n_head must be divisible by n_query_groups" |
|
|
|
|
| class RMSNorm(torch.nn.Module): |
| """Root Mean Square Normalization""" |
|
|
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = torch.nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x): |
| rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) |
| return self.weight * (x / rms) |
|
|
|
|
| def precompute_freqs_cis(dim, end, theta=10000.0): |
| """Precompute the frequency tensor for complex exponentials (cis)""" |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) |
| t = torch.arange(end, device=freqs.device) |
| freqs = torch.outer(t, freqs) |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| return freqs_cis |
|
|
|
|
| def apply_rotary_emb(xq, xk, freqs_cis): |
| """Apply rotary embeddings to input tensors""" |
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
| xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
|
|
| seq_len = xq_.size(2) |
| if freqs_cis.size(0) < seq_len: |
| raise ValueError(f"freqs_cis has only {freqs_cis.size(0)} values but sequence length is {seq_len}") |
|
|
| freqs_cis_seq = freqs_cis[:seq_len] |
| xq_out = torch.view_as_real(xq_ * freqs_cis_seq.unsqueeze(0)).flatten(3) |
| xk_out = torch.view_as_real(xk_ * freqs_cis_seq.unsqueeze(0)).flatten(3) |
|
|
| return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
|
|
| class GroupedQueryAttention(torch.nn.Module): |
| """Grouped Query Attention (GQA) implementation""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| assert config.n_embd % config.n_head == 0 |
|
|
| head_dim = config.n_embd // config.n_head |
| self.head_dim = head_dim |
| self.n_head = config.n_head |
| self.n_embd = config.n_embd |
| self.n_query_groups = config.n_query_groups |
|
|
| self.kv_heads = config.n_head // config.n_query_groups if config.use_gqa else config.n_head |
| qkv_proj_size = (config.n_head + 2 * self.kv_heads) * head_dim |
|
|
| self.c_attn = torch.nn.Linear(config.n_embd, qkv_proj_size, bias=config.bias) |
| self.c_proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
|
|
| |
| self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') |
| if not self.flash: |
| self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) |
| .view(1, 1, config.block_size, config.block_size)) |
|
|
| |
| self.qk_norm = getattr(config, 'use_qk_norm', False) |
| if self.qk_norm: |
| self.q_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6)) |
| self.k_norm = RMSNorm(head_dim, eps=getattr(config, 'eps', 1e-6)) |
|
|
| def forward(self, x, freqs_cis=None): |
| B, T, C = x.size() |
| qkv = self.c_attn(x) |
| head_dim = C // self.n_head |
|
|
| q_size = self.n_head * head_dim |
| k_size = self.kv_heads * head_dim |
| v_size = self.kv_heads * head_dim |
|
|
| q, k, v = qkv.split([q_size, k_size, v_size], dim=2) |
|
|
| q = q.view(B, T, self.n_head, head_dim).transpose(1, 2) |
| k = k.view(B, T, self.kv_heads, head_dim).transpose(1, 2) |
| v = v.view(B, T, self.kv_heads, head_dim).transpose(1, 2) |
|
|
| |
| if self.kv_heads < self.n_head: |
| repeats = self.n_head // self.kv_heads |
| k = k.repeat_interleave(repeats, dim=1) |
| v = v.repeat_interleave(repeats, dim=1) |
|
|
| |
| if freqs_cis is not None: |
| q, k = apply_rotary_emb(q, k, freqs_cis) |
|
|
| |
| if self.qk_norm: |
| q = self.q_norm(q) |
| k = self.k_norm(k) |
|
|
| |
| if self.flash: |
| y = torch.nn.functional.scaled_dot_product_attention( |
| q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True |
| ) |
| else: |
| att = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(k.size(-1), dtype=torch.float32))) |
| att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf')) |
| att = torch.nn.functional.softmax(att, dim=-1) |
| y = att @ v |
|
|
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| y = self.c_proj(y) |
| return y |
|
|
|
|
| class Block(torch.nn.Module): |
| """Transformer block""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.ln_1 = RMSNorm(config.n_embd, eps=config.eps) |
| self.ln_2 = RMSNorm(config.n_embd, eps=config.eps) |
| self.attn = GroupedQueryAttention(config) |
|
|
| |
| if config.use_swiglu: |
| |
| self.mlp = torch.nn.ModuleDict(dict( |
| gate=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias), |
| up=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias), |
| down=torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias), |
| act=torch.nn.SiLU(), |
| )) |
| m = self.mlp |
| self.mlpf = lambda x: m.down(m.act(m.up(x)) * m.gate(x)) |
| else: |
| |
| self.mlp = torch.nn.ModuleDict(dict( |
| c_fc=torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias), |
| c_proj=torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias), |
| act=torch.nn.GELU(), |
| )) |
| m = self.mlp |
| self.mlpf = lambda x: m.c_proj(m.act(m.c_fc(x))) |
|
|
| def forward(self, x, freqs_cis=None): |
| x = x + self.attn(self.ln_1(x), freqs_cis) |
| x = x + self.mlpf(self.ln_2(x)) |
| return x |
|
|
|
|
| class CosmicFish(torch.nn.Module): |
| """ |
| CosmicFish model for inference only. |
| Features: Rotary Positional Embeddings, Grouped-Query Attention, SwiGLU, RMSNorm |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
|
|
| self.transformer = torch.nn.ModuleDict(dict( |
| wte=torch.nn.Embedding(config.vocab_size, config.n_embd), |
| h=torch.nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
| ln_f=RMSNorm(config.n_embd, eps=config.eps), |
| )) |
|
|
| self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
| |
| self.transformer.wte.weight = self.lm_head.weight |
|
|
| |
| if config.use_rotary: |
| head_dim = config.n_embd // config.n_head |
| self.freqs_cis = precompute_freqs_cis(head_dim, config.block_size) |
| else: |
| self.freqs_cis = None |
| self.transformer.wpe = torch.nn.Embedding(config.block_size, config.n_embd) |
|
|
| def get_num_params(self, non_embedding=True): |
| """Return the number of parameters in the model.""" |
| n_params = sum(p.numel() for p in self.parameters()) |
| if non_embedding and hasattr(self.transformer, 'wpe'): |
| n_params -= self.transformer.wpe.weight.numel() |
| return n_params |
|
|
| def forward(self, idx, targets=None): |
| """Forward pass through the model.""" |
| device = idx.device |
| b, t = idx.size() |
| assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" |
|
|
| |
| tok_emb = self.transformer.wte(idx) |
|
|
| |
| if self.config.use_rotary: |
| x = tok_emb |
| freqs_cis = self.freqs_cis.to(device) if self.freqs_cis is not None else None |
| else: |
| pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) |
| pos_emb = self.transformer.wpe(pos) |
| x = tok_emb + pos_emb |
| freqs_cis = None |
|
|
| |
| for block in self.transformer.h: |
| x = block(x, freqs_cis) |
|
|
| |
| x = self.transformer.ln_f(x) |
|
|
| |
| if targets is not None: |
| logits = self.lm_head(x) |
| loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) |
| else: |
| |
| logits = self.lm_head(x[:, [-1], :]) |
| loss = None |
|
|
| return logits, loss |
|
|
| @torch.no_grad() |
| def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
| """ |
| Generate text by sampling from the model, token by token. |
| """ |
| for _ in range(max_new_tokens): |
| |
| idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] |
|
|
| |
| logits, _ = self(idx_cond) |
| logits = logits[:, -1, :] / temperature |
|
|
| |
| if top_k is not None: |
| v, _ = torch.topk(logits, top_k) |
| logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
| |
| probs = torch.nn.functional.softmax(logits, dim=-1) |
| idx_next = torch.multinomial(probs, num_samples=1) |
|
|
| |
| idx = torch.cat((idx, idx_next), dim=1) |
|
|
| return idx |
|
|
|
|
| class RepetitionPenaltyLogitsProcessor: |
| """Apply repetition penalty to prevent repeating tokens.""" |
|
|
| def __init__(self, penalty=1.2): |
| self.penalty = penalty |
|
|
| def __call__(self, input_ids, scores): |
| """Apply repetition penalty to logits where input_ids is already seen.""" |
| score = torch.gather(scores, 1, input_ids) |
| |
| score = torch.where(score > 0, score / self.penalty, score * self.penalty) |
| scores.scatter_(1, input_ids, score) |
| return scores |
|
|
|
|
| class CosmicFishChatSession: |
| """Chat session for CosmicFish model from Hugging Face Hub.""" |
|
|
| def __init__(self, model, tokenizer, config): |
| """Initialize chat session with model and configuration.""" |
| self.model = model |
| self.tokenizer = tokenizer |
| self.config = config |
| self.device = next(model.parameters()).device |
| self.history = [] |
| self.history_tokens = [] |
| self.max_history_tokens = config.max_history_tokens |
| self.prompt_template = config.prompt_template |
| self.human_prefix = config.human_prefix |
| self.assistant_prefix = config.assistant_prefix |
| self.end_of_turn = config.end_of_turn |
| self.block_size = config.block_size |
| self.debug_mode = config.debug_mode |
| self.repetition_penalty = config.repetition_penalty |
| self.min_tokens_to_generate = config.min_tokens_to_generate |
| self.max_retries = 20 |
|
|
| self.fallback_responses = [ |
| "I'd be happy to help with that. Could you provide more details about what specific information you're looking for?", |
| "That's a topic I can provide information about. What specific aspects would you like to know?", |
| "I understand your question. I can share factual information on this topic if you could specify what aspects you're interested in.", |
| "I can help with your question. To give you the most relevant information, could you clarify what specific details you're looking for?", |
| "I'd be glad to address your question. To provide the most helpful response, could you specify what particular aspects of this topic interest you?" |
| ] |
|
|
| self.generation_failure_message = "I'm sorry, but I'm having difficulty generating a response to that prompt. Could you try rephrasing your question or asking something else?" |
|
|
| |
| self.total_prompt_tokens = 0 |
| self.total_generated_tokens = 0 |
|
|
| |
| self.end_markers = [ |
| f"{self.human_prefix}", |
| "Human:", |
| "\nHuman:", |
| "\nH:", |
| "H:", |
| "<|endoftext|>", |
| "Below is a conversation", |
| "\nA:", |
| "A:", |
| "</s>", |
| "User:", |
| "\nUser:" |
| ] |
|
|
| if config.display_welcome: |
| self._print_welcome_message() |
|
|
| def _print_welcome_message(self): |
| welcome_text = f""" |
| {'=' * 80} |
| Welcome to CosmicFish chat interface |
| |
| This is a {self.model.get_num_params() / 1e6:.1f}M parameter model. |
| CosmicFish is an efficient LLM with an advanced architecture. |
| |
| Type your prompts and CosmicFish will respond. |
| |
| Special commands: |
| - /help: Show this help message |
| - /clear: Clear the conversation history |
| - /exit or /quit: Exit the chat |
| - /stats: Show token usage statistics |
| - /save [filename]: Save the conversation |
| - /load [filename]: Load a conversation |
| - /temp [value]: Set temperature (between 0.1 and 2.0) |
| - /penalty [value]: Set repetition penalty (1.0-2.0) |
| - /debug: Toggle debug mode |
| |
| |
| Note: CosmicFIsh may generate incorrect or fictional responses. Verify facts if needed. |
| |
| Visit https://cosmicfish.ai for more info |
| |
| |
| Developed by Mistyoz AI (https://www.mistyoz.com) |
| {'=' * 80} |
| """ |
| print(colored(welcome_text, 'cyan')) |
|
|
| def _format_prompt(self, user_input): |
| """Format the complete prompt with history and current input.""" |
| |
| formatted_prompt = self.prompt_template |
|
|
| |
| for entry in self.history: |
| role, text = entry |
| if role == "human": |
| formatted_prompt += f"{self.human_prefix}{text}{self.end_of_turn}" |
| else: |
| formatted_prompt += f"{self.assistant_prefix}{text}{self.end_of_turn}" |
|
|
| |
| formatted_prompt += f"{self.human_prefix}{user_input}{self.end_of_turn}{self.assistant_prefix}" |
|
|
| return formatted_prompt |
|
|
| def _tokenize(self, text): |
| """Tokenize text and return token IDs.""" |
| return self.tokenizer.encode(text) |
|
|
| def _update_history(self, user_input, response): |
| """Update conversation history.""" |
| |
| self.history.append(("human", user_input)) |
| self.history.append(("assistant", response)) |
|
|
| |
| user_tokens = self._tokenize(f"{self.human_prefix}{user_input}{self.end_of_turn}") |
| response_tokens = self._tokenize(f"{self.assistant_prefix}{response}{self.end_of_turn}") |
|
|
| self.history_tokens.extend(user_tokens) |
| self.history_tokens.extend(response_tokens) |
|
|
| |
| self.total_prompt_tokens += len(user_tokens) |
| self.total_generated_tokens += len(response_tokens) |
|
|
| |
| self._trim_history_if_needed() |
|
|
| def _trim_history_if_needed(self): |
| """Trim history to fit within the context window.""" |
| if len(self.history_tokens) > self.max_history_tokens: |
| |
| while len(self.history_tokens) > self.max_history_tokens and len(self.history) >= 2: |
| |
| self.history = self.history[2:] |
|
|
| |
| user_turn = self.history[0][1] |
| assistant_turn = self.history[1][1] |
| user_tokens = len(self._tokenize(f"{self.human_prefix}{user_turn}{self.end_of_turn}")) |
| assistant_tokens = len(self._tokenize(f"{self.assistant_prefix}{assistant_turn}{self.end_of_turn}")) |
|
|
| |
| self.history_tokens = self.history_tokens[user_tokens + assistant_tokens:] |
|
|
| def _should_stop_generation(self, text): |
| """Check if generation should stop based on end markers.""" |
| for marker in self.end_markers: |
| if marker in text: |
| return True |
| return False |
|
|
| def _clean_token_text(self, text): |
| text = text.replace('��', "'") |
| text = text.replace('�', "'") |
| text = text.replace('\ufffd', "'") |
| text = text.replace('\uFFFD', "'") |
| text = text.replace('’', "'") |
| text = text.replace('â€Å"', "'") |
| text = text.replace('�', "'") |
| text = text.replace('â€"', "'") |
| text = text.replace('â€"', "'") |
| return text |
|
|
| def generate_with_repetition_penalty(self, input_ids, max_new_tokens, temperature, top_k, penalty=1.2, live=False): |
| """Custom generate function with repetition penalty and optional live generation.""" |
| model = self.model |
| device = self.device |
|
|
| |
| model.eval() |
|
|
| |
| generated = input_ids.clone() |
|
|
| |
| live_buffer = "" |
|
|
| |
| rep_processor = RepetitionPenaltyLogitsProcessor(penalty=penalty) |
|
|
| |
| tokens_generated = 0 |
| min_tokens = self.min_tokens_to_generate |
|
|
| |
| eot_token_id = self.tokenizer.eos_token_id if hasattr(self.tokenizer, 'eos_token_id') else 50256 |
|
|
| |
| for _ in range(max_new_tokens): |
| |
| if generated.size(1) > self.block_size: |
| context = generated[:, -self.block_size:] |
| else: |
| context = generated |
|
|
| |
| with torch.no_grad(): |
| logits, _ = model(context) |
|
|
| |
| next_token_logits = logits[:, -1, :] |
|
|
| |
| next_token_logits = next_token_logits / temperature |
|
|
| |
| if penalty > 1.0: |
| next_token_logits = rep_processor(context, next_token_logits) |
|
|
| |
| if top_k is not None: |
| indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] |
| next_token_logits[indices_to_remove] = float('-inf') |
|
|
| |
| probs = torch.nn.functional.softmax(next_token_logits, dim=-1) |
|
|
| |
| next_token = torch.multinomial(probs, num_samples=1) |
|
|
| |
| if next_token.item() == eot_token_id: |
| if live: |
| yield "", live_buffer, True |
| break |
|
|
| |
| generated = torch.cat((generated, next_token), dim=1) |
| tokens_generated += 1 |
|
|
| |
| if live: |
| |
| next_token_text = self.tokenizer.decode([next_token.item()]) |
| |
| next_token_text = self._clean_token_text(next_token_text) |
| live_buffer += next_token_text |
|
|
| |
| eot_marker_pos = live_buffer.find("<|endoftext|>") |
| if eot_marker_pos != -1: |
| |
| live_buffer = live_buffer[:eot_marker_pos] |
| yield "", live_buffer, True |
| break |
|
|
| |
| should_stop = tokens_generated >= min_tokens and self._should_stop_generation(live_buffer) |
| yield next_token_text, live_buffer, should_stop |
|
|
| if should_stop: |
| break |
|
|
| |
| elif tokens_generated >= min_tokens: |
| |
| recent_text = self.tokenizer.decode(generated[0, -20:].tolist()) |
| if self._should_stop_generation(recent_text): |
| break |
|
|
| |
| if tokens_generated == 0 and not live: |
| if self.debug_mode: |
| print(colored("\n[No tokens generated in this attempt]", "red")) |
| return None |
|
|
| if not live: |
| return generated |
|
|
| def generate_response(self, user_input): |
| """Generate a response to the user input.""" |
| |
| prompt = self._format_prompt(user_input) |
|
|
| |
| input_ids = torch.tensor(self._tokenize(prompt), dtype=torch.long).unsqueeze(0).to(self.device) |
|
|
| |
| if input_ids.size(1) > self.block_size: |
| |
| instruction_tokens = self._tokenize(self.prompt_template) |
| |
| keep_from_beginning = len(instruction_tokens) |
| keep_from_end = self.block_size - keep_from_beginning |
|
|
| |
| if keep_from_end < 0: |
| |
| input_ids = input_ids[:, :self.block_size] |
| else: |
| |
| input_ids = torch.cat([ |
| input_ids[:, :keep_from_beginning], |
| input_ids[:, -(keep_from_end):] |
| ], dim=1) |
|
|
| |
| start_time = time.time() |
|
|
| |
| return self._generate_live_response(input_ids, user_input, start_time) |
|
|
| def _generate_live_response(self, input_ids, user_input, start_time): |
| """Generate response with live token-by-token output.""" |
| |
| live_text = "" |
| tokens_generated = 0 |
| retry_count = 0 |
|
|
| |
| while retry_count <= self.max_retries: |
| if retry_count > 0: |
| |
| if retry_count % 2 == 0: |
| |
| temp_adjustment = min(0.2 * (retry_count // 2), 0.8) |
| current_temp = min(self.config.temperature + temp_adjustment, 1.8) |
| else: |
| |
| temp_adjustment = min(0.2 * ((retry_count + 1) // 2), 0.4) |
| current_temp = max(self.config.temperature - temp_adjustment, 0.2) |
|
|
| if self.debug_mode: |
| print(colored(f"\n[Live retry {retry_count}: Using temperature {current_temp:.2f}]", "yellow")) |
| else: |
| current_temp = self.config.temperature |
|
|
| |
| live_text = "" |
| tokens_generated = 0 |
| generation_failed = False |
|
|
| |
| try: |
| |
| for token_text, live_buffer, should_stop in self.generate_with_repetition_penalty( |
| input_ids, |
| max_new_tokens=self.config.max_new_tokens, |
| temperature=current_temp, |
| top_k=self.config.top_k, |
| penalty=self.repetition_penalty, |
| live=True |
| ): |
| |
| if should_stop: |
| |
| live_text = live_buffer |
| break |
|
|
| |
| if token_text: |
| live_text += token_text |
| tokens_generated += 1 |
| yield token_text, live_text, False |
|
|
| |
| if not live_text or len(live_text.strip()) < 10: |
| if self.debug_mode: |
| print(colored("\n[Live generation produced empty or too short response, retrying]", "yellow")) |
| generation_failed = True |
| retry_count += 1 |
| |
| if retry_count <= self.max_retries: |
| print("\r" + " " * 80 + "\r", end="") |
| else: |
| |
| break |
|
|
| except Exception as e: |
| if self.debug_mode: |
| print(colored(f"\n[Live generation error: {str(e)}, retrying]", "red")) |
| generation_failed = True |
| retry_count += 1 |
|
|
| |
| if generation_failed or not live_text or len(live_text.strip()) < 10: |
| live_text = self.generation_failure_message |
| if self.debug_mode: |
| print(colored(f"\n[Returning failure message after {retry_count} live retries]", "red")) |
|
|
| |
| time_taken = time.time() - start_time |
| tokens_per_second = tokens_generated / time_taken if time_taken > 0 else 0 |
|
|
| |
| self._update_history(user_input, live_text) |
|
|
| |
| logger.debug(f"Generated {tokens_generated} tokens in {time_taken:.2f}s ({tokens_per_second:.2f} tokens/s)") |
|
|
| |
| yield "", live_text, True |
|
|
| def execute_command(self, command): |
| """Execute a special command prefixed with /.""" |
| command = command.strip() |
|
|
| if command == '/help': |
| self._print_welcome_message() |
| return True |
|
|
| elif command == '/clear': |
| self.history = [] |
| self.history_tokens = [] |
| print(colored("Conversation history cleared.", 'yellow')) |
| return True |
|
|
| elif command in ['/exit', '/quit']: |
| print(colored("Goodbye!", 'cyan')) |
| return False |
|
|
| elif command == '/stats': |
| prompt_tokens = self.total_prompt_tokens |
| generated_tokens = self.total_generated_tokens |
| total_tokens = prompt_tokens + generated_tokens |
|
|
| stats = f""" |
| Token usage statistics: |
| - Prompt tokens: {prompt_tokens} |
| - Generated tokens: {generated_tokens} |
| - Total tokens: {total_tokens} |
| - Current history length: {len(self.history_tokens)} tokens |
| - Current repetition penalty: {self.repetition_penalty} |
| - Current temperature: {self.config.temperature} |
| - Model: CosmicFish ({self.model.get_num_params() / 1e6:.1f}M parameters) |
| - Source: {DEFAULT_MODEL_REPO} |
| - Format: Safetensors (secure) |
| """ |
| print(colored(stats, 'yellow')) |
| return True |
|
|
| elif command == '/debug': |
| self.debug_mode = not self.debug_mode |
| self.config.debug_mode = self.debug_mode |
| mode = "enabled" if self.debug_mode else "disabled" |
| print(colored(f"Debug mode {mode}", 'yellow')) |
| return True |
|
|
| elif command.startswith('/penalty '): |
| try: |
| penalty = float(command[9:].strip()) |
| if 1.0 <= penalty <= 2.0: |
| self.repetition_penalty = penalty |
| print(colored(f"Repetition penalty set to {penalty}", 'yellow')) |
| else: |
| print(colored("Repetition penalty should be between 1.0 and 2.0", 'red')) |
| except ValueError: |
| print(colored("Invalid repetition penalty value. Please use a number between 1.0 and 2.0", 'red')) |
| return True |
|
|
| elif command.startswith('/temp '): |
| try: |
| temp = float(command[6:].strip()) |
| if 0.1 <= temp <= 2.0: |
| self.config.temperature = temp |
| print(colored(f"Temperature set to {temp}", 'yellow')) |
| else: |
| print(colored("Temperature should be between 0.1 and 2.0", 'red')) |
| except ValueError: |
| print(colored("Invalid temperature value. Please use a number between 0.1 and 2.0", 'red')) |
| return True |
|
|
| elif command.startswith('/save '): |
| filename = command[6:].strip() |
| if not filename: |
| print(colored("Please specify a filename: /save <filename>", 'red')) |
| return True |
|
|
| try: |
| |
| os.makedirs('conversations', exist_ok=True) |
|
|
| |
| if not filename.endswith('.txt'): |
| filename += '.txt' |
|
|
| filepath = os.path.join('conversations', filename) |
|
|
| with open(filepath, 'w', encoding='utf-8') as f: |
| for entry in self.history: |
| role, text = entry |
| prefix = self.human_prefix if role == "human" else self.assistant_prefix |
| f.write(f"{prefix}{text}{self.end_of_turn}") |
|
|
| print(colored(f"Conversation saved to {filepath}", 'green')) |
|
|
| except Exception as e: |
| print(colored(f"Error saving conversation: {str(e)}", 'red')) |
|
|
| return True |
|
|
| elif command.startswith('/load '): |
| filename = command[6:].strip() |
| if not filename: |
| print(colored("Please specify a filename: /load <filename>", 'red')) |
| return True |
|
|
| try: |
| |
| if not filename.endswith('.txt'): |
| filename += '.txt' |
|
|
| filepath = os.path.join('conversations', filename) |
|
|
| if not os.path.exists(filepath): |
| print(colored(f"File not found: {filepath}", 'red')) |
| return True |
|
|
| with open(filepath, 'r', encoding='utf-8') as f: |
| content = f.read() |
|
|
| |
| self.history = [] |
| self.history_tokens = [] |
|
|
| |
| turns = content.split(self.end_of_turn) |
| for turn in turns: |
| turn = turn.strip() |
| if not turn: |
| continue |
|
|
| if turn.startswith(self.human_prefix): |
| text = turn[len(self.human_prefix):].strip() |
| self.history.append(("human", text)) |
| elif turn.startswith(self.assistant_prefix): |
| text = turn[len(self.assistant_prefix):].strip() |
| self.history.append(("assistant", text)) |
|
|
| |
| self.history_tokens = [] |
| for entry in self.history: |
| role, text = entry |
| if role == "human": |
| self.history_tokens.extend(self._tokenize(f"{self.human_prefix}{text}{self.end_of_turn}")) |
| else: |
| self.history_tokens.extend(self._tokenize(f"{self.assistant_prefix}{text}{self.end_of_turn}")) |
|
|
| print(colored(f"Loaded conversation from {filepath} ({len(self.history) // 2} turns)", 'green')) |
|
|
| |
| for i in range(0, len(self.history), 2): |
| if i < len(self.history): |
| user_text = self.history[i][1] |
| print(colored(f"\nYou: {user_text}", 'green')) |
|
|
| if i + 1 < len(self.history): |
| assistant_text = self.history[i + 1][1] |
| print(colored("CosmicFish: ", 'blue'), end="") |
| for line in assistant_text.split('\n'): |
| wrapped_lines = textwrap.wrap(line, width=100) if line.strip() else [''] |
| for wrapped_line in wrapped_lines: |
| print(wrapped_line) |
|
|
| except Exception as e: |
| print(colored(f"Error loading conversation: {str(e)}", 'red')) |
|
|
| return True |
|
|
| else: |
| print(colored(f"Unknown command: {command}. Type /help for available commands.", 'red')) |
| return True |
|
|
|
|
| def download_cosmicfish_from_hub(model_repo=DEFAULT_MODEL_REPO, device='cpu'): |
| """Download and load CosmicFish model from Hugging Face Hub (safetensors only)""" |
| print(colored(f"Downloading CosmicFish from Hugging Face: {model_repo}", "cyan")) |
|
|
| try: |
| |
| print("Downloading model files...") |
| cache_dir = snapshot_download(repo_id=model_repo, cache_dir=None) |
| print(f"Model cached at: {cache_dir}") |
|
|
| |
| config_path = os.path.join(cache_dir, "config.json") |
| with open(config_path, "r") as f: |
| config_dict = json.load(f) |
|
|
| |
| config = CosmicConfig( |
| vocab_size=config_dict["vocab_size"], |
| block_size=config_dict["block_size"], |
| n_layer=config_dict["n_layer"], |
| n_head=config_dict["n_head"], |
| n_embd=config_dict["n_embd"], |
| bias=config_dict["bias"], |
| dropout=0.0, |
| eps=config_dict.get("eps", 1e-6), |
| use_rotary=config_dict["use_rotary"], |
| use_swiglu=config_dict["use_swiglu"], |
| use_gqa=config_dict["use_gqa"], |
| n_query_groups=config_dict["n_query_groups"], |
| use_qk_norm=config_dict.get("use_qk_norm", False) |
| ) |
|
|
| |
| print("Creating model...") |
| model = CosmicFish(config) |
|
|
| |
| print("Loading weights from safetensors...") |
| safetensors_path = os.path.join(cache_dir, "model.safetensors") |
|
|
| if not os.path.exists(safetensors_path): |
| raise FileNotFoundError(f"model.safetensors not found in {cache_dir}. This model requires safetensors format.") |
|
|
| state_dict = load_file(safetensors_path) |
|
|
| |
| if 'lm_head.weight' not in state_dict and 'transformer.wte.weight' in state_dict: |
| state_dict['lm_head.weight'] = state_dict['transformer.wte.weight'] |
|
|
| model.load_state_dict(state_dict) |
| model.to(device) |
| model.eval() |
|
|
| print(f"Model loaded: {model.get_num_params() / 1e6:.1f}M parameters") |
| print(f"Device: {device}") |
| return model, config |
|
|
| except Exception as e: |
| print(colored(f"Error downloading/loading model: {str(e)}", "red")) |
| print(colored("Make sure you have internet connection and the model repo exists", "yellow")) |
| sys.exit(1) |
|
|
|
|
| def load_tokenizer(): |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
| return tokenizer |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Chat with CosmicFish") |
|
|
| |
| parser.add_argument("--model_repo", type=str, default=DEFAULT_MODEL_REPO, |
| help=f"Hugging Face model repository (default: {DEFAULT_MODEL_REPO})") |
| parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", |
| help="Device to use (cuda or cpu)") |
|
|
| |
| parser.add_argument("--temperature", type=float, default=0.5, |
| help="Temperature for sampling (default: 0.7)") |
| parser.add_argument("--max_tokens", type=int, default=512, |
| help="Maximum number of tokens to generate per response") |
| parser.add_argument("--min_tokens", type=int, default=10, |
| help="Minimum number of tokens to generate per response") |
| parser.add_argument("--top_k", type=int, default=40, |
| help="Top-k sampling (0 to disable)") |
| parser.add_argument("--repetition_penalty", type=float, default=1.2, |
| help="Repetition penalty (1.0 = no penalty, 1.2 = mild, 1.5 = moderate)") |
|
|
| |
| parser.add_argument("--human_prefix", type=str, default="Human: ", |
| help="Prefix for human messages") |
| parser.add_argument("--assistant_prefix", type=str, default="Assistant: ", |
| help="Prefix for assistant messages") |
| parser.add_argument("--end_of_turn", type=str, default="\n\n", |
| help="Delimiter between conversation turns") |
| parser.add_argument("--instruction", type=str, |
| default=DEFAULT_PROMPT_TEMPLATE, |
| help="Instruction prompt to prepend to the conversation") |
| parser.add_argument("--max_history", type=int, default=512, |
| help="Maximum number of tokens to keep in history") |
|
|
| |
| parser.add_argument("--no_welcome", action="store_true", |
| help="Don't display the welcome message") |
| parser.add_argument("--debug", action="store_true", |
| help="Enable debug mode") |
|
|
| args = parser.parse_args() |
|
|
| |
| device = args.device |
| if device == "cuda" and not torch.cuda.is_available(): |
| print(colored("CUDA is not available, falling back to CPU", "yellow")) |
| device = "cpu" |
|
|
| try: |
| |
| model, model_config = download_cosmicfish_from_hub(args.model_repo, device) |
|
|
| |
| tokenizer = load_tokenizer() |
|
|
| |
| class ChatConfig: |
| def __init__(self, args, block_size): |
| self.device = device |
| self.temperature = args.temperature |
| self.max_new_tokens = args.max_tokens |
| self.min_tokens_to_generate = args.min_tokens |
| self.top_k = args.top_k |
| self.human_prefix = args.human_prefix |
| self.assistant_prefix = args.assistant_prefix |
| self.end_of_turn = args.end_of_turn |
| self.prompt_template = args.instruction |
| self.max_history_tokens = args.max_history |
| self.display_welcome = not args.no_welcome |
| self.block_size = block_size |
| self.debug_mode = args.debug |
| self.repetition_penalty = args.repetition_penalty |
|
|
| config = ChatConfig(args, model_config.block_size) |
|
|
| |
| chat = CosmicFishChatSession(model, tokenizer, config) |
|
|
| |
| print(colored("\nCosmicFish initialized from Hugging Face! Type your message (or /help for commands).\n", 'cyan')) |
|
|
| while True: |
| try: |
| |
| user_input = input(colored("You: ", 'green')) |
|
|
| |
| if user_input.startswith('/'): |
| |
| if not chat.execute_command(user_input): |
| break |
| continue |
|
|
| |
| if not user_input.strip(): |
| continue |
|
|
| |
| live_buffer = "" |
| final_response = None |
|
|
| |
| response_generator = chat.generate_response(user_input) |
|
|
| try: |
| |
| print(colored("CosmicFish: ", 'blue'), end="") |
| sys.stdout.flush() |
|
|
| for token, live_text, is_done in response_generator: |
| |
| if is_done: |
| final_response = live_text |
| |
| if not live_buffer: |
| print(final_response, end="") |
| break |
| if token: |
| |
| if "<|endoftext|>" in token: |
| token = token.replace("<|endoftext|>", "") |
| if token: |
| print(token, end="", flush=True) |
| break |
|
|
| |
| print(token, end="", flush=True) |
| live_buffer += token |
|
|
| except KeyboardInterrupt: |
| |
| print("\n[Generation interrupted]") |
| final_response = "I was going to respond, but I'll stop here since you interrupted." |
|
|
| |
| print() |
|
|
| except KeyboardInterrupt: |
| print("\n\nKeyboard interrupt detected. Type /exit to quit or continue chatting.") |
|
|
| except Exception as e: |
| print(colored(f"\nError: {str(e)}", 'red')) |
| logger.error(f"Error in chat loop: {str(e)}", exc_info=True) |
|
|
| except Exception as e: |
| print(colored(f"Error setting up chat: {str(e)}", 'red')) |
| logger.error(f"Error setting up chat: {str(e)}", exc_info=True) |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| try: |
| main() |
| except Exception as e: |
| logger.error(f"Fatal error: {str(e)}", exc_info=True) |
| sys.exit(1) |
|
|