File size: 9,212 Bytes
acd771b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
"""Raon-VisionEncoder model."""

import importlib
import os
import sys
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel

from .configuration_raonve import RaonVEConfig


_raon_repo_id = None

def set_repo_id(repo_id):
    global _raon_repo_id
    _raon_repo_id = repo_id

def _ensure_raon_package():
    """Import raon_vision_encoder, downloading from HF Hub if needed."""
    try:
        clip_mod = importlib.import_module("raon_vision_encoder.clip")
        return clip_mod.CustomTextCLIP
    except (ImportError, ModuleNotFoundError):
        pass

    from huggingface_hub import snapshot_download
    repo_id = _raon_repo_id or "KRAFTON/Raon-VisionEncoder"
    repo_dir = snapshot_download(repo_id, allow_patterns=["raon_vision_encoder/**"])
    sys.path.insert(0, repo_dir)

    for key in list(sys.modules.keys()):
        if key.startswith("raon_vision_encoder"):
            del sys.modules[key]

    clip_mod = importlib.import_module("raon_vision_encoder.clip")
    return clip_mod.CustomTextCLIP


class RaonVEPreTrainedModel(PreTrainedModel):
    config_class = RaonVEConfig
    base_model_prefix = ""
    supports_gradient_checkpointing = True

    def _init_weights(self, module):
        pass


class RaonVEModel(RaonVEPreTrainedModel):
    config_class = RaonVEConfig

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
        set_repo_id(str(pretrained_model_name_or_path))
        return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)

    def __init__(self, config: RaonVEConfig):
        super().__init__(config)

        vision_cfg = {
            "image_size": config.vision_config.image_size,
            "timm_model_name": config.vision_config.timm_model_name,
            "timm_model_pretrained": config.vision_config.timm_model_pretrained,
            "timm_pool": config.vision_config.timm_pool,
            "timm_proj": config.vision_config.timm_proj,
        }
        text_cfg = {
            "context_length": config.text_config.context_length,
            "vocab_size": config.text_config.vocab_size,
            "width": config.text_config.width,
            "heads": config.text_config.heads,
            "layers": config.text_config.layers,
            "mlp_ratio": config.text_config.mlp_ratio,
            "no_causal_mask": config.text_config.no_causal_mask,
            "proj_bias": config.text_config.proj_bias,
            "pool_type": config.text_config.pool_type,
            "hf_tokenizer_name": config.text_config.hf_tokenizer_name,
            "tokenizer_kwargs": config.text_config.tokenizer_kwargs,
            "norm_kwargs": config.text_config.norm_kwargs,
            "act_kwargs": config.text_config.act_kwargs,
        }

        CustomTextCLIP = _ensure_raon_package()
        inner = CustomTextCLIP(
            embed_dim=config.embed_dim,
            vision_cfg=vision_cfg,
            text_cfg=text_cfg,
            init_logit_bias=config.init_logit_bias,
        )

        self.visual = inner.visual
        self.text = inner.text
        self.logit_scale = inner.logit_scale
        self.logit_bias = inner.logit_bias

        # Enable NaFlex by default
        self.visual._setup_1d_forward()

        self.post_init()

    def encode_image(self, pixel_values, pixel_attention_mask=None, spatial_shapes=None):
        """Encode images to normalized feature vectors [B, 1152].
        Pass the output of processor(images=...) directly via **inputs.
        """
        kwargs = {}
        if pixel_attention_mask is not None:
            kwargs["patch_valid_mask"] = pixel_attention_mask
        if spatial_shapes is not None:
            kwargs["spatial_shapes"] = spatial_shapes
        features = self.visual(pixel_values, **kwargs) if kwargs else self.visual(pixel_values)
        return F.normalize(features, dim=-1)

    def encode_text(self, input_ids):
        """Encode text to normalized feature vectors [B, 1152].
        Pass the output of processor(text=...) directly via **inputs.
        """
        features = self.text(input_ids)
        return F.normalize(features, dim=-1)

    def forward(self, pixel_values=None, input_ids=None, pixel_attention_mask=None, spatial_shapes=None):
        image_features = None
        text_features = None

        if pixel_values is not None:
            image_features = self.encode_image(
                pixel_values,
                pixel_attention_mask=pixel_attention_mask,
                spatial_shapes=spatial_shapes,
            )
        if input_ids is not None:
            text_features = self.encode_text(input_ids)

        output = {
            "image_features": image_features,
            "text_features": text_features,
            "logit_scale": self.logit_scale,
            "logit_bias": self.logit_bias,
        }
        return output

    @staticmethod
    def get_processor(pretrained_model_name_or_path, **kwargs):
        """Get the processor for this model."""
        return RaonVEProcessor.from_pretrained(pretrained_model_name_or_path, **kwargs)


