File size: 31,593 Bytes
29fc577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
"""
train/sft.py — Supervised Fine-Tuning (SFT) entry point.

Loads a pretrained checkpoint and fine-tunes it on instruction/conversation
data using SFTDataset, which masks prompt tokens with ignore_index=-1 so only
the assistant response tokens contribute to the loss.

Launch single-GPU:
    python train/sft.py \\
        --base_checkpoint checkpoints/korean_1b_fp8_run1/checkpoint-0034000 \\
        --sft_data data/sft/train.jsonl \\
        --device cuda:0

Launch multi-GPU (DDP via torchrun, 7 GPU):
    torchrun --nproc_per_node=7 train/sft.py \\
        --base_checkpoint checkpoints/3b_final/checkpoint-0319772 \\
        --sft_data data/sft_combined/train_filtered.jsonl

KEY DIFFERENCES from pretrain.py:
  - Loads weights from a pretrained checkpoint via LLM.from_pretrained()
  - Uses SFTDataset (JSONL instruction data) instead of PackedDataset
  - Lower default learning rate (2e-5 vs 2e-4)
  - Fewer default steps (3000 vs 100000)
  - Copies tokenizer.json to checkpoint_dir for easy deployment
"""

from __future__ import annotations

import argparse
import os
import random
import signal
import shutil
import sys
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler

# B200 Tensor Core 최대 활용: TF32 matmul + cuDNN
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")  # TF32 precision for fp32 matmul

# Allow imports from the project root regardless of working directory.
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(_PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(_PROJECT_ROOT))

from model import LLM
from train.trainer import TrainConfig, Trainer
from train.utils import (
    cleanup_ddp,
    get_cosine_schedule_with_warmup,
    is_main_process,
    load_checkpoint,
    setup_ddp,
)

# ---------------------------------------------------------------------------
# Optional TransformerEngine import (FP8 support)
# ---------------------------------------------------------------------------
try:
    import transformer_engine.pytorch as te  # type: ignore[import]
    HAS_TE = True
except ImportError:
    te = None  # type: ignore[assignment]
    HAS_TE = False


