| from enum import Enum |
|
|
| import numpy as np |
| import torch |
| import torch.distributed as dist |
|
|
| from transformers import PreTrainedModel |
| from typing import List, Optional |
|
|
|
|
| IGNORE_INDEX = -100 |
| IMAGE_TOKEN_INDEX = -200 |
|
|
| DEFAULT_EOS_TOKEN = '</s>' |
| DEFAULT_BOS_TOKEN = '<s>' |
| DEFAULT_UNK_TOKEN = '<unk>' |
|
|
| DEFAULT_IMAGE_TOKEN = "<image>" |
| DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>" |
| DEFAULT_IM_START_TOKEN = "<im_start>" |
| DEFAULT_IM_END_TOKEN = "<im_end>" |
| DEFAULT_BBOX_TOKEN = "<bbox>" |
|
|
|
|
|
|
| |
| def prepare_inputs_labels_for_multimodal( |
| llm: PreTrainedModel, |
| input_ids: torch.LongTensor = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| labels: Optional[torch.LongTensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| **kwargs): |
| if pixel_values is None: |
| kwargs.update({ |
| 'input_ids': input_ids, |
| 'position_ids': position_ids, |
| 'attention_mask': attention_mask, |
| 'past_key_values': past_key_values, |
| 'inputs_embeds': None, |
| 'labels': labels |
| }) |
| return kwargs |
|
|
| _labels = labels |
| _position_ids = position_ids |
| _attention_mask = attention_mask |
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
| else: |
| attention_mask = attention_mask.bool() |
| if position_ids is None: |
| position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) |
| if labels is None: |
| labels = torch.full_like(input_ids, IGNORE_INDEX) |
|
|
| |
| input_ids = [ |
| cur_input_ids[cur_attention_mask] |
| for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) |
| ] |
| labels = [ |
| cur_labels[cur_attention_mask] |
| for cur_labels, cur_attention_mask in zip(labels, attention_mask) |
| ] |
|
|
| new_inputs_embeds = [] |
| new_labels = [] |
| new_input_ids = [] |
| cur_image_idx = 0 |
| for batch_idx, cur_input_ids in enumerate(input_ids): |
| num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() |
| if num_images == 0: |
| cur_pixel_values = pixel_values[cur_image_idx] |
| cur_inputs_embeds_1 = llm.get_input_embeddings()(cur_input_ids) |
| cur_inputs_embeds = torch.cat([cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0) |
| new_inputs_embeds.append(cur_inputs_embeds) |
| new_labels.append(labels[batch_idx]) |
| new_input_ids.append(cur_input_ids) |
| cur_image_idx += 1 |
| continue |
|
|
| image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] |
| cur_input_ids_noim = [] |
| cur_labels = labels[batch_idx] |
| cur_labels_noim = [] |
| for i in range(len(image_token_indices) - 1): |
| cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]]) |
| cur_labels_noim.append(cur_labels[image_token_indices[i] + 1:image_token_indices[i + 1]]) |
|
|
| split_sizes = [x.shape[0] for x in cur_labels_noim] |
| cur_inputs_embeds = llm.get_input_embeddings()(torch.cat(cur_input_ids_noim)) |
| cur_inputs_embeds_no_im = torch.split(cur_inputs_embeds, split_sizes, dim=0) |
| cur_new_inputs_embeds = [] |
| cur_new_labels = [] |
| cur_new_input_ids = [] |
|
|
| for i in range(num_images + 1): |
| cur_new_inputs_embeds.append(cur_inputs_embeds_no_im[i]) |
| cur_new_labels.append(cur_labels_noim[i]) |
| cur_new_input_ids.append(cur_input_ids_noim[i]) |
| if i < num_images: |
| cur_pixel_values = pixel_values[cur_image_idx] |
| cur_image_idx += 1 |
| cur_new_inputs_embeds.append(cur_pixel_values) |
| cur_new_labels.append(torch.full((cur_pixel_values.shape[0], ), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) |
| cur_new_input_ids.append(torch.full((cur_pixel_values.shape[0], ), IMAGE_TOKEN_INDEX, device=cur_input_ids.device, dtype=cur_input_ids.dtype)) |
| |
| cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds) |
| cur_new_labels = torch.cat(cur_new_labels) |
| cur_new_input_ids = torch.cat(cur_new_input_ids) |
|
|
| new_inputs_embeds.append(cur_new_inputs_embeds) |
| new_labels.append(cur_new_labels) |
| new_input_ids.append(cur_new_input_ids) |
|
|
| |
| max_len = max(x.shape[0] for x in new_inputs_embeds) |
| batch_size = len(new_inputs_embeds) |
|
|
| new_inputs_embeds_padded = [] |
| new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) |
| new_input_ids_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_input_ids[0].dtype, device=new_input_ids[0].device) |
| attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) |
| position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) |
|
|
| for i, (cur_new_embed, cur_new_labels, cur_new_input_ids) in enumerate(zip(new_inputs_embeds, new_labels, new_input_ids)): |
| cur_len = cur_new_embed.shape[0] |
| new_inputs_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)) |
| if cur_len > 0: |
| new_labels_padded[i, :cur_len] = cur_new_labels |
| new_input_ids_padded[i, :cur_len] = cur_new_input_ids |
| attention_mask[i, :cur_len] = True |
| position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) |
|
|
| new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0) |
|
|
| if _labels is None: |
| new_labels = None |
| else: |
| new_labels = new_labels_padded |
|
|
| new_input_ids = new_input_ids_padded |
|
|
| if _attention_mask is None: |
| attention_mask = None |
| else: |
| attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
|
|
| if _position_ids is None: |
| position_ids = None |
|
|
| kwargs.update({ |
| 'input_ids': None, |
| 'position_ids': position_ids, |
| 'attention_mask': attention_mask, |
| 'past_key_values': past_key_values, |
| 'inputs_embeds': new_inputs_embeds, |
| 'labels': new_labels, |
| 'new_input_ids': new_input_ids |
| }) |
| return kwargs |
|
|
| class Summary(Enum): |
| NONE = 0 |
| AVERAGE = 1 |
| SUM = 2 |
| COUNT = 3 |
|
|
|
|
| class AverageMeter(object): |
| """Computes and stores the average and current value""" |
|
|
| def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): |
| self.name = name |
| self.fmt = fmt |
| self.summary_type = summary_type |
| self.reset() |
|
|
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
|
|
| def update(self, val, n=1): |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
| def all_reduce(self): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| if isinstance(self.sum, np.ndarray): |
| total = torch.tensor( |
| self.sum.tolist() |
| + [ |
| self.count, |
| ], |
| dtype=torch.float32, |
| device=device, |
| ) |
| else: |
| total = torch.tensor( |
| [self.sum, self.count], dtype=torch.float32, device=device |
| ) |
|
|
| dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) |
| if total.shape[0] > 2: |
| self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item() |
| else: |
| self.sum, self.count = total.tolist() |
| self.avg = self.sum / (self.count + 1e-5) |
|
|
| def __str__(self): |
| fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" |
| return fmtstr.format(**self.__dict__) |
|
|
| def summary(self): |
| fmtstr = "" |
| if self.summary_type is Summary.NONE: |
| fmtstr = "" |
| elif self.summary_type is Summary.AVERAGE: |
| fmtstr = "{name} {avg:.3f}" |
| elif self.summary_type is Summary.SUM: |
| fmtstr = "{name} {sum:.3f}" |
| elif self.summary_type is Summary.COUNT: |
| fmtstr = "{name} {count:.3f}" |
| else: |
| raise ValueError("invalid summary type %r" % self.summary_type) |
|
|
| return fmtstr.format(**self.__dict__) |
|
|
|
|
| def intersectionAndUnionGPU(output, target, K, ignore_index=255): |
| |
| assert output.dim() in [1, 2, 3] |
| assert output.shape == target.shape |
| output = output.view(-1) |
| target = target.view(-1) |
| output[target == ignore_index] = ignore_index |
| intersection = output[output == target] |
| area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1) |
| area_output = torch.histc(output, bins=K, min=0, max=K - 1) |
| area_target = torch.histc(target, bins=K, min=0, max=K - 1) |
| area_union = area_output + area_target - area_intersection |
| return area_intersection, area_union, area_target |
|
|
|
|
| class ProgressMeter(object): |
| def __init__(self, num_batches, meters, prefix=""): |
| self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
| self.meters = meters |
| self.prefix = prefix |
|
|
| def display(self, batch): |
| entries = [self.prefix + self.batch_fmtstr.format(batch)] |
| entries += [str(meter) for meter in self.meters] |
| print("\t".join(entries)) |
|
|
| def display_summary(self): |
| entries = [" *"] |
| entries += [meter.summary() for meter in self.meters] |
| print(" ".join(entries)) |
|
|
| def _get_batch_fmtstr(self, num_batches): |
| num_digits = len(str(num_batches // 1)) |
| fmt = "{:" + str(num_digits) + "d}" |
| return "[" + fmt + "/" + fmt.format(num_batches) + "]" |
|
|
|
|
| def dict_to_cuda(input_dict): |
| for k, v in input_dict.items(): |
| if isinstance(input_dict[k], torch.Tensor): |
| input_dict[k] = v.cuda(non_blocking=True) |
| elif isinstance(v, list) and len(v) > 0: |
| input_dict[k] = [ele.cuda(non_blocking=True) if isinstance(ele, torch.Tensor) else ele for ele in v] |
| return input_dict |
|
|