# modeling_resnet.py import torch import torch.nn as nn from transformers import PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import ImageClassifierOutput class PrunedResNetConfig(PretrainedConfig): model_type = "resnet" def __init__( self, channel_config: dict[str, int] | None = None, num_classes=1000, **kwargs ): super().__init__(**kwargs) self.channel_config = channel_config self.num_classes = num_classes class PrunedResNet50(PreTrainedModel): config_class = PrunedResNetConfig _tied_weights_keys = [] def __init__(self, config: PrunedResNetConfig): super().__init__(config) self.config = config c = config.channel_config self.conv1 = nn.Conv2d( 3, c["conv1"], kernel_size=7, stride=2, padding=3, bias=False ) self.bn1 = nn.BatchNorm2d(c["conv1"]) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(c, stage_idx=1, layers=3, stride=1) self.layer2 = self._make_layer(c, stage_idx=2, layers=4, stride=2) self.layer3 = self._make_layer(c, stage_idx=3, layers=6, stride=2) self.layer4 = self._make_layer(c, stage_idx=4, layers=3, stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) last_channel = c["layer4.2.conv3"] self.fc = nn.Linear(last_channel, config.num_classes) self.post_init() def _make_layer(self, c, stage_idx, layers, stride): # Builds a ResNet layer (e.g., layer1) containing multiple Bottleneck blocks blocks = [] # The first block in a layer often handles stride and downsampling blocks.append( Bottleneck( inplanes=c[f"layer{stage_idx}.0.in"], planes=[ c[f"layer{stage_idx}.0.conv1"], c[f"layer{stage_idx}.0.conv2"], c[f"layer{stage_idx}.0.conv3"], ], stride=stride, downsample_planes=c.get(f"layer{stage_idx}.0.downsample.0", None), ) ) # Subsequent blocks for i in range(1, layers): blocks.append( Bottleneck( inplanes=c[f"layer{stage_idx}.{i}.in"], planes=[ c[f"layer{stage_idx}.{i}.conv1"], c[f"layer{stage_idx}.{i}.conv2"], c[f"layer{stage_idx}.{i}.conv3"], ], ) ) return nn.Sequential(*blocks) def forward(self, pixel_values=None, labels=None, **kwargs): x = pixel_values x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) logits = self.fc(x) loss = None if labels is not None: # CrossEntropyLoss handles the Softmax internally loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1)) return ImageClassifierOutput(logits=logits, loss=loss) class Bottleneck(nn.Module): # Standard Bottleneck but with dynamic channel sizes def __init__(self, inplanes, planes, stride=1, downsample_planes=None): super().__init__() c1, c2, c3 = planes # The 3 conv widths inside the bottleneck self.conv1 = nn.Conv2d(inplanes, c1, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(c1) self.conv2 = nn.Conv2d( c1, c2, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(c2) self.conv3 = nn.Conv2d(c2, c3, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(c3) self.relu = nn.ReLU(inplace=True) self.downsample = None if downsample_planes is not None: self.downsample = nn.Sequential( nn.Conv2d( inplanes, downsample_planes, kernel_size=1, stride=stride, bias=False, ), nn.BatchNorm2d(downsample_planes), ) def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out