File size: 3,315 Bytes
eca55dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.utilities import rank_zero_only


class WandbOfflineCheckpointCallback(Callback):
    """
    Custom callback to log model checkpoints to WandB even when offline=True.
    Lightning's WandbLogger forbids log_model=True with offline=True.
    This callback manually calls experiment.save() on the checkpoint directory.
    """

    def __init__(self, save_dir: str = None):
        self.save_dir = save_dir
        self.best_model_path = None

    @rank_zero_only
    def on_train_epoch_end(self, trainer, pl_module):
        # Check if we have a wandb logger
        if trainer.logger and isinstance(trainer.logger, WandbLogger):
            # If checkpoint callback exists
            if trainer.checkpoint_callback:
                # We can save all files in dirpath
                self._save_checkpoints(trainer.logger, trainer)

    @rank_zero_only
    def on_fit_end(self, trainer, pl_module):
        if trainer.logger and isinstance(trainer.logger, WandbLogger):
            if trainer.checkpoint_callback:
                self._save_checkpoints(trainer.logger, trainer)

    def _save_checkpoints(self, logger, trainer):
        dirpath = trainer.checkpoint_callback.dirpath
        if not dirpath or not os.path.exists(dirpath):
            return

        # WandB 'save' with base_path argument preserves relative structure
        # We want to save only the last.safetensors file to WandB

        # Identify the last checkpoint path
        last_ckpt = trainer.checkpoint_callback.last_model_path

        # Fallback: if last_model_path is not set, but save_last is True,
        # check for 'last.ckpt' explicitly.
        if (not last_ckpt) and trainer.checkpoint_callback.save_last:
            potential_last = os.path.join(dirpath, "last.ckpt")
            if os.path.exists(potential_last):
                last_ckpt = potential_last

        if last_ckpt and os.path.exists(last_ckpt):
            # Construct expected safetensors path
            base_name = os.path.splitext(os.path.basename(last_ckpt))[0]
            sf_path = os.path.join(dirpath, f"{base_name}.safetensors")

            if os.path.exists(sf_path):
                # Policy="now" ensures it's copied to wandb directory immediately
                logger.experiment.save(
                    sf_path, base_path=os.path.dirname(dirpath), policy="now"
                )

        # Cleanup broken symlinks in the wandb directory
        # This is necessary because if a checkpoint is deleted by ModelCheckpoint (e.g. save_top_k),
        # the symlink in the wandb directory remains but points to a non-existent file.
        # This causes 'wandb sync' to fail.
        wandb_dir = logger.experiment.dir
        # Assuming base_path was os.path.dirname(dirpath) -> 'checkpoints' is the subdir
        wandb_ckpt_dir = os.path.join(wandb_dir, "checkpoints")

        if os.path.exists(wandb_ckpt_dir):
            for filename in os.listdir(wandb_ckpt_dir):
                filepath = os.path.join(wandb_ckpt_dir, filename)
                # Check if it is a broken link
                if os.path.islink(filepath) and not os.path.exists(filepath):
                    os.remove(filepath)