File size: 14,229 Bytes
b39a019 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 |
# BETA VERSION - NEEDS FURTHER DEVELOPMENT
# Read these to catch up on what is (trying to at least) being done here
# https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html
# https://pytorch.org/docs/stable/quantization.html#model-preparation-for-eager-mode-static-quantization
# Torch implementation of these models - mine is heavily based on these with some minor adjustments
#
# I've added squeeze and excitation layers to the MobileNetV2, a feature of MobileNetV3, but I did not put in
# NAS (unnecessary since we're not optimising for mobile) or hardswish (because I prefer ReLU/ think it is better)
# https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py#L117
# https://github.com/pytorch/vision/blob/11bf27e37190b320216c349e39b085fb33aefed1/torchvision/models/mobilenetv3.py#L56
# This is an adapted version of MobileNet, somewhere between versions 2/3, as some features of 3 were not required. There are
# also some additions for our particular use case from miscallaneous sources
from torchvision import transforms
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader, Subset
import ClassUtils
from torch.ao.quantization import QuantStub, DeQuantStub
from torchvision.models.mobilenetv2 import _make_divisible
import time
import random
import os
import matplotlib.pyplot as plt
# Squeeze: summarising global context by pooling feature maps into a single value
# Excitation: Learning attention weights for each channel to prioritise the most relevant ones
class SqueezeExcitation(nn.Module):
def __init__(self, input_channels:int, squeeze_factor: int = 4):
super().__init__()
# If channels are a multiple of 8, they're optimised by the hardware
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
self.squeeze = nn.Conv2d(input_channels, squeeze_channels, 1)
self.relu = nn.ReLU(inplace=True)
self.unsqueeze = nn.Conv2d(squeeze_channels, input_channels, 1)
self.quant = nn.quantized.FloatFunctional()
# Scale returns the feature attention map, how much attention should be payed to each input layer, in range [0, 1]
# Inplace is used to save memory on operations - it might not be necessary in our case since we aren't using edge devices
def _scale(self, input: Tensor, inplace=bool) -> Tensor:
# Squeeze
scale = F.adaptive_avg_pool2d(input, 1)
scale = self.squeeze(scale)
# Excite
scale = self.relu(scale)
scale = self.unsqueeze(scale)
return F.hardsigmoid(scale, inplace=inplace)
def forward(self, input: Tensor) -> Tensor:
# print(self._scale(input, True))
# print(input)
return self.quant.mul(self._scale(input, True), input)
# The basic building block of our convolutional neural network
# - qconfig should automatically insert fakeQuantisation operations during training, so there is no need to manually place them now
class ConvBNReLu(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super().__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
# No point applying a bias (constant addative term) if the next layer is a batch normalisation layer
nn.BatchNorm2d(out_planes, momentum=0.1),
nn.ReLU(inplace=True)
)
# Like typical residual blocks but uses inverse narrow->wide->narrow, with Depth-wise convolutions instead of normal,
# to reduce the number of parameters required compared to the usual residual blocks
class InvertedResidual(nn.Module):
def __init__(self, inpt, oupt, stride, expnd_ratio, kernel_size=3, se_layer=None):
super().__init__()
self.stride = stride
assert stride in [1, 2]
intermediate_channels = int(round(inpt * expnd_ratio))
# If the stride != 1, downsampling occurs so cannot be true.
self.use_residual = (stride==1) and (inpt==oupt)
# Squeeze and excitation layer - applied after the dw and pw convolutions, but before the residual
self.se_layer = se_layer if se_layer else None
layers = []
if expnd_ratio != 1:
# Pointwise convolution to increase the channels
layers.append(ConvBNReLu(inpt, intermediate_channels, kernel_size=1))
layers.extend([
# Depthwise convolution - each channel is convoled on an independent basis
ConvBNReLu(intermediate_channels, intermediate_channels, stride=stride, groups=intermediate_channels),
# point-wise convolution - linear combination to reduce layers back to the expected number
nn.Conv2d(intermediate_channels, oupt, 1, 1, 0, bias=False),
nn.BatchNorm2d(oupt, momentum=0.25)
])
self.conv = nn.Sequential(*layers)
def forward(self, x):
outpt = self.conv(x)
if self.se_layer is not None:
outpt = self.se_layer(outpt)
if self.use_residual:
return x + outpt
else:
return outpt
# Same as the inverted residual, but replaces addition with a quantizable friendly operation
class QuantizableInvertedResidual(InvertedResidual):
def __init__(self, inpt, outpt, stride, expnd_ratio, se_layer=None):
super().__init__(inpt, outpt, stride, expnd_ratio, se_layer=se_layer)
self.skip_add = nn.quantized.FloatFunctional()
# Overwrites the forwarding to use a quantizable friendly version of the addition
def forward(self, x):
outpt = self.conv(x)
if self.se_layer is not None:
outpt = self.se_layer(outpt)
if self.use_residual:
return self.skip_add.add(x, outpt)
else:
return outpt
# The MobileNetV2 Architecture + some features from V3 (squeeze and excitation) but I didn't add NAS since we aren't running this on mobile
# And I prefer ReLU over hardswish
class MobileNetV2_5(nn.Module):
def __init__(self, class_num=2, width_mult=1.0, round_nearest=8):
super().__init__()
layers = []
input_channel = 32
last_channel = 1280
# Just straight up copying this from the torchvision implementation
self.residual_params = [
# expnd_ratio, outpt_channels, num_blocks, stride
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
first_conv_output_channels = _make_divisible(self.residual_params[0][1] *width_mult, round_nearest)
layers.append(
ConvBNReLu(3,
first_conv_output_channels,
kernel_size=3,
stride=2,
)
)
prev_input_channels = first_conv_output_channels
# Main body of feature extraction
for expnd, oupt_c, num_blocks, strd in self.residual_params:
# output channels must be a multiple of 8 for hardware optimisation
output_channel = _make_divisible(oupt_c * width_mult, round_nearest)
for i in range(num_blocks):
stride = strd if i == 0 else 1
se_layer = SqueezeExcitation(oupt_c) if i == 0 else None
layers.append(QuantizableInvertedResidual(prev_input_channels, output_channel, stride, expnd_ratio=expnd, se_layer=se_layer))
prev_input_channels = output_channel
self.last_channel = _make_divisible(last_channel * max(width_mult, 1.0), round_nearest)
# We could put this in the classifier, but I want that to be lightweight so that we could do transfer learning only on the head and
# the feature extraction part of the model.
layers.append(
ConvBNReLu(prev_input_channels, self.last_channel, kernel_size=1)
)
self.feature_extraction = nn.Sequential(*layers)
self.avg_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Sequential(
nn.Dropout(0.125),
nn.Linear(last_channel, class_num)
)
# This bit is also just straight up copied from torch's implementation - I'm not touching it in case it gets messed up
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def forward(self, x: Tensor) -> Tensor:
x = self.feature_extraction(x)
x = self.avg_pooling(x)
x = torch.flatten(x, 1)
print("eyo")
x = self.classifier(x)
return x
class QuantizableMobileNetV2_5(MobileNetV2_5):
def __init__(self, class_num=2, width_mult=1.0, round_nearest=8):
super().__init__(class_num=class_num, width_mult=width_mult, round_nearest=round_nearest)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def _forward_impl(self, x: Tensor) -> Tensor:
x = self.feature_extraction(x)
# This was for debugging errors in shape of feature maps as they pass through - not deleting incase useful later
# for idx, layer in enumerate(self.feature_extraction):
# x = layer(x)
# print(f"Feature extraction layer {idx}, output shape: {x.shape}")
x = self.avg_pooling(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
def forward(self, x: Tensor) -> Tensor:
x = self.quant(x)
x = self._forward_impl(x)
x= self.dequant(x)
return x
def train_single_epoch(model, loss_fnc, optimiser, data_loader, device):
model.train()
running_loss = 0
running_time = 0.0
for images, labels in data_loader:
start_time = time.time()
print(".", end=" ")
images, labels = images.to(device), labels.to(device)
preds = model(images)
loss = loss_fnc(preds, labels)
loss.backward()
optimiser.step()
running_loss += loss.item()
running_time += time.time() - start_time
print(f"{(time.time() - start_time):.2f}, {(running_time):.2f}", end=" ")
print(f"loss of {running_loss}")
return
def print_size_of_model(model):
torch.save(model.state_dict(), "temp.p")
print('Size (MB):', os.path.getsize("temp.p")/1e6)
os.remove('temp.p')
def adjust_quantisation_engine():
# Adjust according to what your device supports
print(torch.backends.quantized.supported_engines)
torch.backends.quantized.engine = 'qnnpack'
def train_model(model, dataloader, loss_function, optimiser, epoch_number=25, const_save=False, save=True):
for epoch in range(epoch_number):
print("IT IS EPOCH", epoch)
train_single_epoch(model, loss_function, optimiser, dataloader, torch.device('cpu'))
# Gradually freezes the unrequired observer parameters for quantisation and batch normalisation after a few epochs
if epoch > 3:
# Freeze quantizer parameters
model.apply(torch.ao.quantization.disable_observer)
if epoch > 2:
# Freeze batch norm mean and variance estimates
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
if const_save:
quantized_model = torch.ao.quantization.convert(model.eval(), inplace=False)
quantized_model.eval()
# Saving each intermediary model since they're so small, and this lets load up any of them for performace difference examples later
torch.save(quantized_model.state_dict(), "quantStateDict"+str(epoch+1)+".pth")
print(f"the above was Epoch {epoch+1} of {epoch_number} \nThe model has a size of", end=" ")
print_size_of_model(quantized_model)
else:
print(f"the above was Epoch {epoch} of {epoch_number}")
if save:
torch.save(quantized_model.state_dict(), "full_quantStateDict.pth")
return model
learning_rate = 1e-3
batch_size = 64
data_size = 2560
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Update to whatever you call your model
modelName = "quantStateDict8.pth"
load = False
model = QuantizableMobileNetV2_5()
# Adjust according to what your device supports
torch.backends.quantized.engine = 'qnnpack'
model.qconfig = torch.ao.quantization.default_qconfig
optimiser = torch.optim.SGD(model.parameters(), lr= learning_rate)
torch.ao.quantization.prepare_qat(model, inplace=True)
dataset = ClassUtils.CrosswalkDataset("zebra_annotations/classification_data")
train_loader = DataLoader(
Subset(dataset, random.sample(list(range(0, int(len(dataset) * 0.95))), data_size)),
batch_size=batch_size, shuffle=True)
test_loader = DataLoader(
Subset(dataset, random.sample(list(range(int(len(dataset) * 0.95), len(dataset))), 256)),
batch_size=batch_size, shuffle=False)
loss_function = nn.BCEWithLogitsLoss()
model_updated = train_model(model, train_loader, loss_function, optimiser, epoch_number=8, const_save=True)
quantized_model = torch.ao.quantization.convert(model_updated.eval(), inplace=True)
if load:
model_loaded_state_dict = torch.load(modelName)
quantized_model.load_state_dict(model_loaded_state_dict)
for images, labels in test_loader:
preds = torch.sigmoid(quantized_model(images))
for i in range(len(preds)):
print(preds)
# plt.imshow(torch.permute(images[i], (1, 2, 0)).detach().numpy())
# plt.title(f"Prediction: {preds[i]}, Actual: {labels[i][0] == 1}")
# plt.axis("off")
# plt.show()
|