VibeToken / examples /encode_decode.py
APGASU's picture
scripts
7bef20f verified
#!/usr/bin/env python3
"""Basic encode-decode example for VibeToken.
Demonstrates how to:
1. Load the tokenizer from config and checkpoint
2. Encode an image to discrete tokens
3. Decode tokens back to an image
4. Save the reconstructed image
Usage:
# Auto mode (recommended)
python examples/encode_decode.py --auto \
--config configs/vibetoken_ll.yaml \
--checkpoint path/to/checkpoint.bin \
--image path/to/image.jpg \
--output reconstructed.png
# Manual mode
python examples/encode_decode.py \
--config configs/vibetoken_ll.yaml \
--checkpoint path/to/checkpoint.bin \
--image path/to/image.jpg \
--output reconstructed.png \
--encoder_patch_size 16,32 \
--decoder_patch_size 16
"""
import argparse
from pathlib import Path
import torch
from PIL import Image
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple
def parse_patch_size(value):
"""Parse patch size from string. Supports single int or tuple (e.g., '16' or '16,32')."""
if value is None:
return None
if ',' in value:
parts = value.split(',')
return (int(parts[0]), int(parts[1]))
return int(value)
def main():
parser = argparse.ArgumentParser(description="VibeToken encode-decode example")
parser.add_argument("--config", type=str, required=True, help="Path to config YAML")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint")
parser.add_argument("--image", type=str, required=True, help="Path to input image")
parser.add_argument("--output", type=str, default="reconstructed.png", help="Output image path")
parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)")
# Auto mode
parser.add_argument("--auto", action="store_true",
help="Auto mode: automatically determine optimal settings")
parser.add_argument("--height", type=int, default=None, help="Output height (default: input height)")
parser.add_argument("--width", type=int, default=None, help="Output width (default: input width)")
parser.add_argument("--encoder_patch_size", type=str, default=None,
help="Encoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
parser.add_argument("--decoder_patch_size", type=str, default=None,
help="Decoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)")
parser.add_argument("--num_tokens", type=int, default=None, help="Number of tokens to encode")
args = parser.parse_args()
# Check if CUDA is available
if args.device == "cuda" and not torch.cuda.is_available():
print("CUDA not available, falling back to CPU")
args.device = "cpu"
print(f"Loading tokenizer from {args.config}")
tokenizer = VibeTokenTokenizer.from_config(
config_path=args.config,
checkpoint_path=args.checkpoint,
device=args.device,
)
print(f"Tokenizer loaded: codebook_size={tokenizer.codebook_size}, "
f"num_latent_tokens={tokenizer.num_latent_tokens}")
# Load image
print(f"Loading image from {args.image}")
image = Image.open(args.image).convert("RGB")
original_size = image.size # (W, H)
print(f"Original image size: {original_size[0]}x{original_size[1]}")
if args.auto:
# AUTO MODE - use centralized auto_preprocess_image
print("\n=== AUTO MODE ===")
image, patch_size, info = auto_preprocess_image(image, verbose=True)
encoder_patch_size = patch_size
decoder_patch_size = patch_size
height, width = info['cropped_size'][1], info['cropped_size'][0]
print("=================\n")
# Encode to tokens
print("Encoding image to tokens...")
print(f" Using encoder patch size: {encoder_patch_size}")
tokens = tokenizer.encode(image, patch_size=encoder_patch_size)
print(f"Token shape: {tokens.shape}")
# Decode back to image
print(f"Decoding tokens to image ({width}x{height})...")
print(f" Using decoder patch size: {decoder_patch_size}")
reconstructed = tokenizer.decode(
tokens, height=height, width=width, patch_size=decoder_patch_size
)
else:
# MANUAL MODE
# Parse patch sizes
encoder_patch_size = parse_patch_size(args.encoder_patch_size)
decoder_patch_size = parse_patch_size(args.decoder_patch_size)
# Always center crop to ensure dimensions divisible by 32
image = center_crop_to_multiple(image, multiple=32)
cropped_size = image.size # (W, H)
if cropped_size != original_size:
print(f"Center cropped to {cropped_size[0]}x{cropped_size[1]} (divisible by 32)")
# Encode to tokens
print("Encoding image to tokens...")
if encoder_patch_size:
print(f" Using encoder patch size: {encoder_patch_size}")
tokens = tokenizer.encode(image, patch_size=encoder_patch_size, num_tokens=args.num_tokens)
print(f"Token shape: {tokens.shape}")
if tokenizer.model.quantize_mode == "mvq":
print(f" - Batch size: {tokens.shape[0]}")
print(f" - Num codebooks: {tokens.shape[1]}")
print(f" - Sequence length: {tokens.shape[2]}")
else:
print(f" - Batch size: {tokens.shape[0]}")
print(f" - Sequence length: {tokens.shape[1]}")
# Decode back to image (use cropped size as default)
height = args.height or cropped_size[1]
width = args.width or cropped_size[0]
print(f"Decoding tokens to image ({width}x{height})...")
if decoder_patch_size:
print(f" Using decoder patch size: {decoder_patch_size}")
reconstructed = tokenizer.decode(
tokens, height=height, width=width, patch_size=decoder_patch_size
)
print(f"Reconstructed image shape: {reconstructed.shape}")
# Convert to PIL and save
output_images = tokenizer.to_pil(reconstructed)
output_path = Path(args.output)
output_images[0].save(output_path)
print(f"Saved reconstructed image to {output_path}")
# Compute PSNR (compare with cropped image)
import numpy as np
original_np = np.array(image).astype(np.float32)
recon_np = np.array(output_images[0]).astype(np.float32)
if original_np.shape == recon_np.shape:
mse = np.mean((original_np - recon_np) ** 2)
if mse > 0:
psnr = 20 * np.log10(255.0 / np.sqrt(mse))
print(f"PSNR: {psnr:.2f} dB")
if __name__ == "__main__":
main()