| | |
| | import torch |
| | from torch.utils.data import DataLoader |
| | from transformers import ( |
| | AutoTokenizer, |
| | TrainingArguments, |
| | Trainer, |
| | default_data_collator, |
| | ) |
| | from datasets import load_dataset |
| | from myolmoe import MyOlmoeForCausalLM, OlmoeConfig |
| | import os |
| | from transformers import TrainerCallback |
| | import subprocess |
| |
|
| | def main(): |
| | print("Starting my COOL OLMoE training script for small experts") |
| | |
| | config_path = os.path.join("myolmoe", "config.json") |
| | if os.path.exists(config_path): |
| | config = OlmoeConfig.from_json_file(config_path) |
| | else: |
| | config = OlmoeConfig.from_pretrained("myolmoe") |
| | |
| | |
| | model = MyOlmoeForCausalLM.from_pretrained( |
| | "myolmoe", |
| | config=config, |
| | torch_dtype=torch.bfloat16, |
| | device_map="auto", |
| | ignore_mismatched_sizes=True |
| | ) |
| | |
| | |
| | tokenizer = AutoTokenizer.from_pretrained("myolmoe") |
| | tokenizer.pad_token = tokenizer.eos_token |
| | |
| | |
| | dataset = load_dataset("allenai/tulu-v2-sft-mixture", split="train") |
| | |
| | def tokenize_function(examples): |
| | texts = [] |
| | for message_list in examples["messages"]: |
| | formatted = "" |
| | for msg in message_list: |
| | role = msg["role"] |
| | content = msg["content"] |
| | if role == "user": |
| | formatted += f"User: {content}\n" |
| | elif role == "assistant": |
| | formatted += f"Assistant: {content}\n" |
| | else: |
| | formatted += f"{role.capitalize()}: {content}\n" |
| | texts.append(formatted) |
| |
|
| | tokenized = tokenizer( |
| | texts, |
| | truncation=True, |
| | max_length=4096, |
| | padding="max_length" |
| | ) |
| | tokenized["labels"] = tokenized["input_ids"].copy() |
| | return tokenized |
| |
|
| | tokenized_dataset = dataset.map( |
| | tokenize_function, |
| | batched=True, |
| | remove_columns=dataset.column_names, |
| | num_proc=4 |
| | ) |
| | |
| | |
| | training_args = TrainingArguments( |
| | output_dir="./checkpoints", |
| | per_device_train_batch_size=2, |
| | gradient_accumulation_steps=8, |
| | learning_rate=1e-4, |
| | num_train_epochs=3, |
| | logging_dir="./logs", |
| | logging_steps=10, |
| | save_steps=20, |
| | save_total_limit=1, |
| | bf16=True, |
| | gradient_checkpointing=False, |
| | report_to="tensorboard", |
| | optim="adamw_torch", |
| | lr_scheduler_type="cosine", |
| | warmup_ratio=0.1, |
| | max_grad_norm=1.0, |
| | ) |
| | |
| | |
| | for param in model.parameters(): |
| | param.requires_grad = False |
| | |
| | |
| | trainable_params = [] |
| | for name, param in model.named_parameters(): |
| | if ( |
| | "small_experts" in name or |
| | "small_gate" in name |
| | ): |
| | param.requires_grad = True |
| | trainable_params.append(name) |
| | |
| | if trainable_params: |
| | print(f"[INFO] Found {len(trainable_params)} small_expert/small_gate parameters.") |
| | else: |
| | print("[WARNING] No small_expert or small_gate parameters found in model!") |
| |
|
| | |
| | unfrozen = [name for name, param in model.named_parameters() if param.requires_grad] |
| | if unfrozen: |
| | print(f"[INFO] {len(unfrozen)} parameters are unfrozen and trainable.") |
| | for name in unfrozen: |
| | print(f" - {name}") |
| | else: |
| | print("[ERROR] No parameters were unfrozen! Training will not update anything.") |
| |
|
| | print(f"Total trainable parameters: {len(trainable_params)}") |
| | |
| | |
| | for name, param in model.named_parameters(): |
| | if param.requires_grad: |
| | print(f"Parameter {name} requires grad: {param.requires_grad}") |
| |
|
| | |
| | def data_collator(features): |
| | batch = default_data_collator(features) |
| | batch["output_router_logits"] = True |
| | return batch |
| |
|
| | |
| | class CustomTrainer(Trainer): |
| | def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
| | |
| | inputs = {k: v for k, v in inputs.items() if k not in ['num_items_in_batch']} |
| | |
| | |
| | model.train() |
| | |
| | |
| | with torch.set_grad_enabled(True): |
| | outputs = model(**inputs) |
| | loss = outputs.loss |
| | |
| | if not loss.requires_grad: |
| | raise RuntimeError("Loss doesn't require gradients. Check model parameters.") |
| | |
| | return (loss, outputs) if return_outputs else loss |
| | |
| | class GitPushCallback(TrainerCallback): |
| | def on_save(self, args, state, control, **kwargs): |
| | try: |
| | print("Saving checkpoint to Git repo...") |
| | |
| | |
| | subprocess.run(["git", "add", "."], check=True) |
| |
|
| | |
| | result = subprocess.run(["git", "diff", "--cached", "--quiet"]) |
| | if result.returncode == 0: |
| | print("No changes to commit.") |
| | return |
| |
|
| | subprocess.run(["git", "commit", "-m", f'Checkpoint at step {state.global_step}'], check=True) |
| | subprocess.run(["git", "push"], check=True) |
| | print("Checkpoint pushed successfully.") |
| | except subprocess.CalledProcessError as e: |
| | print(f"Git push failed: {e}") |
| | class SmallExpertSaveCallback(TrainerCallback): |
| | def __init__(self, model, trainable_params): |
| | self.model = model |
| | self.trainable_params = trainable_params |
| |
|
| | def on_save(self, args, state, control, **kwargs): |
| | |
| | checkpoint_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") |
| | small_expert_path = os.path.join(checkpoint_dir, "small_experts_and_gates.bin") |
| |
|
| | small_expert_state_dict = { |
| | name: param for name, param in self.model.named_parameters() |
| | if name in self.trainable_params |
| | } |
| |
|
| | if small_expert_state_dict: |
| | os.makedirs(checkpoint_dir, exist_ok=True) |
| | torch.save(small_expert_state_dict, small_expert_path) |
| | print(f"[INFO] Saved {len(small_expert_state_dict)} small_expert/small_gate parameters " |
| | f"to {small_expert_path}") |
| | else: |
| | print("[ERROR] No small_expert or small_gate parameters found to save!") |
| |
|
| | |
| | trainer = CustomTrainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=tokenized_dataset, |
| | data_collator=data_collator, |
| | callbacks=[ |
| | GitPushCallback(), |
| | SmallExpertSaveCallback(model, trainable_params) |
| | ] |
| | ) |
| | |
| | |
| | print("Testing gradient flow...") |
| | test_loader = DataLoader(tokenized_dataset, batch_size=1, collate_fn=data_collator) |
| | test_batch = next(iter(test_loader)) |
| | |
| | |
| | device = next(model.parameters()).device |
| | test_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in test_batch.items()} |
| | |
| | model.train() |
| | outputs = model(**test_batch) |
| | loss = outputs.loss |
| | print(f"Initial loss: {loss.item()}") |
| | |
| | loss.backward() |
| | print("Gradients computed successfully") |
| | |
| | |
| | for name, param in model.named_parameters(): |
| | if param.grad is not None: |
| | print(f"Parameter {name} received gradients") |
| | |
| | |
| | model.zero_grad() |
| |
|
| | |
| | import re |
| |
|
| | checkpoint_dir = None |
| | if os.path.isdir(training_args.output_dir): |
| | checkpoints = [ |
| | os.path.join(training_args.output_dir, d) |
| | for d in os.listdir(training_args.output_dir) |
| | if re.match(r"checkpoint-\d+", d) |
| | ] |
| | if checkpoints: |
| | |
| | checkpoint_dir = max(checkpoints, key=lambda x: int(x.split('-')[-1])) |
| | print(f"Resuming from checkpoint: {checkpoint_dir}") |
| |
|
| |
|
| | |
| | print("Starting training...") |
| | trainer.train(resume_from_checkpoint=checkpoint_dir) |
| |
|
| | |
| | print("Saving small experts and gates...") |
| | small_expert_state_dict = { |
| | name: param for name, param in model.named_parameters() |
| | if name in trainable_params |
| | } |
| | |
| | os.makedirs("./final_model", exist_ok=True) |
| | torch.save(small_expert_state_dict, "./final_model/small_experts_and_gates.bin") |
| | config.save_pretrained("./final_model") |
| |
|
| | if __name__ == "__main__": |
| | main() |