import torch import torch.nn as nn from transformers import PreTrainedModel from .configuration_pixel import TopAIImageConfig class ResidualBlock(nn.Module): def __init__(self, channels): super().__init__() self.block = nn.Sequential( nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(channels), nn.ReLU(True), nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(channels) ) def forward(self, x): return x + self.block(x) class TopAIImageGenerator(PreTrainedModel): config_class = TopAIImageConfig # תיקון ה-AttributeError: חייב להיות מילון (dict) כדי שתהיה לו מתודה .keys() all_tied_weights_keys = {} def __init__(self, config): super().__init__(config) # שימוש ב-hidden_dim מהקונפיג (512) h = config.hidden_dim self.text_projection = nn.Linear(config.input_dim, 4 * 4 * h) # בנייה דינמית שמתאימה בדיוק למשקולות ב-Safetensors self.decoder = nn.Sequential( # שכבה 0: מ-512 ל-512 (כאן היה ה-Mismatch) self._upsample(h, h), # שכבה 1 ResidualBlock(h), # שכבה 2: מ-512 ל-256 self._upsample(h, 256), # שכבה 3 ResidualBlock(256), # שכבה 4: מ-256 ל-128 self._upsample(256, 128), # שכבה 5: מ-128 ל-64 self._upsample(128, 64), # שכבה 6: המעבר הסופי ל-32 פילטרים nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False), nn.BatchNorm2d(32), nn.ReLU(True), # שכבה 9: המרה לערוצי תמונה (RGB) nn.Conv2d(32, config.image_channels, kernel_size=3, padding=1), nn.Tanh() ) def _upsample(self, i, o): return nn.Sequential( nn.ConvTranspose2d(i, o, 4, 2, 1, bias=False), nn.BatchNorm2d(o), nn.ReLU(True) ) def forward(self, text_embeddings): # שינוי צורה למפת מאפיינים ראשונית x = self.text_projection(text_embeddings) x = x.view(-1, self.config.hidden_dim, 4, 4) return self.decoder(x)