File size: 3,326 Bytes
f9cae72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch, math, os, json
from PIL import Image
import numpy as np


class SingleValueEncoder(torch.nn.Module):
    def __init__(self, dim_in=256, dim_out=4096, length=32):
        super().__init__()
        self.length = length
        self.prefer_value_embedder = torch.nn.Sequential(torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out))
        self.positional_embedding = torch.nn.Parameter(torch.randn(self.length, dim_out))

    def get_timestep_embedding(self, timesteps, embedding_dim, max_period=10000):
        half_dim = embedding_dim // 2
        exponent = -math.log(max_period) * torch.arange(0, half_dim, dtype=torch.float32, device=timesteps.device) / half_dim
        emb = timesteps[:, None].float() * torch.exp(exponent)[None, :]
        emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
        return emb

    def forward(self, value, dtype):
        emb = self.get_timestep_embedding(value * 1000, 256).to(dtype)
        emb = self.prefer_value_embedder(emb).squeeze(0)
        base_embeddings = emb.expand(self.length, -1)
        positional_embedding = self.positional_embedding.to(dtype=base_embeddings.dtype, device=base_embeddings.device)
        learned_embeddings = base_embeddings + positional_embedding
        return learned_embeddings


class ValueFormatModel(torch.nn.Module):
    def __init__(self, num_double_blocks=5, num_single_blocks=20, dim=3072, num_heads=24, length=512):
        super().__init__()
        self.block_names = [f"double_{i}" for i in range(num_double_blocks)] + [f"single_{i}" for i in range(num_single_blocks)]
        self.proj_k = torch.nn.ModuleDict({block_name: SingleValueEncoder(dim_out=dim, length=length) for block_name in self.block_names})
        self.proj_v = torch.nn.ModuleDict({block_name: SingleValueEncoder(dim_out=dim, length=length) for block_name in self.block_names})
        self.num_heads = num_heads
        self.length = length

    @torch.no_grad()
    def process_inputs(self, pipe, scale, **kwargs):
        return {"value": torch.Tensor([scale]).to(dtype=pipe.torch_dtype, device=pipe.device)}

    def forward(self, value, **kwargs):
        kv_cache = {}
        for block_name in self.block_names:
            k = self.proj_k[block_name](value, value.dtype)
            k = k.view(1, self.length, self.num_heads, -1)
            v = self.proj_v[block_name](value, value.dtype)
            v = v.view(1, self.length, self.num_heads, -1)
            kv_cache[block_name] = (k, v)
        return {"kv_cache": kv_cache}


class DataAnnotator:
    def __init__(self):
        with open(os.path.join(os.path.dirname(__file__), "scores.json"), "r") as f:
            self.scores = json.load(f)

    def get_score(self, x):
        l, r = 0, len(self.scores)
        while l < r:
            m = (l + r) // 2
            if self.scores[m] < x: l = m + 1
            else: r = m
        return l / len(self.scores)
    
    def __call__(self, image, **kwargs):
        import cv2
        image = cv2.imread(image, cv2.IMREAD_GRAYSCALE)
        edges = cv2.Canny(image, 100, 200)
        scale = edges.astype(np.float32).mean().tolist()
        return {"scale": self.get_score(scale)}


TEMPLATE_MODEL = ValueFormatModel
TEMPLATE_MODEL_PATH = "model.safetensors"
TEMPLATE_DATA_PROCESSOR = DataAnnotator