File size: 20,555 Bytes
7bef20f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 | """Data loader using webdataset.
Reference:
https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py
https://github.com/huggingface/open-muse/blob/main/training/data.py
"""
import math
from typing import List, Union, Text
import webdataset as wds
import numpy as np
import torch
from torch.utils.data import default_collate
from torchvision import transforms
from torch.utils.data import Dataset
import linecache
import json
from PIL import Image
import random
import cv2
import numpy as np
from tqdm import tqdm
Image.MAX_IMAGE_PIXELS = None
def load_json(sample):
sample['json'] = json.loads(sample['json'].decode('utf-8'))
return sample
def filter_keys(key_set):
def _f(dictionary):
return {k: v for k, v in dictionary.items() if k in key_set}
return _f
def filter_by_res_ratio(min_res=256, min_ratio=0.5, max_ratio=2.0):
def _f(sample):
cfg = sample['json']
h, w = cfg['original_height'], cfg['original_width']
ratio = h/w
longer_side = max(h, w)
return ratio >= min_ratio and ratio <= max_ratio and longer_side >= min_res
return _f
def calculate_laplacian_variance(image):
"""Calculate the variance of Laplacian which is a measure of image sharpness/blur."""
# Convert to grayscale if it's RGB
image = np.array(image)
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray = image
# Calculate Laplacian
laplacian = cv2.Laplacian(gray, cv2.CV_64F)
# Calculate variance
return laplacian.var()
# Add this function to map Laplacian values to token lengths
def get_dynamic_length(laplacian_value, mean=2734, std=3239, min_tokens=32, max_tokens=256, mean_tokens=128):
"""
Maps Laplacian values to token lengths using a bell curve approach.
At the mean Laplacian value, uses mean_tokens.
Values further from the mean get mapped to shorter/longer token lengths.
"""
# Prevent division by zero and handle edge cases
if std <= 0:
return mean_tokens
# Calculate z-score
z_score = (laplacian_value - mean) / std
# Use bell curve mapping (gaussian)
# When z_score is 0 (at mean), we get mean_tokens
# As z_score increases, token length increases toward max_tokens
# As z_score decreases, token length decreases toward min_tokens
scaling_factor = 2.0 # Controls how quickly we reach min/max tokens
normalized_position = 0.5 * (1 + math.tanh(scaling_factor * z_score))
# Map to token range [min_tokens, max_tokens]
token_length = min_tokens + normalized_position * (max_tokens - min_tokens)
return int(round(token_length))
# Add this function to map Laplacian values to token lengths
def get_dynamic_length_v2(laplacian_value, mean=2734, std=3239, min_tokens=32, max_tokens=128, mean_tokens=128):
"""
Maps Laplacian values to token lengths using a linear mapping.
Ensures laplacian_value=0 maps to min_tokens, mean maps to mean_tokens,
and higher values scale up to max_tokens.
"""
# Prevent division by zero and handle edge cases
if std <= 0:
return mean_tokens
# Linear mapping from laplacian space to token space
# First normalize laplacian value relative to mean
normalized = (laplacian_value - 0.0) / mean
# Map 0->min_tokens, mean->mean_tokens, and scale up linearly
if laplacian_value <= mean:
# Linear interpolation between min_tokens and mean_tokens
ratio = laplacian_value / mean
token_length = min_tokens + (mean_tokens - min_tokens) * ratio
else:
# Linear interpolation between mean_tokens and max_tokens
ratio = (laplacian_value - mean) / mean # How far past mean
token_length = mean_tokens + (max_tokens - mean_tokens) * ratio
# Clamp to valid range
token_length = max(min_tokens, min(max_tokens, token_length))
return int(round(token_length))
def get_laplacian_attention_mask(sample):
"""Process sample to add Laplacian variance and attention mask."""
# Create a new dict to avoid modifying the input
processed = dict(sample)
# Calculate Laplacian variance
var = calculate_laplacian_variance(processed["image"])
length = get_dynamic_length(var)
# Create attention mask
attention_mask = torch.zeros((128,), dtype=torch.float32)
attention_mask[:length+1] = 1.0
# Add new fields to processed dict
processed["laplacian_var"] = var
processed["attention_mask"] = attention_mask
return processed
def get_uniform_attention_mask(min_tokens=32, max_tokens=128):
"""Process sample to add uniform random attention mask."""
def _f(dictionary):
# Sample length uniformly between min_tokens and max_tokens
length = torch.randint(min_tokens, max_tokens+1, (1,)).item()
# Create attention mask
attention_mask = torch.zeros((max_tokens,), dtype=torch.float32)
attention_mask[:length+1] = 1.0
# Add attention mask to dictionary
dictionary["attention_mask"] = attention_mask
return dictionary
return _f
def process_recap_text(p):
def _f(dictionary):
if "recap_txt" in dictionary:
if random.random() < p:
recap_prefixes = ["The image " + v for v in ['depicts', "displays", 'showcases', 'features', 'shows']]
# Convert input to string and strip whitespace
text = dictionary["recap_txt"].decode("utf-8").strip()
# Check if text starts with any of the phrases
for phrase in recap_prefixes:
if text.startswith(phrase):
# Remove the phrase and any leading/trailing whitespace
text = text[len(phrase):].strip()
# Capitalize the first letter
text = text[0].upper() + text[1:] if text else ""
break
dictionary["text"] = text.encode("utf-8")
return dictionary
return _f
def identity(x):
return x
class ImageTransform:
def __init__(self,
resize_shorter_edge: int = 256,
crop_size: int = 256,
random_crop: bool = True,
random_flip: bool = True,
normalize_mean: List[float] = [0., 0., 0.],
normalize_std: List[float] = [1., 1., 1.]):
"""Initializes the WebDatasetReader with specified augmentation parameters.
Args:
resize_shorter_edge: An integer, the shorter edge size to resize the input image to.
crop_size: An integer, the size to crop the input image to.
random_crop: A boolean, whether to use random crop augmentation during training.
random_flip: A boolean, whether to use random flipping augmentation during training.
normalize_mean: A list of float, the normalization mean used to normalize the image tensor.
normalize_std: A list of float, the normalization std used to normalize the image tensor.
Raises:
NotImplementedError: If the interpolation mode is not one of ["bicubic", "bilinear"].
"""
train_transform = []
interpolation = transforms.InterpolationMode.BICUBIC
train_transform.append(
transforms.Resize(resize_shorter_edge, interpolation=interpolation, antialias=True))
if random_crop:
train_transform.append(transforms.RandomCrop(crop_size))
else:
train_transform.append(transforms.CenterCrop(crop_size))
if random_flip:
train_transform.append(transforms.RandomHorizontalFlip())
train_transform.append(transforms.ToTensor())
# normalize_mean = [0, 0, 0] and normalize_std = [1, 1, 1] will normalize images into [0, 1],
# normalize_mean = [0.5, 0.5, 0.5] and normalize_std = [0.5, 0.5, 0.5] will normalize images into [-1, 1].
train_transform.append(transforms.Normalize(normalize_mean, normalize_std))
self.train_transform = transforms.Compose(train_transform)
self.eval_transform = transforms.Compose(
[
# Note that we always resize to crop_size during eval to ensure the results
# can be compared against reference numbers on ImageNet etc.
transforms.Resize(crop_size, interpolation=interpolation, antialias=True),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize(normalize_mean, normalize_std)
]
)
print(f"self.train_transform: {self.train_transform}")
print(f"self.eval_transform: {self.eval_transform}")
class SimpleImageDataset:
def __init__(
self,
train_shards_path: Union[Text, List[Text]],
eval_shards_path: Union[Text, List[Text]],
num_train_examples: int,
per_gpu_batch_size: int,
global_batch_size: int,
num_workers_per_gpu: int = 12,
resize_shorter_edge: int = 256,
crop_size: int = 256,
random_crop = True,
random_flip = True,
normalize_mean: List[float] = [0., 0., 0.],
normalize_std: List[float] = [1., 1., 1.],
dataset_with_class_label: bool = True,
dataset_with_text_label: bool = False,
res_ratio_filtering = False,
min_tokens = 32,
max_tokens = 128,
):
"""Initializes the WebDatasetReader class.
Args:
train_shards_path: A string or list of string, path to the training data shards in webdataset format.
eval_shards_path: A string or list of string, path to the evaluation data shards in webdataset format.
num_train_examples: An integer, total number of training examples.
per_gpu_batch_size: An integer, number of examples per GPU batch.
global_batch_size: An integer, total number of examples in a batch across all GPUs.
num_workers_per_gpu: An integer, number of workers per GPU.
resize_shorter_edge: An integer, the shorter edge size to resize the input image to.
crop_size: An integer, the size to crop the input image to.
random_crop: A boolean, whether to use random crop augmentation during training.
random_flip: A boolean, whether to use random flipping augmentation during training.
normalize_mean: A list of float, the normalization mean used to normalize the image tensor.
normalize_std: A list of float, the normalization std used to normalize the image tensor.
"""
transform = ImageTransform(
resize_shorter_edge, crop_size, random_crop, random_flip,
normalize_mean, normalize_std)
if dataset_with_class_label:
train_processing_pipeline = [
wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"]), handler=wds.warn_and_continue),
wds.rename(
image="jpg;png;jpeg;webp",
class_id="cls",
handler=wds.warn_and_continue,
),
wds.map(filter_keys(set(["image", "class_id", "filename"]))),
wds.map(get_uniform_attention_mask(min_tokens=min_tokens, max_tokens=max_tokens)),
wds.map_dict(
image=transform.train_transform,
class_id=lambda x: int(x),
attention_mask=lambda x: x,
handler=wds.warn_and_continue,
),
]
elif dataset_with_text_label:
train_processing_pipeline = [
wds.map(load_json),
wds.select(filter_by_res_ratio()) if res_ratio_filtering else wds.map(identity),
wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"]),only=["webp", "png", "jpg", "jpeg", "txt"], handler=wds.warn_and_continue),
wds.rename(
image="jpg;png;jpeg;webp",
text="txt",
handler=wds.warn_and_continue,
),
wds.map(filter_keys(set(["image", "text", "__key__"]))),
wds.map(get_uniform_attention_mask(min_tokens=min_tokens, max_tokens=max_tokens)),
wds.map_dict(
image=transform.train_transform,
attention_mask=lambda x: x,
handler=wds.warn_and_continue,
),
]
else:
raise NotImplementedError
test_processing_pipeline = [
wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"]), handler=wds.warn_and_continue),
wds.rename(
image="jpg;png;jpeg;webp",
class_id="cls",
handler=wds.warn_and_continue,
),
wds.map(filter_keys(set(["image", "class_id", "filename"]))),
wds.map(get_uniform_attention_mask(min_tokens=min_tokens, max_tokens=max_tokens)),
wds.map_dict(
image=transform.eval_transform,
class_id=lambda x: int(x),
# laplacian_var=lambda x: x,
attention_mask=lambda x: x,
handler=wds.warn_and_continue,
),
]
# Create train dataset and loader.
pipeline = [
wds.ResampledShards(train_shards_path),
wds.tarfile_to_samples(handler=wds.warn_and_continue),
wds.shuffle(bufsize=5000,
initial=1000),
*train_processing_pipeline,
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
]
num_batches = math.ceil(num_train_examples / global_batch_size)
num_worker_batches = math.ceil(num_train_examples /
(global_batch_size * num_workers_per_gpu))
num_batches = num_worker_batches * num_workers_per_gpu
num_samples = num_batches * global_batch_size
# Each worker is iterating over the complete dataset.
self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
self._train_dataloader = wds.WebLoader(
self._train_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers_per_gpu,
pin_memory=True,
persistent_workers=True,
)
# Add meta-data to dataloader instance for convenience.
self._train_dataloader.num_batches = num_batches
self._train_dataloader.num_samples = num_samples
# Create eval dataset and loader.
pipeline = [
wds.SimpleShardList(eval_shards_path),
wds.split_by_worker,
wds.tarfile_to_samples(handler=wds.ignore_and_continue),
*test_processing_pipeline,
wds.batched(per_gpu_batch_size, partial=True, collation_fn=default_collate),
]
self._eval_dataset = wds.DataPipeline(*pipeline)
self._eval_dataloader = wds.WebLoader(
self._eval_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers_per_gpu,
pin_memory=True,
persistent_workers=True,
)
@property
def train_dataset(self):
return self._train_dataset
@property
def train_dataloader(self):
return self._train_dataloader
@property
def eval_dataset(self):
return self._eval_dataset
@property
def eval_dataloader(self):
return self._eval_dataloader
class PretoeknizedDataSetJSONL(Dataset):
def __init__(self, data_path):
super().__init__()
self.jsonl_file = data_path
self.num_lines = sum(1 for _ in open(self.jsonl_file))
# Ensure the file is cached
linecache.checkcache(self.jsonl_file)
print("Number of data:", self.num_lines)
def __len__(self):
return self.num_lines
def __getitem__(self, idx):
line = linecache.getline(self.jsonl_file, idx + 1).strip()
data = json.loads(line)
return torch.tensor(data["class_id"]), torch.tensor(data["tokens"])
class PretokenizedWebDataset(SimpleImageDataset):
def __init__ (
self,
train_shards_path: Union[Text, List[Text]],
eval_shards_path: Union[Text, List[Text]],
num_train_examples: int,
per_gpu_batch_size: int,
global_batch_size: int,
num_workers_per_gpu: int,
resize_shorter_edge: int = 256,
crop_size: int = 256,
random_crop = True,
random_flip = True,
normalize_mean: List[float] = [0., 0., 0.],
normalize_std: List[float] = [1., 1., 1.],
process_recap = False,
use_recap_prob = 0.95,
):
"""Initializes the PretokenizedWebDataset class.
Text-to-image datasets are pretokenized with careful filtering (Tab. 7 in Supp.) to speed up the training
"""
transform = ImageTransform(
resize_shorter_edge, crop_size, random_crop, random_flip,
normalize_mean, normalize_std)
def decode_npy(x):
arr = np.frombuffer(x, dtype=np.float16)
ret = torch.tensor(arr)
return ret
def decode_text(x):
ret = x.decode("utf-8")
return ret
train_processing_pipeline = [
wds.rename(
tokens="token.npy",
text="txt",
handler=wds.warn_and_continue,
),
wds.map(process_recap_text(use_recap_prob) if process_recap else wds.map(identity)),
wds.map(filter_keys(set(["tokens", "text", "aes_score", "__key__"]))),
wds.map_dict(
tokens=decode_npy,
text=decode_text,
handler=wds.warn_and_continue,
),
]
test_processing_pipeline = [
wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"])),
wds.rename(
image="jpg;png;jpeg;webp",
handler=wds.warn_and_continue,
),
wds.map_dict(
image=transform.eval_transform,
handler=wds.warn_and_continue,
),
]
# Create train dataset and loader.
pipeline = [
wds.ResampledShards(train_shards_path),
wds.tarfile_to_samples(handler=wds.warn_and_continue),
wds.shuffle(bufsize=5000,
initial=1000),
*train_processing_pipeline,
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
]
num_batches = math.ceil(num_train_examples / global_batch_size)
num_worker_batches = math.ceil(num_train_examples /
(global_batch_size * num_workers_per_gpu))
num_batches = num_worker_batches * num_workers_per_gpu
num_samples = num_batches * global_batch_size
# Each worker is iterating over the complete dataset.
self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
self._train_dataloader = wds.WebLoader(
self._train_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers_per_gpu,
pin_memory=True,
persistent_workers=True,
)
# Add meta-data to dataloader instance for convenience.
self._train_dataloader.num_batches = num_batches
self._train_dataloader.num_samples = num_samples
# Create eval dataset and loader.
pipeline = [
wds.SimpleShardList(eval_shards_path),
wds.split_by_worker,
wds.tarfile_to_samples(handler=wds.ignore_and_continue),
*test_processing_pipeline,
wds.batched(per_gpu_batch_size, partial=True, collation_fn=default_collate),
]
self._eval_dataset = wds.DataPipeline(*pipeline)
self._eval_dataloader = wds.WebLoader(
self._eval_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers_per_gpu,
pin_memory=True,
persistent_workers=True,
) |