File size: 5,150 Bytes
aff3c6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
"""

import torch
from segment_anything import sam_model_registry
torch.backends.cuda.matmul.allow_tf32 = True
from torch import nn 
import torch.nn.functional as F

class Transformer(nn.Module):
    def __init__(self, backbone="vit_l", ps=8, nout=3, bsize=256, rdrop=0.4,
                  checkpoint=None, dtype=torch.float32):
        super(Transformer, self).__init__()
        """
        print(self.encoder.patch_embed)
            PatchEmbed(
            (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
            )
        print(self.encoder.neck)
            Sequential(
            (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): LayerNorm2d()
            (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (3): LayerNorm2d()
            )
        """
        # instantiate the vit model, default to not loading SAM
        # checkpoint = sam_vit_l_0b3195.pth is standard pretrained SAM
        self.encoder = sam_model_registry[backbone](checkpoint).image_encoder
        w = self.encoder.patch_embed.proj.weight.detach()
        nchan = w.shape[0]
        
        # change token size to ps x ps
        self.ps = ps
        self.encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps)
        self.encoder.patch_embed.proj.weight.data = w[:,:,::16//ps,::16//ps]
        
        # adjust position embeddings for new bsize and new token size
        ds = (1024 // 16) // (bsize // ps)
        self.encoder.pos_embed = nn.Parameter(self.encoder.pos_embed[:,::ds,::ds], requires_grad=True)

        # readout weights for nout output channels
        # if nout is changed, weights will not load correctly from pretrained Cellpose-SAM
        self.nout = nout
        self.out = nn.Conv2d(256, self.nout * ps**2, kernel_size=1)

        # W2 reshapes token space to pixel space, not trainable
        self.W2 = nn.Parameter(torch.eye(self.nout * ps**2).reshape(self.nout*ps**2, self.nout, ps, ps), 
                               requires_grad=False)
        
        # fraction of layers to drop at random during training
        self.rdrop = rdrop

        # average diameter of ROIs from training images from fine-tuning 
        self.diam_labels = nn.Parameter(torch.tensor([30.]), requires_grad=False)
        # average diameter of ROIs during main training
        self.diam_mean = nn.Parameter(torch.tensor([30.]), requires_grad=False)
        
        # set attention to global in every layer
        for blk in self.encoder.blocks:
            blk.window_size = 0

        self.dtype = dtype

    def forward(self, x, feat=None):      
        # same progression as SAM until readout
        x = self.encoder.patch_embed(x)
        if feat is not None:
            feat = self.encoder.patch_embed(feat)
            x = x + x * feat * 0.5
        
        if self.encoder.pos_embed is not None:
            x = x + self.encoder.pos_embed
        
        if self.training and self.rdrop > 0:
            nlay = len(self.encoder.blocks)
            rdrop = (torch.rand((len(x), nlay), device=x.device) < 
                     torch.linspace(0, self.rdrop, nlay, device=x.device)).to(x.dtype)
            for i, blk in enumerate(self.encoder.blocks):            
                mask = rdrop[:,i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
                x = x * mask + blk(x) * (1-mask)
        else:
            for blk in self.encoder.blocks:
                x = blk(x)

        x = self.encoder.neck(x.permute(0, 3, 1, 2))

        # readout is changed here
        x1 = self.out(x)
        x1 = F.conv_transpose2d(x1, self.W2, stride = self.ps, padding = 0)
        
        # maintain the second output of feature size 256 for backwards compatibility
           
        return x1, torch.randn((x.shape[0], 256), device=x.device)
    
    def load_model(self, PATH, device, strict = False):        
        state_dict = torch.load(PATH, map_location = device, weights_only=True)
        keys = [k for k in state_dict.keys()]
        if keys[0][:7] == "module.":
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
                new_state_dict[name] = v
            self.load_state_dict(new_state_dict, strict = strict)
        else:
            self.load_state_dict(state_dict, strict = strict)

        if self.dtype != torch.float32:
            self = self.to(self.dtype)

    
    @property
    def device(self):
        """
        Get the device of the model.

        Returns:
            torch.device: The device of the model.
        """
        return next(self.parameters()).device

    def save_model(self, filename):
        """
        Save the model to a file.

        Args:
            filename (str): The path to the file where the model will be saved.
        """
        torch.save(self.state_dict(), filename)