PaddleOCR-VL-1.5-VisionEncoder / model /extracted_vision_encoder.py
acsfid's picture
Upload PaddleOCR-VL split vision encoder artifacts
d96cc49 verified
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.processing_utils import BatchFeature
from .configuration_paddleocr_vl import PaddleOCRVLConfig
from .image_processing_paddleocr_vl import PaddleOCRVLImageProcessor
from .modeling_paddleocr_vl import PaddleOCRVisionModel, Projector
VISION_TOWER_CONFIG_NAME = "vision_tower_config.json"
VISION_TOWER_WEIGHTS_NAME = "vision_tower.safetensors"
PROJECTOR_CONFIG_NAME = "projector_config.json"
PROJECTOR_WEIGHTS_NAME = "projector.safetensors"
FULL_MODEL_CONFIG_NAME = "config.json"
FULL_MODEL_WEIGHTS_NAME = "model.safetensors"
FULL_VISUAL_PREFIX = "visual."
FULL_PROJECTOR_PREFIX = "mlp_AR."
STANDALONE_VISUAL_PREFIX = "visual."
STANDALONE_PROJECTOR_PREFIX = "projector."
IMAGE_PROCESSOR_TEMPORAL_PATCH_SIZE = 1
def _read_json(path: Union[str, Path]) -> Dict[str, Any]:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def _write_json(path: Union[str, Path], payload: Dict[str, Any]) -> None:
with open(path, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2, ensure_ascii=False)
def _normalize_image_grid_thw(
image_grid_thw: Union[torch.Tensor, Sequence[Any]]
) -> List[Tuple[int, int, int]]:
if isinstance(image_grid_thw, torch.Tensor):
return [tuple(int(v) for v in row.tolist()) for row in image_grid_thw]
normalized: List[Tuple[int, int, int]] = []
for item in image_grid_thw:
if isinstance(item, torch.Tensor):
normalized.append(tuple(int(v) for v in item.tolist()))
else:
normalized.append(tuple(int(v) for v in item))
return normalized
def build_vision_encoder_export_config(
full_config: Union[PaddleOCRVLConfig, Dict[str, Any]]
) -> Dict[str, Any]:
if isinstance(full_config, PaddleOCRVLConfig):
full_config_dict = full_config.to_dict()
else:
full_config_dict = dict(full_config)
vision_config = dict(full_config_dict["vision_config"])
return {
"model_type": "paddleocr_vl_vision_encoder",
"architectures": ["PaddleOCRVLVisionEncoder"],
"source_model_type": full_config_dict.get("model_type", "paddleocr_vl"),
"source_architecture": "PaddleOCRVLForConditionalGeneration",
"text_hidden_size": full_config_dict["hidden_size"],
"image_token_id": full_config_dict.get("image_token_id"),
"vision_start_token_id": full_config_dict.get("vision_start_token_id"),
"vision_end_token_id": full_config_dict.get("vision_end_token_id"),
"torch_dtype": full_config_dict.get("torch_dtype"),
"vision_config": vision_config,
"projector": {
"merge_kernel_size": [2, 2],
"input_hidden_size": vision_config["hidden_size"],
"output_hidden_size": full_config_dict["hidden_size"],
},
"required_weight_prefixes": [
STANDALONE_VISUAL_PREFIX,
STANDALONE_PROJECTOR_PREFIX,
],
"source_weight_prefixes": {
"visual": FULL_VISUAL_PREFIX,
"projector": FULL_PROJECTOR_PREFIX,
},
"full_model_config": full_config_dict,
}
def build_vision_tower_export_config(
full_config: Union[PaddleOCRVLConfig, Dict[str, Any]]
) -> Dict[str, Any]:
combined = build_vision_encoder_export_config(full_config)
return {
"model_type": "paddleocr_vl_vision_tower",
"architectures": ["PaddleOCRVLVisionTower"],
"torch_dtype": combined.get("torch_dtype"),
"vision_config": combined["vision_config"],
"required_weight_prefixes": [STANDALONE_VISUAL_PREFIX],
"source_weight_prefixes": {"visual": FULL_VISUAL_PREFIX},
"full_model_config": combined["full_model_config"],
}
def build_projector_export_config(
full_config: Union[PaddleOCRVLConfig, Dict[str, Any]]
) -> Dict[str, Any]:
combined = build_vision_encoder_export_config(full_config)
return {
"model_type": "paddleocr_vl_projector",
"architectures": ["PaddleOCRVLProjector"],
"torch_dtype": combined.get("torch_dtype"),
"vision_config": combined["vision_config"],
"text_hidden_size": combined["text_hidden_size"],
"projector": combined["projector"],
"required_weight_prefixes": [STANDALONE_PROJECTOR_PREFIX],
"source_weight_prefixes": {"projector": FULL_PROJECTOR_PREFIX},
"full_model_config": combined["full_model_config"],
}
def remap_full_model_state_dict_to_vision_encoder_parts(
full_state_dict: Dict[str, torch.Tensor]
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, List[str]]]:
visual_state_dict: Dict[str, torch.Tensor] = {}
projector_state_dict: Dict[str, torch.Tensor] = {}
consumed_visual: List[str] = []
consumed_projector: List[str] = []
for key, value in full_state_dict.items():
if key.startswith(FULL_VISUAL_PREFIX):
new_key = STANDALONE_VISUAL_PREFIX + key[len(FULL_VISUAL_PREFIX) :]
visual_state_dict[new_key] = value
consumed_visual.append(key)
elif key.startswith(FULL_PROJECTOR_PREFIX):
new_key = STANDALONE_PROJECTOR_PREFIX + key[len(FULL_PROJECTOR_PREFIX) :]
projector_state_dict[new_key] = value
consumed_projector.append(key)
if not consumed_visual:
raise ValueError("No visual.* weights were found in the full model state dict.")
if not consumed_projector:
raise ValueError("No mlp_AR.* weights were found in the full model state dict.")
return visual_state_dict, projector_state_dict, {
"visual": sorted(consumed_visual),
"projector": sorted(consumed_projector),
}
def remap_full_model_state_dict_to_vision_encoder(
full_state_dict: Dict[str, torch.Tensor]
) -> Tuple[Dict[str, torch.Tensor], Dict[str, List[str]]]:
visual_state_dict, projector_state_dict, consumed = (
remap_full_model_state_dict_to_vision_encoder_parts(full_state_dict)
)
remapped = {}
remapped.update(visual_state_dict)
remapped.update(projector_state_dict)
return remapped, consumed
def _load_safetensors_state_dict(path: Union[str, Path]) -> Dict[str, torch.Tensor]:
try:
from safetensors.torch import load_file
except ImportError as e:
raise RuntimeError(
"Loading safetensors requires the `safetensors` package to be installed."
) from e
return load_file(str(path))
def _save_safetensors_state_dict(
state_dict: Dict[str, torch.Tensor], path: Union[str, Path]
) -> None:
try:
from safetensors.torch import save_file
except ImportError as e:
raise RuntimeError(
"Saving safetensors requires the `safetensors` package to be installed."
) from e
save_file(state_dict, str(path))
def extract_and_save_vision_encoder_artifacts(
full_config: Union[PaddleOCRVLConfig, Dict[str, Any]],
full_state_dict: Dict[str, torch.Tensor],
output_dir: Union[str, Path],
) -> Dict[str, Any]:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
vision_tower_config = build_vision_tower_export_config(full_config)
projector_config = build_projector_export_config(full_config)
visual_state_dict, projector_state_dict, consumed = (
remap_full_model_state_dict_to_vision_encoder_parts(full_state_dict)
)
_save_safetensors_state_dict(
visual_state_dict, output_dir / VISION_TOWER_WEIGHTS_NAME
)
_write_json(output_dir / VISION_TOWER_CONFIG_NAME, vision_tower_config)
_save_safetensors_state_dict(
projector_state_dict, output_dir / PROJECTOR_WEIGHTS_NAME
)
_write_json(output_dir / PROJECTOR_CONFIG_NAME, projector_config)
combined_export_config = build_vision_encoder_export_config(full_config)
combined_state_dict, _ = remap_full_model_state_dict_to_vision_encoder(
full_state_dict
)
combined_dir = output_dir / "combined"
combined_dir.mkdir(parents=True, exist_ok=True)
_save_safetensors_state_dict(
combined_state_dict, combined_dir / "vision_encoder.safetensors"
)
_write_json(combined_dir / "vision_encoder_config.json", combined_export_config)
metadata = {
"vision_tower_config_path": str(output_dir / VISION_TOWER_CONFIG_NAME),
"vision_tower_weights_path": str(output_dir / VISION_TOWER_WEIGHTS_NAME),
"projector_config_path": str(output_dir / PROJECTOR_CONFIG_NAME),
"projector_weights_path": str(output_dir / PROJECTOR_WEIGHTS_NAME),
"combined_config_path": str(combined_dir / "vision_encoder_config.json"),
"combined_weights_path": str(combined_dir / "vision_encoder.safetensors"),
"num_exported_visual_tensors": len(visual_state_dict),
"num_exported_projector_tensors": len(projector_state_dict),
"consumed_full_model_keys": consumed,
}
return metadata
class PaddleOCRVLVisionTower(torch.nn.Module):
def __init__(self, config: PaddleOCRVLConfig):
super().__init__()
self.config = config
self.visual = PaddleOCRVisionModel(config.vision_config)
self.export_config = build_vision_tower_export_config(config)
@staticmethod
def _resolve_full_config(config_payload: Dict[str, Any]) -> PaddleOCRVLConfig:
if config_payload.get("model_type") == "paddleocr_vl_vision_tower":
config_payload = config_payload["full_model_config"]
return PaddleOCRVLConfig(**config_payload)
@classmethod
def from_pretrained(cls, model_dir: Union[str, Path]) -> "PaddleOCRVLVisionTower":
model_dir = Path(model_dir)
config_path = model_dir / VISION_TOWER_CONFIG_NAME
weights_path = model_dir / VISION_TOWER_WEIGHTS_NAME
if config_path.exists():
config_payload = _read_json(config_path)
else:
config_payload = _read_json(model_dir / FULL_MODEL_CONFIG_NAME)
model = cls(cls._resolve_full_config(config_payload))
if weights_path.exists():
state_dict = _load_safetensors_state_dict(weights_path)
else:
full_state_dict = _load_safetensors_state_dict(model_dir / FULL_MODEL_WEIGHTS_NAME)
state_dict, _, _ = remap_full_model_state_dict_to_vision_encoder_parts(
full_state_dict
)
missing, unexpected = model.load_state_dict(state_dict, strict=True)
if missing or unexpected:
raise RuntimeError(
f"Failed to load standalone vision tower weights. Missing: {missing}, unexpected: {unexpected}"
)
return model
def save_pretrained(self, output_dir: Union[str, Path]) -> None:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
_save_safetensors_state_dict(self.state_dict(), output_dir / VISION_TOWER_WEIGHTS_NAME)
_write_json(output_dir / VISION_TOWER_CONFIG_NAME, self.export_config)
@staticmethod
def _build_visual_inputs(
pixel_values: torch.Tensor,
image_grid_thw: List[Tuple[int, int, int]],
device: torch.device,
) -> Tuple[
torch.Tensor,
torch.Tensor,
List[Tuple[int, int, int]],
torch.Tensor,
torch.Tensor,
]:
if pixel_values.dim() == 4:
pixel_values = pixel_values.unsqueeze(0)
elif pixel_values.dim() != 5:
raise ValueError(
"pixel_values must have shape [num_patches, C, H, W] or [1, num_patches, C, H, W]."
)
siglip_position_ids = []
sample_indices = []
cu_seqlens = [0]
for idx, thw in enumerate(image_grid_thw):
numel = int(np.prod(thw))
image_position_ids = torch.arange(numel, device=device) % int(np.prod(thw[1:]))
siglip_position_ids.append(image_position_ids)
sample_indices.append(torch.full((numel,), idx, dtype=torch.int64, device=device))
cu_seqlens.append(cu_seqlens[-1] + numel)
if siglip_position_ids:
siglip_position_ids = torch.cat(siglip_position_ids, dim=0)
sample_indices = torch.cat(sample_indices, dim=0)
else:
siglip_position_ids = torch.empty(0, dtype=torch.long, device=device)
sample_indices = torch.empty(0, dtype=torch.long, device=device)
cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
return pixel_values, siglip_position_ids, image_grid_thw, sample_indices, cu_seqlens_tensor
def forward(
self,
pixel_values: torch.Tensor,
image_grid_thw: Union[torch.Tensor, Sequence[Any]],
) -> Dict[str, Any]:
image_grid_thw_list = _normalize_image_grid_thw(image_grid_thw)
vision_dtype = next(self.visual.parameters()).dtype
pixel_values = pixel_values.to(dtype=vision_dtype)
device = pixel_values.device
(
pixel_values_5d,
siglip_position_ids,
image_grid_hws,
sample_indices,
cu_seqlens,
) = self._build_visual_inputs(pixel_values, image_grid_thw_list, device)
vision_outputs: BaseModelOutputWithPooling = self.visual(
pixel_values=pixel_values_5d,
image_grid_thw=image_grid_hws,
position_ids=siglip_position_ids,
vision_return_embed_list=True,
interpolate_pos_encoding=True,
sample_indices=sample_indices,
cu_seqlens=cu_seqlens,
return_pooler_output=False,
use_rope=True,
window_size=-1,
)
return {
"visual_embeds": vision_outputs.last_hidden_state,
"image_grid_thw": image_grid_thw_list,
"siglip_position_ids": siglip_position_ids,
"sample_indices": sample_indices,
"cu_seqlens": cu_seqlens,
}
def encode_images(
self,
images: Any,
image_processor: Optional[PaddleOCRVLImageProcessor] = None,
**processor_kwargs: Any,
) -> Dict[str, Any]:
image_processor = image_processor or PaddleOCRVLImageProcessor(
patch_size=self.config.vision_config.patch_size,
# The current image preprocessing implementation is image-only and asserts
# `temporal_patch_size == 1`, even though the vision model config may store 2.
temporal_patch_size=IMAGE_PROCESSOR_TEMPORAL_PATCH_SIZE,
merge_size=self.config.vision_config.spatial_merge_size,
)
encoded: BatchFeature = image_processor(
images=images, return_tensors="pt", **processor_kwargs
)
return self.forward(
pixel_values=encoded["pixel_values"], image_grid_thw=encoded["image_grid_thw"]
)
class PaddleOCRVLProjector(torch.nn.Module):
def __init__(self, config: PaddleOCRVLConfig):
super().__init__()
self.config = config
self.projector = Projector(config, config.vision_config)
self.export_config = build_projector_export_config(config)
@staticmethod
def _resolve_full_config(config_payload: Dict[str, Any]) -> PaddleOCRVLConfig:
if config_payload.get("model_type") == "paddleocr_vl_projector":
config_payload = config_payload["full_model_config"]
return PaddleOCRVLConfig(**config_payload)
@classmethod
def from_pretrained(cls, model_dir: Union[str, Path]) -> "PaddleOCRVLProjector":
model_dir = Path(model_dir)
config_path = model_dir / PROJECTOR_CONFIG_NAME
weights_path = model_dir / PROJECTOR_WEIGHTS_NAME
if config_path.exists():
config_payload = _read_json(config_path)
else:
config_payload = _read_json(model_dir / FULL_MODEL_CONFIG_NAME)
model = cls(cls._resolve_full_config(config_payload))
if weights_path.exists():
state_dict = _load_safetensors_state_dict(weights_path)
else:
full_state_dict = _load_safetensors_state_dict(model_dir / FULL_MODEL_WEIGHTS_NAME)
_, state_dict, _ = remap_full_model_state_dict_to_vision_encoder_parts(
full_state_dict
)
missing, unexpected = model.load_state_dict(state_dict, strict=True)
if missing or unexpected:
raise RuntimeError(
f"Failed to load standalone projector weights. Missing: {missing}, unexpected: {unexpected}"
)
return model
def save_pretrained(self, output_dir: Union[str, Path]) -> None:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
_save_safetensors_state_dict(self.state_dict(), output_dir / PROJECTOR_WEIGHTS_NAME)
_write_json(output_dir / PROJECTOR_CONFIG_NAME, self.export_config)
def forward(
self,
visual_embeds: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
image_grid_thw: Union[torch.Tensor, Sequence[Any]],
) -> Dict[str, Any]:
image_grid_thw_list = _normalize_image_grid_thw(image_grid_thw)
image_embeds = self.projector(visual_embeds, image_grid_thw_list)
projector_dtype = next(self.projector.parameters()).dtype
projector_device = next(self.projector.parameters()).device
concat_image_embeds = (
torch.cat(image_embeds, dim=0)
if image_embeds
else torch.empty(
0,
self.config.hidden_size,
device=projector_device,
dtype=projector_dtype,
)
)
return {
"image_embeds": image_embeds,
"concat_image_embeds": concat_image_embeds,
"image_grid_thw": image_grid_thw_list,
}
class PaddleOCRVLVisionEncoder(torch.nn.Module):
def __init__(self, config: PaddleOCRVLConfig):
super().__init__()
self.config = config
self.vision_tower = PaddleOCRVLVisionTower(config)
self.projector = PaddleOCRVLProjector(config)
self.export_config = build_vision_encoder_export_config(config)
@classmethod
def from_pretrained(cls, model_dir: Union[str, Path]) -> "PaddleOCRVLVisionEncoder":
model_dir = Path(model_dir)
config_candidates = [
model_dir / FULL_MODEL_CONFIG_NAME,
model_dir / VISION_TOWER_CONFIG_NAME,
model_dir / PROJECTOR_CONFIG_NAME,
]
config_path = next((path for path in config_candidates if path.exists()), None)
if config_path is None:
raise FileNotFoundError(
"Could not find config.json, vision_tower_config.json, or projector_config.json."
)
config_payload = _read_json(config_path)
if config_payload.get("model_type") == "paddleocr_vl_vision_tower":
config = PaddleOCRVLVisionTower._resolve_full_config(config_payload)
elif config_payload.get("model_type") == "paddleocr_vl_projector":
config = PaddleOCRVLProjector._resolve_full_config(config_payload)
else:
config = PaddleOCRVLProjector._resolve_full_config(config_payload)
model = cls(config)
model.vision_tower = PaddleOCRVLVisionTower.from_pretrained(model_dir)
model.projector = PaddleOCRVLProjector.from_pretrained(model_dir)
return model
def forward(
self,
pixel_values: torch.Tensor,
image_grid_thw: Union[torch.Tensor, Sequence[Any]],
) -> Dict[str, Any]:
vision_outputs = self.vision_tower(
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)
projector_outputs = self.projector(
visual_embeds=vision_outputs["visual_embeds"],
image_grid_thw=vision_outputs["image_grid_thw"],
)
return {
**vision_outputs,
**projector_outputs,
}
def encode_images(
self,
images: Any,
image_processor: Optional[PaddleOCRVLImageProcessor] = None,
**processor_kwargs: Any,
) -> Dict[str, Any]:
vision_outputs = self.vision_tower.encode_images(
images=images,
image_processor=image_processor,
**processor_kwargs,
)
projector_outputs = self.projector(
visual_embeds=vision_outputs["visual_embeds"],
image_grid_thw=vision_outputs["image_grid_thw"],
)
return {**vision_outputs, **projector_outputs}