├── .gitignore ├── README.md ├── datasets ├── __init__.py ├── cityscapes_remap.py ├── cityscapes_utils.py ├── coco_utils.py ├── corrupt_images.py ├── deepscene.py ├── deepscene_remap.py ├── mhp.py ├── mhp_remap.py ├── mhp_utils.py ├── nyu.py ├── nyu_dump.py ├── sun.py └── sun_remap.py ├── onnx_export.py ├── onnx_validate.py ├── train.py ├── transforms.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-segmentation 2 | Training of semantic segmentation networks with PyTorch 3 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | print("pytorch-segmentation/datasets/__init__.py") 2 | 3 | #from .mhp import * 4 | -------------------------------------------------------------------------------- /datasets/cityscapes_remap.py: -------------------------------------------------------------------------------- 1 | # 2 | # this script remaps the original Cityscapes class label ID's (range 0-33) 3 | # to lie within a range that the segmentation networks support (21 classes) 4 | # 5 | # see below for the mapping of the original class ID's to the new class ID's, 6 | # along with the class descriptions, some of which are combined from the originals. 7 | # 8 | # this script will overwrite the *_labelIds.png files from the gtCoarse/gtFine sets. 9 | # to run it, launch these example commands for the desired train/train_extra/val sets: 10 | # 11 | # $ python3 cityscapes_remap.py /gtCoarse/train 12 | # $ python3 cityscapes_remap.py /gtCoarse/val 13 | # 14 | import os 15 | import copy 16 | import argparse 17 | 18 | from PIL import Image 19 | from multiprocessing import Pool as ProcessPool 20 | 21 | 22 | # 23 | # map of existing class label ID's (range 0-33) to new ID's (range 0-21) 24 | # 25 | LABEL_MAP = [0, # unlabeled 26 | 1, # ego vehicle 27 | 0, # rectification border 28 | 0, # out of roi 29 | 2, # static 30 | 2, # dynamic 31 | 2, # ground 32 | 3, # road 33 | 4, # sidewalk 34 | 3, # parking 35 | 3, # rail track 36 | 5, # building 37 | 6, # wall 38 | 7, # fence 39 | 7, # guard rail 40 | 3, # bridge 41 | 3, # tunnel 42 | 8, # pole 43 | 8, # polegroup 44 | 9, # traffic light 45 | 10, # traffic sign 46 | 11, # vegetation 47 | 12, # terrain 48 | 13, # sky 49 | 14, # person 50 | 14, # rider 51 | 15, # car 52 | 16, # truck 53 | 17, # bus 54 | 16, # caravan 55 | 16, # trailer 56 | 18, # train 57 | 19, # motorcycle 58 | 20] # bicycle 59 | 60 | # 61 | # new class label names, corresponding to remapped class ID's (range 0-21) 62 | # 63 | """ 64 | void 65 | ego_vehicle 66 | ground 67 | road 68 | sidewalk 69 | building 70 | wall 71 | fence 72 | pole 73 | traffic_light 74 | traffic_sign 75 | vegetation 76 | terrain 77 | sky 78 | person 79 | car 80 | truck 81 | bus 82 | train 83 | motorcycle 84 | bicycle 85 | """ 86 | 87 | def remap_labels(filename): 88 | print(filename) 89 | img = Image.open(filename) 90 | 91 | for y in range(img.height): 92 | for x in range(img.width): 93 | org_label = img.getpixel((x,y)) 94 | img.putpixel((x,y), LABEL_MAP[org_label]) 95 | 96 | img.save(filename) 97 | 98 | 99 | if __name__ == "__main__": 100 | 101 | parser = argparse.ArgumentParser(description='Remap Cityscapes Segmenation Labels') 102 | parser.add_argument('dir', metavar='DIR', help='path to data labels (e.g. gtCoarse/train, ect.)') 103 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 104 | help='number of data loading workers (default: 8)') 105 | args = parser.parse_args() 106 | 107 | img_list = [] 108 | 109 | for city in os.listdir(args.dir): 110 | img_dir = os.path.join(args.dir, city) 111 | 112 | for file_name in os.listdir(img_dir): 113 | if file_name.find("labelIds.png") == -1: 114 | continue 115 | 116 | img_list.append(os.path.join(img_dir, file_name)) 117 | 118 | with ProcessPool(processes=args.workers) as pool: 119 | pool.map(remap_labels, img_list) 120 | 121 | -------------------------------------------------------------------------------- /datasets/cityscapes_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import torch.utils.data 5 | import torchvision 6 | 7 | from PIL import Image 8 | 9 | 10 | # 11 | # note: is is slow to remap the label categories at runtime, 12 | # so use cityscapes_remap.py to do it in advance. 13 | # 14 | class FilterAndRemapCityscapesCategories(object): 15 | def __init__(self, categories, classes): 16 | self.categories = categories 17 | self.classes = classes 18 | print self.classes 19 | 20 | def __call__(self, image, anno): 21 | 22 | anno = copy.deepcopy(anno) 23 | for y in range(anno.height): 24 | for x in range(anno.width): 25 | org_label = anno.getpixel((x,y)) 26 | 27 | if org_label not in self.categories: 28 | anno.putpixel((x,y), 0) 29 | 30 | return image, anno 31 | 32 | 33 | def get_cityscapes(root, image_set, transforms): 34 | 35 | #CAT_LIST = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] 36 | 37 | #transforms = Compose([ 38 | # FilterAndRemapCityscapesCategories(CAT_LIST, torchvision.datasets.Cityscapes.classes), 39 | # transforms 40 | #]) 41 | 42 | dataset = torchvision.datasets.Cityscapes(root, split=image_set, mode='fine', target_type='semantic', 43 | transform=transforms, target_transform=transforms) 44 | 45 | return dataset 46 | -------------------------------------------------------------------------------- /datasets/coco_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import torch.utils.data 5 | import torchvision 6 | 7 | from PIL import Image 8 | from transforms import Compose 9 | from pycocotools import mask as coco_mask 10 | 11 | 12 | class FilterAndRemapCocoCategories(object): 13 | def __init__(self, categories, remap=True): 14 | self.categories = categories 15 | self.remap = remap 16 | 17 | def __call__(self, image, anno): 18 | anno = [obj for obj in anno if obj["category_id"] in self.categories] 19 | if not self.remap: 20 | return image, anno 21 | anno = copy.deepcopy(anno) 22 | for obj in anno: 23 | obj["category_id"] = self.categories.index(obj["category_id"]) 24 | return image, anno 25 | 26 | 27 | def convert_coco_poly_to_mask(segmentations, height, width): 28 | masks = [] 29 | for polygons in segmentations: 30 | rles = coco_mask.frPyObjects(polygons, height, width) 31 | mask = coco_mask.decode(rles) 32 | if len(mask.shape) < 3: 33 | mask = mask[..., None] 34 | mask = torch.as_tensor(mask, dtype=torch.uint8) 35 | mask = mask.any(dim=2) 36 | masks.append(mask) 37 | if masks: 38 | masks = torch.stack(masks, dim=0) 39 | else: 40 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 41 | return masks 42 | 43 | 44 | class ConvertCocoPolysToMask(object): 45 | def __call__(self, image, anno): 46 | w, h = image.size 47 | segmentations = [obj["segmentation"] for obj in anno] 48 | cats = [obj["category_id"] for obj in anno] 49 | if segmentations: 50 | masks = convert_coco_poly_to_mask(segmentations, h, w) 51 | cats = torch.as_tensor(cats, dtype=masks.dtype) 52 | # merge all instance masks into a single segmentation map 53 | # with its corresponding categories 54 | target, _ = (masks * cats[:, None, None]).max(dim=0) 55 | # discard overlapping instances 56 | target[masks.sum(0) > 1] = 255 57 | else: 58 | target = torch.zeros((h, w), dtype=torch.uint8) 59 | target = Image.fromarray(target.numpy()) 60 | return image, target 61 | 62 | 63 | def _coco_remove_images_without_annotations(dataset, cat_list=None): 64 | def _has_valid_annotation(anno): 65 | # if it's empty, there is no annotation 66 | if len(anno) == 0: 67 | return False 68 | # if more than 1k pixels occupied in the image 69 | return sum(obj["area"] for obj in anno) > 1000 70 | 71 | assert isinstance(dataset, torchvision.datasets.CocoDetection) 72 | ids = [] 73 | for ds_idx, img_id in enumerate(dataset.ids): 74 | ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) 75 | anno = dataset.coco.loadAnns(ann_ids) 76 | if cat_list: 77 | anno = [obj for obj in anno if obj["category_id"] in cat_list] 78 | if _has_valid_annotation(anno): 79 | ids.append(ds_idx) 80 | 81 | dataset = torch.utils.data.Subset(dataset, ids) 82 | return dataset 83 | 84 | 85 | def get_coco(root, image_set, transforms): 86 | PATHS = { 87 | "train": ("train2017", os.path.join("annotations", "instances_train2017.json")), 88 | "val": ("val2017", os.path.join("annotations", "instances_val2017.json")), 89 | # "train": ("val2017", os.path.join("annotations", "instances_val2017.json")) 90 | } 91 | CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 92 | 1, 64, 20, 63, 7, 72] 93 | 94 | transforms = Compose([ 95 | FilterAndRemapCocoCategories(CAT_LIST, remap=True), 96 | ConvertCocoPolysToMask(), 97 | transforms 98 | ]) 99 | 100 | img_folder, ann_file = PATHS[image_set] 101 | img_folder = os.path.join(root, img_folder) 102 | ann_file = os.path.join(root, ann_file) 103 | 104 | dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) 105 | 106 | if image_set == "train": 107 | dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST) 108 | 109 | return dataset 110 | -------------------------------------------------------------------------------- /datasets/corrupt_images.py: -------------------------------------------------------------------------------- 1 | # 2 | # this script detects corrupted images in a directory, 3 | # and optionally moves them to a specified directory 4 | # 5 | import argparse 6 | import warnings 7 | import shutil 8 | 9 | from os import listdir 10 | from os import remove 11 | from os.path import join 12 | 13 | from PIL import Image 14 | 15 | parser = argparse.ArgumentParser(description='corrupt image remover') 16 | parser.add_argument('dir', metavar='DIR', help='path to directory') 17 | parser.add_argument('--move', type=str, default=None, help='optional path to directory that corrupt images are moved to') 18 | args = parser.parse_args() 19 | 20 | num_bad = 0 21 | 22 | warnings.filterwarnings("error") 23 | 24 | 25 | for filename in listdir(args.dir): 26 | file_low = filename.lower() 27 | if file_low.endswith('.png') or file_low.endswith('.jpg') or file_low.endswith('.jpeg') or file_low.endswith('.gif'): 28 | file_path = join(args.dir,filename) 29 | try: 30 | #img = Image.open(file_path) # open the image file 31 | #img.verify() # verify that it is, in fact an image 32 | 33 | img = Image.open(file_path) 34 | img.load() 35 | 36 | imgRGB = img.convert('RGB') 37 | 38 | #if img.width < 16 or img.height < 16: 39 | # print('Strange image dimensions ({:d}x{:d}): {:s}'.format(img.width, img.height, file_path)) 40 | 41 | if img.width < 16 or img.height < 16: 42 | print('Bad image dimensions ({:d}x{:d}): {:s}'.format(img.width, img.height, file_path)) # print out the names of corrupt files 43 | 44 | if args.move is not None: 45 | shutil.move(file_path, args.move) #remove(file_path) 46 | 47 | num_bad = num_bad + 1 48 | 49 | except (IOError, SyntaxError, UserWarning, RuntimeWarning) as e: 50 | print('Bad image: {:s}'.format(file_path)) # print out the names of corrupt files 51 | 52 | if args.move is not None: 53 | shutil.move(file_path, args.move) #remove(file_path) 54 | 55 | num_bad = num_bad + 1 56 | 57 | print('Detected {:d} corrupted images from {:s} '.format(num_bad, args.dir)) 58 | -------------------------------------------------------------------------------- /datasets/deepscene.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import math 4 | import torch 5 | 6 | from PIL import Image 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | 10 | class DeepSceneSegmentation(Dataset): 11 | """http://deepscene.cs.uni-freiburg.de/""" 12 | 13 | def __init__(self, root_dir, image_set='train', train_extra=True, transforms=None): 14 | """ 15 | Parameters: 16 | root_dir (string): Root directory of the dumped NYU-Depth dataset. 17 | image_set (string, optional): Select the image_set to use, ``train``, ``val`` 18 | train_extra (bool, optional): If True, use extra images during training 19 | transforms (callable, optional): Optional transform to be applied 20 | on a sample. 21 | """ 22 | self.root_dir = root_dir 23 | self.image_set = image_set 24 | self.transforms = transforms 25 | 26 | self.images = [] 27 | self.targets = [] 28 | 29 | if image_set == 'train': 30 | train_images, train_targets = self.gather_images(os.path.join(root_dir, 'train/rgb'), 31 | os.path.join(root_dir, 'train/GT_index')) 32 | 33 | self.images.extend(train_images) 34 | self.targets.extend(train_targets) 35 | 36 | if train_extra: 37 | extra_images, extra_targets = self.gather_images(os.path.join(root_dir, 'trainextra/rgb'), 38 | os.path.join(root_dir, 'trainextra/GT_index')) 39 | 40 | self.images.extend(extra_images) 41 | self.targets.extend(extra_targets) 42 | 43 | elif image_set == 'val': 44 | val_images, val_targets = self.gather_images(os.path.join(root_dir, 'test/rgb'), 45 | os.path.join(root_dir, 'test/GT_index')) 46 | 47 | self.images.extend(val_images) 48 | self.targets.extend(val_targets) 49 | 50 | def gather_images(self, images_path, labels_path): 51 | def sorted_alphanumeric(data): 52 | convert = lambda text: int(text) if text.isdigit() else text.lower() 53 | alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 54 | return sorted(data, key=alphanum_key) 55 | 56 | image_files = sorted_alphanumeric(os.listdir(images_path)) 57 | label_files = sorted_alphanumeric(os.listdir(labels_path)) 58 | 59 | if len(image_files) != len(label_files): 60 | print('warning: images path has a different number of files than labels path') 61 | print(' ({:d} files) - {:s}'.format(len(image_files), images_path)) 62 | print(' ({:d} files) - {:s}'.format(len(label_files), labels_path)) 63 | 64 | for n in range(len(image_files)): 65 | image_files[n] = os.path.join(images_path, image_files[n]) 66 | label_files[n] = os.path.join(labels_path, label_files[n]) 67 | 68 | #print('{:s} -> {:s}'.format(image_files[n], label_files[n])) 69 | 70 | return image_files, label_files 71 | 72 | def __len__(self): 73 | return len(self.images) 74 | 75 | def __getitem__(self, index): 76 | image = Image.open(self.images[index]).convert('RGB') 77 | target = Image.open(self.targets[index]) 78 | 79 | if self.transforms is not None: 80 | image, target = self.transforms(image, target) 81 | 82 | return image, target 83 | 84 | -------------------------------------------------------------------------------- /datasets/deepscene_remap.py: -------------------------------------------------------------------------------- 1 | # 2 | # this script remaps the original DeepScene friburg_forest dataset annotations 3 | # from RGB images (with 6 classes) to single-channel index images (with 5 classes). 4 | # 5 | # the original dataset can be downloaded from: http://deepscene.cs.uni-freiburg.de/ 6 | # 7 | import os 8 | import copy 9 | import argparse 10 | 11 | from PIL import Image 12 | from multiprocessing import Pool as ProcessPool 13 | 14 | 15 | # 16 | # map of existing class label ID's and colors to new ID's 17 | # each entry consists of a tuple (new_ID, name, color) 18 | # 19 | CLASS_MAP = [ (0, 'trail', (170, 170, 170)), 20 | (1, 'grass', (0, 255, 0)), 21 | (2, 'vegetation', (102, 102, 51)), 22 | (3, 'obstacle', (0, 0, 0)), 23 | (4, 'sky', (0, 120, 255)), 24 | (2, 'void', (0, 60, 0)) ] # 'void' appears to be trees in the dataset, so it is mapped to vegetation 25 | 26 | 27 | def lookup_class(color): 28 | for c in CLASS_MAP: 29 | if color == c[2]: 30 | return c[0] 31 | 32 | print('could not find class with color ' + str(color)) 33 | return -1 34 | 35 | 36 | def remap_labels(args): 37 | input_path = args[0] 38 | output_path = args[1] 39 | colorized = args[2] 40 | 41 | print('{:s} -> {:s}'.format(input_path, output_path)) 42 | 43 | if os.path.isfile(output_path): 44 | print('skipping image {:s}, already exists'.format(output_path)) 45 | return 46 | 47 | img_input = Image.open(input_path) 48 | img_output = Image.new('RGB' if colorized is True else 'L', (img_input.width, img_input.height)) 49 | 50 | for y in range(img_input.height): 51 | for x in range(img_input.width): 52 | org_label = img_input.getpixel((x,y)) 53 | new_label = CLASS_MAP[lookup_class(org_label)][2 if colorized else 0] 54 | img_output.putpixel((x,y), new_label) 55 | 56 | img_output.save(output_path) 57 | 58 | 59 | if __name__ == "__main__": 60 | 61 | parser = argparse.ArgumentParser(description='Remap DeepScene Segmentation Images') 62 | parser.add_argument('input', type=str, metavar='IN', help='path to directory of annotated images to remap') 63 | parser.add_argument('output', type=str, metavar='OUT', help='path to directory to save remaped annotation images') 64 | parser.add_argument('--colorized', action='store_true', help='output colorized segmentation maps (RGB)') 65 | parser.add_argument('--workers', default=8, type=int, metavar='N', help='number of data loading workers (default: 8)') 66 | args = parser.parse_args() 67 | 68 | if not os.path.exists(args.output): 69 | os.makedirs(args.output) 70 | 71 | files = os.listdir(args.input) 72 | worker_args = [] 73 | 74 | for n in range(len(files)): 75 | worker_args.append((os.path.join(args.input, files[n]), os.path.join(args.output, files[n]), args.colorized)) 76 | 77 | #for n in worker_args: 78 | # remap_labels(n) 79 | 80 | with ProcessPool(processes=args.workers) as pool: 81 | pool.map(remap_labels, worker_args) 82 | 83 | -------------------------------------------------------------------------------- /datasets/mhp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | 5 | from PIL import Image 6 | from mhp_utils import mhp_image_list 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | 10 | class MHPSegmentation(Dataset): 11 | """https://lv-mhp.github.io/dataset""" 12 | 13 | def __init__(self, root_dir, image_set='train', transforms=None): 14 | """ 15 | Parameters: 16 | root_dir (string): Root directory of the extracted LV-MHP-V2 dataset. 17 | image_set (string, optional): Select the image_set to use, ``train``, ``val`` 18 | transforms (callable, optional): Optional transform to be applied 19 | on a sample. 20 | """ 21 | self.root_dir = root_dir 22 | self.image_set = image_set 23 | self.transforms = transforms 24 | 25 | self.images = [] 26 | self.targets = [] 27 | 28 | img_list = mhp_image_list(os.path.join(root_dir, 'list/{:s}.txt'.format(image_set))) 29 | 30 | for img_index in img_list: 31 | img_filename = os.path.join(root_dir, image_set, 'images/{:d}.jpg'.format(img_index)) 32 | target_filename = os.path.join(root_dir, image_set, 'parsing_annos/{:d}.png'.format(img_index)) 33 | 34 | if os.path.isfile(img_filename) and os.path.isfile(target_filename): 35 | self.images.append(img_filename) 36 | self.targets.append(target_filename) 37 | 38 | def __len__(self): 39 | return len(self.images) 40 | 41 | def __getitem__(self, index): 42 | image = Image.open(self.images[index]).convert('RGB') 43 | target = Image.open(self.targets[index]) 44 | 45 | if self.transforms is not None: 46 | image, target = self.transforms(image, target) 47 | 48 | return image, target 49 | 50 | -------------------------------------------------------------------------------- /datasets/mhp_remap.py: -------------------------------------------------------------------------------- 1 | # 2 | # this script remaps the original MHP dataset (https://lv-mhp.github.io/) 3 | # class label ID's (range 0-58) to 21 classes that PyTorch FCN-ResNet uses. 4 | # 5 | # see below for the mapping of the original class ID's to the new class ID's, 6 | # along with the class descriptions, some of which are combined from the originals. 7 | # 8 | # this script reads the train/val image list from LV-MHP-V2 and given the path to 9 | # the parsing_annos directory containing the annotations, remaps them to the output. 10 | # 11 | # $ DATA= 12 | # $ python3 mhp_remap.py --list=$DATA/list/train.txt $DATA/train/parsing_annos $DATA/train/parsing_annos_21 13 | # $ python3 mhp_remap.py --list=$DATA/list/val.txt $DATA/val/parsing_annos $DATA/val/parsing_annos_21 14 | # 15 | import os 16 | import copy 17 | import argparse 18 | 19 | from PIL import Image 20 | from mhp_utils import mhp_image_list 21 | from multiprocessing import Pool as ProcessPool 22 | 23 | 24 | # 25 | # map of existing class label ID's (range 0-58) to new ID's (range 0-21) 26 | # 27 | LABEL_MAP = [0, # Background 28 | 1, # Cap/hat 29 | 1, # Helmet 30 | 2, # Face 31 | 3, # Hair 32 | 4, # Left-arm 33 | 4, # Right-arm 34 | 5, # Left-hand 35 | 5, # Right-hand 36 | 19, # Protector 37 | 9, # Bikini/bra 38 | 7, # Jacket/windbreaker/hoodie 39 | 6, # Tee-shirt 40 | 6, # Polo-shirt 41 | 6, # Sweater 42 | 8, # Singlet 43 | 10, # Torso-skin 44 | 11, # Pants 45 | 12, # Shorts/swim-shorts 46 | 12, # Skirt 47 | 13, # Stockings 48 | 13, # Socks 49 | 14, # Left-boot 50 | 14, # Right-boot 51 | 14, # Left-shoe 52 | 14, # Right-shoe 53 | 14, # Left-highheel 54 | 14, # Right-highheel 55 | 14, # Left-sandal 56 | 14, # Right-sandal 57 | 15, # Left-leg 58 | 15, # Right-leg 59 | 16, # Left-foot 60 | 16, # Right-foot 61 | 7, # Coat 62 | 8, # Dress 63 | 8, # Robe 64 | 8, # Jumpsuit 65 | 8, # Other-full-body-clothes 66 | 1, # Headwear 67 | 17, # Backpack 68 | 20, # Ball 69 | 20, # Bats 70 | 19, # Belt 71 | 20, # Bottle 72 | 17, # Carrybag 73 | 17, # Cases 74 | 18, # Sunglasses 75 | 18, # Eyewear 76 | 19, # Glove 77 | 19, # Scarf 78 | 20, # Umbrella 79 | 17, # Wallet/purse 80 | 19, # Watch 81 | 19, # Wristband 82 | 19, # Tie 83 | 19, # Other-accessory 84 | 6, # Other-upper-body-clothes 85 | 11] # Other-lower-body-clothes 86 | 87 | # 88 | # new class label names, corresponding to remapped class ID's (range 0-21) 89 | # 90 | """ 91 | void 92 | hat/helmet/headwear 93 | face 94 | hair 95 | arm 96 | hand 97 | shirt 98 | jacket/coat 99 | dress/robe 100 | bikini/bra 101 | torso_skin 102 | pants 103 | shorts 104 | socks/stockings 105 | shoe/boot 106 | leg 107 | foot 108 | backpack/purse/bag 109 | sunglasses/eyewear 110 | other_accessory 111 | other_item 112 | """ 113 | 114 | def remap_labels(args): 115 | input_dir = args[0] 116 | output_dir = args[1] 117 | img_index = args[2] 118 | 119 | src_images = 0 120 | img_output = None 121 | 122 | # check if this image has already been processed (i.e. by a previous run) 123 | output_path = os.path.join(output_dir, '{:d}.png'.format(img_index)) 124 | 125 | if os.path.isfile(output_path): 126 | print('skipping image {:d}, already exists'.format(img_index)) 127 | return 128 | 129 | # determine the number of source images for this frame 130 | for n in range(1,30): 131 | if os.path.isfile(os.path.join(input_dir, '{:d}_{:02d}_01.png'.format(img_index, n))): 132 | src_images = n 133 | 134 | print('processing image {:d} \t(src_images={:d})'.format(img_index, src_images)) 135 | 136 | # aggregate and remap the source images into one output 137 | for n in range(1, src_images+1): 138 | img_input = Image.open(os.path.join(input_dir, '{:d}_{:02d}_{:02d}.png'.format(img_index, src_images, n))) 139 | 140 | if img_output is None: 141 | img_output = Image.new('L', (img_input.width, img_input.height)) 142 | 143 | for y in range(img_input.height): 144 | for x in range(img_input.width): 145 | org_label = img_input.getpixel((x,y))[0] 146 | new_label = LABEL_MAP[org_label] 147 | 148 | if new_label != 0: # only overwrite non-background pixels 149 | img_output.putpixel((x,y), new_label) 150 | 151 | #if org_label != 0: 152 | # print('img {:d}_{:02d} ({:d}, {:d}) {:d} -> {:d}'.format(img_index, n, x, y, org_label, new_label)) 153 | 154 | img_output.save(output_path) 155 | 156 | 157 | if __name__ == "__main__": 158 | 159 | parser = argparse.ArgumentParser(description='Remap MHP Annotation Images') 160 | parser.add_argument('input', type=str, metavar='IN', help='path to directory of annotated images to remap') 161 | parser.add_argument('output', type=str, metavar='OUT', help='path to directory to save remaped annotation images') 162 | parser.add_argument('--list', type=str, required=True, metavar='LIST', help='path to image list') 163 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers (default: 8)') 164 | args = parser.parse_args() 165 | 166 | if not os.path.exists(args.output): 167 | os.makedirs(args.output) 168 | 169 | img_list = mhp_image_list(args.list) 170 | pool_args = [] 171 | 172 | for img_index in img_list: 173 | pool_args.append( (args.input, args.output, img_index) ) 174 | 175 | with ProcessPool(processes=args.workers) as pool: 176 | pool.map(remap_labels, pool_args) 177 | 178 | -------------------------------------------------------------------------------- /datasets/mhp_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def mhp_image_list(filename): 4 | """ 5 | Read one of the image index lists from LV-MHP-v2/list 6 | 7 | Parameters: 8 | filename (string): path to the image list file 9 | 10 | Returns: 11 | list (int): list of int's that correspond to image names 12 | """ 13 | list_file = open(filename, 'r') 14 | img_list = [] 15 | 16 | while True: 17 | next_line = list_file.readline() 18 | 19 | if not next_line: 20 | break 21 | 22 | img_list.append(int(next_line)) 23 | 24 | return img_list 25 | 26 | -------------------------------------------------------------------------------- /datasets/nyu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | 5 | from PIL import Image 6 | from torch.utils.data import Dataset, DataLoader 7 | 8 | 9 | class NYUDepth(Dataset): 10 | """https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html""" 11 | 12 | def __init__(self, root_dir, image_set='train', transforms=None): 13 | """ 14 | Parameters: 15 | root_dir (string): Root directory of the dumped NYU-Depth dataset. 16 | image_set (string, optional): Select the image_set to use, ``train``, ``val`` 17 | transforms (callable, optional): Optional transform to be applied 18 | on a sample. 19 | """ 20 | self.root_dir = root_dir 21 | self.image_set = image_set 22 | self.transforms = transforms 23 | 24 | self.images = [] 25 | self.targets = [] 26 | 27 | img_list = self.read_image_list(os.path.join(root_dir, '{:s}.txt'.format(image_set))) 28 | 29 | for img_name in img_list: 30 | img_filename = os.path.join(root_dir, 'images/{:s}'.format(img_name)) 31 | target_filename = os.path.join(root_dir, 'depth/{:s}'.format(img_name)) 32 | 33 | if os.path.isfile(img_filename) and os.path.isfile(target_filename): 34 | self.images.append(img_filename) 35 | self.targets.append(target_filename) 36 | 37 | def read_image_list(self, filename): 38 | """ 39 | Read one of the image index lists 40 | 41 | Parameters: 42 | filename (string): path to the image list file 43 | 44 | Returns: 45 | list (int): list of strings that correspond to image names 46 | """ 47 | list_file = open(filename, 'r') 48 | img_list = [] 49 | 50 | while True: 51 | next_line = list_file.readline() 52 | 53 | if not next_line: 54 | break 55 | 56 | img_list.append(next_line.rstrip()) 57 | 58 | return img_list 59 | 60 | def __len__(self): 61 | return len(self.images) 62 | 63 | def __getitem__(self, index): 64 | image = Image.open(self.images[index]).convert('RGB') 65 | target = Image.open(self.targets[index]) 66 | 67 | if self.transforms is not None: 68 | image, target = self.transforms(image, target) 69 | 70 | return image, target 71 | 72 | -------------------------------------------------------------------------------- /datasets/nyu_dump.py: -------------------------------------------------------------------------------- 1 | # 2 | # this script reads the .mat files from the NYU-Depth datasets 3 | # and dumps their contents to individual image files for training 4 | # 5 | import os 6 | import random 7 | import argparse 8 | import h5py 9 | import numpy as np 10 | 11 | from PIL import Image 12 | 13 | 14 | parser = argparse.ArgumentParser(description='Dump NYU-Depth .mat files') 15 | parser.add_argument('input', type=str, nargs='+', metavar='IN', help='paths to input .mat files') 16 | parser.add_argument('--output', type=str, default='dump', metavar='OUT', help='path to directory to save dataset') 17 | parser.add_argument("--images", action="store_true", help="dump RGB images") 18 | parser.add_argument("--labels", action="store_true", help="dump label images") 19 | parser.add_argument("--depth", action="store_true", help="dump depth images") 20 | parser.add_argument("--depth-levels", type=int, default=20, help="number of disparity depth levels (default: 20)") 21 | parser.add_argument("--split", action="store_true", help="dump train/val split files") 22 | parser.add_argument("--split-val", type=float, default=0.15, help="fraction of dataset to split between train/val") 23 | 24 | args = parser.parse_args() 25 | 26 | input_images = [] 27 | input_labels = [] 28 | input_depths = [] 29 | 30 | global_depth_min = 1000.0 31 | global_depth_max = -1000.0 32 | 33 | 34 | # 35 | # load arrays from .mat files 36 | # 37 | for filename in args.input: 38 | print('\n==> loading ' + filename) 39 | mat = h5py.File(filename, 'r') 40 | print(list(mat.keys())) 41 | 42 | if args.images or args.split: 43 | print('reading images') 44 | 45 | images = mat['images'] 46 | images = np.array(images) 47 | 48 | print(images.shape) 49 | 50 | pixel_min = images.min() 51 | pixel_max = images.max() 52 | pixel_avg = images.mean() 53 | 54 | print('min pixel: {:f}'.format(pixel_min)) 55 | print('max pixel: {:f}'.format(pixel_max)) 56 | print('avg pixel: {:f}'.format(pixel_avg)) 57 | 58 | input_images.append(images) 59 | 60 | if args.depth: 61 | print('reading depths') 62 | 63 | depths = mat['depths'] 64 | depths = np.array(depths) 65 | 66 | print(depths.shape) 67 | 68 | depth_min = depths.min() 69 | depth_max = depths.max() 70 | depth_avg = depths.mean() 71 | 72 | print('min depth: {:f}'.format(depth_min)) 73 | print('max depth: {:f}'.format(depth_max)) 74 | print('avg depth: {:f}'.format(depth_avg)) 75 | 76 | if depth_min < global_depth_min: 77 | global_depth_min = depth_min 78 | 79 | if depth_max > global_depth_max: 80 | global_depth_max = depth_max 81 | 82 | input_depths.append(depths) 83 | 84 | print('') 85 | 86 | 87 | # 88 | # process source images 89 | # 90 | if args.images: 91 | images_path = os.path.join(args.output, 'images') 92 | 93 | if not os.path.exists(images_path): 94 | os.makedirs(images_path) 95 | 96 | for n in range(len(input_images)): 97 | for m in range(input_images[n].shape[0]): 98 | img_name = 'v{:d}_{:04d}.png'.format(n+1, m) 99 | img_path = os.path.join(images_path, img_name) 100 | 101 | print('processing image ' + img_path) 102 | 103 | img_in = input_images[n][m] 104 | img_in = np.moveaxis(img_in, [0, 1, 2], [2, 1, 0]) 105 | #print(img_in.shape) 106 | 107 | img_out = Image.fromarray(img_in.astype('uint8'), 'RGB') 108 | print(img_out.size) 109 | img_out.save(img_path) 110 | 111 | 112 | # 113 | # process depth images 114 | # 115 | if args.depth: 116 | print('global min depth: {:f}'.format(global_depth_min)) 117 | print('global max depth: {:f}'.format(global_depth_max)) 118 | 119 | depth_path = os.path.join(args.output, 'depth') 120 | 121 | if not os.path.exists(depth_path): 122 | os.makedirs(depth_path) 123 | 124 | for n in range(len(input_depths)): 125 | print('\nprocessing depth/v{:d}'.format(n+1)) 126 | 127 | arr = input_depths[n] 128 | 129 | # rescale the depths to lie between [0, args.depth_levels] 130 | arr = np.subtract(arr, global_depth_min) 131 | arr = np.multiply(arr, 1.0 / (global_depth_max - global_depth_min) * float(args.depth_levels)) 132 | 133 | depth_min = arr.min() 134 | depth_max = arr.max() 135 | depth_avg = arr.mean() 136 | 137 | print('min depth: {:f}'.format(depth_min)) 138 | print('max depth: {:f}'.format(depth_max)) 139 | print('avg depth: {:f}'.format(depth_avg)) 140 | 141 | for m in range(arr.shape[0]): 142 | img_name = 'v{:d}_{:04d}.png'.format(n+1, m) 143 | img_path = os.path.join(depth_path, img_name) 144 | 145 | print('processing depth ' + img_path) 146 | 147 | img_in = arr[m] 148 | img_in = np.moveaxis(img_in, [0, 1], [1, 0]) 149 | #print(img_in.shape) 150 | 151 | img_out = Image.fromarray(img_in.astype('uint8'), 'L') 152 | print(img_out.size) 153 | img_out.save(img_path) 154 | 155 | # 156 | # output train/val splits 157 | # 158 | if args.split: 159 | print('creating train/val splits') 160 | 161 | train_file = open(os.path.join(args.output, 'train.txt'), 'w') 162 | val_file = open(os.path.join(args.output, 'val.txt'), 'w') 163 | 164 | train_count = 0 165 | val_count = 0 166 | 167 | for n in range(len(input_images)): 168 | for m in range(input_images[n].shape[0]): 169 | img_name = 'v{:d}_{:04d}.png'.format(n+1, m) 170 | rand = random.random() 171 | 172 | if rand < args.split_val: 173 | val_file.write(img_name + '\n') 174 | val_count = val_count + 1 175 | else: 176 | train_file.write(img_name + '\n') 177 | train_count = train_count + 1 178 | 179 | print('total: {:d}'.format(train_count + val_count)) 180 | print('train: {:d}'.format(train_count)) 181 | print('val: {:d}'.format(val_count)) 182 | 183 | train_file.close() 184 | val_file.close() 185 | 186 | -------------------------------------------------------------------------------- /datasets/sun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | 5 | from PIL import Image 6 | from torch.utils.data import Dataset, DataLoader 7 | 8 | 9 | class SunRGBDSegmentation(Dataset): 10 | """http://rgbd.cs.princeton.edu/challenge.html""" 11 | """https://github.com/ankurhanda/sunrgbd-meta-data""" 12 | 13 | def __init__(self, root_dir, image_set='train', train_extra=True, transforms=None): 14 | """ 15 | Parameters: 16 | root_dir (string): Root directory of the dumped NYU-Depth dataset. 17 | image_set (string, optional): Select the image_set to use, ``train``, ``val`` 18 | train_extra (bool, optional): If True, use extra images during training 19 | transforms (callable, optional): Optional transform to be applied 20 | on a sample. 21 | """ 22 | self.root_dir = root_dir 23 | self.image_set = image_set 24 | self.transforms = transforms 25 | 26 | self.images = [] 27 | self.targets = [] 28 | 29 | if image_set == 'train': 30 | train_images, train_targets = self.gather_images(os.path.join(root_dir, 'SUNRGBD-train_images'), 31 | os.path.join(root_dir, 'train21labels')) 32 | 33 | self.images.extend(train_images) 34 | self.targets.extend(train_targets) 35 | 36 | if train_extra: 37 | extra_images, extra_targets = self.gather_images(os.path.join(root_dir, 'SUNRGBD-trainextra_images'), 38 | os.path.join(root_dir, 'trainextra21labels')) 39 | 40 | self.images.extend(extra_images) 41 | self.targets.extend(extra_targets) 42 | 43 | elif image_set == 'val': 44 | val_images, val_targets = self.gather_images(os.path.join(root_dir, 'SUNRGBD-test_images'), 45 | os.path.join(root_dir, 'test21labels')) 46 | 47 | self.images.extend(val_images) 48 | self.targets.extend(val_targets) 49 | 50 | def gather_images(self, images_path, labels_path, max_images=5500): 51 | image_files = [] 52 | label_files = [] 53 | 54 | for n in range(max_images): 55 | image_filename = os.path.join(images_path, 'img-{:06d}.jpg'.format(n)) 56 | label_filename = os.path.join(labels_path, 'img-{:06d}.png'.format(n)) 57 | 58 | if os.path.isfile(image_filename) and os.path.isfile(label_filename): 59 | image_files.append(image_filename) 60 | label_files.append(label_filename) 61 | 62 | return image_files, label_files 63 | 64 | def __len__(self): 65 | return len(self.images) 66 | 67 | def __getitem__(self, index): 68 | image = Image.open(self.images[index]).convert('RGB') 69 | target = Image.open(self.targets[index]) 70 | 71 | if self.transforms is not None: 72 | image, target = self.transforms(image, target) 73 | 74 | return image, target 75 | 76 | -------------------------------------------------------------------------------- /datasets/sun_remap.py: -------------------------------------------------------------------------------- 1 | # 2 | # this script remaps the SUNRGB-D class label ID's (range 0-37) 3 | # to lie within a range that the segmentation networks support (21 classes) 4 | # 5 | # note that this processes the metadata downloaded from here: 6 | # https://github.com/ankurhanda/sunrgbd-meta-data 7 | # 8 | import os 9 | import re 10 | import copy 11 | import argparse 12 | 13 | from PIL import Image 14 | from multiprocessing import Pool as ProcessPool 15 | 16 | 17 | # 18 | # map of existing class label ID's (range 0-37) to new ID's (range 0-21) 19 | # each entry consists of a tuple (new_ID, name, color) 20 | # 21 | CLASS_MAP = [ (0, 'other', (0, 0, 0)), 22 | (1, 'wall', (128, 0, 0)), 23 | (2, 'floor', (0, 128, 0)), 24 | (3, 'cabinet', (128, 128, 0)), 25 | (4, 'bed', (0, 0, 128)), 26 | (5, 'chair', (128, 0, 128)), 27 | (6, 'sofa', (0, 128, 128)), 28 | (7, 'table', (128, 128, 128)), 29 | (8, 'door', (64, 0, 0)), 30 | (9, 'window', (192, 0, 0)), 31 | (3, 'bookshelf', (64, 128, 0)), 32 | (10, 'picture', (192, 128, 0)), 33 | (7, 'counter', (64, 0, 128)), 34 | (11, 'blinds', (192, 0, 128)), 35 | (7, 'desk', (64, 128, 128)), 36 | (3, 'shelves', (192, 128, 128)), 37 | (11, 'curtain', (0, 64, 0)), 38 | (3, 'dresser', (128, 64, 0)), 39 | (4, 'pillow', (0, 192, 0)), 40 | (10, 'mirror', (128, 192, 0)), 41 | (2, 'floor_mat', (0, 64, 128)), 42 | (12, 'clothes', (128, 64, 128)), 43 | (13, 'ceiling', (0, 192, 128)), 44 | (14, 'books', (128, 192, 128)), 45 | (15, 'fridge', (64, 64, 0)), 46 | (10, 'tv', (192, 64, 0)), 47 | (0, 'paper', (64, 192, 0)), 48 | (12, 'towel', (192, 192, 0)), 49 | (20, 'shower_curtain', (64, 64, 128)), 50 | (0, 'box', (192, 64, 128)), 51 | (1, 'whiteboard', (64, 192, 128)), 52 | (16, 'person', (192, 192, 128)), 53 | (7, 'night_stand', (0, 0, 64)), 54 | (17, 'toilet', (128, 0, 64)), 55 | (18, 'sink', (0, 128, 64)), 56 | (19, 'lamp', (128, 128, 64)), 57 | (20, 'bathtub', (0, 0, 192)), 58 | (0, 'bag', (128, 0, 192)) ] 59 | 60 | 61 | # 62 | # new class label names, corresponding to remapped class ID's (range 0-21) 63 | # 64 | """ 65 | other 66 | wall 67 | floor 68 | cabinet/shelves/bookshelf/dresser 69 | bed/pillow 70 | chair 71 | sofa 72 | table 73 | door 74 | window 75 | picture/tv/mirror 76 | blinds/curtain 77 | clothes 78 | ceiling 79 | books 80 | fridge 81 | person 82 | toilet 83 | sink 84 | lamp 85 | bathtub 86 | """ 87 | 88 | 89 | # Generate Color Map in PASCAL VOC format 90 | def generate_color_map(N=38): 91 | """ 92 | https://github.com/meetshah1995/pytorch-semseg/blob/801fb200547caa5b0d91b8dde56b837da029f746/ptsemseg/loader/sunrgbd_loader.py#L108 93 | """ 94 | def bitget(byteval, idx): 95 | return (byteval & (1 << idx)) != 0 96 | 97 | print('') 98 | print('color map: ') 99 | 100 | cmap = [] 101 | 102 | for i in range(N): 103 | r = g = b = 0 104 | c = i 105 | 106 | for j in range(8): 107 | r = r | (bitget(c, 0) << 7 - j) 108 | g = g | (bitget(c, 1) << 7 - j) 109 | b = b | (bitget(c, 2) << 7 - j) 110 | c = c >> 3 111 | 112 | color = (r,g,b) 113 | print(color) 114 | cmap.append(color) 115 | 116 | return cmap 117 | 118 | 119 | def remap_labels(args): 120 | input_path = args[0] 121 | output_path = args[1] 122 | colorized = args[2] 123 | 124 | print('{:s} -> {:s}'.format(input_path, output_path)) 125 | 126 | if os.path.isfile(output_path): 127 | print('skipping image {:s}, already exists'.format(output_path)) 128 | return 129 | 130 | img_input = Image.open(input_path) 131 | img_output = Image.new('RGB' if colorized is True else 'L', (img_input.width, img_input.height)) 132 | 133 | for y in range(img_input.height): 134 | for x in range(img_input.width): 135 | org_label = img_input.getpixel((x,y))#[0] 136 | new_label = CLASS_MAP[org_label][2 if colorized else 0] 137 | img_output.putpixel((x,y), new_label) 138 | 139 | img_output.save(output_path) 140 | 141 | 142 | def sorted_alphanumeric(data): 143 | convert = lambda text: int(text) if text.isdigit() else text.lower() 144 | alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 145 | return sorted(data, key=alphanum_key) 146 | 147 | 148 | if __name__ == "__main__": 149 | 150 | parser = argparse.ArgumentParser(description='Remap SUNRGB-D Segmentation Images') 151 | parser.add_argument('input', type=str, metavar='IN', help='path to directory of annotated images to remap') 152 | parser.add_argument('output', type=str, metavar='OUT', help='path to directory to save remaped annotation images') 153 | parser.add_argument('--colorized', action='store_true', help='output colorized segmentation maps (RGB)') 154 | parser.add_argument('--workers', default=8, type=int, metavar='N', help='number of data loading workers (default: 8)') 155 | args = parser.parse_args() 156 | 157 | if not os.path.exists(args.output): 158 | os.makedirs(args.output) 159 | 160 | files = sorted_alphanumeric(os.listdir(args.input)) 161 | worker_args = [] 162 | 163 | for n in range(len(files)): 164 | worker_args.append((os.path.join(args.input, files[n]), os.path.join(args.output, 'img-{:06d}.png'.format(n+1)), args.colorized)) 165 | 166 | #for n in worker_args: 167 | # remap_labels(n) 168 | 169 | with ProcessPool(processes=args.workers) as pool: 170 | pool.map(remap_labels, worker_args) 171 | 172 | -------------------------------------------------------------------------------- /onnx_export.py: -------------------------------------------------------------------------------- 1 | # 2 | # converts a saved PyTorch model to ONNX format 3 | # 4 | import os 5 | import argparse 6 | 7 | import torch 8 | import torchvision.models as models 9 | 10 | 11 | # parse command line 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--input', type=str, default='model_best.pth', help="path to input PyTorch model (default: model_best.pth)") 14 | parser.add_argument('--output', type=str, default='', help="desired path of converted ONNX model (default: .onnx)") 15 | parser.add_argument('--model-dir', type=str, default='', help="directory to look for the input PyTorch model in, and export the converted ONNX model to (if --output doesn't specify a directory)") 16 | 17 | opt = parser.parse_args() 18 | print(opt) 19 | 20 | # format input model path 21 | if opt.model_dir: 22 | opt.model_dir = os.path.expanduser(opt.model_dir) 23 | opt.input = os.path.join(opt.model_dir, opt.input) 24 | 25 | # set the device 26 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 27 | print('running on device ' + str(device)) 28 | 29 | # load the model checkpoint 30 | print('loading checkpoint: ' + opt.input) 31 | checkpoint = torch.load(opt.input) 32 | 33 | arch = checkpoint['arch'] 34 | num_classes = checkpoint['num_classes'] 35 | 36 | print('checkpoint accuracy: {:.3f}% mean IoU, {:.3f}% accuracy'.format(checkpoint['mean_IoU'], checkpoint['accuracy'])) 37 | 38 | # create the model architecture 39 | print('using model: ' + arch) 40 | print('num classes: ' + str(num_classes)) 41 | 42 | model = models.segmentation.__dict__[arch](num_classes=num_classes, 43 | aux_loss=None, 44 | pretrained=False, 45 | export_onnx=True) 46 | 47 | # load the model weights 48 | model.load_state_dict(checkpoint['model']) 49 | 50 | model.to(device) 51 | model.eval() 52 | 53 | print(model) 54 | print('') 55 | 56 | # create example image data 57 | resolution = checkpoint['resolution'] 58 | input = torch.ones((1, 3, resolution[0], resolution[1])).cuda() 59 | print('input size: {:d}x{:d}'.format(resolution[1], resolution[0])) 60 | 61 | # format output model path 62 | if not opt.output: 63 | opt.output = arch + '.onnx' 64 | 65 | if opt.model_dir and opt.output.find('/') == -1 and opt.output.find('\\') == -1: 66 | opt.output = os.path.join(opt.model_dir, opt.output) 67 | 68 | # export the model 69 | input_names = [ "input_0" ] 70 | output_names = [ "output_0" ] 71 | 72 | print('exporting model to ONNX...') 73 | torch.onnx.export(model, input, opt.output, verbose=True, input_names=input_names, output_names=output_names) 74 | print('model exported to: {:s}'.format(opt.output)) 75 | 76 | 77 | -------------------------------------------------------------------------------- /onnx_validate.py: -------------------------------------------------------------------------------- 1 | # 2 | # Check that an ONNX model is valid and well-formed. 3 | # 4 | # Before running this script, install the following: 5 | # 6 | # $ sudo apt-get install protobuf-compiler libprotoc-dev 7 | # $ pip install onnx 8 | # 9 | import onnx 10 | import argparse 11 | import sys 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('model', type=str, default='resnet18.onnx', help='path to ONNX model to validate') 15 | args = parser.parse_args(sys.argv[1:]) 16 | 17 | # Load the ONNX model 18 | model = onnx.load(args.model) 19 | 20 | # Print a human readable representation of the graph 21 | print('Network Graph:') 22 | print(onnx.helper.printable_graph(model.graph)) 23 | print('') 24 | 25 | # Print model metadata 26 | print('ONNX version: ' + onnx.__version__) 27 | print('IR version: {:d}'.format(model.ir_version)) 28 | print('Producer name: ' + model.producer_name) 29 | print('Producer version: ' + model.producer_version) 30 | print('Model version: {:d}'.format(model.model_version)) 31 | print('') 32 | 33 | # Check that the IR is well formed 34 | print('Checking model IR...') 35 | onnx.checker.check_model(model) 36 | print('The model was checked successfully!') 37 | 38 | 39 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Note -- this training script is tweaked from the original version at: 3 | # 4 | # https://github.com/pytorch/vision/tree/v0.3.0/references/segmentation 5 | # 6 | # It's also meant to be used against this fork of torchvision, which includes 7 | # some patches for exporting to ONNX and adds fcn_resnet18 and fcn_resnet34: 8 | # 9 | # https://github.com/dusty-nv/vision/tree/v0.3.0 10 | # 11 | import argparse 12 | import datetime 13 | import time 14 | import math 15 | import os 16 | import shutil 17 | 18 | import torch 19 | import torch.utils.data 20 | from torch import nn 21 | import torchvision 22 | 23 | from datasets.coco_utils import get_coco 24 | from datasets.cityscapes_utils import get_cityscapes 25 | from datasets.deepscene import DeepSceneSegmentation 26 | 27 | from datasets.mhp import MHPSegmentation 28 | from datasets.nyu import NYUDepth 29 | from datasets.sun import SunRGBDSegmentation 30 | 31 | import transforms as T 32 | import utils 33 | 34 | model_names = sorted(name for name in torchvision.models.segmentation.__dict__ 35 | if name.islower() and not name.startswith("__") 36 | and callable(torchvision.models.segmentation.__dict__[name])) 37 | 38 | # 39 | # parse command-line arguments 40 | # 41 | def parse_args(): 42 | parser = argparse.ArgumentParser(description='PyTorch Segmentation Training') 43 | 44 | parser.add_argument('data', metavar='DIR', help='path to dataset') 45 | parser.add_argument('--dataset', default='voc', help='dataset type: voc, voc_aug, coco, cityscapes, deepscene, mhp, nyu, sun (default: voc)') 46 | parser.add_argument('-a', '--arch', metavar='ARCH', default='fcn_resnet18', 47 | choices=model_names, 48 | help='model architecture: ' + 49 | ' | '.join(model_names) + 50 | ' (default: fcn_resnet18)') 51 | parser.add_argument('--aux-loss', action='store_true', help='train with auxilliary loss') 52 | parser.add_argument('--resolution', default=320, type=int, metavar='N', 53 | help='NxN resolution used for scaling the training dataset (default: 320x320) ' 54 | 'to specify a non-square resolution, use the --width and --height options') 55 | parser.add_argument('--width', default=argparse.SUPPRESS, type=int, metavar='X', 56 | help='desired width of the training dataset. if this option is not set, --resolution will be used') 57 | parser.add_argument('--height', default=argparse.SUPPRESS, type=int, metavar='Y', 58 | help='desired height of the training dataset. if this option is not set, --resolution will be used') 59 | parser.add_argument('--device', default='cuda', help='device') 60 | parser.add_argument('-b', '--batch-size', default=4, type=int) 61 | parser.add_argument('--epochs', default=30, type=int, metavar='N', help='number of total epochs to run') 62 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 63 | help='number of data loading workers (default: 16)') 64 | parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate') 65 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 66 | help='momentum') 67 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 68 | metavar='W', help='weight decay (default: 1e-4)', 69 | dest='weight_decay') 70 | parser.add_argument('--print-freq', default=10, type=int, help='print frequency') 71 | parser.add_argument('--model-dir', default='.', help='path where to save output models') 72 | parser.add_argument('--resume', default='', help='resume from checkpoint') 73 | parser.add_argument("--test-only", dest="test_only", help="Only test the model", action="store_true") 74 | parser.add_argument("--pretrained", dest="pretrained", help="Use pre-trained models (only supported for fcn_resnet101)", action="store_true") 75 | 76 | # distributed training parameters 77 | parser.add_argument('--world-size', default=1, type=int, 78 | help='number of distributed processes') 79 | parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') 80 | 81 | args = parser.parse_args() 82 | return args 83 | 84 | 85 | # 86 | # load desired dataset 87 | # 88 | def get_dataset(name, path, image_set, transform): 89 | def sbd(*args, **kwargs): 90 | return torchvision.datasets.SBDataset(*args, mode='segmentation', **kwargs) 91 | paths = { 92 | "voc": (path, torchvision.datasets.VOCSegmentation, 21), 93 | "voc_aug": (path, sbd, 21), 94 | "coco": (path, get_coco, 21), 95 | "cityscapes": (path, get_cityscapes, 21), 96 | "deepscene": (path, DeepSceneSegmentation, 5), 97 | "mhp": (path, MHPSegmentation, 21), 98 | "nyu": (path, NYUDepth, 21), 99 | "sun": (path, SunRGBDSegmentation, 21), 100 | } 101 | p, ds_fn, num_classes = paths[name] 102 | 103 | ds = ds_fn(p, image_set=image_set, transforms=transform) 104 | return ds, num_classes 105 | 106 | 107 | # 108 | # create data transform 109 | # 110 | def get_transform(train, resolution): 111 | transforms = [] 112 | 113 | # if square resolution, perform some aspect cropping 114 | # otherwise, resize to the resolution as specified 115 | if resolution[0] == resolution[1]: 116 | base_size = resolution[0] + 32 #520 117 | crop_size = resolution[0] #480 118 | 119 | min_size = int((0.5 if train else 1.0) * base_size) 120 | max_size = int((2.0 if train else 1.0) * base_size) 121 | 122 | transforms.append(T.RandomResize(min_size, max_size)) 123 | 124 | # during training mode, perform some data randomization 125 | if train: 126 | transforms.append(T.RandomHorizontalFlip(0.5)) 127 | transforms.append(T.RandomCrop(crop_size)) 128 | else: 129 | transforms.append(T.Resize(resolution)) 130 | 131 | if train: 132 | transforms.append(T.RandomHorizontalFlip(0.5)) 133 | 134 | transforms.append(T.ToTensor()) 135 | transforms.append(T.Normalize(mean=[0.485, 0.456, 0.406], 136 | std=[0.229, 0.224, 0.225])) 137 | 138 | return T.Compose(transforms) 139 | 140 | 141 | # 142 | # define the loss functions 143 | # 144 | def criterion(inputs, target): 145 | losses = {} 146 | for name, x in inputs.items(): 147 | losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255) 148 | 149 | if len(losses) == 1: 150 | return losses['out'] 151 | 152 | return losses['out'] + 0.5 * losses['aux'] 153 | 154 | 155 | # 156 | # evaluate model IoU (intersection over union) 157 | # 158 | def evaluate(model, data_loader, device, num_classes): 159 | model.eval() 160 | confmat = utils.ConfusionMatrix(num_classes) 161 | metric_logger = utils.MetricLogger(delimiter=" ") 162 | header = 'Test:' 163 | with torch.no_grad(): 164 | for image, target in metric_logger.log_every(data_loader, 100, header): 165 | image, target = image.to(device), target.to(device) 166 | output = model(image) 167 | output = output['out'] 168 | 169 | confmat.update(target.flatten(), output.argmax(1).flatten()) 170 | 171 | confmat.reduce_from_all_processes() 172 | 173 | return confmat 174 | 175 | 176 | # 177 | # train for one epoch over the dataset 178 | # 179 | def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq): 180 | model.train() 181 | metric_logger = utils.MetricLogger(delimiter=" ") 182 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) 183 | header = 'Epoch: [{}]'.format(epoch) 184 | for image, target in metric_logger.log_every(data_loader, print_freq, header): 185 | image, target = image.to(device), target.to(device) 186 | output = model(image) 187 | loss = criterion(output, target) 188 | 189 | optimizer.zero_grad() 190 | loss.backward() 191 | optimizer.step() 192 | 193 | lr_scheduler.step() 194 | 195 | metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) 196 | 197 | 198 | # 199 | # main training function 200 | # 201 | def main(args): 202 | if args.model_dir: 203 | utils.mkdir(args.model_dir) 204 | 205 | utils.init_distributed_mode(args) 206 | print(args) 207 | 208 | device = torch.device(args.device) 209 | 210 | # determine the desired resolution 211 | resolution = (args.resolution, args.resolution) 212 | 213 | if "width" in args and "height" in args: 214 | resolution = (args.height, args.width) 215 | 216 | # load the train and val datasets 217 | dataset, num_classes = get_dataset(args.dataset, args.data, "train", get_transform(train=True, resolution=resolution)) 218 | dataset_test, _ = get_dataset(args.dataset, args.data, "val", get_transform(train=False, resolution=resolution)) 219 | 220 | if args.distributed: 221 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 222 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) 223 | else: 224 | train_sampler = torch.utils.data.RandomSampler(dataset) 225 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 226 | 227 | data_loader = torch.utils.data.DataLoader( 228 | dataset, batch_size=args.batch_size, 229 | sampler=train_sampler, num_workers=args.workers, 230 | collate_fn=utils.collate_fn, drop_last=True) 231 | 232 | data_loader_test = torch.utils.data.DataLoader( 233 | dataset_test, batch_size=1, 234 | sampler=test_sampler, num_workers=args.workers, 235 | collate_fn=utils.collate_fn) 236 | 237 | print("=> training with dataset: '{:s}' (train={:d}, val={:d})".format(args.dataset, len(dataset), len(dataset_test))) 238 | print("=> training with resolution: {:d}x{:d}, {:d} classes".format(resolution[1], resolution[0], num_classes)) 239 | print("=> training with model: {:s}".format(args.arch)) 240 | 241 | # create the segmentation model 242 | model = torchvision.models.segmentation.__dict__[args.arch](num_classes=num_classes, 243 | aux_loss=args.aux_loss, 244 | pretrained=args.pretrained) 245 | model.to(device) 246 | 247 | if args.distributed: 248 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 249 | 250 | if args.resume: 251 | checkpoint = torch.load(args.resume, map_location='cpu') 252 | model.load_state_dict(checkpoint['model']) 253 | 254 | model_without_ddp = model 255 | 256 | if args.distributed: 257 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 258 | model_without_ddp = model.module 259 | 260 | # eval-only mode 261 | if args.test_only: 262 | confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) 263 | print(confmat) 264 | return 265 | 266 | # create the optimizer 267 | params_to_optimize = [ 268 | {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]}, 269 | {"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]}, 270 | ] 271 | 272 | if args.aux_loss: 273 | params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad] 274 | params_to_optimize.append({"params": params, "lr": args.lr * 10}) 275 | 276 | optimizer = torch.optim.SGD( 277 | params_to_optimize, 278 | lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 279 | 280 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR( 281 | optimizer, 282 | lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) 283 | 284 | # training loop 285 | start_time = time.time() 286 | best_IoU = 0.0 287 | 288 | for epoch in range(args.epochs): 289 | if args.distributed: 290 | train_sampler.set_epoch(epoch) 291 | 292 | # train the model over the next epoc 293 | train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq) 294 | 295 | # test the model on the val dataset 296 | confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) 297 | print(confmat) 298 | 299 | # save model checkpoint 300 | checkpoint_path = os.path.join(args.model_dir, 'model_{}.pth'.format(epoch)) 301 | 302 | utils.save_on_master( 303 | { 304 | 'model': model_without_ddp.state_dict(), 305 | 'optimizer': optimizer.state_dict(), 306 | 'epoch': epoch, 307 | 'args': args, 308 | 'arch': args.arch, 309 | 'dataset': args.dataset, 310 | 'num_classes': num_classes, 311 | 'resolution': resolution, 312 | 'accuracy': confmat.acc_global, 313 | 'mean_IoU': confmat.mean_IoU 314 | }, 315 | checkpoint_path) 316 | 317 | print('saved checkpoint to: {:s} ({:.3f}% mean IoU, {:.3f}% accuracy)'.format(checkpoint_path, confmat.mean_IoU, confmat.acc_global)) 318 | 319 | if confmat.mean_IoU > best_IoU: 320 | best_IoU = confmat.mean_IoU 321 | best_path = os.path.join(args.model_dir, 'model_best.pth') 322 | shutil.copyfile(checkpoint_path, best_path) 323 | print('saved best model to: {:s} ({:.3f}% mean IoU, {:.3f}% accuracy)'.format(best_path, best_IoU, confmat.acc_global)) 324 | 325 | total_time = time.time() - start_time 326 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 327 | print('Training time {}'.format(total_time_str)) 328 | 329 | 330 | if __name__ == "__main__": 331 | args = parse_args() 332 | main(args) 333 | 334 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import random 4 | 5 | import torch 6 | from torchvision import transforms as T 7 | from torchvision.transforms import functional as F 8 | 9 | 10 | def pad_if_smaller(img, size, fill=0): 11 | min_size = min(img.size) 12 | if min_size < size: 13 | ow, oh = img.size 14 | padh = size - oh if oh < size else 0 15 | padw = size - ow if ow < size else 0 16 | img = F.pad(img, (0, 0, padw, padh), fill=fill) 17 | return img 18 | 19 | 20 | class Compose(object): 21 | def __init__(self, transforms): 22 | self.transforms = transforms 23 | 24 | def __call__(self, image, target): 25 | for t in self.transforms: 26 | image, target = t(image, target) 27 | return image, target 28 | 29 | 30 | class Resize(object): 31 | def __init__(self, size): 32 | self.size = size 33 | 34 | def __call__(self, image, target): 35 | image = F.resize(image, self.size) 36 | target = F.resize(target, self.size, interpolation=Image.NEAREST) 37 | return image, target 38 | 39 | class RandomResize(object): 40 | def __init__(self, min_size, max_size=None): 41 | self.min_size = min_size 42 | if max_size is None: 43 | max_size = min_size 44 | self.max_size = max_size 45 | 46 | def __call__(self, image, target): 47 | size = random.randint(self.min_size, self.max_size) 48 | image = F.resize(image, size) 49 | target = F.resize(target, size, interpolation=Image.NEAREST) 50 | return image, target 51 | 52 | 53 | class RandomHorizontalFlip(object): 54 | def __init__(self, flip_prob): 55 | self.flip_prob = flip_prob 56 | 57 | def __call__(self, image, target): 58 | if random.random() < self.flip_prob: 59 | image = F.hflip(image) 60 | target = F.hflip(target) 61 | return image, target 62 | 63 | 64 | class RandomCrop(object): 65 | def __init__(self, size): 66 | self.size = size 67 | 68 | def __call__(self, image, target): 69 | image = pad_if_smaller(image, self.size) 70 | target = pad_if_smaller(target, self.size, fill=255) 71 | crop_params = T.RandomCrop.get_params(image, (self.size, self.size)) 72 | image = F.crop(image, *crop_params) 73 | target = F.crop(target, *crop_params) 74 | return image, target 75 | 76 | 77 | class CenterCrop(object): 78 | def __init__(self, size): 79 | self.size = size 80 | 81 | def __call__(self, image, target): 82 | image = F.center_crop(image, self.size) 83 | target = F.center_crop(target, self.size) 84 | return image, target 85 | 86 | 87 | class ToTensor(object): 88 | def __call__(self, image, target): 89 | image = F.to_tensor(image) 90 | target = torch.as_tensor(np.asarray(target), dtype=torch.int64) 91 | return image, target 92 | 93 | 94 | class Normalize(object): 95 | def __init__(self, mean, std): 96 | self.mean = mean 97 | self.std = std 98 | 99 | def __call__(self, image, target): 100 | image = F.normalize(image, mean=self.mean, std=self.std) 101 | return image, target 102 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from collections import defaultdict, deque 3 | import datetime 4 | import math 5 | import time 6 | import torch 7 | import torch.distributed as dist 8 | 9 | import errno 10 | import os 11 | 12 | 13 | class SmoothedValue(object): 14 | """Track a series of values and provide access to smoothed values over a 15 | window or the global series average. 16 | """ 17 | 18 | def __init__(self, window_size=20, fmt=None): 19 | if fmt is None: 20 | fmt = "{median:.4f} ({global_avg:.4f})" 21 | self.deque = deque(maxlen=window_size) 22 | self.total = 0.0 23 | self.count = 0 24 | self.fmt = fmt 25 | 26 | def update(self, value, n=1): 27 | self.deque.append(value) 28 | self.count += n 29 | self.total += value * n 30 | 31 | def synchronize_between_processes(self): 32 | """ 33 | Warning: does not synchronize the deque! 34 | """ 35 | if not is_dist_avail_and_initialized(): 36 | return 37 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 38 | dist.barrier() 39 | dist.all_reduce(t) 40 | t = t.tolist() 41 | self.count = int(t[0]) 42 | self.total = t[1] 43 | 44 | @property 45 | def median(self): 46 | d = torch.tensor(list(self.deque)) 47 | return d.median().item() 48 | 49 | @property 50 | def avg(self): 51 | d = torch.tensor(list(self.deque), dtype=torch.float32) 52 | return d.mean().item() 53 | 54 | @property 55 | def global_avg(self): 56 | return self.total / self.count 57 | 58 | @property 59 | def max(self): 60 | return max(self.deque) 61 | 62 | @property 63 | def value(self): 64 | return self.deque[-1] 65 | 66 | def __str__(self): 67 | return self.fmt.format( 68 | median=self.median, 69 | avg=self.avg, 70 | global_avg=self.global_avg, 71 | max=self.max, 72 | value=self.value) 73 | 74 | 75 | class ConfusionMatrix(object): 76 | def __init__(self, num_classes): 77 | self.num_classes = num_classes 78 | self.mat = None 79 | self.acc_global = 0.0 80 | self.mean_IoU = 0.0 81 | 82 | def update(self, a, b): 83 | n = self.num_classes 84 | if self.mat is None: 85 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device) 86 | with torch.no_grad(): 87 | k = (a >= 0) & (a < n) 88 | inds = n * a[k].to(torch.int64) + b[k] 89 | self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) 90 | 91 | def reset(self): 92 | self.mat.zero_() 93 | 94 | def compute(self): 95 | h = self.mat.float() 96 | 97 | # not all object classes may have data from the val category 98 | h_sum1 = h.sum(1) 99 | h_sum1[h_sum1 == 0] = 1 100 | 101 | if not torch.equal(h.sum(1), h_sum1): 102 | print('Test: Warning -- some classes may be missing validation examples') 103 | 104 | acc_global = torch.diag(h).sum() / h.sum() 105 | acc = torch.diag(h) / h.sum(1) 106 | iu = torch.diag(h) / (h_sum1 + h.sum(0) - torch.diag(h)) 107 | self.acc_global = acc_global.item() * 100 108 | self.mean_IoU = iu.mean().item() * 100 109 | return acc_global, acc, iu 110 | 111 | def reduce_from_all_processes(self): 112 | if not torch.distributed.is_available(): 113 | return 114 | if not torch.distributed.is_initialized(): 115 | return 116 | torch.distributed.barrier() 117 | torch.distributed.all_reduce(self.mat) 118 | 119 | def __str__(self): 120 | acc_global, acc, iu = self.compute() 121 | return ( 122 | 'global correct: {:.1f}\n' 123 | 'average row correct: {}\n' 124 | 'IoU: {}\n' 125 | 'mean IoU: {:.1f}').format( 126 | acc_global.item() * 100, 127 | ['{:.1f}'.format(i) for i in (acc * 100).tolist()], 128 | ['{:.1f}'.format(i) for i in (iu * 100).tolist()], 129 | iu.mean().item() * 100) 130 | 131 | 132 | class MetricLogger(object): 133 | def __init__(self, delimiter="\t"): 134 | self.meters = defaultdict(SmoothedValue) 135 | self.delimiter = delimiter 136 | 137 | def update(self, **kwargs): 138 | for k, v in kwargs.items(): 139 | if isinstance(v, torch.Tensor): 140 | v = v.item() 141 | assert isinstance(v, (float, int)) 142 | self.meters[k].update(v) 143 | 144 | def __getattr__(self, attr): 145 | if attr in self.meters: 146 | return self.meters[attr] 147 | if attr in self.__dict__: 148 | return self.__dict__[attr] 149 | raise AttributeError("'{}' object has no attribute '{}'".format( 150 | type(self).__name__, attr)) 151 | 152 | def __str__(self): 153 | loss_str = [] 154 | for name, meter in self.meters.items(): 155 | loss_str.append( 156 | "{}: {}".format(name, str(meter)) 157 | ) 158 | return self.delimiter.join(loss_str) 159 | 160 | def synchronize_between_processes(self): 161 | for meter in self.meters.values(): 162 | meter.synchronize_between_processes() 163 | 164 | def add_meter(self, name, meter): 165 | self.meters[name] = meter 166 | 167 | def log_every(self, iterable, print_freq, header=None): 168 | i = 0 169 | if not header: 170 | header = '' 171 | start_time = time.time() 172 | end = time.time() 173 | iter_time = SmoothedValue(fmt='{avg:.4f}') 174 | data_time = SmoothedValue(fmt='{avg:.4f}') 175 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 176 | log_msg = self.delimiter.join([ 177 | header, 178 | '[{0' + space_fmt + '}/{1}]', 179 | 'eta: {eta}', 180 | '{meters}', 181 | 'time: {time}', 182 | 'data: {data}', 183 | 'max mem: {memory:.0f}' 184 | ]) 185 | MB = 1024.0 * 1024.0 186 | for obj in iterable: 187 | data_time.update(time.time() - end) 188 | yield obj 189 | iter_time.update(time.time() - end) 190 | if i % print_freq == 0: 191 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 192 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 193 | print(log_msg.format( 194 | i, len(iterable), eta=eta_string, 195 | meters=str(self), 196 | time=str(iter_time), data=str(data_time), 197 | memory=torch.cuda.max_memory_allocated() / MB)) 198 | i += 1 199 | end = time.time() 200 | total_time = time.time() - start_time 201 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 202 | print('{} Total time: {}'.format(header, total_time_str)) 203 | 204 | 205 | def cat_list(images, fill_value=0): 206 | max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) 207 | batch_shape = (len(images),) + max_size 208 | batched_imgs = images[0].new(*batch_shape).fill_(fill_value) 209 | for img, pad_img in zip(images, batched_imgs): 210 | pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img) 211 | return batched_imgs 212 | 213 | 214 | def collate_fn(batch): 215 | images, targets = list(zip(*batch)) 216 | batched_imgs = cat_list(images, fill_value=0) 217 | batched_targets = cat_list(targets, fill_value=255) 218 | return batched_imgs, batched_targets 219 | 220 | 221 | def mkdir(path): 222 | try: 223 | os.makedirs(path) 224 | except OSError as e: 225 | if e.errno != errno.EEXIST: 226 | raise 227 | 228 | 229 | def setup_for_distributed(is_master): 230 | """ 231 | This function disables printing when not in master process 232 | """ 233 | import builtins as __builtin__ 234 | builtin_print = __builtin__.print 235 | 236 | def print(*args, **kwargs): 237 | force = kwargs.pop('force', False) 238 | if is_master or force: 239 | builtin_print(*args, **kwargs) 240 | 241 | __builtin__.print = print 242 | 243 | 244 | def is_dist_avail_and_initialized(): 245 | if not dist.is_available(): 246 | return False 247 | if not dist.is_initialized(): 248 | return False 249 | return True 250 | 251 | 252 | def get_world_size(): 253 | if not is_dist_avail_and_initialized(): 254 | return 1 255 | return dist.get_world_size() 256 | 257 | 258 | def get_rank(): 259 | if not is_dist_avail_and_initialized(): 260 | return 0 261 | return dist.get_rank() 262 | 263 | 264 | def is_main_process(): 265 | return get_rank() == 0 266 | 267 | 268 | def save_on_master(*args, **kwargs): 269 | if is_main_process(): 270 | torch.save(*args, **kwargs) 271 | 272 | 273 | def init_distributed_mode(args): 274 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 275 | args.rank = int(os.environ["RANK"]) 276 | args.world_size = int(os.environ['WORLD_SIZE']) 277 | args.gpu = int(os.environ['LOCAL_RANK']) 278 | elif 'SLURM_PROCID' in os.environ: 279 | args.rank = int(os.environ['SLURM_PROCID']) 280 | args.gpu = args.rank % torch.cuda.device_count() 281 | elif hasattr(args, "rank"): 282 | pass 283 | else: 284 | print('Not using distributed mode') 285 | args.distributed = False 286 | return 287 | 288 | args.distributed = True 289 | 290 | torch.cuda.set_device(args.gpu) 291 | args.dist_backend = 'nccl' 292 | print('| distributed init (rank {}): {}'.format( 293 | args.rank, args.dist_url), flush=True) 294 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 295 | world_size=args.world_size, rank=args.rank) 296 | setup_for_distributed(args.rank == 0) 297 | --------------------------------------------------------------------------------