Buckets:

hf-doc-build/doc-dev / diffusers /pr_12652 /en /training /distributed_inference.html
rtrm's picture
download
raw
68.4 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Distributed inference&quot;,&quot;local&quot;:&quot;distributed-inference&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Accelerate&quot;,&quot;local&quot;:&quot;accelerate&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;PyTorch Distributed&quot;,&quot;local&quot;:&quot;pytorch-distributed&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;device_map&quot;,&quot;local&quot;:&quot;devicemap&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Context parallelism&quot;,&quot;local&quot;:&quot;context-parallelism&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Ring Attention&quot;,&quot;local&quot;:&quot;ring-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Ulysses Attention&quot;,&quot;local&quot;:&quot;ulysses-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Unified Attention&quot;,&quot;local&quot;:&quot;unified-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Ulysses Anything Attention&quot;,&quot;local&quot;:&quot;ulysses-anything-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;parallel_config&quot;,&quot;local&quot;:&quot;parallelconfig&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/diffusers/pr_12652/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/en/_app/immutable/entry/start.78b62fee.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/en/_app/immutable/chunks/scheduler.53228c21.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/en/_app/immutable/chunks/singletons.89d0b97a.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/en/_app/immutable/chunks/index.e93d0901.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/en/_app/immutable/chunks/paths.67f826e3.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/en/_app/immutable/entry/app.062e1615.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/en/_app/immutable/chunks/preload-helper.222e0275.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/en/_app/immutable/chunks/index.100fac89.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/en/_app/immutable/nodes/0.fe8af227.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/en/_app/immutable/nodes/314.5f0da98f.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/en/_app/immutable/chunks/CopyLLMTxtMenu.50ab6782.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/en/_app/immutable/chunks/globals.7f7f1b26.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/en/_app/immutable/chunks/IconCopy.38cf8f56.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.720a8c3c.js">
<link rel="modulepreload" href="/docs/diffusers/pr_12652/en/_app/immutable/chunks/CodeBlock.d30a6509.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Distributed inference&quot;,&quot;local&quot;:&quot;distributed-inference&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Accelerate&quot;,&quot;local&quot;:&quot;accelerate&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;PyTorch Distributed&quot;,&quot;local&quot;:&quot;pytorch-distributed&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;device_map&quot;,&quot;local&quot;:&quot;devicemap&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Context parallelism&quot;,&quot;local&quot;:&quot;context-parallelism&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Ring Attention&quot;,&quot;local&quot;:&quot;ring-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Ulysses Attention&quot;,&quot;local&quot;:&quot;ulysses-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Unified Attention&quot;,&quot;local&quot;:&quot;unified-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;Ulysses Anything Attention&quot;,&quot;local&quot;:&quot;ulysses-anything-attention&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3},{&quot;title&quot;:&quot;parallel_config&quot;,&quot;local&quot;:&quot;parallelconfig&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 h-7 max-sm:h-7 px-2 max-sm:px-1.5 text-sm font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0 hover:text-gray-800 dark:hover:text-gray-200"><svg class="sm:size-3.5 size-3" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-7 max-sm:h-7 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible sm:size-3.5 size-3 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="distributed-inference" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#distributed-inference"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Distributed inference</span></h1> <p data-svelte-h="svelte-eljq9f">Distributed inference splits the workload across multiple GPUs. It a useful technique for fitting larger models in memory and can process multiple prompts for higher throughput.</p> <p data-svelte-h="svelte-og4l62">This guide will show you how to use <a href="https://huggingface.co/docs/accelerate/index" rel="nofollow">Accelerate</a> and <a href="https://pytorch.org/tutorials/beginner/dist_overview.html" rel="nofollow">PyTorch Distributed</a> for distributed inference.</p> <h2 class="relative group"><a id="accelerate" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#accelerate"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Accelerate</span></h2> <p data-svelte-h="svelte-nz1tds">Accelerate is a library designed to simplify inference and training on multiple accelerators by handling the setup, allowing users to focus on their PyTorch code.</p> <p data-svelte-h="svelte-1st21a7">Install Accelerate with the following command.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->uv pip install accelerate<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-s7f0yp">Initialize a <a href="https://huggingface.co/docs/accelerate/main/en/package_reference/state#accelerate.PartialState" rel="nofollow">accelerate.PartialState</a> class in a Python file to create a distributed environment. The <a href="https://huggingface.co/docs/accelerate/main/en/package_reference/state#accelerate.PartialState" rel="nofollow">accelerate.PartialState</a> class manages process management, device control and distribution, and process coordination.</p> <p data-svelte-h="svelte-1hixvx7">Move the <a href="/docs/diffusers/pr_12652/en/api/pipelines/overview#diffusers.DiffusionPipeline">DiffusionPipeline</a> to <code>accelerate.PartialState.device</code> to assign a GPU to each process.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> PartialState
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;Qwen/Qwen-Image&quot;</span>, torch_dtype=torch.float16
)
distributed_state = PartialState()
pipeline.to(distributed_state.device)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-dxnfoj">Use the <a href="https://huggingface.co/docs/accelerate/main/en/package_reference/state#accelerate.PartialState.split_between_processes" rel="nofollow">split_between_processes</a> utility as a context manager to automatically distribute the prompts between the number of processes.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">with</span> distributed_state.split_between_processes([<span class="hljs-string">&quot;a dog&quot;</span>, <span class="hljs-string">&quot;a cat&quot;</span>]) <span class="hljs-keyword">as</span> prompt:
result = pipeline(prompt).images[<span class="hljs-number">0</span>]
result.save(<span class="hljs-string">f&quot;result_<span class="hljs-subst">{distributed_state.process_index}</span>.png&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-xu0wll">Call <code>accelerate launch</code> to run the script and use the <code>--num_processes</code> argument to set the number of GPUs to use.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch run_distributed.py --num_processes=2<!-- HTML_TAG_END --></pre></div> <blockquote class="tip" data-svelte-h="svelte-djm11n"><p>Refer to this minimal example <a href="https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba" rel="nofollow">script</a> for running inference across multiple GPUs. To learn more, take a look at the <a href="https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate" rel="nofollow">Distributed Inference with 🤗 Accelerate</a> guide.</p></blockquote> <h2 class="relative group"><a id="pytorch-distributed" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#pytorch-distributed"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>PyTorch Distributed</span></h2> <p data-svelte-h="svelte-nym58p">PyTorch <a href="https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html" rel="nofollow">DistributedDataParallel</a> enables <a href="https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=data_parallelism" rel="nofollow">data parallelism</a>, which replicates the same model on each device, to process different batches of data in parallel.</p> <p data-svelte-h="svelte-1q3d069">Import <code>torch.distributed</code> and <code>torch.multiprocessing</code> into a Python file to set up the distributed process group and to spawn the processes for inference on each GPU.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> torch.distributed <span class="hljs-keyword">as</span> dist
<span class="hljs-keyword">import</span> torch.multiprocessing <span class="hljs-keyword">as</span> mp
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;Qwen/Qwen-Image&quot;</span>, torch_dtype=torch.float16,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1vva4uo">Create a function for inference with <a href="https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group" rel="nofollow">init_process_group</a>. This method creates a distributed environment with the backend type, the <code>rank</code> of the current process, and the <code>world_size</code> or number of processes participating (for example, 2 GPUs would be <code>world_size=2</code>).</p> <p data-svelte-h="svelte-1329uwh">Move the pipeline to <code>rank</code> and use <code>get_rank</code> to assign a GPU to each process. Each process handles a different prompt.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">run_inference</span>(<span class="hljs-params">rank, world_size</span>):
dist.init_process_group(<span class="hljs-string">&quot;nccl&quot;</span>, rank=rank, world_size=world_size)
pipeline.to(rank)
<span class="hljs-keyword">if</span> torch.distributed.get_rank() == <span class="hljs-number">0</span>:
prompt = <span class="hljs-string">&quot;a dog&quot;</span>
<span class="hljs-keyword">elif</span> torch.distributed.get_rank() == <span class="hljs-number">1</span>:
prompt = <span class="hljs-string">&quot;a cat&quot;</span>
image = sd(prompt).images[<span class="hljs-number">0</span>]
image.save(<span class="hljs-string">f&quot;./<span class="hljs-subst">{<span class="hljs-string">&#x27;_&#x27;</span>.join(prompt)}</span>.png&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-84sp40">Use <a href="https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn" rel="nofollow">mp.spawn</a> to create the number of processes defined in <code>world_size</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">main</span>():
world_size = <span class="hljs-number">2</span>
mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=<span class="hljs-literal">True</span>)
<span class="hljs-keyword">if</span> __name__ == <span class="hljs-string">&quot;__main__&quot;</span>:
main()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-17g6264">Call <code>torchrun</code> to run the inference script and use the <code>--nproc_per_node</code> argument to set the number of GPUs to use.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->torchrun --nproc_per_node=2 run_distributed.py<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="devicemap" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#devicemap"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>device_map</span></h2> <p data-svelte-h="svelte-19ricyq">The <code>device_map</code> argument enables distributed inference by automatically placing model components on separate GPUs. This is especially useful when a model doesn’t fit on a single GPU. You can use <code>device_map</code> to selectively load and unload the required model components at a given stage as shown in the example below (assumes two GPUs are available).</p> <p data-svelte-h="svelte-uepizv">Set <code>device_map=&quot;balanced&quot;</code> to evenly distributes the text encoders on all available GPUs. You can use the <code>max_memory</code> argument to allocate a maximum amount of memory for each text encoder. Don’t load any other pipeline components to avoid memory usage.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> FluxPipeline
<span class="hljs-keyword">import</span> torch
prompt = <span class="hljs-string">&quot;&quot;&quot;
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
&quot;&quot;&quot;</span>
pipeline = FluxPipeline.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/FLUX.1-dev&quot;</span>,
transformer=<span class="hljs-literal">None</span>,
vae=<span class="hljs-literal">None</span>,
device_map=<span class="hljs-string">&quot;balanced&quot;</span>,
max_memory={<span class="hljs-number">0</span>: <span class="hljs-string">&quot;16GB&quot;</span>, <span class="hljs-number">1</span>: <span class="hljs-string">&quot;16GB&quot;</span>},
torch_dtype=torch.bfloat16
)
<span class="hljs-keyword">with</span> torch.no_grad():
<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;Encoding prompts.&quot;</span>)
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
prompt=prompt, prompt_2=<span class="hljs-literal">None</span>, max_sequence_length=<span class="hljs-number">512</span>
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-m088j3">After the text embeddings are computed, remove them from the GPU to make space for the diffusion transformer.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> gc
<span class="hljs-keyword">def</span> <span class="hljs-title function_">flush</span>():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
<span class="hljs-keyword">del</span> pipeline.text_encoder
<span class="hljs-keyword">del</span> pipeline.text_encoder_2
<span class="hljs-keyword">del</span> pipeline.tokenizer
<span class="hljs-keyword">del</span> pipeline.tokenizer_2
<span class="hljs-keyword">del</span> pipeline
flush()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-18jqegg">Set <code>device_map=&quot;auto&quot;</code> to automatically distribute the model on the two GPUs. This strategy places a model on the fastest device first before placing a model on a slower device like a CPU or hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoModel
<span class="hljs-keyword">import</span> torch
transformer = AutoModel.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/FLUX.1-dev&quot;</span>,
subfolder=<span class="hljs-string">&quot;transformer&quot;</span>,
device_map=<span class="hljs-string">&quot;auto&quot;</span>,
torch_dtype=torch.bfloat16
)<!-- HTML_TAG_END --></pre></div> <blockquote class="tip" data-svelte-h="svelte-eiyj9h"><p>Run <code>pipeline.hf_device_map</code> to see how the various models are distributed across devices. This is useful for tracking model device placement. You can also call <code>hf_device_map</code> on the transformer model to see how it is distributed.</p></blockquote> <p data-svelte-h="svelte-128ts8i">Add the transformer model to the pipeline and set the <code>output_type=&quot;latent&quot;</code> to generate the latents.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pipeline = FluxPipeline.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/FLUX.1-dev&quot;</span>,
text_encoder=<span class="hljs-literal">None</span>,
text_encoder_2=<span class="hljs-literal">None</span>,
tokenizer=<span class="hljs-literal">None</span>,
tokenizer_2=<span class="hljs-literal">None</span>,
vae=<span class="hljs-literal">None</span>,
transformer=transformer,
torch_dtype=torch.bfloat16
)
<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;Running denoising.&quot;</span>)
height, width = <span class="hljs-number">768</span>, <span class="hljs-number">1360</span>
latents = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=<span class="hljs-number">50</span>,
guidance_scale=<span class="hljs-number">3.5</span>,
height=height,
width=width,
output_type=<span class="hljs-string">&quot;latent&quot;</span>,
).images<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-jp9jpz">Remove the pipeline and transformer from memory and load a VAE to decode the latents. The VAE is typically small enough to be loaded on a single device.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoencoderKL
<span class="hljs-keyword">from</span> diffusers.image_processor <span class="hljs-keyword">import</span> VaeImageProcessor
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder=<span class="hljs-string">&quot;vae&quot;</span>, torch_dtype=torch.bfloat16).to(<span class="hljs-string">&quot;cuda&quot;</span>)
vae_scale_factor = <span class="hljs-number">2</span> ** (<span class="hljs-built_in">len</span>(vae.config.block_out_channels) - <span class="hljs-number">1</span>)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
<span class="hljs-keyword">with</span> torch.no_grad():
<span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;Running decoding.&quot;</span>)
latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
image = vae.decode(latents, return_dict=<span class="hljs-literal">False</span>)[<span class="hljs-number">0</span>]
image = image_processor.postprocess(image, output_type=<span class="hljs-string">&quot;pil&quot;</span>)
image[<span class="hljs-number">0</span>].save(<span class="hljs-string">&quot;split_transformer.png&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-3tqui3">By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.</p> <h2 class="relative group"><a id="context-parallelism" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#context-parallelism"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Context parallelism</span></h2> <p data-svelte-h="svelte-1bsyx0k"><a href="https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=context_parallelism" rel="nofollow">Context parallelism</a> splits input sequences across multiple GPUs to reduce memory usage. Each GPU processes its own slice of the sequence.</p> <p data-svelte-h="svelte-1u25fql">Use <a href="/docs/diffusers/pr_12652/en/api/models/overview#diffusers.ModelMixin.set_attention_backend">set_attention_backend()</a> to switch to a more optimized attention backend. Refer to this <a href="../optimization/attention_backends#available-backends">table</a> for a complete list of available backends.</p> <p data-svelte-h="svelte-1ppoxcn">Most attention backends are compatible with context parallelism. Open an <a href="https://github.com/huggingface/diffusers/issues/new" rel="nofollow">issue</a> if a backend is not compatible.</p> <h3 class="relative group"><a id="ring-attention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#ring-attention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Ring Attention</span></h3> <p data-svelte-h="svelte-wnvqc8">Key (K) and value (V) representations communicate between devices using <a href="https://huggingface.co/papers/2310.01889" rel="nofollow">Ring Attention</a>. This ensures each split sees every other token’s K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.</p> <p data-svelte-h="svelte-7vz069">Pass a <a href="/docs/diffusers/pr_12652/en/api/parallel#diffusers.ContextParallelConfig">ContextParallelConfig</a> to the <code>parallel_config</code> argument of the transformer model. The config supports the <code>ring_degree</code> argument that determines how many devices to use for Ring Attention.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">from</span> torch <span class="hljs-keyword">import</span> distributed <span class="hljs-keyword">as</span> dist
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline, ContextParallelConfig
<span class="hljs-keyword">def</span> <span class="hljs-title function_">setup_distributed</span>():
<span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> dist.is_initialized():
dist.init_process_group(backend=<span class="hljs-string">&quot;nccl&quot;</span>)
rank = dist.get_rank()
device = torch.device(<span class="hljs-string">f&quot;cuda:<span class="hljs-subst">{rank}</span>&quot;</span>)
torch.cuda.set_device(device)
<span class="hljs-keyword">return</span> device
<span class="hljs-keyword">def</span> <span class="hljs-title function_">main</span>():
device = setup_distributed()
world_size = dist.get_world_size()
pipeline = DiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;black-forest-labs/FLUX.1-dev&quot;</span>, torch_dtype=torch.bfloat16
).to(device)
pipeline.transformer.set_attention_backend(<span class="hljs-string">&quot;_native_cudnn&quot;</span>)
cp_config = ContextParallelConfig(ring_degree=world_size)
pipeline.transformer.enable_parallelism(config=cp_config)
prompt = <span class="hljs-string">&quot;&quot;&quot;
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
&quot;&quot;&quot;</span>
<span class="hljs-comment"># Must specify generator so all ranks start with same latents (or pass your own)</span>
generator = torch.Generator().manual_seed(<span class="hljs-number">42</span>)
image = pipeline(
prompt,
guidance_scale=<span class="hljs-number">3.5</span>,
num_inference_steps=<span class="hljs-number">50</span>,
generator=generator,
).images[<span class="hljs-number">0</span>]
<span class="hljs-keyword">if</span> dist.get_rank() == <span class="hljs-number">0</span>:
image.save(<span class="hljs-string">f&quot;output.png&quot;</span>)
<span class="hljs-keyword">if</span> dist.is_initialized():
dist.destroy_process_group()
<span class="hljs-keyword">if</span> __name__ == <span class="hljs-string">&quot;__main__&quot;</span>:
main()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-6zwmyc">The script above needs to be run with a distributed launcher, such as <a href="https://docs.pytorch.org/docs/stable/elastic/run.html" rel="nofollow">torchrun</a>, that is compatible with PyTorch. <code>--nproc-per-node</code> is set to the number of GPUs available.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->torchrun --nproc-per-node 2 above_script.py<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="ulysses-attention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#ulysses-attention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Ulysses Attention</span></h3> <p data-svelte-h="svelte-1b28cdv"><a href="https://huggingface.co/papers/2309.14509" rel="nofollow">Ulysses Attention</a> splits a sequence across GPUs and performs an <em>all-to-all</em> communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.</p> <p data-svelte-h="svelte-cslf3e"><a href="/docs/diffusers/pr_12652/en/api/parallel#diffusers.ContextParallelConfig">ContextParallelConfig</a> supports Ulysses Attention through the <code>ulysses_degree</code> argument. This determines how many devices to use for Ulysses Attention.</p> <p data-svelte-h="svelte-171zdj3">Pass the <a href="/docs/diffusers/pr_12652/en/api/parallel#diffusers.ContextParallelConfig">ContextParallelConfig</a> to <code>enable_parallelism()</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Depending on the number of GPUs available.</span>
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=<span class="hljs-number">2</span>))<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="unified-attention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#unified-attention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Unified Attention</span></h3> <p data-svelte-h="svelte-12jl1gc"><a href="https://huggingface.co/papers/2405.07719" rel="nofollow">Unified Sequence Parallelism</a> combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses’s <em>all-to-all</em> communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the <em>all-to-all</em> to restore the original layout.</p> <p data-svelte-h="svelte-76urs9">This hybrid approach leverages the strengths of both methods:</p> <ul data-svelte-h="svelte-1ih6klk"><li><strong>Ulysses Attention</strong> efficiently parallelizes across attention heads</li> <li><strong>Ring Attention</strong> handles very long sequences with minimal memory overhead</li> <li>Together, they enable 2D parallelization across both heads and sequence dimensions</li></ul> <p data-svelte-h="svelte-1dtioe1"><a href="/docs/diffusers/pr_12652/en/api/parallel#diffusers.ContextParallelConfig">ContextParallelConfig</a> supports Unified Attention by specifying both <code>ulysses_degree</code> and <code>ring_degree</code>. The total number of devices used is <code>ulysses_degree * ring_degree</code>, arranged in a 2D grid where Ulysses and Ring groups are orthogonal (non-overlapping).
Pass the <a href="/docs/diffusers/pr_12652/en/api/parallel#diffusers.ContextParallelConfig">ContextParallelConfig</a> with both <code>ulysses_degree</code> and <code>ring_degree</code> set to bigger than 1 to <code>enable_parallelism()</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=<span class="hljs-number">2</span>, ring_degree=<span class="hljs-number">2</span>))<!-- HTML_TAG_END --></pre></div> <blockquote class="tip" data-svelte-h="svelte-1rulofq"><p>Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices).</p></blockquote> <p data-svelte-h="svelte-x9jokt">We ran a benchmark with Ulysess, Ring, and Unified Attention with <a href="https://github.com/huggingface/diffusers/pull/12693#issuecomment-3694727532" rel="nofollow">this script</a> on a node of 4 H100 GPUs. The results are summarized as follows:</p> <table data-svelte-h="svelte-1u212r8"><thead><tr><th>CP Backend</th> <th>Time / Iter (ms)</th> <th>Steps / Sec</th> <th>Peak Memory (GB)</th></tr></thead> <tbody><tr><td>ulysses</td> <td>6670.789</td> <td>7.50</td> <td>33.85</td></tr> <tr><td>ring</td> <td>13076.492</td> <td>3.82</td> <td>56.02</td></tr> <tr><td>unified_balanced</td> <td>11068.705</td> <td>4.52</td> <td>33.85</td></tr></tbody></table> <p data-svelte-h="svelte-58w498">From the above table, it’s clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention.</p> <h3 class="relative group"><a id="ulysses-anything-attention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#ulysses-anything-attention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Ulysses Anything Attention</span></h3> <p data-svelte-h="svelte-fmnv6s">The default Ulysses Attention mechanism requires that the sequence length of hidden states must be divisible by the number of devices. This imposes significant limitations on the practical application of Ulysses Attention. <a href="https://github.com/huggingface/diffusers/pull/12996" rel="nofollow">Ulysses Anything Attention</a> is a variant of Ulysses Attention that supports arbitrary sequence lengths and arbitrary numbers of attention heads, thereby enhancing the versatility of Ulysses Attention in practical use.</p> <p data-svelte-h="svelte-1fwxcu0"><a href="/docs/diffusers/pr_12652/en/api/parallel#diffusers.ContextParallelConfig">ContextParallelConfig</a> supports Ulysses Anything Attention by specifying both <code>ulysses_degree</code> and <code>ulysses_anything</code>. Please note that Ulysses Anything Attention is not currently supported by Unified Attention. Pass the <a href="/docs/diffusers/pr_12652/en/api/parallel#diffusers.ContextParallelConfig">ContextParallelConfig</a> with both <code>ulysses_degree</code> set to bigger than 1 and <code>ulysses_anything=True</code> to <code>enable_parallelism()</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=<span class="hljs-number">2</span>, ulysses_anything=<span class="hljs-literal">True</span>))<!-- HTML_TAG_END --></pre></div> <blockquote class="tip" data-svelte-h="svelte-qn72hj"><p>To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the <strong>gloo</strong> backend in <code>init_process_group</code>. This will significantly reduce communication latency.</p></blockquote> <p data-svelte-h="svelte-gqggfk">We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulysses Anything Attention with <a href="https://github.com/huggingface/diffusers/pull/12996#issuecomment-3797695999" rel="nofollow">this script</a> on a node of 4 L20 GPUs. The results are summarized as follows:</p> <table data-svelte-h="svelte-88f0ay"><thead><tr><th>CP Backend</th> <th>Time / Iter (ms)</th> <th>Steps / Sec</th> <th>Peak Memory (GB)</th> <th>Shape (HxW)</th></tr></thead> <tbody><tr><td>ulysses</td> <td>281.07</td> <td>3.56</td> <td>37.11</td> <td>1024x1024</td></tr> <tr><td>ring</td> <td>351.34</td> <td>2.85</td> <td>37.01</td> <td>1024x1024</td></tr> <tr><td>unified_balanced</td> <td>324.37</td> <td>3.08</td> <td>37.16</td> <td>1024x1024</td></tr> <tr><td>ulysses_anything</td> <td>280.94</td> <td>3.56</td> <td>37.11</td> <td>1024x1024</td></tr> <tr><td>ulysses</td> <td>failed</td> <td>failed</td> <td>failed</td> <td>1008x1008</td></tr> <tr><td>ring</td> <td>failed</td> <td>failed</td> <td>failed</td> <td>1008x1008</td></tr> <tr><td>unified_balanced</td> <td>failed</td> <td>failed</td> <td>failed</td> <td>1008x1008</td></tr> <tr><td>ulysses_anything</td> <td>278.40</td> <td>3.59</td> <td>36.99</td> <td>1008x1008</td></tr></tbody></table> <p data-svelte-h="svelte-ml9v1f">From the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention.</p> <h3 class="relative group"><a id="parallelconfig" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#parallelconfig"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>parallel_config</span></h3> <p data-svelte-h="svelte-1abk5kt">Pass <code>parallel_config</code> during model initialization to enable context parallelism.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->CKPT_ID = <span class="hljs-string">&quot;black-forest-labs/FLUX.1-dev&quot;</span>
cp_config = ContextParallelConfig(ring_degree=<span class="hljs-number">2</span>)
transformer = AutoModel.from_pretrained(
CKPT_ID,
subfolder=<span class="hljs-string">&quot;transformer&quot;</span>,
torch_dtype=torch.bfloat16,
parallel_config=cp_config
)
pipeline = DiffusionPipeline.from_pretrained(
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
).to(device)<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/diffusers/blob/main/docs/source/en/training/distributed_inference.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1te7xiv = {
assets: "/docs/diffusers/pr_12652/en",
base: "/docs/diffusers/pr_12652/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/diffusers/pr_12652/en/_app/immutable/entry/start.78b62fee.js"),
import("/docs/diffusers/pr_12652/en/_app/immutable/entry/app.062e1615.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 314],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
68.4 kB
·
Xet hash:
94ca156e31720267395ac140880cad97723cd061bdcc9fd4a5882f89ab7aed0c

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.