| | |
| | """Basic encode-decode example for VibeToken. |
| | |
| | Demonstrates how to: |
| | 1. Load the tokenizer from config and checkpoint |
| | 2. Encode an image to discrete tokens |
| | 3. Decode tokens back to an image |
| | 4. Save the reconstructed image |
| | |
| | Usage: |
| | # Auto mode (recommended) |
| | python examples/encode_decode.py --auto \ |
| | --config configs/vibetoken_ll.yaml \ |
| | --checkpoint path/to/checkpoint.bin \ |
| | --image path/to/image.jpg \ |
| | --output reconstructed.png |
| | |
| | # Manual mode |
| | python examples/encode_decode.py \ |
| | --config configs/vibetoken_ll.yaml \ |
| | --checkpoint path/to/checkpoint.bin \ |
| | --image path/to/image.jpg \ |
| | --output reconstructed.png \ |
| | --encoder_patch_size 16,32 \ |
| | --decoder_patch_size 16 |
| | """ |
| |
|
| | import argparse |
| | from pathlib import Path |
| |
|
| | import torch |
| | from PIL import Image |
| |
|
| | import sys |
| | sys.path.insert(0, str(Path(__file__).parent.parent)) |
| |
|
| | from vibetoken import VibeTokenTokenizer, auto_preprocess_image, center_crop_to_multiple |
| |
|
| |
|
| | def parse_patch_size(value): |
| | """Parse patch size from string. Supports single int or tuple (e.g., '16' or '16,32').""" |
| | if value is None: |
| | return None |
| | if ',' in value: |
| | parts = value.split(',') |
| | return (int(parts[0]), int(parts[1])) |
| | return int(value) |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="VibeToken encode-decode example") |
| | parser.add_argument("--config", type=str, required=True, help="Path to config YAML") |
| | parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint") |
| | parser.add_argument("--image", type=str, required=True, help="Path to input image") |
| | parser.add_argument("--output", type=str, default="reconstructed.png", help="Output image path") |
| | parser.add_argument("--device", type=str, default="cuda", help="Device (cuda/cpu)") |
| | |
| | |
| | parser.add_argument("--auto", action="store_true", |
| | help="Auto mode: automatically determine optimal settings") |
| | |
| | parser.add_argument("--height", type=int, default=None, help="Output height (default: input height)") |
| | parser.add_argument("--width", type=int, default=None, help="Output width (default: input width)") |
| | parser.add_argument("--encoder_patch_size", type=str, default=None, |
| | help="Encoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)") |
| | parser.add_argument("--decoder_patch_size", type=str, default=None, |
| | help="Decoder patch size: single int (e.g., 16) or tuple (e.g., 16,32 for H,W)") |
| | parser.add_argument("--num_tokens", type=int, default=None, help="Number of tokens to encode") |
| |
|
| | args = parser.parse_args() |
| |
|
| | |
| | if args.device == "cuda" and not torch.cuda.is_available(): |
| | print("CUDA not available, falling back to CPU") |
| | args.device = "cpu" |
| |
|
| | print(f"Loading tokenizer from {args.config}") |
| | tokenizer = VibeTokenTokenizer.from_config( |
| | config_path=args.config, |
| | checkpoint_path=args.checkpoint, |
| | device=args.device, |
| | ) |
| | print(f"Tokenizer loaded: codebook_size={tokenizer.codebook_size}, " |
| | f"num_latent_tokens={tokenizer.num_latent_tokens}") |
| |
|
| | |
| | print(f"Loading image from {args.image}") |
| | image = Image.open(args.image).convert("RGB") |
| | original_size = image.size |
| | print(f"Original image size: {original_size[0]}x{original_size[1]}") |
| |
|
| | if args.auto: |
| | |
| | print("\n=== AUTO MODE ===") |
| | image, patch_size, info = auto_preprocess_image(image, verbose=True) |
| | encoder_patch_size = patch_size |
| | decoder_patch_size = patch_size |
| | height, width = info['cropped_size'][1], info['cropped_size'][0] |
| | print("=================\n") |
| | |
| | |
| | print("Encoding image to tokens...") |
| | print(f" Using encoder patch size: {encoder_patch_size}") |
| | tokens = tokenizer.encode(image, patch_size=encoder_patch_size) |
| | print(f"Token shape: {tokens.shape}") |
| | |
| | |
| | print(f"Decoding tokens to image ({width}x{height})...") |
| | print(f" Using decoder patch size: {decoder_patch_size}") |
| | reconstructed = tokenizer.decode( |
| | tokens, height=height, width=width, patch_size=decoder_patch_size |
| | ) |
| | |
| | else: |
| | |
| | |
| | encoder_patch_size = parse_patch_size(args.encoder_patch_size) |
| | decoder_patch_size = parse_patch_size(args.decoder_patch_size) |
| |
|
| | |
| | image = center_crop_to_multiple(image, multiple=32) |
| | cropped_size = image.size |
| | if cropped_size != original_size: |
| | print(f"Center cropped to {cropped_size[0]}x{cropped_size[1]} (divisible by 32)") |
| |
|
| | |
| | print("Encoding image to tokens...") |
| | if encoder_patch_size: |
| | print(f" Using encoder patch size: {encoder_patch_size}") |
| | tokens = tokenizer.encode(image, patch_size=encoder_patch_size, num_tokens=args.num_tokens) |
| | print(f"Token shape: {tokens.shape}") |
| | |
| | if tokenizer.model.quantize_mode == "mvq": |
| | print(f" - Batch size: {tokens.shape[0]}") |
| | print(f" - Num codebooks: {tokens.shape[1]}") |
| | print(f" - Sequence length: {tokens.shape[2]}") |
| | else: |
| | print(f" - Batch size: {tokens.shape[0]}") |
| | print(f" - Sequence length: {tokens.shape[1]}") |
| |
|
| | |
| | height = args.height or cropped_size[1] |
| | width = args.width or cropped_size[0] |
| | print(f"Decoding tokens to image ({width}x{height})...") |
| | if decoder_patch_size: |
| | print(f" Using decoder patch size: {decoder_patch_size}") |
| | |
| | reconstructed = tokenizer.decode( |
| | tokens, height=height, width=width, patch_size=decoder_patch_size |
| | ) |
| | |
| | print(f"Reconstructed image shape: {reconstructed.shape}") |
| |
|
| | |
| | output_images = tokenizer.to_pil(reconstructed) |
| | output_path = Path(args.output) |
| | output_images[0].save(output_path) |
| | print(f"Saved reconstructed image to {output_path}") |
| |
|
| | |
| | import numpy as np |
| | original_np = np.array(image).astype(np.float32) |
| | recon_np = np.array(output_images[0]).astype(np.float32) |
| | if original_np.shape == recon_np.shape: |
| | mse = np.mean((original_np - recon_np) ** 2) |
| | if mse > 0: |
| | psnr = 20 * np.log10(255.0 / np.sqrt(mse)) |
| | print(f"PSNR: {psnr:.2f} dB") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|