|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
from transformers.modeling_outputs import ImageClassifierOutput |
|
|
|
|
|
|
|
|
class PrunedResNetConfig(PretrainedConfig): |
|
|
model_type = "resnet" |
|
|
|
|
|
def __init__( |
|
|
self, channel_config: dict[str, int] | None = None, num_classes=1000, **kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.channel_config = channel_config |
|
|
self.num_classes = num_classes |
|
|
|
|
|
|
|
|
class PrunedResNet50(PreTrainedModel): |
|
|
config_class = PrunedResNetConfig |
|
|
_tied_weights_keys = [] |
|
|
|
|
|
def __init__(self, config: PrunedResNetConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
c = config.channel_config |
|
|
self.conv1 = nn.Conv2d( |
|
|
3, c["conv1"], kernel_size=7, stride=2, padding=3, bias=False |
|
|
) |
|
|
self.bn1 = nn.BatchNorm2d(c["conv1"]) |
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
|
|
self.layer1 = self._make_layer(c, stage_idx=1, layers=3, stride=1) |
|
|
self.layer2 = self._make_layer(c, stage_idx=2, layers=4, stride=2) |
|
|
self.layer3 = self._make_layer(c, stage_idx=3, layers=6, stride=2) |
|
|
self.layer4 = self._make_layer(c, stage_idx=4, layers=3, stride=2) |
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
|
|
last_channel = c["layer4.2.conv3"] |
|
|
self.fc = nn.Linear(last_channel, config.num_classes) |
|
|
self.post_init() |
|
|
|
|
|
def _make_layer(self, c, stage_idx, layers, stride): |
|
|
|
|
|
blocks = [] |
|
|
|
|
|
|
|
|
blocks.append( |
|
|
Bottleneck( |
|
|
inplanes=c[f"layer{stage_idx}.0.in"], |
|
|
planes=[ |
|
|
c[f"layer{stage_idx}.0.conv1"], |
|
|
c[f"layer{stage_idx}.0.conv2"], |
|
|
c[f"layer{stage_idx}.0.conv3"], |
|
|
], |
|
|
stride=stride, |
|
|
downsample_planes=c.get(f"layer{stage_idx}.0.downsample.0", None), |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
for i in range(1, layers): |
|
|
blocks.append( |
|
|
Bottleneck( |
|
|
inplanes=c[f"layer{stage_idx}.{i}.in"], |
|
|
planes=[ |
|
|
c[f"layer{stage_idx}.{i}.conv1"], |
|
|
c[f"layer{stage_idx}.{i}.conv2"], |
|
|
c[f"layer{stage_idx}.{i}.conv3"], |
|
|
], |
|
|
) |
|
|
) |
|
|
|
|
|
return nn.Sequential(*blocks) |
|
|
|
|
|
def forward(self, pixel_values=None, labels=None, **kwargs): |
|
|
x = pixel_values |
|
|
x = self.conv1(x) |
|
|
x = self.bn1(x) |
|
|
x = self.relu(x) |
|
|
x = self.maxpool(x) |
|
|
|
|
|
x = self.layer1(x) |
|
|
x = self.layer2(x) |
|
|
x = self.layer3(x) |
|
|
x = self.layer4(x) |
|
|
|
|
|
x = self.avgpool(x) |
|
|
x = torch.flatten(x, 1) |
|
|
logits = self.fc(x) |
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1)) |
|
|
return ImageClassifierOutput(logits=logits, loss=loss) |
|
|
|
|
|
|
|
|
class Bottleneck(nn.Module): |
|
|
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample_planes=None): |
|
|
super().__init__() |
|
|
c1, c2, c3 = planes |
|
|
|
|
|
self.conv1 = nn.Conv2d(inplanes, c1, kernel_size=1, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(c1) |
|
|
|
|
|
self.conv2 = nn.Conv2d( |
|
|
c1, c2, kernel_size=3, stride=stride, padding=1, bias=False |
|
|
) |
|
|
self.bn2 = nn.BatchNorm2d(c2) |
|
|
|
|
|
self.conv3 = nn.Conv2d(c2, c3, kernel_size=1, bias=False) |
|
|
self.bn3 = nn.BatchNorm2d(c3) |
|
|
|
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
|
|
|
self.downsample = None |
|
|
if downsample_planes is not None: |
|
|
self.downsample = nn.Sequential( |
|
|
nn.Conv2d( |
|
|
inplanes, |
|
|
downsample_planes, |
|
|
kernel_size=1, |
|
|
stride=stride, |
|
|
bias=False, |
|
|
), |
|
|
nn.BatchNorm2d(downsample_planes), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
identity = x |
|
|
out = self.conv1(x) |
|
|
out = self.bn1(out) |
|
|
out = self.relu(out) |
|
|
|
|
|
out = self.conv2(out) |
|
|
out = self.bn2(out) |
|
|
out = self.relu(out) |
|
|
|
|
|
out = self.conv3(out) |
|
|
out = self.bn3(out) |
|
|
|
|
|
if self.downsample is not None: |
|
|
identity = self.downsample(x) |
|
|
|
|
|
out += identity |
|
|
out = self.relu(out) |
|
|
return out |
|
|
|