VibeToken / data /webdataset_reader.py
APGASU's picture
scripts
7bef20f verified
"""Data loader using webdataset.
Reference:
https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py
https://github.com/huggingface/open-muse/blob/main/training/data.py
"""
import math
from typing import List, Union, Text
import webdataset as wds
import numpy as np
import torch
from torch.utils.data import default_collate
from torchvision import transforms
from torch.utils.data import Dataset
import linecache
import json
from PIL import Image
import random
import cv2
import numpy as np
from tqdm import tqdm
Image.MAX_IMAGE_PIXELS = None
def load_json(sample):
sample['json'] = json.loads(sample['json'].decode('utf-8'))
return sample
def filter_keys(key_set):
def _f(dictionary):
return {k: v for k, v in dictionary.items() if k in key_set}
return _f
def filter_by_res_ratio(min_res=256, min_ratio=0.5, max_ratio=2.0):
def _f(sample):
cfg = sample['json']
h, w = cfg['original_height'], cfg['original_width']
ratio = h/w
longer_side = max(h, w)
return ratio >= min_ratio and ratio <= max_ratio and longer_side >= min_res
return _f
def calculate_laplacian_variance(image):
"""Calculate the variance of Laplacian which is a measure of image sharpness/blur."""
# Convert to grayscale if it's RGB
image = np.array(image)
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray = image
# Calculate Laplacian
laplacian = cv2.Laplacian(gray, cv2.CV_64F)
# Calculate variance
return laplacian.var()
# Add this function to map Laplacian values to token lengths
def get_dynamic_length(laplacian_value, mean=2734, std=3239, min_tokens=32, max_tokens=256, mean_tokens=128):
"""
Maps Laplacian values to token lengths using a bell curve approach.
At the mean Laplacian value, uses mean_tokens.
Values further from the mean get mapped to shorter/longer token lengths.
"""
# Prevent division by zero and handle edge cases
if std <= 0:
return mean_tokens
# Calculate z-score
z_score = (laplacian_value - mean) / std
# Use bell curve mapping (gaussian)
# When z_score is 0 (at mean), we get mean_tokens
# As z_score increases, token length increases toward max_tokens
# As z_score decreases, token length decreases toward min_tokens
scaling_factor = 2.0 # Controls how quickly we reach min/max tokens
normalized_position = 0.5 * (1 + math.tanh(scaling_factor * z_score))
# Map to token range [min_tokens, max_tokens]
token_length = min_tokens + normalized_position * (max_tokens - min_tokens)
return int(round(token_length))
# Add this function to map Laplacian values to token lengths
def get_dynamic_length_v2(laplacian_value, mean=2734, std=3239, min_tokens=32, max_tokens=128, mean_tokens=128):
"""
Maps Laplacian values to token lengths using a linear mapping.
Ensures laplacian_value=0 maps to min_tokens, mean maps to mean_tokens,
and higher values scale up to max_tokens.
"""
# Prevent division by zero and handle edge cases
if std <= 0:
return mean_tokens
# Linear mapping from laplacian space to token space
# First normalize laplacian value relative to mean
normalized = (laplacian_value - 0.0) / mean
# Map 0->min_tokens, mean->mean_tokens, and scale up linearly
if laplacian_value <= mean:
# Linear interpolation between min_tokens and mean_tokens
ratio = laplacian_value / mean
token_length = min_tokens + (mean_tokens - min_tokens) * ratio
else:
# Linear interpolation between mean_tokens and max_tokens
ratio = (laplacian_value - mean) / mean # How far past mean
token_length = mean_tokens + (max_tokens - mean_tokens) * ratio
# Clamp to valid range
token_length = max(min_tokens, min(max_tokens, token_length))
return int(round(token_length))
def get_laplacian_attention_mask(sample):
"""Process sample to add Laplacian variance and attention mask."""
# Create a new dict to avoid modifying the input
processed = dict(sample)
# Calculate Laplacian variance
var = calculate_laplacian_variance(processed["image"])
length = get_dynamic_length(var)
# Create attention mask
attention_mask = torch.zeros((128,), dtype=torch.float32)
attention_mask[:length+1] = 1.0
# Add new fields to processed dict
processed["laplacian_var"] = var
processed["attention_mask"] = attention_mask
return processed
def get_uniform_attention_mask(min_tokens=32, max_tokens=128):
"""Process sample to add uniform random attention mask."""
def _f(dictionary):
# Sample length uniformly between min_tokens and max_tokens
length = torch.randint(min_tokens, max_tokens+1, (1,)).item()
# Create attention mask
attention_mask = torch.zeros((max_tokens,), dtype=torch.float32)
attention_mask[:length+1] = 1.0
# Add attention mask to dictionary
dictionary["attention_mask"] = attention_mask
return dictionary
return _f
def process_recap_text(p):
def _f(dictionary):
if "recap_txt" in dictionary:
if random.random() < p:
recap_prefixes = ["The image " + v for v in ['depicts', "displays", 'showcases', 'features', 'shows']]
# Convert input to string and strip whitespace
text = dictionary["recap_txt"].decode("utf-8").strip()
# Check if text starts with any of the phrases
for phrase in recap_prefixes:
if text.startswith(phrase):
# Remove the phrase and any leading/trailing whitespace
text = text[len(phrase):].strip()
# Capitalize the first letter
text = text[0].upper() + text[1:] if text else ""
break
dictionary["text"] = text.encode("utf-8")
return dictionary
return _f
def identity(x):
return x
class ImageTransform:
def __init__(self,
resize_shorter_edge: int = 256,
crop_size: int = 256,
random_crop: bool = True,
random_flip: bool = True,
normalize_mean: List[float] = [0., 0., 0.],
normalize_std: List[float] = [1., 1., 1.]):
"""Initializes the WebDatasetReader with specified augmentation parameters.
Args:
resize_shorter_edge: An integer, the shorter edge size to resize the input image to.
crop_size: An integer, the size to crop the input image to.
random_crop: A boolean, whether to use random crop augmentation during training.
random_flip: A boolean, whether to use random flipping augmentation during training.
normalize_mean: A list of float, the normalization mean used to normalize the image tensor.
normalize_std: A list of float, the normalization std used to normalize the image tensor.
Raises:
NotImplementedError: If the interpolation mode is not one of ["bicubic", "bilinear"].
"""
train_transform = []
interpolation = transforms.InterpolationMode.BICUBIC
train_transform.append(
transforms.Resize(resize_shorter_edge, interpolation=interpolation, antialias=True))
if random_crop:
train_transform.append(transforms.RandomCrop(crop_size))
else:
train_transform.append(transforms.CenterCrop(crop_size))
if random_flip:
train_transform.append(transforms.RandomHorizontalFlip())
train_transform.append(transforms.ToTensor())
# normalize_mean = [0, 0, 0] and normalize_std = [1, 1, 1] will normalize images into [0, 1],
# normalize_mean = [0.5, 0.5, 0.5] and normalize_std = [0.5, 0.5, 0.5] will normalize images into [-1, 1].
train_transform.append(transforms.Normalize(normalize_mean, normalize_std))
self.train_transform = transforms.Compose(train_transform)
self.eval_transform = transforms.Compose(
[
# Note that we always resize to crop_size during eval to ensure the results
# can be compared against reference numbers on ImageNet etc.
transforms.Resize(crop_size, interpolation=interpolation, antialias=True),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize(normalize_mean, normalize_std)
]
)
print(f"self.train_transform: {self.train_transform}")
print(f"self.eval_transform: {self.eval_transform}")
class SimpleImageDataset:
def __init__(
self,
train_shards_path: Union[Text, List[Text]],
eval_shards_path: Union[Text, List[Text]],
num_train_examples: int,
per_gpu_batch_size: int,
global_batch_size: int,
num_workers_per_gpu: int = 12,
resize_shorter_edge: int = 256,
crop_size: int = 256,
random_crop = True,
random_flip = True,
normalize_mean: List[float] = [0., 0., 0.],
normalize_std: List[float] = [1., 1., 1.],
dataset_with_class_label: bool = True,
dataset_with_text_label: bool = False,
res_ratio_filtering = False,
min_tokens = 32,
max_tokens = 128,
):
"""Initializes the WebDatasetReader class.
Args:
train_shards_path: A string or list of string, path to the training data shards in webdataset format.
eval_shards_path: A string or list of string, path to the evaluation data shards in webdataset format.
num_train_examples: An integer, total number of training examples.
per_gpu_batch_size: An integer, number of examples per GPU batch.
global_batch_size: An integer, total number of examples in a batch across all GPUs.
num_workers_per_gpu: An integer, number of workers per GPU.
resize_shorter_edge: An integer, the shorter edge size to resize the input image to.
crop_size: An integer, the size to crop the input image to.
random_crop: A boolean, whether to use random crop augmentation during training.
random_flip: A boolean, whether to use random flipping augmentation during training.
normalize_mean: A list of float, the normalization mean used to normalize the image tensor.
normalize_std: A list of float, the normalization std used to normalize the image tensor.
"""
transform = ImageTransform(
resize_shorter_edge, crop_size, random_crop, random_flip,
normalize_mean, normalize_std)
if dataset_with_class_label:
train_processing_pipeline = [
wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"]), handler=wds.warn_and_continue),
wds.rename(
image="jpg;png;jpeg;webp",
class_id="cls",
handler=wds.warn_and_continue,
),
wds.map(filter_keys(set(["image", "class_id", "filename"]))),
wds.map(get_uniform_attention_mask(min_tokens=min_tokens, max_tokens=max_tokens)),
wds.map_dict(
image=transform.train_transform,
class_id=lambda x: int(x),
attention_mask=lambda x: x,
handler=wds.warn_and_continue,
),
]
elif dataset_with_text_label:
train_processing_pipeline = [
wds.map(load_json),
wds.select(filter_by_res_ratio()) if res_ratio_filtering else wds.map(identity),
wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"]),only=["webp", "png", "jpg", "jpeg", "txt"], handler=wds.warn_and_continue),
wds.rename(
image="jpg;png;jpeg;webp",
text="txt",
handler=wds.warn_and_continue,
),
wds.map(filter_keys(set(["image", "text", "__key__"]))),
wds.map(get_uniform_attention_mask(min_tokens=min_tokens, max_tokens=max_tokens)),
wds.map_dict(
image=transform.train_transform,
attention_mask=lambda x: x,
handler=wds.warn_and_continue,
),
]
else:
raise NotImplementedError
test_processing_pipeline = [
wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"]), handler=wds.warn_and_continue),
wds.rename(
image="jpg;png;jpeg;webp",
class_id="cls",
handler=wds.warn_and_continue,
),
wds.map(filter_keys(set(["image", "class_id", "filename"]))),
wds.map(get_uniform_attention_mask(min_tokens=min_tokens, max_tokens=max_tokens)),
wds.map_dict(
image=transform.eval_transform,
class_id=lambda x: int(x),
# laplacian_var=lambda x: x,
attention_mask=lambda x: x,
handler=wds.warn_and_continue,
),
]
# Create train dataset and loader.
pipeline = [
wds.ResampledShards(train_shards_path),
wds.tarfile_to_samples(handler=wds.warn_and_continue),
wds.shuffle(bufsize=5000,
initial=1000),
*train_processing_pipeline,
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
]
num_batches = math.ceil(num_train_examples / global_batch_size)
num_worker_batches = math.ceil(num_train_examples /
(global_batch_size * num_workers_per_gpu))
num_batches = num_worker_batches * num_workers_per_gpu
num_samples = num_batches * global_batch_size
# Each worker is iterating over the complete dataset.
self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
self._train_dataloader = wds.WebLoader(
self._train_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers_per_gpu,
pin_memory=True,
persistent_workers=True,
)
# Add meta-data to dataloader instance for convenience.
self._train_dataloader.num_batches = num_batches
self._train_dataloader.num_samples = num_samples
# Create eval dataset and loader.
pipeline = [
wds.SimpleShardList(eval_shards_path),
wds.split_by_worker,
wds.tarfile_to_samples(handler=wds.ignore_and_continue),
*test_processing_pipeline,
wds.batched(per_gpu_batch_size, partial=True, collation_fn=default_collate),
]
self._eval_dataset = wds.DataPipeline(*pipeline)
self._eval_dataloader = wds.WebLoader(
self._eval_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers_per_gpu,
pin_memory=True,
persistent_workers=True,
)
@property
def train_dataset(self):
return self._train_dataset
@property
def train_dataloader(self):
return self._train_dataloader
@property
def eval_dataset(self):
return self._eval_dataset
@property
def eval_dataloader(self):
return self._eval_dataloader
class PretoeknizedDataSetJSONL(Dataset):
def __init__(self, data_path):
super().__init__()
self.jsonl_file = data_path
self.num_lines = sum(1 for _ in open(self.jsonl_file))
# Ensure the file is cached
linecache.checkcache(self.jsonl_file)
print("Number of data:", self.num_lines)
def __len__(self):
return self.num_lines
def __getitem__(self, idx):
line = linecache.getline(self.jsonl_file, idx + 1).strip()
data = json.loads(line)
return torch.tensor(data["class_id"]), torch.tensor(data["tokens"])
class PretokenizedWebDataset(SimpleImageDataset):
def __init__ (
self,
train_shards_path: Union[Text, List[Text]],
eval_shards_path: Union[Text, List[Text]],
num_train_examples: int,
per_gpu_batch_size: int,
global_batch_size: int,
num_workers_per_gpu: int,
resize_shorter_edge: int = 256,
crop_size: int = 256,
random_crop = True,
random_flip = True,
normalize_mean: List[float] = [0., 0., 0.],
normalize_std: List[float] = [1., 1., 1.],
process_recap = False,
use_recap_prob = 0.95,
):
"""Initializes the PretokenizedWebDataset class.
Text-to-image datasets are pretokenized with careful filtering (Tab. 7 in Supp.) to speed up the training
"""
transform = ImageTransform(
resize_shorter_edge, crop_size, random_crop, random_flip,
normalize_mean, normalize_std)
def decode_npy(x):
arr = np.frombuffer(x, dtype=np.float16)
ret = torch.tensor(arr)
return ret
def decode_text(x):
ret = x.decode("utf-8")
return ret
train_processing_pipeline = [
wds.rename(
tokens="token.npy",
text="txt",
handler=wds.warn_and_continue,
),
wds.map(process_recap_text(use_recap_prob) if process_recap else wds.map(identity)),
wds.map(filter_keys(set(["tokens", "text", "aes_score", "__key__"]))),
wds.map_dict(
tokens=decode_npy,
text=decode_text,
handler=wds.warn_and_continue,
),
]
test_processing_pipeline = [
wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"])),
wds.rename(
image="jpg;png;jpeg;webp",
handler=wds.warn_and_continue,
),
wds.map_dict(
image=transform.eval_transform,
handler=wds.warn_and_continue,
),
]
# Create train dataset and loader.
pipeline = [
wds.ResampledShards(train_shards_path),
wds.tarfile_to_samples(handler=wds.warn_and_continue),
wds.shuffle(bufsize=5000,
initial=1000),
*train_processing_pipeline,
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
]
num_batches = math.ceil(num_train_examples / global_batch_size)
num_worker_batches = math.ceil(num_train_examples /
(global_batch_size * num_workers_per_gpu))
num_batches = num_worker_batches * num_workers_per_gpu
num_samples = num_batches * global_batch_size
# Each worker is iterating over the complete dataset.
self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
self._train_dataloader = wds.WebLoader(
self._train_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers_per_gpu,
pin_memory=True,
persistent_workers=True,
)
# Add meta-data to dataloader instance for convenience.
self._train_dataloader.num_batches = num_batches
self._train_dataloader.num_samples = num_samples
# Create eval dataset and loader.
pipeline = [
wds.SimpleShardList(eval_shards_path),
wds.split_by_worker,
wds.tarfile_to_samples(handler=wds.ignore_and_continue),
*test_processing_pipeline,
wds.batched(per_gpu_batch_size, partial=True, collation_fn=default_collate),
]
self._eval_dataset = wds.DataPipeline(*pipeline)
self._eval_dataloader = wds.WebLoader(
self._eval_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers_per_gpu,
pin_memory=True,
persistent_workers=True,
)