# ---------------------------------------------------------------------------
# Argument parsing
# ---------------------------------------------------------------------------


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Supervised Fine-Tuning (SFT) of a pretrained decoder-only LLM.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    # --- Required paths -----------------------------------------------------
    parser.add_argument(
        "--base_checkpoint",
        type=Path,
        required=True,
        help=(
            "Path to the pretrained checkpoint directory. "
            "Must contain model.pt and config.yaml (produced by save_checkpoint)."
        ),
    )
    parser.add_argument(
        "--sft_data",
        type=Path,
        required=True,
        help="Path to the JSONL SFT training data file.",
    )

    # --- Optional paths -----------------------------------------------------
    parser.add_argument(
        "--val_data",
        type=Path,
        default=None,
        help="Optional path to JSONL SFT validation data file.",
    )
    parser.add_argument(
        "--checkpoint_dir",
        type=Path,
        default=Path("checkpoints/korean_1b_sft"),
        help="Root directory for saving SFT checkpoints.",
    )
    parser.add_argument(
        "--resume",
        type=Path,
        default=None,
        help="Path to an SFT checkpoint directory to resume fine-tuning from.",
    )
    parser.add_argument(
        "--tokenizer",
        type=Path,
        default=None,
        help=(
            "Override path to tokenizer.json. "
            "Defaults to <base_checkpoint>/tokenizer.json, "
            "then falls back to tokenizer/korean_sp/tokenizer.json."
        ),
    )
    parser.add_argument(
        "--log_file",
        type=Path,
        default=None,
        help=(
            "Path to a text file for structured training logs (rank-0 only). "
            "If omitted, logs go only to stdout."
        ),
    )

    # --- Training hyper-parameters ------------------------------------------
    parser.add_argument(
        "--max_steps",
        type=int,
        default=3000,
        help="Total number of optimiser steps.",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=4,
        help="Per-GPU micro-batch size.",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=2e-5,
        help=(
            "Peak learning rate. "
            "SFT uses a much lower lr than pretraining (2e-5 vs 2e-4) "
            "to preserve pretrained representations."
        ),
    )
    parser.add_argument(
        "--weight_decay",
        type=float,
        default=0.01,
        help="AdamW weight decay. Lower than pretrain (0.01 vs 0.1).",
    )
    parser.add_argument(
        "--warmup_steps",
        type=int,
        default=100,
        help="Number of linear LR warmup steps.",
    )
    parser.add_argument(
        "--grad_accum",
        type=int,
        default=2,
        help="Gradient accumulation steps.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Base random seed (rank offset is added automatically in DDP).",
    )
    parser.add_argument(
        "--use_fp8",
        action="store_true",
        default=False,
        help=(
            "Enable TransformerEngine FP8 training "
            "(requires B200/H100, uses MXFP8BlockScaling)."
        ),
    )

    # --- Single-GPU device override (ignored when using torchrun) -----------
    parser.add_argument(
        "--device",
        type=str,
        default=None,
        help=(
            "Explicit device string (e.g. 'cuda:0'). "
            "Ignored when running under torchrun (DDP auto-assigns devices)."
        ),
    )

    parser.add_argument(
        "--config", type=Path, default=None,
        help="YAML config file. Values under 'train:' section are used as CLI defaults.",
    )
    parser.add_argument("--save_interval", type=int, default=500, help="Checkpoint save interval (steps).")
    parser.add_argument("--eval_interval", type=int, default=250, help="Validation eval interval (steps).")
    parser.add_argument("--neftune_alpha", type=float, default=5.0, help="NEFTune noise magnitude (0 to disable).")
    parser.add_argument("--no_fp8", action="store_true", default=False, help="Force disable FP8 even if pretrained config has use_fp8=True.")
    parser.add_argument("--num_workers", type=int, default=4, help="Number of DataLoader worker processes.")
    parser.add_argument("--max_val_batches", type=int, default=0, help="Max validation batches (0=unlimited).")

    # First pass: just get --config
    args, remaining = parser.parse_known_args()

    # Load YAML config and apply values as defaults
    if args.config is not None:
        if not args.config.exists():
            raise FileNotFoundError(f"Config file not found: {args.config}")
        import yaml
        with open(args.config, "r") as f:
            yaml_cfg = yaml.safe_load(f)
        train_section = yaml_cfg.get("train", {})
        yaml_to_arg = {
            "max_steps": "max_steps",
            "batch_size": "batch_size",
            "lr": "lr",
            "weight_decay": "weight_decay",
            "warmup_steps": "warmup_steps",
            "grad_accum_steps": "grad_accum",
            "save_interval": "save_interval",
            "eval_interval": "eval_interval",
            "neftune_alpha": "neftune_alpha",
            "max_val_batches": "max_val_batches",
        }
        new_defaults = {}
        for yaml_key, arg_name in yaml_to_arg.items():
            if yaml_key in train_section:
                new_defaults[arg_name] = train_section[yaml_key]
        if new_defaults:
            parser.set_defaults(**new_defaults)

    return parser.parse_args()


# ---------------------------------------------------------------------------
# Seed helper
# ---------------------------------------------------------------------------


