| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| import warnings |
| import random |
| from typing import List, Optional, Union, Dict, Any |
| from collections import defaultdict |
| from copy import deepcopy |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer |
| from diffusers.utils import BaseOutput |
|
|
|
|
| def default(value, default_value): |
| return value if value is not None else default_value |
|
|
|
|
| def ensure_list(value): |
| if value is None: |
| return [] |
| if isinstance(value, (list, tuple)): |
| return list(value) |
| return [value] |
|
|
|
|
| class Resolution(object): |
| def __init__(self, size, *args): |
| if isinstance(size, str): |
| if 'x' in size: |
| size = size.split('x') |
| size = (int(size[0]), int(size[1])) |
| else: |
| size = int(size) |
| if len(args) > 0: |
| size = (size, args[0]) |
| if isinstance(size, int): |
| size = (size, size) |
|
|
| self.h = self.height = size[0] |
| self.w = self.width = size[1] |
| self.r = self.ratio = self.height / self.width |
|
|
| def __getitem__(self, idx): |
| if idx == 0: |
| return self.h |
| elif idx == 1: |
| return self.w |
| else: |
| raise IndexError(f'Index {idx} out of range') |
|
|
| def __str__(self): |
| return f'{self.h}x{self.w}' |
|
|
|
|
| class ResolutionGroup(object): |
| def __init__( |
| self, |
| base_size=None, |
| step=None, |
| align=1, |
| min_multiple=0.5, |
| max_multiple=2.0, |
| max_entries=33, |
| presets=None, |
| ): |
| self.align = int(align) |
| if self.align <= 0: |
| raise ValueError(f'align must be positive, got {align}') |
| self.base_size = base_size |
| if base_size is None or not isinstance(base_size, int): |
| raise ValueError(f'base_size must be an int, but got {base_size!r}') |
| if base_size % self.align != 0: |
| raise ValueError(f'base_size {base_size} is not divisible by align {self.align}') |
|
|
| self.min_multiple = float(min_multiple) |
| self.max_multiple = float(max_multiple) |
| if not (0 < self.min_multiple < self.max_multiple): |
| raise ValueError( |
| f"min_multiple ({self.min_multiple}) must be positive and smaller than max_multiple ({self.max_multiple})" |
| ) |
|
|
| self.max_entries = max_entries if max_entries is None else int(max_entries) |
| if self.max_entries is not None and self.max_entries <= 0: |
| raise ValueError(f'max_entries must be positive when provided, got {self.max_entries}') |
|
|
| if step is None: |
| step = max(self.align, base_size // 16) |
| else: |
| if step <= 0: |
| raise ValueError(f'step must be positive, got {step}') |
| step = max(self.align, int(math.ceil(step / self.align)) * self.align) |
| if step > base_size * max(1, int(self.max_multiple * 2)): |
| raise ValueError( |
| f'step {step} is too large for base_size {base_size} and max_multiple {self.max_multiple}' |
| ) |
|
|
| min_height = max(self.align, int(math.ceil(base_size * self.min_multiple / self.align)) * self.align) |
| min_width = min_height |
| max_height = int(math.floor(base_size * self.max_multiple / self.align)) * self.align |
| max_width = max_height |
| if min_height >= max_height: |
| raise ValueError( |
| f'min height ({min_height}) must be smaller than max height ({max_height})' |
| ) |
|
|
| span_up = max_height - base_size |
| span_down = base_size - min_height |
| if self.max_entries is not None: |
| allowed_steps = max(self.max_entries - 1, 0) |
|
|
| def _steps_for(candidate: int) -> int: |
| up_steps = math.ceil(span_up / candidate) if span_up else 0 |
| down_steps = math.ceil(span_down / candidate) if span_down else 0 |
| return up_steps + down_steps |
|
|
| candidate_step = step |
| while _steps_for(candidate_step) > allowed_steps: |
| candidate_step += self.align |
| step = candidate_step |
|
|
| self.step = step |
|
|
| self.min_height = min_height |
| self.min_width = min_width |
| self.max_height = max_height |
| self.max_width = max_width |
| if self.min_height >= self.max_height: |
| raise ValueError( |
| f'min height ({self.min_height}) must be smaller than max height ({self.max_height})' |
| ) |
|
|
| self.presets = presets |
| if presets: |
| self.data = self._build_from_presets(presets) |
| else: |
| self.data = self._calc_by_step() |
|
|
| if not self.data: |
| raise ValueError('ResolutionGroup has no valid entries') |
|
|
| self.ratio = np.array([x.ratio for x in self.data]) |
| self.attr = ['' for _ in range(len(self.data))] |
| self.prefix_space = 0 |
| self._lookup = {(res.height, res.width): res for res in self.data} |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| return self.data[idx] |
|
|
| def __repr__(self): |
| prefix = self.prefix_space * ' ' |
| prefix_close = (self.prefix_space - 4) * ' ' |
| res_str = f'ResolutionGroup(base_size={self.base_size}, step={self.step}, data=' |
| attr_maxlen = max([len(x) for x in self.attr] + [5]) |
| res_str += \ |
| f'\n{prefix}ID: height width ratio {" " * max(0, attr_maxlen - 4)}count h/16 w/16 tokens\n{prefix}' |
| res_str += \ |
| ('\n' + prefix).join([f'{i:2d}: ({x.h:4d}, {x.w:4d}) {self.ratio[i]:.4f} {self.attr[i]:>{attr_maxlen}s} ' |
| f'({x.h // 16:3d}, {x.w // 16:3d}) {x.h // 16 * x.w // 16:6d}' |
| for i, x in enumerate(self.data)]) |
| res_str += f'\n{prefix_close})' |
| return res_str |
|
|
| def _calc_by_step(self): |
| assert self.align <= self.step, f'align {self.align} must be smaller than step {self.step}' |
|
|
| min_height = self.min_height |
| min_width = self.min_width |
| max_height = self.max_height |
| max_width = self.max_width |
|
|
| resolutions = [Resolution(self.base_size, self.base_size)] |
|
|
| cur_height, cur_width = self.base_size, self.base_size |
| while True: |
| if cur_height >= max_height and cur_width <= min_width: |
| break |
|
|
| cur_height = min(cur_height + self.step, max_height) |
| cur_width = max(cur_width - self.step, min_width) |
| resolutions.append(Resolution(cur_height // self.align * self.align, cur_width // self.align * self.align)) |
|
|
| cur_height, cur_width = self.base_size, self.base_size |
| while True: |
| if cur_height <= min_height and cur_width >= max_width: |
| break |
|
|
| cur_height = max(cur_height - self.step, min_height) |
| cur_width = min(cur_width + self.step, max_width) |
| resolutions.append(Resolution(cur_height // self.align * self.align, cur_width // self.align * self.align)) |
|
|
| resolutions = sorted(resolutions, key=lambda x: x.ratio) |
| return self._enforce_entry_limit(resolutions) |
|
|
| def _build_from_presets(self, presets): |
| resolutions = [] |
| seen = set() |
| for entry in presets: |
| if isinstance(entry, str): |
| if 'x' not in entry: |
| continue |
| parts = entry.split('x') |
| if len(parts) != 2: |
| continue |
| try: |
| height = int(parts[0]) |
| width = int(parts[1]) |
| except ValueError: |
| continue |
| elif isinstance(entry, (list, tuple)) and len(entry) == 2: |
| height, width = entry |
| try: |
| height = int(height) |
| width = int(width) |
| except ValueError: |
| continue |
| else: |
| continue |
|
|
| height = max(self.align, int(round(height / self.align)) * self.align) |
| width = max(self.align, int(round(width / self.align)) * self.align) |
|
|
| if not (self.min_height <= height <= self.max_height and self.min_width <= width <= self.max_width): |
| continue |
|
|
| key = (height, width) |
| if key in seen: |
| continue |
| seen.add(key) |
| resolutions.append(Resolution(height, width)) |
|
|
| if not any(reso.height == self.base_size and reso.width == self.base_size for reso in resolutions): |
| resolutions.append(Resolution(self.base_size, self.base_size)) |
|
|
| resolutions = sorted(resolutions, key=lambda x: x.ratio) |
| return self._enforce_entry_limit(resolutions) |
|
|
| def _enforce_entry_limit(self, resolutions): |
| if self.max_entries is None or len(resolutions) <= self.max_entries: |
| return resolutions |
|
|
| step = (len(resolutions) - 1) / (self.max_entries - 1) if self.max_entries > 1 else 1 |
| indices = [] |
| for i in range(self.max_entries): |
| idx = int(round(i * step)) |
| idx = max(0, min(idx, len(resolutions) - 1)) |
| if indices and idx <= indices[-1]: |
| idx = indices[-1] + 1 |
| if idx >= len(resolutions): |
| idx = len(resolutions) - 1 |
| indices.append(idx) |
|
|
| indices = sorted(set(indices)) |
| for idx in range(len(resolutions)): |
| if len(indices) >= self.max_entries: |
| break |
| if idx not in indices: |
| indices.append(idx) |
| indices.sort() |
|
|
| base_index = next((i for i, reso in enumerate(resolutions) |
| if reso.height == self.base_size and reso.width == self.base_size), None) |
| if base_index is not None and base_index not in indices: |
| indices[len(indices) // 2] = base_index |
| indices.sort() |
|
|
| return [resolutions[i] for i in indices[:self.max_entries]] |
|
|
| def get_target_size(self, width, height): |
| key = (int(round(height)), int(round(width))) |
| reso = self._lookup.get(key) |
| if reso is None: |
| ratio = height / width |
| idx = np.argmin(np.abs(self.ratio - ratio)) |
| reso = self.data[idx] |
| return reso.w, reso.h |
|
|
| def get_base_size_and_ratio_index(self, width, height): |
| key = (int(round(height)), int(round(width))) |
| reso = self._lookup.get(key) |
| if reso is not None: |
| return self.base_size, self.data.index(reso) |
|
|
| ratio = height / width |
| idx = np.argmin(np.abs(self.ratio - ratio)) |
| return self.base_size, idx |
|
|
|
|
| class ImageInfo: |
| """ Class to store image information for processing and generation. """ |
|
|
| def __init__( |
| self, |
| image_type: str = None, |
| image_tensor: torch.Tensor = None, |
| image_width: int = None, |
| image_height: int = None, |
| token_width: int = None, |
| token_height: int = None, |
| image_token_length: int = None, |
| base_size: int = None, |
| ratio_index: int = None, |
| **kwargs, |
| ): |
| self.image_type = image_type |
| self.image_tensor = image_tensor |
| self.image_width = image_width |
| self.w = image_width |
| self.image_height = image_height |
| self.h = image_height |
| self.token_width = token_width |
| self.tk_w = token_width |
| self.token_height = token_height |
| self.tk_h = token_height |
| self.image_token_length = default( |
| image_token_length, |
| token_width * token_height if token_width is not None and token_height is not None else None |
| ) |
| self.base_size = base_size |
| self.ratio_index = ratio_index |
|
|
| self.add_timestep_token = kwargs.get("add_timestep_token", True) |
| self.add_guidance_token = kwargs.get("add_guidance_token", False) |
| self.use_front_boi_token = kwargs.get("use_front_boi_token", True) |
| self.add_image_shape_token = kwargs.get("add_image_shape_token", True) |
|
|
| def __getitem__(self, key: str) -> Any: |
| """Allow dictionary-like access to attributes.""" |
| if hasattr(self, key): |
| return getattr(self, key) |
| raise KeyError(f"Key '{key}' not found in ImageInfo") |
|
|
| def __setitem__(self, key: str, value: Any) -> None: |
| """Allow dictionary-like assignment to attributes.""" |
| if hasattr(self, key): |
| setattr(self, key, value) |
| else: |
| raise KeyError(f"Key '{key}' not found in ImageInfo") |
|
|
| def __contains__(self, key: str) -> bool: |
| """Check if the key exists in the ImageInfo object.""" |
| return hasattr(self, key) |
|
|
| def __repr__(self): |
| return (f"ImageInfo(image_type={self.image_type}, image_tensor={self.image_tensor}, " |
| f"image_width={self.image_width}, image_height={self.image_height}, " |
| f"token_width={self.token_width}, token_height={self.token_height}, " |
| f"image_token_length={self.image_token_length}, " |
| f"base_size={self.base_size}, ratio_index={self.ratio_index}") |
|
|
| @property |
| def meta_info(self): |
| |
| if self.image_type in ["vae", "gen_image"]: |
| return dict( |
| token_length=self.image_token_length, |
| add_timestep_token=self.add_timestep_token, |
| add_guidance_token=self.add_guidance_token, |
| use_front_boi_token=self.use_front_boi_token, |
| add_image_shape_token=self.add_image_shape_token, |
| base_size=self.base_size, |
| ratio_idx=self.ratio_index, |
| |
| token_height=self.token_height, |
| token_width=self.token_width, |
| |
| image_height=self.image_height, |
| image_width=self.image_width, |
| ) |
| elif self.image_type in ["vit"]: |
| return dict( |
| token_length=self.image_token_length, |
| use_front_boi_token=self.use_front_boi_token, |
| add_image_shape_token=self.add_image_shape_token, |
| |
| token_height=self.token_height, |
| token_width=self.token_width, |
| |
| image_height=self.image_height, |
| image_width=self.image_width, |
| ) |
| else: |
| raise ValueError(f"Unknown image type '{self.image_type}'") |
|
|
| @property |
| def num_special_tokens(self): |
| if self.args is None: |
| raise ValueError("meta_info requires `args` attribute to be set.") |
| if self.image_type in ["vae", "src_image", "gen_image"]: |
| count = ( |
| 2 + |
| (1 if self.add_timestep_token else 0) + |
| (1 if self.add_guidance_token else 0) + |
| (2 if self.add_image_shape_token else 0) |
| ) |
| else: |
| raise ValueError(f"Unknown image_type: {self.image_type}") |
| return count |
|
|
| def copy(self, copy_image_tensor=True): |
| if copy_image_tensor and self.image_tensor is None: |
| raise ValueError("image_tensor is None, cannot copy") |
| return ImageInfo( |
| image_type=self.image_type, |
| image_tensor=self.image_tensor.clone() if copy_image_tensor else None, |
| image_width=self.image_width, |
| image_height=self.image_height, |
| token_width=self.token_width, |
| token_height=self.token_height, |
| image_token_length=self.image_token_length, |
| base_size=self.base_size, |
| ratio_index=self.ratio_index, |
| ) |
|
|
| def zeros_(self): |
| self.image_tensor = torch.zeros_like(self.image_tensor) |
|
|
|
|
| class ImageTensor(torch.Tensor): |
| |
| |
| i: ImageInfo |
| vision_encoder_kwargs: dict |
|
|
|
|
| class JointImageInfo(object): |
| def __init__(self, vae_image_info: ImageInfo, vision_image_info: ImageInfo, vision_encoder_kwargs: dict = None): |
| self.vae_image_info = vae_image_info |
| self.vision_image_info = vision_image_info |
| self.vision_encoder_kwargs = vision_encoder_kwargs |
|
|
| |
| self.image_type = "joint_image" |
| self.image_token_length = vae_image_info.image_token_length + vision_image_info.image_token_length |
|
|
| self.add_timestep_token = vae_image_info.add_timestep_token |
| self.use_front_boi_token = vae_image_info.use_front_boi_token |
| self.add_image_shape_token = vae_image_info.add_image_shape_token |
|
|
| def __repr__(self): |
| return f"JointImageInfo(vae_image={self.vae_image_info}, vision_image={self.vision_image_info})" |
|
|
| @property |
| def meta_info(self): |
| |
| return dict( |
| token_length=[self.vae_image_info.image_token_length, self.vision_image_info.image_token_length], |
| add_timestep_token=self.add_timestep_token, |
| use_front_boi_token=self.use_front_boi_token, |
| add_image_shape_token=self.add_image_shape_token, |
| base_size=self.vae_image_info.base_size, |
| ratio_idx=self.vae_image_info.ratio_index, |
| |
| token_height=[self.vae_image_info.token_height, self.vision_image_info.token_height], |
| token_width=[self.vae_image_info.token_width, self.vision_image_info.token_width], |
| |
| image_height=[self.vae_image_info.image_height, self.vision_image_info.image_height], |
| image_width=[self.vae_image_info.image_width, self.vision_image_info.image_width], |
| ) |
|
|
| @property |
| def num_special_tokens(self): |
| return ( |
| 2 + |
| (1 if self.add_timestep_token else 0) + |
| (2 if self.add_image_shape_token else 0) + |
| 1 |
| ) |
|
|
| def copy(self, copy_image_tensor=True): |
| if copy_image_tensor and ( |
| self.vae_image_info.image_tensor is None or self.vision_image_info.image_tensor is None): |
| raise ValueError("image_tensor is None, cannot copy") |
| return JointImageInfo( |
| self.vae_image_info.copy(copy_image_tensor), |
| self.vision_image_info.copy(copy_image_tensor), |
| self.vision_encoder_kwargs, |
| ) |
|
|
| def zeros_(self): |
| self.vae_image_info.zeros_() |
| self.vision_image_info.zeros_() |
|
|
|
|
| class JointImage(object): |
| def __init__(self, vae_image: ImageTensor, vision_image: ImageTensor): |
| self.vae_image = vae_image |
| self.vision_image = vision_image |
| self.i = JointImageInfo(vae_image.i, vision_image.i) |
|
|
|
|
| class TokenizerEncodeOutput(BaseOutput): |
| tokens: torch.Tensor = None |
| timestep_scatter_index: Optional[torch.Tensor] = None |
| guidance_scatter_index: Optional[torch.Tensor] = None |
| text_slices: Optional[List[slice]] = None |
| gen_image_slices: Optional[List[slice]] = None |
| joint_image_slices: Optional[List[slice]] = None |
| cond_vae_image_slices: Optional[List[slice]] = None |
| cond_vit_image_slices: Optional[List[slice]] = None |
| text_mask: Optional[torch.Tensor] = None |
| gen_image_mask: Optional[torch.Tensor] = None |
| cond_vae_image_mask: Optional[torch.Tensor] = None |
| cond_vit_image_mask: Optional[torch.Tensor] = None |
| real_pos: Optional[torch.Tensor] = None |
| all_image_slices: Optional[List[slice]] = None |
| cond_timestep_scatter_index: Optional[torch.Tensor] = None |
| gen_timestep_scatter_index: Optional[torch.Tensor] = None |
|
|
|
|
| class Conversation: |
| roles: List[str] = ["User", "Assistant"] |
| sep: str = "\n\n" |
|
|
|
|
| class TokenizerWrapper(object): |
| def __init__(self, tokenizer): |
| if isinstance(tokenizer, str): |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) |
| else: |
| self.tokenizer = tokenizer |
|
|
| |
| self.bos_token_id = self.tokenizer.bos_token_id |
| self.eos_token_id = self.tokenizer.eos_token_id |
| self.pad_token_id = self.tokenizer.pad_token_id |
| self.boi_token_id = self.tokenizer.convert_tokens_to_ids("<boi>") |
| self.eoi_token_id = self.tokenizer.convert_tokens_to_ids("<eoi>") |
| self.img_token_id = self.tokenizer.convert_tokens_to_ids("<img>") |
| self.cfg_token_id = self.tokenizer.convert_tokens_to_ids("<cfg>") |
| self.end_answer_token_id = self.tokenizer.convert_tokens_to_ids("</answer>") |
| self.end_recaption_token_id = self.tokenizer.convert_tokens_to_ids("</recaption>") |
| self.ratio_token_offset = self.tokenizer.convert_tokens_to_ids("<img_ratio_0>") |
| self.special_token_map = self.tokenizer.added_tokens_encoder |
|
|
| def pad(self, tensor_list, dim=0, pad_val=None): |
| if pad_val is None: |
| pad_val = self.pad_token_id |
| max_len = max([t.shape[dim] for t in tensor_list]) |
| padded_tensor_list = [] |
| for t in tensor_list: |
| if t.shape[dim] < max_len: |
| assert pad_val is not False, "Not allowed pad." |
| t = F.pad(t, (0, max_len - t.shape[dim]), value=pad_val) |
| padded_tensor_list.append(t) |
| return padded_tensor_list |
|
|
| def encode(self, *args, **kwargs): |
| return self.tokenizer.encode(*args, **kwargs) |
|
|
| def decode(self, *args, **kwargs): |
| return self.tokenizer.decode(*args, **kwargs) |
|
|
| def encode_text( |
| self, |
| *texts, |
| uncond_enabled: Optional[Union[bool, List[bool]]] = None, |
| uncond_p: Optional[float] = None, |
| max_length: Optional[int] = None, |
| pad: Optional[str] = None, |
| return_lengths: bool = False, |
| ): |
| """ |
| Encode text and image for AR-like model training of the text-to-image/instruction tuning tasks. |
| Support encode multiple texts at once. Each text can be separately conditioned or unconditioned |
| based on the uncond_flags and a uniform uncond_p. |
| **<bos> token is always prepended to the text tokens.** |
| |
| Parameters |
| ---------- |
| texts: str or List[str] |
| List of texts to be encoded. |
| uncond_enabled: bool or List[bool] |
| List of flags to indicate whether the text should be unconditioned. |
| If False, the text will never be unconditioned. |
| If True, the text will be unconditioned with uncond_p. |
| uncond_p: float |
| Probability to the unconditional text. Only works when uncond_enabled is True. |
| max_length: int |
| Maximum length of the encoded text. |
| pad: Optional[str] |
| Padding method. Can be 'left' or 'right'. |
| return_lengths: bool |
| Whether to return the length of each encoded text. |
| """ |
| if pad is not None: |
| assert max_length is not None, "max_length should be provided when pad is not None." |
|
|
| if uncond_enabled is None: |
| uncond_enabled = [True] * len(texts) |
| elif isinstance(uncond_enabled, bool): |
| uncond_enabled = [uncond_enabled] * len(texts) |
| if len(uncond_enabled) != len(texts): |
| print(uncond_enabled, texts) |
| assert len(uncond_enabled) == len(texts), ( |
| f"Length of uncond_flags should be equal to the number of texts, " |
| f"but got {len(uncond_enabled)} and {len(texts)}." |
| ) |
|
|
| |
| |
| |
| do_uncond_drop = (uncond_p is not None) and (random.random() < uncond_p) |
| text_tokens, lengths = [], [] |
| cum_length = 0 |
| for text, uncond_flag in zip(texts, uncond_enabled): |
| |
| if max_length is not None and cum_length >= max_length: |
| warnings.warn( |
| f"Text length exceeds the max_length({max_length}). The remaining texts will be ignored: " |
| f"{text[:80]}..." |
| ) |
| break |
| |
| if isinstance(text, str): |
| text_token = self.tokenizer.encode(text, add_special_tokens=False) |
| else: |
| text_token = text |
| if uncond_flag and do_uncond_drop: |
| text_token = [self.cfg_token_id] * len(text_token) |
| |
| if max_length is not None and (cum_length + len(text_token)) > max_length: |
| text_token = text_token[:max_length - cum_length] |
| text_tokens.extend(text_token) |
| lengths.append(len(text_token)) |
| cum_length += len(text_token) |
|
|
| |
| if pad is not None and (pad_length := max_length - len(text_tokens)) > 0: |
| if pad == 'left': |
| text_tokens = [self.pad_token_id] * pad_length + text_tokens |
| elif pad == 'right': |
| text_tokens = text_tokens + [self.pad_token_id] * pad_length |
| else: |
| raise ValueError(f"Unsupported padding method: {pad}.") |
|
|
| if return_lengths: |
| return text_tokens, lengths |
| return text_tokens |
|
|
| @staticmethod |
| def _check_key_number_matched(keys, data): |
| |
| assert set(keys) == set(data.keys()), ( |
| f"Keys in the template and token source should be matched, but got {set(keys)} and {list(data.keys())}." |
| ) |
| key_counts = {k: 0 for k in keys} |
| for key in keys: |
| key_counts[key] += 1 |
| for key, count in key_counts.items(): |
| assert len(data[key]) == count, ( |
| f"Number of `{key}` in the token source should be matched with the template, but got " |
| f"{data[key]}({len(data[key])}) and {count}." |
| ) |
|
|
| def _add_image_meta_info_token(self, token_seq, token_count, extra_token_pos, add_timestep_token=False, |
| add_image_shape_token=False, base_size=None, ratio_idx=None, image_type=None, |
| add_guidance_token=False): |
| if add_image_shape_token: |
| token_seq.extend([ |
| self.special_token_map[f"<img_size_{base_size}>"], |
| self.special_token_map[f"<img_ratio_{ratio_idx}>"] |
| ]) |
| token_count += 2 |
| if add_timestep_token: |
| token_seq.extend([self.special_token_map["<timestep>"]]) |
| extra_token_pos['timestep'].append(token_count) |
| if image_type is not None: |
| if image_type == "gen_image": |
| extra_token_pos['gen_timestep'].append(token_count) |
| elif image_type in ["joint_image"]: |
| extra_token_pos['cond_timestep'].append(token_count) |
| else: |
| raise ValueError(f"Unsupported image type: {image_type}.") |
| token_count += 1 |
| if add_guidance_token: |
| token_seq.extend([self.special_token_map["<guidance>"]]) |
| extra_token_pos['guidance'].append(token_count) |
| token_count += 1 |
| return token_count |
|
|
| @staticmethod |
| def _shorten_text(text): |
| import re |
| text = re.sub(r"(<img>)+", lambda m: f"[<img>]{{{len(m.group(0)) // 5}}}", text) |
| text = re.sub(r"(<pad>)+", lambda m: f"[<pad>]{{{len(m.group(0)) // 5}}}", text) |
| return text |
|
|
| def encode_sequence( |
| self, |
| template: str, |
| token_source: Dict[str, List], |
| total_length=None, |
| add_timestep_token=False, |
| add_guidance_token=False, |
| last_key_only_prefix=False, |
| add_eos=True, |
| use_front_boi_token=True, |
| add_pad=True, |
| add_bos=True, |
| drop_last: Union[str, bool] = 'auto', |
| add_image_shape_token=False, |
| ): |
| """ |
| Encode a sequence based on the template (e.g., `text-image` for t2i, `text-image-image` for instruction tuning) |
| and token source. |
| |
| Parameters |
| ---------- |
| template: str |
| Template of the sequence. E.g., "text-gen_image" means the sequence is composed of text and an image. |
| "text-text-gen_image" means the sequence is composed of two sections of text and an image. |
| token_source: Dict[str, List] |
| Token source for each key in the template, in order. |
| - text: List[Dict]. |
| - gen_image: List[Dict]. |
| - joint_image: List[Dict]. |
| total_length: int |
| Total length of the encoded sequence, include padding tokens. |
| add_timestep_token: bool |
| Whether to add timestep token before the image tokens. |
| (Right after the <img_ratio_*><img_size_*> tokens) |
| add_guidance_token: bool |
| Whether to add guidance token before the image tokens. |
| last_key_only_prefix: bool |
| Whether to only use the modal prefix in the last key. |
| add_eos: bool or 'auto' |
| Whether to add eos token at the end of the sequence. If True, always add eos token. If 'auto', |
| add eos token only when the total_length is not reached and the last token is not <eos>. |
| use_front_boi_token: bool: |
| Whether to put the <boi> token at the front of iw, ih and timestep tokens. |
| add_pad: bool or 'auto' |
| Whether to add padding tokens to the sequence. If True and total_length is not reached, add padding tokens. |
| add_bos: bool |
| Whether to add bos token at the beginning of the sequence. |
| drop_last: bool or 'auto' |
| - If auto, drop last tokens exceeding the total_length if the total_length is provided. If cut point is |
| in the middle of the image tokens, an error will raised. |
| - If True, drop last tokens exceeding the total_length. If cut point is in the middle of the image tokens, |
| all the successive image tokens will be dropped. |
| - If False, keep the last tokens exceeding the total_length, even if the total_length is reached. |
| add_image_shape_token: bool |
| Whether to add image shape token before the image tokens. (Right before the <timestep> token) |
| |
| Returns |
| ------- |
| token_seq: list |
| Encoded token sequence. |
| extra_token_pos: dict |
| Positions of extra tokens. |
| """ |
| if last_key_only_prefix: |
| assert add_eos is not True, "add_eos should not be True when last_key_only_prefix is True." |
| if drop_last is True and total_length is None: |
| raise ValueError("total_length should be provided when drop_last is True.") |
|
|
| keys = template.split('-') |
| modal_length = len(keys) |
| index_indicator = {k: 0 for k in token_source} |
| for k, v in token_source.items(): |
| assert isinstance(v, (list, tuple)), ( |
| f"Value of `{k}` in the token source should be a list or tuple, but got {type(v)}." |
| ) |
| self._check_key_number_matched(keys, token_source) |
|
|
| token_seq = [] |
| token_count = 0 |
| extra_token_pos = defaultdict(list) |
| if add_bos: |
| token_seq.append(self.bos_token_id) |
| token_count += 1 |
| |
| |
| |
| |
| |
| drop_last_break = False |
| for i, key in enumerate(keys): |
| source = token_source[key][index_indicator[key]] |
| if key == "text": |
| token_seq.extend(source) |
| extra_token_pos["<text>_start"].append(token_count) |
| token_count += len(source) |
| extra_token_pos["<text>_end"].append(token_count - 1) |
|
|
| elif key == "gen_image": |
| if isinstance(source, int): |
| source = {'length': source} |
| extra_count = 2 + ( |
| 1 if source.get('timestep', add_timestep_token) else 0) + ( |
| 1 if source.get('guidance', add_guidance_token) else 0) + ( |
| 2 if source.get('image_shape', add_image_shape_token) else 0 |
| ) |
| if drop_last is True and token_count + extra_count + source['length'] > total_length: |
| drop_last_break = True |
| break |
| if source.get('front_boi', use_front_boi_token): |
| token_seq.append(self.boi_token_id) |
| extra_token_pos["boi"].append(token_count) |
| token_count += 1 |
| token_count = self._add_image_meta_info_token( |
| token_seq=token_seq, |
| token_count=token_count, |
| extra_token_pos=extra_token_pos, |
| add_timestep_token=source.get('timestep', add_timestep_token), |
| add_guidance_token=source.get('guidance', add_guidance_token), |
| add_image_shape_token=source.get('image_shape', add_image_shape_token), |
| base_size=source.get('base_size'), |
| ratio_idx=source.get('ratio_idx'), |
| image_type=key, |
| ) |
| if not source.get('front_boi', use_front_boi_token): |
| token_seq.append(self.boi_token_id) |
| extra_token_pos["boi"].append(token_count) |
| token_count += 1 |
| if last_key_only_prefix and i == modal_length - 1: |
| pass |
| else: |
| token_seq.extend( |
| [self.img_token_id] * source['length'] + |
| [self.eoi_token_id] |
| ) |
| extra_token_pos["<img>_start"].append(token_count) |
| extra_token_pos["<all_img>_start"].append(token_count) |
| token_count += source['length'] |
| extra_token_pos["<img>_end"].append(token_count - 1) |
| extra_token_pos["<all_img>_end"].append(token_count - 1) |
| extra_token_pos["eoi"].append(token_count) |
| token_count += 1 |
|
|
| elif key == "joint_image": |
| assert isinstance(source['length'], list) and len( |
| source['length']) == 2, "joint_image length should be a list of two integers" |
| extra_count = 2 + 1 + ( |
| 1 if source.get('timestep', add_timestep_token) else 0) + ( |
| 2 if source.get('image_shape', add_image_shape_token) else 0 |
| ) |
| if drop_last is True and token_count + extra_count + sum(source['length']) > total_length: |
| drop_last_break = True |
| break |
| if source.get('front_boi', use_front_boi_token): |
| token_seq.append(self.boi_token_id) |
| extra_token_pos["boi"].append(token_count) |
| token_count += 1 |
| token_count = self._add_image_meta_info_token( |
| token_seq=token_seq, |
| token_count=token_count, |
| extra_token_pos=extra_token_pos, |
| add_timestep_token=source.get('timestep', add_timestep_token), |
| add_image_shape_token=source.get('image_shape', add_image_shape_token), |
| base_size=source.get('base_size'), |
| ratio_idx=source.get('ratio_idx'), |
| image_type=key, |
| ) |
| if not source.get('front_boi', use_front_boi_token): |
| token_seq.append(self.boi_token_id) |
| extra_token_pos["boi"].append(token_count) |
| token_count += 1 |
| if last_key_only_prefix and i == modal_length - 1: |
| pass |
| else: |
| token_seq.extend( |
| [self.img_token_id] * source['length'][0] |
| ) |
| extra_token_pos["<vae_img>_start"].append(token_count) |
| extra_token_pos["<joint_img>_start"].append(token_count) |
| extra_token_pos["<all_img>_start"].append(token_count) |
| token_count += source['length'][0] |
| extra_token_pos["<vae_img>_end"].append(token_count - 1) |
| extra_token_pos["<all_img>_end"].append(token_count - 1) |
|
|
| token_seq.extend( |
| [self.special_token_map["<joint_img_sep>"]] |
| ) |
| extra_token_pos["joint_img_sep"].append(token_count) |
| token_count += 1 |
|
|
| token_seq.extend( |
| [self.img_token_id] * source['length'][1] |
| ) |
| extra_token_pos["<vit_img>_start"].append(token_count) |
| extra_token_pos["<all_img>_start"].append(token_count) |
| token_count += source['length'][1] |
| extra_token_pos["<vit_img>_end"].append(token_count - 1) |
| extra_token_pos["<joint_img>_end"].append(token_count - 1) |
| extra_token_pos["<all_img>_end"].append(token_count - 1) |
|
|
| token_seq.extend( |
| [self.eoi_token_id] |
| ) |
| extra_token_pos["eoi"].append(token_count) |
| token_count += 1 |
|
|
| else: |
| raise ValueError(f"Not supported key: {key}") |
| index_indicator[key] += 1 |
|
|
| if add_eos is True and not drop_last_break: |
| |
| token_seq.append(self.eos_token_id) |
| extra_token_pos["eos"].append(token_count) |
| token_count += 1 |
| elif add_eos == 'auto' and not drop_last_break: |
| |
| if token_seq[-1] != self.eos_token_id and (total_length is None or token_count < total_length): |
| token_seq.append(self.eos_token_id) |
| extra_token_pos["eos"].append(token_count) |
| token_count += 1 |
|
|
| if total_length: |
| |
| if token_count > total_length and drop_last: |
| |
| for start_key, end_key in [ |
| ("<img>_start", "<img>_end"), ("<joint_img>_start", "<joint_img>_end"), |
| ("<vae_img>_start", "<vae_img>_end"), ("<vit_img>_start", "<vit_img>_end"), |
| ]: |
| if start_key in extra_token_pos and end_key in extra_token_pos: |
| assert all( |
| (start > total_length or end + 1 < total_length) |
| for start, end in zip(extra_token_pos[start_key], extra_token_pos[end_key]) |
| ), ("Clip position should not be in the middle of the image tokens.\n" |
| f"Below is the text:\n{self._shorten_text(self.tokenizer.decode(token_seq))}") |
| token_seq = token_seq[:total_length] |
|
|
| |
| pad_num = max(0, total_length - len(token_seq)) |
| if add_pad and pad_num: |
| token_seq.extend([self.pad_token_id] * pad_num) |
| extra_token_pos["first_pad"].append(token_count) |
|
|
| return token_seq, extra_token_pos |
|
|
| def batch_gen_infer( |
| self, |
| infer_fn, |
| prompt_list: list, |
| negative_prompt_list: list = None, |
| infer_fn_kwargs_list: List[Dict[str, int]] = None, |
| do_classifier_free_guidance=False, |
| condition_repeat_times: int = 1, |
| uncondition_repeat_times: int = 1, |
| ): |
| """ |
| Batch inference for the AR-like model training of the text-to-image/instruction tuning tasks. |
| |
| Parameters |
| ---------- |
| infer_fn: callable |
| Inference function to encode the prompt. |
| prompt_list: list |
| List of prompts. Each element can be a single prompt or a list of prompts passed to the infer_fn. |
| negative_prompt_list: list |
| List of negative prompts. Only used when do_classifier_free_guidance is True. If None, will use <cfg> |
| token sequence as negative prompt. |
| infer_fn_kwargs_list: List[Dict[str, int]] |
| List of keyword arguments for the infer_fn. |
| do_classifier_free_guidance: bool |
| Whether to do classifier-free guidance. |
| condition_repeat_times: int |
| Support multi-condition. |
| uncondition_repeat_times: int |
| Support multi-uncondition. |
| """ |
| if infer_fn_kwargs_list is None: |
| infer_fn_kwargs_list = [{} for _ in prompt_list] |
|
|
| |
| cond_results_list = None |
| uncond_results_list = None |
| output_type_list = [] |
|
|
| for prompt_idx, (prompt, infer_fn_kwargs) in enumerate(zip(prompt_list, infer_fn_kwargs_list)): |
| if not isinstance(prompt, (list, tuple)): |
| prompt = [prompt] |
| cond_kwargs = {"uncond_p": 0.0} if do_classifier_free_guidance else {} |
| results = infer_fn( |
| *prompt, |
| **infer_fn_kwargs, |
| **cond_kwargs, |
| ) |
| output_type_list.append((type(results), len(results) if isinstance(results, (list, tuple)) else 1)) |
| if isinstance(results, dict): |
| raise ValueError("Make batch on dict is not supported. Please return list or tuple for infer_fn.") |
| if not isinstance(results, (list, tuple)): |
| results = (results,) |
| if cond_results_list is None: |
| cond_results_list = [[] for _ in results] |
| uncond_results_list = [[] for _ in results] |
| for i, result in enumerate(results): |
| cond_results_list[i].append(result) |
|
|
| if do_classifier_free_guidance: |
| if negative_prompt_list is None: |
| uncond_kwargs = {"uncond_p": 1.0} |
| uncond_results = infer_fn( |
| *prompt, |
| **infer_fn_kwargs, |
| **uncond_kwargs, |
| ) |
| else: |
| negative_prompt = negative_prompt_list[prompt_idx] |
| if not isinstance(negative_prompt, (list, tuple)): |
| negative_prompt = [negative_prompt] |
| uncond_results = infer_fn( |
| *negative_prompt, |
| **infer_fn_kwargs, |
| ) |
| if isinstance(uncond_results, TokenizerEncodeOutput): |
| uncond_results_list.append(uncond_results) |
| else: |
| for i, result in enumerate(uncond_results): |
| uncond_results_list[i].append(result) |
|
|
| assert all(output_type_list[0] == n for n in output_type_list), \ |
| f"Number of outputs should be equal for all samples, but got {output_type_list}." |
| output_type, output_num = output_type_list[0] |
|
|
| def make_batch(batch_cond_item, batch_uncond_item): |
| |
| first = batch_cond_item[0] |
| if isinstance(first, torch.Tensor): |
| stacked_item = torch.stack(self.pad( |
| batch_cond_item * condition_repeat_times + batch_uncond_item * uncondition_repeat_times, |
| )) |
|
|
| elif first is None: |
| assert all(item is None for item in batch_cond_item + batch_uncond_item), \ |
| (f"The first cond item is None, but some items are not None:\n\n" |
| f"condition: {batch_cond_item}\n\n" |
| f"uncondition: {batch_uncond_item}") |
| stacked_item = None |
|
|
| elif isinstance(first, (list, tuple)): |
| |
| stacked_item = batch_cond_item * condition_repeat_times + batch_uncond_item * uncondition_repeat_times |
|
|
| elif isinstance(first, TokenizerEncodeOutput): |
| stacked_item = {} |
| |
| for key in list(first.keys()): |
| merged_list = [cond_item[key] for cond_item in batch_cond_item] * condition_repeat_times + \ |
| [uncond_item[key] for uncond_item in batch_uncond_item] * uncondition_repeat_times |
| if isinstance(first[key], torch.Tensor): |
| if 'mask' in key: |
| pad_val = 0.0 |
| elif key == 'tokens': |
| pad_val = self.special_token_map["<pad>"] |
| else: |
| pad_val = False |
| stacked_item[key] = torch.stack(self.pad(merged_list, pad_val=pad_val), dim=0) |
| elif isinstance(first[key], list): |
| stacked_item[key] = merged_list |
| elif first[key] is None: |
| pass |
| else: |
| raise ValueError(f"Unsupported type of {key}: {type(first[key])}.") |
| stacked_item = TokenizerEncodeOutput(stacked_item) |
|
|
| else: |
| raise TypeError(f"Making batch on type {type(first)} is not supported.") |
|
|
| return stacked_item |
|
|
| stacked_outputs = [] |
| for cond_results, uncond_results in zip(cond_results_list, uncond_results_list): |
| stacked_outputs.append(make_batch(cond_results, uncond_results)) |
|
|
| if output_type == list: |
| return stacked_outputs |
| elif output_type == tuple: |
| return tuple(stacked_outputs) |
| elif output_num == 1: |
| return stacked_outputs[0] |
| else: |
| raise ValueError(f"Unsupported output type: {output_type}.") |
|
|
| @staticmethod |
| def parse_extra_token_pos(extra_token_pos, prefix, tokens, rng=None): |
| if rng is None: |
| rng = slice(None) |
| image_slices = [ |
| slice(start, end + 1) |
| for start, end in zip(extra_token_pos[f'<{prefix}>_start'][rng], extra_token_pos[f'<{prefix}>_end'][rng]) |
| ] if f'<{prefix}>_start' in extra_token_pos and f'<{prefix}>_end' in extra_token_pos else [] |
| if image_slices: |
| image_mask = torch.zeros_like(tokens, dtype=torch.bool) |
| for image_slice in image_slices: |
| image_mask[image_slice] = True |
| else: |
| image_mask = None |
| return image_slices, image_mask |
|
|
| def encode_general( |
| self, |
| sections: Optional[List[Dict[str, Any]]] = None, |
| max_token_length: Optional[int] = None, |
| add_eos='auto', |
| use_text_mask=True, |
| add_pad='auto', |
| add_bos=True, |
| drop_last='auto', |
| ): |
| """ |
| General encode function to encode a sequence with multiple sections of text and images. |
| Each section is a dict with a `type` key and other keys depending on the type. |
| Supported section types: |
| - text: dict with keys: |
| - text: str or List[int], text to be encoded. Either `text` or `tokens` should be provided. |
| - tokens: List[int], pre-encoded text tokens. Either `text` or `tokens` should be provided. |
| - uncond_enabled: bool, whether to enable uncondition for this text section. |
| - uncond_p: float, probability to drop the text section for uncondition. |
| - max_length: int, maximum length of the text section. |
| - ignore: bool, whether to ignore this text section in the text mask. |
| - start_offset: int, start offset of the text mask. |
| - end_offset: int, end offset of the text mask. |
| - gen_image: dict with keys: |
| - token_length: int, number of image tokens. |
| - add_timestep_token: bool, whether to add timestep token before the image tokens. |
| - add_guidance_token: bool, whether to add guidance token before the image tokens. |
| - use_front_boi_token: bool, whether to put the <boi> token at the front of size, ratio and timestep tokens. |
| - add_image_shape_token: bool, whether to add image shape token before the image tokens. |
| - base_size: int, base size of the image. |
| - ratio_idx: int, ratio index of the image. |
| - joint_image: dict with keys: |
| - token_length: List[int], number of image tokens for the two images. |
| - add_timestep_token: bool, whether to add timestep token before the image tokens. |
| - use_front_boi_token: bool, whether to put the <boi> token at the front of size, ratio and timestep tokens. |
| - add_image_shape_token: bool, whether to add image shape token before the image tokens. |
| - base_size: int, base size of the image. |
| - ratio_idx: int, ratio index of the image. |
| |
| Parameters |
| ---------- |
| sections: List[Dict[str, Any]] |
| List of sections to be encoded. |
| max_token_length: int |
| Maximum length of the encoded token sequence. |
| add_eos: bool or 'auto' |
| Whether to add eos token at the end of the sequence. If True, always add eos |
| token. If 'auto', add eos token only when the total_length is not reached and the last token is not <eos>. |
| use_text_mask: bool |
| Whether to generate text mask. |
| add_pad: bool or 'auto' |
| Whether to add padding tokens to the sequence. If True and total_length is not reached, |
| add padding tokens. |
| add_bos: bool |
| Whether to add bos token at the beginning of the sequence. |
| drop_last: bool or 'auto' |
| - If auto, drop last tokens exceeding the total_length if the total_length is provided. |
| If cut point is in the middle of the image tokens, an error will raised. |
| - If True, drop last tokens exceeding the total_length. If cut point is in the |
| middle of the image tokens, all the successive image tokens will be dropped. |
| - If False, keep the last tokens exceeding the total_length, even if the total_length |
| is reached. |
| |
| Returns |
| ------- |
| TokenizerEncodeOutput |
| Encoded token sequence and extra information. |
| """ |
| if sections is None: |
| raise ValueError("sections must be provided.") |
| template = '-'.join([section['type'] for section in sections]) |
|
|
| sections = deepcopy(sections) |
| token_source = defaultdict(list) |
| text_mask_specs = [] |
| for section in sections: |
| if section['type'] == 'text': |
| text = self.encode_text( |
| section['text'] if 'text' in section else section['tokens'], |
| uncond_enabled=section.get('uncond_enabled'), |
| uncond_p=section.get('uncond_p'), |
| max_length=section.get('max_length'), |
| ) |
| token_source['text'].append(text) |
| text_mask_specs.append(dict( |
| ignore=section.get('ignore', False), |
| start_offset=section.get('start_offset', 0), |
| end_offset=section.get('end_offset', 0), |
| )) |
| elif section['type'] == 'gen_image': |
| token_source['gen_image'].append(dict( |
| length=section['token_length'], |
| timestep=section.get('add_timestep_token', False), |
| guidance=section.get('add_guidance_token', False), |
| front_boi=section.get('use_front_boi_token', False), |
| image_shape=section.get('add_image_shape_token', False), |
| base_size=section.get('base_size'), |
| ratio_idx=section.get('ratio_idx'), |
| )) |
| elif section['type'] == 'joint_image': |
| token_source['joint_image'].append(dict( |
| length=section['token_length'], |
| timestep=section.get('add_timestep_token', False), |
| front_boi=section.get('use_front_boi_token', False), |
| image_shape=section.get('add_image_shape_token', False), |
| base_size=section.get('base_size'), |
| ratio_idx=section.get('ratio_idx'), |
| )) |
| else: |
| raise ValueError(f"Invalid section type: {section['type']}") |
|
|
| |
| full_token_seq, extra_token_pos = self.encode_sequence( |
| template=template, |
| token_source=dict(token_source), |
| total_length=max_token_length, |
| add_eos=add_eos, |
| add_pad=add_pad, |
| add_bos=add_bos, |
| drop_last=drop_last, |
| ) |
| full_seq_token_tensor = torch.tensor(full_token_seq, dtype=torch.long) |
|
|
| timestep_scatter_index = torch.tensor(extra_token_pos['timestep'], dtype=torch.long) \ |
| if 'timestep' in extra_token_pos else None |
| guidance_scatter_index = torch.tensor(extra_token_pos['guidance'], dtype=torch.long) \ |
| if 'guidance' in extra_token_pos else None |
| cond_timestep_scatter_index = torch.tensor(extra_token_pos['cond_timestep'], dtype=torch.long) \ |
| if 'cond_timestep' in extra_token_pos else None |
| gen_timestep_scatter_index = torch.tensor(extra_token_pos['gen_timestep'], dtype=torch.long) \ |
| if 'gen_timestep' in extra_token_pos else None |
|
|
| |
| gen_image_slices, gen_image_mask = self.parse_extra_token_pos(extra_token_pos, 'img', full_seq_token_tensor) |
| |
| joint_image_slices, _ = self.parse_extra_token_pos(extra_token_pos, 'joint_img', full_seq_token_tensor) |
| |
| cond_vae_image_slices, cond_vae_image_mask = self.parse_extra_token_pos( |
| extra_token_pos, 'vae_img', full_seq_token_tensor) |
| |
| cond_vit_image_slices, cond_vit_image_mask = self.parse_extra_token_pos( |
| extra_token_pos, 'vit_img', full_seq_token_tensor) |
| |
| all_image_slices = [ |
| slice(start, end + 1) |
| for start, end in zip(extra_token_pos['<all_img>_start'], extra_token_pos['<all_img>_end']) |
| ] if '<all_img>_start' in extra_token_pos and '<all_img>_end' in extra_token_pos else [] |
|
|
| |
| text_slices = [ |
| slice(start, end + 1) |
| for start, end in zip(extra_token_pos['<text>_start'], extra_token_pos['<text>_end']) |
| ] if '<text>_start' in extra_token_pos and '<text>_end' in extra_token_pos else [] |
| assert len(text_slices) <= len(text_mask_specs), \ |
| (f"Number of text slices ({len(text_slices)}) should be less than or equal to " |
| f"number of text mask specs ({len(text_mask_specs)})") |
| if use_text_mask: |
| text_mask = torch.zeros_like(full_seq_token_tensor, dtype=torch.float32) |
| for text_slice, mask_spec in zip(text_slices, text_mask_specs): |
| if not mask_spec['ignore']: |
| real_slice = slice( |
| text_slice.start + mask_spec['start_offset'], |
| text_slice.stop + mask_spec['end_offset'] |
| ) |
| text_mask[real_slice] = 1.0 |
| else: |
| text_mask = None |
|
|
| |
| real_pos = torch.tensor(extra_token_pos.get('first_pad', [full_seq_token_tensor.shape[0]]), dtype=torch.long) |
|
|
| return TokenizerEncodeOutput( |
| tokens=full_seq_token_tensor, |
| timestep_scatter_index=timestep_scatter_index, |
| guidance_scatter_index=guidance_scatter_index, |
| text_slices=text_slices, |
| gen_image_slices=gen_image_slices, |
| joint_image_slices=joint_image_slices, |
| cond_vae_image_slices=cond_vae_image_slices, |
| cond_vit_image_slices=cond_vit_image_slices, |
| text_mask=text_mask, |
| gen_image_mask=gen_image_mask, |
| cond_vae_image_mask=cond_vae_image_mask, |
| cond_vit_image_mask=cond_vit_image_mask, |
| real_pos=real_pos, |
| all_image_slices=all_image_slices, |
| cond_timestep_scatter_index=cond_timestep_scatter_index, |
| gen_timestep_scatter_index=gen_timestep_scatter_index, |
| ) |
|
|
| def get_cot_sections(self, cot_text, uncond_kwargs, cot_max_length=None, drop_think=False): |
| if not cot_text: |
| return [] |
| if '<think>' in cot_text and '</think>' in cot_text: |
| before_think_sec = cot_text.split('<think>')[0] |
| after_think_sec = cot_text.split('</think>')[1] |
| think_sec = cot_text.split('<think>')[1].split('</think>')[0] |
| return self.get_cot_sections(before_think_sec, uncond_kwargs, drop_think=drop_think) + \ |
| ([ |
| dict(type="text", text="<think>"), |
| dict(type="text", text=think_sec, max_length=cot_max_length, **uncond_kwargs), |
| dict(type="text", text="</think>") |
| ] if not drop_think else []) + \ |
| self.get_cot_sections(after_think_sec, uncond_kwargs, drop_think=drop_think) |
|
|
| if '<recaption>' in cot_text and '</recaption>' in cot_text: |
| before_recaption_sec = cot_text.split('<recaption>')[0] |
| after_recaption_sec = cot_text.split('</recaption>')[1] |
| recaption_sec = cot_text.split('<recaption>')[1].split('</recaption>')[0] |
| return self.get_cot_sections(before_recaption_sec, uncond_kwargs, drop_think=drop_think) + \ |
| [ |
| dict(type="text", text="<recaption>"), |
| dict(type="text", text=recaption_sec, max_length=cot_max_length, **uncond_kwargs), |
| dict(type="text", text="</recaption>") |
| ] + \ |
| self.get_cot_sections(after_recaption_sec, uncond_kwargs, drop_think=drop_think) |
|
|
| return [ |
| dict(type="text", text=cot_text, **uncond_kwargs), |
| ] |
|
|
| def apply_general_template( |
| self, |
| message_list, |
| max_length=None, |
| add_assistant_prefix=False, |
| answer="auto", |
| bot_task="auto", |
| sequence_template="instruct", |
| uncond_p=0.0, |
| cfg_factor=1, |
| batchify=False, |
| image_base_size=1024, |
| drop_think=False, |
| ): |
| |
| if batchify: |
| assert isinstance(message_list[0], list), \ |
| f"When batchify is True, message_list should be a list of list, but got [{type(message_list[0])}, ...]." |
| return self.batch_gen_infer( |
| infer_fn=self.apply_general_template, |
| prompt_list=[[]], |
| infer_fn_kwargs_list=[dict( |
| message_list=message_list_i, |
| max_length=max_length, |
| add_assistant_prefix=add_assistant_prefix, |
| answer=answer, |
| bot_task=bot_task, |
| sequence_template=sequence_template, |
| image_base_size=image_base_size, |
| drop_think=drop_think, |
| ) for message_list_i in message_list], |
| do_classifier_free_guidance=cfg_factor > 1, |
| condition_repeat_times=1, |
| uncondition_repeat_times=cfg_factor - 1, |
| ) |
|
|
| conv = Conversation() |
| uncond_kwargs = dict(uncond_enabled=uncond_p == 1.0, uncond_p=uncond_p) |
|
|
| def process_successive_message(_message_list, _cur_message_idx, role, prefix, suffix, |
| answer_prefix="", answer_suffix=""): |
| _sub_sections = [] |
| while _cur_message_idx < len(message_list) and _message_list[_cur_message_idx]['role'] == role: |
| message = _message_list[_cur_message_idx] |
| if message['type'] == 'text': |
| text = message['content'] |
| if role == "system": |
| _sub_sections.append(dict(type="text", text=text)) |
| elif role == "assistant": |
| if ("<recaption>" in text and "</recaption>" in text) or ( |
| "<think>" in text and "</think>" in text): |
| _sub_sections.extend(self.get_cot_sections(text, uncond_kwargs, drop_think=drop_think)) |
| else: |
| _sub_sections.append(dict(type="text", text=text, **uncond_kwargs)) |
| else: |
| _sub_sections.append(dict( |
| type="text", text=f"{answer_prefix}{text}{answer_suffix}", **uncond_kwargs)) |
| elif message['type'] == 'gen_image': |
| info = message['content'] |
| assert isinstance(info, ImageInfo), f"Expected ImageInfo, but got {type(info)}" |
| if role == "assistant": |
| _sub_sections.append(dict(type="text", text=answer_prefix)) |
| _sub_sections.append(dict(type=message['type'], **info.meta_info)) |
| if role == "assistant": |
| _sub_sections.append(dict(type="text", text=answer_suffix)) |
| elif message['type'] == 'joint_image': |
| info = message['content'] |
| assert isinstance(info, JointImageInfo), f"Expected JointImageInfo, but got {type(info)}" |
| _sub_sections.append(dict(type=message['type'], **info.meta_info)) |
| else: |
| raise ValueError(f"Unknown message type: {message['type']}") |
| _cur_message_idx += 1 |
| if len(_sub_sections) > 0: |
| |
| _sub_sections.insert(0, dict(type='text', text=prefix)) |
| _sub_sections.append(dict(type='text', text=suffix)) |
| return _sub_sections, _cur_message_idx |
|
|
| |
| if (answer == "auto" and sequence_template == "instruct") or answer is True: |
| answer_prefix, answer_suffix = "<answer>", "</answer>" |
| else: |
| answer_prefix, answer_suffix = "", "" |
| if sequence_template == "pretrain": |
| system_suffix = "" |
| user_prefix = "" |
| user_suffix = "" |
| bot_prefix = "" |
| bot_suffix = "" |
| else: |
| system_suffix = f"{conv.sep}" |
| user_prefix = f"{conv.roles[0]}: " |
| user_suffix = f"{conv.sep}" |
| bot_prefix = f"{conv.roles[1]}: " |
| bot_suffix = f"{conv.sep}" |
|
|
| |
| sections = [] |
| cur_message_idx = 0 |
| final_role = None |
| while cur_message_idx < len(message_list): |
| |
| sub_sections, cur_message_idx = process_successive_message( |
| message_list, cur_message_idx, role="system", prefix="", suffix=system_suffix) |
| |
| sections.extend(sub_sections) |
| if len(sub_sections) > 0: |
| final_role = "system" |
|
|
| |
| sub_sections, cur_message_idx = process_successive_message( |
| message_list, cur_message_idx, role="user", prefix=user_prefix, suffix=user_suffix) |
| |
| sections.extend(sub_sections) |
| if len(sub_sections) > 0: |
| final_role = "user" |
|
|
| |
| sub_sections, cur_message_idx = process_successive_message( |
| message_list, cur_message_idx, role="assistant", prefix=bot_prefix, suffix=bot_suffix, |
| answer_prefix=answer_prefix, answer_suffix=answer_suffix, |
| ) |
| |
| sections.extend(sub_sections) |
| if len(sub_sections) > 0: |
| final_role = "assistant" |
|
|
| if add_assistant_prefix: |
| if final_role == "assistant": |
| |
| _bot_prefix = "" |
| |
| if len(sections) > 0 and sections[-1]['type'] == 'text' and sections[-1]['text'] == bot_suffix: |
| sections = sections[:-1] |
| else: |
| _bot_prefix = bot_prefix |
| |
| bot_response_prefix = dict( |
| auto=_bot_prefix, |
| image="", |
| think=f"{_bot_prefix}<think>", |
| recaption=f"{_bot_prefix}<recaption>", |
| img_ratio=f"{_bot_prefix}{answer_prefix}<boi><img_size_{image_base_size}>", |
| )[bot_task] |
| sections.append(dict(type='text', text=bot_response_prefix)) |
|
|
| output = self.encode_general( |
| sections=sections, |
| use_text_mask=False, |
| add_eos=False, |
| add_pad=False, |
| ) |
|
|
| if max_length is not None: |
| if output.tokens.shape[-1] > max_length: |
| raise ValueError( |
| f"Encoded token length {output.tokens.shape[-1]} exceeds max_length {max_length}.\n" |
| f"Please set a larger max_length or check the input messages:\n{message_list}" |
| ) |
|
|
| return output, sections |
|
|
| def apply_chat_template( |
| self, |
| batch_prompt: Optional[List[str]] = None, |
| batch_message_list: Optional[List[List[Dict[str, Any]]]] = None, |
| mode: str = "gen_text", |
| batch_gen_image_info: Optional[List[ImageInfo]] = None, |
| batch_cond_image_info: Optional[Union[List[JointImageInfo], List[List[JointImageInfo]]]] = None, |
| batch_system_prompt: Optional[List[str]] = None, |
| batch_cot_text: Optional[List[str]] = None, |
| max_length: Optional[int] = None, |
| bot_task: str = "auto", |
| image_base_size: int = 1024, |
| sequence_template: str = "pretrain", |
| cfg_factor: int = 1, |
| add_assistant_prefix: Optional[bool] = None, |
| drop_think: bool = False, |
| ) -> Dict[str, Any]: |
| assert bot_task in ["image", "auto", "think", "recaption", "img_ratio"], \ |
| f"bot_task should be one of ['image', 'auto', 'think', 'recaption', 'img_ratio'], but got {bot_task}." |
|
|
| if batch_message_list is None: |
| |
| batch_size = len(batch_prompt) |
|
|
| |
| if not isinstance(batch_system_prompt, list): |
| batch_system_prompt = [batch_system_prompt] * batch_size |
| if not isinstance(batch_gen_image_info, list): |
| batch_gen_image_info = [batch_gen_image_info] * batch_size |
| if batch_cot_text is not None: |
| assert len(batch_cot_text) == batch_size, \ |
| (f"batch_cot_text should have the same length as batch_size ({batch_size}), " |
| f"but got {len(batch_cot_text)}.") |
| else: |
| batch_cot_text = [None] * batch_size |
| if batch_cond_image_info is not None: |
| assert len(batch_cond_image_info) == batch_size, \ |
| (f"batch_cond_image_info should have the same length as batch_size ({batch_size}), " |
| f"but got {len(batch_cond_image_info)}.") |
| batch_cond_image_info = [ |
| cond_image_info if isinstance(cond_image_info, list) else [cond_image_info] |
| for cond_image_info in batch_cond_image_info |
| ] |
| else: |
| batch_cond_image_info = [[] for _ in range(batch_size)] |
|
|
| |
| batch_message_list = [] |
| for prompt, system_prompt, cot_text, gen_image_info, cond_image_info_list in zip( |
| batch_prompt, batch_system_prompt, batch_cot_text, batch_gen_image_info, |
| batch_cond_image_info, |
| ): |
| message_list = [] |
| |
| if system_prompt: |
| message_list.append(dict( |
| role="system", type="text", content=system_prompt, context_type="str")) |
| |
| |
| if len(cond_image_info_list) > 0: |
| message_list.extend([ |
| dict(role="user", type="joint_image", content=cond_image_info, context_type="image_info") |
| for cond_image_info in cond_image_info_list |
| ]) |
| |
| message_list.append(dict( |
| role="user", type="text", content=prompt, context_type="str")) |
| |
| if cot_text is not None: |
| message_list.append(dict(role="assistant", type="text", content=cot_text, context_type="str")) |
| if mode == "gen_image": |
| message_list.append(dict( |
| role="assistant", type="gen_image", content=gen_image_info, context_type="image_info")) |
| |
| batch_message_list.append(message_list) |
|
|
| output, sections = self.apply_general_template( |
| message_list=batch_message_list, |
| max_length=max_length, |
| add_assistant_prefix=default(add_assistant_prefix, mode != "gen_image"), |
| bot_task=bot_task, |
| sequence_template=sequence_template, |
| cfg_factor=cfg_factor, |
| batchify=True, |
| image_base_size=image_base_size, |
| drop_think=drop_think, |
| ) |
| return dict(output=output, sections=sections) |
|
|