Pixel-1 / modeling_pixel.py
Raziel1234's picture
Update modeling_pixel.py
8c7686d verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_pixel import TopAIImageConfig
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(channels),
nn.ReLU(True),
nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(channels)
)
def forward(self, x):
return x + self.block(x)
class TopAIImageGenerator(PreTrainedModel):
config_class = TopAIImageConfig
# ืชื™ืงื•ืŸ ื”-AttributeError: ื—ื™ื™ื‘ ืœื”ื™ื•ืช ืžื™ืœื•ืŸ (dict) ื›ื“ื™ ืฉืชื”ื™ื” ืœื• ืžืชื•ื“ื” .keys()
all_tied_weights_keys = {}
def __init__(self, config):
super().__init__(config)
# ืฉื™ืžื•ืฉ ื‘-hidden_dim ืžื”ืงื•ื ืคื™ื’ (512)
h = config.hidden_dim
self.text_projection = nn.Linear(config.input_dim, 4 * 4 * h)
# ื‘ื ื™ื™ื” ื“ื™ื ืžื™ืช ืฉืžืชืื™ืžื” ื‘ื“ื™ื•ืง ืœืžืฉืงื•ืœื•ืช ื‘-Safetensors
self.decoder = nn.Sequential(
# ืฉื›ื‘ื” 0: ืž-512 ืœ-512 (ื›ืืŸ ื”ื™ื” ื”-Mismatch)
self._upsample(h, h),
# ืฉื›ื‘ื” 1
ResidualBlock(h),
# ืฉื›ื‘ื” 2: ืž-512 ืœ-256
self._upsample(h, 256),
# ืฉื›ื‘ื” 3
ResidualBlock(256),
# ืฉื›ื‘ื” 4: ืž-256 ืœ-128
self._upsample(256, 128),
# ืฉื›ื‘ื” 5: ืž-128 ืœ-64
self._upsample(128, 64),
# ืฉื›ื‘ื” 6: ื”ืžืขื‘ืจ ื”ืกื•ืคื™ ืœ-32 ืคื™ืœื˜ืจื™ื
nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(True),
# ืฉื›ื‘ื” 9: ื”ืžืจื” ืœืขืจื•ืฆื™ ืชืžื•ื ื” (RGB)
nn.Conv2d(32, config.image_channels, kernel_size=3, padding=1),
nn.Tanh()
)
def _upsample(self, i, o):
return nn.Sequential(
nn.ConvTranspose2d(i, o, 4, 2, 1, bias=False),
nn.BatchNorm2d(o),
nn.ReLU(True)
)
def forward(self, text_embeddings):
# ืฉื™ื ื•ื™ ืฆื•ืจื” ืœืžืคืช ืžืืคื™ื™ื ื™ื ืจืืฉื•ื ื™ืช
x = self.text_projection(text_embeddings)
x = x.view(-1, self.config.hidden_dim, 4, 4)
return self.decoder(x)