| | import torch |
| | import torch.nn as nn |
| | import re |
| | import math |
| | BatchNorm2d=nn.BatchNorm2d |
| | def conv3x3(in_planes, out_planes, stride=1): |
| | """3x3 convolution with padding""" |
| | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, |
| | padding=1, bias=False) |
| |
|
| | class BasicBlockLayerNorm(nn.Module): |
| | expansion = 1 |
| |
|
| | def __init__(self, inplanes, planes,norm_shape, stride=1, downsample=None, dcn=None): |
| | super(BasicBlockLayerNorm, self).__init__() |
| | self.with_dcn = dcn is not None |
| | self.conv1 = conv3x3(inplanes, planes, stride) |
| | self.bn1 = nn.LayerNorm(norm_shape) |
| | self.relu = nn.ReLU(inplace=True) |
| | self.with_modulated_dcn = False |
| | if self.with_dcn: |
| | fallback_on_stride = dcn.get('fallback_on_stride', False) |
| | self.with_modulated_dcn = dcn.get('modulated', False) |
| | self.conv2 = conv3x3(planes, planes) |
| | if not self.with_dcn or fallback_on_stride: |
| | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, |
| | padding=1, bias=False) |
| | else: |
| | raise NotImplementedError |
| | self.bn2 = nn.LayerNorm(norm_shape) |
| | self.downsample = downsample |
| | self.stride = stride |
| |
|
| | def forward(self, x): |
| | residual = x |
| |
|
| | out = self.conv1(x) |
| | out = self.bn1(out) |
| | out = self.relu(out) |
| |
|
| | out = self.conv2(out) |
| | if not self.with_dcn: |
| | out = self.conv2(out) |
| | elif self.with_modulated_dcn: |
| | offset_mask = self.conv2_offset(out) |
| | offset = offset_mask[:, :18, :, :] |
| | mask = offset_mask[:, -9:, :, :].sigmoid() |
| | out = self.conv2(out, offset, mask) |
| | else: |
| | offset = self.conv2_offset(out) |
| | out = self.conv2(out, offset) |
| | out = self.bn2(out) |
| |
|
| | if self.downsample is not None: |
| | residual = self.downsample(x) |
| |
|
| | out += residual |
| | out = self.relu(out) |
| |
|
| | return out |
| |
|
| | class BasicBlock(nn.Module): |
| | expansion = 1 |
| |
|
| | def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None): |
| | super(BasicBlock, self).__init__() |
| | self.with_dcn = dcn is not None |
| | self.conv1 = conv3x3(inplanes, planes, stride) |
| | self.bn1 = nn.BatchNorm2d(planes) |
| | self.relu = nn.ReLU(inplace=True) |
| | self.with_modulated_dcn = False |
| | if self.with_dcn: |
| | fallback_on_stride = dcn.get('fallback_on_stride', False) |
| | self.with_modulated_dcn = dcn.get('modulated', False) |
| | self.conv2 = conv3x3(planes, planes) |
| | if not self.with_dcn or fallback_on_stride: |
| | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, |
| | padding=1, bias=False) |
| | else: |
| | raise NotImplementedError |
| | self.bn2 = nn.BatchNorm2d(planes) |
| | self.downsample = downsample |
| | self.stride = stride |
| |
|
| | def forward(self, x): |
| | residual = x |
| |
|
| | out = self.conv1(x) |
| | out = self.bn1(out) |
| | out = self.relu(out) |
| |
|
| | out = self.conv2(out) |
| | if not self.with_dcn: |
| | out = self.conv2(out) |
| | elif self.with_modulated_dcn: |
| | offset_mask = self.conv2_offset(out) |
| | offset = offset_mask[:, :18, :, :] |
| | mask = offset_mask[:, -9:, :, :].sigmoid() |
| | out = self.conv2(out, offset, mask) |
| | else: |
| | offset = self.conv2_offset(out) |
| | out = self.conv2(out, offset) |
| | out = self.bn2(out) |
| |
|
| | if self.downsample is not None: |
| | residual = self.downsample(x) |
| |
|
| | out += residual |
| | out = self.relu(out) |
| |
|
| | return out |
| |
|
| | class BasicBlockWOnorm(nn.Module): |
| | expansion = 1 |
| |
|
| | def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None): |
| | super(BasicBlockWOnorm, self).__init__() |
| | self.with_dcn = dcn is not None |
| | self.conv1 = conv3x3(inplanes, planes, stride) |
| | self.bn1 = BatchNorm2d(planes) |
| | self.relu = nn.ReLU(inplace=True) |
| | self.with_modulated_dcn = False |
| | if self.with_dcn: |
| | fallback_on_stride = dcn.get('fallback_on_stride', False) |
| | self.with_modulated_dcn = dcn.get('modulated', False) |
| | self.conv2 = conv3x3(planes, planes) |
| | if not self.with_dcn or fallback_on_stride: |
| | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, |
| | padding=1, bias=False) |
| | else: |
| | raise NotImplementedError |
| | self.bn2 = BatchNorm2d(planes) |
| | self.downsample = downsample |
| | self.stride = stride |
| |
|
| | def forward(self, x): |
| | residual = x |
| |
|
| | out = self.conv1(x) |
| | out = self.bn1(out) |
| | out = self.relu(out) |
| |
|
| | out = self.conv2(out) |
| | if not self.with_dcn: |
| | out = self.conv2(out) |
| | elif self.with_modulated_dcn: |
| | offset_mask = self.conv2_offset(out) |
| | offset = offset_mask[:, :18, :, :] |
| | mask = offset_mask[:, -9:, :, :].sigmoid() |
| | out = self.conv2(out, offset, mask) |
| | else: |
| | offset = self.conv2_offset(out) |
| | out = self.conv2(out, offset) |
| | out = self.bn2(out) |
| |
|
| | if self.downsample is not None: |
| | residual = self.downsample(x) |
| |
|
| | out += residual |
| | out = self.relu(out) |
| |
|
| | return out |
| |
|
| | class ResNetWOnorm(nn.Module): |
| | def __init__(self, |
| | dcn=None, |
| | out_dim=4096): |
| | print('using resnet without batchnorm') |
| | self.dcn = dcn |
| | self.inplanes = 256 |
| | super(ResNetWOnorm, self).__init__() |
| | self.layer1 = self._make_layer( |
| | BasicBlockWOnorm, 1024, 1, stride=2, dcn=dcn) |
| | self.layer2 = self._make_layer( |
| | BasicBlockWOnorm, 4096, 1, stride=2, dcn=dcn) |
| | self.fc = nn.Linear(4096, out_dim) |
| |
|
| | for m in self.modules(): |
| | if isinstance(m, nn.Conv2d): |
| | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
| | m.weight.data.normal_(0, math.sqrt(2. / n)) |
| | elif isinstance(m, nn.BatchNorm2d): |
| | m.weight.data.fill_(1) |
| | m.bias.data.zero_() |
| |
|
| |
|
| | def _make_layer(self, block, planes, blocks, stride=1, dcn=None): |
| | downsample = None |
| | if stride != 1 or self.inplanes != planes * block.expansion: |
| | downsample = nn.Sequential( |
| | nn.Conv2d(self.inplanes, planes * block.expansion, |
| | kernel_size=1, stride=stride, bias=False), |
| | BatchNorm2d(planes * block.expansion), |
| | ) |
| |
|
| | layers = [] |
| | layers.append(block(self.inplanes, planes, |
| | stride, downsample, dcn=dcn)) |
| | self.inplanes = planes * block.expansion |
| | for i in range(1, blocks): |
| | layers.append(block(self.inplanes, planes, dcn=dcn)) |
| |
|
| | return nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| |
|
| | x = self.layer1(x) |
| | x = self.layer2(x) |
| |
|
| | x = x.reshape(x.shape[0],x.shape[1],-1) |
| | x = x.permute(0,2,1) |
| |
|
| |
|
| | x = self.fc(x) |
| |
|
| | return x |
| |
|
| | class ResNetLayerNorm(nn.Module): |
| | def __init__(self, |
| | dcn=None, |
| | out_dim=4096): |
| | print('using resnet with layernorm') |
| | self.dcn = dcn |
| | self.inplanes = 256 |
| | h,w = 64,64 |
| | super(ResNetLayerNorm, self).__init__() |
| | self.layer1 = self._make_layer( |
| | BasicBlockLayerNorm, 1024, 1,[1024,32,32], stride=2, dcn=dcn) |
| | self.layer2 = self._make_layer( |
| | BasicBlockLayerNorm, 4096, 1,[4096,16,16], stride=2, dcn=dcn) |
| | self.fc = nn.Linear(4096, out_dim) |
| |
|
| | for m in self.modules(): |
| | if isinstance(m, nn.Conv2d): |
| | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
| | m.weight.data.normal_(0, math.sqrt(2. / n)) |
| | elif isinstance(m, nn.BatchNorm2d): |
| | m.weight.data.fill_(1) |
| | m.bias.data.zero_() |
| |
|
| |
|
| | def _make_layer(self, block, planes, blocks,norm_shape, stride=1, dcn=None): |
| | downsample = None |
| | if stride != 1 or self.inplanes != planes * block.expansion: |
| | downsample = nn.Sequential( |
| | nn.Conv2d(self.inplanes, planes * block.expansion, |
| | kernel_size=1, stride=stride, bias=False), |
| | nn.LayerNorm(norm_shape), |
| | ) |
| |
|
| | layers = [] |
| | layers.append(block(self.inplanes, planes,norm_shape, |
| | stride, downsample, dcn=dcn)) |
| | self.inplanes = planes * block.expansion |
| | for i in range(1, blocks): |
| | layers.append(block(self.inplanes, planes, dcn=dcn)) |
| |
|
| | return nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| |
|
| | x = self.layer1(x) |
| | x = self.layer2(x) |
| |
|
| | x = x.reshape(x.shape[0],x.shape[1],-1) |
| | x = x.permute(0,2,1) |
| |
|
| |
|
| | x = self.fc(x) |
| |
|
| | return x |
| |
|
| | class ResNet(nn.Module): |
| | def __init__(self, |
| | dcn=None, |
| | out_dim=4096): |
| | self.dcn = dcn |
| | self.inplanes = 256 |
| | super(ResNet, self).__init__() |
| | self.layer1 = self._make_layer( |
| | BasicBlock, 1024, 1, stride=2, dcn=dcn) |
| | self.layer2 = self._make_layer( |
| | BasicBlock, 4096, 1, stride=2, dcn=dcn) |
| | self.fc = nn.Linear(4096, out_dim) |
| |
|
| | for m in self.modules(): |
| | if isinstance(m, nn.Conv2d): |
| | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
| | m.weight.data.normal_(0, math.sqrt(2. / n)) |
| | elif isinstance(m, nn.BatchNorm2d): |
| | m.weight.data.fill_(1) |
| | m.bias.data.zero_() |
| |
|
| |
|
| | def _make_layer(self, block, planes, blocks, stride=1, dcn=None): |
| | downsample = None |
| | if stride != 1 or self.inplanes != planes * block.expansion: |
| | downsample = nn.Sequential( |
| | nn.Conv2d(self.inplanes, planes * block.expansion, |
| | kernel_size=1, stride=stride, bias=False), |
| | nn.BatchNorm2d(planes * block.expansion), |
| | ) |
| |
|
| | layers = [] |
| | layers.append(block(self.inplanes, planes, |
| | stride, downsample, dcn=dcn)) |
| | self.inplanes = planes * block.expansion |
| | for i in range(1, blocks): |
| | layers.append(block(self.inplanes, planes, dcn=dcn)) |
| |
|
| | return nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| |
|
| | x = self.layer1(x) |
| | x = self.layer2(x) |
| |
|
| | x = x.reshape(x.shape[0],x.shape[1],-1) |
| | x = x.permute(0,2,1) |
| |
|
| |
|
| | x = self.fc(x) |
| |
|
| | return x |
| |
|
| | class ResNetSwin(nn.Module): |
| | def __init__(self, |
| | dcn=None, |
| | input_dim=1024, |
| | out_dim=4096): |
| | self.dcn = dcn |
| | self.inplanes = input_dim |
| | super(ResNetSwin, self).__init__() |
| | self.layer1 = self._make_layer( |
| | BasicBlock, 2048, 1, stride=2, dcn=dcn) |
| | self.fc = nn.Linear(2048, out_dim) |
| |
|
| | for m in self.modules(): |
| | if isinstance(m, nn.Conv2d): |
| | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
| | m.weight.data.normal_(0, math.sqrt(2. / n)) |
| | elif isinstance(m, nn.BatchNorm2d): |
| | m.weight.data.fill_(1) |
| | m.bias.data.zero_() |
| |
|
| |
|
| | def _make_layer(self, block, planes, blocks, stride=1, dcn=None): |
| | downsample = None |
| | if stride != 1 or self.inplanes != planes * block.expansion: |
| | downsample = nn.Sequential( |
| | nn.Conv2d(self.inplanes, planes * block.expansion, |
| | kernel_size=1, stride=stride, bias=False), |
| | nn.BatchNorm2d(planes * block.expansion), |
| | ) |
| |
|
| | layers = [] |
| | layers.append(block(self.inplanes, planes, |
| | stride, downsample, dcn=dcn)) |
| | self.inplanes = planes * block.expansion |
| | for i in range(1, blocks): |
| | layers.append(block(self.inplanes, planes, dcn=dcn)) |
| |
|
| | return nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| |
|
| | x = self.layer1(x) |
| |
|
| | x = x.reshape(x.shape[0],x.shape[1],-1) |
| | x = x.permute(0,2,1) |
| |
|
| |
|
| | x = self.fc(x) |
| |
|
| | return x |
| |
|
| | class IdentityMap(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| |
|
| | def forward(self, x, *args, **kwargs): |
| | return x |
| |
|
| | @property |
| | def config(self): |
| | return {"mm_projector_type": 'identity'} |
| |
|
| |
|
| | class SimpleResBlock(nn.Module): |
| | def __init__(self, channels): |
| | super().__init__() |
| | self.pre_norm = nn.LayerNorm(channels) |
| |
|
| | self.proj = nn.Sequential( |
| | nn.Linear(channels, channels), |
| | nn.GELU(), |
| | nn.Linear(channels, channels) |
| | ) |
| | def forward(self, x): |
| | x = self.pre_norm(x) |
| | return x + self.proj(x) |
| |
|
| |
|
| | def build_vision_projector(config, delay_load=False, **kwargs): |
| | projector_type = getattr(config, 'mm_projector_type', 'linear') |
| | |
| | print("projector_type:", projector_type) |
| | if projector_type == 'linear': |
| | return nn.Linear(config.mm_hidden_size, config.hidden_size) |
| |
|
| | if projector_type == 'conv': |
| | with_norm = getattr(config, 'with_norm', True) |
| | with_layernorm = getattr(config, 'with_layernorm', True) |
| | out_dim = getattr(config,'projector_outdim',4096) |
| | |
| | if with_layernorm: |
| | return ResNetLayerNorm(out_dim=out_dim) |
| | if with_norm: |
| | return ResNet(out_dim=out_dim) |
| | else: |
| | return ResNetWOnorm(out_dim=out_dim) |
| | if projector_type == 'swin_conv': |
| | out_dim = getattr(config,'projector_outdim',4096) |
| | input_dim = getattr(config,'mm_input_embeds',1024) |
| | return ResNetSwin(input_dim=input_dim,out_dim=out_dim) |
| |
|
| | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) |
| | print("mlp_gelu_match:", mlp_gelu_match) |
| | if mlp_gelu_match: |
| | mlp_depth = int(mlp_gelu_match.group(1)) |
| | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] |
| | for _ in range(1, mlp_depth): |
| | modules.append(nn.GELU()) |
| | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) |
| | return nn.Sequential(*modules) |
| |
|
| | if projector_type == 'identity': |
| | return IdentityMap() |
| |
|
| | raise ValueError(f'Unknown projector type: {projector_type}') |
| |
|
| | if __name__ == '__main__': |
| | class Config: |
| | def __init__(self): |
| | self.mm_projector_type = 'conv' |
| | self.with_layernorm = True |
| | self.with_norm = False |
| | config = Config() |
| | net = build_vision_projector(config) |
| | image = torch.randn((4,256,64,64)) |
| | print(net) |
| | print(net(image).shape) |