| | import torch |
| | import transformers |
| | from torch import nn |
| | from transformers.modeling_outputs import SemanticSegmenterOutput |
| |
|
| |
|
| | class FaceSegmenterConfig(transformers.PretrainedConfig): |
| | model_type = "image-segmentation" |
| |
|
| | _id2label = { |
| | 0: "skin", |
| | 1: "l_brow", |
| | 2: "r_brow", |
| | 3: "l_eye", |
| | 4: "r_eye", |
| | 5: "eye_g", |
| | 6: "l_ear", |
| | 7: "r_ear", |
| | 8: "ear_r", |
| | 9: "nose", |
| | 10: "mouth", |
| | 11: "u_lip", |
| | 12: "l_lip", |
| | 13: "neck", |
| | 14: "neck_l", |
| | 15: "cloth", |
| | 16: "hair", |
| | 17: "hat", |
| | } |
| |
|
| | _label2id = { |
| | "skin": 0, |
| | "l_brow": 1, |
| | "r_brow": 2, |
| | "l_eye": 3, |
| | "r_eye": 4, |
| | "eye_g": 5, |
| | "l_ear": 6, |
| | "r_ear": 7, |
| | "ear_r": 8, |
| | "nose": 9, |
| | "mouth": 10, |
| | "u_lip": 11, |
| | "l_lip": 12, |
| | "neck": 13, |
| | "neck_l": 14, |
| | "cloth": 15, |
| | "hair": 16, |
| | "hat": 17, |
| | } |
| |
|
| | def __init__(self, **kwargs): |
| | super().__init__(**kwargs) |
| | self.id2label = kwargs.get("id2label", self._id2label) |
| |
|
| | |
| | id_keys = list(self.id2label.keys()) |
| | for label_id in id_keys: |
| | label_value = self.id2label.pop(label_id) |
| | self.id2label[int(label_id)] = label_value |
| | |
| | self.label2id = kwargs.get("label2id", self._label2id) |
| | self.num_classes = kwargs.get("num_classes", len(self.id2label)) |
| |
|
| |
|
| | def encode_down(c_in: int, c_out: int): |
| | return nn.Sequential( |
| | nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(num_features=c_out), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(in_channels=c_out, out_channels=c_out, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(num_features=c_out), |
| | nn.ReLU(inplace=True), |
| | ) |
| |
|
| |
|
| | def decode_up(c: int): |
| | return nn.ConvTranspose2d( |
| | in_channels=c, |
| | out_channels=int(c / 2), |
| | kernel_size=2, |
| | stride=2, |
| | ) |
| |
|
| |
|
| | class FaceUNet(nn.Module): |
| | def __init__(self, num_classes: int): |
| | super().__init__() |
| | self.num_classes = num_classes |
| | |
| | self.down_1 = nn.Conv2d( |
| | in_channels=3, |
| | out_channels=64, |
| | kernel_size=3, |
| | padding=1, |
| | ) |
| | self.down_2 = encode_down(64, 128) |
| | self.down_3 = encode_down(128, 256) |
| | self.down_4 = encode_down(256, 512) |
| | self.down_5 = encode_down(512, 1024) |
| |
|
| | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) |
| |
|
| | |
| | self.up_1 = decode_up(1024) |
| | self.up_c1 = encode_down(1024, 512) |
| | self.up_2 = decode_up(512) |
| | self.up_c2 = encode_down(512, 256) |
| | self.up_3 = decode_up(256) |
| | self.up_c3 = encode_down(256, 128) |
| | self.up_4 = decode_up(128) |
| | self.up_c4 = encode_down(128, 64) |
| |
|
| | self.segment = nn.Conv2d( |
| | in_channels=64, |
| | out_channels=self.num_classes, |
| | kernel_size=3, |
| | padding=1, |
| | ) |
| |
|
| | def forward(self, x): |
| | d1 = self.down_1(x) |
| | d2 = self.pool(d1) |
| | d3 = self.down_2(d2) |
| | d4 = self.pool(d3) |
| | d5 = self.down_3(d4) |
| | d6 = self.pool(d5) |
| | d7 = self.down_4(d6) |
| | d8 = self.pool(d7) |
| | d9 = self.down_5(d8) |
| |
|
| | u1 = self.up_1(d9) |
| | x = self.up_c1(torch.cat([d7, u1], 1)) |
| | u2 = self.up_2(x) |
| | x = self.up_c2(torch.cat([d5, u2], 1)) |
| | u3 = self.up_3(x) |
| | x = self.up_c3(torch.cat([d3, u3], 1)) |
| | u4 = self.up_4(x) |
| | x = self.up_c4(torch.cat([d1, u4], 1)) |
| |
|
| | x = self.segment(x) |
| | return x |
| |
|
| |
|
| | class Segformer(transformers.PreTrainedModel): |
| | config_class = FaceSegmenterConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.config = config |
| | self.model = FaceUNet(num_classes=config.num_classes) |
| |
|
| | def forward(self, tensor): |
| | return self.model.forward_features(tensor) |
| |
|
| |
|
| | class SegformerForSemanticSegmentation(transformers.PreTrainedModel): |
| | config_class = FaceSegmenterConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.config = config |
| | self.model = FaceUNet(num_classes=config.num_classes) |
| |
|
| | def forward(self, pixel_values, labels=None): |
| | logits = self.model(pixel_values) |
| | values = {"logits": logits} |
| | if labels is not None: |
| | loss = torch.nn.cross_entropy(logits, labels) |
| | values["loss"] = loss |
| | return SemanticSegmenterOutput(**values) |
| |
|