| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import flax.linen as nn |
| import jax |
| import jax.numpy as jnp |
|
|
|
|
| class FlaxUpsample2D(nn.Module): |
| out_channels: int |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.conv = nn.Conv( |
| self.out_channels, |
| kernel_size=(3, 3), |
| strides=(1, 1), |
| padding=((1, 1), (1, 1)), |
| dtype=self.dtype, |
| ) |
|
|
| def __call__(self, hidden_states): |
| batch, height, width, channels = hidden_states.shape |
| hidden_states = jax.image.resize( |
| hidden_states, |
| shape=(batch, height * 2, width * 2, channels), |
| method="nearest", |
| ) |
| hidden_states = self.conv(hidden_states) |
| return hidden_states |
|
|
|
|
| class FlaxDownsample2D(nn.Module): |
| out_channels: int |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.conv = nn.Conv( |
| self.out_channels, |
| kernel_size=(3, 3), |
| strides=(2, 2), |
| padding=((1, 1), (1, 1)), |
| dtype=self.dtype, |
| ) |
|
|
| def __call__(self, hidden_states): |
| |
| |
| hidden_states = self.conv(hidden_states) |
| return hidden_states |
|
|
|
|
| class FlaxResnetBlock2D(nn.Module): |
| in_channels: int |
| out_channels: int = None |
| dropout_prob: float = 0.0 |
| use_nin_shortcut: bool = None |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| out_channels = self.in_channels if self.out_channels is None else self.out_channels |
|
|
| self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) |
| self.conv1 = nn.Conv( |
| out_channels, |
| kernel_size=(3, 3), |
| strides=(1, 1), |
| padding=((1, 1), (1, 1)), |
| dtype=self.dtype, |
| ) |
|
|
| self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) |
|
|
| self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) |
| self.dropout = nn.Dropout(self.dropout_prob) |
| self.conv2 = nn.Conv( |
| out_channels, |
| kernel_size=(3, 3), |
| strides=(1, 1), |
| padding=((1, 1), (1, 1)), |
| dtype=self.dtype, |
| ) |
|
|
| use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut |
|
|
| self.conv_shortcut = None |
| if use_nin_shortcut: |
| self.conv_shortcut = nn.Conv( |
| out_channels, |
| kernel_size=(1, 1), |
| strides=(1, 1), |
| padding="VALID", |
| dtype=self.dtype, |
| ) |
|
|
| def __call__(self, hidden_states, temb, deterministic=True): |
| residual = hidden_states |
| hidden_states = self.norm1(hidden_states) |
| hidden_states = nn.swish(hidden_states) |
| hidden_states = self.conv1(hidden_states) |
|
|
| temb = self.time_emb_proj(nn.swish(temb)) |
| temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) |
| hidden_states = hidden_states + temb |
|
|
| hidden_states = self.norm2(hidden_states) |
| hidden_states = nn.swish(hidden_states) |
| hidden_states = self.dropout(hidden_states, deterministic) |
| hidden_states = self.conv2(hidden_states) |
|
|
| if self.conv_shortcut is not None: |
| residual = self.conv_shortcut(residual) |
|
|
| return hidden_states + residual |
|
|