File size: 9,212 Bytes
acd771b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 | """Raon-VisionEncoder model."""
import importlib
import os
import sys
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from .configuration_raonve import RaonVEConfig
_raon_repo_id = None
def set_repo_id(repo_id):
global _raon_repo_id
_raon_repo_id = repo_id
def _ensure_raon_package():
"""Import raon_vision_encoder, downloading from HF Hub if needed."""
try:
clip_mod = importlib.import_module("raon_vision_encoder.clip")
return clip_mod.CustomTextCLIP
except (ImportError, ModuleNotFoundError):
pass
from huggingface_hub import snapshot_download
repo_id = _raon_repo_id or "KRAFTON/Raon-VisionEncoder"
repo_dir = snapshot_download(repo_id, allow_patterns=["raon_vision_encoder/**"])
sys.path.insert(0, repo_dir)
for key in list(sys.modules.keys()):
if key.startswith("raon_vision_encoder"):
del sys.modules[key]
clip_mod = importlib.import_module("raon_vision_encoder.clip")
return clip_mod.CustomTextCLIP
class RaonVEPreTrainedModel(PreTrainedModel):
config_class = RaonVEConfig
base_model_prefix = ""
supports_gradient_checkpointing = True
def _init_weights(self, module):
pass
class RaonVEModel(RaonVEPreTrainedModel):
config_class = RaonVEConfig
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
set_repo_id(str(pretrained_model_name_or_path))
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
def __init__(self, config: RaonVEConfig):
super().__init__(config)
vision_cfg = {
"image_size": config.vision_config.image_size,
"timm_model_name": config.vision_config.timm_model_name,
"timm_model_pretrained": config.vision_config.timm_model_pretrained,
"timm_pool": config.vision_config.timm_pool,
"timm_proj": config.vision_config.timm_proj,
}
text_cfg = {
"context_length": config.text_config.context_length,
"vocab_size": config.text_config.vocab_size,
"width": config.text_config.width,
"heads": config.text_config.heads,
"layers": config.text_config.layers,
"mlp_ratio": config.text_config.mlp_ratio,
"no_causal_mask": config.text_config.no_causal_mask,
"proj_bias": config.text_config.proj_bias,
"pool_type": config.text_config.pool_type,
"hf_tokenizer_name": config.text_config.hf_tokenizer_name,
"tokenizer_kwargs": config.text_config.tokenizer_kwargs,
"norm_kwargs": config.text_config.norm_kwargs,
"act_kwargs": config.text_config.act_kwargs,
}
CustomTextCLIP = _ensure_raon_package()
inner = CustomTextCLIP(
embed_dim=config.embed_dim,
vision_cfg=vision_cfg,
text_cfg=text_cfg,
init_logit_bias=config.init_logit_bias,
)
self.visual = inner.visual
self.text = inner.text
self.logit_scale = inner.logit_scale
self.logit_bias = inner.logit_bias
# Enable NaFlex by default
self.visual._setup_1d_forward()
self.post_init()
def encode_image(self, pixel_values, pixel_attention_mask=None, spatial_shapes=None):
"""Encode images to normalized feature vectors [B, 1152].
Pass the output of processor(images=...) directly via **inputs.
"""
kwargs = {}
if pixel_attention_mask is not None:
kwargs["patch_valid_mask"] = pixel_attention_mask
if spatial_shapes is not None:
kwargs["spatial_shapes"] = spatial_shapes
features = self.visual(pixel_values, **kwargs) if kwargs else self.visual(pixel_values)
return F.normalize(features, dim=-1)
def encode_text(self, input_ids):
"""Encode text to normalized feature vectors [B, 1152].
Pass the output of processor(text=...) directly via **inputs.
"""
features = self.text(input_ids)
return F.normalize(features, dim=-1)
def forward(self, pixel_values=None, input_ids=None, pixel_attention_mask=None, spatial_shapes=None):
image_features = None
text_features = None
if pixel_values is not None:
image_features = self.encode_image(
pixel_values,
pixel_attention_mask=pixel_attention_mask,
spatial_shapes=spatial_shapes,
)
if input_ids is not None:
text_features = self.encode_text(input_ids)
output = {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale,
"logit_bias": self.logit_bias,
}
return output
@staticmethod
def get_processor(pretrained_model_name_or_path, **kwargs):
"""Get the processor for this model."""
return RaonVEProcessor.from_pretrained(pretrained_model_name_or_path, **kwargs)
class RaonVEProcessor:
"""Image and text processor for Raon-VisionEncoder.
Preprocesses images into NaFlex patch sequences and tokenizes text.
Args:
max_num_patches: Maximum number of patches per image (controls resolution).
Higher values preserve more detail. Default: 256.
"""
DEFAULT_MAX_PATCHES = 256
def __init__(self, patch_size=16, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), tokenizer=None):
from torchvision import transforms as T
self.patch_size = patch_size
self.mean, self.std = mean, std
self.tokenizer = tokenizer
self._post = T.Compose([T.ToTensor(), T.Normalize(mean=list(mean), std=list(std))])
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
import json
from pathlib import Path as _Path
if _Path(pretrained_model_name_or_path).is_dir():
cfg_path = _Path(pretrained_model_name_or_path) / "config.json"
else:
from huggingface_hub import hf_hub_download
cfg_path = hf_hub_download(pretrained_model_name_or_path, "config.json")
with open(cfg_path) as f:
cfg = json.load(f)
v = cfg.get("vision_config", {}); t = cfg.get("text_config", {})
ps = 16
for part in v.get("timm_model_name", "").split("_"):
if part.startswith("patch") and part[5:].isdigit():
ps = int(part[5:]); break
tokenizer = None
if t.get("hf_tokenizer_name"):
_ensure_raon_package()
tok_mod = importlib.import_module("raon_vision_encoder.tokenizer")
tokenizer = tok_mod.HFTokenizer(
t["hf_tokenizer_name"], context_length=t.get("context_length", 64),
tokenizer_mode=t.get("tokenizer_mode"), **t.get("tokenizer_kwargs", {}),
)
return cls(patch_size=ps, tokenizer=tokenizer)
def __call__(self, images=None, text=None, max_num_patches=None, return_tensors="pt"):
"""Process images and/or text.
Args:
images: PIL Image or list of PIL Images.
text: String or list of strings.
max_num_patches: Resolution budget (default: 256). Higher = more detail.
Returns:
Dict with 'pixel_values', 'pixel_attention_mask', 'spatial_shapes' for images
and/or 'input_ids' for text.
"""
from PIL import Image
result = {}
if images is not None:
mnp = max_num_patches or self.DEFAULT_MAX_PATCHES
_ensure_raon_package()
transform_mod = importlib.import_module("raon_vision_encoder.transform")
get_size = transform_mod.get_image_size_for_max_num_patches
imgs = [images] if isinstance(images, Image.Image) else images
ps = self.patch_size
all_p, all_m, all_s = [], [], []
for img in imgs:
img = img.convert("RGB")
w, h = img.size
th, tw = get_size(h, w, ps, mnp)
t = self._post(img.resize((tw, th), Image.BICUBIC))
gh, gw = th // ps, tw // ps
n = gh * gw
# [C, gh, ps, gw, ps] -> [gh, gw, C, ps, ps] -> [n, C*ps*ps]
patches = t.reshape(3, gh, ps, gw, ps).permute(1,3,0,2,4).reshape(n, 3*ps*ps)
padded = torch.zeros(mnp, ps*ps*3); padded[:n] = patches
mask = torch.zeros(mnp, dtype=torch.bool); mask[:n] = True
all_p.append(padded); all_m.append(mask)
all_s.append(torch.tensor([gh, gw]))
result["pixel_values"] = torch.stack(all_p)
result["pixel_attention_mask"] = torch.stack(all_m)
result["spatial_shapes"] = torch.stack(all_s)
if text is not None:
if self.tokenizer is None:
raise RuntimeError("Tokenizer not initialized.")
result["input_ids"] = self.tokenizer([text] if isinstance(text, str) else text)
return result
|