| import argparse
|
| import random
|
|
|
| import torch
|
| import os
|
| import json
|
| from tqdm import tqdm
|
| import shortuuid
|
| from pycocotools import mask
|
| import numpy as np
|
| import cv2
|
| from psalm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, \
|
| DEFAULT_IM_END_TOKEN, DEFAULT_SEG_TOKEN, SEG_TOKEN_INDEX, CLS_TOKEN_INDEX
|
| from psalm.model.builder import load_pretrained_model
|
| from psalm.utils import disable_torch_init
|
| from psalm.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
| from psalm.eval.segmentation_evaluation.panoptic_evaluation import my_coco_panoptic_evaluator, my_SemSegEvaluator
|
| from transformers import StoppingCriteria, StoppingCriteriaList
|
|
|
| from torch.utils.data import Dataset, DataLoader
|
|
|
| from psalm import conversation as conversation_lib
|
| from detectron2.data.datasets import load_sem_seg
|
| from PIL import Image
|
| import math
|
| import copy
|
| from detectron2.structures import BoxMode
|
| from detectron2.data import MetadataCatalog, DatasetCatalog
|
|
|
| from typing import Dict, Optional, Sequence, List
|
| from dataclasses import dataclass, field
|
| from psalm.train.train_datasets import COCO_semantic_dataset
|
| import transformers
|
| from segmentation_evaluation import openseg_classes
|
|
|
| PASCAL_CTX_459_CATEGORIES=openseg_classes.get_pascal_ctx_459_categories_with_prompt_eng()
|
|
|
| PASCAL_CTX_459_COLORS = [k["color"] for k in PASCAL_CTX_459_CATEGORIES]
|
| PASCAL_CTX_59_CATEGORIES=openseg_classes.get_pascal_ctx_59_categories_with_prompt_eng()
|
|
|
| PASCAL_CTX_59_COLORS = [k["color"] for k in PASCAL_CTX_59_CATEGORIES]
|
| PASCAL_VOC_20_CATEGORIES = openseg_classes.get_pascal_21_categories_with_prompt_eng()[1:]
|
|
|
| PASCAL_VOC_20_COLORS = [k["color"] for k in PASCAL_VOC_20_CATEGORIES]
|
| ADE20K_150_CATEGORIES = openseg_classes.get_ade20k_categories_with_prompt_eng()
|
|
|
| ADE20k_COLORS = [k["color"] for k in ADE20K_150_CATEGORIES]
|
| @dataclass
|
| class DataArguments:
|
| data_path: str = field(default=None,
|
| metadata={"help": "Path to the training data."})
|
| lazy_preprocess: bool = False
|
| is_multimodal: bool = False
|
| model_path: Optional[str] = field(default="/path/to/model")
|
| mask_config: Optional[str] = field(default="./psalm/mask_config/maskformer2_swin_base_384_bs16_50ep.yaml")
|
| image_aspect_ratio: str = 'square'
|
| image_grid_pinpoints: Optional[str] = field(default=None)
|
| model_map_name: str = 'psalm'
|
| version: str = 'llava_phi'
|
| output_dir: str = './output/panoptic_segmentation'
|
| segmentation: bool = True
|
| eval_batch_size: int = 1
|
| dataloader_num_workers: int = 4
|
| seg_task: Optional[str] = field(default="semantic")
|
| ov_task_list: Optional[str] = field(default="ctx_59")
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class DataCollatorForCOCODatasetV2(object):
|
| """Collate examples for supervised fine-tuning."""
|
|
|
| tokenizer: transformers.PreTrainedTokenizer
|
|
|
| def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| input_ids, labels = tuple([instance[key] for instance in instances]
|
| for key in ("input_ids", "labels"))
|
| input_ids = torch.nn.utils.rnn.pad_sequence(
|
| input_ids,
|
| batch_first=True,
|
| padding_value=self.tokenizer.pad_token_id)
|
| labels = torch.nn.utils.rnn.pad_sequence(labels,
|
| batch_first=True,
|
| padding_value=IGNORE_INDEX)
|
| input_ids = input_ids[:, :self.tokenizer.model_max_length]
|
| labels = labels[:, :self.tokenizer.model_max_length]
|
| batch = dict(
|
| input_ids=input_ids,
|
| labels=labels,
|
| attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
| )
|
| if 'image' in instances[0]:
|
| images = [instance['image'] for instance in instances]
|
| if all(x is not None and x.shape == images[0].shape for x in images):
|
| batch['images'] = torch.stack(images)
|
| else:
|
| batch['images'] = images
|
| for instance in instances:
|
| for key in ['input_ids', 'labels', 'image']:
|
| del instance[key]
|
| batch['seg_info'] = [instance for instance in instances]
|
|
|
| if 'dataset_type' in instances[0]:
|
| batch['dataset_type'] = [instance['dataset_type'] for instance in instances]
|
|
|
| if 'class_name_ids' in instances[0]:
|
| class_name_ids = [instance['class_name_ids'] for instance in instances]
|
| if any(x.shape != class_name_ids[0].shape for x in class_name_ids):
|
| batch['class_name_ids'] = torch.nn.utils.rnn.pad_sequence(
|
| class_name_ids,
|
| batch_first=True,
|
| padding_value=-1,
|
| )
|
| else:
|
| batch['class_name_ids'] = torch.stack(class_name_ids, dim=0)
|
| if 'token_refer_id' in instances[0]:
|
| token_refer_id = [instance['token_refer_id'] for instance in instances]
|
| batch['token_refer_id'] = token_refer_id
|
| if 'cls_indices' in instances[0]:
|
| cls_indices = [instance['cls_indices'] for instance in instances]
|
| if any(x.shape != cls_indices[0].shape for x in cls_indices):
|
| batch['cls_indices'] = torch.nn.utils.rnn.pad_sequence(
|
| cls_indices,
|
| batch_first=True,
|
| padding_value=-1,
|
| )
|
| else:
|
| batch['cls_indices'] = torch.stack(cls_indices, dim=0)
|
| if 'random_idx' in instances[0]:
|
| random_idxs = [instance['random_idx'] for instance in instances]
|
| batch['random_idx'] = torch.stack(random_idxs, dim=0)
|
| if 'class_name_embedding_indices' in instances[0]:
|
| class_name_embedding_indices = [instance['class_name_embedding_indices'] for instance in instances]
|
| class_name_embedding_indices = torch.nn.utils.rnn.pad_sequence(
|
| class_name_embedding_indices,
|
| batch_first=True,
|
| padding_value=0)
|
| batch['class_name_embedding_indices'] = class_name_embedding_indices
|
| if 'refer_embedding_indices' in instances[0]:
|
| refer_embedding_indices = [instance['refer_embedding_indices'] for instance in instances]
|
| refer_embedding_indices = torch.nn.utils.rnn.pad_sequence(
|
| refer_embedding_indices,
|
| batch_first=True,
|
| padding_value=0)
|
| batch['refer_embedding_indices'] = refer_embedding_indices
|
| if 'class_id_mapping' in instances[0]:
|
| class_id_mapping = [instance['class_id_mapping'] for instance in instances]
|
| batch['class_id_mapping'] = class_id_mapping
|
|
|
| return batch
|
|
|
| def _get_ctx459_meta():
|
|
|
|
|
| stuff_ids = [k["id"] for k in PASCAL_CTX_459_CATEGORIES]
|
| assert len(stuff_ids) == 459, len(stuff_ids)
|
|
|
|
|
|
|
| stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
|
| stuff_classes = [k["name"] for k in PASCAL_CTX_459_CATEGORIES]
|
|
|
| ret = {
|
| "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
|
| "stuff_classes": stuff_classes,
|
| }
|
| return ret
|
|
|
| def _get_ctx59_meta():
|
|
|
|
|
| stuff_ids = [k["id"] for k in PASCAL_CTX_59_CATEGORIES]
|
| assert len(stuff_ids) == 59, len(stuff_ids)
|
|
|
|
|
|
|
| stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
|
| stuff_classes = [k["name"] for k in PASCAL_CTX_59_CATEGORIES]
|
|
|
| ret = {
|
| "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
|
| "stuff_classes": stuff_classes,
|
| }
|
| return ret
|
|
|
| def _get_pascal20_meta():
|
|
|
|
|
| stuff_ids = [k["id"] for k in PASCAL_VOC_20_CATEGORIES]
|
| assert len(stuff_ids) == 20, len(stuff_ids)
|
|
|
|
|
|
|
| stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
|
| stuff_classes = [k["name"] for k in PASCAL_VOC_20_CATEGORIES]
|
|
|
| ret = {
|
| "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
|
| "stuff_classes": stuff_classes,
|
| }
|
| return ret
|
|
|
| def get_ade150_metadata():
|
| meta = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
| thing_classes = [k["name"] for k in ADE20K_150_CATEGORIES if k["isthing"] == 1]
|
| thing_colors = [k["color"] for k in ADE20K_150_CATEGORIES if k["isthing"] == 1]
|
| stuff_classes = [k["name"] for k in ADE20K_150_CATEGORIES]
|
| stuff_colors = [k["color"] for k in ADE20K_150_CATEGORIES]
|
|
|
| meta["thing_classes"] = thing_classes
|
| meta["thing_colors"] = thing_colors
|
| meta["stuff_classes"] = stuff_classes
|
| meta["stuff_colors"] = stuff_colors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| thing_dataset_id_to_contiguous_id = {}
|
| stuff_dataset_id_to_contiguous_id = {}
|
|
|
| for i, cat in enumerate(ADE20K_150_CATEGORIES):
|
| if cat["isthing"]:
|
| thing_dataset_id_to_contiguous_id[cat["id"]] = i
|
|
|
|
|
|
|
|
|
| stuff_dataset_id_to_contiguous_id[cat["id"]] = i
|
|
|
| meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
|
| meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
|
|
|
| return meta
|
|
|
| OV_SEM_DICT={
|
| 'ade_150':
|
| {
|
| 'json_path': '/home/hk/yyma/data/ov_sem_data/ADEChallengeData2016',
|
| 'image_path': 'images/validation',
|
| 'gt_path': 'annotations_detectron2/validation',
|
| 'ignore_label': 255,
|
| 'tot_cls': 150,
|
| 'gt_ext': "png",
|
| 'image_ext': "jpg",
|
| 'get_mete_method': get_ade150_metadata
|
| },
|
| 'ctx_459':
|
| {
|
| 'json_path': '/home/hk/yyma/data/ov_sem_data/pascal_ctx_d2',
|
| 'image_path': 'images/validation',
|
| 'gt_path': 'annotations_ctx459/validation',
|
| 'ignore_label': 65535,
|
| 'tot_cls':459,
|
| 'gt_ext':"tif",
|
| 'image_ext':"jpg",
|
| 'get_mete_method':_get_ctx459_meta
|
| },
|
| 'ctx_59':
|
| {
|
| 'json_path': '/home/hk/yyma/data/ov_sem_data/pascal_ctx_d2',
|
| 'image_path': 'images/validation',
|
| 'gt_path': 'annotations_ctx59/validation',
|
| 'ignore_label': 255,
|
| 'tot_cls': 59,
|
| 'gt_ext': "png",
|
| 'image_ext': "jpg",
|
| 'get_mete_method': _get_ctx59_meta
|
| },
|
| 'pc_20':
|
| {
|
| 'json_path': '/home/hk/yyma/data/ov_sem_data/pascal_voc_d2',
|
| 'image_path': 'images/validation',
|
| 'gt_path': 'annotations_pascal20/validation',
|
| 'ignore_label': 255,
|
| 'tot_cls': 20,
|
| 'gt_ext': "png",
|
| 'image_ext': "jpg",
|
| 'get_mete_method': _get_pascal20_meta
|
| }
|
| }
|
|
|
|
|
| class common_semantic_dataset(COCO_semantic_dataset):
|
| def __init__(self, task_name, tokenizer, data_args, is_train=True):
|
| super(common_semantic_dataset).__init__()
|
| task_info = OV_SEM_DICT[task_name]
|
| self.semantic_image_path = os.path.join(task_info['json_path'],task_info['image_path'])
|
| self.semantic_gt_path = os.path.join(task_info['json_path'],task_info['gt_path'])
|
| self.cate = task_info['get_mete_method']()
|
| self.data = load_sem_seg(gt_root=self.semantic_gt_path, image_root=self.semantic_image_path, gt_ext=task_info["gt_ext"],
|
| image_ext=task_info["image_ext"])
|
|
|
| self.tokenizer = tokenizer
|
| self.data_args = data_args
|
| self.mask_format = 'polygon'
|
| self.common_id_to_cont_id = self.cate['stuff_dataset_id_to_contiguous_id'] if 'stuff_dataset_id_to_contiguous_id' in self.cate else None
|
| self.common_class_name = self.cate['stuff_classes']
|
| self.common_class_id = list(range(len(self.common_class_name)))
|
| self.ignore_label = task_info['ignore_label']
|
| self.total_class = task_info['tot_cls']
|
|
|
| def preprocess_class_name(self, CLS_token='[SEG]', current_sample_class_name=None):
|
| tokenized = [self.tokenizer.encode(class_name, add_special_tokens=False) for class_name in
|
| current_sample_class_name]
|
| tokenized_class_names = [tokens + [self.tokenizer.encode(CLS_token, add_special_tokens=False)[0]] for tokens in
|
| tokenized]
|
| class_name_id = [token for sublist in tokenized_class_names for token in sublist]
|
| class_name_id = torch.tensor(class_name_id)
|
| cls_indices = [idx for idx, sublist in enumerate(tokenized_class_names) for _ in sublist]
|
| cls_indices = torch.tensor(cls_indices)
|
|
|
| return class_name_id, cls_indices
|
| def __len__(self):
|
| return len(self.data)
|
| def __getitem__(self, idx):
|
| data = self.data[idx]
|
|
|
| data_dict = data
|
|
|
| if isinstance(self.data_args.image_processor, dict):
|
| processor = self.data_args.image_processor['semantic']
|
| else:
|
| processor = self.data_args.image_processor
|
| data_dict = processor.preprocess(data_dict, mask_format=self.mask_format,ignore_label=self.ignore_label)
|
|
|
| instruction = 'Panoptic Segmentation: You need to segment all objects '
|
| prefix_inst = 'This is an image <image>, Please do Panoptic Segmentation.'
|
|
|
|
|
|
|
| num_class = self.total_class
|
| full2sample_mapping = {}
|
| if len(self.common_class_id) > num_class:
|
| current_sample_class_id = data_dict['instances'].gt_classes.numpy().tolist()
|
| num_negatives = num_class - 1 - len(current_sample_class_id)
|
| potential_negative_ids = list(set(self.common_class_id) - set(current_sample_class_id))
|
| negative_sample_ids = np.random.choice(potential_negative_ids, num_negatives, replace=False)
|
| pick_class_id = current_sample_class_id + list(negative_sample_ids)
|
| else:
|
| pick_class_id = self.common_class_id
|
|
|
| for new_id, original_id in enumerate(pick_class_id):
|
| full2sample_mapping[original_id] = new_id
|
| if len(pick_class_id) > 200:
|
| current_sample_class_name = [self.common_class_name[id].split(',')[0] for id in
|
| pick_class_id] + ['background']
|
| else:
|
| current_sample_class_name = [self.common_class_name[id] for id in
|
| pick_class_id] + ['background']
|
|
|
|
|
|
|
|
|
|
|
|
|
| category = '<cls>, ' * (len(current_sample_class_name) - 1) + '<cls>.'
|
|
|
| sources_value = f'\nThis is all the candidate categories: {category}\n'
|
|
|
| sources = [[{'from': 'human', 'value': prefix_inst + sources_value},
|
| {'from': 'gpt', 'value': '\nSure, the segmentation result is <seg>'}]]
|
|
|
|
|
|
|
| text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
| input_ids = text_dict['input_ids'][0]
|
| labels = text_dict['labels'][0]
|
|
|
| class_name_ids, cls_indices = self.preprocess_class_name(current_sample_class_name=current_sample_class_name)
|
| class_name_embedding_indices = torch.zeros_like(input_ids)
|
| class_name_embedding_indices[input_ids == CLS_TOKEN_INDEX] = 1
|
|
|
| data_dict['input_ids'] = text_dict['input_ids'][0]
|
| data_dict['labels'] = text_dict['labels'][0]
|
|
|
| data_dict['class_name_ids'] = class_name_ids
|
| data_dict['cls_indices'] = cls_indices
|
| data_dict['class_name_embedding_indices'] = class_name_embedding_indices
|
| data_dict['class_id_mapping'] = {value: key for key, value in full2sample_mapping.items()}
|
| return data_dict
|
|
|
| class StoppingCriteriaSub(StoppingCriteria):
|
| def __init__(self, stops=[], encounters=1):
|
| super().__init__()
|
| self.stops = stops
|
|
|
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
| last_token = input_ids[0][-1]
|
| for stop in self.stops:
|
| if stop == last_token:
|
| return True
|
| return False
|
|
|
|
|
| def split_list(lst, n):
|
| """Split a list into n (roughly) equal-sized chunks"""
|
| chunk_size = math.ceil(len(lst) / n)
|
| return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
|
|
|
|
| def get_chunk(lst, n, k):
|
| chunks = split_list(lst, n)
|
| return chunks[k]
|
|
|
|
|
| def evaluation(data_args,ov_task=None):
|
| disable_torch_init()
|
| model_path = os.path.expanduser(data_args.model_path)
|
| model_name = get_model_name_from_path(model_path)
|
| tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name,mask_config=data_args.mask_config,model_args=data_args)
|
| data_args.image_processor = image_processor
|
| data_args.is_multimodal = True
|
|
|
| conversation_lib.default_conversation = conversation_lib.conv_templates[data_args.version]
|
| eval_dataset = common_semantic_dataset(task_name=ov_task, tokenizer=tokenizer,
|
| data_args=data_args)
|
| data_collator = DataCollatorForCOCODatasetV2(tokenizer=tokenizer)
|
| dataloader_params = {
|
| "batch_size": data_args.eval_batch_size,
|
| "num_workers": data_args.dataloader_num_workers,
|
| }
|
| eval_dataloader = DataLoader(eval_dataset, batch_size=dataloader_params['batch_size'], collate_fn=data_collator,
|
| num_workers=dataloader_params['num_workers'])
|
|
|
| def load_instruction_dataset():
|
|
|
| return eval_dataset
|
|
|
| try:
|
| DatasetCatalog.register('instruction_dataset', load_instruction_dataset)
|
| except:
|
| print('dataset have been loaded')
|
|
|
| cont_id = eval_dataset.coco_id_to_cont_id if hasattr(eval_dataset,'coco_id_to_cont_id') else eval_dataset.common_id_to_cont_id
|
| class_name = eval_dataset.coco_class_name[:-1] if hasattr(eval_dataset,'coco_class_name') else eval_dataset.common_class_name
|
| ignore_label = 255 if not hasattr(eval_dataset,'ignore_label') else eval_dataset.ignore_label
|
| evaluator = my_SemSegEvaluator('instruction_dataset',
|
| output_dir=data_args.output_dir, dataset_id_to_cont_id=cont_id, class_name=class_name,ignore_label=ignore_label)
|
| evaluator.reset()
|
|
|
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| compute_type=torch.float16
|
| model.to(dtype=torch.float16, device=device).eval()
|
| with torch.no_grad():
|
| for idx, inputs in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader)):
|
| inputs = {k: v.to(device=device) if torch.is_tensor(v) else v for k, v in inputs.items()}
|
| outputs = model.eval_seg(
|
| input_ids=inputs['input_ids'],
|
| attention_mask=inputs['attention_mask'],
|
| images=inputs['images'],
|
| seg_info=inputs['seg_info'],
|
| class_name_embedding_indices=inputs['class_name_embedding_indices'],
|
| class_name_ids=inputs['class_name_ids'],
|
| cls_indices=inputs['cls_indices'],
|
| labels=inputs['labels'],
|
| is_thing_list=eval_dataset.coco_is_thing if hasattr(eval_dataset,'coco_is_thing') else None
|
| )
|
| if torch.cuda.is_available():
|
| torch.cuda.synchronize()
|
| if hasattr(eval_dataset,'common_class_name'):
|
| class_id_mapping = inputs['class_id_mapping'][0]
|
| sem_mask = torch.zeros(len(eval_dataset.common_class_name),outputs[0]['sem_seg'].shape[1],outputs[0]['sem_seg'].shape[2]).to(outputs[0]['sem_seg'].device)
|
| for i in range(outputs[0]['sem_seg'].shape[0]):
|
| real_id = class_id_mapping[i]
|
| sem_mask[real_id,:,:] = outputs[0]['sem_seg'][i,:,:]
|
| outputs = [{'sem_seg':sem_mask}]
|
|
|
| evaluator.process(inputs['seg_info'], outputs)
|
|
|
| results = evaluator.evaluate()
|
| if ov_task is not None:
|
| print(f'current ov_task is {ov_task}')
|
| print(results['sem_seg']['mIoU'])
|
| else:
|
| print(results)
|
|
|
| if results is None:
|
| results = {}
|
| return results
|
|
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| parser = transformers.HfArgumentParser(DataArguments)
|
| data_args = parser.parse_args_into_dataclasses()[0]
|
| ov_task_list = data_args.ov_task_list
|
| if ov_task_list is None:
|
| evaluation(data_args)
|
| else:
|
| ov_task_list = ov_task_list.split('||')
|
| for ov_task in ov_task_list:
|
| evaluation(data_args,ov_task)
|
|
|