Buckets:

download
raw
109 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Reduce memory usage&quot;,&quot;local&quot;:&quot;reduce-memory-usage&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Multiple GPUs&quot;,&quot;local&quot;:&quot;multiple-gpus&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Sharded checkpoints&quot;,&quot;local&quot;:&quot;sharded-checkpoints&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Device placement&quot;,&quot;local&quot;:&quot;device-placement&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;VAE slicing&quot;,&quot;local&quot;:&quot;vae-slicing&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;VAE tiling&quot;,&quot;local&quot;:&quot;vae-tiling&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Offloading&quot;,&quot;local&quot;:&quot;offloading&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;CPU offloading&quot;,&quot;local&quot;:&quot;cpu-offloading&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Model offloading&quot;,&quot;local&quot;:&quot;model-offloading&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Group offloading&quot;,&quot;local&quot;:&quot;group-offloading&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;CUDA stream&quot;,&quot;local&quot;:&quot;cuda-stream&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4},{&quot;title&quot;:&quot;Offloading to disk&quot;,&quot;local&quot;:&quot;offloading-to-disk&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4}],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Layerwise casting&quot;,&quot;local&quot;:&quot;layerwise-casting&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;torch.channels_last&quot;,&quot;local&quot;:&quot;torchchannelslast&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;torch.jit.trace&quot;,&quot;local&quot;:&quot;torchjittrace&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Memory-efficient attention&quot;,&quot;local&quot;:&quot;memory-efficient-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/diffusers/pr_11686/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/entry/start.2b9667fb.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/scheduler.8c3d61f6.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/singletons.756349ae.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/index.0997d446.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/paths.8d5937da.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/entry/app.a2a6117e.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/index.da70eac4.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/nodes/0.a31d0923.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/nodes/236.526edae7.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/Tip.1d9b8c37.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/CodeBlock.a9c4becf.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/getInferenceSnippets.d00e08ac.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/HfOption.6ab18950.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Reduce memory usage&quot;,&quot;local&quot;:&quot;reduce-memory-usage&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Multiple GPUs&quot;,&quot;local&quot;:&quot;multiple-gpus&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Sharded checkpoints&quot;,&quot;local&quot;:&quot;sharded-checkpoints&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Device placement&quot;,&quot;local&quot;:&quot;device-placement&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;VAE slicing&quot;,&quot;local&quot;:&quot;vae-slicing&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;VAE tiling&quot;,&quot;local&quot;:&quot;vae-tiling&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Offloading&quot;,&quot;local&quot;:&quot;offloading&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;CPU offloading&quot;,&quot;local&quot;:&quot;cpu-offloading&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Model offloading&quot;,&quot;local&quot;:&quot;model-offloading&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Group offloading&quot;,&quot;local&quot;:&quot;group-offloading&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;CUDA stream&quot;,&quot;local&quot;:&quot;cuda-stream&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4},{&quot;title&quot;:&quot;Offloading to disk&quot;,&quot;local&quot;:&quot;offloading-to-disk&quot;,&quot;sections&quot;:[],&quot;depth&quot;:4}],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Layerwise casting&quot;,&quot;local&quot;:&quot;layerwise-casting&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;torch.channels_last&quot;,&quot;local&quot;:&quot;torchchannelslast&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;torch.jit.trace&quot;,&quot;local&quot;:&quot;torchjittrace&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Memory-efficient attention&quot;,&quot;local&quot;:&quot;memory-efficient-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="reduce-memory-usage" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#reduce-memory-usage"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Reduce memory usage</span></h1> <p data-svelte-h="svelte-11rtgo9">Modern diffusion models like <a href="../api/pipelines/flux">Flux</a> and <a href="../api/pipelines/wan">Wan</a> have billions of parameters that take up a lot of memory on your hardware for inference. This is challenging because common GPUs often don’t have sufficient memory. To overcome the memory limitations, you can use more than one GPU (if available), offload some of the pipeline components to the CPU, and more.</p> <p data-svelte-h="svelte-vxpcfq">This guide will show you how to reduce your memory usage.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1wzvfzx">Keep in mind these techniques may need to be adjusted depending on the model. For example, a transformer-based diffusion model may not benefit equally from these memory optimizations as a UNet-based model.</p></div> <h2 class="relative group"><a id="multiple-gpus" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#multiple-gpus"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Multiple GPUs</span></h2> <p data-svelte-h="svelte-646d08">If you have access to more than one GPU, there a few options for efficiently loading and distributing a large model across your hardware. These features are supported by the <a href="https://huggingface.co/docs/accelerate/index" rel="nofollow">Accelerate</a> library, so make sure it is installed first.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install -U accelerate<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="sharded-checkpoints" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#sharded-checkpoints"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Sharded checkpoints</span></h3> <p data-svelte-h="svelte-z6beln">Loading large checkpoints in several shards in useful because the shards are loaded one at a time. This keeps memory usage low, only requiring enough memory for the model size and the largest shard size. We recommend sharding when the fp32 checkpoint is greater than 5GB. The default shard size is 5GB.</p> <p data-svelte-h="svelte-iy1zzy">Shard a checkpoint in <a href="/docs/diffusers/pr_11686/en/api/pipelines/overview#diffusers.DiffusionPipeline.save_pretrained">save_pretrained()</a> with the <code>max_shard_size</code> parameter.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoModel
unet = AutoModel.from_pretrained(
<span class="hljs-string">&quot;stabilityai/stable-diffusion-xl-base-1.0&quot;</span>, subfolder=<span class="hljs-string">&quot;unet&quot;</span>
)
unet.save_pretrained(<span class="hljs-string">&quot;sdxl-unet-sharded&quot;</span>, max_shard_size=<span class="hljs-string">&quot;5GB&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1f9wrc7">Now you can use the sharded checkpoint, instead of the regular checkpoint, to save memory.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoModel, StableDiffusionXLPipeline
unet = AutoModel.from_pretrained(
<span class="hljs-string">&quot;username/sdxl-unet-sharded&quot;</span>, torch_dtype=torch.float16
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
<span class="hljs-string">&quot;stabilityai/stable-diffusion-xl-base-1.0&quot;</span>,
unet=unet,
torch_dtype=torch.float16
).to(<span class="hljs-string">&quot;cuda&quot;</span>)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="device-placement" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#device-placement"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Device placement</span></h3> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-1tuq9h4">Device placement is an experimental feature and the API may change. Only the <code>balanced</code> strategy is supported at the moment. We plan to support additional mapping strategies in the future.</p></div> <p data-svelte-h="svelte-1n4q4np">The <code>device_map</code> parameter controls how the model components in a pipeline or the layers in an individual model are distributed across devices.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">pipeline level </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">model level </div></div> <div class="language-select"><p data-svelte-h="svelte-13intml">The <code>balanced</code> device placement strategy evenly splits the pipeline across all available devices.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoModel, StableDiffusionXLPipeline
pipeline = StableDiffusionXLPipeline.from_pretrained(
<span class="hljs-string">&quot;stabilityai/stable-diffusion-xl-base-1.0&quot;</span>,
torch_dtype=torch.float16,
device_map=<span class="hljs-string">&quot;balanced&quot;</span>
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-cfrzyx">You can inspect a pipeline’s device map with <code>hf_device_map</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">print</span>(pipeline.hf_device_map)
{<span class="hljs-string">&#x27;unet&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;vae&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;safety_checker&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;text_encoder&#x27;</span>: <span class="hljs-number">0</span>}<!-- HTML_TAG_END --></pre></div> </div> <p data-svelte-h="svelte-2a94z6">When designing your own <code>device_map</code>, it should be a dictionary of a model’s specific module name or layer and a device identifier (an integer for GPUs, <code>cpu</code> for CPUs, and <code>disk</code> for disk).</p> <p data-svelte-h="svelte-c7w8iz">Call <code>hf_device_map</code> on a model to see how model layers are distributed and then design your own.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">print</span>(transformer.hf_device_map)
{<span class="hljs-string">&#x27;pos_embed&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;time_text_embed&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;context_embedder&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;x_embedder&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;transformer_blocks&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.0&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.1&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.2&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.3&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.4&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.5&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.6&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.7&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.8&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.9&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.10&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.11&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.12&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.13&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.14&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.15&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.16&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.17&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.18&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.19&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.20&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.21&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.22&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.23&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.24&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.25&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.26&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.27&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.28&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.29&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.30&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.31&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.32&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.33&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.34&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.35&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.36&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.37&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;norm_out&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;proj_out&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ut2jly">For example, the <code>device_map</code> below places <code>single_transformer_blocks.10</code> through <code>single_transformer_blocks.20</code> on a second GPU (<code>1</code>).</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoModel
device_map = {
<span class="hljs-string">&#x27;pos_embed&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;time_text_embed&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;context_embedder&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;x_embedder&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;transformer_blocks&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.0&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.1&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.2&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.3&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.4&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.5&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.6&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.7&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.8&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.9&#x27;</span>: <span class="hljs-number">0</span>, <span class="hljs-string">&#x27;single_transformer_blocks.10&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.11&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.12&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.13&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.14&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.15&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.16&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.17&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.18&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.19&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.20&#x27;</span>: <span class="hljs-number">1</span>, <span class="hljs-string">&#x27;single_transformer_blocks.21&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.22&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.23&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.24&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.25&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.26&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.27&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.28&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.29&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.30&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.31&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.32&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.33&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.34&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.35&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.36&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;single_transformer_blocks.37&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;norm_out&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>, <span class="hljs-string">&#x27;proj_out&#x27;</span>: <span class="hljs-string">&#x27;cpu&#x27;</span>
}
transformer = AutoModel.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/FLUX.1-dev&quot;</span>,
subfolder=<span class="hljs-string">&quot;transformer&quot;</span>,
device_map=device_map,
torch_dtype=torch.bfloat16
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1oxgib7">Pass a dictionary mapping maximum memory usage to each device to enforce a limit. If a device is not in <code>max_memory</code>, it is ignored and pipeline components won’t be distributed to it.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoModel, StableDiffusionXLPipeline
max_memory = {<span class="hljs-number">0</span>:<span class="hljs-string">&quot;1GB&quot;</span>, <span class="hljs-number">1</span>:<span class="hljs-string">&quot;1GB&quot;</span>}
pipeline = StableDiffusionXLPipeline.from_pretrained(
<span class="hljs-string">&quot;stabilityai/stable-diffusion-xl-base-1.0&quot;</span>,
torch_dtype=torch.float16,
device_map=<span class="hljs-string">&quot;balanced&quot;</span>,
max_memory=max_memory
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-whoqae">Diffusers uses the maxmium memory of all devices by default, but if they don’t fit on the GPUs, then you’ll need to use a single GPU and offload to the CPU with the methods below.</p> <ul data-svelte-h="svelte-1l47qbr"><li><a href="/docs/diffusers/pr_11686/en/api/pipelines/overview#diffusers.DiffusionPipeline.enable_model_cpu_offload">enable_model_cpu_offload()</a> only works on a single GPU but a very large model may not fit on it</li> <li><a href="/docs/diffusers/pr_11686/en/api/pipelines/overview#diffusers.DiffusionPipeline.enable_sequential_cpu_offload">enable_sequential_cpu_offload()</a> may work but it is extremely slow and also limited to a single GPU</li></ul> <p data-svelte-h="svelte-1jikyeo">Use the <a href="/docs/diffusers/pr_11686/en/api/pipelines/overview#diffusers.DiffusionPipeline.reset_device_map">reset_device_map()</a> method to reset the <code>device_map</code>. This is necessary if you want to use methods like <code>.to()</code>, <a href="/docs/diffusers/pr_11686/en/api/pipelines/overview#diffusers.DiffusionPipeline.enable_sequential_cpu_offload">enable_sequential_cpu_offload()</a>, and <a href="/docs/diffusers/pr_11686/en/api/pipelines/overview#diffusers.DiffusionPipeline.enable_model_cpu_offload">enable_model_cpu_offload()</a> on a pipeline that was device-mapped.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pipeline.reset_device_map()<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="vae-slicing" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#vae-slicing"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>VAE slicing</span></h2> <p data-svelte-h="svelte-d36gg7">VAE slicing saves memory by splitting large batches of inputs into a single batch of data and separately processing them. This method works best when generating more than one image at a time.</p> <p data-svelte-h="svelte-5ay0zr">For example, if you’re generating 4 images at once, decoding would increase peak activation memory by 4x. VAE slicing reduces this by only decoding 1 image at a time instead of all 4 images at once.</p> <p data-svelte-h="svelte-uguktm">Call <a href="/docs/diffusers/pr_11686/en/api/pipelines/controlnet#diffusers.StableDiffusionControlNetPipeline.enable_vae_slicing">enable_vae_slicing()</a> to enable sliced VAE. You can expect a small increase in performance when decoding multi-image batches and no performance impact for single-image batches.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoModel, StableDiffusionXLPipeline
pipeline = StableDiffusionXLPipeline.from_pretrained(
<span class="hljs-string">&quot;stabilityai/stable-diffusion-xl-base-1.0&quot;</span>,
torch_dtype=torch.float16,
).to(<span class="hljs-string">&quot;cuda&quot;</span>)
pipeline.enable_vae_slicing()
pipeline([<span class="hljs-string">&quot;An astronaut riding a horse on Mars&quot;</span>]*<span class="hljs-number">32</span>).images[<span class="hljs-number">0</span>]
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Max memory reserved: <span class="hljs-subst">{torch.cuda.max_memory_allocated() / <span class="hljs-number">1024</span>**<span class="hljs-number">3</span>:<span class="hljs-number">.2</span>f}</span> GB&quot;</span>)<!-- HTML_TAG_END --></pre></div> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-js9naq">The <a href="/docs/diffusers/pr_11686/en/api/models/autoencoder_kl_wan#diffusers.AutoencoderKLWan">AutoencoderKLWan</a> and <a href="/docs/diffusers/pr_11686/en/api/models/asymmetricautoencoderkl#diffusers.AsymmetricAutoencoderKL">AsymmetricAutoencoderKL</a> classes don’t support slicing.</p></div> <h2 class="relative group"><a id="vae-tiling" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#vae-tiling"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>VAE tiling</span></h2> <p data-svelte-h="svelte-vsq5zh">VAE tiling saves memory by dividing an image into smaller overlapping tiles instead of processing the entire image at once. This also reduces peak memory usage because the GPU is only processing a tile at a time.</p> <p data-svelte-h="svelte-12g6c34">Call <a href="/docs/diffusers/pr_11686/en/api/pipelines/latent_consistency_models#diffusers.LatentConsistencyModelPipeline.enable_vae_tiling">enable_vae_tiling()</a> to enable VAE tiling. The generated image may have some tone variation from tile-to-tile because they’re decoded separately, but there shouldn’t be any obvious seams between the tiles. Tiling is disabled for resolutions lower than a pre-specified (but configurable) limit. For example, this limit is 512x512 for the VAE in <a href="/docs/diffusers/pr_11686/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline">StableDiffusionPipeline</a>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoPipelineForImage2Image
<span class="hljs-keyword">from</span> diffusers.utils <span class="hljs-keyword">import</span> load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
<span class="hljs-string">&quot;stabilityai/stable-diffusion-xl-base-1.0&quot;</span>, torch_dtype=torch.float16
).to(<span class="hljs-string">&quot;cuda&quot;</span>)
pipeline.enable_vae_tiling()
init_image = load_image(<span class="hljs-string">&quot;https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png&quot;</span>)
prompt = <span class="hljs-string">&quot;Astronaut in a jungle, cold color palette, muted colors, detailed, 8k&quot;</span>
pipeline(prompt, image=init_image, strength=<span class="hljs-number">0.5</span>).images[<span class="hljs-number">0</span>]
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Max memory reserved: <span class="hljs-subst">{torch.cuda.max_memory_allocated() / <span class="hljs-number">1024</span>**<span class="hljs-number">3</span>:<span class="hljs-number">.2</span>f}</span> GB&quot;</span>)<!-- HTML_TAG_END --></pre></div> <blockquote data-svelte-h="svelte-1t4v9zp"><p>[!WARNING][AutoencoderKLWan](/docs/diffusers/pr_11686/en/api/models/autoencoder_kl_wan#diffusers.AutoencoderKLWan) and <a href="/docs/diffusers/pr_11686/en/api/models/asymmetricautoencoderkl#diffusers.AsymmetricAutoencoderKL">AsymmetricAutoencoderKL</a> don’t support tiling.</p></blockquote> <h2 class="relative group"><a id="offloading" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#offloading"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Offloading</span></h2> <p data-svelte-h="svelte-1lghmz8">Offloading strategies move not currently active layers or models to the CPU to avoid increasing GPU memory. These strategies can be combined with quantization and torch.compile to balance inference speed and memory usage.</p> <p data-svelte-h="svelte-sey8xy">Refer to the <a href="./speed-memory-optims">Compile and offloading quantized models</a> guide for more details.</p> <h3 class="relative group"><a id="cpu-offloading" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#cpu-offloading"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>CPU offloading</span></h3> <p data-svelte-h="svelte-nwk8qi">CPU offloading selectively moves weights from the GPU to the CPU. When a component is required, it is transferred to the GPU and when it isn’t required, it is moved to the CPU. This method works on submodules rather than whole models. It saves memory by avoiding storing the entire model on the GPU.</p> <p data-svelte-h="svelte-1jssif2">CPU offloading dramatically reduces memory usage, but it is also <strong>extremely slow</strong> because submodules are passed back and forth multiple times between devices. It can often be impractical due to how slow it is.</p> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-1clymt">Don’t move the pipeline to CUDA before calling <a href="/docs/diffusers/pr_11686/en/api/pipelines/overview#diffusers.DiffusionPipeline.enable_sequential_cpu_offload">enable_sequential_cpu_offload()</a>, otherwise the amount of memory saved is only minimal (refer to this <a href="https://github.com/huggingface/diffusers/issues/1934" rel="nofollow">issue</a> for more details). This is a stateful operation that installs hooks on the model.</p></div> <p data-svelte-h="svelte-7bnuu1">Call <a href="/docs/diffusers/pr_11686/en/api/pipelines/overview#diffusers.DiffusionPipeline.enable_sequential_cpu_offload">enable_sequential_cpu_offload()</a> to enable it on a pipeline.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/FLUX.1-schnell&quot;</span>, torch_dtype=torch.bfloat16
)
pipeline.enable_sequential_cpu_offload()
pipeline(
prompt=<span class="hljs-string">&quot;An astronaut riding a horse on Mars&quot;</span>,
guidance_scale=<span class="hljs-number">0.</span>,
height=<span class="hljs-number">768</span>,
width=<span class="hljs-number">1360</span>,
num_inference_steps=<span class="hljs-number">4</span>,
max_sequence_length=<span class="hljs-number">256</span>,
).images[<span class="hljs-number">0</span>]
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Max memory reserved: <span class="hljs-subst">{torch.cuda.max_memory_allocated() / <span class="hljs-number">1024</span>**<span class="hljs-number">3</span>:<span class="hljs-number">.2</span>f}</span> GB&quot;</span>)<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="model-offloading" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#model-offloading"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Model offloading</span></h3> <p data-svelte-h="svelte-lwbzkf">Model offloading moves entire models to the GPU instead of selectively moving <em>some</em> layers or model components. One of the main pipeline models, usually the text encoder, UNet, and VAE, is placed on the GPU while the other components are held on the CPU. Components like the UNet that run multiple times stays on the GPU until its completely finished and no longer needed. This eliminates the communication overhead of <a href="#cpu-offloading">CPU offloading</a> and makes model offloading a faster alternative. The tradeoff is memory savings won’t be as large.</p> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-fx1eup">Keep in mind that if models are reused outside the pipeline after hookes have been installed (see <a href="https://huggingface.co/docs/accelerate/en/package_reference/big_modeling#accelerate.hooks.remove_hook_from_module" rel="nofollow">Removing Hooks</a> for more details), you need to run the entire pipeline and models in the expected order to properly offload them. This is a stateful operation that installs hooks on the model.</p></div> <p data-svelte-h="svelte-z8i8mp">Call <a href="/docs/diffusers/pr_11686/en/api/pipelines/overview#diffusers.DiffusionPipeline.enable_model_cpu_offload">enable_model_cpu_offload()</a> to enable it on a pipeline.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/FLUX.1-schnell&quot;</span>, torch_dtype=torch.bfloat16
)
pipeline.enable_model_cpu_offload()
pipeline(
prompt=<span class="hljs-string">&quot;An astronaut riding a horse on Mars&quot;</span>,
guidance_scale=<span class="hljs-number">0.</span>,
height=<span class="hljs-number">768</span>,
width=<span class="hljs-number">1360</span>,
num_inference_steps=<span class="hljs-number">4</span>,
max_sequence_length=<span class="hljs-number">256</span>,
).images[<span class="hljs-number">0</span>]
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Max memory reserved: <span class="hljs-subst">{torch.cuda.max_memory_allocated() / <span class="hljs-number">1024</span>**<span class="hljs-number">3</span>:<span class="hljs-number">.2</span>f}</span> GB&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-17t5qn1"><a href="/docs/diffusers/pr_11686/en/api/pipelines/overview#diffusers.DiffusionPipeline.enable_model_cpu_offload">enable_model_cpu_offload()</a> also helps when you’re using the <a href="/docs/diffusers/pr_11686/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline.encode_prompt">encode_prompt()</a> method on its own to generate the text encoders hidden state.</p> <h3 class="relative group"><a id="group-offloading" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#group-offloading"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Group offloading</span></h3> <p data-svelte-h="svelte-6x5rs4">Group offloading moves groups of internal layers (<a href="https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html" rel="nofollow">torch.nn.ModuleList</a> or <a href="https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html" rel="nofollow">torch.nn.Sequential</a>) to the CPU. It uses less memory than <a href="#model-offloading">model offloading</a> and it is faster than <a href="#cpu-offloading">CPU offloading</a> because it reduces communication overhead.</p> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-78t6ea">Group offloading may not work with all models if the forward implementation contains weight-dependent device casting of inputs because it may clash with group offloading’s device casting mechanism.</p></div> <p data-svelte-h="svelte-1q22l88">Call <a href="/docs/diffusers/pr_11686/en/api/models/overview#diffusers.ModelMixin.enable_group_offload">enable_group_offload()</a> to enable it for standard Diffusers model components that inherit from <a href="/docs/diffusers/pr_11686/en/api/models/overview#diffusers.ModelMixin">ModelMixin</a>. For other model components that don’t inherit from <a href="/docs/diffusers/pr_11686/en/api/models/overview#diffusers.ModelMixin">ModelMixin</a>, such as a generic <a href="https://pytorch.org/docs/stable/generated/torch.nn.Module.html" rel="nofollow">torch.nn.Module</a>, use <a href="/docs/diffusers/pr_11686/en/api/utilities#diffusers.hooks.apply_group_offloading">apply_group_offloading()</a> instead.</p> <p data-svelte-h="svelte-1agvgw7">The <code>offload_type</code> parameter can be set to <code>block_level</code> or <code>leaf_level</code>.</p> <ul data-svelte-h="svelte-8dkla6"><li><code>block_level</code> offloads groups of layers based on the <code>num_blocks_per_group</code> parameter. For example, if <code>num_blocks_per_group=2</code> on a model with 40 layers, 2 layers are onloaded and offloaded at a time (20 total onloads/offloads). This drastically reduces memory requirements.</li> <li><code>leaf_level</code> offloads individual layers at the lowest level and is equivalent to <a href="#cpu-offloading">CPU offloading</a>. But it can be made faster if you use streams without giving up inference speed.</li></ul> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> CogVideoXPipeline
<span class="hljs-keyword">from</span> diffusers.hooks <span class="hljs-keyword">import</span> apply_group_offloading
<span class="hljs-keyword">from</span> diffusers.utils <span class="hljs-keyword">import</span> export_to_video
onload_device = torch.device(<span class="hljs-string">&quot;cuda&quot;</span>)
offload_device = torch.device(<span class="hljs-string">&quot;cpu&quot;</span>)
pipeline = CogVideoXPipeline.from_pretrained(<span class="hljs-string">&quot;THUDM/CogVideoX-5b&quot;</span>, torch_dtype=torch.bfloat16)
<span class="hljs-comment"># Use the enable_group_offload method for Diffusers model implementations</span>
pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type=<span class="hljs-string">&quot;leaf_level&quot;</span>)
pipeline.vae.enable_group_offload(onload_device=onload_device, offload_type=<span class="hljs-string">&quot;leaf_level&quot;</span>)
<span class="hljs-comment"># Use the apply_group_offloading method for other model components</span>
apply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type=<span class="hljs-string">&quot;block_level&quot;</span>, num_blocks_per_group=<span class="hljs-number">2</span>)
prompt = (
<span class="hljs-string">&quot;A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. &quot;</span>
<span class="hljs-string">&quot;The panda&#x27;s fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other &quot;</span>
<span class="hljs-string">&quot;pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, &quot;</span>
<span class="hljs-string">&quot;casting a gentle glow on the scene. The panda&#x27;s face is expressive, showing concentration and joy as it plays. &quot;</span>
<span class="hljs-string">&quot;The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical &quot;</span>
<span class="hljs-string">&quot;atmosphere of this unique musical performance.&quot;</span>
)
video = pipeline(prompt=prompt, guidance_scale=<span class="hljs-number">6</span>, num_inference_steps=<span class="hljs-number">50</span>).frames[<span class="hljs-number">0</span>]
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Max memory reserved: <span class="hljs-subst">{torch.cuda.max_memory_allocated() / <span class="hljs-number">1024</span>**<span class="hljs-number">3</span>:<span class="hljs-number">.2</span>f}</span> GB&quot;</span>)
export_to_video(video, <span class="hljs-string">&quot;output.mp4&quot;</span>, fps=<span class="hljs-number">8</span>)<!-- HTML_TAG_END --></pre></div> <h4 class="relative group"><a id="cuda-stream" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#cuda-stream"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>CUDA stream</span></h4> <p data-svelte-h="svelte-1ik3m1u">The <code>use_stream</code> parameter can be activated for CUDA devices that support asynchronous data transfer streams to reduce overall execution time compared to <a href="#cpu-offloading">CPU offloading</a>. It overlaps data transfer and computation by using layer prefetching. The next layer to be executed is loaded onto the GPU while the current layer is still being executed. It can increase CPU memory significantly so ensure you have 2x the amount of memory as the model size.</p> <p data-svelte-h="svelte-f4wc5s">Set <code>record_stream=True</code> for more of a speedup at the cost of slightly increased memory usage. Refer to the <a href="https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html" rel="nofollow">torch.Tensor.record_stream</a> docs to learn more.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1j8tvtw">When <code>use_stream=True</code> on VAEs with tiling enabled, make sure to do a dummy forward pass (possible with dummy inputs as well) before inference to avoid device mismatch errors. This may not work on all implementations, so feel free to open an issue if you encounter any problems.</p></div> <p data-svelte-h="svelte-17oamma">If you’re using <code>block_level</code> group offloading with <code>use_stream</code> enabled, the <code>num_blocks_per_group</code> parameter should be set to <code>1</code>, otherwise a warning will be raised.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type=<span class="hljs-string">&quot;leaf_level&quot;</span>, use_stream=<span class="hljs-literal">True</span>, record_stream=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ixqgvf">The <code>low_cpu_mem_usage</code> parameter can be set to <code>True</code> to reduce CPU memory usage when using streams during group offloading. It is best for <code>leaf_level</code> offloading and when CPU memory is bottlenecked. Memory is saved by creating pinned tensors on the fly instead of pre-pinning them. However, this may increase overall execution time.</p> <h4 class="relative group"><a id="offloading-to-disk" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#offloading-to-disk"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Offloading to disk</span></h4> <p data-svelte-h="svelte-mq1z2c">Group offloading can consume significant system memory depending on the model size. On systems with limited memory, try group offloading onto the disk as a secondary memory.</p> <p data-svelte-h="svelte-1dgn5po">Set the <code>offload_to_disk_path</code> argument in either <a href="/docs/diffusers/pr_11686/en/api/models/overview#diffusers.ModelMixin.enable_group_offload">enable_group_offload()</a> or <a href="/docs/diffusers/pr_11686/en/api/utilities#diffusers.hooks.apply_group_offloading">apply_group_offloading()</a> to offload the model to the disk.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type=<span class="hljs-string">&quot;leaf_level&quot;</span>, offload_to_disk_path=<span class="hljs-string">&quot;path/to/disk&quot;</span>)
apply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type=<span class="hljs-string">&quot;block_level&quot;</span>, num_blocks_per_group=<span class="hljs-number">2</span>, offload_to_disk_path=<span class="hljs-string">&quot;path/to/disk&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ugmkgt">Refer to these <a href="https://github.com/huggingface/diffusers/pull/11682#issue-3129365363" rel="nofollow">two</a> <a href="https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126" rel="nofollow">tables</a> to compare the speed and memory trade-offs.</p> <h2 class="relative group"><a id="layerwise-casting" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#layerwise-casting"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Layerwise casting</span></h2> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1u3ymti">Combine layerwise casting with <a href="#group-offloading">group offloading</a> for even more memory savings.</p></div> <p data-svelte-h="svelte-1boiobn">Layerwise casting stores weights in a smaller data format (for example, <code>torch.float8_e4m3fn</code> and <code>torch.float8_e5m2</code>) to use less memory and upcasts those weights to a higher precision like <code>torch.float16</code> or <code>torch.bfloat16</code> for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.</p> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-nfpqdo">Layerwise casting may not work with all models if the forward implementation contains internal typecasting of weights. The current implementation of layerwise casting assumes the forward pass is independent of the weight precision and the input datatypes are always specified in <code>compute_dtype</code> (see <a href="https://github.com/huggingface/transformers/blob/7f5077e53682ca855afc826162b204ebf809f1f9/src/transformers/models/t5/modeling_t5.py#L294-L299" rel="nofollow">here</a> for an incompatible implementation).</p> <p data-svelte-h="svelte-5074v9">Layerwise casting may also fail on custom modeling implementations with <a href="https://huggingface.co/docs/peft/index" rel="nofollow">PEFT</a> layers. There are some checks available but they are not extensively tested or guaranteed to work in all cases.</p></div> <p data-svelte-h="svelte-1p9c692">Call <a href="/docs/diffusers/pr_11686/en/api/models/overview#diffusers.ModelMixin.enable_layerwise_casting">enable_layerwise_casting()</a> to set the storage and computation datatypes.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> CogVideoXPipeline, CogVideoXTransformer3DModel
<span class="hljs-keyword">from</span> diffusers.utils <span class="hljs-keyword">import</span> export_to_video
transformer = CogVideoXTransformer3DModel.from_pretrained(
<span class="hljs-string">&quot;THUDM/CogVideoX-5b&quot;</span>,
subfolder=<span class="hljs-string">&quot;transformer&quot;</span>,
torch_dtype=torch.bfloat16
)
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipeline = CogVideoXPipeline.from_pretrained(<span class="hljs-string">&quot;THUDM/CogVideoX-5b&quot;</span>,
transformer=transformer,
torch_dtype=torch.bfloat16
).to(<span class="hljs-string">&quot;cuda&quot;</span>)
prompt = (
<span class="hljs-string">&quot;A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. &quot;</span>
<span class="hljs-string">&quot;The panda&#x27;s fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other &quot;</span>
<span class="hljs-string">&quot;pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, &quot;</span>
<span class="hljs-string">&quot;casting a gentle glow on the scene. The panda&#x27;s face is expressive, showing concentration and joy as it plays. &quot;</span>
<span class="hljs-string">&quot;The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical &quot;</span>
<span class="hljs-string">&quot;atmosphere of this unique musical performance.&quot;</span>
)
video = pipeline(prompt=prompt, guidance_scale=<span class="hljs-number">6</span>, num_inference_steps=<span class="hljs-number">50</span>).frames[<span class="hljs-number">0</span>]
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Max memory reserved: <span class="hljs-subst">{torch.cuda.max_memory_allocated() / <span class="hljs-number">1024</span>**<span class="hljs-number">3</span>:<span class="hljs-number">.2</span>f}</span> GB&quot;</span>)
export_to_video(video, <span class="hljs-string">&quot;output.mp4&quot;</span>, fps=<span class="hljs-number">8</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-g24bcl">The <a href="/docs/diffusers/pr_11686/en/api/utilities#diffusers.hooks.apply_layerwise_casting">apply_layerwise_casting()</a> method can also be used if you need more control and flexibility. It can be partially applied to model layers by calling it on specific internal modules. Use the <code>skip_modules_pattern</code> or <code>skip_modules_classes</code> parameters to specify modules to avoid, such as the normalization and modulation layers.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> CogVideoXTransformer3DModel
<span class="hljs-keyword">from</span> diffusers.hooks <span class="hljs-keyword">import</span> apply_layerwise_casting
transformer = CogVideoXTransformer3DModel.from_pretrained(
<span class="hljs-string">&quot;THUDM/CogVideoX-5b&quot;</span>,
subfolder=<span class="hljs-string">&quot;transformer&quot;</span>,
torch_dtype=torch.bfloat16
)
<span class="hljs-comment"># skip the normalization layer</span>
apply_layerwise_casting(
transformer,
storage_dtype=torch.float8_e4m3fn,
compute_dtype=torch.bfloat16,
skip_modules_classes=[<span class="hljs-string">&quot;norm&quot;</span>],
non_blocking=<span class="hljs-literal">True</span>,
)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="torchchannelslast" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#torchchannelslast"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>torch.channels_last</span></h2> <p data-svelte-h="svelte-1tu261v"><a href="https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html" rel="nofollow">torch.channels_last</a> flips how tensors are stored from <code>(batch size, channels, height, width)</code> to <code>(batch size, heigh, width, channels)</code>. This aligns the tensors with how the hardware sequentially accesses the tensors stored in memory and avoids skipping around in memory to access the pixel values.</p> <p data-svelte-h="svelte-dfwok5">Not all operators currently support the channels-last format and may result in worst performance, but it is still worth trying.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">print</span>(pipeline.unet.conv_out.state_dict()[<span class="hljs-string">&quot;weight&quot;</span>].stride()) <span class="hljs-comment"># (2880, 9, 3, 1)</span>
pipeline.unet.to(memory_format=torch.channels_last) <span class="hljs-comment"># in-place operation</span>
<span class="hljs-built_in">print</span>(
pipeline.unet.conv_out.state_dict()[<span class="hljs-string">&quot;weight&quot;</span>].stride()
) <span class="hljs-comment"># (2880, 1, 960, 320) having a stride of 1 for the 2nd dimension proves that it works</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="torchjittrace" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#torchjittrace"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>torch.jit.trace</span></h2> <p data-svelte-h="svelte-15jhfav"><a href="https://pytorch.org/docs/stable/generated/torch.jit.trace.html" rel="nofollow">torch.jit.trace</a> records the operations a model performs on a sample input and creates a new, optimized representation of the model based on the recorded execution path. During tracing, the model is optimized to reduce overhead from Python and dynamic control flows and operations are fused together for more efficiency. The returned executable or <a href="https://pytorch.org/docs/stable/generated/torch.jit.ScriptFunction.html" rel="nofollow">ScriptFunction</a> can be compiled.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> time
<span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> StableDiffusionPipeline
<span class="hljs-keyword">import</span> functools
<span class="hljs-comment"># torch disable grad</span>
torch.set_grad_enabled(<span class="hljs-literal">False</span>)
<span class="hljs-comment"># set variables</span>
n_experiments = <span class="hljs-number">2</span>
unet_runs_per_experiment = <span class="hljs-number">50</span>
<span class="hljs-comment"># load sample inputs</span>
<span class="hljs-keyword">def</span> <span class="hljs-title function_">generate_inputs</span>():
sample = torch.randn((<span class="hljs-number">2</span>, <span class="hljs-number">4</span>, <span class="hljs-number">64</span>, <span class="hljs-number">64</span>), device=<span class="hljs-string">&quot;cuda&quot;</span>, dtype=torch.float16)
timestep = torch.rand(<span class="hljs-number">1</span>, device=<span class="hljs-string">&quot;cuda&quot;</span>, dtype=torch.float16) * <span class="hljs-number">999</span>
encoder_hidden_states = torch.randn((<span class="hljs-number">2</span>, <span class="hljs-number">77</span>, <span class="hljs-number">768</span>), device=<span class="hljs-string">&quot;cuda&quot;</span>, dtype=torch.float16)
<span class="hljs-keyword">return</span> sample, timestep, encoder_hidden_states
pipeline = StableDiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;stable-diffusion-v1-5/stable-diffusion-v1-5&quot;</span>,
torch_dtype=torch.float16,
use_safetensors=<span class="hljs-literal">True</span>,
).to(<span class="hljs-string">&quot;cuda&quot;</span>)
unet = pipeline.unet
unet.<span class="hljs-built_in">eval</span>()
unet.to(memory_format=torch.channels_last) <span class="hljs-comment"># use channels_last memory format</span>
unet.forward = functools.partial(unet.forward, return_dict=<span class="hljs-literal">False</span>) <span class="hljs-comment"># set return_dict=False as default</span>
<span class="hljs-comment"># warmup</span>
<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">3</span>):
<span class="hljs-keyword">with</span> torch.inference_mode():
inputs = generate_inputs()
orig_output = unet(*inputs)
<span class="hljs-comment"># trace</span>
<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;tracing..&quot;</span>)
unet_traced = torch.jit.trace(unet, inputs)
unet_traced.<span class="hljs-built_in">eval</span>()
<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;done tracing&quot;</span>)
<span class="hljs-comment"># warmup and optimize graph</span>
<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">5</span>):
<span class="hljs-keyword">with</span> torch.inference_mode():
inputs = generate_inputs()
orig_output = unet_traced(*inputs)
<span class="hljs-comment"># benchmarking</span>
<span class="hljs-keyword">with</span> torch.inference_mode():
<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(n_experiments):
torch.cuda.synchronize()
start_time = time.time()
<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(unet_runs_per_experiment):
orig_output = unet_traced(*inputs)
torch.cuda.synchronize()
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;unet traced inference took <span class="hljs-subst">{time.time() - start_time:<span class="hljs-number">.2</span>f}</span> seconds&quot;</span>)
<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(n_experiments):
torch.cuda.synchronize()
start_time = time.time()
<span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(unet_runs_per_experiment):
orig_output = unet(*inputs)
torch.cuda.synchronize()
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;unet inference took <span class="hljs-subst">{time.time() - start_time:<span class="hljs-number">.2</span>f}</span> seconds&quot;</span>)
<span class="hljs-comment"># save the model</span>
unet_traced.save(<span class="hljs-string">&quot;unet_traced.pt&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-16t9h37">Replace the pipeline’s UNet with the traced version.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> StableDiffusionPipeline
<span class="hljs-keyword">from</span> dataclasses <span class="hljs-keyword">import</span> dataclass
<span class="hljs-meta">@dataclass</span>
<span class="hljs-keyword">class</span> <span class="hljs-title class_">UNet2DConditionOutput</span>:
sample: torch.Tensor
pipeline = StableDiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;stable-diffusion-v1-5/stable-diffusion-v1-5&quot;</span>,
torch_dtype=torch.float16,
use_safetensors=<span class="hljs-literal">True</span>,
).to(<span class="hljs-string">&quot;cuda&quot;</span>)
<span class="hljs-comment"># use jitted unet</span>
unet_traced = torch.jit.load(<span class="hljs-string">&quot;unet_traced.pt&quot;</span>)
<span class="hljs-comment"># del pipeline.unet</span>
<span class="hljs-keyword">class</span> <span class="hljs-title class_">TracedUNet</span>(torch.nn.Module):
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self</span>):
<span class="hljs-built_in">super</span>().__init__()
self.in_channels = pipe.unet.config.in_channels
self.device = pipe.unet.device
<span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">self, latent_model_input, t, encoder_hidden_states</span>):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[<span class="hljs-number">0</span>]
<span class="hljs-keyword">return</span> UNet2DConditionOutput(sample=sample)
pipeline.unet = TracedUNet()
<span class="hljs-keyword">with</span> torch.inference_mode():
image = pipe([prompt] * <span class="hljs-number">1</span>, num_inference_steps=<span class="hljs-number">50</span>).images[<span class="hljs-number">0</span>]<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="memory-efficient-attention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#memory-efficient-attention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Memory-efficient attention</span></h2> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-fs3wzo">Memory-efficient attention optimizes for memory usage <em>and</em> <a href="./fp16#scaled-dot-product-attention">inference speed</a>!</p></div> <p data-svelte-h="svelte-8q1sue">The Transformers attention mechanism is memory-intensive, especially for long sequences, so you can try using different and more memory-efficient attention types.</p> <p data-svelte-h="svelte-1idpc42">By default, if PyTorch &gt;= 2.0 is installed, <a href="https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html" rel="nofollow">scaled dot-product attention (SDPA)</a> is used. You don’t need to make any additional changes to your code.</p> <p data-svelte-h="svelte-4p44md">SDPA supports <a href="https://github.com/Dao-AILab/flash-attention" rel="nofollow">FlashAttention</a> and <a href="https://github.com/facebookresearch/xformers" rel="nofollow">xFormers</a> as well as a native C++ PyTorch implementation. It automatically selects the most optimal implementation based on your input.</p> <p data-svelte-h="svelte-1e5e031">You can explicitly use xFormers with the <a href="/docs/diffusers/pr_11686/en/api/models/overview#diffusers.ModelMixin.enable_xformers_memory_efficient_attention">enable_xformers_memory_efficient_attention()</a> method.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># pip install xformers</span>
<span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> StableDiffusionXLPipeline
pipeline = StableDiffusionXLPipeline.from_pretrained(
<span class="hljs-string">&quot;stabilityai/stable-diffusion-xl-base-1.0&quot;</span>,
torch_dtype=torch.float16,
).to(<span class="hljs-string">&quot;cuda&quot;</span>)
pipeline.enable_xformers_memory_efficient_attention()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-ipqgja">Call <a href="/docs/diffusers/pr_11686/en/api/models/overview#diffusers.ModelMixin.disable_xformers_memory_efficient_attention">disable_xformers_memory_efficient_attention()</a> to disable it.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pipeline.disable_xformers_memory_efficient_attention()<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/diffusers/blob/main/docs/source/en/optimization/memory.md" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_o82a48 = {
assets: "/docs/diffusers/pr_11686/en",
base: "/docs/diffusers/pr_11686/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/diffusers/pr_11686/en/_app/immutable/entry/start.2b9667fb.js"),
import("/docs/diffusers/pr_11686/en/_app/immutable/entry/app.a2a6117e.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 236],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
109 kB
·
Xet hash:
b66ee61f45564950322bcfe57b8120339293767c3192937dc5ee3af8cf861ea7

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.