├── README.md ├── configs ├── apollosim │ ├── anchor3dlane.py │ └── anchor3dlane_iter.py ├── once │ ├── anchor3dlane.py │ ├── anchor3dlane_effb3.py │ └── anchor3dlane_iter.py └── openlane │ ├── anchor3dlane.py │ ├── anchor3dlane_effb3.py │ ├── anchor3dlane_iter.py │ ├── anchor3dlane_iter_r50.py │ ├── anchor3dlane_mf.py │ └── anchor3dlane_mf_iter.py ├── data ├── Apollosim │ └── data_lists │ │ ├── illus_chg │ │ ├── test.txt │ │ └── train.txt │ │ ├── rare_subset │ │ ├── test.txt │ │ └── train.txt │ │ └── standard │ │ ├── test.txt │ │ └── train.txt ├── ONCE │ └── data_lists │ │ ├── train.txt │ │ └── val.txt └── OpenLane │ └── data_lists │ ├── training.txt │ └── validation.txt ├── gen-efficientnet-pytorch ├── BENCHMARK.md ├── LICENSE ├── README.md ├── caffe2_benchmark.py ├── caffe2_validate.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── loader.py │ ├── tf_preprocessing.py │ └── transforms.py ├── geffnet │ ├── __init__.py │ ├── activations │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── activations_jit.py │ │ └── activations_me.py │ ├── config.py │ ├── conv2d_layers.py │ ├── efficientnet_builder.py │ ├── gen_efficientnet.py │ ├── helpers.py │ ├── mobilenetv3.py │ ├── model_factory.py │ └── version.py ├── hubconf.py ├── onnx_export.py ├── onnx_optimize.py ├── onnx_to_caffe.py ├── onnx_validate.py ├── requirements.txt ├── setup.py ├── utils.py └── validate.py ├── images ├── pipeline.png ├── vis_apollo.png ├── vis_once.png └── vis_openlane.png ├── mmseg ├── __init__.py ├── apis │ ├── __init__.py │ ├── inference.py │ ├── test.py │ ├── test_apollosim.py │ ├── test_once.py │ ├── test_openlane.py │ └── train.py ├── core │ ├── __init__.py │ ├── builder.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── class_names.py │ │ ├── eval_hooks.py │ │ └── metrics.py │ ├── hook │ │ ├── __init__.py │ │ └── wandblogger_hook.py │ ├── optimizers │ │ ├── __init__.py │ │ └── layer_decay_optimizer_constructor.py │ ├── seg │ │ ├── __init__.py │ │ ├── builder.py │ │ └── sampler │ │ │ ├── __init__.py │ │ │ ├── base_pixel_sampler.py │ │ │ └── ohem_pixel_sampler.py │ └── utils │ │ ├── __init__.py │ │ ├── dist_util.py │ │ ├── misc.py │ │ └── scatter.py ├── datasets │ ├── __init__.py │ ├── ade.py │ ├── builder.py │ ├── chase_db1.py │ ├── cityscapes.py │ ├── coco_stuff.py │ ├── custom.py │ ├── dark_zurich.py │ ├── dataset_wrappers.py │ ├── drive.py │ ├── hrf.py │ ├── isaid.py │ ├── isprs.py │ ├── lane_datasets │ │ ├── __init__.py │ │ ├── apollosim.py │ │ ├── once.py │ │ ├── openlane.py │ │ └── openlane_temporal.py │ ├── loveda.py │ ├── night_driving.py │ ├── pascal_context.py │ ├── pipelines │ │ ├── __init__.py │ │ ├── compose.py │ │ ├── formating.py │ │ ├── formatting.py │ │ ├── lane_format.py │ │ ├── loading.py │ │ ├── test_time_aug.py │ │ └── transforms.py │ ├── potsdam.py │ ├── samplers │ │ ├── __init__.py │ │ └── distributed_sampler.py │ ├── stare.py │ ├── tools │ │ ├── MinCostFlow.py │ │ ├── __init__.py │ │ ├── eval_apollosim.py │ │ ├── eval_once.py │ │ ├── eval_openlane.py │ │ ├── utils.py │ │ ├── vis_apollosim.py │ │ ├── vis_once.py │ │ └── vis_openlane.py │ └── voc.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── beit.py │ │ ├── bisenetv1.py │ │ ├── bisenetv2.py │ │ ├── cgnet.py │ │ ├── efficientnet.py │ │ ├── erfnet.py │ │ ├── fast_scnn.py │ │ ├── hrnet.py │ │ ├── icnet.py │ │ ├── mae.py │ │ ├── mit.py │ │ ├── mobilenet_v2.py │ │ ├── mobilenet_v3.py │ │ ├── resnest.py │ │ ├── resnet.py │ │ ├── resnext.py │ │ ├── stdc.py │ │ ├── swin.py │ │ ├── twins.py │ │ ├── unet.py │ │ └── vit.py │ ├── builder.py │ ├── decode_heads │ │ ├── __init__.py │ │ ├── ann_head.py │ │ ├── apc_head.py │ │ ├── aspp_head.py │ │ ├── cascade_decode_head.py │ │ ├── cc_head.py │ │ ├── da_head.py │ │ ├── decode_head.py │ │ ├── dm_head.py │ │ ├── dnl_head.py │ │ ├── dpt_head.py │ │ ├── ema_head.py │ │ ├── enc_head.py │ │ ├── fcn_head.py │ │ ├── fpn_head.py │ │ ├── gc_head.py │ │ ├── isa_head.py │ │ ├── knet_head.py │ │ ├── lraspp_head.py │ │ ├── nl_head.py │ │ ├── ocr_head.py │ │ ├── point_head.py │ │ ├── psa_head.py │ │ ├── psp_head.py │ │ ├── segformer_head.py │ │ ├── segmenter_mask_head.py │ │ ├── sep_aspp_head.py │ │ ├── sep_fcn_head.py │ │ ├── setr_mla_head.py │ │ ├── setr_up_head.py │ │ ├── stdc_head.py │ │ └── uper_head.py │ ├── lane_detector │ │ ├── __init__.py │ │ ├── anchor_3dlane.py │ │ ├── anchor_3dlane_deform.py │ │ ├── anchor_3dlane_multiframe.py │ │ ├── assigner │ │ │ ├── __init__.py │ │ │ ├── distance_metric.py │ │ │ ├── thresh_assigner.py │ │ │ ├── topk_assigner.py │ │ │ └── topk_fv_assigner.py │ │ ├── msda.py │ │ ├── position_encoding.py │ │ ├── tools.py │ │ ├── transformer.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── anchor.py │ │ │ └── nms.py │ ├── losses │ │ ├── __init__.py │ │ ├── accuracy.py │ │ ├── cross_entropy_loss.py │ │ ├── dice_loss.py │ │ ├── focal_loss.py │ │ ├── kornia_focal.py │ │ ├── lane_loss.py │ │ ├── lovasz_loss.py │ │ └── utils.py │ ├── necks │ │ ├── __init__.py │ │ ├── fpn.py │ │ ├── ic_neck.py │ │ ├── jpu.py │ │ ├── mla_neck.py │ │ └── multilevel_neck.py │ ├── segmentors │ │ ├── __init__.py │ │ ├── base.py │ │ ├── cascade_encoder_decoder.py │ │ └── encoder_decoder.py │ └── utils │ │ ├── __init__.py │ │ ├── embed.py │ │ ├── inverted_residual.py │ │ ├── make_divisible.py │ │ ├── misc.py │ │ ├── ops │ │ ├── functions │ │ │ ├── __init__.py │ │ │ └── ms_deform_attn_func.py │ │ ├── make.sh │ │ ├── modules │ │ │ ├── __init__.py │ │ │ └── ms_deform_attn.py │ │ ├── setup.py │ │ ├── src │ │ │ ├── cpu │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ ├── cuda │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ ├── ms_deform_attn.h │ │ │ └── vision.cpp │ │ └── test.py │ │ ├── res_layer.py │ │ ├── se_layer.py │ │ ├── self_attention_block.py │ │ ├── shape_convert.py │ │ └── up_conv_block.py ├── ops │ ├── __init__.py │ ├── encoding.py │ └── wrappers.py ├── utils │ ├── __init__.py │ ├── avs_metric.py │ ├── collect_env.py │ ├── logger.py │ ├── misc.py │ ├── set_env.py │ └── util_distribution.py └── version.py ├── requirements.txt ├── setup.cfg ├── setup.py └── tools ├── analyze_logs.py ├── benchmark.py ├── browse_dataset.py ├── confusion_matrix.py ├── convert_datasets ├── apollosim.py ├── chase_db1.py ├── cityscapes.py ├── coco_stuff10k.py ├── coco_stuff164k.py ├── drive.py ├── hrf.py ├── isaid.py ├── loveda.py ├── once.py ├── openlane.py ├── openlane_pose.py ├── pascal_context.py ├── potsdam.py ├── stare.py ├── vaihingen.py └── voc_aug.py ├── deploy_test.py ├── dist_test.sh ├── dist_train.sh ├── dist_train_multinode.sh ├── get_flops.py ├── model_converters ├── beit2mmseg.py ├── mit2mmseg.py ├── stdc2mmseg.py ├── swin2mmseg.py ├── twins2mmseg.py ├── vit2mmseg.py └── vitjax2mmseg.py ├── onnx2tensorrt.py ├── print_config.py ├── publish_model.py ├── pytorch2onnx.py ├── pytorch2torchscript.py ├── script.sh ├── slurm_test.sh ├── slurm_train.sh ├── test.py ├── torchserve ├── mmseg2torchserve.py ├── mmseg_handler.py └── test_torchserve.py ├── train.py └── train_dist.py /gen-efficientnet-pytorch/caffe2_benchmark.py: -------------------------------------------------------------------------------- 1 | """ Caffe2 validation script 2 | 3 | This script runs Caffe2 benchmark on exported ONNX model. 4 | It is a useful tool for reporting model FLOPS. 5 | 6 | Copyright 2020 Ross Wightman 7 | """ 8 | import argparse 9 | from caffe2.python import core, workspace, model_helper 10 | from caffe2.proto import caffe2_pb2 11 | 12 | 13 | parser = argparse.ArgumentParser(description='Caffe2 Model Benchmark') 14 | parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME', 15 | help='caffe2 model pb name prefix') 16 | parser.add_argument('--c2-init', default='', type=str, metavar='PATH', 17 | help='caffe2 model init .pb') 18 | parser.add_argument('--c2-predict', default='', type=str, metavar='PATH', 19 | help='caffe2 model predict .pb') 20 | parser.add_argument('-b', '--batch-size', default=1, type=int, 21 | metavar='N', help='mini-batch size (default: 1)') 22 | parser.add_argument('--img-size', default=224, type=int, 23 | metavar='N', help='Input image dimension, uses model default if empty') 24 | 25 | 26 | def main(): 27 | args = parser.parse_args() 28 | args.gpu_id = 0 29 | if args.c2_prefix: 30 | args.c2_init = args.c2_prefix + '.init.pb' 31 | args.c2_predict = args.c2_prefix + '.predict.pb' 32 | 33 | model = model_helper.ModelHelper(name="le_net", init_params=False) 34 | 35 | # Bring in the init net from init_net.pb 36 | init_net_proto = caffe2_pb2.NetDef() 37 | with open(args.c2_init, "rb") as f: 38 | init_net_proto.ParseFromString(f.read()) 39 | model.param_init_net = core.Net(init_net_proto) 40 | 41 | # bring in the predict net from predict_net.pb 42 | predict_net_proto = caffe2_pb2.NetDef() 43 | with open(args.c2_predict, "rb") as f: 44 | predict_net_proto.ParseFromString(f.read()) 45 | model.net = core.Net(predict_net_proto) 46 | 47 | # CUDA performance not impressive 48 | #device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id) 49 | #model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) 50 | #model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) 51 | 52 | input_blob = model.net.external_inputs[0] 53 | model.param_init_net.GaussianFill( 54 | [], 55 | input_blob.GetUnscopedName(), 56 | shape=(args.batch_size, 3, args.img_size, args.img_size), 57 | mean=0.0, 58 | std=1.0) 59 | workspace.RunNetOnce(model.param_init_net) 60 | workspace.CreateNet(model.net, overwrite=True) 61 | workspace.BenchmarkNet(model.net.Proto().name, 5, 20, True) 62 | 63 | 64 | if __name__ == '__main__': 65 | main() 66 | -------------------------------------------------------------------------------- /gen-efficientnet-pytorch/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | from .transforms import * 3 | from .loader import create_loader 4 | -------------------------------------------------------------------------------- /gen-efficientnet-pytorch/data/dataset.py: -------------------------------------------------------------------------------- 1 | """ Quick n simple image folder dataset 2 | 3 | Copyright 2020 Ross Wightman 4 | """ 5 | import torch.utils.data as data 6 | 7 | import os 8 | import re 9 | import torch 10 | from PIL import Image 11 | 12 | 13 | IMG_EXTENSIONS = ['.png', '.jpg', '.jpeg'] 14 | 15 | 16 | def natural_key(string_): 17 | """See http://www.codinghorror.com/blog/archives/001018.html""" 18 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 19 | 20 | 21 | def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): 22 | if class_to_idx is None: 23 | class_to_idx = dict() 24 | build_class_idx = True 25 | else: 26 | build_class_idx = False 27 | labels = [] 28 | filenames = [] 29 | for root, subdirs, files in os.walk(folder, topdown=False): 30 | rel_path = os.path.relpath(root, folder) if (root != folder) else '' 31 | label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') 32 | if build_class_idx and not subdirs: 33 | class_to_idx[label] = None 34 | for f in files: 35 | base, ext = os.path.splitext(f) 36 | if ext.lower() in types: 37 | filenames.append(os.path.join(root, f)) 38 | labels.append(label) 39 | if build_class_idx: 40 | classes = sorted(class_to_idx.keys(), key=natural_key) 41 | for idx, c in enumerate(classes): 42 | class_to_idx[c] = idx 43 | images_and_targets = zip(filenames, [class_to_idx[l] for l in labels]) 44 | if sort: 45 | images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) 46 | if build_class_idx: 47 | return images_and_targets, classes, class_to_idx 48 | else: 49 | return images_and_targets 50 | 51 | 52 | class Dataset(data.Dataset): 53 | 54 | def __init__( 55 | self, 56 | root, 57 | transform=None, 58 | load_bytes=False): 59 | 60 | imgs, _, _ = find_images_and_targets(root) 61 | if len(imgs) == 0: 62 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 63 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 64 | self.root = root 65 | self.imgs = imgs 66 | self.transform = transform 67 | self.load_bytes = load_bytes 68 | 69 | def __getitem__(self, index): 70 | path, target = self.imgs[index] 71 | img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB') 72 | if self.transform is not None: 73 | img = self.transform(img) 74 | if target is None: 75 | target = torch.zeros(1).long() 76 | return img, target 77 | 78 | def __len__(self): 79 | return len(self.imgs) 80 | 81 | def filenames(self, indices=[], basename=False): 82 | if indices: 83 | if basename: 84 | return [os.path.basename(self.imgs[i][0]) for i in indices] 85 | else: 86 | return [self.imgs[i][0] for i in indices] 87 | else: 88 | if basename: 89 | return [os.path.basename(x[0]) for x in self.imgs] 90 | else: 91 | return [x[0] for x in self.imgs] 92 | -------------------------------------------------------------------------------- /gen-efficientnet-pytorch/data/loader.py: -------------------------------------------------------------------------------- 1 | """ Fast Collate, CUDA Prefetcher 2 | 3 | Prefetcher and Fast Collate inspired by NVIDIA APEX example at 4 | https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | import torch 9 | import torch.utils.data 10 | from .transforms import * 11 | 12 | 13 | def fast_collate(batch): 14 | targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) 15 | batch_size = len(targets) 16 | tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) 17 | for i in range(batch_size): 18 | tensor[i] += torch.from_numpy(batch[i][0]) 19 | 20 | return tensor, targets 21 | 22 | 23 | class PrefetchLoader: 24 | 25 | def __init__(self, 26 | loader, 27 | mean=IMAGENET_DEFAULT_MEAN, 28 | std=IMAGENET_DEFAULT_STD): 29 | self.loader = loader 30 | self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) 31 | self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) 32 | 33 | def __iter__(self): 34 | stream = torch.cuda.Stream() 35 | first = True 36 | 37 | for next_input, next_target in self.loader: 38 | with torch.cuda.stream(stream): 39 | next_input = next_input.cuda(non_blocking=True) 40 | next_target = next_target.cuda(non_blocking=True) 41 | next_input = next_input.float().sub_(self.mean).div_(self.std) 42 | 43 | if not first: 44 | yield input, target 45 | else: 46 | first = False 47 | 48 | torch.cuda.current_stream().wait_stream(stream) 49 | input = next_input 50 | target = next_target 51 | 52 | yield input, target 53 | 54 | def __len__(self): 55 | return len(self.loader) 56 | 57 | @property 58 | def sampler(self): 59 | return self.loader.sampler 60 | 61 | 62 | def create_loader( 63 | dataset, 64 | input_size, 65 | batch_size, 66 | is_training=False, 67 | use_prefetcher=True, 68 | interpolation='bilinear', 69 | mean=IMAGENET_DEFAULT_MEAN, 70 | std=IMAGENET_DEFAULT_STD, 71 | num_workers=1, 72 | crop_pct=None, 73 | tensorflow_preprocessing=False 74 | ): 75 | if isinstance(input_size, tuple): 76 | img_size = input_size[-2:] 77 | else: 78 | img_size = input_size 79 | 80 | if tensorflow_preprocessing and use_prefetcher: 81 | from data.tf_preprocessing import TfPreprocessTransform 82 | transform = TfPreprocessTransform( 83 | is_training=is_training, size=img_size, interpolation=interpolation) 84 | else: 85 | transform = transforms_imagenet_eval( 86 | img_size, 87 | interpolation=interpolation, 88 | use_prefetcher=use_prefetcher, 89 | mean=mean, 90 | std=std, 91 | crop_pct=crop_pct) 92 | 93 | dataset.transform = transform 94 | 95 | loader = torch.utils.data.DataLoader( 96 | dataset, 97 | batch_size=batch_size, 98 | shuffle=False, 99 | num_workers=num_workers, 100 | collate_fn=fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate, 101 | ) 102 | if use_prefetcher: 103 | loader = PrefetchLoader( 104 | loader, 105 | mean=mean, 106 | std=std) 107 | 108 | return loader 109 | -------------------------------------------------------------------------------- /gen-efficientnet-pytorch/geffnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .gen_efficientnet import * 2 | from .mobilenetv3 import * 3 | from .model_factory import create_model 4 | from .config import is_exportable, is_scriptable, set_exportable, set_scriptable 5 | from .activations import * -------------------------------------------------------------------------------- /gen-efficientnet-pytorch/geffnet/activations/activations.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | Copyright 2020 Ross Wightman 7 | """ 8 | from torch import nn as nn 9 | from torch.nn import functional as F 10 | 11 | 12 | def swish(x, inplace: bool = False): 13 | """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) 14 | and also as Swish (https://arxiv.org/abs/1710.05941). 15 | 16 | TODO Rename to SiLU with addition to PyTorch 17 | """ 18 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) 19 | 20 | 21 | class Swish(nn.Module): 22 | def __init__(self, inplace: bool = False): 23 | super(Swish, self).__init__() 24 | self.inplace = inplace 25 | 26 | def forward(self, x): 27 | return swish(x, self.inplace) 28 | 29 | 30 | def mish(x, inplace: bool = False): 31 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 32 | """ 33 | return x.mul(F.softplus(x).tanh()) 34 | 35 | 36 | class Mish(nn.Module): 37 | def __init__(self, inplace: bool = False): 38 | super(Mish, self).__init__() 39 | self.inplace = inplace 40 | 41 | def forward(self, x): 42 | return mish(x, self.inplace) 43 | 44 | 45 | def sigmoid(x, inplace: bool = False): 46 | return x.sigmoid_() if inplace else x.sigmoid() 47 | 48 | 49 | # PyTorch has this, but not with a consistent inplace argmument interface 50 | class Sigmoid(nn.Module): 51 | def __init__(self, inplace: bool = False): 52 | super(Sigmoid, self).__init__() 53 | self.inplace = inplace 54 | 55 | def forward(self, x): 56 | return x.sigmoid_() if self.inplace else x.sigmoid() 57 | 58 | 59 | def tanh(x, inplace: bool = False): 60 | return x.tanh_() if inplace else x.tanh() 61 | 62 | 63 | # PyTorch has this, but not with a consistent inplace argmument interface 64 | class Tanh(nn.Module): 65 | def __init__(self, inplace: bool = False): 66 | super(Tanh, self).__init__() 67 | self.inplace = inplace 68 | 69 | def forward(self, x): 70 | return x.tanh_() if self.inplace else x.tanh() 71 | 72 | 73 | def hard_swish(x, inplace: bool = False): 74 | inner = F.relu6(x + 3.).div_(6.) 75 | return x.mul_(inner) if inplace else x.mul(inner) 76 | 77 | 78 | class HardSwish(nn.Module): 79 | def __init__(self, inplace: bool = False): 80 | super(HardSwish, self).__init__() 81 | self.inplace = inplace 82 | 83 | def forward(self, x): 84 | return hard_swish(x, self.inplace) 85 | 86 | 87 | def hard_sigmoid(x, inplace: bool = False): 88 | if inplace: 89 | return x.add_(3.).clamp_(0., 6.).div_(6.) 90 | else: 91 | return F.relu6(x + 3.) / 6. 92 | 93 | 94 | class HardSigmoid(nn.Module): 95 | def __init__(self, inplace: bool = False): 96 | super(HardSigmoid, self).__init__() 97 | self.inplace = inplace 98 | 99 | def forward(self, x): 100 | return hard_sigmoid(x, self.inplace) 101 | 102 | 103 | -------------------------------------------------------------------------------- /gen-efficientnet-pytorch/geffnet/activations/activations_jit.py: -------------------------------------------------------------------------------- 1 | """ Activations (jit) 2 | 3 | A collection of jit-scripted activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not 7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted 8 | versions if they contain in-place ops. 9 | 10 | Copyright 2020 Ross Wightman 11 | """ 12 | 13 | import torch 14 | from torch import nn as nn 15 | from torch.nn import functional as F 16 | 17 | __all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit', 18 | 'hard_sigmoid_jit', 'HardSigmoidJit', 'hard_swish_jit', 'HardSwishJit'] 19 | 20 | 21 | @torch.jit.script 22 | def swish_jit(x, inplace: bool = False): 23 | """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) 24 | and also as Swish (https://arxiv.org/abs/1710.05941). 25 | 26 | TODO Rename to SiLU with addition to PyTorch 27 | """ 28 | return x.mul(x.sigmoid()) 29 | 30 | 31 | @torch.jit.script 32 | def mish_jit(x, _inplace: bool = False): 33 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 34 | """ 35 | return x.mul(F.softplus(x).tanh()) 36 | 37 | 38 | class SwishJit(nn.Module): 39 | def __init__(self, inplace: bool = False): 40 | super(SwishJit, self).__init__() 41 | 42 | def forward(self, x): 43 | return swish_jit(x) 44 | 45 | 46 | class MishJit(nn.Module): 47 | def __init__(self, inplace: bool = False): 48 | super(MishJit, self).__init__() 49 | 50 | def forward(self, x): 51 | return mish_jit(x) 52 | 53 | 54 | @torch.jit.script 55 | def hard_sigmoid_jit(x, inplace: bool = False): 56 | # return F.relu6(x + 3.) / 6. 57 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 58 | 59 | 60 | class HardSigmoidJit(nn.Module): 61 | def __init__(self, inplace: bool = False): 62 | super(HardSigmoidJit, self).__init__() 63 | 64 | def forward(self, x): 65 | return hard_sigmoid_jit(x) 66 | 67 | 68 | @torch.jit.script 69 | def hard_swish_jit(x, inplace: bool = False): 70 | # return x * (F.relu6(x + 3.) / 6) 71 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 72 | 73 | 74 | class HardSwishJit(nn.Module): 75 | def __init__(self, inplace: bool = False): 76 | super(HardSwishJit, self).__init__() 77 | 78 | def forward(self, x): 79 | return hard_swish_jit(x) 80 | -------------------------------------------------------------------------------- /gen-efficientnet-pytorch/geffnet/helpers.py: -------------------------------------------------------------------------------- 1 | """ Checkpoint loading / state_dict helpers 2 | Copyright 2020 Ross Wightman 3 | """ 4 | import torch 5 | import os 6 | from collections import OrderedDict 7 | try: 8 | from torch.hub import load_state_dict_from_url 9 | except ImportError: 10 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 11 | 12 | 13 | def load_checkpoint(model, checkpoint_path): 14 | if checkpoint_path and os.path.isfile(checkpoint_path): 15 | print("=> Loading checkpoint '{}'".format(checkpoint_path)) 16 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 17 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 18 | new_state_dict = OrderedDict() 19 | for k, v in checkpoint['state_dict'].items(): 20 | if k.startswith('module'): 21 | name = k[7:] # remove `module.` 22 | else: 23 | name = k 24 | new_state_dict[name] = v 25 | model.load_state_dict(new_state_dict) 26 | else: 27 | model.load_state_dict(checkpoint) 28 | print("=> Loaded checkpoint '{}'".format(checkpoint_path)) 29 | else: 30 | print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) 31 | raise FileNotFoundError() 32 | 33 | 34 | def load_pretrained(model, url, filter_fn=None, strict=True): 35 | if not url: 36 | print("=> Warning: Pretrained model URL is empty, using random initialization.") 37 | return 38 | 39 | state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu') 40 | 41 | input_conv = 'conv_stem' 42 | classifier = 'classifier' 43 | in_chans = getattr(model, input_conv).weight.shape[1] 44 | num_classes = getattr(model, classifier).weight.shape[0] 45 | 46 | input_conv_weight = input_conv + '.weight' 47 | pretrained_in_chans = state_dict[input_conv_weight].shape[1] 48 | if in_chans != pretrained_in_chans: 49 | if in_chans == 1: 50 | print('=> Converting pretrained input conv {} from {} to 1 channel'.format( 51 | input_conv_weight, pretrained_in_chans)) 52 | conv1_weight = state_dict[input_conv_weight] 53 | state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True) 54 | else: 55 | print('=> Discarding pretrained input conv {} since input channel count != {}'.format( 56 | input_conv_weight, pretrained_in_chans)) 57 | del state_dict[input_conv_weight] 58 | strict = False 59 | 60 | classifier_weight = classifier + '.weight' 61 | pretrained_num_classes = state_dict[classifier_weight].shape[0] 62 | if num_classes != pretrained_num_classes: 63 | print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes)) 64 | del state_dict[classifier_weight] 65 | del state_dict[classifier + '.bias'] 66 | strict = False 67 | 68 | if filter_fn is not None: 69 | state_dict = filter_fn(state_dict) 70 | 71 | model.load_state_dict(state_dict, strict=strict) 72 | -------------------------------------------------------------------------------- /gen-efficientnet-pytorch/geffnet/model_factory.py: -------------------------------------------------------------------------------- 1 | from .config import set_layer_config 2 | from .helpers import load_checkpoint 3 | 4 | from .gen_efficientnet import * 5 | from .mobilenetv3 import * 6 | 7 | 8 | def create_model( 9 | model_name='mnasnet_100', 10 | pretrained=None, 11 | num_classes=1000, 12 | in_chans=3, 13 | checkpoint_path='', 14 | **kwargs): 15 | 16 | model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs) 17 | 18 | if model_name in globals(): 19 | create_fn = globals()[model_name] 20 | model = create_fn(**model_kwargs) 21 | else: 22 | raise RuntimeError('Unknown model (%s)' % model_name) 23 | 24 | if checkpoint_path and not pretrained: 25 | load_checkpoint(model, checkpoint_path) 26 | 27 | return model 28 | -------------------------------------------------------------------------------- /gen-efficientnet-pytorch/geffnet/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.0.2' 2 | -------------------------------------------------------------------------------- /gen-efficientnet-pytorch/hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ['torch', 'math'] 2 | 3 | from geffnet import efficientnet_b0 4 | from geffnet import efficientnet_b1 5 | from geffnet import efficientnet_b2 6 | from geffnet import efficientnet_b3 7 | 8 | from geffnet import efficientnet_es 9 | 10 | from geffnet import efficientnet_lite0 11 | 12 | from geffnet import mixnet_s 13 | from geffnet import mixnet_m 14 | from geffnet import mixnet_l 15 | from geffnet import mixnet_xl 16 | 17 | from geffnet import mobilenetv2_100 18 | from geffnet import mobilenetv2_110d 19 | from geffnet import mobilenetv2_120d 20 | from geffnet import mobilenetv2_140 21 | 22 | from geffnet import mobilenetv3_large_100 23 | from geffnet import mobilenetv3_rw 24 | from geffnet import mnasnet_a1 25 | from geffnet import mnasnet_b1 26 | from geffnet import fbnetc_100 27 | from geffnet import spnasnet_100 28 | 29 | from geffnet import tf_efficientnet_b0 30 | from geffnet import tf_efficientnet_b1 31 | from geffnet import tf_efficientnet_b2 32 | from geffnet import tf_efficientnet_b3 33 | from geffnet import tf_efficientnet_b4 34 | from geffnet import tf_efficientnet_b5 35 | from geffnet import tf_efficientnet_b6 36 | from geffnet import tf_efficientnet_b7 37 | from geffnet import tf_efficientnet_b8 38 | 39 | from geffnet import tf_efficientnet_b0_ap 40 | from geffnet import tf_efficientnet_b1_ap 41 | from geffnet import tf_efficientnet_b2_ap 42 | from geffnet import tf_efficientnet_b3_ap 43 | from geffnet import tf_efficientnet_b4_ap 44 | from geffnet import tf_efficientnet_b5_ap 45 | from geffnet import tf_efficientnet_b6_ap 46 | from geffnet import tf_efficientnet_b7_ap 47 | from geffnet import tf_efficientnet_b8_ap 48 | 49 | from geffnet import tf_efficientnet_b0_ns 50 | from geffnet import tf_efficientnet_b1_ns 51 | from geffnet import tf_efficientnet_b2_ns 52 | from geffnet import tf_efficientnet_b3_ns 53 | from geffnet import tf_efficientnet_b4_ns 54 | from geffnet import tf_efficientnet_b5_ns 55 | from geffnet import tf_efficientnet_b6_ns 56 | from geffnet import tf_efficientnet_b7_ns 57 | from geffnet import tf_efficientnet_l2_ns_475 58 | from geffnet import tf_efficientnet_l2_ns 59 | 60 | from geffnet import tf_efficientnet_es 61 | from geffnet import tf_efficientnet_em 62 | from geffnet import tf_efficientnet_el 63 | 64 | from geffnet import tf_efficientnet_cc_b0_4e 65 | from geffnet import tf_efficientnet_cc_b0_8e 66 | from geffnet import tf_efficientnet_cc_b1_8e 67 | 68 | from geffnet import tf_efficientnet_lite0 69 | from geffnet import tf_efficientnet_lite1 70 | from geffnet import tf_efficientnet_lite2 71 | from geffnet import tf_efficientnet_lite3 72 | from geffnet import tf_efficientnet_lite4 73 | 74 | from geffnet import tf_mixnet_s 75 | from geffnet import tf_mixnet_m 76 | from geffnet import tf_mixnet_l 77 | 78 | from geffnet import tf_mobilenetv3_large_075 79 | from geffnet import tf_mobilenetv3_large_100 80 | from geffnet import tf_mobilenetv3_large_minimal_100 81 | from geffnet import tf_mobilenetv3_small_075 82 | from geffnet import tf_mobilenetv3_small_100 83 | from geffnet import tf_mobilenetv3_small_minimal_100 84 | 85 | -------------------------------------------------------------------------------- /gen-efficientnet-pytorch/onnx_optimize.py: -------------------------------------------------------------------------------- 1 | """ ONNX optimization script 2 | 3 | Run ONNX models through the optimizer to prune unneeded nodes, fuse batchnorm layers into conv, etc. 4 | 5 | NOTE: This isn't working consistently in recent PyTorch/ONNX combos (ie PyTorch 1.6 and ONNX 1.7), 6 | it seems time to switch to using the onnxruntime online optimizer (can also be saved for offline). 7 | 8 | Copyright 2020 Ross Wightman 9 | """ 10 | import argparse 11 | import warnings 12 | 13 | import onnx 14 | from onnx import optimizer 15 | 16 | 17 | parser = argparse.ArgumentParser(description="Optimize ONNX model") 18 | 19 | parser.add_argument("model", help="The ONNX model") 20 | parser.add_argument("--output", required=True, help="The optimized model output filename") 21 | 22 | 23 | def traverse_graph(graph, prefix=''): 24 | content = [] 25 | indent = prefix + ' ' 26 | graphs = [] 27 | num_nodes = 0 28 | for node in graph.node: 29 | pn, gs = onnx.helper.printable_node(node, indent, subgraphs=True) 30 | assert isinstance(gs, list) 31 | content.append(pn) 32 | graphs.extend(gs) 33 | num_nodes += 1 34 | for g in graphs: 35 | g_count, g_str = traverse_graph(g) 36 | content.append('\n' + g_str) 37 | num_nodes += g_count 38 | return num_nodes, '\n'.join(content) 39 | 40 | 41 | def main(): 42 | args = parser.parse_args() 43 | onnx_model = onnx.load(args.model) 44 | num_original_nodes, original_graph_str = traverse_graph(onnx_model.graph) 45 | 46 | # Optimizer passes to perform 47 | passes = [ 48 | #'eliminate_deadend', 49 | 'eliminate_identity', 50 | 'eliminate_nop_dropout', 51 | 'eliminate_nop_pad', 52 | 'eliminate_nop_transpose', 53 | 'eliminate_unused_initializer', 54 | 'extract_constant_to_initializer', 55 | 'fuse_add_bias_into_conv', 56 | 'fuse_bn_into_conv', 57 | 'fuse_consecutive_concats', 58 | 'fuse_consecutive_reduce_unsqueeze', 59 | 'fuse_consecutive_squeezes', 60 | 'fuse_consecutive_transposes', 61 | #'fuse_matmul_add_bias_into_gemm', 62 | 'fuse_pad_into_conv', 63 | #'fuse_transpose_into_gemm', 64 | #'lift_lexical_references', 65 | ] 66 | 67 | # Apply the optimization on the original serialized model 68 | # WARNING I've had issues with optimizer in recent versions of PyTorch / ONNX causing 69 | # 'duplicate definition of name' errors, see: https://github.com/onnx/onnx/issues/2401 70 | # It may be better to rely on onnxruntime optimizations, see onnx_validate.py script. 71 | warnings.warn("I've had issues with optimizer in recent versions of PyTorch / ONNX." 72 | "Try onnxruntime optimization if this doesn't work.") 73 | optimized_model = optimizer.optimize(onnx_model, passes) 74 | 75 | num_optimized_nodes, optimzied_graph_str = traverse_graph(optimized_model.graph) 76 | print('==> The model after optimization:\n{}\n'.format(optimzied_graph_str)) 77 | print('==> The optimized model has {} nodes, the original had {}.'.format(num_optimized_nodes, num_original_nodes)) 78 | 79 | # Save the ONNX model 80 | onnx.save(optimized_model, args.output) 81 | 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /gen-efficientnet-pytorch/onnx_to_caffe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import onnx 4 | from caffe2.python.onnx.backend import Caffe2Backend 5 | 6 | 7 | parser = argparse.ArgumentParser(description="Convert ONNX to Caffe2") 8 | 9 | parser.add_argument("model", help="The ONNX model") 10 | parser.add_argument("--c2-prefix", required=True, 11 | help="The output file prefix for the caffe2 model init and predict file. ") 12 | 13 | 14 | def main(): 15 | args = parser.parse_args() 16 | onnx_model = onnx.load(args.model) 17 | caffe2_init, caffe2_predict = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model) 18 | caffe2_init_str = caffe2_init.SerializeToString() 19 | with open(args.c2_prefix + '.init.pb', "wb") as f: 20 | f.write(caffe2_init_str) 21 | caffe2_predict_str = caffe2_predict.SerializeToString() 22 | with open(args.c2_prefix + '.predict.pb', "wb") as f: 23 | f.write(caffe2_predict_str) 24 | 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /gen-efficientnet-pytorch/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.2.0 2 | torchvision>=0.4.0 3 | -------------------------------------------------------------------------------- /gen-efficientnet-pytorch/setup.py: -------------------------------------------------------------------------------- 1 | """ Setup 2 | """ 3 | from setuptools import setup, find_packages 4 | from codecs import open 5 | from os import path 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | 9 | # Get the long description from the README file 10 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 11 | long_description = f.read() 12 | 13 | exec(open('geffnet/version.py').read()) 14 | setup( 15 | name='geffnet', 16 | version=__version__, 17 | description='(Generic) EfficientNets for PyTorch', 18 | long_description=long_description, 19 | long_description_content_type='text/markdown', 20 | url='https://github.com/rwightman/gen-efficientnet-pytorch', 21 | author='Ross Wightman', 22 | author_email='hello@rwightman.com', 23 | classifiers=[ 24 | # How mature is this project? Common values are 25 | # 3 - Alpha 26 | # 4 - Beta 27 | # 5 - Production/Stable 28 | 'Development Status :: 3 - Alpha', 29 | 'Intended Audience :: Education', 30 | 'Intended Audience :: Science/Research', 31 | 'License :: OSI Approved :: Apache Software License', 32 | 'Programming Language :: Python :: 3.6', 33 | 'Programming Language :: Python :: 3.7', 34 | 'Programming Language :: Python :: 3.8', 35 | 'Topic :: Scientific/Engineering', 36 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 37 | 'Topic :: Software Development', 38 | 'Topic :: Software Development :: Libraries', 39 | 'Topic :: Software Development :: Libraries :: Python Modules', 40 | ], 41 | 42 | # Note that this is a string of words separated by whitespace, not a list. 43 | keywords='pytorch pretrained models efficientnet mixnet mobilenetv3 mnasnet', 44 | packages=find_packages(exclude=['data']), 45 | install_requires=['torch >= 1.4', 'torchvision'], 46 | python_requires='>=3.6', 47 | ) 48 | -------------------------------------------------------------------------------- /gen-efficientnet-pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class AverageMeter: 5 | """Computes and stores the average and current value""" 6 | def __init__(self): 7 | self.reset() 8 | 9 | def reset(self): 10 | self.val = 0 11 | self.avg = 0 12 | self.sum = 0 13 | self.count = 0 14 | 15 | def update(self, val, n=1): 16 | self.val = val 17 | self.sum += val * n 18 | self.count += n 19 | self.avg = self.sum / self.count 20 | 21 | 22 | def accuracy(output, target, topk=(1,)): 23 | """Computes the precision@k for the specified values of k""" 24 | maxk = max(topk) 25 | batch_size = target.size(0) 26 | 27 | _, pred = output.topk(maxk, 1, True, True) 28 | pred = pred.t() 29 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 30 | 31 | res = [] 32 | for k in topk: 33 | correct_k = correct[:k].reshape(-1).float().sum(0) 34 | res.append(correct_k.mul_(100.0 / batch_size)) 35 | return res 36 | 37 | 38 | def get_outdir(path, *paths, inc=False): 39 | outdir = os.path.join(path, *paths) 40 | if not os.path.exists(outdir): 41 | os.makedirs(outdir) 42 | elif inc: 43 | count = 1 44 | outdir_inc = outdir + '-' + str(count) 45 | while os.path.exists(outdir_inc): 46 | count = count + 1 47 | outdir_inc = outdir + '-' + str(count) 48 | assert count < 100 49 | outdir = outdir_inc 50 | os.makedirs(outdir) 51 | return outdir 52 | 53 | -------------------------------------------------------------------------------- /images/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tusen-ai/Anchor3DLane/5dfea5724c4a26cfb01a68e52383bc694e9be201/images/pipeline.png -------------------------------------------------------------------------------- /images/vis_apollo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tusen-ai/Anchor3DLane/5dfea5724c4a26cfb01a68e52383bc694e9be201/images/vis_apollo.png -------------------------------------------------------------------------------- /images/vis_once.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tusen-ai/Anchor3DLane/5dfea5724c4a26cfb01a68e52383bc694e9be201/images/vis_once.png -------------------------------------------------------------------------------- /images/vis_openlane.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tusen-ai/Anchor3DLane/5dfea5724c4a26cfb01a68e52383bc694e9be201/images/vis_openlane.png -------------------------------------------------------------------------------- /mmseg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import mmcv 5 | from packaging.version import parse 6 | 7 | from .version import __version__, version_info 8 | 9 | MMCV_MIN = '1.3.13' 10 | MMCV_MAX = '1.7.0' 11 | 12 | 13 | def digit_version(version_str: str, length: int = 4): 14 | """Convert a version string into a tuple of integers. 15 | 16 | This method is usually used for comparing two versions. For pre-release 17 | versions: alpha < beta < rc. 18 | 19 | Args: 20 | version_str (str): The version string. 21 | length (int): The maximum number of version levels. Default: 4. 22 | 23 | Returns: 24 | tuple[int]: The version info in digits (integers). 25 | """ 26 | version = parse(version_str) 27 | assert version.release, f'failed to parse version {version_str}' 28 | release = list(version.release) 29 | release = release[:length] 30 | if len(release) < length: 31 | release = release + [0] * (length - len(release)) 32 | if version.is_prerelease: 33 | mapping = {'a': -3, 'b': -2, 'rc': -1} 34 | val = -4 35 | # version.pre can be None 36 | if version.pre: 37 | if version.pre[0] not in mapping: 38 | warnings.warn(f'unknown prerelease version {version.pre[0]}, ' 39 | 'version checking may go wrong') 40 | else: 41 | val = mapping[version.pre[0]] 42 | release.extend([val, version.pre[-1]]) 43 | else: 44 | release.extend([val, 0]) 45 | 46 | elif version.is_postrelease: 47 | release.extend([1, version.post]) 48 | else: 49 | release.extend([0, 0]) 50 | return tuple(release) 51 | 52 | 53 | mmcv_min_version = digit_version(MMCV_MIN) 54 | mmcv_max_version = digit_version(MMCV_MAX) 55 | mmcv_version = digit_version(mmcv.__version__) 56 | 57 | 58 | assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \ 59 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 60 | f'Please install mmcv>={mmcv_min_version}, <{mmcv_max_version}.' 61 | 62 | __all__ = ['__version__', 'version_info', 'digit_version'] 63 | -------------------------------------------------------------------------------- /mmseg/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .inference import inference_segmentor, init_segmentor, show_result_pyplot 3 | from .test import multi_gpu_test, single_gpu_test 4 | from .train import (get_root_logger, init_random_seed, set_random_seed, 5 | train_segmentor) 6 | from .test_apollosim import test_apollosim 7 | from .test_openlane import test_openlane 8 | from .test_once import test_once 9 | 10 | __all__ = [ 11 | 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', 12 | 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', 13 | 'show_result_pyplot', 'init_random_seed', 'test_apollosim', 'test_openlane', 'test_once' 14 | ] 15 | -------------------------------------------------------------------------------- /mmseg/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import (OPTIMIZER_BUILDERS, build_optimizer, 3 | build_optimizer_constructor) 4 | from .evaluation import * # noqa: F401, F403 5 | from .hook import * # noqa: F401, F403 6 | from .optimizers import * # noqa: F401, F403 7 | from .seg import * # noqa: F401, F403 8 | from .utils import * # noqa: F401, F403 9 | 10 | __all__ = [ 11 | 'OPTIMIZER_BUILDERS', 'build_optimizer', 'build_optimizer_constructor' 12 | ] 13 | -------------------------------------------------------------------------------- /mmseg/core/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | 4 | from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS 5 | from mmcv.utils import Registry, build_from_cfg 6 | 7 | OPTIMIZER_BUILDERS = Registry( 8 | 'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS) 9 | 10 | 11 | def build_optimizer_constructor(cfg): 12 | constructor_type = cfg.get('type') 13 | if constructor_type in OPTIMIZER_BUILDERS: 14 | return build_from_cfg(cfg, OPTIMIZER_BUILDERS) 15 | elif constructor_type in MMCV_OPTIMIZER_BUILDERS: 16 | return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS) 17 | else: 18 | raise KeyError(f'{constructor_type} is not registered ' 19 | 'in the optimizer builder registry.') 20 | 21 | 22 | def build_optimizer(model, cfg): 23 | optimizer_cfg = copy.deepcopy(cfg) 24 | constructor_type = optimizer_cfg.pop('constructor', 25 | 'DefaultOptimizerConstructor') 26 | paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) 27 | optim_constructor = build_optimizer_constructor( 28 | dict( 29 | type=constructor_type, 30 | optimizer_cfg=optimizer_cfg, 31 | paramwise_cfg=paramwise_cfg)) 32 | optimizer = optim_constructor(model) 33 | return optimizer 34 | -------------------------------------------------------------------------------- /mmseg/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .class_names import get_classes, get_palette 3 | from .eval_hooks import DistEvalHook, EvalHook 4 | from .metrics import (eval_metrics, intersect_and_union, mean_dice, 5 | mean_fscore, mean_iou, pre_eval_to_metrics) 6 | 7 | __all__ = [ 8 | 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore', 9 | 'eval_metrics', 'get_classes', 'get_palette', 'pre_eval_to_metrics', 10 | 'intersect_and_union' 11 | ] 12 | -------------------------------------------------------------------------------- /mmseg/core/hook/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .wandblogger_hook import MMSegWandbHook 3 | 4 | __all__ = ['MMSegWandbHook'] 5 | -------------------------------------------------------------------------------- /mmseg/core/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .layer_decay_optimizer_constructor import ( 3 | LayerDecayOptimizerConstructor, LearningRateDecayOptimizerConstructor) 4 | 5 | __all__ = [ 6 | 'LearningRateDecayOptimizerConstructor', 'LayerDecayOptimizerConstructor' 7 | ] 8 | -------------------------------------------------------------------------------- /mmseg/core/seg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import build_pixel_sampler 3 | from .sampler import BasePixelSampler, OHEMPixelSampler 4 | 5 | __all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] 6 | -------------------------------------------------------------------------------- /mmseg/core/seg/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import Registry, build_from_cfg 3 | 4 | PIXEL_SAMPLERS = Registry('pixel sampler') 5 | 6 | 7 | def build_pixel_sampler(cfg, **default_args): 8 | """Build pixel sampler for segmentation map.""" 9 | return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) 10 | -------------------------------------------------------------------------------- /mmseg/core/seg/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_pixel_sampler import BasePixelSampler 3 | from .ohem_pixel_sampler import OHEMPixelSampler 4 | 5 | __all__ = ['BasePixelSampler', 'OHEMPixelSampler'] 6 | -------------------------------------------------------------------------------- /mmseg/core/seg/sampler/base_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BasePixelSampler(metaclass=ABCMeta): 6 | """Base class of pixel sampler.""" 7 | 8 | def __init__(self, **kwargs): 9 | pass 10 | 11 | @abstractmethod 12 | def sample(self, seg_logit, seg_label): 13 | """Placeholder for sample function.""" 14 | -------------------------------------------------------------------------------- /mmseg/core/seg/sampler/ohem_pixel_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from ..builder import PIXEL_SAMPLERS 7 | from .base_pixel_sampler import BasePixelSampler 8 | 9 | 10 | @PIXEL_SAMPLERS.register_module() 11 | class OHEMPixelSampler(BasePixelSampler): 12 | """Online Hard Example Mining Sampler for segmentation. 13 | 14 | Args: 15 | context (nn.Module): The context of sampler, subclass of 16 | :obj:`BaseDecodeHead`. 17 | thresh (float, optional): The threshold for hard example selection. 18 | Below which, are prediction with low confidence. If not 19 | specified, the hard examples will be pixels of top ``min_kept`` 20 | loss. Default: None. 21 | min_kept (int, optional): The minimum number of predictions to keep. 22 | Default: 100000. 23 | """ 24 | 25 | def __init__(self, context, thresh=None, min_kept=100000): 26 | super(OHEMPixelSampler, self).__init__() 27 | self.context = context 28 | assert min_kept > 1 29 | self.thresh = thresh 30 | self.min_kept = min_kept 31 | 32 | def sample(self, seg_logit, seg_label): 33 | """Sample pixels that have high loss or with low prediction confidence. 34 | 35 | Args: 36 | seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) 37 | seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W) 38 | 39 | Returns: 40 | torch.Tensor: segmentation weight, shape (N, H, W) 41 | """ 42 | with torch.no_grad(): 43 | assert seg_logit.shape[2:] == seg_label.shape[2:] 44 | assert seg_label.shape[1] == 1 45 | seg_label = seg_label.squeeze(1).long() 46 | batch_kept = self.min_kept * seg_label.size(0) 47 | valid_mask = seg_label != self.context.ignore_index 48 | seg_weight = seg_logit.new_zeros(size=seg_label.size()) 49 | valid_seg_weight = seg_weight[valid_mask] 50 | if self.thresh is not None: 51 | seg_prob = F.softmax(seg_logit, dim=1) 52 | 53 | tmp_seg_label = seg_label.clone().unsqueeze(1) 54 | tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0 55 | seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1) 56 | sort_prob, sort_indices = seg_prob[valid_mask].sort() 57 | 58 | if sort_prob.numel() > 0: 59 | min_threshold = sort_prob[min(batch_kept, 60 | sort_prob.numel() - 1)] 61 | else: 62 | min_threshold = 0.0 63 | threshold = max(min_threshold, self.thresh) 64 | valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. 65 | else: 66 | if not isinstance(self.context.loss_decode, nn.ModuleList): 67 | losses_decode = [self.context.loss_decode] 68 | else: 69 | losses_decode = self.context.loss_decode 70 | losses = 0.0 71 | for loss_module in losses_decode: 72 | losses += loss_module( 73 | seg_logit, 74 | seg_label, 75 | weight=None, 76 | ignore_index=self.context.ignore_index, 77 | reduction_override='none') 78 | 79 | # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa 80 | _, sort_indices = losses[valid_mask].sort(descending=True) 81 | valid_seg_weight[sort_indices[:batch_kept]] = 1. 82 | 83 | seg_weight[valid_mask] = valid_seg_weight 84 | 85 | return seg_weight 86 | -------------------------------------------------------------------------------- /mmseg/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dist_util import check_dist_init, sync_random_seed 3 | from .misc import add_prefix 4 | from .scatter import scatter_mean, scatter_sum 5 | 6 | __all__ = ['add_prefix', 'check_dist_init', 'sync_random_seed', 'scatter_mean', 'scatter_sum'] 7 | -------------------------------------------------------------------------------- /mmseg/core/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | import torch.distributed as dist 5 | from mmcv.runner import get_dist_info 6 | 7 | 8 | def check_dist_init(): 9 | return dist.is_available() and dist.is_initialized() 10 | 11 | 12 | def sync_random_seed(seed=None, device='cuda'): 13 | """Make sure different ranks share the same seed. All workers must call 14 | this function, otherwise it will deadlock. This method is generally used in 15 | `DistributedSampler`, because the seed should be identical across all 16 | processes in the distributed group. 17 | 18 | In distributed sampling, different ranks should sample non-overlapped 19 | data in the dataset. Therefore, this function is used to make sure that 20 | each rank shuffles the data indices in the same order based 21 | on the same seed. Then different ranks could use different indices 22 | to select non-overlapped data from the same data list. 23 | 24 | Args: 25 | seed (int, Optional): The seed. Default to None. 26 | device (str): The device where the seed will be put on. 27 | Default to 'cuda'. 28 | Returns: 29 | int: Seed to be used. 30 | """ 31 | 32 | if seed is None: 33 | seed = np.random.randint(2**31) 34 | assert isinstance(seed, int) 35 | 36 | rank, world_size = get_dist_info() 37 | 38 | if world_size == 1: 39 | return seed 40 | 41 | if rank == 0: 42 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 43 | else: 44 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 45 | dist.broadcast(random_num, src=0) 46 | return random_num.item() 47 | -------------------------------------------------------------------------------- /mmseg/core/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def add_prefix(inputs, prefix): 3 | """Add prefix for dict. 4 | 5 | Args: 6 | inputs (dict): The input dict with str keys. 7 | prefix (str): The prefix to add. 8 | 9 | Returns: 10 | 11 | dict: The dict with keys updated with ``prefix``. 12 | """ 13 | 14 | outputs = dict() 15 | for name, value in inputs.items(): 16 | outputs[f'{prefix}.{name}'] = value 17 | 18 | return outputs 19 | -------------------------------------------------------------------------------- /mmseg/core/utils/scatter.py: -------------------------------------------------------------------------------- 1 | # The following code are copied from pytorch_scatter https://github.com/rusty1s/pytorch_scatter 2 | # Copyright (c) 2020 Matthias Fey 3 | # MIT License 4 | from typing import Optional, Tuple 5 | import torch 6 | 7 | @torch.jit.script 8 | def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): 9 | if dim < 0: 10 | dim = other.dim() + dim 11 | if src.dim() == 1: 12 | for _ in range(dim): 13 | src = src.unsqueeze(0) 14 | for _ in range(other.dim()-src.dim()): 15 | src = src.unsqueeze(-1) 16 | src = src.expand_as(other) 17 | return src 18 | 19 | @torch.jit.script 20 | def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 21 | out: Optional[torch.Tensor] = None, 22 | dim_size: Optional[int] = None) -> torch.Tensor: 23 | index = broadcast(index, src, dim) 24 | if out is None: 25 | size = list(src.size()) 26 | if dim_size is not None: 27 | size[dim] = dim_size 28 | elif index.numel() == 0: 29 | size[dim] = 0 30 | else: 31 | size[dim] = int(index.max()) + 1 32 | out = torch.zeros(size, dtype=src.dtype, device=src.device) 33 | return out.scatter_add_(dim, index, src) 34 | else: 35 | return out.scatter_add_(dim, index, src) 36 | 37 | @torch.jit.script 38 | def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 39 | out: Optional[torch.Tensor] = None, 40 | dim_size: Optional[int] = None) -> torch.Tensor: 41 | 42 | out = scatter_sum(src, index, dim, out, dim_size) 43 | dim_size = out.size(dim) 44 | 45 | index_dim = dim 46 | if index_dim < 0: 47 | index_dim = index_dim + src.dim() 48 | if index.dim() <= index_dim: 49 | index_dim = index.dim() - 1 50 | 51 | ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) 52 | count = scatter_sum(ones, index, index_dim, None, dim_size) 53 | count.clamp_(1) 54 | count = broadcast(count, out, dim) 55 | if torch.is_floating_point(out): 56 | out.div_(count) 57 | else: 58 | assert 0 59 | # out.floor_divide_(count) 60 | return out -------------------------------------------------------------------------------- /mmseg/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .ade import ADE20KDataset 3 | from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset 4 | from .chase_db1 import ChaseDB1Dataset 5 | from .cityscapes import CityscapesDataset 6 | from .coco_stuff import COCOStuffDataset 7 | from .custom import CustomDataset 8 | from .dark_zurich import DarkZurichDataset 9 | from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset, 10 | RepeatDataset) 11 | from .drive import DRIVEDataset 12 | from .hrf import HRFDataset 13 | from .isaid import iSAIDDataset 14 | from .isprs import ISPRSDataset 15 | from .loveda import LoveDADataset 16 | from .night_driving import NightDrivingDataset 17 | from .pascal_context import PascalContextDataset, PascalContextDataset59 18 | from .potsdam import PotsdamDataset 19 | from .stare import STAREDataset 20 | from .voc import PascalVOCDataset 21 | 22 | from .lane_datasets.apollosim import APOLLOSIMDataset 23 | from .lane_datasets.openlane import OpenlaneDataset 24 | from .lane_datasets.openlane_temporal import OpenlaneMFDataset 25 | from .lane_datasets.once import ONCEDataset 26 | 27 | __all__ = [ 28 | 'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset', 29 | 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset', 30 | 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset', 31 | 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', 32 | 'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset', 33 | 'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset', 34 | 'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset', 'APOLLOSIMDataset', 'OpenlaneDataset', 35 | 'OpenlaneMFDataset', 'ONCEDataset'] 36 | -------------------------------------------------------------------------------- /mmseg/datasets/chase_db1.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | from .builder import DATASETS 4 | from .custom import CustomDataset 5 | 6 | 7 | @DATASETS.register_module() 8 | class ChaseDB1Dataset(CustomDataset): 9 | """Chase_db1 dataset. 10 | 11 | In segmentation map annotation for Chase_db1, 0 stands for background, 12 | which is included in 2 categories. ``reduce_zero_label`` is fixed to False. 13 | The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 14 | '_1stHO.png'. 15 | """ 16 | 17 | CLASSES = ('background', 'vessel') 18 | 19 | PALETTE = [[120, 120, 120], [6, 230, 230]] 20 | 21 | def __init__(self, **kwargs): 22 | super(ChaseDB1Dataset, self).__init__( 23 | img_suffix='.png', 24 | seg_map_suffix='_1stHO.png', 25 | reduce_zero_label=False, 26 | **kwargs) 27 | assert self.file_client.exists(self.img_dir) 28 | -------------------------------------------------------------------------------- /mmseg/datasets/dark_zurich.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import DATASETS 3 | from .cityscapes import CityscapesDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class DarkZurichDataset(CityscapesDataset): 8 | """DarkZurichDataset dataset.""" 9 | 10 | def __init__(self, **kwargs): 11 | super().__init__( 12 | img_suffix='_rgb_anon.png', 13 | seg_map_suffix='_gt_labelTrainIds.png', 14 | **kwargs) 15 | -------------------------------------------------------------------------------- /mmseg/datasets/drive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | from .builder import DATASETS 4 | from .custom import CustomDataset 5 | 6 | 7 | @DATASETS.register_module() 8 | class DRIVEDataset(CustomDataset): 9 | """DRIVE dataset. 10 | 11 | In segmentation map annotation for DRIVE, 0 stands for background, which is 12 | included in 2 categories. ``reduce_zero_label`` is fixed to False. The 13 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 14 | '_manual1.png'. 15 | """ 16 | 17 | CLASSES = ('background', 'vessel') 18 | 19 | PALETTE = [[120, 120, 120], [6, 230, 230]] 20 | 21 | def __init__(self, **kwargs): 22 | super(DRIVEDataset, self).__init__( 23 | img_suffix='.png', 24 | seg_map_suffix='_manual1.png', 25 | reduce_zero_label=False, 26 | **kwargs) 27 | assert self.file_client.exists(self.img_dir) 28 | -------------------------------------------------------------------------------- /mmseg/datasets/hrf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | from .builder import DATASETS 4 | from .custom import CustomDataset 5 | 6 | 7 | @DATASETS.register_module() 8 | class HRFDataset(CustomDataset): 9 | """HRF dataset. 10 | 11 | In segmentation map annotation for HRF, 0 stands for background, which is 12 | included in 2 categories. ``reduce_zero_label`` is fixed to False. The 13 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 14 | '.png'. 15 | """ 16 | 17 | CLASSES = ('background', 'vessel') 18 | 19 | PALETTE = [[120, 120, 120], [6, 230, 230]] 20 | 21 | def __init__(self, **kwargs): 22 | super(HRFDataset, self).__init__( 23 | img_suffix='.png', 24 | seg_map_suffix='.png', 25 | reduce_zero_label=False, 26 | **kwargs) 27 | assert self.file_client.exists(self.img_dir) 28 | -------------------------------------------------------------------------------- /mmseg/datasets/isaid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | import mmcv 4 | from mmcv.utils import print_log 5 | 6 | from ..utils import get_root_logger 7 | from .builder import DATASETS 8 | from .custom import CustomDataset 9 | 10 | 11 | @DATASETS.register_module() 12 | class iSAIDDataset(CustomDataset): 13 | """ iSAID: A Large-scale Dataset for Instance Segmentation in Aerial Images 14 | In segmentation map annotation for iSAID dataset, which is included 15 | in 16 categories. ``reduce_zero_label`` is fixed to False. The 16 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 17 | '_manual1.png'. 18 | """ 19 | 20 | CLASSES = ('background', 'ship', 'store_tank', 'baseball_diamond', 21 | 'tennis_court', 'basketball_court', 'Ground_Track_Field', 22 | 'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter', 23 | 'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane', 24 | 'Harbor') 25 | 26 | PALETTE = [[0, 0, 0], [0, 0, 63], [0, 63, 63], [0, 63, 0], [0, 63, 127], 27 | [0, 63, 191], [0, 63, 255], [0, 127, 63], [0, 127, 127], 28 | [0, 0, 127], [0, 0, 191], [0, 0, 255], [0, 191, 127], 29 | [0, 127, 191], [0, 127, 255], [0, 100, 155]] 30 | 31 | def __init__(self, **kwargs): 32 | super(iSAIDDataset, self).__init__( 33 | img_suffix='.png', 34 | seg_map_suffix='.png', 35 | ignore_index=255, 36 | **kwargs) 37 | assert self.file_client.exists(self.img_dir) 38 | 39 | def load_annotations(self, 40 | img_dir, 41 | img_suffix, 42 | ann_dir, 43 | seg_map_suffix=None, 44 | split=None): 45 | """Load annotation from directory. 46 | 47 | Args: 48 | img_dir (str): Path to image directory 49 | img_suffix (str): Suffix of images. 50 | ann_dir (str|None): Path to annotation directory. 51 | seg_map_suffix (str|None): Suffix of segmentation maps. 52 | split (str|None): Split txt file. If split is specified, only file 53 | with suffix in the splits will be loaded. Otherwise, all images 54 | in img_dir/ann_dir will be loaded. Default: None 55 | 56 | Returns: 57 | list[dict]: All image info of dataset. 58 | """ 59 | 60 | img_infos = [] 61 | if split is not None: 62 | with open(split) as f: 63 | for line in f: 64 | name = line.strip() 65 | img_info = dict(filename=name + img_suffix) 66 | if ann_dir is not None: 67 | ann_name = name + '_instance_color_RGB' 68 | seg_map = ann_name + seg_map_suffix 69 | img_info['ann'] = dict(seg_map=seg_map) 70 | img_infos.append(img_info) 71 | else: 72 | for img in mmcv.scandir(img_dir, img_suffix, recursive=True): 73 | img_info = dict(filename=img) 74 | if ann_dir is not None: 75 | seg_img = img 76 | seg_map = seg_img.replace( 77 | img_suffix, '_instance_color_RGB' + seg_map_suffix) 78 | img_info['ann'] = dict(seg_map=seg_map) 79 | img_infos.append(img_info) 80 | 81 | print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger()) 82 | return img_infos 83 | -------------------------------------------------------------------------------- /mmseg/datasets/isprs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import DATASETS 3 | from .custom import CustomDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class ISPRSDataset(CustomDataset): 8 | """ISPRS dataset. 9 | 10 | In segmentation map annotation for LoveDA, 0 is the ignore index. 11 | ``reduce_zero_label`` should be set to True. The ``img_suffix`` and 12 | ``seg_map_suffix`` are both fixed to '.png'. 13 | """ 14 | CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree', 15 | 'car', 'clutter') 16 | 17 | PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], 18 | [255, 255, 0], [255, 0, 0]] 19 | 20 | def __init__(self, **kwargs): 21 | super(ISPRSDataset, self).__init__( 22 | img_suffix='.png', 23 | seg_map_suffix='.png', 24 | reduce_zero_label=True, 25 | **kwargs) 26 | -------------------------------------------------------------------------------- /mmseg/datasets/lane_datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tusen-ai/Anchor3DLane/5dfea5724c4a26cfb01a68e52383bc694e9be201/mmseg/datasets/lane_datasets/__init__.py -------------------------------------------------------------------------------- /mmseg/datasets/loveda.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | import mmcv 5 | import numpy as np 6 | from PIL import Image 7 | 8 | from .builder import DATASETS 9 | from .custom import CustomDataset 10 | 11 | 12 | @DATASETS.register_module() 13 | class LoveDADataset(CustomDataset): 14 | """LoveDA dataset. 15 | 16 | In segmentation map annotation for LoveDA, 0 is the ignore index. 17 | ``reduce_zero_label`` should be set to True. The ``img_suffix`` and 18 | ``seg_map_suffix`` are both fixed to '.png'. 19 | """ 20 | CLASSES = ('background', 'building', 'road', 'water', 'barren', 'forest', 21 | 'agricultural') 22 | 23 | PALETTE = [[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255], 24 | [159, 129, 183], [0, 255, 0], [255, 195, 128]] 25 | 26 | def __init__(self, **kwargs): 27 | super(LoveDADataset, self).__init__( 28 | img_suffix='.png', 29 | seg_map_suffix='.png', 30 | reduce_zero_label=True, 31 | **kwargs) 32 | 33 | def results2img(self, results, imgfile_prefix, indices=None): 34 | """Write the segmentation results to images. 35 | 36 | Args: 37 | results (list[ndarray]): Testing results of the 38 | dataset. 39 | imgfile_prefix (str): The filename prefix of the png files. 40 | If the prefix is "somepath/xxx", 41 | the png files will be named "somepath/xxx.png". 42 | indices (list[int], optional): Indices of input results, if not 43 | set, all the indices of the dataset will be used. 44 | Default: None. 45 | 46 | Returns: 47 | list[str: str]: result txt files which contains corresponding 48 | semantic segmentation images. 49 | """ 50 | 51 | mmcv.mkdir_or_exist(imgfile_prefix) 52 | result_files = [] 53 | for result, idx in zip(results, indices): 54 | 55 | filename = self.img_infos[idx]['filename'] 56 | basename = osp.splitext(osp.basename(filename))[0] 57 | 58 | png_filename = osp.join(imgfile_prefix, f'{basename}.png') 59 | 60 | # The index range of official requirement is from 0 to 6. 61 | output = Image.fromarray(result.astype(np.uint8)) 62 | output.save(png_filename) 63 | result_files.append(png_filename) 64 | 65 | return result_files 66 | 67 | def format_results(self, results, imgfile_prefix, indices=None): 68 | """Format the results into dir (standard format for LoveDA evaluation). 69 | 70 | Args: 71 | results (list): Testing results of the dataset. 72 | imgfile_prefix (str): The prefix of images files. It 73 | includes the file path and the prefix of filename, e.g., 74 | "a/b/prefix". 75 | indices (list[int], optional): Indices of input results, 76 | if not set, all the indices of the dataset will be used. 77 | Default: None. 78 | 79 | Returns: 80 | tuple: (result_files, tmp_dir), result_files is a list containing 81 | the image paths, tmp_dir is the temporal directory created 82 | for saving json/png files when img_prefix is not specified. 83 | """ 84 | if indices is None: 85 | indices = list(range(len(self))) 86 | 87 | assert isinstance(results, list), 'results must be a list.' 88 | assert isinstance(indices, list), 'indices must be a list.' 89 | 90 | result_files = self.results2img(results, imgfile_prefix, indices) 91 | 92 | return result_files 93 | -------------------------------------------------------------------------------- /mmseg/datasets/night_driving.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import DATASETS 3 | from .cityscapes import CityscapesDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class NightDrivingDataset(CityscapesDataset): 8 | """NightDrivingDataset dataset.""" 9 | 10 | def __init__(self, **kwargs): 11 | super().__init__( 12 | img_suffix='_leftImg8bit.png', 13 | seg_map_suffix='_gtCoarse_labelTrainIds.png', 14 | **kwargs) 15 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .compose import Compose 3 | from .formatting import (Collect, ImageToTensor, ToDataContainer, ToTensor, 4 | Transpose, to_tensor) 5 | from .loading import LoadAnnotations, LoadImageFromFile, LoadAnnotationsList, LoadImageListFromFile 6 | from .test_time_aug import MultiScaleFlipAug 7 | from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, 8 | PhotoMetricDistortion, RandomCrop, RandomCutOut, 9 | RandomFlip, RandomMosaic, RandomRotate, Rerange, 10 | Resize, RGB2Gray, SegRescale) 11 | from .lane_format import LaneFormat, MaskGenerate 12 | 13 | __all__ = [ 14 | 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', 15 | 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', 16 | 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 17 | 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', 18 | 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut', 19 | 'RandomMosaic', 'LaneFormat', 'MaskGenerate', 'LoadAnnotationsList', 20 | 'LoadImageListFromFile'] 21 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/compose.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import collections 3 | 4 | from mmcv.utils import build_from_cfg 5 | 6 | from ..builder import PIPELINES 7 | 8 | 9 | @PIPELINES.register_module() 10 | class Compose(object): 11 | """Compose multiple transforms sequentially. 12 | 13 | Args: 14 | transforms (Sequence[dict | callable]): Sequence of transform object or 15 | config dict to be composed. 16 | """ 17 | 18 | def __init__(self, transforms): 19 | assert isinstance(transforms, collections.abc.Sequence) 20 | self.transforms = [] 21 | for transform in transforms: 22 | if isinstance(transform, dict): 23 | transform = build_from_cfg(transform, PIPELINES) 24 | self.transforms.append(transform) 25 | elif callable(transform): 26 | self.transforms.append(transform) 27 | else: 28 | raise TypeError('transform must be callable or a dict') 29 | 30 | def __call__(self, data): 31 | """Call function to apply transforms sequentially. 32 | 33 | Args: 34 | data (dict): A result dict contains the data to transform. 35 | 36 | Returns: 37 | dict: Transformed data. 38 | """ 39 | 40 | for t in self.transforms: 41 | data = t(data) 42 | if data is None: 43 | return None 44 | return data 45 | 46 | def __repr__(self): 47 | format_string = self.__class__.__name__ + '(' 48 | for t in self.transforms: 49 | format_string += '\n' 50 | format_string += f' {t}' 51 | format_string += '\n)' 52 | return format_string 53 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/formating.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # flake8: noqa 3 | import warnings 4 | 5 | from .formatting import * 6 | 7 | warnings.warn('DeprecationWarning: mmseg.datasets.pipelines.formating will be ' 8 | 'deprecated in 2021, please replace it with ' 9 | 'mmseg.datasets.pipelines.formatting.') 10 | -------------------------------------------------------------------------------- /mmseg/datasets/pipelines/lane_format.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Source code for Anchor3DLane 3 | # Copyright (c) 2023 TuSimple 4 | # @Time : 2023/04/05 5 | # @Author : Shaofei Huang 6 | # nowherespyfly@gmail.com 7 | # -------------------------------------------------------- 8 | 9 | from collections.abc import Sequence 10 | 11 | import mmcv 12 | import numpy as np 13 | import torch 14 | from mmcv.parallel import DataContainer as DC 15 | from os import path as osp 16 | 17 | from ..builder import PIPELINES 18 | from .formatting import to_tensor 19 | import cv2 20 | import pdb 21 | 22 | @PIPELINES.register_module() 23 | class LaneFormat(object): 24 | """Default formatting bundle. 25 | 26 | It simplifies the pipeline of formatting common fields, including "img" 27 | and other lane data. These fields are formatted as follows. 28 | 29 | - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) 30 | """ 31 | 32 | def __call__(self, results): 33 | """Call function to transform and format common fields in results. 34 | 35 | Args: 36 | results (dict): Result dict contains the data to convert. 37 | 38 | Returns: 39 | dict: The result dict contains the data that is formatted with 40 | default bundle. 41 | """ 42 | if 'img' in results: 43 | img = results['img'] 44 | if len(img.shape) < 3: 45 | img = np.expand_dims(img, -1) 46 | if len(img.shape) > 3: 47 | # [H, W, 3, N] -> [3, H, W, N] 48 | img = np.ascontiguousarray(img.transpose(2, 0, 1, 3)) 49 | else: 50 | img = np.ascontiguousarray(img.transpose(2, 0, 1)) 51 | results['img'] = DC(to_tensor(img), stack=True) 52 | if 'gt_3dlanes' in results: 53 | results['gt_3dlanes'] = DC(to_tensor(results['gt_3dlanes'].astype(np.float32))) 54 | if 'gt_2dlanes' in results: 55 | results['gt_2dlanes'] = DC(to_tensor(results['gt_2dlanes'].astype(np.float32))) 56 | if 'gt_camera_extrinsic' in results: 57 | results['gt_camera_extrinsic'] = DC(to_tensor(results['gt_camera_extrinsic'][None, ...].astype(np.float32)), stack=True) 58 | if 'gt_camera_intrinsic' in results: 59 | results['gt_camera_intrinsic'] = DC(to_tensor(results['gt_camera_intrinsic'][None, ...].astype(np.float32)), stack=True) 60 | if 'gt_project_matrix' in results: 61 | results['gt_project_matrix'] = DC(to_tensor(results['gt_project_matrix'][None, ...].astype(np.float32)), stack=True) 62 | if 'gt_homography_matrix' in results: 63 | results['gt_homography_matrix'] = DC(to_tensor(results['gt_homography_matrix'][None, ...].astype(np.float32)), stack=True) 64 | if 'gt_camera_pitch' in results: 65 | results['gt_camera_pitch'] = DC(to_tensor([results['gt_camera_pitch']])) 66 | if 'gt_camera_height' in results: 67 | results['gt_camera_height'] = DC(to_tensor([results['gt_camera_height']])) 68 | if 'prev_poses' in results: 69 | results['prev_poses'] = DC(to_tensor(np.stack(results['prev_poses'], axis=0).astype(np.float32)), stack=True) # [Np, 3, 4] 70 | if 'mask' in results: 71 | results['mask'] = DC(to_tensor(results['mask'][None, ...].astype(np.float32)), stack=True) 72 | return results 73 | 74 | def __repr__(self): 75 | return self.__class__.__name__ 76 | 77 | 78 | @PIPELINES.register_module() 79 | class MaskGenerate(object): 80 | def __init__(self, input_size): 81 | self.input_size = input_size 82 | 83 | def __call__(self, results): 84 | mask = np.ones((self.input_size[0], self.input_size[1]), dtype=np.bool) 85 | mask = np.logical_not(mask) 86 | results['mask'] = mask 87 | return results -------------------------------------------------------------------------------- /mmseg/datasets/potsdam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import DATASETS 3 | from .custom import CustomDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class PotsdamDataset(CustomDataset): 8 | """ISPRS Potsdam dataset. 9 | 10 | In segmentation map annotation for Potsdam dataset, 0 is the ignore index. 11 | ``reduce_zero_label`` should be set to True. The ``img_suffix`` and 12 | ``seg_map_suffix`` are both fixed to '.png'. 13 | """ 14 | CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree', 15 | 'car', 'clutter') 16 | 17 | PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], 18 | [255, 255, 0], [255, 0, 0]] 19 | 20 | def __init__(self, **kwargs): 21 | super(PotsdamDataset, self).__init__( 22 | img_suffix='.png', 23 | seg_map_suffix='.png', 24 | reduce_zero_label=True, 25 | **kwargs) 26 | -------------------------------------------------------------------------------- /mmseg/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .distributed_sampler import DistributedSampler 3 | 4 | __all__ = ['DistributedSampler'] 5 | -------------------------------------------------------------------------------- /mmseg/datasets/samplers/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from __future__ import division 3 | from typing import Iterator, Optional 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torch.utils.data import DistributedSampler as _DistributedSampler 8 | 9 | from mmseg.core.utils import sync_random_seed 10 | from mmseg.utils import get_device 11 | 12 | 13 | class DistributedSampler(_DistributedSampler): 14 | """DistributedSampler inheriting from 15 | `torch.utils.data.DistributedSampler`. 16 | 17 | Args: 18 | datasets (Dataset): the dataset will be loaded. 19 | num_replicas (int, optional): Number of processes participating in 20 | distributed training. By default, world_size is retrieved from the 21 | current distributed group. 22 | rank (int, optional): Rank of the current process within num_replicas. 23 | By default, rank is retrieved from the current distributed group. 24 | shuffle (bool): If True (default), sampler will shuffle the indices. 25 | seed (int): random seed used to shuffle the sampler if 26 | :attr:`shuffle=True`. This number should be identical across all 27 | processes in the distributed group. Default: ``0``. 28 | """ 29 | 30 | def __init__(self, 31 | dataset: Dataset, 32 | num_replicas: Optional[int] = None, 33 | rank: Optional[int] = None, 34 | shuffle: bool = True, 35 | seed=0) -> None: 36 | super().__init__( 37 | dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 38 | 39 | # In distributed sampling, different ranks should sample 40 | # non-overlapped data in the dataset. Therefore, this function 41 | # is used to make sure that each rank shuffles the data indices 42 | # in the same order based on the same seed. Then different ranks 43 | # could use different indices to select non-overlapped data from the 44 | # same data list. 45 | device = get_device() 46 | self.seed = sync_random_seed(seed, device) 47 | 48 | def __iter__(self) -> Iterator: 49 | """ 50 | Yields: 51 | Iterator: iterator of indices for rank. 52 | """ 53 | # deterministically shuffle based on epoch 54 | if self.shuffle: 55 | g = torch.Generator() 56 | # When :attr:`shuffle=True`, this ensures all replicas 57 | # use a different random ordering for each epoch. 58 | # Otherwise, the next iteration of this sampler will 59 | # yield the same ordering. 60 | g.manual_seed(self.epoch + self.seed) 61 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 62 | else: 63 | indices = torch.arange(len(self.dataset)).tolist() 64 | 65 | # add extra samples to make it evenly divisible 66 | indices += indices[:(self.total_size - len(indices))] 67 | assert len(indices) == self.total_size 68 | 69 | # subsample 70 | indices = indices[self.rank:self.total_size:self.num_replicas] 71 | assert len(indices) == self.num_samples 72 | 73 | return iter(indices) 74 | -------------------------------------------------------------------------------- /mmseg/datasets/stare.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | from .builder import DATASETS 5 | from .custom import CustomDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class STAREDataset(CustomDataset): 10 | """STARE dataset. 11 | 12 | In segmentation map annotation for STARE, 0 stands for background, which is 13 | included in 2 categories. ``reduce_zero_label`` is fixed to False. The 14 | ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to 15 | '.ah.png'. 16 | """ 17 | 18 | CLASSES = ('background', 'vessel') 19 | 20 | PALETTE = [[120, 120, 120], [6, 230, 230]] 21 | 22 | def __init__(self, **kwargs): 23 | super(STAREDataset, self).__init__( 24 | img_suffix='.png', 25 | seg_map_suffix='.ah.png', 26 | reduce_zero_label=False, 27 | **kwargs) 28 | assert osp.exists(self.img_dir) 29 | -------------------------------------------------------------------------------- /mmseg/datasets/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tusen-ai/Anchor3DLane/5dfea5724c4a26cfb01a68e52383bc694e9be201/mmseg/datasets/tools/__init__.py -------------------------------------------------------------------------------- /mmseg/datasets/voc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | from .builder import DATASETS 5 | from .custom import CustomDataset 6 | 7 | 8 | @DATASETS.register_module() 9 | class PascalVOCDataset(CustomDataset): 10 | """Pascal VOC dataset. 11 | 12 | Args: 13 | split (str): Split txt file for Pascal VOC. 14 | """ 15 | 16 | CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 17 | 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 18 | 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 19 | 'train', 'tvmonitor') 20 | 21 | PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], 22 | [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], 23 | [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], 24 | [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], 25 | [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] 26 | 27 | def __init__(self, split, **kwargs): 28 | super(PascalVOCDataset, self).__init__( 29 | img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs) 30 | assert osp.exists(self.img_dir) and self.split is not None 31 | -------------------------------------------------------------------------------- /mmseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .backbones import * # noqa: F401,F403 3 | from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, ASSIGNER, MMSEGMENTORS, build_backbone, 4 | build_head, build_loss, build_segmentor, build_lanedetector, build_assigner, build_segmentor_multimodel) 5 | from .decode_heads import * # noqa: F401,F403 6 | from .losses import * # noqa: F401,F403 7 | from .necks import * # noqa: F401,F403 8 | from .segmentors import * # noqa: F401,F403 9 | from .lane_detector import * 10 | 11 | __all__ = [ 12 | 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'ASSIGNER', 'MMSEGMENTORS', 'build_backbone', 13 | 'build_head', 'build_loss', 'build_segmentor', 'build_lanedetector', 'build_assigner', 'build_segmentor_multimodel' 14 | ] 15 | -------------------------------------------------------------------------------- /mmseg/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .beit import BEiT 3 | from .bisenetv1 import BiSeNetV1 4 | from .bisenetv2 import BiSeNetV2 5 | from .cgnet import CGNet 6 | from .erfnet import ERFNet 7 | from .fast_scnn import FastSCNN 8 | from .hrnet import HRNet 9 | from .icnet import ICNet 10 | from .mae import MAE 11 | from .mit import MixVisionTransformer 12 | from .mobilenet_v2 import MobileNetV2 13 | from .mobilenet_v3 import MobileNetV3 14 | from .resnest import ResNeSt 15 | from .resnet import ResNet, ResNetV1c, ResNetV1d 16 | from .resnext import ResNeXt 17 | from .stdc import STDCContextPathNet, STDCNet 18 | from .swin import SwinTransformer 19 | from .twins import PCPVT, SVT 20 | from .unet import UNet 21 | from .vit import VisionTransformer 22 | from .efficientnet import EfficientNet 23 | 24 | __all__ = [ 25 | 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', 26 | 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', 27 | 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', 28 | 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'ERFNet', 'PCPVT', 29 | 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'EfficientNet' 30 | ] 31 | -------------------------------------------------------------------------------- /mmseg/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | from mmcv.cnn import MODELS as MMCV_MODELS 5 | from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION 6 | from mmcv.utils import Registry 7 | 8 | MODELS = Registry('models', parent=MMCV_MODELS) 9 | ATTENTION = Registry('attention', parent=MMCV_ATTENTION) 10 | ASSIGNER = Registry('assigner') 11 | 12 | BACKBONES = MODELS 13 | NECKS = MODELS 14 | HEADS = MODELS 15 | LOSSES = MODELS 16 | SEGMENTORS = MODELS 17 | LANENET2S = MODELS 18 | MMSEGMENTORS = MODELS 19 | READERS = MODELS 20 | 21 | 22 | def build_backbone(cfg): 23 | """Build backbone.""" 24 | return BACKBONES.build(cfg) 25 | 26 | 27 | def build_neck(cfg): 28 | """Build neck.""" 29 | return NECKS.build(cfg) 30 | 31 | 32 | def build_head(cfg): 33 | """Build head.""" 34 | return HEADS.build(cfg) 35 | 36 | 37 | def build_loss(cfg): 38 | """Build loss.""" 39 | return LOSSES.build(cfg) 40 | 41 | def build_lanedetector(cfg): 42 | return LANENET2S.build(cfg) 43 | 44 | def build_assigner(cfg): 45 | """Build anchor-gt matching function""" 46 | return ASSIGNER.build(cfg) 47 | 48 | def build_segmentor_multimodel(cfg): 49 | return MMSEGMENTORS.build(cfg) 50 | 51 | def build_reader(cfg): 52 | return READERS.build(cfg) 53 | 54 | def build_segmentor(cfg, train_cfg=None, test_cfg=None): 55 | """Build segmentor.""" 56 | if train_cfg is not None or test_cfg is not None: 57 | warnings.warn( 58 | 'train_cfg and test_cfg is deprecated, ' 59 | 'please specify them in model', UserWarning) 60 | assert cfg.get('train_cfg') is None or train_cfg is None, \ 61 | 'train_cfg specified in both outer field and model field ' 62 | assert cfg.get('test_cfg') is None or test_cfg is None, \ 63 | 'test_cfg specified in both outer field and model field ' 64 | return SEGMENTORS.build( 65 | cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) 66 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .ann_head import ANNHead 3 | from .apc_head import APCHead 4 | from .aspp_head import ASPPHead 5 | from .cc_head import CCHead 6 | from .da_head import DAHead 7 | from .dm_head import DMHead 8 | from .dnl_head import DNLHead 9 | from .dpt_head import DPTHead 10 | from .ema_head import EMAHead 11 | from .enc_head import EncHead 12 | from .fcn_head import FCNHead 13 | from .fpn_head import FPNHead 14 | from .gc_head import GCHead 15 | from .isa_head import ISAHead 16 | from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator 17 | from .lraspp_head import LRASPPHead 18 | from .nl_head import NLHead 19 | from .ocr_head import OCRHead 20 | from .point_head import PointHead 21 | from .psa_head import PSAHead 22 | from .psp_head import PSPHead 23 | from .segformer_head import SegformerHead 24 | from .segmenter_mask_head import SegmenterMaskTransformerHead 25 | from .sep_aspp_head import DepthwiseSeparableASPPHead 26 | from .sep_fcn_head import DepthwiseSeparableFCNHead 27 | from .setr_mla_head import SETRMLAHead 28 | from .setr_up_head import SETRUPHead 29 | from .stdc_head import STDCHead 30 | from .uper_head import UPerHead 31 | 32 | __all__ = [ 33 | 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', 34 | 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', 35 | 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead', 36 | 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', 37 | 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead', 38 | 'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead', 39 | 'KernelUpdateHead', 'KernelUpdator' 40 | ] 41 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/cascade_decode_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | from .decode_head import BaseDecodeHead 5 | 6 | 7 | class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta): 8 | """Base class for cascade decode head used in 9 | :class:`CascadeEncoderDecoder.""" 10 | 11 | def __init__(self, *args, **kwargs): 12 | super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs) 13 | 14 | @abstractmethod 15 | def forward(self, inputs, prev_output): 16 | """Placeholder of forward function.""" 17 | pass 18 | 19 | def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, 20 | train_cfg): 21 | """Forward function for training. 22 | Args: 23 | inputs (list[Tensor]): List of multi-level img features. 24 | prev_output (Tensor): The output of previous decode head. 25 | img_metas (list[dict]): List of image info dict where each dict 26 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 27 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 28 | For details on the values of these keys see 29 | `mmseg/datasets/pipelines/formatting.py:Collect`. 30 | gt_semantic_seg (Tensor): Semantic segmentation masks 31 | used if the architecture supports semantic segmentation task. 32 | train_cfg (dict): The training config. 33 | 34 | Returns: 35 | dict[str, Tensor]: a dictionary of loss components 36 | """ 37 | seg_logits = self.forward(inputs, prev_output) 38 | losses = self.losses(seg_logits, gt_semantic_seg) 39 | 40 | return losses 41 | 42 | def forward_test(self, inputs, prev_output, img_metas, test_cfg): 43 | """Forward function for testing. 44 | 45 | Args: 46 | inputs (list[Tensor]): List of multi-level img features. 47 | prev_output (Tensor): The output of previous decode head. 48 | img_metas (list[dict]): List of image info dict where each dict 49 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 50 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 51 | For details on the values of these keys see 52 | `mmseg/datasets/pipelines/formatting.py:Collect`. 53 | test_cfg (dict): The testing config. 54 | 55 | Returns: 56 | Tensor: Output segmentation map. 57 | """ 58 | return self.forward(inputs, prev_output) 59 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/cc_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from ..builder import HEADS 5 | from .fcn_head import FCNHead 6 | 7 | try: 8 | from mmcv.ops import CrissCrossAttention 9 | except ModuleNotFoundError: 10 | CrissCrossAttention = None 11 | 12 | 13 | @HEADS.register_module() 14 | class CCHead(FCNHead): 15 | """CCNet: Criss-Cross Attention for Semantic Segmentation. 16 | 17 | This head is the implementation of `CCNet 18 | `_. 19 | 20 | Args: 21 | recurrence (int): Number of recurrence of Criss Cross Attention 22 | module. Default: 2. 23 | """ 24 | 25 | def __init__(self, recurrence=2, **kwargs): 26 | if CrissCrossAttention is None: 27 | raise RuntimeError('Please install mmcv-full for ' 28 | 'CrissCrossAttention ops') 29 | super(CCHead, self).__init__(num_convs=2, **kwargs) 30 | self.recurrence = recurrence 31 | self.cca = CrissCrossAttention(self.channels) 32 | 33 | def forward(self, inputs): 34 | """Forward function.""" 35 | x = self._transform_inputs(inputs) 36 | output = self.convs[0](x) 37 | for _ in range(self.recurrence): 38 | output = self.cca(output) 39 | output = self.convs[1](output) 40 | if self.concat_input: 41 | output = self.conv_cat(torch.cat([x, output], dim=1)) 42 | output = self.cls_seg(output) 43 | return output 44 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/fcn_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from ..builder import HEADS 7 | from .decode_head import BaseDecodeHead 8 | 9 | 10 | @HEADS.register_module() 11 | class FCNHead(BaseDecodeHead): 12 | """Fully Convolution Networks for Semantic Segmentation. 13 | 14 | This head is implemented of `FCNNet `_. 15 | 16 | Args: 17 | num_convs (int): Number of convs in the head. Default: 2. 18 | kernel_size (int): The kernel size for convs in the head. Default: 3. 19 | concat_input (bool): Whether concat the input and output of convs 20 | before classification layer. 21 | dilation (int): The dilation rate for convs in the head. Default: 1. 22 | """ 23 | 24 | def __init__(self, 25 | num_convs=2, 26 | kernel_size=3, 27 | concat_input=True, 28 | dilation=1, 29 | **kwargs): 30 | assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int) 31 | self.num_convs = num_convs 32 | self.concat_input = concat_input 33 | self.kernel_size = kernel_size 34 | super(FCNHead, self).__init__(**kwargs) 35 | if num_convs == 0: 36 | assert self.in_channels == self.channels 37 | 38 | conv_padding = (kernel_size // 2) * dilation 39 | convs = [] 40 | convs.append( 41 | ConvModule( 42 | self.in_channels, 43 | self.channels, 44 | kernel_size=kernel_size, 45 | padding=conv_padding, 46 | dilation=dilation, 47 | conv_cfg=self.conv_cfg, 48 | norm_cfg=self.norm_cfg, 49 | act_cfg=self.act_cfg)) 50 | for i in range(num_convs - 1): 51 | convs.append( 52 | ConvModule( 53 | self.channels, 54 | self.channels, 55 | kernel_size=kernel_size, 56 | padding=conv_padding, 57 | dilation=dilation, 58 | conv_cfg=self.conv_cfg, 59 | norm_cfg=self.norm_cfg, 60 | act_cfg=self.act_cfg)) 61 | if num_convs == 0: 62 | self.convs = nn.Identity() 63 | else: 64 | self.convs = nn.Sequential(*convs) 65 | if self.concat_input: 66 | self.conv_cat = ConvModule( 67 | self.in_channels + self.channels, 68 | self.channels, 69 | kernel_size=kernel_size, 70 | padding=kernel_size // 2, 71 | conv_cfg=self.conv_cfg, 72 | norm_cfg=self.norm_cfg, 73 | act_cfg=self.act_cfg) 74 | 75 | def _forward_feature(self, inputs): 76 | """Forward function for feature maps before classifying each pixel with 77 | ``self.cls_seg`` fc. 78 | 79 | Args: 80 | inputs (list[Tensor]): List of multi-level img features. 81 | 82 | Returns: 83 | feats (Tensor): A tensor of shape (batch_size, self.channels, 84 | H, W) which is feature map for last layer of decoder head. 85 | """ 86 | x = self._transform_inputs(inputs) 87 | feats = self.convs(x) 88 | if self.concat_input: 89 | feats = self.conv_cat(torch.cat([x, feats], dim=1)) 90 | return feats 91 | 92 | def forward(self, inputs): 93 | """Forward function.""" 94 | output = self._forward_feature(inputs) 95 | output = self.cls_seg(output) 96 | return output 97 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/fpn_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmseg.ops import Upsample, resize 7 | from ..builder import HEADS 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | @HEADS.register_module() 12 | class FPNHead(BaseDecodeHead): 13 | """Panoptic Feature Pyramid Networks. 14 | 15 | This head is the implementation of `Semantic FPN 16 | `_. 17 | 18 | Args: 19 | feature_strides (tuple[int]): The strides for input feature maps. 20 | stack_lateral. All strides suppose to be power of 2. The first 21 | one is of largest resolution. 22 | """ 23 | 24 | def __init__(self, feature_strides, **kwargs): 25 | super(FPNHead, self).__init__( 26 | input_transform='multiple_select', **kwargs) 27 | assert len(feature_strides) == len(self.in_channels) 28 | assert min(feature_strides) == feature_strides[0] 29 | self.feature_strides = feature_strides 30 | 31 | self.scale_heads = nn.ModuleList() 32 | for i in range(len(feature_strides)): 33 | head_length = max( 34 | 1, 35 | int(np.log2(feature_strides[i]) - np.log2(feature_strides[0]))) 36 | scale_head = [] 37 | for k in range(head_length): 38 | scale_head.append( 39 | ConvModule( 40 | self.in_channels[i] if k == 0 else self.channels, 41 | self.channels, 42 | 3, 43 | padding=1, 44 | conv_cfg=self.conv_cfg, 45 | norm_cfg=self.norm_cfg, 46 | act_cfg=self.act_cfg)) 47 | if feature_strides[i] != feature_strides[0]: 48 | scale_head.append( 49 | Upsample( 50 | scale_factor=2, 51 | mode='bilinear', 52 | align_corners=self.align_corners)) 53 | self.scale_heads.append(nn.Sequential(*scale_head)) 54 | 55 | def forward(self, inputs): 56 | 57 | x = self._transform_inputs(inputs) 58 | 59 | output = self.scale_heads[0](x[0]) 60 | for i in range(1, len(self.feature_strides)): 61 | # non inplace 62 | output = output + resize( 63 | self.scale_heads[i](x[i]), 64 | size=output.shape[2:], 65 | mode='bilinear', 66 | align_corners=self.align_corners) 67 | 68 | output = self.cls_seg(output) 69 | return output 70 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/gc_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.cnn import ContextBlock 4 | 5 | from ..builder import HEADS 6 | from .fcn_head import FCNHead 7 | 8 | 9 | @HEADS.register_module() 10 | class GCHead(FCNHead): 11 | """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond. 12 | 13 | This head is the implementation of `GCNet 14 | `_. 15 | 16 | Args: 17 | ratio (float): Multiplier of channels ratio. Default: 1/4. 18 | pooling_type (str): The pooling type of context aggregation. 19 | Options are 'att', 'avg'. Default: 'avg'. 20 | fusion_types (tuple[str]): The fusion type for feature fusion. 21 | Options are 'channel_add', 'channel_mul'. Default: ('channel_add',) 22 | """ 23 | 24 | def __init__(self, 25 | ratio=1 / 4., 26 | pooling_type='att', 27 | fusion_types=('channel_add', ), 28 | **kwargs): 29 | super(GCHead, self).__init__(num_convs=2, **kwargs) 30 | self.ratio = ratio 31 | self.pooling_type = pooling_type 32 | self.fusion_types = fusion_types 33 | self.gc_block = ContextBlock( 34 | in_channels=self.channels, 35 | ratio=self.ratio, 36 | pooling_type=self.pooling_type, 37 | fusion_types=self.fusion_types) 38 | 39 | def forward(self, inputs): 40 | """Forward function.""" 41 | x = self._transform_inputs(inputs) 42 | output = self.convs[0](x) 43 | output = self.gc_block(output) 44 | output = self.convs[1](output) 45 | if self.concat_input: 46 | output = self.conv_cat(torch.cat([x, output], dim=1)) 47 | output = self.cls_seg(output) 48 | return output 49 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/lraspp_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv import is_tuple_of 5 | from mmcv.cnn import ConvModule 6 | 7 | from mmseg.ops import resize 8 | from ..builder import HEADS 9 | from .decode_head import BaseDecodeHead 10 | 11 | 12 | @HEADS.register_module() 13 | class LRASPPHead(BaseDecodeHead): 14 | """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3. 15 | 16 | This head is the improved implementation of `Searching for MobileNetV3 17 | `_. 18 | 19 | Args: 20 | branch_channels (tuple[int]): The number of output channels in every 21 | each branch. Default: (32, 64). 22 | """ 23 | 24 | def __init__(self, branch_channels=(32, 64), **kwargs): 25 | super(LRASPPHead, self).__init__(**kwargs) 26 | if self.input_transform != 'multiple_select': 27 | raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform ' 28 | f'must be \'multiple_select\'. But received ' 29 | f'\'{self.input_transform}\'') 30 | assert is_tuple_of(branch_channels, int) 31 | assert len(branch_channels) == len(self.in_channels) - 1 32 | self.branch_channels = branch_channels 33 | 34 | self.convs = nn.Sequential() 35 | self.conv_ups = nn.Sequential() 36 | for i in range(len(branch_channels)): 37 | self.convs.add_module( 38 | f'conv{i}', 39 | nn.Conv2d( 40 | self.in_channels[i], branch_channels[i], 1, bias=False)) 41 | self.conv_ups.add_module( 42 | f'conv_up{i}', 43 | ConvModule( 44 | self.channels + branch_channels[i], 45 | self.channels, 46 | 1, 47 | norm_cfg=self.norm_cfg, 48 | act_cfg=self.act_cfg, 49 | bias=False)) 50 | 51 | self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1) 52 | 53 | self.aspp_conv = ConvModule( 54 | self.in_channels[-1], 55 | self.channels, 56 | 1, 57 | norm_cfg=self.norm_cfg, 58 | act_cfg=self.act_cfg, 59 | bias=False) 60 | self.image_pool = nn.Sequential( 61 | nn.AvgPool2d(kernel_size=49, stride=(16, 20)), 62 | ConvModule( 63 | self.in_channels[2], 64 | self.channels, 65 | 1, 66 | act_cfg=dict(type='Sigmoid'), 67 | bias=False)) 68 | 69 | def forward(self, inputs): 70 | """Forward function.""" 71 | inputs = self._transform_inputs(inputs) 72 | 73 | x = inputs[-1] 74 | 75 | x = self.aspp_conv(x) * resize( 76 | self.image_pool(x), 77 | size=x.size()[2:], 78 | mode='bilinear', 79 | align_corners=self.align_corners) 80 | x = self.conv_up_input(x) 81 | 82 | for i in range(len(self.branch_channels) - 1, -1, -1): 83 | x = resize( 84 | x, 85 | size=inputs[i].size()[2:], 86 | mode='bilinear', 87 | align_corners=self.align_corners) 88 | x = torch.cat([x, self.convs[i](inputs[i])], 1) 89 | x = self.conv_ups[i](x) 90 | 91 | return self.cls_seg(x) 92 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/nl_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.cnn import NonLocal2d 4 | 5 | from ..builder import HEADS 6 | from .fcn_head import FCNHead 7 | 8 | 9 | @HEADS.register_module() 10 | class NLHead(FCNHead): 11 | """Non-local Neural Networks. 12 | 13 | This head is the implementation of `NLNet 14 | `_. 15 | 16 | Args: 17 | reduction (int): Reduction factor of projection transform. Default: 2. 18 | use_scale (bool): Whether to scale pairwise_weight by 19 | sqrt(1/inter_channels). Default: True. 20 | mode (str): The nonlocal mode. Options are 'embedded_gaussian', 21 | 'dot_product'. Default: 'embedded_gaussian.'. 22 | """ 23 | 24 | def __init__(self, 25 | reduction=2, 26 | use_scale=True, 27 | mode='embedded_gaussian', 28 | **kwargs): 29 | super(NLHead, self).__init__(num_convs=2, **kwargs) 30 | self.reduction = reduction 31 | self.use_scale = use_scale 32 | self.mode = mode 33 | self.nl_block = NonLocal2d( 34 | in_channels=self.channels, 35 | reduction=self.reduction, 36 | use_scale=self.use_scale, 37 | conv_cfg=self.conv_cfg, 38 | norm_cfg=self.norm_cfg, 39 | mode=self.mode) 40 | 41 | def forward(self, inputs): 42 | """Forward function.""" 43 | x = self._transform_inputs(inputs) 44 | output = self.convs[0](x) 45 | output = self.nl_block(output) 46 | output = self.convs[1](output) 47 | if self.concat_input: 48 | output = self.conv_cat(torch.cat([x, output], dim=1)) 49 | output = self.cls_seg(output) 50 | return output 51 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/sep_fcn_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import DepthwiseSeparableConvModule 3 | 4 | from ..builder import HEADS 5 | from .fcn_head import FCNHead 6 | 7 | 8 | @HEADS.register_module() 9 | class DepthwiseSeparableFCNHead(FCNHead): 10 | """Depthwise-Separable Fully Convolutional Network for Semantic 11 | Segmentation. 12 | 13 | This head is implemented according to `Fast-SCNN: Fast Semantic 14 | Segmentation Network `_. 15 | 16 | Args: 17 | in_channels(int): Number of output channels of FFM. 18 | channels(int): Number of middle-stage channels in the decode head. 19 | concat_input(bool): Whether to concatenate original decode input into 20 | the result of several consecutive convolution layers. 21 | Default: True. 22 | num_classes(int): Used to determine the dimension of 23 | final prediction tensor. 24 | in_index(int): Correspond with 'out_indices' in FastSCNN backbone. 25 | norm_cfg (dict | None): Config of norm layers. 26 | align_corners (bool): align_corners argument of F.interpolate. 27 | Default: False. 28 | loss_decode(dict): Config of loss type and some 29 | relevant additional options. 30 | dw_act_cfg (dict):Activation config of depthwise ConvModule. If it is 31 | 'default', it will be the same as `act_cfg`. Default: None. 32 | """ 33 | 34 | def __init__(self, dw_act_cfg=None, **kwargs): 35 | super(DepthwiseSeparableFCNHead, self).__init__(**kwargs) 36 | self.convs[0] = DepthwiseSeparableConvModule( 37 | self.in_channels, 38 | self.channels, 39 | kernel_size=self.kernel_size, 40 | padding=self.kernel_size // 2, 41 | norm_cfg=self.norm_cfg, 42 | dw_act_cfg=dw_act_cfg) 43 | 44 | for i in range(1, self.num_convs): 45 | self.convs[i] = DepthwiseSeparableConvModule( 46 | self.channels, 47 | self.channels, 48 | kernel_size=self.kernel_size, 49 | padding=self.kernel_size // 2, 50 | norm_cfg=self.norm_cfg, 51 | dw_act_cfg=dw_act_cfg) 52 | 53 | if self.concat_input: 54 | self.conv_cat = DepthwiseSeparableConvModule( 55 | self.in_channels + self.channels, 56 | self.channels, 57 | kernel_size=self.kernel_size, 58 | padding=self.kernel_size // 2, 59 | norm_cfg=self.norm_cfg, 60 | dw_act_cfg=dw_act_cfg) 61 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/setr_mla_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | 6 | from mmseg.ops import Upsample 7 | from ..builder import HEADS 8 | from .decode_head import BaseDecodeHead 9 | 10 | 11 | @HEADS.register_module() 12 | class SETRMLAHead(BaseDecodeHead): 13 | """Multi level feature aggretation head of SETR. 14 | 15 | MLA head of `SETR `_. 16 | 17 | Args: 18 | mlahead_channels (int): Channels of conv-conv-4x of multi-level feature 19 | aggregation. Default: 128. 20 | up_scale (int): The scale factor of interpolate. Default:4. 21 | """ 22 | 23 | def __init__(self, mla_channels=128, up_scale=4, **kwargs): 24 | super(SETRMLAHead, self).__init__( 25 | input_transform='multiple_select', **kwargs) 26 | self.mla_channels = mla_channels 27 | 28 | num_inputs = len(self.in_channels) 29 | 30 | # Refer to self.cls_seg settings of BaseDecodeHead 31 | assert self.channels == num_inputs * mla_channels 32 | 33 | self.up_convs = nn.ModuleList() 34 | for i in range(num_inputs): 35 | self.up_convs.append( 36 | nn.Sequential( 37 | ConvModule( 38 | in_channels=self.in_channels[i], 39 | out_channels=mla_channels, 40 | kernel_size=3, 41 | padding=1, 42 | norm_cfg=self.norm_cfg, 43 | act_cfg=self.act_cfg), 44 | ConvModule( 45 | in_channels=mla_channels, 46 | out_channels=mla_channels, 47 | kernel_size=3, 48 | padding=1, 49 | norm_cfg=self.norm_cfg, 50 | act_cfg=self.act_cfg), 51 | Upsample( 52 | scale_factor=up_scale, 53 | mode='bilinear', 54 | align_corners=self.align_corners))) 55 | 56 | def forward(self, inputs): 57 | inputs = self._transform_inputs(inputs) 58 | outs = [] 59 | for x, up_conv in zip(inputs, self.up_convs): 60 | outs.append(up_conv(x)) 61 | out = torch.cat(outs, dim=1) 62 | out = self.cls_seg(out) 63 | return out 64 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/setr_up_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule, build_norm_layer 4 | 5 | from mmseg.ops import Upsample 6 | from ..builder import HEADS 7 | from .decode_head import BaseDecodeHead 8 | 9 | 10 | @HEADS.register_module() 11 | class SETRUPHead(BaseDecodeHead): 12 | """Naive upsampling head and Progressive upsampling head of SETR. 13 | 14 | Naive or PUP head of `SETR `_. 15 | 16 | Args: 17 | norm_layer (dict): Config dict for input normalization. 18 | Default: norm_layer=dict(type='LN', eps=1e-6, requires_grad=True). 19 | num_convs (int): Number of decoder convolutions. Default: 1. 20 | up_scale (int): The scale factor of interpolate. Default:4. 21 | kernel_size (int): The kernel size of convolution when decoding 22 | feature information from backbone. Default: 3. 23 | init_cfg (dict | list[dict] | None): Initialization config dict. 24 | Default: dict( 25 | type='Constant', val=1.0, bias=0, layer='LayerNorm'). 26 | """ 27 | 28 | def __init__(self, 29 | norm_layer=dict(type='LN', eps=1e-6, requires_grad=True), 30 | num_convs=1, 31 | up_scale=4, 32 | kernel_size=3, 33 | init_cfg=[ 34 | dict(type='Constant', val=1.0, bias=0, layer='LayerNorm'), 35 | dict( 36 | type='Normal', 37 | std=0.01, 38 | override=dict(name='conv_seg')) 39 | ], 40 | **kwargs): 41 | 42 | assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.' 43 | 44 | super(SETRUPHead, self).__init__(init_cfg=init_cfg, **kwargs) 45 | 46 | assert isinstance(self.in_channels, int) 47 | 48 | _, self.norm = build_norm_layer(norm_layer, self.in_channels) 49 | 50 | self.up_convs = nn.ModuleList() 51 | in_channels = self.in_channels 52 | out_channels = self.channels 53 | for _ in range(num_convs): 54 | self.up_convs.append( 55 | nn.Sequential( 56 | ConvModule( 57 | in_channels=in_channels, 58 | out_channels=out_channels, 59 | kernel_size=kernel_size, 60 | stride=1, 61 | padding=int(kernel_size - 1) // 2, 62 | norm_cfg=self.norm_cfg, 63 | act_cfg=self.act_cfg), 64 | Upsample( 65 | scale_factor=up_scale, 66 | mode='bilinear', 67 | align_corners=self.align_corners))) 68 | in_channels = out_channels 69 | 70 | def forward(self, x): 71 | x = self._transform_inputs(x) 72 | 73 | n, c, h, w = x.shape 74 | x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() 75 | x = self.norm(x) 76 | x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() 77 | 78 | for up_conv in self.up_convs: 79 | x = up_conv(x) 80 | out = self.cls_seg(x) 81 | return out 82 | -------------------------------------------------------------------------------- /mmseg/models/decode_heads/stdc_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from ..builder import HEADS 6 | from .fcn_head import FCNHead 7 | 8 | 9 | @HEADS.register_module() 10 | class STDCHead(FCNHead): 11 | """This head is the implementation of `Rethinking BiSeNet For Real-time 12 | Semantic Segmentation `_. 13 | 14 | Args: 15 | boundary_threshold (float): The threshold of calculating boundary. 16 | Default: 0.1. 17 | """ 18 | 19 | def __init__(self, boundary_threshold=0.1, **kwargs): 20 | super(STDCHead, self).__init__(**kwargs) 21 | self.boundary_threshold = boundary_threshold 22 | # Using register buffer to make laplacian kernel on the same 23 | # device of `seg_label`. 24 | self.register_buffer( 25 | 'laplacian_kernel', 26 | torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1], 27 | dtype=torch.float32, 28 | requires_grad=False).reshape((1, 1, 3, 3))) 29 | self.fusion_kernel = torch.nn.Parameter( 30 | torch.tensor([[6. / 10], [3. / 10], [1. / 10]], 31 | dtype=torch.float32).reshape(1, 3, 1, 1), 32 | requires_grad=False) 33 | 34 | def losses(self, seg_logit, seg_label): 35 | """Compute Detail Aggregation Loss.""" 36 | # Note: The paper claims `fusion_kernel` is a trainable 1x1 conv 37 | # parameters. However, it is a constant in original repo and other 38 | # codebase because it would not be added into computation graph 39 | # after threshold operation. 40 | seg_label = seg_label.to(self.laplacian_kernel) 41 | boundary_targets = F.conv2d( 42 | seg_label, self.laplacian_kernel, padding=1) 43 | boundary_targets = boundary_targets.clamp(min=0) 44 | boundary_targets[boundary_targets > self.boundary_threshold] = 1 45 | boundary_targets[boundary_targets <= self.boundary_threshold] = 0 46 | 47 | boundary_targets_x2 = F.conv2d( 48 | seg_label, self.laplacian_kernel, stride=2, padding=1) 49 | boundary_targets_x2 = boundary_targets_x2.clamp(min=0) 50 | 51 | boundary_targets_x4 = F.conv2d( 52 | seg_label, self.laplacian_kernel, stride=4, padding=1) 53 | boundary_targets_x4 = boundary_targets_x4.clamp(min=0) 54 | 55 | boundary_targets_x4_up = F.interpolate( 56 | boundary_targets_x4, boundary_targets.shape[2:], mode='nearest') 57 | boundary_targets_x2_up = F.interpolate( 58 | boundary_targets_x2, boundary_targets.shape[2:], mode='nearest') 59 | 60 | boundary_targets_x2_up[ 61 | boundary_targets_x2_up > self.boundary_threshold] = 1 62 | boundary_targets_x2_up[ 63 | boundary_targets_x2_up <= self.boundary_threshold] = 0 64 | 65 | boundary_targets_x4_up[ 66 | boundary_targets_x4_up > self.boundary_threshold] = 1 67 | boundary_targets_x4_up[ 68 | boundary_targets_x4_up <= self.boundary_threshold] = 0 69 | 70 | boundary_targets_pyramids = torch.stack( 71 | (boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up), 72 | dim=1) 73 | 74 | boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2) 75 | boundary_targets_pyramid = F.conv2d(boundary_targets_pyramids, 76 | self.fusion_kernel) 77 | 78 | boundary_targets_pyramid[ 79 | boundary_targets_pyramid > self.boundary_threshold] = 1 80 | boundary_targets_pyramid[ 81 | boundary_targets_pyramid <= self.boundary_threshold] = 0 82 | 83 | loss = super(STDCHead, self).losses(seg_logit, 84 | boundary_targets_pyramid.long()) 85 | return loss 86 | -------------------------------------------------------------------------------- /mmseg/models/lane_detector/__init__.py: -------------------------------------------------------------------------------- 1 | from .anchor_3dlane import Anchor3DLane 2 | from .anchor_3dlane_deform import Anchor3DLaneDeform 3 | from .anchor_3dlane_multiframe import Anchor3DLaneMF 4 | from .utils import * 5 | from .assigner import * 6 | 7 | 8 | __all__ = ['Anchor3DLane', 'Anchor3DLaneMF', 'Anchor3DLaneDeform'] -------------------------------------------------------------------------------- /mmseg/models/lane_detector/assigner/__init__.py: -------------------------------------------------------------------------------- 1 | from .thresh_assigner import ThreshAssigner 2 | from .topk_assigner import TopkAssigner 3 | from .topk_fv_assigner import TopkFVAssigner 4 | 5 | __all__ = ['ThreshAssigner', 'TopkAssigner', 'TopkFVAssigner'] -------------------------------------------------------------------------------- /mmseg/models/lane_detector/assigner/thresh_assigner.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Source code for Anchor3DLane 3 | # Copyright (c) 2023 TuSimple 4 | # @Time : 2023/04/05 5 | # @Author : Shaofei Huang 6 | # nowherespyfly@gmail.com 7 | # -------------------------------------------------------- 8 | 9 | import torch 10 | from ...builder import ASSIGNER 11 | from .distance_metric import Euclidean_dis, Manhattan_dis, Partial_Euclidean_dis 12 | 13 | INFINITY = 987654. 14 | 15 | @ASSIGNER.register_module() 16 | class ThreshAssigner(object): 17 | def __init__(self, t_pos=3.5, t_neg=4.5, anchor_len=10, pos_k=5, neg_k=2000, metric='Euclidean', **kwargs): 18 | self.t_pos = t_pos 19 | self.t_neg = t_neg 20 | self.anchor_len = anchor_len 21 | self.metric = metric 22 | self.pos_k = pos_k 23 | self.neg_k = neg_k 24 | 25 | def match_proposals_with_targets(self, proposals, targets, **kwargs): 26 | valid_targets = targets[targets[:, 1] > 0] 27 | num_proposals = proposals.shape[0] # [Np, 35], [pos_score, neg_score, start_y, end_y, length, x_coord, z_coord, vis] 28 | num_targets = valid_targets.shape[0] # [Nt, 35], [1, 0, start_y, end_y, length, x_coord, z_coord, vis] 29 | 30 | proposals = torch.repeat_interleave(proposals, num_targets, dim=0) # [Np * Nt, 35], [a, b] -> [a, a, b, b] 31 | valid_targets = torch.cat(num_proposals * [valid_targets]) # [Nt * Np, 10, 35], [c, d] -> [c, d, c, d] 32 | 33 | if self.metric == 'Euclidean': 34 | distances = Euclidean_dis(proposals, valid_targets, num_proposals, num_targets, anchor_len=self.anchor_len) 35 | elif self.metric == 'Manhattan': 36 | distances = Manhattan_dis(proposals, valid_targets, num_proposals, num_targets, anchor_len=self.anchor_len) 37 | elif self.metric == 'Partial_Euclidean': 38 | distances = Partial_Euclidean_dis(proposals, valid_targets, num_proposals, num_targets, anchor_len=self.anchor_len) 39 | else: 40 | raise Exception("No metrics as ", self.metric) 41 | 42 | positives = distances.min(dim=1)[0] < self.t_pos # [Np], [True, True, False, False, ...] 43 | all_negatives = distances.min(dim=1)[0] > self.t_neg # [Np] [False, False, True, False, ...] 44 | 45 | # randomly select negatives 46 | all_neg_indices = all_negatives.nonzero().view(-1) # [Num_neg] 47 | perm = torch.randperm(all_neg_indices.shape[0]) # [Num_neg] 48 | neg_k = min(self.neg_k, len(all_neg_indices)) 49 | negative_indices = all_neg_indices[perm[:neg_k]] # [neg_k] 50 | negatives = distances.new_zeros(num_proposals).to(torch.bool) # [Np] 51 | negatives[negative_indices] = True 52 | if positives.sum() == 0: 53 | target_positives_indices = torch.tensor([], device=positives.device, dtype=torch.long) 54 | else: 55 | target_positives_indices = distances[positives].argmin(dim=1) # [N_pos] 56 | 57 | return positives, negatives, target_positives_indices 58 | -------------------------------------------------------------------------------- /mmseg/models/lane_detector/assigner/topk_fv_assigner.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Source code for Anchor3DLane 3 | # Copyright (c) 2023 TuSimple 4 | # @Time : 2023/04/05 5 | # @Author : Shaofei Huang 6 | # nowherespyfly@gmail.com 7 | # -------------------------------------------------------- 8 | 9 | import torch 10 | import numpy as np 11 | from ...builder import ASSIGNER 12 | from .distance_metric import * 13 | 14 | INFINITY = 987654. 15 | 16 | @ASSIGNER.register_module() 17 | class TopkFVAssigner(object): 18 | def __init__(self, pos_k=30, neg_k=100, anchor_len=10, y_steps_3d=None, w2d=0.5, w3d=0.5, **kwargs): 19 | self.pos_k = pos_k 20 | self.neg_k = neg_k 21 | self.anchor_len = anchor_len 22 | self.y_steps_3d = np.array(y_steps_3d, dtype=np.float32) 23 | self.w2d = w2d 24 | self.w3d = w3d 25 | 26 | def match_proposals_with_targets(self, proposals, targets_3d, targets_2d, P_g2im, anchor_len_2d=72): 27 | 28 | num_proposals = proposals.shape[0] # [Np, 35], [pos_score, neg_score, start_y, end_y, length, x_coord, z_coord, vis] 29 | num_targets = targets_3d.shape[0] # [Nt, 35], [1, 0, start_y, end_y, length, x_coord, z_coord, vis] 30 | proposals = torch.repeat_interleave(proposals, num_targets, dim=0) # [Np * Nt, 35], [a, b] -> [a, a, b, b] 31 | targets_3d = torch.cat(num_proposals * [targets_3d]) # [Nt * Np, 10, 35], [c, d] -> [c, d, c, d] 32 | targets_2d = torch.cat(num_proposals * [targets_2d]) # [Nt * Np, 10, 35], [c, d] -> [c, d, c, d] 33 | distances_fv = FV_Euclidean(proposals, targets_2d, num_proposals, num_targets, anchor_len=self.anchor_len, 34 | y_steps_3d=self.y_steps_3d, P_g2im=P_g2im, anchor_len_2d=anchor_len_2d) / 360 35 | # [Nt, Np] 36 | distances_fv = distances_fv / distances_fv[distances_fv < INFINITY].max() 37 | distances_tv = Euclidean_dis(proposals, targets_3d, num_proposals, num_targets, anchor_len=self.anchor_len) 38 | distances_tv = distances_tv / distances_tv[distances_tv < INFINITY].max() 39 | distances = distances_fv * self.w2d + distances_tv * self.w3d 40 | # in case the same anchor been assigned twice 41 | min_indices = distances.min(dim=1)[1] # [Np] 42 | range_indices = torch.arange(num_proposals).long().to(distances.device) 43 | invalid_mask = distances.new_ones(num_proposals, num_targets).to(torch.bool) # [Np, Nt] 44 | invalid_mask[range_indices, min_indices] = False 45 | distances[invalid_mask] = INFINITY 46 | 47 | # select topk anchors for each gt 48 | topk_distances, topk_indices = distances.topk(self.pos_k, dim=0, largest=False) # [pos_k, Nt] 49 | all_pos_indices = topk_indices.view(-1) # [pos_k * Nt] 50 | positives = distances.new_zeros(num_proposals).to(torch.bool) # [Np] 51 | positives[all_pos_indices] = True 52 | negatives = ~positives 53 | all_neg_indices = negatives.nonzero().view(-1) # [Num_neg] 54 | perm = torch.randperm(all_neg_indices.shape[0]) # [Num_neg] 55 | negative_indices = all_neg_indices[perm[:self.neg_k]] # [neg_k] 56 | negatives = distances.new_zeros(num_proposals).to(torch.bool) # [Np] 57 | negatives[negative_indices] = True 58 | 59 | if positives.sum() == 0: 60 | target_positives_indices = torch.tensor([], device=positives.device, dtype=torch.long) 61 | else: 62 | target_positives_indices = distances[positives].argmin(dim=1) # [N_pos] 63 | 64 | return positives, negatives, target_positives_indices -------------------------------------------------------------------------------- /mmseg/models/lane_detector/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .anchor import AnchorGenerator 2 | from .nms import nms_3d -------------------------------------------------------------------------------- /mmseg/models/lane_detector/utils/anchor.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Source code for Anchor3DLane 3 | # Copyright (c) 2023 TuSimple 4 | # @Time : 2023/04/05 5 | # @Author : Shaofei Huang 6 | # nowherespyfly@gmail.com 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import numpy as np 11 | import json 12 | import math 13 | import torch 14 | import cv2 15 | import mmcv 16 | from mmseg.datasets.tools.utils import projection_g2im, resample_laneline_in_y 17 | 18 | class AnchorGenerator(object): 19 | """Normalized anchor coords""" 20 | def __init__(self, anchor_cfg, y_steps=None, x_min=None, x_max=None, y_max=100, norm=None): 21 | self.y_steps = y_steps 22 | if self.y_steps is None: 23 | self.y_steps = np.linspace(1, y_max, y_max) 24 | self.pitches = anchor_cfg['pitches'] 25 | self.yaws = anchor_cfg['yaws'] 26 | self.num_x = anchor_cfg['num_x'] 27 | self.anchor_len = len(self.y_steps) 28 | self.x_min = x_min 29 | self.x_max = x_max 30 | self.y_max = y_max 31 | self.norm = norm 32 | self.start_z = anchor_cfg.get('start_z', 0) 33 | 34 | def generate_anchors(self): 35 | anchors = [] 36 | starts = [x for x in np.linspace(self.x_min, self.x_max, num=self.num_x, dtype=np.float32)] 37 | idx = 0 38 | for start_x in starts: 39 | for pitch in self.pitches: 40 | for yaw in self.yaws: 41 | anchor = self.generate_anchor(start_x, pitch, yaw, start_z = self.start_z) 42 | if anchor is not None: 43 | anchors.append(anchor) 44 | idx += 1 45 | self.anchor_num = len(anchors) 46 | print("anchor:", len(anchors)) 47 | anchors = np.array(anchors) 48 | return anchors 49 | 50 | def generate_anchor(self, start_x, pitch, yaw, start_z=0, cut=True): 51 | # anchor [pos_score, neg_score, start_y, end_y, d, x_coords * 10, z_coords * 10, vis_coords * 10] 52 | anchor = np.zeros(2 + 2 + 1 + self.anchor_len * 3, dtype=np.float32) 53 | pitch = pitch * math.pi / 180. # degrees to radians 54 | yaw = yaw * math.pi / 180. 55 | anchor[2] = 0 56 | anchor[3] = 1 57 | anchor[5:5+self.anchor_len] = start_x + self.y_steps * math.tan(yaw) 58 | anchor[5+self.anchor_len:5+self.anchor_len*2] = start_z + self.y_steps * math.tan(pitch) 59 | anchor_vis = np.logical_and(anchor[5:5+self.anchor_len] > self.x_min, anchor[5:5+self.anchor_len] < self.x_max) 60 | if cut: 61 | if sum(anchor_vis) / self.anchor_len < 0.5: 62 | return None 63 | return anchor 64 | 65 | -------------------------------------------------------------------------------- /mmseg/models/lane_detector/utils/nms.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Source code for Anchor3DLane 3 | # Copyright (c) 2023 TuSimple 4 | # @Time : 2023/04/05 5 | # @Author : Shaofei Huang 6 | # nowherespyfly@gmail.com 7 | # -------------------------------------------------------- 8 | 9 | import torch 10 | 11 | @torch.no_grad() 12 | def nms_3d(proposals, scores, vises, thresh, anchor_len=10): 13 | # proposals: [N, 35], scores: [N] 14 | order = scores.argsort(descending=True) 15 | keep = [] 16 | while order.shape[0] > 0: 17 | i = order[0] 18 | keep.append(i) 19 | x1 = proposals[i][5:5+anchor_len] # [l] 20 | z1 = proposals[i][5+anchor_len:5+anchor_len*2] # [l] 21 | vis1 = vises[i] # [l] 22 | x2s = proposals[order[1:]][:, 5:5+anchor_len] # [n, l] 23 | z2s = proposals[order[1:]][:, 5+anchor_len:5+anchor_len*2] # [n, l] 24 | vis2s = vises[order[1:]] # [n, l] 25 | matched = vis1 * vis2s # [n, l] 26 | lengths = matched.sum(dim=1) # [n] 27 | dis = ((x1 - x2s) ** 2 + (z1 - z2s) ** 2) ** 0.5 # [n, l] 28 | dis = (dis * matched + 1e-6).sum(dim=1) / (lengths + 1e-6) # [n], incase no matched points 29 | inds = torch.where(dis > thresh)[0] # [n'] 30 | order = order[inds + 1] # [n'] 31 | 32 | return torch.tensor(keep) 33 | 34 | if __name__ == '__main__': 35 | anchors = [] 36 | x_base = [i + 0. for i in range(10)] 37 | z_base = [i / 10. for i in range(10)] 38 | vis = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 39 | anchor1 = [0, 0.99, 0, 0, 0] + x_base + z_base + vis 40 | anchor1 = torch.tensor(anchor1) 41 | anchor2 = anchor1.clone() 42 | anchor2[1] = 0.9 43 | anchor2[5:15] = anchor2[5:15] + torch.randn(10) - 0.5 44 | anchor2[33:35] = 0 45 | anchor3 = anchor1.clone() 46 | anchor3[1] = 0.6 47 | anchor3[15:25] = anchor3[15:25] + torch.randn(10) / 10 48 | anchor3[25:28] = 0 49 | anchor4 = anchor1.clone() 50 | anchor4[1] = 0.5 51 | anchor4[5:15] = anchor4[5:15] + 10 52 | anchor4[15:25] = anchor4[15:25] + 0.1 53 | anchor4[25:34] = 0 54 | anchor5 = anchor4.clone() 55 | anchor5[1] = 0.8 56 | anchor5[25:30] = 1 57 | anchors = torch.stack([anchor2, anchor1, anchor3, anchor4, anchor5], dim=0) # [2, 35] 58 | scores = anchors[:, 1] 59 | print(anchors) 60 | keep = nms_3d(anchors, scores, 1.) 61 | import pdb 62 | pdb.set_trace() 63 | -------------------------------------------------------------------------------- /mmseg/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .accuracy import Accuracy, accuracy 3 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, 4 | cross_entropy, mask_cross_entropy) 5 | from .dice_loss import DiceLoss 6 | from .focal_loss import FocalLossSigmoid 7 | from .lovasz_loss import LovaszLoss 8 | from .utils import reduce_loss, weight_reduce_loss, weighted_loss 9 | from .lane_loss import LaneLoss 10 | 11 | __all__ = [ 12 | 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', 13 | 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 14 | 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss', 15 | 'FocalLossSigmoid', 'LaneLoss' 16 | ] 17 | -------------------------------------------------------------------------------- /mmseg/models/losses/accuracy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | 5 | def accuracy(pred, target, topk=1, thresh=None, ignore_index=None): 6 | """Calculate accuracy according to the prediction and target. 7 | 8 | Args: 9 | pred (torch.Tensor): The model prediction, shape (N, num_class, ...) 10 | target (torch.Tensor): The target of each prediction, shape (N, , ...) 11 | ignore_index (int | None): The label index to be ignored. Default: None 12 | topk (int | tuple[int], optional): If the predictions in ``topk`` 13 | matches the target, the predictions will be regarded as 14 | correct ones. Defaults to 1. 15 | thresh (float, optional): If not None, predictions with scores under 16 | this threshold are considered incorrect. Default to None. 17 | 18 | Returns: 19 | float | tuple[float]: If the input ``topk`` is a single integer, 20 | the function will return a single float as accuracy. If 21 | ``topk`` is a tuple containing multiple integers, the 22 | function will return a tuple containing accuracies of 23 | each ``topk`` number. 24 | """ 25 | assert isinstance(topk, (int, tuple)) 26 | if isinstance(topk, int): 27 | topk = (topk, ) 28 | return_single = True 29 | else: 30 | return_single = False 31 | 32 | maxk = max(topk) 33 | if pred.size(0) == 0: 34 | accu = [pred.new_tensor(0.) for i in range(len(topk))] 35 | return accu[0] if return_single else accu 36 | assert pred.ndim == target.ndim + 1 37 | assert pred.size(0) == target.size(0) 38 | assert maxk <= pred.size(1), \ 39 | f'maxk {maxk} exceeds pred dimension {pred.size(1)}' 40 | pred_value, pred_label = pred.topk(maxk, dim=1) 41 | # transpose to shape (maxk, N, ...) 42 | pred_label = pred_label.transpose(0, 1) 43 | correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) 44 | if thresh is not None: 45 | # Only prediction values larger than thresh are counted as correct 46 | correct = correct & (pred_value > thresh).t() 47 | if ignore_index is not None: 48 | correct = correct[:, target != ignore_index] 49 | res = [] 50 | eps = torch.finfo(torch.float32).eps 51 | for k in topk: 52 | # Avoid causing ZeroDivisionError when all pixels 53 | # of an image are ignored 54 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps 55 | if ignore_index is not None: 56 | total_num = target[target != ignore_index].numel() + eps 57 | else: 58 | total_num = target.numel() + eps 59 | res.append(correct_k.mul_(100.0 / total_num)) 60 | return res[0] if return_single else res 61 | 62 | 63 | class Accuracy(nn.Module): 64 | """Accuracy calculation module.""" 65 | 66 | def __init__(self, topk=(1, ), thresh=None, ignore_index=None): 67 | """Module to calculate the accuracy. 68 | 69 | Args: 70 | topk (tuple, optional): The criterion used to calculate the 71 | accuracy. Defaults to (1,). 72 | thresh (float, optional): If not None, predictions with scores 73 | under this threshold are considered incorrect. Default to None. 74 | """ 75 | super().__init__() 76 | self.topk = topk 77 | self.thresh = thresh 78 | self.ignore_index = ignore_index 79 | 80 | def forward(self, pred, target): 81 | """Forward function to calculate accuracy. 82 | 83 | Args: 84 | pred (torch.Tensor): Prediction of models. 85 | target (torch.Tensor): Target for each prediction. 86 | 87 | Returns: 88 | tuple[float]: The accuracies under different topk criterions. 89 | """ 90 | return accuracy(pred, target, self.topk, self.thresh, 91 | self.ignore_index) 92 | -------------------------------------------------------------------------------- /mmseg/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .fpn import FPN 3 | from .ic_neck import ICNeck 4 | from .jpu import JPU 5 | from .mla_neck import MLANeck 6 | from .multilevel_neck import MultiLevelNeck 7 | 8 | __all__ = [ 9 | 'FPN', 'MultiLevelNeck', 'MLANeck', 'ICNeck', 'JPU' 10 | ] 11 | -------------------------------------------------------------------------------- /mmseg/models/necks/multilevel_neck.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import ConvModule, xavier_init 4 | 5 | from mmseg.ops import resize 6 | from ..builder import NECKS 7 | 8 | 9 | @NECKS.register_module() 10 | class MultiLevelNeck(nn.Module): 11 | """MultiLevelNeck. 12 | 13 | A neck structure connect vit backbone and decoder_heads. 14 | 15 | Args: 16 | in_channels (List[int]): Number of input channels per scale. 17 | out_channels (int): Number of output channels (used at each scale). 18 | scales (List[float]): Scale factors for each input feature map. 19 | Default: [0.5, 1, 2, 4] 20 | norm_cfg (dict): Config dict for normalization layer. Default: None. 21 | act_cfg (dict): Config dict for activation layer in ConvModule. 22 | Default: None. 23 | """ 24 | 25 | def __init__(self, 26 | in_channels, 27 | out_channels, 28 | scales=[0.5, 1, 2, 4], 29 | norm_cfg=None, 30 | act_cfg=None): 31 | super(MultiLevelNeck, self).__init__() 32 | assert isinstance(in_channels, list) 33 | self.in_channels = in_channels 34 | self.out_channels = out_channels 35 | self.scales = scales 36 | self.num_outs = len(scales) 37 | self.lateral_convs = nn.ModuleList() 38 | self.convs = nn.ModuleList() 39 | for in_channel in in_channels: 40 | self.lateral_convs.append( 41 | ConvModule( 42 | in_channel, 43 | out_channels, 44 | kernel_size=1, 45 | norm_cfg=norm_cfg, 46 | act_cfg=act_cfg)) 47 | for _ in range(self.num_outs): 48 | self.convs.append( 49 | ConvModule( 50 | out_channels, 51 | out_channels, 52 | kernel_size=3, 53 | padding=1, 54 | stride=1, 55 | norm_cfg=norm_cfg, 56 | act_cfg=act_cfg)) 57 | 58 | # default init_weights for conv(msra) and norm in ConvModule 59 | def init_weights(self): 60 | for m in self.modules(): 61 | if isinstance(m, nn.Conv2d): 62 | xavier_init(m, distribution='uniform') 63 | 64 | def forward(self, inputs): 65 | assert len(inputs) == len(self.in_channels) 66 | inputs = [ 67 | lateral_conv(inputs[i]) 68 | for i, lateral_conv in enumerate(self.lateral_convs) 69 | ] 70 | # for len(inputs) not equal to self.num_outs 71 | if len(inputs) == 1: 72 | inputs = [inputs[0] for _ in range(self.num_outs)] 73 | outs = [] 74 | for i in range(self.num_outs): 75 | x_resize = resize( 76 | inputs[i], scale_factor=self.scales[i], mode='bilinear') 77 | outs.append(self.convs[i](x_resize)) 78 | return tuple(outs) 79 | -------------------------------------------------------------------------------- /mmseg/models/segmentors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base import BaseSegmentor 3 | from .cascade_encoder_decoder import CascadeEncoderDecoder 4 | from .encoder_decoder import EncoderDecoder 5 | 6 | __all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder'] 7 | -------------------------------------------------------------------------------- /mmseg/models/segmentors/cascade_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from torch import nn 3 | 4 | from mmseg.core import add_prefix 5 | from mmseg.ops import resize 6 | from .. import builder 7 | from ..builder import SEGMENTORS 8 | from .encoder_decoder import EncoderDecoder 9 | 10 | 11 | @SEGMENTORS.register_module() 12 | class CascadeEncoderDecoder(EncoderDecoder): 13 | """Cascade Encoder Decoder segmentors. 14 | 15 | CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of 16 | CascadeEncoderDecoder are cascaded. The output of previous decoder_head 17 | will be the input of next decoder_head. 18 | """ 19 | 20 | def __init__(self, 21 | num_stages, 22 | backbone, 23 | decode_head, 24 | neck=None, 25 | auxiliary_head=None, 26 | train_cfg=None, 27 | test_cfg=None, 28 | pretrained=None, 29 | init_cfg=None): 30 | self.num_stages = num_stages 31 | super(CascadeEncoderDecoder, self).__init__( 32 | backbone=backbone, 33 | decode_head=decode_head, 34 | neck=neck, 35 | auxiliary_head=auxiliary_head, 36 | train_cfg=train_cfg, 37 | test_cfg=test_cfg, 38 | pretrained=pretrained, 39 | init_cfg=init_cfg) 40 | 41 | def _init_decode_head(self, decode_head): 42 | """Initialize ``decode_head``""" 43 | assert isinstance(decode_head, list) 44 | assert len(decode_head) == self.num_stages 45 | self.decode_head = nn.ModuleList() 46 | for i in range(self.num_stages): 47 | self.decode_head.append(builder.build_head(decode_head[i])) 48 | self.align_corners = self.decode_head[-1].align_corners 49 | self.num_classes = self.decode_head[-1].num_classes 50 | 51 | def encode_decode(self, img, img_metas): 52 | """Encode images with backbone and decode into a semantic segmentation 53 | map of the same size as input.""" 54 | x = self.extract_feat(img) 55 | out = self.decode_head[0].forward_test(x, img_metas, self.test_cfg) 56 | for i in range(1, self.num_stages): 57 | out = self.decode_head[i].forward_test(x, out, img_metas, 58 | self.test_cfg) 59 | out = resize( 60 | input=out, 61 | size=img.shape[2:], 62 | mode='bilinear', 63 | align_corners=self.align_corners) 64 | return out 65 | 66 | def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg): 67 | """Run forward function and calculate loss for decode head in 68 | training.""" 69 | losses = dict() 70 | 71 | loss_decode = self.decode_head[0].forward_train( 72 | x, img_metas, gt_semantic_seg, self.train_cfg) 73 | 74 | losses.update(add_prefix(loss_decode, 'decode_0')) 75 | 76 | for i in range(1, self.num_stages): 77 | # forward test again, maybe unnecessary for most methods. 78 | if i == 1: 79 | prev_outputs = self.decode_head[0].forward_test( 80 | x, img_metas, self.test_cfg) 81 | else: 82 | prev_outputs = self.decode_head[i - 1].forward_test( 83 | x, prev_outputs, img_metas, self.test_cfg) 84 | loss_decode = self.decode_head[i].forward_train( 85 | x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg) 86 | losses.update(add_prefix(loss_decode, f'decode_{i}')) 87 | 88 | return losses 89 | -------------------------------------------------------------------------------- /mmseg/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .embed import PatchEmbed 3 | from .inverted_residual import InvertedResidual, InvertedResidualV3 4 | from .make_divisible import make_divisible 5 | from .res_layer import ResLayer 6 | from .se_layer import SELayer 7 | from .self_attention_block import SelfAttentionBlock 8 | from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, 9 | nlc_to_nchw) 10 | from .up_conv_block import UpConvBlock 11 | from .misc import get_paddings_indicator 12 | from .ops.modules import MSDeformAttn 13 | 14 | __all__ = [ 15 | 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', 16 | 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed', 17 | 'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc', 'get_paddings_indicator', 18 | 'MSDeformAttn' 19 | ] 20 | -------------------------------------------------------------------------------- /mmseg/models/utils/make_divisible.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def make_divisible(value, divisor, min_value=None, min_ratio=0.9): 3 | """Make divisible function. 4 | 5 | This function rounds the channel number to the nearest value that can be 6 | divisible by the divisor. It is taken from the original tf repo. It ensures 7 | that all layers have a channel number that is divisible by divisor. It can 8 | be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa 9 | 10 | Args: 11 | value (int): The original channel number. 12 | divisor (int): The divisor to fully divide the channel number. 13 | min_value (int): The minimum value of the output channel. 14 | Default: None, means that the minimum value equal to the divisor. 15 | min_ratio (float): The minimum ratio of the rounded channel number to 16 | the original channel number. Default: 0.9. 17 | 18 | Returns: 19 | int: The modified output channel number. 20 | """ 21 | 22 | if min_value is None: 23 | min_value = divisor 24 | new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) 25 | # Make sure that round down does not go down by more than (1-min_ratio). 26 | if new_value < min_ratio * value: 27 | new_value += divisor 28 | return new_value 29 | -------------------------------------------------------------------------------- /mmseg/models/utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def get_paddings_indicator(actual_num, max_num, axis=0): 4 | """Create boolean mask by actually number of a padded tensor. 5 | 6 | Args: 7 | actual_num ([type]): [description] 8 | max_num ([type]): [description] 9 | 10 | Returns: 11 | [type]: [description] 12 | """ 13 | 14 | actual_num = torch.unsqueeze(actual_num, axis + 1) 15 | # tiled_actual_num: [N, M, 1] 16 | max_num_shape = [1] * len(actual_num.shape) 17 | max_num_shape[axis + 1] = -1 18 | max_num = torch.arange(max_num, dtype=torch.int, device=actual_num.device).view( 19 | max_num_shape 20 | ) 21 | # tiled_actual_num: [[3,3,3,3,3], [4,4,4,4,4], [2,2,2,2,2]] 22 | # tiled_max_num: [[0,1,2,3,4], [0,1,2,3,4], [0,1,2,3,4]] 23 | paddings_indicator = actual_num.int() > max_num 24 | # paddings_indicator shape: [batch_size, max_num] 25 | return paddings_indicator -------------------------------------------------------------------------------- /mmseg/models/utils/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn_func import MSDeformAttnFunction 13 | 14 | -------------------------------------------------------------------------------- /mmseg/models/utils/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | # Copyright (c) Facebook, Inc. and its affiliates. 11 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 12 | 13 | python setup.py build install 14 | -------------------------------------------------------------------------------- /mmseg/models/utils/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn import MSDeformAttn 13 | -------------------------------------------------------------------------------- /mmseg/models/utils/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | import os 13 | import glob 14 | 15 | import torch 16 | 17 | from torch.utils.cpp_extension import CUDA_HOME 18 | from torch.utils.cpp_extension import CppExtension 19 | from torch.utils.cpp_extension import CUDAExtension 20 | 21 | from setuptools import find_packages 22 | from setuptools import setup 23 | 24 | requirements = ["torch", "torchvision"] 25 | 26 | def get_extensions(): 27 | this_dir = os.path.dirname(os.path.abspath(__file__)) 28 | extensions_dir = os.path.join(this_dir, "src") 29 | 30 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 31 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 32 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 33 | 34 | sources = main_file + source_cpu 35 | extension = CppExtension 36 | extra_compile_args = {"cxx": []} 37 | define_macros = [] 38 | 39 | # Force cuda since torch ask for a device, not if cuda is in fact available. 40 | if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None: 41 | extension = CUDAExtension 42 | sources += source_cuda 43 | define_macros += [("WITH_CUDA", None)] 44 | extra_compile_args["nvcc"] = [ 45 | "-DCUDA_HAS_FP16=1", 46 | "-D__CUDA_NO_HALF_OPERATORS__", 47 | "-D__CUDA_NO_HALF_CONVERSIONS__", 48 | "-D__CUDA_NO_HALF2_OPERATORS__", 49 | ] 50 | else: 51 | if CUDA_HOME is None: 52 | raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.') 53 | else: 54 | raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().') 55 | 56 | sources = [os.path.join(extensions_dir, s) for s in sources] 57 | include_dirs = [extensions_dir] 58 | ext_modules = [ 59 | extension( 60 | "MultiScaleDeformableAttention", 61 | sources, 62 | include_dirs=include_dirs, 63 | define_macros=define_macros, 64 | extra_compile_args=extra_compile_args, 65 | ) 66 | ] 67 | return ext_modules 68 | 69 | setup( 70 | name="MultiScaleDeformableAttention", 71 | version="1.0", 72 | author="Weijie Su", 73 | url="https://github.com/fundamentalvision/Deformable-DETR", 74 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 75 | packages=find_packages(exclude=("configs", "tests",)), 76 | ext_modules=get_extensions(), 77 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 78 | ) 79 | -------------------------------------------------------------------------------- /mmseg/models/utils/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include 17 | 18 | #include 19 | #include 20 | 21 | 22 | at::Tensor 23 | ms_deform_attn_cpu_forward( 24 | const at::Tensor &value, 25 | const at::Tensor &spatial_shapes, 26 | const at::Tensor &level_start_index, 27 | const at::Tensor &sampling_loc, 28 | const at::Tensor &attn_weight, 29 | const int im2col_step) 30 | { 31 | AT_ERROR("Not implement on cpu"); 32 | } 33 | 34 | std::vector 35 | ms_deform_attn_cpu_backward( 36 | const at::Tensor &value, 37 | const at::Tensor &spatial_shapes, 38 | const at::Tensor &level_start_index, 39 | const at::Tensor &sampling_loc, 40 | const at::Tensor &attn_weight, 41 | const at::Tensor &grad_output, 42 | const int im2col_step) 43 | { 44 | AT_ERROR("Not implement on cpu"); 45 | } 46 | 47 | -------------------------------------------------------------------------------- /mmseg/models/utils/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor 20 | ms_deform_attn_cpu_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step); 27 | 28 | std::vector 29 | ms_deform_attn_cpu_backward( 30 | const at::Tensor &value, 31 | const at::Tensor &spatial_shapes, 32 | const at::Tensor &level_start_index, 33 | const at::Tensor &sampling_loc, 34 | const at::Tensor &attn_weight, 35 | const at::Tensor &grad_output, 36 | const int im2col_step); 37 | 38 | 39 | -------------------------------------------------------------------------------- /mmseg/models/utils/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor ms_deform_attn_cuda_forward( 20 | const at::Tensor &value, 21 | const at::Tensor &spatial_shapes, 22 | const at::Tensor &level_start_index, 23 | const at::Tensor &sampling_loc, 24 | const at::Tensor &attn_weight, 25 | const int im2col_step); 26 | 27 | std::vector ms_deform_attn_cuda_backward( 28 | const at::Tensor &value, 29 | const at::Tensor &spatial_shapes, 30 | const at::Tensor &level_start_index, 31 | const at::Tensor &sampling_loc, 32 | const at::Tensor &attn_weight, 33 | const at::Tensor &grad_output, 34 | const int im2col_step); 35 | 36 | -------------------------------------------------------------------------------- /mmseg/models/utils/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | 18 | #include "cpu/ms_deform_attn_cpu.h" 19 | 20 | #ifdef WITH_CUDA 21 | #include "cuda/ms_deform_attn_cuda.h" 22 | #endif 23 | 24 | 25 | at::Tensor 26 | ms_deform_attn_forward( 27 | const at::Tensor &value, 28 | const at::Tensor &spatial_shapes, 29 | const at::Tensor &level_start_index, 30 | const at::Tensor &sampling_loc, 31 | const at::Tensor &attn_weight, 32 | const int im2col_step) 33 | { 34 | if (value.type().is_cuda()) 35 | { 36 | #ifdef WITH_CUDA 37 | return ms_deform_attn_cuda_forward( 38 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 39 | #else 40 | AT_ERROR("Not compiled with GPU support"); 41 | #endif 42 | } 43 | AT_ERROR("Not implemented on the CPU"); 44 | } 45 | 46 | std::vector 47 | ms_deform_attn_backward( 48 | const at::Tensor &value, 49 | const at::Tensor &spatial_shapes, 50 | const at::Tensor &level_start_index, 51 | const at::Tensor &sampling_loc, 52 | const at::Tensor &attn_weight, 53 | const at::Tensor &grad_output, 54 | const int im2col_step) 55 | { 56 | if (value.type().is_cuda()) 57 | { 58 | #ifdef WITH_CUDA 59 | return ms_deform_attn_cuda_backward( 60 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 61 | #else 62 | AT_ERROR("Not compiled with GPU support"); 63 | #endif 64 | } 65 | AT_ERROR("Not implemented on the CPU"); 66 | } 67 | 68 | -------------------------------------------------------------------------------- /mmseg/models/utils/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include "ms_deform_attn.h" 17 | 18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 19 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 20 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 21 | } 22 | -------------------------------------------------------------------------------- /mmseg/models/utils/res_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import build_conv_layer, build_norm_layer 3 | from mmcv.runner import Sequential 4 | from torch import nn as nn 5 | 6 | 7 | class ResLayer(Sequential): 8 | """ResLayer to build ResNet style backbone. 9 | 10 | Args: 11 | block (nn.Module): block used to build ResLayer. 12 | inplanes (int): inplanes of block. 13 | planes (int): planes of block. 14 | num_blocks (int): number of blocks. 15 | stride (int): stride of the first block. Default: 1 16 | avg_down (bool): Use AvgPool instead of stride conv when 17 | downsampling in the bottleneck. Default: False 18 | conv_cfg (dict): dictionary to construct and config conv layer. 19 | Default: None 20 | norm_cfg (dict): dictionary to construct and config norm layer. 21 | Default: dict(type='BN') 22 | multi_grid (int | None): Multi grid dilation rates of last 23 | stage. Default: None 24 | contract_dilation (bool): Whether contract first dilation of each layer 25 | Default: False 26 | """ 27 | 28 | def __init__(self, 29 | block, 30 | inplanes, 31 | planes, 32 | num_blocks, 33 | stride=1, 34 | dilation=1, 35 | avg_down=False, 36 | conv_cfg=None, 37 | norm_cfg=dict(type='BN'), 38 | multi_grid=None, 39 | contract_dilation=False, 40 | **kwargs): 41 | self.block = block 42 | 43 | downsample = None 44 | if stride != 1 or inplanes != planes * block.expansion: 45 | downsample = [] 46 | conv_stride = stride 47 | if avg_down: 48 | conv_stride = 1 49 | downsample.append( 50 | nn.AvgPool2d( 51 | kernel_size=stride, 52 | stride=stride, 53 | ceil_mode=True, 54 | count_include_pad=False)) 55 | downsample.extend([ 56 | build_conv_layer( 57 | conv_cfg, 58 | inplanes, 59 | planes * block.expansion, 60 | kernel_size=1, 61 | stride=conv_stride, 62 | bias=False), 63 | build_norm_layer(norm_cfg, planes * block.expansion)[1] 64 | ]) 65 | downsample = nn.Sequential(*downsample) 66 | 67 | layers = [] 68 | if multi_grid is None: 69 | if dilation > 1 and contract_dilation: 70 | first_dilation = dilation // 2 71 | else: 72 | first_dilation = dilation 73 | else: 74 | first_dilation = multi_grid[0] 75 | layers.append( 76 | block( 77 | inplanes=inplanes, 78 | planes=planes, 79 | stride=stride, 80 | dilation=first_dilation, 81 | downsample=downsample, 82 | conv_cfg=conv_cfg, 83 | norm_cfg=norm_cfg, 84 | **kwargs)) 85 | inplanes = planes * block.expansion 86 | for i in range(1, num_blocks): 87 | layers.append( 88 | block( 89 | inplanes=inplanes, 90 | planes=planes, 91 | stride=1, 92 | dilation=dilation if multi_grid is None else multi_grid[i], 93 | conv_cfg=conv_cfg, 94 | norm_cfg=norm_cfg, 95 | **kwargs)) 96 | super(ResLayer, self).__init__(*layers) 97 | -------------------------------------------------------------------------------- /mmseg/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .encoding import Encoding 3 | from .wrappers import Upsample, resize 4 | 5 | __all__ = ['Upsample', 'resize', 'Encoding'] 6 | -------------------------------------------------------------------------------- /mmseg/ops/encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | class Encoding(nn.Module): 8 | """Encoding Layer: a learnable residual encoder. 9 | 10 | Input is of shape (batch_size, channels, height, width). 11 | Output is of shape (batch_size, num_codes, channels). 12 | 13 | Args: 14 | channels: dimension of the features or feature channels 15 | num_codes: number of code words 16 | """ 17 | 18 | def __init__(self, channels, num_codes): 19 | super(Encoding, self).__init__() 20 | # init codewords and smoothing factor 21 | self.channels, self.num_codes = channels, num_codes 22 | std = 1. / ((num_codes * channels)**0.5) 23 | # [num_codes, channels] 24 | self.codewords = nn.Parameter( 25 | torch.empty(num_codes, channels, 26 | dtype=torch.float).uniform_(-std, std), 27 | requires_grad=True) 28 | # [num_codes] 29 | self.scale = nn.Parameter( 30 | torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), 31 | requires_grad=True) 32 | 33 | @staticmethod 34 | def scaled_l2(x, codewords, scale): 35 | num_codes, channels = codewords.size() 36 | batch_size = x.size(0) 37 | reshaped_scale = scale.view((1, 1, num_codes)) 38 | expanded_x = x.unsqueeze(2).expand( 39 | (batch_size, x.size(1), num_codes, channels)) 40 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 41 | 42 | scaled_l2_norm = reshaped_scale * ( 43 | expanded_x - reshaped_codewords).pow(2).sum(dim=3) 44 | return scaled_l2_norm 45 | 46 | @staticmethod 47 | def aggregate(assignment_weights, x, codewords): 48 | num_codes, channels = codewords.size() 49 | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) 50 | batch_size = x.size(0) 51 | 52 | expanded_x = x.unsqueeze(2).expand( 53 | (batch_size, x.size(1), num_codes, channels)) 54 | encoded_feat = (assignment_weights.unsqueeze(3) * 55 | (expanded_x - reshaped_codewords)).sum(dim=1) 56 | return encoded_feat 57 | 58 | def forward(self, x): 59 | assert x.dim() == 4 and x.size(1) == self.channels 60 | # [batch_size, channels, height, width] 61 | batch_size = x.size(0) 62 | # [batch_size, height x width, channels] 63 | x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() 64 | # assignment_weights: [batch_size, channels, num_codes] 65 | assignment_weights = F.softmax( 66 | self.scaled_l2(x, self.codewords, self.scale), dim=2) 67 | # aggregate 68 | encoded_feat = self.aggregate(assignment_weights, x, self.codewords) 69 | return encoded_feat 70 | 71 | def __repr__(self): 72 | repr_str = self.__class__.__name__ 73 | repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ 74 | f'x{self.channels})' 75 | return repr_str 76 | -------------------------------------------------------------------------------- /mmseg/ops/wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def resize(input, 9 | size=None, 10 | scale_factor=None, 11 | mode='nearest', 12 | align_corners=None, 13 | warning=True): 14 | if warning: 15 | if size is not None and align_corners: 16 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 17 | output_h, output_w = tuple(int(x) for x in size) 18 | if output_h > input_h or output_w > input_w: 19 | if ((output_h > 1 and output_w > 1 and input_h > 1 20 | and input_w > 1) and (output_h - 1) % (input_h - 1) 21 | and (output_w - 1) % (input_w - 1)): 22 | warnings.warn( 23 | f'When align_corners={align_corners}, ' 24 | 'the output would more aligned if ' 25 | f'input size {(input_h, input_w)} is `x+1` and ' 26 | f'out size {(output_h, output_w)} is `nx+1`') 27 | return F.interpolate(input, size, scale_factor, mode, align_corners) 28 | 29 | 30 | class Upsample(nn.Module): 31 | 32 | def __init__(self, 33 | size=None, 34 | scale_factor=None, 35 | mode='nearest', 36 | align_corners=None): 37 | super(Upsample, self).__init__() 38 | self.size = size 39 | if isinstance(scale_factor, tuple): 40 | self.scale_factor = tuple(float(factor) for factor in scale_factor) 41 | else: 42 | self.scale_factor = float(scale_factor) if scale_factor else None 43 | self.mode = mode 44 | self.align_corners = align_corners 45 | 46 | def forward(self, x): 47 | if not self.size: 48 | size = [int(t * self.scale_factor) for t in x.shape[-2:]] 49 | else: 50 | size = self.size 51 | return resize(x, size, None, self.mode, self.align_corners) 52 | -------------------------------------------------------------------------------- /mmseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .collect_env import collect_env 3 | from .logger import get_root_logger 4 | from .misc import find_latest_checkpoint 5 | from .set_env import setup_multi_processes 6 | from .util_distribution import build_ddp, build_dp, get_device 7 | from .avs_metric import Eval_Fmeasure, mask_iou 8 | 9 | __all__ = [ 10 | 'get_root_logger', 'collect_env', 'find_latest_checkpoint', 11 | 'setup_multi_processes', 'build_ddp', 'build_dp', 'get_device', 'mask_iou', 'Eval_Fmeasure' 12 | ] 13 | -------------------------------------------------------------------------------- /mmseg/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import collect_env as collect_base_env 3 | from mmcv.utils import get_git_hash 4 | 5 | import mmseg 6 | 7 | 8 | def collect_env(): 9 | """Collect the information of the running environments.""" 10 | env_info = collect_base_env() 11 | env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' 12 | 13 | return env_info 14 | 15 | 16 | if __name__ == '__main__': 17 | for name, val in collect_env().items(): 18 | print('{}: {}'.format(name, val)) 19 | -------------------------------------------------------------------------------- /mmseg/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import logging 3 | 4 | from mmcv.utils import get_logger 5 | 6 | 7 | def get_root_logger(log_file=None, log_level=logging.INFO): 8 | """Get the root logger. 9 | 10 | The logger will be initialized if it has not been initialized. By default a 11 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 12 | also be added. The name of the root logger is the top-level package name, 13 | e.g., "mmseg". 14 | 15 | Args: 16 | log_file (str | None): The log filename. If specified, a FileHandler 17 | will be added to the root logger. 18 | log_level (int): The root logger level. Note that only the process of 19 | rank 0 is affected, while other processes will set the level to 20 | "Error" and be silent most of the time. 21 | 22 | Returns: 23 | logging.Logger: The root logger. 24 | """ 25 | 26 | logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) 27 | 28 | return logger 29 | -------------------------------------------------------------------------------- /mmseg/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import glob 3 | import os.path as osp 4 | import warnings 5 | 6 | 7 | def find_latest_checkpoint(path, suffix='pth'): 8 | """This function is for finding the latest checkpoint. 9 | 10 | It will be used when automatically resume, modified from 11 | https://github.com/open-mmlab/mmdetection/blob/dev-v2.20.0/mmdet/utils/misc.py 12 | 13 | Args: 14 | path (str): The path to find checkpoints. 15 | suffix (str): File extension for the checkpoint. Defaults to pth. 16 | 17 | Returns: 18 | latest_path(str | None): File path of the latest checkpoint. 19 | """ 20 | if not osp.exists(path): 21 | warnings.warn("The path of the checkpoints doesn't exist.") 22 | return None 23 | if osp.exists(osp.join(path, f'latest.{suffix}')): 24 | return osp.join(path, f'latest.{suffix}') 25 | 26 | checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) 27 | if len(checkpoints) == 0: 28 | warnings.warn('The are no checkpoints in the path') 29 | return None 30 | latest = -1 31 | latest_path = '' 32 | for checkpoint in checkpoints: 33 | if len(checkpoint) < len(latest_path): 34 | continue 35 | # `count` is iteration number, as checkpoints are saved as 36 | # 'iter_xx.pth' or 'epoch_xx.pth' and xx is iteration number. 37 | count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) 38 | if count > latest: 39 | latest = count 40 | latest_path = checkpoint 41 | return latest_path 42 | -------------------------------------------------------------------------------- /mmseg/utils/set_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import platform 4 | 5 | import cv2 6 | import torch.multiprocessing as mp 7 | 8 | from ..utils import get_root_logger 9 | 10 | 11 | def setup_multi_processes(cfg): 12 | """Setup multi-processing environment variables.""" 13 | logger = get_root_logger() 14 | 15 | # set multi-process start method 16 | if platform.system() != 'Windows': 17 | mp_start_method = cfg.get('mp_start_method', None) 18 | current_method = mp.get_start_method(allow_none=True) 19 | if mp_start_method in ('fork', 'spawn', 'forkserver'): 20 | logger.info( 21 | f'Multi-processing start method `{mp_start_method}` is ' 22 | f'different from the previous setting `{current_method}`.' 23 | f'It will be force set to `{mp_start_method}`.') 24 | mp.set_start_method(mp_start_method, force=True) 25 | else: 26 | logger.info( 27 | f'Multi-processing start method is `{mp_start_method}`') 28 | 29 | # disable opencv multithreading to avoid system being overloaded 30 | opencv_num_threads = cfg.get('opencv_num_threads', None) 31 | if isinstance(opencv_num_threads, int): 32 | logger.info(f'OpenCV num_threads is `{opencv_num_threads}`') 33 | cv2.setNumThreads(opencv_num_threads) 34 | else: 35 | logger.info(f'OpenCV num_threads is `{cv2.getNumThreads()}') 36 | 37 | if cfg.data.workers_per_gpu > 1: 38 | # setup OMP threads 39 | # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa 40 | omp_num_threads = cfg.get('omp_num_threads', None) 41 | if 'OMP_NUM_THREADS' not in os.environ: 42 | if isinstance(omp_num_threads, int): 43 | logger.info(f'OMP num threads is {omp_num_threads}') 44 | os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) 45 | else: 46 | logger.info(f'OMP num threads is {os.environ["OMP_NUM_THREADS"] }') 47 | 48 | # setup MKL threads 49 | if 'MKL_NUM_THREADS' not in os.environ: 50 | mkl_num_threads = cfg.get('mkl_num_threads', None) 51 | if isinstance(mkl_num_threads, int): 52 | logger.info(f'MKL num threads is {mkl_num_threads}') 53 | os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) 54 | else: 55 | logger.info(f'MKL num threads is {os.environ["MKL_NUM_THREADS"]}') 56 | -------------------------------------------------------------------------------- /mmseg/utils/util_distribution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | import torch 4 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 5 | 6 | from mmseg import digit_version 7 | 8 | dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel} 9 | 10 | ddp_factory = {'cuda': MMDistributedDataParallel} 11 | 12 | 13 | def build_dp(model, device='cuda', dim=0, *args, **kwargs): 14 | """build DataParallel module by device type. 15 | 16 | if device is cuda, return a MMDataParallel module; if device is mlu, 17 | return a MLUDataParallel module. 18 | 19 | Args: 20 | model (:class:`nn.Module`): module to be parallelized. 21 | device (str): device type, cuda, cpu or mlu. Defaults to cuda. 22 | dim (int): Dimension used to scatter the data. Defaults to 0. 23 | 24 | Returns: 25 | :class:`nn.Module`: parallelized module. 26 | """ 27 | if device == 'cuda': 28 | model = model.cuda() 29 | elif device == 'mlu': 30 | assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \ 31 | 'Please use MMCV >= 1.5.0 for MLU training!' 32 | from mmcv.device.mlu import MLUDataParallel 33 | dp_factory['mlu'] = MLUDataParallel 34 | model = model.mlu() 35 | 36 | return dp_factory[device](model, dim=dim, *args, **kwargs) 37 | 38 | 39 | def build_ddp(model, device='cuda', *args, **kwargs): 40 | """Build DistributedDataParallel module by device type. 41 | 42 | If device is cuda, return a MMDistributedDataParallel module; 43 | if device is mlu, return a MLUDistributedDataParallel module. 44 | 45 | Args: 46 | model (:class:`nn.Module`): module to be parallelized. 47 | device (str): device type, mlu or cuda. 48 | 49 | Returns: 50 | :class:`nn.Module`: parallelized module. 51 | 52 | References: 53 | .. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel. 54 | DistributedDataParallel.html 55 | """ 56 | assert device in ['cuda', 'mlu'], 'Only available for cuda or mlu devices.' 57 | if device == 'cuda': 58 | model = model.cuda() 59 | elif device == 'mlu': 60 | assert digit_version(mmcv.__version__) >= digit_version('1.5.0'), \ 61 | 'Please use MMCV >= 1.5.0 for MLU training!' 62 | from mmcv.device.mlu import MLUDistributedDataParallel 63 | ddp_factory['mlu'] = MLUDistributedDataParallel 64 | model = model.mlu() 65 | 66 | return ddp_factory[device](model, *args, **kwargs) 67 | 68 | 69 | def is_mlu_available(): 70 | """Returns a bool indicating if MLU is currently available.""" 71 | return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available() 72 | 73 | 74 | def get_device(): 75 | """Returns an available device, cpu, cuda or mlu.""" 76 | is_device_available = { 77 | 'cuda': torch.cuda.is_available(), 78 | 'mlu': is_mlu_available() 79 | } 80 | device_list = [k for k, v in is_device_available.items() if v] 81 | return device_list[0] if len(device_list) == 1 else 'cpu' 82 | -------------------------------------------------------------------------------- /mmseg/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | 3 | __version__ = '0.26.0' 4 | 5 | 6 | def parse_version_info(version_str): 7 | version_info = [] 8 | for x in version_str.split('.'): 9 | if x.isdigit(): 10 | version_info.append(int(x)) 11 | elif x.find('rc') != -1: 12 | patch_version = x.split('rc') 13 | version_info.append(int(patch_version[0])) 14 | version_info.append(f'rc{patch_version[1]}') 15 | return tuple(version_info) 16 | 17 | 18 | version_info = parse_version_info(__version__) 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.5.1 2 | munkres==1.1.4 3 | numpy==1.22.3 4 | opencv_python==4.1.2.30 5 | ortools==9.3.10497 6 | packaging==21.3 7 | pandas==1.5.2 8 | Pillow==9.5.0 9 | prettytable==3.4.1 10 | scipy==1.4.1 11 | Shapely==1.8.4 12 | tabulate==0.8.10 13 | tensorflow==2.4.0 14 | tqdm==4.64.0 15 | ujson==5.5.0 16 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [yapf] 2 | based_on_style = pep8 3 | blank_line_before_nested_class_or_def = true 4 | split_before_expression_after_opening_paren = true 5 | 6 | [isort] 7 | line_length = 79 8 | multi_line_output = 0 9 | extra_standard_library = setuptools 10 | known_first_party = mmseg 11 | known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,packaging,prettytable,pytest,pytorch_sphinx_theme,requests,scipy,seaborn,torch,ts 12 | no_lines_before = STDLIB,LOCALFOLDER 13 | default_section = THIRDPARTY 14 | 15 | # ignore-words-list needs to be lowercase format. For example, if we want to 16 | # ignore word "BA", then we need to append "ba" to ignore-words-list rather 17 | # than "BA" 18 | [codespell] 19 | skip = *.po,*.ts,*.ipynb 20 | count = 21 | quiet-level = 3 22 | ignore-words-list = formating,sur,hist,dota,ba 23 | -------------------------------------------------------------------------------- /tools/convert_datasets/chase_db1.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import tempfile 6 | import zipfile 7 | 8 | import mmcv 9 | 10 | CHASE_DB1_LEN = 28 * 3 11 | TRAINING_LEN = 60 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Convert CHASE_DB1 dataset to mmsegmentation format') 17 | parser.add_argument('dataset_path', help='path of CHASEDB1.zip') 18 | parser.add_argument('--tmp_dir', help='path of the temporary directory') 19 | parser.add_argument('-o', '--out_dir', help='output path') 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | def main(): 25 | args = parse_args() 26 | dataset_path = args.dataset_path 27 | if args.out_dir is None: 28 | out_dir = osp.join('data', 'CHASE_DB1') 29 | else: 30 | out_dir = args.out_dir 31 | 32 | print('Making directories...') 33 | mmcv.mkdir_or_exist(out_dir) 34 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images')) 35 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training')) 36 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) 37 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations')) 38 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) 39 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) 40 | 41 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 42 | print('Extracting CHASEDB1.zip...') 43 | zip_file = zipfile.ZipFile(dataset_path) 44 | zip_file.extractall(tmp_dir) 45 | 46 | print('Generating training dataset...') 47 | 48 | assert len(os.listdir(tmp_dir)) == CHASE_DB1_LEN, \ 49 | 'len(os.listdir(tmp_dir)) != {}'.format(CHASE_DB1_LEN) 50 | 51 | for img_name in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: 52 | img = mmcv.imread(osp.join(tmp_dir, img_name)) 53 | if osp.splitext(img_name)[1] == '.jpg': 54 | mmcv.imwrite( 55 | img, 56 | osp.join(out_dir, 'images', 'training', 57 | osp.splitext(img_name)[0] + '.png')) 58 | else: 59 | # The annotation img should be divided by 128, because some of 60 | # the annotation imgs are not standard. We should set a 61 | # threshold to convert the nonstandard annotation imgs. The 62 | # value divided by 128 is equivalent to '1 if value >= 128 63 | # else 0' 64 | mmcv.imwrite( 65 | img[:, :, 0] // 128, 66 | osp.join(out_dir, 'annotations', 'training', 67 | osp.splitext(img_name)[0] + '.png')) 68 | 69 | for img_name in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: 70 | img = mmcv.imread(osp.join(tmp_dir, img_name)) 71 | if osp.splitext(img_name)[1] == '.jpg': 72 | mmcv.imwrite( 73 | img, 74 | osp.join(out_dir, 'images', 'validation', 75 | osp.splitext(img_name)[0] + '.png')) 76 | else: 77 | mmcv.imwrite( 78 | img[:, :, 0] // 128, 79 | osp.join(out_dir, 'annotations', 'validation', 80 | osp.splitext(img_name)[0] + '.png')) 81 | 82 | print('Removing the temporary files...') 83 | 84 | print('Done!') 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /tools/convert_datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | 5 | import mmcv 6 | from cityscapesscripts.preparation.json2labelImg import json2labelImg 7 | 8 | 9 | def convert_json_to_label(json_file): 10 | label_file = json_file.replace('_polygons.json', '_labelTrainIds.png') 11 | json2labelImg(json_file, label_file, 'trainIds') 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Convert Cityscapes annotations to TrainIds') 17 | parser.add_argument('cityscapes_path', help='cityscapes data path') 18 | parser.add_argument('--gt-dir', default='gtFine', type=str) 19 | parser.add_argument('-o', '--out-dir', help='output path') 20 | parser.add_argument( 21 | '--nproc', default=1, type=int, help='number of process') 22 | args = parser.parse_args() 23 | return args 24 | 25 | 26 | def main(): 27 | args = parse_args() 28 | cityscapes_path = args.cityscapes_path 29 | out_dir = args.out_dir if args.out_dir else cityscapes_path 30 | mmcv.mkdir_or_exist(out_dir) 31 | 32 | gt_dir = osp.join(cityscapes_path, args.gt_dir) 33 | 34 | poly_files = [] 35 | for poly in mmcv.scandir(gt_dir, '_polygons.json', recursive=True): 36 | poly_file = osp.join(gt_dir, poly) 37 | poly_files.append(poly_file) 38 | if args.nproc > 1: 39 | mmcv.track_parallel_progress(convert_json_to_label, poly_files, 40 | args.nproc) 41 | else: 42 | mmcv.track_progress(convert_json_to_label, poly_files) 43 | 44 | split_names = ['train', 'val', 'test'] 45 | 46 | for split in split_names: 47 | filenames = [] 48 | for poly in mmcv.scandir( 49 | osp.join(gt_dir, split), '_polygons.json', recursive=True): 50 | filenames.append(poly.replace('_gtFine_polygons.json', '')) 51 | with open(osp.join(out_dir, f'{split}.txt'), 'w') as f: 52 | f.writelines(f + '\n' for f in filenames) 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /tools/convert_datasets/loveda.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import shutil 6 | import tempfile 7 | import zipfile 8 | 9 | import mmcv 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser( 14 | description='Convert LoveDA dataset to mmsegmentation format') 15 | parser.add_argument('dataset_path', help='LoveDA folder path') 16 | parser.add_argument('--tmp_dir', help='path of the temporary directory') 17 | parser.add_argument('-o', '--out_dir', help='output path') 18 | args = parser.parse_args() 19 | return args 20 | 21 | 22 | def main(): 23 | args = parse_args() 24 | dataset_path = args.dataset_path 25 | if args.out_dir is None: 26 | out_dir = osp.join('data', 'loveDA') 27 | else: 28 | out_dir = args.out_dir 29 | 30 | print('Making directories...') 31 | mmcv.mkdir_or_exist(out_dir) 32 | mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir')) 33 | mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) 34 | mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) 35 | mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test')) 36 | mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir')) 37 | mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) 38 | mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) 39 | 40 | assert 'Train.zip' in os.listdir(dataset_path), \ 41 | 'Train.zip is not in {}'.format(dataset_path) 42 | assert 'Val.zip' in os.listdir(dataset_path), \ 43 | 'Val.zip is not in {}'.format(dataset_path) 44 | assert 'Test.zip' in os.listdir(dataset_path), \ 45 | 'Test.zip is not in {}'.format(dataset_path) 46 | 47 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 48 | for dataset in ['Train', 'Val', 'Test']: 49 | zip_file = zipfile.ZipFile( 50 | os.path.join(dataset_path, dataset + '.zip')) 51 | zip_file.extractall(tmp_dir) 52 | data_type = dataset.lower() 53 | for location in ['Rural', 'Urban']: 54 | for image_type in ['images_png', 'masks_png']: 55 | if image_type == 'images_png': 56 | dst = osp.join(out_dir, 'img_dir', data_type) 57 | else: 58 | dst = osp.join(out_dir, 'ann_dir', data_type) 59 | if dataset == 'Test' and image_type == 'masks_png': 60 | continue 61 | else: 62 | src_dir = osp.join(tmp_dir, dataset, location, 63 | image_type) 64 | src_lst = os.listdir(src_dir) 65 | for file in src_lst: 66 | shutil.move(osp.join(src_dir, file), dst) 67 | print('Removing the temporary files...') 68 | 69 | print('Done!') 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /tools/convert_datasets/pascal_context.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from functools import partial 5 | 6 | import mmcv 7 | import numpy as np 8 | from detail import Detail 9 | from PIL import Image 10 | 11 | _mapping = np.sort( 12 | np.array([ 13 | 0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284, 14 | 158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59, 15 | 440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355, 16 | 85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115 17 | ])) 18 | _key = np.array(range(len(_mapping))).astype('uint8') 19 | 20 | 21 | def generate_labels(img_id, detail, out_dir): 22 | 23 | def _class_to_index(mask, _mapping, _key): 24 | # assert the values 25 | values = np.unique(mask) 26 | for i in range(len(values)): 27 | assert (values[i] in _mapping) 28 | index = np.digitize(mask.ravel(), _mapping, right=True) 29 | return _key[index].reshape(mask.shape) 30 | 31 | mask = Image.fromarray( 32 | _class_to_index(detail.getMask(img_id), _mapping=_mapping, _key=_key)) 33 | filename = img_id['file_name'] 34 | mask.save(osp.join(out_dir, filename.replace('jpg', 'png'))) 35 | return osp.splitext(osp.basename(filename))[0] 36 | 37 | 38 | def parse_args(): 39 | parser = argparse.ArgumentParser( 40 | description='Convert PASCAL VOC annotations to mmsegmentation format') 41 | parser.add_argument('devkit_path', help='pascal voc devkit path') 42 | parser.add_argument('json_path', help='annoation json filepath') 43 | parser.add_argument('-o', '--out_dir', help='output path') 44 | args = parser.parse_args() 45 | return args 46 | 47 | 48 | def main(): 49 | args = parse_args() 50 | devkit_path = args.devkit_path 51 | if args.out_dir is None: 52 | out_dir = osp.join(devkit_path, 'VOC2010', 'SegmentationClassContext') 53 | else: 54 | out_dir = args.out_dir 55 | json_path = args.json_path 56 | mmcv.mkdir_or_exist(out_dir) 57 | img_dir = osp.join(devkit_path, 'VOC2010', 'JPEGImages') 58 | 59 | train_detail = Detail(json_path, img_dir, 'train') 60 | train_ids = train_detail.getImgs() 61 | 62 | val_detail = Detail(json_path, img_dir, 'val') 63 | val_ids = val_detail.getImgs() 64 | 65 | mmcv.mkdir_or_exist( 66 | osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext')) 67 | 68 | train_list = mmcv.track_progress( 69 | partial(generate_labels, detail=train_detail, out_dir=out_dir), 70 | train_ids) 71 | with open( 72 | osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext', 73 | 'train.txt'), 'w') as f: 74 | f.writelines(line + '\n' for line in sorted(train_list)) 75 | 76 | val_list = mmcv.track_progress( 77 | partial(generate_labels, detail=val_detail, out_dir=out_dir), val_ids) 78 | with open( 79 | osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext', 80 | 'val.txt'), 'w') as f: 81 | f.writelines(line + '\n' for line in sorted(val_list)) 82 | 83 | print('Done!') 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /tools/convert_datasets/voc_aug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from functools import partial 5 | 6 | import mmcv 7 | import numpy as np 8 | from PIL import Image 9 | from scipy.io import loadmat 10 | 11 | AUG_LEN = 10582 12 | 13 | 14 | def convert_mat(mat_file, in_dir, out_dir): 15 | data = loadmat(osp.join(in_dir, mat_file)) 16 | mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8) 17 | seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png')) 18 | Image.fromarray(mask).save(seg_filename, 'PNG') 19 | 20 | 21 | def generate_aug_list(merged_list, excluded_list): 22 | return list(set(merged_list) - set(excluded_list)) 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser( 27 | description='Convert PASCAL VOC annotations to mmsegmentation format') 28 | parser.add_argument('devkit_path', help='pascal voc devkit path') 29 | parser.add_argument('aug_path', help='pascal voc aug path') 30 | parser.add_argument('-o', '--out_dir', help='output path') 31 | parser.add_argument( 32 | '--nproc', default=1, type=int, help='number of process') 33 | args = parser.parse_args() 34 | return args 35 | 36 | 37 | def main(): 38 | args = parse_args() 39 | devkit_path = args.devkit_path 40 | aug_path = args.aug_path 41 | nproc = args.nproc 42 | if args.out_dir is None: 43 | out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug') 44 | else: 45 | out_dir = args.out_dir 46 | mmcv.mkdir_or_exist(out_dir) 47 | in_dir = osp.join(aug_path, 'dataset', 'cls') 48 | 49 | mmcv.track_parallel_progress( 50 | partial(convert_mat, in_dir=in_dir, out_dir=out_dir), 51 | list(mmcv.scandir(in_dir, suffix='.mat')), 52 | nproc=nproc) 53 | 54 | full_aug_list = [] 55 | with open(osp.join(aug_path, 'dataset', 'train.txt')) as f: 56 | full_aug_list += [line.strip() for line in f] 57 | with open(osp.join(aug_path, 'dataset', 'val.txt')) as f: 58 | full_aug_list += [line.strip() for line in f] 59 | 60 | with open( 61 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 62 | 'train.txt')) as f: 63 | ori_train_list = [line.strip() for line in f] 64 | with open( 65 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 66 | 'val.txt')) as f: 67 | val_list = [line.strip() for line in f] 68 | 69 | aug_train_list = generate_aug_list(ori_train_list + full_aug_list, 70 | val_list) 71 | assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format( 72 | AUG_LEN) 73 | 74 | with open( 75 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 76 | 'trainaug.txt'), 'w') as f: 77 | f.writelines(line + '\n' for line in aug_train_list) 78 | 79 | aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list) 80 | assert len(aug_list) == AUG_LEN - len( 81 | ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN - 82 | len(ori_train_list)) 83 | with open( 84 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'), 85 | 'w') as f: 86 | f.writelines(line + '\n' for line in aug_list) 87 | 88 | print('Done!') 89 | 90 | 91 | if __name__ == '__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | CHECKPOINT=$2 3 | GPUS=$3 4 | NNODES=${NNODES:-1} 5 | NODE_RANK=${NODE_RANK:-0} 6 | PORT=${PORT:-29500} 7 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 8 | 9 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 10 | python -m torch.distributed.launch \ 11 | --nnodes=$NNODES \ 12 | --node_rank=$NODE_RANK \ 13 | --master_addr=$MASTER_ADDR \ 14 | --nproc_per_node=$GPUS \ 15 | --master_port=$PORT \ 16 | $(dirname "$0")/test.py \ 17 | $CONFIG \ 18 | $CHECKPOINT \ 19 | --launcher pytorch \ 20 | ${@:4} 21 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | GPUS=$2 3 | NNODES=${NNODES:-1} 4 | NODE_RANK=${NODE_RANK:-0} 5 | PORT=${PORT:-49500} 6 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 7 | 8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 9 | python -m torch.distributed.launch \ 10 | --nnodes=$NNODES \ 11 | --node_rank=$NODE_RANK \ 12 | --master_addr=$MASTER_ADDR \ 13 | --nproc_per_node=$GPUS \ 14 | --master_port=$PORT \ 15 | $(dirname "$0")/train_dist.py \ 16 | $CONFIG \ 17 | --launcher pytorch ${@:3} 18 | -------------------------------------------------------------------------------- /tools/dist_train_multinode.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | GPUS=$2 3 | NNODES=${NNODES:-2} 4 | NODE_RANK=$3 5 | PORT=${PORT:-12355} 6 | MASTER_ADDR=${MASTER_ADDR:-"10.12.1.84"} 7 | 8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 9 | python -m torch.distributed.launch \ 10 | --nnodes=$NNODES \ 11 | --node_rank=$NODE_RANK \ 12 | --master_addr=$MASTER_ADDR \ 13 | --nproc_per_node=$GPUS \ 14 | --master_port=$PORT \ 15 | $(dirname "$0")/train_dist.py \ 16 | $CONFIG \ 17 | --launcher pytorch ${@:4} 18 | -------------------------------------------------------------------------------- /tools/get_flops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | from mmcv import Config 5 | from mmcv.cnn import get_model_complexity_info 6 | 7 | from mmseg.models import build_segmentor 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser( 12 | description='Get the FLOPs of a segmentor') 13 | parser.add_argument('config', help='train config file path') 14 | parser.add_argument( 15 | '--shape', 16 | type=int, 17 | nargs='+', 18 | default=[360, 480], 19 | help='input image size') 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | def main(): 25 | 26 | args = parse_args() 27 | 28 | if len(args.shape) == 1: 29 | input_shape = (3, args.shape[0], args.shape[0]) 30 | elif len(args.shape) == 2: 31 | input_shape = (3, ) + tuple(args.shape) 32 | else: 33 | raise ValueError('invalid input shape') 34 | 35 | cfg = Config.fromfile(args.config) 36 | cfg.model.pretrained = None 37 | model = build_segmentor( 38 | cfg.model, 39 | train_cfg=cfg.get('train_cfg'), 40 | test_cfg=cfg.get('test_cfg')).cuda() 41 | model.eval() 42 | 43 | if hasattr(model, 'forward_dummy'): 44 | model.forward = model.forward_dummy 45 | else: 46 | raise NotImplementedError( 47 | 'FLOPs counter is currently not currently supported with {}'. 48 | format(model.__class__.__name__)) 49 | 50 | flops, params = get_model_complexity_info(model, input_shape) 51 | split_line = '=' * 30 52 | print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format( 53 | split_line, input_shape, flops, params)) 54 | print('!!!Please be cautious if you use the results in papers. ' 55 | 'You may need to check if all ops are supported and verify that the ' 56 | 'flops computation is correct.') 57 | 58 | 59 | if __name__ == '__main__': 60 | main() 61 | -------------------------------------------------------------------------------- /tools/model_converters/beit2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from collections import OrderedDict 5 | 6 | import mmcv 7 | import torch 8 | from mmcv.runner import CheckpointLoader 9 | 10 | 11 | def convert_beit(ckpt): 12 | new_ckpt = OrderedDict() 13 | 14 | for k, v in ckpt.items(): 15 | if k.startswith('blocks'): 16 | new_key = k.replace('blocks', 'layers') 17 | if 'norm' in new_key: 18 | new_key = new_key.replace('norm', 'ln') 19 | elif 'mlp.fc1' in new_key: 20 | new_key = new_key.replace('mlp.fc1', 'ffn.layers.0.0') 21 | elif 'mlp.fc2' in new_key: 22 | new_key = new_key.replace('mlp.fc2', 'ffn.layers.1') 23 | new_ckpt[new_key] = v 24 | elif k.startswith('patch_embed'): 25 | new_key = k.replace('patch_embed.proj', 'patch_embed.projection') 26 | new_ckpt[new_key] = v 27 | else: 28 | new_key = k 29 | new_ckpt[new_key] = v 30 | 31 | return new_ckpt 32 | 33 | 34 | def main(): 35 | parser = argparse.ArgumentParser( 36 | description='Convert keys in official pretrained beit models to' 37 | 'MMSegmentation style.') 38 | parser.add_argument('src', help='src model path or url') 39 | # The dst path must be a full path of the new checkpoint. 40 | parser.add_argument('dst', help='save path') 41 | args = parser.parse_args() 42 | 43 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 44 | if 'state_dict' in checkpoint: 45 | state_dict = checkpoint['state_dict'] 46 | elif 'model' in checkpoint: 47 | state_dict = checkpoint['model'] 48 | else: 49 | state_dict = checkpoint 50 | weight = convert_beit(state_dict) 51 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 52 | torch.save(weight, args.dst) 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /tools/model_converters/mit2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from collections import OrderedDict 5 | 6 | import mmcv 7 | import torch 8 | from mmcv.runner import CheckpointLoader 9 | 10 | 11 | def convert_mit(ckpt): 12 | new_ckpt = OrderedDict() 13 | # Process the concat between q linear weights and kv linear weights 14 | for k, v in ckpt.items(): 15 | if k.startswith('head'): 16 | continue 17 | # patch embedding conversion 18 | elif k.startswith('patch_embed'): 19 | stage_i = int(k.split('.')[0].replace('patch_embed', '')) 20 | new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0') 21 | new_v = v 22 | if 'proj.' in new_k: 23 | new_k = new_k.replace('proj.', 'projection.') 24 | # transformer encoder layer conversion 25 | elif k.startswith('block'): 26 | stage_i = int(k.split('.')[0].replace('block', '')) 27 | new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1') 28 | new_v = v 29 | if 'attn.q.' in new_k: 30 | sub_item_k = k.replace('q.', 'kv.') 31 | new_k = new_k.replace('q.', 'attn.in_proj_') 32 | new_v = torch.cat([v, ckpt[sub_item_k]], dim=0) 33 | elif 'attn.kv.' in new_k: 34 | continue 35 | elif 'attn.proj.' in new_k: 36 | new_k = new_k.replace('proj.', 'attn.out_proj.') 37 | elif 'attn.sr.' in new_k: 38 | new_k = new_k.replace('sr.', 'sr.') 39 | elif 'mlp.' in new_k: 40 | string = f'{new_k}-' 41 | new_k = new_k.replace('mlp.', 'ffn.layers.') 42 | if 'fc1.weight' in new_k or 'fc2.weight' in new_k: 43 | new_v = v.reshape((*v.shape, 1, 1)) 44 | new_k = new_k.replace('fc1.', '0.') 45 | new_k = new_k.replace('dwconv.dwconv.', '1.') 46 | new_k = new_k.replace('fc2.', '4.') 47 | string += f'{new_k} {v.shape}-{new_v.shape}' 48 | # norm layer conversion 49 | elif k.startswith('norm'): 50 | stage_i = int(k.split('.')[0].replace('norm', '')) 51 | new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2') 52 | new_v = v 53 | else: 54 | new_k = k 55 | new_v = v 56 | new_ckpt[new_k] = new_v 57 | return new_ckpt 58 | 59 | 60 | def main(): 61 | parser = argparse.ArgumentParser( 62 | description='Convert keys in official pretrained segformer to ' 63 | 'MMSegmentation style.') 64 | parser.add_argument('src', help='src model path or url') 65 | # The dst path must be a full path of the new checkpoint. 66 | parser.add_argument('dst', help='save path') 67 | args = parser.parse_args() 68 | 69 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 70 | if 'state_dict' in checkpoint: 71 | state_dict = checkpoint['state_dict'] 72 | elif 'model' in checkpoint: 73 | state_dict = checkpoint['model'] 74 | else: 75 | state_dict = checkpoint 76 | weight = convert_mit(state_dict) 77 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 78 | torch.save(weight, args.dst) 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /tools/model_converters/stdc2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | 5 | import mmcv 6 | import torch 7 | from mmcv.runner import CheckpointLoader 8 | 9 | 10 | def convert_stdc(ckpt, stdc_type): 11 | new_state_dict = {} 12 | if stdc_type == 'STDC1': 13 | stage_lst = ['0', '1', '2.0', '2.1', '3.0', '3.1', '4.0', '4.1'] 14 | else: 15 | stage_lst = [ 16 | '0', '1', '2.0', '2.1', '2.2', '2.3', '3.0', '3.1', '3.2', '3.3', 17 | '3.4', '4.0', '4.1', '4.2' 18 | ] 19 | for k, v in ckpt.items(): 20 | ori_k = k 21 | flag = False 22 | if 'cp.' in k: 23 | k = k.replace('cp.', '') 24 | if 'features.' in k: 25 | num_layer = int(k.split('.')[1]) 26 | feature_key_lst = 'features.' + str(num_layer) + '.' 27 | stages_key_lst = 'stages.' + stage_lst[num_layer] + '.' 28 | k = k.replace(feature_key_lst, stages_key_lst) 29 | flag = True 30 | if 'conv_list' in k: 31 | k = k.replace('conv_list', 'layers') 32 | flag = True 33 | if 'avd_layer.' in k: 34 | if 'avd_layer.0' in k: 35 | k = k.replace('avd_layer.0', 'downsample.conv') 36 | elif 'avd_layer.1' in k: 37 | k = k.replace('avd_layer.1', 'downsample.bn') 38 | flag = True 39 | if flag: 40 | new_state_dict[k] = ckpt[ori_k] 41 | 42 | return new_state_dict 43 | 44 | 45 | def main(): 46 | parser = argparse.ArgumentParser( 47 | description='Convert keys in official pretrained STDC1/2 to ' 48 | 'MMSegmentation style.') 49 | parser.add_argument('src', help='src model path') 50 | # The dst path must be a full path of the new checkpoint. 51 | parser.add_argument('dst', help='save path') 52 | parser.add_argument('type', help='model type: STDC1 or STDC2') 53 | args = parser.parse_args() 54 | 55 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 56 | if 'state_dict' in checkpoint: 57 | state_dict = checkpoint['state_dict'] 58 | elif 'model' in checkpoint: 59 | state_dict = checkpoint['model'] 60 | else: 61 | state_dict = checkpoint 62 | 63 | assert args.type in ['STDC1', 64 | 'STDC2'], 'STD type should be STDC1 or STDC2!' 65 | weight = convert_stdc(state_dict, args.type) 66 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 67 | torch.save(weight, args.dst) 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /tools/model_converters/swin2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from collections import OrderedDict 5 | 6 | import mmcv 7 | import torch 8 | from mmcv.runner import CheckpointLoader 9 | 10 | 11 | def convert_swin(ckpt): 12 | new_ckpt = OrderedDict() 13 | 14 | def correct_unfold_reduction_order(x): 15 | out_channel, in_channel = x.shape 16 | x = x.reshape(out_channel, 4, in_channel // 4) 17 | x = x[:, [0, 2, 1, 3], :].transpose(1, 18 | 2).reshape(out_channel, in_channel) 19 | return x 20 | 21 | def correct_unfold_norm_order(x): 22 | in_channel = x.shape[0] 23 | x = x.reshape(4, in_channel // 4) 24 | x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) 25 | return x 26 | 27 | for k, v in ckpt.items(): 28 | if k.startswith('head'): 29 | continue 30 | elif k.startswith('layers'): 31 | new_v = v 32 | if 'attn.' in k: 33 | new_k = k.replace('attn.', 'attn.w_msa.') 34 | elif 'mlp.' in k: 35 | if 'mlp.fc1.' in k: 36 | new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') 37 | elif 'mlp.fc2.' in k: 38 | new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') 39 | else: 40 | new_k = k.replace('mlp.', 'ffn.') 41 | elif 'downsample' in k: 42 | new_k = k 43 | if 'reduction.' in k: 44 | new_v = correct_unfold_reduction_order(v) 45 | elif 'norm.' in k: 46 | new_v = correct_unfold_norm_order(v) 47 | else: 48 | new_k = k 49 | new_k = new_k.replace('layers', 'stages', 1) 50 | elif k.startswith('patch_embed'): 51 | new_v = v 52 | if 'proj' in k: 53 | new_k = k.replace('proj', 'projection') 54 | else: 55 | new_k = k 56 | else: 57 | new_v = v 58 | new_k = k 59 | 60 | new_ckpt[new_k] = new_v 61 | 62 | return new_ckpt 63 | 64 | 65 | def main(): 66 | parser = argparse.ArgumentParser( 67 | description='Convert keys in official pretrained swin models to' 68 | 'MMSegmentation style.') 69 | parser.add_argument('src', help='src model path or url') 70 | # The dst path must be a full path of the new checkpoint. 71 | parser.add_argument('dst', help='save path') 72 | args = parser.parse_args() 73 | 74 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 75 | if 'state_dict' in checkpoint: 76 | state_dict = checkpoint['state_dict'] 77 | elif 'model' in checkpoint: 78 | state_dict = checkpoint['model'] 79 | else: 80 | state_dict = checkpoint 81 | weight = convert_swin(state_dict) 82 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 83 | torch.save(weight, args.dst) 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /tools/model_converters/twins2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from collections import OrderedDict 5 | 6 | import mmcv 7 | import torch 8 | from mmcv.runner import CheckpointLoader 9 | 10 | 11 | def convert_twins(args, ckpt): 12 | 13 | new_ckpt = OrderedDict() 14 | 15 | for k, v in list(ckpt.items()): 16 | new_v = v 17 | if k.startswith('head'): 18 | continue 19 | elif k.startswith('patch_embeds'): 20 | if 'proj.' in k: 21 | new_k = k.replace('proj.', 'projection.') 22 | else: 23 | new_k = k 24 | elif k.startswith('blocks'): 25 | # Union 26 | if 'attn.q.' in k: 27 | new_k = k.replace('q.', 'attn.in_proj_') 28 | new_v = torch.cat([v, ckpt[k.replace('attn.q.', 'attn.kv.')]], 29 | dim=0) 30 | elif 'mlp.fc1' in k: 31 | new_k = k.replace('mlp.fc1', 'ffn.layers.0.0') 32 | elif 'mlp.fc2' in k: 33 | new_k = k.replace('mlp.fc2', 'ffn.layers.1') 34 | # Only pcpvt 35 | elif args.model == 'pcpvt': 36 | if 'attn.proj.' in k: 37 | new_k = k.replace('proj.', 'attn.out_proj.') 38 | else: 39 | new_k = k 40 | 41 | # Only svt 42 | else: 43 | if 'attn.proj.' in k: 44 | k_lst = k.split('.') 45 | if int(k_lst[2]) % 2 == 1: 46 | new_k = k.replace('proj.', 'attn.out_proj.') 47 | else: 48 | new_k = k 49 | else: 50 | new_k = k 51 | new_k = new_k.replace('blocks.', 'layers.') 52 | elif k.startswith('pos_block'): 53 | new_k = k.replace('pos_block', 'position_encodings') 54 | if 'proj.0.' in new_k: 55 | new_k = new_k.replace('proj.0.', 'proj.') 56 | else: 57 | new_k = k 58 | if 'attn.kv.' not in k: 59 | new_ckpt[new_k] = new_v 60 | return new_ckpt 61 | 62 | 63 | def main(): 64 | parser = argparse.ArgumentParser( 65 | description='Convert keys in timm pretrained vit models to ' 66 | 'MMSegmentation style.') 67 | parser.add_argument('src', help='src model path or url') 68 | # The dst path must be a full path of the new checkpoint. 69 | parser.add_argument('dst', help='save path') 70 | parser.add_argument('model', help='model: pcpvt or svt') 71 | args = parser.parse_args() 72 | 73 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 74 | 75 | if 'state_dict' in checkpoint: 76 | # timm checkpoint 77 | state_dict = checkpoint['state_dict'] 78 | else: 79 | state_dict = checkpoint 80 | 81 | weight = convert_twins(args, state_dict) 82 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 83 | torch.save(weight, args.dst) 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /tools/model_converters/vit2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from collections import OrderedDict 5 | 6 | import mmcv 7 | import torch 8 | from mmcv.runner import CheckpointLoader 9 | 10 | 11 | def convert_vit(ckpt): 12 | 13 | new_ckpt = OrderedDict() 14 | 15 | for k, v in ckpt.items(): 16 | if k.startswith('head'): 17 | continue 18 | if k.startswith('norm'): 19 | new_k = k.replace('norm.', 'ln1.') 20 | elif k.startswith('patch_embed'): 21 | if 'proj' in k: 22 | new_k = k.replace('proj', 'projection') 23 | else: 24 | new_k = k 25 | elif k.startswith('blocks'): 26 | if 'norm' in k: 27 | new_k = k.replace('norm', 'ln') 28 | elif 'mlp.fc1' in k: 29 | new_k = k.replace('mlp.fc1', 'ffn.layers.0.0') 30 | elif 'mlp.fc2' in k: 31 | new_k = k.replace('mlp.fc2', 'ffn.layers.1') 32 | elif 'attn.qkv' in k: 33 | new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_') 34 | elif 'attn.proj' in k: 35 | new_k = k.replace('attn.proj', 'attn.attn.out_proj') 36 | else: 37 | new_k = k 38 | new_k = new_k.replace('blocks.', 'layers.') 39 | else: 40 | new_k = k 41 | new_ckpt[new_k] = v 42 | 43 | return new_ckpt 44 | 45 | 46 | def main(): 47 | parser = argparse.ArgumentParser( 48 | description='Convert keys in timm pretrained vit models to ' 49 | 'MMSegmentation style.') 50 | parser.add_argument('src', help='src model path or url') 51 | # The dst path must be a full path of the new checkpoint. 52 | parser.add_argument('dst', help='save path') 53 | args = parser.parse_args() 54 | 55 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 56 | if 'state_dict' in checkpoint: 57 | # timm checkpoint 58 | state_dict = checkpoint['state_dict'] 59 | elif 'model' in checkpoint: 60 | # deit checkpoint 61 | state_dict = checkpoint['model'] 62 | else: 63 | state_dict = checkpoint 64 | weight = convert_vit(state_dict) 65 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 66 | torch.save(weight, args.dst) 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /tools/print_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import warnings 4 | 5 | from mmcv import Config, DictAction 6 | 7 | from mmseg.apis import init_segmentor 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description='Print the whole config') 12 | parser.add_argument('config', help='config file path') 13 | parser.add_argument( 14 | '--graph', action='store_true', help='print the models graph') 15 | parser.add_argument( 16 | '--options', 17 | nargs='+', 18 | action=DictAction, 19 | help="--options is deprecated in favor of --cfg_options' and it will " 20 | 'not be supported in version v0.22.0. Override some settings in the ' 21 | 'used config, the key-value pair in xxx=yyy format will be merged ' 22 | 'into config file. If the value to be overwritten is a list, it ' 23 | 'should be like key="[a,b]" or key=a,b It also allows nested ' 24 | 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' 25 | 'marks are necessary and that no white space is allowed.') 26 | parser.add_argument( 27 | '--cfg-options', 28 | nargs='+', 29 | action=DictAction, 30 | help='override some settings in the used config, the key-value pair ' 31 | 'in xxx=yyy format will be merged into config file. If the value to ' 32 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 33 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 34 | 'Note that the quotation marks are necessary and that no white space ' 35 | 'is allowed.') 36 | args = parser.parse_args() 37 | 38 | if args.options and args.cfg_options: 39 | raise ValueError( 40 | '--options and --cfg-options cannot be both ' 41 | 'specified, --options is deprecated in favor of --cfg-options. ' 42 | '--options will not be supported in version v0.22.0.') 43 | if args.options: 44 | warnings.warn('--options is deprecated in favor of --cfg-options, ' 45 | '--options will not be supported in version v0.22.0.') 46 | args.cfg_options = args.options 47 | 48 | return args 49 | 50 | 51 | def main(): 52 | args = parse_args() 53 | 54 | cfg = Config.fromfile(args.config) 55 | if args.cfg_options is not None: 56 | cfg.merge_from_dict(args.cfg_options) 57 | print(f'Config:\n{cfg.pretty_text}') 58 | # dump config 59 | cfg.dump('example.py') 60 | # dump models graph 61 | if args.graph: 62 | model = init_segmentor(args.config, device='cpu') 63 | print(f'Model graph:\n{str(model)}') 64 | with open('example-graph.txt', 'w') as f: 65 | f.writelines(str(model)) 66 | 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /tools/publish_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import subprocess 4 | 5 | import torch 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser( 10 | description='Process a checkpoint to be published') 11 | parser.add_argument('in_file', help='input checkpoint filename') 12 | parser.add_argument('out_file', help='output checkpoint filename') 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | def process_checkpoint(in_file, out_file): 18 | checkpoint = torch.load(in_file, map_location='cpu') 19 | # remove optimizer for smaller file size 20 | if 'optimizer' in checkpoint: 21 | del checkpoint['optimizer'] 22 | # if it is necessary to remove some sensitive data in checkpoint['meta'], 23 | # add the code here. 24 | torch.save(checkpoint, out_file) 25 | sha = subprocess.check_output(['sha256sum', out_file]).decode() 26 | final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) 27 | subprocess.Popen(['mv', out_file, final_file]) 28 | 29 | 30 | def main(): 31 | args = parse_args() 32 | process_checkpoint(args.in_file, args.out_file) 33 | 34 | 35 | if __name__ == '__main__': 36 | main() 37 | -------------------------------------------------------------------------------- /tools/script.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PYTHONPATH:/mnt/weka/scratch/shaofei.huang/Code/Anchor3dlane 2 | export PYTHONPATH=$PYTHONPATH:/mnt/weka/scratch/shaofei.huang/Code/Anchor3dlane/gen-efficientnet-pytorch 3 | 4 | CUDA_VISIBLE_DEVICES=0 python tools/test.py output/openlane/check/temporal_2stage/train_iter.py \ 5 | output/openlane/check/temporal_2stage/iter_60000.pth --show-dir output/once/check_train/baseline_2stage/test_60000 -------------------------------------------------------------------------------- /tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CHECKPOINT=$4 9 | GPUS=${GPUS:-4} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-4} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | GPUS=${GPUS:-4} 9 | GPUS_PER_NODE=${GPUS_PER_NODE:-4} 10 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 11 | SRUN_ARGS=${SRUN_ARGS:-""} 12 | PY_ARGS=${@:4} 13 | 14 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 15 | srun -p ${PARTITION} \ 16 | --job-name=${JOB_NAME} \ 17 | --gres=gpu:${GPUS_PER_NODE} \ 18 | --ntasks=${GPUS} \ 19 | --ntasks-per-node=${GPUS_PER_NODE} \ 20 | --cpus-per-task=${CPUS_PER_TASK} \ 21 | --kill-on-bad-exit=1 \ 22 | ${SRUN_ARGS} \ 23 | python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS} 24 | -------------------------------------------------------------------------------- /tools/torchserve/mmseg_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import base64 3 | import os 4 | 5 | import cv2 6 | import mmcv 7 | import torch 8 | from mmcv.cnn.utils.sync_bn import revert_sync_batchnorm 9 | from ts.torch_handler.base_handler import BaseHandler 10 | 11 | from mmseg.apis import inference_segmentor, init_segmentor 12 | 13 | 14 | class MMsegHandler(BaseHandler): 15 | 16 | def initialize(self, context): 17 | properties = context.system_properties 18 | self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' 19 | self.device = torch.device(self.map_location + ':' + 20 | str(properties.get('gpu_id')) if torch.cuda. 21 | is_available() else self.map_location) 22 | self.manifest = context.manifest 23 | 24 | model_dir = properties.get('model_dir') 25 | serialized_file = self.manifest['model']['serializedFile'] 26 | checkpoint = os.path.join(model_dir, serialized_file) 27 | self.config_file = os.path.join(model_dir, 'config.py') 28 | 29 | self.model = init_segmentor(self.config_file, checkpoint, self.device) 30 | self.model = revert_sync_batchnorm(self.model) 31 | self.initialized = True 32 | 33 | def preprocess(self, data): 34 | images = [] 35 | 36 | for row in data: 37 | image = row.get('data') or row.get('body') 38 | if isinstance(image, str): 39 | image = base64.b64decode(image) 40 | image = mmcv.imfrombytes(image) 41 | images.append(image) 42 | 43 | return images 44 | 45 | def inference(self, data, *args, **kwargs): 46 | results = [inference_segmentor(self.model, img) for img in data] 47 | return results 48 | 49 | def postprocess(self, data): 50 | output = [] 51 | 52 | for image_result in data: 53 | _, buffer = cv2.imencode('.png', image_result[0].astype('uint8')) 54 | content = buffer.tobytes() 55 | output.append(content) 56 | return output 57 | -------------------------------------------------------------------------------- /tools/torchserve/test_torchserve.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from argparse import ArgumentParser 3 | from io import BytesIO 4 | 5 | import matplotlib.pyplot as plt 6 | import mmcv 7 | import requests 8 | 9 | from mmseg.apis import inference_segmentor, init_segmentor 10 | 11 | 12 | def parse_args(): 13 | parser = ArgumentParser( 14 | description='Compare result of torchserve and pytorch,' 15 | 'and visualize them.') 16 | parser.add_argument('img', help='Image file') 17 | parser.add_argument('config', help='Config file') 18 | parser.add_argument('checkpoint', help='Checkpoint file') 19 | parser.add_argument('model_name', help='The model name in the server') 20 | parser.add_argument( 21 | '--inference-addr', 22 | default='127.0.0.1:8080', 23 | help='Address and port of the inference server') 24 | parser.add_argument( 25 | '--result-image', 26 | type=str, 27 | default=None, 28 | help='save server output in result-image') 29 | parser.add_argument( 30 | '--device', default='cuda:0', help='Device used for inference') 31 | 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | def main(args): 37 | url = 'http://' + args.inference_addr + '/predictions/' + args.model_name 38 | with open(args.img, 'rb') as image: 39 | tmp_res = requests.post(url, image) 40 | content = tmp_res.content 41 | if args.result_image: 42 | with open(args.result_image, 'wb') as out_image: 43 | out_image.write(content) 44 | plt.imshow(mmcv.imread(args.result_image, 'grayscale')) 45 | plt.show() 46 | else: 47 | plt.imshow(plt.imread(BytesIO(content))) 48 | plt.show() 49 | model = init_segmentor(args.config, args.checkpoint, args.device) 50 | image = mmcv.imread(args.img) 51 | result = inference_segmentor(model, image) 52 | plt.imshow(result[0]) 53 | plt.show() 54 | 55 | 56 | if __name__ == '__main__': 57 | args = parse_args() 58 | main(args) 59 | --------------------------------------------------------------------------------