| | import torch.nn as nn |
| |
|
| |
|
| | def unwrap_model(model): |
| | """ |
| | Unwrap a model from a DataParallel or DistributedDataParallel wrapper. |
| | """ |
| | if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): |
| | return model.module |
| | else: |
| | return model |
| |
|
| |
|
| | def get_label(lang_x, tokenizer, mode='colon'): |
| | eoc_token = '<|endofchunk|>' |
| | media_token = '<image>' |
| | colon_token_id = tokenizer.encode(':')[0] |
| | eoc_token_id = tokenizer.additional_special_tokens_ids[ |
| | tokenizer.additional_special_tokens.index(eoc_token) |
| | ] |
| | media_token_id = tokenizer.additional_special_tokens_ids[ |
| | tokenizer.additional_special_tokens.index(media_token) |
| | ] |
| | label = lang_x.clone() |
| | |
| | for idx in range(len(label)): |
| | if mode == 'colon': |
| | |
| | |
| | indices = (label[idx] == colon_token_id).nonzero().flatten() |
| | |
| | end_of_context = indices[-1].item() + 1 |
| | elif isinstance(mode, int): |
| | end_of_context = -label[idx].tolist()[::-1].index(media_token_id) - 1 + mode |
| | label[idx, : end_of_context] = -100 |
| | label[label == tokenizer.pad_token_id] = -100 |
| | label[:, 0] = -100 |
| | label[label == media_token_id] = -100 |
| | label[label == eoc_token_id] = -100 |
| | return label |