├── LICENSE ├── README.md ├── attack ├── Attack.py ├── base │ ├── BaseAttack.py │ ├── DI.py │ ├── IFGSM.py │ ├── MI.py │ ├── RRB.py │ └── __init__.py ├── comparing │ ├── TransferAttack.py │ └── __init__.py ├── ours │ ├── OSFD.py │ └── __init__.py └── utils │ ├── __init__.py │ ├── buffer.py │ ├── mmdet.py │ ├── pipelines.py │ └── registry.py ├── config ├── attack_faster_rcnn.yaml ├── attack_swin.yaml ├── attack_vfnet.yaml ├── attack_yolov3.yaml ├── base.yaml └── default.py ├── data └── image_index.txt ├── run_perturb.py ├── tools ├── datasets │ ├── split_dataset.py │ └── voc12_to_coco.py ├── project │ ├── base_config.py │ ├── config.py │ ├── logger.py │ └── recorder.py └── utils.py └── ummdet ├── checkpoints ├── eval_cfg │ ├── detr_r50_8x2_150e_coco.py │ ├── faster_rcnn_r101_caffe_fpn_mstrain_3x_coco.py │ ├── fcos_r50_caffe_fpn_gn-head_mstrain_640-800_2x_coco.py │ ├── mask_rcnn_swin-t-p4-w7_fpn_ms-crop-3x_coco.py │ ├── vfnet_r50_fpn_mstrain_2x_coco.py │ ├── yolof_r50_c5_8x8_1x_coco.py │ ├── yolov3_d53_mstrain-608_273e_coco.py │ └── yolox_l_8x8_300e_coco.py └── train_cfg │ ├── faster_rcnn_r101_caffe_fpn_mstrain_3x_coco.py │ ├── mask_rcnn_swin-t-p4-w7_fpn_ms-crop-3x_coco.py │ ├── vfnet_r50_fpn_mstrain_2x_coco.py │ └── yolov3_d53_mstrain-608_273e_coco.py ├── components ├── __init__.py └── coco.py └── detectors ├── __init__.py ├── faster_rcnn.py ├── mask_rcnn.py ├── model_hook.py ├── vfnet.py ├── vfnet_head.py └── yolo.py /README.md: -------------------------------------------------------------------------------- 1 | # Transferable Adversarial Attacks for Object Detection using Object-Aware Significant Feature Distortion 2 | 3 | --- 4 | This is an official implementation of the **OSFD** adversarial attack method code. 5 | 6 | ## Installation 7 | 8 | --- 9 | ### 1. Prepare Environment 10 | ```shell 11 | # Create a conda environment 12 | conda create -n OSFD python=3.10 -y 13 | 14 | # Install an appropriate version of PyTorch 15 | conda activate OSFD 16 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch 17 | 18 | # Install OpenMMLab's openmim 19 | pip install -U openmim 20 | 21 | # Install mmcv and mmdet using openmim 22 | mim install "mmcv-full==1.7.1" 23 | mim install "mmdet==2.28.2" 24 | 25 | # Install other libs 26 | pip install imgaug 27 | pip install tensorboard 28 | 29 | ``` 30 | 31 | ### 2. Prepare Dataset 32 | ```shell 33 | # Download VOC dataset: http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html 34 | cd OSFD/data 35 | wget --no-check-certificate http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 36 | tar -xvf VOCtrainval_11-May-2012.tar 37 | 38 | # Construct a subset for attack 39 | (->OSFD) cd .. 40 | conda activate OSFD 41 | python tools/datasets/split_dataset.py --num 2000 --voc_type "VOC2012" --paper True 42 | 43 | # Covert VOC dataset to COCO type 44 | python tools/datasets/voc12_to_coco.py --num 2000 --img_size 800 --voc_type "VOC2012" 45 | ``` 46 | 47 | ### 3. Prepare Models 48 | ```shell 49 | cd OSFD/ummdet/checkpoints/models/ 50 | conda activate OSFD 51 | # Download model checkpoints and configs 52 | mim download mmdet --config yolov3_d53_mstrain-608_273e_coco --dest . 53 | mim download mmdet --config yolof_r50_c5_8x8_1x_coco --dest . 54 | mim download mmdet --config yolox_l_8x8_300e_coco --dest . 55 | mim download mmdet --config vfnet_r50_fpn_mstrain_2x_coco --dest . 56 | mim download mmdet --config fcos_r50_caffe_fpn_gn-head_mstrain_640-800_2x_coco --dest . 57 | mim download mmdet --config faster_rcnn_r101_caffe_fpn_mstrain_3x_coco --dest . 58 | mim download mmdet --config mask_rcnn_swin-t-p4-w7_fpn_ms-crop-3x_coco --dest . 59 | mim download mmdet --config detr_r50_8x2_150e_coco --dest . 60 | 61 | # Modify model configs for attack 62 | # ummdet 63 | # └─checkpoints 64 | # ├─train_cfg: The white-box model configs to be attacked. 65 | # ├─eval_cfg: White box and black box configs for all models. 66 | # └─models: Model files with the '.pth' suffix. 67 | 68 | ``` 69 | 70 | 71 | 72 | ## Attack 73 | 74 | --- 75 | All the attack configuration is in the yaml file, you can refer to base.yaml to customize it yourself. 76 | 77 | **model_name** can be: `yolov3`, `vfnet`, `faster_rcnn`, `swin`. 78 | 79 | Before the attack, you need to set OSFD (project_path) for attack_{model_name}.yaml files. 80 | ```shell 81 | # Main attack script 82 | (->OSFD) cd ../../.. 83 | python run_perturb.py config/attack_{model_name}.yaml 84 | ``` -------------------------------------------------------------------------------- /attack/Attack.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mmcv 3 | import logging 4 | import functools 5 | 6 | from mmcv.parallel import scatter 7 | from attack.utils.buffer import Buffer 8 | from attack.utils.mmdet import init_detector, init_cfg, init_dataloader, single_gpu_test, load_noise_for_eval 9 | from attack.utils.pipelines import PreProcessData, CalculateLoss, UpdateNoise, PostProcessData 10 | from attack.utils.registry import BASEATK, TSFATK 11 | # keep this import 12 | import attack.base, attack.comparing, attack.ours, ummdet.detectors, ummdet.components 13 | 14 | 15 | class Attack: 16 | def __init__(self, yml_cfg) -> None: 17 | self.yml_cfg = yml_cfg 18 | self.model_cfg = init_cfg(yml_cfg.source_model_cfg[0], yml_cfg.dataset_cfg) 19 | self.model = init_detector(self.model_cfg, 20 | yml_cfg.source_model_cfg[1]) 21 | self.dataloader = init_dataloader(self.model_cfg, 22 | samples_per_gpu=yml_cfg.dataloader_cfg.get("batch_size", 1), 23 | workers_per_gpu=yml_cfg.dataloader_cfg.get("cpu_num", 0), 24 | persistent_workers=yml_cfg.dataloader_cfg.get("persistent_workers", False)) 25 | # Init attack methods 26 | self.base_attack = self.init_base_attack() 27 | self.transfer_attack = self.init_transfer_attack() 28 | 29 | # Init attack pipelines 30 | self.buffer = Buffer(os.environ.get("tmp_dir")) 31 | self.buffer.update_buffer_types(yml_cfg.buffer) 32 | self.pipeline = self.init_attack_pipeline() 33 | 34 | # Init Hooks 35 | self.attack_step_hooks = [] 36 | 37 | def init_base_attack(self): 38 | base_attack = dict() 39 | for method in self.yml_cfg.attack_base: 40 | base_attack[method] = BASEATK.build(self.yml_cfg.default_cfg.base_attack[method]) 41 | return base_attack 42 | 43 | def init_transfer_attack(self): 44 | transfer_attack = dict() 45 | method = self.yml_cfg.attack_transfer 46 | transfer_attack[method] = TSFATK.build(self.yml_cfg.default_cfg.transfer_attack[method]) 47 | return transfer_attack 48 | 49 | def init_attack_pipeline(self): 50 | pre_process_data_pipeline = PreProcessData(self.base_attack, self.transfer_attack) 51 | calculate_loss_pipeline = CalculateLoss(self.base_attack, self.transfer_attack) 52 | update_noise_pipeline = UpdateNoise(self.base_attack) 53 | post_process_data_pipeline = PostProcessData(self.base_attack) 54 | return [pre_process_data_pipeline, calculate_loss_pipeline, update_noise_pipeline, post_process_data_pipeline] 55 | 56 | def attack_step(self, results): 57 | for p in self.pipeline[:-1]: 58 | results = p(results) 59 | for hook in self.attack_step_hooks: 60 | results = hook(results) 61 | return results 62 | 63 | def attack_epoch(self): 64 | mmcv.print_log("Generating adversarial examples.", logger="verbose_logger") 65 | losses = [0. for _ in range(self.yml_cfg.attack_cfg["steps"])] 66 | for idx, data in enumerate(mmcv.track_iter_progress(self.dataloader)): 67 | if "cuda" in self.yml_cfg.device: 68 | data = scatter(data, [0])[0] 69 | results = dict(idx=str(idx), data=data, buffer=self.buffer, model=self.model, 70 | epsilon=self.yml_cfg.attack_cfg["epsilon"]) 71 | for step in range(self.yml_cfg.attack_cfg["steps"]): 72 | results["step"] = step 73 | results = self.attack_step(results) 74 | # log the loss 75 | losses[step] += results.pop("loss_combined").item() 76 | self.pipeline[-1](results) 77 | return losses 78 | 79 | def eval(self, mode="clean"): 80 | metric_dict = dict() 81 | for model_name, (model_cfg_fp, model_checkpoint_fp) in self.yml_cfg.models_zoo.items(): 82 | config = init_cfg(model_cfg_fp, self.yml_cfg.dataset_cfg) 83 | model = init_detector(config, model_checkpoint_fp) 84 | dataloader = init_dataloader(config, 85 | samples_per_gpu=self.yml_cfg.dataloader_cfg.get("eval_batch_size", 1), 86 | workers_per_gpu=self.yml_cfg.dataloader_cfg.get("eval_cpu_num", 0), 87 | persistent_workers=self.yml_cfg.dataloader_cfg.get("persistent_workers", False)) 88 | if not "clean" == mode: 89 | func_load_noise = functools.partial(load_noise_for_eval, buffer=self.buffer, 90 | buffer_batch_size=self.yml_cfg.dataloader_cfg["batch_size"]) 91 | else: 92 | func_load_noise = None 93 | results = single_gpu_test(model, dataloader, func_load_noise, mode) 94 | 95 | eval_kwargs = config.get('evaluation', {}).copy() 96 | # hard-code way to remove EvalHook args 97 | for key in ['interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 'rule', 'dynamic_intervals']: 98 | eval_kwargs.pop(key, None) 99 | eval_kwargs.update(dict(**self.yml_cfg.eval_cfg)) 100 | metric = dataloader.dataset.evaluate(results, logger=mmcv.get_logger("eval_logger"), 101 | **eval_kwargs) 102 | metric_dict[model_name] = metric 103 | mmcv.print_log(f"The metric of {model_name} is: {metric}", 104 | logger=mmcv.get_logger("eval_logger"), level=logging.INFO) 105 | return metric_dict -------------------------------------------------------------------------------- /attack/base/BaseAttack.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class BaseAttack(object): 5 | 6 | def __init__(self) -> None: 7 | super().__init__() 8 | self.device = os.environ["device"] 9 | 10 | def preprocess_data(self, results): 11 | return results 12 | 13 | def combine_losses(self, results): 14 | return results 15 | 16 | def process_gradients(self, results): 17 | return results 18 | 19 | def update_noise(self, results): 20 | return results 21 | 22 | def postprocess_data(self, results): 23 | return results -------------------------------------------------------------------------------- /attack/base/DI.py: -------------------------------------------------------------------------------- 1 | from .IFGSM import IFGSM 2 | import torch 3 | import random 4 | import torch.nn.functional as F 5 | from attack.utils.registry import BASEATK 6 | 7 | 8 | @BASEATK.register_module() 9 | class DI(IFGSM): 10 | def __init__(self, alpha=1.0, prob=0.7, scale=1.1) -> None: 11 | super().__init__(alpha) 12 | self.prob = prob 13 | self.scale = scale 14 | 15 | def preprocess_data(self, results): 16 | super().preprocess_data(results) 17 | data_adv_imgs = results.get("data_adv_imgs") 18 | if random.random() < self.prob: 19 | data_adv_imgs = self.input_diversity(data_adv_imgs, scale=self.scale) 20 | results["data_adv_imgs"] = data_adv_imgs 21 | return results 22 | 23 | def combine_losses(self, results): 24 | results = super().combine_losses(results) 25 | return results 26 | 27 | def process_gradients(self, results): 28 | results = super().process_gradients(results) 29 | return results 30 | 31 | def update_noise(self, results): 32 | results = super().update_noise(results) 33 | return results 34 | 35 | def postprocess_data(self, results): 36 | results = super().postprocess_data(results) 37 | return results 38 | 39 | @staticmethod 40 | def input_diversity(imgs, scale=1.1): 41 | padded_list = [] 42 | for idx in range(imgs.shape[0]): 43 | input_tensor = imgs[idx].unsqueeze(0) 44 | ori_size = input_tensor.shape[2] 45 | new_size = random.randint(ori_size, int(scale * ori_size)) 46 | rescaled = F.interpolate(input_tensor, size=(new_size, new_size), mode='bilinear', align_corners=True) 47 | rem = int(scale * ori_size) - new_size 48 | pad_left = random.randint(0, rem) 49 | pad_top = random.randint(0, rem) 50 | padded = F.pad(rescaled, (pad_left, rem - pad_left, pad_top, rem - pad_top), mode='constant', value=0.) 51 | padded = F.interpolate(padded, size=(ori_size, ori_size), mode='bilinear', align_corners=True) 52 | padded_list.append(padded) 53 | return torch.cat(padded_list, dim=0) -------------------------------------------------------------------------------- /attack/base/IFGSM.py: -------------------------------------------------------------------------------- 1 | from .BaseAttack import BaseAttack 2 | from attack.utils.registry import BASEATK 3 | import torch 4 | 5 | 6 | @BASEATK.register_module() 7 | class IFGSM(BaseAttack): 8 | def __init__(self, alpha=1.0) -> None: 9 | super().__init__() 10 | self.alpha = alpha 11 | 12 | def preprocess_data(self, results): 13 | if results.get("data_adv_imgs") is None: 14 | data_adv_imgs = results["data_clean_imgs"] + results["noise"] 15 | results["data_adv_imgs"] = data_adv_imgs 16 | return results 17 | 18 | def combine_losses(self, results): 19 | # Sum the loss 20 | losses = results.pop("losses", None) 21 | if losses is not None: 22 | loss = torch.cat([l.get("loss_item")[None] for l in losses]).sum() 23 | results["loss_combined"] = loss 24 | return results 25 | 26 | def process_gradients(self, results): 27 | results = super().process_gradients(results) 28 | return results 29 | 30 | def update_noise(self, results): 31 | data_adv_imgs = results.pop("data_adv_imgs", None) 32 | if data_adv_imgs is not None: 33 | epsilon = results["epsilon"] 34 | gradients_adv = results.pop("gradients_adv", None) 35 | noise = results.pop("noise", None) 36 | noise = noise + self.alpha * torch.sign(gradients_adv) 37 | noise = torch.clamp(noise, min=-epsilon, max=epsilon) 38 | results["noise"] = noise 39 | return results 40 | 41 | def postprocess_data(self, results): 42 | results = super().postprocess_data(results) 43 | return results 44 | 45 | -------------------------------------------------------------------------------- /attack/base/MI.py: -------------------------------------------------------------------------------- 1 | from .IFGSM import IFGSM 2 | from attack.utils.registry import BASEATK 3 | import torch 4 | 5 | 6 | @BASEATK.register_module() 7 | class MI(IFGSM): 8 | def __init__(self, alpha=1.0, momentum=1.0) -> None: 9 | super().__init__(alpha) 10 | self.momentum = momentum 11 | 12 | def preprocess_data(self, results): 13 | results = super().preprocess_data(results) 14 | return results 15 | 16 | def combine_losses(self, results): 17 | results = super().combine_losses(results) 18 | return results 19 | 20 | def process_gradients(self, results): 21 | buffer = results["buffer"] 22 | idx = results["idx"] 23 | gradients_adv = results.pop("gradients_adv") 24 | 25 | # Load from buffer 26 | gradients_last_momentum = results.pop("gradients_with_momentum", None) 27 | if gradients_last_momentum is None: 28 | gradients_last_momentum = buffer.load("momentum", idx) 29 | if gradients_last_momentum is None: 30 | gradients_last_momentum = 0. 31 | gradients_with_momentum = self.momentum * gradients_last_momentum + \ 32 | gradients_adv / torch.mean(torch.abs(gradients_adv), dim=[1, 2, 3], keepdim=True) 33 | 34 | results["gradients_with_momentum"] = gradients_with_momentum 35 | results["gradients_adv"] = gradients_with_momentum 36 | return results 37 | 38 | def update_noise(self, results): 39 | results = super().update_noise(results) 40 | return results 41 | 42 | def postprocess_data(self, results): 43 | buffer = results['buffer'] 44 | idx = results['idx'] 45 | gradients_with_momentum = results['gradients_with_momentum'] 46 | buffer.dump("momentum", idx, gradients_with_momentum) 47 | return results -------------------------------------------------------------------------------- /attack/base/RRB.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from .IFGSM import IFGSM 4 | import torch 5 | import torch.nn.functional as F 6 | from torchvision.transforms.functional import rotate 7 | from attack.utils.registry import BASEATK 8 | 9 | 10 | @BASEATK.register_module() 11 | class RRB(IFGSM): 12 | def __init__(self, alpha=1.0, prob=1.0, theta=10., l_s=10, rho=1.0, s_max=1.1, sigma=4.0) -> None: 13 | super().__init__(alpha) 14 | self.prob = prob 15 | self.theta = theta 16 | self.l_s = l_s 17 | self.rho = rho 18 | self.s_max = s_max 19 | self.sigma = sigma 20 | 21 | def preprocess_data(self, results): 22 | super().preprocess_data(results) 23 | data_adv_imgs = results.get("data_adv_imgs") 24 | data_gt_bboxes = results['data_gt_bboxes'] 25 | data_adv_imgs_list = [] 26 | for i in range(2): 27 | input_diversity = [self.random_axis_rotation, self.adaptive_random_resizing][i % 2] 28 | if random.random() < self.prob: 29 | data_adv_imgs = input_diversity(data_adv_imgs, max_angle=self.theta, 30 | label_boxes=data_gt_bboxes, factor=self.rho, 31 | max_scale=self.s_max, max_pixel=self.l_s) 32 | data_adv_imgs_list.append(data_adv_imgs) 33 | data_adv_imgs = torch.cat(data_adv_imgs_list, dim=0) 34 | data_adv_imgs = self.gaussian_blur(data_adv_imgs, sigma=self.sigma) 35 | results["data_adv_imgs"] = data_adv_imgs 36 | return results 37 | 38 | def combine_losses(self, results): 39 | results = super().combine_losses(results) 40 | return results 41 | 42 | def process_gradients(self, results): 43 | results = super().process_gradients(results) 44 | return results 45 | 46 | def update_noise(self, results): 47 | results = super().update_noise(results) 48 | return results 49 | 50 | def postprocess_data(self, results): 51 | results = super().postprocess_data(results) 52 | return results 53 | 54 | @staticmethod 55 | def random_axis_rotation(imgs, theta=10., l_s=10, label_boxes=None, **kwargs): 56 | device = imgs.device 57 | result_list = [] 58 | for idx in range(imgs.shape[0]): 59 | input_tensor = imgs[idx].unsqueeze(0) 60 | 61 | # Select the rotation axis randomly 62 | boxes = label_boxes[idx % len(label_boxes)] 63 | boxes_centers = (boxes[:, :2] + boxes[:, 2:]) / 2 64 | centers = torch.cat([boxes_centers, torch.tensor([[input_tensor.shape[-2] // 2, 65 | input_tensor.shape[-1] // 2]], device=device)], dim=0) 66 | if l_s == 0: 67 | centers_with_random = centers 68 | else: 69 | centers_with_random = centers + torch.randint_like(centers, low=-l_s, high=l_s) 70 | center_x, center_y = random.choice(centers_with_random) 71 | angle = random.random() * 2 * theta - theta 72 | result = rotate(input_tensor, angle, center=[int(center_x), int(center_y)]) 73 | result_list.append(result) 74 | return torch.cat(result_list, dim=0) 75 | 76 | @staticmethod 77 | def adaptive_random_resizing(imgs, rho=1.0, s_max=1.1, label_boxes=None, **kwargs): 78 | padded_list = [] 79 | for idx in range(imgs.shape[0]): 80 | input_tensor = imgs[idx].unsqueeze(0) 81 | ori_size_h = input_tensor.shape[2] 82 | ori_size_w = input_tensor.shape[3] 83 | 84 | # Extract info of boxes 85 | boxes = label_boxes[idx % len(label_boxes)].cpu().numpy() 86 | random_box_idx = random.randint(0, len(boxes) - 1) 87 | box = boxes[random_box_idx] 88 | box_w = box[2] - box[0] 89 | box_h = box[3] - box[1] 90 | 91 | # Apply Transformations 92 | scale_h = min(1 + rho * (box_h / ori_size_h), s_max) 93 | scale_w = min(1 + rho * (box_w / ori_size_w), s_max) 94 | new_size_h = random.randint(ori_size_h, int(scale_h * ori_size_h)) 95 | new_size_w = random.randint(ori_size_w, int(scale_w * ori_size_w)) 96 | rescaled = F.interpolate(input_tensor, size=(new_size_h, new_size_w), mode='bilinear', align_corners=True) 97 | rem_h = int(scale_h * ori_size_h) - new_size_h 98 | rem_w = int(scale_w * ori_size_w) - new_size_w 99 | pad_left = random.randint(0, rem_w) 100 | pad_top = random.randint(0, rem_h) 101 | padded = F.pad(rescaled, (pad_left, rem_w - pad_left, pad_top, rem_h - pad_top), mode='constant', value=0.) 102 | padded = F.interpolate(padded, size=(ori_size_h, ori_size_w), mode='bilinear', align_corners=True) 103 | padded_list.append(padded) 104 | return torch.cat(padded_list, dim=0) 105 | 106 | @staticmethod 107 | def gaussian_blur(imgs, sigma=1.0, **kwargs): 108 | return torch.clamp(imgs + torch.randn_like(imgs) * sigma, 0., 255.) -------------------------------------------------------------------------------- /attack/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .IFGSM import IFGSM 2 | from .DI import DI 3 | from .MI import MI 4 | from .RRB import RRB 5 | -------------------------------------------------------------------------------- /attack/comparing/TransferAttack.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | class TransferAttack(object): 6 | 7 | def __init__(self) -> None: 8 | super().__init__() 9 | self.device = os.environ["device"] 10 | 11 | @staticmethod 12 | def extract_cln_features(results): 13 | model = results["model"] 14 | data_clean_imgs = results["data_clean_imgs"] 15 | normalizer = results["normalizer"] 16 | data_imgs = normalizer(data_clean_imgs) 17 | with torch.no_grad(): 18 | feat_cln = model.backbone(data_imgs) 19 | return feat_cln 20 | 21 | def preprocess_data(self, results): 22 | buffer = results["buffer"] 23 | idx = results["idx"] 24 | # extract clean features 25 | feats_cln = results.get("feats_cln", None) 26 | if feats_cln is None: 27 | feats_cln = buffer.load_or_create_and_dump("feats_cln", idx, 28 | function=self.extract_cln_features, 29 | parameters=results) 30 | results["feats_cln"] = feats_cln 31 | return results 32 | 33 | def prepare_losses(self, results): 34 | loss_params = [] 35 | feats_adv = results.pop("feats_adv") 36 | feats_cln = results["feats_cln"] 37 | # loss container 38 | num_stage = len(feats_cln) 39 | num_sample = results["num_sample"] 40 | num_groups = int(len(feats_adv[0]) / num_sample) 41 | for stage_idx in range(num_stage): 42 | for group_idx in range(num_groups): 43 | feats_adv_group = feats_adv[stage_idx][group_idx * num_sample: (group_idx + 1) * num_sample] 44 | for sample_idx in range(num_sample): 45 | loss_params.append({ 46 | "group_idx": group_idx, 47 | "stage_idx": stage_idx, 48 | "sample_idx": sample_idx, 49 | "feat_cln": feats_cln[stage_idx][sample_idx], 50 | "feat_adv": feats_adv_group[sample_idx] 51 | }) 52 | results["loss_params"] = loss_params 53 | return results 54 | 55 | def losses(self, results): 56 | return results 57 | -------------------------------------------------------------------------------- /attack/comparing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wakuwu/OSFD/3744caf69e60b46012a6895c095f18c33db491a9/attack/comparing/__init__.py -------------------------------------------------------------------------------- /attack/ours/OSFD.py: -------------------------------------------------------------------------------- 1 | from attack.comparing.TransferAttack import TransferAttack 2 | from attack.utils.registry import TSFATK 3 | 4 | import torch.nn.functional as F 5 | 6 | 7 | @TSFATK.register_module() 8 | class OSFD(TransferAttack): 9 | def __init__(self, k=3.0) -> None: 10 | TransferAttack.__init__(self) 11 | self.k = k 12 | 13 | def preprocess_data(self, results): 14 | results = super().preprocess_data(results) 15 | return results 16 | 17 | def prepare_losses(self, results): 18 | results = super().prepare_losses(results) 19 | return results 20 | 21 | def losses(self, results): 22 | loss_params = results.pop("loss_params") 23 | for param in loss_params: 24 | feat_cln = param.pop("feat_cln") 25 | feat_adv = param.pop("feat_adv") 26 | l = F.mse_loss(self.k * feat_cln, feat_adv) 27 | param["loss_item"] = l 28 | results["losses"] = loss_params 29 | return results -------------------------------------------------------------------------------- /attack/ours/__init__.py: -------------------------------------------------------------------------------- 1 | from .OSFD import OSFD 2 | -------------------------------------------------------------------------------- /attack/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wakuwu/OSFD/3744caf69e60b46012a6895c095f18c33db491a9/attack/utils/__init__.py -------------------------------------------------------------------------------- /attack/utils/buffer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import mmcv 4 | import torch 5 | from mmcv.parallel import scatter 6 | 7 | 8 | class Buffer: 9 | 10 | def __init__(self, buffer_dir) -> None: 11 | self.buffer_dir = buffer_dir 12 | # memory buffer 13 | self.buffer_dict = dict() 14 | # The buffer location corresponding to the buffer variable vname 15 | self.buffer_type_dict = dict() 16 | self.global_buffer_type = "memory" 17 | self.device = os.environ["device"] 18 | 19 | def update_buffer_type(self, vname, buffer_type): 20 | self.buffer_type_dict[vname] = buffer_type 21 | if "disk" == buffer_type: 22 | dump_dir = osp.join(self.buffer_dir, vname) 23 | mmcv.mkdir_or_exist(dump_dir) 24 | 25 | def update_buffer_types(self, buffer_type_dict): 26 | self.buffer_type_dict.update(buffer_type_dict) 27 | self.global_buffer_type = self.buffer_type_dict.get("global", "memory") 28 | 29 | def load_or_create_and_dump_var(self, vname, function=None, parameters=None): 30 | """Load from memory, if it doesn't exist, execute the function and buffer the result to memory""" 31 | content = self.buffer_dict.get(vname) 32 | if content is None and function is not None: 33 | if parameters is None: 34 | content = function() 35 | else: 36 | content = function(parameters) 37 | self.buffer_dict[vname] = content 38 | return content 39 | 40 | def load_or_create_and_dump(self, vname, vid, buffer_type=None, 41 | function=None, parameters=None, 42 | backend="torch"): 43 | content = self.load(vname, vid, backend=backend) 44 | if content is None and function is not None: 45 | if parameters is None: 46 | content = function() 47 | else: 48 | content = function(parameters) 49 | self.dump(vname, vid, content, buffer_type=buffer_type, backend=backend) 50 | return content 51 | 52 | def dump(self, vname, vid, vcontent, buffer_type=None, backend="torch"): 53 | vid = str(vid) 54 | if buffer_type is None: 55 | buffer_type = self.buffer_type_dict.get(vname, self.global_buffer_type) 56 | self.update_buffer_type(vname, buffer_type) 57 | vcontent = self.scatter_to_cpu(vcontent) 58 | if "memory" == buffer_type: 59 | buffer = self.buffer_dict.get(vname, dict()) 60 | buffer[vid] = vcontent 61 | self.buffer_dict[vname] = buffer 62 | elif "disk" == buffer_type: 63 | dump_dir = osp.join(self.buffer_dir, vname) 64 | if "torch" == backend: 65 | torch.save(vcontent, osp.join(dump_dir, vid + ".pth")) 66 | return 67 | 68 | def load(self, vname, vids, backend="torch"): 69 | contents = [] 70 | buffer_type = self.buffer_type_dict.get(vname, self.global_buffer_type) 71 | if not isinstance(vids, list): 72 | vids = [vids] 73 | vids = [str(vid) for vid in vids] 74 | for vid in vids: 75 | if "memory" == buffer_type: 76 | buffer = self.buffer_dict.get(vname, dict()) 77 | content = buffer.get(vid, None) 78 | contents.append(content) 79 | elif "disk" == buffer_type: 80 | dump_dir = osp.join(self.buffer_dir, vname) 81 | if "torch" == backend: 82 | content_fp = osp.join(dump_dir, vid + ".pth") 83 | if osp.exists(content_fp): 84 | contents.append(torch.load(content_fp)) 85 | else: 86 | contents.append(None) 87 | contents = scatter(contents, [0])[0] 88 | if len(vids) == 1: 89 | return contents[0] 90 | return contents 91 | 92 | @staticmethod 93 | def scatter_to_cpu(inputs): 94 | if isinstance(inputs, torch.Tensor): 95 | return inputs.detach().cpu() 96 | elif isinstance(inputs, (list, tuple)): 97 | if isinstance(inputs, tuple): 98 | inputs = list(inputs) 99 | return [Buffer.scatter_to_cpu(x) for x in inputs] 100 | elif isinstance(inputs, dict): 101 | for k, v in inputs.items(): 102 | inputs[k] = Buffer.scatter_to_cpu(v) 103 | return inputs -------------------------------------------------------------------------------- /attack/utils/mmdet.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import os 3 | import os.path as osp 4 | import functools as ft 5 | import mmcv 6 | import torch 7 | from torchvision import transforms 8 | from mmcv.parallel import scatter 9 | from mmcv.runner import load_checkpoint 10 | from mmdet.core import get_classes, encode_mask_results 11 | from mmdet.models import build_detector 12 | from mmdet.datasets import build_dataset, build_dataloader, replace_ImageToTensor 13 | 14 | 15 | def init_cfg(config, dataset_cfg): 16 | if isinstance(config, (str, Path)): 17 | config = mmcv.Config.fromfile(config) 18 | if 'pretrained' in config.model: 19 | config.model.pretrained = None 20 | elif 'init_cfg' in config.model.backbone: 21 | config.model.backbone.init_cfg = None 22 | # modify the dataset settings 23 | config.data.test.ann_file = osp.join(os.environ["project_path"], dataset_cfg["ann_file"]) 24 | config.data.test.img_prefix = osp.join(os.environ["project_path"], dataset_cfg["img_prefix"]) 25 | return config 26 | 27 | 28 | def init_detector(config, checkpoint=None, device='cuda:0'): 29 | """Initialize a detector from config file.""" 30 | model = build_detector(config.model, train_cfg=config.get('train_cfg'), test_cfg=config.get('test_cfg')) 31 | if checkpoint is not None: 32 | checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') 33 | if 'CLASSES' in checkpoint.get('meta', {}): 34 | model.CLASSES = checkpoint['meta']['CLASSES'] 35 | else: 36 | model.CLASSES = get_classes('coco') 37 | model.cfg = config 38 | model.to(device) 39 | model.eval() 40 | return model 41 | 42 | 43 | def init_dataloader(cfg, samples_per_gpu=1, workers_per_gpu=2, persistent_workers=False): 44 | test_dataloader_default_args = dict( 45 | samples_per_gpu=samples_per_gpu, workers_per_gpu=workers_per_gpu, 46 | dist=False, shuffle=False, persistent_workers=persistent_workers) 47 | # in case the test dataset is concatenated 48 | if isinstance(cfg.data.test, dict): 49 | cfg.data.test.test_mode = True 50 | if samples_per_gpu > 1: 51 | # Replace 'ImageToTensor' to 'DefaultFormatBundle' 52 | cfg.data.test.pipeline = replace_ImageToTensor( 53 | cfg.data.test.pipeline) 54 | test_loader_cfg = { 55 | **test_dataloader_default_args, 56 | **cfg.data.get('test_dataloader', {}) 57 | } 58 | # build the dataloader 59 | dataset = build_dataset(cfg.data.test) 60 | data_loader = build_dataloader(dataset, **test_loader_cfg) 61 | return data_loader 62 | 63 | 64 | def parse_data(data): 65 | """parse the data container in mmdet""" 66 | data_imgs = data.get("img", [None])[0] 67 | data_img_metas = data.get("img_metas", [[None]])[0] 68 | data_gt_bboxes = data.get("gt_bboxes", [[None]])[0] 69 | data_gt_labels = data.get("gt_labels", [[None]])[0] 70 | return data_imgs, data_img_metas, data_gt_bboxes, data_gt_labels 71 | 72 | 73 | def get_normalize_tools(data_img_meta): 74 | """clean_imgs always in rgb channel""" 75 | device = os.environ["device"] 76 | # denormalized 77 | img_norm_cfg = data_img_meta.get("img_norm_cfg") 78 | mean = img_norm_cfg.get("mean") 79 | std = img_norm_cfg.get("std") 80 | to_rgb = img_norm_cfg.get("to_rgb") 81 | normalizer = ft.partial(imnormalize, mean=torch.from_numpy(mean).to(device), 82 | std=torch.from_numpy(std).to(device), to_rgb=to_rgb) 83 | denormalizer = ft.partial(imdenormalize, mean=torch.from_numpy(mean).to(device), 84 | std=torch.from_numpy(std).to(device), to_rgb=to_rgb) 85 | return normalizer, denormalizer 86 | 87 | 88 | def imnormalize(img, mean, std, to_rgb): 89 | """ 90 | Normalize an image with mean and std. 91 | Must convert the img to bgr first. 92 | """ 93 | imgs = img 94 | if not to_rgb: 95 | if img.ndim == 4: 96 | imgs = img[:, [2, 1, 0], ...] 97 | elif img.ndim == 3: 98 | imgs = img[[2, 1, 0], ...] 99 | imgs = torch.div(torch.sub(imgs, mean[..., None, None]), std[..., None, None]) 100 | return imgs 101 | 102 | 103 | def imdenormalize(img, mean, std, to_rgb): 104 | """Denormalize an image with mean and std.""" 105 | imgs = torch.mul(img, std[..., None, None]) + mean[..., None, None] 106 | if not to_rgb: 107 | if imgs.ndim == 4: 108 | imgs = imgs[:, [2, 1, 0], ...] 109 | elif imgs.ndim == 3: 110 | imgs = imgs[[2, 1, 0], ...] 111 | return imgs 112 | 113 | 114 | def load_noise_for_eval(buffer, buffer_batch_size, sample_idx_range): 115 | batch_idx_start = sample_idx_range[0] // buffer_batch_size 116 | batch_idx_end = sample_idx_range[1] // buffer_batch_size 117 | idx_bias = sample_idx_range[0] % buffer_batch_size 118 | idx_length = sample_idx_range[1] - sample_idx_range[0] + 1 119 | idxes = [str(i) for i in range(batch_idx_start, batch_idx_end + 1)] 120 | noises = buffer.load("noise_current", idxes) 121 | if isinstance(noises, list) or isinstance(noises, tuple): 122 | noises = torch.cat(noises, dim=0) 123 | noises = noises[idx_bias:idx_bias + idx_length] 124 | return noises 125 | 126 | 127 | def single_gpu_test(model, 128 | data_loader, 129 | func_load_noise=None, 130 | mode="clean"): 131 | model.eval() 132 | results = [] 133 | dataset = data_loader.dataset 134 | prog_bar = mmcv.ProgressBar(len(dataset)) 135 | normalizer, denormalizer, resizer = None, None, None 136 | eval_batch_size = data_loader.batch_size 137 | for i, data in enumerate(data_loader): 138 | # scatter data to gpus 139 | data = scatter(data, [0])[0] 140 | 141 | # Load Noise 142 | if not "clean" == mode and func_load_noise is not None: 143 | data_imgs, data_img_metas, _, _ = parse_data(data) 144 | if normalizer is None or denormalizer is None: 145 | normalizer, denormalizer = get_normalize_tools(data_img_metas[0]) 146 | data_clean_imgs = denormalizer(data_imgs) 147 | if resizer is None: 148 | resizer = transforms.Resize(data_clean_imgs.shape[-2:], antialias=True) 149 | sample_idx_range = (i * eval_batch_size, i * eval_batch_size + len(data_imgs) - 1) 150 | noises = func_load_noise(sample_idx_range=sample_idx_range) 151 | noises_resized = resizer.forward(noises) 152 | # Quantization 153 | data_adv_imgs = torch.round(data_clean_imgs + noises_resized).float() 154 | data_adv_imgs = normalizer(torch.clamp(data_adv_imgs, min=0., max=255.)) 155 | data["img"][0] = data_adv_imgs 156 | 157 | with torch.no_grad(): 158 | result = model(return_loss=False, rescale=True, **data) 159 | 160 | # encode mask results 161 | if isinstance(result[0], tuple): 162 | result = [(bbox_results, encode_mask_results(mask_results)) 163 | for bbox_results, mask_results in result] 164 | 165 | batch_size = len(result) 166 | results.extend(result) 167 | 168 | for _ in range(batch_size): 169 | prog_bar.update() 170 | return results 171 | -------------------------------------------------------------------------------- /attack/utils/pipelines.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from attack.utils.mmdet import parse_data, get_normalize_tools 4 | 5 | 6 | class PreProcessData: 7 | def __init__(self, base_attack, transfer_attack) -> None: 8 | self.base_priority = ["DI", "RRB", "MI", "IFGSM"] 9 | self.device = os.environ["device"] 10 | self.base_attack = base_attack 11 | self.transfer_attack = transfer_attack 12 | 13 | @staticmethod 14 | def init_data_container(results): 15 | buffer = results["buffer"] 16 | data = results.pop("data", None) 17 | if data is None: 18 | return results 19 | data_imgs, data_img_metas, data_gt_bboxes, data_gt_labels = parse_data(data) 20 | normalizer, denormalizer = \ 21 | buffer.load_or_create_and_dump_var("normalize_tools", get_normalize_tools, data_img_metas[0]) 22 | data_clean_imgs = denormalizer(data_imgs) 23 | results["data_img_metas"] = data_img_metas 24 | results["data_gt_bboxes"] = data_gt_bboxes 25 | results["data_gt_labels"] = data_gt_labels 26 | results["normalizer"] = normalizer 27 | results["data_clean_imgs"] = data_clean_imgs 28 | results["num_sample"] = len(data_clean_imgs) 29 | return results 30 | 31 | def init_noise(self, results): 32 | buffer = results["buffer"] 33 | idx = results["idx"] 34 | noise = results.get("noise", None) 35 | if noise is None: 36 | noise_dir = os.environ["noise_dir"] 37 | noise = buffer.load(noise_dir, idx) 38 | if noise is None: 39 | noise = torch.randint_like(results["data_clean_imgs"], low=-2, high=3).float().to(self.device) 40 | noise.requires_grad = True 41 | # Init the gradients of noise 42 | results["noise"] = noise 43 | return results 44 | 45 | @staticmethod 46 | def init_attack_forward(results): 47 | model = results["model"] 48 | normalizer = results["normalizer"] 49 | data_adv_imgs = results["data_adv_imgs"] 50 | feats_adv = model.backbone(normalizer(torch.clamp(data_adv_imgs, min=0., max=255.))) 51 | results["feats_adv"] = feats_adv 52 | return results 53 | 54 | def __call__(self, results): 55 | results = self.init_data_container(results) 56 | # Init Noise 57 | results = self.init_noise(results) 58 | # Base Attack 59 | for method in self.base_priority: 60 | if method in self.base_attack.keys(): 61 | method = self.base_attack[method] 62 | results = method.preprocess_data(results) 63 | # Tsf Attack 64 | for method in self.transfer_attack.keys(): 65 | method = self.transfer_attack[method] 66 | results = method.preprocess_data(results) 67 | # Init features of adv images 68 | results = self.init_attack_forward(results) 69 | return results 70 | 71 | 72 | class CalculateLoss: 73 | def __init__(self, base_attack, transfer_attack) -> None: 74 | self.base_priority = ["MI", "DI", "RRB", "IFGSM"] 75 | self.base_attack = base_attack 76 | self.transfer_attack = transfer_attack 77 | 78 | def __call__(self, results): 79 | # Tsf Attack 80 | for method in self.transfer_attack.keys(): 81 | method = self.transfer_attack[method] 82 | results = method.prepare_losses(results) 83 | results = method.losses(results) 84 | # Combine losses 85 | for method in self.base_priority: 86 | if method in self.base_attack.keys(): 87 | method = self.base_attack[method] 88 | results = method.combine_losses(results) 89 | # Backward to get gradients 90 | loss_combined = results["loss_combined"] 91 | loss_combined.backward() 92 | return results 93 | 94 | 95 | class UpdateNoise: 96 | def __init__(self, base_attack) -> None: 97 | self.base_gradients_priority = ["MI", "DI", "RRB", "IFGSM"] 98 | self.base_update_priority = ["MI", "DI", "RRB", "IFGSM"] 99 | self.base_attack = base_attack 100 | 101 | @torch.no_grad() 102 | def __call__(self, results): 103 | with torch.no_grad(): 104 | results["gradients_adv"] = results["noise"].grad 105 | # Handle gradients 106 | for method in self.base_gradients_priority: 107 | if method in self.base_attack.keys(): 108 | method = self.base_attack[method] 109 | results = method.process_gradients(results) 110 | # Update noise 111 | for method in self.base_update_priority: 112 | if method in self.base_attack.keys(): 113 | method = self.base_attack[method] 114 | results = method.update_noise(results) 115 | return results 116 | 117 | 118 | class PostProcessData: 119 | def __init__(self, base_attack) -> None: 120 | super().__init__() 121 | self.base_attack = base_attack 122 | 123 | @staticmethod 124 | def buffer_noise(results): 125 | buffer = results["buffer"] 126 | idx = results["idx"] 127 | noise = results["noise"] 128 | buffer.dump("noise_current", idx, noise) 129 | return results 130 | 131 | def __call__(self, results): 132 | for method in self.base_attack.keys(): 133 | method = self.base_attack[method] 134 | results = method.postprocess_data(results) 135 | results = self.buffer_noise(results) 136 | return results 137 | -------------------------------------------------------------------------------- /attack/utils/registry.py: -------------------------------------------------------------------------------- 1 | from mmcv import Registry 2 | 3 | TSFATK = Registry("TSFATK") 4 | BASEATK = Registry("BASEATK") 5 | -------------------------------------------------------------------------------- /config/attack_faster_rcnn.yaml: -------------------------------------------------------------------------------- 1 | base: "config/base.yaml" 2 | 3 | global: 4 | project_path: "/home/{===Set Here===}/OSFD/" 5 | buffer: 6 | global: "disk" 7 | 8 | options: 9 | eval_clean: false 10 | 11 | saving_settings: 12 | logging: true 13 | noise: false 14 | adv_img: false 15 | others: true 16 | best_black: true 17 | 18 | attack: 19 | source: "faster_rcnn_r101_caffe_fpn_mstrain_3x_coco" 20 | steps: 10 21 | max_epoch: 20 22 | epsilon: 5 23 | method: 24 | base_attack: ['MI', 'RRB'] 25 | transfer_attack: "OSFD" 26 | method_settings: 27 | OSFD: 28 | k: 3.0 29 | RRB: 30 | theta: 7. 31 | l_s: 10 32 | rho: 0.8 33 | s_max: 1.10 34 | sigma: 6.0 35 | dataset: 36 | ann_file: "data/Voc12_CoCo_800_2000/annotations/instances_train2017.json" 37 | img_prefix: "data/Voc12_CoCo_800_2000/train2017/" 38 | dataloader: 39 | batch_size: 1 40 | persistent_workers: false 41 | cpu_num: 2 42 | eval_cpu_num: 2 43 | -------------------------------------------------------------------------------- /config/attack_swin.yaml: -------------------------------------------------------------------------------- 1 | base: "config/base.yaml" 2 | 3 | global: 4 | project_path: "/home/{===Set Here===}/OSFD/" 5 | buffer: 6 | global: "disk" 7 | 8 | options: 9 | eval_clean: false 10 | 11 | saving_settings: 12 | logging: true 13 | noise: false 14 | adv_img: false 15 | others: true 16 | best_black: true 17 | 18 | attack: 19 | source: "mask_rcnn_swin-t-p4-w7_fpn_ms-crop-3x_coco" 20 | steps: 10 21 | max_epoch: 20 22 | epsilon: 5 23 | method: 24 | base_attack: ['MI', 'RRB'] 25 | transfer_attack: "OSFD" 26 | method_settings: 27 | OSFD: 28 | k: 3.0 29 | RRB: 30 | theta: 7. 31 | l_s: 10 32 | rho: 0.8 33 | s_max: 1.10 34 | sigma: 6.0 35 | dataset: 36 | ann_file: "data/Voc12_CoCo_800_2000/annotations/instances_train2017.json" 37 | img_prefix: "data/Voc12_CoCo_800_2000/train2017/" 38 | dataloader: 39 | batch_size: 1 40 | persistent_workers: false 41 | cpu_num: 2 42 | eval_cpu_num: 2 43 | -------------------------------------------------------------------------------- /config/attack_vfnet.yaml: -------------------------------------------------------------------------------- 1 | base: "config/base.yaml" 2 | 3 | global: 4 | project_path: "/home/{===Set Here===}/OSFD/" 5 | buffer: 6 | global: "disk" 7 | 8 | options: 9 | eval_clean: false 10 | 11 | saving_settings: 12 | logging: true 13 | noise: false 14 | adv_img: false 15 | others: true 16 | best_black: true 17 | 18 | attack: 19 | source: "vfnet_r50_fpn_mstrain_2x_coco" 20 | steps: 10 21 | max_epoch: 20 22 | epsilon: 5 23 | method: 24 | base_attack: ['MI', 'RRB'] 25 | transfer_attack: "OSFD" 26 | method_settings: 27 | OSFD: 28 | k: 3.0 29 | RRB: 30 | theta: 7. 31 | l_s: 10 32 | rho: 0.8 33 | s_max: 1.10 34 | sigma: 6.0 35 | dataset: 36 | ann_file: "data/Voc12_CoCo_800_2000/annotations/instances_train2017.json" 37 | img_prefix: "data/Voc12_CoCo_800_2000/train2017/" 38 | dataloader: 39 | batch_size: 1 40 | persistent_workers: false 41 | cpu_num: 2 42 | eval_cpu_num: 2 43 | -------------------------------------------------------------------------------- /config/attack_yolov3.yaml: -------------------------------------------------------------------------------- 1 | base: "config/base.yaml" 2 | 3 | global: 4 | project_path: "/home/{===Set Here===}/OSFD/" 5 | buffer: 6 | global: "disk" 7 | 8 | options: 9 | eval_clean: false 10 | 11 | saving_settings: 12 | logging: true 13 | noise: false 14 | adv_img: false 15 | others: true 16 | best_black: true 17 | 18 | attack: 19 | source: "yolov3_d53_mstrain-608_273e_coco" 20 | steps: 10 21 | max_epoch: 20 22 | epsilon: 5 23 | method: 24 | base_attack: ['MI', 'RRB'] 25 | transfer_attack: "OSFD" 26 | method_settings: 27 | OSFD: 28 | k: 3.0 29 | RRB: 30 | theta: 7. 31 | l_s: 10 32 | rho: 0.8 33 | s_max: 1.10 34 | sigma: 6.0 35 | dataset: 36 | ann_file: "data/Voc12_CoCo_800_2000/annotations/instances_train2017.json" 37 | img_prefix: "data/Voc12_CoCo_800_2000/train2017/" 38 | dataloader: 39 | batch_size: 1 40 | persistent_workers: false 41 | cpu_num: 2 42 | eval_cpu_num: 2 43 | -------------------------------------------------------------------------------- /config/base.yaml: -------------------------------------------------------------------------------- 1 | global: 2 | # ============ Absolute path to the OSFD project directory. ============ 3 | project_path: "/home/.../OSFD/" 4 | 5 | # ============ Default device, cannot be modified. ============ 6 | device: "cuda:0" 7 | 8 | # ============ Random seed for all. ============ 9 | seed: 2023 10 | 11 | # ============ Cache intermediate variables to reduce the number of calculations. ============ 12 | buffer: 13 | # [disk | memory] 14 | # noise / amplification / momentum / grad_cln_ref / global 15 | # ============ global defines the default cache location for all variables, ============ 16 | # ============ or you can specify a single variable cache location. ============ 17 | global: "memory" 18 | 19 | options: 20 | # ============ Whether to test benign sample raw mAP metrics. ============ 21 | eval_clean: true 22 | 23 | # ============ Whether to start debug mode or not, no adversarial examples will be stored. ============ 24 | debug: false 25 | 26 | saving_settings: 27 | # ============ Whether to record terminal output to the log. ============ 28 | logging: true 29 | 30 | # ============ Whether to enable tensorboard. ============ 31 | tboard: false 32 | 33 | # ============ Whether to save noise tensor. ============ 34 | noise: true 35 | 36 | # ============ Whether to save adversarial images. ============ 37 | adv_img: true 38 | 39 | # ============ Whether to save other files, such as each step loss function, etc. ============ 40 | others: true 41 | 42 | # ============ Whether to save the result of the white-box attack when mAP is 0, if reached. ============ 43 | best_white: false 44 | 45 | # ============ Whether to save black box best migration results, default true. ============ 46 | best_black: true 47 | 48 | # ============ Prefix and suffix of the result saving directory for easy differentiation of experimental results. ============ 49 | saving_dir: 50 | prefix: "" 51 | suffix: "" 52 | 53 | models: 54 | train_cfg_dir: "ummdet/checkpoints/train_cfg/" 55 | eval_cfg_dir: "ummdet/checkpoints/eval_cfg/" 56 | models_dir: "ummdet/checkpoints/models/" 57 | 58 | # ============ All model config files with checkpoints names. ============ 59 | detectors: 60 | - name: "yolov3_d53_mstrain-608_273e_coco" 61 | ckpt: "yolov3_d53_mstrain-608_273e_coco_20210518_115020-a2c3acb8" 62 | - name: "yolof_r50_c5_8x8_1x_coco" 63 | ckpt: "yolof_r50_c5_8x8_1x_coco_20210425_024427-8e864411" 64 | - name: "yolox_l_8x8_300e_coco" 65 | ckpt: "yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23" 66 | - name: "fcos_r50_caffe_fpn_gn-head_mstrain_640-800_2x_coco" 67 | ckpt: "fcos_r50_caffe_fpn_gn-head_mstrain_640-800_2x_coco-d92ceeea" 68 | - name: "faster_rcnn_r101_caffe_fpn_mstrain_3x_coco" 69 | ckpt: "faster_rcnn_r101_caffe_fpn_mstrain_3x_coco_20210526_095742-a7ae426d" 70 | - name: "vfnet_r50_fpn_mstrain_2x_coco" 71 | ckpt: "vfnet_r50_fpn_mstrain_2x_coco_20201027-7cc75bd2" 72 | - name: "mask_rcnn_swin-t-p4-w7_fpn_ms-crop-3x_coco" 73 | ckpt: "mask_rcnn_swin-t-p4-w7_fpn_ms-crop-3x_coco_20210906_131725-bacf6f7b" 74 | - name: "detr_r50_8x2_150e_coco" 75 | ckpt: "detr_r50_8x2_150e_coco_20201130_194835-2c4b8974" 76 | 77 | attack: 78 | default_cfg: "config/default.py" 79 | # ============ Number of steps per epoch attack. ============ 80 | steps: 10 81 | 82 | # ============ Maximum number of epochs, each epoch completed will test mAP metrics on all models. ============ 83 | max_epoch: 20 84 | 85 | # ============ Infinite norm bound size for adversarial attacks. ============ 86 | max_norm: 5 87 | method: 88 | # ============ Integrated base attack methods, either any individual method or any combination of methods except IFGSM. ============ 89 | base_attack: ['IFGSM', 'MI', 'DI', 'RRB'] 90 | # ============ Transferable Adversarial Attacks. ============ 91 | transfer_attack: "OSFD" 92 | method_settings: 93 | null 94 | 95 | # ============ Path of the dataset to be attacked (coco format) images with labels. ============ 96 | dataset: 97 | ann_file: "data/coco/annotations/instances_train2017.json" 98 | img_prefix: "data/coco/train2017/" 99 | dataloader: 100 | # ============ The size of the image batch for each simultaneous attack, ============ 101 | # ============ which slightly affects the mAP metric result, ============ 102 | # ============ is set to 1 or the same size by default to get a fair comparison. ============ 103 | batch_size: 1 104 | 105 | # ============ Test stage batch size, arbitrary setting, no effect on results. ============ 106 | eval_batch_size: 15 107 | 108 | persistent_workers: false 109 | # ============ Number of processes in the dataloader at the time of the attack. ============ 110 | cpu_num: 0 111 | # ============ Number of processes in the dataloader at the time of the test. ============ 112 | eval_cpu_num: 0 113 | eval_cfg: 114 | metric: [ "bbox" ] 115 | metric_items: [ "mAP", "mAP_50", "mAP_75", "AR@100", "AR@300", "AR@1000" ] 116 | classwise: true 117 | -------------------------------------------------------------------------------- /config/default.py: -------------------------------------------------------------------------------- 1 | base_attack = dict( 2 | IFGSM=dict( 3 | type="IFGSM", 4 | alpha=1.0 5 | ), 6 | MI=dict( 7 | type="MI", 8 | alpha=1.0, 9 | momentum=1.0 10 | ), 11 | DI=dict( 12 | type="DI", 13 | alpha=1.0, 14 | prob=1.0, 15 | scale=1.1 16 | ), 17 | RRB=dict( 18 | type="RRB", 19 | alpha=1.0, 20 | prob=1.0, 21 | theta=7., 22 | l_s=10, 23 | rho=0.8, 24 | s_max=1.1, 25 | sigma=6.0 26 | ) 27 | ) 28 | 29 | transfer_attack = dict( 30 | OSFD=dict( 31 | type="OSFD", 32 | k=3.0 33 | ) 34 | ) 35 | -------------------------------------------------------------------------------- /run_perturb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import mmcv 4 | import torch 5 | import logging 6 | import argparse 7 | import traceback 8 | 9 | from attack.Attack import Attack 10 | from tools.project.config import ConfigYaml 11 | from tools.project.logger import Logger 12 | from tools.project.recorder import RecorderMemory, RecorderDisk 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description="Generate perturb images.") 17 | parser.add_argument('config', default="config/attack_yolov3.yaml", help='attack config file path') 18 | return parser.parse_args() 19 | 20 | 21 | def run_main(): 22 | # load config 23 | yml_cfg = ConfigYaml.build(parse_args().config) 24 | # Init Logger 25 | logger = Logger(yml_cfg, parse_args().config) 26 | 27 | # Init Recorder 28 | if "disk" == yml_cfg.buffer.get("nosie", yml_cfg.buffer["global"]): 29 | recorder = RecorderDisk(yml_cfg) 30 | else: 31 | recorder = RecorderMemory(yml_cfg) 32 | 33 | # Init Attack 34 | attack = Attack(yml_cfg) 35 | 36 | # eval clean mAP 37 | if yml_cfg.options.get("eval_clean"): 38 | clean_metric_dict = attack.eval(mode="clean") 39 | recorder.record_clean(clean_metric_dict) 40 | logger.logging_clean_metric(recorder) 41 | 42 | # main loop 43 | try: 44 | for epoch in range(yml_cfg.attack_cfg.get("max_epoch", 100)): 45 | losses = attack.attack_epoch() 46 | metric_dict = attack.eval(mode="noise") 47 | recorder.update_epoch(epoch, buffer=attack.buffer, 48 | loss_step_list=losses, metric_dict=metric_dict, 49 | samples_num=len(attack.dataloader.dataset)) 50 | logger.logging_epoch(epoch, recorder=recorder) 51 | except Exception as e: 52 | mmcv.print_log(f"##### PID {os.getpid()} exit: error. #####", logger=mmcv.get_logger("verbose_logger"), level=logging.INFO) 53 | mmcv.print_log(str(e.args), logger=mmcv.get_logger("error_logger"), level=logging.ERROR) 54 | mmcv.print_log(traceback.format_exc(), logger=mmcv.get_logger("error_logger"), level=logging.ERROR) 55 | save_exit(logger=logger, recorder=recorder, attack=attack, save_type="best_black", exit_flag=True, exit_code=-1) 56 | save_exit(logger=logger, recorder=recorder, attack=attack, save_type="best_black", exit_flag=True, exit_code=0) 57 | 58 | 59 | def save_exit(logger, recorder, attack, save_type="best_black", exit_flag=False, exit_code=0): 60 | logger.saving_results(recorder=recorder, attack=attack, save_type=save_type) 61 | if exit_flag: 62 | logger.close_logger() 63 | torch.cuda.empty_cache() 64 | sys.exit(exit_code) 65 | 66 | 67 | if __name__ == '__main__': 68 | run_main() 69 | 70 | -------------------------------------------------------------------------------- /tools/datasets/split_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import mmcv 5 | import numpy as np 6 | import os.path as osp 7 | 8 | random_seed = 2023 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser(description="Construct a subset for attack.") 13 | parser.add_argument('--voc_type', default="VOC2012", help='voc dataset type') 14 | parser.add_argument('-n', '--num', type=int, default=2000, help='number of examples') 15 | parser.add_argument('--paper', type=bool, default=False, help='similar dataset for paper') 16 | return parser.parse_args() 17 | 18 | 19 | if __name__ == '__main__': 20 | args = parse_args() 21 | VOC_subset_num = args.num 22 | VOCdevkit_dir = "./data/VOCdevkit" 23 | VOC_type = args.voc_type 24 | VOC_subset_path = VOCdevkit_dir 25 | VOC_subset_type = f"{VOC_type}_{VOC_subset_num}" 26 | 27 | if args.paper: 28 | subset_index = mmcv.list_from_file("data/image_index.txt") 29 | else: 30 | # load split list 31 | split_list = np.array(mmcv.list_from_file(osp.join(VOCdevkit_dir, VOC_type, "ImageSets/Main/trainval.txt"))) 32 | # construct the subset 33 | np.random.seed(random_seed) 34 | shuffled_index = np.random.permutation(len(split_list)) 35 | subset_index = split_list[shuffled_index[0: VOC_subset_num]].tolist() 36 | subset_index.sort() 37 | # 38 | split_save_path = osp.join(VOC_subset_path, VOC_subset_type, "ImageSets/Main/") 39 | os.makedirs(split_save_path, exist_ok=True) 40 | with open(osp.join(split_save_path, "trainval.txt"), 'w') as f: 41 | for idx in subset_index: 42 | f.writelines(idx+"\n") 43 | # copy images and labels 44 | image_save_path = osp.join(VOC_subset_path, VOC_subset_type, "JPEGImages/") 45 | label_save_path = osp.join(VOC_subset_path, VOC_subset_type, "Annotations/") 46 | os.makedirs(image_save_path, exist_ok=True) 47 | os.makedirs(label_save_path, exist_ok=True) 48 | for idx in mmcv.track_iter_progress(subset_index): 49 | image_name = idx + ".jpg" 50 | shutil.copyfile(osp.join(VOCdevkit_dir, VOC_type, "JPEGImages/", image_name), 51 | osp.join(image_save_path, image_name)) 52 | label_name = idx + ".xml" 53 | shutil.copyfile(osp.join(VOCdevkit_dir, VOC_type, "Annotations/", label_name), 54 | osp.join(label_save_path, label_name)) -------------------------------------------------------------------------------- /tools/datasets/voc12_to_coco.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | 4 | import mmcv 5 | import numpy as np 6 | import argparse 7 | import imgaug.augmenters as iaa 8 | import xml.etree.ElementTree as ET 9 | import functools as ft 10 | from PIL import Image 11 | 12 | VOC_type = "" 13 | VOCdevkit_dir = "" 14 | CoCo_output_dir = "" 15 | 16 | voc_extend_name_list = [ 17 | 'person', 'bicycle', 'car', 'motorbike', 'aeroplane', 'bus', 'train', 18 | 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 19 | 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 20 | 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 21 | 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 22 | 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 23 | 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 24 | 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 25 | 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 26 | 'sofa', 'pottedplant', 'bed', 'diningtable', 'toilet', 'tvmonitor', 27 | 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 28 | 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 29 | 'scissors', 'teddy bear', 'hair drier', 'toothbrush' 30 | ] 31 | coco_name_list = [ 32 | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 33 | 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 34 | 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 35 | 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 36 | 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 37 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 38 | 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 39 | 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 40 | 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 41 | 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 42 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 43 | 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 44 | 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 45 | 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' 46 | ] 47 | name_dict = dict(zip(voc_extend_name_list, coco_name_list)) 48 | 49 | 50 | def coco_classes(): 51 | return [ 52 | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 53 | 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 54 | 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 55 | 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 56 | 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 57 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 58 | 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 59 | 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 60 | 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 61 | 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 62 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 63 | 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 64 | 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 65 | 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' 66 | ] 67 | 68 | 69 | label_ids = {name: i + 1 for i, name in enumerate(coco_classes())} 70 | 71 | 72 | def box_modified(width, height, bbox: list, pad_size=800): 73 | big = max(width, height) 74 | x1, y1, x2, y2 = bbox 75 | # Normalize 76 | xc_n = (x1 + x2 + big - width) / (2. * big) 77 | yc_n = (y1 + y2 + big - height) / (2. * big) 78 | ww_n = max((x2 - x1) / (1.0 * big), 0.) 79 | hh_n = max((y2 - y1) / (1.0 * big), 0.) 80 | xmin = (xc_n - ww_n / 2) * pad_size 81 | ymin = (yc_n - hh_n / 2) * pad_size 82 | xmax = (xc_n + ww_n / 2) * pad_size 83 | ymax = (yc_n + hh_n / 2) * pad_size 84 | return [xmin, ymin, xmax, ymax] 85 | 86 | 87 | def parse_xml(args, pad_size=800): 88 | xml_path, img_path = args 89 | tree = ET.parse(xml_path) 90 | root = tree.getroot() 91 | size = root.find('size') 92 | w = int(size.find('width').text) 93 | h = int(size.find('height').text) 94 | bboxes = [] 95 | labels = [] 96 | bboxes_ignore = [] 97 | labels_ignore = [] 98 | for obj in root.findall('object'): 99 | name = obj.find('name').text 100 | new_name = name_dict.get(str(name)) 101 | obj.find('name').text = new_name 102 | label = label_ids[new_name] 103 | difficult = int(obj.find('difficult').text) 104 | bnd_box = obj.find('bndbox') 105 | bbox = [ 106 | int(bnd_box.find('xmin').text), 107 | int(bnd_box.find('ymin').text), 108 | int(bnd_box.find('xmax').text), 109 | int(bnd_box.find('ymax').text) 110 | ] 111 | bbox = box_modified(width=w, height=h, bbox=bbox, pad_size=pad_size) 112 | if difficult: 113 | bboxes_ignore.append(bbox) 114 | labels_ignore.append(label) 115 | else: 116 | bboxes.append(bbox) 117 | labels.append(label) 118 | if not bboxes: 119 | bboxes = np.zeros((0, 4)) 120 | labels = np.zeros((0,)) 121 | else: 122 | bboxes = np.array(bboxes, ndmin=2) - 1 123 | labels = np.array(labels) 124 | if not bboxes_ignore: 125 | bboxes_ignore = np.zeros((0, 4)) 126 | labels_ignore = np.zeros((0,)) 127 | else: 128 | bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1 129 | labels_ignore = np.array(labels_ignore) 130 | annotation = { 131 | 'filename': img_path, 132 | 'width': pad_size, 133 | 'height': pad_size, 134 | 'ann': { 135 | 'bboxes': bboxes.astype(np.float32), 136 | 'labels': labels.astype(np.int64), 137 | 'bboxes_ignore': bboxes_ignore.astype(np.float32), 138 | 'labels_ignore': labels_ignore.astype(np.int64) 139 | } 140 | } 141 | return annotation 142 | 143 | 144 | def cvt_annotations(split, pad_size): 145 | annotations = [] 146 | img_names = split 147 | xml_paths = [ 148 | osp.join(VOCdevkit_dir, VOC_type, f'Annotations/{img_name}.xml') 149 | for img_name in img_names 150 | ] 151 | img_paths = [ 152 | f'JPEGImages/{img_name}.png' for img_name in img_names 153 | ] 154 | parse_xml_ft = ft.partial(parse_xml, pad_size=pad_size) 155 | part_annotations = mmcv.track_progress(parse_xml_ft, 156 | list(zip(xml_paths, img_paths))) 157 | annotations.extend(part_annotations) 158 | return annotations 159 | 160 | 161 | def cvt_to_coco_json(annotations): 162 | image_id = 0 163 | annotation_id = 0 164 | coco = dict() 165 | coco['images'] = [] 166 | coco['type'] = 'instance' 167 | coco['categories'] = [] 168 | coco['annotations'] = [] 169 | image_set = set() 170 | 171 | def addAnnItem(annotation_id, image_id, category_id, bbox, difficult_flag): 172 | annotation_item = dict() 173 | annotation_item['segmentation'] = [] 174 | 175 | seg = [] 176 | # bbox[] is x1,y1,x2,y2 177 | # left_top 178 | seg.append(int(bbox[0])) 179 | seg.append(int(bbox[1])) 180 | # left_bottom 181 | seg.append(int(bbox[0])) 182 | seg.append(int(bbox[3])) 183 | # right_bottom 184 | seg.append(int(bbox[2])) 185 | seg.append(int(bbox[3])) 186 | # right_top 187 | seg.append(int(bbox[2])) 188 | seg.append(int(bbox[1])) 189 | 190 | annotation_item['segmentation'].append(seg) 191 | 192 | xywh = np.array( 193 | [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]]) 194 | annotation_item['area'] = int(xywh[2] * xywh[3]) 195 | if difficult_flag == 1: 196 | annotation_item['ignore'] = 0 197 | annotation_item['iscrowd'] = 1 198 | else: 199 | annotation_item['ignore'] = 0 200 | annotation_item['iscrowd'] = 0 201 | annotation_item['image_id'] = int(image_id) 202 | annotation_item['bbox'] = xywh.astype(int).tolist() 203 | annotation_item['category_id'] = int(category_id) 204 | annotation_item['id'] = int(annotation_id) 205 | coco['annotations'].append(annotation_item) 206 | return annotation_id + 1 207 | 208 | for category_id, name in enumerate(coco_classes()): 209 | category_item = dict() 210 | category_item['supercategory'] = str('none') 211 | category_item['id'] = int(category_id) + 1 212 | category_item['name'] = str(name) 213 | coco['categories'].append(category_item) 214 | 215 | for ann_dict in annotations: 216 | file_name = ann_dict['filename'] 217 | ann = ann_dict['ann'] 218 | assert file_name not in image_set 219 | image_item = dict() 220 | image_item['id'] = int(image_id) 221 | image_item['file_name'] = str(file_name) 222 | image_item['height'] = int(ann_dict['height']) 223 | image_item['width'] = int(ann_dict['width']) 224 | coco['images'].append(image_item) 225 | image_set.add(file_name) 226 | 227 | bboxes = ann['bboxes'][:, :4] 228 | labels = ann['labels'] 229 | for bbox_id in range(len(bboxes)): 230 | bbox = bboxes[bbox_id] 231 | label = labels[bbox_id] 232 | annotation_id = addAnnItem( 233 | annotation_id, image_id, label, bbox, difficult_flag=0) 234 | 235 | bboxes_ignore = ann['bboxes_ignore'][:, :4] 236 | labels_ignore = ann['labels_ignore'] 237 | for bbox_id in range(len(bboxes_ignore)): 238 | bbox = bboxes_ignore[bbox_id] 239 | label = labels_ignore[bbox_id] 240 | annotation_id = addAnnItem( 241 | annotation_id, image_id, label, bbox, difficult_flag=1) 242 | 243 | image_id += 1 244 | 245 | return coco 246 | 247 | 248 | def generateImg(imgs_list: list, size=800): 249 | print("Generating images ...") 250 | voc_raw_imgs_path = osp.join(VOCdevkit_dir, VOC_type, "JPEGImages/") 251 | coco_raw_imgs_path = osp.join(CoCo_output_dir, "train2017/JPEGImages") 252 | os.makedirs(coco_raw_imgs_path, exist_ok=True) 253 | transform = iaa.Sequential([ 254 | iaa.PadToAspectRatio(1.0, position="center-center").to_deterministic() 255 | ]) 256 | 257 | for image_file in mmcv.track_iter_progress(imgs_list): 258 | # Transform raw imgs 259 | raw_image = np.array(Image.open(osp.join(voc_raw_imgs_path, image_file)).convert('RGB'), dtype=np.uint8) 260 | raw_image_transformed = Image.fromarray(transform(image=raw_image)) 261 | raw_image_transformed = raw_image_transformed.resize((size, size)) 262 | raw_image_transformed.save(osp.join(coco_raw_imgs_path, osp.splitext(image_file)[0] + '.png')) 263 | print("Generating images done.") 264 | 265 | 266 | def generateLabel(split, size=800): 267 | print("Generating labels ...") 268 | # train, val, trainval and test are same as default 269 | annotations = cvt_annotations(split, pad_size=size) 270 | annotations_json = cvt_to_coco_json(annotations) 271 | for split in ['train', 'val', 'trainval', 'test']: 272 | dataset_name = 'instances' + '_' + split + '2017' 273 | out_dir = osp.join(CoCo_output_dir, "annotations") 274 | os.makedirs(out_dir, exist_ok=True) 275 | print(f'processing {dataset_name} ...') 276 | 277 | mmcv.dump(annotations_json, osp.join(out_dir, dataset_name + '.json')) 278 | print("Generating labels done.") 279 | 280 | 281 | def parse_args(): 282 | parser = argparse.ArgumentParser(description="Covert VOC dataset to COCO type.") 283 | parser.add_argument('--voc_type', default="VOC2012", help='voc dataset type') 284 | parser.add_argument('-n', '--num', type=int, default=2000, help='number of examples') 285 | parser.add_argument('-s', '--img_size', type=int, default=800, help='image size after processing') 286 | return parser.parse_args() 287 | 288 | 289 | if __name__ == '__main__': 290 | print(os.getcwd()) 291 | args = parse_args() 292 | img_size = args.img_size 293 | VOC_type = f"{args.voc_type}_{args.num}" 294 | VOCdevkit_dir = "./data/VOCdevkit" 295 | CoCo_output_dir = f"./data/Voc12_CoCo_{img_size}_{args.num}" 296 | 297 | os.makedirs(CoCo_output_dir, exist_ok=True) 298 | 299 | # Get the list of images 300 | split_list = mmcv.list_from_file(osp.join(VOCdevkit_dir, VOC_type, "ImageSets/Main/trainval.txt")) 301 | 302 | # Generate imgs 303 | generateImg(imgs_list=[number + ".jpg" for number in split_list], size=img_size) 304 | 305 | # Generate labels 306 | generateLabel(split_list, size=img_size) 307 | 308 | -------------------------------------------------------------------------------- /tools/project/base_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import torch 4 | import codecs 5 | import random 6 | import numpy as np 7 | from tools.utils import get_utc8_time 8 | 9 | 10 | class BaseConfig: 11 | def __init__(self, path: str) -> None: 12 | if not path: 13 | raise ValueError('Please specify the configuration file path.') 14 | if not os.path.exists(path): 15 | raise FileNotFoundError('File {} does not exist'.format(path)) 16 | if path.endswith('yml') or path.endswith('yaml'): 17 | self._dic = self._parse_from_yaml(path) 18 | else: 19 | raise RuntimeError('Config file should in yaml format!') 20 | self._timestamp = get_utc8_time() 21 | self._setting_env() 22 | self._setting_seed() 23 | 24 | def _setting_env(self): 25 | os.environ["timestamp"] = self._timestamp 26 | os.environ["project_path"] = self._dic["global"]["project_path"] 27 | 28 | def _setting_seed(self): 29 | seed = self._dic["global"]["seed"] 30 | if seed is None: 31 | seed = 0 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed(seed) 34 | torch.cuda.manual_seed_all(seed) 35 | np.random.seed(seed) 36 | random.seed(seed) 37 | torch.backends.cudnn.deterministic = True 38 | 39 | @staticmethod 40 | def _parse_from_yaml(path: str, full=True): 41 | """Parse a yaml file and build config""" 42 | with codecs.open(path, 'r', 'utf-8') as file: 43 | dic = yaml.load(file, Loader=yaml.FullLoader) 44 | if full and 'base' in dic: 45 | project_dir = dic["global"]["project_path"] 46 | base_path = dic.pop('base') 47 | base_path = os.path.join(project_dir, base_path) 48 | base_dic = BaseConfig._parse_from_yaml(base_path) 49 | dic = BaseConfig._update_dic(dic, base_dic) 50 | return dic 51 | 52 | @staticmethod 53 | def _update_dic(dic, base_dic): 54 | """Update config from dic based base_dic""" 55 | base_dic = base_dic.copy() 56 | for key, val in dic.items(): 57 | if isinstance(val, dict) and key in base_dic and base_dic.get(key) is not None: 58 | base_dic[key] = BaseConfig._update_dic(val, base_dic[key]) 59 | else: 60 | base_dic[key] = val 61 | dic = base_dic 62 | return dic 63 | 64 | def __str__(self): 65 | return yaml.dump(self._dic) 66 | -------------------------------------------------------------------------------- /tools/project/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import mmcv 5 | 6 | from tools.project.base_config import BaseConfig 7 | 8 | 9 | class ConfigYaml(BaseConfig): 10 | def __init__(self, path: str) -> None: 11 | super(ConfigYaml, self).__init__(path) 12 | os.environ["device"] = self.device 13 | 14 | @staticmethod 15 | def build(path): 16 | yaml_config = ConfigYaml(path) 17 | # Init Config 18 | _config_dict = dict() 19 | for attribute in yaml_config.__dir__(): 20 | if "_" != attribute[0] and "build" != attribute: 21 | _config_dict[attribute] = getattr(yaml_config, attribute) 22 | return mmcv.Config(_config_dict) 23 | 24 | @property 25 | def options(self): 26 | return self._dic.get("options") 27 | 28 | @property 29 | def saving_settings(self): 30 | return self._dic.get("saving_settings") 31 | 32 | @property 33 | def saving_dir(self): 34 | return self.saving_settings.get("saving_dir") 35 | 36 | @property 37 | def debug_mode(self): 38 | return self._dic.get("options").get("debug") 39 | 40 | @property 41 | def project_path(self): 42 | project_path = self._dic.get("global").get("project_path") 43 | os.environ["project_path"] = project_path 44 | return project_path 45 | 46 | @property 47 | def device(self): 48 | device = self._dic.get("global").get("device") 49 | return device 50 | 51 | @property 52 | def buffer(self): 53 | buffer_type = self._dic.get("global").get("buffer") 54 | noise_type = buffer_type.get("noise", buffer_type["global"]) 55 | buffer_type["noise_current"] = noise_type 56 | buffer_type["best_current"] = noise_type 57 | return buffer_type 58 | 59 | @property 60 | def attack_cfg(self): 61 | return self._dic.get("attack") 62 | 63 | @property 64 | def attack_base(self): 65 | return self.attack_cfg.get("method").get("base_attack") 66 | 67 | @property 68 | def attack_transfer(self): 69 | return self.attack_cfg.get("method").get("transfer_attack") 70 | 71 | @property 72 | def default_cfg(self): 73 | default_cfg_fp = osp.join(self.project_path, self.attack_cfg.get("default_cfg")) 74 | default_cfg = mmcv.Config.fromfile(default_cfg_fp) 75 | # update default_cfg 76 | method_settings = self.attack_cfg.get("method").get("method_settings") 77 | if method_settings is not None: 78 | for base_method in self.attack_base: 79 | if method_settings.get(base_method) is not None: 80 | default_cfg.base_attack[base_method].update(method_settings.get(base_method)) 81 | if method_settings.get(self.attack_transfer) is not None: 82 | default_cfg.transfer_attack[self.attack_transfer].update(method_settings.get(self.attack_transfer)) 83 | return default_cfg 84 | 85 | @property 86 | def dataloader_cfg(self): 87 | return self.attack_cfg.get("dataloader") 88 | 89 | @property 90 | def dataset_cfg(self): 91 | return self.attack_cfg.get("dataset") 92 | 93 | @property 94 | def eval_cfg(self): 95 | return self.attack_cfg.get("eval_cfg") 96 | 97 | @property 98 | def source_model_name(self): 99 | return self.attack_cfg.get("source") 100 | 101 | @property 102 | def source_model_cfg(self): 103 | cfg_dir = osp.join(self.project_path, self._dic.get("models").get("train_cfg_dir")) 104 | ckpt_dir = osp.join(self.project_path, self._dic.get("models").get("models_dir")) 105 | for model in self._dic.get("models").get("detectors"): 106 | model_name = model.get("name") 107 | if model_name == self.source_model_name: 108 | ckpt = model.get("ckpt") 109 | model_cfg_fp = osp.join(cfg_dir, model_name + ".py") 110 | model_checkpoint_fp = osp.join(ckpt_dir, ckpt + ".pth") 111 | return model_cfg_fp, model_checkpoint_fp 112 | 113 | @property 114 | def models_zoo(self): 115 | model_dict = dict() 116 | cfg_dir = osp.join(self.project_path, self._dic.get("models").get("eval_cfg_dir")) 117 | ckpt_dir = osp.join(self.project_path, self._dic.get("models").get("models_dir")) 118 | for model in self._dic.get("models").get("detectors"): 119 | model_name = model.get("name") 120 | ckpt = model.get("ckpt") 121 | model_cfg_fp = osp.join(cfg_dir, model_name + ".py") 122 | model_checkpoint_fp = osp.join(ckpt_dir, ckpt + ".pth") 123 | model_dict[model_name] = (model_cfg_fp, model_checkpoint_fp) 124 | return model_dict 125 | -------------------------------------------------------------------------------- /tools/project/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mmcv 3 | import torch 4 | import shutil 5 | import logging 6 | import traceback 7 | import os.path as osp 8 | import torchvision.utils as vutils 9 | from mmcv.parallel import scatter 10 | from attack.Attack import Attack 11 | from attack.utils.mmdet import parse_data, get_normalize_tools 12 | from tools.utils import get_file_name 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | 16 | class Logger: 17 | 18 | all_base_methods = ["MI", "DI", "RRB", "IFGSM"] 19 | 20 | def __init__(self, yaml_cfg, yaml_cfg_path) -> None: 21 | self.attack_logger = None 22 | self.eval_logger = None 23 | self.verbose_logger = None 24 | self.tb_writer = None 25 | 26 | self.config = yaml_cfg 27 | self.saving_path_manager = dict() 28 | 29 | if not self.config.debug_mode or "disk" in self.config.buffer.values(): 30 | self._init_saving_setting(yaml_cfg_path) 31 | 32 | # assist variables 33 | self.already_log_step_idx = 0 34 | self.already_log_clean = False 35 | 36 | # print timestamp 37 | mmcv.print_log("\nPID: " + str(os.getpid()), mmcv.get_logger("verbose_logger")) 38 | mmcv.print_log(os.environ.get("timestamp"), mmcv.get_logger("verbose_logger")) 39 | 40 | def _init_saving_setting(self, yaml_cfg_path): 41 | saving_settings = self.config.saving_settings 42 | project_path = self.config.project_path 43 | 44 | base_methods = self.config.attack_cfg.get("method").get("base_attack") 45 | base_method_name = [] 46 | for m in self.all_base_methods: 47 | if m in base_methods: 48 | base_method_name.append(m) 49 | base_method_name = ''.join([x[0] for x in base_method_name]) 50 | perturb_params = '_' + str(self.config.attack_cfg.get("epsilon")) + 'p' 51 | 52 | tsf_method = self.config.attack_cfg.get("method").get("transfer_attack") 53 | 54 | saving_dir = osp.join(project_path, "data/results", 55 | self.config.saving_dir.get("prefix", ""), 56 | self.config.source_model_name, 57 | base_method_name + perturb_params, 58 | tsf_method, 59 | self.config.saving_dir.get("suffix", ""), 60 | os.environ.get("timestamp")) 61 | mmcv.mkdir_or_exist(saving_dir) 62 | self.saving_path_manager["global_dir"] = saving_dir 63 | 64 | # buffer 65 | if "disk" in self.config.buffer.values(): 66 | tmp_dir = osp.join(saving_dir, "tmp") 67 | mmcv.mkdir_or_exist(tmp_dir) 68 | os.environ["tmp_dir"] = tmp_dir 69 | if "disk" == self.config.buffer.get("noise", self.config.buffer["global"]): 70 | tmp_current_noise_buffer = osp.join(saving_dir, "tmp", "noise_current") 71 | tmp_best_noise_buffer = osp.join(saving_dir, "tmp", "noise_best") 72 | tmp_helper = osp.join(saving_dir, "tmp", "helper") 73 | 74 | mmcv.mkdir_or_exist(tmp_current_noise_buffer) 75 | mmcv.mkdir_or_exist(tmp_best_noise_buffer) 76 | 77 | os.environ["tmp_current_noise_buffer"] = tmp_current_noise_buffer 78 | os.environ["tmp_best_noise_buffer"] = tmp_best_noise_buffer 79 | os.environ["tmp_helper"] = tmp_helper 80 | 81 | if not self.config.debug_mode: 82 | # config file backup dir 83 | config_dir = osp.join(saving_dir, "config") 84 | mmcv.mkdir_or_exist(config_dir) 85 | self._config_backup(config_dir, yaml_cfg_path) 86 | 87 | # log file dir 88 | if saving_settings.get("logging"): 89 | log_dir = osp.join(saving_dir, "log") 90 | mmcv.mkdir_or_exist(log_dir) 91 | self._init_logger(log_dir) 92 | 93 | # tensorboard dir 94 | if saving_settings.get("tboard"): 95 | tensorboard_dir = osp.join(saving_dir, "tensorboard") 96 | mmcv.mkdir_or_exist(tensorboard_dir) 97 | self._init_tensorboard(tensorboard_dir) 98 | 99 | saving_best_white = saving_settings.get("best_white") 100 | saving_best_black = saving_settings.get("best_black") 101 | 102 | # noise dir 103 | if saving_settings.get("noise"): 104 | noise_dir = osp.join(saving_dir, "noise") 105 | self.saving_path_manager["noise_dir"] = noise_dir 106 | if saving_best_white: 107 | best_white = osp.join(noise_dir, "best_white") 108 | mmcv.mkdir_or_exist(best_white) 109 | if saving_best_black: 110 | best_black = osp.join(noise_dir, "best_black") 111 | mmcv.mkdir_or_exist(best_black) 112 | 113 | # adv imgs dir 114 | if saving_settings.get("adv_img"): 115 | adv_imgs_dir = osp.join(saving_dir, "adv_imgs") 116 | self.saving_path_manager["adv_imgs_dir"] = adv_imgs_dir 117 | if saving_best_white: 118 | best_white = osp.join(adv_imgs_dir, "best_white") 119 | mmcv.mkdir_or_exist(best_white) 120 | if saving_best_black: 121 | best_black = osp.join(adv_imgs_dir, "best_black") 122 | mmcv.mkdir_or_exist(best_black) 123 | 124 | # others 125 | if saving_settings.get("others"): 126 | others_dir = osp.join(saving_dir, "others") 127 | self.saving_path_manager["others_dir"] = others_dir 128 | if saving_best_white: 129 | best_white = osp.join(others_dir, "best_white") 130 | mmcv.mkdir_or_exist(best_white) 131 | if saving_best_black: 132 | best_black = osp.join(others_dir, "best_black") 133 | mmcv.mkdir_or_exist(best_black) 134 | return 135 | 136 | def _init_logger(self, log_dir): 137 | """Init log logger""" 138 | attack_logger_fp = osp.join(log_dir, "attack_log.txt") 139 | eval_logger_fp = osp.join(log_dir, "eval_log.txt") 140 | error_logger_fp = osp.join(log_dir, "error_log.txt") 141 | verbose_logger_fp = osp.join(log_dir, "verbose_log.txt") 142 | self.attack_logger = mmcv.get_logger("attack_logger", attack_logger_fp) 143 | self.eval_logger = mmcv.get_logger("eval_logger", eval_logger_fp) 144 | self.error_logger = mmcv.get_logger("error_logger", error_logger_fp) 145 | self.verbose_logger = mmcv.get_logger("verbose_logger", verbose_logger_fp) 146 | return 147 | 148 | def _config_backup(self, config_dir, yaml_cfg_path): 149 | """copy current configuration to results dir""" 150 | # backup base yaml config 151 | project_path = self.config.project_path 152 | shutil.copyfile(osp.join(project_path, "config/base.yaml"), 153 | osp.join(config_dir, "base.yaml")) 154 | # backup default config 155 | shutil.copyfile(osp.join(project_path, "config/default.py"), 156 | osp.join(config_dir, "default.py")) 157 | # backup attack yaml config 158 | shutil.copyfile(osp.join(project_path, yaml_cfg_path), 159 | osp.join(config_dir, osp.split(yaml_cfg_path)[-1])) 160 | # backup train model config in mmcls 161 | shutil.copyfile(osp.join(project_path, "ummdet/checkpoints/train_cfg", self.config.source_model_name + ".py"), 162 | osp.join(config_dir, self.config.source_model_name + ".py")) 163 | return 164 | 165 | def _init_tensorboard(self, log_dir): 166 | """Init TB""" 167 | self.tb_writer = SummaryWriter(log_dir=log_dir, flush_secs=60) 168 | 169 | def logging_clean_metric(self, recorder): 170 | if self.tb_writer is not None: 171 | """write clean metric to TB""" 172 | black_metric, white_metric = getattr(recorder, "_calculate_metric")(recorder.config.source_model_name, 173 | recorder.all_metric_dict[0]) 174 | self.tb_writer.add_scalar("metric_white", white_metric, 0) 175 | self.tb_writer.add_scalar("metric_black", black_metric, 0) 176 | 177 | def logging_epoch(self, epoch, recorder): 178 | epoch = epoch + 1 179 | accumulate_steps = recorder.config.attack_cfg.get("steps") * (epoch - 1) 180 | # record loss in step 181 | for step, loss in enumerate(recorder.loss_steps_one_epoch): 182 | if self.tb_writer is not None: 183 | self.tb_writer.add_scalar("loss_step", loss, accumulate_steps + step + 1) 184 | mmcv.print_log(f"step: {accumulate_steps + step + 1} loss: {loss}", logger=self.attack_logger, 185 | level=logging.INFO) 186 | 187 | # record loss in epoch 188 | if self.tb_writer is not None: 189 | self.tb_writer.add_scalar("loss_epoch", recorder.loss_epoch, accumulate_steps + 1) 190 | # record metric 191 | self.tb_writer.add_scalar("metric_white", recorder.white_metric, accumulate_steps + 1) 192 | self.tb_writer.add_scalar("metric_black", recorder.black_metric, accumulate_steps + 1) 193 | 194 | mmcv.print_log(f"epoch: {epoch} loss: {recorder.loss_epoch}\n" 195 | f"metric_white: {recorder.white_metric} metric_black: {recorder.black_metric}", 196 | logger=self.attack_logger, 197 | level=logging.INFO) 198 | 199 | def saving_results(self, recorder, attack: Attack, save_type="best_black"): 200 | """save all results from memory to disk""" 201 | # All results will be threw away in debug mode. 202 | if self.config.debug_mode: 203 | return 204 | saving_settings = self.config.saving_settings 205 | if not (saving_settings.get(save_type) is not None and saving_settings.get(save_type)): 206 | return 207 | 208 | # saving adversarial images and noise 209 | if saving_settings.get("adv_img") or saving_settings.get("noise"): 210 | dataloader = attack.dataloader 211 | noise_dir = "noise_current" if "best_white" == save_type else "noise_best" 212 | try: 213 | mmcv.print_log("Start saving adversarial images and noise ...", 214 | logger=self.verbose_logger, level=logging.INFO) 215 | normalizer, denormalizer = None, None 216 | for batch_idx, data in enumerate(mmcv.track_iter_progress(dataloader)): 217 | data = scatter(data, [0])[0] 218 | data_imgs, data_img_metas, _, _ = parse_data(data) 219 | data_noises = attack.buffer.load(noise_dir, str(batch_idx)) 220 | if saving_settings.get("adv_img"): 221 | if normalizer is None or denormalizer is None: 222 | normalizer, denormalizer = get_normalize_tools(data_img_metas[0]) 223 | data_clean_imgs = denormalizer(data_imgs) 224 | # Quantization 225 | data_adv_imgs = torch.clamp(torch.round(data_clean_imgs + data_noises), min=0., max=255.) 226 | 227 | for idx, image_metas in enumerate(data_img_metas): 228 | image_name = get_file_name(image_metas.get("ori_filename"), with_ext=False) 229 | if saving_settings.get("noise"): 230 | torch.save(data_noises[idx].detach().cpu(), 231 | osp.join(self.saving_path_manager.get("noise_dir"), 232 | save_type, image_name + ".pth")) 233 | if saving_settings.get("adv_img"): 234 | vutils.save_image(data_adv_imgs[idx].detach().cpu(), 235 | osp.join(self.saving_path_manager.get("adv_imgs_dir"), 236 | f"{save_type}", 237 | image_name + ".png"), 238 | normalize=True, 239 | value_range=(0, 255)) 240 | 241 | mmcv.print_log("Saving saving adversarial images and noise done!", logger=self.verbose_logger, 242 | level=logging.INFO) 243 | except Exception as e: 244 | mmcv.print_log(str(e.args), logger=self.error_logger, 245 | level=logging.ERROR) 246 | mmcv.print_log(traceback.format_exc(), logger=self.error_logger, 247 | level=logging.ERROR) 248 | 249 | # saving others to pkl 250 | if saving_settings.get("others"): 251 | try: 252 | mmcv.print_log("Start saving others...", logger=self.verbose_logger, level=logging.INFO) 253 | others_dir = self.saving_path_manager.get("others_dir") 254 | mmcv.dump(recorder.loss_list, osp.join(others_dir, f"{save_type}", "loss_step.pkl")) 255 | mmcv.dump(recorder.all_metric_dict, osp.join(others_dir, f"{save_type}", "all_metric_dict.pkl")) 256 | best_params = dict() 257 | best_params["best_epoch"] = recorder.best_epoch 258 | best_params["best_white_metric"] = recorder.best_white_metric 259 | best_params["best_black_metric"] = recorder.best_black_metric 260 | best_params["best_metric_dict"] = recorder.best_metric_dict 261 | mmcv.dump(best_params, osp.join(others_dir, f"{save_type}", "best_params.pkl")) 262 | mmcv.dump(best_params, osp.join(others_dir, f"{save_type}", "best_params.yaml")) 263 | mmcv.print_log("Saving others done!", logger=self.verbose_logger, level=logging.INFO) 264 | except Exception as e: 265 | mmcv.print_log(str(e.args), logger=self.error_logger, 266 | level=logging.ERROR) 267 | mmcv.print_log(traceback.format_exc(), logger=self.error_logger, 268 | level=logging.ERROR) 269 | return 270 | 271 | def close_logger(self): 272 | # Close Tb_writer 273 | if self.tb_writer is not None: 274 | self.tb_writer.close() 275 | # clear tmp dir 276 | if "disk" in self.config.buffer.values(): 277 | shutil.rmtree(os.environ["tmp_dir"]) 278 | -------------------------------------------------------------------------------- /tools/project/recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import numpy as np 4 | 5 | 6 | class Recorder(object): 7 | 8 | def __init__(self, args) -> None: 9 | self.config = args 10 | # best params 11 | self.best_epoch = 0 12 | self.best_white_metric = None 13 | self.best_black_metric = None 14 | self.best_metric_dict = None 15 | # current 16 | self.metric_dict = None 17 | self.noise_dict = None 18 | self.white_metric = None 19 | self.black_metric = None 20 | self.loss_epoch = None 21 | self.loss_steps_one_epoch = None 22 | # saving to pkl 23 | self.loss_list = [] 24 | # key:epoch value:eval_dict 25 | self.all_metric_dict = dict() 26 | os.environ["noise_dir"] = "noise_current" 27 | 28 | def _record_best(self, epoch): 29 | self.best_metric_dict = copy.deepcopy(self.metric_dict) 30 | self.best_epoch = epoch 31 | self.best_white_metric = self.white_metric 32 | self.best_black_metric = self.black_metric 33 | 34 | def _update_epoch(self, epoch, loss_step_list, metric_dict, samples_num, **kwargs): 35 | self.metric_dict = metric_dict 36 | # record loss 37 | loss_sum_epoch_ndarray = np.array(loss_step_list) 38 | loss_mean_epoch_ndarray = loss_sum_epoch_ndarray / samples_num 39 | self.loss_epoch = loss_mean_epoch_ndarray.mean() 40 | self.loss_steps_one_epoch = loss_mean_epoch_ndarray.tolist() 41 | self.loss_list.extend(self.loss_steps_one_epoch) 42 | # record metric 43 | self.white_metric, self.black_metric = self._calculate_metric(self.config.source_model_name, metric_dict) 44 | # record all metrics 45 | self.all_metric_dict[epoch] = metric_dict 46 | 47 | def record_clean(self, clean_metric_dict): 48 | self.all_metric_dict[0] = clean_metric_dict 49 | 50 | @staticmethod 51 | def _calculate_metric(source_name, metric_dict): 52 | tmp_dict = dict() 53 | white_metric = metric_dict.get(source_name, tmp_dict).get("bbox_mAP") 54 | black_metric = 0 55 | for name, metric in metric_dict.items(): 56 | if name != source_name: 57 | black_metric += metric.get("bbox_mAP") 58 | black_metric /= (len(metric_dict) - 1) 59 | return white_metric, black_metric 60 | 61 | 62 | class RecorderMemory(Recorder): 63 | 64 | def __init__(self, args) -> None: 65 | super(RecorderMemory, self).__init__(args) 66 | 67 | def record_best(self, epoch, buffer): 68 | noise_dict = buffer.buffer_dict["noise_current"] 69 | buffer.buffer_dict.pop("noise_best", None) 70 | buffer.buffer_dict["noise_best"] = copy.deepcopy(noise_dict) 71 | self._record_best(epoch) 72 | return 73 | 74 | def update_epoch(self, epoch, buffer, loss_step_list, metric_dict, samples_num, **kwargs): 75 | self._update_epoch(epoch, loss_step_list, metric_dict, samples_num, **kwargs) 76 | # Update the optimal solution 77 | if self.best_black_metric is None or self.black_metric < self.best_black_metric: 78 | self.record_best(epoch, buffer) 79 | 80 | 81 | class RecorderDisk(Recorder): 82 | 83 | def __init__(self, args) -> None: 84 | super(RecorderDisk, self).__init__(args) 85 | 86 | def record_best(self, epoch): 87 | best_dir = os.environ.get("tmp_best_noise_buffer") 88 | current_dir = os.environ.get("tmp_current_noise_buffer") 89 | helper_dir = os.environ.get("tmp_helper") 90 | # deprecated best -> helper 91 | os.rename(best_dir, helper_dir) 92 | # current -> best 93 | os.rename(current_dir, best_dir) 94 | # helper -> current 95 | os.rename(helper_dir, current_dir) 96 | self._record_best(epoch) 97 | 98 | def update_epoch(self, epoch, loss_step_list, metric_dict, samples_num, **kwargs): 99 | self._update_epoch(epoch, loss_step_list, metric_dict, samples_num, **kwargs) 100 | # Update the optimal solution 101 | if self.best_black_metric is None or self.black_metric < self.best_black_metric: 102 | self.record_best(epoch) 103 | # make sure that the noise load for the next epoch is the best one 104 | os.environ["noise_dir"] = "noise_best" 105 | else: 106 | os.environ["noise_dir"] = "noise_current" 107 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from datetime import datetime 3 | 4 | 5 | def get_utc8_time(): 6 | utc_now = datetime.utcnow() 7 | return utc_now.strftime('%Y-%m-%d-%H-%M-%f') 8 | 9 | 10 | def get_file_name(fp, with_ext=True): 11 | if with_ext: 12 | return osp.split(fp)[-1] 13 | return osp.split(fp)[-1].split(".")[0] 14 | 15 | -------------------------------------------------------------------------------- /ummdet/checkpoints/eval_cfg/detr_r50_8x2_150e_coco.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'CocoDataset' 2 | data_root = 'data/coco/' 3 | img_norm_cfg = dict( 4 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 5 | train_pipeline = [ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations', with_bbox=True), 8 | dict(type='RandomFlip', flip_ratio=0.5), 9 | dict( 10 | type='AutoAugment', 11 | policies=[[{ 12 | 'type': 13 | 'Resize', 14 | 'img_scale': [(480, 1333), (512, 1333), (544, 1333), (576, 1333), 15 | (608, 1333), (640, 1333), (672, 1333), (704, 1333), 16 | (736, 1333), (768, 1333), (800, 1333)], 17 | 'multiscale_mode': 18 | 'value', 19 | 'keep_ratio': 20 | True 21 | }], 22 | [{ 23 | 'type': 'Resize', 24 | 'img_scale': [(400, 1333), (500, 1333), (600, 1333)], 25 | 'multiscale_mode': 'value', 26 | 'keep_ratio': True 27 | }, { 28 | 'type': 'RandomCrop', 29 | 'crop_type': 'absolute_range', 30 | 'crop_size': (384, 600), 31 | 'allow_negative_crop': True 32 | }, { 33 | 'type': 34 | 'Resize', 35 | 'img_scale': [(480, 1333), (512, 1333), (544, 1333), 36 | (576, 1333), (608, 1333), (640, 1333), 37 | (672, 1333), (704, 1333), (736, 1333), 38 | (768, 1333), (800, 1333)], 39 | 'multiscale_mode': 40 | 'value', 41 | 'override': 42 | True, 43 | 'keep_ratio': 44 | True 45 | }]]), 46 | dict( 47 | type='Normalize', 48 | mean=[123.675, 116.28, 103.53], 49 | std=[58.395, 57.12, 57.375], 50 | to_rgb=True), 51 | dict(type='Pad', size_divisor=1), 52 | dict(type='DefaultFormatBundle'), 53 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 54 | ] 55 | test_pipeline = [ 56 | dict(type='LoadImageFromFile'), 57 | dict( 58 | type='MultiScaleFlipAug', 59 | img_scale=(1333, 800), 60 | flip=False, 61 | transforms=[ 62 | dict(type='Resize', keep_ratio=True), 63 | dict(type='RandomFlip'), 64 | dict( 65 | type='Normalize', 66 | mean=[123.675, 116.28, 103.53], 67 | std=[58.395, 57.12, 57.375], 68 | to_rgb=True), 69 | dict(type='Pad', size_divisor=1), 70 | dict(type='ImageToTensor', keys=['img']), 71 | dict(type='Collect', keys=['img']) 72 | ]) 73 | ] 74 | data = dict( 75 | samples_per_gpu=2, 76 | workers_per_gpu=2, 77 | train=dict( 78 | type='CocoDataset', 79 | ann_file='data/coco/annotations/instances_train2017.json', 80 | img_prefix='data/coco/train2017/', 81 | pipeline=[ 82 | dict(type='LoadImageFromFile'), 83 | dict(type='LoadAnnotations', with_bbox=True), 84 | dict(type='RandomFlip', flip_ratio=0.5), 85 | dict( 86 | type='AutoAugment', 87 | policies=[[{ 88 | 'type': 89 | 'Resize', 90 | 'img_scale': [(480, 1333), (512, 1333), (544, 1333), 91 | (576, 1333), (608, 1333), (640, 1333), 92 | (672, 1333), (704, 1333), (736, 1333), 93 | (768, 1333), (800, 1333)], 94 | 'multiscale_mode': 95 | 'value', 96 | 'keep_ratio': 97 | True 98 | }], 99 | [{ 100 | 'type': 'Resize', 101 | 'img_scale': [(400, 1333), (500, 1333), 102 | (600, 1333)], 103 | 'multiscale_mode': 'value', 104 | 'keep_ratio': True 105 | }, { 106 | 'type': 'RandomCrop', 107 | 'crop_type': 'absolute_range', 108 | 'crop_size': (384, 600), 109 | 'allow_negative_crop': True 110 | }, { 111 | 'type': 112 | 'Resize', 113 | 'img_scale': [(480, 1333), (512, 1333), 114 | (544, 1333), (576, 1333), 115 | (608, 1333), (640, 1333), 116 | (672, 1333), (704, 1333), 117 | (736, 1333), (768, 1333), 118 | (800, 1333)], 119 | 'multiscale_mode': 120 | 'value', 121 | 'override': 122 | True, 123 | 'keep_ratio': 124 | True 125 | }]]), 126 | dict( 127 | type='Normalize', 128 | mean=[123.675, 116.28, 103.53], 129 | std=[58.395, 57.12, 57.375], 130 | to_rgb=True), 131 | dict(type='Pad', size_divisor=1), 132 | dict(type='DefaultFormatBundle'), 133 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 134 | ]), 135 | val=dict( 136 | type='CocoDataset', 137 | ann_file='data/coco/annotations/instances_val2017.json', 138 | img_prefix='data/coco/val2017/', 139 | pipeline=[ 140 | dict(type='LoadImageFromFile'), 141 | dict( 142 | type='MultiScaleFlipAug', 143 | img_scale=(1333, 800), 144 | flip=False, 145 | transforms=[ 146 | dict(type='Resize', keep_ratio=True), 147 | dict(type='RandomFlip'), 148 | dict( 149 | type='Normalize', 150 | mean=[123.675, 116.28, 103.53], 151 | std=[58.395, 57.12, 57.375], 152 | to_rgb=True), 153 | dict(type='Pad', size_divisor=1), 154 | dict(type='ImageToTensor', keys=['img']), 155 | dict(type='Collect', keys=['img']) 156 | ]) 157 | ]), 158 | test=dict( 159 | type='CocoDataset', 160 | ann_file='data/Voc12_CoCo_800/annotations/instances_train2017.json', 161 | img_prefix='data/Voc12_CoCo_800/train2017/', 162 | pipeline=[ 163 | dict(type='LoadImageFromFile'), 164 | dict( 165 | type='MultiScaleFlipAug', 166 | img_scale=(1333, 800), 167 | flip=False, 168 | transforms=[ 169 | dict(type='Resize', keep_ratio=True), 170 | dict(type='RandomFlip'), 171 | dict( 172 | type='Normalize', 173 | mean=[123.675, 116.28, 103.53], 174 | std=[58.395, 57.12, 57.375], 175 | to_rgb=True), 176 | dict(type='Pad', size_divisor=1), 177 | dict(type='ImageToTensor', keys=['img']), 178 | dict(type='Collect', keys=['img']) 179 | ]) 180 | ])) 181 | evaluation = dict(interval=1, metric='bbox') 182 | checkpoint_config = dict(interval=1) 183 | log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')]) 184 | custom_hooks = [dict(type='NumClassCheckHook')] 185 | dist_params = dict(backend='nccl') 186 | log_level = 'INFO' 187 | load_from = None 188 | resume_from = None 189 | workflow = [('train', 1)] 190 | opencv_num_threads = 0 191 | mp_start_method = 'fork' 192 | auto_scale_lr = dict(enable=False, base_batch_size=16) 193 | model = dict( 194 | type='DETR', 195 | backbone=dict( 196 | type='ResNet', 197 | depth=50, 198 | num_stages=4, 199 | out_indices=(3, ), 200 | frozen_stages=1, 201 | norm_cfg=dict(type='BN', requires_grad=False), 202 | norm_eval=True, 203 | style='pytorch', 204 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), 205 | bbox_head=dict( 206 | type='DETRHead', 207 | num_classes=80, 208 | in_channels=2048, 209 | transformer=dict( 210 | type='Transformer', 211 | encoder=dict( 212 | type='DetrTransformerEncoder', 213 | num_layers=6, 214 | transformerlayers=dict( 215 | type='BaseTransformerLayer', 216 | attn_cfgs=[ 217 | dict( 218 | type='MultiheadAttention', 219 | embed_dims=256, 220 | num_heads=8, 221 | dropout=0.1) 222 | ], 223 | feedforward_channels=2048, 224 | ffn_dropout=0.1, 225 | operation_order=('self_attn', 'norm', 'ffn', 'norm'))), 226 | decoder=dict( 227 | type='DetrTransformerDecoder', 228 | return_intermediate=True, 229 | num_layers=6, 230 | transformerlayers=dict( 231 | type='DetrTransformerDecoderLayer', 232 | attn_cfgs=dict( 233 | type='MultiheadAttention', 234 | embed_dims=256, 235 | num_heads=8, 236 | dropout=0.1), 237 | feedforward_channels=2048, 238 | ffn_dropout=0.1, 239 | operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 240 | 'ffn', 'norm')))), 241 | positional_encoding=dict( 242 | type='SinePositionalEncoding', num_feats=128, normalize=True), 243 | loss_cls=dict( 244 | type='CrossEntropyLoss', 245 | bg_cls_weight=0.1, 246 | use_sigmoid=False, 247 | loss_weight=1.0, 248 | class_weight=1.0), 249 | loss_bbox=dict(type='L1Loss', loss_weight=5.0), 250 | loss_iou=dict(type='GIoULoss', loss_weight=2.0)), 251 | train_cfg=dict( 252 | assigner=dict( 253 | type='HungarianAssigner', 254 | cls_cost=dict(type='ClassificationCost', weight=1.0), 255 | reg_cost=dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'), 256 | iou_cost=dict(type='IoUCost', iou_mode='giou', weight=2.0))), 257 | test_cfg=dict(max_per_img=100)) 258 | optimizer = dict( 259 | type='AdamW', 260 | lr=0.0001, 261 | weight_decay=0.0001, 262 | paramwise_cfg=dict( 263 | custom_keys=dict(backbone=dict(lr_mult=0.1, decay_mult=1.0)))) 264 | optimizer_config = dict(grad_clip=dict(max_norm=0.1, norm_type=2)) 265 | lr_config = dict(policy='step', step=[100]) 266 | runner = dict(type='EpochBasedRunner', max_epochs=150) 267 | -------------------------------------------------------------------------------- /ummdet/checkpoints/eval_cfg/faster_rcnn_r101_caffe_fpn_mstrain_3x_coco.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=1) 2 | log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')]) 3 | custom_hooks = [dict(type='NumClassCheckHook')] 4 | dist_params = dict(backend='nccl') 5 | log_level = 'INFO' 6 | load_from = None 7 | resume_from = None 8 | workflow = [('train', 1)] 9 | opencv_num_threads = 0 10 | mp_start_method = 'fork' 11 | auto_scale_lr = dict(enable=False, base_batch_size=16) 12 | dataset_type = 'CocoDataset' 13 | data_root = 'data/coco/' 14 | img_norm_cfg = dict( 15 | mean=[103.53, 116.28, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 16 | train_pipeline = [ 17 | dict(type='LoadImageFromFile'), 18 | dict(type='LoadAnnotations', with_bbox=True), 19 | dict( 20 | type='Resize', 21 | img_scale=[(1333, 640), (1333, 800)], 22 | multiscale_mode='range', 23 | keep_ratio=True), 24 | dict(type='RandomFlip', flip_ratio=0.5), 25 | dict( 26 | type='Normalize', 27 | mean=[103.53, 116.28, 123.675], 28 | std=[1.0, 1.0, 1.0], 29 | to_rgb=False), 30 | dict(type='Pad', size_divisor=32), 31 | dict(type='DefaultFormatBundle'), 32 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 33 | ] 34 | test_pipeline = [ 35 | dict(type='LoadImageFromFile'), 36 | dict( 37 | type='MultiScaleFlipAug', 38 | img_scale=(1333, 800), 39 | flip=False, 40 | transforms=[ 41 | dict(type='Resize', keep_ratio=True), 42 | dict(type='RandomFlip'), 43 | dict( 44 | type='Normalize', 45 | mean=[103.53, 116.28, 123.675], 46 | std=[1.0, 1.0, 1.0], 47 | to_rgb=False), 48 | dict(type='Pad', size_divisor=32), 49 | dict(type='ImageToTensor', keys=['img']), 50 | dict(type='Collect', keys=['img']) 51 | ]) 52 | ] 53 | data = dict( 54 | samples_per_gpu=2, 55 | workers_per_gpu=2, 56 | train=dict( 57 | type='RepeatDataset', 58 | times=3, 59 | dataset=dict( 60 | type='CocoDataset', 61 | ann_file='data/coco/annotations/instances_train2017.json', 62 | img_prefix='data/coco/train2017/', 63 | pipeline=[ 64 | dict(type='LoadImageFromFile'), 65 | dict(type='LoadAnnotations', with_bbox=True), 66 | dict( 67 | type='Resize', 68 | img_scale=[(1333, 640), (1333, 800)], 69 | multiscale_mode='range', 70 | keep_ratio=True), 71 | dict(type='RandomFlip', flip_ratio=0.5), 72 | dict( 73 | type='Normalize', 74 | mean=[103.53, 116.28, 123.675], 75 | std=[1.0, 1.0, 1.0], 76 | to_rgb=False), 77 | dict(type='Pad', size_divisor=32), 78 | dict(type='DefaultFormatBundle'), 79 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 80 | ])), 81 | val=dict( 82 | type='CocoDataset', 83 | ann_file='data/coco/annotations/instances_val2017.json', 84 | img_prefix='data/coco/val2017/', 85 | pipeline=[ 86 | dict(type='LoadImageFromFile'), 87 | dict( 88 | type='MultiScaleFlipAug', 89 | img_scale=(1333, 800), 90 | flip=False, 91 | transforms=[ 92 | dict(type='Resize', keep_ratio=True), 93 | dict(type='RandomFlip'), 94 | dict( 95 | type='Normalize', 96 | mean=[103.53, 116.28, 123.675], 97 | std=[1.0, 1.0, 1.0], 98 | to_rgb=False), 99 | dict(type='Pad', size_divisor=32), 100 | dict(type='ImageToTensor', keys=['img']), 101 | dict(type='Collect', keys=['img']) 102 | ]) 103 | ]), 104 | test=dict( 105 | type='CocoDataset', 106 | ann_file='data/Voc12_CoCo_800/annotations/instances_train2017.json', 107 | img_prefix='data/Voc12_CoCo_800/train2017/', 108 | pipeline=[ 109 | dict(type='LoadImageFromFile'), 110 | dict( 111 | type='MultiScaleFlipAug', 112 | img_scale=(1333, 800), 113 | flip=False, 114 | transforms=[ 115 | dict(type='Resize', keep_ratio=True), 116 | dict(type='RandomFlip'), 117 | dict( 118 | type='Normalize', 119 | mean=[103.53, 116.28, 123.675], 120 | std=[1.0, 1.0, 1.0], 121 | to_rgb=False), 122 | dict(type='Pad', size_divisor=32), 123 | dict(type='ImageToTensor', keys=['img']), 124 | dict(type='Collect', keys=['img']) 125 | ]) 126 | ])) 127 | evaluation = dict(interval=1, metric='bbox') 128 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) 129 | optimizer_config = dict(grad_clip=None) 130 | lr_config = dict( 131 | policy='step', 132 | warmup='linear', 133 | warmup_iters=500, 134 | warmup_ratio=0.001, 135 | step=[9, 11]) 136 | runner = dict(type='EpochBasedRunner', max_epochs=12) 137 | model = dict( 138 | type='FasterRCNN', 139 | backbone=dict( 140 | type='ResNet', 141 | depth=101, 142 | num_stages=4, 143 | out_indices=(0, 1, 2, 3), 144 | frozen_stages=1, 145 | norm_cfg=dict(type='BN', requires_grad=False), 146 | norm_eval=True, 147 | style='caffe', 148 | init_cfg=dict( 149 | type='Pretrained', 150 | checkpoint='open-mmlab://detectron2/resnet101_caffe')), 151 | neck=dict( 152 | type='FPN', 153 | in_channels=[256, 512, 1024, 2048], 154 | out_channels=256, 155 | num_outs=5), 156 | rpn_head=dict( 157 | type='RPNHead', 158 | in_channels=256, 159 | feat_channels=256, 160 | anchor_generator=dict( 161 | type='AnchorGenerator', 162 | scales=[8], 163 | ratios=[0.5, 1.0, 2.0], 164 | strides=[4, 8, 16, 32, 64]), 165 | bbox_coder=dict( 166 | type='DeltaXYWHBBoxCoder', 167 | target_means=[0.0, 0.0, 0.0, 0.0], 168 | target_stds=[1.0, 1.0, 1.0, 1.0]), 169 | loss_cls=dict( 170 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 171 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 172 | roi_head=dict( 173 | type='StandardRoIHead', 174 | bbox_roi_extractor=dict( 175 | type='SingleRoIExtractor', 176 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), 177 | out_channels=256, 178 | featmap_strides=[4, 8, 16, 32]), 179 | bbox_head=dict( 180 | type='Shared2FCBBoxHead', 181 | in_channels=256, 182 | fc_out_channels=1024, 183 | roi_feat_size=7, 184 | num_classes=80, 185 | bbox_coder=dict( 186 | type='DeltaXYWHBBoxCoder', 187 | target_means=[0.0, 0.0, 0.0, 0.0], 188 | target_stds=[0.1, 0.1, 0.2, 0.2]), 189 | reg_class_agnostic=False, 190 | loss_cls=dict( 191 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 192 | loss_bbox=dict(type='L1Loss', loss_weight=1.0))), 193 | train_cfg=dict( 194 | rpn=dict( 195 | assigner=dict( 196 | type='MaxIoUAssigner', 197 | pos_iou_thr=0.7, 198 | neg_iou_thr=0.3, 199 | min_pos_iou=0.3, 200 | match_low_quality=True, 201 | ignore_iof_thr=-1), 202 | sampler=dict( 203 | type='RandomSampler', 204 | num=256, 205 | pos_fraction=0.5, 206 | neg_pos_ub=-1, 207 | add_gt_as_proposals=False), 208 | allowed_border=-1, 209 | pos_weight=-1, 210 | debug=False), 211 | rpn_proposal=dict( 212 | nms_pre=2000, 213 | max_per_img=1000, 214 | nms=dict(type='nms', iou_threshold=0.7), 215 | min_bbox_size=0), 216 | rcnn=dict( 217 | assigner=dict( 218 | type='MaxIoUAssigner', 219 | pos_iou_thr=0.5, 220 | neg_iou_thr=0.5, 221 | min_pos_iou=0.5, 222 | match_low_quality=False, 223 | ignore_iof_thr=-1), 224 | sampler=dict( 225 | type='RandomSampler', 226 | num=512, 227 | pos_fraction=0.25, 228 | neg_pos_ub=-1, 229 | add_gt_as_proposals=True), 230 | pos_weight=-1, 231 | debug=False)), 232 | test_cfg=dict( 233 | rpn=dict( 234 | nms_pre=1000, 235 | max_per_img=1000, 236 | nms=dict(type='nms', iou_threshold=0.7), 237 | min_bbox_size=0), 238 | rcnn=dict( 239 | score_thr=0.05, 240 | nms=dict(type='nms', iou_threshold=0.5), 241 | max_per_img=100))) 242 | -------------------------------------------------------------------------------- /ummdet/checkpoints/eval_cfg/fcos_r50_caffe_fpn_gn-head_mstrain_640-800_2x_coco.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'CocoDataset' 2 | data_root = 'data/coco/' 3 | img_norm_cfg = dict( 4 | mean=[102.9801, 115.9465, 122.7717], std=[1.0, 1.0, 1.0], to_rgb=False) 5 | train_pipeline = [ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations', with_bbox=True), 8 | dict( 9 | type='Resize', 10 | img_scale=[(1333, 640), (1333, 800)], 11 | multiscale_mode='value', 12 | keep_ratio=True), 13 | dict(type='RandomFlip', flip_ratio=0.5), 14 | dict( 15 | type='Normalize', 16 | mean=[102.9801, 115.9465, 122.7717], 17 | std=[1.0, 1.0, 1.0], 18 | to_rgb=False), 19 | dict(type='Pad', size_divisor=32), 20 | dict(type='DefaultFormatBundle'), 21 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 22 | ] 23 | test_pipeline = [ 24 | dict(type='LoadImageFromFile'), 25 | dict( 26 | type='MultiScaleFlipAug', 27 | img_scale=(1333, 800), 28 | flip=False, 29 | transforms=[ 30 | dict(type='Resize', keep_ratio=True), 31 | dict(type='RandomFlip'), 32 | dict( 33 | type='Normalize', 34 | mean=[102.9801, 115.9465, 122.7717], 35 | std=[1.0, 1.0, 1.0], 36 | to_rgb=False), 37 | dict(type='Pad', size_divisor=32), 38 | dict(type='ImageToTensor', keys=['img']), 39 | dict(type='Collect', keys=['img']) 40 | ]) 41 | ] 42 | data = dict( 43 | samples_per_gpu=2, 44 | workers_per_gpu=2, 45 | train=dict( 46 | type='CocoDataset', 47 | ann_file='data/coco/annotations/instances_train2017.json', 48 | img_prefix='data/coco/train2017/', 49 | pipeline=[ 50 | dict(type='LoadImageFromFile'), 51 | dict(type='LoadAnnotations', with_bbox=True), 52 | dict( 53 | type='Resize', 54 | img_scale=[(1333, 640), (1333, 800)], 55 | multiscale_mode='value', 56 | keep_ratio=True), 57 | dict(type='RandomFlip', flip_ratio=0.5), 58 | dict( 59 | type='Normalize', 60 | mean=[102.9801, 115.9465, 122.7717], 61 | std=[1.0, 1.0, 1.0], 62 | to_rgb=False), 63 | dict(type='Pad', size_divisor=32), 64 | dict(type='DefaultFormatBundle'), 65 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 66 | ]), 67 | val=dict( 68 | type='CocoDataset', 69 | ann_file='data/coco/annotations/instances_val2017.json', 70 | img_prefix='data/coco/val2017/', 71 | pipeline=[ 72 | dict(type='LoadImageFromFile'), 73 | dict( 74 | type='MultiScaleFlipAug', 75 | img_scale=(1333, 800), 76 | flip=False, 77 | transforms=[ 78 | dict(type='Resize', keep_ratio=True), 79 | dict(type='RandomFlip'), 80 | dict( 81 | type='Normalize', 82 | mean=[102.9801, 115.9465, 122.7717], 83 | std=[1.0, 1.0, 1.0], 84 | to_rgb=False), 85 | dict(type='Pad', size_divisor=32), 86 | dict(type='ImageToTensor', keys=['img']), 87 | dict(type='Collect', keys=['img']) 88 | ]) 89 | ]), 90 | test=dict( 91 | type='CocoDataset', 92 | ann_file='data/Voc12_CoCo_800/annotations/instances_train2017.json', 93 | img_prefix='data/Voc12_CoCo_800/train2017/', 94 | pipeline=[ 95 | dict(type='LoadImageFromFile'), 96 | dict( 97 | type='MultiScaleFlipAug', 98 | img_scale=(1333, 800), 99 | flip=False, 100 | transforms=[ 101 | dict(type='Resize', keep_ratio=True), 102 | dict(type='RandomFlip'), 103 | dict( 104 | type='Normalize', 105 | mean=[102.9801, 115.9465, 122.7717], 106 | std=[1.0, 1.0, 1.0], 107 | to_rgb=False), 108 | dict(type='Pad', size_divisor=32), 109 | dict(type='ImageToTensor', keys=['img']), 110 | dict(type='Collect', keys=['img']) 111 | ]) 112 | ])) 113 | evaluation = dict(interval=1, metric='bbox') 114 | optimizer = dict( 115 | type='SGD', 116 | lr=0.01, 117 | momentum=0.9, 118 | weight_decay=0.0001, 119 | paramwise_cfg=dict(bias_lr_mult=2.0, bias_decay_mult=0.0)) 120 | optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) 121 | lr_config = dict( 122 | policy='step', 123 | warmup='constant', 124 | warmup_iters=500, 125 | warmup_ratio=0.3333333333333333, 126 | step=[16, 22]) 127 | runner = dict(type='EpochBasedRunner', max_epochs=24) 128 | checkpoint_config = dict(interval=1) 129 | log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')]) 130 | custom_hooks = [dict(type='NumClassCheckHook')] 131 | dist_params = dict(backend='nccl') 132 | log_level = 'INFO' 133 | load_from = None 134 | resume_from = None 135 | workflow = [('train', 1)] 136 | opencv_num_threads = 0 137 | mp_start_method = 'fork' 138 | auto_scale_lr = dict(enable=False, base_batch_size=16) 139 | model = dict( 140 | type='FCOS', 141 | backbone=dict( 142 | type='ResNet', 143 | depth=50, 144 | num_stages=4, 145 | out_indices=(0, 1, 2, 3), 146 | frozen_stages=1, 147 | norm_cfg=dict(type='BN', requires_grad=False), 148 | norm_eval=True, 149 | style='caffe', 150 | init_cfg=dict( 151 | type='Pretrained', 152 | checkpoint='open-mmlab://detectron/resnet50_caffe')), 153 | neck=dict( 154 | type='FPN', 155 | in_channels=[256, 512, 1024, 2048], 156 | out_channels=256, 157 | start_level=1, 158 | add_extra_convs='on_output', 159 | num_outs=5, 160 | relu_before_extra_convs=True), 161 | bbox_head=dict( 162 | type='FCOSHead', 163 | num_classes=80, 164 | in_channels=256, 165 | stacked_convs=4, 166 | feat_channels=256, 167 | strides=[8, 16, 32, 64, 128], 168 | loss_cls=dict( 169 | type='FocalLoss', 170 | use_sigmoid=True, 171 | gamma=2.0, 172 | alpha=0.25, 173 | loss_weight=1.0), 174 | loss_bbox=dict(type='IoULoss', loss_weight=1.0), 175 | loss_centerness=dict( 176 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), 177 | train_cfg=dict( 178 | assigner=dict( 179 | type='MaxIoUAssigner', 180 | pos_iou_thr=0.5, 181 | neg_iou_thr=0.4, 182 | min_pos_iou=0, 183 | ignore_iof_thr=-1), 184 | allowed_border=-1, 185 | pos_weight=-1, 186 | debug=False), 187 | test_cfg=dict( 188 | nms_pre=1000, 189 | min_bbox_size=0, 190 | score_thr=0.05, 191 | nms=dict(type='nms', iou_threshold=0.5), 192 | max_per_img=100)) 193 | -------------------------------------------------------------------------------- /ummdet/checkpoints/eval_cfg/mask_rcnn_swin-t-p4-w7_fpn_ms-crop-3x_coco.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='MaskRCNN', 3 | backbone=dict( 4 | type='SwinTransformer', 5 | embed_dims=96, 6 | depths=[2, 2, 6, 2], 7 | num_heads=[3, 6, 12, 24], 8 | window_size=7, 9 | mlp_ratio=4, 10 | qkv_bias=True, 11 | qk_scale=None, 12 | drop_rate=0.0, 13 | attn_drop_rate=0.0, 14 | drop_path_rate=0.2, 15 | patch_norm=True, 16 | out_indices=(0, 1, 2, 3), 17 | with_cp=False, 18 | convert_weights=True, 19 | init_cfg=dict( 20 | type='Pretrained', 21 | checkpoint= 22 | 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' 23 | )), 24 | neck=dict( 25 | type='FPN', 26 | in_channels=[96, 192, 384, 768], 27 | out_channels=256, 28 | num_outs=5), 29 | rpn_head=dict( 30 | type='RPNHead', 31 | in_channels=256, 32 | feat_channels=256, 33 | anchor_generator=dict( 34 | type='AnchorGenerator', 35 | scales=[8], 36 | ratios=[0.5, 1.0, 2.0], 37 | strides=[4, 8, 16, 32, 64]), 38 | bbox_coder=dict( 39 | type='DeltaXYWHBBoxCoder', 40 | target_means=[0.0, 0.0, 0.0, 0.0], 41 | target_stds=[1.0, 1.0, 1.0, 1.0]), 42 | loss_cls=dict( 43 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 44 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 45 | roi_head=dict( 46 | type='StandardRoIHead', 47 | bbox_roi_extractor=dict( 48 | type='SingleRoIExtractor', 49 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), 50 | out_channels=256, 51 | featmap_strides=[4, 8, 16, 32]), 52 | bbox_head=dict( 53 | type='Shared2FCBBoxHead', 54 | in_channels=256, 55 | fc_out_channels=1024, 56 | roi_feat_size=7, 57 | num_classes=80, 58 | bbox_coder=dict( 59 | type='DeltaXYWHBBoxCoder', 60 | target_means=[0.0, 0.0, 0.0, 0.0], 61 | target_stds=[0.1, 0.1, 0.2, 0.2]), 62 | reg_class_agnostic=False, 63 | loss_cls=dict( 64 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 65 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 66 | mask_roi_extractor=dict( 67 | type='SingleRoIExtractor', 68 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), 69 | out_channels=256, 70 | featmap_strides=[4, 8, 16, 32]), 71 | mask_head=dict( 72 | type='FCNMaskHead', 73 | num_convs=4, 74 | in_channels=256, 75 | conv_out_channels=256, 76 | num_classes=80, 77 | loss_mask=dict( 78 | type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), 79 | train_cfg=dict( 80 | rpn=dict( 81 | assigner=dict( 82 | type='MaxIoUAssigner', 83 | pos_iou_thr=0.7, 84 | neg_iou_thr=0.3, 85 | min_pos_iou=0.3, 86 | match_low_quality=True, 87 | ignore_iof_thr=-1), 88 | sampler=dict( 89 | type='RandomSampler', 90 | num=256, 91 | pos_fraction=0.5, 92 | neg_pos_ub=-1, 93 | add_gt_as_proposals=False), 94 | allowed_border=-1, 95 | pos_weight=-1, 96 | debug=False), 97 | rpn_proposal=dict( 98 | nms_pre=2000, 99 | max_per_img=1000, 100 | nms=dict(type='nms', iou_threshold=0.7), 101 | min_bbox_size=0), 102 | rcnn=dict( 103 | assigner=dict( 104 | type='MaxIoUAssigner', 105 | pos_iou_thr=0.5, 106 | neg_iou_thr=0.5, 107 | min_pos_iou=0.5, 108 | match_low_quality=True, 109 | ignore_iof_thr=-1), 110 | sampler=dict( 111 | type='RandomSampler', 112 | num=512, 113 | pos_fraction=0.25, 114 | neg_pos_ub=-1, 115 | add_gt_as_proposals=True), 116 | mask_size=28, 117 | pos_weight=-1, 118 | debug=False)), 119 | test_cfg=dict( 120 | rpn=dict( 121 | nms_pre=1000, 122 | max_per_img=1000, 123 | nms=dict(type='nms', iou_threshold=0.7), 124 | min_bbox_size=0), 125 | rcnn=dict( 126 | score_thr=0.05, 127 | nms=dict(type='nms', iou_threshold=0.5), 128 | max_per_img=100, 129 | mask_thr_binary=0.5))) 130 | dataset_type = 'CocoDataset' 131 | data_root = 'data/coco/' 132 | img_norm_cfg = dict( 133 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 134 | train_pipeline = [ 135 | dict(type='LoadImageFromFile'), 136 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True), 137 | dict(type='RandomFlip', flip_ratio=0.5), 138 | dict( 139 | type='AutoAugment', 140 | policies=[[{ 141 | 'type': 142 | 'Resize', 143 | 'img_scale': [(480, 1333), (512, 1333), (544, 1333), (576, 1333), 144 | (608, 1333), (640, 1333), (672, 1333), (704, 1333), 145 | (736, 1333), (768, 1333), (800, 1333)], 146 | 'multiscale_mode': 147 | 'value', 148 | 'keep_ratio': 149 | True 150 | }], 151 | [{ 152 | 'type': 'Resize', 153 | 'img_scale': [(400, 1333), (500, 1333), (600, 1333)], 154 | 'multiscale_mode': 'value', 155 | 'keep_ratio': True 156 | }, { 157 | 'type': 'RandomCrop', 158 | 'crop_type': 'absolute_range', 159 | 'crop_size': (384, 600), 160 | 'allow_negative_crop': True 161 | }, { 162 | 'type': 163 | 'Resize', 164 | 'img_scale': [(480, 1333), (512, 1333), (544, 1333), 165 | (576, 1333), (608, 1333), (640, 1333), 166 | (672, 1333), (704, 1333), (736, 1333), 167 | (768, 1333), (800, 1333)], 168 | 'multiscale_mode': 169 | 'value', 170 | 'override': 171 | True, 172 | 'keep_ratio': 173 | True 174 | }]]), 175 | dict( 176 | type='Normalize', 177 | mean=[123.675, 116.28, 103.53], 178 | std=[58.395, 57.12, 57.375], 179 | to_rgb=True), 180 | dict(type='Pad', size_divisor=32), 181 | dict(type='DefaultFormatBundle'), 182 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']) 183 | ] 184 | test_pipeline = [ 185 | dict(type='LoadImageFromFile'), 186 | dict( 187 | type='MultiScaleFlipAug', 188 | img_scale=(1333, 800), 189 | flip=False, 190 | transforms=[ 191 | dict(type='Resize', keep_ratio=True), 192 | dict(type='RandomFlip'), 193 | dict( 194 | type='Normalize', 195 | mean=[123.675, 116.28, 103.53], 196 | std=[58.395, 57.12, 57.375], 197 | to_rgb=True), 198 | dict(type='Pad', size_divisor=32), 199 | dict(type='ImageToTensor', keys=['img']), 200 | dict(type='Collect', keys=['img']) 201 | ]) 202 | ] 203 | data = dict( 204 | samples_per_gpu=2, 205 | workers_per_gpu=2, 206 | train=dict( 207 | type='CocoDataset', 208 | ann_file='data/coco/annotations/instances_train2017.json', 209 | img_prefix='data/coco/train2017/', 210 | pipeline=[ 211 | dict(type='LoadImageFromFile'), 212 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True), 213 | dict(type='RandomFlip', flip_ratio=0.5), 214 | dict( 215 | type='AutoAugment', 216 | policies=[[{ 217 | 'type': 218 | 'Resize', 219 | 'img_scale': [(480, 1333), (512, 1333), (544, 1333), 220 | (576, 1333), (608, 1333), (640, 1333), 221 | (672, 1333), (704, 1333), (736, 1333), 222 | (768, 1333), (800, 1333)], 223 | 'multiscale_mode': 224 | 'value', 225 | 'keep_ratio': 226 | True 227 | }], 228 | [{ 229 | 'type': 'Resize', 230 | 'img_scale': [(400, 1333), (500, 1333), 231 | (600, 1333)], 232 | 'multiscale_mode': 'value', 233 | 'keep_ratio': True 234 | }, { 235 | 'type': 'RandomCrop', 236 | 'crop_type': 'absolute_range', 237 | 'crop_size': (384, 600), 238 | 'allow_negative_crop': True 239 | }, { 240 | 'type': 241 | 'Resize', 242 | 'img_scale': [(480, 1333), (512, 1333), 243 | (544, 1333), (576, 1333), 244 | (608, 1333), (640, 1333), 245 | (672, 1333), (704, 1333), 246 | (736, 1333), (768, 1333), 247 | (800, 1333)], 248 | 'multiscale_mode': 249 | 'value', 250 | 'override': 251 | True, 252 | 'keep_ratio': 253 | True 254 | }]]), 255 | dict( 256 | type='Normalize', 257 | mean=[123.675, 116.28, 103.53], 258 | std=[58.395, 57.12, 57.375], 259 | to_rgb=True), 260 | dict(type='Pad', size_divisor=32), 261 | dict(type='DefaultFormatBundle'), 262 | dict( 263 | type='Collect', 264 | keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']) 265 | ]), 266 | val=dict( 267 | type='CocoDataset', 268 | ann_file='data/coco/annotations/instances_val2017.json', 269 | img_prefix='data/coco/val2017/', 270 | pipeline=[ 271 | dict(type='LoadImageFromFile'), 272 | dict( 273 | type='MultiScaleFlipAug', 274 | img_scale=(1333, 800), 275 | flip=False, 276 | transforms=[ 277 | dict(type='Resize', keep_ratio=True), 278 | dict(type='RandomFlip'), 279 | dict( 280 | type='Normalize', 281 | mean=[123.675, 116.28, 103.53], 282 | std=[58.395, 57.12, 57.375], 283 | to_rgb=True), 284 | dict(type='Pad', size_divisor=32), 285 | dict(type='ImageToTensor', keys=['img']), 286 | dict(type='Collect', keys=['img']) 287 | ]) 288 | ]), 289 | test=dict( 290 | type='CocoDataset', 291 | ann_file='data/Voc12_CoCo_800/annotations/instances_train2017.json', 292 | img_prefix='data/Voc12_CoCo_800/train2017/', 293 | pipeline=[ 294 | dict(type='LoadImageFromFile'), 295 | dict( 296 | type='MultiScaleFlipAug', 297 | img_scale=(1333, 800), 298 | flip=False, 299 | transforms=[ 300 | dict(type='Resize', keep_ratio=True), 301 | dict(type='RandomFlip'), 302 | dict( 303 | type='Normalize', 304 | mean=[123.675, 116.28, 103.53], 305 | std=[58.395, 57.12, 57.375], 306 | to_rgb=True), 307 | dict(type='Pad', size_divisor=32), 308 | dict(type='ImageToTensor', keys=['img']), 309 | dict(type='Collect', keys=['img']) 310 | ]) 311 | ])) 312 | evaluation = dict(metric=['bbox', 'segm']) 313 | optimizer = dict( 314 | type='AdamW', 315 | lr=0.0001, 316 | betas=(0.9, 0.999), 317 | weight_decay=0.05, 318 | paramwise_cfg=dict( 319 | custom_keys=dict( 320 | absolute_pos_embed=dict(decay_mult=0.0), 321 | relative_position_bias_table=dict(decay_mult=0.0), 322 | norm=dict(decay_mult=0.0)))) 323 | optimizer_config = dict(grad_clip=None) 324 | lr_config = dict( 325 | policy='step', 326 | warmup='linear', 327 | warmup_iters=1000, 328 | warmup_ratio=0.001, 329 | step=[27, 33]) 330 | runner = dict(type='EpochBasedRunner', max_epochs=36) 331 | checkpoint_config = dict(interval=1) 332 | log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')]) 333 | custom_hooks = [dict(type='NumClassCheckHook')] 334 | dist_params = dict(backend='nccl') 335 | log_level = 'INFO' 336 | load_from = None 337 | resume_from = None 338 | workflow = [('train', 1)] 339 | opencv_num_threads = 0 340 | mp_start_method = 'fork' 341 | auto_scale_lr = dict(enable=False, base_batch_size=16) 342 | pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' 343 | -------------------------------------------------------------------------------- /ummdet/checkpoints/eval_cfg/vfnet_r50_fpn_mstrain_2x_coco.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'CocoDataset' 2 | data_root = 'data/coco/' 3 | img_norm_cfg = dict( 4 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 5 | train_pipeline = [ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations', with_bbox=True), 8 | dict( 9 | type='Resize', 10 | img_scale=[(1333, 480), (1333, 960)], 11 | multiscale_mode='range', 12 | keep_ratio=True), 13 | dict(type='RandomFlip', flip_ratio=0.5), 14 | dict( 15 | type='Normalize', 16 | mean=[123.675, 116.28, 103.53], 17 | std=[58.395, 57.12, 57.375], 18 | to_rgb=True), 19 | dict(type='Pad', size_divisor=32), 20 | dict(type='DefaultFormatBundle'), 21 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 22 | ] 23 | test_pipeline = [ 24 | dict(type='LoadImageFromFile'), 25 | dict( 26 | type='MultiScaleFlipAug', 27 | img_scale=(1333, 800), 28 | flip=False, 29 | transforms=[ 30 | dict(type='Resize', keep_ratio=True), 31 | dict(type='RandomFlip'), 32 | dict( 33 | type='Normalize', 34 | mean=[123.675, 116.28, 103.53], 35 | std=[58.395, 57.12, 57.375], 36 | to_rgb=True), 37 | dict(type='Pad', size_divisor=32), 38 | dict(type='DefaultFormatBundle'), 39 | dict(type='Collect', keys=['img']) 40 | ]) 41 | ] 42 | data = dict( 43 | samples_per_gpu=2, 44 | workers_per_gpu=2, 45 | train=dict( 46 | type='CocoDataset', 47 | ann_file='data/coco/annotations/instances_train2017.json', 48 | img_prefix='data/coco/train2017/', 49 | pipeline=[ 50 | dict(type='LoadImageFromFile'), 51 | dict(type='LoadAnnotations', with_bbox=True), 52 | dict( 53 | type='Resize', 54 | img_scale=[(1333, 480), (1333, 960)], 55 | multiscale_mode='range', 56 | keep_ratio=True), 57 | dict(type='RandomFlip', flip_ratio=0.5), 58 | dict( 59 | type='Normalize', 60 | mean=[123.675, 116.28, 103.53], 61 | std=[58.395, 57.12, 57.375], 62 | to_rgb=True), 63 | dict(type='Pad', size_divisor=32), 64 | dict(type='DefaultFormatBundle'), 65 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 66 | ]), 67 | val=dict( 68 | type='CocoDataset', 69 | ann_file='data/coco/annotations/instances_val2017.json', 70 | img_prefix='data/coco/val2017/', 71 | pipeline=[ 72 | dict(type='LoadImageFromFile'), 73 | dict( 74 | type='MultiScaleFlipAug', 75 | img_scale=(1333, 800), 76 | flip=False, 77 | transforms=[ 78 | dict(type='Resize', keep_ratio=True), 79 | dict(type='RandomFlip'), 80 | dict( 81 | type='Normalize', 82 | mean=[123.675, 116.28, 103.53], 83 | std=[58.395, 57.12, 57.375], 84 | to_rgb=True), 85 | dict(type='Pad', size_divisor=32), 86 | dict(type='DefaultFormatBundle'), 87 | dict(type='Collect', keys=['img']) 88 | ]) 89 | ]), 90 | test=dict( 91 | type='CocoDataset', 92 | ann_file='data/Voc12_CoCo_800/annotations/instances_train2017.json', 93 | img_prefix='data/Voc12_CoCo_800/train2017/', 94 | pipeline=[ 95 | dict(type='LoadImageFromFile'), 96 | dict( 97 | type='MultiScaleFlipAug', 98 | img_scale=(1333, 800), 99 | flip=False, 100 | transforms=[ 101 | dict(type='Resize', keep_ratio=True), 102 | dict(type='RandomFlip'), 103 | dict( 104 | type='Normalize', 105 | mean=[123.675, 116.28, 103.53], 106 | std=[58.395, 57.12, 57.375], 107 | to_rgb=True), 108 | dict(type='Pad', size_divisor=32), 109 | dict(type='DefaultFormatBundle'), 110 | dict(type='Collect', keys=['img']) 111 | ]) 112 | ])) 113 | evaluation = dict(interval=1, metric='bbox') 114 | optimizer = dict( 115 | type='SGD', 116 | lr=0.01, 117 | momentum=0.9, 118 | weight_decay=0.0001, 119 | paramwise_cfg=dict(bias_lr_mult=2.0, bias_decay_mult=0.0)) 120 | optimizer_config = dict(grad_clip=None) 121 | lr_config = dict( 122 | policy='step', 123 | warmup='linear', 124 | warmup_iters=500, 125 | warmup_ratio=0.1, 126 | step=[16, 22]) 127 | runner = dict(type='EpochBasedRunner', max_epochs=24) 128 | checkpoint_config = dict(interval=1) 129 | log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')]) 130 | custom_hooks = [dict(type='NumClassCheckHook')] 131 | dist_params = dict(backend='nccl') 132 | log_level = 'INFO' 133 | load_from = None 134 | resume_from = None 135 | workflow = [('train', 1)] 136 | opencv_num_threads = 0 137 | mp_start_method = 'fork' 138 | auto_scale_lr = dict(enable=False, base_batch_size=16) 139 | model = dict( 140 | type='VFNet', 141 | backbone=dict( 142 | type='ResNet', 143 | depth=50, 144 | num_stages=4, 145 | out_indices=(0, 1, 2, 3), 146 | frozen_stages=1, 147 | norm_cfg=dict(type='BN', requires_grad=True), 148 | norm_eval=True, 149 | style='pytorch', 150 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), 151 | neck=dict( 152 | type='FPN', 153 | in_channels=[256, 512, 1024, 2048], 154 | out_channels=256, 155 | start_level=1, 156 | add_extra_convs='on_output', 157 | num_outs=5, 158 | relu_before_extra_convs=True), 159 | bbox_head=dict( 160 | type='VFNetHead', 161 | num_classes=80, 162 | in_channels=256, 163 | stacked_convs=3, 164 | feat_channels=256, 165 | strides=[8, 16, 32, 64, 128], 166 | center_sampling=False, 167 | dcn_on_last_conv=False, 168 | use_atss=True, 169 | use_vfl=True, 170 | loss_cls=dict( 171 | type='VarifocalLoss', 172 | use_sigmoid=True, 173 | alpha=0.75, 174 | gamma=2.0, 175 | iou_weighted=True, 176 | loss_weight=1.0), 177 | loss_bbox=dict(type='GIoULoss', loss_weight=1.5), 178 | loss_bbox_refine=dict(type='GIoULoss', loss_weight=2.0)), 179 | train_cfg=dict( 180 | assigner=dict(type='ATSSAssigner', topk=9), 181 | allowed_border=-1, 182 | pos_weight=-1, 183 | debug=False), 184 | test_cfg=dict( 185 | nms_pre=1000, 186 | min_bbox_size=0, 187 | score_thr=0.05, 188 | nms=dict(type='nms', iou_threshold=0.6), 189 | max_per_img=100)) 190 | -------------------------------------------------------------------------------- /ummdet/checkpoints/eval_cfg/yolof_r50_c5_8x8_1x_coco.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'CocoDataset' 2 | data_root = 'data/coco/' 3 | img_norm_cfg = dict( 4 | mean=[103.53, 116.28, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 5 | train_pipeline = [ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations', with_bbox=True), 8 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), 9 | dict(type='RandomFlip', flip_ratio=0.5), 10 | dict(type='RandomShift', shift_ratio=0.5, max_shift_px=32), 11 | dict( 12 | type='Normalize', 13 | mean=[103.53, 116.28, 123.675], 14 | std=[1.0, 1.0, 1.0], 15 | to_rgb=False), 16 | dict(type='Pad', size_divisor=32), 17 | dict(type='DefaultFormatBundle'), 18 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 19 | ] 20 | test_pipeline = [ 21 | dict(type='LoadImageFromFile'), 22 | dict( 23 | type='MultiScaleFlipAug', 24 | img_scale=(1333, 800), 25 | flip=False, 26 | transforms=[ 27 | dict(type='Resize', keep_ratio=True), 28 | dict(type='RandomFlip'), 29 | dict( 30 | type='Normalize', 31 | mean=[103.53, 116.28, 123.675], 32 | std=[1.0, 1.0, 1.0], 33 | to_rgb=False), 34 | dict(type='Pad', size_divisor=32), 35 | dict(type='ImageToTensor', keys=['img']), 36 | dict(type='Collect', keys=['img']) 37 | ]) 38 | ] 39 | data = dict( 40 | samples_per_gpu=8, 41 | workers_per_gpu=8, 42 | train=dict( 43 | type='CocoDataset', 44 | ann_file='data/coco/annotations/instances_train2017.json', 45 | img_prefix='data/coco/train2017/', 46 | pipeline=[ 47 | dict(type='LoadImageFromFile'), 48 | dict(type='LoadAnnotations', with_bbox=True), 49 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), 50 | dict(type='RandomFlip', flip_ratio=0.5), 51 | dict(type='RandomShift', shift_ratio=0.5, max_shift_px=32), 52 | dict( 53 | type='Normalize', 54 | mean=[103.53, 116.28, 123.675], 55 | std=[1.0, 1.0, 1.0], 56 | to_rgb=False), 57 | dict(type='Pad', size_divisor=32), 58 | dict(type='DefaultFormatBundle'), 59 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 60 | ]), 61 | val=dict( 62 | type='CocoDataset', 63 | ann_file='data/coco/annotations/instances_val2017.json', 64 | img_prefix='data/coco/val2017/', 65 | pipeline=[ 66 | dict(type='LoadImageFromFile'), 67 | dict( 68 | type='MultiScaleFlipAug', 69 | img_scale=(1333, 800), 70 | flip=False, 71 | transforms=[ 72 | dict(type='Resize', keep_ratio=True), 73 | dict(type='RandomFlip'), 74 | dict( 75 | type='Normalize', 76 | mean=[103.53, 116.28, 123.675], 77 | std=[1.0, 1.0, 1.0], 78 | to_rgb=False), 79 | dict(type='Pad', size_divisor=32), 80 | dict(type='ImageToTensor', keys=['img']), 81 | dict(type='Collect', keys=['img']) 82 | ]) 83 | ]), 84 | test=dict( 85 | type='CocoDataset', 86 | ann_file='data/Voc12_CoCo_800/annotations/instances_train2017.json', 87 | img_prefix='data/Voc12_CoCo_800/train2017/', 88 | pipeline=[ 89 | dict(type='LoadImageFromFile'), 90 | dict( 91 | type='MultiScaleFlipAug', 92 | img_scale=(1333, 800), 93 | flip=False, 94 | transforms=[ 95 | dict(type='Resize', keep_ratio=True), 96 | dict(type='RandomFlip'), 97 | dict( 98 | type='Normalize', 99 | mean=[103.53, 116.28, 123.675], 100 | std=[1.0, 1.0, 1.0], 101 | to_rgb=False), 102 | dict(type='Pad', size_divisor=32), 103 | dict(type='ImageToTensor', keys=['img']), 104 | dict(type='Collect', keys=['img']) 105 | ]) 106 | ])) 107 | evaluation = dict(interval=1, metric='bbox') 108 | optimizer = dict( 109 | type='SGD', 110 | lr=0.12, 111 | momentum=0.9, 112 | weight_decay=0.0001, 113 | paramwise_cfg=dict( 114 | norm_decay_mult=0.0, 115 | custom_keys=dict(backbone=dict(lr_mult=0.3333333333333333)))) 116 | optimizer_config = dict(grad_clip=None) 117 | lr_config = dict( 118 | policy='step', 119 | warmup='linear', 120 | warmup_iters=1500, 121 | warmup_ratio=0.00066667, 122 | step=[8, 11]) 123 | runner = dict(type='EpochBasedRunner', max_epochs=12) 124 | checkpoint_config = dict(interval=1) 125 | log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')]) 126 | custom_hooks = [dict(type='NumClassCheckHook')] 127 | dist_params = dict(backend='nccl') 128 | log_level = 'INFO' 129 | load_from = None 130 | resume_from = None 131 | workflow = [('train', 1)] 132 | opencv_num_threads = 0 133 | mp_start_method = 'fork' 134 | auto_scale_lr = dict(enable=False, base_batch_size=64) 135 | model = dict( 136 | type='YOLOF', 137 | backbone=dict( 138 | type='ResNet', 139 | depth=50, 140 | num_stages=4, 141 | out_indices=(3, ), 142 | frozen_stages=1, 143 | norm_cfg=dict(type='BN', requires_grad=False), 144 | norm_eval=True, 145 | style='caffe', 146 | init_cfg=dict( 147 | type='Pretrained', 148 | checkpoint='open-mmlab://detectron/resnet50_caffe')), 149 | neck=dict( 150 | type='DilatedEncoder', 151 | in_channels=2048, 152 | out_channels=512, 153 | block_mid_channels=128, 154 | num_residual_blocks=4, 155 | block_dilations=[2, 4, 6, 8]), 156 | bbox_head=dict( 157 | type='YOLOFHead', 158 | num_classes=80, 159 | in_channels=512, 160 | reg_decoded_bbox=True, 161 | anchor_generator=dict( 162 | type='AnchorGenerator', 163 | ratios=[1.0], 164 | scales=[1, 2, 4, 8, 16], 165 | strides=[32]), 166 | bbox_coder=dict( 167 | type='DeltaXYWHBBoxCoder', 168 | target_means=[0.0, 0.0, 0.0, 0.0], 169 | target_stds=[1.0, 1.0, 1.0, 1.0], 170 | add_ctr_clamp=True, 171 | ctr_clamp=32), 172 | loss_cls=dict( 173 | type='FocalLoss', 174 | use_sigmoid=True, 175 | gamma=2.0, 176 | alpha=0.25, 177 | loss_weight=1.0), 178 | loss_bbox=dict(type='GIoULoss', loss_weight=1.0)), 179 | train_cfg=dict( 180 | assigner=dict( 181 | type='UniformAssigner', pos_ignore_thr=0.15, neg_ignore_thr=0.7), 182 | allowed_border=-1, 183 | pos_weight=-1, 184 | debug=False), 185 | test_cfg=dict( 186 | nms_pre=1000, 187 | min_bbox_size=0, 188 | score_thr=0.05, 189 | nms=dict(type='nms', iou_threshold=0.6), 190 | max_per_img=100)) 191 | -------------------------------------------------------------------------------- /ummdet/checkpoints/eval_cfg/yolov3_d53_mstrain-608_273e_coco.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=1) 2 | log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')]) 3 | custom_hooks = [dict(type='NumClassCheckHook')] 4 | dist_params = dict(backend='nccl') 5 | log_level = 'INFO' 6 | load_from = None 7 | resume_from = None 8 | workflow = [('train', 1)] 9 | opencv_num_threads = 0 10 | mp_start_method = 'fork' 11 | auto_scale_lr = dict(enable=False, base_batch_size=64) 12 | model = dict( 13 | type='YOLOV3', 14 | backbone=dict( 15 | type='Darknet', 16 | depth=53, 17 | out_indices=(3, 4, 5), 18 | init_cfg=dict(type='Pretrained', checkpoint='open-mmlab://darknet53')), 19 | neck=dict( 20 | type='YOLOV3Neck', 21 | num_scales=3, 22 | in_channels=[1024, 512, 256], 23 | out_channels=[512, 256, 128]), 24 | bbox_head=dict( 25 | type='YOLOV3Head', 26 | num_classes=80, 27 | in_channels=[512, 256, 128], 28 | out_channels=[1024, 512, 256], 29 | anchor_generator=dict( 30 | type='YOLOAnchorGenerator', 31 | base_sizes=[[(116, 90), (156, 198), (373, 326)], 32 | [(30, 61), (62, 45), (59, 119)], 33 | [(10, 13), (16, 30), (33, 23)]], 34 | strides=[32, 16, 8]), 35 | bbox_coder=dict(type='YOLOBBoxCoder'), 36 | featmap_strides=[32, 16, 8], 37 | loss_cls=dict( 38 | type='CrossEntropyLoss', 39 | use_sigmoid=True, 40 | loss_weight=1.0, 41 | reduction='sum'), 42 | loss_conf=dict( 43 | type='CrossEntropyLoss', 44 | use_sigmoid=True, 45 | loss_weight=1.0, 46 | reduction='sum'), 47 | loss_xy=dict( 48 | type='CrossEntropyLoss', 49 | use_sigmoid=True, 50 | loss_weight=2.0, 51 | reduction='sum'), 52 | loss_wh=dict(type='MSELoss', loss_weight=2.0, reduction='sum')), 53 | train_cfg=dict( 54 | assigner=dict( 55 | type='GridAssigner', 56 | pos_iou_thr=0.5, 57 | neg_iou_thr=0.5, 58 | min_pos_iou=0)), 59 | test_cfg=dict( 60 | nms_pre=1000, 61 | min_bbox_size=0, 62 | score_thr=0.05, 63 | conf_thr=0.005, 64 | nms=dict(type='nms', iou_threshold=0.45), 65 | max_per_img=100)) 66 | dataset_type = 'CocoDataset' 67 | data_root = 'data/coco/' 68 | img_norm_cfg = dict(mean=[0, 0, 0], std=[255.0, 255.0, 255.0], to_rgb=True) 69 | train_pipeline = [ 70 | dict(type='LoadImageFromFile', to_float32=True), 71 | dict(type='LoadAnnotations', with_bbox=True), 72 | dict(type='Expand', mean=[0, 0, 0], to_rgb=True, ratio_range=(1, 2)), 73 | dict( 74 | type='MinIoURandomCrop', 75 | min_ious=(0.4, 0.5, 0.6, 0.7, 0.8, 0.9), 76 | min_crop_size=0.3), 77 | dict(type='Resize', img_scale=[(320, 320), (608, 608)], keep_ratio=True), 78 | dict(type='RandomFlip', flip_ratio=0.5), 79 | dict(type='PhotoMetricDistortion'), 80 | dict( 81 | type='Normalize', 82 | mean=[0, 0, 0], 83 | std=[255.0, 255.0, 255.0], 84 | to_rgb=True), 85 | dict(type='Pad', size_divisor=32), 86 | dict(type='DefaultFormatBundle'), 87 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 88 | ] 89 | test_pipeline = [ 90 | dict(type='LoadImageFromFile'), 91 | dict( 92 | type='MultiScaleFlipAug', 93 | img_scale=(608, 608), 94 | flip=False, 95 | transforms=[ 96 | dict(type='Resize', keep_ratio=True), 97 | dict(type='RandomFlip'), 98 | dict( 99 | type='Normalize', 100 | mean=[0, 0, 0], 101 | std=[255.0, 255.0, 255.0], 102 | to_rgb=True), 103 | dict(type='Pad', size_divisor=32), 104 | dict(type='ImageToTensor', keys=['img']), 105 | dict(type='Collect', keys=['img']) 106 | ]) 107 | ] 108 | data = dict( 109 | samples_per_gpu=8, 110 | workers_per_gpu=4, 111 | train=dict( 112 | type='CocoDataset', 113 | ann_file='data/coco/annotations/instances_train2017.json', 114 | img_prefix='data/coco/train2017/', 115 | pipeline=[ 116 | dict(type='LoadImageFromFile', to_float32=True), 117 | dict(type='LoadAnnotations', with_bbox=True), 118 | dict( 119 | type='Expand', mean=[0, 0, 0], to_rgb=True, 120 | ratio_range=(1, 2)), 121 | dict( 122 | type='MinIoURandomCrop', 123 | min_ious=(0.4, 0.5, 0.6, 0.7, 0.8, 0.9), 124 | min_crop_size=0.3), 125 | dict( 126 | type='Resize', 127 | img_scale=[(320, 320), (608, 608)], 128 | keep_ratio=True), 129 | dict(type='RandomFlip', flip_ratio=0.5), 130 | dict(type='PhotoMetricDistortion'), 131 | dict( 132 | type='Normalize', 133 | mean=[0, 0, 0], 134 | std=[255.0, 255.0, 255.0], 135 | to_rgb=True), 136 | dict(type='Pad', size_divisor=32), 137 | dict(type='DefaultFormatBundle'), 138 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 139 | ]), 140 | val=dict( 141 | type='CocoDataset', 142 | ann_file='data/coco/annotations/instances_val2017.json', 143 | img_prefix='data/coco/val2017/', 144 | pipeline=[ 145 | dict(type='LoadImageFromFile'), 146 | dict( 147 | type='MultiScaleFlipAug', 148 | img_scale=(608, 608), 149 | flip=False, 150 | transforms=[ 151 | dict(type='Resize', keep_ratio=True), 152 | dict(type='RandomFlip'), 153 | dict( 154 | type='Normalize', 155 | mean=[0, 0, 0], 156 | std=[255.0, 255.0, 255.0], 157 | to_rgb=True), 158 | dict(type='Pad', size_divisor=32), 159 | dict(type='ImageToTensor', keys=['img']), 160 | dict(type='Collect', keys=['img']) 161 | ]) 162 | ]), 163 | test=dict( 164 | type='CocoDataset', 165 | ann_file='data/Voc12_CoCo_800/annotations/instances_train2017.json', 166 | img_prefix='data/Voc12_CoCo_800/train2017/', 167 | pipeline=[ 168 | dict(type='LoadImageFromFile'), 169 | dict( 170 | type='MultiScaleFlipAug', 171 | img_scale=(608, 608), 172 | flip=False, 173 | transforms=[ 174 | dict(type='Resize', keep_ratio=True), 175 | dict(type='RandomFlip'), 176 | dict( 177 | type='Normalize', 178 | mean=[0, 0, 0], 179 | std=[255.0, 255.0, 255.0], 180 | to_rgb=True), 181 | dict(type='Pad', size_divisor=32), 182 | dict(type='ImageToTensor', keys=['img']), 183 | dict(type='Collect', keys=['img']) 184 | ]) 185 | ])) 186 | optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0005) 187 | optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) 188 | lr_config = dict( 189 | policy='step', 190 | warmup='linear', 191 | warmup_iters=2000, 192 | warmup_ratio=0.1, 193 | step=[218, 246]) 194 | runner = dict(type='EpochBasedRunner', max_epochs=273) 195 | evaluation = dict(interval=1, metric=['bbox']) 196 | -------------------------------------------------------------------------------- /ummdet/checkpoints/eval_cfg/yolox_l_8x8_300e_coco.py: -------------------------------------------------------------------------------- 1 | optimizer = dict( 2 | type='SGD', 3 | lr=0.01, 4 | momentum=0.9, 5 | weight_decay=0.0005, 6 | nesterov=True, 7 | paramwise_cfg=dict(norm_decay_mult=0.0, bias_decay_mult=0.0)) 8 | optimizer_config = dict(grad_clip=None) 9 | lr_config = dict( 10 | policy='YOLOX', 11 | warmup='exp', 12 | by_epoch=False, 13 | warmup_by_epoch=True, 14 | warmup_ratio=1, 15 | warmup_iters=5, 16 | num_last_epochs=15, 17 | min_lr_ratio=0.05) 18 | runner = dict(type='EpochBasedRunner', max_epochs=300) 19 | checkpoint_config = dict(interval=10) 20 | log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')]) 21 | custom_hooks = [ 22 | dict(type='YOLOXModeSwitchHook', num_last_epochs=15, priority=48), 23 | dict(type='SyncNormHook', num_last_epochs=15, interval=10, priority=48), 24 | dict( 25 | type='ExpMomentumEMAHook', 26 | resume_from=None, 27 | momentum=0.0001, 28 | priority=49) 29 | ] 30 | dist_params = dict(backend='nccl') 31 | log_level = 'INFO' 32 | load_from = None 33 | resume_from = None 34 | workflow = [('train', 1)] 35 | opencv_num_threads = 0 36 | mp_start_method = 'fork' 37 | auto_scale_lr = dict(enable=False, base_batch_size=64) 38 | img_scale = (640, 640) 39 | model = dict( 40 | type='YOLOX', 41 | input_size=(640, 640), 42 | random_size_range=(15, 25), 43 | random_size_interval=10, 44 | backbone=dict(type='CSPDarknet', deepen_factor=1.0, widen_factor=1.0), 45 | neck=dict( 46 | type='YOLOXPAFPN', 47 | in_channels=[256, 512, 1024], 48 | out_channels=256, 49 | num_csp_blocks=3), 50 | bbox_head=dict( 51 | type='YOLOXHead', num_classes=80, in_channels=256, feat_channels=256), 52 | train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)), 53 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) 54 | data_root = 'data/coco/' 55 | dataset_type = 'CocoDataset' 56 | train_pipeline = [ 57 | dict(type='Mosaic', img_scale=(640, 640), pad_val=114.0), 58 | dict( 59 | type='RandomAffine', scaling_ratio_range=(0.1, 2), 60 | border=(-320, -320)), 61 | dict( 62 | type='MixUp', 63 | img_scale=(640, 640), 64 | ratio_range=(0.8, 1.6), 65 | pad_val=114.0), 66 | dict(type='YOLOXHSVRandomAug'), 67 | dict(type='RandomFlip', flip_ratio=0.5), 68 | dict(type='Resize', img_scale=(640, 640), keep_ratio=True), 69 | dict( 70 | type='Pad', 71 | pad_to_square=True, 72 | pad_val=dict(img=(114.0, 114.0, 114.0))), 73 | dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False), 74 | dict(type='DefaultFormatBundle'), 75 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 76 | ] 77 | train_dataset = dict( 78 | type='MultiImageMixDataset', 79 | dataset=dict( 80 | type='CocoDataset', 81 | ann_file='data/coco/annotations/instances_train2017.json', 82 | img_prefix='data/coco/train2017/', 83 | pipeline=[ 84 | dict(type='LoadImageFromFile'), 85 | dict(type='LoadAnnotations', with_bbox=True) 86 | ], 87 | filter_empty_gt=False), 88 | pipeline=[ 89 | dict(type='Mosaic', img_scale=(640, 640), pad_val=114.0), 90 | dict( 91 | type='RandomAffine', 92 | scaling_ratio_range=(0.1, 2), 93 | border=(-320, -320)), 94 | dict( 95 | type='MixUp', 96 | img_scale=(640, 640), 97 | ratio_range=(0.8, 1.6), 98 | pad_val=114.0), 99 | dict(type='YOLOXHSVRandomAug'), 100 | dict(type='RandomFlip', flip_ratio=0.5), 101 | dict(type='Resize', img_scale=(640, 640), keep_ratio=True), 102 | dict( 103 | type='Pad', 104 | pad_to_square=True, 105 | pad_val=dict(img=(114.0, 114.0, 114.0))), 106 | dict( 107 | type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False), 108 | dict(type='DefaultFormatBundle'), 109 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 110 | ]) 111 | test_pipeline = [ 112 | dict(type='LoadImageFromFile'), 113 | dict( 114 | type='MultiScaleFlipAug', 115 | img_scale=(640, 640), 116 | flip=False, 117 | transforms=[ 118 | dict(type='Resize', keep_ratio=True), 119 | dict(type='RandomFlip'), 120 | dict( 121 | type='Pad', 122 | pad_to_square=True, 123 | pad_val=dict(img=(114.0, 114.0, 114.0))), 124 | dict(type='DefaultFormatBundle'), 125 | dict(type='Collect', keys=['img']) 126 | ]) 127 | ] 128 | data = dict( 129 | samples_per_gpu=8, 130 | workers_per_gpu=4, 131 | persistent_workers=True, 132 | train=dict( 133 | type='MultiImageMixDataset', 134 | dataset=dict( 135 | type='CocoDataset', 136 | ann_file='data/coco/annotations/instances_train2017.json', 137 | img_prefix='data/coco/train2017/', 138 | pipeline=[ 139 | dict(type='LoadImageFromFile'), 140 | dict(type='LoadAnnotations', with_bbox=True) 141 | ], 142 | filter_empty_gt=False), 143 | pipeline=[ 144 | dict(type='Mosaic', img_scale=(640, 640), pad_val=114.0), 145 | dict( 146 | type='RandomAffine', 147 | scaling_ratio_range=(0.1, 2), 148 | border=(-320, -320)), 149 | dict( 150 | type='MixUp', 151 | img_scale=(640, 640), 152 | ratio_range=(0.8, 1.6), 153 | pad_val=114.0), 154 | dict(type='YOLOXHSVRandomAug'), 155 | dict(type='RandomFlip', flip_ratio=0.5), 156 | dict(type='Resize', img_scale=(640, 640), keep_ratio=True), 157 | dict( 158 | type='Pad', 159 | pad_to_square=True, 160 | pad_val=dict(img=(114.0, 114.0, 114.0))), 161 | dict( 162 | type='FilterAnnotations', 163 | min_gt_bbox_wh=(1, 1), 164 | keep_empty=False), 165 | dict(type='DefaultFormatBundle'), 166 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 167 | ]), 168 | val=dict( 169 | type='CocoDataset', 170 | ann_file='data/coco/annotations/instances_val2017.json', 171 | img_prefix='data/coco/val2017/', 172 | pipeline=[ 173 | dict(type='LoadImageFromFile'), 174 | dict( 175 | type='MultiScaleFlipAug', 176 | img_scale=(640, 640), 177 | flip=False, 178 | transforms=[ 179 | dict(type='Resize', keep_ratio=True), 180 | dict(type='RandomFlip'), 181 | dict( 182 | type='Pad', 183 | pad_to_square=True, 184 | pad_val=dict(img=(114.0, 114.0, 114.0))), 185 | dict(type='DefaultFormatBundle'), 186 | dict(type='Collect', keys=['img']) 187 | ]) 188 | ]), 189 | test=dict( 190 | type='CocoDataset', 191 | ann_file='data/Voc12_CoCo_800/annotations/instances_train2017.json', 192 | img_prefix='data/Voc12_CoCo_800/train2017/', 193 | pipeline=[ 194 | dict(type='LoadImageFromFile'), 195 | dict( 196 | type='MultiScaleFlipAug', 197 | img_scale=(640, 640), 198 | flip=False, 199 | transforms=[ 200 | dict(type='Resize', keep_ratio=True), 201 | dict(type='RandomFlip'), 202 | dict( 203 | type='Pad', 204 | pad_to_square=True, 205 | pad_val=dict(img=(114.0, 114.0, 114.0))), 206 | dict(type='DefaultFormatBundle'), 207 | dict(type='Collect', keys=['img']) 208 | ]) 209 | ])) 210 | max_epochs = 300 211 | num_last_epochs = 15 212 | interval = 10 213 | evaluation = dict( 214 | save_best='auto', interval=10, dynamic_intervals=[(285, 1)], metric='bbox') 215 | -------------------------------------------------------------------------------- /ummdet/checkpoints/train_cfg/faster_rcnn_r101_caffe_fpn_mstrain_3x_coco.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=1) 2 | log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')]) 3 | custom_hooks = [dict(type='NumClassCheckHook')] 4 | dist_params = dict(backend='nccl') 5 | log_level = 'INFO' 6 | load_from = None 7 | resume_from = None 8 | workflow = [('train', 1)] 9 | opencv_num_threads = 0 10 | mp_start_method = 'fork' 11 | auto_scale_lr = dict(enable=False, base_batch_size=16) 12 | dataset_type = 'CocoDataset' 13 | data_root = 'data/coco/' 14 | img_norm_cfg = dict( 15 | mean=[103.53, 116.28, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 16 | train_pipeline = [ 17 | dict(type='LoadImageFromFile'), 18 | dict(type='LoadAnnotations', with_bbox=True), 19 | dict( 20 | type='Resize', 21 | img_scale=[(1333, 640), (1333, 800)], 22 | multiscale_mode='range', 23 | keep_ratio=True), 24 | dict(type='RandomFlip', flip_ratio=0.5), 25 | dict( 26 | type='Normalize', 27 | mean=[103.53, 116.28, 123.675], 28 | std=[1.0, 1.0, 1.0], 29 | to_rgb=False), 30 | dict(type='Pad', size_divisor=32), 31 | dict(type='DefaultFormatBundle'), 32 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 33 | ] 34 | test_pipeline = [ 35 | dict(type='LoadImageFromFile'), 36 | dict( 37 | type='MultiScaleFlipAug', 38 | img_scale=(1333, 800), 39 | flip=False, 40 | transforms=[ 41 | dict(type='Resize', keep_ratio=True), 42 | dict(type='RandomFlip'), 43 | dict( 44 | type='Normalize', 45 | mean=[103.53, 116.28, 123.675], 46 | std=[1.0, 1.0, 1.0], 47 | to_rgb=False), 48 | dict(type='Pad', size_divisor=32), 49 | dict(type='ImageToTensor', keys=['img']), 50 | dict(type='Collect', keys=['img']) 51 | ]) 52 | ] 53 | data = dict( 54 | samples_per_gpu=2, 55 | workers_per_gpu=2, 56 | train=dict( 57 | type='RepeatDataset', 58 | times=3, 59 | dataset=dict( 60 | type='CocoDataset', 61 | ann_file='data/coco/annotations/instances_train2017.json', 62 | img_prefix='data/coco/train2017/', 63 | pipeline=[ 64 | dict(type='LoadImageFromFile'), 65 | dict(type='LoadAnnotations', with_bbox=True), 66 | dict( 67 | type='Resize', 68 | img_scale=[(1333, 640), (1333, 800)], 69 | multiscale_mode='range', 70 | keep_ratio=True), 71 | dict(type='RandomFlip', flip_ratio=0.5), 72 | dict( 73 | type='Normalize', 74 | mean=[103.53, 116.28, 123.675], 75 | std=[1.0, 1.0, 1.0], 76 | to_rgb=False), 77 | dict(type='Pad', size_divisor=32), 78 | dict(type='DefaultFormatBundle'), 79 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 80 | ])), 81 | val=dict( 82 | type='CocoDataset', 83 | ann_file='data/coco/annotations/instances_val2017.json', 84 | img_prefix='data/coco/val2017/', 85 | pipeline=[ 86 | dict(type='LoadImageFromFile'), 87 | dict( 88 | type='MultiScaleFlipAug', 89 | img_scale=(1333, 800), 90 | flip=False, 91 | transforms=[ 92 | dict(type='Resize', keep_ratio=True), 93 | dict(type='RandomFlip'), 94 | dict( 95 | type='Normalize', 96 | mean=[103.53, 116.28, 123.675], 97 | std=[1.0, 1.0, 1.0], 98 | to_rgb=False), 99 | dict(type='Pad', size_divisor=32), 100 | dict(type='ImageToTensor', keys=['img']), 101 | dict(type='Collect', keys=['img']) 102 | ]) 103 | ]), 104 | test=dict( 105 | type='CocoDatasetAdv', # Load Labels for FIA / NAA / RPA 106 | ann_file='data/coco/annotations/instances_val2017.json', 107 | img_prefix='data/coco/val2017/', 108 | pipeline=[ 109 | dict(type='LoadImageFromFile'), 110 | dict(type='LoadAnnotations', with_bbox=True), # Load Labels for FIA / NAA / RPA 111 | dict( 112 | type='MultiScaleFlipAug', 113 | img_scale=(1333, 800), 114 | flip=False, 115 | transforms=[ 116 | dict(type='Resize', keep_ratio=True), 117 | dict(type='RandomFlip'), 118 | dict( 119 | type='Normalize', 120 | mean=[103.53, 116.28, 123.675], 121 | std=[1.0, 1.0, 1.0], 122 | to_rgb=False), 123 | dict(type='Pad', size_divisor=32), 124 | # dict(type='ImageToTensor', keys=['img']), 125 | dict(type='DefaultFormatBundle'), # Load Labels for FIA / NAA / RPA 126 | # dict(type='Collect', keys=['img']) 127 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) # Load Labels for FIA / NAA / RPA 128 | ]) 129 | ])) 130 | evaluation = dict(interval=1, metric='bbox') 131 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) 132 | optimizer_config = dict(grad_clip=None) 133 | lr_config = dict( 134 | policy='step', 135 | warmup='linear', 136 | warmup_iters=500, 137 | warmup_ratio=0.001, 138 | step=[9, 11]) 139 | runner = dict(type='EpochBasedRunner', max_epochs=12) 140 | model = dict( 141 | type='FasterRCNNAdv', 142 | backbone=dict( 143 | type='ResNet', 144 | depth=101, 145 | num_stages=4, 146 | out_indices=(0, 1, 2, 3), 147 | frozen_stages=1, 148 | norm_cfg=dict(type='BN', requires_grad=False), 149 | norm_eval=True, 150 | style='caffe', 151 | init_cfg=dict( 152 | type='Pretrained', 153 | checkpoint='open-mmlab://detectron2/resnet101_caffe')), 154 | neck=dict( 155 | type='FPN', 156 | in_channels=[256, 512, 1024, 2048], 157 | out_channels=256, 158 | num_outs=5), 159 | rpn_head=dict( 160 | type='RPNHead', 161 | in_channels=256, 162 | feat_channels=256, 163 | anchor_generator=dict( 164 | type='AnchorGenerator', 165 | scales=[8], 166 | ratios=[0.5, 1.0, 2.0], 167 | strides=[4, 8, 16, 32, 64]), 168 | bbox_coder=dict( 169 | type='DeltaXYWHBBoxCoder', 170 | target_means=[0.0, 0.0, 0.0, 0.0], 171 | target_stds=[1.0, 1.0, 1.0, 1.0]), 172 | loss_cls=dict( 173 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 174 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 175 | roi_head=dict( 176 | type='StandardRoIHead', 177 | bbox_roi_extractor=dict( 178 | type='SingleRoIExtractor', 179 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), 180 | out_channels=256, 181 | featmap_strides=[4, 8, 16, 32]), 182 | bbox_head=dict( 183 | type='Shared2FCBBoxHead', 184 | in_channels=256, 185 | fc_out_channels=1024, 186 | roi_feat_size=7, 187 | num_classes=80, 188 | bbox_coder=dict( 189 | type='DeltaXYWHBBoxCoder', 190 | target_means=[0.0, 0.0, 0.0, 0.0], 191 | target_stds=[0.1, 0.1, 0.2, 0.2]), 192 | reg_class_agnostic=False, 193 | loss_cls=dict( 194 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 195 | loss_bbox=dict(type='L1Loss', loss_weight=1.0))), 196 | train_cfg=dict( 197 | rpn=dict( 198 | assigner=dict( 199 | type='MaxIoUAssigner', 200 | pos_iou_thr=0.7, 201 | neg_iou_thr=0.3, 202 | min_pos_iou=0.3, 203 | match_low_quality=True, 204 | ignore_iof_thr=-1), 205 | sampler=dict( 206 | type='RandomSampler', 207 | num=256, 208 | pos_fraction=0.5, 209 | neg_pos_ub=-1, 210 | add_gt_as_proposals=False), 211 | allowed_border=-1, 212 | pos_weight=-1, 213 | debug=False), 214 | rpn_proposal=dict( 215 | nms_pre=2000, 216 | max_per_img=1000, 217 | nms=dict(type='nms', iou_threshold=0.7), 218 | min_bbox_size=0), 219 | rcnn=dict( 220 | assigner=dict( 221 | type='MaxIoUAssigner', 222 | pos_iou_thr=0.5, 223 | neg_iou_thr=0.5, 224 | min_pos_iou=0.5, 225 | match_low_quality=False, 226 | ignore_iof_thr=-1), 227 | sampler=dict( 228 | type='RandomSampler', 229 | num=512, 230 | pos_fraction=0.25, 231 | neg_pos_ub=-1, 232 | add_gt_as_proposals=True), 233 | pos_weight=-1, 234 | debug=False)), 235 | test_cfg=dict( 236 | rpn=dict( 237 | nms_pre=1000, 238 | max_per_img=1000, 239 | nms=dict(type='nms', iou_threshold=0.7), 240 | min_bbox_size=0), 241 | rcnn=dict( 242 | score_thr=0.05, 243 | nms=dict(type='nms', iou_threshold=0.5), 244 | max_per_img=100))) 245 | -------------------------------------------------------------------------------- /ummdet/checkpoints/train_cfg/mask_rcnn_swin-t-p4-w7_fpn_ms-crop-3x_coco.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='MaskRCNNAdv', 3 | backbone=dict( 4 | type='SwinTransformer', 5 | embed_dims=96, 6 | depths=[2, 2, 6, 2], 7 | num_heads=[3, 6, 12, 24], 8 | window_size=7, 9 | mlp_ratio=4, 10 | qkv_bias=True, 11 | qk_scale=None, 12 | drop_rate=0.0, 13 | attn_drop_rate=0.0, 14 | drop_path_rate=0.2, 15 | patch_norm=True, 16 | out_indices=(0, 1, 2, 3), 17 | with_cp=False, 18 | convert_weights=True, 19 | init_cfg=dict( 20 | type='Pretrained', 21 | checkpoint= 22 | 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' 23 | )), 24 | neck=dict( 25 | type='FPN', 26 | in_channels=[96, 192, 384, 768], 27 | out_channels=256, 28 | num_outs=5), 29 | rpn_head=dict( 30 | type='RPNHead', 31 | in_channels=256, 32 | feat_channels=256, 33 | anchor_generator=dict( 34 | type='AnchorGenerator', 35 | scales=[8], 36 | ratios=[0.5, 1.0, 2.0], 37 | strides=[4, 8, 16, 32, 64]), 38 | bbox_coder=dict( 39 | type='DeltaXYWHBBoxCoder', 40 | target_means=[0.0, 0.0, 0.0, 0.0], 41 | target_stds=[1.0, 1.0, 1.0, 1.0]), 42 | loss_cls=dict( 43 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 44 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 45 | roi_head=dict( 46 | type='StandardRoIHead', 47 | bbox_roi_extractor=dict( 48 | type='SingleRoIExtractor', 49 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), 50 | out_channels=256, 51 | featmap_strides=[4, 8, 16, 32]), 52 | bbox_head=dict( 53 | type='Shared2FCBBoxHead', 54 | in_channels=256, 55 | fc_out_channels=1024, 56 | roi_feat_size=7, 57 | num_classes=80, 58 | bbox_coder=dict( 59 | type='DeltaXYWHBBoxCoder', 60 | target_means=[0.0, 0.0, 0.0, 0.0], 61 | target_stds=[0.1, 0.1, 0.2, 0.2]), 62 | reg_class_agnostic=False, 63 | loss_cls=dict( 64 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 65 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 66 | mask_roi_extractor=dict( 67 | type='SingleRoIExtractor', 68 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), 69 | out_channels=256, 70 | featmap_strides=[4, 8, 16, 32]), 71 | mask_head=dict( 72 | type='FCNMaskHead', 73 | num_convs=4, 74 | in_channels=256, 75 | conv_out_channels=256, 76 | num_classes=80, 77 | loss_mask=dict( 78 | type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), 79 | train_cfg=dict( 80 | rpn=dict( 81 | assigner=dict( 82 | type='MaxIoUAssigner', 83 | pos_iou_thr=0.7, 84 | neg_iou_thr=0.3, 85 | min_pos_iou=0.3, 86 | match_low_quality=True, 87 | ignore_iof_thr=-1), 88 | sampler=dict( 89 | type='RandomSampler', 90 | num=256, 91 | pos_fraction=0.5, 92 | neg_pos_ub=-1, 93 | add_gt_as_proposals=False), 94 | allowed_border=-1, 95 | pos_weight=-1, 96 | debug=False), 97 | rpn_proposal=dict( 98 | nms_pre=2000, 99 | max_per_img=1000, 100 | nms=dict(type='nms', iou_threshold=0.7), 101 | min_bbox_size=0), 102 | rcnn=dict( 103 | assigner=dict( 104 | type='MaxIoUAssigner', 105 | pos_iou_thr=0.5, 106 | neg_iou_thr=0.5, 107 | min_pos_iou=0.5, 108 | match_low_quality=True, 109 | ignore_iof_thr=-1), 110 | sampler=dict( 111 | type='RandomSampler', 112 | num=512, 113 | pos_fraction=0.25, 114 | neg_pos_ub=-1, 115 | add_gt_as_proposals=True), 116 | mask_size=28, 117 | pos_weight=-1, 118 | debug=False)), 119 | test_cfg=dict( 120 | rpn=dict( 121 | nms_pre=1000, 122 | max_per_img=1000, 123 | nms=dict(type='nms', iou_threshold=0.7), 124 | min_bbox_size=0), 125 | rcnn=dict( 126 | score_thr=0.05, 127 | nms=dict(type='nms', iou_threshold=0.5), 128 | max_per_img=100, 129 | mask_thr_binary=0.5))) 130 | dataset_type = 'CocoDataset' 131 | data_root = 'data/coco/' 132 | img_norm_cfg = dict( 133 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 134 | train_pipeline = [ 135 | dict(type='LoadImageFromFile'), 136 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True), 137 | dict(type='RandomFlip', flip_ratio=0.5), 138 | dict( 139 | type='AutoAugment', 140 | policies=[[{ 141 | 'type': 142 | 'Resize', 143 | 'img_scale': [(480, 1333), (512, 1333), (544, 1333), (576, 1333), 144 | (608, 1333), (640, 1333), (672, 1333), (704, 1333), 145 | (736, 1333), (768, 1333), (800, 1333)], 146 | 'multiscale_mode': 147 | 'value', 148 | 'keep_ratio': 149 | True 150 | }], 151 | [{ 152 | 'type': 'Resize', 153 | 'img_scale': [(400, 1333), (500, 1333), (600, 1333)], 154 | 'multiscale_mode': 'value', 155 | 'keep_ratio': True 156 | }, { 157 | 'type': 'RandomCrop', 158 | 'crop_type': 'absolute_range', 159 | 'crop_size': (384, 600), 160 | 'allow_negative_crop': True 161 | }, { 162 | 'type': 163 | 'Resize', 164 | 'img_scale': [(480, 1333), (512, 1333), (544, 1333), 165 | (576, 1333), (608, 1333), (640, 1333), 166 | (672, 1333), (704, 1333), (736, 1333), 167 | (768, 1333), (800, 1333)], 168 | 'multiscale_mode': 169 | 'value', 170 | 'override': 171 | True, 172 | 'keep_ratio': 173 | True 174 | }]]), 175 | dict( 176 | type='Normalize', 177 | mean=[123.675, 116.28, 103.53], 178 | std=[58.395, 57.12, 57.375], 179 | to_rgb=True), 180 | dict(type='Pad', size_divisor=32), 181 | dict(type='DefaultFormatBundle'), 182 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']) 183 | ] 184 | test_pipeline = [ 185 | dict(type='LoadImageFromFile'), 186 | dict( 187 | type='MultiScaleFlipAug', 188 | img_scale=(1333, 800), 189 | flip=False, 190 | transforms=[ 191 | dict(type='Resize', keep_ratio=True), 192 | dict(type='RandomFlip'), 193 | dict( 194 | type='Normalize', 195 | mean=[123.675, 116.28, 103.53], 196 | std=[58.395, 57.12, 57.375], 197 | to_rgb=True), 198 | dict(type='Pad', size_divisor=32), 199 | dict(type='ImageToTensor', keys=['img']), 200 | dict(type='Collect', keys=['img']) 201 | ]) 202 | ] 203 | data = dict( 204 | samples_per_gpu=2, 205 | workers_per_gpu=2, 206 | train=dict( 207 | type='CocoDataset', 208 | ann_file='data/coco/annotations/instances_train2017.json', 209 | img_prefix='data/coco/train2017/', 210 | pipeline=[ 211 | dict(type='LoadImageFromFile'), 212 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True), 213 | dict(type='RandomFlip', flip_ratio=0.5), 214 | dict( 215 | type='AutoAugment', 216 | policies=[[{ 217 | 'type': 218 | 'Resize', 219 | 'img_scale': [(480, 1333), (512, 1333), (544, 1333), 220 | (576, 1333), (608, 1333), (640, 1333), 221 | (672, 1333), (704, 1333), (736, 1333), 222 | (768, 1333), (800, 1333)], 223 | 'multiscale_mode': 224 | 'value', 225 | 'keep_ratio': 226 | True 227 | }], 228 | [{ 229 | 'type': 'Resize', 230 | 'img_scale': [(400, 1333), (500, 1333), 231 | (600, 1333)], 232 | 'multiscale_mode': 'value', 233 | 'keep_ratio': True 234 | }, { 235 | 'type': 'RandomCrop', 236 | 'crop_type': 'absolute_range', 237 | 'crop_size': (384, 600), 238 | 'allow_negative_crop': True 239 | }, { 240 | 'type': 241 | 'Resize', 242 | 'img_scale': [(480, 1333), (512, 1333), 243 | (544, 1333), (576, 1333), 244 | (608, 1333), (640, 1333), 245 | (672, 1333), (704, 1333), 246 | (736, 1333), (768, 1333), 247 | (800, 1333)], 248 | 'multiscale_mode': 249 | 'value', 250 | 'override': 251 | True, 252 | 'keep_ratio': 253 | True 254 | }]]), 255 | dict( 256 | type='Normalize', 257 | mean=[123.675, 116.28, 103.53], 258 | std=[58.395, 57.12, 57.375], 259 | to_rgb=True), 260 | dict(type='Pad', size_divisor=32), 261 | dict(type='DefaultFormatBundle'), 262 | dict( 263 | type='Collect', 264 | keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']) 265 | ]), 266 | val=dict( 267 | type='CocoDataset', 268 | ann_file='data/coco/annotations/instances_val2017.json', 269 | img_prefix='data/coco/val2017/', 270 | pipeline=[ 271 | dict(type='LoadImageFromFile'), 272 | dict( 273 | type='MultiScaleFlipAug', 274 | img_scale=(1333, 800), 275 | flip=False, 276 | transforms=[ 277 | dict(type='Resize', keep_ratio=True), 278 | dict(type='RandomFlip'), 279 | dict( 280 | type='Normalize', 281 | mean=[123.675, 116.28, 103.53], 282 | std=[58.395, 57.12, 57.375], 283 | to_rgb=True), 284 | dict(type='Pad', size_divisor=32), 285 | dict(type='ImageToTensor', keys=['img']), 286 | dict(type='Collect', keys=['img']) 287 | ]) 288 | ]), 289 | test=dict( 290 | type='CocoDatasetAdv', # Load Labels for FIA / NAA / RPA 291 | ann_file='data/coco/annotations/instances_val2017.json', 292 | img_prefix='data/coco/val2017/', 293 | pipeline=[ 294 | dict(type='LoadImageFromFile'), 295 | dict(type='LoadAnnotations', with_bbox=True), # Load Labels for FIA / NAA / RPA 296 | dict( 297 | type='MultiScaleFlipAug', 298 | img_scale=(1333, 800), 299 | flip=False, 300 | transforms=[ 301 | dict(type='Resize', keep_ratio=True), 302 | dict(type='RandomFlip'), 303 | dict( 304 | type='Normalize', 305 | mean=[123.675, 116.28, 103.53], 306 | std=[58.395, 57.12, 57.375], 307 | to_rgb=True), 308 | dict(type='Pad', size_divisor=32), 309 | # dict(type='ImageToTensor', keys=['img']), 310 | dict(type='DefaultFormatBundle'), # Load Labels for FIA / NAA / RPA 311 | # dict(type='Collect', keys=['img']) 312 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) # Load Labels for FIA / NAA / RPA 313 | ]) 314 | ])) 315 | evaluation = dict(metric=['bbox', 'segm']) 316 | optimizer = dict( 317 | type='AdamW', 318 | lr=0.0001, 319 | betas=(0.9, 0.999), 320 | weight_decay=0.05, 321 | paramwise_cfg=dict( 322 | custom_keys=dict( 323 | absolute_pos_embed=dict(decay_mult=0.0), 324 | relative_position_bias_table=dict(decay_mult=0.0), 325 | norm=dict(decay_mult=0.0)))) 326 | optimizer_config = dict(grad_clip=None) 327 | lr_config = dict( 328 | policy='step', 329 | warmup='linear', 330 | warmup_iters=1000, 331 | warmup_ratio=0.001, 332 | step=[27, 33]) 333 | runner = dict(type='EpochBasedRunner', max_epochs=36) 334 | checkpoint_config = dict(interval=1) 335 | log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')]) 336 | custom_hooks = [dict(type='NumClassCheckHook')] 337 | dist_params = dict(backend='nccl') 338 | log_level = 'INFO' 339 | load_from = None 340 | resume_from = None 341 | workflow = [('train', 1)] 342 | opencv_num_threads = 0 343 | mp_start_method = 'fork' 344 | auto_scale_lr = dict(enable=False, base_batch_size=16) 345 | pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' 346 | -------------------------------------------------------------------------------- /ummdet/checkpoints/train_cfg/vfnet_r50_fpn_mstrain_2x_coco.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'CocoDataset' 2 | data_root = 'data/coco/' 3 | img_norm_cfg = dict( 4 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 5 | train_pipeline = [ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations', with_bbox=True), 8 | dict( 9 | type='Resize', 10 | img_scale=[(1333, 480), (1333, 960)], 11 | multiscale_mode='range', 12 | keep_ratio=True), 13 | dict(type='RandomFlip', flip_ratio=0.5), 14 | dict( 15 | type='Normalize', 16 | mean=[123.675, 116.28, 103.53], 17 | std=[58.395, 57.12, 57.375], 18 | to_rgb=True), 19 | dict(type='Pad', size_divisor=32), 20 | dict(type='DefaultFormatBundle'), 21 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 22 | ] 23 | test_pipeline = [ 24 | dict(type='LoadImageFromFile'), 25 | dict( 26 | type='MultiScaleFlipAug', 27 | img_scale=(1333, 800), 28 | flip=False, 29 | transforms=[ 30 | dict(type='Resize', keep_ratio=True), 31 | dict(type='RandomFlip'), 32 | dict( 33 | type='Normalize', 34 | mean=[123.675, 116.28, 103.53], 35 | std=[58.395, 57.12, 57.375], 36 | to_rgb=True), 37 | dict(type='Pad', size_divisor=32), 38 | dict(type='DefaultFormatBundle'), 39 | dict(type='Collect', keys=['img']) 40 | ]) 41 | ] 42 | data = dict( 43 | samples_per_gpu=2, 44 | workers_per_gpu=2, 45 | train=dict( 46 | type='CocoDataset', 47 | ann_file='data/coco/annotations/instances_train2017.json', 48 | img_prefix='data/coco/train2017/', 49 | pipeline=[ 50 | dict(type='LoadImageFromFile'), 51 | dict(type='LoadAnnotations', with_bbox=True), 52 | dict( 53 | type='Resize', 54 | img_scale=[(1333, 480), (1333, 960)], 55 | multiscale_mode='range', 56 | keep_ratio=True), 57 | dict(type='RandomFlip', flip_ratio=0.5), 58 | dict( 59 | type='Normalize', 60 | mean=[123.675, 116.28, 103.53], 61 | std=[58.395, 57.12, 57.375], 62 | to_rgb=True), 63 | dict(type='Pad', size_divisor=32), 64 | dict(type='DefaultFormatBundle'), 65 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 66 | ]), 67 | val=dict( 68 | type='CocoDataset', 69 | ann_file='data/coco/annotations/instances_val2017.json', 70 | img_prefix='data/coco/val2017/', 71 | pipeline=[ 72 | dict(type='LoadImageFromFile'), 73 | dict( 74 | type='MultiScaleFlipAug', 75 | img_scale=(1333, 800), 76 | flip=False, 77 | transforms=[ 78 | dict(type='Resize', keep_ratio=True), 79 | dict(type='RandomFlip'), 80 | dict( 81 | type='Normalize', 82 | mean=[123.675, 116.28, 103.53], 83 | std=[58.395, 57.12, 57.375], 84 | to_rgb=True), 85 | dict(type='Pad', size_divisor=32), 86 | dict(type='DefaultFormatBundle'), 87 | dict(type='Collect', keys=['img']) 88 | ]) 89 | ]), 90 | test=dict( 91 | type='CocoDatasetAdv', # Load Labels for FIA / NAA / RPA 92 | ann_file='data/coco/annotations/instances_val2017.json', 93 | img_prefix='data/coco/val2017/', 94 | pipeline=[ 95 | dict(type='LoadImageFromFile'), 96 | dict(type='LoadAnnotations', with_bbox=True), # Load Labels for FIA / NAA / RPA 97 | dict( 98 | type='MultiScaleFlipAug', 99 | img_scale=(1333, 800), 100 | flip=False, 101 | transforms=[ 102 | dict(type='Resize', keep_ratio=True), 103 | dict(type='RandomFlip'), 104 | dict( 105 | type='Normalize', 106 | mean=[123.675, 116.28, 103.53], 107 | std=[58.395, 57.12, 57.375], 108 | to_rgb=True), 109 | dict(type='Pad', size_divisor=32), 110 | dict(type='DefaultFormatBundle'), 111 | # dict(type='Collect', keys=['img']) 112 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) # Load Labels for FIA / NAA / RPA 113 | ]) 114 | ])) 115 | evaluation = dict(interval=1, metric='bbox') 116 | optimizer = dict( 117 | type='SGD', 118 | lr=0.01, 119 | momentum=0.9, 120 | weight_decay=0.0001, 121 | paramwise_cfg=dict(bias_lr_mult=2.0, bias_decay_mult=0.0)) 122 | optimizer_config = dict(grad_clip=None) 123 | lr_config = dict( 124 | policy='step', 125 | warmup='linear', 126 | warmup_iters=500, 127 | warmup_ratio=0.1, 128 | step=[16, 22]) 129 | runner = dict(type='EpochBasedRunner', max_epochs=24) 130 | checkpoint_config = dict(interval=1) 131 | log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')]) 132 | custom_hooks = [dict(type='NumClassCheckHook')] 133 | dist_params = dict(backend='nccl') 134 | log_level = 'INFO' 135 | load_from = None 136 | resume_from = None 137 | workflow = [('train', 1)] 138 | opencv_num_threads = 0 139 | mp_start_method = 'fork' 140 | auto_scale_lr = dict(enable=False, base_batch_size=16) 141 | model = dict( 142 | type='VFNetAdv', 143 | backbone=dict( 144 | type='ResNet', 145 | depth=50, 146 | num_stages=4, 147 | out_indices=(0, 1, 2, 3), 148 | frozen_stages=1, 149 | norm_cfg=dict(type='BN', requires_grad=True), 150 | norm_eval=True, 151 | style='pytorch', 152 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), 153 | neck=dict( 154 | type='FPN', 155 | in_channels=[256, 512, 1024, 2048], 156 | out_channels=256, 157 | start_level=1, 158 | add_extra_convs='on_output', 159 | num_outs=5, 160 | relu_before_extra_convs=True), 161 | bbox_head=dict( 162 | type='VFNetHeadAdv', 163 | num_classes=80, 164 | in_channels=256, 165 | stacked_convs=3, 166 | feat_channels=256, 167 | strides=[8, 16, 32, 64, 128], 168 | center_sampling=False, 169 | dcn_on_last_conv=False, 170 | use_atss=True, 171 | use_vfl=True, 172 | loss_cls=dict( 173 | type='VarifocalLoss', 174 | use_sigmoid=True, 175 | alpha=0.75, 176 | gamma=2.0, 177 | iou_weighted=True, 178 | loss_weight=1.0), 179 | loss_bbox=dict(type='GIoULoss', loss_weight=1.5), 180 | loss_bbox_refine=dict(type='GIoULoss', loss_weight=2.0)), 181 | train_cfg=dict( 182 | assigner=dict(type='ATSSAssigner', topk=9), 183 | allowed_border=-1, 184 | pos_weight=-1, 185 | debug=False), 186 | test_cfg=dict( 187 | nms_pre=1000, 188 | min_bbox_size=0, 189 | score_thr=0.05, 190 | nms=dict(type='nms', iou_threshold=0.6), 191 | max_per_img=100)) 192 | -------------------------------------------------------------------------------- /ummdet/checkpoints/train_cfg/yolov3_d53_mstrain-608_273e_coco.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=1) 2 | log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')]) 3 | custom_hooks = [dict(type='NumClassCheckHook')] 4 | dist_params = dict(backend='nccl') 5 | log_level = 'INFO' 6 | load_from = None 7 | resume_from = None 8 | workflow = [('train', 1)] 9 | opencv_num_threads = 0 10 | mp_start_method = 'fork' 11 | auto_scale_lr = dict(enable=False, base_batch_size=64) 12 | model = dict( 13 | type='YOLOV3Adv', 14 | backbone=dict( 15 | type='Darknet', 16 | depth=53, 17 | out_indices=(3, 4, 5), 18 | init_cfg=dict(type='Pretrained', checkpoint='open-mmlab://darknet53')), 19 | neck=dict( 20 | type='YOLOV3Neck', 21 | num_scales=3, 22 | in_channels=[1024, 512, 256], 23 | out_channels=[512, 256, 128]), 24 | bbox_head=dict( 25 | type='YOLOV3Head', 26 | num_classes=80, 27 | in_channels=[512, 256, 128], 28 | out_channels=[1024, 512, 256], 29 | anchor_generator=dict( 30 | type='YOLOAnchorGenerator', 31 | base_sizes=[[(116, 90), (156, 198), (373, 326)], 32 | [(30, 61), (62, 45), (59, 119)], 33 | [(10, 13), (16, 30), (33, 23)]], 34 | strides=[32, 16, 8]), 35 | bbox_coder=dict(type='YOLOBBoxCoder'), 36 | featmap_strides=[32, 16, 8], 37 | loss_cls=dict( 38 | type='CrossEntropyLoss', 39 | use_sigmoid=True, 40 | loss_weight=1.0, 41 | reduction='sum'), 42 | loss_conf=dict( 43 | type='CrossEntropyLoss', 44 | use_sigmoid=True, 45 | loss_weight=1.0, 46 | reduction='sum'), 47 | loss_xy=dict( 48 | type='CrossEntropyLoss', 49 | use_sigmoid=True, 50 | loss_weight=2.0, 51 | reduction='sum'), 52 | loss_wh=dict(type='MSELoss', loss_weight=2.0, reduction='sum')), 53 | train_cfg=dict( 54 | assigner=dict( 55 | type='GridAssigner', 56 | pos_iou_thr=0.5, 57 | neg_iou_thr=0.5, 58 | min_pos_iou=0)), 59 | test_cfg=dict( 60 | nms_pre=1000, 61 | min_bbox_size=0, 62 | score_thr=0.05, 63 | conf_thr=0.005, 64 | nms=dict(type='nms', iou_threshold=0.45), 65 | max_per_img=100)) 66 | dataset_type = 'CocoDataset' 67 | data_root = 'data/coco/' 68 | img_norm_cfg = dict(mean=[0, 0, 0], std=[255.0, 255.0, 255.0], to_rgb=True) 69 | train_pipeline = [ 70 | dict(type='LoadImageFromFile', to_float32=True), 71 | dict(type='LoadAnnotations', with_bbox=True), 72 | dict(type='Expand', mean=[0, 0, 0], to_rgb=True, ratio_range=(1, 2)), 73 | dict( 74 | type='MinIoURandomCrop', 75 | min_ious=(0.4, 0.5, 0.6, 0.7, 0.8, 0.9), 76 | min_crop_size=0.3), 77 | dict(type='Resize', img_scale=[(320, 320), (608, 608)], keep_ratio=True), 78 | dict(type='RandomFlip', flip_ratio=0.5), 79 | dict(type='PhotoMetricDistortion'), 80 | dict( 81 | type='Normalize', 82 | mean=[0, 0, 0], 83 | std=[255.0, 255.0, 255.0], 84 | to_rgb=True), 85 | dict(type='Pad', size_divisor=32), 86 | dict(type='DefaultFormatBundle'), 87 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 88 | ] 89 | test_pipeline = [ 90 | dict(type='LoadImageFromFile'), 91 | dict( 92 | type='MultiScaleFlipAug', 93 | img_scale=(608, 608), 94 | flip=False, 95 | transforms=[ 96 | dict(type='Resize', keep_ratio=True), 97 | dict(type='RandomFlip'), 98 | dict( 99 | type='Normalize', 100 | mean=[0, 0, 0], 101 | std=[255.0, 255.0, 255.0], 102 | to_rgb=True), 103 | dict(type='Pad', size_divisor=32), 104 | dict(type='ImageToTensor', keys=['img']), 105 | dict(type='Collect', keys=['img']) 106 | ]) 107 | ] 108 | data = dict( 109 | samples_per_gpu=8, 110 | workers_per_gpu=4, 111 | train=dict( 112 | type='CocoDataset', 113 | ann_file='data/coco/annotations/instances_train2017.json', 114 | img_prefix='data/coco/train2017/', 115 | pipeline=[ 116 | dict(type='LoadImageFromFile', to_float32=True), 117 | dict(type='LoadAnnotations', with_bbox=True), 118 | dict( 119 | type='Expand', mean=[0, 0, 0], to_rgb=True, 120 | ratio_range=(1, 2)), 121 | dict( 122 | type='MinIoURandomCrop', 123 | min_ious=(0.4, 0.5, 0.6, 0.7, 0.8, 0.9), 124 | min_crop_size=0.3), 125 | dict( 126 | type='Resize', 127 | img_scale=[(320, 320), (608, 608)], 128 | keep_ratio=True), 129 | dict(type='RandomFlip', flip_ratio=0.5), 130 | dict(type='PhotoMetricDistortion'), 131 | dict( 132 | type='Normalize', 133 | mean=[0, 0, 0], 134 | std=[255.0, 255.0, 255.0], 135 | to_rgb=True), 136 | dict(type='Pad', size_divisor=32), 137 | dict(type='DefaultFormatBundle'), 138 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 139 | ]), 140 | val=dict( 141 | type='CocoDataset', 142 | ann_file='data/coco/annotations/instances_val2017.json', 143 | img_prefix='data/coco/val2017/', 144 | pipeline=[ 145 | dict(type='LoadImageFromFile'), 146 | dict( 147 | type='MultiScaleFlipAug', 148 | img_scale=(608, 608), 149 | flip=False, 150 | transforms=[ 151 | dict(type='Resize', keep_ratio=True), 152 | dict(type='RandomFlip'), 153 | dict( 154 | type='Normalize', 155 | mean=[0, 0, 0], 156 | std=[255.0, 255.0, 255.0], 157 | to_rgb=True), 158 | dict(type='Pad', size_divisor=32), 159 | dict(type='ImageToTensor', keys=['img']), 160 | dict(type='Collect', keys=['img']) 161 | ]) 162 | ]), 163 | test=dict( 164 | type='CocoDatasetAdv', # Load Labels for FIA / NAA / RPA 165 | ann_file='data/coco/annotations/instances_val2017.json', 166 | img_prefix='data/coco/val2017/', 167 | pipeline=[ 168 | dict(type='LoadImageFromFile'), 169 | dict(type='LoadAnnotations', with_bbox=True), # Load Labels for FIA / NAA / RPA 170 | dict( 171 | type='MultiScaleFlipAug', 172 | img_scale=(608, 608), 173 | flip=False, 174 | transforms=[ 175 | dict(type='Resize', keep_ratio=True), 176 | dict(type='RandomFlip'), 177 | dict( 178 | type='Normalize', 179 | mean=[0, 0, 0], 180 | std=[255.0, 255.0, 255.0], 181 | to_rgb=True), 182 | dict(type='Pad', size_divisor=32), 183 | # dict(type='ImageToTensor', keys=['img']), 184 | dict(type='DefaultFormatBundle'), # Load Labels for FIA / NAA / RPA 185 | # dict(type='Collect', keys=['img']) 186 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) # Load Labels for FIA / NAA / RPA 187 | ]) 188 | ])) 189 | optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0005) 190 | optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) 191 | lr_config = dict( 192 | policy='step', 193 | warmup='linear', 194 | warmup_iters=2000, 195 | warmup_ratio=0.1, 196 | step=[218, 246]) 197 | runner = dict(type='EpochBasedRunner', max_epochs=273) 198 | evaluation = dict(interval=1, metric=['bbox']) 199 | -------------------------------------------------------------------------------- /ummdet/components/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco import CocoDatasetAdv -------------------------------------------------------------------------------- /ummdet/detectors/__init__.py: -------------------------------------------------------------------------------- 1 | from .yolo import YOLOV3Adv 2 | from .vfnet import VFNetAdv 3 | from .vfnet_head import VFNetHeadAdv 4 | from .faster_rcnn import FasterRCNNAdv 5 | from .mask_rcnn import MaskRCNNAdv 6 | -------------------------------------------------------------------------------- /ummdet/detectors/faster_rcnn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .model_hook import ModelHook 3 | from mmdet.models import DETECTORS, TwoStageDetector 4 | 5 | 6 | @DETECTORS.register_module() 7 | class FasterRCNNAdv(TwoStageDetector, ModelHook): 8 | """Implementation of `Faster R-CNN `_""" 9 | 10 | def __init__(self, 11 | backbone, 12 | rpn_head, 13 | roi_head, 14 | train_cfg, 15 | test_cfg, 16 | neck=None, 17 | pretrained=None, 18 | init_cfg=None): 19 | super(FasterRCNNAdv, self).__init__( 20 | backbone=backbone, 21 | neck=neck, 22 | rpn_head=rpn_head, 23 | roi_head=roi_head, 24 | train_cfg=train_cfg, 25 | test_cfg=test_cfg, 26 | pretrained=pretrained, 27 | init_cfg=init_cfg) 28 | -------------------------------------------------------------------------------- /ummdet/detectors/mask_rcnn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .model_hook import ModelHook 3 | from mmdet.models import DETECTORS, TwoStageDetector 4 | 5 | 6 | @DETECTORS.register_module() 7 | class MaskRCNNAdv(TwoStageDetector, ModelHook): 8 | """Implementation of `Mask R-CNN `_""" 9 | 10 | def __init__(self, 11 | backbone, 12 | rpn_head, 13 | roi_head, 14 | train_cfg, 15 | test_cfg, 16 | neck=None, 17 | pretrained=None, 18 | init_cfg=None): 19 | super(MaskRCNNAdv, self).__init__( 20 | backbone=backbone, 21 | neck=neck, 22 | rpn_head=rpn_head, 23 | roi_head=roi_head, 24 | train_cfg=train_cfg, 25 | test_cfg=test_cfg, 26 | pretrained=pretrained, 27 | init_cfg=init_cfg) 28 | # Disable Mask 29 | self.roi_head.mask_head = None 30 | self.roi_head.mask_roi_extractor = None 31 | -------------------------------------------------------------------------------- /ummdet/detectors/model_hook.py: -------------------------------------------------------------------------------- 1 | from mmdet.models import SingleStageDetector, TwoStageDetector 2 | import os 3 | import torch 4 | 5 | 6 | class ModelHook(object): 7 | """ 8 | For FIA/NAA/RPA Attacks. 9 | """ 10 | 11 | def __init__(self) -> None: 12 | super().__init__() 13 | 14 | def forward_bottom(self, backbone_features, img_metas, gt_bboxes, gt_labels, return_loss=False): 15 | batch_input_shape = tuple(img_metas[0]["img_shape"][0:2]) 16 | for img_meta in img_metas: 17 | img_meta['batch_input_shape'] = batch_input_shape 18 | losses = None 19 | if isinstance(self, SingleStageDetector): 20 | losses = self.forward_bottom_one_stage(backbone_features, img_metas, gt_bboxes, gt_labels) 21 | elif isinstance(self, TwoStageDetector): 22 | losses = self.forward_bottom_two_stage(backbone_features, img_metas, gt_bboxes, gt_labels) 23 | 24 | # Accumulate model losses 25 | loss_accumulate = torch.tensor(0.0, device=os.environ["device"]) 26 | for loss_name, loss_data in losses.items(): 27 | if 'loss' in loss_name: 28 | if isinstance(loss_data, torch.Tensor): 29 | loss_accumulate = loss_accumulate + loss_data 30 | if isinstance(loss_data, list): 31 | for loss_item in loss_data: 32 | loss_accumulate = loss_accumulate + loss_item 33 | # Return 34 | if losses is not None and not return_loss: 35 | loss_accumulate.backward() 36 | return 37 | else: 38 | return loss_accumulate 39 | 40 | def forward_bottom_one_stage(self, x, img_metas, gt_bboxes, gt_labels): 41 | if self.with_neck: 42 | x = self.neck(x) 43 | losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes, gt_labels) 44 | return losses 45 | 46 | def forward_bottom_two_stage(self, x, img_metas, gt_bboxes, gt_labels): 47 | if self.with_neck: 48 | x = self.neck(x) 49 | losses = dict() 50 | # RPN forward and loss 51 | if self.with_rpn: 52 | proposal_cfg = self.train_cfg.get('rpn_proposal', 53 | self.test_cfg.rpn) 54 | rpn_losses, proposal_list = self.rpn_head.forward_train( 55 | x, 56 | img_metas, 57 | gt_bboxes=gt_bboxes, 58 | gt_labels=None, 59 | gt_bboxes_ignore=None, 60 | gt_masks=None, 61 | proposal_cfg=proposal_cfg) 62 | losses.update(rpn_losses) 63 | else: 64 | proposal_list = None 65 | roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list, 66 | gt_bboxes, gt_labels) 67 | losses.update(roi_losses) 68 | return losses -------------------------------------------------------------------------------- /ummdet/detectors/vfnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .model_hook import ModelHook 3 | from mmdet.models import SingleStageDetector, DETECTORS 4 | 5 | 6 | @DETECTORS.register_module() 7 | class VFNetAdv(SingleStageDetector, ModelHook): 8 | """Implementation of `VarifocalNet 9 | (VFNet).`_""" 10 | 11 | def __init__(self, 12 | backbone, 13 | neck, 14 | bbox_head, 15 | train_cfg=None, 16 | test_cfg=None, 17 | pretrained=None, 18 | init_cfg=None): 19 | super(VFNetAdv, self).__init__(backbone, neck, bbox_head, train_cfg, 20 | test_cfg, pretrained, init_cfg) 21 | -------------------------------------------------------------------------------- /ummdet/detectors/yolo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .model_hook import ModelHook 3 | from mmdet.models import SingleStageDetector, DETECTORS 4 | 5 | 6 | @DETECTORS.register_module() 7 | class YOLOV3Adv(SingleStageDetector, ModelHook): 8 | 9 | def __init__(self, 10 | backbone, 11 | neck, 12 | bbox_head, 13 | train_cfg=None, 14 | test_cfg=None, 15 | pretrained=None, 16 | init_cfg=None): 17 | super(YOLOV3Adv, self).__init__(backbone, neck, bbox_head, train_cfg, 18 | test_cfg, pretrained, init_cfg) 19 | --------------------------------------------------------------------------------