File size: 6,826 Bytes
7bef20f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | #!/usr/bin/env python3
"""Basic encode-decode example for VibeToken.
Demonstrates how to:
1. Load the tokenizer from config and checkpoint
2. Encode an image to discrete tokens
3. Decode tokens back to an image
4. Save the reconstructed image
Usage:
# Auto mode (recommended)
python examples/encode_decode.py --auto \
--config configs/vibetoken_ll.yaml \
--checkpoint path/to/checkpoint.bin \
--image path/to/image.jpg \
--output reconstructed.png
# Manual mode
python examples/encode_decode.py \
--config configs/vibetoken_ll.yaml \
--checkpoint path/to/checkpoint.bin \
--image path/to/image.jpg \
--output reconstructed.png \
--encoder_patch_size 16,32 \
--decoder_patch_size 16
"""
import argparse
from pathlib import Path
import torch
from PIL import Image
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple
def parse_patch_size(value):
"""Parse patch size from string. Supports single int or tuple (e.g., '16' or '16,32')."""
if value is None:
return None
if ',' in value:
parts = value.split(',')
return (int(parts[0]), int(parts[1]))
return int(value)
def main():
parser = argparse.ArgumentParser(description="VibeToken encode-decode example")
parser.add_argument("--config", type=str, required=True, help="Path to config YAML")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint")
parser.add_argument("--image", type=str, required=True, help="Path to input image")
parser.add_argument("--output", type=str, default="reconstructed.png", help="Output image path")
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
# Auto mode
parser.add_argument("--auto", action="store_true",
help="Auto mode: automatically determine optimal settings")
parser.add_argument("--height", type=int, default=None, help="Output height (default: input height)")
parser.add_argument("--width", type=int, default=None, help="Output width (default: input width)")
parser.add_argument("--encoder_patch_size", type=str, default=None,
help="Encoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
parser.add_argument("--decoder_patch_size", type=str, default=None,
help="Decoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
parser.add_argument("--num_tokens", type=int, default=None, help="Number of tokens to encode")
args = parser.parse_args()
# Check if CUDA is available
if args.device == "cuda" and not torch.cuda.is_available():
print("CUDA not available, falling back to CPU")
args.device = "cpu"
print(f"Loading tokenizer from {args.config}")
tokenizer = VibeTokenTokenizer.from_config(
config_path=args.config,
checkpoint_path=args.checkpoint,
device=args.device,
)
print(f"Tokenizer loaded: codebook_size={tokenizer.codebook_size}, "
f"num_latent_tokens={tokenizer.num_latent_tokens}")
# Load image
print(f"Loading image from {args.image}")
image = Image.open(args.image).convert("RGB")
original_size = image.size # (W, H)
print(f"Original image size: {original_size[0]}x{original_size[1]}")
if args.auto:
# AUTO MODE - use centralized auto_preprocess_image
print("\n=== AUTO MODE ===")
image, patch_size, info = auto_preprocess_image(image, verbose=True)
encoder_patch_size = patch_size
decoder_patch_size = patch_size
height, width = info['cropped_size'][1], info['cropped_size'][0]
print("=================\n")
# Encode to tokens
print("Encoding image to tokens...")
print(f" Using encoder patch size: {encoder_patch_size}")
tokens = tokenizer.encode(image, patch_size=encoder_patch_size)
print(f"Token shape: {tokens.shape}")
# Decode back to image
print(f"Decoding tokens to image ({width}x{height})...")
print(f" Using decoder patch size: {decoder_patch_size}")
reconstructed = tokenizer.decode(
tokens, height=height, width=width, patch_size=decoder_patch_size
)
else:
# MANUAL MODE
# Parse patch sizes
encoder_patch_size = parse_patch_size(args.encoder_patch_size)
decoder_patch_size = parse_patch_size(args.decoder_patch_size)
# Always center crop to ensure dimensions divisible by 32
image = center_crop_to_multiple(image, multiple=32)
cropped_size = image.size # (W, H)
if cropped_size != original_size:
print(f"Center cropped to {cropped_size[0]}x{cropped_size[1]} (divisible by 32)")
# Encode to tokens
print("Encoding image to tokens...")
if encoder_patch_size:
print(f" Using encoder patch size: {encoder_patch_size}")
tokens = tokenizer.encode(image, patch_size=encoder_patch_size, num_tokens=args.num_tokens)
print(f"Token shape: {tokens.shape}")
if tokenizer.model.quantize_mode == "mvq":
print(f" - Batch size: {tokens.shape[0]}")
print(f" - Num codebooks: {tokens.shape[1]}")
print(f" - Sequence length: {tokens.shape[2]}")
else:
print(f" - Batch size: {tokens.shape[0]}")
print(f" - Sequence length: {tokens.shape[1]}")
# Decode back to image (use cropped size as default)
height = args.height or cropped_size[1]
width = args.width or cropped_size[0]
print(f"Decoding tokens to image ({width}x{height})...")
if decoder_patch_size:
print(f" Using decoder patch size: {decoder_patch_size}")
reconstructed = tokenizer.decode(
tokens, height=height, width=width, patch_size=decoder_patch_size
)
print(f"Reconstructed image shape: {reconstructed.shape}")
# Convert to PIL and save
output_images = tokenizer.to_pil(reconstructed)
output_path = Path(args.output)
output_images[0].save(output_path)
print(f"Saved reconstructed image to {output_path}")
# Compute PSNR (compare with cropped image)
import numpy as np
original_np = np.array(image).astype(np.float32)
recon_np = np.array(output_images[0]).astype(np.float32)
if original_np.shape == recon_np.shape:
mse = np.mean((original_np - recon_np) ** 2)
if mse > 0:
psnr = 20 * np.log10(255.0 / np.sqrt(mse))
print(f"PSNR: {psnr:.2f} dB")
if __name__ == "__main__":
main()
|