| from mmdet.datasets import RefCocoDataset |
|
|
| from mmdet.datasets.transforms import LoadAnnotations |
| from mmdet.evaluation import RefSegMetric |
| import argparse |
| from mmengine.config import Config |
| from xtuner.model.utils import guess_load_checkpoint |
| from xtuner.registry import BUILDER |
| from xtuner.utils.constants import DEFAULT_IMAGE_TOKEN |
| from accelerate import Accelerator |
| from accelerate.utils import gather_object |
| from mmdet.structures.mask import BitmapMasks |
| from mmcv.transforms import LoadImageFromFile |
| from tqdm import tqdm |
| import torch |
| import torch.nn.functional as F |
| from time import time |
|
|
| from projects.f_llm.datasets.transforms import PILLoadImageFromFile, RefCOCO2PNG |
| from projects.lisa.datasets.refcoco_segm_dataset import ReferSegmDataset |
| from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn |
| from third_parts.segment_anything.utils.transforms import ResizeLongestSide |
| from pycocotools import mask as mask_utils |
| from projects.glamm.datasets.utils.utils import SEG_QUESTIONS, ANSWER_LIST |
| extra_image_processor = ResizeLongestSide( |
| target_length=1024, |
| ) |
| import copy |
| import torchvision.transforms as T |
| from torchvision.transforms.functional import InterpolationMode |
| from xtuner.utils import PROMPT_TEMPLATE |
| template = PROMPT_TEMPLATE.phi3_chat |
| _system = '你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。' |
| |
| _system = '' |
| begin_str = f'{DEFAULT_IMAGE_TOKEN}\n' |
| template['INSTRUCTION'] = '<|user|>\n{input}<|end|><|assistant|>\n' |
|
|
| transformer = T.Compose([ |
| T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
| T.Resize((448, 448), interpolation=InterpolationMode.BICUBIC), |
| T.ToTensor(), |
| T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) |
| ]) |
|
|
| def get_inputid_labels(conversations, image_token_str): |
| input = '' |
| out_conversation = [] |
| while conversations and conversations[0]['from'] == 'gpt': |
| |
| conversations = conversations[1:] |
| for msg in conversations: |
| if msg['from'] == 'human': |
| if image_token_str is None and '<image>' in msg['value']: |
| msg['value'] = msg['value'].replace('<image>', '') |
| if '<image>' in msg['value']: |
| msg['value'] = msg['value'].replace('<image>', image_token_str).strip() |
| input += msg['value'].strip() |
| elif msg['from'] == 'gpt': |
| out_conversation.append({ |
| 'input': input, |
| 'output': msg['value'].strip() |
| }) |
| input = '' |
| else: |
| raise NotImplementedError |
| input_ids, labels = [], [] |
| for i, single_turn_conversation in enumerate(out_conversation): |
| input = single_turn_conversation.get('input', '') |
| if input is None: |
| input = '' |
| input_text = template.INSTRUCTION.format( |
| input=input, round=i + 1) |
| if i == 0: |
| if _system != '' and _system is not None: |
| system = template.SYSTEM.format(system=_system) |
| input_text = system + input_text |
| input_encode = tokenizer.encode(input_text, add_special_tokens=True) |
| else: |
| input_encode = tokenizer.encode(input_text, add_special_tokens=False) |
| input_ids += input_encode |
| labels += [-100] * len(input_encode) |
| output_text = single_turn_conversation.get('output', '') |
| if template.get('SUFFIX', None): |
| output_text += template.SUFFIX |
| output_encode = tokenizer.encode( |
| output_text, add_special_tokens=False) |
| input_ids += output_encode |
| labels += copy.deepcopy(output_encode) |
| max_length = 8192 |
| if len(input_ids) > max_length: |
| input_ids = input_ids[:max_length] |
| labels = labels[:max_length] |
| return input_ids, labels |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) |
| parser.add_argument('config', help='config file path.') |
| parser.add_argument('--checkpoint', default=None, type=str) |
| args = parser.parse_args() |
|
|
| |
| accelerator = Accelerator() |
| |
| message = [f"Hello this is GPU {accelerator.process_index}"] |
| |
| messages = gather_object(message) |
| |
| accelerator.print(messages) |
|
|
| cfg = Config.fromfile(args.config) |
| tokenizer = cfg.tokenizer |
| tokenizer = BUILDER.build(tokenizer) |
| tokenizer.add_tokens(['[SEG]'], special_tokens=True) |
|
|
| model = BUILDER.build(cfg.model) |
| if args.checkpoint is not None: |
| state_dict = guess_load_checkpoint(args.checkpoint) |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| accelerator.print(f"Unexpected parameters: {unexpected}") |
|
|
| model = model.to(device=accelerator.device) |
| model.eval() |
| model.to(torch.bfloat16) |
|
|
| dataset = RefCocoDataset( |
| data_root='data/coco/', |
| data_prefix=dict(img_path='train2014/'), |
| text_mode='select_first', |
| ann_file='refcoco/instances.json', |
| split_file='refcoco/refs(unc).p', |
| split='val' |
| ) |
| accelerator.wait_for_everyone() |
|
|
| data_ids = list(range(len(dataset))) |
|
|
| results = [] |
| from PIL import Image |
| import numpy as np |
| from projects.lisa.datasets.sem_seg_dataset import dynamic_preprocess |
| with accelerator.split_between_processes(data_ids) as sub_ids: |
| for idx in tqdm(sub_ids, disable=not accelerator.is_main_process): |
| ann_info = dataset[idx] |
| image = Image.open(ann_info['img_path']).convert('RGB') |
| width, height = image.size |
| g_image = np.array(image) |
| g_image = extra_image_processor.apply_image(g_image) |
| g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous() |
| |
| images = dynamic_preprocess(image, 1, 12, 448, True) |
| pixel_values = [transformer(image) for image in images] |
| pixel_values = torch.stack(pixel_values) |
| patch_token = int((448 // 14)**2 * (0.5**2)) |
| num_image_tokens = pixel_values.shape[0] * patch_token |
| image_token_str = f'<img>' + '<IMG_CONTEXT>' * num_image_tokens+ '</img>' |
|
|
| instances, phrases = ann_info['instances'], ann_info['text'] |
| for inst, phrase in zip(instances, phrases): |
| if '.' == phrase[-1]: |
| phrase = phrase[:-1] |
| binary_mask = np.zeros((height, width), dtype=np.uint8) |
| for seg in inst["mask"]: |
| rles = mask_utils.frPyObjects([seg], height, width) |
| m = mask_utils.decode(rles) |
| m = m.astype(np.uint8) |
| binary_mask += m.squeeze() |
| |
| import random |
| conversation = [] |
| question = random.choice(SEG_QUESTIONS).format(class_name=phrase) |
| question = begin_str + question |
| conversation.append({'from':'human', 'value': question}) |
| conversation.append({'from':'gpt', 'value': ''}) |
|
|
| input_ids, labels = get_inputid_labels(conversation, image_token_str) |
| input_ids = input_ids[:-1] |
| out_data_dict = { |
| 'input_ids': torch.tensor(input_ids), |
| 'labels': torch.tensor(labels), |
| 'g_pixel_values': g_pixel_values, |
| 'pixel_values': pixel_values, |
| 'masks': binary_mask[None], |
| } |
| |
| data_sample = glamm_collate_fn([out_data_dict]) |
| with torch.no_grad(): |
| outputs = model(**data_sample, mode='predict') |
| |
| gt_masks = binary_mask[None] > 0 |
| pred_mask_logits = outputs['pred_mask_logits'] |
| if pred_mask_logits is None: |
| pred_masks = torch.zeros_like(gt_masks) |
| else: |
| pred_masks = pred_mask_logits.sigmoid().cpu() > 0.5 |
|
|
| assert len(pred_masks) == len(gt_masks) |
| mask_cnt = pred_masks.shape[0] |
| results.append( |
| dict( |
| pred_instances=dict(masks=pred_masks), |
| gt_masks=BitmapMasks( |
| masks=gt_masks, |
| height=gt_masks.shape[1], |
| width=gt_masks.shape[2])) |
| ) |
| results = gather_object(results) |
|
|
| if accelerator.is_main_process: |
| accelerator.print( |
| f"Collected {len(results)} result samples from all gpus") |
| evaluator = RefSegMetric(metric=['cIoU', 'mIoU']) |
| evaluator.process(data_batch=dict(), data_samples=results) |
| metrics = evaluator.compute_metrics(evaluator.results) |
| accelerator.print(f"Evaluation results on : {metrics}") |
|
|