iac-coder-1.5b / scripts /train_iac.py
Tejas86's picture
Upload scripts/train_iac.py
6d66e19 verified
"""
IaC (Infrastructure as Code) Language Model Fine-Tuning Script (GPU version)
=============================================================================
Base model: Qwen/Qwen2.5-Coder-1.5B (Apache-2.0, 1.5B params, code-specialized)
Dataset: bigcode/commitpackft (HCL + YAML + Dockerfile + Shell + Nix + Makefile + SaltStack + TOML)
Method: LoRA SFT via TRL SFTTrainer
Recipe: Based on DocCGen (arxiv:2406.11925) + Astraios (arxiv:2401.00788)
Hardware: a10g-large (24GB VRAM) recommended
Expected training time: ~3-4 hours
Usage:
pip install trl transformers peft datasets torch accelerate trackio
python train_iac.py
"""
import os
import torch
from datasets import load_dataset, concatenate_datasets
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
# ── Trackio setup ────────────────────────────────────────────────────────────
os.environ["TRACKIO_SPACE_ID"] = "Tejas86/iac-coder-trackio"
os.environ["TRACKIO_PROJECT"] = "iac-coder"
# ── Configuration ────────────────────────────────────────────────────────────
MODEL_NAME = "Qwen/Qwen2.5-Coder-1.5B"
HUB_MODEL_ID = "Tejas86/iac-coder-1.5b"
MAX_SEQ_LENGTH = 2048
BATCH_SIZE = 2
GRAD_ACCUM = 8 # effective batch = 16
NUM_EPOCHS = 3
LEARNING_RATE = 2e-4
# ── Load IaC datasets ────────────────────────────────────────────────────────
BASE_URL = "hf://datasets/bigcode/commitpackft/data"
IAC_CONFIGS = {
"hcl": "Terraform HCL",
"yaml": "YAML (Ansible/K8s)",
"dockerfile": "Dockerfile",
"shell": "Shell/Bash script",
"nix": "Nix expression",
"makefile": "Makefile",
"saltstack": "SaltStack state",
"toml": "TOML configuration",
}
print("Loading IaC datasets from bigcode/commitpackft...")
raw_datasets = {}
for config_name, display_name in IAC_CONFIGS.items():
ds = load_dataset("json", data_files=f"{BASE_URL}/{config_name}/data.jsonl", split="train")
raw_datasets[config_name] = ds
print(f" {display_name:30s}: {len(ds):>6} samples")
# ── Format into conversational messages ──────────────────────────────────────
def to_messages(example, lang_display):
subject = (example.get("subject") or "").strip()
old_contents = (example.get("old_contents") or "").strip()
new_contents = (example.get("new_contents") or "").strip()
if not new_contents or not subject:
return {"messages": None}
if old_contents:
user_content = (
f"Task: {subject}\n\n"
f"Modify the following {lang_display} file:\n"
f"```\n{old_contents}\n```"
)
else:
user_content = (
f"Task: {subject}\n\n"
f"Generate the {lang_display} code."
)
assistant_content = f"```\n{new_contents}\n```"
return {
"messages": [
{"role": "system", "content": f"You are an expert DevOps engineer specializing in Infrastructure as Code. Generate clean, production-ready {lang_display} code."},
{"role": "user", "content": user_content},
{"role": "assistant", "content": assistant_content},
]
}
print("\nProcessing datasets...")
processed_datasets = []
for config_name, display_name in IAC_CONFIGS.items():
ds = raw_datasets[config_name]
processed = ds.map(
lambda ex, ld=display_name: to_messages(ex, ld),
remove_columns=ds.column_names, num_proc=4,
)
processed = processed.filter(lambda x: x["messages"] is not None, num_proc=4)
# Cap large configs for balance
if config_name == "yaml" and len(processed) > 30000:
processed = processed.shuffle(seed=42).select(range(30000))
elif config_name == "shell" and len(processed) > 15000:
processed = processed.shuffle(seed=42).select(range(15000))
print(f" {display_name:30s}: {len(processed):>6}")
processed_datasets.append(processed)
train_dataset = concatenate_datasets(processed_datasets).shuffle(seed=42)
print(f"\nTotal: {len(train_dataset)} samples")
eval_size = min(1000, max(100, int(len(train_dataset) * 0.02)))
splits = train_dataset.train_test_split(test_size=eval_size, seed=42)
train_dataset = splits["train"]
eval_dataset = splits["test"]
print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
# ── LoRA Config ──────────────────────────────────────────────────────────────
peft_config = LoraConfig(
r=16, lora_alpha=32, lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
bias="none", task_type="CAUSAL_LM",
)
# ── Training Config ──────────────────────────────────────────────────────────
training_args = SFTConfig(
output_dir="./iac-coder-1.5b-output",
hub_model_id=HUB_MODEL_ID,
push_to_hub=True,
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=LEARNING_RATE,
lr_scheduler_type="linear",
warmup_steps=150,
weight_decay=0.01,
max_grad_norm=1.0,
max_length=MAX_SEQ_LENGTH,
packing=False,
logging_steps=25,
logging_first_step=True,
disable_tqdm=True,
report_to="trackio",
run_name="iac-coder-1.5b-lora-sft",
eval_strategy="steps",
eval_steps=500,
save_strategy="steps",
save_steps=500,
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
hub_strategy="checkpoint",
seed=42, data_seed=42,
)
# ── Train ─────────────────────────────────────────────────────────────────────
print("\nInitializing trainer...")
trainer = SFTTrainer(
model=MODEL_NAME,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
)
print("Starting training...")
trainer.train()
print("Pushing to Hub...")
trainer.push_to_hub(commit_message="Final IaC-Coder 1.5B LoRA SFT model")
print(f"\nβœ… Model: https://huggingface.co/{HUB_MODEL_ID}")
print(f"πŸ“Š Dashboard: https://huggingface.co/spaces/Tejas86/iac-coder-trackio")