| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ Classes to support Flax Speech-Encoder-Decoder architectures""" |
|
|
| import os |
| from functools import partial |
| from typing import Optional, Tuple, Union, Dict |
|
|
| import flax |
| import flax.linen as nn |
| import jax |
| import jax.numpy as jnp |
| from flax.core.frozen_dict import FrozenDict, unfreeze |
| from jax import lax |
| from jax.random import PRNGKey |
| import numpy as np |
|
|
| from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput |
| from transformers.modeling_flax_utils import FlaxPreTrainedModel |
| from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ModelOutput |
| from transformers.generation_flax_utils import FlaxLogitsProcessorList |
| from models import ( |
| FlaxWav2Vec2Model, |
| FlaxWav2Vec2Module, |
| FlaxBartForCausalLM, |
| FlaxBartForCausalLMModule, |
| BartConfig, |
| Wav2Vec2Config, |
| SpeechEncoderDecoderConfig, |
| ) |
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CONFIG_FOR_DOC = "SpeechEncoderDecoderConfig" |
|
|
| SPEECH_ENCODER_DECODER_START_DOCSTRING = r""" |
| This class can be used to initialize a speech-sequence-to-text-sequence model with any pretrained speech |
| autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is |
| loaded via [`~AutoModel.from_pretrained`] function and the decoder is loaded via |
| [`~AutoModelForCausalLM.from_pretrained`] function. Cross-attention layers are automatically added to the decoder |
| and should be fine-tuned on a downstream generative task, like summarization. |
| |
| The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation |
| tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation |
| Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi |
| Zhou, Wei Li, Peter J. Liu. |
| |
| Additionally, in [Large-Scale Self- and Semi-Supervised Learning for Speech |
| Translation](https://arxiv.org/abs/2104.06678) it is shown how leveraging large pretrained speech models for speech |
| translation yields a significant performance improvement. |
| |
| After such an Speech-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other |
| models (see the examples for more information). |
| |
| This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| etc.) |
| |
| This model is also a Flax Linen |
| [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a |
| regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. |
| |
| Parameters: |
| config ([`SpeechEncoderDecoderConfig`]): Model configuration class with all the parameters of the model. |
| Initializing with a config file does not load the weights associated with the model, only the |
| configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. |
| dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): |
| The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and |
| `jax.numpy.bfloat16` (on TPUs). |
| |
| This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If |
| specified all the computation will be performed with the given `dtype`. |
| |
| **Note that this only specifies the dtype of the computation and does not influence the dtype of model |
| parameters.** |
| |
| If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and |
| [`~FlaxPreTrainedModel.to_bf16`]. |
| """ |
|
|
| SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r""" |
| Args: |
| inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*): |
| Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac* |
| or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile |
| library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or |
| [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type |
| *torch.FloatTensor*. |
| attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): |
| Indices of decoder input sequence tokens in the vocabulary. |
| |
| Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| [What are input IDs?](../glossary#input-ids) |
| |
| If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see |
| `past_key_values`). |
| |
| For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be |
| created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id` |
| and prepending them with the `decoder_start_token_id`. |
| decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): |
| Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also |
| be used by default. |
| decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
| Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the |
| range `[0, config.decoder.max_position_embeddings - 1]`. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple. |
| """ |
|
|
| SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r""" |
| Args: |
| inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*): |
| Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac* |
| or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile |
| library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or |
| [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type |
| *torch.FloatTensor*. |
| attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple. |
| """ |
|
|
| SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r""" |
| Args: |
| decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): |
| Indices of decoder input sequence tokens in the vocabulary. |
| |
| Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| [What are decoder input IDs?](../glossary#decoder-input-ids) |
| |
| If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see |
| `past_key_values`). |
| |
| For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be |
| created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id` |
| and prepending them with the `decoder_start_token_id`. |
| encoder_outputs (`tuple(tuple(jnp.ndarray)`): |
| Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) |
| `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of |
| hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. |
| encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): |
| Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also |
| be used by default. |
| decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
| Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the |
| range `[0, config.decoder.max_position_embeddings - 1]`. |
| past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): |
| Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast |
| auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a |
| plain tuple. |
| """ |
|
|
| @flax.struct.dataclass |
| class FlaxBeamSearchOutput(ModelOutput): |
| """ |
| Flax Base class for outputs of decoder-only generation models using greedy search. |
| |
| |
| Args: |
| sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): |
| The generated sequences. |
| scores (`jnp.ndarray` of shape `(batch_size,)`): |
| The scores (log probabilites) of the generated sequences. |
| """ |
|
|
| sequences: jnp.ndarray = None |
| scores: jnp.ndarray = None |
|
|
|
|
| @flax.struct.dataclass |
| class BeamSearchState: |
| cur_len: jnp.ndarray |
| running_sequences: jnp.ndarray |
| running_scores: jnp.ndarray |
| sequences: jnp.ndarray |
| scores: jnp.ndarray |
| is_sent_finished: jnp.ndarray |
| model_kwargs: Dict[str, jnp.ndarray] |
|
|
|
|
|
|
|
|
| class FlaxSpeechEncoderDecoderModule(nn.Module): |
| config: SpeechEncoderDecoderConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| encoder_config = self.config.encoder |
| decoder_config = self.config.decoder |
|
|
| |
| encoder_module = FlaxWav2Vec2Module |
| decoder_module = FlaxBartForCausalLMModule |
|
|
| self.encoder = encoder_module(encoder_config, dtype=self.dtype) |
| self.decoder = decoder_module(decoder_config, dtype=self.dtype) |
|
|
| |
| if ( |
| self.encoder.config.hidden_size != self.decoder.config.hidden_size |
| and self.decoder.config.cross_attention_hidden_size is None |
| ): |
| self.enc_to_dec_proj = nn.Dense( |
| self.decoder.config.hidden_size, |
| kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range), |
| dtype=self.dtype, |
| ) |
| else: |
| self.enc_to_dec_proj = None |
|
|
| def _get_feat_extract_output_lengths( |
| self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None |
| ): |
| """ |
| Computes the output length of the convolutional layers |
| """ |
|
|
| add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter |
|
|
| def _conv_out_length(input_length, kernel_size, stride): |
| |
| |
| return (input_length - kernel_size) // stride + 1 |
|
|
| for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride): |
| input_lengths = _conv_out_length(input_lengths, kernel_size, stride) |
|
|
| if add_adapter: |
| for _ in range(self.config.encoder.num_adapter_layers): |
| input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride) |
|
|
| return input_lengths |
|
|
| def _get_encoder_module(self): |
| return self.encoder |
|
|
| def _get_projection_module(self): |
| return self.enc_to_dec_proj |
|
|
| def _get_decoder_module(self): |
| return self.decoder |
|
|
| def __call__( |
| self, |
| inputs, |
| attention_mask, |
| decoder_input_ids, |
| decoder_attention_mask, |
| decoder_position_ids, |
| encoder_outputs=None, |
| extract_features=None, |
| output_attentions: bool = False, |
| output_hidden_states: bool = False, |
| output_features: bool = False, |
| return_dict: bool = True, |
| deterministic: bool = True, |
| freeze_feature_encoder: bool = False, |
| ): |
| if encoder_outputs is None: |
| encoder_outputs = self.encoder( |
| inputs, |
| attention_mask=attention_mask, |
| extract_features=extract_features, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| output_features=output_features, |
| return_dict=return_dict, |
| deterministic=deterministic, |
| freeze_feature_encoder=freeze_feature_encoder, |
| ) |
|
|
| if output_features: |
| return encoder_outputs |
|
|
| encoder_hidden_states = encoder_outputs[0] |
|
|
| |
| if self.enc_to_dec_proj is not None: |
| encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) |
|
|
| |
| if attention_mask is not None: |
| encoder_attention_mask = self.encoder._get_feature_vector_attention_mask( |
| encoder_hidden_states.shape[1], attention_mask |
| ) |
| else: |
| encoder_attention_mask = None |
|
|
| |
| decoder_outputs = self.decoder( |
| input_ids=decoder_input_ids, |
| attention_mask=decoder_attention_mask, |
| position_ids=decoder_position_ids, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| deterministic=deterministic, |
| ) |
|
|
| if not return_dict: |
| return decoder_outputs + encoder_outputs |
|
|
| return FlaxSeq2SeqLMOutput( |
| logits=decoder_outputs.logits, |
| decoder_hidden_states=decoder_outputs.hidden_states, |
| decoder_attentions=decoder_outputs.attentions, |
| cross_attentions=decoder_outputs.cross_attentions, |
| encoder_last_hidden_state=encoder_hidden_states, |
| encoder_hidden_states=encoder_outputs.hidden_states, |
| encoder_attentions=encoder_outputs.attentions, |
| ) |
|
|
|
|
| @add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING) |
| class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): |
| r""" |
| [`FlaxSpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture |
| with the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one |
| as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the |
| encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder. |
| """ |
|
|
| config_class = SpeechEncoderDecoderConfig |
| base_model_prefix: str = "speech_encoder_decoder" |
| module_class = FlaxSpeechEncoderDecoderModule |
|
|
| def __init__( |
| self, |
| config: SpeechEncoderDecoderConfig, |
| input_shape: Optional[Tuple] = None, |
| seed: int = 0, |
| dtype: jnp.dtype = jnp.float32, |
| _do_init: bool = True, |
| **kwargs |
| ): |
|
|
| if not _do_init: |
| raise ValueError( |
| "`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`." |
| ) |
|
|
| if config.decoder.cross_attention_hidden_size is not None: |
| |
| if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: |
| raise ValueError( |
| "If `cross_attention_hidden_size` is specified in the decoder's configuration, " |
| "it has to be equal to the encoder's `hidden_size`. " |
| f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` " |
| f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`." |
| ) |
|
|
| |
| config.tie_word_embeddings = False |
| module = self.module_class(config=config, dtype=dtype, **kwargs) |
|
|
| if input_shape is None: |
| |
| encoder_input_length = 1024 |
| decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length) |
| input_shape = ((1, encoder_input_length), (1, decoder_input_length)) |
|
|
| super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) |
|
|
| def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: |
| encoder_input_shape, decoder_input_shape = input_shape |
|
|
| |
| inputs = jnp.zeros(encoder_input_shape, dtype="f4") |
| attention_mask = jnp.ones_like(inputs, dtype="i4") |
| decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4") |
| decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
|
|
| batch_size, sequence_length = inputs.shape |
|
|
| decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape |
| if not decoder_batch_size == batch_size: |
| raise ValueError( |
| f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder." |
| ) |
| decoder_position_ids = jnp.broadcast_to( |
| jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length) |
| ) |
|
|
| params_rng, dropout_rng = jax.random.split(rng) |
| rngs = {"params": params_rng, "dropout": dropout_rng} |
|
|
| return self.module.init( |
| rngs, |
| inputs, |
| attention_mask, |
| decoder_input_ids, |
| decoder_attention_mask, |
| decoder_position_ids, |
| )["params"] |
|
|
| def init_cache(self, batch_size, max_length, encoder_outputs): |
| r""" |
| Args: |
| batch_size (`int`): |
| batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. |
| max_length (`int`): |
| maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized |
| cache. |
| encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): |
| `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: |
| `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) |
| is a sequence of hidden-states at the output of the last layer of the encoder. Used in the |
| cross-attention of the decoder. |
| """ |
| |
| decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") |
| decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
| decoder_position_ids = jnp.broadcast_to( |
| jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape |
| ) |
|
|
| def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): |
| decoder_module = module._get_decoder_module() |
| return decoder_module( |
| input_ids=decoder_input_ids, |
| attention_mask=decoder_attention_mask, |
| position_ids=decoder_position_ids, |
| **kwargs, |
| ) |
|
|
| init_variables = self.module.init( |
| jax.random.PRNGKey(0), |
| decoder_input_ids=decoder_input_ids, |
| decoder_attention_mask=decoder_attention_mask, |
| decoder_position_ids=decoder_position_ids, |
| encoder_hidden_states=encoder_outputs[0], |
| init_cache=True, |
| method=_decoder_forward, |
| ) |
| return unfreeze(init_variables["cache"]) |
|
|
| def _get_feat_extract_output_lengths( |
| self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None |
| ): |
| return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter) |
|
|
| @add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING) |
| @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC) |
| def encode( |
| self, |
| inputs: jnp.ndarray, |
| attention_mask: Optional[jnp.ndarray] = None, |
| extract_features: Optional[jnp.ndarray] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_features: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| train: bool = False, |
| freeze_feature_encoder: bool = False, |
| params: dict = None, |
| dropout_rng: PRNGKey = None, |
| ): |
| r""" |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import FlaxSpeechEncoderDecoderModel |
| |
| >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized |
| >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( |
| ... "facebook/wav2vec2-large-lv60", "facebook/bart-large" |
| ... ) |
| |
| >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32) |
| >>> encoder_outputs = model.encode(inputs) |
| ```""" |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
| if attention_mask is None: |
| attention_mask = jnp.ones_like(inputs, dtype="i4") |
|
|
| if extract_features is not None: |
| extract_features = jnp.array(extract_features, dtype="f4") |
|
|
| |
| rngs = {} |
| if dropout_rng is not None: |
| rngs["dropout"] = dropout_rng |
|
|
| def _encoder_forward(module, inputs, attention_mask, **kwargs): |
| encode_module = module._get_encoder_module() |
| return encode_module(inputs, attention_mask, **kwargs) |
|
|
| outputs = self.module.apply( |
| {"params": params or self.params}, |
| inputs=jnp.array(inputs, dtype="f4"), |
| attention_mask=jnp.array(attention_mask, dtype="i4"), |
| extract_features=extract_features, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| output_features=output_features, |
| return_dict=return_dict, |
| deterministic=not train, |
| freeze_feature_encoder=freeze_feature_encoder, |
| rngs=rngs, |
| method=_encoder_forward, |
| ) |
|
|
| if return_dict and not output_features: |
| outputs = FlaxBaseModelOutput( |
| last_hidden_state=outputs.last_hidden_state, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| return outputs |
|
|
| @add_start_docstrings(SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING) |
| @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) |
| def decode( |
| self, |
| decoder_input_ids, |
| encoder_outputs, |
| encoder_attention_mask: Optional[jnp.ndarray] = None, |
| decoder_attention_mask: Optional[jnp.ndarray] = None, |
| decoder_position_ids: Optional[jnp.ndarray] = None, |
| past_key_values: dict = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| train: bool = False, |
| params: dict = None, |
| dropout_rng: PRNGKey = None, |
| ): |
| r""" |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import FlaxSpeechEncoderDecoderModel |
| >>> import jax.numpy as jnp |
| |
| >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized |
| >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( |
| ... "facebook/wav2vec2-large-lv60", "facebook/bart-large" |
| ... ) |
| |
| >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32) |
| >>> encoder_outputs = model.encode(inputs) |
| |
| >>> decoder_start_token_id = model.config.decoder.bos_token_id |
| >>> decoder_input_ids = jnp.ones((inputs.shape[0], 1), dtype="i4") * decoder_start_token_id |
| |
| >>> outputs = model.decode(decoder_input_ids, encoder_outputs) |
| >>> logits = outputs.logits |
| ```""" |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
| encoder_hidden_states = encoder_outputs[0] |
| if encoder_attention_mask is None: |
| batch_size, sequence_length = encoder_hidden_states.shape[:2] |
| encoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
| batch_size, sequence_length = decoder_input_ids.shape |
| if decoder_attention_mask is None: |
| decoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
| if decoder_position_ids is None: |
| if past_key_values is not None: |
| raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") |
|
|
| decoder_position_ids = jnp.broadcast_to( |
| jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) |
| ) |
|
|
| |
| rngs = {} |
| if dropout_rng is not None: |
| rngs["dropout"] = dropout_rng |
|
|
| params = {"params": params or self.params} |
|
|
| |
| |
| |
| if past_key_values: |
| params["cache"] = past_key_values |
| mutable = ["cache"] |
| else: |
| mutable = False |
|
|
| def _decoder_forward( |
| module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs |
| ): |
|
|
| projection_module = module._get_projection_module() |
| decoder_module = module._get_decoder_module() |
|
|
| |
| if projection_module is not None: |
| encoder_hidden_states = projection_module(encoder_hidden_states) |
|
|
| return decoder_module( |
| decoder_input_ids, |
| decoder_attention_mask, |
| decoder_position_ids, |
| encoder_hidden_states, |
| **kwargs, |
| ) |
|
|
| outputs = self.module.apply( |
| params, |
| decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
| decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
| decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| deterministic=not train, |
| rngs=rngs, |
| mutable=mutable, |
| method=_decoder_forward, |
| ) |
|
|
| |
| if past_key_values is not None and return_dict: |
| outputs, past = outputs |
| outputs["past_key_values"] = unfreeze(past["cache"]) |
| return outputs |
| elif past_key_values is not None and not return_dict: |
| outputs, past = outputs |
| outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] |
|
|
| return outputs |
|
|
| @add_start_docstrings_to_model_forward(SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING) |
| @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) |
| def __call__( |
| self, |
| inputs: jnp.ndarray, |
| attention_mask: Optional[jnp.ndarray] = None, |
| extract_features: Optional[jnp.ndarray] = None, |
| decoder_input_ids: Optional[jnp.ndarray] = None, |
| decoder_attention_mask: Optional[jnp.ndarray] = None, |
| decoder_position_ids: Optional[jnp.ndarray] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_features: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| train: bool = False, |
| freeze_feature_encoder: bool = False, |
| params: dict = None, |
| dropout_rng: PRNGKey = None, |
| ): |
| r""" |
| Returns: |
| |
| Examples: |
| |
| ```python |
| >>> from transformers import FlaxSpeechEncoderDecoderModel, BartTokenizer |
| |
| >>> # load a fine-tuned wav2vec2-2-bart model |
| >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large") |
| >>> # load output tokenizer |
| >>> tokenizer_output = BartTokenizer.from_pretrained("facebook/bart-large") |
| |
| >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32) |
| |
| >>> # use bart's special bos, pad and eos tokens |
| >>> model.config.decoder_start_token_id = model.decoder.config.bos_token_id |
| >>> model.config.pad_token_id = model.decoder.config.pad_token_id |
| >>> model.config.eos_token_id = model.decoder.config.eos_token_id |
| |
| >>> outputs = model.generate(inputs) |
| # Assert something? More interesting input? dtype correct? |
| ``` |
| """ |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
| |
| if attention_mask is None: |
| attention_mask = jnp.ones_like(inputs, dtype="i4") |
|
|
| if extract_features is not None: |
| inputs = None |
| extract_features = jnp.array(extract_features, dtype="f4") |
| else: |
| inputs = jnp.array(inputs, dtype="f4") |
|
|
| |
| if decoder_input_ids is None: |
| raise ValueError( |
| "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument." |
| ) |
| if decoder_attention_mask is None: |
| decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
| if decoder_position_ids is None: |
| batch_size, sequence_length = decoder_input_ids.shape |
| decoder_position_ids = jnp.broadcast_to( |
| jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) |
| ) |
|
|
| |
| rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} |
|
|
| return self.module.apply( |
| {"params": params or self.params}, |
| inputs=inputs, |
| attention_mask=jnp.array(attention_mask, dtype="i4"), |
| extract_features=extract_features, |
| decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
| decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
| decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| output_features=output_features, |
| return_dict=return_dict, |
| deterministic=not train, |
| freeze_feature_encoder=freeze_feature_encoder, |
| rngs=rngs, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| decoder_input_ids, |
| max_length, |
| attention_mask: Optional[jnp.DeviceArray] = None, |
| decoder_attention_mask: Optional[jnp.DeviceArray] = None, |
| encoder_outputs=None, |
| **kwargs |
| ): |
| |
| batch_size, seq_length = decoder_input_ids.shape |
|
|
| past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) |
| |
| |
| |
| extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") |
| if decoder_attention_mask is not None: |
| decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 |
| extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) |
| else: |
| decoder_position_ids = jnp.broadcast_to( |
| jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) |
| ) |
|
|
| return { |
| "past_key_values": past_key_values, |
| "encoder_outputs": encoder_outputs, |
| "encoder_attention_mask": attention_mask, |
| "decoder_attention_mask": extended_attention_mask, |
| "decoder_position_ids": decoder_position_ids, |
| } |
|
|
| def update_inputs_for_generation(self, model_outputs, model_kwargs): |
| model_kwargs["past_key_values"] = model_outputs.past_key_values |
| model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 |
| return model_kwargs |
|
|
| @classmethod |
| def from_encoder_decoder_pretrained( |
| cls, |
| encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, |
| decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, |
| *model_args, |
| **kwargs |
| ) -> FlaxPreTrainedModel: |
| r""" |
| Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model |
| checkpoints. |
| |
| Params: |
| encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*): |
| Information necessary to initiate the encoder. Can be either: |
| |
| - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. |
| Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a |
| user or organization name, like `dbmdz/bert-base-german-cased`. |
| - A path to a *directory* containing model weights saved using |
| [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. |
| |
| decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`): |
| Information necessary to initiate the decoder. Can be either: |
| |
| - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. |
| Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a |
| user or organization name, like `dbmdz/bert-base-german-cased`. |
| - A path to a *directory* containing model weights saved using |
| [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. |
| |
| model_args (remaining positional arguments, *optional*): |
| All remaning positional arguments will be passed to the underlying model's `__init__` method. |
| |
| kwargs (remaining dictionary of keyword arguments, *optional*): |
| Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., |
| `output_attentions=True`). |
| |
| - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. |
| - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. |
| - To update the parent model configuration, do not use a prefix for each configuration parameter. |
| |
| Behaves differently depending on whether a `config` is provided or automatically loaded. |
| |
| Example: |
| |
| ```python |
| >>> from transformers import FlaxSpeechEncoderDecoderModel |
| |
| >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized |
| >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( |
| ... "facebook/wav2vec2-large-lv60", "facebook/bart-large" |
| ... ) |
| >>> # saving model after fine-tuning |
| >>> model.save_pretrained("./wav2vec2-2-bart-large") |
| >>> # load fine-tuned model |
| >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("./wav2vec2-2-bart-large") |
| ```""" |
|
|
| kwargs_encoder = { |
| argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") |
| } |
|
|
| kwargs_decoder = { |
| argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") |
| } |
|
|
| |
| for key in kwargs_encoder.keys(): |
| del kwargs["encoder_" + key] |
| for key in kwargs_decoder.keys(): |
| del kwargs["decoder_" + key] |
|
|
| |
| |
| |
| encoder = kwargs_encoder.pop("model", None) |
| if encoder is None: |
| if encoder_pretrained_model_name_or_path is None: |
| raise ValueError( |
| "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " |
| "to be defined." |
| ) |
|
|
| if "config" not in kwargs_encoder: |
| |
| encoder_config, kwargs_encoder = Wav2Vec2Config.from_pretrained( |
| encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True |
| ) |
| if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: |
| logger.info( |
| f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " |
| "from a decoder model. Cross-attention and casual mask are disabled." |
| ) |
| encoder_config.is_decoder = False |
| encoder_config.add_cross_attention = False |
|
|
| kwargs_encoder["config"] = encoder_config |
|
|
| |
| encoder = FlaxWav2Vec2Model.from_pretrained( |
| encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder |
| ) |
|
|
| decoder = kwargs_decoder.pop("model", None) |
| if decoder is None: |
| if decoder_pretrained_model_name_or_path is None: |
| raise ValueError( |
| "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " |
| "to be defined." |
| ) |
|
|
| if "config" not in kwargs_decoder: |
| |
| decoder_config, kwargs_decoder = BartConfig.from_pretrained( |
| decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True |
| ) |
| if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: |
| logger.info( |
| f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " |
| f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} " |
| f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for " |
| "cross attention layers." |
| ) |
| decoder_config.is_decoder = True |
| decoder_config.add_cross_attention = True |
|
|
| kwargs_decoder["config"] = decoder_config |
|
|
| if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: |
| logger.warning( |
| f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " |
| f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " |
| "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " |
| "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " |
| "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" |
| ) |
|
|
| |
| decoder = FlaxBartForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) |
|
|
| |
| dtype = kwargs.pop("dtype", jnp.float32) |
| config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) |
|
|
| |
| config.tie_word_embeddings = False |
|
|
| |
| model = cls(config, dtype=dtype) |
| model.params["encoder"] = encoder.params |
| model.params["decoder"] = decoder.params |
|
|
| return model |
|
|
| def _beam_search( |
| self, |
| input_ids: None, |
| max_length: Optional[int] = None, |
| pad_token_id: Optional[int] = None, |
| eos_token_id: Optional[int] = None, |
| length_penalty: Optional[float] = None, |
| early_stopping: Optional[bool] = None, |
| logits_processor: Optional[FlaxLogitsProcessorList] = None, |
| trace: bool = True, |
| params: Optional[Dict[str, jnp.ndarray]] = None, |
| model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, |
| ): |
| """ |
| This beam search function is heavily inspired by Flax's official example: |
| https://github.com/google/flax/blob/master/examples/wmt/train.py#L254 |
| """ |
|
|
| def flatten_beam_dim(tensor): |
| """Flattens the first two dimensions of a non-scalar array.""" |
| |
| if tensor.ndim == 0 or tensor.ndim == 1: |
| return tensor |
| elif tensor.ndim == 6: |
| return tensor.reshape(tensor.shape[:1] + (tensor.shape[1] * tensor.shape[2],) + tensor.shape[3:]) |
| return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:]) |
|
|
| def unflatten_beam_dim(tensor, batch_size, num_beams): |
| """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" |
| |
| if tensor.ndim == 0 or tensor.ndim == 1: |
| return tensor |
| if tensor.ndim == 5: |
| return tensor.reshape(tensor.shape[:1] + (batch_size, num_beams) + tensor.shape[2:]) |
| return tensor.reshape((batch_size, num_beams) + tensor.shape[1:]) |
|
|
| def gather_beams(nested, beam_indices, batch_size, new_num_beams): |
| """ |
| Gathers the beam slices indexed by beam_indices into new beam array. |
| """ |
| batch_indices = jnp.reshape( |
| jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams) |
| ) |
|
|
| def gather_fn(tensor): |
| |
| if tensor.ndim == 0 or tensor.ndim == 1: |
| return tensor |
| if tensor.ndim == 6: |
| return tensor[:, batch_indices, beam_indices] |
| return tensor[batch_indices, beam_indices] |
|
|
| return jax.tree_map(gather_fn, nested) |
|
|
| |
| max_length = max_length if max_length is not None else self.config.max_length |
| pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id |
| eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
| length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty |
| early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping |
|
|
| batch_size, num_beams, cur_len = input_ids.shape |
|
|
| eos_token_id = jnp.array(eos_token_id) |
| pad_token_id = jnp.array(pad_token_id) |
| cur_len = jnp.array(cur_len) |
|
|
| |
| sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32) |
| running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32) |
| running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0)) |
|
|
| |
| is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_) |
|
|
| |
| running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1]) |
| scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7) |
|
|
| |
| |
| model = self.decode if self.config.is_encoder_decoder else self |
|
|
| |
| if "encoder_outputs" in model_kwargs: |
| model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim( |
| model_kwargs["encoder_outputs"]["last_hidden_state"] |
| ) |
| if "attention_mask" in model_kwargs: |
| model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"]) |
|
|
| |
| model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs) |
|
|
| |
| state = BeamSearchState( |
| cur_len=cur_len, |
| running_sequences=running_sequences, |
| running_scores=running_scores, |
| sequences=sequences, |
| scores=scores, |
| is_sent_finished=is_sent_finished, |
| model_kwargs=model_kwargs, |
| ) |
|
|
| def beam_search_cond_fn(state): |
| """beam search state termination condition fn.""" |
|
|
| |
| not_max_length_yet = state.cur_len < max_length |
|
|
| |
| best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty) |
| worst_finished_score = jnp.where( |
| state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7) |
| ) |
| improvement_still_possible = jnp.all(worst_finished_score < best_running_score) |
|
|
| |
| still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping) |
|
|
| return not_max_length_yet & still_open_beam & improvement_still_possible |
|
|
| def beam_search_body_fn(state, input_ids_length=1): |
| """beam search state update fn.""" |
| |
| |
| |
| |
| |
| |
| input_token = flatten_beam_dim( |
| lax.dynamic_slice( |
| state.running_sequences, |
| (0, 0, state.cur_len - input_ids_length), |
| (batch_size, num_beams, input_ids_length), |
| ) |
| ) |
| model_outputs = model(input_token, params=params, **state.model_kwargs) |
|
|
| logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams) |
| cache = jax.tree_map( |
| lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values |
| ) |
|
|
| |
| logits = self._adapt_logits_for_beam_search(logits) |
|
|
| |
| |
| |
| |
| log_probs = jax.nn.log_softmax(logits) |
| log_probs = logits_processor( |
| flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len |
| ) |
| log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) |
| log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2) |
| vocab_size = log_probs.shape[2] |
| log_probs = log_probs.reshape((batch_size, num_beams * vocab_size)) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| beams_to_keep = 2 * num_beams |
| topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep) |
| topk_beam_indices = topk_indices // vocab_size |
| topk_running_sequences = gather_beams( |
| state.running_sequences, topk_beam_indices, batch_size, beams_to_keep |
| ) |
| topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) |
| topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len)) |
|
|
| |
| |
| |
| |
| |
| |
| did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id |
| running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7) |
| |
| |
| |
| next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1) |
| next_running_sequences, next_running_scores = gather_beams( |
| [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams |
| ) |
|
|
| |
| |
| |
| |
| |
| topk_log_probs = topk_log_probs / (state.cur_len**length_penalty) |
| beams_in_batch_are_full = ( |
| jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape) |
| & early_stopping |
| ) |
| add_penalty = ~did_topk_just_finished | beams_in_batch_are_full |
| topk_log_probs += add_penalty * np.array(-1.0e7) |
|
|
| |
| |
| |
| |
| merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1) |
| merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1) |
| merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1) |
| topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1) |
| next_sequences, next_scores, next_is_sent_finished = gather_beams( |
| [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams |
| ) |
|
|
| |
| |
| |
| next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams) |
| next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams) |
| model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache) |
| next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs) |
|
|
| return BeamSearchState( |
| cur_len=state.cur_len + 1, |
| running_scores=next_running_scores, |
| running_sequences=next_running_sequences, |
| scores=next_scores, |
| sequences=next_sequences, |
| is_sent_finished=next_is_sent_finished, |
| model_kwargs=next_model_kwargs, |
| ) |
|
|
| |
| if input_ids.shape[-1] > 1: |
| state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state) |
|
|
| if not trace: |
| state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state) |
| else: |
| state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state) |
|
|
| |
| |
| none_finished = jnp.any(state.is_sent_finished, axis=1) |
| sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences) |
| scores = jnp.where(none_finished[:, None], state.scores, state.running_scores) |
|
|
| |
| sequences = sequences[:, :] |
| scores = scores[:, -1] |
|
|
| return FlaxBeamSearchOutput(sequences=sequences, scores=scores) |
|
|