def set_seed(seed: int) -> None:
    """Set deterministic seeds for Python, NumPy, and PyTorch."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# ---------------------------------------------------------------------------
# Optimizer parameter groups
# (Copied from pretrain.py to avoid circular import; identical logic)
# ---------------------------------------------------------------------------


def build_optimizer_param_groups(
    model: torch.nn.Module,
    weight_decay: float,
) -> list[dict]:
    """
    Split parameters into two groups:
      - decay group   : weight tensors with ndim >= 2 (Linear, etc.)
      - no-decay group: bias, LayerNorm/RMSNorm weights, and embedding weights

    This follows standard practice (e.g. GPT-style training).
    """
    decay_params: list[torch.nn.Parameter] = []
    no_decay_params: list[torch.nn.Parameter] = []

    # Module types whose parameters should never be decayed.
    no_decay_module_types = (
        torch.nn.Embedding,
        torch.nn.LayerNorm,
    )
    # Also skip any parameter whose name ends with '.bias'.
    no_decay_name_suffixes = ("bias",)

    # Collect module-level exclusions.
    no_decay_module_params: set[int] = set()
    for module in model.modules():
        if isinstance(module, no_decay_module_types):
            for param in module.parameters(recurse=False):
                no_decay_module_params.add(id(param))

    seen: set[int] = set()
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if id(param) in seen:
            continue
        seen.add(id(param))

        if (
            id(param) in no_decay_module_params
            or any(name.endswith(sfx) for sfx in no_decay_name_suffixes)
            or param.ndim < 2
        ):
            no_decay_params.append(param)
        else:
            decay_params.append(param)

    return [
        {"params": decay_params, "weight_decay": weight_decay},
        {"params": no_decay_params, "weight_decay": 0.0},
    ]


# ---------------------------------------------------------------------------
# Tokenizer resolution helper
# ---------------------------------------------------------------------------


def _resolve_tokenizer_path(args: argparse.Namespace) -> Path:
    """
    Determine the tokenizer path in priority order:
      1. Explicit --tokenizer argument
      2. tokenizer.json inside the base_checkpoint directory
      3. Project default: tokenizer/korean_sp/tokenizer.json
    """
    if args.tokenizer is not None:
        p = Path(args.tokenizer)
        if not p.exists():
            raise FileNotFoundError(f"Tokenizer not found at --tokenizer path: {p}")
        return p

    ckpt_tok = args.base_checkpoint / "tokenizer.json"
    if ckpt_tok.exists():
        return ckpt_tok

    default_tok = _PROJECT_ROOT / "tokenizer" / "korean_sp" / "tokenizer.json"
    if default_tok.exists():
        return default_tok

    raise FileNotFoundError(
        "Could not locate tokenizer.json. Tried:\n"
        f"  1. {ckpt_tok}\n"
        f"  2. {default_tok}\n"
        "Use --tokenizer to specify an explicit path."
    )


# ---------------------------------------------------------------------------
# Dynamic padding collate function
# ---------------------------------------------------------------------------


def dynamic_collate_fn(batch: list) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Collate function that pads each batch to its own maximum sequence length
    instead of a fixed global max_seq_len.  This reduces wasted FLOPs on
    short sequences and speeds up SFT which tends to have highly variable
    response lengths.

    Pads to the batch-local max, aligned to 64 tokens (for Flash Attention
    efficiency), with a floor of 512 tokens so micro-batches are not too short.

    Args:
        batch: List of ``(input_ids, labels)`` tuples from SFTDataset.

    Returns:
        Tuple of ``(input_ids, labels, attention_mask)`` tensors shaped
        ``[B, max_len]``.
        ``input_ids``      is right-padded with 0 (pad token).
        ``labels``         is right-padded with -1 (cross-entropy ignore_index).
        ``attention_mask`` is 1 for real tokens, 0 for padding.
    """
    # 64-token alignment + minimum 512 floor
    raw_max = max(item[0].size(0) for item in batch)
    max_len = max(512, ((raw_max + 63) // 64) * 64)

    input_ids_list, labels_list, mask_list = [], [], []
    for ids, labs in batch:
        pad_len = max_len - ids.size(0)
        input_ids_list.append(F.pad(ids, (0, pad_len), value=0))
        labels_list.append(F.pad(labs, (0, pad_len), value=-1))
        mask_list.append(
            F.pad(torch.ones(ids.size(0), dtype=torch.long), (0, pad_len), value=0)
        )

    return (
        torch.stack(input_ids_list),
        torch.stack(labels_list),
        torch.stack(mask_list),
    )


# ---------------------------------------------------------------------------
# NEFTune helper
# ---------------------------------------------------------------------------


def add_neftune_hook(model: torch.nn.Module, noise_alpha: float = 10.0):
    """
    Register a forward hook on the model's input embedding layer that adds
    uniform noise scaled by noise_alpha during training (NEFTune).

    Reference: "NEFTune: Noisy Embeddings Improve Instruction Finetuning"
    (Jain et al., 2023). https://arxiv.org/abs/2310.05914

    Args:
        model:       Raw (non-DDP) model instance.
        noise_alpha: Noise magnitude parameter (paper default: 10).

    Returns:
        The hook handle (call ``handle.remove()`` to deactivate), or None if
        the embedding layer could not be located.
    """
    # Unwrap DDP if needed
    raw = model.module if hasattr(model, "module") else model

    # 1) Try the standard HuggingFace accessor first.
    embedding: torch.nn.Embedding | None = None
    if hasattr(raw, "get_input_embeddings"):
        try:
            emb = raw.get_input_embeddings()
            if isinstance(emb, torch.nn.Embedding):
                embedding = emb
        except Exception:
            pass

    # 2) Fallback: walk common attribute paths found in open-source LLMs.
    if embedding is None:
        for attr_path in [
            "embedding",
            "embed_tokens",
            "token_embedding",
            "wte",
            "word_embeddings",
            "tok_embeddings",
            "transformer.wte",
            "model.embed_tokens",
            "model.embedding",
        ]:
            obj = raw
            for part in attr_path.split("."):
                obj = getattr(obj, part, None)
                if obj is None:
                    break
            if obj is not None and isinstance(obj, torch.nn.Embedding):
                embedding = obj
                break

    if embedding is None:
        print("[WARN] NEFTune: embedding layer을 찾지 못함, NEFTune 비활성화")
        return None

    print(
        f"[INFO] NEFTune: {type(embedding).__name__} hook 등록 "
        f"(shape={tuple(embedding.weight.shape)}, alpha={noise_alpha})"
    )

    def _hook(
        module: torch.nn.Module,
        inp: tuple,
        out: torch.Tensor,
    ) -> torch.Tensor:
        if module.training:
            # out shape: [B, seq_len, d_model]
            mag = noise_alpha / ((out.size(1) * out.size(2)) ** 0.5)
            out = out + torch.empty_like(out).uniform_(-mag, mag)
        return out

    return embedding.register_forward_hook(_hook)


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------


def main() -> None:
    args = parse_args()

    # ---- Distributed setup -------------------------------------------------
    is_ddp = "RANK" in os.environ
    rank = 0
    local_rank = 0
    world_size = 1

    if is_ddp:
        rank, local_rank, world_size, device = setup_ddp()
    else:
        # Single-GPU: honour --device flag, else pick cuda:0 or cpu.
        if args.device is not None:
            device = torch.device(args.device)
        elif torch.cuda.is_available():
            device = torch.device("cuda:0")
        else:
            device = torch.device("cpu")

    # Per-rank seed so data shuffling differs across replicas.
    set_seed(args.seed + rank)

    # ---- NUMA affinity for optimal GPU↔CPU memory locality ---------------
    # B200 topology: GPU 0-3 → NUMA node 0 (cores 0-35)
    #                GPU 4-6 → NUMA node 1 (cores 36-71)  [7 GPU 환경]
    try:
        if local_rank < 4:
            os.sched_setaffinity(0, set(range(0, 36)))   # NUMA node 0
        else:
            os.sched_setaffinity(0, set(range(36, 72)))   # NUMA node 1
        if is_main_process():
            print(f"NUMA affinity: rank {rank} (GPU {local_rank}) → "
                  f"{'NUMA0 cores 0-35' if local_rank < 4 else 'NUMA1 cores 36-71'}")
    except (AttributeError, OSError) as e:
        if is_main_process():
            print(f"[WARN] NUMA affinity failed: {e}")

    # ---- Validate base checkpoint ------------------------------------------
    if not args.base_checkpoint.exists():
        raise FileNotFoundError(
            f"Base checkpoint directory not found: {args.base_checkpoint}"
        )
    for required_file in ("model.pt", "config.yaml"):
        if not (args.base_checkpoint / required_file).exists():
            raise FileNotFoundError(
                f"Expected {required_file} inside base checkpoint: {args.base_checkpoint}"
            )

    # ---- Load pretrained model ---------------------------------------------
    # LLM.from_pretrained() reads config.yaml + model.pt and returns the model on CPU.
    # We move it to the target device immediately after loading.
    #
    # NOTE: fp8_model_init() is intentionally NOT used here (same as pretrain.py).
    # MXFP8Tensor weights are incompatible with DDP's _broadcast_coalesced.
    # Weights stay in float32; TransformerEngine quantizes on-the-fly inside fp8_autocast.
    model = LLM.from_pretrained(args.base_checkpoint)

    # FP8 override: --no_fp8 forces BF16 even if pretrained config had use_fp8=True.
    # --use_fp8 enables FP8 if pretrained config had it disabled.
    if args.no_fp8:
        model.config.use_fp8 = False
    elif args.use_fp8:
        model.config.use_fp8 = True

    # Move model to target device in bfloat16 (more memory-efficient than fp32
    # for fine-tuning, and required when BF16 autocast + TE are active).
    model = model.to(device=device, dtype=torch.bfloat16)

    # ---- Gradient checkpointing ----------------------------------------
    # Trades activation memory for recomputation during backward pass.
    # Especially useful for large models / long sequences in SFT.
    if hasattr(model, 'gradient_checkpointing_enable'):
        model.gradient_checkpointing_enable()
        if rank == 0:
            print("[INFO] Gradient checkpointing enabled")

    # FP8 alignment check: (batch_size × seq_len) must be divisible by 8.
    if model.config.use_fp8:
        seq_len = model.config.max_seq_len
        if (args.batch_size * seq_len) % 8 != 0:
            raise ValueError(
                f"FP8: batch_size × max_seq_len = {args.batch_size} × {seq_len} "
                f"= {args.batch_size * seq_len} must be divisible by 8."
            )

    if is_main_process():
        total_params = sum(p.numel() for p in model.parameters())
        print(f"Pretrained model loaded: {total_params:,} parameters")
        print(f"LMConfig: {model.config}")

    # ---- Wrap in DDP -------------------------------------------------------
    if is_ddp:
        from torch.nn.parallel import DistributedDataParallel as DDP

        model = DDP(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            gradient_as_bucket_view=True,
            bucket_cap_mb=800,
            find_unused_parameters=False,
        )

    # ---- Tokenizer ---------------------------------------------------------
    tokenizer_path = _resolve_tokenizer_path(args)
    if is_main_process():
        print(f"Loading tokenizer from: {tokenizer_path}")

    # Use the fast tokenizers library (same as the rest of the project).
    from tokenizers import Tokenizer  # type: ignore[import]
    tokenizer = Tokenizer.from_file(str(tokenizer_path))

    # ---- Dataset & DataLoader ----------------------------------------------
    # Import SFTDataset (created separately alongside this file).
    # SFTDataset returns (input_ids, targets) where prompt token positions in
    # targets are filled with -1.  The Trainer._compute_loss already uses
    # ignore_index=-1, so only response tokens contribute to the gradient.
    from data.sft_dataset import SFTDataset  # type: ignore[import]

    train_dataset = SFTDataset(
        data_path=args.sft_data,
        tokenizer=tokenizer,
        max_seq_len=model.config.max_seq_len
        if not isinstance(model, torch.nn.parallel.DistributedDataParallel)
        else model.module.config.max_seq_len,
    )

    if is_ddp:
        train_sampler: DistributedSampler | RandomSampler = DistributedSampler(
            train_dataset,
            num_replicas=world_size,
            rank=rank,
            shuffle=True,
            seed=args.seed,
        )
        shuffle = False
    else:
        train_sampler = RandomSampler(train_dataset)
        shuffle = False  # Sampler is provided; DataLoader must not also shuffle.

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        # SFT datasets are typically small enough that 2–4 workers suffice.
        # We use 4 to balance I/O with CPU parsing overhead from JSONL.
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
        prefetch_factor=2,
        persistent_workers=True,
        collate_fn=dynamic_collate_fn,
    )

    # Optional validation loader.
    # NOTE: The current Trainer implementation does not yet accept a val_loader
    # argument; the eval_interval config field is reserved for future use.
    # We construct the loader here so that once Trainer gains eval support,
    # wiring it in requires only passing val_loader=val_loader below.
    val_loader: DataLoader | None = None
    if args.val_data is not None:
        if not args.val_data.exists():
            raise FileNotFoundError(f"Validation data not found: {args.val_data}")
        val_dataset = SFTDataset(
            data_path=args.val_data,
            tokenizer=tokenizer,
            max_seq_len=train_dataset.max_seq_len,
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=True,
            drop_last=False,
            collate_fn=dynamic_collate_fn,
        )
        if is_main_process():
            print(f"Validation dataset: {len(val_dataset):,} samples")

    # ---- Optimizer ---------------------------------------------------------
    # Use the same two-group split (weight_decay / no weight_decay) as pretrain.
    # Unwrap DDP to get the raw model's parameters.
    raw_model = getattr(model, "module", model)
    param_groups = build_optimizer_param_groups(raw_model, args.weight_decay)
    optimizer = torch.optim.AdamW(
        param_groups,
        lr=args.lr,
        betas=(0.9, 0.95),
        eps=1e-8,
        fused=torch.cuda.is_available(),  # Use fused kernel when on CUDA.
    )

    # ---- TrainConfig -------------------------------------------------------
    # Set use_fp8 from the (possibly overridden) model config so Trainer builds
    # the correct FP8 recipe and wraps forward passes in fp8_autocast.
    use_fp8 = raw_model.config.use_fp8

    train_config = TrainConfig(
        max_steps=args.max_steps,
        checkpoint_dir=str(args.checkpoint_dir),
        grad_accum_steps=args.grad_accum,
        use_fp8=use_fp8,
        log_file=str(args.log_file) if args.log_file is not None else None,
        save_interval=args.save_interval,
        log_interval=10,
        eval_interval=args.eval_interval,
        max_val_batches=args.max_val_batches,
    )

    # ---- LR Scheduler ------------------------------------------------------
    scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        warmup_steps=args.warmup_steps,
        total_steps=train_config.max_steps,
    )

    # ---- Resume from SFT checkpoint ----------------------------------------
    # When --resume is given we restore the SFT optimizer/scheduler state as
    # well so learning rate, momentum buffers, etc. are correctly restored.
    # NOTE: This resumes SFT training, NOT the pretrain checkpoint.
    #       The pretrain weights were already loaded above via from_pretrained().
    start_step = 0
    if args.resume is not None:
        if not args.resume.exists():
            raise FileNotFoundError(f"Resume checkpoint not found: {args.resume}")
        start_step, resume_loss = load_checkpoint(
            path=args.resume,
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
        )
        if is_main_process():
            print(f"Resumed SFT from {args.resume} at step {start_step} (loss={resume_loss:.4f})")

    if args.resume is not None and isinstance(train_sampler, DistributedSampler):
        steps_per_epoch = len(train_loader)
        approx_epoch = start_step // steps_per_epoch if steps_per_epoch > 0 else 0
        train_sampler.set_epoch(approx_epoch)
        if is_main_process():
            print(f"[INFO] Resume: sampler epoch set to {approx_epoch}")

    # ---- Checkpoint directory ----------------------------------------------
    args.checkpoint_dir.mkdir(parents=True, exist_ok=True)

    # ---- Copy tokenizer to checkpoint dir for easy deployment later --------
    # This mirrors the tokenizer into the SFT checkpoint root so that the
    # final checkpoint directory is self-contained for convert_to_hf.py, etc.
    if is_main_process():
        dest_tok = args.checkpoint_dir / "tokenizer.json"
        if not dest_tok.exists():
            shutil.copy2(str(tokenizer_path), str(dest_tok))
            print(f"Tokenizer copied to {dest_tok}")

    # ---- Trainer -----------------------------------------------------------
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        config=train_config,
        device=device,
        rank=rank,
        sampler=train_sampler if is_ddp else None,
        val_loader=val_loader,
    )

    # ---- Signal handlers for graceful shutdown ----------------------------
    import signal as _signal_mod

    _trainer_ref = trainer

    def _graceful_shutdown_handler(signum, frame):
        sig_name = _signal_mod.Signals(signum).name
        if is_main_process():
            import datetime as _dt
            ts = _dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            msg = (
                f"[{ts}] [SIGNAL] Received {sig_name} (signum={signum}). "
                f"Initiating graceful shutdown..."
            )
            print(f"\n{msg}")
            if args.log_file is not None:
                try:
                    with open(args.log_file, "a", encoding="utf-8") as f:
                        f.write(msg + "\n")
                except Exception:
                    pass
        _trainer_ref.request_shutdown(sig_name)

    for _sig in (_signal_mod.SIGHUP, _signal_mod.SIGTERM):
        _signal_mod.signal(_sig, _graceful_shutdown_handler)

    # ---- SFT banner --------------------------------------------------------
    if is_main_process():
        import datetime

        inner_config = raw_model.config
        eff_batch_seqs = args.batch_size * args.grad_accum * world_size
        eff_tokens_per_step = eff_batch_seqs * inner_config.max_seq_len
        train_samples = len(train_dataset)
        precision_label = "FP8 (MXFP8BlockScaling)" if use_fp8 else "BF16"
        nccl_debug = os.environ.get("NCCL_DEBUG", "not set")
        omp_threads = os.environ.get("OMP_NUM_THREADS", "not set")

        print(
            f"\n{'='*70}\n"
            f"  LLM Supervised Fine-Tuning — "
            f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"
            f"{'='*70}\n"
            f"  base ckpt : {args.base_checkpoint}\n"
            f"  sft data  : {args.sft_data} ({train_samples:,} samples)\n"
            f"  model     : {inner_config.num_params:,} params  |  "
            f"d_model={inner_config.d_model}  n_layers={inner_config.n_layers}\n"
            f"  precision : {precision_label}\n"
            f"  GPUs      : {world_size}  |  batch/GPU={args.batch_size}  "
            f"grad_accum={args.grad_accum}\n"
            f"  eff_batch : {eff_batch_seqs} seqs  "
            f"= {eff_tokens_per_step:,} tok/step\n"
            f"  max_steps : {train_config.max_steps:,}\n"
            f"  lr        : {args.lr:.2e}  "
            f"warmup={args.warmup_steps}  weight_decay={args.weight_decay}\n"
            f"  ckpt_dir  : {args.checkpoint_dir}\n"
            f"  env       : OMP_NUM_THREADS={omp_threads}  NCCL_DEBUG={nccl_debug}\n"
            f"{'='*70}\n"
        )

    # ---- NEFTune -----------------------------------------------------------
    # Add uniform noise to embeddings during training to improve instruction
    # following (Jain et al., 2023).  Hook is registered on the raw (non-DDP)
    # model so it survives DDP's internal module wrapping.
    neftune_alpha = getattr(args, 'neftune_alpha', 5.0)
    neftune_handle = add_neftune_hook(raw_model, noise_alpha=neftune_alpha)
    if rank == 0:
        if neftune_handle is not None:
            print(f"[INFO] NEFTune enabled (noise_alpha={neftune_alpha})")
        else:
            print("[WARN] NEFTune disabled - embedding layer not found")

    # ---- Train -------------------------------------------------------------
    try:
        trainer.train(start_step=start_step)
    except KeyboardInterrupt:
        if is_main_process():
            print("\n[INFO] SFT interrupted by user (KeyboardInterrupt).")
    except Exception as e:
        import traceback
        if is_main_process():
            tb = traceback.format_exc()
            print(f"\n[ERROR] SFT failed at rank {rank}:\n{tb}")
            if args.log_file is not None:
                with open(args.log_file, "a", encoding="utf-8") as f:
                    import datetime
                    f.write(
                        f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
                        f"[FATAL] {tb}\n"
                    )
        raise
    finally:
        # Remove NEFTune hook so the model is clean for inference/saving.
        if neftune_handle is not None:
            neftune_handle.remove()
        if is_ddp:
            cleanup_ddp()


if __name__ == "__main__":
    main()