File size: 4,862 Bytes
33c2790
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# modeling_resnet.py
import torch
import torch.nn as nn
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import ImageClassifierOutput


class PrunedResNetConfig(PretrainedConfig):
    model_type = "resnet"

    def __init__(
        self, channel_config: dict[str, int] | None = None, num_classes=1000, **kwargs
    ):
        super().__init__(**kwargs)
        self.channel_config = channel_config
        self.num_classes = num_classes


class PrunedResNet50(PreTrainedModel):
    config_class = PrunedResNetConfig
    _tied_weights_keys = []

    def __init__(self, config: PrunedResNetConfig):
        super().__init__(config)
        self.config = config
        c = config.channel_config
        self.conv1 = nn.Conv2d(
            3, c["conv1"], kernel_size=7, stride=2, padding=3, bias=False
        )
        self.bn1 = nn.BatchNorm2d(c["conv1"])
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(c, stage_idx=1, layers=3, stride=1)
        self.layer2 = self._make_layer(c, stage_idx=2, layers=4, stride=2)
        self.layer3 = self._make_layer(c, stage_idx=3, layers=6, stride=2)
        self.layer4 = self._make_layer(c, stage_idx=4, layers=3, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        last_channel = c["layer4.2.conv3"]
        self.fc = nn.Linear(last_channel, config.num_classes)
        self.post_init()

    def _make_layer(self, c, stage_idx, layers, stride):
        # Builds a ResNet layer (e.g., layer1) containing multiple Bottleneck blocks
        blocks = []

        # The first block in a layer often handles stride and downsampling
        blocks.append(
            Bottleneck(
                inplanes=c[f"layer{stage_idx}.0.in"],
                planes=[
                    c[f"layer{stage_idx}.0.conv1"],
                    c[f"layer{stage_idx}.0.conv2"],
                    c[f"layer{stage_idx}.0.conv3"],
                ],
                stride=stride,
                downsample_planes=c.get(f"layer{stage_idx}.0.downsample.0", None),
            )
        )

        # Subsequent blocks
        for i in range(1, layers):
            blocks.append(
                Bottleneck(
                    inplanes=c[f"layer{stage_idx}.{i}.in"],
                    planes=[
                        c[f"layer{stage_idx}.{i}.conv1"],
                        c[f"layer{stage_idx}.{i}.conv2"],
                        c[f"layer{stage_idx}.{i}.conv3"],
                    ],
                )
            )

        return nn.Sequential(*blocks)

    def forward(self, pixel_values=None, labels=None, **kwargs):
        x = pixel_values
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        logits = self.fc(x)
        loss = None
        if labels is not None:
            # CrossEntropyLoss handles the Softmax internally
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1))
        return ImageClassifierOutput(logits=logits, loss=loss)


class Bottleneck(nn.Module):
    # Standard Bottleneck but with dynamic channel sizes
    def __init__(self, inplanes, planes, stride=1, downsample_planes=None):
        super().__init__()
        c1, c2, c3 = planes  # The 3 conv widths inside the bottleneck

        self.conv1 = nn.Conv2d(inplanes, c1, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(c1)

        self.conv2 = nn.Conv2d(
            c1, c2, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(c2)

        self.conv3 = nn.Conv2d(c2, c3, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(c3)

        self.relu = nn.ReLU(inplace=True)

        self.downsample = None
        if downsample_planes is not None:
            self.downsample = nn.Sequential(
                nn.Conv2d(
                    inplanes,
                    downsample_planes,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(downsample_planes),
            )

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out