VisionLanguageGroup's picture
clean up
86072ea
"""
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
"""
import torch
from segment_anything import sam_model_registry
torch.backends.cuda.matmul.allow_tf32 = True
from torch import nn
import torch.nn.functional as F
class Transformer(nn.Module):
def __init__(self, backbone="vit_l", ps=8, nout=3, bsize=256, rdrop=0.4,
checkpoint=None, dtype=torch.float32):
super(Transformer, self).__init__()
"""
print(self.encoder.patch_embed)
PatchEmbed(
(proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
)
print(self.encoder.neck)
Sequential(
(0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): LayerNorm2d()
(2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(3): LayerNorm2d()
)
"""
# instantiate the vit model, default to not loading SAM
# checkpoint = sam_vit_l_0b3195.pth is standard pretrained SAM
self.encoder = sam_model_registry[backbone](checkpoint).image_encoder
w = self.encoder.patch_embed.proj.weight.detach()
nchan = w.shape[0]
# change token size to ps x ps
self.ps = ps
self.encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps)
self.encoder.patch_embed.proj.weight.data = w[:,:,::16//ps,::16//ps]
# adjust position embeddings for new bsize and new token size
ds = (1024 // 16) // (bsize // ps)
self.encoder.pos_embed = nn.Parameter(self.encoder.pos_embed[:,::ds,::ds], requires_grad=True)
# readout weights for nout output channels
# if nout is changed, weights will not load correctly from pretrained Cellpose-SAM
self.nout = nout
self.out = nn.Conv2d(256, self.nout * ps**2, kernel_size=1)
# W2 reshapes token space to pixel space, not trainable
self.W2 = nn.Parameter(torch.eye(self.nout * ps**2).reshape(self.nout*ps**2, self.nout, ps, ps),
requires_grad=False)
# fraction of layers to drop at random during training
self.rdrop = rdrop
# average diameter of ROIs from training images from fine-tuning
self.diam_labels = nn.Parameter(torch.tensor([30.]), requires_grad=False)
# average diameter of ROIs during main training
self.diam_mean = nn.Parameter(torch.tensor([30.]), requires_grad=False)
# set attention to global in every layer
for blk in self.encoder.blocks:
blk.window_size = 0
self.dtype = dtype
def forward(self, x, feat=None):
# same progression as SAM until readout
x = self.encoder.patch_embed(x)
if feat is not None:
feat = self.encoder.patch_embed(feat)
x = x + x * feat * 0.5
if self.encoder.pos_embed is not None:
x = x + self.encoder.pos_embed
if self.training and self.rdrop > 0:
nlay = len(self.encoder.blocks)
rdrop = (torch.rand((len(x), nlay), device=x.device) <
torch.linspace(0, self.rdrop, nlay, device=x.device)).to(x.dtype)
for i, blk in enumerate(self.encoder.blocks):
mask = rdrop[:,i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
x = x * mask + blk(x) * (1-mask)
else:
for blk in self.encoder.blocks:
x = blk(x)
x = self.encoder.neck(x.permute(0, 3, 1, 2))
# readout is changed here
x1 = self.out(x)
x1 = F.conv_transpose2d(x1, self.W2, stride = self.ps, padding = 0)
# maintain the second output of feature size 256 for backwards compatibility
return x1, torch.randn((x.shape[0], 256), device=x.device)
def load_model(self, PATH, device, strict = False):
state_dict = torch.load(PATH, map_location = device, weights_only=True)
keys = [k for k in state_dict.keys()]
if keys[0][:7] == "module.":
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
new_state_dict[name] = v
self.load_state_dict(new_state_dict, strict = strict)
else:
self.load_state_dict(state_dict, strict = strict)
if self.dtype != torch.float32:
self = self.to(self.dtype)
@property
def device(self):
"""
Get the device of the model.
Returns:
torch.device: The device of the model.
"""
return next(self.parameters()).device
def save_model(self, filename):
"""
Save the model to a file.
Args:
filename (str): The path to the file where the model will be saved.
"""
torch.save(self.state_dict(), filename)