| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import math |
|
|
| import flax.linen as nn |
| import jax.numpy as jnp |
|
|
|
|
| def get_sinusoidal_embeddings( |
| timesteps: jnp.ndarray, |
| embedding_dim: int, |
| freq_shift: float = 1, |
| min_timescale: float = 1, |
| max_timescale: float = 1.0e4, |
| flip_sin_to_cos: bool = False, |
| scale: float = 1.0, |
| ) -> jnp.ndarray: |
| """Returns the positional encoding (same as Tensor2Tensor). |
| |
| Args: |
| timesteps: a 1-D Tensor of N indices, one per batch element. |
| These may be fractional. |
| embedding_dim: The number of output channels. |
| min_timescale: The smallest time unit (should probably be 0.0). |
| max_timescale: The largest time unit. |
| Returns: |
| a Tensor of timing signals [N, num_channels] |
| """ |
| assert timesteps.ndim == 1, "Timesteps should be a 1d-array" |
| assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" |
| num_timescales = float(embedding_dim // 2) |
| log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) |
| inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) |
| emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) |
|
|
| |
| scaled_time = scale * emb |
|
|
| if flip_sin_to_cos: |
| signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1) |
| else: |
| signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1) |
| signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) |
| return signal |
|
|
|
|
| class FlaxTimestepEmbedding(nn.Module): |
| r""" |
| Time step Embedding Module. Learns embeddings for input time steps. |
| |
| Args: |
| time_embed_dim (`int`, *optional*, defaults to `32`): |
| Time step embedding dimension |
| dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): |
| Parameters `dtype` |
| """ |
|
|
| time_embed_dim: int = 32 |
| dtype: jnp.dtype = jnp.float32 |
|
|
| @nn.compact |
| def __call__(self, temb): |
| temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) |
| temb = nn.silu(temb) |
| temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) |
| return temb |
|
|
|
|
| class FlaxTimesteps(nn.Module): |
| r""" |
| Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 |
| |
| Args: |
| dim (`int`, *optional*, defaults to `32`): |
| Time step embedding dimension |
| """ |
|
|
| dim: int = 32 |
| flip_sin_to_cos: bool = False |
| freq_shift: float = 1 |
|
|
| @nn.compact |
| def __call__(self, timesteps): |
| return get_sinusoidal_embeddings( |
| timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift |
| ) |
|
|