| import io |
| import os |
| import random |
| import re |
| from typing import Dict |
| import copy |
|
|
| import cv2 |
| import imageio |
| import numpy as np |
| import torch |
| import torchvision.transforms as T |
| import transformers |
| from PIL import Image |
| from torch.utils.data import ConcatDataset, WeightedRandomSampler |
| from torchvision.transforms.functional import InterpolationMode |
| from xtuner.utils import IGNORE_INDEX |
| IGNORE_TOKEN_ID = IGNORE_INDEX |
| from mmengine.config import ConfigDict |
|
|
| from ..utils import (get_conv_template, IMG_CONTEXT_TOKEN, IMG_START_TOKEN, |
| IMG_END_TOKEN, DEFAULT_VISION_PROMPT_TOKEN, VPT_START_TOKEN, |
| VPT_END_TOKEN, VPT_CONTEXT_TOKEN) |
|
|
| try: |
| from petrel_client.client import Client |
| from petrel_client.common.config import Config |
| except ImportError as E: |
| print('petrel_client is not installed. If you read data locally instead of from ceph, ignore it.') |
| import sys |
|
|
|
|
| def preprocess( |
| template_name, |
| sources, |
| tokenizer: transformers.PreTrainedTokenizer, |
| num_image_token_list: list, |
| text_only: bool = False, |
| group_by_length: bool = False, |
| use_packed_ds: bool = False, |
| ds_name: str = None, |
| num_image: int = 1, |
| object_tokens_str: str = "", |
| ) -> Dict: |
| conv = get_conv_template(template_name) |
| roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} |
|
|
| |
| conversations = [] |
| for i, source in enumerate(sources): |
| if roles[source[0]['from']] != conv.roles[0]: |
| |
| source = source[1:] |
|
|
| conv.messages = [] |
| for j, sentence in enumerate(source): |
| role = roles[sentence['from']] |
| assert role == conv.roles[j % 2], f'{i}' |
| conv.append_message(role, sentence['value']) |
| conversations.append(conv.get_prompt()) |
|
|
| if not text_only: |
| new_conversations = [] |
| for conversation in conversations: |
| for i in range(num_image): |
| image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' |
| conversation = conversation.replace('<image>', image_tokens, 1) |
| conversation = conversation.replace('<OBJECT_TOKENS>', object_tokens_str, 1) |
| new_conversations.append(conversation) |
| conversations = new_conversations |
|
|
| |
| input_ids = tokenizer( |
| conversations, |
| return_tensors='pt', |
| padding=False if group_by_length or use_packed_ds else 'max_length', |
| max_length=tokenizer.model_max_length, |
| truncation=True, |
| ).input_ids |
| targets = input_ids.clone() |
|
|
| |
|
|
| |
| sep = conv.sep + conv.roles[1] + ': ' |
| for conversation, target in zip(conversations, targets): |
| total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
| turns = conversation.split(conv.sep2) |
| cur_len = 1 |
| target[:cur_len] = IGNORE_TOKEN_ID |
| for i, turn in enumerate(turns): |
| if turn == '': |
| break |
| turn_len = len(tokenizer(turn).input_ids) |
|
|
| parts = turn.split(sep) |
| if len(parts) != 2: |
| break |
| parts[0] += sep |
| |
| instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
| if i != 0 and not tokenizer.legacy: |
| |
| instruction_len -= 1 |
|
|
| |
| target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID |
| cur_len += turn_len |
|
|
| if i != 0 and not tokenizer.legacy: |
| |
| cur_len -= 1 |
|
|
| target[cur_len:] = IGNORE_TOKEN_ID |
|
|
| if False: |
| z = target.clone() |
| z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) |
| logger.info(tokenizer.decode(z)) |
| exit() |
|
|
| if cur_len < tokenizer.model_max_length: |
| if cur_len != total_len: |
| target[:] = IGNORE_TOKEN_ID |
| print( |
| f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' |
| f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' |
| ) |
| sys.stdout.flush() |
|
|
| return dict( |
| input_ids=input_ids, |
| labels=targets, |
| attention_mask=input_ids.ne(tokenizer.pad_token_id), |
| ) |
|
|
|
|
|
|
|
|
| def preprocess_mpt( |
| template_name, |
| sources, |
| tokenizer: transformers.PreTrainedTokenizer, |
| num_image_token_list: list, |
| text_only: bool = False, |
| group_by_length: bool = False, |
| use_packed_ds: bool = False, |
| ds_name: str = None, |
| num_image: int = 1, |
| object_tokens_str: str = "" |
| ) -> Dict: |
| conv = get_conv_template(template_name) |
| roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} |
|
|
| |
| conversations = [] |
| for i, source in enumerate(sources): |
| if roles[source[0]['from']] != conv.roles[0]: |
| |
| source = source[1:] |
|
|
| conv.messages = [] |
| for j, sentence in enumerate(source): |
| role = roles[sentence['from']] |
| assert role == conv.roles[j % 2], f'{i}' |
| conv.append_message(role, sentence['value']) |
| conversations.append(conv.get_prompt()) |
|
|
| if not text_only: |
| new_conversations = [] |
| for conversation in conversations: |
| for i in range(num_image): |
| image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' |
| conversation = conversation.replace('<image>', image_tokens, 1) |
| conversation = conversation.replace('<OBJECT_TOKENS>', object_tokens_str, 1) |
| new_conversations.append(conversation) |
| conversations = new_conversations |
|
|
| |
| input_ids = tokenizer( |
| conversations, |
| return_tensors='pt', |
| padding=False if group_by_length or use_packed_ds else 'max_length', |
| max_length=tokenizer.model_max_length, |
| truncation=True, |
| ).input_ids |
| targets = input_ids.clone() |
|
|
| |
| sep = conv.sep + conv.roles[1] |
| for conversation, target in zip(conversations, targets): |
| total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
| turns = conversation.split(conv.sep) |
| re_turns = [conv.sep.join(turns[:3])] |
| for conv_idx in range(3, len(turns), 2): |
| re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) |
| cur_len = 0 |
| target[:cur_len] = IGNORE_TOKEN_ID |
| for i, turn in enumerate(re_turns): |
| if turn == '': |
| break |
| turn_len = len(tokenizer(turn).input_ids) + 1 |
|
|
| parts = turn.split(sep) |
| if len(parts) != 2: |
| break |
| parts[0] += sep |
| instruction_len = len(tokenizer(parts[0]).input_ids) |
|
|
| |
| target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID |
| |
| |
| |
| cur_len += turn_len |
|
|
| target[cur_len:] = IGNORE_TOKEN_ID |
|
|
| if cur_len < tokenizer.model_max_length: |
| if cur_len != total_len: |
| target[:] = IGNORE_TOKEN_ID |
| print( |
| f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' |
| f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' |
| ) |
| sys.stdout.flush() |
|
|
| return dict( |
| input_ids=input_ids, |
| labels=targets, |
| attention_mask=input_ids.ne(tokenizer.pad_token_id), |
| ) |
|
|
|
|
|
|
| def preprocess_phi3_debug( |
| template_name, |
| sources, |
| tokenizer: transformers.PreTrainedTokenizer, |
| num_image_token_list: list, |
| text_only: bool = False, |
| group_by_length: bool = False, |
| use_packed_ds: bool = False, |
| ds_name: str = None, |
| num_image: int = 1, |
| object_tokens_str: str = "" |
| ) -> Dict: |
| conversations = sources[0] |
| input = '' |
| out_conversation = [] |
| while conversations and conversations[0]['from'] == 'gpt': |
| |
| conversations = conversations[1:] |
| |
| for msg in conversations: |
| if msg['from'] == 'human': |
| msg_value = msg['value'] |
| if not text_only: |
| for i in range(num_image): |
| image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' |
| msg_value = msg_value.replace('<image>', image_tokens, 1) |
| msg_value = msg_value.replace('<OBJECT_TOKENS>', object_tokens_str, 1).strip() |
| input += msg_value |
| elif msg['from'] == 'gpt': |
| out_conversation.append({ |
| 'input': input, |
| 'output': msg['value'].strip(), |
| }) |
| input = '' |
| else: |
| raise NotImplementedError |
|
|
| _system = 'You are an AI assistant whose name is Phi-3.' |
| PROMPT_TEMPLATE = ConfigDict( |
| phi3_chat=dict( |
| SYSTEM='<|system|>\n{system}<|end|>\n', |
| INSTRUCTION='<|user|>\n{input}<|end|>\n<|assistant|>\n', |
| SUFFIX='<|end|>', |
| SUFFIX_AS_EOS=True, |
| SEP='\n', |
| STOP_WORDS=['<|end|>'], |
| ) |
| ) |
| template = PROMPT_TEMPLATE.phi3_chat |
| template['INSTRUCTION'] = '<|user|>\n{input}<|end|><|assistant|>\n' |
| |
| input_ids, labels = [], [] |
| for i, single_turn_conversation in enumerate(out_conversation): |
| input = single_turn_conversation.get('input', '') |
| if input is None: |
| input = '' |
| input_text = template.INSTRUCTION.format(input=input, round=i+1) |
|
|
| if i == 0: |
| system = template.SYSTEM.format(system=_system) |
| input_text = system + input_text |
| input_encode = tokenizer.encode(input_text, add_special_tokens=True) |
| else: |
| input_encode = tokenizer.encode(input_text, add_special_tokens=False) |
| input_ids += input_encode |
| labels += [IGNORE_INDEX] * len(input_encode) |
|
|
| output_text = single_turn_conversation.get('output', '') |
| if template.get('SUFFIX', None): |
| output_text += template.SUFFIX |
| output_encode = tokenizer.encode(output_text, add_special_tokens=False) |
| input_ids += output_encode |
| labels += copy.deepcopy(output_encode) |
| |
| if len(input_ids) > tokenizer.model_max_length: |
| input_ids = input_ids[:tokenizer.model_max_length] |
| labels = labels[:tokenizer.model_max_length] |
| print( |
| f"Warning: input_ids length({len(input_ids)})" |
| f"is longer than max_length, cut to {tokenizer.model_max_length}" |
| ) |
| input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) |
| labels = torch.tensor(labels, dtype=torch.long).unsqueeze(0) |
| return dict( |
| input_ids=input_ids, |
| labels=labels, |
| attention_mask=input_ids.ne(tokenizer.pad_token_id), |
| ) |
|
|
|
|
|
|
| def preprocess_phi3( |
| template_name, |
| sources, |
| tokenizer: transformers.PreTrainedTokenizer, |
| num_image_token_list: list, |
| text_only: bool = False, |
| group_by_length: bool = False, |
| use_packed_ds: bool = False, |
| ds_name: str = None, |
| num_image: int = 1, |
| object_tokens_str: str = "" |
| ) -> Dict: |
| |
| conv = get_conv_template(template_name) |
| roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} |
|
|
| |
| conversations = [] |
| for i, source in enumerate(sources): |
| if roles[source[0]['from']] != conv.roles[0]: |
| |
| source = source[1:] |
|
|
| conv.messages = [] |
| for j, sentence in enumerate(source): |
| role = roles[sentence['from']] |
| assert role == conv.roles[j % 2], f'{i}' |
| conv.append_message(role, sentence['value']) |
| conversations.append(conv.get_prompt()) |
|
|
| if not text_only: |
| new_conversations = [] |
| for conversation in conversations: |
| for i in range(num_image): |
| image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' |
| conversation = conversation.replace('<image>', image_tokens, 1) |
| |
| new_conversations.append(conversation) |
| conversations = new_conversations |
|
|
| |
| tokenizer.padding_side = 'right' |
| input_ids = tokenizer( |
| conversations, |
| return_tensors='pt', |
| padding=False if group_by_length or use_packed_ds else 'max_length', |
| max_length=tokenizer.model_max_length, |
| truncation=True, |
| ).input_ids |
| targets = input_ids.clone() |
|
|
| |
| sep = conv.sep + conv.roles[1] |
| for conversation, target in zip(conversations, targets): |
| total_len = int(target.ne(int(tokenizer.pad_token_id)).sum()) |
|
|
| turns = conversation.split(conv.sep) |
| re_turns = [conv.sep.join(turns[:3])] |
| for conv_idx in range(3, len(turns), 2): |
| re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) |
| cur_len = 1 |
| target[:cur_len] = IGNORE_TOKEN_ID |
| endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>') |
| target[target == endoftext_id] = IGNORE_TOKEN_ID |
|
|
| |
| |
| |
|
|
| for i, turn in enumerate(re_turns): |
| if turn == '': |
| |
| break |
| if i == 0: |
| turn_len = len(tokenizer(turn).input_ids) |
| else: |
| turn_len = len(tokenizer(turn).input_ids) - 1 |
| parts = turn.split(sep) |
| if len(parts) != 2: |
| print("len(parts) != 2") |
| break |
| parts[0] += sep |
|
|
| if i == 0: |
| instruction_len = len(tokenizer(parts[0]).input_ids) - 1 |
| else: |
| instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
| |
| target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID |
| |
| |
| |
| cur_len += turn_len |
|
|
| target[cur_len:] = IGNORE_TOKEN_ID |
|
|
| if False: |
| z = target.clone() |
| z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) |
| print(repr(tokenizer.decode(z))) |
|
|
| if cur_len < tokenizer.model_max_length: |
| if cur_len != total_len: |
| target[:] = IGNORE_TOKEN_ID |
| print( |
| f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' |
| f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' |
| ) |
| sys.stdout.flush() |
| exit(0) |
|
|
| return dict( |
| input_ids=input_ids, |
| labels=targets, |
| attention_mask=input_ids.ne(tokenizer.pad_token_id), |
| ) |
|
|
|
|
|
|
| def preprocess_internlm( |
| template_name, |
| sources, |
| tokenizer: transformers.PreTrainedTokenizer, |
| num_image_token_list: list, |
| text_only: bool = False, |
| group_by_length: bool = False, |
| use_packed_ds: bool = False, |
| ds_name: str = None, |
| num_image: int = 1, |
| object_tokens_str: str = "", |
| ) -> Dict: |
| conv = get_conv_template(template_name) |
| roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} |
|
|
| |
| conversations = [] |
| for i, source in enumerate(sources): |
| if roles[source[0]['from']] != conv.roles[0]: |
| |
| source = source[1:] |
|
|
| conv.messages = [] |
| for j, sentence in enumerate(source): |
| role = roles[sentence['from']] |
| assert role == conv.roles[j % 2], f'{i}' |
| sentence['value'] = sentence['value'].strip() |
| conv.append_message(role, sentence['value']) |
| conversations.append(conv.get_prompt()) |
|
|
| if not text_only: |
| new_conversations = [] |
| for conversation in conversations: |
| for i in range(num_image): |
| image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' |
| conversation = conversation.replace('<image>', image_tokens, 1) |
| conversation = conversation.replace('<OBJECT_TOKENS>', object_tokens_str, 1) |
| new_conversations.append(conversation) |
| conversations = new_conversations |
|
|
| |
| input_ids = tokenizer( |
| conversations, |
| return_tensors='pt', |
| padding=False if group_by_length or use_packed_ds else 'max_length', |
| max_length=tokenizer.model_max_length, |
| truncation=True, |
| ).input_ids |
| targets = input_ids.clone() |
|
|
| for conversation, target in zip(conversations, targets): |
| total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
| cur_len = 1 |
| target[:cur_len] = IGNORE_TOKEN_ID |
| parts = conversation.split(conv.roles[1]) |
| info = parts[0] + conv.roles[1] |
| temp_len = len(tokenizer(info).input_ids) - 1 |
| target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID |
| cur_len = cur_len + temp_len |
|
|
| for index in range(1, len(parts) - 1): |
| info = parts[index] |
| part1, part2 = info.split(conv.roles[0]) |
| temp_len = len(tokenizer(part1).input_ids) - 1 |
| cur_len = cur_len + temp_len |
| part = conv.roles[0] + part2 + conv.roles[1] |
| temp_len = len(tokenizer(part).input_ids) - 1 |
| target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID |
| cur_len = cur_len + temp_len |
| last_info = parts[-1] |
| temp_len = len(tokenizer(last_info).input_ids) - 1 |
| cur_len = cur_len + temp_len |
|
|
| target[cur_len:] = IGNORE_TOKEN_ID |
| if False: |
| z = target.clone() |
| z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) |
| print(repr(tokenizer.decode(z))) |
|
|
| if cur_len < tokenizer.model_max_length: |
| if cur_len != total_len: |
| target[:] = IGNORE_TOKEN_ID |
| print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.') |
| sys.stdout.flush() |
|
|
| return dict( |
| input_ids=input_ids, |
| labels=targets, |
| attention_mask=input_ids.ne(tokenizer.pad_token_id), |
| ) |
|
|
|
|
| def preprocess_qwen2vl(conversations, object_tokens_str, num_images=0): |
| out_conversation_list = [{ |
| "role": "system", |
| "content": [{ |
| "type": "text", |
| "text": "You are a helpful assistant."}] |
| }] |
| |
| if conversations[0]['from'] != 'human': |
| conversations = conversations[1:] |
| |
| total_images = 0 |
| for msg in conversations: |
| if msg['from'] == 'human': |
| msg_value = msg['value'] |
| cur_image_count = msg_value.count('<image>\n') |
| total_images += cur_image_count |
| msg_value = msg_value.replace('<OBJECT_TOKENS>', object_tokens_str, 1) |
| if cur_image_count == 0: |
| |
| out_conversation_list.append({ |
| "role": "user", |
| "content": [ |
| { |
| "type": "text", |
| "text": msg_value, |
| }, |
| ], |
| }) |
| else: |
| out_contents = [] |
| text_str_list = msg_value.split('<image>\n') |
| for idx, text_str in enumerate(text_str_list): |
| if idx > 0 and cur_image_count > 0: |
| out_contents.append({ |
| "type": "image", |
| }) |
| cur_image_count = cur_image_count - 1 |
| |
| if text_str.strip() == '': |
| continue |
| else: |
| out_contents.append({ |
| "type": "text", |
| "text": text_str, |
| }) |
| out_conversation_list.append({ |
| "role": "user", |
| "content": out_contents, |
| }) |
| elif msg['from'] == 'gpt': |
| msg_value = msg['value'] |
| out_conversation_list.append({ |
| "role": "assistant", |
| "content": [ |
| { |
| "type": "text", |
| "text": msg_value, |
| }, |
| ], |
| }) |
| if total_images != num_images: |
| return None |
| else: |
| return out_conversation_list |
|
|
|
|
| def preprocess_llava(conversations, object_tokens_str, num_images=0): |
| out_conversation_list = [] |
| |
| if conversations[0]['from'] != 'human': |
| conversations = conversations[1:] |
| |
| total_images = 0 |
| for msg in conversations: |
| if msg['from'] == 'human': |
| msg_value = msg['value'] |
| cur_image_count = msg_value.count('<image>\n') |
| total_images += cur_image_count |
| msg_value = msg_value.replace('<OBJECT_TOKENS>', object_tokens_str, 1) |
| if cur_image_count == 0: |
| |
| out_conversation_list.append({ |
| "role": "user", |
| "content": [ |
| { |
| "type": "text", |
| "text": msg_value, |
| }, |
| ], |
| }) |
| else: |
| out_contents = [] |
| text_str_list = msg_value.split('<image>\n') |
| for idx, text_str in enumerate(text_str_list): |
| if idx > 0 and cur_image_count > 0: |
| out_contents.append({ |
| "type": "image", |
| }) |
| cur_image_count = cur_image_count - 1 |
| |
| if text_str.strip() == '': |
| continue |
| else: |
| out_contents.append({ |
| "type": "text", |
| "text": text_str, |
| }) |
| out_conversation_list.append({ |
| "role": "user", |
| "content": out_contents, |
| }) |
| elif msg['from'] == 'gpt': |
| msg_value = msg['value'] |
| out_conversation_list.append({ |
| "role": "assistant", |
| "content": [ |
| { |
| "type": "text", |
| "text": msg_value, |
| }, |
| ], |
| }) |
| if total_images != num_images: |
| return None |
| else: |
| return out_conversation_list |
|
|