| | """Processor class for MarkupDM.""" |
| |
|
| | import math |
| | import re |
| | import shutil |
| | import subprocess |
| | import tempfile |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import torch |
| | from .fonts import FontManager |
| | from PIL import Image, ImageDraw |
| | from transformers import ( |
| | ImageProcessingMixin, |
| | PreTrainedModel, |
| | PreTrainedTokenizerBase, |
| | ProcessorMixin, |
| | ) |
| | from transformers.utils import logging |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | MAXIMUM_DECODE_IMAGE_SIZE = 4096 |
| | IMG_FORMAT = "{:03d}.png" |
| | FONT_FORMAT = "{:03d}.ttf" |
| |
|
| |
|
| | class MarkupDMProcessor(ProcessorMixin): |
| | attributes = ["tokenizer", "image_processor"] |
| |
|
| | |
| | tokenizer_class = "AutoTokenizer" |
| | tokenizer: PreTrainedTokenizerBase |
| |
|
| | |
| | image_processor_class = "AutoImageProcessor" |
| | image_processor: ImageProcessingMixin |
| |
|
| | def __init__( |
| | self, |
| | tokenizer: PreTrainedTokenizerBase, |
| | image_processor: ImageProcessingMixin, |
| | ): |
| | super().__init__(tokenizer, image_processor) |
| |
|
| | |
| | if "<begin_of_image>" not in tokenizer.additional_special_tokens: |
| | self.extend_base_tokenizer(self.tokenizer) |
| |
|
| | |
| | boi = "<begin_of_image>" |
| | img_sep = "<image_sep>" |
| | self.re_img_size = re.compile(rf"{boi}(\d+){img_sep}(\d+){img_sep}") |
| | self.re_svg_width = re.compile(r'<svg[^>]*\bwidth="(\d+)"[^>]*>') |
| | self.re_svg_height = re.compile(r'<svg[^>]*\bheight="(\d+)"[^>]*>') |
| |
|
| | |
| | self.font_manager = None |
| |
|
| | def extend_base_tokenizer(self, tokenizer: PreTrainedTokenizerBase) -> None: |
| | logger.info("Extending tokenizer...") |
| | tokenizer.clean_up_tokenization_spaces = False |
| |
|
| | |
| | additional_special_tokens = [ |
| | "<begin_of_image>", |
| | "<end_of_image>", |
| | "<image_sep>", |
| | "<image_token>", |
| | ] |
| | logger.info(f"Add special tokens: {additional_special_tokens}") |
| | tokenizer.add_special_tokens( |
| | {"additional_special_tokens": additional_special_tokens}, |
| | replace_additional_special_tokens=False, |
| | ) |
| |
|
| | def __call__( |
| | self, |
| | svg: str | None = None, |
| | images: list[Image.Image] | None = None, |
| | filenames: list[str] | None = None, |
| | vision_model: PreTrainedModel | None = None, |
| | ) -> dict: |
| | |
| | if not isinstance(images, list): |
| | images = [images] |
| |
|
| | if len(images) > 0 and images[0] is not None: |
| | output = self.preprocess_images(images) |
| | output = self.encode_images(output, vision_model) |
| | else: |
| | output = {"width": [], "height": [], "image_ids": []} |
| |
|
| | |
| | output.update({"svg": svg, "filenames": filenames}) |
| | output = self.tokenize_example(output) |
| |
|
| | return output |
| |
|
| | def preprocess_images(self, images: list[Image.Image]) -> dict: |
| | assert images is not None, "Images must be provided." |
| | output: dict = {"image": [], "width": [], "height": []} |
| |
|
| | for image in images: |
| | processed = self.image_processor(image) |
| | for key, value in processed.items(): |
| | output[key].append(value) |
| |
|
| | |
| | output["image"] = torch.stack(output["image"]) |
| |
|
| | return output |
| |
|
| | def encode_images(self, example: dict, vision_model: PreTrainedModel) -> dict: |
| | if "images" in example and "width" not in example: |
| | example = self.preprocess_images(example["images"]) |
| |
|
| | assert vision_model is not None, "Vision model must be provided." |
| | image = example.pop("image") |
| | image = image.to(dtype=vision_model.dtype, device=vision_model.device) |
| | with torch.inference_mode(): |
| | _, _, (_, _, image_ids) = vision_model.model.encode(image) |
| | example["image_ids"] = list(image_ids.view(image.size(0), -1).cpu()) |
| |
|
| | return example |
| |
|
| | def tokenize_example(self, example: dict) -> dict: |
| | |
| | for key in ["svg", "filenames", "width", "height", "image_ids"]: |
| | msg = f"Missing key: {key}." |
| | if key in ["width", "height", "image_ids"]: |
| | msg += " Images must be encoded first using `encode_images`." |
| | assert example.get(key, None) is not None, msg |
| |
|
| | tokenizer = self.tokenizer |
| | bos_id = tokenizer.bos_token_id |
| | eos_id = tokenizer.eos_token_id |
| | bos_id = bos_id if bos_id is not None else eos_id |
| | boi_id = tokenizer.convert_tokens_to_ids("<begin_of_image>") |
| | eoi_id = tokenizer.convert_tokens_to_ids("<end_of_image>") |
| | img_sep_id = tokenizer.convert_tokens_to_ids("<image_sep>") |
| |
|
| | |
| | name2token = {} |
| | for filename, image_ids, width, height in zip( |
| | example["filenames"], |
| | example["image_ids"], |
| | example["width"], |
| | example["height"], |
| | ): |
| | _image_ids = (image_ids + len(tokenizer)).tolist() |
| | W_tokens = tokenizer.encode(str(width)) |
| | H_tokens = tokenizer.encode(str(height)) |
| |
|
| | |
| | image_tokens = [ |
| | boi_id, |
| | *W_tokens, |
| | img_sep_id, |
| | *H_tokens, |
| | img_sep_id, |
| | *_image_ids, |
| | eoi_id, |
| | ] |
| |
|
| | name2token[filename] = image_tokens |
| |
|
| | |
| | |
| | tokens = [bos_id] |
| | svg = example["svg"] |
| | while svg: |
| | |
| | start, end = len(svg), len(svg) |
| | for name in name2token.keys(): |
| | _start = svg.find(name) |
| | if -1 < _start and _start < start: |
| | start = _start |
| | end = start + len(name) |
| |
|
| | |
| | tokens += tokenizer.encode(svg[:start]) |
| |
|
| | |
| | if start < end: |
| | tokens += name2token[svg[start:end]] |
| |
|
| | |
| | svg = svg[end:] |
| |
|
| | tokens.append(eos_id) |
| |
|
| | |
| | input_ids = torch.tensor(tokens) |
| | image_mask = input_ids >= len(tokenizer) |
| |
|
| | |
| | image_pos_ids = torch.zeros_like(input_ids) |
| | if len(example["image_ids"]) > 0: |
| | length = example["image_ids"][0].size(0) |
| | num_images = sum(image_mask) // length |
| | image_pos_ids[image_mask] = torch.arange(length).repeat(num_images) |
| |
|
| | return { |
| | "input_ids": input_ids, |
| | "image_mask": image_mask, |
| | "image_pos_ids": image_pos_ids, |
| | } |
| |
|
| | def decode( |
| | self, |
| | tokens: torch.Tensor | np.ndarray, |
| | vision_model: PreTrainedModel | None = None, |
| | ) -> dict: |
| | tokenizer = self.tokenizer |
| | bos = tokenizer.bos_token |
| | eos = tokenizer.eos_token |
| | bos = bos if bos is not None else eos |
| |
|
| | |
| | msg = "Should be reverted from FIM format before decoding." |
| | for fim_type in ["prefix", "middle", "suffix"]: |
| | token_id = tokenizer.convert_tokens_to_ids(f"<fim_{fim_type}>") |
| | if token_id is None: |
| | token_id = tokenizer.convert_tokens_to_ids(f"<|fim_{fim_type}|>") |
| | assert token_id is not None, f"{fim_type} token not found" |
| | assert token_id not in tokens, msg |
| |
|
| | tokens = torch.asarray(tokens).detach().cpu() |
| | assert tokens.ndim == 1, "Tokens must be 1D." |
| | boi_id = tokenizer.convert_tokens_to_ids("<begin_of_image>") |
| | eoi_id = tokenizer.convert_tokens_to_ids("<end_of_image>") |
| |
|
| | |
| | svg = "" |
| | images: list = [] |
| | filenames: list = [] |
| | while len(tokens) > 0: |
| | |
| | boi_idx = torch.where(tokens == boi_id)[0] |
| | eoi_idx = torch.where(tokens == eoi_id)[0] |
| | if boi_idx.size(0) > 0: |
| | start = int(boi_idx[0].item()) |
| | end = int(eoi_idx[0].item()) + 1 if eoi_idx.size(0) > 0 else len(tokens) |
| | assert start < end, "Invalid image tokens." |
| | else: |
| | start, end = len(tokens), len(tokens) |
| |
|
| | |
| | svg += tokenizer.decode(tokens[:start]) |
| |
|
| | |
| | if start < end: |
| | |
| | image_tokens = tokens[start:end] |
| | image_text = tokenizer.decode(image_tokens) |
| | matched = self.re_img_size.match(image_text) |
| | if matched is not None: |
| | width, height = map(int, matched.groups()) |
| | else: |
| | width = self.image_processor.size |
| | height = self.image_processor.size |
| |
|
| | |
| | image_mask = image_tokens >= len(tokenizer) |
| | image_ids = image_tokens[image_mask] - len(tokenizer) |
| | image = self.decode_image(vision_model, image_ids, width, height) |
| | filename = IMG_FORMAT.format(len(images)) |
| | svg += filename |
| |
|
| | images.append(image) |
| | filenames.append(filename) |
| |
|
| | |
| | tokens = tokens[end:] |
| |
|
| | |
| | svg = re.sub(rf"({re.escape(bos)})+", bos, svg) |
| | svg = re.sub(rf"({re.escape(eos)})+", eos, svg) |
| |
|
| | |
| | i_bos = svg.find(bos) |
| | svg = svg[i_bos + len(bos) :] if i_bos > -1 else svg |
| | i_eos = svg.find(eos, i_bos + 1) |
| | svg = svg[:i_eos] if i_eos > -1 else svg |
| |
|
| | return {"svg": svg, "images": images, "filenames": filenames} |
| |
|
| | def decode_image( |
| | self, |
| | vision_model: PreTrainedModel | None = None, |
| | image_ids: torch.Tensor | np.ndarray | None = None, |
| | width: int | None = None, |
| | height: int | None = None, |
| | dummy_color: tuple[int, int, int, int] = (200,) * 4, |
| | pad_value: int = 0, |
| | ) -> Image.Image: |
| | |
| | width = width or self.image_processor.size |
| | height = height or self.image_processor.size |
| | width, height = self.compute_safe_image_size(width, height) |
| |
|
| | if vision_model is None and image_ids is None: |
| | |
| | return Image.new("RGBA", (width, height), dummy_color) |
| |
|
| | |
| | assert vision_model is not None, "Vision model must be provided." |
| | scale_factor = 2 ** (vision_model.model.encoder.num_resolutions - 1) |
| | latent_size = self.image_processor.size // scale_factor |
| | required_length = latent_size**2 |
| |
|
| | |
| | image_ids = torch.asarray(image_ids, device=vision_model.device) |
| | code_length = image_ids.shape[0] |
| | if code_length < required_length: |
| | pad_size = required_length - code_length |
| | pad = torch.full((pad_size,), pad_value).to(image_ids) |
| | image_ids = torch.cat([image_ids, pad]) |
| |
|
| | |
| | with torch.inference_mode(): |
| | codebook_entry = vision_model.model.quantize.get_codebook_entry( |
| | image_ids, (1, latent_size, latent_size, -1) |
| | ) |
| | recon = vision_model.model.decode(codebook_entry)[0].float() |
| |
|
| | |
| | img = self.image_processor.postprocess( |
| | recon, self.image_processor.size, self.image_processor.size |
| | ) |
| |
|
| | |
| | if code_length < required_length: |
| | img = self.mask_padded_area(img, code_length, scale_factor) |
| |
|
| | |
| | img = img.resize((width, height), resample=self.image_processor.resample) |
| |
|
| | return img |
| |
|
| | def compute_safe_image_size(self, width: int, height: int) -> tuple[int, int]: |
| | long_edge = max(width, height) |
| | if MAXIMUM_DECODE_IMAGE_SIZE < long_edge: |
| | scale = MAXIMUM_DECODE_IMAGE_SIZE / long_edge |
| | width = min(max(int(width * scale), 1), MAXIMUM_DECODE_IMAGE_SIZE) |
| | height = min(max(int(height * scale), 1), MAXIMUM_DECODE_IMAGE_SIZE) |
| | return width, height |
| |
|
| | def mask_padded_area( |
| | self, |
| | img: Image.Image, |
| | code_length: int, |
| | scale_factor: int, |
| | fill: tuple[int, int, int, int] = (200, 200, 200, 255), |
| | ) -> Image.Image: |
| | draw = ImageDraw.Draw(img, mode="RGBA") |
| | width, height = img.size |
| | zw = math.ceil(width / scale_factor) |
| | cw = code_length % zw |
| | ch = code_length // zw |
| | draw.polygon( |
| | [ |
| | (cw * scale_factor, ch * scale_factor), |
| | (width, ch * scale_factor), |
| | (width, height), |
| | (0, height), |
| | (0, (ch + 1) * scale_factor), |
| | (cw * scale_factor, (ch + 1) * scale_factor), |
| | ], |
| | fill=fill, |
| | ) |
| | return img |
| |
|
| | def set_font_manager(self, fonts_path: str | None = None) -> None: |
| | self.font_manager = FontManager(fonts_path) |
| |
|
| | def render_preprocess(self, example: dict, out_dir: str | Path) -> None: |
| | msg = "Font manager is not set. Call `set_font_manager` first." |
| | assert self.font_manager is not None, msg |
| |
|
| | out_dir = Path(out_dir) |
| | out_dir.mkdir(parents=True, exist_ok=True) |
| | svg = example["svg"] |
| |
|
| | |
| | found = set() |
| | style_text = "text{dominant-baseline:text-before-edge}" |
| | for i, text_str in enumerate(re.findall("<text[^>]*>", svg)): |
| | matched = re.search('font-family="([^"]*)"', text_str) |
| | if matched is None: |
| | logger.warning(f"Font family not found in {text_str}") |
| | continue |
| |
|
| | |
| | font_family = matched.group(1) |
| | is_bold = 'font-weight="bold"' in text_str |
| | is_italic = 'font-style="italic"' in text_str |
| | font_weight = "bold" if is_bold else "regular" |
| | if is_italic: |
| | font_style = "bolditalic" if is_bold else "italic" |
| | else: |
| | font_style = font_weight |
| | key = (font_family, font_weight, font_style) |
| | if key in found: |
| | continue |
| |
|
| | font_bytes = self.font_manager.lookup( |
| | font_family=font_family, |
| | font_weight=font_weight, |
| | font_style=font_style, |
| | ) |
| |
|
| | |
| | font_path = FONT_FORMAT.format(i) |
| | font_face = "@font-face{" |
| | font_face += f"font-family:'{font_family}';" |
| | font_face += f"font-weight:{font_weight};" |
| | font_face += f"font-style:{font_style};" |
| | font_face += f"src:url('{font_path}');" |
| | font_face += "}" |
| | style_text += font_face |
| |
|
| | |
| | Path(f"{out_dir}/{font_path}").write_bytes(font_bytes) |
| | found.add(key) |
| |
|
| | |
| | matched = re.search("<svg[^>]*>", svg) |
| | assert matched is not None, "SVG tag not found" |
| | i = matched.span()[1] |
| | style = f"<style>{style_text}</style>" |
| | example["svg"] = svg[:i] + style + svg[i:] |
| |
|
| | def render(self, example: dict, save_dir: str | Path | None = None) -> Image.Image: |
| | with tempfile.TemporaryDirectory() as tmp_dir: |
| | self.render_preprocess(example, tmp_dir) |
| |
|
| | |
| | matched = self.re_svg_width.search(example["svg"]) |
| | assert matched is not None, "Width not found in SVG." |
| | width = int(matched.group(1)) |
| | matched = self.re_svg_height.search(example["svg"]) |
| | assert matched is not None, "Height not found in SVG." |
| | height = int(matched.group(1)) |
| |
|
| | |
| | html = '<!DOCTYPE html><html><body style="margin: 0px">' |
| | html += f"{example['svg']}</body></html>" |
| |
|
| | |
| | Path(f"{tmp_dir}/index.html").write_text(html, encoding="utf-8") |
| |
|
| | |
| | for img, filename in zip(example["images"], example["filenames"]): |
| | Path(f"{tmp_dir}/{filename}").parent.mkdir(parents=True, exist_ok=True) |
| | img.save(f"{tmp_dir}/{filename}") |
| |
|
| | |
| | command = [ |
| | "google-chrome", |
| | "--headless", |
| | "--disable-web-security", |
| | "--allow-running-insecure-content", |
| | "--no-sandbox", |
| | "--disable-infobars", |
| | "--hide-scrollbars", |
| | "--disable-dev-shm-usage", |
| | "--no-zygote", |
| | f"--window-size={width},{height}", |
| | f"--screenshot={tmp_dir}/screenshot.png", |
| | f"{tmp_dir}/index.html", |
| | ] |
| | subprocess.run(command, check=True, stderr=subprocess.DEVNULL) |
| |
|
| | |
| | out = Image.open(f"{tmp_dir}/screenshot.png") |
| | size = (width, height) |
| | out = out.resize(size, resample=Image.Resampling.LANCZOS) |
| |
|
| | |
| | if save_dir is not None: |
| | shutil.copytree(tmp_dir, save_dir, dirs_exist_ok=True) |
| |
|
| | return out |
| |
|