Buckets:

hf-doc-build/doc-dev / diffusers /pr_12249 /en /training /distributed_inference.html
rtrm's picture
download
raw
62.6 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Distributed inference&quot;,&quot;local&quot;:&quot;distributed-inference&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Accelerate&quot;,&quot;local&quot;:&quot;accelerate&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;PyTorch Distributed&quot;,&quot;local&quot;:&quot;pytorch-distributed&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;device_map&quot;,&quot;local&quot;:&quot;devicemap&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Context parallelism&quot;,&quot;local&quot;:&quot;context-parallelism&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Ring Attention&quot;,&quot;local&quot;:&quot;ring-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Ulysses Attention&quot;,&quot;local&quot;:&quot;ulysses-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Unified Attention&quot;,&quot;local&quot;:&quot;unified-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;parallel_config&quot;,&quot;local&quot;:&quot;parallelconfig&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/diffusers/pr_12249/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/diffusers/pr_12249/en/_app/immutable/entry/start.d78ccb7e.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12249/en/_app/immutable/chunks/scheduler.53228c21.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12249/en/_app/immutable/chunks/singletons.7109d99c.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12249/en/_app/immutable/chunks/index.e93d0901.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12249/en/_app/immutable/chunks/paths.85ae17be.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12249/en/_app/immutable/entry/app.cd68c46f.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12249/en/_app/immutable/chunks/preload-helper.20fd9bc3.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12249/en/_app/immutable/chunks/index.100fac89.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12249/en/_app/immutable/nodes/0.78828106.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12249/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12249/en/_app/immutable/nodes/308.27b51b91.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12249/en/_app/immutable/chunks/CopyLLMTxtMenu.133e28e0.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12249/en/_app/immutable/chunks/globals.7f7f1b26.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12249/en/_app/immutable/chunks/IconCopy.38cf8f56.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12249/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.d8195636.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12249/en/_app/immutable/chunks/CodeBlock.d30a6509.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Distributed inference&quot;,&quot;local&quot;:&quot;distributed-inference&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Accelerate&quot;,&quot;local&quot;:&quot;accelerate&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;PyTorch Distributed&quot;,&quot;local&quot;:&quot;pytorch-distributed&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;device_map&quot;,&quot;local&quot;:&quot;devicemap&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Context parallelism&quot;,&quot;local&quot;:&quot;context-parallelism&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Ring Attention&quot;,&quot;local&quot;:&quot;ring-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Ulysses Attention&quot;,&quot;local&quot;:&quot;ulysses-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Unified Attention&quot;,&quot;local&quot;:&quot;unified-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;parallel_config&quot;,&quot;local&quot;:&quot;parallelconfig&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 h-7 max-sm:h-7 px-2 max-sm:px-1.5 text-sm font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0 hover:text-gray-800 dark:hover:text-gray-200"><svg class="sm:size-3.5 size-3" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-7 max-sm:h-7 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible sm:size-3.5 size-3 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="distributed-inference" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#distributed-inference"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Distributed inference</span></h1> <p data-svelte-h="svelte-eljq9f">Distributed inference splits the workload across multiple GPUs. It a useful technique for fitting larger models in memory and can process multiple prompts for higher throughput.</p> <p data-svelte-h="svelte-og4l62">This guide will show you how to use <a href="https://huggingface.co/docs/accelerate/index" rel="nofollow">Accelerate</a> and <a href="https://pytorch.org/tutorials/beginner/dist_overview.html" rel="nofollow">PyTorch Distributed</a> for distributed inference.</p> <h2 class="relative group"><a id="accelerate" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#accelerate"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Accelerate</span></h2> <p data-svelte-h="svelte-nz1tds">Accelerate is a library designed to simplify inference and training on multiple accelerators by handling the setup, allowing users to focus on their PyTorch code.</p> <p data-svelte-h="svelte-1st21a7">Install Accelerate with the following command.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->uv pip install accelerate<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-s7f0yp">Initialize a <a href="https://huggingface.co/docs/accelerate/main/en/package_reference/state#accelerate.PartialState" rel="nofollow">accelerate.PartialState</a> class in a Python file to create a distributed environment. The <a href="https://huggingface.co/docs/accelerate/main/en/package_reference/state#accelerate.PartialState" rel="nofollow">accelerate.PartialState</a> class manages process management, device control and distribution, and process coordination.</p> <p data-svelte-h="svelte-16d1xin">Move the <a href="/docs/diffusers/pr_12249/en/api/pipelines/overview#diffusers.DiffusionPipeline">DiffusionPipeline</a> to <code>accelerate.PartialState.device</code> to assign a GPU to each process.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> PartialState
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;Qwen/Qwen-Image&quot;</span>, torch_dtype=torch.float16
)
distributed_state = PartialState()
pipeline.to(distributed_state.device)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-dxnfoj">Use the <a href="https://huggingface.co/docs/accelerate/main/en/package_reference/state#accelerate.PartialState.split_between_processes" rel="nofollow">split_between_processes</a> utility as a context manager to automatically distribute the prompts between the number of processes.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">with</span> distributed_state.split_between_processes([<span class="hljs-string">&quot;a dog&quot;</span>, <span class="hljs-string">&quot;a cat&quot;</span>]) <span class="hljs-keyword">as</span> prompt:
result = pipeline(prompt).images[<span class="hljs-number">0</span>]
result.save(<span class="hljs-string">f&quot;result_<span class="hljs-subst">{distributed_state.process_index}</span>.png&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-xu0wll">Call <code>accelerate launch</code> to run the script and use the <code>--num_processes</code> argument to set the number of GPUs to use.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch run_distributed.py --num_processes=2<!-- HTML_TAG_END --></pre></div> <blockquote class="tip" data-svelte-h="svelte-djm11n"><p>Refer to this minimal example <a href="https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba" rel="nofollow">script</a> for running inference across multiple GPUs. To learn more, take a look at the <a href="https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate" rel="nofollow">Distributed Inference with 🤗 Accelerate</a> guide.</p></blockquote> <h2 class="relative group"><a id="pytorch-distributed" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#pytorch-distributed"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>PyTorch Distributed</span></h2> <p data-svelte-h="svelte-nym58p">PyTorch <a href="https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html" rel="nofollow">DistributedDataParallel</a> enables <a href="https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=data_parallelism" rel="nofollow">data parallelism</a>, which replicates the same model on each device, to process different batches of data in parallel.</p> <p data-svelte-h="svelte-1q3d069">Import <code>torch.distributed</code> and <code>torch.multiprocessing</code> into a Python file to set up the distributed process group and to spawn the processes for inference on each GPU.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> torch.distributed <span class="hljs-keyword">as</span> dist
<span class="hljs-keyword">import</span> torch.multiprocessing <span class="hljs-keyword">as</span> mp
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;Qwen/Qwen-Image&quot;</span>, torch_dtype=torch.float16,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1vva4uo">Create a function for inference with <a href="https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group" rel="nofollow">init_process_group</a>. This method creates a distributed environment with the backend type, the <code>rank</code> of the current process, and the <code>world_size</code> or number of processes participating (for example, 2 GPUs would be <code>world_size=2</code>).</p> <p data-svelte-h="svelte-1329uwh">Move the pipeline to <code>rank</code> and use <code>get_rank</code> to assign a GPU to each process. Each process handles a different prompt.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">run_inference</span>(<span class="hljs-params">rank, world_size</span>):
dist.init_process_group(<span class="hljs-string">&quot;nccl&quot;</span>, rank=rank, world_size=world_size)
pipeline.to(rank)
<span class="hljs-keyword">if</span> torch.distributed.get_rank() == <span class="hljs-number">0</span>:
prompt = <span class="hljs-string">&quot;a dog&quot;</span>
<span class="hljs-keyword">elif</span> torch.distributed.get_rank() == <span class="hljs-number">1</span>:
prompt = <span class="hljs-string">&quot;a cat&quot;</span>
image = sd(prompt).images[<span class="hljs-number">0</span>]
image.save(<span class="hljs-string">f&quot;./<span class="hljs-subst">{<span class="hljs-string">&#x27;_&#x27;</span>.join(prompt)}</span>.png&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-84sp40">Use <a href="https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn" rel="nofollow">mp.spawn</a> to create the number of processes defined in <code>world_size</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">main</span>():
world_size = <span class="hljs-number">2</span>
mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=<span class="hljs-literal">True</span>)
<span class="hljs-keyword">if</span> __name__ == <span class="hljs-string">&quot;__main__&quot;</span>:
main()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-17g6264">Call <code>torchrun</code> to run the inference script and use the <code>--nproc_per_node</code> argument to set the number of GPUs to use.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->torchrun run_distributed.py --nproc_per_node=2<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="devicemap" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#devicemap"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>device_map</span></h2> <p data-svelte-h="svelte-19ricyq">The <code>device_map</code> argument enables distributed inference by automatically placing model components on separate GPUs. This is especially useful when a model doesn’t fit on a single GPU. You can use <code>device_map</code> to selectively load and unload the required model components at a given stage as shown in the example below (assumes two GPUs are available).</p> <p data-svelte-h="svelte-uepizv">Set <code>device_map=&quot;balanced&quot;</code> to evenly distributes the text encoders on all available GPUs. You can use the <code>max_memory</code> argument to allocate a maximum amount of memory for each text encoder. Don’t load any other pipeline components to avoid memory usage.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> FluxPipeline
<span class="hljs-keyword">import</span> torch
prompt = <span class="hljs-string">&quot;&quot;&quot;
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
&quot;&quot;&quot;</span>
pipeline = FluxPipeline.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/FLUX.1-dev&quot;</span>,
transformer=<span class="hljs-literal">None</span>,
vae=<span class="hljs-literal">None</span>,
device_map=<span class="hljs-string">&quot;balanced&quot;</span>,
max_memory={<span class="hljs-number">0</span>: <span class="hljs-string">&quot;16GB&quot;</span>, <span class="hljs-number">1</span>: <span class="hljs-string">&quot;16GB&quot;</span>},
torch_dtype=torch.bfloat16
)
<span class="hljs-keyword">with</span> torch.no_grad():
<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;Encoding prompts.&quot;</span>)
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
prompt=prompt, prompt_2=<span class="hljs-literal">None</span>, max_sequence_length=<span class="hljs-number">512</span>
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-m088j3">After the text embeddings are computed, remove them from the GPU to make space for the diffusion transformer.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> gc
<span class="hljs-keyword">def</span> <span class="hljs-title function_">flush</span>():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
<span class="hljs-keyword">del</span> pipeline.text_encoder
<span class="hljs-keyword">del</span> pipeline.text_encoder_2
<span class="hljs-keyword">del</span> pipeline.tokenizer
<span class="hljs-keyword">del</span> pipeline.tokenizer_2
<span class="hljs-keyword">del</span> pipeline
flush()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-18jqegg">Set <code>device_map=&quot;auto&quot;</code> to automatically distribute the model on the two GPUs. This strategy places a model on the fastest device first before placing a model on a slower device like a CPU or hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoModel
<span class="hljs-keyword">import</span> torch
transformer = AutoModel.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/FLUX.1-dev&quot;</span>,
subfolder=<span class="hljs-string">&quot;transformer&quot;</span>,
device_map=<span class="hljs-string">&quot;auto&quot;</span>,
torch_dtype=torch.bfloat16
)<!-- HTML_TAG_END --></pre></div> <blockquote class="tip" data-svelte-h="svelte-eiyj9h"><p>Run <code>pipeline.hf_device_map</code> to see how the various models are distributed across devices. This is useful for tracking model device placement. You can also call <code>hf_device_map</code> on the transformer model to see how it is distributed.</p></blockquote> <p data-svelte-h="svelte-128ts8i">Add the transformer model to the pipeline and set the <code>output_type=&quot;latent&quot;</code> to generate the latents.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pipeline = FluxPipeline.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/FLUX.1-dev&quot;</span>,
text_encoder=<span class="hljs-literal">None</span>,
text_encoder_2=<span class="hljs-literal">None</span>,
tokenizer=<span class="hljs-literal">None</span>,
tokenizer_2=<span class="hljs-literal">None</span>,
vae=<span class="hljs-literal">None</span>,
transformer=transformer,
torch_dtype=torch.bfloat16
)
<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;Running denoising.&quot;</span>)
height, width = <span class="hljs-number">768</span>, <span class="hljs-number">1360</span>
latents = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=<span class="hljs-number">50</span>,
guidance_scale=<span class="hljs-number">3.5</span>,
height=height,
width=width,
output_type=<span class="hljs-string">&quot;latent&quot;</span>,
).images<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-jp9jpz">Remove the pipeline and transformer from memory and load a VAE to decode the latents. The VAE is typically small enough to be loaded on a single device.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoencoderKL
<span class="hljs-keyword">from</span> diffusers.image_processor <span class="hljs-keyword">import</span> VaeImageProcessor
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder=<span class="hljs-string">&quot;vae&quot;</span>, torch_dtype=torch.bfloat16).to(<span class="hljs-string">&quot;cuda&quot;</span>)
vae_scale_factor = <span class="hljs-number">2</span> ** (<span class="hljs-built_in">len</span>(vae.config.block_out_channels) - <span class="hljs-number">1</span>)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
<span class="hljs-keyword">with</span> torch.no_grad():
<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;Running decoding.&quot;</span>)
latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
image = vae.decode(latents, return_dict=<span class="hljs-literal">False</span>)[<span class="hljs-number">0</span>]
image = image_processor.postprocess(image, output_type=<span class="hljs-string">&quot;pil&quot;</span>)
image[<span class="hljs-number">0</span>].save(<span class="hljs-string">&quot;split_transformer.png&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-3tqui3">By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.</p> <h2 class="relative group"><a id="context-parallelism" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#context-parallelism"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Context parallelism</span></h2> <p data-svelte-h="svelte-1bsyx0k"><a href="https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=context_parallelism" rel="nofollow">Context parallelism</a> splits input sequences across multiple GPUs to reduce memory usage. Each GPU processes its own slice of the sequence.</p> <p data-svelte-h="svelte-15gx6qx">Use <a href="/docs/diffusers/pr_12249/en/api/models/overview#diffusers.ModelMixin.set_attention_backend">set_attention_backend()</a> to switch to a more optimized attention backend. Refer to this <a href="../optimization/attention_backends#available-backends">table</a> for a complete list of available backends.</p> <p data-svelte-h="svelte-1ppoxcn">Most attention backends are compatible with context parallelism. Open an <a href="https://github.com/huggingface/diffusers/issues/new" rel="nofollow">issue</a> if a backend is not compatible.</p> <h3 class="relative group"><a id="ring-attention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#ring-attention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Ring Attention</span></h3> <p data-svelte-h="svelte-wnvqc8">Key (K) and value (V) representations communicate between devices using <a href="https://huggingface.co/papers/2310.01889" rel="nofollow">Ring Attention</a>. This ensures each split sees every other token’s K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.</p> <p data-svelte-h="svelte-1fbao21">Pass a <a href="/docs/diffusers/pr_12249/en/api/parallel#diffusers.ContextParallelConfig">ContextParallelConfig</a> to the <code>parallel_config</code> argument of the transformer model. The config supports the <code>ring_degree</code> argument that determines how many devices to use for Ring Attention.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> torch <span class="hljs-keyword">import</span> distributed <span class="hljs-keyword">as</span> dist
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline, ContextParallelConfig
<span class="hljs-keyword">def</span> <span class="hljs-title function_">setup_distributed</span>():
<span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> dist.is_initialized():
dist.init_process_group(backend=<span class="hljs-string">&quot;nccl&quot;</span>)
rank = dist.get_rank()
device = torch.device(<span class="hljs-string">f&quot;cuda:<span class="hljs-subst">{rank}</span>&quot;</span>)
torch.cuda.set_device(device)
<span class="hljs-keyword">return</span> device
<span class="hljs-keyword">def</span> <span class="hljs-title function_">main</span>():
device = setup_distributed()
world_size = dist.get_world_size()
pipeline = DiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/FLUX.1-dev&quot;</span>, torch_dtype=torch.bfloat16
).to(device)
pipeline.transformer.set_attention_backend(<span class="hljs-string">&quot;_native_cudnn&quot;</span>)
cp_config = ContextParallelConfig(ring_degree=world_size)
pipeline.transformer.enable_parallelism(config=cp_config)
prompt = <span class="hljs-string">&quot;&quot;&quot;
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
&quot;&quot;&quot;</span>
<span class="hljs-comment"># Must specify generator so all ranks start with same latents (or pass your own)</span>
generator = torch.Generator().manual_seed(<span class="hljs-number">42</span>)
image = pipeline(
prompt,
guidance_scale=<span class="hljs-number">3.5</span>,
num_inference_steps=<span class="hljs-number">50</span>,
generator=generator,
).images[<span class="hljs-number">0</span>]
<span class="hljs-keyword">if</span> dist.get_rank() == <span class="hljs-number">0</span>:
image.save(<span class="hljs-string">f&quot;output.png&quot;</span>)
<span class="hljs-keyword">if</span> dist.is_initialized():
dist.destroy_process_group()
<span class="hljs-keyword">if</span> __name__ == <span class="hljs-string">&quot;__main__&quot;</span>:
main()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-6zwmyc">The script above needs to be run with a distributed launcher, such as <a href="https://docs.pytorch.org/docs/stable/elastic/run.html" rel="nofollow">torchrun</a>, that is compatible with PyTorch. <code>--nproc-per-node</code> is set to the number of GPUs available.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->torchrun --nproc-per-node 2 above_script.py<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="ulysses-attention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#ulysses-attention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Ulysses Attention</span></h3> <p data-svelte-h="svelte-1b28cdv"><a href="https://huggingface.co/papers/2309.14509" rel="nofollow">Ulysses Attention</a> splits a sequence across GPUs and performs an <em>all-to-all</em> communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.</p> <p data-svelte-h="svelte-16c9xr2"><a href="/docs/diffusers/pr_12249/en/api/parallel#diffusers.ContextParallelConfig">ContextParallelConfig</a> supports Ulysses Attention through the <code>ulysses_degree</code> argument. This determines how many devices to use for Ulysses Attention.</p> <p data-svelte-h="svelte-e2saj7">Pass the <a href="/docs/diffusers/pr_12249/en/api/parallel#diffusers.ContextParallelConfig">ContextParallelConfig</a> to <code>enable_parallelism()</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Depending on the number of GPUs available.</span>
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=<span class="hljs-number">2</span>))<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="unified-attention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#unified-attention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Unified Attention</span></h3> <p data-svelte-h="svelte-12jl1gc"><a href="https://huggingface.co/papers/2405.07719" rel="nofollow">Unified Sequence Parallelism</a> combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses’s <em>all-to-all</em> communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the <em>all-to-all</em> to restore the original layout.</p> <p data-svelte-h="svelte-76urs9">This hybrid approach leverages the strengths of both methods:</p> <ul data-svelte-h="svelte-1ih6klk"><li><strong>Ulysses Attention</strong> efficiently parallelizes across attention heads</li> <li><strong>Ring Attention</strong> handles very long sequences with minimal memory overhead</li> <li>Together, they enable 2D parallelization across both heads and sequence dimensions</li></ul> <p data-svelte-h="svelte-14swe5p"><a href="/docs/diffusers/pr_12249/en/api/parallel#diffusers.ContextParallelConfig">ContextParallelConfig</a> supports Unified Attention by specifying both <code>ulysses_degree</code> and <code>ring_degree</code>. The total number of devices used is <code>ulysses_degree * ring_degree</code>, arranged in a 2D grid where Ulysses and Ring groups are orthogonal (non-overlapping).
Pass the <a href="/docs/diffusers/pr_12249/en/api/parallel#diffusers.ContextParallelConfig">ContextParallelConfig</a> with both <code>ulysses_degree</code> and <code>ring_degree</code> set to bigger than 1 to <code>enable_parallelism()</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=<span class="hljs-number">2</span>, ring_degree=<span class="hljs-number">2</span>))<!-- HTML_TAG_END --></pre></div> <blockquote class="tip" data-svelte-h="svelte-1rulofq"><p>Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices).</p></blockquote> <p data-svelte-h="svelte-x9jokt">We ran a benchmark with Ulysess, Ring, and Unified Attention with <a href="https://github.com/huggingface/diffusers/pull/12693#issuecomment-3694727532" rel="nofollow">this script</a> on a node of 4 H100 GPUs. The results are summarized as follows:</p> <table data-svelte-h="svelte-1u212r8"><thead><tr><th>CP Backend</th> <th>Time / Iter (ms)</th> <th>Steps / Sec</th> <th>Peak Memory (GB)</th></tr></thead> <tbody><tr><td>ulysses</td> <td>6670.789</td> <td>7.50</td> <td>33.85</td></tr> <tr><td>ring</td> <td>13076.492</td> <td>3.82</td> <td>56.02</td></tr> <tr><td>unified_balanced</td> <td>11068.705</td> <td>4.52</td> <td>33.85</td></tr></tbody></table> <p data-svelte-h="svelte-58w498">From the above table, it’s clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention.</p> <h3 class="relative group"><a id="parallelconfig" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#parallelconfig"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>parallel_config</span></h3> <p data-svelte-h="svelte-1abk5kt">Pass <code>parallel_config</code> during model initialization to enable context parallelism.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->CKPT_ID = <span class="hljs-string">&quot;black-forest-labs/FLUX.1-dev&quot;</span>
cp_config = ContextParallelConfig(ring_degree=<span class="hljs-number">2</span>)
transformer = AutoModel.from_pretrained(
CKPT_ID,
subfolder=<span class="hljs-string">&quot;transformer&quot;</span>,
torch_dtype=torch.bfloat16,
parallel_config=cp_config
)
pipeline = DiffusionPipeline.from_pretrained(
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
).to(device)<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/diffusers/blob/main/docs/source/en/training/distributed_inference.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_15arw4w = {
assets: "/docs/diffusers/pr_12249/en",
base: "/docs/diffusers/pr_12249/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/diffusers/pr_12249/en/_app/immutable/entry/start.d78ccb7e.js"),
import("/docs/diffusers/pr_12249/en/_app/immutable/entry/app.cd68c46f.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 308],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
62.6 kB
·
Xet hash:
dc6b8f0c3035e44ad516ca3181850472253b19cc7b1e70a3b8ba31bdcb0558ed

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.