├── .gitignore ├── LICENSE ├── README.md ├── data ├── __init__.py ├── datasets.py ├── samplers.py └── threeaugment.py ├── detection ├── .gitignore ├── README.md ├── checkpoint.py ├── configs │ ├── _base_ │ │ ├── datasets │ │ │ ├── cityscapes_detection.py │ │ │ ├── cityscapes_instance.py │ │ │ ├── coco_detection.py │ │ │ ├── coco_instance.py │ │ │ ├── coco_instance_semantic.py │ │ │ ├── deepfashion.py │ │ │ ├── lvis_v0.5_instance.py │ │ │ ├── lvis_v1_instance.py │ │ │ ├── voc0712.py │ │ │ └── wider_face.py │ │ ├── default_runtime.py │ │ ├── models │ │ │ ├── cascade_mask_rcnn_pvtv2_b2_fpn.py │ │ │ ├── cascade_mask_rcnn_r50_fpn.py │ │ │ ├── cascade_rcnn_r50_fpn.py │ │ │ ├── fast_rcnn_r50_fpn.py │ │ │ ├── faster_rcnn_r50_caffe_c4.py │ │ │ ├── faster_rcnn_r50_caffe_dc5.py │ │ │ ├── faster_rcnn_r50_fpn.py │ │ │ ├── mask_rcnn_r50_caffe_c4.py │ │ │ ├── mask_rcnn_r50_fpn.py │ │ │ ├── retinanet_r50_fpn.py │ │ │ ├── rpn_r50_caffe_c4.py │ │ │ ├── rpn_r50_fpn.py │ │ │ └── ssd300.py │ │ └── schedules │ │ │ ├── schedule_1x.py │ │ │ ├── schedule_20e.py │ │ │ └── schedule_2x.py │ ├── mask_rcnn_repvit_m1_1_fpn_1x_coco.py │ ├── mask_rcnn_repvit_m1_5_fpn_1x_coco.py │ └── mask_rcnn_repvit_m2_3_fpn_1x_coco.py ├── dist_test.sh ├── dist_train.sh ├── eval.sh ├── logs │ ├── repvit_m1_1_coco.json │ ├── repvit_m1_5_coco.json │ └── repvit_m2_3_coco.json ├── mmcv_custom │ └── runner │ │ ├── checkpoint.py │ │ ├── epoch_based_runner.py │ │ └── optimizer.py ├── mmdet_custom │ └── apis │ │ └── train.py ├── repvit.py ├── slurm_train.sh ├── test.py ├── train.py └── train.sh ├── engine.py ├── eval.sh ├── export_coreml.py ├── figures ├── latency.png └── repvit_m0_9_latency.png ├── flops.py ├── logs ├── repvit_m0_9_distill_300e.txt ├── repvit_m0_9_distill_450e.txt ├── repvit_m1_0_distill_300e.txt ├── repvit_m1_0_distill_450e.txt ├── repvit_m1_1_distill_300e.txt ├── repvit_m1_1_distill_450e.txt ├── repvit_m1_5_distill_300e.txt ├── repvit_m1_5_distill_450e.txt ├── repvit_m2_3_distill_300e.txt └── repvit_m2_3_distill_450e.txt ├── losses.py ├── main.py ├── model ├── __init__.py └── repvit.py ├── requirements.txt ├── sam ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── app │ ├── .gitattributes │ ├── app.py │ ├── assets │ │ ├── .DS_Store │ │ ├── picture1.jpg │ │ ├── picture2.jpg │ │ ├── picture3.jpg │ │ ├── picture4.jpg │ │ ├── picture5.jpg │ │ └── picture6.jpg │ ├── requirements.txt │ └── utils │ │ ├── __init__.py │ │ ├── tools.py │ │ └── tools_gradio.py ├── assets │ ├── logo2.png │ ├── mask_box.jpg │ ├── mask_comparision.jpg │ ├── mask_point.jpg │ ├── model_diagram.jpg │ ├── notebook1.png │ └── notebook2.png ├── figures │ └── comparison.png ├── linter.sh ├── notebooks │ ├── automatic_mask_generator_example.ipynb │ ├── coreml_example.ipynb │ ├── images │ │ ├── picture1.jpg │ │ └── picture2.jpg │ └── predictor_example.ipynb ├── repvit_sam │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── build_sam.py │ ├── modeling │ │ ├── __init__.py │ │ ├── common.py │ │ ├── image_encoder.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ ├── repvit.py │ │ ├── sam.py │ │ ├── tiny_vit_sam.py │ │ └── transformer.py │ ├── predictor.py │ └── utils │ │ ├── __init__.py │ │ ├── amg.py │ │ ├── coreml.py │ │ ├── onnx.py │ │ └── transforms.py ├── scripts │ ├── amg.py │ ├── export_coreml_decoder.py │ ├── export_coreml_encoder.py │ └── export_onnx_model.py ├── setup.cfg └── setup.py ├── segmentation ├── .gitignore ├── README.md ├── align_resize.py ├── configs │ ├── _base_ │ │ ├── datasets │ │ │ └── ade20k.py │ │ ├── default_runtime.py │ │ ├── models │ │ │ └── fpn_r50.py │ │ └── schedules │ │ │ ├── schedule_160k.py │ │ │ ├── schedule_20k.py │ │ │ ├── schedule_40k.py │ │ │ └── schedule_80k.py │ └── sem_fpn │ │ ├── fpn_repvit_m1_1_ade20k_40k.py │ │ ├── fpn_repvit_m1_5_ade20k_40k.py │ │ └── fpn_repvit_m2_3_ade20k_40k.py ├── eval.sh ├── logs │ ├── repvit_m1_1_ade20k.json │ ├── repvit_m1_5_ade20k.json │ └── repvit_m2_3_ade20k.json ├── repvit.py ├── tools │ ├── analyze_logs.py │ ├── benchmark.py │ ├── browse_dataset.py │ ├── convert_datasets │ │ ├── chase_db1.py │ │ ├── cityscapes.py │ │ ├── coco_stuff10k.py │ │ ├── coco_stuff164k.py │ │ ├── drive.py │ │ ├── hrf.py │ │ ├── pascal_context.py │ │ ├── stare.py │ │ └── voc_aug.py │ ├── deploy_test.py │ ├── dist_test.sh │ ├── dist_train.sh │ ├── get_flops.py │ ├── model_converters │ │ ├── mit2mmseg.py │ │ ├── swin2mmseg.py │ │ └── vit2mmseg.py │ ├── onnx2tensorrt.py │ ├── print_config.py │ ├── publish_model.py │ ├── pytorch2onnx.py │ ├── pytorch2torchscript.py │ ├── slurm_test.sh │ ├── slurm_train.sh │ ├── test.py │ ├── torchserve │ │ ├── mmseg2torchserve.py │ │ ├── mmseg_handler.py │ │ └── test_torchserve.py │ └── train.py └── train.sh ├── speed_gpu.py ├── train.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | wandb 2 | coreml 3 | pretrain 4 | **/__pycache__ 5 | pretrain 6 | ignore 7 | *.zip 8 | checkpoints 9 | trt -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/data/__init__.py -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Build trainining/testing datasets 3 | ''' 4 | import os 5 | import json 6 | 7 | from torchvision import datasets, transforms 8 | from torchvision.datasets.folder import ImageFolder, default_loader 9 | import torch 10 | 11 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 12 | from timm.data import create_transform 13 | 14 | try: 15 | from timm.data import TimmDatasetTar 16 | except ImportError: 17 | # for higher version of timm 18 | from timm.data import ImageDataset as TimmDatasetTar 19 | 20 | class INatDataset(ImageFolder): 21 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 22 | category='name', loader=default_loader): 23 | self.transform = transform 24 | self.loader = loader 25 | self.target_transform = target_transform 26 | self.year = year 27 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 28 | path_json = os.path.join( 29 | root, f'{"train" if train else "val"}{year}.json') 30 | with open(path_json) as json_file: 31 | data = json.load(json_file) 32 | 33 | with open(os.path.join(root, 'categories.json')) as json_file: 34 | data_catg = json.load(json_file) 35 | 36 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 37 | 38 | with open(path_json_for_targeter) as json_file: 39 | data_for_targeter = json.load(json_file) 40 | 41 | targeter = {} 42 | indexer = 0 43 | for elem in data_for_targeter['annotations']: 44 | king = [] 45 | king.append(data_catg[int(elem['category_id'])][category]) 46 | if king[0] not in targeter.keys(): 47 | targeter[king[0]] = indexer 48 | indexer += 1 49 | self.nb_classes = len(targeter) 50 | 51 | self.samples = [] 52 | for elem in data['images']: 53 | cut = elem['file_name'].split('/') 54 | target_current = int(cut[2]) 55 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 56 | 57 | categors = data_catg[target_current] 58 | target_current_true = targeter[categors[category]] 59 | self.samples.append((path_current, target_current_true)) 60 | 61 | # __getitem__ and __len__ inherited from ImageFolder 62 | 63 | 64 | def build_dataset(is_train, args): 65 | transform = build_transform(is_train, args) 66 | 67 | if args.data_set == 'CIFAR': 68 | dataset = datasets.CIFAR100( 69 | args.data_path, train=is_train, transform=transform) 70 | nb_classes = 100 71 | elif args.data_set == 'IMNET': 72 | prefix = 'train' if is_train else 'val' 73 | data_dir = os.path.join(args.data_path, f'{prefix}.tar') 74 | if os.path.exists(data_dir): 75 | dataset = TimmDatasetTar(data_dir, transform=transform) 76 | else: 77 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 78 | dataset = datasets.ImageFolder(root, transform=transform) 79 | nb_classes = 1000 80 | elif args.data_set == 'IMNETEE': 81 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 82 | dataset = datasets.ImageFolder(root, transform=transform) 83 | nb_classes = 10 84 | elif args.data_set == 'FLOWERS': 85 | root = os.path.join(args.data_path, 'train' if is_train else 'test') 86 | dataset = datasets.ImageFolder(root, transform=transform) 87 | if is_train: 88 | dataset = torch.utils.data.ConcatDataset( 89 | [dataset for _ in range(100)]) 90 | nb_classes = 102 91 | elif args.data_set == 'INAT': 92 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 93 | category=args.inat_category, transform=transform) 94 | nb_classes = dataset.nb_classes 95 | elif args.data_set == 'INAT19': 96 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 97 | category=args.inat_category, transform=transform) 98 | nb_classes = dataset.nb_classes 99 | return dataset, nb_classes 100 | 101 | 102 | def build_transform(is_train, args): 103 | resize_im = args.input_size > 32 104 | if is_train: 105 | # this should always dispatch to transforms_imagenet_train 106 | transform = create_transform( 107 | input_size=args.input_size, 108 | is_training=True, 109 | color_jitter=args.color_jitter, 110 | auto_augment=args.aa, 111 | interpolation=args.train_interpolation, 112 | re_prob=args.reprob, 113 | re_mode=args.remode, 114 | re_count=args.recount, 115 | ) 116 | if not resize_im: 117 | # replace RandomResizedCropAndInterpolation with 118 | # RandomCrop 119 | transform.transforms[0] = transforms.RandomCrop( 120 | args.input_size, padding=4) 121 | return transform 122 | 123 | t = [] 124 | if args.finetune: 125 | t.append( 126 | transforms.Resize((args.input_size, args.input_size), 127 | interpolation=3) 128 | ) 129 | else: 130 | if resize_im: 131 | size = int((256 / 224) * args.input_size) 132 | t.append( 133 | # to maintain same ratio w.r.t. 224 images 134 | transforms.Resize(size, interpolation=3), 135 | ) 136 | t.append(transforms.CenterCrop(args.input_size)) 137 | 138 | t.append(transforms.ToTensor()) 139 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 140 | return transforms.Compose(t) 141 | -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Build samplers for data loading 3 | ''' 4 | import torch 5 | import torch.distributed as dist 6 | import math 7 | 8 | 9 | class RASampler(torch.utils.data.Sampler): 10 | """Sampler that restricts data loading to a subset of the dataset for distributed, 11 | with repeated augmentation. 12 | It ensures that different each augmented version of a sample will be visible to a 13 | different process (GPU) 14 | Heavily based on torch.utils.data.DistributedSampler 15 | """ 16 | 17 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 18 | if num_replicas is None: 19 | if not dist.is_available(): 20 | raise RuntimeError( 21 | "Requires distributed package to be available") 22 | num_replicas = dist.get_world_size() 23 | if rank is None: 24 | if not dist.is_available(): 25 | raise RuntimeError( 26 | "Requires distributed package to be available") 27 | rank = dist.get_rank() 28 | self.dataset = dataset 29 | self.num_replicas = num_replicas 30 | self.rank = rank 31 | self.epoch = 0 32 | self.num_samples = int( 33 | math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 34 | self.total_size = self.num_samples * self.num_replicas 35 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 36 | self.num_selected_samples = int(math.floor( 37 | len(self.dataset) // 256 * 256 / self.num_replicas)) 38 | self.shuffle = shuffle 39 | 40 | def __iter__(self): 41 | # deterministically shuffle based on epoch 42 | g = torch.Generator() 43 | g.manual_seed(self.epoch) 44 | if self.shuffle: 45 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 46 | else: 47 | indices = list(range(len(self.dataset))) 48 | 49 | # add extra samples to make it evenly divisible 50 | indices = [ele for ele in indices for i in range(3)] 51 | indices += indices[:(self.total_size - len(indices))] 52 | assert len(indices) == self.total_size 53 | 54 | # subsample 55 | indices = indices[self.rank:self.total_size:self.num_replicas] 56 | assert len(indices) == self.num_samples 57 | 58 | return iter(indices[:self.num_selected_samples]) 59 | 60 | def __len__(self): 61 | return self.num_selected_samples 62 | 63 | def set_epoch(self, epoch): 64 | self.epoch = epoch 65 | -------------------------------------------------------------------------------- /data/threeaugment.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3Augment implementation from (https://github.com/facebookresearch/deit/blob/main/augment.py) 3 | Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino) 4 | and timm DA(https://github.com/rwightman/pytorch-image-models) 5 | Can be called by adding "--ThreeAugment" to the command line 6 | """ 7 | import torch 8 | from torchvision import transforms 9 | 10 | from timm.data.transforms import str_to_pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor 11 | 12 | import numpy as np 13 | from torchvision import datasets, transforms 14 | import random 15 | 16 | 17 | 18 | from PIL import ImageFilter, ImageOps 19 | import torchvision.transforms.functional as TF 20 | 21 | 22 | class GaussianBlur(object): 23 | """ 24 | Apply Gaussian Blur to the PIL image. 25 | """ 26 | def __init__(self, p=0.1, radius_min=0.1, radius_max=2.): 27 | self.prob = p 28 | self.radius_min = radius_min 29 | self.radius_max = radius_max 30 | 31 | def __call__(self, img): 32 | do_it = random.random() <= self.prob 33 | if not do_it: 34 | return img 35 | 36 | img = img.filter( 37 | ImageFilter.GaussianBlur( 38 | radius=random.uniform(self.radius_min, self.radius_max) 39 | ) 40 | ) 41 | return img 42 | 43 | class Solarization(object): 44 | """ 45 | Apply Solarization to the PIL image. 46 | """ 47 | def __init__(self, p=0.2): 48 | self.p = p 49 | 50 | def __call__(self, img): 51 | if random.random() < self.p: 52 | return ImageOps.solarize(img) 53 | else: 54 | return img 55 | 56 | class gray_scale(object): 57 | """ 58 | Apply Solarization to the PIL image. 59 | """ 60 | def __init__(self, p=0.2): 61 | self.p = p 62 | self.transf = transforms.Grayscale(3) 63 | 64 | def __call__(self, img): 65 | if random.random() < self.p: 66 | return self.transf(img) 67 | else: 68 | return img 69 | 70 | 71 | 72 | class horizontal_flip(object): 73 | """ 74 | Apply Solarization to the PIL image. 75 | """ 76 | def __init__(self, p=0.2,activate_pred=False): 77 | self.p = p 78 | self.transf = transforms.RandomHorizontalFlip(p=1.0) 79 | 80 | def __call__(self, img): 81 | if random.random() < self.p: 82 | return self.transf(img) 83 | else: 84 | return img 85 | 86 | 87 | 88 | def new_data_aug_generator(args = None): 89 | img_size = args.input_size 90 | remove_random_resized_crop = False 91 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 92 | primary_tfl = [] 93 | scale=(0.08, 1.0) 94 | interpolation='bicubic' 95 | if remove_random_resized_crop: 96 | primary_tfl = [ 97 | transforms.Resize(img_size, interpolation=3), 98 | transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'), 99 | transforms.RandomHorizontalFlip() 100 | ] 101 | else: 102 | primary_tfl = [ 103 | RandomResizedCropAndInterpolation( 104 | img_size, scale=scale, interpolation=interpolation), 105 | transforms.RandomHorizontalFlip() 106 | ] 107 | 108 | 109 | secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0), 110 | Solarization(p=1.0), 111 | GaussianBlur(p=1.0)])] 112 | 113 | if args.color_jitter is not None and not args.color_jitter==0: 114 | secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter)) 115 | final_tfl = [ 116 | transforms.ToTensor(), 117 | transforms.Normalize( 118 | mean=torch.tensor(mean), 119 | std=torch.tensor(std)) 120 | ] 121 | return transforms.Compose(primary_tfl+secondary_tfl+final_tfl) 122 | -------------------------------------------------------------------------------- /detection/.gitignore: -------------------------------------------------------------------------------- 1 | work_dirs 2 | pretrain 3 | data 4 | det_pretrain -------------------------------------------------------------------------------- /detection/README.md: -------------------------------------------------------------------------------- 1 | # Object Detection and Instance Segmentation 2 | 3 | Detection and instance segmentation on MS COCO 2017 is implemented based on [MMDetection](https://github.com/open-mmlab/mmdetection). 4 | 5 | ## Models 6 | | Model | $AP^b$ | $AP_{50}^b$ | $AP_{75}^b$ | $AP^m$ | $AP_{50}^m$ | $AP_{75}^m$ | Latency | Ckpt | Log | 7 | |:---------------|:----:|:---:|:--:|:--:|:--:|:--:|:--:|:--:|:--:| 8 | | RepViT-M1.1 | 39.8 | 61.9 | 43.5 | 37.2 | 58.8 | 40.1 | 4.9ms | [M1.1](https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_1_coco.pth) | [M1.1](./logs/repvit_m1_1_coco.json) | 9 | | RepViT-M1.5 | 41.6 | 63.2 | 45.3 | 38.6 | 60.5 | 41.5 | 6.4ms | [M1.5](https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_5_coco.pth) | [M1.5](./logs/repvit_m1_5_coco.json) | 10 | | RepViT-M2.3 | 44.6 | 66.1 | 48.8 | 40.8 | 63.6 | 43.9 | 9.9ms | [M2.3](https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_3_coco.pth) | [M2.3](./logs/repvit_m2_3_coco.json) | 11 | 12 | ## Installation 13 | 14 | Install [mmcv-full](https://github.com/open-mmlab/mmcv) and [MMDetection v2.28.2](https://github.com/open-mmlab/mmdetection/tree/v2.28.2), 15 | Later versions should work as well. 16 | The easiest way is to install via [MIM](https://github.com/open-mmlab/mim) 17 | ``` 18 | pip install -U openmim 19 | mim install mmcv-full==1.7.1 20 | mim install mmdet==2.28.2 21 | ``` 22 | 23 | ## Data preparation 24 | 25 | Prepare COCO 2017 dataset according to the [instructions in MMDetection](https://github.com/open-mmlab/mmdetection/blob/master/docs/en/1_exist_data_model.md#test-existing-models-on-standard-datasets). 26 | The dataset should be organized as 27 | ``` 28 | detection 29 | ├── data 30 | │ ├── coco 31 | │ │ ├── annotations 32 | │ │ ├── train2017 33 | │ │ ├── val2017 34 | │ │ ├── test2017 35 | ``` 36 | 37 | ## Testing 38 | 39 | We provide a multi-GPU testing script, specify config file, checkpoint, and number of GPUs to use: 40 | ``` 41 | ./dist_test.sh config_file path/to/checkpoint #GPUs --eval bbox segm 42 | ``` 43 | 44 | For example, to test RepViT-M1.1 on COCO 2017 on an 8-GPU machine, 45 | 46 | ``` 47 | ./dist_test.sh configs/mask_rcnn_repvit_m1_1_fpn_1x_coco.py path/to/repvit_m1_1_coco.pth 8 --eval bbox segm 48 | ``` 49 | 50 | ## Training 51 | Download ImageNet-1K pretrained weights into `./pretrain` 52 | 53 | We provide PyTorch distributed data parallel (DDP) training script `dist_train.sh`, for example, to train RepViT-M1.1 on an 8-GPU machine: 54 | ``` 55 | ./dist_train.sh configs/mask_rcnn_repvit_m1_1_fpn_1x_coco.py 8 56 | ``` 57 | Tips: specify configs and #GPUs! 58 | 59 | -------------------------------------------------------------------------------- /detection/configs/_base_/datasets/cityscapes_detection.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'CityscapesDataset' 3 | data_root = 'data/cityscapes/' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | train_pipeline = [ 7 | dict(type='LoadImageFromFile'), 8 | dict(type='LoadAnnotations', with_bbox=True), 9 | dict( 10 | type='Resize', img_scale=[(2048, 800), (2048, 1024)], keep_ratio=True), 11 | dict(type='RandomFlip', flip_ratio=0.5), 12 | dict(type='Normalize', **img_norm_cfg), 13 | dict(type='Pad', size_divisor=32), 14 | dict(type='DefaultFormatBundle'), 15 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), 16 | ] 17 | test_pipeline = [ 18 | dict(type='LoadImageFromFile'), 19 | dict( 20 | type='MultiScaleFlipAug', 21 | img_scale=(2048, 1024), 22 | flip=False, 23 | transforms=[ 24 | dict(type='Resize', keep_ratio=True), 25 | dict(type='RandomFlip'), 26 | dict(type='Normalize', **img_norm_cfg), 27 | dict(type='Pad', size_divisor=32), 28 | dict(type='ImageToTensor', keys=['img']), 29 | dict(type='Collect', keys=['img']), 30 | ]) 31 | ] 32 | data = dict( 33 | samples_per_gpu=1, 34 | workers_per_gpu=2, 35 | train=dict( 36 | type='RepeatDataset', 37 | times=8, 38 | dataset=dict( 39 | type=dataset_type, 40 | ann_file=data_root + 41 | 'annotations/instancesonly_filtered_gtFine_train.json', 42 | img_prefix=data_root + 'leftImg8bit/train/', 43 | pipeline=train_pipeline)), 44 | val=dict( 45 | type=dataset_type, 46 | ann_file=data_root + 47 | 'annotations/instancesonly_filtered_gtFine_val.json', 48 | img_prefix=data_root + 'leftImg8bit/val/', 49 | pipeline=test_pipeline), 50 | test=dict( 51 | type=dataset_type, 52 | ann_file=data_root + 53 | 'annotations/instancesonly_filtered_gtFine_test.json', 54 | img_prefix=data_root + 'leftImg8bit/test/', 55 | pipeline=test_pipeline)) 56 | evaluation = dict(interval=1, metric='bbox') 57 | -------------------------------------------------------------------------------- /detection/configs/_base_/datasets/cityscapes_instance.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'CityscapesDataset' 3 | data_root = 'data/cityscapes/' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | train_pipeline = [ 7 | dict(type='LoadImageFromFile'), 8 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True), 9 | dict( 10 | type='Resize', img_scale=[(2048, 800), (2048, 1024)], keep_ratio=True), 11 | dict(type='RandomFlip', flip_ratio=0.5), 12 | dict(type='Normalize', **img_norm_cfg), 13 | dict(type='Pad', size_divisor=32), 14 | dict(type='DefaultFormatBundle'), 15 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), 16 | ] 17 | test_pipeline = [ 18 | dict(type='LoadImageFromFile'), 19 | dict( 20 | type='MultiScaleFlipAug', 21 | img_scale=(2048, 1024), 22 | flip=False, 23 | transforms=[ 24 | dict(type='Resize', keep_ratio=True), 25 | dict(type='RandomFlip'), 26 | dict(type='Normalize', **img_norm_cfg), 27 | dict(type='Pad', size_divisor=32), 28 | dict(type='ImageToTensor', keys=['img']), 29 | dict(type='Collect', keys=['img']), 30 | ]) 31 | ] 32 | data = dict( 33 | samples_per_gpu=1, 34 | workers_per_gpu=2, 35 | train=dict( 36 | type='RepeatDataset', 37 | times=8, 38 | dataset=dict( 39 | type=dataset_type, 40 | ann_file=data_root + 41 | 'annotations/instancesonly_filtered_gtFine_train.json', 42 | img_prefix=data_root + 'leftImg8bit/train/', 43 | pipeline=train_pipeline)), 44 | val=dict( 45 | type=dataset_type, 46 | ann_file=data_root + 47 | 'annotations/instancesonly_filtered_gtFine_val.json', 48 | img_prefix=data_root + 'leftImg8bit/val/', 49 | pipeline=test_pipeline), 50 | test=dict( 51 | type=dataset_type, 52 | ann_file=data_root + 53 | 'annotations/instancesonly_filtered_gtFine_test.json', 54 | img_prefix=data_root + 'leftImg8bit/test/', 55 | pipeline=test_pipeline)) 56 | evaluation = dict(metric=['bbox', 'segm']) 57 | -------------------------------------------------------------------------------- /detection/configs/_base_/datasets/coco_detection.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'CocoDataset' 3 | data_root = 'data/coco/' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | train_pipeline = [ 7 | dict(type='LoadImageFromFile'), 8 | dict(type='LoadAnnotations', with_bbox=True), 9 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), 10 | dict(type='RandomFlip', flip_ratio=0.5), 11 | dict(type='Normalize', **img_norm_cfg), 12 | dict(type='Pad', size_divisor=32), 13 | dict(type='DefaultFormatBundle'), 14 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), 15 | ] 16 | test_pipeline = [ 17 | dict(type='LoadImageFromFile'), 18 | dict( 19 | type='MultiScaleFlipAug', 20 | img_scale=(1333, 800), 21 | flip=False, 22 | transforms=[ 23 | dict(type='Resize', keep_ratio=True), 24 | dict(type='RandomFlip'), 25 | dict(type='Normalize', **img_norm_cfg), 26 | dict(type='Pad', size_divisor=32), 27 | dict(type='ImageToTensor', keys=['img']), 28 | dict(type='Collect', keys=['img']), 29 | ]) 30 | ] 31 | data = dict( 32 | samples_per_gpu=2, 33 | workers_per_gpu=2, 34 | train=dict( 35 | type=dataset_type, 36 | ann_file=data_root + 'annotations/instances_train2017.json', 37 | img_prefix=data_root + 'train2017/', 38 | pipeline=train_pipeline), 39 | val=dict( 40 | type=dataset_type, 41 | ann_file=data_root + 'annotations/instances_val2017.json', 42 | img_prefix=data_root + 'val2017/', 43 | pipeline=test_pipeline), 44 | test=dict( 45 | type=dataset_type, 46 | ann_file=data_root + 'annotations/instances_val2017.json', 47 | img_prefix=data_root + 'val2017/', 48 | pipeline=test_pipeline)) 49 | evaluation = dict(interval=1, metric='bbox') 50 | -------------------------------------------------------------------------------- /detection/configs/_base_/datasets/coco_instance.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'CocoDataset' 3 | data_root = 'data/coco/' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | train_pipeline = [ 7 | dict(type='LoadImageFromFile'), 8 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True), 9 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), 10 | dict(type='RandomFlip', flip_ratio=0.5), 11 | dict(type='Normalize', **img_norm_cfg), 12 | dict(type='Pad', size_divisor=32), 13 | dict(type='DefaultFormatBundle'), 14 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), 15 | ] 16 | test_pipeline = [ 17 | dict(type='LoadImageFromFile'), 18 | dict( 19 | type='MultiScaleFlipAug', 20 | img_scale=(1333, 800), 21 | flip=False, 22 | transforms=[ 23 | dict(type='Resize', keep_ratio=True), 24 | dict(type='RandomFlip'), 25 | dict(type='Normalize', **img_norm_cfg), 26 | dict(type='Pad', size_divisor=32), 27 | dict(type='ImageToTensor', keys=['img']), 28 | dict(type='Collect', keys=['img']), 29 | ]) 30 | ] 31 | data = dict( 32 | samples_per_gpu=2, 33 | workers_per_gpu=2, 34 | train=dict( 35 | type=dataset_type, 36 | ann_file=data_root + 'annotations/instances_train2017.json', 37 | img_prefix=data_root + 'train2017/', 38 | pipeline=train_pipeline), 39 | val=dict( 40 | type=dataset_type, 41 | ann_file=data_root + 'annotations/instances_val2017.json', 42 | img_prefix=data_root + 'val2017/', 43 | pipeline=test_pipeline), 44 | test=dict( 45 | type=dataset_type, 46 | ann_file=data_root + 'annotations/instances_val2017.json', 47 | img_prefix=data_root + 'val2017/', 48 | pipeline=test_pipeline)) 49 | evaluation = dict(metric=['bbox', 'segm']) 50 | -------------------------------------------------------------------------------- /detection/configs/_base_/datasets/coco_instance_semantic.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'CocoDataset' 3 | data_root = 'data/coco/' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | train_pipeline = [ 7 | dict(type='LoadImageFromFile'), 8 | dict( 9 | type='LoadAnnotations', with_bbox=True, with_mask=True, with_seg=True), 10 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), 11 | dict(type='RandomFlip', flip_ratio=0.5), 12 | dict(type='Normalize', **img_norm_cfg), 13 | dict(type='Pad', size_divisor=32), 14 | dict(type='SegRescale', scale_factor=1 / 8), 15 | dict(type='DefaultFormatBundle'), 16 | dict( 17 | type='Collect', 18 | keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']), 19 | ] 20 | test_pipeline = [ 21 | dict(type='LoadImageFromFile'), 22 | dict( 23 | type='MultiScaleFlipAug', 24 | img_scale=(1333, 800), 25 | flip=False, 26 | transforms=[ 27 | dict(type='Resize', keep_ratio=True), 28 | dict(type='RandomFlip', flip_ratio=0.5), 29 | dict(type='Normalize', **img_norm_cfg), 30 | dict(type='Pad', size_divisor=32), 31 | dict(type='ImageToTensor', keys=['img']), 32 | dict(type='Collect', keys=['img']), 33 | ]) 34 | ] 35 | data = dict( 36 | samples_per_gpu=2, 37 | workers_per_gpu=2, 38 | train=dict( 39 | type=dataset_type, 40 | ann_file=data_root + 'annotations/instances_train2017.json', 41 | img_prefix=data_root + 'train2017/', 42 | seg_prefix=data_root + 'stuffthingmaps/train2017/', 43 | pipeline=train_pipeline), 44 | val=dict( 45 | type=dataset_type, 46 | ann_file=data_root + 'annotations/instances_val2017.json', 47 | img_prefix=data_root + 'val2017/', 48 | pipeline=test_pipeline), 49 | test=dict( 50 | type=dataset_type, 51 | ann_file=data_root + 'annotations/instances_val2017.json', 52 | img_prefix=data_root + 'val2017/', 53 | pipeline=test_pipeline)) 54 | evaluation = dict(metric=['bbox', 'segm']) 55 | -------------------------------------------------------------------------------- /detection/configs/_base_/datasets/deepfashion.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'DeepFashionDataset' 3 | data_root = 'data/DeepFashion/In-shop/' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | train_pipeline = [ 7 | dict(type='LoadImageFromFile'), 8 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True), 9 | dict(type='Resize', img_scale=(750, 1101), keep_ratio=True), 10 | dict(type='RandomFlip', flip_ratio=0.5), 11 | dict(type='Normalize', **img_norm_cfg), 12 | dict(type='Pad', size_divisor=32), 13 | dict(type='DefaultFormatBundle'), 14 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), 15 | ] 16 | test_pipeline = [ 17 | dict(type='LoadImageFromFile'), 18 | dict( 19 | type='MultiScaleFlipAug', 20 | img_scale=(750, 1101), 21 | flip=False, 22 | transforms=[ 23 | dict(type='Resize', keep_ratio=True), 24 | dict(type='RandomFlip'), 25 | dict(type='Normalize', **img_norm_cfg), 26 | dict(type='Pad', size_divisor=32), 27 | dict(type='ImageToTensor', keys=['img']), 28 | dict(type='Collect', keys=['img']), 29 | ]) 30 | ] 31 | data = dict( 32 | imgs_per_gpu=2, 33 | workers_per_gpu=1, 34 | train=dict( 35 | type=dataset_type, 36 | ann_file=data_root + 'annotations/DeepFashion_segmentation_query.json', 37 | img_prefix=data_root + 'Img/', 38 | pipeline=train_pipeline, 39 | data_root=data_root), 40 | val=dict( 41 | type=dataset_type, 42 | ann_file=data_root + 'annotations/DeepFashion_segmentation_query.json', 43 | img_prefix=data_root + 'Img/', 44 | pipeline=test_pipeline, 45 | data_root=data_root), 46 | test=dict( 47 | type=dataset_type, 48 | ann_file=data_root + 49 | 'annotations/DeepFashion_segmentation_gallery.json', 50 | img_prefix=data_root + 'Img/', 51 | pipeline=test_pipeline, 52 | data_root=data_root)) 53 | evaluation = dict(interval=5, metric=['bbox', 'segm']) 54 | -------------------------------------------------------------------------------- /detection/configs/_base_/datasets/lvis_v0.5_instance.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | _base_ = 'coco_instance.py' 3 | dataset_type = 'LVISV05Dataset' 4 | data_root = 'data/lvis_v0.5/' 5 | data = dict( 6 | samples_per_gpu=2, 7 | workers_per_gpu=2, 8 | train=dict( 9 | _delete_=True, 10 | type='ClassBalancedDataset', 11 | oversample_thr=1e-3, 12 | dataset=dict( 13 | type=dataset_type, 14 | ann_file=data_root + 'annotations/lvis_v0.5_train.json', 15 | img_prefix=data_root + 'train2017/')), 16 | val=dict( 17 | type=dataset_type, 18 | ann_file=data_root + 'annotations/lvis_v0.5_val.json', 19 | img_prefix=data_root + 'val2017/'), 20 | test=dict( 21 | type=dataset_type, 22 | ann_file=data_root + 'annotations/lvis_v0.5_val.json', 23 | img_prefix=data_root + 'val2017/')) 24 | evaluation = dict(metric=['bbox', 'segm']) 25 | -------------------------------------------------------------------------------- /detection/configs/_base_/datasets/lvis_v1_instance.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | _base_ = 'coco_instance.py' 3 | dataset_type = 'LVISV1Dataset' 4 | data_root = 'data/lvis_v1/' 5 | data = dict( 6 | samples_per_gpu=2, 7 | workers_per_gpu=2, 8 | train=dict( 9 | _delete_=True, 10 | type='ClassBalancedDataset', 11 | oversample_thr=1e-3, 12 | dataset=dict( 13 | type=dataset_type, 14 | ann_file=data_root + 'annotations/lvis_v1_train.json', 15 | img_prefix=data_root)), 16 | val=dict( 17 | type=dataset_type, 18 | ann_file=data_root + 'annotations/lvis_v1_val.json', 19 | img_prefix=data_root), 20 | test=dict( 21 | type=dataset_type, 22 | ann_file=data_root + 'annotations/lvis_v1_val.json', 23 | img_prefix=data_root)) 24 | evaluation = dict(metric=['bbox', 'segm']) 25 | -------------------------------------------------------------------------------- /detection/configs/_base_/datasets/voc0712.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'VOCDataset' 3 | data_root = 'data/VOCdevkit/' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | train_pipeline = [ 7 | dict(type='LoadImageFromFile'), 8 | dict(type='LoadAnnotations', with_bbox=True), 9 | dict(type='Resize', img_scale=(1000, 600), keep_ratio=True), 10 | dict(type='RandomFlip', flip_ratio=0.5), 11 | dict(type='Normalize', **img_norm_cfg), 12 | dict(type='Pad', size_divisor=32), 13 | dict(type='DefaultFormatBundle'), 14 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), 15 | ] 16 | test_pipeline = [ 17 | dict(type='LoadImageFromFile'), 18 | dict( 19 | type='MultiScaleFlipAug', 20 | img_scale=(1000, 600), 21 | flip=False, 22 | transforms=[ 23 | dict(type='Resize', keep_ratio=True), 24 | dict(type='RandomFlip'), 25 | dict(type='Normalize', **img_norm_cfg), 26 | dict(type='Pad', size_divisor=32), 27 | dict(type='ImageToTensor', keys=['img']), 28 | dict(type='Collect', keys=['img']), 29 | ]) 30 | ] 31 | data = dict( 32 | samples_per_gpu=2, 33 | workers_per_gpu=2, 34 | train=dict( 35 | type='RepeatDataset', 36 | times=3, 37 | dataset=dict( 38 | type=dataset_type, 39 | ann_file=[ 40 | data_root + 'VOC2007/ImageSets/Main/trainval.txt', 41 | data_root + 'VOC2012/ImageSets/Main/trainval.txt' 42 | ], 43 | img_prefix=[data_root + 'VOC2007/', data_root + 'VOC2012/'], 44 | pipeline=train_pipeline)), 45 | val=dict( 46 | type=dataset_type, 47 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt', 48 | img_prefix=data_root + 'VOC2007/', 49 | pipeline=test_pipeline), 50 | test=dict( 51 | type=dataset_type, 52 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt', 53 | img_prefix=data_root + 'VOC2007/', 54 | pipeline=test_pipeline)) 55 | evaluation = dict(interval=1, metric='mAP') 56 | -------------------------------------------------------------------------------- /detection/configs/_base_/datasets/wider_face.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'WIDERFaceDataset' 3 | data_root = 'data/WIDERFace/' 4 | img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[1, 1, 1], to_rgb=True) 5 | train_pipeline = [ 6 | dict(type='LoadImageFromFile', to_float32=True), 7 | dict(type='LoadAnnotations', with_bbox=True), 8 | dict( 9 | type='PhotoMetricDistortion', 10 | brightness_delta=32, 11 | contrast_range=(0.5, 1.5), 12 | saturation_range=(0.5, 1.5), 13 | hue_delta=18), 14 | dict( 15 | type='Expand', 16 | mean=img_norm_cfg['mean'], 17 | to_rgb=img_norm_cfg['to_rgb'], 18 | ratio_range=(1, 4)), 19 | dict( 20 | type='MinIoURandomCrop', 21 | min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), 22 | min_crop_size=0.3), 23 | dict(type='Resize', img_scale=(300, 300), keep_ratio=False), 24 | dict(type='Normalize', **img_norm_cfg), 25 | dict(type='RandomFlip', flip_ratio=0.5), 26 | dict(type='DefaultFormatBundle'), 27 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), 28 | ] 29 | test_pipeline = [ 30 | dict(type='LoadImageFromFile'), 31 | dict( 32 | type='MultiScaleFlipAug', 33 | img_scale=(300, 300), 34 | flip=False, 35 | transforms=[ 36 | dict(type='Resize', keep_ratio=False), 37 | dict(type='Normalize', **img_norm_cfg), 38 | dict(type='ImageToTensor', keys=['img']), 39 | dict(type='Collect', keys=['img']), 40 | ]) 41 | ] 42 | data = dict( 43 | samples_per_gpu=60, 44 | workers_per_gpu=2, 45 | train=dict( 46 | type='RepeatDataset', 47 | times=2, 48 | dataset=dict( 49 | type=dataset_type, 50 | ann_file=data_root + 'train.txt', 51 | img_prefix=data_root + 'WIDER_train/', 52 | min_size=17, 53 | pipeline=train_pipeline)), 54 | val=dict( 55 | type=dataset_type, 56 | ann_file=data_root + 'val.txt', 57 | img_prefix=data_root + 'WIDER_val/', 58 | pipeline=test_pipeline), 59 | test=dict( 60 | type=dataset_type, 61 | ann_file=data_root + 'val.txt', 62 | img_prefix=data_root + 'WIDER_val/', 63 | pipeline=test_pipeline)) 64 | -------------------------------------------------------------------------------- /detection/configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=1) 2 | # yapf:disable 3 | log_config = dict( 4 | interval=50, 5 | hooks=[ 6 | dict(type='TextLoggerHook'), 7 | # dict(type='TensorboardLoggerHook') 8 | ]) 9 | # yapf:enable 10 | custom_hooks = [dict(type='NumClassCheckHook')] 11 | 12 | dist_params = dict(backend='nccl') 13 | log_level = 'INFO' 14 | load_from = None 15 | resume_from = None 16 | workflow = [('train', 1)] 17 | -------------------------------------------------------------------------------- /detection/configs/_base_/models/fast_rcnn_r50_fpn.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='FastRCNN', 4 | pretrained='torchvision://resnet50', 5 | backbone=dict( 6 | type='ResNet', 7 | depth=50, 8 | num_stages=4, 9 | out_indices=(0, 1, 2, 3), 10 | frozen_stages=1, 11 | norm_cfg=dict(type='BN', requires_grad=True), 12 | norm_eval=True, 13 | style='pytorch'), 14 | neck=dict( 15 | type='FPN', 16 | in_channels=[256, 512, 1024, 2048], 17 | out_channels=256, 18 | num_outs=5), 19 | roi_head=dict( 20 | type='StandardRoIHead', 21 | bbox_roi_extractor=dict( 22 | type='SingleRoIExtractor', 23 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), 24 | out_channels=256, 25 | featmap_strides=[4, 8, 16, 32]), 26 | bbox_head=dict( 27 | type='Shared2FCBBoxHead', 28 | in_channels=256, 29 | fc_out_channels=1024, 30 | roi_feat_size=7, 31 | num_classes=80, 32 | bbox_coder=dict( 33 | type='DeltaXYWHBBoxCoder', 34 | target_means=[0., 0., 0., 0.], 35 | target_stds=[0.1, 0.1, 0.2, 0.2]), 36 | reg_class_agnostic=False, 37 | loss_cls=dict( 38 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 39 | loss_bbox=dict(type='L1Loss', loss_weight=1.0))), 40 | # model training and testing settings 41 | train_cfg=dict( 42 | rcnn=dict( 43 | assigner=dict( 44 | type='MaxIoUAssigner', 45 | pos_iou_thr=0.5, 46 | neg_iou_thr=0.5, 47 | min_pos_iou=0.5, 48 | match_low_quality=False, 49 | ignore_iof_thr=-1), 50 | sampler=dict( 51 | type='RandomSampler', 52 | num=512, 53 | pos_fraction=0.25, 54 | neg_pos_ub=-1, 55 | add_gt_as_proposals=True), 56 | pos_weight=-1, 57 | debug=False)), 58 | test_cfg=dict( 59 | rcnn=dict( 60 | score_thr=0.05, 61 | nms=dict(type='nms', iou_threshold=0.5), 62 | max_per_img=100))) 63 | -------------------------------------------------------------------------------- /detection/configs/_base_/models/faster_rcnn_r50_caffe_c4.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='BN', requires_grad=False) 3 | model = dict( 4 | type='FasterRCNN', 5 | pretrained='open-mmlab://detectron2/resnet50_caffe', 6 | backbone=dict( 7 | type='ResNet', 8 | depth=50, 9 | num_stages=3, 10 | strides=(1, 2, 2), 11 | dilations=(1, 1, 1), 12 | out_indices=(2, ), 13 | frozen_stages=1, 14 | norm_cfg=norm_cfg, 15 | norm_eval=True, 16 | style='caffe'), 17 | rpn_head=dict( 18 | type='RPNHead', 19 | in_channels=1024, 20 | feat_channels=1024, 21 | anchor_generator=dict( 22 | type='AnchorGenerator', 23 | scales=[2, 4, 8, 16, 32], 24 | ratios=[0.5, 1.0, 2.0], 25 | strides=[16]), 26 | bbox_coder=dict( 27 | type='DeltaXYWHBBoxCoder', 28 | target_means=[.0, .0, .0, .0], 29 | target_stds=[1.0, 1.0, 1.0, 1.0]), 30 | loss_cls=dict( 31 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 32 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 33 | roi_head=dict( 34 | type='StandardRoIHead', 35 | shared_head=dict( 36 | type='ResLayer', 37 | depth=50, 38 | stage=3, 39 | stride=2, 40 | dilation=1, 41 | style='caffe', 42 | norm_cfg=norm_cfg, 43 | norm_eval=True), 44 | bbox_roi_extractor=dict( 45 | type='SingleRoIExtractor', 46 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), 47 | out_channels=1024, 48 | featmap_strides=[16]), 49 | bbox_head=dict( 50 | type='BBoxHead', 51 | with_avg_pool=True, 52 | roi_feat_size=7, 53 | in_channels=2048, 54 | num_classes=80, 55 | bbox_coder=dict( 56 | type='DeltaXYWHBBoxCoder', 57 | target_means=[0., 0., 0., 0.], 58 | target_stds=[0.1, 0.1, 0.2, 0.2]), 59 | reg_class_agnostic=False, 60 | loss_cls=dict( 61 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 62 | loss_bbox=dict(type='L1Loss', loss_weight=1.0))), 63 | # model training and testing settings 64 | train_cfg=dict( 65 | rpn=dict( 66 | assigner=dict( 67 | type='MaxIoUAssigner', 68 | pos_iou_thr=0.7, 69 | neg_iou_thr=0.3, 70 | min_pos_iou=0.3, 71 | match_low_quality=True, 72 | ignore_iof_thr=-1), 73 | sampler=dict( 74 | type='RandomSampler', 75 | num=256, 76 | pos_fraction=0.5, 77 | neg_pos_ub=-1, 78 | add_gt_as_proposals=False), 79 | allowed_border=0, 80 | pos_weight=-1, 81 | debug=False), 82 | rpn_proposal=dict( 83 | nms_pre=12000, 84 | max_per_img=2000, 85 | nms=dict(type='nms', iou_threshold=0.7), 86 | min_bbox_size=0), 87 | rcnn=dict( 88 | assigner=dict( 89 | type='MaxIoUAssigner', 90 | pos_iou_thr=0.5, 91 | neg_iou_thr=0.5, 92 | min_pos_iou=0.5, 93 | match_low_quality=False, 94 | ignore_iof_thr=-1), 95 | sampler=dict( 96 | type='RandomSampler', 97 | num=512, 98 | pos_fraction=0.25, 99 | neg_pos_ub=-1, 100 | add_gt_as_proposals=True), 101 | pos_weight=-1, 102 | debug=False)), 103 | test_cfg=dict( 104 | rpn=dict( 105 | nms_pre=6000, 106 | max_per_img=1000, 107 | nms=dict(type='nms', iou_threshold=0.7), 108 | min_bbox_size=0), 109 | rcnn=dict( 110 | score_thr=0.05, 111 | nms=dict(type='nms', iou_threshold=0.5), 112 | max_per_img=100))) 113 | -------------------------------------------------------------------------------- /detection/configs/_base_/models/faster_rcnn_r50_caffe_dc5.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='BN', requires_grad=False) 3 | model = dict( 4 | type='FasterRCNN', 5 | pretrained='open-mmlab://detectron2/resnet50_caffe', 6 | backbone=dict( 7 | type='ResNet', 8 | depth=50, 9 | num_stages=4, 10 | strides=(1, 2, 2, 1), 11 | dilations=(1, 1, 1, 2), 12 | out_indices=(3, ), 13 | frozen_stages=1, 14 | norm_cfg=norm_cfg, 15 | norm_eval=True, 16 | style='caffe'), 17 | rpn_head=dict( 18 | type='RPNHead', 19 | in_channels=2048, 20 | feat_channels=2048, 21 | anchor_generator=dict( 22 | type='AnchorGenerator', 23 | scales=[2, 4, 8, 16, 32], 24 | ratios=[0.5, 1.0, 2.0], 25 | strides=[16]), 26 | bbox_coder=dict( 27 | type='DeltaXYWHBBoxCoder', 28 | target_means=[.0, .0, .0, .0], 29 | target_stds=[1.0, 1.0, 1.0, 1.0]), 30 | loss_cls=dict( 31 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 32 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 33 | roi_head=dict( 34 | type='StandardRoIHead', 35 | bbox_roi_extractor=dict( 36 | type='SingleRoIExtractor', 37 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), 38 | out_channels=2048, 39 | featmap_strides=[16]), 40 | bbox_head=dict( 41 | type='Shared2FCBBoxHead', 42 | in_channels=2048, 43 | fc_out_channels=1024, 44 | roi_feat_size=7, 45 | num_classes=80, 46 | bbox_coder=dict( 47 | type='DeltaXYWHBBoxCoder', 48 | target_means=[0., 0., 0., 0.], 49 | target_stds=[0.1, 0.1, 0.2, 0.2]), 50 | reg_class_agnostic=False, 51 | loss_cls=dict( 52 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 53 | loss_bbox=dict(type='L1Loss', loss_weight=1.0))), 54 | # model training and testing settings 55 | train_cfg=dict( 56 | rpn=dict( 57 | assigner=dict( 58 | type='MaxIoUAssigner', 59 | pos_iou_thr=0.7, 60 | neg_iou_thr=0.3, 61 | min_pos_iou=0.3, 62 | match_low_quality=True, 63 | ignore_iof_thr=-1), 64 | sampler=dict( 65 | type='RandomSampler', 66 | num=256, 67 | pos_fraction=0.5, 68 | neg_pos_ub=-1, 69 | add_gt_as_proposals=False), 70 | allowed_border=0, 71 | pos_weight=-1, 72 | debug=False), 73 | rpn_proposal=dict( 74 | nms_pre=12000, 75 | max_per_img=2000, 76 | nms=dict(type='nms', iou_threshold=0.7), 77 | min_bbox_size=0), 78 | rcnn=dict( 79 | assigner=dict( 80 | type='MaxIoUAssigner', 81 | pos_iou_thr=0.5, 82 | neg_iou_thr=0.5, 83 | min_pos_iou=0.5, 84 | match_low_quality=False, 85 | ignore_iof_thr=-1), 86 | sampler=dict( 87 | type='RandomSampler', 88 | num=512, 89 | pos_fraction=0.25, 90 | neg_pos_ub=-1, 91 | add_gt_as_proposals=True), 92 | pos_weight=-1, 93 | debug=False)), 94 | test_cfg=dict( 95 | rpn=dict( 96 | nms=dict(type='nms', iou_threshold=0.7), 97 | nms_pre=6000, 98 | max_per_img=1000, 99 | min_bbox_size=0), 100 | rcnn=dict( 101 | score_thr=0.05, 102 | nms=dict(type='nms', iou_threshold=0.5), 103 | max_per_img=100))) 104 | -------------------------------------------------------------------------------- /detection/configs/_base_/models/faster_rcnn_r50_fpn.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='FasterRCNN', 4 | pretrained='torchvision://resnet50', 5 | backbone=dict( 6 | type='ResNet', 7 | depth=50, 8 | num_stages=4, 9 | out_indices=(0, 1, 2, 3), 10 | frozen_stages=1, 11 | norm_cfg=dict(type='BN', requires_grad=True), 12 | norm_eval=True, 13 | style='pytorch'), 14 | neck=dict( 15 | type='FPN', 16 | in_channels=[256, 512, 1024, 2048], 17 | out_channels=256, 18 | num_outs=5), 19 | rpn_head=dict( 20 | type='RPNHead', 21 | in_channels=256, 22 | feat_channels=256, 23 | anchor_generator=dict( 24 | type='AnchorGenerator', 25 | scales=[8], 26 | ratios=[0.5, 1.0, 2.0], 27 | strides=[4, 8, 16, 32, 64]), 28 | bbox_coder=dict( 29 | type='DeltaXYWHBBoxCoder', 30 | target_means=[.0, .0, .0, .0], 31 | target_stds=[1.0, 1.0, 1.0, 1.0]), 32 | loss_cls=dict( 33 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 34 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 35 | roi_head=dict( 36 | type='StandardRoIHead', 37 | bbox_roi_extractor=dict( 38 | type='SingleRoIExtractor', 39 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), 40 | out_channels=256, 41 | featmap_strides=[4, 8, 16, 32]), 42 | bbox_head=dict( 43 | type='Shared2FCBBoxHead', 44 | in_channels=256, 45 | fc_out_channels=1024, 46 | roi_feat_size=7, 47 | num_classes=80, 48 | bbox_coder=dict( 49 | type='DeltaXYWHBBoxCoder', 50 | target_means=[0., 0., 0., 0.], 51 | target_stds=[0.1, 0.1, 0.2, 0.2]), 52 | reg_class_agnostic=False, 53 | loss_cls=dict( 54 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 55 | loss_bbox=dict(type='L1Loss', loss_weight=1.0))), 56 | # model training and testing settings 57 | train_cfg=dict( 58 | rpn=dict( 59 | assigner=dict( 60 | type='MaxIoUAssigner', 61 | pos_iou_thr=0.7, 62 | neg_iou_thr=0.3, 63 | min_pos_iou=0.3, 64 | match_low_quality=True, 65 | ignore_iof_thr=-1), 66 | sampler=dict( 67 | type='RandomSampler', 68 | num=256, 69 | pos_fraction=0.5, 70 | neg_pos_ub=-1, 71 | add_gt_as_proposals=False), 72 | allowed_border=-1, 73 | pos_weight=-1, 74 | debug=False), 75 | rpn_proposal=dict( 76 | nms_pre=2000, 77 | max_per_img=1000, 78 | nms=dict(type='nms', iou_threshold=0.7), 79 | min_bbox_size=0), 80 | rcnn=dict( 81 | assigner=dict( 82 | type='MaxIoUAssigner', 83 | pos_iou_thr=0.5, 84 | neg_iou_thr=0.5, 85 | min_pos_iou=0.5, 86 | match_low_quality=False, 87 | ignore_iof_thr=-1), 88 | sampler=dict( 89 | type='RandomSampler', 90 | num=512, 91 | pos_fraction=0.25, 92 | neg_pos_ub=-1, 93 | add_gt_as_proposals=True), 94 | pos_weight=-1, 95 | debug=False)), 96 | test_cfg=dict( 97 | rpn=dict( 98 | nms_pre=1000, 99 | max_per_img=1000, 100 | nms=dict(type='nms', iou_threshold=0.7), 101 | min_bbox_size=0), 102 | rcnn=dict( 103 | score_thr=0.05, 104 | nms=dict(type='nms', iou_threshold=0.5), 105 | max_per_img=100) 106 | # soft-nms is also supported for rcnn testing 107 | # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) 108 | )) 109 | -------------------------------------------------------------------------------- /detection/configs/_base_/models/mask_rcnn_r50_caffe_c4.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='BN', requires_grad=False) 3 | model = dict( 4 | type='MaskRCNN', 5 | pretrained='open-mmlab://detectron2/resnet50_caffe', 6 | backbone=dict( 7 | type='ResNet', 8 | depth=50, 9 | num_stages=3, 10 | strides=(1, 2, 2), 11 | dilations=(1, 1, 1), 12 | out_indices=(2, ), 13 | frozen_stages=1, 14 | norm_cfg=norm_cfg, 15 | norm_eval=True, 16 | style='caffe'), 17 | rpn_head=dict( 18 | type='RPNHead', 19 | in_channels=1024, 20 | feat_channels=1024, 21 | anchor_generator=dict( 22 | type='AnchorGenerator', 23 | scales=[2, 4, 8, 16, 32], 24 | ratios=[0.5, 1.0, 2.0], 25 | strides=[16]), 26 | bbox_coder=dict( 27 | type='DeltaXYWHBBoxCoder', 28 | target_means=[.0, .0, .0, .0], 29 | target_stds=[1.0, 1.0, 1.0, 1.0]), 30 | loss_cls=dict( 31 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 32 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 33 | roi_head=dict( 34 | type='StandardRoIHead', 35 | shared_head=dict( 36 | type='ResLayer', 37 | depth=50, 38 | stage=3, 39 | stride=2, 40 | dilation=1, 41 | style='caffe', 42 | norm_cfg=norm_cfg, 43 | norm_eval=True), 44 | bbox_roi_extractor=dict( 45 | type='SingleRoIExtractor', 46 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), 47 | out_channels=1024, 48 | featmap_strides=[16]), 49 | bbox_head=dict( 50 | type='BBoxHead', 51 | with_avg_pool=True, 52 | roi_feat_size=7, 53 | in_channels=2048, 54 | num_classes=80, 55 | bbox_coder=dict( 56 | type='DeltaXYWHBBoxCoder', 57 | target_means=[0., 0., 0., 0.], 58 | target_stds=[0.1, 0.1, 0.2, 0.2]), 59 | reg_class_agnostic=False, 60 | loss_cls=dict( 61 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 62 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 63 | mask_roi_extractor=None, 64 | mask_head=dict( 65 | type='FCNMaskHead', 66 | num_convs=0, 67 | in_channels=2048, 68 | conv_out_channels=256, 69 | num_classes=80, 70 | loss_mask=dict( 71 | type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), 72 | # model training and testing settings 73 | train_cfg=dict( 74 | rpn=dict( 75 | assigner=dict( 76 | type='MaxIoUAssigner', 77 | pos_iou_thr=0.7, 78 | neg_iou_thr=0.3, 79 | min_pos_iou=0.3, 80 | match_low_quality=True, 81 | ignore_iof_thr=-1), 82 | sampler=dict( 83 | type='RandomSampler', 84 | num=256, 85 | pos_fraction=0.5, 86 | neg_pos_ub=-1, 87 | add_gt_as_proposals=False), 88 | allowed_border=0, 89 | pos_weight=-1, 90 | debug=False), 91 | rpn_proposal=dict( 92 | nms_pre=12000, 93 | max_per_img=2000, 94 | nms=dict(type='nms', iou_threshold=0.7), 95 | min_bbox_size=0), 96 | rcnn=dict( 97 | assigner=dict( 98 | type='MaxIoUAssigner', 99 | pos_iou_thr=0.5, 100 | neg_iou_thr=0.5, 101 | min_pos_iou=0.5, 102 | match_low_quality=False, 103 | ignore_iof_thr=-1), 104 | sampler=dict( 105 | type='RandomSampler', 106 | num=512, 107 | pos_fraction=0.25, 108 | neg_pos_ub=-1, 109 | add_gt_as_proposals=True), 110 | mask_size=14, 111 | pos_weight=-1, 112 | debug=False)), 113 | test_cfg=dict( 114 | rpn=dict( 115 | nms_pre=6000, 116 | nms=dict(type='nms', iou_threshold=0.7), 117 | max_per_img=1000, 118 | min_bbox_size=0), 119 | rcnn=dict( 120 | score_thr=0.05, 121 | nms=dict(type='nms', iou_threshold=0.5), 122 | max_per_img=100, 123 | mask_thr_binary=0.5))) 124 | -------------------------------------------------------------------------------- /detection/configs/_base_/models/mask_rcnn_r50_fpn.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='MaskRCNN', 4 | pretrained='torchvision://resnet50', 5 | backbone=dict( 6 | type='ResNet', 7 | depth=50, 8 | num_stages=4, 9 | out_indices=(0, 1, 2, 3), 10 | frozen_stages=1, 11 | norm_cfg=dict(type='BN', requires_grad=True), 12 | norm_eval=True, 13 | style='pytorch'), 14 | neck=dict( 15 | type='FPN', 16 | in_channels=[256, 512, 1024, 2048], 17 | out_channels=256, 18 | num_outs=5), 19 | rpn_head=dict( 20 | type='RPNHead', 21 | in_channels=256, 22 | feat_channels=256, 23 | anchor_generator=dict( 24 | type='AnchorGenerator', 25 | scales=[8], 26 | ratios=[0.5, 1.0, 2.0], 27 | strides=[4, 8, 16, 32, 64]), 28 | bbox_coder=dict( 29 | type='DeltaXYWHBBoxCoder', 30 | target_means=[.0, .0, .0, .0], 31 | target_stds=[1.0, 1.0, 1.0, 1.0]), 32 | loss_cls=dict( 33 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 34 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 35 | roi_head=dict( 36 | type='StandardRoIHead', 37 | bbox_roi_extractor=dict( 38 | type='SingleRoIExtractor', 39 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), 40 | out_channels=256, 41 | featmap_strides=[4, 8, 16, 32]), 42 | bbox_head=dict( 43 | type='Shared2FCBBoxHead', 44 | in_channels=256, 45 | fc_out_channels=1024, 46 | roi_feat_size=7, 47 | num_classes=80, 48 | bbox_coder=dict( 49 | type='DeltaXYWHBBoxCoder', 50 | target_means=[0., 0., 0., 0.], 51 | target_stds=[0.1, 0.1, 0.2, 0.2]), 52 | reg_class_agnostic=False, 53 | loss_cls=dict( 54 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 55 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 56 | mask_roi_extractor=dict( 57 | type='SingleRoIExtractor', 58 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), 59 | out_channels=256, 60 | featmap_strides=[4, 8, 16, 32]), 61 | mask_head=dict( 62 | type='FCNMaskHead', 63 | num_convs=4, 64 | in_channels=256, 65 | conv_out_channels=256, 66 | num_classes=80, 67 | loss_mask=dict( 68 | type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), 69 | # model training and testing settings 70 | train_cfg=dict( 71 | rpn=dict( 72 | assigner=dict( 73 | type='MaxIoUAssigner', 74 | pos_iou_thr=0.7, 75 | neg_iou_thr=0.3, 76 | min_pos_iou=0.3, 77 | match_low_quality=True, 78 | ignore_iof_thr=-1), 79 | sampler=dict( 80 | type='RandomSampler', 81 | num=256, 82 | pos_fraction=0.5, 83 | neg_pos_ub=-1, 84 | add_gt_as_proposals=False), 85 | allowed_border=-1, 86 | pos_weight=-1, 87 | debug=False), 88 | rpn_proposal=dict( 89 | nms_pre=2000, 90 | max_per_img=1000, 91 | nms=dict(type='nms', iou_threshold=0.7), 92 | min_bbox_size=0), 93 | rcnn=dict( 94 | assigner=dict( 95 | type='MaxIoUAssigner', 96 | pos_iou_thr=0.5, 97 | neg_iou_thr=0.5, 98 | min_pos_iou=0.5, 99 | match_low_quality=True, 100 | ignore_iof_thr=-1), 101 | sampler=dict( 102 | type='RandomSampler', 103 | num=512, 104 | pos_fraction=0.25, 105 | neg_pos_ub=-1, 106 | add_gt_as_proposals=True), 107 | mask_size=28, 108 | pos_weight=-1, 109 | debug=False)), 110 | test_cfg=dict( 111 | rpn=dict( 112 | nms_pre=1000, 113 | max_per_img=1000, 114 | nms=dict(type='nms', iou_threshold=0.7), 115 | min_bbox_size=0), 116 | rcnn=dict( 117 | score_thr=0.05, 118 | nms=dict(type='nms', iou_threshold=0.5), 119 | max_per_img=100, 120 | mask_thr_binary=0.5))) 121 | -------------------------------------------------------------------------------- /detection/configs/_base_/models/retinanet_r50_fpn.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='RetinaNet', 4 | pretrained='torchvision://resnet50', 5 | backbone=dict( 6 | type='ResNet', 7 | depth=50, 8 | num_stages=4, 9 | out_indices=(0, 1, 2, 3), 10 | frozen_stages=1, 11 | norm_cfg=dict(type='BN', requires_grad=True), 12 | norm_eval=True, 13 | style='pytorch'), 14 | neck=dict( 15 | type='FPN', 16 | in_channels=[256, 512, 1024, 2048], 17 | out_channels=256, 18 | start_level=1, 19 | add_extra_convs='on_input', 20 | num_outs=5), 21 | bbox_head=dict( 22 | type='RetinaHead', 23 | num_classes=80, 24 | in_channels=256, 25 | stacked_convs=4, 26 | feat_channels=256, 27 | anchor_generator=dict( 28 | type='AnchorGenerator', 29 | octave_base_scale=4, 30 | scales_per_octave=3, 31 | ratios=[0.5, 1.0, 2.0], 32 | strides=[8, 16, 32, 64, 128]), 33 | bbox_coder=dict( 34 | type='DeltaXYWHBBoxCoder', 35 | target_means=[.0, .0, .0, .0], 36 | target_stds=[1.0, 1.0, 1.0, 1.0]), 37 | loss_cls=dict( 38 | type='FocalLoss', 39 | use_sigmoid=True, 40 | gamma=2.0, 41 | alpha=0.25, 42 | loss_weight=1.0), 43 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 44 | # model training and testing settings 45 | train_cfg=dict( 46 | assigner=dict( 47 | type='MaxIoUAssigner', 48 | pos_iou_thr=0.5, 49 | neg_iou_thr=0.4, 50 | min_pos_iou=0, 51 | ignore_iof_thr=-1), 52 | allowed_border=-1, 53 | pos_weight=-1, 54 | debug=False), 55 | test_cfg=dict( 56 | nms_pre=1000, 57 | min_bbox_size=0, 58 | score_thr=0.05, 59 | nms=dict(type='nms', iou_threshold=0.5), 60 | max_per_img=100)) 61 | -------------------------------------------------------------------------------- /detection/configs/_base_/models/rpn_r50_caffe_c4.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='RPN', 4 | pretrained='open-mmlab://detectron2/resnet50_caffe', 5 | backbone=dict( 6 | type='ResNet', 7 | depth=50, 8 | num_stages=3, 9 | strides=(1, 2, 2), 10 | dilations=(1, 1, 1), 11 | out_indices=(2, ), 12 | frozen_stages=1, 13 | norm_cfg=dict(type='BN', requires_grad=False), 14 | norm_eval=True, 15 | style='caffe'), 16 | neck=None, 17 | rpn_head=dict( 18 | type='RPNHead', 19 | in_channels=1024, 20 | feat_channels=1024, 21 | anchor_generator=dict( 22 | type='AnchorGenerator', 23 | scales=[2, 4, 8, 16, 32], 24 | ratios=[0.5, 1.0, 2.0], 25 | strides=[16]), 26 | bbox_coder=dict( 27 | type='DeltaXYWHBBoxCoder', 28 | target_means=[.0, .0, .0, .0], 29 | target_stds=[1.0, 1.0, 1.0, 1.0]), 30 | loss_cls=dict( 31 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 32 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 33 | # model training and testing settings 34 | train_cfg=dict( 35 | rpn=dict( 36 | assigner=dict( 37 | type='MaxIoUAssigner', 38 | pos_iou_thr=0.7, 39 | neg_iou_thr=0.3, 40 | min_pos_iou=0.3, 41 | ignore_iof_thr=-1), 42 | sampler=dict( 43 | type='RandomSampler', 44 | num=256, 45 | pos_fraction=0.5, 46 | neg_pos_ub=-1, 47 | add_gt_as_proposals=False), 48 | allowed_border=0, 49 | pos_weight=-1, 50 | debug=False)), 51 | test_cfg=dict( 52 | rpn=dict( 53 | nms_pre=12000, 54 | max_per_img=2000, 55 | nms=dict(type='nms', iou_threshold=0.7), 56 | min_bbox_size=0))) 57 | -------------------------------------------------------------------------------- /detection/configs/_base_/models/rpn_r50_fpn.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='RPN', 4 | pretrained='torchvision://resnet50', 5 | backbone=dict( 6 | type='ResNet', 7 | depth=50, 8 | num_stages=4, 9 | out_indices=(0, 1, 2, 3), 10 | frozen_stages=1, 11 | norm_cfg=dict(type='BN', requires_grad=True), 12 | norm_eval=True, 13 | style='pytorch'), 14 | neck=dict( 15 | type='FPN', 16 | in_channels=[256, 512, 1024, 2048], 17 | out_channels=256, 18 | num_outs=5), 19 | rpn_head=dict( 20 | type='RPNHead', 21 | in_channels=256, 22 | feat_channels=256, 23 | anchor_generator=dict( 24 | type='AnchorGenerator', 25 | scales=[8], 26 | ratios=[0.5, 1.0, 2.0], 27 | strides=[4, 8, 16, 32, 64]), 28 | bbox_coder=dict( 29 | type='DeltaXYWHBBoxCoder', 30 | target_means=[.0, .0, .0, .0], 31 | target_stds=[1.0, 1.0, 1.0, 1.0]), 32 | loss_cls=dict( 33 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 34 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 35 | # model training and testing settings 36 | train_cfg=dict( 37 | rpn=dict( 38 | assigner=dict( 39 | type='MaxIoUAssigner', 40 | pos_iou_thr=0.7, 41 | neg_iou_thr=0.3, 42 | min_pos_iou=0.3, 43 | ignore_iof_thr=-1), 44 | sampler=dict( 45 | type='RandomSampler', 46 | num=256, 47 | pos_fraction=0.5, 48 | neg_pos_ub=-1, 49 | add_gt_as_proposals=False), 50 | allowed_border=0, 51 | pos_weight=-1, 52 | debug=False)), 53 | test_cfg=dict( 54 | rpn=dict( 55 | nms_pre=2000, 56 | max_per_img=1000, 57 | nms=dict(type='nms', iou_threshold=0.7), 58 | min_bbox_size=0))) 59 | -------------------------------------------------------------------------------- /detection/configs/_base_/models/ssd300.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | input_size = 300 3 | model = dict( 4 | type='SingleStageDetector', 5 | pretrained='open-mmlab://vgg16_caffe', 6 | backbone=dict( 7 | type='SSDVGG', 8 | input_size=input_size, 9 | depth=16, 10 | with_last_pool=False, 11 | ceil_mode=True, 12 | out_indices=(3, 4), 13 | out_feature_indices=(22, 34), 14 | l2_norm_scale=20), 15 | neck=None, 16 | bbox_head=dict( 17 | type='SSDHead', 18 | in_channels=(512, 1024, 512, 256, 256, 256), 19 | num_classes=80, 20 | anchor_generator=dict( 21 | type='SSDAnchorGenerator', 22 | scale_major=False, 23 | input_size=input_size, 24 | basesize_ratio_range=(0.15, 0.9), 25 | strides=[8, 16, 32, 64, 100, 300], 26 | ratios=[[2], [2, 3], [2, 3], [2, 3], [2], [2]]), 27 | bbox_coder=dict( 28 | type='DeltaXYWHBBoxCoder', 29 | target_means=[.0, .0, .0, .0], 30 | target_stds=[0.1, 0.1, 0.2, 0.2])), 31 | # model training and testing settings 32 | train_cfg=dict( 33 | assigner=dict( 34 | type='MaxIoUAssigner', 35 | pos_iou_thr=0.5, 36 | neg_iou_thr=0.5, 37 | min_pos_iou=0., 38 | ignore_iof_thr=-1, 39 | gt_max_assign_all=False), 40 | smoothl1_beta=1., 41 | allowed_border=-1, 42 | pos_weight=-1, 43 | neg_pos_ratio=3, 44 | debug=False), 45 | test_cfg=dict( 46 | nms_pre=1000, 47 | nms=dict(type='nms', iou_threshold=0.45), 48 | min_bbox_size=0, 49 | score_thr=0.02, 50 | max_per_img=200)) 51 | cudnn_benchmark = True 52 | -------------------------------------------------------------------------------- /detection/configs/_base_/schedules/schedule_1x.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict( 6 | policy='step', 7 | warmup='linear', 8 | warmup_iters=500, 9 | warmup_ratio=1e-6, # 0.001 10 | step=[8, 11]) 11 | runner = dict(type='EpochBasedRunner', max_epochs=12) 12 | -------------------------------------------------------------------------------- /detection/configs/_base_/schedules/schedule_20e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict( 6 | policy='step', 7 | warmup='linear', 8 | warmup_iters=500, 9 | warmup_ratio=0.001, 10 | step=[16, 19]) 11 | runner = dict(type='EpochBasedRunner', max_epochs=20) 12 | -------------------------------------------------------------------------------- /detection/configs/_base_/schedules/schedule_2x.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict( 6 | policy='step', 7 | warmup='linear', 8 | warmup_iters=500, 9 | warmup_ratio=0.001, 10 | step=[16, 22]) 11 | runner = dict(type='EpochBasedRunner', max_epochs=24) 12 | -------------------------------------------------------------------------------- /detection/configs/mask_rcnn_repvit_m1_1_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '_base_/models/mask_rcnn_r50_fpn.py', 3 | '_base_/datasets/coco_instance.py', 4 | '_base_/schedules/schedule_1x.py', 5 | '_base_/default_runtime.py' 6 | ] 7 | # optimizer 8 | model = dict( 9 | backbone=dict( 10 | type='repvit_m1_1', 11 | init_cfg=dict( 12 | type='Pretrained', 13 | checkpoint='pretrain/repvit_m1_1_distill_300e.pth', 14 | ), 15 | out_indices = [2,6,20,24] 16 | ), 17 | neck=dict( 18 | type='FPN', 19 | in_channels=[64, 128, 256, 512], 20 | out_channels=256, 21 | num_outs=5)) 22 | # optimizer 23 | optimizer = dict(_delete_=True, type='AdamW', lr=0.0002, weight_decay=0.05) # 0.0001 24 | optimizer_config = dict(grad_clip=None) 25 | -------------------------------------------------------------------------------- /detection/configs/mask_rcnn_repvit_m1_5_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '_base_/models/mask_rcnn_r50_fpn.py', 3 | '_base_/datasets/coco_instance.py', 4 | '_base_/schedules/schedule_1x.py', 5 | '_base_/default_runtime.py' 6 | ] 7 | # optimizer 8 | model = dict( 9 | backbone=dict( 10 | type='repvit_m1_5', 11 | init_cfg=dict( 12 | type='Pretrained', 13 | checkpoint='pretrain/repvit_m1_5_distill_300e.pth', 14 | ), 15 | out_indices=[4, 10, 36, 42] 16 | ), 17 | neck=dict( 18 | type='FPN', 19 | in_channels=[64, 128, 256, 512], 20 | out_channels=256, 21 | num_outs=5)) 22 | # optimizer 23 | optimizer = dict(_delete_=True, type='AdamW', lr=0.0002, weight_decay=0.05) # 0.0001 24 | optimizer_config = dict(grad_clip=None) 25 | -------------------------------------------------------------------------------- /detection/configs/mask_rcnn_repvit_m2_3_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '_base_/models/mask_rcnn_r50_fpn.py', 3 | '_base_/datasets/coco_instance.py', 4 | '_base_/schedules/schedule_1x.py', 5 | '_base_/default_runtime.py' 6 | ] 7 | # optimizer 8 | model = dict( 9 | backbone=dict( 10 | type='repvit_m2_3', 11 | init_cfg=dict( 12 | type='Pretrained', 13 | checkpoint='pretrain/repvit_m2_3_distill_450e.pth', 14 | ), 15 | out_indices=[6, 14, 50, 54] 16 | ), 17 | neck=dict( 18 | type='FPN', 19 | in_channels=[80, 160, 320, 640], 20 | out_channels=256, 21 | num_outs=5)) 22 | # optimizer 23 | optimizer = dict(_delete_=True, type='AdamW', lr=0.0002, weight_decay=0.05) # 0.0001 24 | optimizer_config = dict(grad_clip=None) 25 | -------------------------------------------------------------------------------- /detection/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | NNODES=${NNODES:-1} 7 | NODE_RANK=${NODE_RANK:-0} 8 | PORT=${PORT:-29500} 9 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 10 | 11 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 12 | NCCL_P2P_DISABLE=1 \ 13 | python -m torch.distributed.launch \ 14 | --nnodes=$NNODES \ 15 | --node_rank=$NODE_RANK \ 16 | --master_addr=$MASTER_ADDR \ 17 | --nproc_per_node=$GPUS \ 18 | --master_port=$PORT \ 19 | $(dirname "$0")/test.py \ 20 | $CONFIG \ 21 | $CHECKPOINT \ 22 | --launcher pytorch \ 23 | ${@:4} 24 | -------------------------------------------------------------------------------- /detection/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | NNODES=${NNODES:-1} 6 | NODE_RANK=${NODE_RANK:-0} 7 | PORT=${PORT:-29500} 8 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 9 | 10 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 11 | NCCL_P2P_DISABLE=1 \ 12 | python -m torch.distributed.launch \ 13 | --nnodes=$NNODES \ 14 | --node_rank=$NODE_RANK \ 15 | --master_addr=$MASTER_ADDR \ 16 | --nproc_per_node=$GPUS \ 17 | --master_port=$PORT \ 18 | $(dirname "$0")/train.py \ 19 | $CONFIG \ 20 | --seed 0 \ 21 | --launcher pytorch ${@:3} 22 | -------------------------------------------------------------------------------- /detection/eval.sh: -------------------------------------------------------------------------------- 1 | PORT=12345 ./dist_test.sh configs/mask_rcnn_repvit_m1_1_fpn_1x_coco.py det_pretrain/repvit_m1_1_coco.pth 8 --eval bbox segm -------------------------------------------------------------------------------- /detection/mmcv_custom/runner/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | import os.path as osp 3 | import time 4 | from tempfile import TemporaryDirectory 5 | 6 | import torch 7 | from torch.optim import Optimizer 8 | 9 | import mmcv 10 | from mmcv.parallel import is_module_wrapper 11 | from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict 12 | 13 | def save_checkpoint(model, filename, optimizer=None, meta=None): 14 | """Save checkpoint to file. 15 | 16 | The checkpoint will have 4 fields: ``meta``, ``state_dict`` and 17 | ``optimizer``, ``amp``. By default ``meta`` will contain version 18 | and time info. 19 | 20 | Args: 21 | model (Module): Module whose params are to be saved. 22 | filename (str): Checkpoint filename. 23 | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. 24 | meta (dict, optional): Metadata to be saved in checkpoint. 25 | """ 26 | if meta is None: 27 | meta = {} 28 | elif not isinstance(meta, dict): 29 | raise TypeError(f'meta must be a dict or None, but got {type(meta)}') 30 | meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) 31 | 32 | if is_module_wrapper(model): 33 | model = model.module 34 | 35 | if hasattr(model, 'CLASSES') and model.CLASSES is not None: 36 | # save class name to the meta 37 | meta.update(CLASSES=model.CLASSES) 38 | 39 | checkpoint = { 40 | 'meta': meta, 41 | 'state_dict': weights_to_cpu(get_state_dict(model)) 42 | } 43 | # save optimizer state dict in the checkpoint 44 | if isinstance(optimizer, Optimizer): 45 | checkpoint['optimizer'] = optimizer.state_dict() 46 | elif isinstance(optimizer, dict): 47 | checkpoint['optimizer'] = {} 48 | for name, optim in optimizer.items(): 49 | checkpoint['optimizer'][name] = optim.state_dict() 50 | 51 | # save amp state dict in the checkpoint 52 | checkpoint['amp'] = apex.amp.state_dict() 53 | 54 | if filename.startswith('pavi://'): 55 | try: 56 | from pavi import modelcloud 57 | from pavi.exception import NodeNotFoundError 58 | except ImportError: 59 | raise ImportError( 60 | 'Please install pavi to load checkpoint from modelcloud.') 61 | model_path = filename[7:] 62 | root = modelcloud.Folder() 63 | model_dir, model_name = osp.split(model_path) 64 | try: 65 | model = modelcloud.get(model_dir) 66 | except NodeNotFoundError: 67 | model = root.create_training_model(model_dir) 68 | with TemporaryDirectory() as tmp_dir: 69 | checkpoint_file = osp.join(tmp_dir, model_name) 70 | with open(checkpoint_file, 'wb') as f: 71 | torch.save(checkpoint, f) 72 | f.flush() 73 | model.create_file(checkpoint_file, name=model_name) 74 | else: 75 | mmcv.mkdir_or_exist(osp.dirname(filename)) 76 | # immediately flush buffer 77 | with open(filename, 'wb') as f: 78 | torch.save(checkpoint, f) 79 | f.flush() -------------------------------------------------------------------------------- /detection/mmcv_custom/runner/epoch_based_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | import os.path as osp 3 | import platform 4 | import shutil 5 | 6 | import torch 7 | from torch.optim import Optimizer 8 | 9 | import mmcv 10 | from mmcv.runner import RUNNERS, EpochBasedRunner 11 | from .checkpoint import save_checkpoint 12 | 13 | @RUNNERS.register_module() 14 | class EpochBasedRunnerAmp(EpochBasedRunner): 15 | """Epoch-based Runner with AMP support. 16 | 17 | This runner train models epoch by epoch. 18 | """ 19 | 20 | def save_checkpoint(self, 21 | out_dir, 22 | filename_tmpl='epoch_{}.pth', 23 | save_optimizer=True, 24 | meta=None, 25 | create_symlink=True): 26 | """Save the checkpoint. 27 | 28 | Args: 29 | out_dir (str): The directory that checkpoints are saved. 30 | filename_tmpl (str, optional): The checkpoint filename template, 31 | which contains a placeholder for the epoch number. 32 | Defaults to 'epoch_{}.pth'. 33 | save_optimizer (bool, optional): Whether to save the optimizer to 34 | the checkpoint. Defaults to True. 35 | meta (dict, optional): The meta information to be saved in the 36 | checkpoint. Defaults to None. 37 | create_symlink (bool, optional): Whether to create a symlink 38 | "latest.pth" to point to the latest checkpoint. 39 | Defaults to True. 40 | """ 41 | if meta is None: 42 | meta = dict(epoch=self.epoch + 1, iter=self.iter) 43 | elif isinstance(meta, dict): 44 | meta.update(epoch=self.epoch + 1, iter=self.iter) 45 | else: 46 | raise TypeError( 47 | f'meta should be a dict or None, but got {type(meta)}') 48 | if self.meta is not None: 49 | meta.update(self.meta) 50 | 51 | filename = filename_tmpl.format(self.epoch + 1) 52 | filepath = osp.join(out_dir, filename) 53 | optimizer = self.optimizer if save_optimizer else None 54 | save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) 55 | # in some environments, `os.symlink` is not supported, you may need to 56 | # set `create_symlink` to False 57 | if create_symlink: 58 | dst_file = osp.join(out_dir, 'latest.pth') 59 | if platform.system() != 'Windows': 60 | mmcv.symlink(filename, dst_file) 61 | else: 62 | shutil.copy(filepath, dst_file) 63 | 64 | def resume(self, 65 | checkpoint, 66 | resume_optimizer=True, 67 | map_location='default'): 68 | if map_location == 'default': 69 | if torch.cuda.is_available(): 70 | device_id = torch.cuda.current_device() 71 | checkpoint = self.load_checkpoint( 72 | checkpoint, 73 | map_location=lambda storage, loc: storage.cuda(device_id)) 74 | else: 75 | checkpoint = self.load_checkpoint(checkpoint) 76 | else: 77 | checkpoint = self.load_checkpoint( 78 | checkpoint, map_location=map_location) 79 | 80 | self._epoch = checkpoint['meta']['epoch'] 81 | self._iter = checkpoint['meta']['iter'] 82 | if 'optimizer' in checkpoint and resume_optimizer: 83 | if isinstance(self.optimizer, Optimizer): 84 | self.optimizer.load_state_dict(checkpoint['optimizer']) 85 | elif isinstance(self.optimizer, dict): 86 | for k in self.optimizer.keys(): 87 | self.optimizer[k].load_state_dict( 88 | checkpoint['optimizer'][k]) 89 | else: 90 | raise TypeError( 91 | 'Optimizer should be dict or torch.optim.Optimizer ' 92 | f'but got {type(self.optimizer)}') 93 | 94 | if 'amp' in checkpoint: 95 | apex.amp.load_state_dict(checkpoint['amp']) 96 | self.logger.info('load amp state dict') 97 | 98 | self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter) 99 | -------------------------------------------------------------------------------- /detection/mmcv_custom/runner/optimizer.py: -------------------------------------------------------------------------------- 1 | from mmcv.runner import OptimizerHook, HOOKS 2 | 3 | 4 | @HOOKS.register_module() 5 | class DistOptimizerHook(OptimizerHook): 6 | """Optimizer hook for distributed training.""" 7 | 8 | def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False): 9 | self.grad_clip = grad_clip 10 | self.coalesce = coalesce 11 | self.bucket_size_mb = bucket_size_mb 12 | self.update_interval = update_interval 13 | self.use_fp16 = use_fp16 14 | 15 | def before_run(self, runner): 16 | runner.optimizer.zero_grad() 17 | 18 | def after_train_iter(self, runner): 19 | runner.outputs['loss'] /= self.update_interval 20 | if self.use_fp16: 21 | with apex.amp.scale_loss(runner.outputs['loss'], runner.optimizer) as scaled_loss: 22 | scaled_loss.backward() 23 | else: 24 | runner.outputs['loss'].backward() 25 | if self.every_n_iters(runner, self.update_interval): 26 | if self.grad_clip is not None: 27 | self.clip_grads(runner.model.parameters()) 28 | runner.optimizer.step() 29 | runner.optimizer.zero_grad() -------------------------------------------------------------------------------- /detection/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | WORK_DIR=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-4} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-12} 12 | SRUN_ARGS=${SRUN_ARGS:-""} 13 | PY_ARGS=${@:5} 14 | 15 | export NCCL_P2P_DISABLE=1 16 | export MASTER_PORT=24680 17 | 18 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 19 | srun -p ${PARTITION} \ 20 | --job-name=${JOB_NAME} \ 21 | --gres=gpu:${GPUS_PER_NODE} \ 22 | --ntasks=${GPUS} \ 23 | --ntasks-per-node=${GPUS_PER_NODE} \ 24 | --cpus-per-task=${CPUS_PER_TASK} \ 25 | --kill-on-bad-exit=1 \ 26 | --mem 250G \ 27 | ${SRUN_ARGS} \ 28 | python -u train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} 29 | -------------------------------------------------------------------------------- /detection/train.sh: -------------------------------------------------------------------------------- 1 | ./dist_train.sh configs/mask_rcnn_repvit_m1_1_fpn_1x_coco.py 8 -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train and eval functions used in main.py 3 | """ 4 | import math 5 | import sys 6 | from typing import Iterable, Optional 7 | 8 | import torch 9 | 10 | from timm.data import Mixup 11 | from timm.utils import accuracy, ModelEma 12 | 13 | from losses import DistillationLoss 14 | import utils 15 | 16 | def set_bn_state(model): 17 | for m in model.modules(): 18 | if isinstance(m, torch.nn.modules.batchnorm._BatchNorm): 19 | m.eval() 20 | 21 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 22 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 23 | device: torch.device, epoch: int, loss_scaler, 24 | clip_grad: float = 0, 25 | clip_mode: str = 'norm', 26 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 27 | set_training_mode=True, 28 | set_bn_eval=False,): 29 | model.train(set_training_mode) 30 | if set_bn_eval: 31 | set_bn_state(model) 32 | metric_logger = utils.MetricLogger(delimiter=" ") 33 | metric_logger.add_meter('lr', utils.SmoothedValue( 34 | window_size=1, fmt='{value:.6f}')) 35 | header = 'Epoch: [{}]'.format(epoch) 36 | print_freq = 100 37 | 38 | for samples, targets in metric_logger.log_every( 39 | data_loader, print_freq, header): 40 | samples = samples.to(device, non_blocking=True) 41 | targets = targets.to(device, non_blocking=True) 42 | 43 | if mixup_fn is not None: 44 | samples, targets = mixup_fn(samples, targets) 45 | 46 | with torch.cuda.amp.autocast(): 47 | outputs = model(samples) 48 | loss = criterion(samples, outputs, targets) 49 | 50 | loss_value = loss.item() 51 | 52 | if not math.isfinite(loss_value): 53 | print("Loss is {}, stopping training".format(loss_value)) 54 | sys.exit(1) 55 | 56 | optimizer.zero_grad() 57 | 58 | # this attribute is added by timm on one optimizer (adahessian) 59 | is_second_order = hasattr( 60 | optimizer, 'is_second_order') and optimizer.is_second_order 61 | loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode, 62 | parameters=model.parameters(), create_graph=is_second_order) 63 | 64 | torch.cuda.synchronize() 65 | if model_ema is not None: 66 | model_ema.update(model) 67 | 68 | metric_logger.update(loss=loss_value) 69 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 70 | # gather the stats from all processes 71 | metric_logger.synchronize_between_processes() 72 | print("Averaged stats:", metric_logger) 73 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 74 | 75 | 76 | @torch.no_grad() 77 | def evaluate(data_loader, model, device): 78 | criterion = torch.nn.CrossEntropyLoss() 79 | 80 | metric_logger = utils.MetricLogger(delimiter=" ") 81 | header = 'Test:' 82 | 83 | # switch to evaluation mode 84 | model.eval() 85 | 86 | for images, target in metric_logger.log_every(data_loader, 10, header): 87 | images = images.to(device, non_blocking=True) 88 | target = target.to(device, non_blocking=True) 89 | 90 | # compute output 91 | with torch.cuda.amp.autocast(): 92 | output = model(images) 93 | loss = criterion(output, target) 94 | 95 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 96 | 97 | batch_size = images.shape[0] 98 | metric_logger.update(loss=loss.item()) 99 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 100 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 101 | # gather the stats from all processes 102 | metric_logger.synchronize_between_processes() 103 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 104 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 105 | 106 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 107 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | python main.py --eval --model repvit_m1_1 --resume pretrain/repvit_m1_1_distill_300e.pth --data-path ~/imagenet -------------------------------------------------------------------------------- /export_coreml.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from timm import create_model 4 | import model 5 | 6 | import utils 7 | 8 | import torch 9 | import torchvision 10 | from argparse import ArgumentParser 11 | 12 | parser = ArgumentParser() 13 | 14 | parser.add_argument('--model', default='repvit_m1_1', type=str) 15 | parser.add_argument('--resolution', default=224, type=int) 16 | parser.add_argument('--ckpt', default=None, type=str) 17 | 18 | if __name__ == "__main__": 19 | # Load a pre-trained version of MobileNetV2 20 | args = parser.parse_args() 21 | model = create_model(args.model, distillation=True) 22 | if args.ckpt: 23 | model.load_state_dict(torch.load(args.ckpt)['model']) 24 | utils.replace_batchnorm(model) 25 | model.eval() 26 | 27 | # Trace the model with random data. 28 | resolution = args.resolution 29 | example_input = torch.rand(1, 3, resolution, resolution) 30 | traced_model = torch.jit.trace(model, example_input) 31 | out = traced_model(example_input) 32 | 33 | import coremltools as ct 34 | 35 | # Using image_input in the inputs parameter: 36 | # Convert to Core ML neural network using the Unified Conversion API. 37 | model = ct.convert( 38 | traced_model, 39 | inputs=[ct.ImageType(shape=example_input.shape)] 40 | ) 41 | 42 | # Save the converted model. 43 | model.save(f"coreml/{args.model}_{resolution}.mlmodel") -------------------------------------------------------------------------------- /figures/latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/figures/latency.png -------------------------------------------------------------------------------- /figures/repvit_m0_9_latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/figures/repvit_m0_9_latency.png -------------------------------------------------------------------------------- /flops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from timm import create_model 4 | import model 5 | import utils 6 | from fvcore.nn import FlopCountAnalysis 7 | 8 | T0 = 5 9 | T1 = 10 10 | 11 | for n, batch_size, resolution in [ 12 | ('repvit_m0_9', 1024, 224), 13 | ]: 14 | inputs = torch.randn(1, 3, resolution, 15 | resolution) 16 | model = create_model(n, num_classes=1000) 17 | utils.replace_batchnorm(model) 18 | n_parameters = sum(p.numel() 19 | for p in model.parameters() if p.requires_grad) 20 | print('number of params:', n_parameters / 1e6) 21 | flops = FlopCountAnalysis(model, inputs) 22 | print("flops: ", flops.total() / 1e9) -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the knowledge distillation loss, proposed in deit 3 | """ 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | 8 | class DistillationLoss(torch.nn.Module): 9 | """ 10 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 11 | taking a teacher model prediction and using it as additional supervision. 12 | """ 13 | 14 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 15 | distillation_type: str, alpha: float, tau: float): 16 | super().__init__() 17 | self.base_criterion = base_criterion 18 | self.teacher_model = teacher_model 19 | assert distillation_type in ['none', 'soft', 'hard'] 20 | self.distillation_type = distillation_type 21 | self.alpha = alpha 22 | self.tau = tau 23 | 24 | def forward(self, inputs, outputs, labels): 25 | """ 26 | Args: 27 | inputs: The original inputs that are feed to the teacher model 28 | outputs: the outputs of the model to be trained. It is expected to be 29 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 30 | in the first position and the distillation predictions as the second output 31 | labels: the labels for the base criterion 32 | """ 33 | outputs_kd = None 34 | if not isinstance(outputs, torch.Tensor): 35 | # assume that the model outputs a tuple of [outputs, outputs_kd] 36 | outputs, outputs_kd = outputs 37 | base_loss = self.base_criterion(outputs, labels) 38 | if self.distillation_type == 'none': 39 | return base_loss 40 | 41 | if outputs_kd is None: 42 | raise ValueError("When knowledge distillation is enabled, the model is " 43 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 44 | "class_token and the dist_token") 45 | # don't backprop throught the teacher 46 | with torch.no_grad(): 47 | teacher_outputs = self.teacher_model(inputs) 48 | 49 | if self.distillation_type == 'soft': 50 | T = self.tau 51 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 52 | # with slight modifications 53 | distillation_loss = F.kl_div( 54 | F.log_softmax(outputs_kd / T, dim=1), 55 | F.log_softmax(teacher_outputs / T, dim=1), 56 | reduction='sum', 57 | log_target=True 58 | ) * (T * T) / outputs_kd.numel() 59 | elif self.distillation_type == 'hard': 60 | distillation_loss = F.cross_entropy( 61 | outputs_kd, teacher_outputs.argmax(dim=1)) 62 | 63 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 64 | return loss 65 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import model.repvit -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | timm==0.5.4 3 | fvcore -------------------------------------------------------------------------------- /sam/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pyo 3 | *.pyd 4 | __py 5 | **/__pycache__/ 6 | repvit_sam.egg-info 7 | weights/*.pt 8 | *.pt 9 | *.onnx -------------------------------------------------------------------------------- /sam/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /sam/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to segment-anything 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints, using the `linter.sh` script in the project's root directory. Linting requires `black==23.*`, `isort==5.12.0`, `flake8`, and `mypy`. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to segment-anything, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /sam/app/.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tar filter=lfs diff=lfs merge=lfs -text 29 | *.tflite filter=lfs diff=lfs merge=lfs -text 30 | *.tgz filter=lfs diff=lfs merge=lfs -text 31 | *.wasm filter=lfs diff=lfs merge=lfs -text 32 | *.xz filter=lfs diff=lfs merge=lfs -text 33 | *.zip filter=lfs diff=lfs merge=lfs -text 34 | *.zst filter=lfs diff=lfs merge=lfs -text 35 | *tfevents* filter=lfs diff=lfs merge=lfs -text 36 | assets/sa_1309.jpg filter=lfs diff=lfs merge=lfs -text 37 | assets/sa_192.jpg filter=lfs diff=lfs merge=lfs -text 38 | assets/sa_414.jpg filter=lfs diff=lfs merge=lfs -text 39 | assets/sa_862.jpg filter=lfs diff=lfs merge=lfs -text 40 | -------------------------------------------------------------------------------- /sam/app/assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/sam/app/assets/.DS_Store -------------------------------------------------------------------------------- /sam/app/assets/picture1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/sam/app/assets/picture1.jpg -------------------------------------------------------------------------------- /sam/app/assets/picture2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/sam/app/assets/picture2.jpg -------------------------------------------------------------------------------- /sam/app/assets/picture3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/sam/app/assets/picture3.jpg -------------------------------------------------------------------------------- /sam/app/assets/picture4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/sam/app/assets/picture4.jpg -------------------------------------------------------------------------------- /sam/app/assets/picture5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/sam/app/assets/picture5.jpg -------------------------------------------------------------------------------- /sam/app/assets/picture6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/sam/app/assets/picture6.jpg -------------------------------------------------------------------------------- /sam/app/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | timm 4 | opencv-python 5 | -------------------------------------------------------------------------------- /sam/app/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/sam/app/utils/__init__.py -------------------------------------------------------------------------------- /sam/assets/logo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/sam/assets/logo2.png -------------------------------------------------------------------------------- /sam/assets/mask_box.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/sam/assets/mask_box.jpg -------------------------------------------------------------------------------- /sam/assets/mask_comparision.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/sam/assets/mask_comparision.jpg -------------------------------------------------------------------------------- /sam/assets/mask_point.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/sam/assets/mask_point.jpg -------------------------------------------------------------------------------- /sam/assets/model_diagram.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/sam/assets/model_diagram.jpg -------------------------------------------------------------------------------- /sam/assets/notebook1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/sam/assets/notebook1.png -------------------------------------------------------------------------------- /sam/assets/notebook2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/sam/assets/notebook2.png -------------------------------------------------------------------------------- /sam/figures/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/sam/figures/comparison.png -------------------------------------------------------------------------------- /sam/linter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | { 5 | black --version | grep -E "23\." > /dev/null 6 | } || { 7 | echo "Linter requires 'black==23.*' !" 8 | exit 1 9 | } 10 | 11 | ISORT_VERSION=$(isort --version-number) 12 | if [[ "$ISORT_VERSION" != 5.12* ]]; then 13 | echo "Linter requires isort==5.12.0 !" 14 | exit 1 15 | fi 16 | 17 | echo "Running isort ..." 18 | isort . --atomic 19 | 20 | echo "Running black ..." 21 | black -l 100 . 22 | 23 | echo "Running flake8 ..." 24 | if [ -x "$(command -v flake8)" ]; then 25 | flake8 . 26 | else 27 | python3 -m flake8 . 28 | fi 29 | 30 | echo "Running mypy..." 31 | 32 | mypy --exclude 'setup.py|notebooks' . 33 | -------------------------------------------------------------------------------- /sam/notebooks/coreml_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "\n", 10 | "\n", 11 | "import torch\n", 12 | "import numpy as np\n", 13 | "import cv2\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import coremltools as ct\n", 16 | "import math\n", 17 | "from repvit_sam.utils.transforms import ResizeLongestSide\n", 18 | "import torch.nn.functional as F\n", 19 | "\n", 20 | "\n", 21 | "def show_mask(mask, ax):\n", 22 | " color = np.array([30/255, 144/255, 255/255, 0.6])\n", 23 | " h, w = mask.shape[-2:]\n", 24 | " mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n", 25 | " ax.imshow(mask_image)\n", 26 | " \n", 27 | "def show_points(coords, labels, ax, marker_size=375):\n", 28 | " pos_points = coords[labels==1]\n", 29 | " neg_points = coords[labels==0]\n", 30 | " ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n", 31 | " ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) \n", 32 | "\n", 33 | "def preprocess(x, img_size=1024):\n", 34 | " \"\"\"Normalize pixel values and pad to a square input.\"\"\"\n", 35 | " # Normalize colors\n", 36 | " transform = ResizeLongestSide(img_size)\n", 37 | " x = transform.apply_image(x)\n", 38 | " x = torch.as_tensor(x)\n", 39 | " x = x.permute(2, 0, 1).contiguous()\n", 40 | "\n", 41 | " pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)\n", 42 | " pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)\n", 43 | " x = (x - pixel_mean) / pixel_std\n", 44 | "\n", 45 | " # Pad\n", 46 | " h, w = x.shape[-2:]\n", 47 | " padh = img_size - h\n", 48 | " padw = img_size - w\n", 49 | " x = F.pad(x, (0, padw, 0, padh))\n", 50 | " return x, transform\n", 51 | "\n", 52 | "def postprocess(raw_image, masks):\n", 53 | " def resize_longest_image_size(\n", 54 | " input_image_size, longest_side: int\n", 55 | " ):\n", 56 | " scale = longest_side / max(input_image_size)\n", 57 | " transformed_size = [int(math.floor(scale * each + 0.5)) for each in input_image_size]\n", 58 | " return transformed_size\n", 59 | "\n", 60 | " prepadded_size = resize_longest_image_size(raw_image.shape[:2], masks.shape[2])\n", 61 | " masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore\n", 62 | "\n", 63 | " h, w = raw_image.shape[:2]\n", 64 | " masks = F.interpolate(torch.tensor(masks), size=(h, w), mode=\"bilinear\", align_corners=False)\n", 65 | " masks = masks > 0\n", 66 | " return masks" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "!python3 ../scripts/export_coreml_encoder.py --resolution 1024 --model repvit --samckpt ../weights/repvit_sam.pt\n", 76 | "!python3 ../scripts/export_coreml_decoder.py --checkpoint ../weights/repvit_sam.pt --model-type repvit" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "encoder = ct.models.MLModel('coreml/repvit_1024.mlpackage')" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "decoder = ct.models.MLModel('coreml/sam_decoder.mlpackage')" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "raw_image = cv2.imread('../../app/assets/picture3.jpg')\n", 104 | "raw_image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB)\n", 105 | "image, transform = preprocess(raw_image)\n", 106 | "image_embedding= list(encoder.predict({'x_1': image.numpy()[None, ...]}).values())[0]" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "input_point = np.array([[553, 808]])\n", 116 | "input_label = np.array([1])\n", 117 | "\n", 118 | "coreml_coord = input_point[None, :, :].astype(np.float32)\n", 119 | "coreml_label = input_label[None, :].astype(np.float32)\n", 120 | "\n", 121 | "coreml_coord = transform.apply_coords(coreml_coord, raw_image.shape[:2]).astype(np.float32)\n", 122 | "\n", 123 | "coreml_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)\n", 124 | "coreml_has_mask_input = np.zeros(1, dtype=np.float32)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "ort_inputs = {\n", 134 | " \"image_embeddings\": image_embedding,\n", 135 | " \"point_coords\": coreml_coord,\n", 136 | " \"point_labels\": coreml_label,\n", 137 | " \"mask_input\": coreml_mask_input,\n", 138 | " \"has_mask_input\": coreml_has_mask_input,\n", 139 | "}" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "low_res_logits, score, masks = decoder.predict(ort_inputs).values()\n", 149 | "plt.figure(figsize=(10,10))\n", 150 | "plt.imshow(raw_image)\n", 151 | "show_mask(postprocess(raw_image, masks), plt.gca())\n", 152 | "show_points(input_point, input_label, plt.gca())\n", 153 | "plt.axis('off')\n", 154 | "plt.show() " 155 | ] 156 | } 157 | ], 158 | "metadata": { 159 | "kernelspec": { 160 | "display_name": "Python 3", 161 | "language": "python", 162 | "name": "python3" 163 | }, 164 | "language_info": { 165 | "codemirror_mode": { 166 | "name": "ipython", 167 | "version": 3 168 | }, 169 | "file_extension": ".py", 170 | "mimetype": "text/x-python", 171 | "name": "python", 172 | "nbconvert_exporter": "python", 173 | "pygments_lexer": "ipython3", 174 | "version": "3.10.7" 175 | }, 176 | "orig_nbformat": 4 177 | }, 178 | "nbformat": 4, 179 | "nbformat_minor": 2 180 | } 181 | -------------------------------------------------------------------------------- /sam/notebooks/images/picture1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/sam/notebooks/images/picture1.jpg -------------------------------------------------------------------------------- /sam/notebooks/images/picture2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-MIG/RepViT/298f42075eda5d2e6102559fad260c970769d34e/sam/notebooks/images/picture2.jpg -------------------------------------------------------------------------------- /sam/repvit_sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | build_sam_vit_t, 13 | sam_model_registry, 14 | ) 15 | from .predictor import SamPredictor 16 | from .automatic_mask_generator import SamAutomaticMaskGenerator 17 | -------------------------------------------------------------------------------- /sam/repvit_sam/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | from .tiny_vit_sam import TinyViT 13 | from .repvit import RepViT -------------------------------------------------------------------------------- /sam/repvit_sam/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /sam/repvit_sam/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam/repvit_sam/utils/coreml.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from math import floor 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import functional as F 11 | 12 | from typing import Tuple, List 13 | 14 | from ..modeling import Sam 15 | from .amg import calculate_stability_score 16 | 17 | 18 | class SamCoreMLModel(nn.Module): 19 | """ 20 | This model should not be called directly, but is used in ONNX export. 21 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 22 | with some functions modified to enable model tracing. Also supports extra 23 | options controlling what information. See the ONNX export script for details. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | model: Sam, 29 | orig_img_size: List, 30 | return_single_mask: bool, 31 | use_stability_score: bool = False, 32 | return_extra_metrics: bool = False, 33 | ) -> None: 34 | super().__init__() 35 | self.mask_decoder = model.mask_decoder 36 | self.model = model 37 | self.img_size = model.image_encoder.img_size 38 | self.return_single_mask = return_single_mask 39 | self.use_stability_score = use_stability_score 40 | self.stability_score_offset = 1.0 41 | self.return_extra_metrics = return_extra_metrics 42 | self.orig_img_size = orig_img_size 43 | 44 | @staticmethod 45 | def resize_longest_image_size( 46 | input_image_size: List, longest_side: int 47 | ) -> List: 48 | scale = longest_side / max(input_image_size) 49 | transformed_size = [int(floor(scale * each + 0.5)) for each in input_image_size] 50 | return transformed_size 51 | 52 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 53 | point_coords = point_coords + 0.5 54 | point_coords = point_coords / self.img_size 55 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 56 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 57 | 58 | point_embedding = point_embedding * (point_labels != -1) 59 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 60 | point_labels == -1 61 | ) 62 | 63 | for i in range(self.model.prompt_encoder.num_point_embeddings): 64 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 65 | i 66 | ].weight * (point_labels == i) 67 | 68 | return point_embedding 69 | 70 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 71 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 72 | mask_embedding = mask_embedding + ( 73 | 1 - has_mask_input 74 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 75 | return mask_embedding 76 | 77 | def mask_postprocessing(self, masks: torch.Tensor) -> torch.Tensor: 78 | masks = F.interpolate( 79 | masks, 80 | size=(self.img_size, self.img_size), 81 | mode="bilinear", 82 | align_corners=False, 83 | ) 84 | 85 | prepadded_size = self.resize_longest_image_size(self.orig_img_size, self.img_size) 86 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 87 | 88 | h, w = self.orig_img_size[0], self.orig_img_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor 115 | ): 116 | sparse_embedding = self._embed_points(point_coords, point_labels) 117 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 118 | 119 | masks, scores = self.model.mask_decoder.predict_masks( 120 | image_embeddings=image_embeddings, 121 | image_pe=self.model.prompt_encoder.get_dense_pe(), 122 | sparse_prompt_embeddings=sparse_embedding, 123 | dense_prompt_embeddings=dense_embedding, 124 | ) 125 | 126 | if self.use_stability_score: 127 | scores = calculate_stability_score( 128 | masks, self.model.mask_threshold, self.stability_score_offset 129 | ) 130 | 131 | if self.return_single_mask: 132 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 133 | 134 | upscaled_masks = self.mask_postprocessing(masks) 135 | 136 | if self.return_extra_metrics: 137 | stability_scores = calculate_stability_score( 138 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 139 | ) 140 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 141 | return upscaled_masks, scores, stability_scores, areas, masks 142 | 143 | return upscaled_masks, scores, masks -------------------------------------------------------------------------------- /sam/repvit_sam/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from math import floor 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import functional as F 11 | 12 | from typing import Tuple, List 13 | 14 | from ..modeling import Sam 15 | from .amg import calculate_stability_score 16 | 17 | 18 | class SamOnnxModel(nn.Module): 19 | """ 20 | This model should not be called directly, but is used in ONNX export. 21 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 22 | with some functions modified to enable model tracing. Also supports extra 23 | options controlling what information. See the ONNX export script for details. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | model: Sam, 29 | return_single_mask: bool, 30 | use_stability_score: bool = False, 31 | return_extra_metrics: bool = False, 32 | ) -> None: 33 | super().__init__() 34 | self.mask_decoder = model.mask_decoder 35 | self.model = model 36 | self.img_size = model.image_encoder.img_size 37 | self.return_single_mask = return_single_mask 38 | self.use_stability_score = use_stability_score 39 | self.stability_score_offset = 1.0 40 | self.return_extra_metrics = return_extra_metrics 41 | 42 | @staticmethod 43 | def resize_longest_image_size( 44 | input_image_size: torch.Tensor, longest_side: int 45 | ) -> torch.Tensor: 46 | input_image_size = input_image_size.to(torch.float32) 47 | scale = longest_side / torch.max(input_image_size) 48 | transformed_size = scale * input_image_size 49 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 50 | return transformed_size 51 | 52 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 53 | point_coords = point_coords + 0.5 54 | point_coords = point_coords / self.img_size 55 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 56 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 57 | 58 | point_embedding = point_embedding * (point_labels != -1) 59 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 60 | point_labels == -1 61 | ) 62 | 63 | for i in range(self.model.prompt_encoder.num_point_embeddings): 64 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 65 | i 66 | ].weight * (point_labels == i) 67 | 68 | return point_embedding 69 | 70 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 71 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 72 | mask_embedding = mask_embedding + ( 73 | 1 - has_mask_input 74 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 75 | return mask_embedding 76 | 77 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 78 | masks = F.interpolate( 79 | masks, 80 | size=(self.img_size, self.img_size), 81 | mode="bilinear", 82 | align_corners=False, 83 | ) 84 | 85 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 86 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 87 | 88 | orig_im_size = orig_im_size.to(torch.int64) 89 | h, w = orig_im_size[0], orig_im_size[1] 90 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 91 | return masks 92 | 93 | def select_masks( 94 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 95 | ) -> Tuple[torch.Tensor, torch.Tensor]: 96 | # Determine if we should return the multiclick mask or not from the number of points. 97 | # The reweighting is used to avoid control flow. 98 | score_reweight = torch.tensor( 99 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 100 | ).to(iou_preds.device) 101 | score = iou_preds + (num_points - 2.5) * score_reweight 102 | best_idx = torch.argmax(score, dim=1) 103 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 104 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 105 | 106 | return masks, iou_preds 107 | 108 | @torch.no_grad() 109 | def forward( 110 | self, 111 | image_embeddings: torch.Tensor, 112 | point_coords: torch.Tensor, 113 | point_labels: torch.Tensor, 114 | mask_input: torch.Tensor, 115 | has_mask_input: torch.Tensor, 116 | orig_im_size: torch.Tensor, 117 | ): 118 | sparse_embedding = self._embed_points(point_coords, point_labels) 119 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 120 | 121 | masks, scores = self.model.mask_decoder.predict_masks( 122 | image_embeddings=image_embeddings, 123 | image_pe=self.model.prompt_encoder.get_dense_pe(), 124 | sparse_prompt_embeddings=sparse_embedding, 125 | dense_prompt_embeddings=dense_embedding, 126 | ) 127 | 128 | if self.use_stability_score: 129 | scores = calculate_stability_score( 130 | masks, self.model.mask_threshold, self.stability_score_offset 131 | ) 132 | 133 | if self.return_single_mask: 134 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 135 | 136 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 137 | 138 | if self.return_extra_metrics: 139 | stability_scores = calculate_stability_score( 140 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 141 | ) 142 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 143 | return upscaled_masks, scores, stability_scores, areas, masks 144 | 145 | return upscaled_masks, scores, masks 146 | 147 | -------------------------------------------------------------------------------- /sam/repvit_sam/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /sam/scripts/export_coreml_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from repvit_sam import sam_model_registry 10 | from repvit_sam.utils.coreml import SamCoreMLModel 11 | 12 | import argparse 13 | import warnings 14 | 15 | parser = argparse.ArgumentParser( 16 | description="Export the SAM prompt encoder and mask decoder to an ONNX model." 17 | ) 18 | 19 | parser.add_argument( 20 | "--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint." 21 | ) 22 | 23 | parser.add_argument( 24 | "--output", type=str, required=False, help="The filename to save the ONNX model to." 25 | ) 26 | 27 | parser.add_argument( 28 | "--model-type", 29 | type=str, 30 | required=True, 31 | help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.", 32 | ) 33 | 34 | parser.add_argument( 35 | "--return-single-mask", 36 | action="store_true", 37 | default=True, 38 | help=( 39 | "If true, the exported ONNX model will only return the best mask, " 40 | "instead of returning multiple masks. For high resolution images " 41 | "this can improve runtime when upscaling masks is expensive." 42 | ), 43 | ) 44 | 45 | parser.add_argument( 46 | "--opset", 47 | type=int, 48 | default=17, 49 | help="The ONNX opset version to use. Must be >=11", 50 | ) 51 | 52 | parser.add_argument( 53 | "--quantize-out", 54 | type=str, 55 | default=None, 56 | help=( 57 | "If set, will quantize the model and save it with this name. " 58 | "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize." 59 | ), 60 | ) 61 | 62 | parser.add_argument( 63 | "--gelu-approximate", 64 | action="store_true", 65 | help=( 66 | "Replace GELU operations with approximations using tanh. Useful " 67 | "for some runtimes that have slow or unimplemented erf ops, used in GELU." 68 | ), 69 | ) 70 | 71 | parser.add_argument( 72 | "--use-stability-score", 73 | action="store_true", 74 | help=( 75 | "Replaces the model's predicted mask quality score with the stability " 76 | "score calculated on the low resolution masks using an offset of 1.0. " 77 | ), 78 | ) 79 | 80 | parser.add_argument( 81 | "--return-extra-metrics", 82 | action="store_true", 83 | help=( 84 | "The model will return five results: (masks, scores, stability_scores, " 85 | "areas, low_res_logits) instead of the usual three. This can be " 86 | "significantly slower for high resolution outputs." 87 | ), 88 | ) 89 | 90 | parser.add_argument('--precision', default='fp16', type=str) 91 | 92 | @torch.no_grad() 93 | def run_export( 94 | model_type: str, 95 | checkpoint: str, 96 | output: str, 97 | opset: int, 98 | return_single_mask: bool, 99 | gelu_approximate: bool = False, 100 | use_stability_score: bool = False, 101 | return_extra_metrics=False, 102 | ): 103 | print("Loading model...") 104 | sam = sam_model_registry[model_type](checkpoint=checkpoint) 105 | 106 | onnx_model = SamCoreMLModel( 107 | model=sam, 108 | orig_img_size=[1024, 1024], 109 | return_single_mask=return_single_mask, 110 | use_stability_score=use_stability_score, 111 | return_extra_metrics=return_extra_metrics, 112 | ) 113 | onnx_model.eval() 114 | 115 | dynamic_axes = { 116 | "point_coords": {1: "num_points"}, 117 | "point_labels": {1: "num_points"}, 118 | } 119 | 120 | embed_dim = sam.prompt_encoder.embed_dim 121 | embed_size = sam.prompt_encoder.image_embedding_size 122 | mask_input_size = [4 * x for x in embed_size] 123 | dummy_inputs = { 124 | "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), 125 | "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), 126 | "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), 127 | "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), 128 | "has_mask_input": torch.tensor([1], dtype=torch.float), 129 | } 130 | traced_model = torch.jit.trace(onnx_model, example_inputs=list(dummy_inputs.values())) 131 | out = traced_model(**dummy_inputs) 132 | 133 | output_names = ["masks", "iou_predictions", "low_res_masks"] 134 | 135 | import coremltools as ct 136 | 137 | # Using image_input in the inputs parameter: 138 | # Convert to Core ML neural network using the Unified Conversion API. 139 | model = ct.convert( 140 | traced_model, 141 | inputs=[ 142 | ct.TensorType(name='image_embeddings', shape=dummy_inputs['image_embeddings'].shape), 143 | ct.TensorType(name='point_coords', shape=ct.Shape(shape=(1, ct.RangeDim(lower_bound=0, upper_bound=5,default=1), 2))), 144 | ct.TensorType(name='point_labels', shape=ct.Shape(shape=(1, ct.RangeDim(lower_bound=0, upper_bound=5,default=1)))), 145 | ct.TensorType(name='mask_input', shape=dummy_inputs['mask_input'].shape), 146 | ct.TensorType(name='has_mask_input', shape=dummy_inputs['has_mask_input'].shape), 147 | ], 148 | compute_precision=ct.precision.FLOAT16 if args.precision=='fp16' else ct.precision.FLOAT32 149 | ) 150 | 151 | # Save the converted model. 152 | model.save(f"coreml/sam_decoder.mlpackage") 153 | 154 | 155 | 156 | def to_numpy(tensor): 157 | return tensor.cpu().numpy() 158 | 159 | 160 | if __name__ == "__main__": 161 | args = parser.parse_args() 162 | run_export( 163 | model_type=args.model_type, 164 | checkpoint=args.checkpoint, 165 | output=args.output, 166 | opset=args.opset, 167 | return_single_mask=args.return_single_mask, 168 | gelu_approximate=args.gelu_approximate, 169 | use_stability_score=args.use_stability_score, 170 | return_extra_metrics=args.return_extra_metrics, 171 | ) 172 | -------------------------------------------------------------------------------- /sam/scripts/export_coreml_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from timm import create_model 4 | 5 | import torch 6 | import torchvision 7 | from argparse import ArgumentParser 8 | from timm.models import create_model 9 | import repvit_sam.modeling 10 | 11 | parser = ArgumentParser() 12 | 13 | parser.add_argument('--model', default='vit_t', type=str) 14 | parser.add_argument('--resolution', default=224, type=int) 15 | parser.add_argument('--ckpt', default=None, type=str) 16 | parser.add_argument('--samckpt', default=None, type=str) 17 | parser.add_argument('--precision', default='fp16', type=str) 18 | 19 | if __name__ == "__main__": 20 | # Load a pre-trained version of MobileNetV2 21 | args = parser.parse_args() 22 | model = create_model(args.model) 23 | if args.ckpt: 24 | model.load_state_dict(torch.load(args.ckpt)['model']) 25 | if args.samckpt: 26 | state = torch.load(args.samckpt, map_location='cpu') 27 | new_state = {} 28 | for k, v in state.items(): 29 | if not 'image_encoder' in k: 30 | continue 31 | new_state[k.replace('image_encoder.', '')] = v 32 | model.load_state_dict(new_state) 33 | model.eval() 34 | 35 | # Trace the model with random data. 36 | resolution = args.resolution 37 | example_input = torch.rand(1, 3, resolution, resolution) 38 | traced_model = torch.jit.trace(model, example_input) 39 | out = traced_model(example_input) 40 | 41 | import coremltools as ct 42 | 43 | # Using image_input in the inputs parameter: 44 | # Convert to Core ML neural network using the Unified Conversion API. 45 | model = ct.convert( 46 | traced_model, 47 | inputs=[ct.TensorType(shape=example_input.shape)], 48 | compute_precision=ct.precision.FLOAT16 if args.precision=='fp16' else ct.precision.FLOAT32 49 | ) 50 | 51 | # Save the converted model. 52 | model.save(f"coreml/{args.model}_{resolution}.mlpackage") 53 | -------------------------------------------------------------------------------- /sam/setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length=100 3 | multi_line_output=3 4 | include_trailing_comma=True 5 | known_standard_library=numpy,setuptools 6 | skip_glob=*/__init__.py 7 | known_myself=repvit_sam 8 | known_third_party=matplotlib,cv2,torch,torchvision,pycocotools,onnx,black,isort 9 | no_lines_before=STDLIB,THIRDPARTY 10 | sections=FUTURE,STDLIB,THIRDPARTY,MYSELF,FIRSTPARTY,LOCALFOLDER 11 | default_section=FIRSTPARTY 12 | -------------------------------------------------------------------------------- /sam/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import find_packages, setup 8 | 9 | setup( 10 | name="repvit_sam", 11 | version="1.0", 12 | install_requires=[], 13 | packages=find_packages(exclude="notebooks"), 14 | extras_require={ 15 | "all": ["matplotlib", "pycocotools", "opencv-python", "onnx", "onnxruntime"], 16 | "dev": ["flake8", "isort", "black", "mypy"], 17 | }, 18 | ) 19 | -------------------------------------------------------------------------------- /segmentation/.gitignore: -------------------------------------------------------------------------------- 1 | pretrain 2 | work_dirs 3 | data 4 | seg_pretrain -------------------------------------------------------------------------------- /segmentation/README.md: -------------------------------------------------------------------------------- 1 | # Semantic Segmentation 2 | 3 | Segmentation on ADE20K is implemented based on [MMSegmentation](https://github.com/open-mmlab/mmsegmentation). 4 | 5 | ## Models 6 | | Model | mIoU | Latency | Ckpt | Log | 7 | |:---------------|:----:|:---:|:--:|:--:| 8 | | RepViT-M1.1 | 40.6 | 4.9ms | [M1.1](https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_1_ade20k.pth) | [M1.1](./logs/repvit_m1_1_ade20k.json) | 9 | | RepViT-M1.5 | 43.6 | 6.4ms | [M1.5](https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_5_ade20k.pth) | [M1.5](./logs/repvit_m1_5_ade20k.json) | 10 | | RepViT-M2.3 | 46.1 | 9.9ms | [M2.3](https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_3_ade20k.pth) | [M2.3](./logs/repvit_m2_3_ade20k.json) | 11 | 12 | The backbone latency is measured with image crops of 512x512 on iPhone 12 by Core ML Tools. 13 | 14 | ## Requirements 15 | Install [mmcv-full](https://github.com/open-mmlab/mmcv) and [MMSegmentation v0.30.0](https://github.com/open-mmlab/mmsegmentation/tree/v0.30.0). 16 | Later versions should work as well. 17 | The easiest way is to install via [MIM](https://github.com/open-mmlab/mim) 18 | ``` 19 | pip install -U openmim 20 | mim install mmcv-full==1.7.1 21 | mim install mmseg==0.30.0 22 | ``` 23 | 24 | ## Data preparation 25 | 26 | We benchmark RepViT on the challenging ADE20K dataset, which can be downloaded and prepared following [insructions in MMSeg](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#prepare-datasets). 27 | The data should appear as: 28 | ``` 29 | ├── segmentation 30 | │ ├── data 31 | │ │ ├── ade 32 | │ │ │ ├── ADEChallengeData2016 33 | │ │ │ │ ├── annotations 34 | │ │ │ │ │ ├── training 35 | │ │ │ │ │ ├── validation 36 | │ │ │ │ ├── images 37 | │ │ │ │ │ ├── training 38 | │ │ │ │ │ ├── validation 39 | 40 | ``` 41 | 42 | 43 | 44 | ## Testing 45 | 46 | We provide a multi-GPU testing script, specify config file, checkpoint, and number of GPUs to use: 47 | ``` 48 | ./tools/dist_test.sh config_file path/to/checkpoint #GPUs --eval mIoU 49 | ``` 50 | 51 | For example, to test RepViT-M1.1 on ADE20K on an 8-GPU machine, 52 | 53 | ``` 54 | ./tools/dist_test.sh configs/sem_fpn/fpn_repvit_m1_1_ade20k_40k.py path/to/repvit_m1_1_ade20k.pth 8 --eval mIoU 55 | ``` 56 | 57 | ## Training 58 | Download ImageNet-1K pretrained weights into `./pretrain` 59 | 60 | We provide PyTorch distributed data parallel (DDP) training script `dist_train.sh`, for example, to train RepViT-M1.1 on an 8-GPU machine: 61 | ``` 62 | ./tools/dist_train.sh configs/sem_fpn/fpn_repvit_m1_1_ade20k_40k.py 8 63 | ``` 64 | Tips: specify configs and #GPUs! 65 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/datasets/ade20k.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'ADE20KDataset' 3 | data_root = 'data/ade/ADEChallengeData2016' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | crop_size = (512, 512) 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='LoadAnnotations', reduce_zero_label=True), 10 | dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), 11 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 12 | dict(type='RandomFlip', prob=0.5), 13 | dict(type='PhotoMetricDistortion'), 14 | dict(type='Normalize', **img_norm_cfg), 15 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 16 | dict(type='DefaultFormatBundle'), 17 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 18 | ] 19 | test_pipeline = [ 20 | dict(type='LoadImageFromFile'), 21 | dict( 22 | type='MultiScaleFlipAug', 23 | img_scale=(2048, 512), 24 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 25 | flip=False, 26 | transforms=[ 27 | dict(type='AlignResize', keep_ratio=True, size_divisor=32), 28 | dict(type='RandomFlip'), 29 | dict(type='Normalize', **img_norm_cfg), 30 | dict(type='ImageToTensor', keys=['img']), 31 | dict(type='Collect', keys=['img']), 32 | ]) 33 | ] 34 | data = dict( 35 | samples_per_gpu=4, 36 | workers_per_gpu=4, 37 | train=dict( 38 | type='RepeatDataset', 39 | times=50, 40 | dataset=dict( 41 | type=dataset_type, 42 | data_root=data_root, 43 | img_dir='images/training', 44 | ann_dir='annotations/training', 45 | pipeline=train_pipeline)), 46 | val=dict( 47 | type=dataset_type, 48 | data_root=data_root, 49 | img_dir='images/validation', 50 | ann_dir='annotations/validation', 51 | pipeline=test_pipeline), 52 | test=dict( 53 | type=dataset_type, 54 | data_root=data_root, 55 | img_dir='images/validation', 56 | ann_dir='annotations/validation', 57 | pipeline=test_pipeline)) 58 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # yapf:disable 2 | log_config = dict( 3 | interval=50, 4 | hooks=[ 5 | dict(type='TextLoggerHook', by_epoch=False), 6 | # dict(type='TensorboardLoggerHook') 7 | ]) 8 | # yapf:enable 9 | dist_params = dict(backend='nccl') 10 | log_level = 'INFO' 11 | load_from = None 12 | resume_from = None 13 | workflow = [('train', 1)] 14 | cudnn_benchmark = True 15 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/models/fpn_r50.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained='open-mmlab://resnet50_v1c', 6 | backbone=dict( 7 | type='ResNetV1c', 8 | depth=50, 9 | num_stages=4, 10 | out_indices=(0, 1, 2, 3), 11 | dilations=(1, 1, 1, 1), 12 | strides=(1, 2, 2, 2), 13 | norm_cfg=norm_cfg, 14 | norm_eval=False, 15 | style='pytorch', 16 | contract_dilation=True), 17 | neck=dict( 18 | type='FPN', 19 | in_channels=[256, 512, 1024, 2048], 20 | out_channels=256, 21 | num_outs=4), 22 | decode_head=dict( 23 | type='FPNHead', 24 | in_channels=[256, 256, 256, 256], 25 | in_index=[0, 1, 2, 3], 26 | feature_strides=[4, 8, 16, 32], 27 | channels=128, 28 | dropout_ratio=0.1, 29 | num_classes=19, 30 | norm_cfg=norm_cfg, 31 | align_corners=False, 32 | loss_decode=dict( 33 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 34 | # model training and testing settings 35 | train_cfg=dict(), 36 | test_cfg=dict(mode='whole')) 37 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/schedules/schedule_160k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=160000) 8 | checkpoint_config = dict(by_epoch=False, interval=16000) 9 | evaluation = dict(interval=16000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/schedules/schedule_20k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=20000) 8 | checkpoint_config = dict(by_epoch=False, interval=2000) 9 | evaluation = dict(interval=2000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/schedules/schedule_40k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=40000) 8 | checkpoint_config = dict(by_epoch=False, interval=4000) 9 | evaluation = dict(interval=4000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/schedules/schedule_80k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=80000) 8 | checkpoint_config = dict(by_epoch=False, interval=8000) 9 | evaluation = dict(interval=8000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /segmentation/configs/sem_fpn/fpn_repvit_m1_1_ade20k_40k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/fpn_r50.py', 3 | '../_base_/datasets/ade20k.py', 4 | '../_base_/default_runtime.py' 5 | ] 6 | # model settings 7 | model = dict( 8 | type='EncoderDecoder', 9 | backbone=dict( 10 | type='repvit_m1_1', 11 | style='pytorch', 12 | init_cfg=dict( 13 | type='Pretrained', 14 | checkpoint='pretrain/repvit_m1_1_distill_300e.pth', 15 | ), 16 | out_indices = [3,7,21,24] 17 | ), 18 | neck=dict(in_channels=[64, 128, 256, 512]), 19 | decode_head=dict(num_classes=150)) 20 | 21 | gpu_multiples = 2 # we use 8 gpu instead of 4 in mmsegmentation, so lr*2 and max_iters/2 22 | # optimizer 23 | optimizer = dict(type='AdamW', lr=0.0001 * gpu_multiples, weight_decay=0.0001) 24 | optimizer_config = dict() 25 | # learning policy 26 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-6, by_epoch=False) 27 | # runtime settings 28 | runner = dict(type='IterBasedRunner', max_iters=80000 // gpu_multiples) 29 | checkpoint_config = dict(by_epoch=False, interval=8000 // gpu_multiples) 30 | evaluation = dict(interval=8000 // gpu_multiples, metric='mIoU') 31 | -------------------------------------------------------------------------------- /segmentation/configs/sem_fpn/fpn_repvit_m1_5_ade20k_40k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/fpn_r50.py', 3 | '../_base_/datasets/ade20k.py', 4 | '../_base_/default_runtime.py' 5 | ] 6 | # model settings 7 | model = dict( 8 | type='EncoderDecoder', 9 | backbone=dict( 10 | type='repvit_m1_5', 11 | style='pytorch', 12 | init_cfg=dict( 13 | type='Pretrained', 14 | checkpoint='pretrain/repvit_m1_5_distill_300e.pth', 15 | ), 16 | out_indices=[5, 11, 37, 42] 17 | ), 18 | neck=dict(in_channels=[64, 128, 256, 512]), 19 | decode_head=dict(num_classes=150)) 20 | 21 | gpu_multiples = 2 # we use 8 gpu instead of 4 in mmsegmentation, so lr*2 and max_iters/2 22 | # optimizer 23 | optimizer = dict(type='AdamW', lr=0.0001 * gpu_multiples, weight_decay=0.0001) 24 | optimizer_config = dict() 25 | # learning policy 26 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-6, by_epoch=False) 27 | # runtime settings 28 | runner = dict(type='IterBasedRunner', max_iters=80000 // gpu_multiples) 29 | checkpoint_config = dict(by_epoch=False, interval=8000 // gpu_multiples) 30 | evaluation = dict(interval=8000 // gpu_multiples, metric='mIoU') 31 | -------------------------------------------------------------------------------- /segmentation/configs/sem_fpn/fpn_repvit_m2_3_ade20k_40k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/fpn_r50.py', 3 | '../_base_/datasets/ade20k.py', 4 | '../_base_/default_runtime.py' 5 | ] 6 | # model settings 7 | model = dict( 8 | type='EncoderDecoder', 9 | backbone=dict( 10 | type='repvit_m2_3', 11 | style='pytorch', 12 | init_cfg=dict( 13 | type='Pretrained', 14 | checkpoint='pretrain/repvit_m2_3_distill_450e.pth', 15 | ), 16 | out_indices=[7, 15, 51, 54] 17 | ), 18 | neck=dict(in_channels=[80, 160, 320, 640]), 19 | decode_head=dict(num_classes=150)) 20 | 21 | gpu_multiples = 2 # we use 8 gpu instead of 4 in mmsegmentation, so lr*2 and max_iters/2 22 | # optimizer 23 | optimizer = dict(type='AdamW', lr=0.0001 * gpu_multiples, weight_decay=0.0001) 24 | optimizer_config = dict() 25 | # learning policy 26 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-6, by_epoch=False) 27 | # runtime settings 28 | runner = dict(type='IterBasedRunner', max_iters=80000 // gpu_multiples) 29 | checkpoint_config = dict(by_epoch=False, interval=8000 // gpu_multiples) 30 | evaluation = dict(interval=8000 // gpu_multiples, metric='mIoU') 31 | -------------------------------------------------------------------------------- /segmentation/eval.sh: -------------------------------------------------------------------------------- 1 | PORT=12345 ./tools/dist_test.sh configs/sem_fpn/fpn_repvit_m1_1_ade20k_40k.py seg_pretrain/repvit_m1_1_ade20k.pth 8 --eval mIoU -------------------------------------------------------------------------------- /segmentation/tools/analyze_logs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | """Modified from https://github.com/open- 3 | mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py.""" 4 | import argparse 5 | import json 6 | from collections import defaultdict 7 | 8 | import matplotlib.pyplot as plt 9 | import seaborn as sns 10 | 11 | 12 | def plot_curve(log_dicts, args): 13 | if args.backend is not None: 14 | plt.switch_backend(args.backend) 15 | sns.set_style(args.style) 16 | # if legend is None, use {filename}_{key} as legend 17 | legend = args.legend 18 | if legend is None: 19 | legend = [] 20 | for json_log in args.json_logs: 21 | for metric in args.keys: 22 | legend.append(f'{json_log}_{metric}') 23 | assert len(legend) == (len(args.json_logs) * len(args.keys)) 24 | metrics = args.keys 25 | 26 | num_metrics = len(metrics) 27 | for i, log_dict in enumerate(log_dicts): 28 | epochs = list(log_dict.keys()) 29 | for j, metric in enumerate(metrics): 30 | print(f'plot curve of {args.json_logs[i]}, metric is {metric}') 31 | plot_epochs = [] 32 | plot_iters = [] 33 | plot_values = [] 34 | # In some log files, iters number is not correct, `pre_iter` is 35 | # used to prevent generate wrong lines. 36 | pre_iter = -1 37 | for epoch in epochs: 38 | epoch_logs = log_dict[epoch] 39 | if metric not in epoch_logs.keys(): 40 | continue 41 | if metric in ['mIoU', 'mAcc', 'aAcc']: 42 | plot_epochs.append(epoch) 43 | plot_values.append(epoch_logs[metric][0]) 44 | else: 45 | for idx in range(len(epoch_logs[metric])): 46 | if pre_iter > epoch_logs['iter'][idx]: 47 | continue 48 | pre_iter = epoch_logs['iter'][idx] 49 | plot_iters.append(epoch_logs['iter'][idx]) 50 | plot_values.append(epoch_logs[metric][idx]) 51 | ax = plt.gca() 52 | label = legend[i * num_metrics + j] 53 | if metric in ['mIoU', 'mAcc', 'aAcc']: 54 | ax.set_xticks(plot_epochs) 55 | plt.xlabel('epoch') 56 | plt.plot(plot_epochs, plot_values, label=label, marker='o') 57 | else: 58 | plt.xlabel('iter') 59 | plt.plot(plot_iters, plot_values, label=label, linewidth=0.5) 60 | plt.legend() 61 | if args.title is not None: 62 | plt.title(args.title) 63 | if args.out is None: 64 | plt.show() 65 | else: 66 | print(f'save curve to: {args.out}') 67 | plt.savefig(args.out) 68 | plt.cla() 69 | 70 | 71 | def parse_args(): 72 | parser = argparse.ArgumentParser(description='Analyze Json Log') 73 | parser.add_argument( 74 | 'json_logs', 75 | type=str, 76 | nargs='+', 77 | help='path of train log in json format') 78 | parser.add_argument( 79 | '--keys', 80 | type=str, 81 | nargs='+', 82 | default=['mIoU'], 83 | help='the metric that you want to plot') 84 | parser.add_argument('--title', type=str, help='title of figure') 85 | parser.add_argument( 86 | '--legend', 87 | type=str, 88 | nargs='+', 89 | default=None, 90 | help='legend of each plot') 91 | parser.add_argument( 92 | '--backend', type=str, default=None, help='backend of plt') 93 | parser.add_argument( 94 | '--style', type=str, default='dark', help='style of plt') 95 | parser.add_argument('--out', type=str, default=None) 96 | args = parser.parse_args() 97 | return args 98 | 99 | 100 | def load_json_logs(json_logs): 101 | # load and convert json_logs to log_dict, key is epoch, value is a sub dict 102 | # keys of sub dict is different metrics 103 | # value of sub dict is a list of corresponding values of all iterations 104 | log_dicts = [dict() for _ in json_logs] 105 | for json_log, log_dict in zip(json_logs, log_dicts): 106 | with open(json_log, 'r') as log_file: 107 | for line in log_file: 108 | log = json.loads(line.strip()) 109 | # skip lines without `epoch` field 110 | if 'epoch' not in log: 111 | continue 112 | epoch = log.pop('epoch') 113 | if epoch not in log_dict: 114 | log_dict[epoch] = defaultdict(list) 115 | for k, v in log.items(): 116 | log_dict[epoch][k].append(v) 117 | return log_dicts 118 | 119 | 120 | def main(): 121 | args = parse_args() 122 | json_logs = args.json_logs 123 | for json_log in json_logs: 124 | assert json_log.endswith('.json') 125 | log_dicts = load_json_logs(json_logs) 126 | plot_curve(log_dicts, args) 127 | 128 | 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /segmentation/tools/benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import time 4 | 5 | import torch 6 | from mmcv import Config 7 | from mmcv.parallel import MMDataParallel 8 | from mmcv.runner import load_checkpoint, wrap_fp16_model 9 | 10 | from mmseg.datasets import build_dataloader, build_dataset 11 | from mmseg.models import build_segmentor 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser(description='MMSeg benchmark a model') 16 | parser.add_argument('config', help='test config file path') 17 | parser.add_argument('checkpoint', help='checkpoint file') 18 | parser.add_argument( 19 | '--log-interval', type=int, default=50, help='interval of logging') 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | def main(): 25 | args = parse_args() 26 | 27 | cfg = Config.fromfile(args.config) 28 | # set cudnn_benchmark 29 | torch.backends.cudnn.benchmark = False 30 | cfg.model.pretrained = None 31 | cfg.data.test.test_mode = True 32 | 33 | # build the dataloader 34 | # TODO: support multiple images per gpu (only minor changes are needed) 35 | dataset = build_dataset(cfg.data.test) 36 | data_loader = build_dataloader( 37 | dataset, 38 | samples_per_gpu=1, 39 | workers_per_gpu=cfg.data.workers_per_gpu, 40 | dist=False, 41 | shuffle=False) 42 | 43 | # build the model and load checkpoint 44 | cfg.model.train_cfg = None 45 | model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg')) 46 | fp16_cfg = cfg.get('fp16', None) 47 | if fp16_cfg is not None: 48 | wrap_fp16_model(model) 49 | load_checkpoint(model, args.checkpoint, map_location='cpu') 50 | 51 | model = MMDataParallel(model, device_ids=[0]) 52 | 53 | model.eval() 54 | 55 | # the first several iterations may be very slow so skip them 56 | num_warmup = 5 57 | pure_inf_time = 0 58 | total_iters = 200 59 | 60 | # benchmark with 200 image and take the average 61 | for i, data in enumerate(data_loader): 62 | 63 | torch.cuda.synchronize() 64 | start_time = time.perf_counter() 65 | 66 | with torch.no_grad(): 67 | model(return_loss=False, rescale=True, **data) 68 | 69 | torch.cuda.synchronize() 70 | elapsed = time.perf_counter() - start_time 71 | 72 | if i >= num_warmup: 73 | pure_inf_time += elapsed 74 | if (i + 1) % args.log_interval == 0: 75 | fps = (i + 1 - num_warmup) / pure_inf_time 76 | print(f'Done image [{i + 1:<3}/ {total_iters}], ' 77 | f'fps: {fps:.2f} img / s') 78 | 79 | if (i + 1) == total_iters: 80 | fps = (i + 1 - num_warmup) / pure_inf_time 81 | print(f'Overall fps: {fps:.2f} img / s') 82 | break 83 | 84 | 85 | if __name__ == '__main__': 86 | main() 87 | -------------------------------------------------------------------------------- /segmentation/tools/browse_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | from pathlib import Path 5 | 6 | import mmcv 7 | import numpy as np 8 | from mmcv import Config 9 | 10 | from mmseg.datasets.builder import build_dataset 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description='Browse a dataset') 15 | parser.add_argument('config', help='train config file path') 16 | parser.add_argument( 17 | '--show-origin', 18 | default=False, 19 | action='store_true', 20 | help='if True, omit all augmentation in pipeline,' 21 | ' show origin image and seg map') 22 | parser.add_argument( 23 | '--skip-type', 24 | type=str, 25 | nargs='+', 26 | default=['DefaultFormatBundle', 'Normalize', 'Collect'], 27 | help='skip some useless pipeline,if `show-origin` is true, ' 28 | 'all pipeline except `Load` will be skipped') 29 | parser.add_argument( 30 | '--output-dir', 31 | default='./output', 32 | type=str, 33 | help='If there is no display interface, you can save it') 34 | parser.add_argument('--show', default=False, action='store_true') 35 | parser.add_argument( 36 | '--show-interval', 37 | type=int, 38 | default=999, 39 | help='the interval of show (ms)') 40 | parser.add_argument( 41 | '--opacity', 42 | type=float, 43 | default=0.5, 44 | help='the opacity of semantic map') 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | def imshow_semantic(img, 50 | seg, 51 | class_names, 52 | palette=None, 53 | win_name='', 54 | show=False, 55 | wait_time=0, 56 | out_file=None, 57 | opacity=0.5): 58 | """Draw `result` over `img`. 59 | 60 | Args: 61 | img (str or Tensor): The image to be displayed. 62 | seg (Tensor): The semantic segmentation results to draw over 63 | `img`. 64 | class_names (list[str]): Names of each classes. 65 | palette (list[list[int]]] | np.ndarray | None): The palette of 66 | segmentation map. If None is given, random palette will be 67 | generated. Default: None 68 | win_name (str): The window name. 69 | wait_time (int): Value of waitKey param. 70 | Default: 0. 71 | show (bool): Whether to show the image. 72 | Default: False. 73 | out_file (str or None): The filename to write the image. 74 | Default: None. 75 | opacity(float): Opacity of painted segmentation map. 76 | Default 0.5. 77 | Must be in (0, 1] range. 78 | Returns: 79 | img (Tensor): Only if not `show` or `out_file` 80 | """ 81 | img = mmcv.imread(img) 82 | img = img.copy() 83 | if palette is None: 84 | palette = np.random.randint(0, 255, size=(len(class_names), 3)) 85 | palette = np.array(palette) 86 | assert palette.shape[0] == len(class_names) 87 | assert palette.shape[1] == 3 88 | assert len(palette.shape) == 2 89 | assert 0 < opacity <= 1.0 90 | color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) 91 | for label, color in enumerate(palette): 92 | color_seg[seg == label, :] = color 93 | # convert to BGR 94 | color_seg = color_seg[..., ::-1] 95 | 96 | img = img * (1 - opacity) + color_seg * opacity 97 | img = img.astype(np.uint8) 98 | # if out_file specified, do not show image in window 99 | if out_file is not None: 100 | show = False 101 | 102 | if show: 103 | mmcv.imshow(img, win_name, wait_time) 104 | if out_file is not None: 105 | mmcv.imwrite(img, out_file) 106 | 107 | if not (show or out_file): 108 | warnings.warn('show==False and out_file is not specified, only ' 109 | 'result image will be returned') 110 | return img 111 | 112 | 113 | def _retrieve_data_cfg(_data_cfg, skip_type, show_origin): 114 | if show_origin is True: 115 | # only keep pipeline of Loading data and ann 116 | _data_cfg['pipeline'] = [ 117 | x for x in _data_cfg.pipeline if 'Load' in x['type'] 118 | ] 119 | else: 120 | _data_cfg['pipeline'] = [ 121 | x for x in _data_cfg.pipeline if x['type'] not in skip_type 122 | ] 123 | 124 | 125 | def retrieve_data_cfg(config_path, skip_type, show_origin=False): 126 | cfg = Config.fromfile(config_path) 127 | train_data_cfg = cfg.data.train 128 | if isinstance(train_data_cfg, list): 129 | for _data_cfg in train_data_cfg: 130 | if 'pipeline' in _data_cfg: 131 | _retrieve_data_cfg(_data_cfg, skip_type, show_origin) 132 | elif 'dataset' in _data_cfg: 133 | _retrieve_data_cfg(_data_cfg['dataset'], skip_type, 134 | show_origin) 135 | else: 136 | raise ValueError 137 | elif 'dataset' in train_data_cfg: 138 | _retrieve_data_cfg(train_data_cfg['dataset'], skip_type, show_origin) 139 | else: 140 | _retrieve_data_cfg(train_data_cfg, skip_type, show_origin) 141 | return cfg 142 | 143 | 144 | def main(): 145 | args = parse_args() 146 | cfg = retrieve_data_cfg(args.config, args.skip_type, args.show_origin) 147 | dataset = build_dataset(cfg.data.train) 148 | progress_bar = mmcv.ProgressBar(len(dataset)) 149 | for item in dataset: 150 | filename = os.path.join(args.output_dir, 151 | Path(item['filename']).name 152 | ) if args.output_dir is not None else None 153 | imshow_semantic( 154 | item['img'], 155 | item['gt_semantic_seg'], 156 | dataset.CLASSES, 157 | dataset.PALETTE, 158 | show=args.show, 159 | wait_time=args.show_interval, 160 | out_file=filename, 161 | opacity=args.opacity, 162 | ) 163 | progress_bar.update() 164 | 165 | 166 | if __name__ == '__main__': 167 | main() 168 | -------------------------------------------------------------------------------- /segmentation/tools/convert_datasets/chase_db1.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import tempfile 6 | import zipfile 7 | 8 | import mmcv 9 | 10 | CHASE_DB1_LEN = 28 * 3 11 | TRAINING_LEN = 60 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Convert CHASE_DB1 dataset to mmsegmentation format') 17 | parser.add_argument('dataset_path', help='path of CHASEDB1.zip') 18 | parser.add_argument('--tmp_dir', help='path of the temporary directory') 19 | parser.add_argument('-o', '--out_dir', help='output path') 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | def main(): 25 | args = parse_args() 26 | dataset_path = args.dataset_path 27 | if args.out_dir is None: 28 | out_dir = osp.join('data', 'CHASE_DB1') 29 | else: 30 | out_dir = args.out_dir 31 | 32 | print('Making directories...') 33 | mmcv.mkdir_or_exist(out_dir) 34 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images')) 35 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training')) 36 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) 37 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations')) 38 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) 39 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) 40 | 41 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 42 | print('Extracting CHASEDB1.zip...') 43 | zip_file = zipfile.ZipFile(dataset_path) 44 | zip_file.extractall(tmp_dir) 45 | 46 | print('Generating training dataset...') 47 | 48 | assert len(os.listdir(tmp_dir)) == CHASE_DB1_LEN, \ 49 | 'len(os.listdir(tmp_dir)) != {}'.format(CHASE_DB1_LEN) 50 | 51 | for img_name in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: 52 | img = mmcv.imread(osp.join(tmp_dir, img_name)) 53 | if osp.splitext(img_name)[1] == '.jpg': 54 | mmcv.imwrite( 55 | img, 56 | osp.join(out_dir, 'images', 'training', 57 | osp.splitext(img_name)[0] + '.png')) 58 | else: 59 | # The annotation img should be divided by 128, because some of 60 | # the annotation imgs are not standard. We should set a 61 | # threshold to convert the nonstandard annotation imgs. The 62 | # value divided by 128 is equivalent to '1 if value >= 128 63 | # else 0' 64 | mmcv.imwrite( 65 | img[:, :, 0] // 128, 66 | osp.join(out_dir, 'annotations', 'training', 67 | osp.splitext(img_name)[0] + '.png')) 68 | 69 | for img_name in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: 70 | img = mmcv.imread(osp.join(tmp_dir, img_name)) 71 | if osp.splitext(img_name)[1] == '.jpg': 72 | mmcv.imwrite( 73 | img, 74 | osp.join(out_dir, 'images', 'validation', 75 | osp.splitext(img_name)[0] + '.png')) 76 | else: 77 | mmcv.imwrite( 78 | img[:, :, 0] // 128, 79 | osp.join(out_dir, 'annotations', 'validation', 80 | osp.splitext(img_name)[0] + '.png')) 81 | 82 | print('Removing the temporary files...') 83 | 84 | print('Done!') 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /segmentation/tools/convert_datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | 5 | import mmcv 6 | from cityscapesscripts.preparation.json2labelImg import json2labelImg 7 | 8 | 9 | def convert_json_to_label(json_file): 10 | label_file = json_file.replace('_polygons.json', '_labelTrainIds.png') 11 | json2labelImg(json_file, label_file, 'trainIds') 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Convert Cityscapes annotations to TrainIds') 17 | parser.add_argument('cityscapes_path', help='cityscapes data path') 18 | parser.add_argument('--gt-dir', default='gtFine', type=str) 19 | parser.add_argument('-o', '--out-dir', help='output path') 20 | parser.add_argument( 21 | '--nproc', default=1, type=int, help='number of process') 22 | args = parser.parse_args() 23 | return args 24 | 25 | 26 | def main(): 27 | args = parse_args() 28 | cityscapes_path = args.cityscapes_path 29 | out_dir = args.out_dir if args.out_dir else cityscapes_path 30 | mmcv.mkdir_or_exist(out_dir) 31 | 32 | gt_dir = osp.join(cityscapes_path, args.gt_dir) 33 | 34 | poly_files = [] 35 | for poly in mmcv.scandir(gt_dir, '_polygons.json', recursive=True): 36 | poly_file = osp.join(gt_dir, poly) 37 | poly_files.append(poly_file) 38 | if args.nproc > 1: 39 | mmcv.track_parallel_progress(convert_json_to_label, poly_files, 40 | args.nproc) 41 | else: 42 | mmcv.track_progress(convert_json_to_label, poly_files) 43 | 44 | split_names = ['train', 'val', 'test'] 45 | 46 | for split in split_names: 47 | filenames = [] 48 | for poly in mmcv.scandir( 49 | osp.join(gt_dir, split), '_polygons.json', recursive=True): 50 | filenames.append(poly.replace('_gtFine_polygons.json', '')) 51 | with open(osp.join(out_dir, f'{split}.txt'), 'w') as f: 52 | f.writelines(f + '\n' for f in filenames) 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /segmentation/tools/convert_datasets/coco_stuff164k.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import shutil 4 | from functools import partial 5 | from glob import glob 6 | 7 | import mmcv 8 | import numpy as np 9 | from PIL import Image 10 | 11 | COCO_LEN = 123287 12 | 13 | clsID_to_trID = { 14 | 0: 0, 15 | 1: 1, 16 | 2: 2, 17 | 3: 3, 18 | 4: 4, 19 | 5: 5, 20 | 6: 6, 21 | 7: 7, 22 | 8: 8, 23 | 9: 9, 24 | 10: 10, 25 | 12: 11, 26 | 13: 12, 27 | 14: 13, 28 | 15: 14, 29 | 16: 15, 30 | 17: 16, 31 | 18: 17, 32 | 19: 18, 33 | 20: 19, 34 | 21: 20, 35 | 22: 21, 36 | 23: 22, 37 | 24: 23, 38 | 26: 24, 39 | 27: 25, 40 | 30: 26, 41 | 31: 27, 42 | 32: 28, 43 | 33: 29, 44 | 34: 30, 45 | 35: 31, 46 | 36: 32, 47 | 37: 33, 48 | 38: 34, 49 | 39: 35, 50 | 40: 36, 51 | 41: 37, 52 | 42: 38, 53 | 43: 39, 54 | 45: 40, 55 | 46: 41, 56 | 47: 42, 57 | 48: 43, 58 | 49: 44, 59 | 50: 45, 60 | 51: 46, 61 | 52: 47, 62 | 53: 48, 63 | 54: 49, 64 | 55: 50, 65 | 56: 51, 66 | 57: 52, 67 | 58: 53, 68 | 59: 54, 69 | 60: 55, 70 | 61: 56, 71 | 62: 57, 72 | 63: 58, 73 | 64: 59, 74 | 66: 60, 75 | 69: 61, 76 | 71: 62, 77 | 72: 63, 78 | 73: 64, 79 | 74: 65, 80 | 75: 66, 81 | 76: 67, 82 | 77: 68, 83 | 78: 69, 84 | 79: 70, 85 | 80: 71, 86 | 81: 72, 87 | 83: 73, 88 | 84: 74, 89 | 85: 75, 90 | 86: 76, 91 | 87: 77, 92 | 88: 78, 93 | 89: 79, 94 | 91: 80, 95 | 92: 81, 96 | 93: 82, 97 | 94: 83, 98 | 95: 84, 99 | 96: 85, 100 | 97: 86, 101 | 98: 87, 102 | 99: 88, 103 | 100: 89, 104 | 101: 90, 105 | 102: 91, 106 | 103: 92, 107 | 104: 93, 108 | 105: 94, 109 | 106: 95, 110 | 107: 96, 111 | 108: 97, 112 | 109: 98, 113 | 110: 99, 114 | 111: 100, 115 | 112: 101, 116 | 113: 102, 117 | 114: 103, 118 | 115: 104, 119 | 116: 105, 120 | 117: 106, 121 | 118: 107, 122 | 119: 108, 123 | 120: 109, 124 | 121: 110, 125 | 122: 111, 126 | 123: 112, 127 | 124: 113, 128 | 125: 114, 129 | 126: 115, 130 | 127: 116, 131 | 128: 117, 132 | 129: 118, 133 | 130: 119, 134 | 131: 120, 135 | 132: 121, 136 | 133: 122, 137 | 134: 123, 138 | 135: 124, 139 | 136: 125, 140 | 137: 126, 141 | 138: 127, 142 | 139: 128, 143 | 140: 129, 144 | 141: 130, 145 | 142: 131, 146 | 143: 132, 147 | 144: 133, 148 | 145: 134, 149 | 146: 135, 150 | 147: 136, 151 | 148: 137, 152 | 149: 138, 153 | 150: 139, 154 | 151: 140, 155 | 152: 141, 156 | 153: 142, 157 | 154: 143, 158 | 155: 144, 159 | 156: 145, 160 | 157: 146, 161 | 158: 147, 162 | 159: 148, 163 | 160: 149, 164 | 161: 150, 165 | 162: 151, 166 | 163: 152, 167 | 164: 153, 168 | 165: 154, 169 | 166: 155, 170 | 167: 156, 171 | 168: 157, 172 | 169: 158, 173 | 170: 159, 174 | 171: 160, 175 | 172: 161, 176 | 173: 162, 177 | 174: 163, 178 | 175: 164, 179 | 176: 165, 180 | 177: 166, 181 | 178: 167, 182 | 179: 168, 183 | 180: 169, 184 | 181: 170, 185 | 255: 255 186 | } 187 | 188 | 189 | def convert_to_trainID(maskpath, out_mask_dir, is_train): 190 | mask = np.array(Image.open(maskpath)) 191 | mask_copy = mask.copy() 192 | for clsID, trID in clsID_to_trID.items(): 193 | mask_copy[mask == clsID] = trID 194 | seg_filename = osp.join( 195 | out_mask_dir, 'train2017', 196 | osp.basename(maskpath).split('.')[0] + 197 | '_labelTrainIds.png') if is_train else osp.join( 198 | out_mask_dir, 'val2017', 199 | osp.basename(maskpath).split('.')[0] + '_labelTrainIds.png') 200 | Image.fromarray(mask_copy).save(seg_filename, 'PNG') 201 | 202 | 203 | def parse_args(): 204 | parser = argparse.ArgumentParser( 205 | description=\ 206 | 'Convert COCO Stuff 164k annotations to mmsegmentation format') # noqa 207 | parser.add_argument('coco_path', help='coco stuff path') 208 | parser.add_argument('-o', '--out_dir', help='output path') 209 | parser.add_argument( 210 | '--nproc', default=16, type=int, help='number of process') 211 | args = parser.parse_args() 212 | return args 213 | 214 | 215 | def main(): 216 | args = parse_args() 217 | coco_path = args.coco_path 218 | nproc = args.nproc 219 | 220 | out_dir = args.out_dir or coco_path 221 | out_img_dir = osp.join(out_dir, 'images') 222 | out_mask_dir = osp.join(out_dir, 'annotations') 223 | 224 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2017')) 225 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'val2017')) 226 | 227 | if out_dir != coco_path: 228 | shutil.copytree(osp.join(coco_path, 'images'), out_img_dir) 229 | 230 | train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png')) 231 | train_list = [file for file in train_list if '_labelTrainIds' not in file] 232 | test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png')) 233 | test_list = [file for file in test_list if '_labelTrainIds' not in file] 234 | assert (len(train_list) + 235 | len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format( 236 | len(train_list), len(test_list)) 237 | 238 | if args.nproc > 1: 239 | mmcv.track_parallel_progress( 240 | partial( 241 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), 242 | train_list, 243 | nproc=nproc) 244 | mmcv.track_parallel_progress( 245 | partial( 246 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), 247 | test_list, 248 | nproc=nproc) 249 | else: 250 | mmcv.track_progress( 251 | partial( 252 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), 253 | train_list) 254 | mmcv.track_progress( 255 | partial( 256 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), 257 | test_list) 258 | 259 | print('Done!') 260 | 261 | 262 | if __name__ == '__main__': 263 | main() 264 | -------------------------------------------------------------------------------- /segmentation/tools/convert_datasets/drive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import tempfile 6 | import zipfile 7 | 8 | import cv2 9 | import mmcv 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser( 14 | description='Convert DRIVE dataset to mmsegmentation format') 15 | parser.add_argument( 16 | 'training_path', help='the training part of DRIVE dataset') 17 | parser.add_argument( 18 | 'testing_path', help='the testing part of DRIVE dataset') 19 | parser.add_argument('--tmp_dir', help='path of the temporary directory') 20 | parser.add_argument('-o', '--out_dir', help='output path') 21 | args = parser.parse_args() 22 | return args 23 | 24 | 25 | def main(): 26 | args = parse_args() 27 | training_path = args.training_path 28 | testing_path = args.testing_path 29 | if args.out_dir is None: 30 | out_dir = osp.join('data', 'DRIVE') 31 | else: 32 | out_dir = args.out_dir 33 | 34 | print('Making directories...') 35 | mmcv.mkdir_or_exist(out_dir) 36 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images')) 37 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training')) 38 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) 39 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations')) 40 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) 41 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) 42 | 43 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 44 | print('Extracting training.zip...') 45 | zip_file = zipfile.ZipFile(training_path) 46 | zip_file.extractall(tmp_dir) 47 | 48 | print('Generating training dataset...') 49 | now_dir = osp.join(tmp_dir, 'training', 'images') 50 | for img_name in os.listdir(now_dir): 51 | img = mmcv.imread(osp.join(now_dir, img_name)) 52 | mmcv.imwrite( 53 | img, 54 | osp.join( 55 | out_dir, 'images', 'training', 56 | osp.splitext(img_name)[0].replace('_training', '') + 57 | '.png')) 58 | 59 | now_dir = osp.join(tmp_dir, 'training', '1st_manual') 60 | for img_name in os.listdir(now_dir): 61 | cap = cv2.VideoCapture(osp.join(now_dir, img_name)) 62 | ret, img = cap.read() 63 | mmcv.imwrite( 64 | img[:, :, 0] // 128, 65 | osp.join(out_dir, 'annotations', 'training', 66 | osp.splitext(img_name)[0] + '.png')) 67 | 68 | print('Extracting test.zip...') 69 | zip_file = zipfile.ZipFile(testing_path) 70 | zip_file.extractall(tmp_dir) 71 | 72 | print('Generating validation dataset...') 73 | now_dir = osp.join(tmp_dir, 'test', 'images') 74 | for img_name in os.listdir(now_dir): 75 | img = mmcv.imread(osp.join(now_dir, img_name)) 76 | mmcv.imwrite( 77 | img, 78 | osp.join( 79 | out_dir, 'images', 'validation', 80 | osp.splitext(img_name)[0].replace('_test', '') + '.png')) 81 | 82 | now_dir = osp.join(tmp_dir, 'test', '1st_manual') 83 | if osp.exists(now_dir): 84 | for img_name in os.listdir(now_dir): 85 | cap = cv2.VideoCapture(osp.join(now_dir, img_name)) 86 | ret, img = cap.read() 87 | # The annotation img should be divided by 128, because some of 88 | # the annotation imgs are not standard. We should set a 89 | # threshold to convert the nonstandard annotation imgs. The 90 | # value divided by 128 is equivalent to '1 if value >= 128 91 | # else 0' 92 | mmcv.imwrite( 93 | img[:, :, 0] // 128, 94 | osp.join(out_dir, 'annotations', 'validation', 95 | osp.splitext(img_name)[0] + '.png')) 96 | 97 | now_dir = osp.join(tmp_dir, 'test', '2nd_manual') 98 | if osp.exists(now_dir): 99 | for img_name in os.listdir(now_dir): 100 | cap = cv2.VideoCapture(osp.join(now_dir, img_name)) 101 | ret, img = cap.read() 102 | mmcv.imwrite( 103 | img[:, :, 0] // 128, 104 | osp.join(out_dir, 'annotations', 'validation', 105 | osp.splitext(img_name)[0] + '.png')) 106 | 107 | print('Removing the temporary files...') 108 | 109 | print('Done!') 110 | 111 | 112 | if __name__ == '__main__': 113 | main() 114 | -------------------------------------------------------------------------------- /segmentation/tools/convert_datasets/hrf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import tempfile 6 | import zipfile 7 | 8 | import mmcv 9 | 10 | HRF_LEN = 15 11 | TRAINING_LEN = 5 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Convert HRF dataset to mmsegmentation format') 17 | parser.add_argument('healthy_path', help='the path of healthy.zip') 18 | parser.add_argument( 19 | 'healthy_manualsegm_path', help='the path of healthy_manualsegm.zip') 20 | parser.add_argument('glaucoma_path', help='the path of glaucoma.zip') 21 | parser.add_argument( 22 | 'glaucoma_manualsegm_path', help='the path of glaucoma_manualsegm.zip') 23 | parser.add_argument( 24 | 'diabetic_retinopathy_path', 25 | help='the path of diabetic_retinopathy.zip') 26 | parser.add_argument( 27 | 'diabetic_retinopathy_manualsegm_path', 28 | help='the path of diabetic_retinopathy_manualsegm.zip') 29 | parser.add_argument('--tmp_dir', help='path of the temporary directory') 30 | parser.add_argument('-o', '--out_dir', help='output path') 31 | args = parser.parse_args() 32 | return args 33 | 34 | 35 | def main(): 36 | args = parse_args() 37 | images_path = [ 38 | args.healthy_path, args.glaucoma_path, args.diabetic_retinopathy_path 39 | ] 40 | annotations_path = [ 41 | args.healthy_manualsegm_path, args.glaucoma_manualsegm_path, 42 | args.diabetic_retinopathy_manualsegm_path 43 | ] 44 | if args.out_dir is None: 45 | out_dir = osp.join('data', 'HRF') 46 | else: 47 | out_dir = args.out_dir 48 | 49 | print('Making directories...') 50 | mmcv.mkdir_or_exist(out_dir) 51 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images')) 52 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training')) 53 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) 54 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations')) 55 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) 56 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) 57 | 58 | print('Generating images...') 59 | for now_path in images_path: 60 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 61 | zip_file = zipfile.ZipFile(now_path) 62 | zip_file.extractall(tmp_dir) 63 | 64 | assert len(os.listdir(tmp_dir)) == HRF_LEN, \ 65 | 'len(os.listdir(tmp_dir)) != {}'.format(HRF_LEN) 66 | 67 | for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: 68 | img = mmcv.imread(osp.join(tmp_dir, filename)) 69 | mmcv.imwrite( 70 | img, 71 | osp.join(out_dir, 'images', 'training', 72 | osp.splitext(filename)[0] + '.png')) 73 | for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: 74 | img = mmcv.imread(osp.join(tmp_dir, filename)) 75 | mmcv.imwrite( 76 | img, 77 | osp.join(out_dir, 'images', 'validation', 78 | osp.splitext(filename)[0] + '.png')) 79 | 80 | print('Generating annotations...') 81 | for now_path in annotations_path: 82 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 83 | zip_file = zipfile.ZipFile(now_path) 84 | zip_file.extractall(tmp_dir) 85 | 86 | assert len(os.listdir(tmp_dir)) == HRF_LEN, \ 87 | 'len(os.listdir(tmp_dir)) != {}'.format(HRF_LEN) 88 | 89 | for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: 90 | img = mmcv.imread(osp.join(tmp_dir, filename)) 91 | # The annotation img should be divided by 128, because some of 92 | # the annotation imgs are not standard. We should set a 93 | # threshold to convert the nonstandard annotation imgs. The 94 | # value divided by 128 is equivalent to '1 if value >= 128 95 | # else 0' 96 | mmcv.imwrite( 97 | img[:, :, 0] // 128, 98 | osp.join(out_dir, 'annotations', 'training', 99 | osp.splitext(filename)[0] + '.png')) 100 | for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: 101 | img = mmcv.imread(osp.join(tmp_dir, filename)) 102 | mmcv.imwrite( 103 | img[:, :, 0] // 128, 104 | osp.join(out_dir, 'annotations', 'validation', 105 | osp.splitext(filename)[0] + '.png')) 106 | 107 | print('Done!') 108 | 109 | 110 | if __name__ == '__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /segmentation/tools/convert_datasets/pascal_context.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from functools import partial 5 | 6 | import mmcv 7 | import numpy as np 8 | from detail import Detail 9 | from PIL import Image 10 | 11 | _mapping = np.sort( 12 | np.array([ 13 | 0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284, 14 | 158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59, 15 | 440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355, 16 | 85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115 17 | ])) 18 | _key = np.array(range(len(_mapping))).astype('uint8') 19 | 20 | 21 | def generate_labels(img_id, detail, out_dir): 22 | 23 | def _class_to_index(mask, _mapping, _key): 24 | # assert the values 25 | values = np.unique(mask) 26 | for i in range(len(values)): 27 | assert (values[i] in _mapping) 28 | index = np.digitize(mask.ravel(), _mapping, right=True) 29 | return _key[index].reshape(mask.shape) 30 | 31 | mask = Image.fromarray( 32 | _class_to_index(detail.getMask(img_id), _mapping=_mapping, _key=_key)) 33 | filename = img_id['file_name'] 34 | mask.save(osp.join(out_dir, filename.replace('jpg', 'png'))) 35 | return osp.splitext(osp.basename(filename))[0] 36 | 37 | 38 | def parse_args(): 39 | parser = argparse.ArgumentParser( 40 | description='Convert PASCAL VOC annotations to mmsegmentation format') 41 | parser.add_argument('devkit_path', help='pascal voc devkit path') 42 | parser.add_argument('json_path', help='annoation json filepath') 43 | parser.add_argument('-o', '--out_dir', help='output path') 44 | args = parser.parse_args() 45 | return args 46 | 47 | 48 | def main(): 49 | args = parse_args() 50 | devkit_path = args.devkit_path 51 | if args.out_dir is None: 52 | out_dir = osp.join(devkit_path, 'VOC2010', 'SegmentationClassContext') 53 | else: 54 | out_dir = args.out_dir 55 | json_path = args.json_path 56 | mmcv.mkdir_or_exist(out_dir) 57 | img_dir = osp.join(devkit_path, 'VOC2010', 'JPEGImages') 58 | 59 | train_detail = Detail(json_path, img_dir, 'train') 60 | train_ids = train_detail.getImgs() 61 | 62 | val_detail = Detail(json_path, img_dir, 'val') 63 | val_ids = val_detail.getImgs() 64 | 65 | mmcv.mkdir_or_exist( 66 | osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext')) 67 | 68 | train_list = mmcv.track_progress( 69 | partial(generate_labels, detail=train_detail, out_dir=out_dir), 70 | train_ids) 71 | with open( 72 | osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext', 73 | 'train.txt'), 'w') as f: 74 | f.writelines(line + '\n' for line in sorted(train_list)) 75 | 76 | val_list = mmcv.track_progress( 77 | partial(generate_labels, detail=val_detail, out_dir=out_dir), val_ids) 78 | with open( 79 | osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext', 80 | 'val.txt'), 'w') as f: 81 | f.writelines(line + '\n' for line in sorted(val_list)) 82 | 83 | print('Done!') 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /segmentation/tools/convert_datasets/voc_aug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from functools import partial 5 | 6 | import mmcv 7 | import numpy as np 8 | from PIL import Image 9 | from scipy.io import loadmat 10 | 11 | AUG_LEN = 10582 12 | 13 | 14 | def convert_mat(mat_file, in_dir, out_dir): 15 | data = loadmat(osp.join(in_dir, mat_file)) 16 | mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8) 17 | seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png')) 18 | Image.fromarray(mask).save(seg_filename, 'PNG') 19 | 20 | 21 | def generate_aug_list(merged_list, excluded_list): 22 | return list(set(merged_list) - set(excluded_list)) 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser( 27 | description='Convert PASCAL VOC annotations to mmsegmentation format') 28 | parser.add_argument('devkit_path', help='pascal voc devkit path') 29 | parser.add_argument('aug_path', help='pascal voc aug path') 30 | parser.add_argument('-o', '--out_dir', help='output path') 31 | parser.add_argument( 32 | '--nproc', default=1, type=int, help='number of process') 33 | args = parser.parse_args() 34 | return args 35 | 36 | 37 | def main(): 38 | args = parse_args() 39 | devkit_path = args.devkit_path 40 | aug_path = args.aug_path 41 | nproc = args.nproc 42 | if args.out_dir is None: 43 | out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug') 44 | else: 45 | out_dir = args.out_dir 46 | mmcv.mkdir_or_exist(out_dir) 47 | in_dir = osp.join(aug_path, 'dataset', 'cls') 48 | 49 | mmcv.track_parallel_progress( 50 | partial(convert_mat, in_dir=in_dir, out_dir=out_dir), 51 | list(mmcv.scandir(in_dir, suffix='.mat')), 52 | nproc=nproc) 53 | 54 | full_aug_list = [] 55 | with open(osp.join(aug_path, 'dataset', 'train.txt')) as f: 56 | full_aug_list += [line.strip() for line in f] 57 | with open(osp.join(aug_path, 'dataset', 'val.txt')) as f: 58 | full_aug_list += [line.strip() for line in f] 59 | 60 | with open( 61 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 62 | 'train.txt')) as f: 63 | ori_train_list = [line.strip() for line in f] 64 | with open( 65 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 66 | 'val.txt')) as f: 67 | val_list = [line.strip() for line in f] 68 | 69 | aug_train_list = generate_aug_list(ori_train_list + full_aug_list, 70 | val_list) 71 | assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format( 72 | AUG_LEN) 73 | 74 | with open( 75 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 76 | 'trainaug.txt'), 'w') as f: 77 | f.writelines(line + '\n' for line in aug_train_list) 78 | 79 | aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list) 80 | assert len(aug_list) == AUG_LEN - len( 81 | ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN - 82 | len(ori_train_list)) 83 | with open( 84 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'), 85 | 'w') as f: 86 | f.writelines(line + '\n' for line in aug_list) 87 | 88 | print('Done!') 89 | 90 | 91 | if __name__ == '__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /segmentation/tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | PORT=${PORT:-29500} 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | NCCL_P2P_DISABLE=1 \ 9 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} 11 | -------------------------------------------------------------------------------- /segmentation/tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | NNODES=${NNODES:-1} 6 | NODE_RANK=${NODE_RANK:-0} 7 | PORT=${PORT:-29500} 8 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 9 | 10 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 11 | NCCL_P2P_DISABLE=1 \ 12 | python -m torch.distributed.launch \ 13 | --nnodes=$NNODES \ 14 | --node_rank=$NODE_RANK \ 15 | --master_addr=$MASTER_ADDR \ 16 | --nproc_per_node=$GPUS \ 17 | --master_port=$PORT \ 18 | $(dirname "$0")/train.py \ 19 | $CONFIG \ 20 | --launcher pytorch ${@:3} 21 | -------------------------------------------------------------------------------- /segmentation/tools/get_flops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | from mmcv import Config 5 | from mmcv.cnn import get_model_complexity_info 6 | 7 | from mmseg.models import build_segmentor 8 | import sys 9 | sys.path.append("..") 10 | import xformer 11 | import pvt 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description='Train a segmentor') 15 | parser.add_argument('config', help='train config file path') 16 | parser.add_argument( 17 | '--shape', 18 | type=int, 19 | nargs='+', 20 | default=[2048, 1024], 21 | help='input image size') 22 | args = parser.parse_args() 23 | return args 24 | 25 | 26 | def main(): 27 | 28 | args = parse_args() 29 | 30 | if len(args.shape) == 1: 31 | input_shape = (3, args.shape[0], args.shape[0]) 32 | elif len(args.shape) == 2: 33 | input_shape = (3, ) + tuple(args.shape) 34 | else: 35 | raise ValueError('invalid input shape') 36 | 37 | cfg = Config.fromfile(args.config) 38 | cfg.model.pretrained = None 39 | model = build_segmentor( 40 | cfg.model, 41 | train_cfg=cfg.get('train_cfg'), 42 | test_cfg=cfg.get('test_cfg')).cuda() 43 | model.eval() 44 | 45 | if hasattr(model, 'forward_dummy'): 46 | model.forward = model.forward_dummy 47 | else: 48 | raise NotImplementedError( 49 | 'FLOPs counter is currently not currently supported with {}'. 50 | format(model.__class__.__name__)) 51 | 52 | flops, params = get_model_complexity_info(model, input_shape) 53 | split_line = '=' * 30 54 | print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format( 55 | split_line, input_shape, flops, params)) 56 | print('!!!Please be cautious if you use the results in papers. ' 57 | 'You may need to check if all ops are supported and verify that the ' 58 | 'flops computation is correct.') 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /segmentation/tools/model_converters/mit2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from collections import OrderedDict 5 | 6 | import mmcv 7 | import torch 8 | from mmcv.runner import CheckpointLoader 9 | 10 | 11 | def convert_mit(ckpt): 12 | new_ckpt = OrderedDict() 13 | # Process the concat between q linear weights and kv linear weights 14 | for k, v in ckpt.items(): 15 | if k.startswith('head'): 16 | continue 17 | # patch embedding conversion 18 | elif k.startswith('patch_embed'): 19 | stage_i = int(k.split('.')[0].replace('patch_embed', '')) 20 | new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0') 21 | new_v = v 22 | if 'proj.' in new_k: 23 | new_k = new_k.replace('proj.', 'projection.') 24 | # transformer encoder layer conversion 25 | elif k.startswith('block'): 26 | stage_i = int(k.split('.')[0].replace('block', '')) 27 | new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1') 28 | new_v = v 29 | if 'attn.q.' in new_k: 30 | sub_item_k = k.replace('q.', 'kv.') 31 | new_k = new_k.replace('q.', 'attn.in_proj_') 32 | new_v = torch.cat([v, ckpt[sub_item_k]], dim=0) 33 | elif 'attn.kv.' in new_k: 34 | continue 35 | elif 'attn.proj.' in new_k: 36 | new_k = new_k.replace('proj.', 'attn.out_proj.') 37 | elif 'attn.sr.' in new_k: 38 | new_k = new_k.replace('sr.', 'sr.') 39 | elif 'mlp.' in new_k: 40 | string = f'{new_k}-' 41 | new_k = new_k.replace('mlp.', 'ffn.layers.') 42 | if 'fc1.weight' in new_k or 'fc2.weight' in new_k: 43 | new_v = v.reshape((*v.shape, 1, 1)) 44 | new_k = new_k.replace('fc1.', '0.') 45 | new_k = new_k.replace('dwconv.dwconv.', '1.') 46 | new_k = new_k.replace('fc2.', '4.') 47 | string += f'{new_k} {v.shape}-{new_v.shape}' 48 | # norm layer conversion 49 | elif k.startswith('norm'): 50 | stage_i = int(k.split('.')[0].replace('norm', '')) 51 | new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2') 52 | new_v = v 53 | else: 54 | new_k = k 55 | new_v = v 56 | new_ckpt[new_k] = new_v 57 | return new_ckpt 58 | 59 | 60 | def main(): 61 | parser = argparse.ArgumentParser( 62 | description='Convert keys in official pretrained segformer to ' 63 | 'MMSegmentation style.') 64 | parser.add_argument('src', help='src model path or url') 65 | # The dst path must be a full path of the new checkpoint. 66 | parser.add_argument('dst', help='save path') 67 | args = parser.parse_args() 68 | 69 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 70 | if 'state_dict' in checkpoint: 71 | state_dict = checkpoint['state_dict'] 72 | elif 'model' in checkpoint: 73 | state_dict = checkpoint['model'] 74 | else: 75 | state_dict = checkpoint 76 | weight = convert_mit(state_dict) 77 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 78 | torch.save(weight, args.dst) 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /segmentation/tools/model_converters/swin2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from collections import OrderedDict 5 | 6 | import mmcv 7 | import torch 8 | from mmcv.runner import CheckpointLoader 9 | 10 | 11 | def convert_swin(ckpt): 12 | new_ckpt = OrderedDict() 13 | 14 | def correct_unfold_reduction_order(x): 15 | out_channel, in_channel = x.shape 16 | x = x.reshape(out_channel, 4, in_channel // 4) 17 | x = x[:, [0, 2, 1, 3], :].transpose(1, 18 | 2).reshape(out_channel, in_channel) 19 | return x 20 | 21 | def correct_unfold_norm_order(x): 22 | in_channel = x.shape[0] 23 | x = x.reshape(4, in_channel // 4) 24 | x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) 25 | return x 26 | 27 | for k, v in ckpt.items(): 28 | if k.startswith('head'): 29 | continue 30 | elif k.startswith('layers'): 31 | new_v = v 32 | if 'attn.' in k: 33 | new_k = k.replace('attn.', 'attn.w_msa.') 34 | elif 'mlp.' in k: 35 | if 'mlp.fc1.' in k: 36 | new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') 37 | elif 'mlp.fc2.' in k: 38 | new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') 39 | else: 40 | new_k = k.replace('mlp.', 'ffn.') 41 | elif 'downsample' in k: 42 | new_k = k 43 | if 'reduction.' in k: 44 | new_v = correct_unfold_reduction_order(v) 45 | elif 'norm.' in k: 46 | new_v = correct_unfold_norm_order(v) 47 | else: 48 | new_k = k 49 | new_k = new_k.replace('layers', 'stages', 1) 50 | elif k.startswith('patch_embed'): 51 | new_v = v 52 | if 'proj' in k: 53 | new_k = k.replace('proj', 'projection') 54 | else: 55 | new_k = k 56 | else: 57 | new_v = v 58 | new_k = k 59 | 60 | new_ckpt[new_k] = new_v 61 | 62 | return new_ckpt 63 | 64 | 65 | def main(): 66 | parser = argparse.ArgumentParser( 67 | description='Convert keys in official pretrained swin models to' 68 | 'MMSegmentation style.') 69 | parser.add_argument('src', help='src model path or url') 70 | # The dst path must be a full path of the new checkpoint. 71 | parser.add_argument('dst', help='save path') 72 | args = parser.parse_args() 73 | 74 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 75 | if 'state_dict' in checkpoint: 76 | state_dict = checkpoint['state_dict'] 77 | elif 'model' in checkpoint: 78 | state_dict = checkpoint['model'] 79 | else: 80 | state_dict = checkpoint 81 | weight = convert_swin(state_dict) 82 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 83 | torch.save(weight, args.dst) 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /segmentation/tools/model_converters/vit2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from collections import OrderedDict 5 | 6 | import mmcv 7 | import torch 8 | from mmcv.runner import CheckpointLoader 9 | 10 | 11 | def convert_vit(ckpt): 12 | 13 | new_ckpt = OrderedDict() 14 | 15 | for k, v in ckpt.items(): 16 | if k.startswith('head'): 17 | continue 18 | if k.startswith('norm'): 19 | new_k = k.replace('norm.', 'ln1.') 20 | elif k.startswith('patch_embed'): 21 | if 'proj' in k: 22 | new_k = k.replace('proj', 'projection') 23 | else: 24 | new_k = k 25 | elif k.startswith('blocks'): 26 | if 'norm' in k: 27 | new_k = k.replace('norm', 'ln') 28 | elif 'mlp.fc1' in k: 29 | new_k = k.replace('mlp.fc1', 'ffn.layers.0.0') 30 | elif 'mlp.fc2' in k: 31 | new_k = k.replace('mlp.fc2', 'ffn.layers.1') 32 | elif 'attn.qkv' in k: 33 | new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_') 34 | elif 'attn.proj' in k: 35 | new_k = k.replace('attn.proj', 'attn.attn.out_proj') 36 | else: 37 | new_k = k 38 | new_k = new_k.replace('blocks.', 'layers.') 39 | else: 40 | new_k = k 41 | new_ckpt[new_k] = v 42 | 43 | return new_ckpt 44 | 45 | 46 | def main(): 47 | parser = argparse.ArgumentParser( 48 | description='Convert keys in timm pretrained vit models to ' 49 | 'MMSegmentation style.') 50 | parser.add_argument('src', help='src model path or url') 51 | # The dst path must be a full path of the new checkpoint. 52 | parser.add_argument('dst', help='save path') 53 | args = parser.parse_args() 54 | 55 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 56 | if 'state_dict' in checkpoint: 57 | # timm checkpoint 58 | state_dict = checkpoint['state_dict'] 59 | elif 'model' in checkpoint: 60 | # deit checkpoint 61 | state_dict = checkpoint['model'] 62 | else: 63 | state_dict = checkpoint 64 | weight = convert_vit(state_dict) 65 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 66 | torch.save(weight, args.dst) 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /segmentation/tools/print_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | from mmcv import Config, DictAction 5 | 6 | from mmseg.apis import init_segmentor 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='Print the whole config') 11 | parser.add_argument('config', help='config file path') 12 | parser.add_argument( 13 | '--graph', action='store_true', help='print the models graph') 14 | parser.add_argument( 15 | '--options', nargs='+', action=DictAction, help='arguments in dict') 16 | args = parser.parse_args() 17 | 18 | return args 19 | 20 | 21 | def main(): 22 | args = parse_args() 23 | 24 | cfg = Config.fromfile(args.config) 25 | if args.options is not None: 26 | cfg.merge_from_dict(args.options) 27 | print(f'Config:\n{cfg.pretty_text}') 28 | # dump config 29 | cfg.dump('example.py') 30 | # dump models graph 31 | if args.graph: 32 | model = init_segmentor(args.config, device='cpu') 33 | print(f'Model graph:\n{str(model)}') 34 | with open('example-graph.txt', 'w') as f: 35 | f.writelines(str(model)) 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /segmentation/tools/publish_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import subprocess 4 | 5 | import torch 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser( 10 | description='Process a checkpoint to be published') 11 | parser.add_argument('in_file', help='input checkpoint filename') 12 | parser.add_argument('out_file', help='output checkpoint filename') 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | def process_checkpoint(in_file, out_file): 18 | checkpoint = torch.load(in_file, map_location='cpu') 19 | # remove optimizer for smaller file size 20 | if 'optimizer' in checkpoint: 21 | del checkpoint['optimizer'] 22 | # if it is necessary to remove some sensitive data in checkpoint['meta'], 23 | # add the code here. 24 | torch.save(checkpoint, out_file) 25 | sha = subprocess.check_output(['sha256sum', out_file]).decode() 26 | final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) 27 | subprocess.Popen(['mv', out_file, final_file]) 28 | 29 | 30 | def main(): 31 | args = parse_args() 32 | process_checkpoint(args.in_file, args.out_file) 33 | 34 | 35 | if __name__ == '__main__': 36 | main() 37 | -------------------------------------------------------------------------------- /segmentation/tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CHECKPOINT=$4 9 | GPUS=${GPUS:-4} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-4} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /segmentation/tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | GPUS=${GPUS:-8} 9 | GPUS_PER_NODE=${GPUS_PER_NODE:-4} 10 | CPUS_PER_TASK=${CPUS_PER_TASK:-12} 11 | SRUN_ARGS=${SRUN_ARGS:-""} 12 | PY_ARGS=${@:4} 13 | 14 | export NCCL_P2P_DISABLE=1 15 | export MASTER_PORT=13579 16 | 17 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 18 | srun -p ${PARTITION} \ 19 | --job-name=${JOB_NAME} \ 20 | --gres=gpu:${GPUS_PER_NODE} \ 21 | --ntasks=${GPUS} \ 22 | --ntasks-per-node=${GPUS_PER_NODE} \ 23 | --cpus-per-task=${CPUS_PER_TASK} \ 24 | --kill-on-bad-exit=1 \ 25 | --mem 250G \ 26 | ${SRUN_ARGS} \ 27 | python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS} 28 | -------------------------------------------------------------------------------- /segmentation/tools/torchserve/mmseg2torchserve.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from argparse import ArgumentParser, Namespace 3 | from pathlib import Path 4 | from tempfile import TemporaryDirectory 5 | 6 | import mmcv 7 | 8 | try: 9 | from model_archiver.model_packaging import package_model 10 | from model_archiver.model_packaging_utils import ModelExportUtils 11 | except ImportError: 12 | package_model = None 13 | 14 | 15 | def mmseg2torchserve( 16 | config_file: str, 17 | checkpoint_file: str, 18 | output_folder: str, 19 | model_name: str, 20 | model_version: str = '1.0', 21 | force: bool = False, 22 | ): 23 | """Converts mmsegmentation model (config + checkpoint) to TorchServe 24 | `.mar`. 25 | 26 | Args: 27 | config_file: 28 | In MMSegmentation config format. 29 | The contents vary for each task repository. 30 | checkpoint_file: 31 | In MMSegmentation checkpoint format. 32 | The contents vary for each task repository. 33 | output_folder: 34 | Folder where `{model_name}.mar` will be created. 35 | The file created will be in TorchServe archive format. 36 | model_name: 37 | If not None, used for naming the `{model_name}.mar` file 38 | that will be created under `output_folder`. 39 | If None, `{Path(checkpoint_file).stem}` will be used. 40 | model_version: 41 | Model's version. 42 | force: 43 | If True, if there is an existing `{model_name}.mar` 44 | file under `output_folder` it will be overwritten. 45 | """ 46 | mmcv.mkdir_or_exist(output_folder) 47 | 48 | config = mmcv.Config.fromfile(config_file) 49 | 50 | with TemporaryDirectory() as tmpdir: 51 | config.dump(f'{tmpdir}/config.py') 52 | 53 | args = Namespace( 54 | **{ 55 | 'model_file': f'{tmpdir}/config.py', 56 | 'serialized_file': checkpoint_file, 57 | 'handler': f'{Path(__file__).parent}/mmseg_handler.py', 58 | 'model_name': model_name or Path(checkpoint_file).stem, 59 | 'version': model_version, 60 | 'export_path': output_folder, 61 | 'force': force, 62 | 'requirements_file': None, 63 | 'extra_files': None, 64 | 'runtime': 'python', 65 | 'archive_format': 'default' 66 | }) 67 | manifest = ModelExportUtils.generate_manifest_json(args) 68 | package_model(args, manifest) 69 | 70 | 71 | def parse_args(): 72 | parser = ArgumentParser( 73 | description='Convert mmseg models to TorchServe `.mar` format.') 74 | parser.add_argument('config', type=str, help='config file path') 75 | parser.add_argument('checkpoint', type=str, help='checkpoint file path') 76 | parser.add_argument( 77 | '--output-folder', 78 | type=str, 79 | required=True, 80 | help='Folder where `{model_name}.mar` will be created.') 81 | parser.add_argument( 82 | '--model-name', 83 | type=str, 84 | default=None, 85 | help='If not None, used for naming the `{model_name}.mar`' 86 | 'file that will be created under `output_folder`.' 87 | 'If None, `{Path(checkpoint_file).stem}` will be used.') 88 | parser.add_argument( 89 | '--model-version', 90 | type=str, 91 | default='1.0', 92 | help='Number used for versioning.') 93 | parser.add_argument( 94 | '-f', 95 | '--force', 96 | action='store_true', 97 | help='overwrite the existing `{model_name}.mar`') 98 | args = parser.parse_args() 99 | 100 | return args 101 | 102 | 103 | if __name__ == '__main__': 104 | args = parse_args() 105 | 106 | if package_model is None: 107 | raise ImportError('`torch-model-archiver` is required.' 108 | 'Try: pip install torch-model-archiver') 109 | 110 | mmseg2torchserve(args.config, args.checkpoint, args.output_folder, 111 | args.model_name, args.model_version, args.force) 112 | -------------------------------------------------------------------------------- /segmentation/tools/torchserve/mmseg_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import base64 3 | import os 4 | 5 | import cv2 6 | import mmcv 7 | import torch 8 | from mmcv.cnn.utils.sync_bn import revert_sync_batchnorm 9 | from ts.torch_handler.base_handler import BaseHandler 10 | 11 | from mmseg.apis import inference_segmentor, init_segmentor 12 | 13 | 14 | class MMsegHandler(BaseHandler): 15 | 16 | def initialize(self, context): 17 | properties = context.system_properties 18 | self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' 19 | self.device = torch.device(self.map_location + ':' + 20 | str(properties.get('gpu_id')) if torch.cuda. 21 | is_available() else self.map_location) 22 | self.manifest = context.manifest 23 | 24 | model_dir = properties.get('model_dir') 25 | serialized_file = self.manifest['model']['serializedFile'] 26 | checkpoint = os.path.join(model_dir, serialized_file) 27 | self.config_file = os.path.join(model_dir, 'config.py') 28 | 29 | self.model = init_segmentor(self.config_file, checkpoint, self.device) 30 | self.model = revert_sync_batchnorm(self.model) 31 | self.initialized = True 32 | 33 | def preprocess(self, data): 34 | images = [] 35 | 36 | for row in data: 37 | image = row.get('data') or row.get('body') 38 | if isinstance(image, str): 39 | image = base64.b64decode(image) 40 | image = mmcv.imfrombytes(image) 41 | images.append(image) 42 | 43 | return images 44 | 45 | def inference(self, data, *args, **kwargs): 46 | results = [inference_segmentor(self.model, img) for img in data] 47 | return results 48 | 49 | def postprocess(self, data): 50 | output = [] 51 | 52 | for image_result in data: 53 | _, buffer = cv2.imencode('.png', image_result[0].astype('uint8')) 54 | content = buffer.tobytes() 55 | output.append(content) 56 | return output 57 | -------------------------------------------------------------------------------- /segmentation/tools/torchserve/test_torchserve.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from io import BytesIO 3 | 4 | import matplotlib.pyplot as plt 5 | import mmcv 6 | import requests 7 | 8 | from mmseg.apis import inference_segmentor, init_segmentor 9 | 10 | 11 | def parse_args(): 12 | parser = ArgumentParser( 13 | description='Compare result of torchserve and pytorch,' 14 | 'and visualize them.') 15 | parser.add_argument('img', help='Image file') 16 | parser.add_argument('config', help='Config file') 17 | parser.add_argument('checkpoint', help='Checkpoint file') 18 | parser.add_argument('model_name', help='The model name in the server') 19 | parser.add_argument( 20 | '--inference-addr', 21 | default='127.0.0.1:8080', 22 | help='Address and port of the inference server') 23 | parser.add_argument( 24 | '--result-image', 25 | type=str, 26 | default=None, 27 | help='save server output in result-image') 28 | parser.add_argument( 29 | '--device', default='cuda:0', help='Device used for inference') 30 | 31 | args = parser.parse_args() 32 | return args 33 | 34 | 35 | def main(args): 36 | url = 'http://' + args.inference_addr + '/predictions/' + args.model_name 37 | with open(args.img, 'rb') as image: 38 | tmp_res = requests.post(url, image) 39 | content = tmp_res.content 40 | if args.result_image: 41 | with open(args.result_image, 'wb') as out_image: 42 | out_image.write(content) 43 | plt.imshow(mmcv.imread(args.result_image, 'grayscale')) 44 | plt.show() 45 | else: 46 | plt.imshow(plt.imread(BytesIO(content))) 47 | plt.show() 48 | model = init_segmentor(args.config, args.checkpoint, args.device) 49 | image = mmcv.imread(args.img) 50 | result = inference_segmentor(model, image) 51 | plt.imshow(result[0]) 52 | plt.show() 53 | 54 | 55 | if __name__ == '__main__': 56 | args = parse_args() 57 | main(args) 58 | -------------------------------------------------------------------------------- /segmentation/train.sh: -------------------------------------------------------------------------------- 1 | ./tools/dist_train.sh configs/sem_fpn/fpn_repvit_m1_1_ade20k_40k.py 8 2 | -------------------------------------------------------------------------------- /speed_gpu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from timm import create_model 4 | import model 5 | import utils 6 | torch.autograd.set_grad_enabled(False) 7 | 8 | T0 = 5 9 | T1 = 10 10 | 11 | def throughput(name, model, device, batch_size, resolution=224): 12 | inputs = torch.randn(batch_size, 3, resolution, resolution, device=device) 13 | torch.cuda.empty_cache() 14 | torch.cuda.synchronize() 15 | start = time.time() 16 | while time.time() - start < T0: 17 | model(inputs) 18 | timing = [] 19 | torch.cuda.synchronize() 20 | while sum(timing) < T1: 21 | start = time.time() 22 | model(inputs) 23 | torch.cuda.synchronize() 24 | timing.append(time.time() - start) 25 | timing = torch.as_tensor(timing, dtype=torch.float32) 26 | print(name, device, batch_size / timing.mean().item(), 27 | 'images/s @ batch size', batch_size) 28 | 29 | device = "cuda:0" 30 | 31 | from argparse import ArgumentParser 32 | 33 | parser = ArgumentParser() 34 | 35 | parser.add_argument('--model', default='repvit_m0_9', type=str) 36 | parser.add_argument('--resolution', default=224, type=int) 37 | parser.add_argument('--batch-size', default=2048, type=int) 38 | 39 | if __name__ == "__main__": 40 | args = parser.parse_args() 41 | model_name = args.model 42 | batch_size = args.batch_size 43 | resolution = args.resolution 44 | torch.cuda.empty_cache() 45 | inputs = torch.randn(batch_size, 3, resolution, 46 | resolution, device=device) 47 | model = create_model(model_name, num_classes=1000) 48 | utils.replace_batchnorm(model) 49 | model.to(device) 50 | model.eval() 51 | throughput(model_name, model, device, batch_size, resolution=resolution) 52 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | NCCL_P2P_DISABLE=1 python -m torch.distributed.launch --nproc_per_node=8 --master_port 12346 --use_env main.py --model repvit_m0_9 --data-path ~/imagenet --dist-eval 2 | --------------------------------------------------------------------------------