Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Distributed inference","local":"distributed-inference","sections":[{"title":"Accelerate","local":"accelerate","sections":[],"depth":2},{"title":"PyTorch Distributed","local":"pytorch-distributed","sections":[],"depth":2},{"title":"device_map","local":"devicemap","sections":[],"depth":2},{"title":"Context parallelism","local":"context-parallelism","sections":[{"title":"Ring Attention","local":"ring-attention","sections":[],"depth":3},{"title":"Ulysses Attention","local":"ulysses-attention","sections":[],"depth":3},{"title":"parallel_config","local":"parallelconfig","sections":[],"depth":3}],"depth":2}],"depth":1}"> | |
| <link href="/docs/diffusers/pr_11739/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/diffusers/pr_11739/en/_app/immutable/entry/start.56da5b2b.js"> | |
| <link rel="modulepreload" href="/docs/diffusers/pr_11739/en/_app/immutable/chunks/scheduler.53228c21.js"> | |
| <link rel="modulepreload" href="/docs/diffusers/pr_11739/en/_app/immutable/chunks/singletons.70d778ba.js"> | |
| <link rel="modulepreload" href="/docs/diffusers/pr_11739/en/_app/immutable/chunks/index.e93d0901.js"> | |
| <link rel="modulepreload" href="/docs/diffusers/pr_11739/en/_app/immutable/chunks/paths.a0989064.js"> | |
| <link rel="modulepreload" href="/docs/diffusers/pr_11739/en/_app/immutable/entry/app.794fd245.js"> | |
| <link rel="modulepreload" href="/docs/diffusers/pr_11739/en/_app/immutable/chunks/preload-helper.b5c24ab3.js"> | |
| <link rel="modulepreload" href="/docs/diffusers/pr_11739/en/_app/immutable/chunks/index.100fac89.js"> | |
| <link rel="modulepreload" href="/docs/diffusers/pr_11739/en/_app/immutable/nodes/0.d68fe2c3.js"> | |
| <link rel="modulepreload" href="/docs/diffusers/pr_11739/en/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/diffusers/pr_11739/en/_app/immutable/nodes/308.b5a0dcdf.js"> | |
| <link rel="modulepreload" href="/docs/diffusers/pr_11739/en/_app/immutable/chunks/CopyLLMTxtMenu.ed0e3681.js"> | |
| <link rel="modulepreload" href="/docs/diffusers/pr_11739/en/_app/immutable/chunks/globals.7f7f1b26.js"> | |
| <link rel="modulepreload" href="/docs/diffusers/pr_11739/en/_app/immutable/chunks/IconCopy.38cf8f56.js"> | |
| <link rel="modulepreload" href="/docs/diffusers/pr_11739/en/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.dd42f483.js"> | |
| <link rel="modulepreload" href="/docs/diffusers/pr_11739/en/_app/immutable/chunks/CodeBlock.d30a6509.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Distributed inference","local":"distributed-inference","sections":[{"title":"Accelerate","local":"accelerate","sections":[],"depth":2},{"title":"PyTorch Distributed","local":"pytorch-distributed","sections":[],"depth":2},{"title":"device_map","local":"devicemap","sections":[],"depth":2},{"title":"Context parallelism","local":"context-parallelism","sections":[{"title":"Ring Attention","local":"ring-attention","sections":[],"depth":3},{"title":"Ulysses Attention","local":"ulysses-attention","sections":[],"depth":3},{"title":"parallel_config","local":"parallelconfig","sections":[],"depth":3}],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 max-sm:gap-0.5 h-6 max-sm:h-5 px-2 max-sm:px-1.5 text-[11px] max-sm:text-[9px] font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0"><svg class="w-3 h-3 max-sm:w-2.5 max-sm:h-2.5" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-6 max-sm:h-5 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible w-3 h-3 max-sm:w-2.5 max-sm:h-2.5 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <h1 class="relative group"><a id="distributed-inference" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#distributed-inference"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Distributed inference</span></h1> <p data-svelte-h="svelte-eljq9f">Distributed inference splits the workload across multiple GPUs. It a useful technique for fitting larger models in memory and can process multiple prompts for higher throughput.</p> <p data-svelte-h="svelte-og4l62">This guide will show you how to use <a href="https://huggingface.co/docs/accelerate/index" rel="nofollow">Accelerate</a> and <a href="https://pytorch.org/tutorials/beginner/dist_overview.html" rel="nofollow">PyTorch Distributed</a> for distributed inference.</p> <h2 class="relative group"><a id="accelerate" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#accelerate"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Accelerate</span></h2> <p data-svelte-h="svelte-nz1tds">Accelerate is a library designed to simplify inference and training on multiple accelerators by handling the setup, allowing users to focus on their PyTorch code.</p> <p data-svelte-h="svelte-1st21a7">Install Accelerate with the following command.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->uv pip install accelerate<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-s7f0yp">Initialize a <a href="https://huggingface.co/docs/accelerate/main/en/package_reference/state#accelerate.PartialState" rel="nofollow">accelerate.PartialState</a> class in a Python file to create a distributed environment. The <a href="https://huggingface.co/docs/accelerate/main/en/package_reference/state#accelerate.PartialState" rel="nofollow">accelerate.PartialState</a> class manages process management, device control and distribution, and process coordination.</p> <p data-svelte-h="svelte-2q3ora">Move the <a href="/docs/diffusers/pr_11739/en/api/pipelines/overview#diffusers.DiffusionPipeline">DiffusionPipeline</a> to <code>accelerate.PartialState.device</code> to assign a GPU to each process.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> PartialState | |
| <span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| <span class="hljs-string">"Qwen/Qwen-Image"</span>, torch_dtype=torch.float16 | |
| ) | |
| distributed_state = PartialState() | |
| pipeline.to(distributed_state.device)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-dxnfoj">Use the <a href="https://huggingface.co/docs/accelerate/main/en/package_reference/state#accelerate.PartialState.split_between_processes" rel="nofollow">split_between_processes</a> utility as a context manager to automatically distribute the prompts between the number of processes.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">with</span> distributed_state.split_between_processes([<span class="hljs-string">"a dog"</span>, <span class="hljs-string">"a cat"</span>]) <span class="hljs-keyword">as</span> prompt: | |
| result = pipeline(prompt).images[<span class="hljs-number">0</span>] | |
| result.save(<span class="hljs-string">f"result_<span class="hljs-subst">{distributed_state.process_index}</span>.png"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-xu0wll">Call <code>accelerate launch</code> to run the script and use the <code>--num_processes</code> argument to set the number of GPUs to use.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch run_distributed.py --num_processes=2<!-- HTML_TAG_END --></pre></div> <blockquote class="tip" data-svelte-h="svelte-djm11n"><p>Refer to this minimal example <a href="https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba" rel="nofollow">script</a> for running inference across multiple GPUs. To learn more, take a look at the <a href="https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate" rel="nofollow">Distributed Inference with 🤗 Accelerate</a> guide.</p></blockquote> <h2 class="relative group"><a id="pytorch-distributed" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#pytorch-distributed"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>PyTorch Distributed</span></h2> <p data-svelte-h="svelte-nym58p">PyTorch <a href="https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html" rel="nofollow">DistributedDataParallel</a> enables <a href="https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=data_parallelism" rel="nofollow">data parallelism</a>, which replicates the same model on each device, to process different batches of data in parallel.</p> <p data-svelte-h="svelte-1q3d069">Import <code>torch.distributed</code> and <code>torch.multiprocessing</code> into a Python file to set up the distributed process group and to spawn the processes for inference on each GPU.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">import</span> torch.distributed <span class="hljs-keyword">as</span> dist | |
| <span class="hljs-keyword">import</span> torch.multiprocessing <span class="hljs-keyword">as</span> mp | |
| <span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| <span class="hljs-string">"Qwen/Qwen-Image"</span>, torch_dtype=torch.float16, | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1vva4uo">Create a function for inference with <a href="https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group" rel="nofollow">init_process_group</a>. This method creates a distributed environment with the backend type, the <code>rank</code> of the current process, and the <code>world_size</code> or number of processes participating (for example, 2 GPUs would be <code>world_size=2</code>).</p> <p data-svelte-h="svelte-1329uwh">Move the pipeline to <code>rank</code> and use <code>get_rank</code> to assign a GPU to each process. Each process handles a different prompt.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">run_inference</span>(<span class="hljs-params">rank, world_size</span>): | |
| dist.init_process_group(<span class="hljs-string">"nccl"</span>, rank=rank, world_size=world_size) | |
| pipeline.to(rank) | |
| <span class="hljs-keyword">if</span> torch.distributed.get_rank() == <span class="hljs-number">0</span>: | |
| prompt = <span class="hljs-string">"a dog"</span> | |
| <span class="hljs-keyword">elif</span> torch.distributed.get_rank() == <span class="hljs-number">1</span>: | |
| prompt = <span class="hljs-string">"a cat"</span> | |
| image = sd(prompt).images[<span class="hljs-number">0</span>] | |
| image.save(<span class="hljs-string">f"./<span class="hljs-subst">{<span class="hljs-string">'_'</span>.join(prompt)}</span>.png"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-84sp40">Use <a href="https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn" rel="nofollow">mp.spawn</a> to create the number of processes defined in <code>world_size</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">main</span>(): | |
| world_size = <span class="hljs-number">2</span> | |
| mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=<span class="hljs-literal">True</span>) | |
| <span class="hljs-keyword">if</span> __name__ == <span class="hljs-string">"__main__"</span>: | |
| main()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-17g6264">Call <code>torchrun</code> to run the inference script and use the <code>--nproc_per_node</code> argument to set the number of GPUs to use.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->torchrun run_distributed.py --nproc_per_node=2<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="devicemap" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#devicemap"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>device_map</span></h2> <p data-svelte-h="svelte-19ricyq">The <code>device_map</code> argument enables distributed inference by automatically placing model components on separate GPUs. This is especially useful when a model doesn’t fit on a single GPU. You can use <code>device_map</code> to selectively load and unload the required model components at a given stage as shown in the example below (assumes two GPUs are available).</p> <p data-svelte-h="svelte-uepizv">Set <code>device_map="balanced"</code> to evenly distributes the text encoders on all available GPUs. You can use the <code>max_memory</code> argument to allocate a maximum amount of memory for each text encoder. Don’t load any other pipeline components to avoid memory usage.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> FluxPipeline | |
| <span class="hljs-keyword">import</span> torch | |
| prompt = <span class="hljs-string">""" | |
| cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California | |
| highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain | |
| """</span> | |
| pipeline = FluxPipeline.from_pretrained( | |
| <span class="hljs-string">"black-forest-labs/FLUX.1-dev"</span>, | |
| transformer=<span class="hljs-literal">None</span>, | |
| vae=<span class="hljs-literal">None</span>, | |
| device_map=<span class="hljs-string">"balanced"</span>, | |
| max_memory={<span class="hljs-number">0</span>: <span class="hljs-string">"16GB"</span>, <span class="hljs-number">1</span>: <span class="hljs-string">"16GB"</span>}, | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| <span class="hljs-keyword">with</span> torch.no_grad(): | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"Encoding prompts."</span>) | |
| prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( | |
| prompt=prompt, prompt_2=<span class="hljs-literal">None</span>, max_sequence_length=<span class="hljs-number">512</span> | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-m088j3">After the text embeddings are computed, remove them from the GPU to make space for the diffusion transformer.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> gc | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">flush</span>(): | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| torch.cuda.reset_max_memory_allocated() | |
| torch.cuda.reset_peak_memory_stats() | |
| <span class="hljs-keyword">del</span> pipeline.text_encoder | |
| <span class="hljs-keyword">del</span> pipeline.text_encoder_2 | |
| <span class="hljs-keyword">del</span> pipeline.tokenizer | |
| <span class="hljs-keyword">del</span> pipeline.tokenizer_2 | |
| <span class="hljs-keyword">del</span> pipeline | |
| flush()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-18jqegg">Set <code>device_map="auto"</code> to automatically distribute the model on the two GPUs. This strategy places a model on the fastest device first before placing a model on a slower device like a CPU or hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoModel | |
| <span class="hljs-keyword">import</span> torch | |
| transformer = AutoModel.from_pretrained( | |
| <span class="hljs-string">"black-forest-labs/FLUX.1-dev"</span>, | |
| subfolder=<span class="hljs-string">"transformer"</span>, | |
| device_map=<span class="hljs-string">"auto"</span>, | |
| torch_dtype=torch.bfloat16 | |
| )<!-- HTML_TAG_END --></pre></div> <blockquote class="tip" data-svelte-h="svelte-eiyj9h"><p>Run <code>pipeline.hf_device_map</code> to see how the various models are distributed across devices. This is useful for tracking model device placement. You can also call <code>hf_device_map</code> on the transformer model to see how it is distributed.</p></blockquote> <p data-svelte-h="svelte-128ts8i">Add the transformer model to the pipeline and set the <code>output_type="latent"</code> to generate the latents.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pipeline = FluxPipeline.from_pretrained( | |
| <span class="hljs-string">"black-forest-labs/FLUX.1-dev"</span>, | |
| text_encoder=<span class="hljs-literal">None</span>, | |
| text_encoder_2=<span class="hljs-literal">None</span>, | |
| tokenizer=<span class="hljs-literal">None</span>, | |
| tokenizer_2=<span class="hljs-literal">None</span>, | |
| vae=<span class="hljs-literal">None</span>, | |
| transformer=transformer, | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"Running denoising."</span>) | |
| height, width = <span class="hljs-number">768</span>, <span class="hljs-number">1360</span> | |
| latents = pipeline( | |
| prompt_embeds=prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| num_inference_steps=<span class="hljs-number">50</span>, | |
| guidance_scale=<span class="hljs-number">3.5</span>, | |
| height=height, | |
| width=width, | |
| output_type=<span class="hljs-string">"latent"</span>, | |
| ).images<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-jp9jpz">Remove the pipeline and transformer from memory and load a VAE to decode the latents. The VAE is typically small enough to be loaded on a single device.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> AutoencoderKL | |
| <span class="hljs-keyword">from</span> diffusers.image_processor <span class="hljs-keyword">import</span> VaeImageProcessor | |
| vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder=<span class="hljs-string">"vae"</span>, torch_dtype=torch.bfloat16).to(<span class="hljs-string">"cuda"</span>) | |
| vae_scale_factor = <span class="hljs-number">2</span> ** (<span class="hljs-built_in">len</span>(vae.config.block_out_channels) - <span class="hljs-number">1</span>) | |
| image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) | |
| <span class="hljs-keyword">with</span> torch.no_grad(): | |
| <span class="hljs-built_in">print</span>(<span class="hljs-string">"Running decoding."</span>) | |
| latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor) | |
| latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor | |
| image = vae.decode(latents, return_dict=<span class="hljs-literal">False</span>)[<span class="hljs-number">0</span>] | |
| image = image_processor.postprocess(image, output_type=<span class="hljs-string">"pil"</span>) | |
| image[<span class="hljs-number">0</span>].save(<span class="hljs-string">"split_transformer.png"</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-3tqui3">By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.</p> <h2 class="relative group"><a id="context-parallelism" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#context-parallelism"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Context parallelism</span></h2> <p data-svelte-h="svelte-1bsyx0k"><a href="https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=context_parallelism" rel="nofollow">Context parallelism</a> splits input sequences across multiple GPUs to reduce memory usage. Each GPU processes its own slice of the sequence.</p> <p data-svelte-h="svelte-n1ikmm">Use <a href="/docs/diffusers/pr_11739/en/api/models/overview#diffusers.ModelMixin.set_attention_backend">set_attention_backend()</a> to switch to a more optimized attention backend. Refer to this <a href="../optimization/attention_backends#available-backends">table</a> for a complete list of available backends.</p> <p data-svelte-h="svelte-1ppoxcn">Most attention backends are compatible with context parallelism. Open an <a href="https://github.com/huggingface/diffusers/issues/new" rel="nofollow">issue</a> if a backend is not compatible.</p> <h3 class="relative group"><a id="ring-attention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#ring-attention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Ring Attention</span></h3> <p data-svelte-h="svelte-wnvqc8">Key (K) and value (V) representations communicate between devices using <a href="https://huggingface.co/papers/2310.01889" rel="nofollow">Ring Attention</a>. This ensures each split sees every other token’s K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.</p> <p data-svelte-h="svelte-3dsqkg">Pass a <a href="/docs/diffusers/pr_11739/en/api/parallel#diffusers.ContextParallelConfig">ContextParallelConfig</a> to the <code>parallel_config</code> argument of the transformer model. The config supports the <code>ring_degree</code> argument that determines how many devices to use for Ring Attention.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| <span class="hljs-keyword">from</span> torch <span class="hljs-keyword">import</span> distributed <span class="hljs-keyword">as</span> dist | |
| <span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DiffusionPipeline, ContextParallelConfig | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">setup_distributed</span>(): | |
| <span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> dist.is_initialized(): | |
| dist.init_process_group(backend=<span class="hljs-string">"nccl"</span>) | |
| rank = dist.get_rank() | |
| device = torch.device(<span class="hljs-string">f"cuda:<span class="hljs-subst">{rank}</span>"</span>) | |
| torch.cuda.set_device(device) | |
| <span class="hljs-keyword">return</span> device | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">main</span>(): | |
| device = setup_distributed() | |
| world_size = dist.get_world_size() | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| <span class="hljs-string">"black-forest-labs/FLUX.1-dev"</span>, torch_dtype=torch.bfloat16 | |
| ).to(device) | |
| pipeline.transformer.set_attention_backend(<span class="hljs-string">"_native_cudnn"</span>) | |
| cp_config = ContextParallelConfig(ring_degree=world_size) | |
| pipeline.transformer.enable_parallelism(config=cp_config) | |
| prompt = <span class="hljs-string">""" | |
| cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California | |
| highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain | |
| """</span> | |
| <span class="hljs-comment"># Must specify generator so all ranks start with same latents (or pass your own)</span> | |
| generator = torch.Generator().manual_seed(<span class="hljs-number">42</span>) | |
| image = pipeline( | |
| prompt, | |
| guidance_scale=<span class="hljs-number">3.5</span>, | |
| num_inference_steps=<span class="hljs-number">50</span>, | |
| generator=generator, | |
| ).images[<span class="hljs-number">0</span>] | |
| <span class="hljs-keyword">if</span> dist.get_rank() == <span class="hljs-number">0</span>: | |
| image.save(<span class="hljs-string">f"output.png"</span>) | |
| <span class="hljs-keyword">if</span> dist.is_initialized(): | |
| dist.destroy_process_group() | |
| <span class="hljs-keyword">if</span> __name__ == <span class="hljs-string">"__main__"</span>: | |
| main()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-6zwmyc">The script above needs to be run with a distributed launcher, such as <a href="https://docs.pytorch.org/docs/stable/elastic/run.html" rel="nofollow">torchrun</a>, that is compatible with PyTorch. <code>--nproc-per-node</code> is set to the number of GPUs available.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->torchrun --nproc-per-node 2 above_script.py<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="ulysses-attention" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#ulysses-attention"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Ulysses Attention</span></h3> <p data-svelte-h="svelte-1b28cdv"><a href="https://huggingface.co/papers/2309.14509" rel="nofollow">Ulysses Attention</a> splits a sequence across GPUs and performs an <em>all-to-all</em> communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.</p> <p data-svelte-h="svelte-sritq9"><a href="/docs/diffusers/pr_11739/en/api/parallel#diffusers.ContextParallelConfig">ContextParallelConfig</a> supports Ulysses Attention through the <code>ulysses_degree</code> argument. This determines how many devices to use for Ulysses Attention.</p> <p data-svelte-h="svelte-18qdbre">Pass the <a href="/docs/diffusers/pr_11739/en/api/parallel#diffusers.ContextParallelConfig">ContextParallelConfig</a> to <code>enable_parallelism()</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Depending on the number of GPUs available.</span> | |
| pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=<span class="hljs-number">2</span>))<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="parallelconfig" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#parallelconfig"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>parallel_config</span></h3> <p data-svelte-h="svelte-1abk5kt">Pass <code>parallel_config</code> during model initialization to enable context parallelism.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->CKPT_ID = <span class="hljs-string">"black-forest-labs/FLUX.1-dev"</span> | |
| cp_config = ContextParallelConfig(ring_degree=<span class="hljs-number">2</span>) | |
| transformer = AutoModel.from_pretrained( | |
| CKPT_ID, | |
| subfolder=<span class="hljs-string">"transformer"</span>, | |
| torch_dtype=torch.bfloat16, | |
| parallel_config=cp_config | |
| ) | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16, | |
| ).to(device)<!-- HTML_TAG_END --></pre></div> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/diffusers/blob/main/docs/source/en/training/distributed_inference.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_kxvna0 = { | |
| assets: "/docs/diffusers/pr_11739/en", | |
| base: "/docs/diffusers/pr_11739/en", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/diffusers/pr_11739/en/_app/immutable/entry/start.56da5b2b.js"), | |
| import("/docs/diffusers/pr_11739/en/_app/immutable/entry/app.794fd245.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 308], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 57.2 kB
- Xet hash:
- feb346158de15cb25333038b35281f28cb1f7ae0e8e89cfd8c89fbc8256d8c7e
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.