File size: 4,191 Bytes
b6a400a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#!/usr/bin/env python3
# /// script
# dependencies = [
#     "trl>=0.12.0",
#     "peft>=0.7.0",
#     "transformers>=4.36.0",
#     "accelerate>=0.24.0",
#     "trackio",
# ]
# ///

from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import trackio
import os

print("πŸš€ Medium-Scale SFT Training with Trackio")
print("=" * 60)

# Initialize Trackio with Space sync
print("\nπŸ“Š Initializing Trackio...")
trackio.init(
    project="medium-sft-training",
    space_id="evalstate/trl-trackio-dashboard",
    config={
        "model": "Qwen/Qwen2.5-0.5B",
        "dataset": "trl-lib/Capybara",
        "dataset_size": 1000,
        "num_epochs": 3,
        "learning_rate": 2e-5,
        "batch_size": 4,
        "gradient_accumulation": 4,
        "lora_r": 16,
        "lora_alpha": 32,
        "hardware": "a10g-large",
    }
)
print("βœ… Trackio initialized!")
print("πŸ“ˆ Dashboard: https://huggingface.co/spaces/evalstate/trl-trackio-dashboard")

# Load dataset - 1000 examples
print("\nπŸ“Š Loading dataset...")
dataset = load_dataset("trl-lib/Capybara", split="train[:1000]")
print(f"βœ… Dataset loaded: {len(dataset)} examples")

# Get username
username = os.environ.get("HF_USERNAME", "evalstate")

# Training configuration - production settings
print("\nβš™οΈ  Configuring training...")
config = SFTConfig(
    # Output and Hub settings
    output_dir="qwen-capybara-medium",
    push_to_hub=True,
    hub_model_id=f"{username}/qwen-capybara-medium",
    hub_strategy="every_save",  # Push all checkpoints

    # Training parameters - 3 epochs on 1K examples
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,  # Effective batch size = 16

    # Learning rate and schedule
    learning_rate=2e-5,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",

    # Logging and checkpointing
    logging_steps=10,  # Log every 10 steps
    save_strategy="steps",
    save_steps=50,  # Save every 50 steps
    save_total_limit=3,  # Keep only 3 latest checkpoints

    # Evaluation
    eval_strategy="steps",
    eval_steps=50,

    # Optimization
    bf16=True,  # Use bfloat16 for A10G
    gradient_checkpointing=True,  # Save memory

    # Trackio monitoring
    report_to="trackio",
)

# LoRA configuration - larger than demo
print("πŸ”§ Setting up LoRA (r=16)...")
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # More modules
)

# Create eval split
print("\nπŸ”€ Creating train/eval split...")
dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = dataset_split["train"]
eval_dataset = dataset_split["test"]
print(f"   Train: {len(train_dataset)} examples")
print(f"   Eval: {len(eval_dataset)} examples")

# Initialize trainer
print("\n🎯 Initializing trainer...")
trainer = SFTTrainer(
    model="Qwen/Qwen2.5-0.5B",
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=config,
    peft_config=peft_config,
)

# Calculate training info
total_steps = (len(train_dataset) // (4 * 4)) * 3  # samples / (batch * grad_accum) * epochs
print(f"\nπŸ“Š Training Info:")
print(f"   Total steps: ~{total_steps}")
print(f"   Epochs: 3")
print(f"   Effective batch size: 16")
print(f"   Expected time: ~45-60 minutes")
print(f"   Checkpoints saved every 50 steps")

# Train!
print("\nπŸƒ Starting training...")
print("πŸ“ˆ Watch live metrics: https://huggingface.co/spaces/evalstate/trl-trackio-dashboard")
print("-" * 60)
trainer.train()

# Save to Hub
print("\nπŸ’Ύ Pushing final model to Hub...")
trainer.push_to_hub()

# Finish Trackio
print("\nπŸ“Š Finalizing Trackio metrics...")
trackio.finish()

print("\n" + "=" * 60)
print("βœ… Training complete!")
print(f"πŸ“¦ Model: https://huggingface.co/{username}/qwen-capybara-medium")
print(f"πŸ“Š Metrics: https://huggingface.co/spaces/evalstate/trl-trackio-dashboard")
print(f"πŸ’‘ Try the model with:")
print(f'   from transformers import pipeline')
print(f'   generator = pipeline("text-generation", model="{username}/qwen-capybara-medium")')
print("=" * 60)