| | import torch |
| |
|
| | def generate(model, input_ids, generation_config=None, left_padding=None, **kwargs): |
| | |
| | print("✨ using a custom generation method ✨") |
| | |
| | generation_config = generation_config or model.generation_config |
| | cur_length = input_ids.shape[1] |
| | max_length = generation_config.max_length or cur_length + kwargs.get("max_new_tokens") |
| |
|
| | |
| | if left_padding is not None: |
| | if not isinstance(left_padding, int) or left_padding < 0: |
| | raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}") |
| |
|
| | pad_token = kwargs.pop("pad_token", None) or generation_config.pad_token_id or model.config.pad_token_id |
| | if pad_token is None: |
| | raise ValueError("pad_token is not defined") |
| | batch_size = input_ids.shape[0] |
| | pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device) |
| | input_ids = torch.cat((pad_tensor, input_ids), dim=1) |
| | cur_length = input_ids.shape[1] |
| |
|
| | |
| | while cur_length < max_length: |
| | logits = model(input_ids).logits |
| | next_token_logits = logits[:, -1, :] |
| | next_tokens = torch.argmax(next_token_logits, dim=-1) |
| | input_ids = torch.cat((input_ids, next_tokens[:, None]), dim=-1) |
| | cur_length += 1 |
| |
|
| | return input_ids |