| | import os |
| | import torch |
| | import numpy as np |
| | import torch.nn as nn |
| | import lightning.pytorch as pl |
| | from safetensors.torch import save_file, load_file |
| | from transformers import AutoImageProcessor, AutoModel |
| | from transformers.image_utils import load_image |
| |
|
| | class EmbeddingNetwork(nn.Module): |
| | def __init__(self): |
| | super(EmbeddingNetwork, self).__init__() |
| | self.fc1 = nn.Linear(1280, 256) |
| | self.dropout1 = nn.Dropout(0.33) |
| | self.fc2 = nn.Linear(256, 128) |
| | self.dropout2 = nn.Dropout(0.33) |
| | self.fc3 = nn.Linear(128, 7) |
| | self.act = nn.ReLU(inplace=True) |
| |
|
| | def forward(self, x): |
| | x = self.fc1(x) |
| | |
| | x = self.act(x) |
| | x = self.fc2(x) |
| | |
| | x = self.act(x) |
| | x = self.fc3(x) |
| | return x |
| |
|
| |
|
| |
|
| | class PLModule(pl.LightningModule): |
| | def __init__(self): |
| | super().__init__() |
| | self.save_hyperparameters() |
| | self.network = EmbeddingNetwork() |
| |
|
| | def forward(self, x): |
| | return self.network(x) |
| |
|
| | def predict_step(self, batch, batch_idx, dataloader_idx=0): |
| | outputs = self.forward(batch[0]) |
| | return outputs, batch[1] |
| |
|
| |
|
| |
|
| | if __name__ == '__main__': |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | |
| | embd_model = EmbeddingNetwork().to(device=device, dtype=torch.bfloat16) |
| | state_dict = load_file("Style Embedder v4.safetensors") |
| | embd_model.load_state_dict(state_dict) |
| |
|
| | token = 'Enter your huggingface token here' |
| | processor = AutoImageProcessor.from_pretrained("facebook/dinov3-vits16-pretrain-lvd1689m", |
| | do_resize=False, token=token) |
| | dino_model = AutoModel.from_pretrained("facebook/dinov3-vith16plus-pretrain-lvd1689m", token=token, device_map="auto", |
| | dtype=torch.bfloat16) |
| | image = load_image('images_for_style_embedding/6857740.webp') |
| | input = processor(images=image, return_tensors="pt").to(device=dino_model.device, dtype=torch.bfloat16) |
| | output = dino_model(**input) |
| | last_hidden_states = output.last_hidden_state |
| | cls_token = last_hidden_states[:, 0, :] |
| |
|
| | pred = embd_model(cls_token).cpu() |
| | print(pred) |
| |
|