├── LICENSE ├── README.md ├── base ├── __init__.py ├── base_trainer.py ├── dataset.py └── parse_config.py ├── configs ├── config_zitspp.yml └── config_zitspp_finetune.yml ├── dataset └── dataloader.py ├── dnnlib ├── __init__.py └── util.py ├── inpainting_metric.py ├── logger ├── __init__.py ├── logger.py └── logger_config.json ├── networks ├── ade20k │ ├── __init__.py │ ├── base.py │ ├── color150.mat │ ├── mobilenet.py │ ├── object150_info.csv │ ├── resnet.py │ ├── segm_lib │ │ ├── __init__.py │ │ ├── nn │ │ │ ├── __init__.py │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── batchnorm.py │ │ │ │ ├── comm.py │ │ │ │ ├── replicate.py │ │ │ │ ├── tests │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── test_numeric_batchnorm.py │ │ │ │ │ └── test_sync_batchnorm.py │ │ │ │ └── unittest.py │ │ │ └── parallel │ │ │ │ ├── __init__.py │ │ │ │ └── data_parallel.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── dataloader.py │ │ │ ├── dataset.py │ │ │ ├── distributed.py │ │ │ └── sampler.py │ │ │ └── th.py │ └── utils.py ├── basic_module.py ├── discriminators.py ├── ffc.py ├── generators.py ├── inception.py ├── layers.py ├── losses.py ├── mat.py ├── pcp.py ├── trainer.py ├── transformer_layers.py ├── tsr.py ├── upsample.py ├── van.py └── vggNet.py ├── nms ├── cxx │ ├── README.md │ ├── lib │ │ └── solve_csa.so │ └── src │ │ ├── Exception.cc │ │ ├── Exception.hh │ │ ├── Random.cc │ │ ├── Random.hh │ │ ├── String.cc │ │ ├── String.hh │ │ ├── build.sh │ │ ├── csa.cc │ │ ├── csa.hh │ │ ├── csa_defs.h │ │ ├── csa_types.h │ │ ├── kofn.cc │ │ ├── kofn.hh │ │ ├── nms.cc │ │ ├── solve.cc │ │ └── solve.h ├── impl │ ├── bwmorph_thin.py │ ├── correspond_pixels.py │ ├── edges_eval_dir.py │ ├── edges_eval_plot.py │ └── toolbox.py ├── nms_temp.py └── nms_torch.py ├── requirements.txt ├── test.py ├── torch_utils ├── __init__.py ├── custom_ops.py ├── misc.py ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── persistence.py └── training_stats.py ├── trainers ├── impl │ ├── bwmorph_thin.py │ ├── correspond_pixels.py │ ├── edges_eval_dir.py │ ├── edges_eval_plot.py │ └── toolbox.py ├── lama_trainers.py ├── lsm_hawp │ ├── detector.py │ ├── lsm_hawp_model.py │ ├── model_config.py │ ├── multi_task_head.py │ └── stacked_hg.py ├── nms_temp.py ├── nms_torch.py └── pl_trainers.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # ZITS-PlusPlus 2 | ZITS++: Image Inpainting by Improving the Incremental Transformer on Structural Priors (TPAMI2023) 3 | 4 | [arxiv paper](https://arxiv.org/abs/2210.05950), 5 | [project page](https://ewrfcas.github.io/ZITS-PlusPlus/) 6 | 7 | [ZITS: CVPR2022 Version](https://github.com/DQiaole/ZITS_inpainting) 8 | 9 | ## TODO 10 | 11 | - [x] Releasing dataset and inference codes. 12 | - [x] Releasing pre-trained weights. 13 | - [ ] Releasing training codes. 14 | 15 | ## Dataset 16 | 17 | Test data HR-Flickr: [Download](https://1drv.ms/u/s!AqmYPmoRZryegRz79ueT2gVqWR4T?e=LTZMZM). 18 | 19 | Note that the HR-Flickr Dataset includes images obtained from [Flickr](https://www.flickr.com/). Use of the images must abide by the Flickr Terms of Use. 20 | We do not own the copyright of the images. 21 | They are solely provided for researchers and educators who wish to use the dataset for non-commercial research and/or educational purposes. 22 | 23 | ## Pre-trained Models 24 | 25 | 1. model_256: [Download](https://1drv.ms/u/s!AqmYPmoRZryegR1XjcmbjLV2OTk1?e=ToOT2d). 26 | 27 | 2. model_512: [Download](https://1drv.ms/u/s!AqmYPmoRZryegR9OPEgqq7LvgqJR?e=4Erzvr). 28 | 29 | 3. LSM-HAWP (line detector from MST): [Download](https://drive.google.com/drive/folders/1yg4Nc20D34sON0Ni_IOezjJCFHXKGWUW). 30 | 31 | ## Install 32 | 33 | ``` 34 | conda create -n zitspp python=3.8 35 | conda activate zitspp 36 | pip install -r requirements.txt 37 | cd nms/src 38 | source build.sh 39 | ``` 40 | 41 | ## Test 42 | 43 | Please use model_256 for images whose short sides are 256 or shorter. For larger images, using model_512 instead. 44 | 45 | 256 images 46 | ``` 47 | CUDA_VISIBLE_DEVICES=0 python test.py --config configs/config_zitspp.yml \ 48 | --exp_name \ 49 | --ckpt_resume ckpts/model_256/models/last.ckpt \ 50 | --save_path ./outputs/model_256 \ 51 | --img_dir \ 52 | --mask_dir \ 53 | --wf_ckpt ckpts/best_lsm_hawp.pth \ 54 | --use_ema \ 55 | --test_size 256 \ 56 | --object_removal # optional 57 | ``` 58 | 59 | 512 images 60 | ``` 61 | CUDA_VISIBLE_DEVICES=0 python test.py --config configs/config_zitspp_finetune.yml \ 62 | --exp_name \ 63 | --ckpt_resume ckpts/model_512/models/last.ckpt \ 64 | --save_path ./outputs/model_512 \ 65 | --img_dir \ 66 | --mask_dir \ 67 | --wf_ckpt ckpts/best_lsm_hawp.pth \ 68 | --use_ema \ 69 | --test_size 512 \ 70 | --object_removal # optional 71 | ``` 72 | 73 | ## Acknowledgments 74 | 75 | * This repo is built upon [MST](https://github.com/ewrfcas/MST_inpainting), [LaMa](https://github.com/saic-mdal/lama), and [ZITS](https://github.com/DQiaole/ZITS_inpainting). 76 | 77 | ## Cite 78 | 79 | If you found our program helpful, please consider citing: 80 | 81 | ``` 82 | @article{cao2023zits++, 83 | title={ZITS++: Image Inpainting by Improving the Incremental Transformer on Structural Priors}, 84 | author={Cao, Chenjie and Dong, Qiaole and Fu, Yanwei}, 85 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 86 | year={2023}, 87 | publisher={IEEE} 88 | } 89 | ``` 90 | 91 | 92 | -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_trainer import BaseTrainer 2 | -------------------------------------------------------------------------------- /base/base_trainer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import torch 5 | 6 | 7 | class BaseTrainer: 8 | """ 9 | Base class for all trainers 10 | """ 11 | 12 | def __init__(self, G, D, g_opt, d_opt, g_sche=None, d_sche=None, config=None, total_rank=1, writer=None, rank=0): 13 | self.config = config 14 | self.rank = rank 15 | if rank == 0: 16 | self.logger = config.get_logger('trainer') 17 | else: 18 | self.logger = None 19 | 20 | # setup GPU device if available, move models into configured device 21 | # self.device, device_ids = self._prepare_device(total_rank) 22 | self.device = rank 23 | self.G = G 24 | self.D = D 25 | if hasattr(G, 'module'): 26 | self.G_ema = copy.deepcopy(G.module).eval() 27 | else: 28 | self.G_ema = copy.deepcopy(G).eval() 29 | self.g_opt = g_opt 30 | self.d_opt = d_opt 31 | self.g_sche = g_sche 32 | self.d_sche = d_sche 33 | 34 | self.total_step = config['trainer']['total_step'] 35 | self.sample_period = config['trainer']['sample_period'] 36 | self.eval_period = config['trainer']['eval_period'] 37 | self.save_period = config['trainer']['save_period'] 38 | 39 | self.sample_path = os.path.join(config.log_dir, 'samples') 40 | os.makedirs(self.sample_path, exist_ok=True) 41 | self.eval_path = os.path.join(config.log_dir, 'validation') 42 | os.makedirs(self.eval_path, exist_ok=True) 43 | 44 | self.global_step = 0 45 | self.best_metric = dict() 46 | self.metric = dict() 47 | 48 | self.checkpoint_dir = config.save_dir 49 | 50 | # setup visualization writer instance 51 | self.writer = writer 52 | 53 | def train(self): 54 | raise NotImplementedError 55 | 56 | def _prepare_device(self, n_gpu_use): 57 | """ 58 | setup GPU device if available, move models into configured device 59 | """ 60 | n_gpu = torch.cuda.device_count() 61 | if n_gpu_use > 0 and n_gpu == 0: 62 | self.logger.warning("Warning: There\'s no GPU available on this machine," 63 | "training will be performed on CPU.") 64 | n_gpu_use = 0 65 | if n_gpu_use > n_gpu: 66 | self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available " 67 | "on this machine.".format(n_gpu_use, n_gpu)) 68 | n_gpu_use = n_gpu 69 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') 70 | list_ids = list(range(n_gpu_use)) 71 | return device, list_ids 72 | 73 | def _save_checkpoint(self, postfix='last'): 74 | """ 75 | Saving checkpoints 76 | :param epoch: current epoch number 77 | :param log: logging information of the epoch 78 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth' 79 | """ 80 | raw_G = self.G.module if hasattr(self.G, "module") else self.G 81 | raw_D = self.D.module if hasattr(self.D, "module") else self.D 82 | raw_Gema = self.G_ema.module if hasattr(self.G_ema, "module") else self.G_ema 83 | state = { 84 | 'global_step': self.global_step, 85 | 'G_model': raw_G.state_dict(), 86 | 'D_model': raw_D.state_dict(), 87 | 'G_ema': raw_Gema.state_dict(), 88 | 'G_opt': self.g_opt.state_dict(), 89 | 'D_opt': self.d_opt.state_dict(), 90 | 'best_metric': self.best_metric, 91 | 'metric': self.metric, 92 | 'config': self.config 93 | } 94 | save_path = str(self.checkpoint_dir / f'ckpt_{postfix}.pth') 95 | torch.save(state, save_path) 96 | self.logger.info(f"Saving current model to: ckpt_{postfix}.pth ...") 97 | 98 | def _resume_checkpoint(self, resume_path): 99 | """ 100 | Resume from saved checkpoints 101 | :param resume_path: Checkpoint path to be resumed 102 | """ 103 | resume_path = str(resume_path) 104 | if self.rank == 0: 105 | self.logger.info("Loading checkpoint: {} ...".format(resume_path)) 106 | print("Loading checkpoint: {} ...".format(resume_path)) 107 | checkpoint = torch.load(resume_path, map_location='cpu') 108 | self.global_step = checkpoint['global_step'] 109 | self.best_metric = checkpoint['best_metric'] 110 | 111 | state_dict = {} 112 | for k, v in checkpoint['G_model'].items(): 113 | state_dict[k.replace('module.', '')] = v 114 | if hasattr(self.G, 'module'): 115 | self.G.module.load_state_dict(state_dict) 116 | else: 117 | self.G.load_state_dict(state_dict) 118 | state_dict = {} 119 | for k, v in checkpoint['D_model'].items(): 120 | state_dict[k.replace('module.', '')] = v 121 | if hasattr(self.D, 'module'): 122 | self.D.module.load_state_dict(state_dict) 123 | else: 124 | self.D.load_state_dict(state_dict) 125 | 126 | if hasattr(self.G_ema, 'module'): 127 | self.G_ema.module.load_state_dict(checkpoint['G_ema']) 128 | else: 129 | self.G_ema.load_state_dict(checkpoint['G_ema']) 130 | 131 | # load opt 132 | if self.rank == 0: 133 | print("Loading optimizer: {} ...".format(resume_path)) 134 | self.g_opt.load_state_dict(checkpoint['G_opt']) 135 | self.d_opt.load_state_dict(checkpoint['D_opt']) 136 | 137 | # load sche 138 | for _ in range(self.global_step): 139 | self.g_sche.step() 140 | self.d_sche.step() 141 | 142 | if self.rank == 0: 143 | print("Checkpoint loaded. Resume training from global step {}".format(self.global_step)) 144 | -------------------------------------------------------------------------------- /base/parse_config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | from datetime import datetime 5 | from functools import reduce, partial 6 | from operator import getitem 7 | from pathlib import Path 8 | 9 | import yaml 10 | 11 | from logger import setup_logging 12 | 13 | 14 | class ConfigParser: 15 | def __init__(self, config, cfg_fname, resume=None, modification=None, run_id=None, mkdir=True): 16 | """ 17 | class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving 18 | and logging module. 19 | :param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example. 20 | :param resume: String, path to the checkpoint being loaded. 21 | :param modification: Dict keychain:value, specifying position values to be replaced from config dict. 22 | :param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default 23 | """ 24 | # load config file and apply modification 25 | self._config = _update_config(config, modification) 26 | self.resume = resume 27 | 28 | # set save_dir where trained models and log will be saved. 29 | save_dir = Path('ckpts') 30 | 31 | if run_id is None: # use timestamp as default run-id 32 | if resume is not None: 33 | run_id = str(resume).split('/')[-1] 34 | else: 35 | run_id = datetime.now().strftime(r'%m%d_%H%M%S') 36 | self._save_dir = save_dir / run_id / 'models' 37 | self._log_dir = save_dir / run_id / 'log' 38 | 39 | # make directory for saving checkpoints and log. 40 | if mkdir: 41 | self.save_dir.mkdir(parents=True, exist_ok=True) 42 | self.log_dir.mkdir(parents=True, exist_ok=True) 43 | 44 | # save updated config file to the checkpoint dir 45 | if not os.path.exists(str(save_dir / run_id / 'config.yml')): 46 | shutil.copy(cfg_fname, str(save_dir / run_id / 'config.yml')) 47 | 48 | # configure logging module 49 | setup_logging(self.log_dir) 50 | self.log_levels = { 51 | 0: logging.WARNING, 52 | 1: logging.INFO, 53 | 2: logging.DEBUG 54 | } 55 | 56 | @classmethod 57 | def from_args(cls, args, options='', mkdir=True): 58 | """ 59 | Initialize this class from some cli arguments. Used in train, test. 60 | """ 61 | for opt in options: 62 | args.add_argument(*opt.flags, default=None, type=opt.type) 63 | import argparse 64 | if not isinstance(args, tuple) and not isinstance(args, argparse.Namespace): 65 | args = args.parse_args() 66 | if args.resume is not None: 67 | resume = Path(args.resume) 68 | cfg_fname = resume / 'config.yml' 69 | else: 70 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 71 | assert args.config is not None, msg_no_cfg 72 | resume = None 73 | cfg_fname = Path(args.config) 74 | 75 | config = yaml.load(open(cfg_fname), Loader=yaml.FullLoader) 76 | 77 | # parse custom cli options into dictionary 78 | modification = {opt.target: getattr(args, _get_opt_name(opt.flags)) for opt in options} 79 | return cls(config, cfg_fname, resume, modification, run_id=args.exp_name, mkdir=mkdir) 80 | 81 | def init_obj(self, name, module, *args, **kwargs): 82 | """ 83 | Finds a function handle with the name given as 'type' in config, and returns the 84 | instance initialized with corresponding arguments given. 85 | 86 | `object = config.init_obj('name', module, a, b=1)` 87 | is equivalent to 88 | `object = module.name(a, b=1)` 89 | """ 90 | module_name = self[name]['type'] 91 | module_args = dict(self[name]['args']) 92 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 93 | module_args.update(kwargs) 94 | return getattr(module, module_name)(*args, **module_args) 95 | 96 | def init_ftn(self, name, module, *args, **kwargs): 97 | """ 98 | Finds a function handle with the name given as 'type' in config, and returns the 99 | function with given arguments fixed with functools.partial. 100 | 101 | `function = config.init_ftn('name', module, a, b=1)` 102 | is equivalent to 103 | `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`. 104 | """ 105 | module_name = self[name]['type'] 106 | module_args = dict(self[name]['args']) 107 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 108 | module_args.update(kwargs) 109 | return partial(getattr(module, module_name), *args, **module_args) 110 | 111 | def __getitem__(self, name): 112 | """Access items like ordinary dict.""" 113 | return self.config[name] 114 | 115 | def get_logger(self, name, verbosity=2): 116 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, 117 | self.log_levels.keys()) 118 | assert verbosity in self.log_levels, msg_verbosity 119 | logger = logging.getLogger(name) 120 | logger.setLevel(self.log_levels[verbosity]) 121 | return logger 122 | 123 | # setting read-only attributes 124 | @property 125 | def config(self): 126 | return self._config 127 | 128 | @property 129 | def save_dir(self): 130 | return self._save_dir 131 | 132 | @property 133 | def log_dir(self): 134 | return self._log_dir 135 | 136 | 137 | # helper functions to update config dict with custom cli options 138 | def _update_config(config, modification): 139 | if modification is None: 140 | return config 141 | 142 | for k, v in modification.items(): 143 | if v is not None: 144 | _set_by_path(config, k, v) 145 | return config 146 | 147 | 148 | def _get_opt_name(flags): 149 | for flg in flags: 150 | if flg.startswith('--'): 151 | return flg.replace('--', '') 152 | return flags[0].replace('--', '') 153 | 154 | 155 | def _set_by_path(tree, keys, value): 156 | """Set a value in a nested object in tree by sequence of keys.""" 157 | keys = keys.split(';') 158 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 159 | 160 | 161 | def _get_by_path(tree, keys): 162 | """Access a nested object in tree by sequence of keys.""" 163 | return reduce(getitem, keys, tree) 164 | -------------------------------------------------------------------------------- /configs/config_zitspp.yml: -------------------------------------------------------------------------------- 1 | train_flist: '/your_data_path/places365_standard/places2_all/train_list.txt' 2 | val_flist: '/your_data_path/places365_standard/places2_all/test_sub_list.txt' 3 | test_path: '/your_data_path/places365_standard/val_256img_for_eval' 4 | 5 | # see https://drive.google.com/drive/folders/1eU6VaTWGdgCXXWueCXilt6oxHdONgUgf?usp=sharing for downloading masks 6 | train_mask_flist: [ '/your_mask_path/irregular_mask/irregular_lama_mask_list.txt', 7 | '/your_mask_path/coco_mask/coco_mask_list.txt' ] 8 | test_mask_flist: '/your_mask_path/test_mask' 9 | 10 | batch_size: 32 # input batch size for training 11 | num_workers: 16 12 | sample_size: 12 13 | fp16: false 14 | 15 | # Dataset settings 16 | data_class: 'base.dataset.DynamicDataset_gradient_line' 17 | dataset: 18 | rect_mask_rate: 0.0 19 | train_line_path: "places2_train_wireframes" 20 | eval_line_path: "places2_val_wireframes" 21 | round: 64 22 | str_size: 256 23 | input_size: 256 # size for eval 24 | 25 | # model settings 26 | structure_upsample_class: 'networks.upsample.StructureUpsampling4' 27 | edgeline_tsr_class: 'networks.tsr.EdgeLineGPT256RelBCE_edge_pred_infer' 28 | grad_tsr_class: 'networks.tsr.GradientGPT256RelBCE' 29 | PLTrainer: 'trainers.pl_trainers.FinetunePLTrainer' 30 | 31 | g_class: 'networks.generators.FTRModel' 32 | g_args: 33 | use_gradient: True 34 | use_GFBlock: False 35 | activation: 'swish' 36 | use_VAN_between_FFC: False 37 | van_kernel_size: 21 38 | van_dilation: 3 39 | prior_ch: 5 40 | rezero_for_mpe: True 41 | rel_pos_num: 128 42 | 43 | d_class: 'networks.discriminators.NLayerDiscriminator' 44 | d_args: 45 | input_nc: 3 46 | 47 | 48 | # pretrained ckpt settings (used for finetuning only) 49 | resume_structure_upsample: none 50 | resume_edgeline_tsr: none 51 | resume_grad_tsr: none 52 | resume_ftr: none 53 | 54 | 55 | # Trainer settings 56 | trainer: 57 | fix_256: True 58 | Turning_Point: 10000 59 | total_step: 150000 60 | sample_period: 1000 61 | eval_period: 2000 62 | save_period: 1000 63 | logging_every: 50 64 | ema_beta: 0.995 65 | sample_with_center_mask: false 66 | # loss 67 | l1: 68 | use_l1: true 69 | weight_missing: 0 70 | weight_known: 10.0 71 | adversarial: 72 | weight: 10.0 73 | gp_coef: 0.001 74 | mask_as_fake_target: true 75 | allow_scale_mask: true 76 | extra_mask_weight_for_gen: 0.0 77 | use_unmasked_for_gen: true 78 | use_unmasked_for_discr: true 79 | mask_scale_mode: 'maxpool' 80 | perceptual: 81 | weight: 0 82 | resnet_pl: 83 | weight: 30.0 84 | # mkdir -p ade20k/ade20k-resnet50dilated-ppm_deepsup/ 85 | # wget -P ade20k/ade20k-resnet50dilated-ppm_deepsup/ http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth 86 | weights_path: './' 87 | feature_matching: 88 | weight: 100.0 89 | 90 | # opt settings 91 | optimizer: 92 | warmup_steps: 0 93 | decay_steps: [50000, 100000] 94 | decay_rate: 0.5 95 | g_opt: 96 | lr: 3.0e-4 97 | beta1: 0 98 | beta2: 0.99 99 | d_opt: 100 | lr: 1.0e-4 101 | beta1: 0 102 | beta2: 0.99 -------------------------------------------------------------------------------- /configs/config_zitspp_finetune.yml: -------------------------------------------------------------------------------- 1 | train_flist: '/your_data_path/places365_standard/places2_all/train_list.txt' 2 | val_flist: '/your_data_path/places365_standard/places2_all/test_sub_list.txt' 3 | test_path: '/your_data_path/places365_standard/val_256img_for_eval' 4 | 5 | # see https://drive.google.com/drive/folders/1eU6VaTWGdgCXXWueCXilt6oxHdONgUgf?usp=sharing for downloading masks 6 | train_mask_flist: [ '/your_mask_path/irregular_mask/irregular_lama_mask_list.txt', 7 | '/your_mask_path/coco_mask/coco_mask_list.txt' ] 8 | test_mask_flist: '/your_mask_path/test_mask' 9 | 10 | batch_size: 12 # input batch size for training 11 | num_workers: 12 12 | sample_size: 12 13 | fp16: false 14 | 15 | # Dataset settings 16 | data_class: 'base.dataset.DynamicDataset_gradient_line' 17 | dataset: 18 | rect_mask_rate: 0.0 19 | train_line_path: "places2_train_wireframes" 20 | eval_line_path: "places2_val_wireframes" 21 | round: 64 22 | str_size: 256 23 | input_size: 512 # size for eval 24 | 25 | # model settings 26 | structure_upsample_class: 'networks.upsample.StructureUpsampling4' 27 | edgeline_tsr_class: 'networks.tsr.EdgeLineGPT256RelBCE_edge_pred_infer' 28 | grad_tsr_class: 'networks.tsr.GradientGPT256RelBCE' 29 | PLTrainer: 'trainers.pl_trainers.FinetunePLTrainer_nms_threshold' 30 | 31 | g_class: 'networks.generators.FTRModel' 32 | g_args: 33 | use_gradient: False 34 | use_GFBlock: False 35 | activation: 'swish' 36 | use_VAN_between_FFC: False 37 | van_kernel_size: 21 38 | van_dilation: 3 39 | prior_ch: 3 40 | rezero_for_mpe: True 41 | rel_pos_num: 128 42 | 43 | d_class: 'networks.discriminators.NLayerDiscriminator' 44 | d_args: 45 | input_nc: 3 46 | 47 | 48 | # pretrained ckpt settings 49 | resume_structure_upsample: none 50 | resume_edgeline_tsr: none 51 | resume_grad_tsr: none 52 | resume_ftr: none 53 | 54 | 55 | # Trainer settings 56 | trainer: 57 | fix_256: False 58 | Turning_Point: 10000 59 | total_step: 150000 60 | sample_period: 1000 61 | eval_period: 2000 62 | save_period: 1000 63 | logging_every: 50 64 | ema_beta: 0.995 65 | sample_with_center_mask: false 66 | # loss 67 | l1: 68 | use_l1: true 69 | weight_missing: 0 70 | weight_known: 10.0 71 | adversarial: 72 | weight: 10.0 73 | gp_coef: 0.001 74 | mask_as_fake_target: true 75 | allow_scale_mask: true 76 | extra_mask_weight_for_gen: 0.0 77 | use_unmasked_for_gen: true 78 | use_unmasked_for_discr: true 79 | mask_scale_mode: 'maxpool' 80 | perceptual: 81 | weight: 0 82 | resnet_pl: 83 | weight: 30.0 84 | # mkdir -p ade20k/ade20k-resnet50dilated-ppm_deepsup/ 85 | # wget -P ade20k/ade20k-resnet50dilated-ppm_deepsup/ http://sceneparsing.csail.mit.edu/model/pytorch/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth 86 | weights_path: './' 87 | feature_matching: 88 | weight: 100.0 89 | 90 | # opt settings 91 | optimizer: 92 | warmup_steps: 0 93 | decay_steps: [50000, 100000] 94 | decay_rate: 0.5 95 | g_opt: 96 | lr: 3.0e-4 97 | beta1: 0 98 | beta2: 0.99 99 | d_opt: 100 | lr: 1.0e-4 101 | beta1: 0 102 | beta2: 0.99 103 | -------------------------------------------------------------------------------- /dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import torchvision.transforms.functional as F 8 | from skimage.color import rgb2gray 9 | 10 | 11 | def resize(img, height, width, center_crop=False): 12 | imgh, imgw = img.shape[0:2] 13 | 14 | if center_crop and imgh != imgw: 15 | # center crop 16 | side = np.minimum(imgh, imgw) 17 | j = (imgh - side) // 2 18 | i = (imgw - side) // 2 19 | img = img[j:j + side, i:i + side, ...] 20 | 21 | if imgh > height and imgw > width: 22 | inter = cv2.INTER_AREA 23 | else: 24 | inter = cv2.INTER_LINEAR 25 | img = cv2.resize(img, (width, height), interpolation=inter) 26 | 27 | return img 28 | 29 | 30 | ones_filter = np.ones((3, 3), dtype=np.float32) 31 | d_filter1 = np.array([[1, 1, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32) 32 | d_filter2 = np.array([[0, 0, 0], [1, 1, 0], [1, 1, 0]], dtype=np.float32) 33 | d_filter3 = np.array([[0, 1, 1], [0, 1, 1], [0, 0, 0]], dtype=np.float32) 34 | d_filter4 = np.array([[0, 0, 0], [0, 1, 1], [0, 1, 1]], dtype=np.float32) 35 | 36 | 37 | def load_masked_position_encoding(mask): 38 | ori_mask = mask.copy() 39 | ori_h, ori_w = ori_mask.shape[0:2] 40 | ori_mask = ori_mask / 255 41 | mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_AREA) 42 | mask[mask > 0] = 255 43 | h, w = mask.shape[0:2] 44 | mask3 = mask.copy() 45 | mask3 = 1. - (mask3 / 255.0) 46 | pos = np.zeros((h, w), dtype=np.int32) 47 | direct = np.zeros((h, w, 4), dtype=np.int32) 48 | i = 0 49 | while np.sum(1 - mask3) > 0: 50 | i += 1 51 | mask3_ = cv2.filter2D(mask3, -1, ones_filter) 52 | mask3_[mask3_ > 0] = 1 53 | sub_mask = mask3_ - mask3 54 | pos[sub_mask == 1] = i 55 | 56 | m = cv2.filter2D(mask3, -1, d_filter1) 57 | m[m > 0] = 1 58 | m = m - mask3 59 | direct[m == 1, 0] = 1 60 | 61 | m = cv2.filter2D(mask3, -1, d_filter2) 62 | m[m > 0] = 1 63 | m = m - mask3 64 | direct[m == 1, 1] = 1 65 | 66 | m = cv2.filter2D(mask3, -1, d_filter3) 67 | m[m > 0] = 1 68 | m = m - mask3 69 | direct[m == 1, 2] = 1 70 | 71 | m = cv2.filter2D(mask3, -1, d_filter4) 72 | m[m > 0] = 1 73 | m = m - mask3 74 | direct[m == 1, 3] = 1 75 | 76 | mask3 = mask3_ 77 | 78 | abs_pos = pos.copy() 79 | rel_pos = pos / (256 / 2) # to 0~1 maybe larger than 1 80 | rel_pos = (rel_pos * 128).astype(np.int32) 81 | rel_pos = np.clip(rel_pos, 0, 128 - 1) 82 | 83 | if ori_w != w or ori_h != h: 84 | rel_pos = cv2.resize(rel_pos, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST) 85 | rel_pos[ori_mask == 0] = 0 86 | direct = cv2.resize(direct, (ori_w, ori_h), interpolation=cv2.INTER_NEAREST) 87 | direct[ori_mask == 0, :] = 0 88 | 89 | return rel_pos, abs_pos, direct 90 | 91 | 92 | def to_tensor(img, norm=False): 93 | # img = Image.fromarray(img) 94 | img_t = F.to_tensor(img).float() 95 | if norm: 96 | img_t = F.normalize(img_t, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 97 | return img_t 98 | 99 | 100 | class InpaintingDataset(torch.utils.data.Dataset): 101 | def __init__(self, img_dir, mask_dir, test_size=None, use_gradient=False): 102 | super(InpaintingDataset, self).__init__() 103 | self.test_size = test_size 104 | self.use_gradient = use_gradient 105 | if img_dir.endswith(".txt"): 106 | with open(img_dir, 'r') as f: 107 | data = f.readlines() 108 | self.data = [d.strip() for d in data] 109 | else: 110 | data = glob(img_dir + '/*') 111 | self.data = sorted(data, key=lambda x: x.split('/')[-1]) 112 | 113 | mask_list = glob(mask_dir + '/*') 114 | self.mask_list = sorted(mask_list, key=lambda x: x.split('/')[-1]) 115 | 116 | print('Image num:', len(self.data)) 117 | print('Mask num:', len(self.mask_list)) 118 | 119 | 120 | def __len__(self): 121 | return len(self.data) 122 | 123 | def __getitem__(self, index): 124 | img = cv2.imread(self.data[index]) 125 | if self.test_size is not None: 126 | img = resize(img, self.test_size, self.test_size) 127 | img = img[:, :, ::-1] 128 | # resize/crop if needed 129 | imgh, imgw, _ = img.shape 130 | img_512 = resize(img, 512, 512) 131 | img_256 = resize(img, 256, 256) 132 | 133 | # load mask 134 | mask = cv2.imread(self.mask_list[index % len(self.mask_list)], cv2.IMREAD_GRAYSCALE) 135 | mask = cv2.resize(mask, (imgw, imgh), interpolation=cv2.INTER_NEAREST) 136 | mask = (mask > 127).astype(np.uint8) * 255 137 | 138 | mask_256 = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_AREA) 139 | mask_256[mask_256 > 0] = 255 140 | mask_512 = cv2.resize(mask, (512, 512), interpolation=cv2.INTER_NEAREST) 141 | mask_512 = (mask_512 > 127).astype(np.uint8) * 255 142 | 143 | batch = dict() 144 | batch['image'] = to_tensor(img.copy(), norm=True) 145 | batch['img_256'] = to_tensor(img_256, norm=True) 146 | batch['mask'] = to_tensor(mask) 147 | batch['mask_256'] = to_tensor(mask_256) 148 | batch['mask_512'] = to_tensor(mask_512) 149 | batch['img_512'] = to_tensor(img_512) 150 | batch['imgh'] = imgh 151 | batch['imgw'] = imgw 152 | 153 | batch['name'] = os.path.basename(self.data[index]) 154 | 155 | # load pos encoding 156 | rel_pos, abs_pos, direct = load_masked_position_encoding(mask) 157 | batch['rel_pos'] = torch.LongTensor(rel_pos) 158 | batch['abs_pos'] = torch.LongTensor(abs_pos) 159 | batch['direct'] = torch.LongTensor(direct) 160 | 161 | # load gradient 162 | if self.use_gradient: 163 | img_gray = rgb2gray(img_256) * 255 164 | sobelx = cv2.Sobel(img_gray, cv2.CV_64F, 1, 0).astype(np.float32) 165 | sobely = cv2.Sobel(img_gray, cv2.CV_64F, 0, 1).astype(np.float32) 166 | 167 | img_gray = rgb2gray(img) * 255 168 | sobelx_hr = cv2.Sobel(img_gray, cv2.CV_64F, 1, 0).astype(np.float32) 169 | sobely_hr = cv2.Sobel(img_gray, cv2.CV_64F, 0, 1).astype(np.float32) 170 | 171 | batch['gradientx'] = torch.from_numpy(sobelx).unsqueeze(0).float() 172 | batch['gradienty'] = torch.from_numpy(sobely).unsqueeze(0).float() 173 | batch['gradientx_hr'] = torch.from_numpy(sobelx_hr).unsqueeze(0).float() 174 | batch['gradienty_hr'] = torch.from_numpy(sobely_hr).unsqueeze(0).float() 175 | 176 | return batch 177 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | -------------------------------------------------------------------------------- /logger/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | from pathlib import Path 4 | 5 | from utils import read_json 6 | 7 | 8 | def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO): 9 | """ 10 | Setup logging configuration 11 | """ 12 | log_config = Path(log_config) 13 | if log_config.is_file(): 14 | config = read_json(log_config) 15 | # modify logging paths based on run config 16 | for _, handler in config['handlers'].items(): 17 | if 'filename' in handler: 18 | handler['filename'] = str(save_dir / handler['filename']) 19 | 20 | logging.config.dictConfig(config) 21 | else: 22 | print("Warning: logging configuration file is not found in {}.".format(log_config)) 23 | logging.basicConfig(level=default_level) 24 | -------------------------------------------------------------------------------- /logger/logger_config.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "version": 1, 4 | "disable_existing_loggers": false, 5 | "formatters": { 6 | "simple": {"format": "%(message)s"}, 7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"} 8 | }, 9 | "handlers": { 10 | "console": { 11 | "class": "logging.StreamHandler", 12 | "level": "DEBUG", 13 | "formatter": "simple", 14 | "stream": "ext://sys.stdout" 15 | }, 16 | "info_file_handler": { 17 | "class": "logging.handlers.RotatingFileHandler", 18 | "level": "INFO", 19 | "formatter": "datetime", 20 | "filename": "info.log", 21 | "maxBytes": 10485760, 22 | "backupCount": 20, "encoding": "utf8" 23 | } 24 | }, 25 | "root": { 26 | "level": "INFO", 27 | "handlers": [ 28 | "console", 29 | "info_file_handler" 30 | ] 31 | } 32 | } -------------------------------------------------------------------------------- /networks/ade20k/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * -------------------------------------------------------------------------------- /networks/ade20k/color150.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ewrfcas/ZITS-PlusPlus/de8dd48b17aedd15824842adb7bcca7535daba84/networks/ade20k/color150.mat -------------------------------------------------------------------------------- /networks/ade20k/mobilenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This MobileNetV2 implementation is modified from the following repository: 3 | https://github.com/tonylins/pytorch-mobilenet-v2 4 | """ 5 | 6 | import math 7 | 8 | import torch.nn as nn 9 | 10 | from .segm_lib.nn import SynchronizedBatchNorm2d 11 | from .utils import load_url 12 | 13 | BatchNorm2d = SynchronizedBatchNorm2d 14 | 15 | 16 | __all__ = ['mobilenetv2'] 17 | 18 | 19 | model_urls = { 20 | 'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar', 21 | } 22 | 23 | 24 | def conv_bn(inp, oup, stride): 25 | return nn.Sequential( 26 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 27 | BatchNorm2d(oup), 28 | nn.ReLU6(inplace=True) 29 | ) 30 | 31 | 32 | def conv_1x1_bn(inp, oup): 33 | return nn.Sequential( 34 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 35 | BatchNorm2d(oup), 36 | nn.ReLU6(inplace=True) 37 | ) 38 | 39 | 40 | class InvertedResidual(nn.Module): 41 | def __init__(self, inp, oup, stride, expand_ratio): 42 | super(InvertedResidual, self).__init__() 43 | self.stride = stride 44 | assert stride in [1, 2] 45 | 46 | hidden_dim = round(inp * expand_ratio) 47 | self.use_res_connect = self.stride == 1 and inp == oup 48 | 49 | if expand_ratio == 1: 50 | self.conv = nn.Sequential( 51 | # dw 52 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 53 | BatchNorm2d(hidden_dim), 54 | nn.ReLU6(inplace=True), 55 | # pw-linear 56 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 57 | BatchNorm2d(oup), 58 | ) 59 | else: 60 | self.conv = nn.Sequential( 61 | # pw 62 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 63 | BatchNorm2d(hidden_dim), 64 | nn.ReLU6(inplace=True), 65 | # dw 66 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 67 | BatchNorm2d(hidden_dim), 68 | nn.ReLU6(inplace=True), 69 | # pw-linear 70 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 71 | BatchNorm2d(oup), 72 | ) 73 | 74 | def forward(self, x): 75 | if self.use_res_connect: 76 | return x + self.conv(x) 77 | else: 78 | return self.conv(x) 79 | 80 | 81 | class MobileNetV2(nn.Module): 82 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 83 | super(MobileNetV2, self).__init__() 84 | block = InvertedResidual 85 | input_channel = 32 86 | last_channel = 1280 87 | interverted_residual_setting = [ 88 | # t, c, n, s 89 | [1, 16, 1, 1], 90 | [6, 24, 2, 2], 91 | [6, 32, 3, 2], 92 | [6, 64, 4, 2], 93 | [6, 96, 3, 1], 94 | [6, 160, 3, 2], 95 | [6, 320, 1, 1], 96 | ] 97 | 98 | # building first layer 99 | assert input_size % 32 == 0 100 | input_channel = int(input_channel * width_mult) 101 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 102 | self.features = [conv_bn(3, input_channel, 2)] 103 | # building inverted residual blocks 104 | for t, c, n, s in interverted_residual_setting: 105 | output_channel = int(c * width_mult) 106 | for i in range(n): 107 | if i == 0: 108 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 109 | else: 110 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 111 | input_channel = output_channel 112 | # building last several layers 113 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 114 | # make it nn.Sequential 115 | self.features = nn.Sequential(*self.features) 116 | 117 | # building classifier 118 | self.classifier = nn.Sequential( 119 | nn.Dropout(0.2), 120 | nn.Linear(self.last_channel, n_class), 121 | ) 122 | 123 | self._initialize_weights() 124 | 125 | def forward(self, x): 126 | x = self.features(x) 127 | x = x.mean(3).mean(2) 128 | x = self.classifier(x) 129 | return x 130 | 131 | def _initialize_weights(self): 132 | for m in self.modules(): 133 | if isinstance(m, nn.Conv2d): 134 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 135 | m.weight.data.normal_(0, math.sqrt(2. / n)) 136 | if m.bias is not None: 137 | m.bias.data.zero_() 138 | elif isinstance(m, BatchNorm2d): 139 | m.weight.data.fill_(1) 140 | m.bias.data.zero_() 141 | elif isinstance(m, nn.Linear): 142 | n = m.weight.size(1) 143 | m.weight.data.normal_(0, 0.01) 144 | m.bias.data.zero_() 145 | 146 | 147 | def mobilenetv2(pretrained=False, **kwargs): 148 | """Constructs a MobileNet_V2 model. 149 | 150 | Args: 151 | pretrained (bool): If True, returns a model pre-trained on ImageNet 152 | """ 153 | model = MobileNetV2(n_class=1000, **kwargs) 154 | if pretrained: 155 | model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False) 156 | return model -------------------------------------------------------------------------------- /networks/ade20k/object150_info.csv: -------------------------------------------------------------------------------- 1 | Idx,Ratio,Train,Val,Stuff,Name 2 | 1,0.1576,11664,1172,1,wall 3 | 2,0.1072,6046,612,1,building;edifice 4 | 3,0.0878,8265,796,1,sky 5 | 4,0.0621,9336,917,1,floor;flooring 6 | 5,0.0480,6678,641,0,tree 7 | 6,0.0450,6604,643,1,ceiling 8 | 7,0.0398,4023,408,1,road;route 9 | 8,0.0231,1906,199,0,bed 10 | 9,0.0198,4688,460,0,windowpane;window 11 | 10,0.0183,2423,225,1,grass 12 | 11,0.0181,2874,294,0,cabinet 13 | 12,0.0166,3068,310,1,sidewalk;pavement 14 | 13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul 15 | 14,0.0151,1804,190,1,earth;ground 16 | 15,0.0118,6666,796,0,door;double;door 17 | 16,0.0110,4269,411,0,table 18 | 17,0.0109,1691,160,1,mountain;mount 19 | 18,0.0104,3999,441,0,plant;flora;plant;life 20 | 19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall 21 | 20,0.0103,3261,318,0,chair 22 | 21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar 23 | 22,0.0074,709,75,1,water 24 | 23,0.0067,3296,315,0,painting;picture 25 | 24,0.0065,1191,106,0,sofa;couch;lounge 26 | 25,0.0061,1516,162,0,shelf 27 | 26,0.0060,667,69,1,house 28 | 27,0.0053,651,57,1,sea 29 | 28,0.0052,1847,224,0,mirror 30 | 29,0.0046,1158,128,1,rug;carpet;carpeting 31 | 30,0.0044,480,44,1,field 32 | 31,0.0044,1172,98,0,armchair 33 | 32,0.0044,1292,184,0,seat 34 | 33,0.0033,1386,138,0,fence;fencing 35 | 34,0.0031,698,61,0,desk 36 | 35,0.0030,781,73,0,rock;stone 37 | 36,0.0027,380,43,0,wardrobe;closet;press 38 | 37,0.0026,3089,302,0,lamp 39 | 38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub 40 | 39,0.0024,804,99,0,railing;rail 41 | 40,0.0023,1453,153,0,cushion 42 | 41,0.0023,411,37,0,base;pedestal;stand 43 | 42,0.0022,1440,162,0,box 44 | 43,0.0022,800,77,0,column;pillar 45 | 44,0.0020,2650,298,0,signboard;sign 46 | 45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser 47 | 46,0.0019,367,36,0,counter 48 | 47,0.0018,311,30,1,sand 49 | 48,0.0018,1181,122,0,sink 50 | 49,0.0018,287,23,1,skyscraper 51 | 50,0.0018,468,38,0,fireplace;hearth;open;fireplace 52 | 51,0.0018,402,43,0,refrigerator;icebox 53 | 52,0.0018,130,12,1,grandstand;covered;stand 54 | 53,0.0018,561,64,1,path 55 | 54,0.0017,880,102,0,stairs;steps 56 | 55,0.0017,86,12,1,runway 57 | 56,0.0017,172,11,0,case;display;case;showcase;vitrine 58 | 57,0.0017,198,18,0,pool;table;billiard;table;snooker;table 59 | 58,0.0017,930,109,0,pillow 60 | 59,0.0015,139,18,0,screen;door;screen 61 | 60,0.0015,564,52,1,stairway;staircase 62 | 61,0.0015,320,26,1,river 63 | 62,0.0015,261,29,1,bridge;span 64 | 63,0.0014,275,22,0,bookcase 65 | 64,0.0014,335,60,0,blind;screen 66 | 65,0.0014,792,75,0,coffee;table;cocktail;table 67 | 66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne 68 | 67,0.0014,1309,138,0,flower 69 | 68,0.0013,1112,113,0,book 70 | 69,0.0013,266,27,1,hill 71 | 70,0.0013,659,66,0,bench 72 | 71,0.0012,331,31,0,countertop 73 | 72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove 74 | 73,0.0012,369,36,0,palm;palm;tree 75 | 74,0.0012,144,9,0,kitchen;island 76 | 75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system 77 | 76,0.0010,324,33,0,swivel;chair 78 | 77,0.0009,304,27,0,boat 79 | 78,0.0009,170,20,0,bar 80 | 79,0.0009,68,6,0,arcade;machine 81 | 80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty 82 | 81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle 83 | 82,0.0008,492,49,0,towel 84 | 83,0.0008,2510,269,0,light;light;source 85 | 84,0.0008,440,39,0,truck;motortruck 86 | 85,0.0008,147,18,1,tower 87 | 86,0.0008,583,56,0,chandelier;pendant;pendent 88 | 87,0.0007,533,61,0,awning;sunshade;sunblind 89 | 88,0.0007,1989,239,0,streetlight;street;lamp 90 | 89,0.0007,71,5,0,booth;cubicle;stall;kiosk 91 | 90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box 92 | 91,0.0007,135,12,0,airplane;aeroplane;plane 93 | 92,0.0007,83,5,1,dirt;track 94 | 93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes 95 | 94,0.0006,1003,104,0,pole 96 | 95,0.0006,182,12,1,land;ground;soil 97 | 96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail 98 | 97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway 99 | 98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock 100 | 99,0.0006,965,114,0,bottle 101 | 100,0.0006,117,13,0,buffet;counter;sideboard 102 | 101,0.0006,354,35,0,poster;posting;placard;notice;bill;card 103 | 102,0.0006,108,9,1,stage 104 | 103,0.0006,557,55,0,van 105 | 104,0.0006,52,4,0,ship 106 | 105,0.0005,99,5,0,fountain 107 | 106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter 108 | 107,0.0005,292,31,0,canopy 109 | 108,0.0005,77,9,0,washer;automatic;washer;washing;machine 110 | 109,0.0005,340,38,0,plaything;toy 111 | 110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium 112 | 111,0.0005,465,49,0,stool 113 | 112,0.0005,50,4,0,barrel;cask 114 | 113,0.0005,622,75,0,basket;handbasket 115 | 114,0.0005,80,9,1,waterfall;falls 116 | 115,0.0005,59,3,0,tent;collapsible;shelter 117 | 116,0.0005,531,72,0,bag 118 | 117,0.0005,282,30,0,minibike;motorbike 119 | 118,0.0005,73,7,0,cradle 120 | 119,0.0005,435,44,0,oven 121 | 120,0.0005,136,25,0,ball 122 | 121,0.0005,116,24,0,food;solid;food 123 | 122,0.0004,266,31,0,step;stair 124 | 123,0.0004,58,12,0,tank;storage;tank 125 | 124,0.0004,418,83,0,trade;name;brand;name;brand;marque 126 | 125,0.0004,319,43,0,microwave;microwave;oven 127 | 126,0.0004,1193,139,0,pot;flowerpot 128 | 127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna 129 | 128,0.0004,347,36,0,bicycle;bike;wheel;cycle 130 | 129,0.0004,52,5,1,lake 131 | 130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine 132 | 131,0.0004,108,13,0,screen;silver;screen;projection;screen 133 | 132,0.0004,201,30,0,blanket;cover 134 | 133,0.0004,285,21,0,sculpture 135 | 134,0.0004,268,27,0,hood;exhaust;hood 136 | 135,0.0003,1020,108,0,sconce 137 | 136,0.0003,1282,122,0,vase 138 | 137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight 139 | 138,0.0003,453,57,0,tray 140 | 139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin 141 | 140,0.0003,397,44,0,fan 142 | 141,0.0003,92,8,1,pier;wharf;wharfage;dock 143 | 142,0.0003,228,18,0,crt;screen 144 | 143,0.0003,570,59,0,plate 145 | 144,0.0003,217,22,0,monitor;monitoring;device 146 | 145,0.0003,206,19,0,bulletin;board;notice;board 147 | 146,0.0003,130,14,0,shower 148 | 147,0.0003,178,28,0,radiator 149 | 148,0.0002,504,57,0,glass;drinking;glass 150 | 149,0.0002,775,96,0,clock 151 | 150,0.0002,421,56,0,flag 152 | -------------------------------------------------------------------------------- /networks/ade20k/resnet.py: -------------------------------------------------------------------------------- 1 | """Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch""" 2 | 3 | import math 4 | 5 | import torch.nn as nn 6 | from torch.nn import BatchNorm2d 7 | 8 | from .utils import load_url 9 | 10 | __all__ = ['ResNet', 'resnet50'] 11 | 12 | 13 | model_urls = { 14 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | "3x3 convolution with padding" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, block, layers, num_classes=1000): 98 | self.inplanes = 128 99 | super(ResNet, self).__init__() 100 | self.conv1 = conv3x3(3, 64, stride=2) 101 | self.bn1 = BatchNorm2d(64) 102 | self.relu1 = nn.ReLU(inplace=True) 103 | self.conv2 = conv3x3(64, 64) 104 | self.bn2 = BatchNorm2d(64) 105 | self.relu2 = nn.ReLU(inplace=True) 106 | self.conv3 = conv3x3(64, 128) 107 | self.bn3 = BatchNorm2d(128) 108 | self.relu3 = nn.ReLU(inplace=True) 109 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 110 | 111 | self.layer1 = self._make_layer(block, 64, layers[0]) 112 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 113 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 114 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 115 | self.avgpool = nn.AvgPool2d(7, stride=1) 116 | self.fc = nn.Linear(512 * block.expansion, num_classes) 117 | 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 121 | m.weight.data.normal_(0, math.sqrt(2. / n)) 122 | elif isinstance(m, BatchNorm2d): 123 | m.weight.data.fill_(1) 124 | m.bias.data.zero_() 125 | 126 | def _make_layer(self, block, planes, blocks, stride=1): 127 | downsample = None 128 | if stride != 1 or self.inplanes != planes * block.expansion: 129 | downsample = nn.Sequential( 130 | nn.Conv2d(self.inplanes, planes * block.expansion, 131 | kernel_size=1, stride=stride, bias=False), 132 | BatchNorm2d(planes * block.expansion), 133 | ) 134 | 135 | layers = [] 136 | layers.append(block(self.inplanes, planes, stride, downsample)) 137 | self.inplanes = planes * block.expansion 138 | for i in range(1, blocks): 139 | layers.append(block(self.inplanes, planes)) 140 | 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x): 144 | x = self.relu1(self.bn1(self.conv1(x))) 145 | x = self.relu2(self.bn2(self.conv2(x))) 146 | x = self.relu3(self.bn3(self.conv3(x))) 147 | x = self.maxpool(x) 148 | 149 | x = self.layer1(x) 150 | x = self.layer2(x) 151 | x = self.layer3(x) 152 | x = self.layer4(x) 153 | 154 | x = self.avgpool(x) 155 | x = x.view(x.size(0), -1) 156 | x = self.fc(x) 157 | 158 | return x 159 | 160 | 161 | def resnet50(pretrained=False, **kwargs): 162 | """Constructs a ResNet-50 model. 163 | 164 | Args: 165 | pretrained (bool): If True, returns a model pre-trained on ImageNet 166 | """ 167 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 168 | if pretrained: 169 | model.load_state_dict(load_url(model_urls['resnet50']), strict=False) 170 | return model 171 | 172 | 173 | def resnet18(pretrained=False, **kwargs): 174 | """Constructs a ResNet-18 model. 175 | Args: 176 | pretrained (bool): If True, returns a model pre-trained on ImageNet 177 | """ 178 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 179 | if pretrained: 180 | model.load_state_dict(load_url(model_urls['resnet18'])) 181 | return model -------------------------------------------------------------------------------- /networks/ade20k/segm_lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ewrfcas/ZITS-PlusPlus/de8dd48b17aedd15824842adb7bcca7535daba84/networks/ade20k/segm_lib/__init__.py -------------------------------------------------------------------------------- /networks/ade20k/segm_lib/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 3 | -------------------------------------------------------------------------------- /networks/ade20k/segm_lib/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /networks/ade20k/segm_lib/nn/modules/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | import queue 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def register_slave(self, identifier): 79 | """ 80 | Register an slave device. 81 | 82 | Args: 83 | identifier: an identifier, usually is the device id. 84 | 85 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 86 | 87 | """ 88 | if self._activated: 89 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 90 | self._activated = False 91 | self._registry.clear() 92 | future = FutureResult() 93 | self._registry[identifier] = _MasterRegistry(future) 94 | return SlavePipe(identifier, self._queue, future) 95 | 96 | def run_master(self, master_msg): 97 | """ 98 | Main entry for the master device in each forward pass. 99 | The messages were first collected from each devices (including the master device), and then 100 | an callback will be invoked to compute the message to be sent back to each devices 101 | (including the master device). 102 | 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | 107 | Returns: the message to be sent back to the master device. 108 | 109 | """ 110 | self._activated = True 111 | 112 | intermediates = [(0, master_msg)] 113 | for i in range(self.nr_slaves): 114 | intermediates.append(self._queue.get()) 115 | 116 | results = self._master_callback(intermediates) 117 | assert results[0][0] == 0, 'The first result should belongs to the master.' 118 | 119 | for i, res in results: 120 | if i == 0: 121 | continue 122 | self._registry[i].result.put(res) 123 | 124 | for i in range(self.nr_slaves): 125 | assert self._queue.get() is True 126 | 127 | return results[0][1] 128 | 129 | @property 130 | def nr_slaves(self): 131 | return len(self._registry) 132 | -------------------------------------------------------------------------------- /networks/ade20k/segm_lib/nn/modules/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /networks/ade20k/segm_lib/nn/modules/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ewrfcas/ZITS-PlusPlus/de8dd48b17aedd15824842adb7bcca7535daba84/networks/ade20k/segm_lib/nn/modules/tests/__init__.py -------------------------------------------------------------------------------- /networks/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_numeric_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from sync_batchnorm.unittest import TorchTestCase 14 | from torch.autograd import Variable 15 | 16 | 17 | def handy_var(a, unbias=True): 18 | n = a.size(0) 19 | asum = a.sum(dim=0) 20 | as_sum = (a ** 2).sum(dim=0) # a square sum 21 | sumvar = as_sum - asum * asum / n 22 | if unbias: 23 | return sumvar / (n - 1) 24 | else: 25 | return sumvar / n 26 | 27 | 28 | class NumericTestCase(TorchTestCase): 29 | def testNumericBatchNorm(self): 30 | a = torch.rand(16, 10) 31 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) 32 | bn.train() 33 | 34 | a_var1 = Variable(a, requires_grad=True) 35 | b_var1 = bn(a_var1) 36 | loss1 = b_var1.sum() 37 | loss1.backward() 38 | 39 | a_var2 = Variable(a, requires_grad=True) 40 | a_mean2 = a_var2.mean(dim=0, keepdim=True) 41 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) 42 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) 43 | b_var2 = (a_var2 - a_mean2) / a_std2 44 | loss2 = b_var2.sum() 45 | loss2.backward() 46 | 47 | self.assertTensorClose(bn.running_mean, a.mean(dim=0)) 48 | self.assertTensorClose(bn.running_var, handy_var(a)) 49 | self.assertTensorClose(a_var1.data, a_var2.data) 50 | self.assertTensorClose(b_var1.data, b_var2.data) 51 | self.assertTensorClose(a_var1.grad, a_var2.grad) 52 | 53 | 54 | if __name__ == '__main__': 55 | unittest.main() 56 | -------------------------------------------------------------------------------- /networks/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_sync_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback 14 | from sync_batchnorm.unittest import TorchTestCase 15 | from torch.autograd import Variable 16 | 17 | 18 | def handy_var(a, unbias=True): 19 | n = a.size(0) 20 | asum = a.sum(dim=0) 21 | as_sum = (a ** 2).sum(dim=0) # a square sum 22 | sumvar = as_sum - asum * asum / n 23 | if unbias: 24 | return sumvar / (n - 1) 25 | else: 26 | return sumvar / n 27 | 28 | 29 | def _find_bn(module): 30 | for m in module.modules(): 31 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): 32 | return m 33 | 34 | 35 | class SyncTestCase(TorchTestCase): 36 | def _syncParameters(self, bn1, bn2): 37 | bn1.reset_parameters() 38 | bn2.reset_parameters() 39 | if bn1.affine and bn2.affine: 40 | bn2.weight.data.copy_(bn1.weight.data) 41 | bn2.bias.data.copy_(bn1.bias.data) 42 | 43 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): 44 | """Check the forward and backward for the customized batch normalization.""" 45 | bn1.train(mode=is_train) 46 | bn2.train(mode=is_train) 47 | 48 | if cuda: 49 | input = input.cuda() 50 | 51 | self._syncParameters(_find_bn(bn1), _find_bn(bn2)) 52 | 53 | input1 = Variable(input, requires_grad=True) 54 | output1 = bn1(input1) 55 | output1.sum().backward() 56 | input2 = Variable(input, requires_grad=True) 57 | output2 = bn2(input2) 58 | output2.sum().backward() 59 | 60 | self.assertTensorClose(input1.data, input2.data) 61 | self.assertTensorClose(output1.data, output2.data) 62 | self.assertTensorClose(input1.grad, input2.grad) 63 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) 64 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) 65 | 66 | def testSyncBatchNormNormalTrain(self): 67 | bn = nn.BatchNorm1d(10) 68 | sync_bn = SynchronizedBatchNorm1d(10) 69 | 70 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) 71 | 72 | def testSyncBatchNormNormalEval(self): 73 | bn = nn.BatchNorm1d(10) 74 | sync_bn = SynchronizedBatchNorm1d(10) 75 | 76 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) 77 | 78 | def testSyncBatchNormSyncTrain(self): 79 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 80 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | 83 | bn.cuda() 84 | sync_bn.cuda() 85 | 86 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) 87 | 88 | def testSyncBatchNormSyncEval(self): 89 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 90 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 91 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 92 | 93 | bn.cuda() 94 | sync_bn.cuda() 95 | 96 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) 97 | 98 | def testSyncBatchNorm2DSyncTrain(self): 99 | bn = nn.BatchNorm2d(10) 100 | sync_bn = SynchronizedBatchNorm2d(10) 101 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 102 | 103 | bn.cuda() 104 | sync_bn.cuda() 105 | 106 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) 107 | 108 | 109 | if __name__ == '__main__': 110 | unittest.main() 111 | -------------------------------------------------------------------------------- /networks/ade20k/segm_lib/nn/modules/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /networks/ade20k/segm_lib/nn/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 2 | -------------------------------------------------------------------------------- /networks/ade20k/segm_lib/nn/parallel/data_parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf8 -*- 2 | 3 | import collections 4 | 5 | import torch 6 | import torch.cuda as cuda 7 | import torch.nn as nn 8 | from torch.nn.parallel._functions import Gather 9 | 10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] 11 | 12 | 13 | def async_copy_to(obj, dev, main_stream=None): 14 | if torch.is_tensor(obj): 15 | v = obj.cuda(dev, non_blocking=True) 16 | if main_stream is not None: 17 | v.data.record_stream(main_stream) 18 | return v 19 | elif isinstance(obj, collections.Mapping): 20 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} 21 | elif isinstance(obj, collections.Sequence): 22 | return [async_copy_to(o, dev, main_stream) for o in obj] 23 | else: 24 | return obj 25 | 26 | 27 | def dict_gather(outputs, target_device, dim=0): 28 | """ 29 | Gathers variables from different GPUs on a specified device 30 | (-1 means the CPU), with dictionary support. 31 | """ 32 | def gather_map(outputs): 33 | out = outputs[0] 34 | if torch.is_tensor(out): 35 | # MJY(20180330) HACK:: force nr_dims > 0 36 | if out.dim() == 0: 37 | outputs = [o.unsqueeze(0) for o in outputs] 38 | return Gather.apply(target_device, dim, *outputs) 39 | elif out is None: 40 | return None 41 | elif isinstance(out, collections.Mapping): 42 | return {k: gather_map([o[k] for o in outputs]) for k in out} 43 | elif isinstance(out, collections.Sequence): 44 | return type(out)(map(gather_map, zip(*outputs))) 45 | return gather_map(outputs) 46 | 47 | 48 | class DictGatherDataParallel(nn.DataParallel): 49 | def gather(self, outputs, output_device): 50 | return dict_gather(outputs, output_device, dim=self.dim) 51 | 52 | 53 | class UserScatteredDataParallel(DictGatherDataParallel): 54 | def scatter(self, inputs, kwargs, device_ids): 55 | assert len(inputs) == 1 56 | inputs = inputs[0] 57 | inputs = _async_copy_stream(inputs, device_ids) 58 | inputs = [[i] for i in inputs] 59 | assert len(kwargs) == 0 60 | kwargs = [{} for _ in range(len(inputs))] 61 | 62 | return inputs, kwargs 63 | 64 | 65 | def user_scattered_collate(batch): 66 | return batch 67 | 68 | 69 | def _async_copy(inputs, device_ids): 70 | nr_devs = len(device_ids) 71 | assert type(inputs) in (tuple, list) 72 | assert len(inputs) == nr_devs 73 | 74 | outputs = [] 75 | for i, dev in zip(inputs, device_ids): 76 | with cuda.device(dev): 77 | outputs.append(async_copy_to(i, dev)) 78 | 79 | return tuple(outputs) 80 | 81 | 82 | def _async_copy_stream(inputs, device_ids): 83 | nr_devs = len(device_ids) 84 | assert type(inputs) in (tuple, list) 85 | assert len(inputs) == nr_devs 86 | 87 | outputs = [] 88 | streams = [_get_stream(d) for d in device_ids] 89 | for i, dev, stream in zip(inputs, device_ids, streams): 90 | with cuda.device(dev): 91 | main_stream = cuda.current_stream() 92 | with cuda.stream(stream): 93 | outputs.append(async_copy_to(i, dev, main_stream=main_stream)) 94 | main_stream.wait_stream(stream) 95 | 96 | return outputs 97 | 98 | 99 | """Adapted from: torch/nn/parallel/_functions.py""" 100 | # background streams used for copying 101 | _streams = None 102 | 103 | 104 | def _get_stream(device): 105 | """Gets a background stream for copying between CPU and GPU""" 106 | global _streams 107 | if device == -1: 108 | return None 109 | if _streams is None: 110 | _streams = [None] * cuda.device_count() 111 | if _streams[device] is None: _streams[device] = cuda.Stream(device) 112 | return _streams[device] 113 | -------------------------------------------------------------------------------- /networks/ade20k/segm_lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .th import * 2 | -------------------------------------------------------------------------------- /networks/ade20k/segm_lib/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .dataloader import DataLoader 3 | from .dataset import Dataset, TensorDataset, ConcatDataset 4 | -------------------------------------------------------------------------------- /networks/ade20k/segm_lib/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import warnings 3 | 4 | from torch import randperm 5 | from torch._utils import _accumulate 6 | 7 | 8 | class Dataset(object): 9 | """An abstract class representing a Dataset. 10 | 11 | All other datasets should subclass it. All subclasses should override 12 | ``__len__``, that provides the size of the dataset, and ``__getitem__``, 13 | supporting integer indexing in range from 0 to len(self) exclusive. 14 | """ 15 | 16 | def __getitem__(self, index): 17 | raise NotImplementedError 18 | 19 | def __len__(self): 20 | raise NotImplementedError 21 | 22 | def __add__(self, other): 23 | return ConcatDataset([self, other]) 24 | 25 | 26 | class TensorDataset(Dataset): 27 | """Dataset wrapping data and target tensors. 28 | 29 | Each sample will be retrieved by indexing both tensors along the first 30 | dimension. 31 | 32 | Arguments: 33 | data_tensor (Tensor): contains sample data. 34 | target_tensor (Tensor): contains sample targets (labels). 35 | """ 36 | 37 | def __init__(self, data_tensor, target_tensor): 38 | assert data_tensor.size(0) == target_tensor.size(0) 39 | self.data_tensor = data_tensor 40 | self.target_tensor = target_tensor 41 | 42 | def __getitem__(self, index): 43 | return self.data_tensor[index], self.target_tensor[index] 44 | 45 | def __len__(self): 46 | return self.data_tensor.size(0) 47 | 48 | 49 | class ConcatDataset(Dataset): 50 | """ 51 | Dataset to concatenate multiple datasets. 52 | Purpose: useful to assemble different existing datasets, possibly 53 | large-scale datasets as the concatenation operation is done in an 54 | on-the-fly manner. 55 | 56 | Arguments: 57 | datasets (iterable): List of datasets to be concatenated 58 | """ 59 | 60 | @staticmethod 61 | def cumsum(sequence): 62 | r, s = [], 0 63 | for e in sequence: 64 | l = len(e) 65 | r.append(l + s) 66 | s += l 67 | return r 68 | 69 | def __init__(self, datasets): 70 | super(ConcatDataset, self).__init__() 71 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 72 | self.datasets = list(datasets) 73 | self.cumulative_sizes = self.cumsum(self.datasets) 74 | 75 | def __len__(self): 76 | return self.cumulative_sizes[-1] 77 | 78 | def __getitem__(self, idx): 79 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 80 | if dataset_idx == 0: 81 | sample_idx = idx 82 | else: 83 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 84 | return self.datasets[dataset_idx][sample_idx] 85 | 86 | @property 87 | def cummulative_sizes(self): 88 | warnings.warn("cummulative_sizes attribute is renamed to " 89 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 90 | return self.cumulative_sizes 91 | 92 | 93 | class Subset(Dataset): 94 | def __init__(self, dataset, indices): 95 | self.dataset = dataset 96 | self.indices = indices 97 | 98 | def __getitem__(self, idx): 99 | return self.dataset[self.indices[idx]] 100 | 101 | def __len__(self): 102 | return len(self.indices) 103 | 104 | 105 | def random_split(dataset, lengths): 106 | """ 107 | Randomly split a dataset into non-overlapping new datasets of given lengths 108 | ds 109 | 110 | Arguments: 111 | dataset (Dataset): Dataset to be split 112 | lengths (iterable): lengths of splits to be produced 113 | """ 114 | if sum(lengths) != len(dataset): 115 | raise ValueError("Sum of input lengths does not equal the length of the input dataset!") 116 | 117 | indices = randperm(sum(lengths)) 118 | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] 119 | -------------------------------------------------------------------------------- /networks/ade20k/segm_lib/utils/data/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.distributed import get_world_size, get_rank 5 | 6 | from .sampler import Sampler 7 | 8 | 9 | class DistributedSampler(Sampler): 10 | """Sampler that restricts data loading to a subset of the dataset. 11 | 12 | It is especially useful in conjunction with 13 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 14 | process can pass a DistributedSampler instance as a DataLoader sampler, 15 | and load a subset of the original dataset that is exclusive to it. 16 | 17 | .. note:: 18 | Dataset is assumed to be of constant size. 19 | 20 | Arguments: 21 | dataset: Dataset used for sampling. 22 | num_replicas (optional): Number of processes participating in 23 | distributed training. 24 | rank (optional): Rank of the current process within num_replicas. 25 | """ 26 | 27 | def __init__(self, dataset, num_replicas=None, rank=None): 28 | if num_replicas is None: 29 | num_replicas = get_world_size() 30 | if rank is None: 31 | rank = get_rank() 32 | self.dataset = dataset 33 | self.num_replicas = num_replicas 34 | self.rank = rank 35 | self.epoch = 0 36 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 37 | self.total_size = self.num_samples * self.num_replicas 38 | 39 | def __iter__(self): 40 | # deterministically shuffle based on epoch 41 | g = torch.Generator() 42 | g.manual_seed(self.epoch) 43 | indices = list(torch.randperm(len(self.dataset), generator=g)) 44 | 45 | # add extra samples to make it evenly divisible 46 | indices += indices[:(self.total_size - len(indices))] 47 | assert len(indices) == self.total_size 48 | 49 | # subsample 50 | offset = self.num_samples * self.rank 51 | indices = indices[offset:offset + self.num_samples] 52 | assert len(indices) == self.num_samples 53 | 54 | return iter(indices) 55 | 56 | def __len__(self): 57 | return self.num_samples 58 | 59 | def set_epoch(self, epoch): 60 | self.epoch = epoch 61 | -------------------------------------------------------------------------------- /networks/ade20k/segm_lib/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Sampler(object): 5 | """Base class for all Samplers. 6 | 7 | Every Sampler subclass has to provide an __iter__ method, providing a way 8 | to iterate over indices of dataset elements, and a __len__ method that 9 | returns the length of the returned iterators. 10 | """ 11 | 12 | def __init__(self, data_source): 13 | pass 14 | 15 | def __iter__(self): 16 | raise NotImplementedError 17 | 18 | def __len__(self): 19 | raise NotImplementedError 20 | 21 | 22 | class SequentialSampler(Sampler): 23 | """Samples elements sequentially, always in the same order. 24 | 25 | Arguments: 26 | data_source (Dataset): dataset to sample from 27 | """ 28 | 29 | def __init__(self, data_source): 30 | self.data_source = data_source 31 | 32 | def __iter__(self): 33 | return iter(range(len(self.data_source))) 34 | 35 | def __len__(self): 36 | return len(self.data_source) 37 | 38 | 39 | class RandomSampler(Sampler): 40 | """Samples elements randomly, without replacement. 41 | 42 | Arguments: 43 | data_source (Dataset): dataset to sample from 44 | """ 45 | 46 | def __init__(self, data_source): 47 | self.data_source = data_source 48 | 49 | def __iter__(self): 50 | return iter(torch.randperm(len(self.data_source)).long()) 51 | 52 | def __len__(self): 53 | return len(self.data_source) 54 | 55 | 56 | class SubsetRandomSampler(Sampler): 57 | """Samples elements randomly from a given list of indices, without replacement. 58 | 59 | Arguments: 60 | indices (list): a list of indices 61 | """ 62 | 63 | def __init__(self, indices): 64 | self.indices = indices 65 | 66 | def __iter__(self): 67 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 68 | 69 | def __len__(self): 70 | return len(self.indices) 71 | 72 | 73 | class WeightedRandomSampler(Sampler): 74 | """Samples elements from [0,..,len(weights)-1] with given probabilities (weights). 75 | 76 | Arguments: 77 | weights (list) : a list of weights, not necessary summing up to one 78 | num_samples (int): number of samples to draw 79 | replacement (bool): if ``True``, samples are drawn with replacement. 80 | If not, they are drawn without replacement, which means that when a 81 | sample index is drawn for a row, it cannot be drawn again for that row. 82 | """ 83 | 84 | def __init__(self, weights, num_samples, replacement=True): 85 | self.weights = torch.DoubleTensor(weights) 86 | self.num_samples = num_samples 87 | self.replacement = replacement 88 | 89 | def __iter__(self): 90 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) 91 | 92 | def __len__(self): 93 | return self.num_samples 94 | 95 | 96 | class BatchSampler(object): 97 | """Wraps another sampler to yield a mini-batch of indices. 98 | 99 | Args: 100 | sampler (Sampler): Base sampler. 101 | batch_size (int): Size of mini-batch. 102 | drop_last (bool): If ``True``, the sampler will drop the last batch if 103 | its size would be less than ``batch_size`` 104 | 105 | Example: 106 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) 107 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 108 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) 109 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 110 | """ 111 | 112 | def __init__(self, sampler, batch_size, drop_last): 113 | self.sampler = sampler 114 | self.batch_size = batch_size 115 | self.drop_last = drop_last 116 | 117 | def __iter__(self): 118 | batch = [] 119 | for idx in self.sampler: 120 | batch.append(idx) 121 | if len(batch) == self.batch_size: 122 | yield batch 123 | batch = [] 124 | if len(batch) > 0 and not self.drop_last: 125 | yield batch 126 | 127 | def __len__(self): 128 | if self.drop_last: 129 | return len(self.sampler) // self.batch_size 130 | else: 131 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 132 | -------------------------------------------------------------------------------- /networks/ade20k/segm_lib/utils/th.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import numpy as np 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | __all__ = ['as_variable', 'as_numpy', 'mark_volatile'] 8 | 9 | def as_variable(obj): 10 | if isinstance(obj, Variable): 11 | return obj 12 | if isinstance(obj, collections.Sequence): 13 | return [as_variable(v) for v in obj] 14 | elif isinstance(obj, collections.Mapping): 15 | return {k: as_variable(v) for k, v in obj.items()} 16 | else: 17 | return Variable(obj) 18 | 19 | def as_numpy(obj): 20 | if isinstance(obj, collections.Sequence): 21 | return [as_numpy(v) for v in obj] 22 | elif isinstance(obj, collections.Mapping): 23 | return {k: as_numpy(v) for k, v in obj.items()} 24 | elif isinstance(obj, Variable): 25 | return obj.data.cpu().numpy() 26 | elif torch.is_tensor(obj): 27 | return obj.cpu().numpy() 28 | else: 29 | return np.array(obj) 30 | 31 | def mark_volatile(obj): 32 | if torch.is_tensor(obj): 33 | obj = Variable(obj) 34 | if isinstance(obj, Variable): 35 | obj.no_grad = True 36 | return obj 37 | elif isinstance(obj, collections.Mapping): 38 | return {k: mark_volatile(o) for k, o in obj.items()} 39 | elif isinstance(obj, collections.Sequence): 40 | return [mark_volatile(o) for o in obj] 41 | else: 42 | return obj 43 | -------------------------------------------------------------------------------- /networks/ade20k/utils.py: -------------------------------------------------------------------------------- 1 | """Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch""" 2 | 3 | import os 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | 9 | try: 10 | from urllib import urlretrieve 11 | except ImportError: 12 | from urllib.request import urlretrieve 13 | 14 | 15 | def load_url(url, model_dir='./pretrained', map_location=None): 16 | if not os.path.exists(model_dir): 17 | os.makedirs(model_dir) 18 | filename = url.split('/')[-1] 19 | cached_file = os.path.join(model_dir, filename) 20 | if not os.path.exists(cached_file): 21 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 22 | urlretrieve(url, cached_file) 23 | return torch.load(cached_file, map_location=map_location) 24 | 25 | 26 | def color_encode(labelmap, colors, mode='RGB'): 27 | labelmap = labelmap.astype('int') 28 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), 29 | dtype=np.uint8) 30 | for label in np.unique(labelmap): 31 | if label < 0: 32 | continue 33 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ 34 | np.tile(colors[label], 35 | (labelmap.shape[0], labelmap.shape[1], 1)) 36 | 37 | if mode == 'BGR': 38 | return labelmap_rgb[:, :, ::-1] 39 | else: 40 | return labelmap_rgb 41 | -------------------------------------------------------------------------------- /networks/discriminators.py: -------------------------------------------------------------------------------- 1 | from networks.basic_module import * 2 | 3 | 4 | class NLayerDiscriminator(nn.Module): 5 | def __init__(self, config): 6 | super().__init__() 7 | kw = 4 8 | padw = int(np.ceil((kw - 1.0) / 2)) 9 | input_nc = config['input_nc'] 10 | 11 | self.conv1 = nn.Sequential( 12 | nn.Conv2d(in_channels=input_nc, out_channels=64, kernel_size=kw, stride=2, padding=padw), 13 | nn.LeakyReLU(0.2, inplace=True), 14 | ) 15 | 16 | self.act = nn.LeakyReLU(0.2, inplace=True) 17 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=kw, stride=2, padding=padw) 18 | self.bn2 = nn.BatchNorm2d(128) 19 | 20 | self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=kw, stride=2, padding=padw) 21 | self.bn3 = nn.BatchNorm2d(256) 22 | 23 | self.conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=kw, stride=2, padding=padw) 24 | self.bn4 = nn.BatchNorm2d(512) 25 | 26 | self.conv5 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=kw, stride=1, padding=padw) 27 | self.bn5 = nn.BatchNorm2d(512) 28 | 29 | self.conv6 = nn.Conv2d(512, 1, kernel_size=kw, stride=1, padding=padw) 30 | 31 | def forward(self, x): 32 | conv1 = self.conv1(x) 33 | 34 | conv2 = self.conv2(conv1) 35 | conv2 = self.bn2(conv2.to(torch.float32)) 36 | conv2 = self.act(conv2) 37 | 38 | conv3 = self.conv3(conv2) 39 | conv3 = self.bn3(conv3.to(torch.float32)) 40 | conv3 = self.act(conv3) 41 | 42 | conv4 = self.conv4(conv3) 43 | conv4 = self.bn4(conv4.to(torch.float32)) 44 | conv4 = self.act(conv4) 45 | 46 | conv5 = self.conv5(conv4) 47 | conv5 = self.bn5(conv5.to(torch.float32)) 48 | conv5 = self.act(conv5) 49 | 50 | conv6 = self.conv6(conv5) 51 | 52 | outputs = conv6 53 | 54 | return outputs, [conv1, conv2, conv3, conv4, conv5] 55 | 56 | 57 | class NLayerDiscriminatorSingleLogits(nn.Module): 58 | def __init__(self, config): 59 | super().__init__() 60 | kw = 4 61 | padw = int(np.ceil((kw - 1.0) / 2)) 62 | input_nc = config['input_nc'] 63 | 64 | self.conv1 = nn.Sequential( 65 | nn.Conv2d(in_channels=input_nc, out_channels=64, kernel_size=kw, stride=2, padding=padw), 66 | nn.LeakyReLU(0.2, inplace=True), 67 | ) 68 | 69 | self.act = nn.LeakyReLU(0.2, inplace=True) 70 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=kw, stride=2, padding=padw) 71 | self.bn2 = nn.BatchNorm2d(128) 72 | 73 | self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=kw, stride=2, padding=padw) 74 | self.bn3 = nn.BatchNorm2d(256) 75 | 76 | self.conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=kw, stride=2, padding=padw) 77 | self.bn4 = nn.BatchNorm2d(512) 78 | 79 | self.conv5 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=kw, stride=2, padding=padw) 80 | self.bn5 = nn.BatchNorm2d(512) 81 | 82 | self.conv6 = nn.Conv2d(512, 1, kernel_size=9, stride=1, padding=0) 83 | 84 | def forward(self, x): 85 | conv1 = self.conv1(x) 86 | 87 | conv2 = self.conv2(conv1) 88 | conv2 = self.bn2(conv2.to(torch.float32)) 89 | conv2 = self.act(conv2) 90 | 91 | conv3 = self.conv3(conv2) 92 | conv3 = self.bn3(conv3.to(torch.float32)) 93 | conv3 = self.act(conv3) 94 | 95 | conv4 = self.conv4(conv3) 96 | conv4 = self.bn4(conv4.to(torch.float32)) 97 | conv4 = self.act(conv4) 98 | 99 | conv5 = self.conv5(conv4) 100 | conv5 = self.bn5(conv5.to(torch.float32)) 101 | conv5 = self.act(conv5) 102 | 103 | conv6 = self.conv6(conv5) 104 | 105 | outputs = conv6 106 | 107 | return outputs, [conv1, conv2, conv3, conv4, conv5] 108 | 109 | 110 | class StyleDiscriminator(torch.nn.Module): 111 | def __init__(self, config): 112 | super().__init__() 113 | ch = config['ch'] 114 | input_nc = config['input_nc'] 115 | activation = config['act'] 116 | 117 | self.conv0 = DisFromRGB(input_nc, ch, activation) # 256 118 | self.conv1 = DisBlock(ch, ch * 2, activation) # 128 119 | self.conv2 = DisBlock(ch * 2, ch * 4, activation) # 64 120 | self.conv3 = DisBlock(ch * 4, ch * 8, activation) # 32 121 | self.conv4 = DisBlock(ch * 8, ch * 8, activation) # 16 122 | self.conv5 = Conv2dLayer(ch * 8, ch * 8, kernel_size=3, activation=activation) # 16 123 | self.conv6 = nn.Conv2d(ch * 8, 1, kernel_size=3, stride=1, padding=1) 124 | 125 | def forward(self, x): 126 | conv1 = self.conv1(self.conv0(x)) 127 | conv2 = self.conv2(conv1) 128 | conv3 = self.conv3(conv2) 129 | conv4 = self.conv4(conv3) 130 | conv5 = self.conv5(conv4) 131 | outputs = self.conv6(conv5) 132 | 133 | return outputs, [conv1, conv2, conv3, conv4, conv5] 134 | -------------------------------------------------------------------------------- /networks/inception.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torchvision import models 4 | 5 | 6 | class InceptionV3(nn.Module): 7 | """Pretrained InceptionV3 network returning feature maps""" 8 | 9 | # Index of default block of inception to return, 10 | # corresponds to output of final average pooling 11 | DEFAULT_BLOCK_INDEX = 3 12 | 13 | # Maps feature dimensionality to their output blocks indices 14 | BLOCK_INDEX_BY_DIM = { 15 | 64: 0, # First max pooling features 16 | 192: 1, # Second max pooling featurs 17 | 768: 2, # Pre-aux classifier features 18 | 2048: 3 # Final average pooling features 19 | } 20 | 21 | def __init__(self, 22 | output_blocks=[DEFAULT_BLOCK_INDEX], 23 | resize_input=True, 24 | normalize_input=True, 25 | requires_grad=False): 26 | """Build pretrained InceptionV3 27 | Parameters 28 | ---------- 29 | output_blocks : list of int 30 | Indices of blocks to return features of. Possible values are: 31 | - 0: corresponds to output of first max pooling 32 | - 1: corresponds to output of second max pooling 33 | - 2: corresponds to output which is fed to aux classifier 34 | - 3: corresponds to output of final average pooling 35 | resize_input : bool 36 | If true, bilinearly resizes input to width and height 299 before 37 | feeding input to model. As the network without fully connected 38 | layers is fully convolutional, it should be able to handle inputs 39 | of arbitrary size, so resizing might not be strictly needed 40 | normalize_input : bool 41 | If true, normalizes the input to the statistics the pretrained 42 | Inception network expects 43 | requires_grad : bool 44 | If true, parameters of the model require gradient. Possibly useful 45 | for finetuning the network 46 | """ 47 | super(InceptionV3, self).__init__() 48 | 49 | self.resize_input = resize_input 50 | self.normalize_input = normalize_input 51 | self.output_blocks = sorted(output_blocks) 52 | self.last_needed_block = max(output_blocks) 53 | 54 | assert self.last_needed_block <= 3, \ 55 | 'Last possible output block index is 3' 56 | 57 | self.blocks = nn.ModuleList() 58 | 59 | inception = models.inception_v3(pretrained=True) 60 | 61 | # Block 0: input to maxpool1 62 | block0 = [ 63 | inception.Conv2d_1a_3x3, 64 | inception.Conv2d_2a_3x3, 65 | inception.Conv2d_2b_3x3, 66 | nn.MaxPool2d(kernel_size=3, stride=2) 67 | ] 68 | self.blocks.append(nn.Sequential(*block0)) 69 | 70 | # Block 1: maxpool1 to maxpool2 71 | if self.last_needed_block >= 1: 72 | block1 = [ 73 | inception.Conv2d_3b_1x1, 74 | inception.Conv2d_4a_3x3, 75 | nn.MaxPool2d(kernel_size=3, stride=2) 76 | ] 77 | self.blocks.append(nn.Sequential(*block1)) 78 | 79 | # Block 2: maxpool2 to aux classifier 80 | if self.last_needed_block >= 2: 81 | block2 = [ 82 | inception.Mixed_5b, 83 | inception.Mixed_5c, 84 | inception.Mixed_5d, 85 | inception.Mixed_6a, 86 | inception.Mixed_6b, 87 | inception.Mixed_6c, 88 | inception.Mixed_6d, 89 | inception.Mixed_6e, 90 | ] 91 | self.blocks.append(nn.Sequential(*block2)) 92 | 93 | # Block 3: aux classifier to final avgpool 94 | if self.last_needed_block >= 3: 95 | block3 = [ 96 | inception.Mixed_7a, 97 | inception.Mixed_7b, 98 | inception.Mixed_7c, 99 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 100 | ] 101 | self.blocks.append(nn.Sequential(*block3)) 102 | 103 | for param in self.parameters(): 104 | param.requires_grad = requires_grad 105 | 106 | def forward(self, inp): 107 | """Get Inception feature maps 108 | Parameters 109 | ---------- 110 | inp : torch.autograd.Variable 111 | Input tensor of shape Bx3xHxW. Values are expected to be in 112 | range (0, 1) 113 | Returns 114 | ------- 115 | List of torch.autograd.Variable, corresponding to the selected output 116 | block, sorted ascending by index 117 | """ 118 | outp = [] 119 | x = inp 120 | 121 | if self.resize_input: 122 | x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False) 123 | 124 | if self.normalize_input: 125 | x = x.clone() 126 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 127 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 128 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 129 | 130 | for idx, block in enumerate(self.blocks): 131 | x = block(x) 132 | if idx in self.output_blocks: 133 | outp.append(x) 134 | 135 | if idx == self.last_needed_block: 136 | break 137 | 138 | return outp 139 | -------------------------------------------------------------------------------- /networks/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def interpolate_mask(mask, shape, allow_scale_mask=False, mask_scale_mode='nearest'): 6 | assert mask is not None 7 | assert allow_scale_mask or shape == mask.shape[-2:] 8 | if shape != mask.shape[-2:] and allow_scale_mask: 9 | if mask_scale_mode == 'maxpool': 10 | mask = F.adaptive_max_pool2d(mask, shape) 11 | else: 12 | mask = F.interpolate(mask, size=shape, mode=mask_scale_mode) 13 | return mask 14 | 15 | 16 | def generator_loss(discr_fake_pred: torch.Tensor, mask=None, args=None): 17 | fake_loss = F.softplus(-discr_fake_pred) 18 | # == if masked region should be treated differently 19 | if (args['mask_as_fake_target'] and args['extra_mask_weight_for_gen'] > 0) or not args['use_unmasked_for_gen']: 20 | mask = interpolate_mask(mask, discr_fake_pred.shape[-2:], args['allow_scale_mask'], args['mask_scale_mode']) 21 | if not args['use_unmasked_for_gen']: 22 | fake_loss = fake_loss * mask 23 | else: 24 | pixel_weights = 1 + mask * args['extra_mask_weight_for_gen'] 25 | fake_loss = fake_loss * pixel_weights 26 | 27 | return fake_loss.mean() * args['weight'] 28 | 29 | 30 | def feature_matching_loss(fake_features, target_features, mask=None): 31 | if mask is None: 32 | res = torch.stack([F.mse_loss(fake_feat, target_feat) 33 | for fake_feat, target_feat in zip(fake_features, target_features)]).mean() 34 | else: 35 | res = 0 36 | norm = 0 37 | for fake_feat, target_feat in zip(fake_features, target_features): 38 | cur_mask = F.interpolate(mask, size=fake_feat.shape[-2:], mode='bilinear', align_corners=False) 39 | error_weights = 1 - cur_mask 40 | cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean() 41 | res = res + cur_val 42 | norm += 1 43 | res = res / norm 44 | return res 45 | 46 | 47 | def make_r1_gp(discr_real_pred, real_batch): 48 | if torch.is_grad_enabled(): 49 | grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(), inputs=real_batch, create_graph=True)[0] 50 | grad_penalty = (grad_real.view(grad_real.shape[0], -1).norm(2, dim=1) ** 2).mean() 51 | else: 52 | grad_penalty = 0 53 | real_batch.requires_grad = False 54 | 55 | return grad_penalty 56 | 57 | 58 | def discriminator_real_loss(real_batch, discr_real_pred, gp_coef, do_GP=True): 59 | real_loss = F.softplus(-discr_real_pred).mean() 60 | if do_GP: 61 | grad_penalty = (make_r1_gp(discr_real_pred, real_batch) * gp_coef).mean() 62 | # grad_penalty = torch.tensor(0.0) 63 | else: 64 | grad_penalty = 0 65 | 66 | return real_loss, grad_penalty 67 | 68 | def discriminator_fake_loss(discr_fake_pred: torch.Tensor, mask=None, args=None): 69 | 70 | fake_loss = F.softplus(discr_fake_pred) 71 | 72 | if not args['use_unmasked_for_discr'] or args['mask_as_fake_target']: 73 | # == if masked region should be treated differently 74 | mask = interpolate_mask(mask, discr_fake_pred.shape[-2:], args['allow_scale_mask'], args['mask_scale_mode']) 75 | # use_unmasked_for_discr=False only makes sense for fakes; 76 | # for reals there is no difference beetween two regions 77 | fake_loss = fake_loss * mask 78 | if args['mask_as_fake_target']: 79 | fake_loss = fake_loss + (1 - mask) * F.softplus(-discr_fake_pred) 80 | 81 | sum_discr_loss = fake_loss 82 | return sum_discr_loss.mean() 83 | -------------------------------------------------------------------------------- /networks/pcp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from networks.ade20k import ModelBuilder 6 | from networks.vggNet import VGGFeatureExtractor 7 | 8 | IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None] 9 | IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None] 10 | 11 | 12 | class PerceptualLoss(nn.Module): 13 | """Perceptual loss with commonly used style loss. 14 | 15 | Args: 16 | layer_weights (dict): The weight for each layer of vgg feature. 17 | Here is an example: {'conv5_4': 1.}, which means the conv5_4 18 | feature layer (before relu5_4) will be extracted with weight 19 | 1.0 in calculting losses. 20 | vgg_type (str): The type of vgg network used as feature extractor. 21 | Default: 'vgg19'. 22 | use_input_norm (bool): If True, normalize the input image in vgg. 23 | Default: True. 24 | perceptual_weight (float): If `perceptual_weight > 0`, the perceptual 25 | loss will be calculated and the loss will multiplied by the 26 | weight. Default: 1.0. 27 | style_weight (float): If `style_weight > 0`, the style loss will be 28 | calculated and the loss will multiplied by the weight. 29 | Default: 0. 30 | norm_img (bool): If True, the image will be normed to [0, 1]. Note that 31 | this is different from the `use_input_norm` which norm the input in 32 | in forward function of vgg according to the statistics of dataset. 33 | Importantly, the input image must be in range [-1, 1]. 34 | Default: False. 35 | criterion (str): Criterion used for perceptual loss. Default: 'l1'. 36 | """ 37 | 38 | def __init__(self, 39 | layer_weights, 40 | vgg_type='vgg19', 41 | use_input_norm=True, 42 | use_pcp_loss=True, 43 | use_style_loss=False, 44 | norm_img=True, 45 | criterion='l1'): 46 | super(PerceptualLoss, self).__init__() 47 | self.norm_img = norm_img 48 | self.use_pcp_loss = use_pcp_loss 49 | self.use_style_loss = use_style_loss 50 | self.layer_weights = layer_weights 51 | self.vgg = VGGFeatureExtractor( 52 | layer_name_list=list(layer_weights.keys()), 53 | vgg_type=vgg_type, 54 | use_input_norm=use_input_norm) 55 | 56 | self.criterion_type = criterion 57 | if self.criterion_type == 'l1': 58 | self.criterion = torch.nn.L1Loss() 59 | elif self.criterion_type == 'l2': 60 | self.criterion = torch.nn.L2loss() 61 | elif self.criterion_type == 'fro': 62 | self.criterion = None 63 | else: 64 | raise NotImplementedError('%s criterion has not been supported.' % self.criterion_type) 65 | 66 | def forward(self, x, gt): 67 | """Forward function. 68 | 69 | Args: 70 | x (Tensor): Input tensor with shape (n, c, h, w). 71 | gt (Tensor): Ground-truth tensor with shape (n, c, h, w). 72 | 73 | Returns: 74 | Tensor: Forward results. 75 | """ 76 | 77 | if self.norm_img: 78 | x = torch.clamp(x, -1, 1) 79 | x = (x + 1.) * 0.5 80 | gt = (gt + 1.) * 0.5 81 | 82 | # extract vgg features 83 | x_features = self.vgg(x) 84 | gt_features = self.vgg(gt.detach()) 85 | 86 | # calculate perceptual loss 87 | if self.use_pcp_loss: 88 | percep_loss = 0 89 | for k in x_features.keys(): 90 | if self.criterion_type == 'fro': 91 | percep_loss += torch.norm( 92 | x_features[k] - gt_features[k], 93 | p='fro') * self.layer_weights[k] 94 | else: 95 | percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] 96 | else: 97 | percep_loss = None 98 | 99 | # calculate style loss 100 | if self.use_style_loss: 101 | style_loss = 0 102 | for k in x_features.keys(): 103 | if self.criterion_type == 'fro': 104 | style_loss += torch.norm( 105 | self._gram_mat(x_features[k]) - 106 | self._gram_mat(gt_features[k]), 107 | p='fro') * self.layer_weights[k] 108 | else: 109 | style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) \ 110 | * self.layer_weights[k] 111 | else: 112 | style_loss = None 113 | 114 | return percep_loss, style_loss 115 | 116 | def _gram_mat(self, x): 117 | """Calculate Gram matrix. 118 | 119 | Args: 120 | x (torch.Tensor): Tensor with shape of (n, c, h, w). 121 | 122 | Returns: 123 | torch.Tensor: Gram matrix. 124 | """ 125 | n, c, h, w = x.size() 126 | features = x.view(n, c, w * h) 127 | features_t = features.transpose(1, 2) 128 | gram = features.bmm(features_t) / (c * h * w) 129 | return gram 130 | 131 | 132 | class ResNetPL(nn.Module): 133 | def __init__(self, weight=1, weights_path=None, arch_encoder='resnet50dilated', segmentation=True): 134 | super().__init__() 135 | self.impl = ModelBuilder.get_encoder(weights_path=weights_path, 136 | arch_encoder=arch_encoder, 137 | arch_decoder='ppm_deepsup', 138 | fc_dim=2048, 139 | segmentation=segmentation) 140 | self.impl.eval() 141 | for w in self.impl.parameters(): 142 | w.requires_grad_(False) 143 | 144 | self.weight = weight 145 | 146 | def forward(self, pred, target): 147 | # -1~1 to 0~1, then norm 148 | pred, target = torch.clamp(pred, -1, 1), torch.clamp(target, -1, 1) 149 | pred = (pred + 1) / 2 150 | target = (target + 1) / 2 151 | pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred) 152 | target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target) 153 | 154 | pred_feats = self.impl(pred, return_feature_maps=True) 155 | target_feats = self.impl(target, return_feature_maps=True) 156 | 157 | result = torch.stack([F.mse_loss(cur_pred, cur_target) 158 | for cur_pred, cur_target 159 | in zip(pred_feats, target_feats)]).sum() * self.weight 160 | return result 161 | -------------------------------------------------------------------------------- /networks/upsample.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class StructureUpsampling4(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | self.convs = nn.Sequential(nn.ReflectionPad2d(3), 10 | nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=0), 11 | nn.ReLU(inplace=True), 12 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 13 | nn.ReLU(inplace=True), 14 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)) 17 | self.out = nn.Sequential(nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)) 20 | 21 | def forward(self, line): 22 | x = line 23 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 24 | x = self.convs(x) 25 | x2 = self.out(x) 26 | 27 | return x, x2 28 | -------------------------------------------------------------------------------- /networks/van.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from timm.models.layers import DropPath, trunc_normal_ 6 | 7 | 8 | class Mlp(nn.Module): 9 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 10 | super().__init__() 11 | out_features = out_features or in_features 12 | hidden_features = hidden_features or in_features 13 | self.fc1 = nn.Conv2d(in_features, hidden_features, 3, 1, 1) 14 | self.act = act_layer() 15 | self.fc2 = nn.Conv2d(hidden_features, out_features, 3, 1, 1) 16 | self.drop = nn.Dropout(drop) 17 | self.apply(self._init_weights) 18 | 19 | def _init_weights(self, m): 20 | if isinstance(m, nn.Linear): 21 | trunc_normal_(m.weight, std=.02) 22 | if isinstance(m, nn.Linear) and m.bias is not None: 23 | nn.init.constant_(m.bias, 0) 24 | elif isinstance(m, nn.LayerNorm): 25 | nn.init.constant_(m.bias, 0) 26 | nn.init.constant_(m.weight, 1.0) 27 | elif isinstance(m, nn.Conv2d): 28 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 29 | fan_out //= m.groups 30 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 31 | if m.bias is not None: 32 | m.bias.data.zero_() 33 | 34 | def forward(self, x): 35 | x = self.fc1(x) 36 | x = self.act(x) 37 | x = self.drop(x) 38 | x = self.fc2(x) 39 | x = self.drop(x) 40 | return x 41 | 42 | 43 | class LKA(nn.Module): 44 | def __init__(self, dim, kernel_size=21, dilation=3): 45 | super().__init__() 46 | self.conv0 = nn.Conv2d(dim, dim, 2*dilation-1, padding=dilation-1, groups=dim) 47 | self.conv_spatial = nn.Conv2d(dim, dim, math.ceil(kernel_size / dilation), stride=1, 48 | padding=math.ceil((kernel_size - dilation - 1) / 2), groups=dim, 49 | dilation=dilation) 50 | self.conv1 = nn.Conv2d(dim, dim, 1) 51 | 52 | def forward(self, x): 53 | u = x.clone() 54 | attn = self.conv0(x) 55 | attn = self.conv_spatial(attn) 56 | attn = self.conv1(attn) 57 | 58 | return u * attn 59 | 60 | 61 | class Attention(nn.Module): 62 | def __init__(self, d_model, kernel_size=21, dilation=3, act_layer=nn.GELU): 63 | super().__init__() 64 | 65 | self.proj_1 = nn.Conv2d(d_model, d_model, 1) 66 | self.activation = act_layer() 67 | self.spatial_gating_unit = LKA(d_model, kernel_size=kernel_size, dilation=dilation) 68 | self.proj_2 = nn.Conv2d(d_model, d_model, 1) 69 | 70 | def forward(self, x): 71 | shorcut = x.clone() 72 | x = self.proj_1(x) 73 | x = self.activation(x) 74 | x = self.spatial_gating_unit(x) 75 | x = self.proj_2(x) 76 | x = x + shorcut 77 | return x 78 | 79 | 80 | class VANBlock(nn.Module): 81 | def __init__(self, dim, mlp_ratio=1., drop=0., drop_path=0., act_layer=nn.GELU, kernel_size=21, dilation=3): 82 | super().__init__() 83 | self.norm1 = nn.BatchNorm2d(dim) 84 | self.attn = Attention(dim, kernel_size=kernel_size, dilation=dilation, act_layer=act_layer) 85 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 86 | 87 | self.norm2 = nn.BatchNorm2d(dim) 88 | mlp_hidden_dim = int(dim * mlp_ratio) 89 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 90 | layer_scale_init_value = 1e-2 91 | self.layer_scale_1 = nn.Parameter( 92 | layer_scale_init_value * torch.ones((dim)), requires_grad=True) 93 | self.layer_scale_2 = nn.Parameter( 94 | layer_scale_init_value * torch.ones((dim)), requires_grad=True) 95 | 96 | self.apply(self._init_weights) 97 | 98 | def _init_weights(self, m): 99 | if isinstance(m, nn.Linear): 100 | trunc_normal_(m.weight, std=.02) 101 | if isinstance(m, nn.Linear) and m.bias is not None: 102 | nn.init.constant_(m.bias, 0) 103 | elif isinstance(m, nn.LayerNorm): 104 | nn.init.constant_(m.bias, 0) 105 | nn.init.constant_(m.weight, 1.0) 106 | elif isinstance(m, nn.Conv2d): 107 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 108 | fan_out //= m.groups 109 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 110 | if m.bias is not None: 111 | m.bias.data.zero_() 112 | 113 | def forward(self, x): 114 | x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x))) 115 | x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))) 116 | # x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(x)) 117 | # x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm1(x))) 118 | # x = self.norm2(x) 119 | return x 120 | -------------------------------------------------------------------------------- /nms/cxx/README.md: -------------------------------------------------------------------------------- 1 | ## NOTE 2 | I copy files from [davidstutz/extended-berkeley-segmentation-benchmark](https://github.com/davidstutz/extended-berkeley-segmentation-benchmark/tree/master/source) 3 | and write `src/solve.h`, `src/solve.cc`, `src/nms.cc`, `src/build.sh` 4 | 5 | 6 | ## BUILD 7 | ```shell 8 | cd src 9 | source build.sh 10 | ``` 11 | 12 | 13 | ## License 14 | **AS SAME AS [davidstutz/extended-berkeley-segmentation-benchmark](https://github.com/davidstutz/extended-berkeley-segmentation-benchmark/tree/master/source)** 15 | 16 | Licenses for source code corresponding to: 17 | 18 | D. Stutz. **Superpixel Segmentation using Depth Information.** Bachelor Thesis, RWTH Aachen University, 2014. 19 | 20 | D. Stutz. **Superpixel Segmentation: An Evaluation.** Pattern Recognition (J. Gall, P. Gehler, B. Leibe (Eds.)), Lecture Notes in Computer Science, vol. 9358, pages 555 - 562, 2015. 21 | 22 | Note that the source code is based on the following projects for which separate licenses apply: 23 | 24 | * [Berkeley Segmentation Benchmark](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html) 25 | 26 | Copyright (c) 2014-2018 David Stutz, RWTH Aachen University 27 | 28 | **Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use this software and associated documentation files (the "Software").** 29 | 30 | The authors hereby grant you a non-exclusive, non-transferable, free of charge right to copy, modify, merge, publish, distribute, and sublicense the Software for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects. 31 | 32 | Any other use, in particular any use for commercial purposes, is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artefacts for commercial purposes. 33 | 34 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 35 | 36 | You understand and agree that the authors are under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Software. The authors nevertheless reserve the right to update, modify, or discontinue the Software at any time. 37 | 38 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. You agree to cite the corresponding papers (see above) in documents and papers that report on research using the Software. 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /nms/cxx/lib/solve_csa.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ewrfcas/ZITS-PlusPlus/de8dd48b17aedd15824842adb7bcca7535daba84/nms/cxx/lib/solve_csa.so -------------------------------------------------------------------------------- /nms/cxx/src/Exception.cc: -------------------------------------------------------------------------------- 1 | 2 | // Copyright (C) 2002 David R. Martin 3 | // 4 | // This program is free software; you can redistribute it and/or 5 | // modify it under the terms of the GNU General Public License as 6 | // published by the Free Software Foundation; either version 2 of the 7 | // License, or (at your option) any later version. 8 | // 9 | // This program is distributed in the hope that it will be useful, but 10 | // WITHOUT ANY WARRANTY; without even the implied warranty of 11 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 12 | // General Public License for more details. 13 | // 14 | // You should have received a copy of the GNU General Public License 15 | // along with this program; if not, write to the Free Software 16 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 17 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 18 | 19 | #include 20 | #include 21 | #include "Exception.hh" 22 | 23 | Exception::Exception (const char* msg) 24 | : _msg (strdup (msg)) 25 | { 26 | } 27 | 28 | Exception::Exception (const Exception& that) 29 | : _msg (strdup (that._msg)) 30 | { 31 | } 32 | 33 | Exception::~Exception () 34 | { 35 | free (_msg); 36 | } 37 | 38 | const char* 39 | Exception::msg () const 40 | { 41 | return _msg; 42 | } 43 | 44 | -------------------------------------------------------------------------------- /nms/cxx/src/Exception.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __Exception_hh__ 3 | #define __Exception_hh__ 4 | 5 | // A simple exception class that contains an error message. 6 | 7 | // Copyright (C) 2002 David R. Martin 8 | // 9 | // This program is free software; you can redistribute it and/or 10 | // modify it under the terms of the GNU General Public License as 11 | // published by the Free Software Foundation; either version 2 of the 12 | // License, or (at your option) any later version. 13 | // 14 | // This program is distributed in the hope that it will be useful, but 15 | // WITHOUT ANY WARRANTY; without even the implied warranty of 16 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 17 | // General Public License for more details. 18 | // 19 | // You should have received a copy of the GNU General Public License 20 | // along with this program; if not, write to the Free Software 21 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 22 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 23 | 24 | #include 25 | 26 | class Exception 27 | { 28 | public: 29 | 30 | // Always construct exception with a message, so we can print 31 | // a useful error/log message. 32 | Exception (const char* msg); 33 | 34 | // We need to implement the copy constructor so that rethrowing 35 | // works. 36 | Exception (const Exception& that); 37 | 38 | virtual ~Exception (); 39 | 40 | // Retrieve the message that this exception carries. 41 | virtual const char* msg () const; 42 | 43 | protected: 44 | 45 | char* _msg; 46 | 47 | }; 48 | 49 | // write to output stream 50 | inline std::ostream& operator<< (std::ostream& out, const Exception& e) { 51 | out << e.msg(); 52 | return out; 53 | } 54 | 55 | #endif // __Exception_hh__ 56 | -------------------------------------------------------------------------------- /nms/cxx/src/Random.cc: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "Random.hh" 8 | #include "String.hh" 9 | #include "Exception.hh" 10 | 11 | // Copyright (C) 2002 David R. Martin 12 | // 13 | // This program is free software; you can redistribute it and/or 14 | // modify it under the terms of the GNU General Public License as 15 | // published by the Free Software Foundation; either version 2 of the 16 | // License, or (at your option) any later version. 17 | // 18 | // This program is distributed in the hope that it will be useful, but 19 | // WITHOUT ANY WARRANTY; without even the implied warranty of 20 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 21 | // General Public License for more details. 22 | // 23 | // You should have received a copy of the GNU General Public License 24 | // along with this program; if not, write to the Free Software 25 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 26 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 27 | 28 | Random Random::rand; 29 | 30 | Random::Random () 31 | { 32 | reseed (0); 33 | } 34 | 35 | Random::Random (u_int64_t seed) 36 | { 37 | reseed (seed); 38 | } 39 | 40 | Random::Random (Random& that) 41 | { 42 | u_int64_t a = that.ui32 (); 43 | u_int64_t b = that.ui32 (); 44 | u_int64_t seed = (a << 32) | b; 45 | _init (seed); 46 | } 47 | 48 | void 49 | Random::reset () 50 | { 51 | _init (_seed); 52 | } 53 | 54 | void 55 | Random::reseed (u_int64_t seed) 56 | { 57 | if (seed == 0) { 58 | struct timeval t; 59 | gettimeofday (&t, NULL); 60 | u_int64_t a = (t.tv_usec >> 3) & 0xffff; 61 | u_int64_t b = t.tv_sec & 0xffff; 62 | u_int64_t c = (t.tv_sec >> 16) & 0xffff; 63 | seed = a | (b << 16) | (c << 32); 64 | } 65 | _init (seed); 66 | } 67 | 68 | void 69 | Random::_init (u_int64_t seed) 70 | { 71 | _seed = seed & 0xffffffffffffull; 72 | _xsubi[0] = (seed >> 0) & 0xffff; 73 | _xsubi[1] = (seed >> 16) & 0xffff; 74 | _xsubi[2] = (seed >> 32) & 0xffff; 75 | } 76 | 77 | -------------------------------------------------------------------------------- /nms/cxx/src/Random.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __Random_hh__ 3 | #define __Random_hh__ 4 | 5 | // Copyright (C) 2002 David R. Martin 6 | // 7 | // This program is free software; you can redistribute it and/or 8 | // modify it under the terms of the GNU General Public License as 9 | // published by the Free Software Foundation; either version 2 of the 10 | // License, or (at your option) any later version. 11 | // 12 | // This program is distributed in the hope that it will be useful, but 13 | // WITHOUT ANY WARRANTY; without even the implied warranty of 14 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 15 | // General Public License for more details. 16 | // 17 | // You should have received a copy of the GNU General Public License 18 | // along with this program; if not, write to the Free Software 19 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 20 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | // All random numbers are generated from a single seed. This is true 28 | // even when private random streams (seperate from the global 29 | // Random::rand stream) are spawned from existing streams, since the new 30 | // streams are seeded automatically from the parent's random stream. 31 | // Any random stream can be reset so that a sequence of random values 32 | // can be replayed. 33 | 34 | // If seed==0, then the seed is generated from the system clock. 35 | 36 | class Random 37 | { 38 | public: 39 | 40 | static Random rand; 41 | 42 | // These are defined in as the limits of int, but 43 | // here we need the limits of int32_t. 44 | static const int32_t int32_max = 2147483647; 45 | static const int32_t int32_min = -int32_max-1; 46 | static const u_int32_t u_int32_max = 4294967295u; 47 | 48 | // Seed from the system clock. 49 | Random (); 50 | 51 | // Specify seed. 52 | // If zero, seed from the system clock. 53 | Random (u_int64_t seed); 54 | 55 | // Spawn off a new random stream seeded from the parent's stream. 56 | Random (Random& that); 57 | 58 | // Restore initial seed so we can replay a random sequence. 59 | void reset (); 60 | 61 | // Set the seed. 62 | // If zero, seed from the system clock. 63 | void reseed (u_int64_t seed); 64 | 65 | // double in [0..1) or [a..b) 66 | inline double fp (); 67 | inline double fp (double a, double b); 68 | 69 | // 32-bit signed integer in [-2^31,2^31) or [a..b] 70 | inline int32_t i32 (); 71 | inline int32_t i32 (int32_t a, int32_t b); 72 | 73 | // 32-bit unsigned integer in [0,2^32) or [a..b] 74 | inline u_int32_t ui32 (); 75 | inline u_int32_t ui32 (u_int32_t a, u_int32_t b); 76 | 77 | protected: 78 | 79 | void _init (u_int64_t seed); 80 | 81 | // The original seed for this random stream. 82 | u_int64_t _seed; 83 | 84 | // The current state for this random stream. 85 | u_int16_t _xsubi[3]; 86 | 87 | }; 88 | 89 | inline u_int32_t 90 | Random::ui32 () 91 | { 92 | return ui32(0,u_int32_max); 93 | } 94 | 95 | inline u_int32_t 96 | Random::ui32 (u_int32_t a, u_int32_t b) 97 | { 98 | assert (a <= b); 99 | double x = fp (); 100 | return (u_int32_t) floor (x * ((double)b - (double)a + 1) + a); 101 | } 102 | 103 | inline int32_t 104 | Random::i32 () 105 | { 106 | return i32(int32_min,int32_max); 107 | } 108 | 109 | inline int32_t 110 | Random::i32 (int32_t a, int32_t b) 111 | { 112 | assert (a <= b); 113 | double x = fp (); 114 | return (int32_t) floor (x * ((double)b - (double)a + 1) + a); 115 | } 116 | 117 | inline double 118 | Random::fp () 119 | { 120 | return erand48 (_xsubi); 121 | } 122 | 123 | inline double 124 | Random::fp (double a, double b) 125 | { 126 | assert (a < b); 127 | return erand48 (_xsubi) * (b - a) + a; 128 | } 129 | 130 | #endif // __Random_hh__ 131 | 132 | -------------------------------------------------------------------------------- /nms/cxx/src/String.cc: -------------------------------------------------------------------------------- 1 | 2 | // Copyright (C) 2002 David R. Martin 3 | // 4 | // This program is free software; you can redistribute it and/or 5 | // modify it under the terms of the GNU General Public License as 6 | // published by the Free Software Foundation; either version 2 of the 7 | // License, or (at your option) any later version. 8 | // 9 | // This program is distributed in the hope that it will be useful, but 10 | // WITHOUT ANY WARRANTY; without even the implied warranty of 11 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 12 | // General Public License for more details. 13 | // 14 | // You should have received a copy of the GNU General Public License 15 | // along with this program; if not, write to the Free Software 16 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 17 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include "String.hh" 24 | 25 | String::String () 26 | { 27 | _length = 0; 28 | _size = defaultMinSize + 1; 29 | _text = new char [_size]; 30 | _text[_length] = '\0'; 31 | } 32 | 33 | String::String (const String& that) 34 | { 35 | _length = that._length; 36 | _size = that._size; 37 | _text = new char [_size]; 38 | memcpy (_text, that._text, _length + 1); 39 | } 40 | 41 | String::String (const char* fmt, ...) 42 | { 43 | assert (fmt != NULL); 44 | 45 | _length = 0; 46 | _size = strlen (fmt) + 1; 47 | _text = new char [_size]; 48 | _text[_length] = '\0'; 49 | 50 | va_list ap; 51 | va_start (ap, fmt); 52 | _append (fmt, ap); 53 | va_end (ap); 54 | } 55 | 56 | String::~String () 57 | { 58 | assert (_text != NULL); 59 | delete [] _text; 60 | } 61 | 62 | String& 63 | String::operator= (const String& that) 64 | { 65 | if (&that == this) { return *this; } 66 | clear(); 67 | append ("%s", that.text()); 68 | return *this; 69 | } 70 | 71 | String& 72 | String::operator= (const char* s) 73 | { 74 | clear(); 75 | if (s != NULL) { 76 | append ("%s", s); 77 | } 78 | return *this; 79 | } 80 | 81 | void 82 | String::clear () 83 | { 84 | _length = 0; 85 | _text[0] = '\0'; 86 | } 87 | 88 | void 89 | String::append (char c) 90 | { 91 | _append (1, (const char*)&c); 92 | } 93 | 94 | void 95 | String::append (unsigned length, const char* s) 96 | { 97 | _append (length, s); 98 | } 99 | 100 | void 101 | String::append (const char* fmt, ...) 102 | { 103 | assert (fmt != NULL); 104 | va_list ap; 105 | va_start (ap, fmt); 106 | _append (fmt, ap); 107 | va_end (ap); 108 | } 109 | 110 | const char& 111 | String::operator[] (unsigned i) const 112 | { 113 | assert (i < _length); 114 | return _text[i]; 115 | } 116 | 117 | bool 118 | String::nextLine (FILE* fp) 119 | { 120 | assert (fp != NULL); 121 | 122 | const int bufLen = 128; 123 | char buf[bufLen]; 124 | 125 | clear (); 126 | 127 | while (fgets (buf, bufLen, fp) != NULL) { 128 | _append (strlen (buf), buf); 129 | if (_text[_length - 1] == '\n') { 130 | _length--; 131 | _text[_length] = '\0'; 132 | return true; 133 | } 134 | } 135 | 136 | if (_length > 0) { 137 | assert (_text[_length - 1] != '\n'); 138 | return true; 139 | } else { 140 | return false; 141 | } 142 | } 143 | 144 | void 145 | String::_append (unsigned length, const char* s) 146 | { 147 | _grow (length + _length + 1); 148 | if (length > 0) { 149 | memcpy (_text + _length, s, length); 150 | _length += length; 151 | _text[_length] = '\0'; 152 | } 153 | } 154 | 155 | // On solaris and linux, vsnprintf returns the number of characters needed 156 | // to format the entire string. 157 | // On irix, vsnprintf returns the number of characters written. This is 158 | // at most length(buf)-1. 159 | // On some sytems, vsnprintf returns -1 if there wasn't enough space. 160 | void 161 | String::_append (const char* fmt, va_list ap) 162 | { 163 | int bufLen = 128; 164 | char* buf; 165 | 166 | while (1) { 167 | buf = new char [bufLen]; 168 | int cnt = vsnprintf (buf, bufLen, fmt, ap); 169 | if (cnt < 0 || cnt >= bufLen - 1) { 170 | delete [] buf; 171 | bufLen *= 2; 172 | continue; 173 | } else { 174 | break; 175 | } 176 | } 177 | 178 | _append (strlen (buf), buf); 179 | delete [] buf; 180 | } 181 | 182 | void 183 | String::_grow (unsigned minSize) 184 | { 185 | if (minSize > _size) { 186 | char* old = _text; 187 | _size += minSize; 188 | _text = new char [_size]; 189 | memcpy (_text, old, _length + 1); 190 | delete [] old; 191 | } 192 | } 193 | 194 | -------------------------------------------------------------------------------- /nms/cxx/src/String.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __String_hh__ 3 | #define __String_hh__ 4 | 5 | // Class that makes it easy to construct strings in a safe manner. 6 | // The main bonus is the printf-style interface for creating and 7 | // appending strings. 8 | 9 | // This class implements strings so that they behave like intrinsic 10 | // types, i.e. assignment creates a copy, passing by value in a 11 | // function call creates a copy. 12 | 13 | // NOTE: Calling a constructor or append() method with a plain char* 14 | // is dangerous, since the string is interpreted by sprintf. To be 15 | // safe, always do append("%s",s) instead of append(s). 16 | 17 | // Copyright (C) 2002 David R. Martin 18 | // 19 | // This program is free software; you can redistribute it and/or 20 | // modify it under the terms of the GNU General Public License as 21 | // published by the Free Software Foundation; either version 2 of the 22 | // License, or (at your option) any later version. 23 | // 24 | // This program is distributed in the hope that it will be useful, but 25 | // WITHOUT ANY WARRANTY; without even the implied warranty of 26 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 27 | // General Public License for more details. 28 | // 29 | // You should have received a copy of the GNU General Public License 30 | // along with this program; if not, write to the Free Software 31 | // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 32 | // 02111-1307, USA, or see http://www.gnu.org/copyleft/gpl.html. 33 | 34 | #include 35 | #include 36 | #include 37 | #include 38 | 39 | class String 40 | { 41 | public: 42 | 43 | // Constructors. 44 | String (); 45 | String (const String& that); 46 | String (const char* fmt, ...); 47 | 48 | // Destructor. 49 | ~String (); 50 | 51 | // Assignment operators. 52 | String& operator= (const String& that); 53 | String& operator= (const char* s); 54 | 55 | // Accessors. 56 | unsigned length () const { return _length; } 57 | const char* text () const { return _text; } 58 | const char& operator[] (unsigned i) const; 59 | 60 | // Modifiers. 61 | void clear (); 62 | void append (char c); 63 | void append (unsigned length, const char* s); 64 | void append (const char* fmt, ...); 65 | 66 | // Load next line from file; newline is discarded. 67 | // Return true if new data; false on EOF. 68 | bool nextLine (FILE* fp); 69 | 70 | // Implicit convertion to const char* is useful so that other 71 | // modules that take strings as arguments don't have to know about 72 | // the String class, and the caller doesn't have to explicitly 73 | // call the text() method. 74 | operator const char* () const { return text(); } 75 | 76 | private: 77 | 78 | static const unsigned defaultMinSize = 16; 79 | 80 | void _append (unsigned length, const char* s); 81 | void _append (const char* fmt, va_list ap); 82 | 83 | void _grow (unsigned minSize); 84 | 85 | unsigned _length; 86 | unsigned _size; 87 | char* _text; 88 | 89 | }; 90 | 91 | // == operator 92 | inline int operator== (const String& x, const String& y) 93 | { return strcmp (x, y) == 0; } 94 | inline int operator== (const String& x, const char* y) 95 | { return strcmp (x, y) == 0; } 96 | inline int operator== (const char* x, const String& y) 97 | { return strcmp (x, y) == 0; } 98 | 99 | // != operator 100 | inline int operator!= (const String& x, const String& y) 101 | { return strcmp (x, y) != 0; } 102 | inline int operator!= (const String& x, const char* y) 103 | { return strcmp (x, y) != 0; } 104 | inline int operator!= (const char* x, const String& y) 105 | { return strcmp (x, y) != 0; } 106 | 107 | // < operator 108 | inline int operator< (const String& x, const String& y) 109 | { return strcmp (x, y) < 0; } 110 | inline int operator< (const String& x, const char* y) 111 | { return strcmp (x, y) < 0; } 112 | inline int operator< (const char* x, const String& y) 113 | { return strcmp (x, y) < 0; } 114 | 115 | // > operator 116 | inline int operator> (const String& x, const String& y) 117 | { return strcmp (x, y) > 0; } 118 | inline int operator> (const String& x, const char* y) 119 | { return strcmp (x, y) > 0; } 120 | inline int operator> (const char* x, const String& y) 121 | { return strcmp (x, y) > 0; } 122 | 123 | // <= operator 124 | inline int operator<= (const String& x, const String& y) 125 | { return strcmp (x, y) <= 0; } 126 | inline int operator<= (const String& x, const char* y) 127 | { return strcmp (x, y) <= 0; } 128 | inline int operator<= (const char* x, const String& y) 129 | { return strcmp (x, y) <= 0; } 130 | 131 | // >= operator 132 | inline int operator>= (const String& x, const String& y) 133 | { return strcmp (x, y) >= 0; } 134 | inline int operator>= (const String& x, const char* y) 135 | { return strcmp (x, y) >= 0; } 136 | inline int operator>= (const char* x, const String& y) 137 | { return strcmp (x, y) >= 0; } 138 | 139 | // write to output stream 140 | inline std::ostream& operator<< (std::ostream& out, const String& s) { 141 | out << (const char*)s; 142 | return out; 143 | } 144 | 145 | #endif // __String_hh__ 146 | -------------------------------------------------------------------------------- /nms/cxx/src/build.sh: -------------------------------------------------------------------------------- 1 | g++ solve.cc nms.cc csa.cc kofn.cc Random.cc Exception.cc String.cc -v -fPIC -DNOBLAS -shared -o ../lib/solve_csa.so -------------------------------------------------------------------------------- /nms/cxx/src/csa.cc: -------------------------------------------------------------------------------- 1 | 2 | #include "csa.hh" 3 | 4 | char* CSA::err_messages[] = 5 | { 6 | "Can't read from the input file.", 7 | "Not a correct assignment problem line.", 8 | "Error reading a node descriptor from the input.", 9 | "Error reading an arc descriptor from the input.", 10 | "Unknown line type in the input", 11 | "Inconsistent number of arcs in the input.", 12 | "Parsing noncontiguous node ID numbers not implemented.", 13 | "Can't obtain enough memory to solve this problem.", 14 | }; 15 | 16 | char* CSA::nomem_msg = "Insufficient memory.\n"; 17 | 18 | CSA::CSA (int n, int m, const int* graph) 19 | { 20 | assert(n>0); 21 | assert(m>0); 22 | assert(graph!=NULL); 23 | assert((n%2)==0); 24 | _init(n,m); 25 | main(graph); 26 | } 27 | 28 | CSA::~CSA () 29 | { 30 | _delete(); 31 | } 32 | 33 | -------------------------------------------------------------------------------- /nms/cxx/src/csa_defs.h: -------------------------------------------------------------------------------- 1 | #define TRUE 1 2 | #define FALSE 0 3 | #define MAXLINE 100 4 | #define DEFAULT_SCALE_FACTOR 10 5 | #define DEFAULT_PO_COST_THRESH (2.0 * sqrt((double) n) * \ 6 | sqrt(sqrt((double) n))) 7 | #define DEFAULT_PO_WORK_THRESH 50 8 | #define DEFAULT_UPD_FAC 2 9 | #if defined(USE_SP_AUG_FORWARD) || defined(USE_SP_AUG_BACKWARD) 10 | #ifndef USE_SP_AUG 11 | #define USE_SP_AUG 12 | #endif 13 | #endif 14 | 15 | #ifdef USE_SP_AUG 16 | #define EXCESS_THRESH 127 17 | #else 18 | #define EXCESS_THRESH 0 19 | #endif 20 | 21 | #if defined(USE_P_UPDATE) || defined(STRONG_PO) 22 | #define WORK_TYPE unsigned 23 | #define REFINE_WORK relabelings 24 | #endif 25 | 26 | #if defined(DEBUG) && defined(ROUND_COSTS) 27 | #define MAGIC_MARKER 0xAAAAAAAA 28 | #endif 29 | 30 | #ifdef QUEUE_ORDER 31 | #define ACTIVE_TYPE queue 32 | #define create_active(size) active = q_create(size) 33 | #define make_active(v) enq(active, (char *) v) 34 | #define get_active_node(v) v = (lhs_ptr) deq(active) 35 | #else 36 | #define ACTIVE_TYPE stack 37 | #define create_active(size) active = st_create(size) 38 | #define make_active(v) st_push(active, (char *) v) 39 | #define get_active_node(v) v = (lhs_ptr) st_pop(active) 40 | #endif 41 | 42 | #define st_push(s, el) \ 43 | {\ 44 | *(s->top) = (char *) el;\ 45 | s->top++;\ 46 | } 47 | 48 | #define st_empty(s) (s->top == s->bottom) 49 | 50 | #define enq(q, el) \ 51 | {\ 52 | *(q->tail) = el;\ 53 | if (q->tail == q->end) q->tail = q->storage;\ 54 | else q->tail++;\ 55 | } 56 | 57 | #define q_empty(q) (q->head == q->tail ? 1 : 0) 58 | 59 | #define insert_list(node, head) \ 60 | {\ 61 | node->next = (*(head));\ 62 | (*(head))->prev = node;\ 63 | (*(head)) = node;\ 64 | node->prev = tail_rhs_node;\ 65 | } 66 | 67 | #define delete_list(node, head) \ 68 | {\ 69 | if (node->prev == tail_rhs_node)\ 70 | (*(head)) = node->next;\ 71 | node->prev->next = node->next;\ 72 | node->next->prev = node->prev;\ 73 | } 74 | 75 | /* 76 | The author hereby apologizes for the following incomprehensible 77 | muddle. Price-outs involve moving arcs around in the data structure, 78 | and it turns out to be faster to copy them field-by-field than to use 79 | memcpy() because they're so small. But the set of fields an arc has 80 | depends on lots of things, hence this mess. 81 | */ 82 | 83 | #if defined(USE_PRICE_OUT) || defined(ROUND_COSTS) 84 | #ifdef STORE_REV_ARCS 85 | #ifdef ROUND_COSTS 86 | #define copy_lr_arc(a, b) \ 87 | {\ 88 | b->head = a->head;\ 89 | b->c_init = a->c_init;\ 90 | b->c = a->c;\ 91 | b->rev = a->rev;\ 92 | } 93 | #else /* ROUND_COSTS */ 94 | #define copy_lr_arc(a, b) \ 95 | {\ 96 | b->head = a->head;\ 97 | b->c = a->c;\ 98 | b->rev = a->rev;\ 99 | } 100 | #endif /* ROUND_COSTS */ 101 | 102 | #ifdef USE_P_UPDATE 103 | #define copy_rl_arc(a, b) \ 104 | { b->tail = a->tail; b->c = a->c; b->rev = a->rev; } 105 | #else /* USE_P_UPDATE */ 106 | #define copy_rl_arc(a, b) \ 107 | { b->tail = a->tail; b->rev = a->rev; } 108 | #endif /* USE_P_UPDATE */ 109 | 110 | #define exch_rl_arcs(a, b) \ 111 | {\ 112 | copy_rl_arc(b, tail_rl_arc);\ 113 | copy_rl_arc(a, b);\ 114 | copy_rl_arc(tail_rl_arc, a);\ 115 | } 116 | #else /* STORE_REV_ARCS */ 117 | #ifdef PREC_COSTS 118 | #define copy_lr_arc(a, b) \ 119 | {\ 120 | b->head = a->head;\ 121 | b->c = a->c;\ 122 | } 123 | #else /* PREC_COSTS */ 124 | #define copy_lr_arc(a, b) \ 125 | {\ 126 | b->head = a->head;\ 127 | b->c_init = a->c_init;\ 128 | b->c = a->c;\ 129 | } 130 | #endif /* PREC_COSTS */ 131 | #endif /* STORE_REV_ARCS */ 132 | 133 | #define exch_lr_arcs(a, b) \ 134 | {\ 135 | copy_lr_arc(b, tail_lr_arc);\ 136 | copy_lr_arc(a, b);\ 137 | copy_lr_arc(tail_lr_arc, a);\ 138 | } 139 | 140 | extern lr_aptr tail_lr_arc; 141 | #ifdef STORE_REV_ARCS 142 | extern rl_aptr tail_rl_arc; 143 | #endif 144 | 145 | #ifdef STORE_REV_ARCS 146 | #define price_in_rev(a) \ 147 | { \ 148 | register rl_aptr b_a = --a->head->back_arcs; \ 149 | register rl_aptr a_r = a->rev; \ 150 | if (b_a != a_r) \ 151 | { \ 152 | register lr_aptr b_r = b_a->rev; \ 153 | exch_rl_arcs(b_a, a_r); \ 154 | b_r->rev = a_r; \ 155 | a->rev = b_a; \ 156 | } \ 157 | } 158 | 159 | #define price_out_rev(a) \ 160 | { \ 161 | register rl_aptr b_a = a->head->back_arcs; \ 162 | register rl_aptr a_r = a->rev; \ 163 | if (b_a != a_r) \ 164 | { \ 165 | register lr_aptr b_r = b_a->rev; \ 166 | exch_rl_arcs(b_a, a_r); \ 167 | b_r->rev = a_r; \ 168 | a->rev = b_a; \ 169 | } \ 170 | a->head->back_arcs++; \ 171 | } 172 | 173 | #define handle_rev_pointers(a, b) { a->rev->rev = b; b->rev->rev = a; } 174 | #else /* STORE_REV_ARCS */ 175 | #define price_in_rev(a) /* do nothing */ 176 | #define price_out_rev(a) /* do nothing */ 177 | #define handle_rev_pointers(a, b) /* do nothing */ 178 | #endif /* STORE_REV_ARCS */ 179 | 180 | #define price_in_unm_arc(v, a) \ 181 | { \ 182 | register lr_aptr f_a = --v->first; \ 183 | price_in_rev(a); \ 184 | if (f_a != a) \ 185 | { \ 186 | if (v->matched == f_a) v->matched = a; \ 187 | handle_rev_pointers(a, f_a); \ 188 | exch_lr_arcs(a, f_a); \ 189 | } \ 190 | } 191 | 192 | #define price_in_mch_arc(v, a) \ 193 | { \ 194 | register lr_aptr f_a = --v->first; \ 195 | price_in_rev(a); \ 196 | a->head->node_info.priced_in = TRUE; \ 197 | if (f_a != a) \ 198 | { \ 199 | v->matched = f_a; \ 200 | handle_rev_pointers(a, f_a); \ 201 | exch_lr_arcs(a, f_a); \ 202 | } \ 203 | } 204 | 205 | #define price_out_unm_arc(v, a) \ 206 | { \ 207 | register lr_aptr f_a = v->first++; \ 208 | price_out_rev(a); \ 209 | if (f_a != a) \ 210 | { \ 211 | if (v->matched == f_a) v->matched = a; \ 212 | handle_rev_pointers(a, f_a); \ 213 | exch_lr_arcs(a, f_a); \ 214 | } \ 215 | } 216 | 217 | #define price_out_mch_arc(v, a) \ 218 | { \ 219 | register lr_aptr f_a = v->first++; \ 220 | price_out_rev(a); \ 221 | a->head->node_info.priced_in = FALSE; \ 222 | if (f_a != a) \ 223 | { \ 224 | v->matched = f_a; \ 225 | handle_rev_pointers(a, f_a); \ 226 | exch_lr_arcs(a, f_a); \ 227 | } \ 228 | } 229 | #endif /* USE_PRICE_OUT || ROUND_COSTS */ 230 | -------------------------------------------------------------------------------- /nms/cxx/src/csa_types.h: -------------------------------------------------------------------------------- 1 | #define PREC_COSTS 2 | 3 | #if defined(QUICK_MIN) && !defined(NUM_BEST) 4 | #define NUM_BEST 3 5 | #endif 6 | 7 | #if defined(USE_SP_AUG_FORWARD) || defined(USE_SP_AUG_BACKWARD) 8 | #ifndef USE_SP_AUG 9 | #define USE_SP_AUG 10 | #endif 11 | #endif 12 | 13 | #if defined(USE_P_UPDATE) || defined(BACK_PRICE_OUT) || \ 14 | defined(USE_SP_AUG_BACKWARD) 15 | #define STORE_REV_ARCS 16 | #endif 17 | 18 | typedef struct lhs_node { 19 | #if defined(QUICK_MIN) 20 | struct { 21 | /* 22 | flag used to indicate to 23 | double_push() that so few arcs 24 | are incident that best[] is 25 | useless. 26 | */ 27 | #ifdef QUICK_MIN 28 | unsigned few_arcs : 1; 29 | #endif 30 | } node_info; 31 | #ifdef QUICK_MIN 32 | /* 33 | list of arcs to consider first in 34 | calculating the minimum-reduced-cost 35 | incident arc; if we find it here, we 36 | need look no further. 37 | */ 38 | struct lr_arc *best[NUM_BEST]; 39 | /* 40 | bound on the reduced cost of an arc we 41 | can be certain still belongs among 42 | those in best[]. 43 | */ 44 | double next_best; 45 | #endif 46 | #endif 47 | #ifdef EXPLICIT_LHS_PRICES 48 | /* 49 | price of this node. 50 | */ 51 | double p; 52 | #endif 53 | /* 54 | first arc in the arc array associated 55 | with this node. 56 | */ 57 | struct lr_arc *priced_out; 58 | /* 59 | first priced-in arc in the arc array 60 | associated with this node. 61 | */ 62 | struct lr_arc *first; 63 | /* 64 | matching arc (if any) associated with 65 | this node; NULL if this node is 66 | unmatched. 67 | */ 68 | struct lr_arc *matched; 69 | #if defined(USE_P_UPDATE) 70 | /* 71 | price change required on this node (in 72 | units of epsilon) to ensure that its 73 | excess can reach a deficit in the 74 | admissible graph. computed and used in 75 | p_update(). 76 | */ 77 | long delta_reqd; 78 | #endif 79 | #ifdef USE_SP_AUG_BACKWARD 80 | struct lr_arc *aug_path; 81 | #endif 82 | } *lhs_ptr; 83 | 84 | typedef struct rhs_node { 85 | struct { 86 | #ifdef USE_P_REFINE 87 | /* 88 | depth-first search flags. 89 | dfs is to determine whether 90 | admissible graph contains a 91 | cycle in p_refine(). 92 | */ 93 | unsigned srchng : 1; 94 | unsigned srched : 1; 95 | #endif 96 | /* 97 | flag to indicate this node's 98 | matching arc (if any) is 99 | priced in. 100 | */ 101 | unsigned priced_in : 1; 102 | } node_info; 103 | /* 104 | lhs node this rhs node is matched to. 105 | */ 106 | lhs_ptr matched; 107 | /* 108 | price of this node. 109 | */ 110 | double p; 111 | #ifdef USE_SP_AUG_FORWARD 112 | struct lr_arc *aug_path; 113 | #endif 114 | #if defined(USE_P_REFINE) || defined(USE_P_UPDATE) || defined(USE_SP_AUG) 115 | /* 116 | number of epsilons of price change 117 | required at this node to accomplish 118 | p_refine()'s or p_update()'s goal. 119 | */ 120 | long key; 121 | /* 122 | fields to maintain buckets of nodes as 123 | lists in p_refine() and p_update(). 124 | */ 125 | struct rhs_node *prev, *next; 126 | #endif 127 | #ifdef STORE_REV_ARCS 128 | /* 129 | first back arc in the arc array 130 | associated with this node. 131 | */ 132 | struct rl_arc *priced_out; 133 | /* 134 | first priced-in back arc in the arc 135 | array associated with this node. 136 | */ 137 | struct rl_arc *back_arcs; 138 | #endif 139 | } *rhs_ptr; 140 | 141 | #ifdef STORE_REV_ARCS 142 | typedef struct rl_arc { 143 | /* 144 | lhs node associated with this back 145 | arc. some would have liked the name 146 | head better. 147 | */ 148 | lhs_ptr tail; 149 | #if defined(USE_P_UPDATE) || defined(USE_SP_AUG_BACKWARD) 150 | /* 151 | cost of this back arc. this cost gets 152 | modified to incorporate other arc 153 | costs in p_update() and sp_aug(), 154 | while forward arc costs remain 155 | constant throughout. 156 | */ 157 | double c; 158 | #endif 159 | #if defined(USE_PRICE_OUT) || defined(USE_SP_AUG_BACKWARD) 160 | /* 161 | this arc's reverse in the forward arc 162 | list. 163 | */ 164 | struct lr_arc *rev; 165 | #endif 166 | } *rl_aptr; 167 | #endif 168 | 169 | typedef struct lr_arc { 170 | /* 171 | rhs node associated with this arc. 172 | */ 173 | rhs_ptr head; 174 | /* 175 | arc cost. 176 | */ 177 | double c; 178 | #ifdef USE_SP_AUG_FORWARD 179 | lhs_ptr tail; 180 | #endif 181 | #ifdef STORE_REV_ARCS 182 | /* 183 | this arc's reverse in the back arc 184 | list. 185 | */ 186 | struct rl_arc *rev; 187 | #endif 188 | } *lr_aptr; 189 | 190 | typedef struct stack_st { 191 | /* 192 | Sometimes stacks have lhs nodes, and 193 | other times they have rhs nodes. So 194 | there's a little type clash; 195 | everything gets cast to (char *) so we 196 | can use the same structure for both. 197 | */ 198 | char **bottom; 199 | char **top; 200 | } *stack; 201 | 202 | typedef struct queue_st { 203 | /* 204 | Sometimes queues have lhs nodes, and 205 | other times they have rhs nodes. So 206 | there's a little type clash; 207 | everything gets cast to (char *) so we 208 | can use the same structure for both. 209 | */ 210 | char **head; 211 | char **tail; 212 | char **storage; 213 | char **end; 214 | unsigned max_size; 215 | } *queue; 216 | -------------------------------------------------------------------------------- /nms/cxx/src/kofn.cc: -------------------------------------------------------------------------------- 1 | 2 | #include "Random.hh" 3 | #include "kofn.hh" 4 | 5 | // O(n) implementation. 6 | static void 7 | _kOfN_largeK (int k, int n, int* values) 8 | { 9 | assert (k > 0); 10 | assert (k <= n); 11 | int j = 0; 12 | for (int i = 0; i < n; i++) { 13 | double prob = (double) (k - j) / (n - i); 14 | assert (prob <= 1); 15 | double x = Random::rand.fp (); 16 | if (x < prob) { 17 | values[j++] = i; 18 | } 19 | } 20 | assert (j == k); 21 | } 22 | 23 | // O(k*lg(k)) implementation; constant factor is about 2x the constant 24 | // factor for the O(n) implementation. 25 | static void 26 | _kOfN_smallK (int k, int n, int* values) 27 | { 28 | assert (k > 0); 29 | assert (k <= n); 30 | if (k == 1) { 31 | values[0] = Random::rand.i32 (0, n - 1); 32 | return; 33 | } 34 | int leftN = n / 2; 35 | int rightN = n - leftN; 36 | int leftK = 0; 37 | int rightK = 0; 38 | for (int i = 0; i < k; i++) { 39 | int x = Random::rand.i32 (0, n - i - 1); 40 | if (x < leftN - leftK) { 41 | leftK++; 42 | } else { 43 | rightK++; 44 | } 45 | } 46 | if (leftK > 0) { _kOfN_smallK (leftK, leftN, values); } 47 | if (rightK > 0) { _kOfN_smallK (rightK, rightN, values + leftK); } 48 | for (int i = leftK; i < k; i++) { 49 | values[i] += leftN; 50 | } 51 | } 52 | 53 | // Return k randomly selected integers from the interval [0,n), in 54 | // increasing sorted order. 55 | void 56 | kOfN (int k, int n, int* values) 57 | { 58 | assert (k >= 0); 59 | assert (n >= 0); 60 | if (k == 0) { return; } 61 | static double log2 = log (2); 62 | double klogk = k * log (k) / log2; 63 | if (klogk < n / 2) { 64 | _kOfN_smallK (k, n, values); 65 | } else { 66 | _kOfN_largeK (k, n, values); 67 | } 68 | } 69 | 70 | -------------------------------------------------------------------------------- /nms/cxx/src/kofn.hh: -------------------------------------------------------------------------------- 1 | 2 | #ifndef __kofn_hh__ 3 | #define __kofn_hh__ 4 | 5 | extern "C" { 6 | void kOfN (int k, int n, int* values); 7 | } 8 | 9 | #endif // __kofn_hh__ 10 | -------------------------------------------------------------------------------- /nms/cxx/src/nms.cc: -------------------------------------------------------------------------------- 1 | #include "solve.h" 2 | #include "math.h" 3 | 4 | 5 | inline float interp(const float* image, int h, int w, float x, float y) { 6 | x = x < 0 ? 0 : (x > w - 1.001 ? w - 1.001 : x); 7 | y = y < 0 ? 0 : (y > h - 1.001 ? h - 1.001 : y); 8 | int x0 = int(x), y0 = int(y); 9 | int x1 = x0 + 1, y1 = y0 + 1; 10 | float dx0 = x - x0, dy0 = y - y0; 11 | float dx1 = 1 - dx0, dy1 = 1 - dy0; 12 | float out = image[y0 * w + x0] * dx1 * dy1 + 13 | image[y0 * w + x1] * dx0 * dy1 + 14 | image[y1 * w + x0] * dx1 * dy0 + 15 | image[y1 * w + x1] * dx0 * dy0; 16 | return out; 17 | } 18 | 19 | 20 | void nms(float* out, const float* edge, const float* ori, int r, int s, float m, int w, int h) { 21 | for (int x = 0; x < w; ++x) { 22 | for (int y = 0; y < h; ++y) { 23 | float e = out[y * w + x] = edge[y * w + x]; 24 | if (e == 0) { 25 | continue; 26 | } 27 | e *= m; 28 | float cos_o = cos(ori[y * w + x]); 29 | float sin_o = sin(ori[y * w + x]); 30 | for (int d = -r; d <= r; ++d) { 31 | if (d != 0) { 32 | float e0 = interp(edge, h, w, x + d * cos_o, y + d * sin_o); 33 | if (e < e0) { 34 | out[y * w + x] = 0; 35 | break; 36 | } 37 | } 38 | } 39 | } 40 | } 41 | 42 | 43 | s = s > w / 2 ? w / 2 : s; 44 | s = s > h / 2 ? h / 2 : s; 45 | for (int x = 0; x < s; ++x) { 46 | for (int y = 0; y < h; ++y) { 47 | out[y * w + x] *= float(x) / s; 48 | out[y * w + w - 1 - x] *= float(x) / s; 49 | } 50 | } 51 | for (int x = 0; x < w; ++x) { 52 | for (int y = 0; y < s; ++y) { 53 | out[y * w + x] *= float(y) / s; 54 | out[(h - 1 - y) * w + x] *= float(y) / s; 55 | } 56 | } 57 | } -------------------------------------------------------------------------------- /nms/cxx/src/solve.cc: -------------------------------------------------------------------------------- 1 | #include "solve.h" 2 | 3 | 4 | void solve(int n, int m, const int* graph, int* out_graph) { 5 | CSA csa(2 * n, m, graph); 6 | assert(csa.edges() == n); 7 | for (int i = 0; i < n; ++i) { 8 | int a, b, c; 9 | csa.edge(i, a, b, c); 10 | out_graph[i * 3 + 0] = a - 1; 11 | out_graph[i * 3 + 1] = b - 1 - n; 12 | out_graph[i * 3 + 2] = c; 13 | } 14 | } 15 | 16 | -------------------------------------------------------------------------------- /nms/cxx/src/solve.h: -------------------------------------------------------------------------------- 1 | #include "csa.hh" 2 | 3 | extern "C" { 4 | void solve(int n, int m, const int* graph, int* out_graph); 5 | } 6 | 7 | extern "C" { 8 | void nms(float* out, const float* edge, const float* ori, int r, int s, float m, int w, int h); 9 | } 10 | 11 | -------------------------------------------------------------------------------- /nms/impl/edges_eval_plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from scipy.interpolate import interp1d 6 | 7 | 8 | def edges_eval_plot(algs, nms=None, cols=None): 9 | """ 10 | See https://github.com/pdollar/edges/blob/master/edgesEvalPlot.m 11 | """ 12 | 13 | # parse inputs 14 | nms = nms or [] 15 | cols = cols or list("rgbkmr" * 100) 16 | cols = np.array(cols) 17 | if not isinstance(algs, list): 18 | algs = [algs] 19 | if not isinstance(nms, list): 20 | nms = [nms] 21 | nms = np.array(nms) 22 | 23 | # setup basic plot (isometric contour lines and human performance) 24 | plt.figure() 25 | ax = plt.gca() 26 | plt.box(True) 27 | plt.grid(True) 28 | plt.axhline(0.5, 0, 1, linewidth=2, color=[0.7, 0.7, 0.7]) 29 | for f in np.arange(0.1, 1, 0.1): 30 | r = np.arange(f, 1.01, 0.01) 31 | p = f * r / (2 * r - f) 32 | plt.plot(r, p, color=[0, 1, 0]) 33 | plt.plot(p, r, color=[0, 1, 0]) 34 | h = plt.plot(0.7235, 0.9014, marker="o", markersize=8, color=[0, 0.5, 0], 35 | markerfacecolor=[0, 0.5, 0], markeredgecolor=[0, 0.5, 0]) 36 | plt.xticks(np.linspace(0, 1, 11)) 37 | plt.yticks(np.linspace(0, 1, 11)) 38 | plt.xlabel("Recall") 39 | plt.ylabel("Precision") 40 | ax.set_aspect('equal', adjustable='box') 41 | plt.axis([0, 1, 0, 1]) 42 | 43 | # load results for every algorithm (pr=[T, R, P, F]) 44 | n = len(algs) 45 | hs, res, prs = [None] * n, np.zeros((n, 9), dtype=np.float32), [] 46 | for i, alg in enumerate(algs): 47 | a = "{}-eval".format(alg) 48 | pr = np.loadtxt(os.path.join(a, "eval_bdry_thr.txt")) 49 | pr = pr[pr[:, 1] >= 1e-3] 50 | _, o = np.unique(pr[:, 2], return_index=True) 51 | r50 = interp1d(pr[o, 2], pr[o, 1], bounds_error=False, fill_value=np.nan)(np.maximum(pr[o[0], 2], 0.5)) 52 | res[i, :8] = np.loadtxt(os.path.join(a, "eval_bdry.txt")) 53 | res[i, 8] = r50 54 | prs.append(pr) 55 | prs = np.stack(prs, axis=0) 56 | 57 | # sort algorithms by ODS score 58 | o = np.argsort(res[:, 3])[::-1] 59 | res, prs, cols = res[o, :], prs[o], cols[o] 60 | if nms: 61 | nms = nms[o] 62 | 63 | # plot results for every algorithm (plot best last) 64 | for i in range(n - 1, -1, -1): 65 | hs[i] = plt.plot(prs[i, :, 1], prs[i, :, 2], linestyle="-", linewidth=3, color=cols[i])[0] 66 | prefix = "ODS={:.3f}, OIS={:.3f}, AP={:.3f}, R50={:.3f}".format(*res[i, [3, 6, 7, 8]]) 67 | if nms: 68 | prefix += " - {}".format(nms[i]) 69 | print(prefix) 70 | 71 | # show legend if nms provided (report best first) 72 | if not nms: 73 | plt.show() 74 | return 75 | 76 | nms = ["[F=.80] Human"] + ["[F={:.2f}] {}".format(res[i, 3], nms[i]) for i in range(n)] 77 | hs = h + hs 78 | plt.legend(hs, nms, loc="lower left") 79 | plt.show() 80 | 81 | -------------------------------------------------------------------------------- /nms/impl/toolbox.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | from scipy import signal 5 | 6 | 7 | def conv_tri(image, r, s=1): 8 | """ 2D image convolution with a triangle filter (no fast) 9 | See https://github.com/pdollar/toolbox/blob/master/channels/convTri.m 10 | Note: signal.convolve2d does not support float16('single' in MATLAB) 11 | """ 12 | if image.size == 0 or (r == 0 and s == 1): 13 | return image 14 | if r <= 1: 15 | p = 12 / r / (r + 2) - 2 16 | f = np.array([[1, p, 1]]) / (2 + p) 17 | r = 1 18 | else: 19 | f = np.array([list(range(1, r + 1)) + [r + 1] + list(range(r, 0, -1))]) / (r + 1) ** 2 20 | f = f.astype(image.dtype) 21 | image = np.pad(image, ((r, r), (r, r)), mode="symmetric") 22 | image = signal.convolve2d(signal.convolve2d(image, f, "valid"), f.T, "valid") 23 | if s > 1: 24 | t = int(np.floor(s / 2) + 1) 25 | image = image[t-1:image.shape[0]-(s-t)+1:s, t-1:image.shape[1]-(s-t)+1:s] 26 | return image 27 | 28 | 29 | def grad2(image): 30 | """ numerical gradients along x and y directions (no fast) 31 | See https://github.com/pdollar/toolbox/blob/master/channels/gradient2.m 32 | Note: np.gradient return [oy, ox], MATLAB version return [ox, oy] 33 | """ 34 | assert image.ndim == 2 35 | oy, ox = np.gradient(image) 36 | return ox, oy 37 | 38 | 39 | class Time: 40 | def __init__(self): 41 | self.time = None 42 | 43 | def set(self): 44 | self.time = time.time() 45 | 46 | def get(self): 47 | return time.time() - self.time 48 | 49 | 50 | -------------------------------------------------------------------------------- /nms/nms_temp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from ctypes import * 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | from nms.impl.toolbox import conv_tri, grad2 8 | 9 | # NOTE: 10 | # In NMS, `if edge < interp: out = 0`, I found that sometimes edge is very close to interp. 11 | # `edge = 10e-8` and `interp = 11e-8` in C, while `edge = 10e-8` and `interp = 9e-8` in python. 12 | # ** Such slight differences (11e-8 - 9e-8 = 2e-8) in precision ** 13 | # ** would lead to very different results (`out = 0` in C and `out = edge` in python). ** 14 | # Sadly, C implementation is not expected but needed :( 15 | solver = cdll.LoadLibrary("./nms/cxx/lib/solve_csa.so") 16 | c_float_pointer = POINTER(c_float) 17 | solver.nms.argtypes = [c_float_pointer, c_float_pointer, c_float_pointer, c_int, c_int, c_float, c_int, c_int] 18 | 19 | 20 | def nms_process_one_image(image, save_path=None, save=True): 21 | """" 22 | :param image: numpy array, edge, model output 23 | :param save_path: str, save path 24 | :param save: bool, if True, save .png 25 | :return: edge 26 | NOTE: in MATLAB, uint8(x) means round(x).astype(uint8) in numpy 27 | """ 28 | 29 | if save and save_path is not None: 30 | assert os.path.splitext(save_path)[-1] == ".png" 31 | edge = conv_tri(image, 1) 32 | ox, oy = grad2(conv_tri(edge, 4)) 33 | oxx, _ = grad2(ox) 34 | oxy, oyy = grad2(oy) 35 | ori = np.mod(np.arctan(oyy * np.sign(-oxy) / (oxx + 1e-5)), np.pi) 36 | out = np.zeros_like(edge) 37 | r, s, m, w, h = 1, 5, float(1.01), int(out.shape[1]), int(out.shape[0]) 38 | solver.nms(out.ctypes.data_as(c_float_pointer), 39 | edge.ctypes.data_as(c_float_pointer), 40 | ori.ctypes.data_as(c_float_pointer), 41 | r, s, m, w, h) 42 | edge = np.round(out * 255).astype(np.uint8) 43 | if save: 44 | cv2.imwrite(save_path, edge) 45 | return edge 46 | 47 | 48 | import torch 49 | 50 | 51 | def get_nms(edge_pred, binary_threshold=55): 52 | # edge_pred:[B,1,H,W] detached 53 | device = edge_pred.device 54 | edge_np = edge_pred.cpu().numpy() 55 | 56 | edges_nms = [] 57 | for i in range(edge_np.shape[0]): 58 | try: 59 | edge_nms = nms_process_one_image(edge_np[i, 0], save_path=None, save=False) 60 | edge_nms[edge_nms > binary_threshold] = 255 61 | edge_nms[edge_nms <= binary_threshold] = 0 62 | edge_nms = edge_nms / 255. 63 | except: 64 | edge_nms = edge_np[i, 0] 65 | edge_nms = torch.tensor(edge_nms, device=device, dtype=torch.float32)[None, ...] 66 | edges_nms.append(edge_nms) 67 | 68 | edges_nms = torch.stack(edges_nms, dim=0) 69 | return edges_nms 70 | -------------------------------------------------------------------------------- /nms/nms_torch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | def conv_tri_torch(image, r): 6 | if r <= 1: 7 | p = 12 / r / (r + 2) - 2 8 | f = np.array([[1, p, 1]]) / (2 + p) 9 | r = 1 10 | else: 11 | f = np.array([list(range(1, r + 1)) + [r + 1] + list(range(r, 0, -1))]) / (r + 1) ** 2 12 | f = torch.tensor(f, dtype=image.dtype, device=image.device)[None, None, ...] 13 | image = F.pad(image, (r, r, r, r), mode='constant') 14 | image = F.conv2d(F.conv2d(image, f, stride=1), f.permute(0, 1, 3, 2), stride=1) 15 | return image 16 | 17 | 18 | def grad_torch(image): 19 | image_np = image.cpu().numpy() 20 | oy, ox = [], [] 21 | for bi in range(image.shape[0]): 22 | oy_, ox_ = np.gradient(image_np[bi, 0]) 23 | oy.append(torch.from_numpy(oy_)) 24 | ox.append(torch.from_numpy(ox_)) 25 | oy = torch.stack(oy, dim=0).unsqueeze(1).to(device=image.device, dtype=image.dtype) 26 | ox = torch.stack(ox, dim=0).unsqueeze(1).to(device=image.device, dtype=image.dtype) 27 | return ox, oy 28 | 29 | 30 | def interp_torch(edge, h, w, cos_o, sin_o): 31 | grid_h = torch.arange(h, dtype=torch.float32) 32 | grid_w = torch.arange(w, dtype=torch.float32) 33 | grid_y, grid_x = torch.meshgrid(grid_h, grid_w) 34 | grid_y = grid_y.to(device=edge.device) 35 | grid_x = grid_x.to(device=edge.device) 36 | interp_res = [] 37 | for d in [-1, 1]: 38 | grid_y_ = torch.clamp(grid_y + d * sin_o, 0, h) 39 | grid_y_ = (grid_y_ / h - 0.5) * 2 40 | grid_x_ = torch.clamp(grid_x + d * cos_o, 0, w) 41 | grid_x_ = (grid_x_ / w - 0.5) * 2 42 | grid = torch.stack([grid_x_, grid_y_], dim=-1) 43 | interp_res_ = F.grid_sample(edge, grid, mode='bilinear', align_corners=False) 44 | interp_res.append(interp_res_) 45 | 46 | return interp_res 47 | 48 | 49 | def nms_torch(edge, ori, m, h, w): 50 | nms = edge.clone() 51 | mask = (edge != 0) 52 | mask = mask.to(torch.float32) 53 | cos_o = torch.cos(ori) 54 | sin_o = torch.sin(ori) 55 | interp_maps = interp_torch(edge, h, w, cos_o, sin_o) 56 | edgem = edge * m 57 | for interp_map in interp_maps: 58 | nms[edgem < interp_map] = 0 59 | nms = edge * (1 - mask) + nms * mask 60 | return nms 61 | 62 | 63 | def get_nms(edge_pred, binary_threshold=55): 64 | edge_pred = conv_tri_torch(edge_pred, r=1) 65 | edge_pred2 = conv_tri_torch(edge_pred, r=5) 66 | oxt, oyt = grad_torch(edge_pred2) 67 | oxxt, _ = grad_torch(oxt) 68 | oxyt, oyyt = grad_torch(oyt) 69 | orit = torch.arctan(oyyt * torch.sign(-oxyt) / (oxxt + 1e-5)) % np.pi 70 | orit = orit.squeeze(1) 71 | m, h, w = 1.001, int(edge_pred.shape[2]), int(edge_pred.shape[3]) 72 | edges_nms = nms_torch(edge_pred, orit, m, h, w) 73 | edges_nms = torch.round(edges_nms * 255) 74 | edges_nms[edges_nms > binary_threshold] = 255 75 | edges_nms[edges_nms <= binary_threshold] = 0 76 | edges_nms = edges_nms / 255. 77 | 78 | return edges_nms 79 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | lpips 2 | matplotlib==3.3.4 3 | numpy==1.20.1 4 | opencv_python==4.5.5.62 5 | pandas==1.2.4 6 | pytorch_lightning==1.2.9 7 | PyYAML==6.0 8 | Requests==2.31.0 9 | scikit_image==0.15.0 10 | scipy==1.3.3 11 | sync_batchnorm== 12 | timm==0.3.2 13 | torch==1.9.0+cu111 14 | torchvision==0.10.0+cu111 15 | tqdm==4.59.0 16 | yacs==0.1.8 -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import hashlib 11 | import importlib 12 | import os 13 | import shutil 14 | from pathlib import Path 15 | 16 | import torch 17 | import torch.utils.cpp_extension 18 | from torch.utils.file_baton import FileBaton 19 | 20 | #---------------------------------------------------------------------------- 21 | # Global options. 22 | 23 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 24 | 25 | #---------------------------------------------------------------------------- 26 | # Internal helper funcs. 27 | 28 | def _find_compiler_bindir(): 29 | patterns = [ 30 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 34 | ] 35 | for pattern in patterns: 36 | matches = sorted(glob.glob(pattern)) 37 | if len(matches): 38 | return matches[-1] 39 | return None 40 | 41 | #---------------------------------------------------------------------------- 42 | # Main entry point for compiling and loading C++/CUDA plugins. 43 | 44 | _cached_plugins = dict() 45 | 46 | def get_plugin(module_name, sources, **build_kwargs): 47 | assert verbosity in ['none', 'brief', 'full'] 48 | 49 | # Already cached? 50 | if module_name in _cached_plugins: 51 | return _cached_plugins[module_name] 52 | 53 | # Print status. 54 | if verbosity == 'full': 55 | print(f'Setting up PyTorch plugin "{module_name}"...') 56 | elif verbosity == 'brief': 57 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 58 | 59 | try: # pylint: disable=too-many-nested-blocks 60 | # Make sure we can find the necessary compiler binaries. 61 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 62 | compiler_bindir = _find_compiler_bindir() 63 | if compiler_bindir is None: 64 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 65 | os.environ['PATH'] += ';' + compiler_bindir 66 | 67 | # Compile and load. 68 | verbose_build = (verbosity == 'full') 69 | 70 | # Incremental build md5sum trickery. Copies all the input source files 71 | # into a cached build directory under a combined md5 digest of the input 72 | # source files. Copying is done only if the combined digest has changed. 73 | # This keeps input file timestamps and filenames the same as in previous 74 | # extension builds, allowing for fast incremental rebuilds. 75 | # 76 | # This optimization is done only in case all the source files reside in 77 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 78 | # environment variable is set (we take this as a signal that the user 79 | # actually cares about this.) 80 | source_dirs_set = set(os.path.dirname(source) for source in sources) 81 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): 82 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) 83 | 84 | # Compute a combined hash digest for all source files in the same 85 | # custom op directory (usually .cu, .cpp, .py and .h files). 86 | hash_md5 = hashlib.md5() 87 | for src in all_source_files: 88 | with open(src, 'rb') as f: 89 | hash_md5.update(f.read()) 90 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 91 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) 92 | 93 | if not os.path.isdir(digest_build_dir): 94 | os.makedirs(digest_build_dir, exist_ok=True) 95 | baton = FileBaton(os.path.join(digest_build_dir, 'lock')) 96 | if baton.try_acquire(): 97 | try: 98 | for src in all_source_files: 99 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) 100 | finally: 101 | baton.release() 102 | else: 103 | # Someone else is copying source files under the digest dir, 104 | # wait until done and continue. 105 | baton.wait() 106 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] 107 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, 108 | verbose=verbose_build, sources=digest_sources, **build_kwargs) 109 | else: 110 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 111 | module = importlib.import_module(module_name) 112 | 113 | except: 114 | if verbosity == 'brief': 115 | print('Failed!') 116 | raise 117 | 118 | # Print status and add to cache. 119 | if verbosity == 'full': 120 | print(f'Done setting up PyTorch plugin "{module_name}".') 121 | elif verbosity == 'brief': 122 | print('Done.') 123 | _cached_plugins[module_name] = module 124 | return module 125 | 126 | #---------------------------------------------------------------------------- 127 | -------------------------------------------------------------------------------- /torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import warnings 15 | 16 | import torch 17 | 18 | # pylint: disable=redefined-builtin 19 | # pylint: disable=arguments-differ 20 | # pylint: disable=protected-access 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | enabled = False # Enable the custom op by setting this to true. 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def grid_sample(input, grid): 29 | if _should_use_custom_op(): 30 | return _GridSample2dForward.apply(input, grid) 31 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def _should_use_custom_op(): 36 | if not enabled: 37 | return False 38 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 39 | return True 40 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') 41 | return False 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | class _GridSample2dForward(torch.autograd.Function): 46 | @staticmethod 47 | def forward(ctx, input, grid): 48 | assert input.ndim == 4 49 | assert grid.ndim == 4 50 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 51 | ctx.save_for_backward(input, grid) 52 | return output 53 | 54 | @staticmethod 55 | def backward(ctx, grad_output): 56 | input, grid = ctx.saved_tensors 57 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 58 | return grad_input, grad_grid 59 | 60 | #---------------------------------------------------------------------------- 61 | 62 | class _GridSample2dBackward(torch.autograd.Function): 63 | @staticmethod 64 | def forward(ctx, grad_output, input, grid): 65 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 66 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 67 | ctx.save_for_backward(grid) 68 | return grad_input, grad_grid 69 | 70 | @staticmethod 71 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 72 | _ = grad2_grad_grid # unused 73 | grid, = ctx.saved_tensors 74 | grad2_grad_output = None 75 | grad2_input = None 76 | grad2_grid = None 77 | 78 | if ctx.needs_input_grad[0]: 79 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 80 | 81 | assert not ctx.needs_input_grad[2] 82 | return grad2_grad_output, grad2_input, grad2_grid 83 | 84 | #---------------------------------------------------------------------------- 85 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 29 | 30 | // Create output tensor. 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 37 | 38 | // Initialize CUDA kernel parameters. 39 | upfirdn2d_kernel_params p; 40 | p.x = x.data_ptr(); 41 | p.f = f.data_ptr(); 42 | p.y = y.data_ptr(); 43 | p.up = make_int2(upx, upy); 44 | p.down = make_int2(downx, downy); 45 | p.pad0 = make_int2(padx0, pady0); 46 | p.flip = (flip) ? 1 : 0; 47 | p.gain = gain; 48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 56 | 57 | // Choose CUDA kernel. 58 | upfirdn2d_kernel_spec spec; 59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 60 | { 61 | spec = choose_upfirdn2d_kernel(p); 62 | }); 63 | 64 | // Set looping options. 65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 66 | p.loopMinor = spec.loopMinor; 67 | p.loopX = spec.loopX; 68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 70 | 71 | // Compute grid size. 72 | dim3 blockSize, gridSize; 73 | if (spec.tileOutW < 0) // large 74 | { 75 | blockSize = dim3(4, 32, 1); 76 | gridSize = dim3( 77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 79 | p.launchMajor); 80 | } 81 | else // small 82 | { 83 | blockSize = dim3(256, 1, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | 90 | // Launch CUDA kernel. 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("upfirdn2d", &upfirdn2d); 101 | } 102 | 103 | //------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /trainers/impl/edges_eval_plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from scipy.interpolate import interp1d 6 | 7 | 8 | def edges_eval_plot(algs, nms=None, cols=None): 9 | """ 10 | See https://github.com/pdollar/edges/blob/master/edgesEvalPlot.m 11 | """ 12 | 13 | # parse inputs 14 | nms = nms or [] 15 | cols = cols or list("rgbkmr" * 100) 16 | cols = np.array(cols) 17 | if not isinstance(algs, list): 18 | algs = [algs] 19 | if not isinstance(nms, list): 20 | nms = [nms] 21 | nms = np.array(nms) 22 | 23 | # setup basic plot (isometric contour lines and human performance) 24 | plt.figure() 25 | ax = plt.gca() 26 | plt.box(True) 27 | plt.grid(True) 28 | plt.axhline(0.5, 0, 1, linewidth=2, color=[0.7, 0.7, 0.7]) 29 | for f in np.arange(0.1, 1, 0.1): 30 | r = np.arange(f, 1.01, 0.01) 31 | p = f * r / (2 * r - f) 32 | plt.plot(r, p, color=[0, 1, 0]) 33 | plt.plot(p, r, color=[0, 1, 0]) 34 | h = plt.plot(0.7235, 0.9014, marker="o", markersize=8, color=[0, 0.5, 0], 35 | markerfacecolor=[0, 0.5, 0], markeredgecolor=[0, 0.5, 0]) 36 | plt.xticks(np.linspace(0, 1, 11)) 37 | plt.yticks(np.linspace(0, 1, 11)) 38 | plt.xlabel("Recall") 39 | plt.ylabel("Precision") 40 | ax.set_aspect('equal', adjustable='box') 41 | plt.axis([0, 1, 0, 1]) 42 | 43 | # load results for every algorithm (pr=[T, R, P, F]) 44 | n = len(algs) 45 | hs, res, prs = [None] * n, np.zeros((n, 9), dtype=np.float32), [] 46 | for i, alg in enumerate(algs): 47 | a = "{}-eval".format(alg) 48 | pr = np.loadtxt(os.path.join(a, "eval_bdry_thr.txt")) 49 | pr = pr[pr[:, 1] >= 1e-3] 50 | _, o = np.unique(pr[:, 2], return_index=True) 51 | r50 = interp1d(pr[o, 2], pr[o, 1], bounds_error=False, fill_value=np.nan)(np.maximum(pr[o[0], 2], 0.5)) 52 | res[i, :8] = np.loadtxt(os.path.join(a, "eval_bdry.txt")) 53 | res[i, 8] = r50 54 | prs.append(pr) 55 | prs = np.stack(prs, axis=0) 56 | 57 | # sort algorithms by ODS score 58 | o = np.argsort(res[:, 3])[::-1] 59 | res, prs, cols = res[o, :], prs[o], cols[o] 60 | if nms: 61 | nms = nms[o] 62 | 63 | # plot results for every algorithm (plot best last) 64 | for i in range(n - 1, -1, -1): 65 | hs[i] = plt.plot(prs[i, :, 1], prs[i, :, 2], linestyle="-", linewidth=3, color=cols[i])[0] 66 | prefix = "ODS={:.3f}, OIS={:.3f}, AP={:.3f}, R50={:.3f}".format(*res[i, [3, 6, 7, 8]]) 67 | if nms: 68 | prefix += " - {}".format(nms[i]) 69 | print(prefix) 70 | 71 | # show legend if nms provided (report best first) 72 | if not nms: 73 | plt.show() 74 | return 75 | 76 | nms = ["[F=.80] Human"] + ["[F={:.2f}] {}".format(res[i, 3], nms[i]) for i in range(n)] 77 | hs = h + hs 78 | plt.legend(hs, nms, loc="lower left") 79 | plt.show() 80 | 81 | -------------------------------------------------------------------------------- /trainers/impl/toolbox.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | from scipy import signal 5 | 6 | 7 | def conv_tri(image, r, s=1): 8 | """ 2D image convolution with a triangle filter (no fast) 9 | See https://github.com/pdollar/toolbox/blob/master/channels/convTri.m 10 | Note: signal.convolve2d does not support float16('single' in MATLAB) 11 | """ 12 | if image.size == 0 or (r == 0 and s == 1): 13 | return image 14 | if r <= 1: 15 | p = 12 / r / (r + 2) - 2 16 | f = np.array([[1, p, 1]]) / (2 + p) 17 | r = 1 18 | else: 19 | f = np.array([list(range(1, r + 1)) + [r + 1] + list(range(r, 0, -1))]) / (r + 1) ** 2 20 | f = f.astype(image.dtype) 21 | image = np.pad(image, ((r, r), (r, r)), mode="symmetric") 22 | image = signal.convolve2d(signal.convolve2d(image, f, "valid"), f.T, "valid") 23 | if s > 1: 24 | t = int(np.floor(s / 2) + 1) 25 | image = image[t-1:image.shape[0]-(s-t)+1:s, t-1:image.shape[1]-(s-t)+1:s] 26 | return image 27 | 28 | 29 | def grad2(image): 30 | """ numerical gradients along x and y directions (no fast) 31 | See https://github.com/pdollar/toolbox/blob/master/channels/gradient2.m 32 | Note: np.gradient return [oy, ox], MATLAB version return [ox, oy] 33 | """ 34 | assert image.ndim == 2 35 | oy, ox = np.gradient(image) 36 | return ox, oy 37 | 38 | 39 | class Time: 40 | def __init__(self): 41 | self.time = None 42 | 43 | def set(self): 44 | self.time = time.time() 45 | 46 | def get(self): 47 | return time.time() - self.time 48 | 49 | 50 | -------------------------------------------------------------------------------- /trainers/lsm_hawp/lsm_hawp_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms as transforms 7 | from skimage import io 8 | from skimage.transform import resize 9 | from torchvision.transforms import functional as F 10 | from tqdm import tqdm 11 | 12 | from .detector import WireframeDetector 13 | 14 | 15 | class ResizeImage(object): 16 | def __init__(self, image_height, image_width): 17 | self.image_height = image_height 18 | self.image_width = image_width 19 | 20 | def __call__(self, image): 21 | image = resize(image, (self.image_height, self.image_width)) 22 | image = np.array(image, dtype=np.float32) / 255.0 23 | return image 24 | 25 | 26 | class ToTensor(object): 27 | def __call__(self, image): 28 | return F.to_tensor(image) 29 | 30 | 31 | class Normalize(object): 32 | def __init__(self, mean, std, to_255=True): 33 | self.mean = mean 34 | self.std = std 35 | self.to_255 = to_255 36 | 37 | def __call__(self, image): 38 | if self.to_255: 39 | image *= 255.0 40 | image = F.normalize(image, mean=self.mean, std=self.std) 41 | return image 42 | 43 | 44 | def to_device(data, device): 45 | if isinstance(data, torch.Tensor): 46 | return data.to(device) 47 | if isinstance(data, dict): 48 | for key in data: 49 | if isinstance(data[key], torch.Tensor): 50 | data[key] = data[key].to(device) 51 | return data 52 | if isinstance(data, list): 53 | return [to_device(d, device) for d in data] 54 | 55 | 56 | class LSM_HAWP: 57 | def __init__(self, threshold=0.6, size=512): 58 | self.lsm_hawp = WireframeDetector(is_cuda=True).cuda() 59 | self.transform = transforms.Compose([ResizeImage(size, size), ToTensor(), 60 | Normalize(mean=[109.730, 103.832, 98.681], 61 | std=[22.275, 22.124, 23.229], 62 | to_255=True)]) 63 | self.threshold = threshold 64 | 65 | def wireframe_detect(self, img_paths, output_path): 66 | os.makedirs(output_path, exist_ok=True) 67 | self.lsm_hawp.eval() 68 | with torch.no_grad(): 69 | for img_path in tqdm(img_paths): 70 | image = io.imread(img_path).astype(float) 71 | if len(image.shape) == 3: 72 | image = image[:, :, :3] 73 | else: 74 | image = image[:, :, None] 75 | image = np.tile(image, [1, 1, 3]) 76 | image = self.transform(image).unsqueeze(0).cuda() 77 | output = self.lsm_hawp(image) 78 | output = to_device(output, 'cpu') 79 | lines = [] 80 | scores = [] 81 | if output['num_proposals'] > 0: 82 | lines_tmp = output['lines_pred'].numpy() 83 | scores_tmp = output['lines_score'].tolist() 84 | for line, score in zip(lines_tmp, scores_tmp): 85 | if score > self.threshold: 86 | # y1, x1, y2, x2 87 | lines.append([line[1], line[0], line[3], line[2]]) 88 | scores.append(score) 89 | wireframe_info = {'lines': lines, 'scores': scores} 90 | with open(os.path.join(output_path, img_path.split('/')[-1].split('.')[0] + '.pkl'), 'wb') as w: 91 | pickle.dump(wireframe_info, w) 92 | 93 | def wireframe_places2_detect(self, img_paths, output_path): 94 | os.makedirs(output_path, exist_ok=True) 95 | self.lsm_hawp.eval() 96 | with torch.no_grad(): 97 | for img_path in tqdm(img_paths): 98 | sub_paths = img_path.split('/') 99 | idx = sub_paths.index('data_large') 100 | new_output = output_path + '/'.join(sub_paths[idx + 1:-1]) 101 | os.makedirs(new_output, exist_ok=True) 102 | new_output = os.path.join(new_output, img_path.split('/')[-1].split('.')[0] + '.pkl') 103 | if os.path.exists(new_output): 104 | continue 105 | try: 106 | image = io.imread(img_path).astype(float) 107 | except: 108 | print('error to load', img_path) 109 | continue 110 | if len(image.shape) == 3: 111 | image = image[:, :, :3] 112 | else: 113 | image = image[:, :, None] 114 | image = np.tile(image, [1, 1, 3]) 115 | image = self.transform(image).unsqueeze(0).cuda() 116 | output = self.lsm_hawp(image) 117 | output = to_device(output, 'cpu') 118 | lines = [] 119 | scores = [] 120 | if output['num_proposals'] > 0: 121 | lines_tmp = output['lines_pred'].numpy() 122 | scores_tmp = output['lines_score'].tolist() 123 | for line, score in zip(lines_tmp, scores_tmp): 124 | if score > self.threshold: 125 | # y1, x1, y2, x2 126 | lines.append([line[1], line[0], line[3], line[2]]) 127 | scores.append(score) 128 | wireframe_info = {'lines': lines, 'scores': scores} 129 | with open(new_output, 'wb') as w: 130 | pickle.dump(wireframe_info, w) 131 | -------------------------------------------------------------------------------- /trainers/lsm_hawp/model_config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | def get_config(): 4 | cfg = CN() 5 | MODELS = CN() 6 | 7 | MODELS.NAME = "Hourglass" 8 | HGNETS = CN() 9 | 10 | HGNETS.DEPTH = 4 11 | HGNETS.NUM_STACKS = 2 12 | HGNETS.NUM_BLOCKS = 1 13 | 14 | HGNETS.INPLANES = 64 15 | HGNETS.NUM_FEATS = 128 16 | 17 | MODELS.HGNETS = HGNETS 18 | MODELS.DEVICE = "cuda" 19 | MODELS.WEIGHTS = "" 20 | MODELS.HEAD_SIZE = [[3], [1], [1], [2], [2]] 21 | MODELS.OUT_FEATURE_CHANNELS = 256 22 | 23 | MODELS.LOSS_WEIGHTS = CN(new_allowed=True) 24 | 25 | PARSING_HEAD = CN() 26 | 27 | PARSING_HEAD.MAX_DISTANCE = 5.0 28 | 29 | PARSING_HEAD.N_STC_POSL = 300 30 | PARSING_HEAD.N_STC_NEGL = 40 31 | 32 | PARSING_HEAD.MATCHING_STRATEGY = 'junction' 33 | PARSING_HEAD.N_DYN_JUNC = 300 34 | PARSING_HEAD.N_DYN_POSL = 300 35 | PARSING_HEAD.N_DYN_NEGL = 300 36 | PARSING_HEAD.N_DYN_OTHR = 0 37 | PARSING_HEAD.N_DYN_OTHR2 = 300 38 | 39 | PARSING_HEAD.N_PTS0 = 32 40 | PARSING_HEAD.N_PTS1 = 8 41 | 42 | PARSING_HEAD.DIM_LOI = 128 43 | PARSING_HEAD.DIM_FC = 1024 44 | PARSING_HEAD.USE_RESIDUAL = True 45 | PARSING_HEAD.N_OUT_JUNC = 250 46 | PARSING_HEAD.N_OUT_LINE = 2500 47 | 48 | MODELS.PARSING_HEAD = PARSING_HEAD 49 | MODELS.SCALE = 1.0 50 | 51 | cfg.MODEL = MODELS 52 | 53 | return cfg 54 | -------------------------------------------------------------------------------- /trainers/lsm_hawp/multi_task_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MultitaskHead(nn.Module): 6 | def __init__(self, input_channels, num_class, head_size): 7 | super(MultitaskHead, self).__init__() 8 | 9 | m = int(input_channels / 4) 10 | heads = [] 11 | for output_channels in sum(head_size, []): 12 | heads.append( 13 | nn.Sequential( 14 | nn.Conv2d(input_channels, m, kernel_size=3, padding=1), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(m, output_channels, kernel_size=1), 17 | ) 18 | ) 19 | self.heads = nn.ModuleList(heads) 20 | assert num_class == sum(sum(head_size, [])) 21 | 22 | def forward(self, x): 23 | return torch.cat([head(x) for head in self.heads], dim=1) 24 | -------------------------------------------------------------------------------- /trainers/lsm_hawp/stacked_hg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hourglass network inserted in the pre-activated Resnet 3 | Use lr=0.01 for current version 4 | (c) Nan Xue (HAWP) 5 | (c) Yichao Zhou (LCNN) 6 | (c) YANG, Wei 7 | """ 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class Bottleneck2D(nn.Module): 13 | expansion = 2 14 | 15 | def __init__(self, inplanes, planes, stride=1, downsample=None): 16 | super(Bottleneck2D, self).__init__() 17 | 18 | self.bn1 = nn.BatchNorm2d(inplanes) 19 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1) 22 | self.bn3 = nn.BatchNorm2d(planes) 23 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.downsample = downsample 26 | self.stride = stride 27 | 28 | def forward(self, x): 29 | residual = x 30 | 31 | out = self.bn1(x) 32 | out = self.relu(out) 33 | out = self.conv1(out) 34 | 35 | out = self.bn2(out) 36 | out = self.relu(out) 37 | out = self.conv2(out) 38 | 39 | out = self.bn3(out) 40 | out = self.relu(out) 41 | out = self.conv3(out) 42 | 43 | if self.downsample is not None: 44 | residual = self.downsample(x) 45 | 46 | out += residual 47 | 48 | return out 49 | 50 | 51 | class Hourglass(nn.Module): 52 | def __init__(self, block, num_blocks, planes, depth): 53 | super(Hourglass, self).__init__() 54 | self.depth = depth 55 | self.block = block 56 | self.hg = self._make_hour_glass(block, num_blocks, planes, depth) 57 | 58 | def _make_residual(self, block, num_blocks, planes): 59 | layers = [] 60 | for i in range(0, num_blocks): 61 | layers.append(block(planes * block.expansion, planes)) 62 | return nn.Sequential(*layers) 63 | 64 | def _make_hour_glass(self, block, num_blocks, planes, depth): 65 | hg = [] 66 | for i in range(depth): 67 | res = [] 68 | for j in range(3): 69 | res.append(self._make_residual(block, num_blocks, planes)) 70 | if i == 0: 71 | res.append(self._make_residual(block, num_blocks, planes)) 72 | hg.append(nn.ModuleList(res)) 73 | return nn.ModuleList(hg) 74 | 75 | def _hour_glass_forward(self, n, x): 76 | up1 = self.hg[n - 1][0](x) 77 | low1 = F.max_pool2d(x, 2, stride=2) 78 | low1 = self.hg[n - 1][1](low1) 79 | 80 | if n > 1: 81 | low2 = self._hour_glass_forward(n - 1, low1) 82 | else: 83 | low2 = self.hg[n - 1][3](low1) 84 | low3 = self.hg[n - 1][2](low2) 85 | up2 = F.interpolate(low3, scale_factor=2) 86 | out = up1 + up2 87 | return out 88 | 89 | def forward(self, x): 90 | return self._hour_glass_forward(self.depth, x) 91 | 92 | 93 | class HourglassNet(nn.Module): 94 | """Hourglass model from Newell et al ECCV 2016""" 95 | 96 | def __init__(self, input_channel, inplanes, num_feats, block, head, depth, num_stacks, num_blocks, num_classes): 97 | super(HourglassNet, self).__init__() 98 | 99 | self.inplanes = inplanes 100 | self.num_feats = num_feats 101 | self.num_stacks = num_stacks 102 | self.conv1 = nn.Conv2d(input_channel, self.inplanes, kernel_size=7, stride=2, padding=3) 103 | self.bn1 = nn.BatchNorm2d(self.inplanes) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.layer1 = self._make_residual(block, self.inplanes, 1) 106 | self.layer2 = self._make_residual(block, self.inplanes, 1) 107 | self.layer3 = self._make_residual(block, self.num_feats, 1) 108 | self.maxpool = nn.MaxPool2d(2, stride=2) 109 | 110 | # build hourglass modules 111 | ch = self.num_feats * block.expansion 112 | hg, res, fc, score, fc_, score_ = [], [], [], [], [], [] 113 | for i in range(num_stacks): 114 | hg.append(Hourglass(block, num_blocks, self.num_feats, depth)) 115 | res.append(self._make_residual(block, self.num_feats, num_blocks)) 116 | fc.append(self._make_fc(ch, ch)) 117 | score.append(head(ch, num_classes)) 118 | if i < num_stacks - 1: 119 | fc_.append(nn.Conv2d(ch, ch, kernel_size=1)) 120 | score_.append(nn.Conv2d(num_classes, ch, kernel_size=1)) 121 | self.hg = nn.ModuleList(hg) 122 | self.res = nn.ModuleList(res) 123 | self.fc = nn.ModuleList(fc) 124 | self.score = nn.ModuleList(score) 125 | self.fc_ = nn.ModuleList(fc_) 126 | self.score_ = nn.ModuleList(score_) 127 | 128 | def _make_residual(self, block, planes, blocks, stride=1): 129 | downsample = None 130 | if stride != 1 or self.inplanes != planes * block.expansion: 131 | downsample = nn.Sequential( 132 | nn.Conv2d( 133 | self.inplanes, 134 | planes * block.expansion, 135 | kernel_size=1, 136 | stride=stride, 137 | ) 138 | ) 139 | 140 | layers = [] 141 | layers.append(block(self.inplanes, planes, stride, downsample)) 142 | self.inplanes = planes * block.expansion 143 | for i in range(1, blocks): 144 | layers.append(block(self.inplanes, planes)) 145 | 146 | return nn.Sequential(*layers) 147 | 148 | def _make_fc(self, inplanes, outplanes): 149 | bn = nn.BatchNorm2d(inplanes) 150 | conv = nn.Conv2d(inplanes, outplanes, kernel_size=1) 151 | return nn.Sequential(conv, bn, self.relu) 152 | 153 | def forward(self, x): 154 | out = [] 155 | x = self.conv1(x) 156 | x = self.bn1(x) 157 | x = self.relu(x) 158 | 159 | x = self.layer1(x) 160 | # if the inputsize is 256,remain the hourglass input/output is 128 161 | if x.shape[2] >= 256: 162 | x = self.maxpool(x) 163 | x = self.layer2(x) 164 | x = self.layer3(x) 165 | 166 | for i in range(self.num_stacks): 167 | y = self.hg[i](x) 168 | y = self.res[i](y) 169 | y = self.fc[i](y) 170 | score = self.score[i](y) 171 | out.append(score) 172 | 173 | if i < self.num_stacks - 1: 174 | fc_ = self.fc_[i](y) 175 | score_ = self.score_[i](score) 176 | x = x + fc_ + score_ 177 | 178 | return out[::-1], y 179 | -------------------------------------------------------------------------------- /trainers/nms_temp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from ctypes import * 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | from .impl.toolbox import conv_tri, grad2 8 | 9 | # NOTE: 10 | # In NMS, `if edge < interp: out = 0`, I found that sometimes edge is very close to interp. 11 | # `edge = 10e-8` and `interp = 11e-8` in C, while `edge = 10e-8` and `interp = 9e-8` in python. 12 | # ** Such slight differences (11e-8 - 9e-8 = 2e-8) in precision ** 13 | # ** would lead to very different results (`out = 0` in C and `out = edge` in python). ** 14 | # Sadly, C implementation is not expected but needed :( 15 | solver = cdll.LoadLibrary("/home/wmlce/dql_inpainting/CNN_final/src/cxx/lib/solve_csa.so") 16 | c_float_pointer = POINTER(c_float) 17 | solver.nms.argtypes = [c_float_pointer, c_float_pointer, c_float_pointer, c_int, c_int, c_float, c_int, c_int] 18 | 19 | 20 | def nms_process_one_image(image, save_path=None, save=True): 21 | """" 22 | :param image: numpy array, edge, model output 23 | :param save_path: str, save path 24 | :param save: bool, if True, save .png 25 | :return: edge 26 | NOTE: in MATLAB, uint8(x) means round(x).astype(uint8) in numpy 27 | """ 28 | 29 | if save and save_path is not None: 30 | assert os.path.splitext(save_path)[-1] == ".png" 31 | edge = conv_tri(image, 1) 32 | ox, oy = grad2(conv_tri(edge, 4)) 33 | oxx, _ = grad2(ox) 34 | oxy, oyy = grad2(oy) 35 | ori = np.mod(np.arctan(oyy * np.sign(-oxy) / (oxx + 1e-5)), np.pi) 36 | out = np.zeros_like(edge) 37 | r, s, m, w, h = 1, 5, float(1.01), int(out.shape[1]), int(out.shape[0]) 38 | solver.nms(out.ctypes.data_as(c_float_pointer), 39 | edge.ctypes.data_as(c_float_pointer), 40 | ori.ctypes.data_as(c_float_pointer), 41 | r, s, m, w, h) 42 | edge = np.round(out * 255).astype(np.uint8) 43 | if save: 44 | cv2.imwrite(save_path, edge) 45 | return edge 46 | 47 | 48 | import torch 49 | 50 | 51 | def get_nms(edge_pred, binary_threshold=55): 52 | # edge_pred:[B,1,H,W] detached 53 | device = edge_pred.device 54 | edge_np = edge_pred.cpu().numpy() 55 | 56 | edges_nms = [] 57 | for i in range(edge_np.shape[0]): 58 | try: 59 | edge_nms = nms_process_one_image(edge_np[i, 0], save_path=None, save=False) 60 | edge_nms[edge_nms > binary_threshold] = 255 61 | edge_nms[edge_nms <= binary_threshold] = 0 62 | edge_nms = edge_nms / 255. 63 | except: 64 | edge_nms = edge_np[i, 0] 65 | edge_nms = torch.tensor(edge_nms, device=device, dtype=torch.float32)[None, ...] 66 | edges_nms.append(edge_nms) 67 | 68 | edges_nms = torch.stack(edges_nms, dim=0) 69 | return edges_nms 70 | 71 | -------------------------------------------------------------------------------- /trainers/nms_torch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | def conv_tri_torch(image, r): 6 | if r <= 1: 7 | p = 12 / r / (r + 2) - 2 8 | f = np.array([[1, p, 1]]) / (2 + p) 9 | r = 1 10 | else: 11 | f = np.array([list(range(1, r + 1)) + [r + 1] + list(range(r, 0, -1))]) / (r + 1) ** 2 12 | f = torch.tensor(f, dtype=image.dtype, device=image.device)[None, None, ...] 13 | image = F.pad(image, (r, r, r, r), mode='constant') 14 | image = F.conv2d(F.conv2d(image, f, stride=1), f.permute(0, 1, 3, 2), stride=1) 15 | return image 16 | 17 | 18 | def grad_torch(image): 19 | image_np = image.cpu().numpy() 20 | oy, ox = [], [] 21 | for bi in range(image.shape[0]): 22 | oy_, ox_ = np.gradient(image_np[bi, 0]) 23 | oy.append(torch.from_numpy(oy_)) 24 | ox.append(torch.from_numpy(ox_)) 25 | oy = torch.stack(oy, dim=0).unsqueeze(1).to(device=image.device, dtype=image.dtype) 26 | ox = torch.stack(ox, dim=0).unsqueeze(1).to(device=image.device, dtype=image.dtype) 27 | return ox, oy 28 | 29 | 30 | def interp_torch(edge, h, w, cos_o, sin_o): 31 | grid_h = torch.arange(h, dtype=torch.float32) 32 | grid_w = torch.arange(w, dtype=torch.float32) 33 | grid_y, grid_x = torch.meshgrid(grid_h, grid_w) 34 | grid_y = grid_y.to(device=edge.device) 35 | grid_x = grid_x.to(device=edge.device) 36 | interp_res = [] 37 | for d in [-1, 1]: 38 | grid_y_ = torch.clamp(grid_y + d * sin_o, 0, h) 39 | grid_y_ = (grid_y_ / h - 0.5) * 2 40 | grid_x_ = torch.clamp(grid_x + d * cos_o, 0, w) 41 | grid_x_ = (grid_x_ / w - 0.5) * 2 42 | grid = torch.stack([grid_x_, grid_y_], dim=-1) 43 | interp_res_ = F.grid_sample(edge, grid, mode='bilinear', align_corners=False) 44 | interp_res.append(interp_res_) 45 | 46 | return interp_res 47 | 48 | 49 | def nms_torch(edge, ori, m, h, w): 50 | nms = edge.clone() 51 | mask = (edge != 0) 52 | mask = mask.to(torch.float32) 53 | cos_o = torch.cos(ori) 54 | sin_o = torch.sin(ori) 55 | interp_maps = interp_torch(edge, h, w, cos_o, sin_o) 56 | edgem = edge * m 57 | for interp_map in interp_maps: 58 | nms[edgem < interp_map] = 0 59 | nms = edge * (1 - mask) + nms * mask 60 | return nms 61 | 62 | 63 | def get_nms(edge_pred, binary_threshold=55): 64 | edge_pred = conv_tri_torch(edge_pred, r=1) 65 | edge_pred2 = conv_tri_torch(edge_pred, r=5) 66 | oxt, oyt = grad_torch(edge_pred2) 67 | oxxt, _ = grad_torch(oxt) 68 | oxyt, oyyt = grad_torch(oyt) 69 | orit = torch.arctan(oyyt * torch.sign(-oxyt) / (oxxt + 1e-5)) % np.pi 70 | orit = orit.squeeze(1) 71 | m, h, w = 1.001, int(edge_pred.shape[2]), int(edge_pred.shape[3]) 72 | edges_nms = nms_torch(edge_pred, orit, m, h, w) 73 | edges_nms = torch.round(edges_nms * 255) 74 | edges_nms[edges_nms > binary_threshold] = 255 75 | edges_nms[edges_nms <= binary_threshold] = 0 76 | edges_nms = edges_nms / 255. 77 | 78 | return edges_nms 79 | --------------------------------------------------------------------------------