| | --- |
| | license: apache-2.0 |
| | tags: |
| | - vision |
| | --- |
| | |
| |
|
| | # VisualSplit |
| |
|
| | **VisualSplit** is a ViT-based model that explicitly factorises an image into **classical visual descriptors**—such as **edges**, **color segmentation**, and **grayscale histogram**—and learns to reconstruct the image conditioned on those descriptors. This design yields **interpretable representations** where geometry (edges), albedo/appearance (segmented colors), and global tone (histogram) can be reasoned about or varied independently. |
| |
|
| | > **Training data**: ImageNet-1K. |
| | --- |
| |
|
| | ## Model Description |
| |
|
| | - **Inputs** (at inference): |
| | - An RGB image (for convenience) which is converted to descriptors using the provided `FeatureExtractor` (edges, color segmentation, grayscale histogram). |
| | - **Outputs**: |
| | - A reconstructed RGB image tensor (same spatial size as the model’s training resolution; default `224×224` unless you trained otherwise). |
| |
|
| | --- |
| |
|
| | ## Getting Started (Inference) |
| |
|
| | Below are two ways to run inference with the uploaded `model.safetensors`. |
| |
|
| | ### 1) Minimal PyTorch + safetensors (load state dict) |
| |
|
| | ```python |
| | import torch |
| | from huggingface_hub import hf_hub_download |
| | from safetensors.torch import load_file |
| | |
| | # 1) Import your model & config from the VisualSplit repo |
| | from visualsplit.models.CrossViT import CrossViTForPreTraining, CrossViTConfig |
| | from visualsplit.utils import FeatureExtractor |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | |
| | # 2) Build a config matching your training (edit if you changed widths/depths) |
| | config = CrossViTConfig( |
| | image_size=224, # change if your training size differs |
| | patch_size=16, |
| | # ... any other config fields your repo exposes |
| | ) |
| | |
| | model = CrossViTForPreTraining(config).to(device) |
| | model.eval() |
| | |
| | # 3) Download and load state dict from this model repo |
| | # Replace REPO_ID with your Hugging Face model id, e.g. "HenryQUQ/visualsplit") |
| | ckpt_path = hf_hub_download(repo_id="REPO_ID", filename="model.safetensors") |
| | state_dict = load_file(ckpt_path) |
| | missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| | print("Missing keys:", missing) |
| | print("Unexpected keys:", unexpected) |
| | |
| | # 4) Prepare an input image and extract descriptors |
| | from PIL import Image |
| | from torchvision import transforms |
| | |
| | image = Image.open("input.jpg").convert("RGB") |
| | transform = transforms.Compose([ |
| | transforms.Resize((config.image_size, config.image_size)), |
| | transforms.ToTensor(), |
| | ]) |
| | pixel_values = transform(image).unsqueeze(0).to(device) # (1, 3, H, W) |
| | |
| | # FeatureExtractor provided by the repo should return the required tensors |
| | extractor = FeatureExtractor().to(device) |
| | with torch.no_grad(): |
| | edge, gray_hist, segmented_rgb, _ = extractor(pixel_values) |
| | |
| | # 5) Run inference (reconstruction) |
| | with torch.no_grad(): |
| | outputs = model( |
| | source_edge=edge, |
| | source_gray_level_histogram=gray_hist, |
| | source_segmented_rgb=segmented_rgb, |
| | ) |
| | # Your repo’s forward returns may differ; adjust the key accordingly: |
| | reconstructed = outputs["logits_reshape"] # (1, 3, H, W) |
| | |
| | # 6) Convert to PIL for visualisation |
| | to_pil = transforms.ToPILImage() |
| | recon_img = to_pil(reconstructed.squeeze(0).cpu().clamp(0, 1)) |
| | recon_img.save("reconstructed.png") |
| | print("Saved to reconstructed.png") |
| | ``` |
| |
|
| | ### 2) Reproducing the notebook flow (`notebook/validation.ipynb`) |
| |
|
| | The repository provides a validation notebook that: |
| | 1. Loads the trained model, |
| | 2. Uses `FeatureExtractor` to compute **edges**, **color-segmented RGB**, and **grayscale histograms**, |
| | 3. Runs the model to obtain a reconstructed image, |
| | 4. Saves/visualises the result. |
| |
|
| | --- |
| |
|
| | ## Installation & Requirements |
| |
|
| | ```bash |
| | # clone the VisualSplit code |
| | git clone https://github.com/HenryQUQ/VisualSplit.git |
| | cd VisualSplit |
| | # pip install -e . |
| | ``` |
| |
|
| | --- |
| |
|
| | ## Training Data |
| |
|
| | - **Dataset**: **ImageNet-1K**. |
| | - |
| | > This repository only hosts the **trained checkpoint for inference**. Follow the GitHub repo for the full training pipeline and data preparation scripts. |
| |
|
| | --- |
| |
|
| | ## Model Sources |
| |
|
| | - **Code**: https://github.com/HenryQUQ/VisualSplit |
| | - **Weights (this page)**: this Hugging Face model repo |
| |
|
| | --- |
| |
|
| | ## Citation |
| |
|
| | If you use this model or ideas, please cite: |
| |
|
| | ```bibtex |
| | @inproceedings{Qu2025VisualSplit, |
| | title = {Exploring Image Representation with Decoupled Classical Visual Descriptors}, |
| | author = {Qu, Chenyuan and Chen, Hao and Jiao, Jianbo}, |
| | booktitle = {British Machine Vision Conference (BMVC)}, |
| | year = {2025} |
| | } |
| | ``` |
| |
|
| | --- |