Buckets:

hf-doc-build/doc-dev / diffusers /pr_11686 /en /using-diffusers /stable_diffusion_jax_how_to.html
download
raw
43.1 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;JAX/Flax&quot;,&quot;local&quot;:&quot;jaxflax&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Load a model&quot;,&quot;local&quot;:&quot;load-a-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Inference&quot;,&quot;local&quot;:&quot;inference&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Using different prompts&quot;,&quot;local&quot;:&quot;using-different-prompts&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;How does parallelization work?&quot;,&quot;local&quot;:&quot;how-does-parallelization-work&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Resources&quot;,&quot;local&quot;:&quot;resources&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/diffusers/pr_11686/en/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/entry/start.2b9667fb.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/scheduler.8c3d61f6.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/singletons.756349ae.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/index.0997d446.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/paths.8d5937da.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/entry/app.a2a6117e.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/index.da70eac4.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/nodes/0.a31d0923.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/nodes/309.c101e34d.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/Tip.1d9b8c37.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/CodeBlock.a9c4becf.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/DocNotebookDropdown.48852948.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/globals.7f7f1b26.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11686/en/_app/immutable/chunks/getInferenceSnippets.d00e08ac.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;JAX/Flax&quot;,&quot;local&quot;:&quot;jaxflax&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;Load a model&quot;,&quot;local&quot;:&quot;load-a-model&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Inference&quot;,&quot;local&quot;:&quot;inference&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Using different prompts&quot;,&quot;local&quot;:&quot;using-different-prompts&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;How does parallelization work?&quot;,&quot;local&quot;:&quot;how-does-parallelization-work&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Resources&quot;,&quot;local&quot;:&quot;resources&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="jaxflax" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#jaxflax"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>JAX/Flax</span></h1> <div class="flex space-x-1 absolute z-10 right-0 top-0"> <div class="relative colab-dropdown "> <button class=" " type="button"> <img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"> </button> </div> <div class="relative colab-dropdown "> <button class=" " type="button"> <img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"> </button> </div></div> <p data-svelte-h="svelte-12svr29">🤗 Diffusers supports Flax for super fast inference on Google TPUs, such as those available in Colab, Kaggle or Google Cloud Platform. This guide shows you how to run inference with Stable Diffusion using JAX/Flax.</p> <p data-svelte-h="svelte-vd07ss">Before you begin, make sure you have the necessary libraries installed:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># uncomment to install the necessary libraries in Colab</span>
<span class="hljs-comment">#!pip install -q jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy</span>
<span class="hljs-comment">#!pip install -q diffusers</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-15famve">You should also make sure you’re using a TPU backend. While JAX does not run exclusively on TPUs, you’ll get the best performance on a TPU because each server has 8 TPU accelerators working in parallel.</p> <p data-svelte-h="svelte-1jp1nmd">If you are running this guide in Colab, select <em>Runtime</em> in the menu above, select the option <em>Change runtime type</em>, and then select <em>TPU</em> under the <em>Hardware accelerator</em> setting. Import JAX and quickly check whether you’re using a TPU:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> jax
<span class="hljs-keyword">import</span> jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
num_devices = jax.device_count()
device_type = jax.devices()[<span class="hljs-number">0</span>].device_kind
<span class="hljs-built_in">print</span>(<span class="hljs-string">f&quot;Found <span class="hljs-subst">{num_devices}</span> JAX devices of type <span class="hljs-subst">{device_type}</span>.&quot;</span>)
<span class="hljs-keyword">assert</span> (
<span class="hljs-string">&quot;TPU&quot;</span> <span class="hljs-keyword">in</span> device_type,
<span class="hljs-string">&quot;Available device is not a TPU, please select TPU from Runtime &gt; Change runtime type &gt; Hardware accelerator&quot;</span>
)
<span class="hljs-comment"># Found 8 JAX devices of type Cloud TPU.</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-5ufrkr">Great, now you can import the rest of the dependencies you’ll need:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> jax.numpy <span class="hljs-keyword">as</span> jnp
<span class="hljs-keyword">from</span> jax <span class="hljs-keyword">import</span> pmap
<span class="hljs-keyword">from</span> flax.jax_utils <span class="hljs-keyword">import</span> replicate
<span class="hljs-keyword">from</span> flax.training.common_utils <span class="hljs-keyword">import</span> shard
<span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> FlaxStableDiffusionPipeline<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="load-a-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#load-a-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Load a model</span></h2> <p data-svelte-h="svelte-349jsc">Flax is a functional framework, so models are stateless and parameters are stored outside of them. Loading a pretrained Flax pipeline returns <em>both</em> the pipeline and the model weights (or parameters). In this guide, you’ll use <code>bfloat16</code>, a more efficient half-float type that is supported by TPUs (you can also use <code>float32</code> for full precision if you want).</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->dtype = jnp.bfloat16
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
<span class="hljs-string">&quot;CompVis/stable-diffusion-v1-4&quot;</span>,
variant=<span class="hljs-string">&quot;bf16&quot;</span>,
dtype=dtype,
)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="inference" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#inference"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Inference</span></h2> <p data-svelte-h="svelte-19grb2g">TPUs usually have 8 devices working in parallel, so let’s use the same prompt for each device. This means you can perform inference on 8 devices at once, with each device generating one image. As a result, you’ll get 8 images in the same amount of time it takes for one chip to generate a single image!</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-5pbnhk">Learn more details in the <a href="#how-does-parallelization-work">How does parallelization work?</a> section.</p></div> <p data-svelte-h="svelte-ih3iqi">After replicating the prompt, get the tokenized text ids by calling the <code>prepare_inputs</code> function on the pipeline. The length of the tokenized text is set to 77 tokens as required by the configuration of the underlying CLIP text model.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->prompt = <span class="hljs-string">&quot;A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic&quot;</span>
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape
<span class="hljs-comment"># (8, 77)</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-b3wqu4">Model parameters and inputs have to be replicated across the 8 parallel devices. The parameters dictionary is replicated with <a href="https://flax.readthedocs.io/en/latest/api_reference/flax.jax_utils.html#flax.jax_utils.replicate" rel="nofollow"><code>flax.jax_utils.replicate</code></a> which traverses the dictionary and changes the shape of the weights so they are repeated 8 times. Arrays are replicated using <code>shard</code>.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># parameters</span>
p_params = replicate(params)
<span class="hljs-comment"># arrays</span>
prompt_ids = shard(prompt_ids)
prompt_ids.shape
<span class="hljs-comment"># (8, 1, 77)</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-18q6dmb">This shape means each one of the 8 devices receives as an input a <code>jnp</code> array with shape <code>(1, 77)</code>, where <code>1</code> is the batch size per device. On TPUs with sufficient memory, you could have a batch size larger than <code>1</code> if you want to generate multiple images (per chip) at once.</p> <p data-svelte-h="svelte-j1l5pg">Next, create a random number generator to pass to the generation function. This is standard procedure in Flax, which is very serious and opinionated about random numbers. All functions that deal with random numbers are expected to receive a generator to ensure reproducibility, even when you’re training across multiple distributed devices.</p> <p data-svelte-h="svelte-1kfts0e">The helper function below uses a seed to initialize a random number generator. As long as you use the same seed, you’ll get the exact same results. Feel free to use different seeds when exploring results later in the guide.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">def</span> <span class="hljs-title function_">create_key</span>(<span class="hljs-params">seed=<span class="hljs-number">0</span></span>):
<span class="hljs-keyword">return</span> jax.random.PRNGKey(seed)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-bsl0kg">The helper function, or <code>rng</code>, is split 8 times so each device receives a different generator and generates a different image.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->rng = create_key(<span class="hljs-number">0</span>)
rng = jax.random.split(rng, jax.device_count())<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1sp1ijt">To take advantage of JAX’s optimized speed on a TPU, pass <code>jit=True</code> to the pipeline to compile the JAX code into an efficient representation and to ensure the model runs in parallel across the 8 devices.</p> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-jgm5rf">You need to ensure all your inputs have the same shape in subsequent calls, otherwise JAX will need to recompile the code which is slower.</p></div> <p data-svelte-h="svelte-1xv75ai">The first inference run takes more time because it needs to compile the code, but subsequent calls (even with different inputs) are much faster. For example, it took more than a minute to compile on a TPU v2-8, but then it takes about <strong>7s</strong> on a future inference run!</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->%%time
images = pipeline(prompt_ids, p_params, rng, jit=<span class="hljs-literal">True</span>)[<span class="hljs-number">0</span>]
<span class="hljs-comment"># CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s</span>
<span class="hljs-comment"># Wall time: 1min 29s</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-l3b4rx">The returned array has shape <code>(8, 1, 512, 512, 3)</code> which should be reshaped to remove the second dimension and get 8 images of <code>512 × 512 × 3</code>. Then you can use the <a href="/docs/diffusers/pr_11686/en/api/utilities#diffusers.utils.numpy_to_pil">numpy_to_pil()</a> function to convert the arrays into images.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> diffusers.utils <span class="hljs-keyword">import</span> make_image_grid
images = images.reshape((images.shape[<span class="hljs-number">0</span>] * images.shape[<span class="hljs-number">1</span>],) + images.shape[-<span class="hljs-number">3</span>:])
images = pipeline.numpy_to_pil(images)
make_image_grid(images, rows=<span class="hljs-number">2</span>, cols=<span class="hljs-number">4</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-e1s2k"><img src="https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_38_output_0.jpeg" alt="img"></p> <h2 class="relative group"><a id="using-different-prompts" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#using-different-prompts"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Using different prompts</span></h2> <p data-svelte-h="svelte-9e9hdw">You don’t necessarily have to use the same prompt on all devices. For example, to generate 8 different prompts:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->prompts = [
<span class="hljs-string">&quot;Labrador in the style of Hokusai&quot;</span>,
<span class="hljs-string">&quot;Painting of a squirrel skating in New York&quot;</span>,
<span class="hljs-string">&quot;HAL-9000 in the style of Van Gogh&quot;</span>,
<span class="hljs-string">&quot;Times Square under water, with fish and a dolphin swimming around&quot;</span>,
<span class="hljs-string">&quot;Ancient Roman fresco showing a man working on his laptop&quot;</span>,
<span class="hljs-string">&quot;Close-up photograph of young black woman against urban background, high quality, bokeh&quot;</span>,
<span class="hljs-string">&quot;Armchair in the shape of an avocado&quot;</span>,
<span class="hljs-string">&quot;Clown astronaut in space, with Earth in the background&quot;</span>,
]
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, p_params, rng, jit=<span class="hljs-literal">True</span>).images
images = images.reshape((images.shape[<span class="hljs-number">0</span>] * images.shape[<span class="hljs-number">1</span>],) + images.shape[-<span class="hljs-number">3</span>:])
images = pipeline.numpy_to_pil(images)
make_image_grid(images, <span class="hljs-number">2</span>, <span class="hljs-number">4</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1nea4f2"><img src="https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_43_output_0.jpeg" alt="img"></p> <h2 class="relative group"><a id="how-does-parallelization-work" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#how-does-parallelization-work"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>How does parallelization work?</span></h2> <p data-svelte-h="svelte-1217ir9">The Flax pipeline in 🤗 Diffusers automatically compiles the model and runs it in parallel on all available devices. Let’s take a closer look at how that process works.</p> <p data-svelte-h="svelte-1toapy0">JAX parallelization can be done in multiple ways. The easiest one revolves around using the <a href="https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html" rel="nofollow"><code>jax.pmap</code></a> function to achieve single-program multiple-data (SPMD) parallelization. It means running several copies of the same code, each on different data inputs. More sophisticated approaches are possible, and you can go over to the JAX <a href="https://jax.readthedocs.io/en/latest/index.html" rel="nofollow">documentation</a> to explore this topic in more detail if you are interested!</p> <p data-svelte-h="svelte-1pwzekg"><code>jax.pmap</code> does two things:</p> <ol data-svelte-h="svelte-12i8xv7"><li>Compiles (or ”<code>jit</code>s”) the code which is similar to <code>jax.jit()</code>. This does not happen when you call <code>pmap</code>, and only the first time the <code>pmap</code>ped function is called.</li> <li>Ensures the compiled code runs in parallel on all available devices.</li></ol> <p data-svelte-h="svelte-1gtj740">To demonstrate, call <code>pmap</code> on the pipeline’s <code>_generate</code> method (this is a private method that generates images and may be renamed or removed in future releases of 🤗 Diffusers):</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->p_generate = pmap(pipeline._generate)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1774hdu">After calling <code>pmap</code>, the prepared function <code>p_generate</code> will:</p> <ol data-svelte-h="svelte-12ig0qf"><li>Make a copy of the underlying function, <code>pipeline._generate</code>, on each device.</li> <li>Send each device a different portion of the input arguments (this is why it’s necessary to call the <em>shard</em> function). In this case, <code>prompt_ids</code> has shape <code>(8, 1, 77, 768)</code> so the array is split into 8 and each copy of <code>_generate</code> receives an input with shape <code>(1, 77, 768)</code>.</li></ol> <p data-svelte-h="svelte-w1m4t9">The most important thing to pay attention to here is the batch size (1 in this example), and the input dimensions that make sense for your code. You don’t have to change anything else to make the code work in parallel.</p> <p data-svelte-h="svelte-1ue3hjw">The first time you call the pipeline takes more time, but the calls afterward are much faster. The <code>block_until_ready</code> function is used to correctly measure inference time because JAX uses asynchronous dispatch and returns control to the Python loop as soon as it can. You don’t need to use that in your code; blocking occurs automatically when you want to use the result of a computation that has not yet been materialized.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->%%time
images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
<span class="hljs-comment"># CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s</span>
<span class="hljs-comment"># Wall time: 1min 15s</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-735q12">Check your image dimensions to see if they’re correct:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->images.shape
<span class="hljs-comment"># (8, 1, 512, 512, 3)</span><!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="resources" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#resources"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Resources</span></h2> <p data-svelte-h="svelte-1rrs38l">To learn more about how JAX works with Stable Diffusion, you may be interested in reading:</p> <ul data-svelte-h="svelte-16herqb"><li><a href="https://hf.co/blog/sdxl_jax" rel="nofollow">Accelerating Stable Diffusion XL Inference with JAX on Cloud TPU v5e</a></li></ul> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/diffusers/blob/main/docs/source/en/using-diffusers/stable_diffusion_jax_how_to.md" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_o82a48 = {
assets: "/docs/diffusers/pr_11686/en",
base: "/docs/diffusers/pr_11686/en",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/diffusers/pr_11686/en/_app/immutable/entry/start.2b9667fb.js"),
import("/docs/diffusers/pr_11686/en/_app/immutable/entry/app.a2a6117e.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 309],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
43.1 kB
·
Xet hash:
483b19739529561d8bf8b67ca825be97435cb43cce9c2b4cf1a933cf2f5b9124

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.