Elastic model: Wan2.2-T2V-A14B

Overview


ElasticModels are the models produced by TheStage AI ANNA: Automated Neural Networks Accelerator. ANNA allows you to control model size, latency and quality with a simple slider movement, routing different compression algorithms to different layers. For each model, we have produced a series of optimized models:

  • S: The fastest model, with accuracy degradation less than 2%.

Models can be accessed via TheStage AI Python SDK: ElasticModels.

Installation


System Requirements

Property Value
GPU H100, B200
Python Version 3.10-3.12
CPU Intel/AMD x86_64
CUDA Version 12.8+

TheStage AI Access token setup

Install TheStage AI CLI and setup API token:

pip install thestage
thestage config set --access-token <YOUR_ACCESS_TOKEN>

ElasticModels installation

Install TheStage Elastic Models package:

pip install thestage
pip install 'thestage-elastic-models[nvidia]' \
    --extra-index-url https://thestage.jfrog.io/artifactory/api/pypi/pypi-thestage-ai-production/simple
pip install diffusers==0.35.1
pip install flash_attn==2.7.3 --no-build-isolation
pip uninstall apex
pip install opencv-python==4.11.0.86 imageio-ffmpeg==0.6.0 imageio ftfy
pip install hf_transfer

Usage example


Elastic Models provides the same interface as HuggingFace Diffusers. To infer our models, just replace diffusers import with elastic_models.diffusers. Here is an example of how to use the Wan2.2-T2V-A14B model:

import torch
from elastic_models.diffusers import WanPipeline
from diffusers.utils import export_to_video

model_name = 'Wan-AI/Wan2.2-T2V-A14B-Diffusers'
device = torch.device("cuda")
dtype = torch.bfloat16

pipe = WanPipeline.from_pretrained(
    model_name,
    torch_dtype=dtype,
    # 'original' for original model
    # 'S' for accelerated model
    mode='S'
)
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
pipe.to(device)

prompt = "A beautiful woman in a red dress dancing"

with torch.no_grad():
    output = pipe(
        prompt=prompt,
        negative_prompt="",
        height=480,
        width=480,
        num_frames=81,
        num_inference_steps=40,
        guidance_scale=3.0,
        guidance_scale_2=2.0,
        generator=torch.Generator("cuda").manual_seed(42),
    )

    video = output.frames[0]
    export_to_video(video, "output.mp4", fps=16)

Compiled versions are currently available only for 81-frame generations at 480x480 resolution. Other versions are not yet accessible. Stay tuned for updates!

Quality Benchmarks


We have used VBench to evaluate the quality of videos generated by different sizes of Wan2.2-T2V-A14B models compared to the original model. The evaluation metrics cover subject consistency, background consistency, motion smoothness, dynamic degree, aesthetic quality, and imaging quality.

Quality Benchmarking

Quality Benchmark Results

Metric/Model Size S M L XL Original
Subject Consistency 0.96 N/A N/A N/A 0.96
Background Consistency 0.96 N/A N/A N/A 0.96
Motion Smoothness 0.98 N/A N/A N/A 0.98
Dynamic Degree 0.29 N/A N/A N/A 0.29
Aesthetic Quality 0.62 N/A N/A N/A 0.62
Imaging Quality 0.68 N/A N/A N/A 0.68

Dataset


  • VBench: A comprehensive benchmark suite for video generative models. It evaluates videos across multiple dimensions of quality including temporal consistency, motion naturalness, aesthetic appeal, and prompt faithfulness. See VBench GitHub for details.

Metrics


  • Subject Consistency: Measures how consistently the main subject appears across frames.
  • Background Consistency: Measures temporal consistency of the background across frames.
  • Motion Smoothness: Evaluates the naturalness and smoothness of motion between frames.
  • Dynamic Degree: Measures the amount of motion/dynamics present in the generated video.
  • Aesthetic Quality: Assesses the overall visual appeal and aesthetic of generated frames.
  • Imaging Quality: Evaluates the per-frame visual quality and sharpness of generated content.

Latency Benchmarks


We have measured the latency of different sizes of Wan2.2-T2V-A14B model (S, original) on various GPUs. The measurements were taken for generating videos at 480x480 resolution with 81 frames.

Latency Benchmarking

Latency Benchmark Results

Latency (in seconds) for generating a 480x480 video with 81 frames on various hardware setups.

GPU/Model Size S M L XL Original
H100 90 N/A N/A N/A 180
B200 99 N/A N/A N/A 117

Benchmarking Methodology


The benchmarking was performed on a single GPU with a batch size of 1. Each model was run for 5 iterations, and the average latency was calculated.

Algorithm summary:

  1. Load the Wan2.2-T2V-A14B model with the specified size (S, original).
  2. Move the model to the GPU.
  3. Prepare a sample prompt for video generation.
  4. Run the model for a number of iterations (e.g., 5) and measure the time taken for each iteration. On each iteration:
    • Synchronize the GPU to flush any previous operations.
    • Record the start time.
    • Generate the video using the model.
    • Synchronize the GPU again.
    • Record the end time and calculate the latency for that iteration.
  5. Calculate the average latency over all iterations.

Reproduce benchmarking


import time
import torch
from elastic_models.diffusers import WanPipeline

model_name = 'Wan-AI/Wan2.2-T2V-A14B-Diffusers'
device = torch.device("cuda")

pipe = WanPipeline.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    mode='S'
)
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
pipe.to(device)

prompt = "A beautiful woman in a red dress dancing"
generate_kwargs = {
    "prompt": prompt,
    "negative_prompt": "",
    "height": 480,
    "width": 480,
    "num_frames": 81,
    "num_inference_steps": 40,
    "guidance_scale": 3.0,
    "guidance_scale_2": 2.0,
}

def evaluate_pipeline():
    torch.cuda.synchronize()
    start_time = time.time()
    with torch.no_grad():
        pipe(**generate_kwargs)
    torch.cuda.synchronize()
    return time.time() - start_time

# Warm-up
for _ in range(2):
    evaluate_pipeline()

# Benchmarking
num_runs = 5
total_time = sum(evaluate_pipeline() for _ in range(num_runs))
average_latency = total_time / num_runs
print(f"Average Latency over {num_runs} runs: {average_latency} seconds")

Links


Downloads last month
225
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for TheStageAI/Elastic-Wan2.2-T2V-A14B-Diffusers

Quantized
(17)
this model

Collection including TheStageAI/Elastic-Wan2.2-T2V-A14B-Diffusers