class RaonVEProcessor:
    """Image and text processor for Raon-VisionEncoder.

    Preprocesses images into NaFlex patch sequences and tokenizes text.

    Args:
        max_num_patches: Maximum number of patches per image (controls resolution).
            Higher values preserve more detail. Default: 256.
    """

    DEFAULT_MAX_PATCHES = 256

    def __init__(self, patch_size=16, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), tokenizer=None):
        from torchvision import transforms as T
        self.patch_size = patch_size
        self.mean, self.std = mean, std
        self.tokenizer = tokenizer
        self._post = T.Compose([T.ToTensor(), T.Normalize(mean=list(mean), std=list(std))])

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        import json
        from pathlib import Path as _Path
        if _Path(pretrained_model_name_or_path).is_dir():
            cfg_path = _Path(pretrained_model_name_or_path) / "config.json"
        else:
            from huggingface_hub import hf_hub_download
            cfg_path = hf_hub_download(pretrained_model_name_or_path, "config.json")
        with open(cfg_path) as f:
            cfg = json.load(f)
        v = cfg.get("vision_config", {}); t = cfg.get("text_config", {})
        ps = 16
        for part in v.get("timm_model_name", "").split("_"):
            if part.startswith("patch") and part[5:].isdigit():
                ps = int(part[5:]); break
        tokenizer = None
        if t.get("hf_tokenizer_name"):
            _ensure_raon_package()
            tok_mod = importlib.import_module("raon_vision_encoder.tokenizer")
            tokenizer = tok_mod.HFTokenizer(
                t["hf_tokenizer_name"], context_length=t.get("context_length", 64),
                tokenizer_mode=t.get("tokenizer_mode"), **t.get("tokenizer_kwargs", {}),
            )
        return cls(patch_size=ps, tokenizer=tokenizer)

    def __call__(self, images=None, text=None, max_num_patches=None, return_tensors="pt"):
        """Process images and/or text.

        Args:
            images: PIL Image or list of PIL Images.
            text: String or list of strings.
            max_num_patches: Resolution budget (default: 256). Higher = more detail.

        Returns:
            Dict with 'pixel_values', 'pixel_attention_mask', 'spatial_shapes' for images
            and/or 'input_ids' for text.
        """
        from PIL import Image
        result = {}
        if images is not None:
            mnp = max_num_patches or self.DEFAULT_MAX_PATCHES
            _ensure_raon_package()
            transform_mod = importlib.import_module("raon_vision_encoder.transform")
            get_size = transform_mod.get_image_size_for_max_num_patches
            imgs = [images] if isinstance(images, Image.Image) else images
            ps = self.patch_size
            all_p, all_m, all_s = [], [], []
            for img in imgs:
                img = img.convert("RGB")
                w, h = img.size
                th, tw = get_size(h, w, ps, mnp)
                t = self._post(img.resize((tw, th), Image.BICUBIC))
                gh, gw = th // ps, tw // ps
                n = gh * gw
                # [C, gh, ps, gw, ps] -> [gh, gw, C, ps, ps] -> [n, C*ps*ps]
                patches = t.reshape(3, gh, ps, gw, ps).permute(1,3,0,2,4).reshape(n, 3*ps*ps)
                padded = torch.zeros(mnp, ps*ps*3); padded[:n] = patches
                mask = torch.zeros(mnp, dtype=torch.bool); mask[:n] = True
                all_p.append(padded); all_m.append(mask)
                all_s.append(torch.tensor([gh, gw]))
            result["pixel_values"] = torch.stack(all_p)
            result["pixel_attention_mask"] = torch.stack(all_m)
            result["spatial_shapes"] = torch.stack(all_s)
        if text is not None:
            if self.tokenizer is None:
                raise RuntimeError("Tokenizer not initialized.")
            result["input_ids"] = self.tokenizer([text] if isinstance(text, str) else text)
        return result