{ "cells": [ { "cell_type": "markdown", "id": "f09246e8", "metadata": {}, "source": [ "# Easy Dev for Post-hoc OOD Detectors\n", "\n", "This notebook integrates some simple post-hoc OOD detection methods.\n", "\n", "We choose ImageNet-1K as in-distribution (ID) and load a pretrained vision transformer (ViT)." ] }, { "cell_type": "code", "execution_count": 1, "id": "4dc297c1", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "id": "8b218878", "metadata": {}, "source": [ "## Load Models and Dataset" ] }, { "cell_type": "code", "execution_count": 13, "id": "f155277c", "metadata": {}, "outputs": [], "source": [ "from openood.utils import config\n", "from openood.datasets import get_dataloader, get_ood_dataloader\n", "from openood.evaluators import get_evaluator\n", "from openood.networks import get_network" ] }, { "cell_type": "code", "execution_count": 14, "id": "48968bc2", "metadata": {}, "outputs": [], "source": [ "# load config files for cifar10 baseline\n", "config_files = [\n", " './configs/datasets/cifar10/cifar10.yml',\n", " './configs/datasets/cifar10/cifar10_ood.yml',\n", " './configs/networks/resnet18_32x32.yml',\n", " './configs/pipelines/test/test_ood.yml',\n", " './configs/preprocessors/base_preprocessor.yml',\n", " './configs/postprocessors/msp.yml',\n", "]\n", "config = config.Config(*config_files)\n", "# modify config \n", "config.network.checkpoint = './results/cifar10_resnet18_32x32_base_e100_lr0.1/best.ckpt'\n", "config.network.pretrained = True\n", "config.num_workers = 8\n", "config.save_output = False\n", "config.parse_refs()" ] }, { "cell_type": "code", "execution_count": 15, "id": "a99ab6de", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dataset:\n", " image_size: 32\n", " name: cifar10\n", " num_classes: 10\n", " num_gpus: 1\n", " num_machines: 1\n", " num_workers: 8\n", " pre_size: 32\n", " split_names: ['train', 'val', 'test']\n", " test:\n", " batch_size: 200\n", " data_dir: ./data/images_classic/\n", " dataset_class: ImglistDataset\n", " imglist_pth: ./data/benchmark_imglist/cifar10/test_cifar10.txt\n", " interpolation: bilinear\n", " shuffle: False\n", " train:\n", " batch_size: 128\n", " data_dir: ./data/images_classic/\n", " dataset_class: ImglistDataset\n", " imglist_pth: ./data/benchmark_imglist/cifar10/train_cifar10.txt\n", " interpolation: bilinear\n", " shuffle: True\n", " val:\n", " batch_size: 200\n", " data_dir: ./data/images_classic/\n", " dataset_class: ImglistDataset\n", " imglist_pth: ./data/benchmark_imglist/cifar10/val_cifar10.txt\n", " interpolation: bilinear\n", " shuffle: False\n", "evaluator:\n", " name: ood\n", "exp_name: cifar10_resnet18_32x32_test_ood_ood_msp_default\n", "merge_option: default\n", "machine_rank: 0\n", "mark: default\n", "network:\n", " checkpoint: ./results/cifar10_resnet18_32x32_base_e100_lr0.1/best.ckpt\n", " name: resnet18_32x32\n", " num_classes: 10\n", " num_gpus: 1\n", " pretrained: True\n", "num_gpus: 1\n", "num_machines: 1\n", "num_workers: 8\n", "ood_dataset:\n", " batch_size: 128\n", " dataset_class: ImglistDataset\n", " farood:\n", " datasets: ['mnist', 'svhn', 'texture', 'place365']\n", " mnist:\n", " data_dir: ./data/images_classic/\n", " imglist_pth: ./data/benchmark_imglist/cifar10/test_mnist.txt\n", " place365:\n", " data_dir: ./data/images_classic/\n", " imglist_pth: ./data/benchmark_imglist/cifar10/test_places365.txt\n", " svhn:\n", " data_dir: ./data/images_classic/\n", " imglist_pth: ./data/benchmark_imglist/cifar10/test_svhn.txt\n", " texture:\n", " data_dir: ./data/images_classic/\n", " imglist_pth: ./data/benchmark_imglist/cifar10/test_texture.txt\n", " image_size: 32\n", " interpolation: bilinear\n", " name: cifar10_ood\n", " nearood:\n", " cifar100:\n", " data_dir: ./data/images_classic/\n", " imglist_pth: ./data/benchmark_imglist/cifar10/test_cifar100.txt\n", " datasets: ['cifar100', 'tin']\n", " tin:\n", " data_dir: ./data/images_classic/\n", " imglist_pth: ./data/benchmark_imglist/osr_tin20/test_tin.txt\n", " num_classes: 10\n", " num_gpus: 1\n", " num_machines: 1\n", " num_workers: 8\n", " shuffle: False\n", " split_names: ['val', 'nearood', 'farood']\n", " val:\n", " data_dir: ./data/images_classic/\n", " imglist_pth: ./data/benchmark_imglist/cifar10/val_cifar100.txt\n", "output_dir: ./results/\n", "pipeline:\n", " name: test_ood\n", "postprocessor:\n", " name: msp\n", "preprocessor:\n", " name: base\n", "recorder:\n", " save_csv: True\n", " save_scores: True\n", "save_output: False" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "config" ] }, { "cell_type": "code", "execution_count": 16, "id": "35d30b38", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model Loading resnet18_32x32 Completed!\n" ] } ], "source": [ "# get dataloader\n", "id_loader_dict = get_dataloader(config)\n", "ood_loader_dict = get_ood_dataloader(config)\n", "# init network\n", "net = get_network(config.network).cuda()\n", "# init ood evaluator\n", "evaluator = get_evaluator(config)" ] }, { "cell_type": "markdown", "id": "2043b57f", "metadata": {}, "source": [ "## Feature Extraction" ] }, { "cell_type": "code", "execution_count": 6, "id": "f430f169", "metadata": {}, "outputs": [], "source": [ "from tqdm import tqdm\n", "import numpy as np\n", "import torch\n", "import os.path as osp" ] }, { "cell_type": "code", "execution_count": 7, "id": "512fe29c", "metadata": {}, "outputs": [], "source": [ "def save_arr_to_dir(arr, dir):\n", " with open(dir, 'wb') as f:\n", " np.save(f, arr)" ] }, { "cell_type": "code", "execution_count": 8, "id": "418778f7", "metadata": {}, "outputs": [], "source": [ "save_root = f'./results/{config.exp_name}'" ] }, { "cell_type": "code", "execution_count": 15, "id": "e2f12809", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Extracting reults...: 100%|██████████| 45/45 [00:03<00:00, 13.48it/s]\n", "Extracting reults...: 100%|██████████| 5/5 [00:00<00:00, 6.84it/s]\n" ] } ], "source": [ "# save id (test & val) results\n", "net.eval()\n", "modes = ['test', 'val']\n", "for mode in modes:\n", " dl = id_loader_dict[mode]\n", " dataiter = iter(dl)\n", " \n", " logits_list = []\n", " feature_list = []\n", " label_list = []\n", " \n", " for i in tqdm(range(1,\n", " len(dataiter) + 1),\n", " desc='Extracting reults...',\n", " position=0,\n", " leave=True):\n", " batch = next(dataiter)\n", " data = batch['data'].cuda()\n", " label = batch['label']\n", " with torch.no_grad():\n", " logits_cls, feature = net(data, return_feature=True)\n", " logits_list.append(logits_cls.data.to('cpu').numpy())\n", " feature_list.append(feature.data.to('cpu').numpy())\n", " label_list.append(label.numpy())\n", "\n", " logits_arr = np.concatenate(logits_list)\n", " feature_arr = np.concatenate(feature_list)\n", " label_arr = np.concatenate(label_list)\n", " \n", " save_arr_to_dir(logits_arr, osp.join(save_root, 'id', f'{mode}_logits.npy'))\n", " save_arr_to_dir(feature_arr, osp.join(save_root, 'id', f'{mode}_feature.npy'))\n", " save_arr_to_dir(label_arr, osp.join(save_root, 'id', f'{mode}_labels.npy'))" ] }, { "cell_type": "code", "execution_count": 41, "id": "5cded214", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Extracting reults...: 100%|██████████| 71/71 [00:01<00:00, 35.57it/s]\n", "Extracting reults...: 100%|██████████| 141/141 [00:04<00:00, 32.65it/s]\n", "Extracting reults...: 100%|██████████| 547/547 [00:15<00:00, 34.35it/s]\n", "Extracting reults...: 100%|██████████| 204/204 [00:05<00:00, 34.51it/s]\n", "Extracting reults...: 100%|██████████| 45/45 [00:06<00:00, 7.47it/s]\n", "Extracting reults...: 100%|██████████| 286/286 [00:19<00:00, 14.37it/s]\n" ] } ], "source": [ "# save ood results\n", "net.eval()\n", "ood_splits = ['nearood', 'farood']\n", "for ood_split in ood_splits:\n", " for dataset_name, ood_dl in ood_loader_dict[ood_split].items():\n", " dataiter = iter(ood_dl)\n", " \n", " logits_list = []\n", " feature_list = []\n", " label_list = []\n", "\n", " for i in tqdm(range(1,\n", " len(dataiter) + 1),\n", " desc='Extracting reults...',\n", " position=0,\n", " leave=True):\n", " batch = next(dataiter)\n", " data = batch['data'].cuda()\n", " label = batch['label']\n", "\n", " with torch.no_grad():\n", " logits_cls, feature = net(data, return_feature=True)\n", " logits_list.append(logits_cls.data.to('cpu').numpy())\n", " feature_list.append(feature.data.to('cpu').numpy())\n", " label_list.append(label.numpy())\n", "\n", " logits_arr = np.concatenate(logits_list)\n", " feature_arr = np.concatenate(feature_list)\n", " label_arr = np.concatenate(label_list)\n", "\n", " save_arr_to_dir(logits_arr, osp.join(save_root, ood_split, f'{dataset_name}_logits.npy'))\n", " save_arr_to_dir(feature_arr, osp.join(save_root, ood_split, f'{dataset_name}_feature.npy'))\n", " save_arr_to_dir(label_arr, osp.join(save_root, ood_split, f'{dataset_name}_labels.npy'))" ] }, { "cell_type": "markdown", "id": "0b69a9ed", "metadata": {}, "source": [ "## MSP Evaluation" ] }, { "cell_type": "code", "execution_count": 9, "id": "db3f705b", "metadata": {}, "outputs": [], "source": [ "# build msp method (pass in pre-saved logits)\n", "def msp_postprocess(logits):\n", " score = torch.softmax(logits, dim=1)\n", " conf, pred = torch.max(score, dim=1)\n", " return pred, conf" ] }, { "cell_type": "code", "execution_count": 10, "id": "a1cac34c", "metadata": {}, "outputs": [], "source": [ "# load logits, feature, label for this benchmark\n", "results = dict()\n", "# for id\n", "modes = ['val', 'test']\n", "results['id'] = dict()\n", "for mode in modes:\n", " results['id'][mode] = dict()\n", " results['id'][mode]['feature'] = np.load(osp.join(save_root, 'id', f'{mode}_feature.npy'))\n", " results['id'][mode]['logits'] = np.load(osp.join(save_root, 'id', f'{mode}_logits.npy'))\n", " results['id'][mode]['labels'] = np.load(osp.join(save_root, 'id', f'{mode}_labels.npy'))\n", "\n", "# for ood\n", "split_types = ['nearood', 'farood']\n", "for split_type in split_types:\n", " results[split_type] = dict()\n", " dataset_names = config['ood_dataset'][split_type].datasets\n", " for dataset_name in dataset_names:\n", " results[split_type][dataset_name] = dict()\n", " results[split_type][dataset_name]['feature'] = np.load(osp.join(save_root, split_type, f'{dataset_name}_feature.npy'))\n", " results[split_type][dataset_name]['logits'] = np.load(osp.join(save_root, split_type, f'{dataset_name}_logits.npy'))\n", " results[split_type][dataset_name]['labels'] = np.load(osp.join(save_root, split_type, f'{dataset_name}_labels.npy'))\n" ] }, { "cell_type": "code", "execution_count": 60, "id": "79697d8c", "metadata": {}, "outputs": [], "source": [ "def print_nested_dict(dict_obj, indent = 0):\n", " ''' Pretty Print nested dictionary with given indent level \n", " '''\n", " # Iterate over all key-value pairs of dictionary\n", " for key, value in dict_obj.items():\n", " # If value is dict type, then print nested dict \n", " if isinstance(value, dict):\n", " print(' ' * indent, key, ':', '{')\n", " print_nested_dict(value, indent + 2)\n", " print(' ' * indent, '}')\n", " else:\n", " print(' ' * indent, key, ':', value.shape)\n" ] }, { "cell_type": "code", "execution_count": 61, "id": "2dfe3614", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " id : {\n", " val : {\n", " feature : (1000, 512)\n", " logits : (1000, 10)\n", " labels : (1000,)\n", " }\n", " test : {\n", " feature : (9000, 512)\n", " logits : (9000, 10)\n", " labels : (9000,)\n", " }\n", " }\n", " nearood : {\n", " cifar100 : {\n", " feature : (9000, 512)\n", " logits : (9000, 10)\n", " labels : (9000,)\n", " }\n", " tin : {\n", " feature : (18000, 512)\n", " logits : (18000, 10)\n", " labels : (18000,)\n", " }\n", " }\n", " farood : {\n", " mnist : {\n", " feature : (70000, 512)\n", " logits : (70000, 10)\n", " labels : (70000,)\n", " }\n", " svhn : {\n", " feature : (26032, 512)\n", " logits : (26032, 10)\n", " labels : (26032,)\n", " }\n", " texture : {\n", " feature : (5640, 512)\n", " logits : (5640, 10)\n", " labels : (5640,)\n", " }\n", " place365 : {\n", " feature : (36500, 512)\n", " logits : (36500, 10)\n", " labels : (36500,)\n", " }\n", " }\n" ] } ], "source": [ "print_nested_dict(results)" ] }, { "cell_type": "code", "execution_count": 23, "id": "a3af6815", "metadata": {}, "outputs": [], "source": [ "# get pred, conf, gt from MSP postprocessor (can change to your custom_postprocessor here)\n", "postprocess_results = dict()\n", "# id\n", "modes = ['val', 'test']\n", "postprocess_results['id'] = dict()\n", "for mode in modes:\n", " pred, conf = msp_postprocess(torch.from_numpy(results['id'][mode]['logits']))\n", " pred, conf = pred.numpy(), conf.numpy()\n", " gt = results['id'][mode]['labels']\n", " postprocess_results['id'][mode] = [pred, conf, gt]\n", "\n", "# ood\n", "split_types = ['nearood', 'farood']\n", "for split_type in split_types:\n", " postprocess_results[split_type] = dict()\n", " dataset_names = config['ood_dataset'][split_type].datasets\n", " for dataset_name in dataset_names:\n", " pred, conf = msp_postprocess(torch.from_numpy(results[split_type][dataset_name]['logits']))\n", " pred, conf = pred.numpy(), conf.numpy()\n", " gt = results[split_type][dataset_name]['labels']\n", " gt = -1 * np.ones_like(gt) # hard set to -1 here\n", " postprocess_results[split_type][dataset_name] = [pred, conf, gt]" ] }, { "cell_type": "code", "execution_count": 14, "id": "41e3dac6", "metadata": {}, "outputs": [], "source": [ "def print_all_metrics(metrics):\n", " [fpr, auroc, aupr_in, aupr_out,\n", " ccr_4, ccr_3, ccr_2, ccr_1, accuracy] \\\n", " = metrics\n", " print('FPR@95: {:.2f}, AUROC: {:.2f}'.format(100 * fpr, 100 * auroc),\n", " end=' ',\n", " flush=True)\n", " print('AUPR_IN: {:.2f}, AUPR_OUT: {:.2f}'.format(\n", " 100 * aupr_in, 100 * aupr_out),\n", " flush=True)\n", " print('CCR: {:.2f}, {:.2f}, {:.2f}, {:.2f},'.format(\n", " ccr_4 * 100, ccr_3 * 100, ccr_2 * 100, ccr_1 * 100),\n", " end=' ',\n", " flush=True)\n", " print('ACC: {:.2f}'.format(accuracy * 100), flush=True)\n", " print(u'\\u2500' * 70, flush=True) " ] }, { "cell_type": "code", "execution_count": 28, "id": "b16e6f5b", "metadata": {}, "outputs": [], "source": [ "from openood.evaluators.metrics import compute_all_metrics\n", "def eval_ood(postprocess_results):\n", " [id_pred, id_conf, id_gt] = postprocess_results['id']['test']\n", " split_types = ['nearood', 'farood']\n", "\n", " for split_type in split_types:\n", " metrics_list = []\n", " print(f\"Performing evaluation on {split_type} datasets...\")\n", " dataset_names = config['ood_dataset'][split_type].datasets\n", " \n", " for dataset_name in dataset_names:\n", " [ood_pred, ood_conf, ood_gt] = postprocess_results[split_type][dataset_name]\n", "\n", " pred = np.concatenate([id_pred, ood_pred])\n", " conf = np.concatenate([id_conf, ood_conf])\n", " label = np.concatenate([id_gt, ood_gt])\n", " print(f'Computing metrics on {dataset_name} dataset...')\n", "\n", " ood_metrics = compute_all_metrics(conf, label, pred)\n", " print_all_metrics(ood_metrics)\n", " metrics_list.append(ood_metrics)\n", " print('Computing mean metrics...', flush=True)\n", " metrics_list = np.array(metrics_list)\n", " metrics_mean = np.mean(metrics_list, axis=0) \n", " print_all_metrics(metrics_mean)\n", "\n", " " ] }, { "cell_type": "code", "execution_count": 29, "id": "335686fe", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Performing evaluation on nearood datasets...\n", "Computing metrics on cifar100 dataset...\n", "FPR@95: 62.01, AUROC: 87.11 AUPR_IN: 85.93, AUPR_OUT: 85.25\n", "CCR: 0.36, 1.84, 11.31, 68.38, ACC: 95.34\n", "──────────────────────────────────────────────────────────────────────\n", "Computing metrics on tin dataset...\n", "FPR@95: 60.31, AUROC: 86.61 AUPR_IN: 73.79, AUPR_OUT: 91.74\n", "CCR: 0.14, 0.88, 7.78, 67.00, ACC: 95.34\n", "──────────────────────────────────────────────────────────────────────\n", "Computing mean metrics...\n", "FPR@95: 61.16, AUROC: 86.86 AUPR_IN: 79.86, AUPR_OUT: 88.50\n", "CCR: 0.25, 1.36, 9.54, 67.69, ACC: 95.34\n", "──────────────────────────────────────────────────────────────────────\n", "Performing evaluation on farood datasets...\n", "Computing metrics on mnist dataset...\n", "FPR@95: 58.56, AUROC: 89.92 AUPR_IN: 66.95, AUPR_OUT: 98.10\n", "CCR: 5.49, 11.13, 31.94, 77.40, ACC: 95.34\n", "──────────────────────────────────────────────────────────────────────\n", "Computing metrics on svhn dataset...\n", "FPR@95: 52.26, AUROC: 90.76 AUPR_IN: 77.86, AUPR_OUT: 95.65\n", "CCR: 0.08, 0.73, 17.60, 80.54, ACC: 95.34\n", "──────────────────────────────────────────────────────────────────────\n", "Computing metrics on texture dataset...\n", "FPR@95: 59.75, AUROC: 88.73 AUPR_IN: 91.28, AUPR_OUT: 80.60\n", "CCR: 0.09, 0.34, 9.69, 76.00, ACC: 95.34\n", "──────────────────────────────────────────────────────────────────────\n", "Computing metrics on place365 dataset...\n", "FPR@95: 58.70, AUROC: 88.03 AUPR_IN: 65.24, AUPR_OUT: 96.04\n", "CCR: 0.39, 2.11, 13.42, 71.66, ACC: 95.34\n", "──────────────────────────────────────────────────────────────────────\n", "Computing mean metrics...\n", "FPR@95: 57.32, AUROC: 89.36 AUPR_IN: 75.33, AUPR_OUT: 92.60\n", "CCR: 1.51, 3.58, 18.16, 76.40, ACC: 95.34\n", "──────────────────────────────────────────────────────────────────────\n" ] } ], "source": [ "eval_ood(postprocess_results)" ] }, { "cell_type": "code", "execution_count": null, "id": "f2ce6263", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "9bfdd80c4cd5b3ca30f79c8858286326028dc154b9efddfc3ea147df9fc4c063" }, "kernelspec": { "display_name": "Python 3.8.12 64-bit ('ood': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }