├── .gitignore ├── README.md ├── __init__.py ├── demo.ipynb ├── engine_finetune.py ├── engine_pretrain.py ├── evaluate ├── __init__.py ├── evaluate_colorization.py ├── evaluate_reasoning.py ├── evaluate_segmentation.py ├── in_colorization_dataloader.py ├── mae_utils.py ├── pascal_dataloader.py ├── reasoning_dataloader.py ├── segmentation_utils.py └── splits │ ├── coco │ ├── trn │ │ ├── fold0.pkl │ │ ├── fold1.pkl │ │ ├── fold2.pkl │ │ └── fold3.pkl │ └── val │ │ ├── fold0.pkl │ │ ├── fold1.pkl │ │ ├── fold2.pkl │ │ └── fold3.pkl │ ├── fss │ ├── test.txt │ ├── trn.txt │ └── val.txt │ └── pascal │ ├── trn │ ├── fold0.txt │ ├── fold1.txt │ ├── fold2.txt │ └── fold3.txt │ └── val │ ├── fold0.txt │ ├── fold1.txt │ ├── fold2.txt │ └── fold3.txt ├── evaluate_detection ├── 2012_support_set.pth ├── 2012_val_flattened_set.pth ├── __init__.py ├── box_ops.py ├── canvas_ds.py ├── misc.py ├── transforms.py └── voc_orig.py ├── figures_dataset ├── __init__.py ├── df_train.csv ├── df_val.csv ├── download_links.py └── requirements.txt ├── main_pretrain.py ├── models_mae.py ├── models_vit.py ├── requirements.txt ├── tta.py ├── util ├── lr_sched.py ├── misc.py └── pos_embed.py ├── viz_utils.py └── vqgan.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/* 2 | output/* 3 | output_dir/* 4 | util/__pycache__/* 5 | evalaute/__pycache__/* 6 | slurm* 7 | wandb/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Visual Prompting via Image Inpainting 2 | ### [Amir Bar*](https://amirbar.net), [Yossi Gandelsman*](https://yossi.gandelsman.com/), [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/), [Amir Globerson](http://www.cs.tau.ac.il/~gamir/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/) 3 | ![Visual Prompting](https://yossigandelsman.github.io/visual_prompt/images/teaser.png) 4 | 5 | This repository is the implementation of the paper, for more info about this work see [Project Page](https://yossigandelsman.github.io/visual_prompt/). 6 | You can experiment with visual prompting using this (demo)[demo.ipynb]. 7 | 8 | ## Abstract 9 | How does one adapt a pre-trained visual model to novel downstream tasks without task-specific finetuning or any model modification? Inspired by prompting in NLP, this paper investigates visual prompting: given input-output image example(s) of a new task at test time and a new input image, the goal is to automatically produce the correct output image, consistent with the example(s) task. We show that posing this problem as a simple image inpainting task - literally just filling in a hole in a concatenated visual prompt image - turns out to be surprisingly effective, given that the inpainting algorithm has been trained on the right data. We train masked auto-encoding models on a new dataset that we curated - 88k unlabeled figures from academic papers sources on Arxiv. We apply visual prompting to these pretrained models and demonstrate results on various downstream tasks, including foreground segmentation, single object detection, colorization, edge detection, etc. 10 | 11 | ## Computer Vision Figures Dataset 12 | To download the dataset run: 13 | 14 | ``` 15 | cd figures_dataset 16 | sudo apt-get install poppler-utils 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | Download train/val: 21 | ``` 22 | python download_links.py --output_dir --split train 23 | python download_links.py --output_dir --split val 24 | ``` 25 | 26 | **Note**: the paper sources are hosted by arXiv and download time might take 2-3 days.
For inquiries/questions about this please email the authors directly. 27 | 28 | ## Train 29 | ### Prerequisites 30 | pytorch/pytorch-lightining installation, set cudatoolkit to your cuda version or choose an installation using these [instructions](https://pytorch.org/get-started/previous-versions/#v18). 31 | ``` 32 | conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 pytorch-lightning==1.6.2 -c pytorch -c conda-forge 33 | ``` 34 | 35 | Then install the following requirements: 36 | ``` 37 | pip install -r requirements.txt 38 | ``` 39 | Download pretrained VQGAN codebook checkpoint and config [_vqgan_imagenet_f16_1024_](https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/?p=%2F), place both _last.ckpt_ and _model.yaml_ on the repository root. 40 | 41 | ### Pretrain a model on CVF dataset with 8 V100 gpus: 42 | ``` 43 | python -m torch.distributed.launch --nproc_per_node=8 main_pretrain.py --model mae_vit_large_patch16 --input_size 224 --batch_size 64 --mask_ratio 0.75 --warmup_epochs 15 --epochs 1000 --blr 1e-4 --save_ckpt_freq 100 --output_dir logs_dir/maevqgan --data_path 44 | ``` 45 | 46 | ## Evaluation 47 | 48 | ### Dataset preparation: 49 | 50 | Our evaluation pipeline is based on [Volumetric Aggregation Transformer](https://github.com/Seokju-Cho/Volumetric-Aggregation-Transformer). Please follow the dataset preparation steps for PASCAL-5i dataset in this repository. 51 | 52 | ### Evaluate on Foreground Segm on Pascal 5i on split [0-3]: 53 | ``` 54 | cd evaluate && python evaluate_segmentation.py \ 55 | --model mae_vit_large_patch16 \ 56 | --base_dir \ 57 | --output_dir \ 58 | --ckpt \ 59 | --split \ 60 | --dataset_type pascal 61 | ``` 62 | The script will save a log.txt file with the results as well as results visualization. 63 | 64 | ### Evaluate on Reasoning Tasks: 65 | set dataset_type to the reasoning task out of 'color' 'shape' 'size' 'shape_color' 'size_color' 'size_shape'. 66 | 67 | ``` 68 | python -m evaluate.evaluate_reasoning \ 69 | --model mae_vit_large_patch16 \ 70 | --output_dir \ 71 | --ckpt \ 72 | --dataset_type color \ 73 | --tta_option 0 74 | ``` 75 | 76 | tta_option allows to play with different prompt ensmebling. tta_option=0 is for standard visual prompt. Other configurations are listed in visual_prompting/evaluate/evaluate_reasoning.py:42 77 | The script will save a log.txt file with the results as well as results visualization. 78 | 79 | ### Evaluate on Single Object Detection: 80 | ``` 81 | cd evaluate && python evaluate_segmentation.py \ 82 | --task detection \ 83 | --model mae_vit_large_patch16 \ 84 | --base_dir \ 85 | --output_dir \ 86 | --ckpt \ 87 | --dataset_type pascal_det 88 | ``` 89 | The script will save a log.txt file with the results as well as results visualization. 90 | 91 | ### Evaluate on Colorization: 92 | ``` 93 | python -m evaluate.evaluate_colorization \ 94 | --model mae_vit_large_patch16 \ 95 | --output_dir \ 96 | --ckpt \ 97 | --data_path 98 | ``` 99 | The script will save a log.txt file with the results as well as results visualization. 100 | 101 | 102 | # Pretrained Models 103 | | Model | Pretraining | Epochs | Link | 104 | |-------------------|-------------|--------|------| 105 | | MAE-VQGAN (ViT-L) | CVF | 1000 | [link](https://drive.google.com/file/d/1Xe0-cypS4dcwqbPuT8wflqj0b1E9Ct7E/view?usp=sharing) | 106 | | MAE-VQGAN (ViT-L) | CVF + IN | 3400 | [link](https://drive.google.com/file/d/130vNSlqg3faHzVGGh_vkeUUh-2uVX3se/view?usp=sharing) | 107 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirbar/visual_prompting/6351e0d786c2d25bfdd5fe504153b1dfb345fa81/__init__.py -------------------------------------------------------------------------------- /engine_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | import sys 14 | from typing import Iterable, Optional 15 | 16 | import torch 17 | 18 | from timm.data import Mixup 19 | from timm.utils import accuracy 20 | 21 | import util.misc as misc 22 | import util.lr_sched as lr_sched 23 | 24 | 25 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 26 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 27 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 28 | mixup_fn: Optional[Mixup] = None, log_writer=None, 29 | args=None): 30 | model.train(True) 31 | metric_logger = misc.MetricLogger(delimiter=" ") 32 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 33 | header = 'Epoch: [{}]'.format(epoch) 34 | print_freq = 20 35 | 36 | accum_iter = args.accum_iter 37 | 38 | optimizer.zero_grad() 39 | 40 | if log_writer is not None: 41 | print('log_dir: {}'.format(log_writer.log_dir)) 42 | 43 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 44 | 45 | # we use a per iteration (instead of per epoch) lr scheduler 46 | if data_iter_step % accum_iter == 0: 47 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 48 | 49 | samples = samples.to(device, non_blocking=True) 50 | targets = targets.to(device, non_blocking=True) 51 | 52 | if mixup_fn is not None: 53 | samples, targets = mixup_fn(samples, targets) 54 | 55 | with torch.cuda.amp.autocast(): 56 | outputs = model(samples) 57 | loss = criterion(outputs, targets) 58 | 59 | loss_value = loss.item() 60 | 61 | if not math.isfinite(loss_value): 62 | print("Loss is {}, stopping training".format(loss_value)) 63 | sys.exit(1) 64 | 65 | loss /= accum_iter 66 | loss_scaler(loss, optimizer, clip_grad=max_norm, 67 | parameters=model.parameters(), create_graph=False, 68 | update_grad=(data_iter_step + 1) % accum_iter == 0) 69 | if (data_iter_step + 1) % accum_iter == 0: 70 | optimizer.zero_grad() 71 | 72 | torch.cuda.synchronize() 73 | 74 | metric_logger.update(loss=loss_value) 75 | min_lr = 10. 76 | max_lr = 0. 77 | for group in optimizer.param_groups: 78 | min_lr = min(min_lr, group["lr"]) 79 | max_lr = max(max_lr, group["lr"]) 80 | 81 | metric_logger.update(lr=max_lr) 82 | 83 | loss_value_reduce = misc.all_reduce_mean(loss_value) 84 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 85 | """ We use epoch_1000x as the x-axis in tensorboard. 86 | This calibrates different curves when batch size changes. 87 | """ 88 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 89 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) 90 | log_writer.add_scalar('lr', max_lr, epoch_1000x) 91 | 92 | # gather the stats from all processes 93 | metric_logger.synchronize_between_processes() 94 | print("Averaged stats:", metric_logger) 95 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 96 | 97 | 98 | @torch.no_grad() 99 | def evaluate(data_loader, model, device): 100 | criterion = torch.nn.CrossEntropyLoss() 101 | 102 | metric_logger = misc.MetricLogger(delimiter=" ") 103 | header = 'Test:' 104 | 105 | # switch to evaluation mode 106 | model.eval() 107 | 108 | for batch in metric_logger.log_every(data_loader, 10, header): 109 | images = batch[0] 110 | target = batch[-1] 111 | images = images.to(device, non_blocking=True) 112 | target = target.to(device, non_blocking=True) 113 | 114 | # compute output 115 | with torch.cuda.amp.autocast(): 116 | output = model(images) 117 | loss = criterion(output, target) 118 | 119 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 120 | 121 | batch_size = images.shape[0] 122 | metric_logger.update(loss=loss.item()) 123 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 124 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 125 | # gather the stats from all processes 126 | metric_logger.synchronize_between_processes() 127 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 128 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 129 | 130 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import math 12 | import sys 13 | from typing import Iterable 14 | 15 | import torch 16 | 17 | import util.misc as misc 18 | import util.lr_sched as lr_sched 19 | 20 | 21 | def train_one_epoch(model: torch.nn.Module, 22 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 23 | device: torch.device, epoch: int, loss_scaler, 24 | log_writer=None, 25 | args=None, 26 | epoch_size=1): 27 | model.train(True) 28 | metric_logger = misc.MetricLogger(delimiter=" ") 29 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 30 | header = 'Epoch: [{}]'.format(epoch) 31 | print_freq = 20 32 | 33 | accum_iter = args.accum_iter 34 | 35 | optimizer.zero_grad() 36 | 37 | if log_writer is not None: 38 | print('log_dir: {}'.format(log_writer.log_dir)) 39 | data_loader_i = iter(data_loader) 40 | for data_iter_step in metric_logger.log_every(range(epoch_size), print_freq, header): 41 | (batch, _) = next(data_loader_i) 42 | # we use a per iteration (instead of per epoch) lr scheduler 43 | if isinstance(batch, tuple): 44 | samples, visual_tokens = batch 45 | samples = samples.to(device, non_blocking=True) 46 | visual_tokens = visual_tokens.to(device, non_blocking=True) 47 | else: # hack for consistency 48 | samples = batch 49 | samples = samples.to(device, non_blocking=True) 50 | visual_tokens = samples 51 | 52 | if data_iter_step % accum_iter == 0: 53 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 54 | 55 | with torch.cuda.amp.autocast(): 56 | loss_dict, _, _ = model(samples, visual_tokens, mask_ratio=args.mask_ratio) 57 | 58 | loss = torch.stack([loss_dict[l] for l in loss_dict if 'unscaled' not in l]).sum() 59 | loss_value = loss.item() 60 | 61 | if not math.isfinite(loss_value): 62 | print("Loss is {}, stopping training".format(loss_value)) 63 | sys.exit(1) 64 | 65 | loss /= accum_iter 66 | loss_scaler(loss, optimizer, parameters=model.parameters(), 67 | update_grad=(data_iter_step + 1) % accum_iter == 0) 68 | if (data_iter_step + 1) % accum_iter == 0: 69 | optimizer.zero_grad() 70 | 71 | torch.cuda.synchronize() 72 | 73 | metric_logger.update(**{k: v.item() for k, v in loss_dict.items()}) 74 | 75 | lr = optimizer.param_groups[0]["lr"] 76 | metric_logger.update(lr=lr) 77 | 78 | loss_value_reduce = misc.all_reduce_mean(loss_value) 79 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 80 | """ We use epoch_1000x as the x-axis in tensorboard. 81 | This calibrates different curves when batch size changes. 82 | """ 83 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 84 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 85 | log_writer.add_scalar('lr', lr, epoch_1000x) 86 | 87 | # gather the stats from all processes 88 | metric_logger.synchronize_between_processes() 89 | print("Averaged stats:", metric_logger) 90 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 91 | 92 | 93 | @torch.no_grad() 94 | def validate(model, data_loader, device, epoch, log_writer, args): 95 | model.eval() 96 | metric_logger = misc.MetricLogger(delimiter=" ") 97 | header = 'Epoch: [{}]'.format(epoch) 98 | print_freq = 50 99 | if log_writer is not None: 100 | print('log_dir: {}'.format(log_writer.log_dir)) 101 | 102 | for data_iter_step, (batch, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 103 | samples, visual_tokens = batch 104 | samples = samples.to(device, non_blocking=True) 105 | visual_tokens = visual_tokens.to(device, non_blocking=True) 106 | 107 | with torch.cuda.amp.autocast(): 108 | loss_dict, _, _ = model(samples, visual_tokens, mask_ratio=args.mask_ratio) 109 | 110 | loss = torch.stack([loss_dict[l] for l in loss_dict if 'unscaled' not in l]).sum() 111 | loss_value = loss.item() 112 | 113 | if not math.isfinite(loss_value): 114 | print("Loss is {}, stopping training".format(loss_value)) 115 | sys.exit(1) 116 | 117 | metric_logger.update(**{k: v.item() for k, v in loss_dict.items()}) 118 | 119 | # gather the stats from all processes 120 | metric_logger.synchronize_between_processes() 121 | print("Averaged stats for val:", metric_logger) 122 | return {'val_' + k: meter.global_avg for k, meter in metric_logger.meters.items()} 123 | -------------------------------------------------------------------------------- /evaluate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirbar/visual_prompting/6351e0d786c2d25bfdd5fe504153b1dfb345fa81/evaluate/__init__.py -------------------------------------------------------------------------------- /evaluate/evaluate_colorization.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | cwd = os.path.dirname(os.path.abspath(__file__)) 4 | sys.path.append(os.path.dirname(cwd)) 5 | import os.path 6 | 7 | import torchvision 8 | from tqdm import trange 9 | from evaluate.in_colorization_dataloader import DatasetColorization 10 | from evaluate.reasoning_dataloader import * 11 | from evaluate.mae_utils import * 12 | import argparse 13 | from pathlib import Path 14 | 15 | 16 | def get_args(): 17 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 18 | parser.add_argument('--model', default='mae_vit_small_patch16', type=str, metavar='MODEL', 19 | help='Name of model to train') 20 | parser.add_argument('--output_dir', default='../output_dir/') 21 | parser.add_argument('--data_path') 22 | parser.add_argument('--device', default='cuda', 23 | help='device to use for training / testing') 24 | parser.add_argument('--seed', default=0, type=int) 25 | parser.add_argument('--tta_option', default=0, type=int) 26 | parser.add_argument('--ckpt', help='resume from checkpoint') 27 | 28 | parser.set_defaults(autoregressive=False) 29 | return parser 30 | 31 | 32 | def _generate_result_for_canvas(args, model, canvas): 33 | """canvas is already in the right range.""" 34 | ids_shuffle, len_keep = generate_mask_for_evaluation() 35 | _, im_paste, _ = generate_image(canvas.unsqueeze(0).to(args.device), model, ids_shuffle.to(args.device), 36 | len_keep, device=args.device) 37 | canvas = torch.einsum('chw->hwc', canvas) 38 | canvas = torch.clip((canvas.cpu().detach() * imagenet_std + imagenet_mean) * 255, 0, 255).int().numpy() 39 | assert canvas.shape == im_paste.shape, (canvas.shape, im_paste.shape) 40 | return np.uint8(canvas), np.uint8(im_paste) 41 | 42 | 43 | def calculate_metric(args, target, ours): 44 | ours = (np.transpose(ours/255., [2, 0, 1]) - imagenet_mean[:, None, None]) / imagenet_std[:, None, None] 45 | target = (np.transpose(target/255., [2, 0, 1]) - imagenet_mean[:, None, None]) / imagenet_std[:, None, None] 46 | 47 | target = target[:, 113:, 113:] 48 | ours = ours[:, 113:, 113:] 49 | mse = np.mean((target - ours)**2) 50 | return {'mse': mse} 51 | 52 | 53 | def evaluate(args): 54 | with open(os.path.join(args.output_dir, 'log.txt'), 'w') as log: 55 | log.write(str(args) + '\n') 56 | 57 | model = prepare_model(args.ckpt, arch=args.model) 58 | _ = model.to(args.device) 59 | # Build the transforms: 60 | padding = 1 61 | 62 | image_transform = torchvision.transforms.Compose( 63 | [torchvision.transforms.CenterCrop((224 // 2 - padding, 224 // 2 - padding)), 64 | torchvision.transforms.ToTensor()]) 65 | mask_transform = torchvision.transforms.Compose( 66 | [torchvision.transforms.CenterCrop((224 // 2 - padding, 224 // 2 - padding)), 67 | torchvision.transforms.Grayscale(3), 68 | torchvision.transforms.ToTensor()]) 69 | 70 | ds = DatasetColorization(args.data_path, image_transform, mask_transform) 71 | 72 | eval_dict = {'mse': 0.} 73 | 74 | for idx in trange(len(ds)): 75 | canvas = ds[idx]['grid'] 76 | canvas = (canvas - imagenet_mean[:, None, None]) / imagenet_std[:, None, None] 77 | original_image, generated_result = _generate_result_for_canvas(args, model, canvas) 78 | 79 | if args.output_dir: 80 | Image.fromarray(np.uint8(original_image)).save( 81 | os.path.join(args.output_dir, f'original_{idx}.png')) 82 | Image.fromarray(np.uint8(generated_result)).save( 83 | os.path.join(args.output_dir, f'generated_{idx}.png')) 84 | 85 | if args.output_dir: 86 | Image.fromarray(np.uint8(generated_result)).save( 87 | os.path.join(args.output_dir, f'generated_before_rounding_{idx}.png')) 88 | Image.fromarray(np.uint8(generated_result)).save( 89 | os.path.join(args.output_dir, f'generated_rounded_{idx}.png')) 90 | Image.fromarray(np.uint8(original_image)).save( 91 | os.path.join(args.output_dir, f'original_{idx}.png')) 92 | Image.fromarray(np.uint8(generated_result)).save( 93 | os.path.join(args.output_dir, f'generated_fixed_{idx}.png')) 94 | 95 | current_metric = calculate_metric(args, original_image, generated_result) 96 | with open(os.path.join(args.output_dir, 'log.txt'), 'a') as log: 97 | log.write(str(idx) + '\t' + str(current_metric) + '\n') 98 | for i, j in current_metric.items(): 99 | eval_dict[i] += (j / len(ds)) 100 | 101 | with open(os.path.join(args.output_dir, 'log.txt'), 'a') as log: 102 | log.write('all\t' + str(eval_dict) + '\n') 103 | 104 | 105 | if __name__ == '__main__': 106 | args = get_args() 107 | 108 | args = args.parse_args() 109 | seed = args.seed 110 | torch.manual_seed(seed) 111 | np.random.seed(seed) 112 | if args.output_dir: 113 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 114 | evaluate(args) 115 | -------------------------------------------------------------------------------- /evaluate/evaluate_reasoning.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | cwd = os.path.dirname(os.path.abspath(__file__)) 4 | sys.path.append(os.path.dirname(cwd)) 5 | import os.path 6 | from tqdm import trange 7 | from evaluate.reasoning_dataloader import * 8 | import cv2 9 | from evaluate.mae_utils import * 10 | import argparse 11 | from pathlib import Path 12 | from tta import TTA, reverse_trans 13 | 14 | 15 | def get_args(): 16 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 17 | parser.add_argument('--model', default='mae_vit_small_patch16', type=str, metavar='MODEL', 18 | help='Name of model to train') 19 | parser.add_argument('--output_dir', default='../output_dir/') 20 | parser.add_argument('--device', default='cuda', 21 | help='device to use for training / testing') 22 | parser.add_argument('--seed', default=0, type=int) 23 | parser.add_argument('--tta_option', default=0, type=int) 24 | parser.add_argument('--ckpt', help='resume from checkpoint') 25 | parser.add_argument('--dataset_type', default='color') 26 | return parser 27 | 28 | 29 | def get_default_mask_2rows_mask(): 30 | mask = np.zeros((14,14)) 31 | mask[:9] = 1 32 | mask[:, :7] = 1 33 | mask[: ,12:] = 1 34 | return mask 35 | 36 | 37 | def _generate_result_for_canvas(args, model, inpt_pairs): 38 | """canvas is already in the right range.""" 39 | 40 | final_imgs = [] 41 | rcs_ls = [ 42 | [TTA(shuffle_rows=False, shuffle_cols=False, transpose=False)], 43 | [TTA(shuffle_rows=False, shuffle_cols=True, transpose=True)], 44 | [TTA(shuffle_rows=False, shuffle_cols=True, transpose=True), TTA(shuffle_rows=False, shuffle_cols=False, transpose=False)], 45 | [TTA(shuffle_rows=False, shuffle_cols=True, transpose=True), TTA(shuffle_rows=False, shuffle_cols=False, transpose=True), TTA(shuffle_rows=False, shuffle_cols=False, transpose=False)], 46 | [TTA(shuffle_rows=False, shuffle_cols=False, transpose=True)] 47 | ][args.tta_option] 48 | for i in range(len(rcs_ls)): 49 | rcs = rcs_ls[i] 50 | canvas, len_keep_ps, ids_shuffle_ps, psuedo_gt_mask, v_order, shuffle_cols, transpose = rcs( 51 | inpt_pairs) 52 | input_image = torch.tensor(canvas).unsqueeze(0).to(args.device) 53 | _, im_paste, _ = generate_image(input_image, model, ids_shuffle_ps.to(args.device), len_keep_ps, device=args.device) 54 | im_paste = reverse_trans(im_paste, v_order, shuffle_cols, transpose) 55 | final_imgs.append(im_paste) 56 | if len(final_imgs) > 1: 57 | im_paste = np.mean(final_imgs, axis=0) 58 | else: 59 | im_paste = final_imgs[0] 60 | rcs = TTA(shuffle_rows=False, shuffle_cols=False, transpose=False) 61 | canvas, _, _, _, _, _, _ = rcs(inpt_pairs) 62 | input_image = torch.tensor(canvas).unsqueeze(0).to(args.device) 63 | canvas = torch.einsum('chw->hwc', input_image[0]) 64 | canvas = torch.clip((canvas.cpu().detach() * imagenet_std + imagenet_mean) * 255, 0, 255).int().numpy() 65 | assert canvas.shape == im_paste.shape, (canvas.shape, im_paste.shape) 66 | return np.uint8(canvas), np.uint8(im_paste) 67 | 68 | 69 | def is_square(mask): 70 | mask = np.uint8(mask) 71 | contours,_ = cv2.findContours(mask.copy(), 1, 1) # not copying here will throw an error 72 | if not contours: 73 | return None 74 | x,y,w,h = cv2.boundingRect(contours[0]) 75 | radius = max(h,w) // 2 76 | center_x = x + w/2 77 | center_y = y + h/2 78 | circle_mask = np.zeros_like(mask) 79 | circle_mask = cv2.circle(circle_mask, (int(center_x),int(center_y)), radius, 1, -1) 80 | circle_mask = circle_mask > 0 81 | square_mask = np.zeros_like(mask) 82 | square_mask[int(center_y)-radius: int(center_y)+radius, int(center_x)-radius:int(center_x)+radius] = 1 83 | square_mask = square_mask > 0 84 | mask = mask > 0 85 | square_shape = np.sum(np.float32(square_mask & mask)) / np.sum(np.float32(square_mask | mask)) 86 | circle_shape = np.sum(np.float32(circle_mask & mask)) / np.sum(np.float32(circle_mask | mask)) 87 | return square_shape > circle_shape 88 | 89 | def calculate_metric(args, target, ours): 90 | # Crop the right area: 91 | target = target[-74:, 113: 113+74] 92 | ours = ours[-74:, 113: 113+74] 93 | # Calculate accuracy: 94 | accuracy = np.sum(np.float32((target == ours).all(axis=2))) / (ours.shape[0] * ours.shape[1]) 95 | colors = np.unique(np.reshape(target, (-1, 3)), axis=0) 96 | assert colors.shape[0] == 2, colors # white and the expected color 97 | other_color = colors[0] if np.all(colors[1] == np.array([255,255,255])) else colors[1] 98 | seg_orig = ((target - other_color[np.newaxis, np.newaxis,:]) == 0).all(axis=2) 99 | seg_our = ((ours - other_color[np.newaxis, np.newaxis,:]) == 0).all(axis=2) 100 | color_blind_seg_our = (ours - np.array([[[255,255,255]]]) != 0).any(axis=2) 101 | iou = np.sum(np.float32(seg_orig & seg_our)) / np.sum(np.float32(seg_orig | seg_our)) 102 | color_blind_iou = np.sum(np.float32(seg_orig & color_blind_seg_our)) / np.sum(np.float32(seg_orig | color_blind_seg_our)) 103 | shape_accuracy = is_square(color_blind_seg_our) 104 | d = { 105 | 'iou': iou, 106 | 'color_blind_iou': color_blind_iou, 107 | 'accuracy': accuracy, 108 | } 109 | if shape_accuracy is not None: 110 | d['shape_accuracy'] = shape_accuracy 111 | return d 112 | 113 | def evaluate(args): 114 | with open(os.path.join(args.output_dir, 'log.txt'), 'w') as log: 115 | log.write(str(args)+'\n') 116 | 117 | ds = { 118 | 'size': SizeChangeTask, 119 | 'size_color': ChangeSizeColorTask, 120 | 'size_shape': ChangeSizeShapeTask, 121 | 'color': ColorChangeTask, 122 | 'shape': ShapeChangeTask, 123 | 'shape_color': ChangeShapeColorTask, 124 | }[args.dataset_type]() 125 | model = prepare_model(args.ckpt, arch=args.model) 126 | _ = model.to(args.device) 127 | # Build the transforms: 128 | figure_size = 74 129 | transforms = T.Compose([ 130 | T.Resize((figure_size, figure_size)), 131 | T.ToTensor(), 132 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 133 | num_support = 2 134 | eval_dict = {'iou': 0, 'color_blind_iou': 0, 'accuracy': 0, 'shape_accuracy': 0} 135 | for idx in trange(len(ds)): 136 | query_image, query_target = ds[idx] 137 | pairs = [] 138 | 139 | for k in range(num_support): 140 | 141 | idx2 = np.random.choice(np.arange(len(ds))) 142 | support_image, support_target = ds[idx2] 143 | pairs.append((support_image, support_target)) 144 | 145 | pairs.append((query_image, query_target)) 146 | 147 | inpt_pairs = [] 148 | for p in pairs: 149 | support_image, support_target = p 150 | support_image_ten = transforms(Image.fromarray(support_image)) 151 | support_target_ten = transforms(Image.fromarray(support_target)) 152 | inpt_pairs.append((support_image_ten, support_target_ten)) 153 | 154 | # Calculate the original_image and the result 155 | original_image, generated_result = _generate_result_for_canvas(args, model, inpt_pairs) 156 | original_image = round_image(original_image, ds.color_options() + [BLACK]) 157 | generated_result = round_image(generated_result, ds.color_options()+ [BLACK]) 158 | if args.output_dir: 159 | Image.fromarray(np.uint8(original_image)).save( 160 | os.path.join(args.output_dir, f'original_{idx}.png')) 161 | Image.fromarray(np.uint8(generated_result)).save( 162 | os.path.join(args.output_dir, f'generated_{idx}.png')) 163 | current_metric = calculate_metric(args, original_image, generated_result) 164 | with open(os.path.join(args.output_dir, 'log.txt'), 'a') as log: 165 | log.write(str(idx)+'\t'+str(current_metric)+'\n') 166 | for i, j in current_metric.items(): 167 | eval_dict[i] += (j / len(ds)) 168 | 169 | with open(os.path.join(args.output_dir, 'log.txt'), 'a') as log: 170 | log.write('all\t'+str(eval_dict)+'\n') 171 | 172 | 173 | if __name__ == '__main__': 174 | args = get_args() 175 | 176 | args = args.parse_args() 177 | seed = args.seed 178 | torch.manual_seed(seed) 179 | np.random.seed(seed) 180 | if args.output_dir: 181 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 182 | evaluate(args) 183 | -------------------------------------------------------------------------------- /evaluate/evaluate_segmentation.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from tqdm import trange 3 | import pascal_dataloader 4 | from evaluate_detection.box_ops import to_rectangle 5 | from evaluate_detection.canvas_ds import CanvasDataset 6 | from reasoning_dataloader import * 7 | import torchvision 8 | from mae_utils import * 9 | import argparse 10 | from pathlib import Path 11 | from segmentation_utils import * 12 | 13 | 14 | def get_args(): 15 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 16 | parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL', 17 | help='Name of model to train') 18 | parser.add_argument('--output_dir', default='../output_dir/') 19 | parser.add_argument('--device', default='cuda', 20 | help='device to use for training / testing') 21 | parser.add_argument('--base_dir', default='/shared/yossi_gandelsman/code/occlusionwalk/pascal', help='pascal base dir') 22 | parser.add_argument('--seed', default=0, type=int) 23 | parser.add_argument('--t', default=[0, 0, 0], type=float, nargs='+') 24 | parser.add_argument('--task', default='segmentation', choices=['segmentation', 'detection']) 25 | parser.add_argument('--ckpt', help='model checkpoint') 26 | parser.add_argument('--dataset_type', default='pascal', 27 | choices=['pascal', 'pascal_det']) 28 | parser.add_argument('--split', default=0, type=int) 29 | parser.add_argument('--purple', default=0, type=int) 30 | parser.add_argument('--flip', default=0, type=int) 31 | return parser 32 | 33 | 34 | def _generate_result_for_ens(args, model, canvases, method='sum'): 35 | ids_shuffle, len_keep = generate_mask_for_evaluation() 36 | num_patches = 14 37 | if method == 'sum': 38 | canvas, canvas2 = canvases[0], canvases[1] 39 | mask, orig_image, x1 = generate_raw_prediction(args.device, ids_shuffle, len_keep, model, canvas.unsqueeze(0)) 40 | _, _, x2 = generate_raw_prediction(args.device, ids_shuffle, len_keep, model, canvas2.unsqueeze(0)) 41 | x1 = torch.softmax(x1, dim=-1) 42 | x2 = torch.softmax(x2, dim=-1) 43 | y = ((x1 + x2)/2).argmax(dim=-1) 44 | im_paste, mask, orig_image = decode_raw_predicion(mask, model, num_patches, orig_image, y) 45 | 46 | elif method == 'sum_pre': 47 | canvas, canvas2 = canvases[0], canvases[1] 48 | mask, orig_image, x1 = generate_raw_prediction(args.device, ids_shuffle, len_keep, model, canvas.unsqueeze(0)) 49 | _, _, x2 = generate_raw_prediction(args.device, ids_shuffle, len_keep, model, canvas2.unsqueeze(0)) 50 | y = ((x1 + x2) / 2).argmax(dim=-1) 51 | im_paste, mask, orig_image = decode_raw_predicion(mask, model, num_patches, orig_image, y) 52 | 53 | elif method == 'mult': 54 | canvas, canvas2 = canvases[0], canvases[1] 55 | mask, orig_image, x1 = generate_raw_prediction(args.device, ids_shuffle, len_keep, model, canvas.unsqueeze(0)) 56 | _, _, x2 = generate_raw_prediction(args.device, ids_shuffle, len_keep, model, canvas2.unsqueeze(0)) 57 | x1 = torch.softmax(x1, dim=-1) 58 | x2 = torch.softmax(x2, dim=-1) 59 | y = (x1 * x2).argmax(dim=-1) 60 | im_paste, mask, orig_image = decode_raw_predicion(mask, model, num_patches, orig_image, y) 61 | 62 | elif method == 'max': 63 | canvas, canvas2 = canvases[0], canvases[1] 64 | mask, orig_image, x1 = generate_raw_prediction(args.device, ids_shuffle, len_keep, model, canvas.unsqueeze(0)) 65 | _, _, x2 = generate_raw_prediction(args.device, ids_shuffle, len_keep, model, canvas2.unsqueeze(0)) 66 | x1 = torch.softmax(x1, dim=-1) 67 | x2 = torch.softmax(x2, dim=-1) 68 | y = torch.argmax(torch.max(torch.stack([x1,x2], dim=-1), dim=-1)[0], dim=-1) 69 | im_paste, mask, orig_image = decode_raw_predicion(mask, model, num_patches, orig_image, y) 70 | 71 | # elif method == 'union': 72 | # canvas, canvas2 = canvases[0], canvases[1] 73 | # mask, orig_image, x1 = generate_raw_prediction(args.device, ids_shuffle, len_keep, model, canvas.unsqueeze(0)) 74 | # _, _, x2 = generate_raw_prediction(args.device, ids_shuffle, len_keep, model, canvas2.unsqueeze(0)) 75 | # y1 = torch.argmax(x1, dim=-1) 76 | # y2 = torch.argmax(x2, dim=-1) 77 | # im_paste1, mask, orig_image = decode_raw_predicion(mask, model, num_patches, orig_image, y1) 78 | # im_paste2, mask, orig_image = decode_raw_predicion(mask, model, num_patches, orig_image, y2) 79 | # 80 | # 81 | else: 82 | raise ValueError("Wrong ens") 83 | canvas = torch.einsum('chw->hwc', canvas) 84 | canvas = torch.clip((canvas.cpu().detach() * imagenet_std + imagenet_mean) * 255, 0, 255).int().numpy() 85 | 86 | return np.uint8(canvas), np.uint8(im_paste[0]), mask 87 | 88 | 89 | def _generate_result_for_canvas(args, model, canvas): 90 | """canvas is already in the right range.""" 91 | ids_shuffle, len_keep = generate_mask_for_evaluation() 92 | _, im_paste, _ = generate_image(canvas.unsqueeze(0).to(args.device), model, ids_shuffle.to(args.device), 93 | len_keep, device=args.device) 94 | canvas = torch.einsum('chw->hwc', canvas) 95 | canvas = torch.clip((canvas.cpu().detach() * imagenet_std + imagenet_mean) * 255, 0, 255).int().numpy() 96 | assert canvas.shape == im_paste.shape, (canvas.shape, im_paste.shape) 97 | return np.uint8(canvas), np.uint8(im_paste) 98 | 99 | 100 | def evaluate(args): 101 | with open(os.path.join(args.output_dir, 'log.txt'), 'w') as log: 102 | log.write(str(args) + '\n') 103 | padding = 1 104 | image_transform = torchvision.transforms.Compose( 105 | [torchvision.transforms.Resize((224 // 2 - padding, 224 // 2 - padding), 3), 106 | torchvision.transforms.ToTensor()]) 107 | mask_transform = torchvision.transforms.Compose( 108 | [torchvision.transforms.Resize((224 // 2 - padding, 224 // 2 - padding), 3), 109 | torchvision.transforms.ToTensor()]) 110 | ds = { 111 | 'pascal': pascal_dataloader.DatasetPASCAL, 112 | 'pascal_det': CanvasDataset 113 | }[args.dataset_type](args.base_dir, fold=args.split, image_transform=image_transform, mask_transform=mask_transform, 114 | flipped_order=args.flip, purple=args.purple) 115 | model = prepare_model(args.ckpt, arch=args.model) 116 | _ = model.to(args.device) 117 | # Build the transforms: 118 | eval_dict = {'iou': 0, 'color_blind_iou': 0, 'accuracy': 0} 119 | for idx in trange(len(ds)): 120 | canvas = ds[idx]['grid'] 121 | if args.dataset_type != 'pascal_det': 122 | canvas = (canvas - imagenet_mean[:, None, None]) / imagenet_std[:, None, None] 123 | # Calculate the original_image and the result 124 | original_image, generated_result = _generate_result_for_canvas(args, model, canvas) 125 | if args.output_dir: 126 | Image.fromarray(np.uint8(original_image)).save( 127 | os.path.join(args.output_dir, f'original_{idx}.png')) 128 | Image.fromarray(np.uint8(generated_result)).save( 129 | os.path.join(args.output_dir, f'generated_{idx}.png')) 130 | if args.purple: 131 | original_image = round_image(original_image, [YELLOW, PURPLE]) 132 | else: 133 | original_image = round_image(original_image, [WHITE, BLACK]) 134 | 135 | if args.output_dir: 136 | Image.fromarray(np.uint8(generated_result)).save( 137 | os.path.join(args.output_dir, f'generated_before_rounding_{idx}.png')) 138 | 139 | if args.purple: 140 | generated_result = round_image(generated_result, [YELLOW, PURPLE], t=args.t) 141 | else: 142 | generated_result = round_image(generated_result, [WHITE, BLACK], t=args.t) 143 | 144 | if args.output_dir: 145 | Image.fromarray(np.uint8(generated_result)).save( 146 | os.path.join(args.output_dir, f'generated_rounded_{idx}.png')) 147 | 148 | if args.task == 'detection': 149 | generated_result = to_rectangle(generated_result) 150 | 151 | if args.output_dir: 152 | Image.fromarray(np.uint8(original_image)).save( 153 | os.path.join(args.output_dir, f'original_{idx}.png')) 154 | Image.fromarray(np.uint8(generated_result)).save( 155 | os.path.join(args.output_dir, f'generated_fixed_{idx}.png')) 156 | if args.purple: 157 | current_metric = calculate_metric(args, original_image, generated_result, fg_color=YELLOW, bg_color=PURPLE) 158 | else: 159 | current_metric = calculate_metric(args, original_image, generated_result, fg_color=WHITE, bg_color=BLACK) 160 | with open(os.path.join(args.output_dir, 'log.txt'), 'a') as log: 161 | log.write(str(idx) + '\t' + str(current_metric) + '\n') 162 | for i, j in current_metric.items(): 163 | eval_dict[i] += (j / len(ds)) 164 | 165 | with open(os.path.join(args.output_dir, 'log.txt'), 'a') as log: 166 | log.write('all\t' + str(eval_dict) + '\n') 167 | 168 | 169 | if __name__ == '__main__': 170 | args = get_args() 171 | 172 | args = args.parse_args() 173 | seed = args.seed 174 | torch.manual_seed(seed) 175 | np.random.seed(seed) 176 | if args.output_dir: 177 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 178 | evaluate(args) 179 | -------------------------------------------------------------------------------- /evaluate/in_colorization_dataloader.py: -------------------------------------------------------------------------------- 1 | """Based on https://github.com/Seokju-Cho/Volumetric-Aggregation-Transformer/blob/main/data/pascal.py 2 | """ 3 | import os 4 | from PIL import Image 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | from torchvision.datasets import ImageFolder 9 | 10 | 11 | class DatasetColorization(Dataset): 12 | def __init__(self, datapath, image_transform, mask_transform, padding: bool = 1, 13 | use_original_imgsize: bool = False, flipped_order: bool = False, 14 | reverse_support_and_query: bool = False, random: bool = False): 15 | self.padding = padding 16 | self.random = random 17 | self.use_original_imgsize = use_original_imgsize 18 | self.image_transform = image_transform 19 | self.reverse_support_and_query = reverse_support_and_query 20 | self.mask_transform = mask_transform 21 | self.ds = ImageFolder(os.path.join(datapath, 'val')) 22 | self.flipped_order = flipped_order 23 | np.random.seed(5) 24 | self.indices = np.random.choice(np.arange(0, len(self.ds)-1), size=1000, replace=False) 25 | 26 | 27 | def __len__(self): 28 | return 1000 29 | 30 | def create_grid_from_images(self, support_img, support_mask, query_img, query_mask): 31 | if self.reverse_support_and_query: 32 | support_img, support_mask, query_img, query_mask = query_img, query_mask, support_img, support_mask 33 | canvas = torch.ones((support_img.shape[0], 2 * support_img.shape[1] + 2 * self.padding, 34 | 2 * support_img.shape[2] + 2 * self.padding)) 35 | canvas[:, :support_img.shape[1], :support_img.shape[2]] = support_img 36 | if self.flipped_order: 37 | canvas[:, :support_img.shape[1], -support_img.shape[2]:] = query_img 38 | canvas[:, -query_img.shape[1]:, -support_img.shape[2]:] = query_mask 39 | canvas[:, -query_img.shape[1]:, :query_img.shape[2]] = support_mask 40 | else: 41 | canvas[:, -query_img.shape[1]:, :query_img.shape[2]] = query_img 42 | canvas[:, :support_img.shape[1], -support_img.shape[2]:] = support_mask 43 | canvas[:, -query_img.shape[1]:, -support_img.shape[2]:] = query_mask 44 | 45 | return canvas 46 | 47 | def __getitem__(self, idx): 48 | support_idx = np.random.choice(np.arange(0, len(self)-1)) 49 | idx = self.indices[idx] 50 | query, support = self.ds[idx], self.ds[support_idx] 51 | query_img, query_mask = self.mask_transform(query[0]), self.image_transform(query[0]) 52 | support_img, support_mask = self.mask_transform(support[0]), self.image_transform(support[0]) 53 | grid = self.create_grid_from_images(support_img, support_mask, query_img, query_mask) 54 | batch = {'query_img': query_img, 'query_mask': query_mask, 'support_img': support_img, 55 | 'support_mask': support_mask, 'grid': grid} 56 | 57 | return batch -------------------------------------------------------------------------------- /evaluate/pascal_dataloader.py: -------------------------------------------------------------------------------- 1 | """Based on https://github.com/Seokju-Cho/Volumetric-Aggregation-Transformer/blob/main/data/pascal.py 2 | """ 3 | import os 4 | from PIL import Image 5 | from scipy.io import loadmat 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | from mae_utils import PURPLE, YELLOW 10 | 11 | def create_grid_from_images_old(canvas, support_img, support_mask, query_img, query_mask): 12 | canvas[:, :support_img.shape[1], :support_img.shape[2]] = support_img 13 | canvas[:, -query_img.shape[1]:, :query_img.shape[2]] = query_img 14 | canvas[:, :support_img.shape[1], -support_img.shape[2]:] = support_mask 15 | canvas[:, -query_img.shape[1]:, -support_img.shape[2]:] = query_mask 16 | return canvas 17 | 18 | class DatasetPASCAL(Dataset): 19 | def __init__(self, datapath, fold, image_transform, mask_transform, padding: bool = 1, use_original_imgsize: bool = False, flipped_order: bool = False, 20 | reverse_support_and_query: bool = False, random: bool = False, ensemble: bool = False, purple: bool = False): 21 | self.fold = fold 22 | self.nfolds = 4 23 | self.flipped_order = flipped_order 24 | self.nclass = 20 25 | self.padding = padding 26 | self.random = random 27 | self.ensemble = ensemble 28 | self.purple = purple 29 | self.use_original_imgsize = use_original_imgsize 30 | 31 | self.img_path = os.path.join(datapath, 'VOCdevkit/VOC2012/JPEGImages/') 32 | self.ann_path = os.path.join(datapath, 'VOCdevkit/VOC2012/SegmentationClassAug/') 33 | self.image_transform = image_transform 34 | self.reverse_support_and_query = reverse_support_and_query 35 | self.mask_transform = mask_transform 36 | 37 | self.class_ids = self.build_class_ids() 38 | self.img_metadata = self.build_img_metadata() 39 | self.img_metadata_classwise = self.build_img_metadata_classwise() 40 | 41 | def __len__(self): 42 | return 1000 43 | 44 | def create_grid_from_images(self, support_img, support_mask, query_img, query_mask, flip: bool = False): 45 | if self.reverse_support_and_query: 46 | support_img, support_mask, query_img, query_mask = query_img, query_mask, support_img, support_mask 47 | canvas = torch.ones((support_img.shape[0], 2 * support_img.shape[1] + 2 * self.padding, 2 * support_img.shape[2] + 2 * self.padding)) 48 | canvas[:, :support_img.shape[1], :support_img.shape[2]] = support_img 49 | if flip: 50 | canvas[:, :support_img.shape[1], -support_img.shape[2]:] = query_img 51 | canvas[:, -query_img.shape[1]:, -support_img.shape[2]:] = query_mask 52 | canvas[:, -query_img.shape[1]:, :query_img.shape[2]] = support_mask 53 | else: 54 | canvas[:, -query_img.shape[1]:, :query_img.shape[2]] = query_img 55 | canvas[:, :support_img.shape[1], -support_img.shape[2]:] = support_mask 56 | canvas[:, -query_img.shape[1]:, -support_img.shape[2]:] = query_mask 57 | 58 | return canvas 59 | 60 | def __getitem__(self, idx): 61 | idx %= len(self.img_metadata) # for testing, as n_images < 1000 62 | query_name, support_name, class_sample_query, class_sample_support = self.sample_episode(idx) 63 | query_img, query_cmask, support_img, support_cmask, org_qry_imsize = self.load_frame(query_name, support_name) 64 | 65 | if self.image_transform: 66 | query_img = self.image_transform(query_img) 67 | query_mask, query_ignore_idx = self.extract_ignore_idx(query_cmask, class_sample_query, purple=self.purple) 68 | if self.mask_transform: 69 | query_mask = self.mask_transform(query_mask) 70 | 71 | if self.image_transform: 72 | support_img = self.image_transform(support_img) 73 | support_mask, support_ignore_idx = self.extract_ignore_idx(support_cmask, class_sample_support, purple=self.purple) 74 | if self.mask_transform: 75 | support_mask = self.mask_transform(support_mask) 76 | grid = self.create_grid_from_images(support_img, support_mask, query_img, query_mask, flip=self.flipped_order) 77 | if self.ensemble: 78 | grid2 = self.create_grid_from_images(support_img, support_mask, query_img, query_mask, (not self.flipped_order)) 79 | 80 | 81 | support_purple_mask, _ = self.extract_ignore_idx(support_cmask, class_sample_support, 82 | purple=True) 83 | if self.mask_transform: 84 | support_purple_mask = self.mask_transform(support_purple_mask) 85 | 86 | grid3 = self.create_grid_from_images(support_img, support_purple_mask, query_img, query_mask, 87 | flip=self.flipped_order) 88 | 89 | grid4 = self.create_grid_from_images(support_img, support_purple_mask, query_img, query_mask, 90 | flip=(not self.flipped_order)) 91 | 92 | grid = grid, grid2, grid3, grid4 93 | batch = {'query_img': query_img, 94 | 'query_mask': query_mask, 95 | 'query_name': query_name, 96 | 'query_ignore_idx': query_ignore_idx, 97 | 'org_query_imsize': org_qry_imsize, 98 | 'support_img': support_img, 99 | 'support_mask': support_mask, 100 | 'support_name': support_name, 101 | 'support_ignore_idx': support_ignore_idx, 102 | 'class_id': torch.tensor(class_sample_query), 103 | 'grid': grid} 104 | 105 | return batch 106 | 107 | def extract_ignore_idx(self, mask, class_id, purple): 108 | mask = np.array(mask) 109 | boundary = np.floor(mask / 255.) 110 | if not purple: 111 | mask[mask != class_id + 1] = 0 112 | mask[mask == class_id + 1] = 255 113 | return Image.fromarray(mask), boundary 114 | color_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) 115 | for x in range(mask.shape[0]): 116 | for y in range(mask.shape[1]): 117 | if mask[x,y] != class_id + 1: 118 | color_mask[x, y] = np.array(PURPLE) 119 | else: 120 | color_mask[x, y] = np.array(YELLOW) 121 | return Image.fromarray(color_mask), boundary 122 | 123 | 124 | def load_frame(self, query_name, support_name): 125 | query_img = self.read_img(query_name) 126 | query_mask = self.read_mask(query_name) 127 | support_img = self.read_img(support_name) 128 | support_mask = self.read_mask(support_name) 129 | org_qry_imsize = query_img.size 130 | 131 | return query_img, query_mask, support_img, support_mask, org_qry_imsize 132 | 133 | def read_mask(self, img_name): 134 | r"""Return segmentation mask in PIL Image""" 135 | mask = Image.open(os.path.join(self.ann_path, img_name) + '.png') 136 | return mask 137 | 138 | def read_img(self, img_name): 139 | r"""Return RGB image in PIL Image""" 140 | return Image.open(os.path.join(self.img_path, img_name) + '.jpg') 141 | 142 | def sample_episode(self, idx): 143 | """Returns the index of the query, support and class.""" 144 | query_name, class_sample = self.img_metadata[idx] 145 | if not self.random: 146 | support_class = class_sample 147 | else: 148 | support_class = np.random.choice([k for k in self.img_metadata_classwise.keys() if self.img_metadata_classwise[k]], 1, replace=False)[0] 149 | while True: # keep sampling support set if query == support 150 | support_name = np.random.choice(self.img_metadata_classwise[support_class], 1, replace=False)[0] 151 | if query_name != support_name: 152 | break 153 | return query_name, support_name, class_sample, support_class 154 | 155 | def build_class_ids(self): 156 | nclass_trn = self.nclass // self.nfolds 157 | class_ids_val = [self.fold * nclass_trn + i for i in range(nclass_trn)] 158 | return class_ids_val 159 | 160 | def build_img_metadata(self): 161 | 162 | def read_metadata(split, fold_id): 163 | cwd = os.path.dirname(os.path.abspath(__file__)) 164 | fold_n_metadata = os.path.join(cwd, 'splits/pascal/%s/fold%d.txt' % (split, fold_id)) 165 | with open(fold_n_metadata, 'r') as f: 166 | fold_n_metadata = f.read().split('\n')[:-1] 167 | fold_n_metadata = [[data.split('__')[0], int(data.split('__')[1]) - 1] for data in fold_n_metadata] 168 | return fold_n_metadata 169 | 170 | img_metadata = [] 171 | img_metadata = read_metadata('val', self.fold) 172 | 173 | print('Total (val) images are : %d' % len(img_metadata)) 174 | 175 | return img_metadata 176 | 177 | def build_img_metadata_classwise(self): 178 | img_metadata_classwise = {} 179 | for class_id in range(self.nclass): 180 | img_metadata_classwise[class_id] = [] 181 | 182 | for img_name, img_class in self.img_metadata: 183 | img_metadata_classwise[img_class] += [img_name] 184 | return img_metadata_classwise 185 | 186 | class DatasetPASCALforFinetune(Dataset): 187 | def __init__(self, datapath, fold, image_transform, mask_transform, num_supports=1, padding: bool = 1, 188 | use_original_imgsize: bool = False, random: bool = False): 189 | self.fold = fold 190 | self.nfolds = 4 191 | self.nclass = 20 192 | self.padding = padding 193 | self.random = random 194 | self.use_original_imgsize = use_original_imgsize 195 | 196 | self.img_path = os.path.join(datapath, 'VOCdevkit/VOC2012/JPEGImages/') 197 | self.ann_path = os.path.join(datapath, 'VOCdevkit/VOC2012/SegmentationClassAug/') 198 | self.image_transform = image_transform 199 | self.mask_transform = mask_transform 200 | 201 | self.class_ids = self.build_class_ids() 202 | self.img_metadata = self.build_img_metadata() 203 | self.img_metadata_classwise = self.build_img_metadata_classwise() 204 | self.num_supports = num_supports 205 | 206 | def __len__(self): 207 | return 1000 208 | 209 | def __getitem__(self, idx): 210 | idx %= len(self.img_metadata) # for testing, as n_images < 1000 211 | query_name, support_names, class_sample_query, class_sample_support = self.sample_episode(idx) 212 | query_img, query_cmask, support_imgs, support_cmasks, org_qry_imsize = self.load_frame(query_name, 213 | support_names) 214 | 215 | query_mask = self.extract_ignore_idx(query_cmask, class_sample_query)[0] 216 | 217 | if self.image_transform: 218 | query_img = self.image_transform(query_img) 219 | 220 | if self.mask_transform: 221 | query_mask = self.mask_transform(query_mask) 222 | 223 | if self.image_transform: 224 | for i in range(len(support_imgs)): 225 | support_imgs[i] = self.image_transform(support_imgs[i]) 226 | support_masks = [self.extract_ignore_idx(m, class_sample_support)[0] for m in support_cmasks] 227 | if self.mask_transform: 228 | for i in range(len(support_masks)): 229 | support_masks[i] = self.mask_transform(support_masks[i]) 230 | 231 | batch = {'query_img': query_img, 232 | 'query_mask': query_mask, 233 | 'query_name': query_name, 234 | 'org_query_imsize': org_qry_imsize, 235 | 'support_imgs': support_imgs, 236 | 'support_masks': support_masks, 237 | 'class_id': torch.tensor(class_sample_query), 238 | } 239 | 240 | return batch 241 | 242 | def extract_ignore_idx(self, mask, class_id): 243 | mask = np.array(mask) 244 | boundary = np.floor(mask / 255.) 245 | mask[mask != class_id + 1] = 0. 246 | mask[mask == class_id + 1] = 255. 247 | return Image.fromarray(mask), boundary 248 | 249 | def load_frame(self, query_name, support_names): 250 | query_img = self.read_img(query_name) 251 | query_mask = self.read_mask(query_name) 252 | support_imgs = [] 253 | support_masks = [] 254 | for support_name in support_names: 255 | support_imgs.append(self.read_img(support_name)) 256 | support_masks.append(self.read_mask(support_name)) 257 | 258 | org_qry_imsize = query_img.size 259 | 260 | return query_img, query_mask, support_imgs, support_masks, org_qry_imsize 261 | 262 | def read_mask(self, img_name): 263 | r"""Return segmentation mask in PIL Image""" 264 | mask = Image.open(os.path.join(self.ann_path, img_name) + '.png') 265 | return mask 266 | 267 | def read_img(self, img_name): 268 | r"""Return RGB image in PIL Image""" 269 | return Image.open(os.path.join(self.img_path, img_name) + '.jpg') 270 | 271 | def sample_episode(self, idx): 272 | """Returns the index of the query, support and class.""" 273 | query_name, class_sample = self.img_metadata[idx] 274 | if not self.random: 275 | support_class = class_sample 276 | else: 277 | support_class = \ 278 | np.random.choice([k for k in self.img_metadata_classwise.keys() if self.img_metadata_classwise[k]], 1, 279 | replace=False)[0] 280 | 281 | support_names = [] 282 | while True: # keep sampling support set if query == support 283 | support_name = np.random.choice(self.img_metadata_classwise[support_class], 1, replace=False)[0] 284 | if query_name != support_name and support_name not in support_names: 285 | support_names.append(support_name) 286 | if len(support_names) >= self.num_supports: 287 | break 288 | return query_name, support_names, class_sample, support_class 289 | 290 | def build_class_ids(self): 291 | nclass_trn = self.nclass // self.nfolds 292 | class_ids_val = [self.fold * nclass_trn + i for i in range(nclass_trn)] 293 | return class_ids_val 294 | 295 | def build_img_metadata(self): 296 | 297 | def read_metadata(split, fold_id): 298 | cwd = os.path.dirname(os.path.abspath(__file__)) 299 | fold_n_metadata = os.path.join(cwd, 'splits/pascal/%s/fold%d.txt' % (split, fold_id)) 300 | with open(fold_n_metadata, 'r') as f: 301 | fold_n_metadata = f.read().split('\n')[:-1] 302 | fold_n_metadata = [[data.split('__')[0], int(data.split('__')[1]) - 1] for data in fold_n_metadata] 303 | return fold_n_metadata 304 | 305 | img_metadata = [] 306 | img_metadata = read_metadata('val', self.fold) 307 | 308 | print('Total (val) images are : %d' % len(img_metadata)) 309 | 310 | return img_metadata 311 | 312 | def build_img_metadata_classwise(self): 313 | img_metadata_classwise = {} 314 | for class_id in range(self.nclass): 315 | img_metadata_classwise[class_id] = [] 316 | 317 | for img_name, img_class in self.img_metadata: 318 | img_metadata_classwise[img_class] += [img_name] 319 | return img_metadata_classwise -------------------------------------------------------------------------------- /evaluate/reasoning_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | import torchvision.transforms as T 5 | import cv2 6 | import numpy as np 7 | 8 | h, w = 224, 224 9 | 10 | 11 | def get_img(center_coordinates=None, radius=None, color=None, shape=None): 12 | image = Image.new('RGB', size=(224, 224), color='white') 13 | w, h = image.size 14 | if center_coordinates is None: 15 | center_coords_x = np.random.randint(10, w-10) 16 | center_coords_y = np.random.randint(10, h-10) 17 | center_coordinates = (center_coords_x, center_coords_y) 18 | 19 | if radius is None: 20 | radius = np.random.randint(10, 50) 21 | 22 | if color is None: 23 | color = (0, 255, 0) 24 | 25 | thickness = -1 26 | image = np.array(image) 27 | 28 | if shape is None or shape == 'circle': 29 | image = cv2.circle(image, center_coordinates, radius, color, thickness) 30 | elif shape == 'rectangle': 31 | start_point = (center_coordinates[0] - 32 | radius, center_coordinates[1] - radius) 33 | end_point = (center_coordinates[0] + 34 | radius, center_coordinates[1] + radius) 35 | image = cv2.rectangle(image, start_point, end_point, color, thickness) 36 | else: 37 | raise ValueError("Wrong shape") 38 | 39 | return image, center_coordinates, radius, color 40 | 41 | 42 | WHITE = (255, 255, 255) 43 | BLACK = (0, 0, 0) 44 | RED = (255, 0, 0) 45 | GREEN = (0, 255, 0) 46 | BLUE = (0, 0, 255) 47 | 48 | 49 | def round_image(img, options=(WHITE, BLACK, RED, GREEN, BLUE), outputs=None, t=(0, 0, 0)): 50 | # img.shape == [224, 224, 3], img.dtype == torch.int32 51 | img = torch.tensor(img) 52 | t = torch.tensor((t)).to(img) 53 | options = torch.tensor(options) 54 | opts = options.view(len(options), 1, 1, 3).permute(1, 2, 3, 0).to(img) 55 | nn = (((img + t).unsqueeze(-1) - opts) ** 2).float().mean(dim=2) 56 | nn_indices = torch.argmin(nn, dim=-1) 57 | if outputs is None: 58 | outputs = options 59 | res_img = torch.tensor(outputs)[nn_indices] 60 | return res_img 61 | 62 | # fixed as circle 63 | class ColorChangeTask(Dataset): 64 | def __init__(self, transforms=None): 65 | self.transforms = transforms 66 | super(ColorChangeTask, self).__init__() 67 | 68 | def __len__(self): 69 | return 100 70 | 71 | def color_options(self, ): 72 | """The color options for the output. We take all of the colors that are in the image as possiblities""" 73 | return [BLUE, GREEN, WHITE] 74 | 75 | def __getitem__(self, index): 76 | image1, center_coordinates1, radius1, color1 = get_img( 77 | color=GREEN) 78 | image2, center_coordinates2, radius2, color2 = get_img( 79 | center_coordinates1, radius1, BLUE) # Get boxes 80 | if self.transforms is not None: 81 | return self.transforms(Image.fromarray(image1)), self.transforms(Image.fromarray(image2)) 82 | return image1, image2 83 | 84 | class SizeChangeTask(Dataset): 85 | def __init__(self, transforms=None): 86 | self.transforms = transforms 87 | super(SizeChangeTask, self).__init__() 88 | 89 | def __len__(self): 90 | return 100 91 | 92 | def color_options(self, ): 93 | """The color options for the output. We take all of the colors that are in the image as possiblities""" 94 | return [BLUE, GREEN, WHITE] 95 | 96 | def __getitem__(self, index): 97 | radius1 = 30 98 | radius2 = 20 99 | image1, center_coordinates1, radius1, color1 = get_img(color=GREEN, shape='circle', radius=radius1) 100 | image2, center_coordinates2, radius2, color1 = get_img( 101 | center_coordinates1, radius2, GREEN, shape='circle') # Get boxes 102 | if self.transforms is not None: 103 | return self.transforms(Image.fromarray(image1)), self.transforms(Image.fromarray(image2)) 104 | return image1, image2 105 | 106 | class ShapeChangeTask(Dataset): 107 | def __init__(self, transforms=None): 108 | self.transforms = transforms 109 | super(ShapeChangeTask, self).__init__() 110 | 111 | def __len__(self): 112 | return 100 113 | 114 | def color_options(self, ): 115 | """The color options for the output. We take all of the colors that are in the image as possiblities""" 116 | return [GREEN, WHITE] 117 | 118 | def __getitem__(self, index): 119 | image1, center_coordinates1, radius1, color1 = get_img( 120 | color=GREEN, shape='circle') 121 | image2, center_coordinates2, radius2, color2 = get_img( 122 | center_coordinates1, radius1, color1, shape='rectangle') # Get boxes 123 | if self.transforms is not None: 124 | return self.transforms(Image.fromarray(image1)), self.transforms(Image.fromarray(image2)) 125 | return image1, image2 126 | 127 | 128 | # This task is actually very hard for MAEVQGAN or any other inpainting task. We discussed that issue in the limitations. 129 | class ChangeLocationTask(Dataset): 130 | def __init__(self, transforms=None): 131 | self.transforms = transforms 132 | super(ChangeLocationTask, self).__init__() 133 | 134 | def __len__(self): 135 | return 100 136 | 137 | def color_options(self, ): 138 | """The color options for the output. We take all of the colors that are in the image as possiblities""" 139 | return [GREEN, WHITE] 140 | 141 | def __getitem__(self, index): 142 | image1, center_coordinates1, radius1, color1 = get_img( 143 | color=GREEN, shape='circle') 144 | center_coordinates2 = ( 145 | 223 - center_coordinates1[0], center_coordinates1[1]) 146 | image2, center_coordinates2, radius2, color2 = get_img( 147 | center_coordinates2, radius1, color1, shape='circle') # Get boxes 148 | if self.transforms is not None: 149 | return self.transforms(Image.fromarray(image1)), self.transforms(Image.fromarray(image2)) 150 | return image1, image2 151 | 152 | 153 | class ChangeLocationVFlipTask(Dataset): 154 | def __init__(self, transforms=None): 155 | self.transforms = transforms 156 | super(ChangeLocationVFlipTask, self).__init__() 157 | 158 | def __len__(self): 159 | return 100 160 | 161 | def color_options(self, ): 162 | """The color options for the output. We take all of the colors that are in the image as possiblities""" 163 | return [GREEN, WHITE] 164 | 165 | def __getitem__(self, index): 166 | image1, center_coordinates1, radius1, color1 = get_img( 167 | color=GREEN, shape='circle') 168 | center_coordinates2 = ( 169 | center_coordinates1[0], 223 - center_coordinates1[1]) 170 | image2, center_coordinates2, radius2, color2 = get_img( 171 | center_coordinates2, radius1, color1, shape='circle') # Get boxes 172 | if self.transforms is not None: 173 | return self.transforms(Image.fromarray(image1)), self.transforms(Image.fromarray(image2)) 174 | return image1, image2 175 | 176 | 177 | class ChangeLocationTransposeTask(Dataset): 178 | def __init__(self, transforms=None): 179 | self.transforms = transforms 180 | super(ChangeLocationTransposeTask, self).__init__() 181 | 182 | def __len__(self): 183 | return 100 184 | 185 | def color_options(self, ): 186 | """The color options for the output. We take all of the colors that are in the image as possiblities""" 187 | return [GREEN, WHITE] 188 | 189 | def __getitem__(self, index): 190 | image1, center_coordinates1, radius1, color1 = get_img( 191 | color=GREEN, shape='circle') 192 | image1 = Image.fromarray(image1) 193 | image2 = np.array(image1.transpose(5)) 194 | image1 = np.array(image1) 195 | if self.transforms is not None: 196 | return self.transforms(Image.fromarray(image1)), self.transforms(Image.fromarray(image2)) 197 | return image1, image2 198 | 199 | 200 | class ChangeLocationHShift(Dataset): 201 | def __init__(self, transforms=None): 202 | self.transforms = transforms 203 | super(ChangeLocationHShift, self).__init__() 204 | 205 | def __len__(self): 206 | return 100 207 | 208 | def color_options(self, ): 209 | """The color options for the output. We take all of the colors that are in the image as possiblities""" 210 | return [GREEN, WHITE] 211 | 212 | def __getitem__(self, index): 213 | h = w = 224 214 | shift = 50 215 | center_coords_x = np.random.randint(10, w - 10 - shift) 216 | center_coords_y = np.random.randint(10, h - 10) 217 | center_coordinates = (center_coords_x, center_coords_y) 218 | 219 | image1, center_coordinates1, radius1, color1 = get_img( 220 | center_coordinates, color=GREEN, shape='circle') 221 | center_coordinates2 = ( 222 | center_coordinates[0] + shift, center_coordinates1[1]) 223 | image2, center_coordinates2, radius2, color2 = get_img( 224 | center_coordinates2, radius1, color1, shape='circle') # Get boxes 225 | 226 | if self.transforms is not None: 227 | return self.transforms(Image.fromarray(image1)), self.transforms(Image.fromarray(image2)) 228 | return image1, image2 229 | 230 | 231 | class ChangeShapeColorTask(Dataset): 232 | def __init__(self, transforms=None): 233 | self.transforms = transforms 234 | super(ChangeShapeColorTask, self).__init__() 235 | 236 | def __len__(self): 237 | return 100 238 | 239 | def color_options(self, ): 240 | """The color options for the output. We take all of the colors that are in the image as possiblities""" 241 | return [GREEN, BLUE, WHITE] 242 | 243 | def __getitem__(self, index): 244 | image1, center_coordinates1, radius1, color1 = get_img( 245 | color=GREEN, shape='circle') 246 | image2, center_coordinates2, radius2, color2 = get_img( 247 | center_coordinates1, radius1, BLUE, shape='rectangle') # Get boxes 248 | if self.transforms is not None: 249 | return self.transforms(Image.fromarray(image1)), self.transforms(Image.fromarray(image2)) 250 | return image1, image2 251 | 252 | 253 | class ChangeLocationColorTask(Dataset): 254 | def __init__(self, transforms=None): 255 | self.transforms = transforms 256 | super(ChangeLocationColorTask, self).__init__() 257 | 258 | def __len__(self): 259 | return 100 260 | 261 | def color_options(self, ): 262 | """The color options for the output. We take all of the colors that are in the image as possiblities""" 263 | return [GREEN, BLUE, WHITE] 264 | 265 | def __getitem__(self, index): 266 | h = w = 224 267 | shift = 50 268 | center_coords_x = np.random.randint(10, w - 10 - shift) 269 | center_coords_y = np.random.randint(10, h - 10) 270 | center_coordinates = (center_coords_x, center_coords_y) 271 | 272 | image1, center_coordinates1, radius1, color1 = get_img( 273 | center_coordinates, color=GREEN, shape='circle') 274 | center_coordinates2 = ( 275 | center_coordinates[0] + shift, center_coordinates1[1]) 276 | image2, center_coordinates2, radius2, color2 = get_img( 277 | center_coordinates2, radius1, color=BLUE, shape='circle') # Get boxes 278 | 279 | if self.transforms is not None: 280 | return self.transforms(Image.fromarray(image1)), self.transforms(Image.fromarray(image2)) 281 | return image1, image2 282 | 283 | class ChangeSizeColorTask(Dataset): 284 | 285 | def __init__(self, transforms=None): 286 | self.transforms = transforms 287 | super(ChangeSizeColorTask, self).__init__() 288 | 289 | def __len__(self): 290 | return 100 291 | 292 | def color_options(self, ): 293 | """The color options for the output. We take all of the colors that are in the image as possiblities""" 294 | return [GREEN, BLUE, WHITE] 295 | 296 | def __getitem__(self, index): 297 | radius1 = 30 298 | radius2 = 20 299 | image1, center_coordinates1, radius1, color1 = get_img(color=GREEN, shape='circle', radius=radius1) 300 | image2, center_coordinates2, radius2, color1 = get_img( 301 | center_coordinates1, radius2, BLUE, shape='circle') # Get boxes 302 | if self.transforms is not None: 303 | return self.transforms(Image.fromarray(image1)), self.transforms(Image.fromarray(image2)) 304 | return image1, image2 305 | 306 | 307 | class ChangeSizeShapeTask(Dataset): 308 | def __init__(self, transforms=None): 309 | self.transforms = transforms 310 | super(ChangeSizeShapeTask, self).__init__() 311 | 312 | def __len__(self): 313 | return 100 314 | 315 | def color_options(self, ): 316 | """The color options for the output. We take all of the colors that are in the image as possiblities""" 317 | return [GREEN, WHITE] 318 | 319 | def __getitem__(self, index): 320 | radius1 = 30 321 | radius2 = 20 322 | image1, center_coordinates1, radius1, color1 = get_img(color=GREEN, shape='circle', radius=radius1) 323 | image2, center_coordinates2, radius2, color1 = get_img( 324 | center_coordinates1, radius2, GREEN, shape='rectangle') # Get boxes 325 | if self.transforms is not None: 326 | return self.transforms(Image.fromarray(image1)), self.transforms(Image.fromarray(image2)) 327 | return image1, image2 328 | 329 | 330 | def box_to_img(mask, target, border_width=4): 331 | if mask is None: 332 | mask = np.zeros((112, 112, 3)) 333 | h, w, _ = mask.shape 334 | for box in target['boxes']: 335 | x_min, y_min, x_max, y_max = list( 336 | (box * (h - 1)).round().int().numpy()) 337 | cv2.rectangle(mask, (x_min, y_min), (x_max, y_max), 338 | (255, 255, 255), border_width) 339 | return Image.fromarray(mask.astype('uint8')) 340 | 341 | 342 | def get_annotated_image(img, boxes, border_width=3, mode='draw', copy_img=True): 343 | if mode == 'draw': 344 | h, w, _ = img.shape 345 | if copy_img: 346 | image_copy = np.array(img.copy()) 347 | else: 348 | image_copy = np.array(Image.new('RGB', (w, h), color='black')) 349 | 350 | for box in boxes: 351 | box = box.numpy().astype('int') 352 | cv2.rectangle( 353 | image_copy, (box[0], box[1]), (box[2], box[3]), (255, 255, 255), border_width) 354 | elif mode == 'keep': 355 | h, w, _ = img.shape 356 | image_copy = np.array(Image.new('RGB', (w, h), color='white')) 357 | 358 | for box in boxes: 359 | box = box.numpy().astype('int') 360 | image_copy[box[1]:box[3], box[0]:box[2] 361 | ] = img[box[1]:box[3], box[0]:box[2]] 362 | 363 | return image_copy 364 | 365 | 366 | background_transforms = T.Compose([ 367 | T.Resize((224, 224)), 368 | T.Compose([ 369 | T.ToTensor(), 370 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 371 | ]) 372 | ]) 373 | 374 | 375 | def create_grid_from_images(canvas, pairs, padding, figure_size): 376 | 377 | for i in range(len(pairs)): 378 | img, label = pairs[i] 379 | start_row = i*(figure_size + padding) 380 | 381 | canvas[:, start_row:start_row + figure_size, 224//2 - figure_size :224//2] = img 382 | canvas[:, start_row:start_row + figure_size, 224//2 +1 : 224//2 +1 + figure_size] = label 383 | 384 | return canvas 385 | 386 | if __name__ == "__main__": 387 | ds = ChangeLocationTransposeTask() 388 | img1, img2 = ds[10] -------------------------------------------------------------------------------- /evaluate/segmentation_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from evaluate.mae_utils import WHITE, YELLOW, PURPLE, BLACK 4 | 5 | 6 | def calculate_metric(args, target, ours, fg_color=WHITE, bg_color=BLACK): 7 | # Crop the right area: 8 | target = target[113:, 113:] 9 | ours = ours[113:, 113:] 10 | return _calc_metric(ours, target, fg_color, bg_color) 11 | 12 | 13 | def _calc_metric(ours, target, fg_color=WHITE, bg_color=BLACK): 14 | fg_color = np.array(fg_color) 15 | # Calculate accuracy: 16 | accuracy = np.sum(np.float32((target == ours).all(axis=2))) / (ours.shape[0] * ours.shape[1]) 17 | seg_orig = ((target - fg_color[np.newaxis, np.newaxis, :]) == 0).all(axis=2) 18 | seg_our = ((ours - fg_color[np.newaxis, np.newaxis, :]) == 0).all(axis=2) 19 | color_blind_seg_our = (ours - np.array([[bg_color]]) != 0).any(axis=2) 20 | iou = np.sum(np.float32(seg_orig & seg_our)) / np.sum(np.float32(seg_orig | seg_our)) 21 | color_blind_iou = np.sum(np.float32(seg_orig & color_blind_seg_our)) / np.sum( 22 | np.float32(seg_orig | color_blind_seg_our)) 23 | return {'iou': iou, 'color_blind_iou': color_blind_iou, 'accuracy': accuracy} 24 | 25 | 26 | def get_default_mask_1row_mask(): 27 | mask = np.zeros((14,14)) 28 | mask[:7] = 1 29 | mask[:, :7] = 1 30 | return mask -------------------------------------------------------------------------------- /evaluate/splits/coco/trn/fold0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirbar/visual_prompting/6351e0d786c2d25bfdd5fe504153b1dfb345fa81/evaluate/splits/coco/trn/fold0.pkl -------------------------------------------------------------------------------- /evaluate/splits/coco/trn/fold1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirbar/visual_prompting/6351e0d786c2d25bfdd5fe504153b1dfb345fa81/evaluate/splits/coco/trn/fold1.pkl -------------------------------------------------------------------------------- /evaluate/splits/coco/trn/fold2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirbar/visual_prompting/6351e0d786c2d25bfdd5fe504153b1dfb345fa81/evaluate/splits/coco/trn/fold2.pkl -------------------------------------------------------------------------------- /evaluate/splits/coco/trn/fold3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirbar/visual_prompting/6351e0d786c2d25bfdd5fe504153b1dfb345fa81/evaluate/splits/coco/trn/fold3.pkl -------------------------------------------------------------------------------- /evaluate/splits/coco/val/fold0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirbar/visual_prompting/6351e0d786c2d25bfdd5fe504153b1dfb345fa81/evaluate/splits/coco/val/fold0.pkl -------------------------------------------------------------------------------- /evaluate/splits/coco/val/fold1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirbar/visual_prompting/6351e0d786c2d25bfdd5fe504153b1dfb345fa81/evaluate/splits/coco/val/fold1.pkl -------------------------------------------------------------------------------- /evaluate/splits/coco/val/fold2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirbar/visual_prompting/6351e0d786c2d25bfdd5fe504153b1dfb345fa81/evaluate/splits/coco/val/fold2.pkl -------------------------------------------------------------------------------- /evaluate/splits/coco/val/fold3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirbar/visual_prompting/6351e0d786c2d25bfdd5fe504153b1dfb345fa81/evaluate/splits/coco/val/fold3.pkl -------------------------------------------------------------------------------- /evaluate/splits/fss/test.txt: -------------------------------------------------------------------------------- 1 | bus 2 | hotel_slipper 3 | burj_al 4 | reflex_camera 5 | abe's_flyingfish 6 | oiltank_car 7 | doormat 8 | fish_eagle 9 | barber_shaver 10 | motorbike 11 | feather_clothes 12 | wandering_albatross 13 | rice_cooker 14 | delta_wing 15 | fish 16 | nintendo_switch 17 | bustard 18 | diver 19 | minicooper 20 | cathedrale_paris 21 | big_ben 22 | combination_lock 23 | villa_savoye 24 | american_alligator 25 | gym_ball 26 | andean_condor 27 | leggings 28 | pyramid_cube 29 | jet_aircraft 30 | meatloaf 31 | reel 32 | swan 33 | osprey 34 | crt_screen 35 | microscope 36 | rubber_eraser 37 | arrow 38 | monkey 39 | mitten 40 | spiderman 41 | parthenon 42 | bat 43 | chess_king 44 | sulphur_butterfly 45 | quail_egg 46 | oriole 47 | iron_man 48 | wooden_boat 49 | anise 50 | steering_wheel 51 | groenendael 52 | dwarf_beans 53 | pteropus 54 | chalk_brush 55 | bloodhound 56 | moon 57 | english_foxhound 58 | boxing_gloves 59 | peregine_falcon 60 | pyraminx 61 | cicada 62 | screw 63 | shower_curtain 64 | tredmill 65 | bulb 66 | bell_pepper 67 | lemur_catta 68 | doughnut 69 | twin_tower 70 | astronaut 71 | nintendo_3ds 72 | fennel_bulb 73 | indri 74 | captain_america_shield 75 | kunai 76 | broom 77 | iphone 78 | earphone1 79 | flying_squirrel 80 | onion 81 | vinyl 82 | sydney_opera_house 83 | oyster 84 | harmonica 85 | egg 86 | breast_pump 87 | guitar 88 | potato_chips 89 | tunnel 90 | cuckoo 91 | rubick_cube 92 | plastic_bag 93 | phonograph 94 | net_surface_shoes 95 | goldfinch 96 | ipad 97 | mite_predator 98 | coffee_mug 99 | golden_plover 100 | f1_racing 101 | lapwing 102 | nintendo_gba 103 | pizza 104 | rally_car 105 | drilling_platform 106 | cd 107 | fly 108 | magpie_bird 109 | leaf_fan 110 | little_blue_heron 111 | carriage 112 | moist_proof_pad 113 | flying_snakes 114 | dart_target 115 | warehouse_tray 116 | nintendo_wiiu 117 | chiffon_cake 118 | bath_ball 119 | manatee 120 | cloud 121 | marimba 122 | eagle 123 | ruler 124 | soymilk_machine 125 | sled 126 | seagull 127 | glider_flyingfish 128 | doublebus 129 | transport_helicopter 130 | window_screen 131 | truss_bridge 132 | wasp 133 | snowman 134 | poached_egg 135 | strawberry 136 | spinach 137 | earphone2 138 | downy_pitch 139 | taj_mahal 140 | rocking_chair 141 | cablestayed_bridge 142 | sealion 143 | banana_boat 144 | pheasant 145 | stone_lion 146 | electronic_stove 147 | fox 148 | iguana 149 | rugby_ball 150 | hang_glider 151 | water_buffalo 152 | lotus 153 | paper_plane 154 | missile 155 | flamingo 156 | american_chamelon 157 | kart 158 | chinese_knot 159 | cabbage_butterfly 160 | key 161 | church 162 | tiltrotor 163 | helicopter 164 | french_fries 165 | water_heater 166 | snow_leopard 167 | goblet 168 | fan 169 | snowplow 170 | leafhopper 171 | pspgo 172 | black_bear 173 | quail 174 | condor 175 | chandelier 176 | hair_razor 177 | white_wolf 178 | toaster 179 | pidan 180 | pyramid 181 | chicken_leg 182 | letter_opener 183 | apple_icon 184 | porcupine 185 | chicken 186 | stingray 187 | warplane 188 | windmill 189 | bamboo_slip 190 | wig 191 | flying_geckos 192 | stonechat 193 | haddock 194 | australian_terrier 195 | hover_board 196 | siamang 197 | canton_tower 198 | santa_sledge 199 | arch_bridge 200 | curlew 201 | sushi 202 | beet_root 203 | accordion 204 | leaf_egg 205 | stealth_aircraft 206 | stork 207 | bucket 208 | hawk 209 | chess_queen 210 | ocarina 211 | knife 212 | whippet 213 | cantilever_bridge 214 | may_bug 215 | wagtail 216 | leather_shoes 217 | wheelchair 218 | shumai 219 | speedboat 220 | vacuum_cup 221 | chess_knight 222 | pumpkin_pie 223 | wooden_spoon 224 | bamboo_dragonfly 225 | ganeva_chair 226 | soap 227 | clearwing_flyingfish 228 | pencil_sharpener1 229 | cricket 230 | photocopier 231 | nintendo_sp 232 | samarra_mosque 233 | clam 234 | charge_battery 235 | flying_frog 236 | ferrari911 237 | polo_shirt 238 | echidna 239 | coin 240 | tower_pisa 241 | -------------------------------------------------------------------------------- /evaluate/splits/fss/trn.txt: -------------------------------------------------------------------------------- 1 | fountain 2 | taxi 3 | assult_rifle 4 | radio 5 | comb 6 | box_turtle 7 | igloo 8 | head_cabbage 9 | cottontail 10 | coho 11 | ashtray 12 | joystick 13 | sleeping_bag 14 | jackfruit 15 | trailer_truck 16 | shower_cap 17 | ibex 18 | kinguin 19 | squirrel 20 | ac_wall 21 | sidewinder 22 | remote_control 23 | marshmallow 24 | bolotie 25 | polar_bear 26 | rock_beauty 27 | tokyo_tower 28 | wafer 29 | red_bayberry 30 | electronic_toothbrush 31 | hartebeest 32 | cassette 33 | oil_filter 34 | bomb 35 | walnut 36 | toilet_tissue 37 | memory_stick 38 | wild_boar 39 | cableways 40 | chihuahua 41 | envelope 42 | bison 43 | poker 44 | pubg_lvl3helmet 45 | indian_cobra 46 | staffordshire 47 | park_bench 48 | wombat 49 | black_grouse 50 | submarine 51 | washer 52 | agama 53 | coyote 54 | feeder 55 | sarong 56 | buckingham_palace 57 | frog 58 | steam_locomotive 59 | acorn 60 | german_pointer 61 | obelisk 62 | polecat 63 | black_swan 64 | butterfly 65 | mountain_tent 66 | gorilla 67 | sloth_bear 68 | aubergine 69 | stinkhorn 70 | stole 71 | owl 72 | mooli 73 | pool_table 74 | collar 75 | lhasa_apso 76 | ambulance 77 | spade 78 | pufferfish 79 | paint_brush 80 | lark 81 | golf_ball 82 | hock 83 | fork 84 | drake 85 | bee_house 86 | mooncake 87 | wok 88 | cocacola 89 | water_bike 90 | ladder 91 | psp 92 | bassoon 93 | bear 94 | border_terrier 95 | petri_dish 96 | pill_bottle 97 | aircraft_carrier 98 | panther 99 | canoe 100 | baseball_player 101 | turtle 102 | espresso 103 | throne 104 | cornet 105 | coucal 106 | eletrical_switch 107 | bra 108 | snail 109 | backpack 110 | jacamar 111 | scroll_brush 112 | gliding_lizard 113 | raft 114 | pinwheel 115 | grasshopper 116 | green_mamba 117 | eft_newt 118 | computer_mouse 119 | vine_snake 120 | recreational_vehicle 121 | llama 122 | meerkat 123 | chainsaw 124 | ferret 125 | garbage_can 126 | kangaroo 127 | litchi 128 | carbonara 129 | housefinch 130 | modem 131 | tebby_cat 132 | thatch 133 | face_powder 134 | tomb 135 | apple 136 | ladybug 137 | killer_whale 138 | rocket 139 | airship 140 | surfboard 141 | lesser_panda 142 | jordan_logo 143 | banana 144 | nail_scissor 145 | swab 146 | perfume 147 | punching_bag 148 | victor_icon 149 | waffle_iron 150 | trimaran 151 | garlic 152 | flute 153 | langur 154 | starfish 155 | parallel_bars 156 | dandie_dinmont 157 | cosmetic_brush 158 | screwdriver 159 | brick_card 160 | balance_weight 161 | hornet 162 | carton 163 | toothpaste 164 | bracelet 165 | egg_tart 166 | pencil_sharpener2 167 | swimming_glasses 168 | howler_monkey 169 | camel 170 | dragonfly 171 | lionfish 172 | convertible 173 | mule 174 | usb 175 | conch 176 | papaya 177 | garbage_truck 178 | dingo 179 | radiator 180 | solar_dish 181 | streetcar 182 | trilobite 183 | bouzouki 184 | ringlet_butterfly 185 | space_shuttle 186 | waffle 187 | american_staffordshire 188 | violin 189 | flowerpot 190 | forklift 191 | manx 192 | sundial 193 | snowmobile 194 | chickadee_bird 195 | ruffed_grouse 196 | brick_tea 197 | paddle 198 | stove 199 | carousel 200 | spatula 201 | beaker 202 | gas_pump 203 | lawn_mower 204 | speaker 205 | tank 206 | tresher 207 | kappa_logo 208 | hare 209 | tennis_racket 210 | shopping_cart 211 | thimble 212 | tractor 213 | anemone_fish 214 | trolleybus 215 | steak 216 | capuchin 217 | red_breasted_merganser 218 | golden_retriever 219 | light_tube 220 | flatworm 221 | melon_seed 222 | digital_watch 223 | jacko_lantern 224 | brown_bear 225 | cairn 226 | mushroom 227 | chalk 228 | skull 229 | stapler 230 | potato 231 | telescope 232 | proboscis 233 | microphone 234 | torii 235 | baseball_bat 236 | dhole 237 | excavator 238 | fig 239 | snake 240 | bradypod 241 | pepitas 242 | prairie_chicken 243 | scorpion 244 | shotgun 245 | bottle_cap 246 | file_cabinet 247 | grey_whale 248 | one-armed_bandit 249 | banded_gecko 250 | flying_disc 251 | croissant 252 | toothbrush 253 | miniskirt 254 | pokermon_ball 255 | gazelle 256 | grey_fox 257 | esport_chair 258 | necklace 259 | ptarmigan 260 | watermelon 261 | besom 262 | pomelo 263 | radio_telescope 264 | studio_couch 265 | black_stork 266 | vestment 267 | koala 268 | brambling 269 | muscle_car 270 | window_shade 271 | space_heater 272 | sunglasses 273 | motor_scooter 274 | ladyfinger 275 | pencil_box 276 | titi_monkey 277 | chicken_wings 278 | mount_fuji 279 | giant_panda 280 | dart 281 | fire_engine 282 | running_shoe 283 | dumbbell 284 | donkey 285 | loafer 286 | hard_disk 287 | globe 288 | lifeboat 289 | medical_kit 290 | brain_coral 291 | paper_towel 292 | dugong 293 | seatbelt 294 | skunk 295 | military_vest 296 | cocktail_shaker 297 | zucchini 298 | quad_drone 299 | ocicat 300 | shih-tzu 301 | teapot 302 | tile_roof 303 | cheese_burger 304 | handshower 305 | red_wolf 306 | stop_sign 307 | mouse 308 | battery 309 | adidas_logo2 310 | earplug 311 | hummingbird 312 | brush_pen 313 | pistachio 314 | hamster 315 | air_strip 316 | indian_elephant 317 | otter 318 | cucumber 319 | scabbard 320 | hawthorn 321 | bullet_train 322 | leopard 323 | whale 324 | cream 325 | chinese_date 326 | jellyfish 327 | lobster 328 | skua 329 | single_log 330 | chicory 331 | bagel 332 | beacon 333 | pingpong_racket 334 | spoon 335 | yurt 336 | wallaby 337 | egret 338 | christmas_stocking 339 | mcdonald_uncle 340 | wrench 341 | spark_plug 342 | triceratops 343 | wall_clock 344 | jinrikisha 345 | pickup 346 | rhinoceros 347 | swimming_trunk 348 | band-aid 349 | spotted_salamander 350 | leeks 351 | marmot 352 | warthog 353 | cello 354 | stool 355 | chest 356 | toilet_plunger 357 | wardrobe 358 | cannon 359 | adidas_logo1 360 | drumstick 361 | lady_slipper 362 | puma_logo 363 | great_wall 364 | white_shark 365 | witch_hat 366 | vending_machine 367 | wreck 368 | chopsticks 369 | garfish 370 | african_elephant 371 | children_slide 372 | hornbill 373 | zebra 374 | boa_constrictor 375 | armour 376 | pineapple 377 | angora 378 | brick 379 | car_wheel 380 | wallet 381 | boston_bull 382 | hyena 383 | lynx 384 | crash_helmet 385 | terrapin_turtle 386 | persian_cat 387 | shift_gear 388 | cactus_ball 389 | fur_coat 390 | plate 391 | pen 392 | okra 393 | mario 394 | airedale 395 | cowboy_hat 396 | celery 397 | macaque 398 | candle 399 | goose 400 | raccoon 401 | brasscica 402 | almond 403 | maotai_bottle 404 | soccer_ball 405 | sports_car 406 | tobacco_pipe 407 | water_polo 408 | eggnog 409 | hook 410 | ostrich 411 | patas 412 | table_lamp 413 | teddy 414 | mongoose 415 | spoonbill 416 | redheart 417 | crane 418 | dinosaur 419 | kitchen_knife 420 | seal 421 | baboon 422 | golfcart 423 | roller_coaster 424 | avocado 425 | birdhouse 426 | yorkshire_terrier 427 | saluki 428 | basketball 429 | buckler 430 | harvester 431 | afghan_hound 432 | beam_bridge 433 | guinea_pig 434 | lorikeet 435 | shakuhachi 436 | motarboard 437 | statue_liberty 438 | police_car 439 | sulphur_crested 440 | gourd 441 | sombrero 442 | mailbox 443 | adhensive_tape 444 | night_snake 445 | bushtit 446 | mouthpiece 447 | beaver 448 | bathtub 449 | printer 450 | cumquat 451 | orange 452 | cleaver 453 | quill_pen 454 | panpipe 455 | diamond 456 | gypsy_moth 457 | cauliflower 458 | lampshade 459 | cougar 460 | traffic_light 461 | briefcase 462 | ballpoint 463 | african_grey 464 | kremlin 465 | barometer 466 | peacock 467 | paper_crane 468 | sunscreen 469 | tofu 470 | bedlington_terrier 471 | snowball 472 | carrot 473 | tiger 474 | mink 475 | cristo_redentor 476 | ladle 477 | keyboard 478 | maraca 479 | monitor 480 | water_snake 481 | can_opener 482 | mud_turtle 483 | bald_eagle 484 | carp 485 | cn_tower 486 | egyptian_cat 487 | hen_of_the_woods 488 | measuring_cup 489 | roller_skate 490 | kite 491 | sandwich_cookies 492 | sandwich 493 | persimmon 494 | chess_bishop 495 | coffin 496 | ruddy_turnstone 497 | prayer_rug 498 | rain_barrel 499 | neck_brace 500 | nematode 501 | rosehip 502 | dutch_oven 503 | goldfish 504 | blossom_card 505 | dough 506 | trench_coat 507 | sponge 508 | stupa 509 | wash_basin 510 | electric_fan 511 | spring_scroll 512 | potted_plant 513 | sparrow 514 | car_mirror 515 | gecko 516 | diaper 517 | leatherback_turtle 518 | strainer 519 | guacamole 520 | microwave 521 | -------------------------------------------------------------------------------- /evaluate/splits/fss/val.txt: -------------------------------------------------------------------------------- 1 | handcuff 2 | mortar 3 | matchstick 4 | wine_bottle 5 | dowitcher 6 | triumphal_arch 7 | gyromitra 8 | hatchet 9 | airliner 10 | broccoli 11 | olive 12 | pubg_lvl3backpack 13 | calculator 14 | toucan 15 | shovel 16 | sewing_machine 17 | icecream 18 | woodpecker 19 | pig 20 | relay_stick 21 | mcdonald_sign 22 | cpu 23 | peanut 24 | pumpkin 25 | sturgeon 26 | hammer 27 | hami_melon 28 | squirrel_monkey 29 | shuriken 30 | power_drill 31 | pingpong_ball 32 | crocodile 33 | carambola 34 | monarch_butterfly 35 | drum 36 | water_tower 37 | panda 38 | toilet_brush 39 | pay_phone 40 | yonex_icon 41 | cricketball 42 | revolver 43 | chimpanzee 44 | crab 45 | corn 46 | baseball 47 | rabbit 48 | croquet_ball 49 | artichoke 50 | abacus 51 | harp 52 | bell 53 | gas_tank 54 | scissors 55 | vase 56 | upright_piano 57 | typewriter 58 | bittern 59 | impala 60 | tray 61 | fire_hydrant 62 | beer_bottle 63 | sock 64 | soup_bowl 65 | spider 66 | cherry 67 | macaw 68 | toilet_seat 69 | fire_balloon 70 | french_ball 71 | fox_squirrel 72 | volleyball 73 | cornmeal 74 | folding_chair 75 | pubg_airdrop 76 | beagle 77 | skateboard 78 | narcissus 79 | whiptail 80 | cup 81 | arabian_camel 82 | badger 83 | stopwatch 84 | ab_wheel 85 | ox 86 | lettuce 87 | monocycle 88 | redshank 89 | vulture 90 | whistle 91 | smoothing_iron 92 | mashed_potato 93 | conveyor 94 | yoga_pad 95 | tow_truck 96 | siamese_cat 97 | cigar 98 | white_stork 99 | sniper_rifle 100 | stretcher 101 | tulip 102 | handkerchief 103 | basset 104 | iceberg 105 | gibbon 106 | lacewing 107 | thrush 108 | cheetah 109 | bighorn_sheep 110 | espresso_maker 111 | pretzel 112 | english_setter 113 | sandbar 114 | cheese 115 | daisy 116 | arctic_fox 117 | briard 118 | colubus 119 | balance_beam 120 | coffeepot 121 | soap_dispenser 122 | yawl 123 | consomme 124 | parking_meter 125 | cactus 126 | turnstile 127 | taro 128 | fire_screen 129 | digital_clock 130 | rose 131 | pomegranate 132 | bee_eater 133 | schooner 134 | ski_mask 135 | jay_bird 136 | plaice 137 | red_fox 138 | syringe 139 | camomile 140 | pickelhaube 141 | blenheim_spaniel 142 | pear 143 | parachute 144 | common_newt 145 | bowtie 146 | cigarette 147 | oscilloscope 148 | laptop 149 | african_crocodile 150 | apron 151 | coconut 152 | sandal 153 | kwanyin 154 | lion 155 | eel 156 | balloon 157 | crepe 158 | armadillo 159 | kazoo 160 | lemon 161 | spider_monkey 162 | tape_player 163 | ipod 164 | bee 165 | sea_cucumber 166 | suitcase 167 | television 168 | pillow 169 | banjo 170 | rock_snake 171 | partridge 172 | platypus 173 | lycaenid_butterfly 174 | pinecone 175 | conversion_plug 176 | wolf 177 | frying_pan 178 | timber_wolf 179 | bluetick 180 | crayon 181 | giant_schnauzer 182 | orang 183 | scarerow 184 | kobe_logo 185 | loguat 186 | saxophone 187 | ceiling_fan 188 | cardoon 189 | equestrian_helmet 190 | louvre_pyramid 191 | hotdog 192 | ironing_board 193 | razor 194 | nagoya_castle 195 | loggerhead_turtle 196 | lipstick 197 | cradle 198 | strongbox 199 | raven 200 | kit_fox 201 | albatross 202 | flat-coated_retriever 203 | beer_glass 204 | ice_lolly 205 | sungnyemun 206 | totem_pole 207 | vacuum 208 | bolete 209 | mango 210 | ginger 211 | weasel 212 | cabbage 213 | refrigerator 214 | school_bus 215 | hippo 216 | tiger_cat 217 | saltshaker 218 | piano_keyboard 219 | windsor_tie 220 | sea_urchin 221 | microsd 222 | barbell 223 | swim_ring 224 | bulbul_bird 225 | water_ouzel 226 | ac_ground 227 | sweatshirt 228 | umbrella 229 | hair_drier 230 | hammerhead_shark 231 | tomato 232 | projector 233 | cushion 234 | dishwasher 235 | three-toed_sloth 236 | tiger_shark 237 | har_gow 238 | baby 239 | thor's_hammer 240 | nike_logo 241 | -------------------------------------------------------------------------------- /evaluate/splits/pascal/val/fold0.txt: -------------------------------------------------------------------------------- 1 | 2007_000033__01 2 | 2007_000061__04 3 | 2007_000129__02 4 | 2007_000346__05 5 | 2007_000529__04 6 | 2007_000559__05 7 | 2007_000572__02 8 | 2007_000762__05 9 | 2007_001288__01 10 | 2007_001289__03 11 | 2007_001311__02 12 | 2007_001408__05 13 | 2007_001568__01 14 | 2007_001630__02 15 | 2007_001761__01 16 | 2007_001884__01 17 | 2007_002094__03 18 | 2007_002266__01 19 | 2007_002376__01 20 | 2007_002400__03 21 | 2007_002619__01 22 | 2007_002719__04 23 | 2007_003088__05 24 | 2007_003131__04 25 | 2007_003188__02 26 | 2007_003349__03 27 | 2007_003571__04 28 | 2007_003621__02 29 | 2007_003682__03 30 | 2007_003861__04 31 | 2007_004052__01 32 | 2007_004143__03 33 | 2007_004241__04 34 | 2007_004468__05 35 | 2007_005074__04 36 | 2007_005107__02 37 | 2007_005294__05 38 | 2007_005304__05 39 | 2007_005428__05 40 | 2007_005509__01 41 | 2007_005600__01 42 | 2007_005705__04 43 | 2007_005828__01 44 | 2007_006076__03 45 | 2007_006086__05 46 | 2007_006449__02 47 | 2007_006946__01 48 | 2007_007084__03 49 | 2007_007235__02 50 | 2007_007341__01 51 | 2007_007470__01 52 | 2007_007477__04 53 | 2007_007836__02 54 | 2007_008051__03 55 | 2007_008084__03 56 | 2007_008204__05 57 | 2007_008670__03 58 | 2007_009088__03 59 | 2007_009258__02 60 | 2007_009323__03 61 | 2007_009458__05 62 | 2007_009687__05 63 | 2007_009817__03 64 | 2007_009911__01 65 | 2008_000120__04 66 | 2008_000123__03 67 | 2008_000533__03 68 | 2008_000725__02 69 | 2008_000911__05 70 | 2008_001013__04 71 | 2008_001040__04 72 | 2008_001135__04 73 | 2008_001260__04 74 | 2008_001404__02 75 | 2008_001514__03 76 | 2008_001531__02 77 | 2008_001546__01 78 | 2008_001580__04 79 | 2008_001966__03 80 | 2008_001971__01 81 | 2008_002043__03 82 | 2008_002269__02 83 | 2008_002358__01 84 | 2008_002429__03 85 | 2008_002467__05 86 | 2008_002504__04 87 | 2008_002775__05 88 | 2008_002864__05 89 | 2008_003034__04 90 | 2008_003076__05 91 | 2008_003108__02 92 | 2008_003110__03 93 | 2008_003155__01 94 | 2008_003270__02 95 | 2008_003369__01 96 | 2008_003858__04 97 | 2008_003876__01 98 | 2008_003886__04 99 | 2008_003926__01 100 | 2008_003976__01 101 | 2008_004363__02 102 | 2008_004654__02 103 | 2008_004659__05 104 | 2008_004704__01 105 | 2008_004758__02 106 | 2008_004995__02 107 | 2008_005262__05 108 | 2008_005338__01 109 | 2008_005628__04 110 | 2008_005727__02 111 | 2008_005812__05 112 | 2008_005904__05 113 | 2008_006216__01 114 | 2008_006229__04 115 | 2008_006254__02 116 | 2008_006703__01 117 | 2008_007120__03 118 | 2008_007143__04 119 | 2008_007219__05 120 | 2008_007350__01 121 | 2008_007498__03 122 | 2008_007811__05 123 | 2008_007994__03 124 | 2008_008268__03 125 | 2008_008629__02 126 | 2008_008711__02 127 | 2008_008746__03 128 | 2009_000032__01 129 | 2009_000037__03 130 | 2009_000121__05 131 | 2009_000149__02 132 | 2009_000201__05 133 | 2009_000205__01 134 | 2009_000318__03 135 | 2009_000354__02 136 | 2009_000387__01 137 | 2009_000421__04 138 | 2009_000440__01 139 | 2009_000446__04 140 | 2009_000457__02 141 | 2009_000469__04 142 | 2009_000573__02 143 | 2009_000619__03 144 | 2009_000664__03 145 | 2009_000723__04 146 | 2009_000828__04 147 | 2009_000840__05 148 | 2009_000879__03 149 | 2009_000991__03 150 | 2009_000998__03 151 | 2009_001108__03 152 | 2009_001160__03 153 | 2009_001255__02 154 | 2009_001278__05 155 | 2009_001314__03 156 | 2009_001332__01 157 | 2009_001565__03 158 | 2009_001607__03 159 | 2009_001683__03 160 | 2009_001718__02 161 | 2009_001765__03 162 | 2009_001818__05 163 | 2009_001850__01 164 | 2009_001851__01 165 | 2009_001941__04 166 | 2009_002185__05 167 | 2009_002295__02 168 | 2009_002320__01 169 | 2009_002372__05 170 | 2009_002521__05 171 | 2009_002594__05 172 | 2009_002604__03 173 | 2009_002649__05 174 | 2009_002727__04 175 | 2009_002732__05 176 | 2009_002749__05 177 | 2009_002808__01 178 | 2009_002856__05 179 | 2009_002888__01 180 | 2009_002928__02 181 | 2009_003003__05 182 | 2009_003005__01 183 | 2009_003043__04 184 | 2009_003080__04 185 | 2009_003193__02 186 | 2009_003224__02 187 | 2009_003269__05 188 | 2009_003273__03 189 | 2009_003343__02 190 | 2009_003378__03 191 | 2009_003450__03 192 | 2009_003498__03 193 | 2009_003504__04 194 | 2009_003517__05 195 | 2009_003640__03 196 | 2009_003696__01 197 | 2009_003707__04 198 | 2009_003806__01 199 | 2009_003858__03 200 | 2009_003971__02 201 | 2009_004021__03 202 | 2009_004084__03 203 | 2009_004125__04 204 | 2009_004247__05 205 | 2009_004324__05 206 | 2009_004509__03 207 | 2009_004540__03 208 | 2009_004568__03 209 | 2009_004579__05 210 | 2009_004635__04 211 | 2009_004653__01 212 | 2009_004848__02 213 | 2009_004882__02 214 | 2009_004886__03 215 | 2009_004895__03 216 | 2009_004969__01 217 | 2009_005038__05 218 | 2009_005137__03 219 | 2009_005156__02 220 | 2009_005189__01 221 | 2009_005190__05 222 | 2009_005260__03 223 | 2009_005262__03 224 | 2009_005302__05 225 | 2010_000065__02 226 | 2010_000083__02 227 | 2010_000084__04 228 | 2010_000238__01 229 | 2010_000241__03 230 | 2010_000272__04 231 | 2010_000342__02 232 | 2010_000426__05 233 | 2010_000572__01 234 | 2010_000622__01 235 | 2010_000814__03 236 | 2010_000906__04 237 | 2010_000961__03 238 | 2010_001016__03 239 | 2010_001017__01 240 | 2010_001024__01 241 | 2010_001036__04 242 | 2010_001061__03 243 | 2010_001069__03 244 | 2010_001174__01 245 | 2010_001367__02 246 | 2010_001367__05 247 | 2010_001448__01 248 | 2010_001830__05 249 | 2010_001995__03 250 | 2010_002017__05 251 | 2010_002030__02 252 | 2010_002142__03 253 | 2010_002147__01 254 | 2010_002150__04 255 | 2010_002200__01 256 | 2010_002310__01 257 | 2010_002536__02 258 | 2010_002546__04 259 | 2010_002693__02 260 | 2010_002939__01 261 | 2010_003127__01 262 | 2010_003132__01 263 | 2010_003168__03 264 | 2010_003362__03 265 | 2010_003365__01 266 | 2010_003418__03 267 | 2010_003468__05 268 | 2010_003473__03 269 | 2010_003495__01 270 | 2010_003547__04 271 | 2010_003716__01 272 | 2010_003771__03 273 | 2010_003781__05 274 | 2010_003820__03 275 | 2010_003912__02 276 | 2010_003915__01 277 | 2010_004041__04 278 | 2010_004056__05 279 | 2010_004208__04 280 | 2010_004314__01 281 | 2010_004419__01 282 | 2010_004520__05 283 | 2010_004529__05 284 | 2010_004551__05 285 | 2010_004556__03 286 | 2010_004559__03 287 | 2010_004662__04 288 | 2010_004772__04 289 | 2010_004828__05 290 | 2010_004994__03 291 | 2010_005252__04 292 | 2010_005401__04 293 | 2010_005428__03 294 | 2010_005496__05 295 | 2010_005531__03 296 | 2010_005534__01 297 | 2010_005582__05 298 | 2010_005664__02 299 | 2010_005705__04 300 | 2010_005718__01 301 | 2010_005762__05 302 | 2010_005877__01 303 | 2010_005888__01 304 | 2010_006034__01 305 | 2010_006070__02 306 | 2011_000066__05 307 | 2011_000112__03 308 | 2011_000185__03 309 | 2011_000234__04 310 | 2011_000238__04 311 | 2011_000412__02 312 | 2011_000435__04 313 | 2011_000456__03 314 | 2011_000482__03 315 | 2011_000585__02 316 | 2011_000669__03 317 | 2011_000747__05 318 | 2011_000874__01 319 | 2011_001114__01 320 | 2011_001161__04 321 | 2011_001263__01 322 | 2011_001287__03 323 | 2011_001407__01 324 | 2011_001421__03 325 | 2011_001434__01 326 | 2011_001589__04 327 | 2011_001624__01 328 | 2011_001793__04 329 | 2011_001880__01 330 | 2011_001988__02 331 | 2011_002064__02 332 | 2011_002098__05 333 | 2011_002223__02 334 | 2011_002295__03 335 | 2011_002327__01 336 | 2011_002515__01 337 | 2011_002675__01 338 | 2011_002713__02 339 | 2011_002754__04 340 | 2011_002863__05 341 | 2011_002929__01 342 | 2011_002975__04 343 | 2011_003003__02 344 | 2011_003030__03 345 | 2011_003145__03 346 | 2011_003271__05 347 | -------------------------------------------------------------------------------- /evaluate/splits/pascal/val/fold1.txt: -------------------------------------------------------------------------------- 1 | 2007_000452__09 2 | 2007_000464__10 3 | 2007_000491__10 4 | 2007_000663__06 5 | 2007_000663__07 6 | 2007_000727__06 7 | 2007_000727__07 8 | 2007_000804__09 9 | 2007_000830__09 10 | 2007_001299__10 11 | 2007_001321__07 12 | 2007_001457__09 13 | 2007_001677__09 14 | 2007_001717__09 15 | 2007_001763__08 16 | 2007_001774__08 17 | 2007_001884__06 18 | 2007_002268__08 19 | 2007_002387__10 20 | 2007_002445__08 21 | 2007_002470__08 22 | 2007_002539__06 23 | 2007_002597__08 24 | 2007_002643__07 25 | 2007_002903__10 26 | 2007_003011__09 27 | 2007_003051__07 28 | 2007_003101__06 29 | 2007_003106__08 30 | 2007_003137__06 31 | 2007_003143__07 32 | 2007_003169__08 33 | 2007_003195__06 34 | 2007_003201__10 35 | 2007_003503__06 36 | 2007_003503__07 37 | 2007_003621__06 38 | 2007_003711__06 39 | 2007_003786__06 40 | 2007_003841__10 41 | 2007_003917__07 42 | 2007_003991__08 43 | 2007_004193__09 44 | 2007_004392__09 45 | 2007_004405__09 46 | 2007_004510__09 47 | 2007_004712__09 48 | 2007_004856__08 49 | 2007_004866__08 50 | 2007_005074__07 51 | 2007_005114__10 52 | 2007_005296__07 53 | 2007_005331__07 54 | 2007_005460__08 55 | 2007_005547__07 56 | 2007_005547__10 57 | 2007_005844__09 58 | 2007_005845__08 59 | 2007_005911__06 60 | 2007_005978__06 61 | 2007_006035__07 62 | 2007_006086__09 63 | 2007_006241__09 64 | 2007_006260__08 65 | 2007_006277__07 66 | 2007_006348__09 67 | 2007_006553__09 68 | 2007_006761__10 69 | 2007_006841__10 70 | 2007_007414__07 71 | 2007_007417__08 72 | 2007_007524__08 73 | 2007_007815__07 74 | 2007_007818__07 75 | 2007_007996__09 76 | 2007_008106__09 77 | 2007_008110__09 78 | 2007_008543__09 79 | 2007_008722__10 80 | 2007_008747__06 81 | 2007_008815__08 82 | 2007_008897__09 83 | 2007_008973__10 84 | 2007_009015__06 85 | 2007_009015__07 86 | 2007_009068__09 87 | 2007_009084__09 88 | 2007_009096__07 89 | 2007_009221__08 90 | 2007_009245__10 91 | 2007_009346__08 92 | 2007_009392__06 93 | 2007_009392__07 94 | 2007_009413__09 95 | 2007_009521__09 96 | 2007_009764__06 97 | 2007_009794__08 98 | 2007_009897__10 99 | 2007_009923__08 100 | 2007_009938__07 101 | 2008_000009__10 102 | 2008_000073__10 103 | 2008_000075__06 104 | 2008_000107__09 105 | 2008_000149__09 106 | 2008_000182__08 107 | 2008_000345__08 108 | 2008_000401__08 109 | 2008_000464__08 110 | 2008_000501__07 111 | 2008_000673__09 112 | 2008_000853__08 113 | 2008_000919__10 114 | 2008_001078__08 115 | 2008_001433__08 116 | 2008_001439__09 117 | 2008_001513__08 118 | 2008_001640__08 119 | 2008_001715__09 120 | 2008_001885__08 121 | 2008_002152__08 122 | 2008_002205__06 123 | 2008_002212__07 124 | 2008_002379__09 125 | 2008_002521__09 126 | 2008_002623__08 127 | 2008_002681__08 128 | 2008_002778__10 129 | 2008_002958__07 130 | 2008_003141__06 131 | 2008_003141__07 132 | 2008_003333__07 133 | 2008_003477__09 134 | 2008_003499__08 135 | 2008_003577__07 136 | 2008_003777__06 137 | 2008_003821__09 138 | 2008_003846__07 139 | 2008_004069__07 140 | 2008_004339__07 141 | 2008_004552__07 142 | 2008_004612__09 143 | 2008_004701__10 144 | 2008_005097__10 145 | 2008_005105__10 146 | 2008_005245__07 147 | 2008_005676__06 148 | 2008_006008__09 149 | 2008_006063__10 150 | 2008_006254__07 151 | 2008_006325__08 152 | 2008_006341__08 153 | 2008_006480__08 154 | 2008_006528__10 155 | 2008_006554__06 156 | 2008_006986__07 157 | 2008_007025__10 158 | 2008_007031__10 159 | 2008_007048__09 160 | 2008_007123__10 161 | 2008_007194__09 162 | 2008_007273__10 163 | 2008_007378__09 164 | 2008_007402__09 165 | 2008_007527__09 166 | 2008_007548__08 167 | 2008_007596__10 168 | 2008_007737__09 169 | 2008_007797__06 170 | 2008_007804__07 171 | 2008_007828__09 172 | 2008_008252__06 173 | 2008_008301__06 174 | 2008_008469__06 175 | 2008_008682__06 176 | 2009_000013__08 177 | 2009_000080__08 178 | 2009_000219__10 179 | 2009_000309__10 180 | 2009_000335__06 181 | 2009_000335__07 182 | 2009_000426__06 183 | 2009_000455__06 184 | 2009_000457__07 185 | 2009_000523__07 186 | 2009_000641__10 187 | 2009_000716__08 188 | 2009_000731__10 189 | 2009_000771__10 190 | 2009_000825__07 191 | 2009_000964__08 192 | 2009_001008__08 193 | 2009_001082__06 194 | 2009_001240__07 195 | 2009_001255__07 196 | 2009_001299__09 197 | 2009_001391__08 198 | 2009_001411__08 199 | 2009_001536__07 200 | 2009_001775__09 201 | 2009_001804__06 202 | 2009_001816__06 203 | 2009_001854__06 204 | 2009_002035__10 205 | 2009_002122__10 206 | 2009_002150__10 207 | 2009_002164__07 208 | 2009_002171__10 209 | 2009_002221__10 210 | 2009_002238__06 211 | 2009_002238__07 212 | 2009_002239__07 213 | 2009_002268__08 214 | 2009_002346__09 215 | 2009_002415__09 216 | 2009_002487__09 217 | 2009_002527__08 218 | 2009_002535__06 219 | 2009_002549__10 220 | 2009_002571__09 221 | 2009_002618__07 222 | 2009_002635__10 223 | 2009_002753__08 224 | 2009_002936__08 225 | 2009_002990__07 226 | 2009_003003__07 227 | 2009_003059__10 228 | 2009_003071__09 229 | 2009_003269__07 230 | 2009_003304__06 231 | 2009_003387__07 232 | 2009_003406__07 233 | 2009_003494__09 234 | 2009_003507__09 235 | 2009_003542__10 236 | 2009_003549__07 237 | 2009_003569__10 238 | 2009_003589__07 239 | 2009_003703__06 240 | 2009_003771__08 241 | 2009_003773__10 242 | 2009_003849__09 243 | 2009_003895__09 244 | 2009_003904__08 245 | 2009_004072__06 246 | 2009_004140__09 247 | 2009_004217__09 248 | 2009_004248__08 249 | 2009_004455__07 250 | 2009_004504__08 251 | 2009_004590__06 252 | 2009_004594__07 253 | 2009_004687__09 254 | 2009_004721__08 255 | 2009_004732__06 256 | 2009_004748__07 257 | 2009_004789__06 258 | 2009_004859__09 259 | 2009_004867__06 260 | 2009_005158__08 261 | 2009_005219__08 262 | 2009_005231__06 263 | 2010_000003__09 264 | 2010_000160__07 265 | 2010_000163__08 266 | 2010_000372__07 267 | 2010_000427__10 268 | 2010_000530__07 269 | 2010_000552__08 270 | 2010_000573__06 271 | 2010_000628__07 272 | 2010_000639__09 273 | 2010_000682__06 274 | 2010_000683__08 275 | 2010_000724__08 276 | 2010_000907__10 277 | 2010_000941__08 278 | 2010_000952__07 279 | 2010_001000__10 280 | 2010_001010__10 281 | 2010_001070__08 282 | 2010_001206__06 283 | 2010_001292__08 284 | 2010_001331__08 285 | 2010_001351__08 286 | 2010_001403__06 287 | 2010_001403__07 288 | 2010_001534__08 289 | 2010_001553__07 290 | 2010_001579__09 291 | 2010_001646__06 292 | 2010_001656__08 293 | 2010_001692__10 294 | 2010_001699__09 295 | 2010_001767__07 296 | 2010_001851__09 297 | 2010_001913__08 298 | 2010_002017__07 299 | 2010_002017__09 300 | 2010_002025__08 301 | 2010_002137__08 302 | 2010_002146__08 303 | 2010_002305__08 304 | 2010_002336__09 305 | 2010_002348__08 306 | 2010_002361__07 307 | 2010_002390__10 308 | 2010_002422__08 309 | 2010_002512__08 310 | 2010_002531__08 311 | 2010_002546__06 312 | 2010_002623__09 313 | 2010_002693__08 314 | 2010_002693__09 315 | 2010_002763__08 316 | 2010_002763__10 317 | 2010_002868__06 318 | 2010_002900__08 319 | 2010_002902__07 320 | 2010_002921__09 321 | 2010_002929__07 322 | 2010_002988__07 323 | 2010_003123__07 324 | 2010_003183__10 325 | 2010_003231__07 326 | 2010_003239__10 327 | 2010_003275__08 328 | 2010_003276__07 329 | 2010_003293__06 330 | 2010_003302__09 331 | 2010_003325__09 332 | 2010_003381__07 333 | 2010_003402__08 334 | 2010_003409__09 335 | 2010_003446__07 336 | 2010_003453__07 337 | 2010_003468__08 338 | 2010_003531__09 339 | 2010_003675__08 340 | 2010_003746__07 341 | 2010_003758__08 342 | 2010_003764__08 343 | 2010_003768__07 344 | 2010_003772__06 345 | 2010_003781__08 346 | 2010_003813__07 347 | 2010_003854__07 348 | 2010_003971__08 349 | 2010_003971__09 350 | 2010_004104__08 351 | 2010_004120__08 352 | 2010_004320__08 353 | 2010_004322__10 354 | 2010_004348__06 355 | 2010_004369__08 356 | 2010_004472__07 357 | 2010_004479__08 358 | 2010_004635__10 359 | 2010_004763__09 360 | 2010_004783__09 361 | 2010_004789__10 362 | 2010_004815__08 363 | 2010_004825__09 364 | 2010_004861__08 365 | 2010_004946__07 366 | 2010_005013__07 367 | 2010_005021__08 368 | 2010_005021__09 369 | 2010_005063__06 370 | 2010_005108__08 371 | 2010_005118__06 372 | 2010_005160__06 373 | 2010_005166__10 374 | 2010_005284__06 375 | 2010_005344__08 376 | 2010_005421__08 377 | 2010_005432__07 378 | 2010_005501__07 379 | 2010_005508__08 380 | 2010_005606__08 381 | 2010_005709__08 382 | 2010_005718__07 383 | 2010_005860__07 384 | 2010_005899__08 385 | 2010_006070__07 386 | 2011_000178__06 387 | 2011_000226__09 388 | 2011_000239__06 389 | 2011_000248__06 390 | 2011_000312__06 391 | 2011_000338__09 392 | 2011_000419__08 393 | 2011_000503__07 394 | 2011_000548__10 395 | 2011_000566__10 396 | 2011_000607__09 397 | 2011_000661__08 398 | 2011_000661__09 399 | 2011_000780__08 400 | 2011_000789__08 401 | 2011_000809__09 402 | 2011_000813__08 403 | 2011_000813__09 404 | 2011_000830__06 405 | 2011_000843__09 406 | 2011_000888__06 407 | 2011_000900__07 408 | 2011_000969__06 409 | 2011_001047__10 410 | 2011_001064__06 411 | 2011_001071__09 412 | 2011_001110__07 413 | 2011_001159__10 414 | 2011_001232__10 415 | 2011_001292__08 416 | 2011_001341__06 417 | 2011_001346__09 418 | 2011_001447__09 419 | 2011_001530__10 420 | 2011_001534__08 421 | 2011_001546__10 422 | 2011_001567__09 423 | 2011_001597__08 424 | 2011_001601__08 425 | 2011_001607__08 426 | 2011_001665__09 427 | 2011_001708__10 428 | 2011_001775__08 429 | 2011_001782__10 430 | 2011_001812__09 431 | 2011_002041__09 432 | 2011_002064__07 433 | 2011_002124__09 434 | 2011_002200__09 435 | 2011_002298__09 436 | 2011_002322__07 437 | 2011_002343__09 438 | 2011_002358__09 439 | 2011_002391__09 440 | 2011_002509__09 441 | 2011_002592__07 442 | 2011_002644__09 443 | 2011_002685__08 444 | 2011_002812__07 445 | 2011_002885__10 446 | 2011_003011__09 447 | 2011_003019__07 448 | 2011_003019__10 449 | 2011_003055__07 450 | 2011_003103__09 451 | 2011_003114__06 452 | -------------------------------------------------------------------------------- /evaluate/splits/pascal/val/fold2.txt: -------------------------------------------------------------------------------- 1 | 2007_000129__15 2 | 2007_000323__15 3 | 2007_000332__13 4 | 2007_000346__15 5 | 2007_000762__11 6 | 2007_000762__15 7 | 2007_000783__13 8 | 2007_000783__15 9 | 2007_000799__13 10 | 2007_000799__15 11 | 2007_000830__11 12 | 2007_000847__11 13 | 2007_000847__15 14 | 2007_000999__15 15 | 2007_001175__15 16 | 2007_001239__12 17 | 2007_001284__15 18 | 2007_001311__15 19 | 2007_001408__15 20 | 2007_001423__15 21 | 2007_001430__11 22 | 2007_001430__15 23 | 2007_001526__15 24 | 2007_001585__15 25 | 2007_001586__13 26 | 2007_001586__15 27 | 2007_001594__15 28 | 2007_001630__15 29 | 2007_001677__11 30 | 2007_001678__15 31 | 2007_001717__15 32 | 2007_001763__12 33 | 2007_001955__13 34 | 2007_002046__13 35 | 2007_002119__15 36 | 2007_002260__14 37 | 2007_002268__12 38 | 2007_002378__15 39 | 2007_002426__15 40 | 2007_002539__15 41 | 2007_002565__15 42 | 2007_002597__12 43 | 2007_002624__11 44 | 2007_002624__15 45 | 2007_002643__15 46 | 2007_002728__15 47 | 2007_002823__14 48 | 2007_002823__15 49 | 2007_002824__15 50 | 2007_002852__12 51 | 2007_003011__11 52 | 2007_003020__15 53 | 2007_003022__13 54 | 2007_003022__15 55 | 2007_003088__15 56 | 2007_003106__15 57 | 2007_003110__12 58 | 2007_003134__15 59 | 2007_003188__15 60 | 2007_003194__12 61 | 2007_003367__14 62 | 2007_003367__15 63 | 2007_003373__12 64 | 2007_003373__15 65 | 2007_003530__15 66 | 2007_003621__15 67 | 2007_003742__11 68 | 2007_003742__15 69 | 2007_003872__12 70 | 2007_004033__14 71 | 2007_004033__15 72 | 2007_004112__12 73 | 2007_004112__15 74 | 2007_004121__15 75 | 2007_004189__12 76 | 2007_004275__14 77 | 2007_004275__15 78 | 2007_004281__15 79 | 2007_004380__14 80 | 2007_004380__15 81 | 2007_004392__15 82 | 2007_004405__11 83 | 2007_004538__13 84 | 2007_004538__15 85 | 2007_004644__12 86 | 2007_004712__11 87 | 2007_004712__15 88 | 2007_004722__13 89 | 2007_004722__15 90 | 2007_004902__13 91 | 2007_004902__15 92 | 2007_005114__13 93 | 2007_005114__15 94 | 2007_005149__12 95 | 2007_005173__14 96 | 2007_005173__15 97 | 2007_005281__15 98 | 2007_005304__15 99 | 2007_005331__13 100 | 2007_005331__15 101 | 2007_005354__14 102 | 2007_005354__15 103 | 2007_005509__15 104 | 2007_005547__15 105 | 2007_005608__14 106 | 2007_005608__15 107 | 2007_005696__12 108 | 2007_005759__14 109 | 2007_005803__11 110 | 2007_005844__11 111 | 2007_005845__15 112 | 2007_006028__15 113 | 2007_006076__15 114 | 2007_006086__11 115 | 2007_006117__15 116 | 2007_006171__12 117 | 2007_006171__15 118 | 2007_006241__11 119 | 2007_006364__13 120 | 2007_006364__15 121 | 2007_006373__15 122 | 2007_006444__12 123 | 2007_006444__15 124 | 2007_006560__15 125 | 2007_006647__14 126 | 2007_006647__15 127 | 2007_006698__15 128 | 2007_006802__15 129 | 2007_006841__15 130 | 2007_006864__15 131 | 2007_006866__13 132 | 2007_006866__15 133 | 2007_007007__11 134 | 2007_007007__15 135 | 2007_007109__13 136 | 2007_007109__15 137 | 2007_007195__15 138 | 2007_007203__15 139 | 2007_007211__14 140 | 2007_007235__15 141 | 2007_007417__12 142 | 2007_007493__15 143 | 2007_007498__11 144 | 2007_007498__15 145 | 2007_007651__11 146 | 2007_007651__15 147 | 2007_007688__14 148 | 2007_007748__13 149 | 2007_007748__15 150 | 2007_007795__15 151 | 2007_007810__11 152 | 2007_007810__15 153 | 2007_007815__15 154 | 2007_007836__15 155 | 2007_007849__15 156 | 2007_007996__15 157 | 2007_008110__15 158 | 2007_008204__15 159 | 2007_008222__12 160 | 2007_008256__13 161 | 2007_008256__15 162 | 2007_008260__12 163 | 2007_008374__15 164 | 2007_008415__12 165 | 2007_008430__15 166 | 2007_008596__13 167 | 2007_008596__15 168 | 2007_008708__15 169 | 2007_008802__13 170 | 2007_008897__15 171 | 2007_008944__15 172 | 2007_008964__12 173 | 2007_008964__15 174 | 2007_008980__12 175 | 2007_009068__15 176 | 2007_009084__12 177 | 2007_009084__14 178 | 2007_009251__13 179 | 2007_009251__15 180 | 2007_009258__15 181 | 2007_009320__15 182 | 2007_009331__12 183 | 2007_009331__13 184 | 2007_009331__15 185 | 2007_009413__11 186 | 2007_009413__15 187 | 2007_009521__11 188 | 2007_009562__12 189 | 2007_009592__12 190 | 2007_009654__15 191 | 2007_009655__15 192 | 2007_009684__15 193 | 2007_009687__15 194 | 2007_009691__14 195 | 2007_009691__15 196 | 2007_009706__11 197 | 2007_009750__15 198 | 2007_009756__14 199 | 2007_009756__15 200 | 2007_009841__13 201 | 2007_009938__14 202 | 2008_000080__12 203 | 2008_000213__15 204 | 2008_000215__15 205 | 2008_000223__15 206 | 2008_000233__15 207 | 2008_000234__15 208 | 2008_000239__12 209 | 2008_000270__12 210 | 2008_000270__15 211 | 2008_000271__15 212 | 2008_000359__15 213 | 2008_000474__15 214 | 2008_000510__15 215 | 2008_000573__11 216 | 2008_000573__15 217 | 2008_000602__13 218 | 2008_000630__15 219 | 2008_000661__12 220 | 2008_000661__15 221 | 2008_000662__15 222 | 2008_000666__15 223 | 2008_000673__15 224 | 2008_000700__15 225 | 2008_000725__15 226 | 2008_000731__15 227 | 2008_000763__11 228 | 2008_000763__15 229 | 2008_000765__13 230 | 2008_000782__14 231 | 2008_000795__15 232 | 2008_000811__14 233 | 2008_000811__15 234 | 2008_000863__12 235 | 2008_000943__12 236 | 2008_000992__15 237 | 2008_001013__15 238 | 2008_001028__15 239 | 2008_001070__12 240 | 2008_001074__15 241 | 2008_001076__15 242 | 2008_001150__14 243 | 2008_001170__15 244 | 2008_001231__15 245 | 2008_001249__15 246 | 2008_001283__15 247 | 2008_001308__15 248 | 2008_001379__12 249 | 2008_001404__15 250 | 2008_001478__12 251 | 2008_001491__15 252 | 2008_001504__15 253 | 2008_001531__15 254 | 2008_001547__15 255 | 2008_001629__15 256 | 2008_001682__13 257 | 2008_001821__15 258 | 2008_001874__15 259 | 2008_001895__12 260 | 2008_001895__15 261 | 2008_001992__13 262 | 2008_001992__15 263 | 2008_002212__15 264 | 2008_002239__12 265 | 2008_002240__14 266 | 2008_002241__15 267 | 2008_002379__11 268 | 2008_002383__14 269 | 2008_002495__15 270 | 2008_002536__12 271 | 2008_002588__15 272 | 2008_002775__11 273 | 2008_002775__15 274 | 2008_002835__13 275 | 2008_002835__15 276 | 2008_002859__12 277 | 2008_002864__11 278 | 2008_002864__15 279 | 2008_002904__12 280 | 2008_002929__15 281 | 2008_002936__12 282 | 2008_002942__15 283 | 2008_002958__12 284 | 2008_003034__15 285 | 2008_003076__15 286 | 2008_003108__15 287 | 2008_003141__15 288 | 2008_003210__15 289 | 2008_003238__12 290 | 2008_003238__15 291 | 2008_003330__15 292 | 2008_003333__14 293 | 2008_003333__15 294 | 2008_003379__13 295 | 2008_003451__14 296 | 2008_003451__15 297 | 2008_003461__13 298 | 2008_003461__15 299 | 2008_003477__11 300 | 2008_003492__15 301 | 2008_003511__12 302 | 2008_003511__15 303 | 2008_003546__15 304 | 2008_003576__12 305 | 2008_003676__15 306 | 2008_003733__15 307 | 2008_003782__13 308 | 2008_003856__15 309 | 2008_003874__15 310 | 2008_004101__15 311 | 2008_004140__11 312 | 2008_004140__15 313 | 2008_004175__13 314 | 2008_004345__14 315 | 2008_004396__13 316 | 2008_004399__14 317 | 2008_004399__15 318 | 2008_004575__11 319 | 2008_004575__15 320 | 2008_004624__13 321 | 2008_004654__15 322 | 2008_004687__13 323 | 2008_004705__13 324 | 2008_005049__14 325 | 2008_005089__15 326 | 2008_005145__11 327 | 2008_005197__12 328 | 2008_005197__15 329 | 2008_005245__14 330 | 2008_005245__15 331 | 2008_005399__15 332 | 2008_005422__14 333 | 2008_005445__15 334 | 2008_005525__13 335 | 2008_005637__14 336 | 2008_005642__13 337 | 2008_005691__13 338 | 2008_005738__15 339 | 2008_005812__15 340 | 2008_005915__14 341 | 2008_006008__11 342 | 2008_006036__13 343 | 2008_006108__11 344 | 2008_006108__15 345 | 2008_006130__12 346 | 2008_006216__15 347 | 2008_006219__13 348 | 2008_006254__15 349 | 2008_006275__15 350 | 2008_006341__15 351 | 2008_006408__11 352 | 2008_006408__15 353 | 2008_006526__14 354 | 2008_006526__15 355 | 2008_006554__15 356 | 2008_006722__12 357 | 2008_006722__15 358 | 2008_006874__14 359 | 2008_006874__15 360 | 2008_006981__12 361 | 2008_007048__11 362 | 2008_007219__15 363 | 2008_007378__11 364 | 2008_007378__12 365 | 2008_007392__13 366 | 2008_007392__15 367 | 2008_007402__11 368 | 2008_007402__15 369 | 2008_007513__12 370 | 2008_007737__15 371 | 2008_007828__15 372 | 2008_007945__13 373 | 2008_007994__15 374 | 2008_008051__11 375 | 2008_008127__14 376 | 2008_008127__15 377 | 2008_008221__15 378 | 2008_008335__11 379 | 2008_008335__15 380 | 2008_008362__11 381 | 2008_008362__15 382 | 2008_008392__13 383 | 2008_008393__13 384 | 2008_008421__13 385 | 2008_008469__15 386 | 2009_000012__13 387 | 2009_000074__14 388 | 2009_000074__15 389 | 2009_000156__12 390 | 2009_000219__15 391 | 2009_000309__15 392 | 2009_000412__13 393 | 2009_000418__15 394 | 2009_000421__15 395 | 2009_000457__15 396 | 2009_000704__15 397 | 2009_000705__13 398 | 2009_000727__13 399 | 2009_000730__14 400 | 2009_000730__15 401 | 2009_000825__14 402 | 2009_000825__15 403 | 2009_000839__12 404 | 2009_000892__12 405 | 2009_000931__13 406 | 2009_000935__12 407 | 2009_001215__11 408 | 2009_001215__15 409 | 2009_001299__15 410 | 2009_001433__13 411 | 2009_001433__15 412 | 2009_001535__12 413 | 2009_001663__15 414 | 2009_001687__12 415 | 2009_001687__15 416 | 2009_001718__15 417 | 2009_001768__15 418 | 2009_001854__15 419 | 2009_002012__12 420 | 2009_002042__15 421 | 2009_002097__13 422 | 2009_002155__12 423 | 2009_002165__13 424 | 2009_002185__15 425 | 2009_002239__14 426 | 2009_002239__15 427 | 2009_002317__14 428 | 2009_002317__15 429 | 2009_002346__12 430 | 2009_002346__15 431 | 2009_002372__15 432 | 2009_002382__14 433 | 2009_002382__15 434 | 2009_002415__11 435 | 2009_002445__12 436 | 2009_002487__11 437 | 2009_002539__12 438 | 2009_002571__11 439 | 2009_002584__15 440 | 2009_002649__15 441 | 2009_002651__14 442 | 2009_002651__15 443 | 2009_002732__15 444 | 2009_002975__13 445 | 2009_003003__11 446 | 2009_003003__15 447 | 2009_003063__12 448 | 2009_003065__15 449 | 2009_003071__11 450 | 2009_003071__15 451 | 2009_003123__11 452 | 2009_003196__14 453 | 2009_003217__12 454 | 2009_003241__12 455 | 2009_003269__15 456 | 2009_003323__13 457 | 2009_003323__15 458 | 2009_003466__12 459 | 2009_003481__13 460 | 2009_003494__15 461 | 2009_003507__11 462 | 2009_003576__14 463 | 2009_003576__15 464 | 2009_003756__12 465 | 2009_003804__13 466 | 2009_003810__12 467 | 2009_003849__11 468 | 2009_003849__15 469 | 2009_003903__13 470 | 2009_003928__12 471 | 2009_003991__11 472 | 2009_003991__15 473 | 2009_004033__12 474 | 2009_004043__14 475 | 2009_004043__15 476 | 2009_004140__11 477 | 2009_004221__15 478 | 2009_004455__14 479 | 2009_004497__13 480 | 2009_004507__12 481 | 2009_004507__15 482 | 2009_004581__12 483 | 2009_004592__12 484 | 2009_004738__14 485 | 2009_004738__15 486 | 2009_004848__15 487 | 2009_004859__11 488 | 2009_004859__15 489 | 2009_004942__13 490 | 2009_004987__14 491 | 2009_004987__15 492 | 2009_004994__12 493 | 2009_004994__15 494 | 2009_005038__11 495 | 2009_005038__15 496 | 2009_005078__14 497 | 2009_005087__15 498 | 2009_005217__13 499 | 2009_005217__15 500 | 2010_000003__12 501 | 2010_000038__13 502 | 2010_000038__15 503 | 2010_000087__14 504 | 2010_000087__15 505 | 2010_000110__12 506 | 2010_000110__15 507 | 2010_000159__12 508 | 2010_000174__11 509 | 2010_000174__15 510 | 2010_000216__12 511 | 2010_000238__15 512 | 2010_000256__15 513 | 2010_000422__12 514 | 2010_000530__15 515 | 2010_000559__15 516 | 2010_000639__12 517 | 2010_000666__13 518 | 2010_000666__15 519 | 2010_000738__15 520 | 2010_000788__12 521 | 2010_000874__13 522 | 2010_000904__12 523 | 2010_001024__15 524 | 2010_001124__12 525 | 2010_001251__14 526 | 2010_001264__12 527 | 2010_001313__14 528 | 2010_001313__15 529 | 2010_001367__15 530 | 2010_001376__12 531 | 2010_001451__13 532 | 2010_001553__14 533 | 2010_001563__12 534 | 2010_001563__15 535 | 2010_001579__11 536 | 2010_001579__15 537 | 2010_001692__15 538 | 2010_001699__15 539 | 2010_001734__15 540 | 2010_001767__15 541 | 2010_001851__11 542 | 2010_001908__12 543 | 2010_001956__12 544 | 2010_002017__15 545 | 2010_002137__15 546 | 2010_002161__13 547 | 2010_002161__15 548 | 2010_002228__12 549 | 2010_002251__14 550 | 2010_002251__15 551 | 2010_002271__14 552 | 2010_002336__11 553 | 2010_002396__14 554 | 2010_002396__15 555 | 2010_002480__12 556 | 2010_002623__15 557 | 2010_002691__13 558 | 2010_002763__15 559 | 2010_002792__15 560 | 2010_002902__15 561 | 2010_002929__15 562 | 2010_003014__15 563 | 2010_003060__12 564 | 2010_003187__12 565 | 2010_003207__14 566 | 2010_003239__15 567 | 2010_003325__11 568 | 2010_003325__15 569 | 2010_003381__15 570 | 2010_003409__15 571 | 2010_003446__15 572 | 2010_003506__12 573 | 2010_003531__11 574 | 2010_003532__13 575 | 2010_003597__11 576 | 2010_003597__15 577 | 2010_003746__12 578 | 2010_003746__15 579 | 2010_003947__14 580 | 2010_003971__11 581 | 2010_004042__14 582 | 2010_004165__12 583 | 2010_004165__15 584 | 2010_004219__14 585 | 2010_004219__15 586 | 2010_004337__15 587 | 2010_004355__14 588 | 2010_004432__15 589 | 2010_004472__15 590 | 2010_004479__15 591 | 2010_004519__13 592 | 2010_004550__12 593 | 2010_004559__15 594 | 2010_004628__12 595 | 2010_004697__14 596 | 2010_004697__15 597 | 2010_004795__12 598 | 2010_004815__15 599 | 2010_004825__11 600 | 2010_004828__15 601 | 2010_004856__13 602 | 2010_004941__14 603 | 2010_004951__15 604 | 2010_005046__11 605 | 2010_005046__15 606 | 2010_005118__15 607 | 2010_005159__12 608 | 2010_005160__14 609 | 2010_005166__15 610 | 2010_005174__13 611 | 2010_005206__12 612 | 2010_005245__12 613 | 2010_005245__15 614 | 2010_005252__14 615 | 2010_005252__15 616 | 2010_005284__15 617 | 2010_005366__14 618 | 2010_005433__14 619 | 2010_005501__14 620 | 2010_005575__12 621 | 2010_005582__15 622 | 2010_005606__15 623 | 2010_005626__11 624 | 2010_005626__15 625 | 2010_005644__12 626 | 2010_005709__15 627 | 2010_005871__15 628 | 2010_005991__12 629 | 2010_005991__15 630 | 2010_005992__12 631 | 2011_000045__12 632 | 2011_000051__15 633 | 2011_000054__15 634 | 2011_000178__15 635 | 2011_000226__11 636 | 2011_000248__15 637 | 2011_000338__11 638 | 2011_000396__13 639 | 2011_000435__15 640 | 2011_000438__15 641 | 2011_000455__14 642 | 2011_000455__15 643 | 2011_000479__15 644 | 2011_000512__14 645 | 2011_000526__13 646 | 2011_000536__12 647 | 2011_000566__15 648 | 2011_000585__15 649 | 2011_000598__11 650 | 2011_000618__14 651 | 2011_000618__15 652 | 2011_000638__15 653 | 2011_000780__15 654 | 2011_000809__11 655 | 2011_000809__15 656 | 2011_000843__15 657 | 2011_000953__11 658 | 2011_000953__15 659 | 2011_001014__12 660 | 2011_001060__15 661 | 2011_001069__15 662 | 2011_001071__15 663 | 2011_001159__15 664 | 2011_001276__11 665 | 2011_001276__12 666 | 2011_001276__15 667 | 2011_001346__15 668 | 2011_001416__15 669 | 2011_001447__15 670 | 2011_001530__15 671 | 2011_001567__15 672 | 2011_001619__15 673 | 2011_001642__12 674 | 2011_001665__11 675 | 2011_001674__15 676 | 2011_001714__12 677 | 2011_001714__15 678 | 2011_001722__13 679 | 2011_001745__12 680 | 2011_001794__15 681 | 2011_001862__11 682 | 2011_001862__12 683 | 2011_001868__12 684 | 2011_001984__12 685 | 2011_001988__15 686 | 2011_002002__15 687 | 2011_002040__12 688 | 2011_002075__11 689 | 2011_002075__15 690 | 2011_002098__12 691 | 2011_002110__12 692 | 2011_002110__15 693 | 2011_002121__12 694 | 2011_002124__15 695 | 2011_002156__12 696 | 2011_002200__11 697 | 2011_002200__15 698 | 2011_002247__15 699 | 2011_002279__12 700 | 2011_002298__12 701 | 2011_002308__15 702 | 2011_002317__15 703 | 2011_002322__14 704 | 2011_002322__15 705 | 2011_002343__15 706 | 2011_002358__11 707 | 2011_002358__15 708 | 2011_002371__12 709 | 2011_002498__15 710 | 2011_002509__15 711 | 2011_002532__15 712 | 2011_002575__15 713 | 2011_002578__15 714 | 2011_002589__12 715 | 2011_002623__15 716 | 2011_002641__15 717 | 2011_002675__15 718 | 2011_002951__13 719 | 2011_002997__15 720 | 2011_003019__14 721 | 2011_003019__15 722 | 2011_003085__13 723 | 2011_003114__15 724 | 2011_003240__15 725 | 2011_003256__12 726 | -------------------------------------------------------------------------------- /evaluate/splits/pascal/val/fold3.txt: -------------------------------------------------------------------------------- 1 | 2007_000042__19 2 | 2007_000123__19 3 | 2007_000175__17 4 | 2007_000187__20 5 | 2007_000452__18 6 | 2007_000559__20 7 | 2007_000629__19 8 | 2007_000636__19 9 | 2007_000661__18 10 | 2007_000676__17 11 | 2007_000804__18 12 | 2007_000925__17 13 | 2007_001154__18 14 | 2007_001175__20 15 | 2007_001408__16 16 | 2007_001430__16 17 | 2007_001430__20 18 | 2007_001457__18 19 | 2007_001458__18 20 | 2007_001585__18 21 | 2007_001594__17 22 | 2007_001678__20 23 | 2007_001717__20 24 | 2007_001733__17 25 | 2007_001763__18 26 | 2007_001763__20 27 | 2007_002119__20 28 | 2007_002132__20 29 | 2007_002268__18 30 | 2007_002284__16 31 | 2007_002378__16 32 | 2007_002426__18 33 | 2007_002427__18 34 | 2007_002565__19 35 | 2007_002618__17 36 | 2007_002648__17 37 | 2007_002728__19 38 | 2007_003011__18 39 | 2007_003011__20 40 | 2007_003169__18 41 | 2007_003367__16 42 | 2007_003499__19 43 | 2007_003506__16 44 | 2007_003530__18 45 | 2007_003587__19 46 | 2007_003714__17 47 | 2007_003848__19 48 | 2007_003957__19 49 | 2007_004190__20 50 | 2007_004193__20 51 | 2007_004275__16 52 | 2007_004281__19 53 | 2007_004483__19 54 | 2007_004510__20 55 | 2007_004558__16 56 | 2007_004649__19 57 | 2007_004712__16 58 | 2007_004969__17 59 | 2007_005469__17 60 | 2007_005626__19 61 | 2007_005689__19 62 | 2007_005813__16 63 | 2007_005857__16 64 | 2007_005915__17 65 | 2007_006171__18 66 | 2007_006348__20 67 | 2007_006373__18 68 | 2007_006678__17 69 | 2007_006680__19 70 | 2007_006802__19 71 | 2007_007130__20 72 | 2007_007165__17 73 | 2007_007168__19 74 | 2007_007195__19 75 | 2007_007196__20 76 | 2007_007203__20 77 | 2007_007417__18 78 | 2007_007534__17 79 | 2007_007624__16 80 | 2007_007795__16 81 | 2007_007881__19 82 | 2007_007996__18 83 | 2007_008204__20 84 | 2007_008260__18 85 | 2007_008339__19 86 | 2007_008374__20 87 | 2007_008543__18 88 | 2007_008547__16 89 | 2007_009068__18 90 | 2007_009252__18 91 | 2007_009320__17 92 | 2007_009419__16 93 | 2007_009446__20 94 | 2007_009521__18 95 | 2007_009521__20 96 | 2007_009592__18 97 | 2007_009655__18 98 | 2007_009684__18 99 | 2007_009750__16 100 | 2008_000016__20 101 | 2008_000149__18 102 | 2008_000270__18 103 | 2008_000391__16 104 | 2008_000589__18 105 | 2008_000657__19 106 | 2008_001078__16 107 | 2008_001283__16 108 | 2008_001688__16 109 | 2008_001688__20 110 | 2008_001966__16 111 | 2008_002273__16 112 | 2008_002379__16 113 | 2008_002464__20 114 | 2008_002536__17 115 | 2008_002680__20 116 | 2008_002900__19 117 | 2008_002929__18 118 | 2008_003003__20 119 | 2008_003026__20 120 | 2008_003105__19 121 | 2008_003135__16 122 | 2008_003676__16 123 | 2008_003709__18 124 | 2008_003733__18 125 | 2008_003885__20 126 | 2008_004172__18 127 | 2008_004212__19 128 | 2008_004279__20 129 | 2008_004367__19 130 | 2008_004453__17 131 | 2008_004477__16 132 | 2008_004562__18 133 | 2008_004610__19 134 | 2008_004621__17 135 | 2008_004754__20 136 | 2008_004854__17 137 | 2008_004910__20 138 | 2008_005089__20 139 | 2008_005217__16 140 | 2008_005242__16 141 | 2008_005254__20 142 | 2008_005439__20 143 | 2008_005445__20 144 | 2008_005544__19 145 | 2008_005633__17 146 | 2008_005680__16 147 | 2008_006055__19 148 | 2008_006159__20 149 | 2008_006327__17 150 | 2008_006523__19 151 | 2008_006553__19 152 | 2008_006752__19 153 | 2008_006784__18 154 | 2008_006835__17 155 | 2008_007497__17 156 | 2008_007527__20 157 | 2008_007677__17 158 | 2008_007814__17 159 | 2008_007828__20 160 | 2008_008103__18 161 | 2008_008221__19 162 | 2008_008434__16 163 | 2009_000022__19 164 | 2009_000039__17 165 | 2009_000087__18 166 | 2009_000096__18 167 | 2009_000136__20 168 | 2009_000242__18 169 | 2009_000391__20 170 | 2009_000418__16 171 | 2009_000418__18 172 | 2009_000487__18 173 | 2009_000488__16 174 | 2009_000488__20 175 | 2009_000628__19 176 | 2009_000675__17 177 | 2009_000704__20 178 | 2009_000712__19 179 | 2009_000732__18 180 | 2009_000845__19 181 | 2009_000924__17 182 | 2009_001300__19 183 | 2009_001333__19 184 | 2009_001363__20 185 | 2009_001505__17 186 | 2009_001644__16 187 | 2009_001644__18 188 | 2009_001644__20 189 | 2009_001684__16 190 | 2009_001731__18 191 | 2009_001768__17 192 | 2009_001775__16 193 | 2009_001775__18 194 | 2009_001991__17 195 | 2009_002082__17 196 | 2009_002094__20 197 | 2009_002202__19 198 | 2009_002265__19 199 | 2009_002291__19 200 | 2009_002346__18 201 | 2009_002366__20 202 | 2009_002390__18 203 | 2009_002487__16 204 | 2009_002562__20 205 | 2009_002568__19 206 | 2009_002571__16 207 | 2009_002571__18 208 | 2009_002573__20 209 | 2009_002584__16 210 | 2009_002638__19 211 | 2009_002732__18 212 | 2009_002887__19 213 | 2009_002982__19 214 | 2009_003105__19 215 | 2009_003123__18 216 | 2009_003299__19 217 | 2009_003311__19 218 | 2009_003433__19 219 | 2009_003523__20 220 | 2009_003551__20 221 | 2009_003564__16 222 | 2009_003564__18 223 | 2009_003607__18 224 | 2009_003666__17 225 | 2009_003857__20 226 | 2009_003895__18 227 | 2009_003895__20 228 | 2009_003938__19 229 | 2009_004099__18 230 | 2009_004140__18 231 | 2009_004255__19 232 | 2009_004298__18 233 | 2009_004687__18 234 | 2009_004730__19 235 | 2009_004799__19 236 | 2009_004993__18 237 | 2009_004993__20 238 | 2009_005148__19 239 | 2009_005220__19 240 | 2010_000256__18 241 | 2010_000284__18 242 | 2010_000309__17 243 | 2010_000318__20 244 | 2010_000330__16 245 | 2010_000639__16 246 | 2010_000738__20 247 | 2010_000764__19 248 | 2010_001011__17 249 | 2010_001079__17 250 | 2010_001104__19 251 | 2010_001149__18 252 | 2010_001151__19 253 | 2010_001246__16 254 | 2010_001256__17 255 | 2010_001327__18 256 | 2010_001367__20 257 | 2010_001522__17 258 | 2010_001557__17 259 | 2010_001577__17 260 | 2010_001699__16 261 | 2010_001734__19 262 | 2010_001752__20 263 | 2010_001767__18 264 | 2010_001773__16 265 | 2010_001851__16 266 | 2010_001951__19 267 | 2010_001962__18 268 | 2010_002106__17 269 | 2010_002137__16 270 | 2010_002137__18 271 | 2010_002232__17 272 | 2010_002531__18 273 | 2010_002682__19 274 | 2010_002921__20 275 | 2010_003014__18 276 | 2010_003123__16 277 | 2010_003302__16 278 | 2010_003514__19 279 | 2010_003541__17 280 | 2010_003597__18 281 | 2010_003781__16 282 | 2010_003956__19 283 | 2010_004149__19 284 | 2010_004226__17 285 | 2010_004382__16 286 | 2010_004479__20 287 | 2010_004757__16 288 | 2010_004757__18 289 | 2010_004783__18 290 | 2010_004825__16 291 | 2010_004857__20 292 | 2010_004951__19 293 | 2010_004980__19 294 | 2010_005180__18 295 | 2010_005187__16 296 | 2010_005305__20 297 | 2010_005606__18 298 | 2010_005706__19 299 | 2010_005719__17 300 | 2010_005727__19 301 | 2010_005788__17 302 | 2010_005860__16 303 | 2010_005871__19 304 | 2010_005991__18 305 | 2010_006054__19 306 | 2011_000070__18 307 | 2011_000173__18 308 | 2011_000283__19 309 | 2011_000291__19 310 | 2011_000310__18 311 | 2011_000436__17 312 | 2011_000521__19 313 | 2011_000747__16 314 | 2011_001005__18 315 | 2011_001060__19 316 | 2011_001281__19 317 | 2011_001350__17 318 | 2011_001567__18 319 | 2011_001601__18 320 | 2011_001614__19 321 | 2011_001674__18 322 | 2011_001713__16 323 | 2011_001713__18 324 | 2011_001726__20 325 | 2011_001794__18 326 | 2011_001862__18 327 | 2011_001863__16 328 | 2011_001910__20 329 | 2011_002124__18 330 | 2011_002156__20 331 | 2011_002178__17 332 | 2011_002247__19 333 | 2011_002379__19 334 | 2011_002391__18 335 | 2011_002532__20 336 | 2011_002535__19 337 | 2011_002644__18 338 | 2011_002644__20 339 | 2011_002879__18 340 | 2011_002879__20 341 | 2011_003103__16 342 | 2011_003103__18 343 | 2011_003146__19 344 | 2011_003182__18 345 | 2011_003197__19 346 | 2011_003256__18 347 | -------------------------------------------------------------------------------- /evaluate_detection/2012_support_set.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirbar/visual_prompting/6351e0d786c2d25bfdd5fe504153b1dfb345fa81/evaluate_detection/2012_support_set.pth -------------------------------------------------------------------------------- /evaluate_detection/2012_val_flattened_set.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirbar/visual_prompting/6351e0d786c2d25bfdd5fe504153b1dfb345fa81/evaluate_detection/2012_val_flattened_set.pth -------------------------------------------------------------------------------- /evaluate_detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirbar/visual_prompting/6351e0d786c2d25bfdd5fe504153b1dfb345fa81/evaluate_detection/__init__.py -------------------------------------------------------------------------------- /evaluate_detection/box_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def to_rectangle(img, start_h=113, start_w=113): 8 | ''' 9 | assuming image is a binary mask 10 | ''' 11 | # from matplotlib import pyplot as plt 12 | 13 | img_np = img.numpy().astype('uint8')[start_h:, start_w:, 0] 14 | num_labels, labels_im = cv2.connectedComponents(img_np) 15 | new_img = np.zeros((img_np.shape[0], img_np.shape[1])) 16 | indices = np.argsort([np.sum(labels_im == i) for i in range(num_labels)])[::-1] 17 | for i in indices: 18 | indices_y, indices_x = np.where(labels_im == i) 19 | if img_np[indices_y[0], indices_x[0]] != 255: 20 | continue 21 | new_img[np.min(indices_y): np.max(indices_y) + 1, np.min(indices_x): np.max(indices_x) + 1] = 255 22 | break 23 | new_img = torch.tensor(new_img) 24 | new_img = torch.stack([new_img, new_img, new_img], dim=-1) 25 | img[start_h:, start_w:] = new_img 26 | 27 | return img 28 | 29 | def box_cxcywh_to_xyxy(x): 30 | x_c, y_c, w, h = x.unbind(-1) 31 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 32 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 33 | return torch.stack(b, dim=-1) 34 | 35 | 36 | def box_xyxy_to_cxcywh(x): 37 | x0, y0, x1, y1 = x.unbind(-1) 38 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 39 | (x1 - x0), (y1 - y0)] 40 | return torch.stack(b, dim=-1) 41 | 42 | -------------------------------------------------------------------------------- /evaluate_detection/canvas_ds.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from evaluate_detection.voc_orig import VOCDetection as VOCDetectionOrig, make_transforms 3 | import cv2 4 | from evaluate.pascal_dataloader import create_grid_from_images_old as create_grid_from_images 5 | from PIL import Image 6 | from evaluate.mae_utils import * 7 | from matplotlib import pyplot as plt 8 | import torch 9 | import numpy as np 10 | import torchvision.transforms as T 11 | 12 | 13 | def box_to_img(mask, target, border_width=4): 14 | if mask is None: 15 | mask = np.zeros((112, 112, 3)) 16 | h, w, _ = mask.shape 17 | for box in target['boxes']: 18 | x_min, y_min, x_max, y_max = list((box * (h - 1)).round().int().numpy()) 19 | cv2.rectangle(mask, (x_min, y_min), (x_max, y_max), (255, 255, 255), border_width) 20 | return Image.fromarray(mask.astype('uint8')) 21 | 22 | 23 | def get_annotated_image(img, boxes, border_width=3, mode='draw', bgcolor='white', fg='image'): 24 | if mode == 'draw': 25 | image_copy = np.array(img.copy()) 26 | for box in boxes: 27 | box = box.numpy().astype('int') 28 | cv2.rectangle(image_copy, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), border_width) 29 | elif mode == 'keep': 30 | image_copy = np.array(Image.new('RGB', (img.shape[1], img.shape[0]), color=bgcolor)) 31 | 32 | for box in boxes: 33 | box = box.numpy().astype('int') 34 | if fg == 'image': 35 | image_copy[box[1]:box[3], box[0]:box[2]] = img[box[1]:box[3], box[0]:box[2]] 36 | elif fg == 'white': 37 | image_copy[box[1]:box[3], box[0]:box[2]] = 255 38 | 39 | 40 | 41 | 42 | return image_copy 43 | 44 | 45 | 46 | 47 | # ids_shuffle, len_keep = generate_mask_for_evaluation_2rows() 48 | 49 | class CanvasDataset(data.Dataset): 50 | 51 | def __init__(self, pascal_path='/shared/yossi_gandelsman/code/occlusionwalk/pascal', years=("2012",), random=False, **kwargs): 52 | self.train_ds = VOCDetectionOrig(pascal_path, years, image_sets=['train'], transforms=None) 53 | self.val_ds = VOCDetectionOrig(pascal_path, years, image_sets=['val'], transforms=None) 54 | self.background_transforms = T.Compose([ 55 | T.Resize((224, 224)), 56 | T.Compose([ 57 | T.ToTensor(), 58 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 59 | ]) 60 | ]) 61 | self.transforms = make_transforms('val') 62 | self.random = random 63 | 64 | 65 | def __len__(self): 66 | return len(self.val_ds) 67 | 68 | def __getitem__(self, idx): 69 | 70 | query_image, query_target = self.val_ds[idx] 71 | # should we run on all classes? 72 | label = np.random.choice(query_target['labels']).item() 73 | 74 | # how many supports should we use? 75 | indices = np.arange(len(self.train_ds)) 76 | np.random.shuffle(indices) 77 | 78 | for idx in indices: 79 | support_image, support_target = self.train_ds[idx] 80 | if torch.any(support_target['labels'] == label).item() or self.random: 81 | break 82 | 83 | boxes = support_target['boxes'][torch.where(support_target['labels'] == label)[0]] 84 | support_image_copy = get_annotated_image(np.array(support_image), boxes, border_width=-1, mode='keep', bgcolor='black', fg='white') 85 | support_image_copy_pil = Image.fromarray(support_image_copy) 86 | 87 | boxes = query_target['boxes'][torch.where(query_target['labels'] == label)[0]] 88 | query_image_copy = get_annotated_image(np.array(query_image), boxes, border_width=-1, mode='keep', bgcolor='black', fg='white') 89 | query_image_copy_pil = Image.fromarray(query_image_copy) 90 | 91 | query_image_ten = self.transforms(query_image, None)[0] 92 | query_target_ten = self.transforms(query_image_copy_pil, None)[0] 93 | support_target_ten = self.transforms(support_image_copy_pil, None)[0] 94 | support_image_ten = self.transforms(support_image, None)[0] 95 | 96 | background_image = Image.new('RGB', (224, 224), color='white') 97 | background_image = self.background_transforms(background_image) 98 | canvas = create_grid_from_images(background_image, support_image_ten, support_target_ten, query_image_ten, 99 | query_target_ten) 100 | 101 | return {'grid': canvas} 102 | 103 | -------------------------------------------------------------------------------- /evaluate_detection/transforms.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | Transforms and data augmentation for both image + bbox. 12 | """ 13 | import random 14 | import torchvision.transforms.functional_tensor as TF 15 | 16 | import PIL 17 | import torch 18 | import torchvision.transforms as T 19 | import torchvision.transforms.functional as F 20 | from torchvision.transforms import RandomResizedCrop 21 | 22 | from evaluate_detection.box_ops import box_xyxy_to_cxcywh, box_cxcywh_to_xyxy 23 | from evaluate_detection.misc import interpolate 24 | 25 | 26 | def crop(image, target, region): 27 | cropped_image = F.crop(image, *region) 28 | 29 | target = target.copy() 30 | i, j, h, w = region 31 | 32 | # should we do something wrt the original size? 33 | target["size"] = torch.tensor([h, w]) 34 | 35 | fields = ["labels", "area", "iscrowd", "patches"] 36 | 37 | if "boxes" in target: 38 | boxes = target["boxes"] 39 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 40 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 41 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 42 | cropped_boxes = cropped_boxes.clamp(min=0) 43 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 44 | target["boxes"] = cropped_boxes.reshape(-1, 4) 45 | target["area"] = area 46 | fields.append("boxes") 47 | 48 | if "masks" in target: 49 | # FIXME should we update the area here if there are no boxes? 50 | target['masks'] = target['masks'][:, i:i + h, j:j + w] 51 | fields.append("masks") 52 | 53 | # remove elements for which the boxes or masks that have zero area 54 | if "boxes" in target or "masks" in target: 55 | # favor boxes selection when defining which elements to keep 56 | # this is compatible with previous implementation 57 | if "boxes" in target: 58 | cropped_boxes = target['boxes'].reshape(-1, 2, 2) 59 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) 60 | else: 61 | keep = target['masks'].flatten(1).any(1) 62 | 63 | for field in fields: 64 | if field in target: 65 | target[field] = target[field][keep] 66 | 67 | return cropped_image, target 68 | 69 | 70 | def hflip(image, target): 71 | flipped_image = F.hflip(image) 72 | 73 | w, h = image.size 74 | 75 | target = target.copy() 76 | if "boxes" in target: 77 | boxes = target["boxes"] 78 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) 79 | target["boxes"] = boxes 80 | 81 | if "masks" in target: 82 | target['masks'] = target['masks'].flip(-1) 83 | 84 | return flipped_image, target 85 | 86 | 87 | def resize(image, target, size, max_size=None): 88 | # size can be min_size (scalar) or (w, h) tuple 89 | 90 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 91 | w, h = image_size 92 | if max_size is not None: 93 | min_original_size = float(min((w, h))) 94 | max_original_size = float(max((w, h))) 95 | if max_original_size / min_original_size * size > max_size: 96 | size = int(round(max_size * min_original_size / max_original_size)) 97 | 98 | if (w <= h and w == size) or (h <= w and h == size): 99 | return (h, w) 100 | 101 | if w < h: 102 | ow = size 103 | oh = int(size * h / w) 104 | else: 105 | oh = size 106 | ow = int(size * w / h) 107 | 108 | return (oh, ow) 109 | 110 | def get_size(image_size, size, max_size=None): 111 | if isinstance(size, (list, tuple)): 112 | return size[::-1] 113 | else: 114 | return get_size_with_aspect_ratio(image_size, size, max_size) 115 | 116 | size = get_size(image.size, size, max_size) 117 | rescaled_image = F.resize(image, size) 118 | 119 | if target is None: 120 | return rescaled_image, None 121 | 122 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) 123 | ratio_width, ratio_height = ratios 124 | 125 | target = target.copy() 126 | if "boxes" in target: 127 | boxes = target["boxes"] 128 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 129 | target["boxes"] = scaled_boxes 130 | 131 | if "area" in target: 132 | area = target["area"] 133 | scaled_area = area * (ratio_width * ratio_height) 134 | target["area"] = scaled_area 135 | 136 | h, w = size 137 | target["size"] = torch.tensor([h, w]) 138 | 139 | if "masks" in target: 140 | target['masks'] = interpolate( 141 | target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 142 | 143 | return rescaled_image, target 144 | 145 | 146 | def pad(image, target, padding): 147 | # assumes that we only pad on the bottom right corners 148 | padded_image = F.pad(image, (0, 0, padding[0], padding[1])) 149 | if target is None: 150 | return padded_image, None 151 | target = target.copy() 152 | # should we do something wrt the original size? 153 | target["size"] = torch.tensor(padded_image[::-1]) 154 | if "masks" in target: 155 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) 156 | return padded_image, target 157 | 158 | 159 | class RandomCrop(object): 160 | def __init__(self, size): 161 | self.size = size 162 | 163 | def __call__(self, img, target): 164 | region = T.RandomCrop.get_params(img, self.size) 165 | return crop(img, target, region) 166 | 167 | 168 | class RandomSizeCrop(object): 169 | def __init__(self, min_size: int, max_size: int): 170 | self.min_size = min_size 171 | self.max_size = max_size 172 | 173 | def __call__(self, img: PIL.Image.Image, target: dict): 174 | w = random.randint(self.min_size, min(img.width, self.max_size)) 175 | h = random.randint(self.min_size, min(img.height, self.max_size)) 176 | region = T.RandomCrop.get_params(img, [h, w]) 177 | return crop(img, target, region) 178 | 179 | 180 | class CenterCrop(object): 181 | def __init__(self, size): 182 | self.size = size 183 | 184 | def __call__(self, img, target): 185 | image_width, image_height = img.size 186 | crop_height, crop_width = self.size 187 | crop_top = int(round((image_height - crop_height) / 2.)) 188 | crop_left = int(round((image_width - crop_width) / 2.)) 189 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 190 | 191 | 192 | class RandomHorizontalFlip(object): 193 | def __init__(self, p=0.5): 194 | self.p = p 195 | 196 | def __call__(self, img, target): 197 | if random.random() < self.p: 198 | return hflip(img, target) 199 | return img, target 200 | 201 | 202 | class RandomResize(object): 203 | def __init__(self, sizes, max_size=None): 204 | assert isinstance(sizes, (list, tuple)) 205 | self.sizes = sizes 206 | self.max_size = max_size 207 | 208 | def __call__(self, img, target=None): 209 | size = random.choice(self.sizes) 210 | return resize(img, target, size, self.max_size) 211 | 212 | 213 | class RandomPad(object): 214 | def __init__(self, max_pad): 215 | self.max_pad = max_pad 216 | 217 | def __call__(self, img, target): 218 | pad_x = random.randint(0, self.max_pad) 219 | pad_y = random.randint(0, self.max_pad) 220 | return pad(img, target, (pad_x, pad_y)) 221 | 222 | 223 | class RandomSelect(object): 224 | """ 225 | Randomly selects between transforms1 and transforms2, 226 | with probability p for transforms1 and (1 - p) for transforms2 227 | """ 228 | def __init__(self, transforms1, transforms2, p=0.5): 229 | self.transforms1 = transforms1 230 | self.transforms2 = transforms2 231 | self.p = p 232 | 233 | def __call__(self, img, target): 234 | if random.random() < self.p: 235 | return self.transforms1(img, target) 236 | return self.transforms2(img, target) 237 | 238 | 239 | class ToTensor(object): 240 | def __call__(self, img, target): 241 | return F.to_tensor(img), target 242 | 243 | 244 | class RandomErasing(object): 245 | 246 | def __init__(self, *args, **kwargs): 247 | self.eraser = T.RandomErasing(*args, **kwargs) 248 | 249 | def __call__(self, img, target): 250 | return self.eraser(img), target 251 | 252 | 253 | class Normalize(object): 254 | def __init__(self, mean, std): 255 | self.mean = mean 256 | self.std = std 257 | 258 | def __call__(self, image, target=None): 259 | image = F.normalize(image, mean=self.mean, std=self.std) 260 | if target is None: 261 | return image, None 262 | target = target.copy() 263 | h, w = image.shape[-2:] 264 | if "boxes" in target: 265 | boxes = target["boxes"] 266 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) 267 | target["boxes"] = boxes 268 | return image, target 269 | 270 | 271 | class Compose(object): 272 | def __init__(self, transforms): 273 | self.transforms = transforms 274 | 275 | def __call__(self, image, target): 276 | for t in self.transforms: 277 | image, target = t(image, target) 278 | return image, target 279 | 280 | def __repr__(self): 281 | format_string = self.__class__.__name__ + "(" 282 | for t in self.transforms: 283 | format_string += "\n" 284 | format_string += " {0}".format(t) 285 | format_string += "\n)" 286 | return format_string 287 | -------------------------------------------------------------------------------- /evaluate_detection/voc_orig.py: -------------------------------------------------------------------------------- 1 | import evaluate_detection.transforms as T 2 | # partly taken from https://github.com/pytorch/vision/blob/master/torchvision/datasets/voc.py 3 | import functools 4 | import torch 5 | 6 | import os 7 | import tarfile 8 | import collections 9 | 10 | from torchvision.datasets import VisionDataset 11 | import xml.etree.ElementTree as ET 12 | from PIL import Image 13 | from torchvision.datasets.utils import download_url 14 | 15 | 16 | 17 | CLASS_NAMES = ( 18 | "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", 19 | "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", 20 | "pottedplant", "sheep", "sofa", "train", "tvmonitor" 21 | ) 22 | 23 | DATASET_YEAR_DICT = { 24 | '2012': { 25 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', 26 | 'filename': 'VOCtrainval_11-May-2012.tar', 27 | 'md5': '6cd6e144f989b92b3379bac3b3de84fd', 28 | 'base_dir': os.path.join('VOCdevkit', 'VOC2012') 29 | }, 30 | '2011': { 31 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar', 32 | 'filename': 'VOCtrainval_25-May-2011.tar', 33 | 'md5': '6c3384ef61512963050cb5d687e5bf1e', 34 | 'base_dir': os.path.join('TrainVal', 'VOCdevkit', 'VOC2011') 35 | }, 36 | '2010': { 37 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar', 38 | 'filename': 'VOCtrainval_03-May-2010.tar', 39 | 'md5': 'da459979d0c395079b5c75ee67908abb', 40 | 'base_dir': os.path.join('VOCdevkit', 'VOC2010') 41 | }, 42 | '2009': { 43 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar', 44 | 'filename': 'VOCtrainval_11-May-2009.tar', 45 | 'md5': '59065e4b188729180974ef6572f6a212', 46 | 'base_dir': os.path.join('VOCdevkit', 'VOC2009') 47 | }, 48 | '2008': { 49 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar', 50 | 'filename': 'VOCtrainval_11-May-2012.tar', 51 | 'md5': '2629fa636546599198acfcfbfcf1904a', 52 | 'base_dir': os.path.join('VOCdevkit', 'VOC2008') 53 | }, 54 | '2007': { 55 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', 56 | 'filename': 'VOCtrainval_06-Nov-2007.tar', 57 | 'md5': 'c52e279531787c972589f7e41ab4ae64', 58 | 'base_dir': os.path.join('VOCdevkit', 'VOC2007') 59 | }, 60 | '2007-test': { 61 | 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar', 62 | 'filename': 'VOCtest_06-Nov-2007.tar', 63 | 'md5': 'b6e924de25625d8de591ea690078ad9f', 64 | 'base_dir': os.path.join('VOCdevkit', 'VOC2007') 65 | } 66 | } 67 | 68 | def make_transforms(image_set, imgs_size=224, padding=1): 69 | normalize = T.Compose([ 70 | T.ToTensor(), 71 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 72 | ]) 73 | 74 | rec_size = imgs_size // 2 - padding 75 | scales = [(rec_size, rec_size)] 76 | 77 | if image_set == 'train': 78 | return T.Compose([ 79 | T.RandomHorizontalFlip(), 80 | T.RandomResize(scales), 81 | normalize, 82 | ]) 83 | 84 | if image_set == 'val': 85 | return T.Compose([ 86 | T.RandomResize(scales), 87 | normalize, 88 | ]) 89 | 90 | raise ValueError(f'unknown {image_set}') 91 | class VOCDetection(VisionDataset): 92 | """`Pascal VOC `_ Detection Dataset. 93 | Args: 94 | root (string): Root directory of the VOC Dataset. 95 | year (string, optional): The dataset year, supports years 2007 to 2012. 96 | image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` 97 | download (bool, optional): If true, downloads the dataset from the internet and 98 | puts it in root directory. If dataset is already downloaded, it is not 99 | downloaded again. 100 | (default: alphabetic indexing of VOC's 20 classes). 101 | transform (callable, optional): A function/transform that takes in an PIL image 102 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 103 | target_transform (callable, required): A function/transform that takes in the 104 | target and transforms it. 105 | transforms (callable, optional): A function/transform that takes input sample and its target as entry 106 | and returns a transformed version. 107 | """ 108 | 109 | def __init__(self, 110 | root, 111 | years='2012', 112 | image_sets='train', 113 | transform=None, 114 | target_transform=None, 115 | transforms=None, 116 | no_cats=False, 117 | keep_single_objs_only=1, 118 | filter_by_mask_size=1): 119 | super(VOCDetection, self).__init__(root, transforms, transform, target_transform) 120 | self.images = [] 121 | self.annotations = [] 122 | self.imgids = [] 123 | self.imgid2annotations = {} 124 | self.image_set = [] 125 | 126 | self.CLASS_NAMES = CLASS_NAMES 127 | self.MAX_NUM_OBJECTS = 64 128 | self.no_cats = no_cats 129 | base_dir = os.path.dirname(os.path.abspath(__file__)) 130 | # loading random per class support, randomly chosen from pascal 2012 train partition 131 | self.support_set = torch.load(os.path.join(base_dir, '2012_support_set.pth')) 132 | # load pascal 2012 val samples that have single object and occupy less than 20% of the image. 133 | self.val_flattened_set = torch.load(os.path.join(base_dir, '2012_val_flattened_set.pth')) 134 | 135 | 136 | for year, image_set in zip(years, image_sets): 137 | 138 | if year == "2007" and image_set == "test": 139 | year = "2007-test" 140 | valid_sets = ["train", "trainval", "val"] 141 | if year == "2007-test": 142 | valid_sets.append("test") 143 | 144 | base_dir = DATASET_YEAR_DICT[year]['base_dir'] 145 | voc_root = os.path.join(self.root, base_dir) 146 | image_dir = os.path.join(voc_root, 'JPEGImages') 147 | annotation_dir = os.path.join(voc_root, 'Annotations') 148 | 149 | if not os.path.isdir(voc_root): 150 | raise RuntimeError('Dataset not found or corrupted.' + 151 | ' You can use download=True to download it') 152 | file_names = self.extract_fns(image_set, voc_root) 153 | self.image_set.extend(file_names) 154 | 155 | self.images.extend([os.path.join(image_dir, x + ".jpg") for x in file_names]) 156 | self.annotations.extend([os.path.join(annotation_dir, x + ".xml") for x in file_names]) 157 | 158 | self.imgids.extend(self.convert_image_id(x, to_integer=True) for x in file_names) 159 | self.imgid2annotations.update(dict(zip(self.imgids, self.annotations))) 160 | 161 | if keep_single_objs_only: 162 | single_indices = [] 163 | for index in range(len(self.imgids)): 164 | target, instances = self.load_instances(self.imgids[index]) 165 | if len(instances) == 1: 166 | single_indices.append(index) 167 | self.images = [self.images[i] for i in range(len(self.images)) if i in single_indices] 168 | self.annotations = [self.annotations[i] for i in range(len(self.annotations)) if i in single_indices] 169 | self.imgids = [self.imgids[i] for i in range(len(self.imgids)) if i in single_indices] 170 | 171 | if filter_by_mask_size: 172 | valid_mask_size_indices = [] 173 | for index in range(len(self.imgids)): 174 | target, instances = self.load_instances(self.imgids[index]) 175 | s = target['annotation']['size'] 176 | image_area = int(s['width'])*int(s['height']) 177 | instance_area = instances[0]['area'] 178 | frac = instance_area / image_area 179 | if frac < 0.2: 180 | valid_mask_size_indices.append(index) 181 | self.images = [self.images[i] for i in range(len(self.images)) if i in valid_mask_size_indices] 182 | self.annotations = [self.annotations[i] for i in range(len(self.annotations)) if i in valid_mask_size_indices] 183 | self.imgids = [self.imgids[i] for i in range(len(self.imgids)) if i in valid_mask_size_indices] 184 | 185 | 186 | 187 | assert (len(self.images) == len(self.annotations) == len(self.imgids)) 188 | 189 | @staticmethod 190 | def convert_image_id(img_id, to_integer=False, to_string=False, prefix='2021'): 191 | if to_integer: 192 | return int(prefix + img_id.replace('_', '')) 193 | if to_string: 194 | x = str(img_id) 195 | assert x.startswith(prefix) 196 | x = x[len(prefix):] 197 | if len(x) == 6: 198 | return x 199 | return x[:4] + '_' + x[4:] 200 | 201 | @functools.lru_cache(maxsize=None) 202 | def load_instances(self, img_id): 203 | tree = ET.parse(self.imgid2annotations[img_id]) 204 | target = self.parse_voc_xml(tree.getroot()) 205 | 206 | image_id = target['annotation']['filename'] 207 | instances = [] 208 | for obj in target['annotation']['object']: 209 | cls = obj["name"] 210 | # We include "difficult" samples in training. 211 | # Based on limited experiments, they don't hurt accuracy. 212 | difficult = int(obj["difficult"]) 213 | # if difficult == 1: 214 | # continue 215 | bbox = obj["bndbox"] 216 | bbox = [float(bbox[x]) for x in ["xmin", "ymin", "xmax", "ymax"]] 217 | # Original annotations are integers in the range [1, W or H] 218 | # Assuming they mean 1-based pixel indices (inclusive), 219 | # a box with annotation (xmin=1, xmax=W) covers the whole image. 220 | # In coordinate space this is represented by (xmin=0, xmax=W) 221 | bbox[0] -= 1.0 222 | bbox[1] -= 1.0 223 | instance = dict( 224 | category_id=1 if self.no_cats else CLASS_NAMES.index(cls), 225 | bbox=bbox, 226 | area=(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]), 227 | difficult=difficult, 228 | image_id=img_id 229 | ) 230 | instances.append(instance) 231 | 232 | assert len(instances) <= self.MAX_NUM_OBJECTS 233 | return target, instances 234 | 235 | def extract_fns(self, image_set, voc_root): 236 | splits_dir = os.path.join(voc_root, 'ImageSets/Main') 237 | split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') 238 | with open(os.path.join(split_f), "r") as f: 239 | file_names = [x.strip() for x in f.readlines()] 240 | return file_names 241 | 242 | def __getitem__(self, idx): 243 | """ 244 | Args: 245 | index (int): Index 246 | Returns: 247 | tuple: (image, target) where target is a dictionary of the XML tree. 248 | """ 249 | index, label = self.val_flattened_set[idx] 250 | img = Image.open(self.images[index]).convert('RGB') 251 | target, instances = self.load_instances(self.imgids[index]) 252 | # keep instance with a same label 253 | w, h = map(target['annotation']['size'].get, ['width', 'height']) 254 | target = dict( 255 | image_id=torch.tensor([self.imgids[index]], dtype=torch.int64), 256 | labels=torch.tensor([i['category_id'] for i in instances], dtype=torch.int64), 257 | area=torch.tensor([i['area'] for i in instances], dtype=torch.float32), 258 | boxes=torch.as_tensor([i['bbox'] for i in instances], dtype=torch.float32), 259 | orig_size=torch.as_tensor([int(h), int(w)]), 260 | size=torch.as_tensor([int(h), int(w)]), 261 | iscrowd=torch.zeros(len(instances), dtype=torch.uint8) 262 | ) 263 | 264 | if self.transforms is not None: 265 | img, target = self.transforms(img, target) 266 | 267 | return img, target 268 | 269 | def __len__(self): 270 | return len(self.imgids) 271 | 272 | def parse_voc_xml(self, node): 273 | voc_dict = {} 274 | children = list(node) 275 | if children: 276 | def_dic = collections.defaultdict(list) 277 | for dc in map(self.parse_voc_xml, children): 278 | for ind, v in dc.items(): 279 | def_dic[ind].append(v) 280 | if node.tag == 'annotation': 281 | def_dic['object'] = [def_dic['object']] 282 | voc_dict = { 283 | node.tag: 284 | {ind: v[0] if len(v) == 1 else v 285 | for ind, v in def_dic.items()} 286 | } 287 | if node.text: 288 | text = node.text.strip() 289 | if not children: 290 | voc_dict[node.tag] = text 291 | return voc_dict 292 | 293 | 294 | def download_extract(url, root, filename, md5): 295 | download_url(url, root, filename, md5) 296 | with tarfile.open(os.path.join(root, filename), "r") as tar: 297 | tar.extractall(path=root) 298 | 299 | 300 | -------------------------------------------------------------------------------- /figures_dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirbar/visual_prompting/6351e0d786c2d25bfdd5fe504153b1dfb345fa81/figures_dataset/__init__.py -------------------------------------------------------------------------------- /figures_dataset/download_links.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from multiprocessing import Pool 3 | 4 | import pandas as pd 5 | import glob 6 | import logging 7 | import shutil 8 | import os 9 | 10 | from PIL import Image 11 | from arxiv import Search, Client 12 | from pdf2image import convert_from_path 13 | import tarfile 14 | import PIL 15 | from tqdm import tqdm 16 | import numpy as np 17 | 18 | logger = logging.getLogger(__name__) 19 | CHUNK_SIZE = 100 20 | MAX_SIZE = 1024 21 | big_slow_client = Client( 22 | page_size=CHUNK_SIZE, 23 | delay_seconds=10, 24 | num_retries=100 25 | ) 26 | 27 | 28 | def download_paper_from_ids(paper, paper_id, directory, filter_files=None): 29 | file_location = paper.download_source() 30 | assert file_location.endswith(".tar.gz") 31 | return extract_file(file_location, directory, paper_id, filter_files) 32 | 33 | 34 | def _resize_and_save(path, file, paper_id): 35 | saved_fp = os.path.join(path, file.name) 36 | try: 37 | if file.name.endswith(".pdf"): 38 | images = convert_from_path(saved_fp) 39 | output_path = os.path.join(path, f"{file.name}_{paper_id}_0.png") 40 | else: 41 | images = [Image.open(saved_fp)] 42 | output_path = saved_fp 43 | 44 | os.remove(saved_fp) 45 | 46 | if len(images) != 1: 47 | return 48 | 49 | img = images[0] 50 | img = img.convert('RGB') 51 | img.thumbnail((MAX_SIZE, MAX_SIZE)) 52 | img.save(output_path, "PNG") 53 | except PIL.Image.DecompressionBombError: 54 | print("DecompressionBombError") 55 | 56 | 57 | def extract_file(file_location, directory, paper_id, filter_files): 58 | tar_file = tarfile.open(file_location, 'r:gz') 59 | path = os.path.join(directory, paper_id) 60 | os.makedirs(path, exist_ok=True) 61 | 62 | while True: 63 | next_file = tar_file.next() 64 | if next_file is None: 65 | break 66 | if not _check_file_name(next_file.name): 67 | continue 68 | 69 | next_file.name = next_file.name.replace('/', '_') 70 | if filter_files is not None and next_file.name not in filter_files: 71 | continue 72 | tar_file.extract(next_file, path) 73 | # print('Extracted {}'.format(next_file.name)) 74 | _resize_and_save(path, next_file, paper_id) 75 | tar_file.close() 76 | os.remove(file_location) 77 | return path 78 | 79 | 80 | def url_to_id(url: str) -> str: 81 | """ 82 | Parse the given URL of the form `https://arxiv.org/abs/1907.13625` to the id `1907.13625`. 83 | 84 | Args: 85 | url: Input arxiv URL. 86 | 87 | Returns: 88 | str: ArXiv article ID. 89 | """ 90 | if url.endswith(".pdf"): 91 | url = url[:-4] 92 | 93 | return url.split("/")[-1] 94 | 95 | 96 | def clean_up_path(path): 97 | for f in glob.glob(os.path.join(path, '*')): 98 | if os.path.isdir(f): 99 | try: 100 | shutil.rmtree(f) 101 | except OSError as e: 102 | print(e) 103 | 104 | 105 | def _check_file_name(file_name: str): 106 | file_name = file_name.lower() 107 | file_suffix = file_name.split(".")[-1] 108 | if file_suffix in ['jpeg', 'jpg', 'png', 'gif', 'bmp', 'tiff', 'tif', 'pdf']: 109 | if 'arch' in file_name or 'pipeline' in file_name: 110 | # Remove architecture 111 | return False 112 | if 'CVPR_Template' in file_name: 113 | # Remove templates 114 | return False 115 | return True 116 | return False 117 | 118 | 119 | def main(args): 120 | assert args.split in ['train', 'val'] 121 | logging.basicConfig(level=logging.INFO) 122 | df = pd.read_csv('df_train.csv', dtype={'paper_id': str, 'img_name': str}) 123 | agg_df = df.groupby(['paper_id'])['img_name'].apply(list) 124 | if args.workers > 1: 125 | with Pool(args.workers) as p: 126 | print(p.map(download_figures, np.array_split(agg_df, args.workers))) 127 | else: 128 | download_figures(agg_df) 129 | 130 | 131 | def download_figures(agg_df): 132 | output_dir = os.path.join(args.output_dir, args.split) 133 | paper_ids = list(agg_df.index) 134 | num_chunks = len(paper_ids) // CHUNK_SIZE + 1 135 | for i in tqdm(range(num_chunks)): 136 | success_counter = 0 137 | skip_counter = 0 138 | fail_counter = 0 139 | for paper in big_slow_client.results(Search(id_list=paper_ids[CHUNK_SIZE * i:CHUNK_SIZE * (i + 1)])): 140 | paper_id = paper.get_short_id().split('v')[0] 141 | try: 142 | if args.skip_exists and os.path.exists(os.path.join(output_dir, paper_id)): 143 | skip_counter += 1 144 | continue 145 | download_paper_from_ids(paper, paper_id, output_dir, filter_files=agg_df[paper_id]) 146 | success_counter += 1 147 | except Exception as e: 148 | fail_counter += 1 149 | print(e) 150 | print( 151 | f"Downloaded successfully {success_counter}/{CHUNK_SIZE} papers. Skipped {skip_counter}/{CHUNK_SIZE}. Failed to download {fail_counter}/{CHUNK_SIZE}") 152 | 153 | 154 | if __name__ == "__main__": 155 | parser = argparse.ArgumentParser('Dataset downloader', add_help=False) 156 | parser.add_argument('--output_dir', default='dataset', type=str) 157 | parser.add_argument('--split', default='train', type=str) 158 | parser.add_argument('--skip_exists', default=True, type=str) 159 | parser.add_argument('--workers', default=1, type=int) 160 | args = parser.parse_args() 161 | 162 | main(args) 163 | -------------------------------------------------------------------------------- /figures_dataset/requirements.txt: -------------------------------------------------------------------------------- 1 | arxiv==1.4.2 2 | pdf2image==1.16.0 -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import argparse 12 | import datetime 13 | import numpy as np 14 | import os 15 | import time 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.backends.cudnn as cudnn 20 | from torch.utils.tensorboard import SummaryWriter 21 | import torchvision.transforms as transforms 22 | from torchvision import datasets 23 | import timm.optim.optim_factory as optim_factory 24 | import util.misc as misc 25 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 26 | import models_mae 27 | from engine_pretrain import train_one_epoch 28 | 29 | 30 | def get_args_parser(): 31 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 32 | parser.add_argument('--save_ckpt_freq', default=100, type=int) 33 | parser.add_argument('--batch_size', default=64, type=int, 34 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 35 | parser.add_argument('--epochs', default=400, type=int) 36 | parser.add_argument('--accum_iter', default=1, type=int, 37 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 38 | 39 | # Model parameters 40 | parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL', 41 | help='Name of model to train') 42 | 43 | parser.add_argument('--input_size', default=224, type=int, 44 | help='images input size') 45 | 46 | parser.add_argument('--mask_ratio', default=0.75, type=float, 47 | help='Masking ratio (percentage of removed patches).') 48 | 49 | # Optimizer parameters 50 | parser.add_argument('--weight_decay', type=float, default=0.05, 51 | help='weight decay (default: 0.05)') 52 | 53 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 54 | help='learning rate (absolute lr)') 55 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 56 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 57 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 58 | help='lower lr bound for cyclic schedulers that hit 0') 59 | 60 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 61 | help='epochs to warmup LR') 62 | 63 | parser.add_argument('--break_after_epoch', type=int, metavar='N', help='break training after X epochs, to tune hyperparams and avoid messing with training schedule') 64 | 65 | 66 | # Dataset parameters 67 | parser.add_argument('--data_path', default='/shared/yossi_gandelsman/arxiv/arxiv_data/', type=str, 68 | help='dataset path') 69 | parser.add_argument('--imagenet_percent', default=0.5, type=float) 70 | parser.add_argument('--subsample', action='store_true') 71 | parser.set_defaults(subsample=False) 72 | parser.add_argument('--output_dir', default='./output_dir', 73 | help='path where to save, empty for no saving') 74 | parser.add_argument('--log_dir', default='./output_dir', 75 | help='path where to tensorboard log') 76 | parser.add_argument('--device', default='cuda', 77 | help='device to use for training / testing') 78 | parser.add_argument('--seed', default=0, type=int) 79 | parser.add_argument('--resume', default='', 80 | help='resume from checkpoint') 81 | 82 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 83 | help='start epoch') 84 | parser.add_argument('--num_workers', default=10, type=int) 85 | parser.add_argument('--pin_mem', action='store_true', 86 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 87 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 88 | parser.set_defaults(pin_mem=True) 89 | 90 | # distributed training parameters 91 | parser.add_argument('--world_size', default=1, type=int, 92 | help='number of distributed processes') 93 | parser.add_argument('--local_rank', default=-1, type=int) 94 | parser.add_argument('--dist_on_itp', action='store_true') 95 | parser.add_argument('--dist_url', default='env://', 96 | help='url used to set up distributed training') 97 | 98 | return parser 99 | 100 | 101 | def main(args): 102 | args.second_input_size = 224 103 | misc.init_distributed_mode(args) 104 | 105 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 106 | print("{}".format(args).replace(', ', ',\n')) 107 | 108 | device = torch.device(args.device) 109 | 110 | # fix the seed for reproducibility 111 | seed = args.seed + misc.get_rank() 112 | torch.manual_seed(seed) 113 | np.random.seed(seed) 114 | 115 | cudnn.benchmark = True 116 | 117 | # simple augmentation 118 | transforms_train = transforms.Compose([ 119 | transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic 120 | transforms.RandomHorizontalFlip(), 121 | transforms.ToTensor(), 122 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 123 | 124 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transforms_train) 125 | print(dataset_train) 126 | 127 | if True: # args.distributed: 128 | num_tasks = misc.get_world_size() 129 | global_rank = misc.get_rank() 130 | sampler_train = torch.utils.data.DistributedSampler( 131 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 132 | ) 133 | print("Sampler_train = %s" % str(sampler_train)) 134 | else: 135 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 136 | 137 | log_writer = None 138 | 139 | data_loader_train = torch.utils.data.DataLoader( 140 | dataset_train, sampler=sampler_train, 141 | batch_size=args.batch_size, 142 | num_workers=args.num_workers, 143 | pin_memory=args.pin_mem, 144 | drop_last=True, 145 | ) 146 | 147 | # define the model 148 | model = models_mae.__dict__[args.model]() 149 | 150 | model.to(device) 151 | epoch_size = len(dataset_train) 152 | print(f'epoch_size is {epoch_size}') 153 | model_without_ddp = model 154 | print("Model = %s" % str(model_without_ddp)) 155 | 156 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 157 | 158 | if args.lr is None: # only base_lr is specified 159 | args.lr = args.blr * eff_batch_size / 256 160 | 161 | base_lr = (args.lr * 256 / eff_batch_size) 162 | print("base lr: %.2e" % base_lr) 163 | print("actual lr: %.2e" % args.lr) 164 | print("accumulate grad iterations: %d" % args.accum_iter) 165 | print("effective batch size: %d" % eff_batch_size) 166 | 167 | if args.distributed: 168 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 169 | model_without_ddp = model.module 170 | 171 | # following timm: set wd as 0 for bias and norm layers 172 | for k, v in model_without_ddp.named_parameters(): 173 | if 'vae' in k: 174 | v.requires_grad = False 175 | 176 | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) 177 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 178 | print(optimizer) 179 | loss_scaler = NativeScaler() 180 | 181 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 182 | 183 | print(f"Start training for {args.epochs} epochs") 184 | start_time = time.time() 185 | for epoch in range(args.start_epoch, args.epochs): 186 | if args.distributed: 187 | data_loader_train.sampler.set_epoch(epoch) 188 | train_one_epoch( 189 | model, data_loader_train, 190 | optimizer, device, epoch, loss_scaler, 191 | log_writer=log_writer, 192 | args=args, 193 | epoch_size=epoch_size // eff_batch_size 194 | ) 195 | if args.output_dir and (epoch % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs): 196 | misc.save_model( 197 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 198 | loss_scaler=loss_scaler, epoch=epoch) 199 | 200 | total_time = time.time() - start_time 201 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 202 | print('Training time {}'.format(total_time_str)) 203 | if misc.is_main_process(): 204 | run.finish() 205 | 206 | if __name__ == '__main__': 207 | args = get_args_parser() 208 | args = args.parse_args() 209 | if args.output_dir: 210 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 211 | main(args) 212 | -------------------------------------------------------------------------------- /models_mae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | import os.path 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from timm.models.vision_transformer import PatchEmbed, Block 18 | from util.pos_embed import get_2d_sincos_pos_embed 19 | from vqgan import get_vq_model 20 | 21 | 22 | class MaskedAutoencoderViT(nn.Module): 23 | """ Masked Autoencoder with VisionTransformer backbone 24 | """ 25 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 26 | embed_dim=1024, depth=24, num_heads=16, 27 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 28 | mlp_ratio=4., norm_layer=nn.LayerNorm): 29 | super().__init__() 30 | 31 | # -------------------------------------------------------------------------- 32 | # MAE encoder specifics 33 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 34 | num_patches = self.patch_embed.num_patches 35 | 36 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 37 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 38 | 39 | self.blocks = nn.ModuleList([ 40 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 41 | for i in range(depth)]) 42 | self.norm = norm_layer(embed_dim) 43 | self.vae = get_vq_model().eval() 44 | vocab_size = 1024 45 | 46 | # -------------------------------------------------------------------------- 47 | 48 | # -------------------------------------------------------------------------- 49 | # MAE decoder specifics 50 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 51 | 52 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 53 | 54 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 55 | 56 | self.decoder_blocks = nn.ModuleList([ 57 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 58 | for i in range(decoder_depth)]) 59 | 60 | self.decoder_norm = norm_layer(decoder_embed_dim) 61 | self.decoder_pred = nn.Linear(decoder_embed_dim, vocab_size, bias=True) # decoder to patch 62 | # -------------------------------------------------------------------------- 63 | self.initialize_weights() 64 | 65 | def initialize_weights(self): 66 | # initialization 67 | # initialize (and freeze) pos_embed by sin-cos embedding 68 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 69 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 70 | 71 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 72 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 73 | 74 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 75 | w = self.patch_embed.proj.weight.data 76 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 77 | 78 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 79 | torch.nn.init.normal_(self.cls_token, std=.02) 80 | torch.nn.init.normal_(self.mask_token, std=.02) 81 | 82 | # initialize nn.Linear and nn.LayerNorm 83 | self.apply(self._init_weights) 84 | 85 | def _init_weights(self, m): 86 | if isinstance(m, nn.Linear): 87 | # we use xavier_uniform following official JAX ViT: 88 | torch.nn.init.xavier_uniform_(m.weight) 89 | if isinstance(m, nn.Linear) and m.bias is not None: 90 | nn.init.constant_(m.bias, 0) 91 | elif isinstance(m, nn.LayerNorm): 92 | nn.init.constant_(m.bias, 0) 93 | nn.init.constant_(m.weight, 1.0) 94 | 95 | def patchify(self, imgs): 96 | """ 97 | imgs: (N, 3, H, W) 98 | x: (N, L, patch_size**2 *3) 99 | """ 100 | p = self.patch_embed.patch_size[0] 101 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 102 | 103 | h = w = imgs.shape[2] // p 104 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 105 | x = torch.einsum('nchpwq->nhwpqc', x) 106 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 107 | return x 108 | 109 | def unpatchify(self, x): 110 | """ 111 | x: (N, L, patch_size**2 *3) 112 | imgs: (N, 3, H, W) 113 | """ 114 | p = self.patch_embed.patch_size[0] 115 | h = w = int(x.shape[1]**.5) 116 | assert h * w == x.shape[1] 117 | 118 | x = x.reshape(shape=(x.shape[0], h, w, p, p, -1)) 119 | x = torch.einsum('nhwpqc->nchpwq', x) 120 | imgs = x.reshape(shape=(x.shape[0], -1, h * p, h * p)) 121 | return imgs 122 | 123 | def random_masking(self, x, mask_ratio): 124 | """ 125 | Perform per-sample random masking by per-sample shuffling. 126 | Per-sample shuffling is done by argsort random noise. 127 | x: [N, L, D], sequence 128 | """ 129 | N, L, D = x.shape # batch, length, dim 130 | len_keep = int(L * (1 - mask_ratio)) 131 | 132 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 133 | 134 | # sort noise for each sample 135 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 136 | ids_restore = torch.argsort(ids_shuffle, dim=1) 137 | 138 | # keep the first subset 139 | ids_keep = ids_shuffle[:, :len_keep] 140 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 141 | 142 | # generate the binary mask: 0 is keep, 1 is remove 143 | mask = torch.ones([N, L], device=x.device) 144 | mask[:, :len_keep] = 0 145 | # unshuffle to get the binary mask 146 | mask = torch.gather(mask, dim=1, index=ids_restore) 147 | 148 | return x_masked, mask, ids_restore 149 | 150 | def forward_encoder(self, x, mask_ratio): 151 | # embed patches 152 | x = self.patch_embed(x) 153 | 154 | # add pos embed w/o cls token 155 | x = x + self.pos_embed[:, 1:, :] 156 | 157 | # masking: length -> length * mask_ratio 158 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 159 | 160 | # append cls token 161 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 162 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 163 | x = torch.cat((cls_tokens, x), dim=1) 164 | 165 | # apply Transformer blocks 166 | for blk in self.blocks: 167 | x = blk(x) 168 | x = self.norm(x) 169 | 170 | return x, mask, ids_restore 171 | 172 | def forward_decoder(self, x, ids_restore): 173 | # embed tokens 174 | x = self.decoder_embed(x) 175 | 176 | # append mask tokens to sequence 177 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 178 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 179 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 180 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 181 | 182 | # add pos embed 183 | x = x + self.decoder_pos_embed 184 | 185 | # apply Transformer blocks 186 | for blk in self.decoder_blocks: 187 | x = blk(x) 188 | x = self.decoder_norm(x) 189 | 190 | # predictor projection 191 | x = self.decoder_pred(x) 192 | 193 | # remove cls token 194 | x = x[:, 1:, :] 195 | 196 | return x 197 | 198 | def forward_loss(self, imgs, pred, mask): 199 | """ 200 | imgs: [N, 3, H, W] 201 | pred: [N, L, p*p*3] 202 | mask: [N, L], 0 is keep, 1 is remove, 203 | """ 204 | with torch.no_grad(): 205 | target = self.vae.get_codebook_indices(imgs).flatten(1) 206 | loss = nn.CrossEntropyLoss(reduction='none')(input=pred.permute(0, 2, 1), target=target) 207 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 208 | return loss 209 | 210 | def forward(self, imgs, visual_tokens=None, mask_ratio=0.75, inpt_mask=None): 211 | loss = {} 212 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 213 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 214 | if visual_tokens is not None: 215 | loss['mae'] = self.forward_loss(visual_tokens, pred, mask) 216 | return loss, pred, mask 217 | 218 | 219 | def mae_vit_small_patch16(**kwargs): 220 | model = MaskedAutoencoderViT( 221 | patch_size=16, embed_dim=384, depth=12, num_heads=12, 222 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 223 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 224 | return model 225 | 226 | 227 | def mae_vit_base_patch16_dec512d8b(**kwargs): 228 | model = MaskedAutoencoderViT( 229 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 230 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 231 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 232 | return model 233 | 234 | 235 | def mae_vit_large_patch16_dec512d8b(**kwargs): 236 | model = MaskedAutoencoderViT( 237 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 238 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 239 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 240 | return model 241 | 242 | 243 | def mae_vit_huge_patch14_dec512d8b(**kwargs): 244 | model = MaskedAutoencoderViT( 245 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 246 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 247 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 248 | return model 249 | 250 | 251 | # set recommended archs 252 | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks 253 | mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks 254 | mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks 255 | -------------------------------------------------------------------------------- /models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | import timm.models.vision_transformer 18 | 19 | 20 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 21 | """ Vision Transformer with support for global average pooling 22 | """ 23 | def __init__(self, global_pool=False, **kwargs): 24 | super(VisionTransformer, self).__init__(**kwargs) 25 | 26 | self.global_pool = global_pool 27 | if self.global_pool: 28 | norm_layer = kwargs['norm_layer'] 29 | embed_dim = kwargs['embed_dim'] 30 | self.fc_norm = norm_layer(embed_dim) 31 | 32 | del self.norm # remove the original norm 33 | 34 | def forward_features(self, x): 35 | B = x.shape[0] 36 | x = self.patch_embed(x) 37 | 38 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 39 | x = torch.cat((cls_tokens, x), dim=1) 40 | x = x + self.pos_embed 41 | x = self.pos_drop(x) 42 | 43 | for blk in self.blocks: 44 | x = blk(x) 45 | 46 | if self.global_pool: 47 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 48 | outcome = self.fc_norm(x) 49 | else: 50 | x = self.norm(x) 51 | outcome = x[:, 0] 52 | 53 | return outcome 54 | 55 | 56 | def vit_small_patch16(**kwargs): 57 | model = VisionTransformer( 58 | patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 59 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 60 | return model 61 | 62 | 63 | def vit_base_patch16(**kwargs): 64 | model = VisionTransformer( 65 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 66 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 67 | return model 68 | 69 | 70 | def vit_large_patch16(**kwargs): 71 | model = VisionTransformer( 72 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 73 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 74 | return model 75 | 76 | 77 | def vit_huge_patch14(**kwargs): 78 | model = VisionTransformer( 79 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 80 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 81 | return model -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.3.2 2 | omegaconf==2.1.2 3 | Pillow 4 | matplotlib 5 | einops 6 | scipy 7 | opencv-python -------------------------------------------------------------------------------- /tta.py: -------------------------------------------------------------------------------- 1 | from evaluate.reasoning_dataloader import background_transforms 2 | from evaluate.mae_utils import * 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | 7 | DEVICE = 'cuda' 8 | h, w = 224, 224 9 | 10 | 11 | class RowColShuffle(torch.nn.Module): 12 | def __init__(self, shuffle_rows=False, shuffle_cols=False, num_rows=3): 13 | super(RowColShuffle, self).__init__() 14 | self.shuffle_rows = shuffle_rows 15 | self.shuffle_cols = shuffle_cols 16 | self.num_rows = num_rows 17 | 18 | def forward(self, pairs): 19 | background_image = Image.new('RGB', (224, 224), color='black') 20 | canvas = background_transforms(background_image) 21 | 22 | v_order = np.arange(0, self.num_rows) 23 | if self.shuffle_rows: 24 | np.random.shuffle(v_order) 25 | 26 | shuffle_cols = False 27 | if self.shuffle_cols: 28 | shuffle_cols = np.random.choice([True, False]) 29 | 30 | padding = 1 31 | figure_size = 74 32 | for i in range(len(pairs)): 33 | img, label = pairs[v_order[i]] 34 | start_row = i * (figure_size + padding) 35 | if shuffle_cols: 36 | img, label = label, img 37 | canvas[:, start_row:start_row + figure_size, 224 // 2 - figure_size:224 // 2] = img 38 | canvas[:, start_row:start_row + figure_size, 224 // 2 + 1: 224 // 2 + 1 + figure_size] = label 39 | 40 | pred_row_idx = np.where(v_order == 2)[0][0] 41 | pred_col_idx = 1 if not shuffle_cols else 0 42 | canvas = np.array(canvas) 43 | 44 | # keep all but occluded part 45 | mask_psuedo_gt = np.ones((14, 14)) 46 | row_mask_start = int(np.floor(14 * float(pred_row_idx) / 3)) 47 | row_mask_end = int(np.ceil(14 * float(pred_row_idx + 1) / 3)) + 1 48 | mask_psuedo_gt[row_mask_start:row_mask_end, 2 + pred_col_idx * 5:2 + pred_col_idx * 5 + 5] = 0 49 | 50 | # keep everything in 20% except for the occluded part 51 | mask = np.round(np.random.uniform(0, 1, (14, 14)) >= 0.5) 52 | mask[row_mask_start:row_mask_end, 2 + pred_col_idx * 5:2 + pred_col_idx * 5 + 5] = 0 53 | 54 | _mask = obtain_values_from_mask(mask) 55 | _mask_psuedo_gt = obtain_values_from_mask(mask_psuedo_gt) 56 | 57 | return canvas, len(_mask), fill_to_full(_mask), mask, len(_mask_psuedo_gt), fill_to_full( 58 | _mask_psuedo_gt), mask_psuedo_gt 59 | 60 | def shuffle_cols(self, canvas, num_cols, h_order, fig_size, border_size): 61 | new_canvas = np.zeros_like(canvas) 62 | for i in range(num_cols): 63 | col_start = 224 - fig_size + i * (fig_size + border_size) 64 | original_col_start = 224 - fig_size + h_order[i] * (fig_size + border_size) 65 | new_canvas[:, :, col_start:col_start + fig_size] = canvas[:, :, 66 | original_col_start: original_col_start + fig_size] 67 | return new_canvas 68 | 69 | def shuffle_rows(self, canvas, num_rows, v_order, fig_size, border_size): 70 | new_canvas = np.zeros_like(canvas) 71 | 72 | for i in range(num_rows): 73 | start_col = i * (fig_size + border_size) 74 | old_start_col = v_order[i] * (fig_size + border_size) 75 | new_canvas[:, start_col:start_col + fig_size] = canvas[:, old_start_col: old_start_col + fig_size] 76 | 77 | return new_canvas 78 | 79 | 80 | def reverse_trans(im_paste, v_order, shuffle_cols, transpose): 81 | background_image = Image.new('RGB', (224, 224), color='black') 82 | new_canvas = np.array(background_image) 83 | 84 | if transpose: 85 | im_paste = np.transpose(im_paste, [1, 0, 2]) 86 | 87 | padding = 1 88 | figure_size = 74 89 | for i in range(len(v_order)): 90 | start_row = i * (figure_size + padding) 91 | img = im_paste[start_row: start_row + figure_size, 224 // 2 - figure_size - 1: 224 // 2 - 1] 92 | label = im_paste[start_row:start_row + figure_size, 224 // 2 + 1: 224 // 2 + 1 + figure_size] 93 | 94 | if shuffle_cols: 95 | img, label = label, img 96 | 97 | start_row = v_order[i] * (figure_size + padding) 98 | new_canvas[start_row:start_row + figure_size, 224 // 2 - figure_size - 1:224 // 2 - 1] = img 99 | new_canvas[start_row:start_row + figure_size, 224 // 2 + 1: 224 // 2 + 1 + figure_size] = label 100 | return new_canvas 101 | 102 | 103 | class TTA(torch.nn.Module): 104 | def __init__(self, shuffle_rows=False, shuffle_cols=False, transpose=False, num_rows=3): 105 | super(TTA, self).__init__() 106 | self.shuffle_rows = shuffle_rows 107 | self.shuffle_cols = shuffle_cols 108 | self.transpose = transpose 109 | self.num_rows = num_rows 110 | 111 | def forward(self, pairs): 112 | background_image = Image.new('RGB', (224, 224), color='black') 113 | canvas = background_transforms(background_image) 114 | 115 | v_order = np.arange(0, self.num_rows) 116 | if self.shuffle_rows: 117 | v_order = [2, 0, 1] 118 | 119 | shuffle_cols = False 120 | if self.shuffle_cols: 121 | shuffle_cols = True 122 | 123 | padding = 1 124 | figure_size = 74 125 | for i in range(len(pairs)): 126 | img, label = pairs[v_order[i]] 127 | start_row = i * (figure_size + padding) 128 | if shuffle_cols: 129 | img, label = label, img 130 | canvas[:, start_row:start_row + figure_size, 224 // 2 - figure_size - 1:224 // 2 - 1] = img 131 | canvas[:, start_row:start_row + figure_size, 224 // 2 + 1: 224 // 2 + 1 + figure_size] = label 132 | 133 | pred_row_idx = np.where(v_order == 2)[0][0] 134 | pred_col_idx = 1 if not shuffle_cols else 0 135 | canvas = np.array(canvas) 136 | 137 | # keep all but occluded part 138 | mask_psuedo_gt = np.ones((14, 14)) 139 | row_mask_start = int(np.floor(14 * float(pred_row_idx) / 3)) 140 | row_mask_end = int(np.ceil(14 * float(pred_row_idx + 1) / 3)) + 1 141 | mask_psuedo_gt[row_mask_start:row_mask_end, 2 + pred_col_idx * 5:2 + pred_col_idx * 5 + 5] = 0 142 | 143 | transpose_img = False 144 | if self.transpose: 145 | transpose_img = True 146 | 147 | if transpose_img: 148 | mask_psuedo_gt = np.transpose(mask_psuedo_gt, [1, 0]) 149 | canvas = np.transpose(canvas, [0, 2, 1]) 150 | 151 | _mask_psuedo_gt = obtain_values_from_mask(mask_psuedo_gt) 152 | return canvas, len(_mask_psuedo_gt), fill_to_full( 153 | _mask_psuedo_gt), mask_psuedo_gt, v_order, shuffle_cols, transpose_img 154 | 155 | def shuffle_cols(self, canvas, num_cols, h_order, fig_size, border_size): 156 | new_canvas = np.zeros_like(canvas) 157 | for i in range(num_cols): 158 | col_start = 224 - fig_size + i * (fig_size + border_size) 159 | original_col_start = 224 - fig_size + h_order[i] * (fig_size + border_size) 160 | new_canvas[:, :, col_start:col_start + fig_size] = canvas[:, :, 161 | original_col_start: original_col_start + fig_size] 162 | return new_canvas 163 | 164 | def shuffle_rows(self, canvas, num_rows, v_order, fig_size, border_size): 165 | new_canvas = np.zeros_like(canvas) 166 | 167 | for i in range(num_rows): 168 | start_col = i * (fig_size + border_size) 169 | old_start_col = v_order[i] * (fig_size + border_size) 170 | new_canvas[:, start_col:start_col + fig_size] = canvas[:, old_start_col: old_start_col + fig_size] 171 | 172 | return new_canvas 173 | 174 | 175 | -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t"): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | 91 | def update(self, **kwargs): 92 | for k, v in kwargs.items(): 93 | if v is None: 94 | continue 95 | if isinstance(v, torch.Tensor): 96 | v = v.item() 97 | assert isinstance(v, (float, int)) 98 | self.meters[k].update(v) 99 | 100 | def __getattr__(self, attr): 101 | if attr in self.meters: 102 | return self.meters[attr] 103 | if attr in self.__dict__: 104 | return self.__dict__[attr] 105 | raise AttributeError("'{}' object has no attribute '{}'".format( 106 | type(self).__name__, attr)) 107 | 108 | def __str__(self): 109 | loss_str = [] 110 | for name, meter in self.meters.items(): 111 | loss_str.append( 112 | "{}: {}".format(name, str(meter)) 113 | ) 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = '' 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt='{avg:.4f}') 130 | data_time = SmoothedValue(fmt='{avg:.4f}') 131 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 132 | log_msg = [ 133 | header, 134 | '[{0' + space_fmt + '}/{1}]', 135 | 'eta: {eta}', 136 | '{meters}', 137 | 'time: {time}', 138 | 'data: {data}' 139 | ] 140 | if torch.cuda.is_available(): 141 | log_msg.append('max mem: {memory:.0f}') 142 | log_msg = self.delimiter.join(log_msg) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if i % print_freq == 0 or i == len(iterable) - 1: 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | print(log_msg.format( 153 | i, len(iterable), eta=eta_string, 154 | meters=str(self), 155 | time=str(iter_time), data=str(data_time), 156 | memory=torch.cuda.max_memory_allocated() / MB)) 157 | else: 158 | print(log_msg.format( 159 | i, len(iterable), eta=eta_string, 160 | meters=str(self), 161 | time=str(iter_time), data=str(data_time))) 162 | i += 1 163 | end = time.time() 164 | total_time = time.time() - start_time 165 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 166 | print('{} Total time: {} ({:.4f} s / it)'.format( 167 | header, total_time_str, total_time / len(iterable))) 168 | 169 | 170 | def setup_for_distributed(is_master): 171 | """ 172 | This function disables printing when not in master process 173 | """ 174 | builtin_print = builtins.print 175 | 176 | def print(*args, **kwargs): 177 | force = kwargs.pop('force', False) 178 | force = force or (get_world_size() > 8) 179 | if is_master or force: 180 | now = datetime.datetime.now().time() 181 | builtin_print('[{}] '.format(now), end='') # print with time stamp 182 | builtin_print(*args, **kwargs) 183 | 184 | builtins.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if args.dist_on_itp: 218 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 219 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 220 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 221 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 222 | os.environ['LOCAL_RANK'] = str(args.gpu) 223 | os.environ['RANK'] = str(args.rank) 224 | os.environ['WORLD_SIZE'] = str(args.world_size) 225 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 226 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 227 | args.rank = int(os.environ["RANK"]) 228 | args.world_size = int(os.environ['WORLD_SIZE']) 229 | args.gpu = int(os.environ['LOCAL_RANK']) 230 | elif 'SLURM_PROCID' in os.environ: 231 | args.rank = int(os.environ['SLURM_PROCID']) 232 | args.gpu = args.rank % torch.cuda.device_count() 233 | else: 234 | print('Not using distributed mode') 235 | setup_for_distributed(is_master=True) # hack 236 | args.distributed = False 237 | return 238 | 239 | args.distributed = True 240 | 241 | torch.cuda.set_device(args.gpu) 242 | args.dist_backend = 'nccl' 243 | print('| distributed init (rank {}): {}, gpu {}'.format( 244 | args.rank, args.dist_url, args.gpu), flush=True) 245 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 246 | world_size=args.world_size, rank=args.rank) 247 | torch.distributed.barrier() 248 | setup_for_distributed(args.rank == 0) 249 | 250 | 251 | class NativeScalerWithGradNormCount: 252 | state_dict_key = "amp_scaler" 253 | 254 | def __init__(self): 255 | self._scaler = torch.cuda.amp.GradScaler() 256 | 257 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 258 | self._scaler.scale(loss).backward(create_graph=create_graph) 259 | if update_grad: 260 | if clip_grad is not None: 261 | assert parameters is not None 262 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 263 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 264 | else: 265 | self._scaler.unscale_(optimizer) 266 | norm = get_grad_norm_(parameters) 267 | self._scaler.step(optimizer) 268 | self._scaler.update() 269 | else: 270 | norm = None 271 | return norm 272 | 273 | def state_dict(self): 274 | return self._scaler.state_dict() 275 | 276 | def load_state_dict(self, state_dict): 277 | self._scaler.load_state_dict(state_dict) 278 | 279 | 280 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 281 | if isinstance(parameters, torch.Tensor): 282 | parameters = [parameters] 283 | parameters = [p for p in parameters if p.grad is not None] 284 | norm_type = float(norm_type) 285 | if len(parameters) == 0: 286 | return torch.tensor(0.) 287 | device = parameters[0].grad.device 288 | if norm_type == inf: 289 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 290 | else: 291 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 292 | return total_norm 293 | 294 | 295 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 296 | output_dir = Path(args.output_dir) 297 | epoch_name = str(epoch) 298 | if loss_scaler is not None: 299 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 300 | for checkpoint_path in checkpoint_paths: 301 | to_save = { 302 | 'model': model_without_ddp.state_dict(), 303 | 'optimizer': optimizer.state_dict(), 304 | 'epoch': epoch, 305 | 'scaler': loss_scaler.state_dict(), 306 | 'args': args, 307 | } 308 | 309 | save_on_master(to_save, checkpoint_path) 310 | else: 311 | client_state = {'epoch': epoch} 312 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 313 | 314 | 315 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 316 | if args.resume: 317 | if args.resume.startswith('https'): 318 | checkpoint = torch.hub.load_state_dict_from_url( 319 | args.resume, map_location='cpu', check_hash=True) 320 | else: 321 | checkpoint = torch.load(args.resume, map_location='cpu') 322 | model_without_ddp.load_state_dict(checkpoint['model']) 323 | print("Resume checkpoint %s" % args.resume) 324 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 325 | optimizer.load_state_dict(checkpoint['optimizer']) 326 | args.start_epoch = checkpoint['epoch'] + 1 327 | if 'scaler' in checkpoint: 328 | loss_scaler.load_state_dict(checkpoint['scaler']) 329 | print("With optim & sched!") 330 | 331 | 332 | def all_reduce_mean(x): 333 | world_size = get_world_size() 334 | if world_size > 1: 335 | x_reduce = torch.tensor(x).cuda() 336 | dist.all_reduce(x_reduce) 337 | x_reduce /= world_size 338 | return x_reduce.item() 339 | else: 340 | return x -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /viz_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import torch.nn.functional as F 3 | matplotlib.use('Agg') 4 | import torch 5 | from PIL import Image 6 | from torchvision.transforms import transforms, ToPILImage 7 | from glob import glob 8 | from matplotlib import pyplot as plt 9 | import numpy as np 10 | 11 | demo_images = glob("./imgs/*") 12 | imagenet_mean = np.array([0.485, 0.456, 0.406]) 13 | imagenet_std = np.array([0.229, 0.224, 0.225]) 14 | 15 | t = transforms.Compose([ 16 | transforms.Resize(256, interpolation=3), 17 | transforms.CenterCrop(224), 18 | transforms.ToTensor(), 19 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 20 | 21 | 22 | @torch.no_grad() 23 | def get_demo_predictions(args, device, model): 24 | figs = get_demo_predictions_with_mask(args, model, t) 25 | return {"image_%s" % i: fig for i, fig in enumerate(figs)} 26 | 27 | 28 | def show_image(image, ax, in_reverse=True): 29 | # image is [H, W, 3] 30 | assert image.shape[2] == 3 31 | ax.imshow(image, vmin=0, vmax=255) 32 | ax.axis('off') 33 | return 34 | 35 | 36 | @torch.no_grad() 37 | def get_demo_predictions_with_mask(args, model, t): 38 | num_patches = 14 39 | imgs = [] 40 | for p in glob("./imgs/*"): 41 | with open(p, 'rb') as f: 42 | png = Image.open(f).convert('RGBA') 43 | background = Image.new('RGBA', png.size, (255, 255, 255)) 44 | img = Image.alpha_composite(background, png).convert('RGB').resize((args.input_size, args.input_size), 45 | resample=Image.LANCZOS) 46 | img = t(img) 47 | imgs.append(img) 48 | imgs = torch.stack(imgs, dim=0) 49 | x = imgs.cuda(non_blocking=True) 50 | _, y, mask = model(x.float(), mask_ratio=0.75) 51 | y = y.argmax(dim=-1) 52 | y = model.vae.quantize.get_codebook_entry(y.reshape(-1), [y.shape[0], y.shape[-1] // num_patches, y.shape[-1] // num_patches, -1]) 53 | y = model.vae.decode(y).detach().cpu() 54 | y = F.interpolate(y, size=(224, 224), mode='bilinear').permute(0, 2, 3, 1) 55 | y = torch.clip(y * 255, 0, 255).int() 56 | 57 | mask = mask.detach() 58 | mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0] ** 2 * 3) # (N, H*W, p*p*3) 59 | mask = model.unpatchify(mask) # 1 is removing, 0 is keeping 60 | mask = torch.einsum('nchw->nhwc', mask).detach().cpu() 61 | 62 | x = torch.einsum('nchw->nhwc', x).to(mask) 63 | 64 | # masked image 65 | im_masked = x * (1 - mask) 66 | im_masked = torch.clip((im_masked * imagenet_std + imagenet_mean) * 255, 0, 255).int() 67 | # MAE reconstruction pasted with visible patches 68 | x = torch.clip((x * imagenet_std + imagenet_mean) * 255, 0, 255).int() 69 | im_paste = (x * (1 - mask) + y * mask).int() 70 | 71 | # make the plt figure larger 72 | # plt.figure() 73 | figs = [] 74 | for k in range(0, len(imgs), 4): 75 | fig, ax = plt.subplots(4, 4, figsize=(10, 10)) 76 | plt.subplots_adjust(wspace=0, hspace=0) 77 | for i in range(len(imgs[k:k + 4])): 78 | show_image(x[k + i], ax[i, 0]) 79 | show_image(im_masked[k+i], ax[i, 1]) 80 | show_image(y[k + i], ax[i, 2], in_reverse=False) 81 | show_image(im_paste[k+i], ax[i, 3]) 82 | 83 | for j in range(4): 84 | ax[i, j].set_xticklabels([]) 85 | ax[i, j].set_yticklabels([]) 86 | ax[i, j].set_aspect('equal') 87 | figs.append(fig) 88 | 89 | # plt.show() 90 | 91 | return figs 92 | --------------------------------------------------------------------------------