| import logging |
| import os |
| import torch |
| from datasets import Dataset as HFDataset |
| from datasets import DatasetDict, load_from_disk |
| from mmengine import print_log |
| from PIL import Image |
| from torch.utils.data import Dataset |
| import numpy as np |
|
|
| from xtuner.registry import BUILDER |
| from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset |
| import copy |
| from .encode_fn import video_lisa_encode_fn |
| import json |
| import random |
| import pycocotools.mask as maskUtils |
| import cv2 |
| import torchvision.transforms as T |
| from torchvision.transforms.functional import InterpolationMode |
|
|
| SEG_QUESTIONS = [ |
| "Please segment the object according to the description: {class_name}", |
| ] |
|
|
| SEG_QUESTIONS_SHORT = [ |
| "Can you segment the {class_name} in this image?", |
| "Please segment {class_name} in this image.", |
| "What is {class_name} in this image? Please respond with segmentation mask.", |
| "What is {class_name} in this image? Please output segmentation mask.", |
|
|
| "Can you segment the {class_name} in this image", |
| "Please segment {class_name} in this image", |
| "What is {class_name} in this image? Please respond with segmentation mask", |
| "What is {class_name} in this image? Please output segmentation mask", |
|
|
| "Could you provide a segmentation mask for the {class_name} in this image?", |
| "Please identify and segment the {class_name} in this image.", |
| "Where is the {class_name} in this picture? Please respond with a segmentation mask.", |
| "Can you highlight the {class_name} in this image with a segmentation mask?", |
|
|
| "Could you provide a segmentation mask for the {class_name} in this image", |
| "Please identify and segment the {class_name} in this image", |
| "Where is the {class_name} in this picture? Please respond with a segmentation mask", |
| "Can you highlight the {class_name} in this image with a segmentation mask", |
| ] |
|
|
| ANSWER_LIST = [ |
| "It is [SEG].", |
| "Sure, [SEG].", |
| "Sure, it is [SEG].", |
| "Sure, the segmentation result is [SEG].", |
| "[SEG].", |
| ] |
|
|
| class VideoSAM2Dataset(Dataset): |
| IMAGENET_MEAN = (0.485, 0.456, 0.406) |
| IMAGENET_STD = (0.229, 0.224, 0.225) |
| IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>' |
| IMG_START_TOKEN = '<img>' |
| IMG_END_TOKEN = '</img>' |
|
|
| FAST_IMG_CONTEXT_TOKEN = '<FAST_IMG_CONTEXT>' |
| FAST_IMG_START_TOKEN = '<fast_img>' |
| FAST_IMG_END_TOKEN = '</fast_img>' |
|
|
| def __init__(self, |
| sam2_folder, |
| expression_file, |
| extra_image_processor=None, |
| tokenizer=None, |
| select_number=5, |
| sampled_frames=5, |
| offline_processed_text_folder=None, |
| template_map_fn=None, |
| max_length=8196, |
| lazy=True, |
| repeats=1, |
| special_tokens=None, |
| use_fast=False, |
| n_fast_images=50, |
| fast_pool_size=4, |
| mode='long', |
| frame_contiguous_sample=False, |
| ): |
| assert mode in ['long', 'long_short', 'short'] |
| self.mode = mode |
| self.cur_mode = mode |
| assert lazy is True |
| self.tokenizer = BUILDER.build(tokenizer) |
| self.select_number = select_number |
| self.sampled_frames = sampled_frames |
| assert offline_processed_text_folder or (expression_file and tokenizer) |
| self.lazy = lazy |
|
|
| self.max_length = max_length |
|
|
| self.template_map_fn = template_map_fn |
| if isinstance(self.template_map_fn, dict) and self.lazy: |
| _type = self.template_map_fn['type'] |
| del self.template_map_fn['type'] |
| self.template_map_fn = _type(**self.template_map_fn) |
|
|
| if offline_processed_text_folder and expression_file: |
| print_log( |
| 'Both `offline_processed_text_folder` and ' |
| '`data_path` are set, and we load dataset from' |
| '`offline_processed_text_folder` ' |
| f'({offline_processed_text_folder})', |
| logger='current', |
| level=logging.WARNING) |
|
|
| if offline_processed_text_folder is not None: |
| raise NotImplementedError |
| else: |
| video_ids, anno_dict = self.json_file_preprocess(expression_file) |
| if self.lazy: |
| self.video_ids = video_ids |
| self.anno_dict = anno_dict |
| else: |
| raise NotImplementedError |
|
|
| self.sam2_folder = sam2_folder |
| if extra_image_processor is not None: |
| self.extra_image_processor = BUILDER.build(extra_image_processor) |
| self.down_ratio = 1 |
| self.repeats = repeats |
|
|
| self._system = '' |
|
|
| self.downsample_ratio = 0.5 |
| self.image_size = 448 |
| patch_size = 14 |
| self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2)) |
|
|
| self.transformer = T.Compose([ |
| T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
| T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC), |
| T.ToTensor(), |
| T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD) |
| ]) |
|
|
| if special_tokens is not None: |
| self.tokenizer.add_tokens(special_tokens, special_tokens=True) |
|
|
| self.use_fast = use_fast |
| self.n_fast_images = n_fast_images |
| self.fast_pool_size = fast_pool_size |
|
|
| self.frame_contiguous_sample = frame_contiguous_sample |
|
|
| |
| self.save_folder = './work_dirs/video_debug/' |
| self.cur_number = 0 |
|
|
| print("Video res dataset (ref-sam2), include {} items.".format(len(self.video_ids))) |
|
|
| def __len__(self): |
| return len(self.video_ids) * self.repeats |
|
|
| @property |
| def modality_length(self): |
| length_list = [] |
| for data_dict in self.video_ids: |
| cur_len = 20000 |
| length_list.append(cur_len) |
| return length_list |
|
|
| def real_len(self): |
| return len(self.video_ids) |
|
|
| def json_file_preprocess(self, expression_file): |
| |
| with open(expression_file, 'r') as f: |
| expression_datas = json.load(f) |
|
|
| video_ids = list(expression_datas.keys()) |
| return video_ids, expression_datas |
|
|
| def dataset_map_fn(self, objects_expression_infos, n_frames, n_fast_frames=0): |
| |
| if self.mode == 'long': |
| expressions = [object_info['formated'] for object_info in objects_expression_infos] |
| self.cur_mode = self.mode |
| elif self.mode == 'short': |
| expressions = [object_info['short_caps'][random.randint(0, len(object_info['short_caps'])-1)] for object_info in objects_expression_infos] |
| self.cur_mode = self.mode |
| else: |
| if random.random() < 0.5: |
| expressions = [object_info['formated'] for object_info in objects_expression_infos] |
| self.cur_mode = 'long' |
| else: |
| expressions = [object_info['short_caps'][random.randint(0, len(object_info['short_caps']) - 1)] for |
| object_info in objects_expression_infos] |
| self.cur_mode = 'short' |
| text_dict = self.prepare_text(n_frames, expressions, num_image_tokens=self.patch_token, |
| n_fast_frames=n_fast_frames) |
| ret = {'conversation': text_dict['conversation']} |
| return ret |
|
|
| def prepare_text(self, n_frames, expressions, num_image_tokens=256, n_fast_frames=0): |
|
|
| if self.use_fast: |
| fast_frame_token_str = f'{self.FAST_IMG_START_TOKEN}' \ |
| f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_frames * self.fast_pool_size * self.fast_pool_size}' \ |
| f'{self.FAST_IMG_END_TOKEN}' + '\n' |
| else: |
| fast_frame_token_str = '' |
|
|
| frame_token_str = f'{self.IMG_START_TOKEN}' \ |
| f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \ |
| f'{self.IMG_END_TOKEN}' |
|
|
| questions = [] |
| answers = [] |
| for i, exp in enumerate(expressions): |
| if self.cur_mode == 'short': |
| question_template = random.choice(SEG_QUESTIONS_SHORT) |
| exp = exp.replace("A ", '') |
| else: |
| question_template = random.choice(SEG_QUESTIONS) |
| questions.append(question_template.format(class_name=exp)) |
| answers.append(random.choice(ANSWER_LIST)) |
| qa_list = [] |
| for i, (question, answer) in enumerate(zip(questions, answers)): |
| if i == 0: |
| frame_tokens = frame_token_str + '\n' |
| |
| frame_tokens = frame_tokens * n_frames |
| frame_tokens = frame_tokens.strip() |
| frame_tokens = fast_frame_token_str + frame_tokens |
| qa_list.append( |
| {'from': 'human', 'value': frame_tokens + question} |
| ) |
| else: |
| qa_list.append( |
| {'from': 'human', 'value': question} |
| ) |
| qa_list.append( |
| {'from': 'gpt', 'value': answer} |
| ) |
|
|
| input = '' |
| conversation = [] |
| for msg in qa_list: |
| if msg['from'] == 'human': |
| input += msg['value'] |
| elif msg['from'] == 'gpt': |
| conversation.append({'input': input, 'output': msg['value']}) |
| input = '' |
| else: |
| raise NotImplementedError |
|
|
| |
| conversation[0].update({'system': self._system}) |
| return {'conversation': conversation} |
|
|
| def __getitem__(self, index): |
| index = index % self.real_len() |
| video_id = self.video_ids[index] |
| expression_dict = self.anno_dict[video_id] |
| object_ids = list(expression_dict['objects'].keys()) |
|
|
| video_path = os.path.join(self.sam2_folder, expression_dict['video_path']) |
| anno_path = os.path.join(self.sam2_folder, expression_dict['anno_path']) |
|
|
| video_frames = get_video_frames(video_path) |
|
|
| if self.use_fast: |
| |
| fast_interval = len(video_frames) / (self.n_fast_images + 1e-4) |
| sampled_fast_frame_idxs = [min(int(i * fast_interval), len(video_frames) - 1) for i in range(self.n_fast_images)] |
| fast_video_frames = [video_frames[_idx] for _idx in sampled_fast_frame_idxs] |
| else: |
| fast_video_frames = None |
|
|
| video_frames = video_frames[::4] |
|
|
| |
| with open(anno_path, 'r') as f: |
| mask_data = json.load(f) |
| masklents = decode_masklet(mask_data['masklet']) |
|
|
| n_frames = len(masklents) |
| n_objects = len(object_ids) |
|
|
| |
| if n_objects > self.select_number: |
| selected_indexes = np.random.choice(n_objects, self.select_number) |
| else: |
| selected_indexes = np.random.choice(n_objects, self.select_number, replace=True) |
|
|
| selected_object_ids = [object_ids[_idx] for _idx in selected_indexes] |
| objects_expression_infos = [expression_dict['objects'][_idx] for _idx in selected_object_ids] |
| _masklents = [] |
| for _mask in masklents: |
| _mask_selected = [] |
| for _idx in selected_object_ids: |
| _mask_selected.append(_mask[:, :, int(_idx)]) |
| _mask_selected = np.stack(_mask_selected, axis=2) |
| _masklents.append(_mask_selected) |
| masklents = _masklents |
|
|
| |
| |
| if n_frames > self.sampled_frames + 1: |
| if self.frame_contiguous_sample and random.random() < 0.5: |
| |
| selected_start_frame = np.random.choice(n_frames - self.sampled_frames, 1, replace=False) |
| selected_frame_indexes = [selected_start_frame[0] + _i for _i in range(self.sampled_frames)] |
| else: |
| selected_frame_indexes = np.random.choice(n_frames, self.sampled_frames, replace=False) |
| else: |
| selected_frame_indexes = np.random.choice(n_frames, self.sampled_frames, replace=True) |
| selected_frame_indexes.sort() |
|
|
| video_frames = [video_frames[_idx] for _idx in selected_frame_indexes] |
| masklents = [masklents[_idx] for _idx in selected_frame_indexes] |
|
|
| data_dict = self.dataset_map_fn(objects_expression_infos, len(video_frames), n_fast_frames=self.n_fast_images) |
| result = self.template_map_fn(data_dict) |
| data_dict.update(result) |
| result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True) |
| data_dict.update(result) |
|
|
| pixel_values = [] |
| extra_pixel_values = [] |
| for frame in video_frames: |
| frame = frame[:, :, ::-1] |
| frame_image = Image.fromarray(frame).convert('RGB') |
| ori_width, ori_height = frame_image.size |
| if self.extra_image_processor is not None: |
| g_image = np.array(frame_image) |
| g_image = self.extra_image_processor.apply_image(g_image) |
| g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous() |
| extra_pixel_values.append(g_pixel_values) |
|
|
| frame_image = self.transformer(frame_image) |
| pixel_values.append(frame_image) |
|
|
| pixel_values = torch.stack(pixel_values, dim=0) |
| data_dict['pixel_values'] = pixel_values |
| if self.extra_image_processor is not None: |
| data_dict['g_pixel_values'] = extra_pixel_values |
|
|
| |
| if self.use_fast: |
| fast_pixel_values = [] |
| for frame_image in fast_video_frames: |
| frame = frame_image[:, :, ::-1] |
| frame_image = Image.fromarray(frame).convert('RGB') |
| ori_width, ori_height = frame_image.size |
|
|
| frame_image = self.transformer(frame_image) |
| fast_pixel_values.append(frame_image) |
|
|
| fast_pixel_values = torch.stack(fast_pixel_values, dim=0) |
| data_dict['fast_pixel_values'] = fast_pixel_values |
|
|
| |
| masklents = np.stack(masklents, axis=0) |
| masklents = torch.from_numpy(masklents).permute(3, 0, 1, 2) |
| masklents = masklents.flatten(0, 1) |
| |
| |
| |
| data_dict['masks'] = masklents |
| data_dict['type'] = 'video' |
| return data_dict |
|
|
| def visualization_debug(self, data_dict): |
| save_folder = os.path.join(self.save_folder, 'sample_{}'.format(self.cur_number)) |
| if not os.path.exists(save_folder): |
| os.mkdir(save_folder) |
| self.cur_number += 1 |
|
|
| |
|
|
| show_images = [] |
|
|
| pixel_values = data_dict['pixel_values'] |
| save_folder_image = os.path.join(save_folder, 'image') |
| if not os.path.exists(save_folder_image): |
| os.mkdir(save_folder_image) |
| for i_image, image_pixel_value in enumerate(pixel_values): |
| |
| image_pixel_value[0] = image_pixel_value[0] * 0.2686 |
| image_pixel_value[1] = image_pixel_value[1] * 0.2613 |
| image_pixel_value[2] = image_pixel_value[2] * 0.2757 |
| image_pixel_value[0] = image_pixel_value[0] + 0.4814 |
| image_pixel_value[1] = image_pixel_value[1] + 0.4578 |
| image_pixel_value[2] = image_pixel_value[2] + 0.4082 |
| image_pixel_value = image_pixel_value * 255 |
| image_pixel_value = image_pixel_value.permute(1, 2, 0) |
| image_pixel_value = image_pixel_value.to(torch.uint8).numpy() |
| |
| |
| show_images.append(image_pixel_value) |
| cv2.imwrite(os.path.join(save_folder_image, '{}.jpg'.format(i_image)), image_pixel_value) |
|
|
| |
| input_text = self.tokenizer.decode(data_dict['input_ids'], skip_special_tokens=False) |
| with open(os.path.join(save_folder, 'text.json'), 'w') as f: |
| json.dump([input_text], f) |
|
|
| |
| save_folder_mask = os.path.join(save_folder, 'mask') |
| if not os.path.exists(save_folder_mask): |
| os.mkdir(save_folder_mask) |
| n_frames = len(pixel_values) |
| masks = data_dict['masks'] |
| _, h, w = masks.shape |
| masks = masks.reshape(-1, n_frames, h, w) |
| for i_obj, obj_masks in enumerate(masks): |
| save_folder_mask_obj_folder = os.path.join(save_folder_mask, 'obj_{}'.format(i_obj)) |
| if not os.path.exists(save_folder_mask_obj_folder): |
| os.mkdir(save_folder_mask_obj_folder) |
| for i_frame, f_mask in enumerate(obj_masks): |
| f_mask = f_mask.numpy() |
| f_mask = f_mask * 255 |
| f_mask = np.stack([f_mask * 1, f_mask * 0, f_mask * 0], axis=2) |
| f_mask = show_images[i_frame] * 0.3 + 0.7 * f_mask |
| f_mask = f_mask.astype(np.uint8) |
| cv2.imwrite(os.path.join(save_folder_mask_obj_folder, '{}.png'.format(i_frame)), f_mask) |
| return |
|
|
| def get_video_frames(video_path): |
| cap = cv2.VideoCapture(video_path) |
|
|
| if not cap.isOpened(): |
| print("Error: Cannot open video file.") |
| return |
|
|
| frames = [] |
|
|
| frame_id = 0 |
| while True: |
| ret, frame = cap.read() |
|
|
| if not ret: |
| break |
|
|
| frames.append(frame) |
|
|
| frame_id += 1 |
|
|
| cap.release() |
| return frames |
|
|
|
|
| def images_to_video(frames, video_name, fps=6): |
| height, width, layers = frames[0].shape |
|
|
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| video = cv2.VideoWriter(video_name, fourcc, fps, (width, height)) |
|
|
| for frame in frames: |
| video.write(frame) |
|
|
| |
| video.release() |
| return |
|
|
| def decode_masklet(masklet): |
| masks = [] |
| for _rle in masklet: |
| mask = maskUtils.decode(_rle) |
| masks.append(mask) |
| return masks |
|
|
| def draw_mask(image, mask): |
| obj_mask = mask * 255 |
| obj_mask = np.stack([obj_mask * 1, obj_mask * 0, obj_mask * 0], axis=2) |
| obj_mask = obj_mask * 0.5 + copy.deepcopy(image) * 0.5 |
| obj_mask = obj_mask.astype(np.uint8) |
| return obj_mask |
|
|
| def add_mask2images(frames, masklets): |
| show_videos = [] |
| for i_frames, (frame, masks) in enumerate(zip(frames, masklets)): |
| if i_frames == 0: |
| n_obj = masks.shape[-1] |
| for i_obj in range(n_obj): |
| show_videos.append([]) |
|
|
| n_obj = masks.shape[-1] |
| for i_obj in range(n_obj): |
| show_videos[i_obj].append(draw_mask(copy.deepcopy(frame), masks[:, :, i_obj])) |
| return show_videos |