├── .gitignore ├── README.md ├── datasets ├── __init__.py ├── cityscapes_remap.py ├── cityscapes_utils.py ├── coco_utils.py ├── corrupt_images.py ├── custom_dataset.py ├── deepscene.py ├── deepscene_remap.py ├── mhp.py ├── mhp_remap.py ├── mhp_utils.py ├── nyu.py ├── nyu_dump.py ├── sun.py └── sun_remap.py ├── labelme2voc.py ├── models ├── __init__.py ├── _utils.py ├── resnet.py ├── segmentation │ ├── __init__.py │ ├── _utils.py │ ├── deeplabv3.py │ ├── fcn.py │ └── segmentation.py └── utils.py ├── onnx_export.py ├── onnx_validate.py ├── requirements.txt ├── split_custom.py ├── train.py ├── transforms.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | #models 107 | *.pth 108 | *.onnx 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-segmentation 2 | Training of semantic segmentation networks with PyTorch 3 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | print("pytorch-segmentation/datasets/__init__.py") 2 | 3 | #from .mhp import * 4 | -------------------------------------------------------------------------------- /datasets/cityscapes_remap.py: -------------------------------------------------------------------------------- 1 | # 2 | # this script remaps the original Cityscapes class label ID's (range 0-33) 3 | # to lie within a range that the segmentation networks support (21 classes) 4 | # 5 | # see below for the mapping of the original class ID's to the new class ID's, 6 | # along with the class descriptions, some of which are combined from the originals. 7 | # 8 | # this script will overwrite the *_labelIds.png files from the gtCoarse/gtFine sets. 9 | # to run it, launch these example commands for the desired train/train_extra/val sets: 10 | # 11 | # $ python3 cityscapes_remap.py /gtCoarse/train 12 | # $ python3 cityscapes_remap.py /gtCoarse/val 13 | # 14 | import os 15 | import copy 16 | import argparse 17 | 18 | from PIL import Image 19 | from multiprocessing import Pool as ProcessPool 20 | 21 | 22 | # 23 | # map of existing class label ID's (range 0-33) to new ID's (range 0-21) 24 | # 25 | LABEL_MAP = [0, # unlabeled 26 | 1, # ego vehicle 27 | 0, # rectification border 28 | 0, # out of roi 29 | 2, # static 30 | 2, # dynamic 31 | 2, # ground 32 | 3, # road 33 | 4, # sidewalk 34 | 3, # parking 35 | 3, # rail track 36 | 5, # building 37 | 6, # wall 38 | 7, # fence 39 | 7, # guard rail 40 | 3, # bridge 41 | 3, # tunnel 42 | 8, # pole 43 | 8, # polegroup 44 | 9, # traffic light 45 | 10, # traffic sign 46 | 11, # vegetation 47 | 12, # terrain 48 | 13, # sky 49 | 14, # person 50 | 14, # rider 51 | 15, # car 52 | 16, # truck 53 | 17, # bus 54 | 16, # caravan 55 | 16, # trailer 56 | 18, # train 57 | 19, # motorcycle 58 | 20] # bicycle 59 | 60 | # 61 | # new class label names, corresponding to remapped class ID's (range 0-21) 62 | # 63 | """ 64 | void 65 | ego_vehicle 66 | ground 67 | road 68 | sidewalk 69 | building 70 | wall 71 | fence 72 | pole 73 | traffic_light 74 | traffic_sign 75 | vegetation 76 | terrain 77 | sky 78 | person 79 | car 80 | truck 81 | bus 82 | train 83 | motorcycle 84 | bicycle 85 | """ 86 | 87 | def remap_labels(filename): 88 | print(filename) 89 | img = Image.open(filename) 90 | 91 | for y in range(img.height): 92 | for x in range(img.width): 93 | org_label = img.getpixel((x,y)) 94 | img.putpixel((x,y), LABEL_MAP[org_label]) 95 | 96 | img.save(filename) 97 | 98 | 99 | if __name__ == "__main__": 100 | 101 | parser = argparse.ArgumentParser(description='Remap Cityscapes Segmenation Labels') 102 | parser.add_argument('dir', metavar='DIR', help='path to data labels (e.g. gtCoarse/train, ect.)') 103 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 104 | help='number of data loading workers (default: 8)') 105 | args = parser.parse_args() 106 | 107 | img_list = [] 108 | 109 | for city in os.listdir(args.dir): 110 | img_dir = os.path.join(args.dir, city) 111 | 112 | for file_name in os.listdir(img_dir): 113 | if file_name.find("labelIds.png") == -1: 114 | continue 115 | 116 | img_list.append(os.path.join(img_dir, file_name)) 117 | 118 | with ProcessPool(processes=args.workers) as pool: 119 | pool.map(remap_labels, img_list) 120 | 121 | -------------------------------------------------------------------------------- /datasets/cityscapes_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import torch.utils.data 5 | import torchvision 6 | 7 | from PIL import Image 8 | 9 | 10 | # 11 | # note: is is slow to remap the label categories at runtime, 12 | # so use cityscapes_remap.py to do it in advance. 13 | # 14 | class FilterAndRemapCityscapesCategories(object): 15 | def __init__(self, categories, classes): 16 | self.categories = categories 17 | self.classes = classes 18 | print (self.classes) 19 | 20 | def __call__(self, image, anno): 21 | 22 | anno = copy.deepcopy(anno) 23 | for y in range(anno.height): 24 | for x in range(anno.width): 25 | org_label = anno.getpixel((x,y)) 26 | 27 | if org_label not in self.categories: 28 | anno.putpixel((x,y), 0) 29 | 30 | return image, anno 31 | 32 | 33 | def get_cityscapes(root, image_set, transforms): 34 | 35 | #CAT_LIST = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] 36 | 37 | #transforms = Compose([ 38 | # FilterAndRemapCityscapesCategories(CAT_LIST, torchvision.datasets.Cityscapes.classes), 39 | # transforms 40 | #]) 41 | 42 | dataset = torchvision.datasets.Cityscapes(root, split=image_set, mode='fine', target_type='semantic', 43 | transform=transforms, target_transform=transforms) 44 | 45 | return dataset 46 | -------------------------------------------------------------------------------- /datasets/coco_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import torch.utils.data 5 | import torchvision 6 | 7 | from PIL import Image 8 | from transforms import Compose 9 | from pycocotools import mask as coco_mask 10 | 11 | 12 | class FilterAndRemapCocoCategories(object): 13 | def __init__(self, categories, remap=True): 14 | self.categories = categories 15 | self.remap = remap 16 | 17 | def __call__(self, image, anno): 18 | anno = [obj for obj in anno if obj["category_id"] in self.categories] 19 | if not self.remap: 20 | return image, anno 21 | anno = copy.deepcopy(anno) 22 | for obj in anno: 23 | obj["category_id"] = self.categories.index(obj["category_id"]) 24 | return image, anno 25 | 26 | 27 | def convert_coco_poly_to_mask(segmentations, height, width): 28 | masks = [] 29 | for polygons in segmentations: 30 | rles = coco_mask.frPyObjects(polygons, height, width) 31 | mask = coco_mask.decode(rles) 32 | if len(mask.shape) < 3: 33 | mask = mask[..., None] 34 | mask = torch.as_tensor(mask, dtype=torch.uint8) 35 | mask = mask.any(dim=2) 36 | masks.append(mask) 37 | if masks: 38 | masks = torch.stack(masks, dim=0) 39 | else: 40 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 41 | return masks 42 | 43 | 44 | class ConvertCocoPolysToMask(object): 45 | def __call__(self, image, anno): 46 | w, h = image.size 47 | segmentations = [obj["segmentation"] for obj in anno] 48 | cats = [obj["category_id"] for obj in anno] 49 | if segmentations: 50 | masks = convert_coco_poly_to_mask(segmentations, h, w) 51 | cats = torch.as_tensor(cats, dtype=masks.dtype) 52 | # merge all instance masks into a single segmentation map 53 | # with its corresponding categories 54 | target, _ = (masks * cats[:, None, None]).max(dim=0) 55 | # discard overlapping instances 56 | target[masks.sum(0) > 1] = 255 57 | else: 58 | target = torch.zeros((h, w), dtype=torch.uint8) 59 | target = Image.fromarray(target.numpy()) 60 | return image, target 61 | 62 | 63 | def _coco_remove_images_without_annotations(dataset, cat_list=None): 64 | def _has_valid_annotation(anno): 65 | # if it's empty, there is no annotation 66 | if len(anno) == 0: 67 | return False 68 | # if more than 1k pixels occupied in the image 69 | return sum(obj["area"] for obj in anno) > 1000 70 | 71 | assert isinstance(dataset, torchvision.datasets.CocoDetection) 72 | ids = [] 73 | for ds_idx, img_id in enumerate(dataset.ids): 74 | ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) 75 | anno = dataset.coco.loadAnns(ann_ids) 76 | if cat_list: 77 | anno = [obj for obj in anno if obj["category_id"] in cat_list] 78 | if _has_valid_annotation(anno): 79 | ids.append(ds_idx) 80 | 81 | dataset = torch.utils.data.Subset(dataset, ids) 82 | return dataset 83 | 84 | 85 | def get_coco(root, image_set, transforms): 86 | PATHS = { 87 | "train": ("train2017", os.path.join("annotations", "instances_train2017.json")), 88 | "val": ("val2017", os.path.join("annotations", "instances_val2017.json")), 89 | # "train": ("val2017", os.path.join("annotations", "instances_val2017.json")) 90 | } 91 | CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 92 | 1, 64, 20, 63, 7, 72] 93 | 94 | transforms = Compose([ 95 | FilterAndRemapCocoCategories(CAT_LIST, remap=True), 96 | ConvertCocoPolysToMask(), 97 | transforms 98 | ]) 99 | 100 | img_folder, ann_file = PATHS[image_set] 101 | img_folder = os.path.join(root, img_folder) 102 | ann_file = os.path.join(root, ann_file) 103 | 104 | dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) 105 | 106 | if image_set == "train": 107 | dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST) 108 | 109 | return dataset 110 | -------------------------------------------------------------------------------- /datasets/corrupt_images.py: -------------------------------------------------------------------------------- 1 | # 2 | # this script detects corrupted images in a directory, 3 | # and optionally moves them to a specified directory 4 | # 5 | import argparse 6 | import warnings 7 | import shutil 8 | 9 | from os import listdir 10 | from os import remove 11 | from os.path import join 12 | 13 | from PIL import Image 14 | 15 | parser = argparse.ArgumentParser(description='corrupt image remover') 16 | parser.add_argument('dir', metavar='DIR', help='path to directory') 17 | parser.add_argument('--move', type=str, default=None, help='optional path to directory that corrupt images are moved to') 18 | args = parser.parse_args() 19 | 20 | num_bad = 0 21 | 22 | warnings.filterwarnings("error") 23 | 24 | 25 | for filename in listdir(args.dir): 26 | file_low = filename.lower() 27 | if file_low.endswith('.png') or file_low.endswith('.jpg') or file_low.endswith('.jpeg') or file_low.endswith('.gif'): 28 | file_path = join(args.dir,filename) 29 | try: 30 | #img = Image.open(file_path) # open the image file 31 | #img.verify() # verify that it is, in fact an image 32 | 33 | img = Image.open(file_path) 34 | img.load() 35 | 36 | imgRGB = img.convert('RGB') 37 | 38 | #if img.width < 16 or img.height < 16: 39 | # print('Strange image dimensions ({:d}x{:d}): {:s}'.format(img.width, img.height, file_path)) 40 | 41 | if img.width < 16 or img.height < 16: 42 | print('Bad image dimensions ({:d}x{:d}): {:s}'.format(img.width, img.height, file_path)) # print out the names of corrupt files 43 | 44 | if args.move is not None: 45 | shutil.move(file_path, args.move) #remove(file_path) 46 | 47 | num_bad = num_bad + 1 48 | 49 | except (IOError, SyntaxError, UserWarning, RuntimeWarning) as e: 50 | print('Bad image: {:s}'.format(file_path)) # print out the names of corrupt files 51 | 52 | if args.move is not None: 53 | shutil.move(file_path, args.move) #remove(file_path) 54 | 55 | num_bad = num_bad + 1 56 | 57 | print('Detected {:d} corrupted images from {:s} '.format(num_bad, args.dir)) 58 | -------------------------------------------------------------------------------- /datasets/custom_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import math 4 | import torch 5 | 6 | from PIL import Image 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | 10 | class CustomSegmentation(Dataset): 11 | """ 12 | Based on ADE20K dataset format 13 | 14 | images/ 15 | training/ 16 | validation/ 17 | 18 | annotations/ 19 | training/ 20 | validation/ 21 | 22 | """ 23 | 24 | 25 | 26 | def __init__(self, root_dir, image_set='train', transforms=None): 27 | 28 | self.images = [] 29 | self.targets = [] 30 | self.transforms = transforms 31 | 32 | if image_set == 'train': 33 | train_images, train_targets = self.gather_images(os.path.join(root_dir, 'images/training'), 34 | os.path.join(root_dir, 'annotations/training')) 35 | 36 | self.images.extend(train_images) 37 | self.targets.extend(train_targets) 38 | 39 | elif image_set == 'val': 40 | val_images, val_targets = self.gather_images(os.path.join(root_dir, 'images/validation'), 41 | os.path.join(root_dir, 'annotations/validation')) 42 | 43 | self.images.extend(val_images) 44 | self.targets.extend(val_targets) 45 | 46 | def gather_images(self, images_path, labels_path): 47 | def sorted_alphanumeric(data): 48 | convert = lambda text: int(text) if text.isdigit() else text.lower() 49 | alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 50 | return sorted(data, key=alphanum_key) 51 | 52 | image_files = sorted_alphanumeric(os.listdir(images_path)) 53 | label_files = sorted_alphanumeric(os.listdir(labels_path)) 54 | 55 | if len(image_files) != len(label_files): 56 | print('warning: images path has a different number of files than labels path') 57 | print(' ({:d} files) - {:s}'.format(len(image_files), images_path)) 58 | print(' ({:d} files) - {:s}'.format(len(label_files), labels_path)) 59 | 60 | for n in range(len(image_files)): 61 | image_files[n] = os.path.join(images_path, image_files[n]) 62 | label_files[n] = os.path.join(labels_path, label_files[n]) 63 | 64 | #print('{:s} -> {:s}'.format(image_files[n], label_files[n])) 65 | 66 | return image_files, label_files 67 | 68 | def __len__(self): 69 | return len(self.images) 70 | 71 | def __getitem__(self, index): 72 | image = Image.open(self.images[index]).convert('RGB') 73 | target = Image.open(self.targets[index]) 74 | 75 | if self.transforms is not None: 76 | image, target = self.transforms(image, target) 77 | 78 | return image, target 79 | 80 | -------------------------------------------------------------------------------- /datasets/deepscene.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import math 4 | import torch 5 | 6 | from PIL import Image 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | 10 | class DeepSceneSegmentation(Dataset): 11 | """http://deepscene.cs.uni-freiburg.de/""" 12 | 13 | def __init__(self, root_dir, image_set='train', train_extra=True, transforms=None): 14 | """ 15 | Parameters: 16 | root_dir (string): Root directory of the dumped NYU-Depth dataset. 17 | image_set (string, optional): Select the image_set to use, ``train``, ``val`` 18 | train_extra (bool, optional): If True, use extra images during training 19 | transforms (callable, optional): Optional transform to be applied 20 | on a sample. 21 | """ 22 | self.root_dir = root_dir 23 | self.image_set = image_set 24 | self.transforms = transforms 25 | 26 | self.images = [] 27 | self.targets = [] 28 | 29 | if image_set == 'train': 30 | train_images, train_targets = self.gather_images(os.path.join(root_dir, 'train/rgb'), 31 | os.path.join(root_dir, 'train/GT_index')) 32 | 33 | self.images.extend(train_images) 34 | self.targets.extend(train_targets) 35 | 36 | if train_extra: 37 | extra_images, extra_targets = self.gather_images(os.path.join(root_dir, 'trainextra/rgb'), 38 | os.path.join(root_dir, 'trainextra/GT_index')) 39 | 40 | self.images.extend(extra_images) 41 | self.targets.extend(extra_targets) 42 | 43 | elif image_set == 'val': 44 | val_images, val_targets = self.gather_images(os.path.join(root_dir, 'test/rgb'), 45 | os.path.join(root_dir, 'test/GT_index')) 46 | 47 | self.images.extend(val_images) 48 | self.targets.extend(val_targets) 49 | 50 | def gather_images(self, images_path, labels_path): 51 | def sorted_alphanumeric(data): 52 | convert = lambda text: int(text) if text.isdigit() else text.lower() 53 | alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 54 | return sorted(data, key=alphanum_key) 55 | 56 | image_files = sorted_alphanumeric(os.listdir(images_path)) 57 | label_files = sorted_alphanumeric(os.listdir(labels_path)) 58 | 59 | if len(image_files) != len(label_files): 60 | print('warning: images path has a different number of files than labels path') 61 | print(' ({:d} files) - {:s}'.format(len(image_files), images_path)) 62 | print(' ({:d} files) - {:s}'.format(len(label_files), labels_path)) 63 | 64 | for n in range(len(image_files)): 65 | image_files[n] = os.path.join(images_path, image_files[n]) 66 | label_files[n] = os.path.join(labels_path, label_files[n]) 67 | 68 | #print('{:s} -> {:s}'.format(image_files[n], label_files[n])) 69 | 70 | return image_files, label_files 71 | 72 | def __len__(self): 73 | return len(self.images) 74 | 75 | def __getitem__(self, index): 76 | image = Image.open(self.images[index]).convert('RGB') 77 | target = Image.open(self.targets[index]) 78 | 79 | if self.transforms is not None: 80 | image, target = self.transforms(image, target) 81 | 82 | return image, target 83 | 84 | -------------------------------------------------------------------------------- /datasets/deepscene_remap.py: -------------------------------------------------------------------------------- 1 | # 2 | # this script remaps the original DeepScene friburg_forest dataset annotations 3 | # from RGB images (with 6 classes) to single-channel index images (with 5 classes). 4 | # 5 | # the original dataset can be downloaded from: http://deepscene.cs.uni-freiburg.de/ 6 | # 7 | import os 8 | import copy 9 | import argparse 10 | 11 | from PIL import Image 12 | from multiprocessing import Pool as ProcessPool 13 | 14 | 15 | # 16 | # map of existing class label ID's and colors to new ID's 17 | # each entry consists of a tuple (new_ID, name, color) 18 | # 19 | CLASS_MAP = [ (0, 'trail', (170, 170, 170)), 20 | (1, 'grass', (0, 255, 0)), 21 | (2, 'vegetation', (102, 102, 51)), 22 | (3, 'obstacle', (0, 0, 0)), 23 | (4, 'sky', (0, 120, 255)), 24 | (2, 'void', (0, 60, 0)) ] # 'void' appears to be trees in the dataset, so it is mapped to vegetation 25 | 26 | 27 | def lookup_class(color): 28 | for c in CLASS_MAP: 29 | if color == c[2]: 30 | return c[0] 31 | 32 | print('could not find class with color ' + str(color)) 33 | return -1 34 | 35 | 36 | def remap_labels(args): 37 | input_path = args[0] 38 | output_path = args[1] 39 | colorized = args[2] 40 | 41 | print('{:s} -> {:s}'.format(input_path, output_path)) 42 | 43 | if os.path.isfile(output_path): 44 | print('skipping image {:s}, already exists'.format(output_path)) 45 | return 46 | 47 | img_input = Image.open(input_path) 48 | img_output = Image.new('RGB' if colorized is True else 'L', (img_input.width, img_input.height)) 49 | 50 | for y in range(img_input.height): 51 | for x in range(img_input.width): 52 | org_label = img_input.getpixel((x,y)) 53 | new_label = CLASS_MAP[lookup_class(org_label)][2 if colorized else 0] 54 | img_output.putpixel((x,y), new_label) 55 | 56 | img_output.save(output_path) 57 | 58 | 59 | if __name__ == "__main__": 60 | 61 | parser = argparse.ArgumentParser(description='Remap DeepScene Segmentation Images') 62 | parser.add_argument('input', type=str, metavar='IN', help='path to directory of annotated images to remap') 63 | parser.add_argument('output', type=str, metavar='OUT', help='path to directory to save remaped annotation images') 64 | parser.add_argument('--colorized', action='store_true', help='output colorized segmentation maps (RGB)') 65 | parser.add_argument('--workers', default=8, type=int, metavar='N', help='number of data loading workers (default: 8)') 66 | args = parser.parse_args() 67 | 68 | if not os.path.exists(args.output): 69 | os.makedirs(args.output) 70 | 71 | files = os.listdir(args.input) 72 | worker_args = [] 73 | 74 | for n in range(len(files)): 75 | worker_args.append((os.path.join(args.input, files[n]), os.path.join(args.output, files[n]), args.colorized)) 76 | 77 | #for n in worker_args: 78 | # remap_labels(n) 79 | 80 | with ProcessPool(processes=args.workers) as pool: 81 | pool.map(remap_labels, worker_args) 82 | 83 | -------------------------------------------------------------------------------- /datasets/mhp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | 5 | from PIL import Image 6 | from .mhp_utils import mhp_image_list 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | 10 | class MHPSegmentation(Dataset): 11 | """https://lv-mhp.github.io/dataset""" 12 | 13 | def __init__(self, root_dir, image_set='train', transforms=None): 14 | """ 15 | Parameters: 16 | root_dir (string): Root directory of the extracted LV-MHP-V2 dataset. 17 | image_set (string, optional): Select the image_set to use, ``train``, ``val`` 18 | transforms (callable, optional): Optional transform to be applied 19 | on a sample. 20 | """ 21 | self.root_dir = root_dir 22 | self.image_set = image_set 23 | self.transforms = transforms 24 | 25 | self.images = [] 26 | self.targets = [] 27 | 28 | img_list = mhp_image_list(os.path.join(root_dir, 'list/{:s}.txt'.format(image_set))) 29 | 30 | for img_index in img_list: 31 | img_filename = os.path.join(root_dir, image_set, 'images/{:d}.jpg'.format(img_index)) 32 | target_filename = os.path.join(root_dir, image_set, 'parsing_annos/{:d}.png'.format(img_index)) 33 | 34 | if os.path.isfile(img_filename) and os.path.isfile(target_filename): 35 | self.images.append(img_filename) 36 | self.targets.append(target_filename) 37 | 38 | def __len__(self): 39 | return len(self.images) 40 | 41 | def __getitem__(self, index): 42 | image = Image.open(self.images[index]).convert('RGB') 43 | target = Image.open(self.targets[index]) 44 | 45 | if self.transforms is not None: 46 | image, target = self.transforms(image, target) 47 | 48 | return image, target 49 | 50 | -------------------------------------------------------------------------------- /datasets/mhp_remap.py: -------------------------------------------------------------------------------- 1 | # 2 | # this script remaps the original MHP dataset (https://lv-mhp.github.io/) 3 | # class label ID's (range 0-58) to 21 classes that PyTorch FCN-ResNet uses. 4 | # 5 | # see below for the mapping of the original class ID's to the new class ID's, 6 | # along with the class descriptions, some of which are combined from the originals. 7 | # 8 | # this script reads the train/val image list from LV-MHP-V2 and given the path to 9 | # the parsing_annos directory containing the annotations, remaps them to the output. 10 | # 11 | # $ DATA= 12 | # $ python3 mhp_remap.py --list=$DATA/list/train.txt $DATA/train/parsing_annos $DATA/train/parsing_annos_21 13 | # $ python3 mhp_remap.py --list=$DATA/list/val.txt $DATA/val/parsing_annos $DATA/val/parsing_annos_21 14 | # 15 | import os 16 | import copy 17 | import argparse 18 | 19 | from PIL import Image 20 | from mhp_utils import mhp_image_list 21 | from multiprocessing import Pool as ProcessPool 22 | 23 | 24 | # 25 | # map of existing class label ID's (range 0-58) to new ID's (range 0-21) 26 | # 27 | LABEL_MAP = [0, # Background 28 | 1, # Cap/hat 29 | 1, # Helmet 30 | 2, # Face 31 | 3, # Hair 32 | 4, # Left-arm 33 | 4, # Right-arm 34 | 5, # Left-hand 35 | 5, # Right-hand 36 | 19, # Protector 37 | 9, # Bikini/bra 38 | 7, # Jacket/windbreaker/hoodie 39 | 6, # Tee-shirt 40 | 6, # Polo-shirt 41 | 6, # Sweater 42 | 8, # Singlet 43 | 10, # Torso-skin 44 | 11, # Pants 45 | 12, # Shorts/swim-shorts 46 | 12, # Skirt 47 | 13, # Stockings 48 | 13, # Socks 49 | 14, # Left-boot 50 | 14, # Right-boot 51 | 14, # Left-shoe 52 | 14, # Right-shoe 53 | 14, # Left-highheel 54 | 14, # Right-highheel 55 | 14, # Left-sandal 56 | 14, # Right-sandal 57 | 15, # Left-leg 58 | 15, # Right-leg 59 | 16, # Left-foot 60 | 16, # Right-foot 61 | 7, # Coat 62 | 8, # Dress 63 | 8, # Robe 64 | 8, # Jumpsuit 65 | 8, # Other-full-body-clothes 66 | 1, # Headwear 67 | 17, # Backpack 68 | 20, # Ball 69 | 20, # Bats 70 | 19, # Belt 71 | 20, # Bottle 72 | 17, # Carrybag 73 | 17, # Cases 74 | 18, # Sunglasses 75 | 18, # Eyewear 76 | 19, # Glove 77 | 19, # Scarf 78 | 20, # Umbrella 79 | 17, # Wallet/purse 80 | 19, # Watch 81 | 19, # Wristband 82 | 19, # Tie 83 | 19, # Other-accessory 84 | 6, # Other-upper-body-clothes 85 | 11] # Other-lower-body-clothes 86 | 87 | # 88 | # new class label names, corresponding to remapped class ID's (range 0-21) 89 | # 90 | """ 91 | void 92 | hat/helmet/headwear 93 | face 94 | hair 95 | arm 96 | hand 97 | shirt 98 | jacket/coat 99 | dress/robe 100 | bikini/bra 101 | torso_skin 102 | pants 103 | shorts 104 | socks/stockings 105 | shoe/boot 106 | leg 107 | foot 108 | backpack/purse/bag 109 | sunglasses/eyewear 110 | other_accessory 111 | other_item 112 | """ 113 | 114 | def remap_labels(args): 115 | input_dir = args[0] 116 | output_dir = args[1] 117 | img_index = args[2] 118 | 119 | src_images = 0 120 | img_output = None 121 | 122 | # check if this image has already been processed (i.e. by a previous run) 123 | output_path = os.path.join(output_dir, '{:d}.png'.format(img_index)) 124 | 125 | if os.path.isfile(output_path): 126 | print('skipping image {:d}, already exists'.format(img_index)) 127 | return 128 | 129 | # determine the number of source images for this frame 130 | for n in range(1,30): 131 | if os.path.isfile(os.path.join(input_dir, '{:d}_{:02d}_01.png'.format(img_index, n))): 132 | src_images = n 133 | 134 | print('processing image {:d} \t(src_images={:d})'.format(img_index, src_images)) 135 | 136 | # aggregate and remap the source images into one output 137 | for n in range(1, src_images+1): 138 | img_input = Image.open(os.path.join(input_dir, '{:d}_{:02d}_{:02d}.png'.format(img_index, src_images, n))) 139 | 140 | if img_output is None: 141 | img_output = Image.new('L', (img_input.width, img_input.height)) 142 | 143 | for y in range(img_input.height): 144 | for x in range(img_input.width): 145 | org_label = img_input.getpixel((x,y))[0] 146 | new_label = LABEL_MAP[org_label] 147 | 148 | if new_label != 0: # only overwrite non-background pixels 149 | img_output.putpixel((x,y), new_label) 150 | 151 | #if org_label != 0: 152 | # print('img {:d}_{:02d} ({:d}, {:d}) {:d} -> {:d}'.format(img_index, n, x, y, org_label, new_label)) 153 | 154 | img_output.save(output_path) 155 | 156 | 157 | if __name__ == "__main__": 158 | 159 | parser = argparse.ArgumentParser(description='Remap MHP Annotation Images') 160 | parser.add_argument('input', type=str, metavar='IN', help='path to directory of annotated images to remap') 161 | parser.add_argument('output', type=str, metavar='OUT', help='path to directory to save remaped annotation images') 162 | parser.add_argument('--list', type=str, required=True, metavar='LIST', help='path to image list') 163 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers (default: 8)') 164 | args = parser.parse_args() 165 | 166 | if not os.path.exists(args.output): 167 | os.makedirs(args.output) 168 | 169 | img_list = mhp_image_list(args.list) 170 | pool_args = [] 171 | 172 | for img_index in img_list: 173 | pool_args.append( (args.input, args.output, img_index) ) 174 | 175 | with ProcessPool(processes=args.workers) as pool: 176 | pool.map(remap_labels, pool_args) 177 | 178 | -------------------------------------------------------------------------------- /datasets/mhp_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def mhp_image_list(filename): 4 | """ 5 | Read one of the image index lists from LV-MHP-v2/list 6 | 7 | Parameters: 8 | filename (string): path to the image list file 9 | 10 | Returns: 11 | list (int): list of int's that correspond to image names 12 | """ 13 | list_file = open(filename, 'r') 14 | img_list = [] 15 | 16 | while True: 17 | next_line = list_file.readline() 18 | 19 | if not next_line: 20 | break 21 | 22 | img_list.append(int(next_line)) 23 | 24 | return img_list 25 | 26 | -------------------------------------------------------------------------------- /datasets/nyu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | 5 | from PIL import Image 6 | from torch.utils.data import Dataset, DataLoader 7 | 8 | 9 | class NYUDepth(Dataset): 10 | """https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html""" 11 | 12 | def __init__(self, root_dir, image_set='train', transforms=None): 13 | """ 14 | Parameters: 15 | root_dir (string): Root directory of the dumped NYU-Depth dataset. 16 | image_set (string, optional): Select the image_set to use, ``train``, ``val`` 17 | transforms (callable, optional): Optional transform to be applied 18 | on a sample. 19 | """ 20 | self.root_dir = root_dir 21 | self.image_set = image_set 22 | self.transforms = transforms 23 | 24 | self.images = [] 25 | self.targets = [] 26 | 27 | img_list = self.read_image_list(os.path.join(root_dir, '{:s}.txt'.format(image_set))) 28 | 29 | for img_name in img_list: 30 | img_filename = os.path.join(root_dir, 'images/{:s}'.format(img_name)) 31 | target_filename = os.path.join(root_dir, 'depth/{:s}'.format(img_name)) 32 | 33 | if os.path.isfile(img_filename) and os.path.isfile(target_filename): 34 | self.images.append(img_filename) 35 | self.targets.append(target_filename) 36 | 37 | def read_image_list(self, filename): 38 | """ 39 | Read one of the image index lists 40 | 41 | Parameters: 42 | filename (string): path to the image list file 43 | 44 | Returns: 45 | list (int): list of strings that correspond to image names 46 | """ 47 | list_file = open(filename, 'r') 48 | img_list = [] 49 | 50 | while True: 51 | next_line = list_file.readline() 52 | 53 | if not next_line: 54 | break 55 | 56 | img_list.append(next_line.rstrip()) 57 | 58 | return img_list 59 | 60 | def __len__(self): 61 | return len(self.images) 62 | 63 | def __getitem__(self, index): 64 | image = Image.open(self.images[index]).convert('RGB') 65 | target = Image.open(self.targets[index]) 66 | 67 | if self.transforms is not None: 68 | image, target = self.transforms(image, target) 69 | 70 | return image, target 71 | 72 | -------------------------------------------------------------------------------- /datasets/nyu_dump.py: -------------------------------------------------------------------------------- 1 | # 2 | # this script reads the .mat files from the NYU-Depth datasets 3 | # and dumps their contents to individual image files for training 4 | # 5 | import os 6 | import random 7 | import argparse 8 | import h5py 9 | import numpy as np 10 | 11 | from PIL import Image 12 | 13 | 14 | parser = argparse.ArgumentParser(description='Dump NYU-Depth .mat files') 15 | parser.add_argument('input', type=str, nargs='+', metavar='IN', help='paths to input .mat files') 16 | parser.add_argument('--output', type=str, default='dump', metavar='OUT', help='path to directory to save dataset') 17 | parser.add_argument("--images", action="store_true", help="dump RGB images") 18 | parser.add_argument("--labels", action="store_true", help="dump label images") 19 | parser.add_argument("--depth", action="store_true", help="dump depth images") 20 | parser.add_argument("--depth-levels", type=int, default=20, help="number of disparity depth levels (default: 20)") 21 | parser.add_argument("--split", action="store_true", help="dump train/val split files") 22 | parser.add_argument("--split-val", type=float, default=0.15, help="fraction of dataset to split between train/val") 23 | 24 | args = parser.parse_args() 25 | 26 | input_images = [] 27 | input_labels = [] 28 | input_depths = [] 29 | 30 | global_depth_min = 1000.0 31 | global_depth_max = -1000.0 32 | 33 | 34 | # 35 | # load arrays from .mat files 36 | # 37 | for filename in args.input: 38 | print('\n==> loading ' + filename) 39 | mat = h5py.File(filename, 'r') 40 | print(list(mat.keys())) 41 | 42 | if args.images or args.split: 43 | print('reading images') 44 | 45 | images = mat['images'] 46 | images = np.array(images) 47 | 48 | print(images.shape) 49 | 50 | pixel_min = images.min() 51 | pixel_max = images.max() 52 | pixel_avg = images.mean() 53 | 54 | print('min pixel: {:f}'.format(pixel_min)) 55 | print('max pixel: {:f}'.format(pixel_max)) 56 | print('avg pixel: {:f}'.format(pixel_avg)) 57 | 58 | input_images.append(images) 59 | 60 | if args.depth: 61 | print('reading depths') 62 | 63 | depths = mat['depths'] 64 | depths = np.array(depths) 65 | 66 | print(depths.shape) 67 | 68 | depth_min = depths.min() 69 | depth_max = depths.max() 70 | depth_avg = depths.mean() 71 | 72 | print('min depth: {:f}'.format(depth_min)) 73 | print('max depth: {:f}'.format(depth_max)) 74 | print('avg depth: {:f}'.format(depth_avg)) 75 | 76 | if depth_min < global_depth_min: 77 | global_depth_min = depth_min 78 | 79 | if depth_max > global_depth_max: 80 | global_depth_max = depth_max 81 | 82 | input_depths.append(depths) 83 | 84 | print('') 85 | 86 | 87 | # 88 | # process source images 89 | # 90 | if args.images: 91 | images_path = os.path.join(args.output, 'images') 92 | 93 | if not os.path.exists(images_path): 94 | os.makedirs(images_path) 95 | 96 | for n in range(len(input_images)): 97 | for m in range(input_images[n].shape[0]): 98 | img_name = 'v{:d}_{:04d}.png'.format(n+1, m) 99 | img_path = os.path.join(images_path, img_name) 100 | 101 | print('processing image ' + img_path) 102 | 103 | img_in = input_images[n][m] 104 | img_in = np.moveaxis(img_in, [0, 1, 2], [2, 1, 0]) 105 | #print(img_in.shape) 106 | 107 | img_out = Image.fromarray(img_in.astype('uint8'), 'RGB') 108 | print(img_out.size) 109 | img_out.save(img_path) 110 | 111 | 112 | # 113 | # process depth images 114 | # 115 | if args.depth: 116 | print('global min depth: {:f}'.format(global_depth_min)) 117 | print('global max depth: {:f}'.format(global_depth_max)) 118 | 119 | depth_path = os.path.join(args.output, 'depth') 120 | 121 | if not os.path.exists(depth_path): 122 | os.makedirs(depth_path) 123 | 124 | for n in range(len(input_depths)): 125 | print('\nprocessing depth/v{:d}'.format(n+1)) 126 | 127 | arr = input_depths[n] 128 | 129 | # rescale the depths to lie between [0, args.depth_levels] 130 | arr = np.subtract(arr, global_depth_min) 131 | arr = np.multiply(arr, 1.0 / (global_depth_max - global_depth_min) * float(args.depth_levels)) 132 | 133 | depth_min = arr.min() 134 | depth_max = arr.max() 135 | depth_avg = arr.mean() 136 | 137 | print('min depth: {:f}'.format(depth_min)) 138 | print('max depth: {:f}'.format(depth_max)) 139 | print('avg depth: {:f}'.format(depth_avg)) 140 | 141 | for m in range(arr.shape[0]): 142 | img_name = 'v{:d}_{:04d}.png'.format(n+1, m) 143 | img_path = os.path.join(depth_path, img_name) 144 | 145 | print('processing depth ' + img_path) 146 | 147 | img_in = arr[m] 148 | img_in = np.moveaxis(img_in, [0, 1], [1, 0]) 149 | #print(img_in.shape) 150 | 151 | img_out = Image.fromarray(img_in.astype('uint8'), 'L') 152 | print(img_out.size) 153 | img_out.save(img_path) 154 | 155 | # 156 | # output train/val splits 157 | # 158 | if args.split: 159 | print('creating train/val splits') 160 | 161 | train_file = open(os.path.join(args.output, 'train.txt'), 'w') 162 | val_file = open(os.path.join(args.output, 'val.txt'), 'w') 163 | 164 | train_count = 0 165 | val_count = 0 166 | 167 | for n in range(len(input_images)): 168 | for m in range(input_images[n].shape[0]): 169 | img_name = 'v{:d}_{:04d}.png'.format(n+1, m) 170 | rand = random.random() 171 | 172 | if rand < args.split_val: 173 | val_file.write(img_name + '\n') 174 | val_count = val_count + 1 175 | else: 176 | train_file.write(img_name + '\n') 177 | train_count = train_count + 1 178 | 179 | print('total: {:d}'.format(train_count + val_count)) 180 | print('train: {:d}'.format(train_count)) 181 | print('val: {:d}'.format(val_count)) 182 | 183 | train_file.close() 184 | val_file.close() 185 | 186 | -------------------------------------------------------------------------------- /datasets/sun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | 5 | from PIL import Image 6 | from torch.utils.data import Dataset, DataLoader 7 | 8 | 9 | class SunRGBDSegmentation(Dataset): 10 | """http://rgbd.cs.princeton.edu/challenge.html""" 11 | """https://github.com/ankurhanda/sunrgbd-meta-data""" 12 | 13 | def __init__(self, root_dir, image_set='train', train_extra=True, transforms=None): 14 | """ 15 | Parameters: 16 | root_dir (string): Root directory of the dumped NYU-Depth dataset. 17 | image_set (string, optional): Select the image_set to use, ``train``, ``val`` 18 | train_extra (bool, optional): If True, use extra images during training 19 | transforms (callable, optional): Optional transform to be applied 20 | on a sample. 21 | """ 22 | self.root_dir = root_dir 23 | self.image_set = image_set 24 | self.transforms = transforms 25 | 26 | self.images = [] 27 | self.targets = [] 28 | 29 | if image_set == 'train': 30 | train_images, train_targets = self.gather_images(os.path.join(root_dir, 'SUNRGBD-train_images'), 31 | os.path.join(root_dir, 'train21labels')) 32 | 33 | self.images.extend(train_images) 34 | self.targets.extend(train_targets) 35 | 36 | if train_extra: 37 | extra_images, extra_targets = self.gather_images(os.path.join(root_dir, 'SUNRGBD-trainextra_images'), 38 | os.path.join(root_dir, 'trainextra21labels')) 39 | 40 | self.images.extend(extra_images) 41 | self.targets.extend(extra_targets) 42 | 43 | elif image_set == 'val': 44 | val_images, val_targets = self.gather_images(os.path.join(root_dir, 'SUNRGBD-test_images'), 45 | os.path.join(root_dir, 'test21labels')) 46 | 47 | self.images.extend(val_images) 48 | self.targets.extend(val_targets) 49 | 50 | def gather_images(self, images_path, labels_path, max_images=5500): 51 | image_files = [] 52 | label_files = [] 53 | 54 | for n in range(max_images): 55 | image_filename = os.path.join(images_path, 'img-{:06d}.jpg'.format(n)) 56 | label_filename = os.path.join(labels_path, 'img-{:06d}.png'.format(n)) 57 | 58 | if os.path.isfile(image_filename) and os.path.isfile(label_filename): 59 | image_files.append(image_filename) 60 | label_files.append(label_filename) 61 | 62 | return image_files, label_files 63 | 64 | def __len__(self): 65 | return len(self.images) 66 | 67 | def __getitem__(self, index): 68 | image = Image.open(self.images[index]).convert('RGB') 69 | target = Image.open(self.targets[index]) 70 | 71 | if self.transforms is not None: 72 | image, target = self.transforms(image, target) 73 | 74 | return image, target 75 | 76 | -------------------------------------------------------------------------------- /datasets/sun_remap.py: -------------------------------------------------------------------------------- 1 | # 2 | # this script remaps the SUNRGB-D class label ID's (range 0-37) 3 | # to lie within a range that the segmentation networks support (21 classes) 4 | # 5 | # note that this processes the metadata downloaded from here: 6 | # https://github.com/ankurhanda/sunrgbd-meta-data 7 | # 8 | import os 9 | import re 10 | import copy 11 | import argparse 12 | 13 | from PIL import Image 14 | from multiprocessing import Pool as ProcessPool 15 | 16 | 17 | # 18 | # map of existing class label ID's (range 0-37) to new ID's (range 0-21) 19 | # each entry consists of a tuple (new_ID, name, color) 20 | # 21 | CLASS_MAP = [ (0, 'other', (0, 0, 0)), 22 | (1, 'wall', (128, 0, 0)), 23 | (2, 'floor', (0, 128, 0)), 24 | (3, 'cabinet', (128, 128, 0)), 25 | (4, 'bed', (0, 0, 128)), 26 | (5, 'chair', (128, 0, 128)), 27 | (6, 'sofa', (0, 128, 128)), 28 | (7, 'table', (128, 128, 128)), 29 | (8, 'door', (64, 0, 0)), 30 | (9, 'window', (192, 0, 0)), 31 | (3, 'bookshelf', (64, 128, 0)), 32 | (10, 'picture', (192, 128, 0)), 33 | (7, 'counter', (64, 0, 128)), 34 | (11, 'blinds', (192, 0, 128)), 35 | (7, 'desk', (64, 128, 128)), 36 | (3, 'shelves', (192, 128, 128)), 37 | (11, 'curtain', (0, 64, 0)), 38 | (3, 'dresser', (128, 64, 0)), 39 | (4, 'pillow', (0, 192, 0)), 40 | (10, 'mirror', (128, 192, 0)), 41 | (2, 'floor_mat', (0, 64, 128)), 42 | (12, 'clothes', (128, 64, 128)), 43 | (13, 'ceiling', (0, 192, 128)), 44 | (14, 'books', (128, 192, 128)), 45 | (15, 'fridge', (64, 64, 0)), 46 | (10, 'tv', (192, 64, 0)), 47 | (0, 'paper', (64, 192, 0)), 48 | (12, 'towel', (192, 192, 0)), 49 | (20, 'shower_curtain', (64, 64, 128)), 50 | (0, 'box', (192, 64, 128)), 51 | (1, 'whiteboard', (64, 192, 128)), 52 | (16, 'person', (192, 192, 128)), 53 | (7, 'night_stand', (0, 0, 64)), 54 | (17, 'toilet', (128, 0, 64)), 55 | (18, 'sink', (0, 128, 64)), 56 | (19, 'lamp', (128, 128, 64)), 57 | (20, 'bathtub', (0, 0, 192)), 58 | (0, 'bag', (128, 0, 192)) ] 59 | 60 | 61 | # 62 | # new class label names, corresponding to remapped class ID's (range 0-21) 63 | # 64 | """ 65 | other 66 | wall 67 | floor 68 | cabinet/shelves/bookshelf/dresser 69 | bed/pillow 70 | chair 71 | sofa 72 | table 73 | door 74 | window 75 | picture/tv/mirror 76 | blinds/curtain 77 | clothes 78 | ceiling 79 | books 80 | fridge 81 | person 82 | toilet 83 | sink 84 | lamp 85 | bathtub 86 | """ 87 | 88 | 89 | # Generate Color Map in PASCAL VOC format 90 | def generate_color_map(N=38): 91 | """ 92 | https://github.com/meetshah1995/pytorch-semseg/blob/801fb200547caa5b0d91b8dde56b837da029f746/ptsemseg/loader/sunrgbd_loader.py#L108 93 | """ 94 | def bitget(byteval, idx): 95 | return (byteval & (1 << idx)) != 0 96 | 97 | print('') 98 | print('color map: ') 99 | 100 | cmap = [] 101 | 102 | for i in range(N): 103 | r = g = b = 0 104 | c = i 105 | 106 | for j in range(8): 107 | r = r | (bitget(c, 0) << 7 - j) 108 | g = g | (bitget(c, 1) << 7 - j) 109 | b = b | (bitget(c, 2) << 7 - j) 110 | c = c >> 3 111 | 112 | color = (r,g,b) 113 | print(color) 114 | cmap.append(color) 115 | 116 | return cmap 117 | 118 | 119 | def remap_labels(args): 120 | input_path = args[0] 121 | output_path = args[1] 122 | colorized = args[2] 123 | 124 | print('{:s} -> {:s}'.format(input_path, output_path)) 125 | 126 | if os.path.isfile(output_path): 127 | print('skipping image {:s}, already exists'.format(output_path)) 128 | return 129 | 130 | img_input = Image.open(input_path) 131 | img_output = Image.new('RGB' if colorized is True else 'L', (img_input.width, img_input.height)) 132 | 133 | for y in range(img_input.height): 134 | for x in range(img_input.width): 135 | org_label = img_input.getpixel((x,y))#[0] 136 | new_label = CLASS_MAP[org_label][2 if colorized else 0] 137 | img_output.putpixel((x,y), new_label) 138 | 139 | img_output.save(output_path) 140 | 141 | 142 | def sorted_alphanumeric(data): 143 | convert = lambda text: int(text) if text.isdigit() else text.lower() 144 | alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 145 | return sorted(data, key=alphanum_key) 146 | 147 | 148 | if __name__ == "__main__": 149 | 150 | parser = argparse.ArgumentParser(description='Remap SUNRGB-D Segmentation Images') 151 | parser.add_argument('input', type=str, metavar='IN', help='path to directory of annotated images to remap') 152 | parser.add_argument('output', type=str, metavar='OUT', help='path to directory to save remaped annotation images') 153 | parser.add_argument('--colorized', action='store_true', help='output colorized segmentation maps (RGB)') 154 | parser.add_argument('--workers', default=8, type=int, metavar='N', help='number of data loading workers (default: 8)') 155 | args = parser.parse_args() 156 | 157 | if not os.path.exists(args.output): 158 | os.makedirs(args.output) 159 | 160 | files = sorted_alphanumeric(os.listdir(args.input)) 161 | worker_args = [] 162 | 163 | for n in range(len(files)): 164 | worker_args.append((os.path.join(args.input, files[n]), os.path.join(args.output, 'img-{:06d}.png'.format(n+1)), args.colorized)) 165 | 166 | #for n in worker_args: 167 | # remap_labels(n) 168 | 169 | with ProcessPool(processes=args.workers) as pool: 170 | pool.map(remap_labels, worker_args) 171 | 172 | -------------------------------------------------------------------------------- /labelme2voc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import glob 7 | import os 8 | import os.path as osp 9 | import sys 10 | 11 | import imgviz 12 | import numpy as np 13 | 14 | import labelme 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser( 19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 20 | ) 21 | parser.add_argument("input_dir", help="input annotated directory") 22 | parser.add_argument("output_dir", help="output dataset directory") 23 | parser.add_argument("--labels", help="labels file", required=True) 24 | parser.add_argument( 25 | "--noviz", help="no visualization", action="store_true" 26 | ) 27 | args = parser.parse_args() 28 | 29 | if osp.exists(args.output_dir): 30 | print("Output directory already exists:", args.output_dir) 31 | sys.exit(1) 32 | os.makedirs(args.output_dir) 33 | os.makedirs(osp.join(args.output_dir, "JPEGImages")) 34 | os.makedirs(osp.join(args.output_dir, "SegmentationClass")) 35 | if not args.noviz: 36 | os.makedirs( 37 | osp.join(args.output_dir, "SegmentationClassVisualization") 38 | ) 39 | print("Creating dataset:", args.output_dir) 40 | 41 | class_names = [] 42 | class_name_to_id = {} 43 | for i, line in enumerate(open(args.labels).readlines()): 44 | class_id = i # starts with -1 45 | class_name = line.strip() 46 | class_name_to_id[class_name] = class_id 47 | if class_id == 0: 48 | assert class_name == "background" 49 | class_names.append(class_name) 50 | class_names = tuple(class_names) 51 | print("class_names:", class_names) 52 | 53 | for filename in glob.glob(osp.join(args.input_dir, "*.json")): 54 | print("Generating dataset from:", filename) 55 | 56 | label_file = labelme.LabelFile(filename=filename) 57 | 58 | base = osp.splitext(osp.basename(filename))[0] 59 | out_img_file = osp.join(args.output_dir, "JPEGImages", base + ".jpg") 60 | 61 | out_png_file = osp.join( 62 | args.output_dir, "SegmentationClass", base + ".png" 63 | ) 64 | if not args.noviz: 65 | out_viz_file = osp.join( 66 | args.output_dir, 67 | "SegmentationClassVisualization", 68 | base + ".jpg", 69 | ) 70 | 71 | with open(out_img_file, "wb") as f: 72 | f.write(label_file.imageData) 73 | img = labelme.utils.img_data_to_arr(label_file.imageData) 74 | 75 | lbl, _ = labelme.utils.shapes_to_label( 76 | img_shape=img.shape, 77 | shapes=label_file.shapes, 78 | label_name_to_value=class_name_to_id, 79 | ) 80 | labelme.utils.lblsave(out_png_file, lbl) 81 | 82 | if not args.noviz: 83 | viz = imgviz.label2rgb( 84 | label=lbl, 85 | img=imgviz.rgb2gray(img), 86 | font_size=15, 87 | label_names=class_names, 88 | loc="rb", 89 | ) 90 | imgviz.io.imsave(out_viz_file, viz) 91 | 92 | 93 | if __name__ == "__main__": 94 | main() 95 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from . import segmentation 3 | -------------------------------------------------------------------------------- /models/_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class IntermediateLayerGetter(nn.ModuleDict): 8 | """ 9 | Module wrapper that returns intermediate layers from a model 10 | 11 | It has a strong assumption that the modules have been registered 12 | into the model in the same order as they are used. 13 | This means that one should **not** reuse the same nn.Module 14 | twice in the forward if you want this to work. 15 | 16 | Additionally, it is only able to query submodules that are directly 17 | assigned to the model. So if `model` is passed, `model.feature1` can 18 | be returned, but not `model.feature1.layer2`. 19 | 20 | Arguments: 21 | model (nn.Module): model on which we will extract the features 22 | return_layers (Dict[name, new_name]): a dict containing the names 23 | of the modules for which the activations will be returned as 24 | the key of the dict, and the value of the dict is the name 25 | of the returned activation (which the user can specify). 26 | 27 | Examples:: 28 | 29 | >>> m = torchvision.models.resnet18(pretrained=True) 30 | >>> # extract layer1 and layer3, giving as names `feat1` and feat2` 31 | >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, 32 | >>> {'layer1': 'feat1', 'layer3': 'feat2'}) 33 | >>> out = new_m(torch.rand(1, 3, 224, 224)) 34 | >>> print([(k, v.shape) for k, v in out.items()]) 35 | >>> [('feat1', torch.Size([1, 64, 56, 56])), 36 | >>> ('feat2', torch.Size([1, 256, 14, 14]))] 37 | """ 38 | def __init__(self, model, return_layers): 39 | if not set(return_layers).issubset([name for name, _ in model.named_children()]): 40 | raise ValueError("return_layers are not present in model") 41 | 42 | orig_return_layers = return_layers 43 | return_layers = {k: v for k, v in return_layers.items()} 44 | layers = OrderedDict() 45 | for name, module in model.named_children(): 46 | layers[name] = module 47 | if name in return_layers: 48 | del return_layers[name] 49 | if not return_layers: 50 | break 51 | 52 | super(IntermediateLayerGetter, self).__init__(layers) 53 | self.return_layers = orig_return_layers 54 | 55 | def forward(self, x): 56 | out = OrderedDict() 57 | for name, module in self.named_children(): 58 | x = module(x) 59 | if name in self.return_layers: 60 | out_name = self.return_layers[name] 61 | out[out_name] = x 62 | return out 63 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .utils import load_state_dict_from_url 3 | 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 16 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 21 | """3x3 convolution with padding""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=dilation, groups=groups, bias=False, dilation=dilation) 24 | 25 | 26 | def conv1x1(in_planes, out_planes, stride=1): 27 | """1x1 convolution""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 29 | 30 | 31 | class BasicBlock(nn.Module): 32 | expansion = 1 33 | 34 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 35 | base_width=64, dilation=1, norm_layer=None): 36 | super(BasicBlock, self).__init__() 37 | if norm_layer is None: 38 | norm_layer = nn.BatchNorm2d 39 | if groups != 1 or base_width != 64: 40 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 41 | if dilation > 1: 42 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 43 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 44 | self.conv1 = conv3x3(inplanes, planes, stride) 45 | self.bn1 = norm_layer(planes) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.conv2 = conv3x3(planes, planes) 48 | self.bn2 = norm_layer(planes) 49 | self.downsample = downsample 50 | self.stride = stride 51 | 52 | def forward(self, x): 53 | identity = x 54 | 55 | out = self.conv1(x) 56 | out = self.bn1(out) 57 | out = self.relu(out) 58 | 59 | out = self.conv2(out) 60 | out = self.bn2(out) 61 | 62 | if self.downsample is not None: 63 | identity = self.downsample(x) 64 | 65 | out += identity 66 | out = self.relu(out) 67 | 68 | return out 69 | 70 | 71 | class Bottleneck(nn.Module): 72 | expansion = 4 73 | 74 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 75 | base_width=64, dilation=1, norm_layer=None): 76 | super(Bottleneck, self).__init__() 77 | if norm_layer is None: 78 | norm_layer = nn.BatchNorm2d 79 | width = int(planes * (base_width / 64.)) * groups 80 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 81 | self.conv1 = conv1x1(inplanes, width) 82 | self.bn1 = norm_layer(width) 83 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 84 | self.bn2 = norm_layer(width) 85 | self.conv3 = conv1x1(width, planes * self.expansion) 86 | self.bn3 = norm_layer(planes * self.expansion) 87 | self.relu = nn.ReLU(inplace=True) 88 | self.downsample = downsample 89 | self.stride = stride 90 | 91 | def forward(self, x): 92 | identity = x 93 | 94 | out = self.conv1(x) 95 | out = self.bn1(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv2(out) 99 | out = self.bn2(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv3(out) 103 | out = self.bn3(out) 104 | 105 | if self.downsample is not None: 106 | identity = self.downsample(x) 107 | 108 | out += identity 109 | out = self.relu(out) 110 | 111 | return out 112 | 113 | 114 | class ResNet(nn.Module): 115 | 116 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 117 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 118 | norm_layer=None): 119 | super(ResNet, self).__init__() 120 | if norm_layer is None: 121 | norm_layer = nn.BatchNorm2d 122 | self._norm_layer = norm_layer 123 | 124 | self.inplanes = 64 125 | self.dilation = 1 126 | if replace_stride_with_dilation is None: 127 | # each element in the tuple indicates if we should replace 128 | # the 2x2 stride with a dilated convolution instead 129 | replace_stride_with_dilation = [False, False, False] 130 | if len(replace_stride_with_dilation) != 3: 131 | raise ValueError("replace_stride_with_dilation should be None " 132 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 133 | self.groups = groups 134 | self.base_width = width_per_group 135 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 136 | bias=False) 137 | self.bn1 = norm_layer(self.inplanes) 138 | self.relu = nn.ReLU(inplace=True) 139 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 140 | self.layer1 = self._make_layer(block, 64, layers[0]) 141 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 142 | dilate=replace_stride_with_dilation[0]) 143 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 144 | dilate=replace_stride_with_dilation[1]) 145 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 146 | dilate=replace_stride_with_dilation[2]) 147 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 148 | self.fc = nn.Linear(512 * block.expansion, num_classes) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 154 | nn.init.constant_(m.weight, 1) 155 | nn.init.constant_(m.bias, 0) 156 | 157 | # Zero-initialize the last BN in each residual branch, 158 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 159 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 160 | if zero_init_residual: 161 | for m in self.modules(): 162 | if isinstance(m, Bottleneck): 163 | nn.init.constant_(m.bn3.weight, 0) 164 | elif isinstance(m, BasicBlock): 165 | nn.init.constant_(m.bn2.weight, 0) 166 | 167 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 168 | norm_layer = self._norm_layer 169 | downsample = None 170 | previous_dilation = self.dilation 171 | if dilate: 172 | self.dilation *= stride 173 | stride = 1 174 | if stride != 1 or self.inplanes != planes * block.expansion: 175 | downsample = nn.Sequential( 176 | conv1x1(self.inplanes, planes * block.expansion, stride), 177 | norm_layer(planes * block.expansion), 178 | ) 179 | 180 | layers = [] 181 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 182 | self.base_width, previous_dilation, norm_layer)) 183 | self.inplanes = planes * block.expansion 184 | for _ in range(1, blocks): 185 | layers.append(block(self.inplanes, planes, groups=self.groups, 186 | base_width=self.base_width, dilation=self.dilation, 187 | norm_layer=norm_layer)) 188 | 189 | return nn.Sequential(*layers) 190 | 191 | def forward(self, x): 192 | x = self.conv1(x) 193 | x = self.bn1(x) 194 | x = self.relu(x) 195 | x = self.maxpool(x) 196 | 197 | x = self.layer1(x) 198 | x = self.layer2(x) 199 | x = self.layer3(x) 200 | x = self.layer4(x) 201 | 202 | x = self.avgpool(x) 203 | x = x.flatten(1) #x.reshape(x.size(0), -1) 204 | x = self.fc(x) 205 | 206 | return x 207 | 208 | 209 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 210 | model = ResNet(block, layers, **kwargs) 211 | if pretrained: 212 | state_dict = load_state_dict_from_url(model_urls[arch], 213 | progress=progress) 214 | model.load_state_dict(state_dict) 215 | return model 216 | 217 | 218 | def resnet18(pretrained=False, progress=True, **kwargs): 219 | """Constructs a ResNet-18 model. 220 | 221 | Args: 222 | pretrained (bool): If True, returns a model pre-trained on ImageNet 223 | progress (bool): If True, displays a progress bar of the download to stderr 224 | """ 225 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 226 | **kwargs) 227 | 228 | 229 | def resnet34(pretrained=False, progress=True, **kwargs): 230 | """Constructs a ResNet-34 model. 231 | 232 | Args: 233 | pretrained (bool): If True, returns a model pre-trained on ImageNet 234 | progress (bool): If True, displays a progress bar of the download to stderr 235 | """ 236 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 237 | **kwargs) 238 | 239 | 240 | def resnet50(pretrained=False, progress=True, **kwargs): 241 | """Constructs a ResNet-50 model. 242 | 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | progress (bool): If True, displays a progress bar of the download to stderr 246 | """ 247 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 248 | **kwargs) 249 | 250 | 251 | def resnet101(pretrained=False, progress=True, **kwargs): 252 | """Constructs a ResNet-101 model. 253 | 254 | Args: 255 | pretrained (bool): If True, returns a model pre-trained on ImageNet 256 | progress (bool): If True, displays a progress bar of the download to stderr 257 | """ 258 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 259 | **kwargs) 260 | 261 | 262 | def resnet152(pretrained=False, progress=True, **kwargs): 263 | """Constructs a ResNet-152 model. 264 | 265 | Args: 266 | pretrained (bool): If True, returns a model pre-trained on ImageNet 267 | progress (bool): If True, displays a progress bar of the download to stderr 268 | """ 269 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 270 | **kwargs) 271 | 272 | 273 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 274 | """Constructs a ResNeXt-50 32x4d model. 275 | 276 | Args: 277 | pretrained (bool): If True, returns a model pre-trained on ImageNet 278 | progress (bool): If True, displays a progress bar of the download to stderr 279 | """ 280 | kwargs['groups'] = 32 281 | kwargs['width_per_group'] = 4 282 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 283 | pretrained, progress, **kwargs) 284 | 285 | 286 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 287 | """Constructs a ResNeXt-101 32x8d model. 288 | 289 | Args: 290 | pretrained (bool): If True, returns a model pre-trained on ImageNet 291 | progress (bool): If True, displays a progress bar of the download to stderr 292 | """ 293 | kwargs['groups'] = 32 294 | kwargs['width_per_group'] = 8 295 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 296 | pretrained, progress, **kwargs) 297 | -------------------------------------------------------------------------------- /models/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from .segmentation import * 2 | from .fcn import * 3 | from .deeplabv3 import * 4 | -------------------------------------------------------------------------------- /models/segmentation/_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class _SimpleSegmentationModel(nn.Module): 9 | def __init__(self, backbone, classifier, aux_classifier=None, export_onnx=False): 10 | super(_SimpleSegmentationModel, self).__init__() 11 | self.backbone = backbone 12 | self.classifier = classifier 13 | self.aux_classifier = aux_classifier 14 | self.export_onnx = export_onnx 15 | 16 | print('torchvision.models.segmentation.FCN() => configuring model for ' + ('ONNX export' if export_onnx else 'training')) 17 | 18 | 19 | def forward(self, x): 20 | input_shape = x.shape[-2:] 21 | 22 | # contract: features is a dict of tensors 23 | features = self.backbone(x) 24 | x = features["out"] 25 | x = self.classifier(x) 26 | 27 | # TensorRT doesn't support bilinear upsample, so when exporting to ONNX, 28 | # use nearest-neighbor upsampling, and also return a tensor (not an OrderedDict) 29 | if self.export_onnx: 30 | print('FCN configured for export to ONNX') 31 | print('FCN model input size = ' + str(input_shape)) 32 | print('FCN classifier output size = ' + str(x.size())) 33 | 34 | #x = F.interpolate(x, size=(int(input_shape[0]), int(input_shape[1])), mode='nearest') 35 | 36 | print('FCN upsample() output size = ' + str(x.size())) 37 | print('FCN => returning tensor instead of OrderedDict') 38 | return x 39 | 40 | # non-ONNX training/eval path 41 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) 42 | 43 | result = OrderedDict() 44 | result["out"] = x 45 | 46 | if self.aux_classifier is not None: 47 | x = features["aux"] 48 | x = self.aux_classifier(x) 49 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) 50 | result["aux"] = x 51 | 52 | return result 53 | -------------------------------------------------------------------------------- /models/segmentation/deeplabv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from ._utils import _SimpleSegmentationModel 6 | 7 | 8 | __all__ = ["DeepLabV3"] 9 | 10 | 11 | class DeepLabV3(_SimpleSegmentationModel): 12 | """ 13 | Implements DeepLabV3 model from 14 | `"Rethinking Atrous Convolution for Semantic Image Segmentation" 15 | `_. 16 | 17 | Arguments: 18 | backbone (nn.Module): the network used to compute the features for the model. 19 | The backbone should return an OrderedDict[Tensor], with the key being 20 | "out" for the last feature map used, and "aux" if an auxiliary classifier 21 | is used. 22 | classifier (nn.Module): module that takes the "out" element returned from 23 | the backbone and returns a dense prediction. 24 | aux_classifier (nn.Module, optional): auxiliary classifier used during training 25 | """ 26 | pass 27 | 28 | 29 | class DeepLabHead(nn.Sequential): 30 | def __init__(self, in_channels, num_classes): 31 | super(DeepLabHead, self).__init__( 32 | ASPP(in_channels, [12, 24, 36]), 33 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 34 | nn.BatchNorm2d(256), 35 | nn.ReLU(), 36 | nn.Conv2d(256, num_classes, 1) 37 | ) 38 | 39 | 40 | class ASPPConv(nn.Sequential): 41 | def __init__(self, in_channels, out_channels, dilation): 42 | modules = [ 43 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), 44 | nn.BatchNorm2d(out_channels), 45 | nn.ReLU() 46 | ] 47 | super(ASPPConv, self).__init__(*modules) 48 | 49 | 50 | class ASPPPooling(nn.Sequential): 51 | def __init__(self, in_channels, out_channels): 52 | super(ASPPPooling, self).__init__( 53 | nn.AdaptiveAvgPool2d(1), 54 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 55 | nn.BatchNorm2d(out_channels), 56 | nn.ReLU()) 57 | 58 | def forward(self, x): 59 | size = x.shape[-2:] 60 | x = super(ASPPPooling, self).forward(x) 61 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 62 | 63 | 64 | class ASPP(nn.Module): 65 | def __init__(self, in_channels, atrous_rates): 66 | super(ASPP, self).__init__() 67 | out_channels = 256 68 | modules = [] 69 | modules.append(nn.Sequential( 70 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 71 | nn.BatchNorm2d(out_channels), 72 | nn.ReLU())) 73 | 74 | rate1, rate2, rate3 = tuple(atrous_rates) 75 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 76 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 77 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 78 | modules.append(ASPPPooling(in_channels, out_channels)) 79 | 80 | self.convs = nn.ModuleList(modules) 81 | 82 | self.project = nn.Sequential( 83 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 84 | nn.BatchNorm2d(out_channels), 85 | nn.ReLU(), 86 | nn.Dropout(0.5)) 87 | 88 | def forward(self, x): 89 | res = [] 90 | for conv in self.convs: 91 | res.append(conv(x)) 92 | res = torch.cat(res, dim=1) 93 | return self.project(res) 94 | -------------------------------------------------------------------------------- /models/segmentation/fcn.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from ._utils import _SimpleSegmentationModel 4 | 5 | 6 | __all__ = ["FCN"] 7 | 8 | 9 | class FCN(_SimpleSegmentationModel): 10 | """ 11 | Implements a Fully-Convolutional Network for semantic segmentation. 12 | 13 | Arguments: 14 | backbone (nn.Module): the network used to compute the features for the model. 15 | The backbone should return an OrderedDict[Tensor], with the key being 16 | "out" for the last feature map used, and "aux" if an auxiliary classifier 17 | is used. 18 | classifier (nn.Module): module that takes the "out" element returned from 19 | the backbone and returns a dense prediction. 20 | aux_classifier (nn.Module, optional): auxiliary classifier used during training 21 | """ 22 | pass 23 | 24 | 25 | class FCNHead(nn.Sequential): 26 | def __init__(self, in_channels, channels): 27 | inter_channels = in_channels // 4 28 | layers = [ 29 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 30 | nn.BatchNorm2d(inter_channels), 31 | nn.ReLU(), 32 | nn.Dropout(0.1), 33 | nn.Conv2d(inter_channels, channels, 1) 34 | ] 35 | 36 | super(FCNHead, self).__init__(*layers) 37 | -------------------------------------------------------------------------------- /models/segmentation/segmentation.py: -------------------------------------------------------------------------------- 1 | from .._utils import IntermediateLayerGetter 2 | from ..utils import load_state_dict_from_url 3 | from .. import resnet 4 | from .deeplabv3 import DeepLabHead, DeepLabV3 5 | from .fcn import FCN, FCNHead 6 | 7 | 8 | __all__ = ['fcn_resnet18', 'fcn_resnet34', 'fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101'] 9 | 10 | 11 | model_urls = { 12 | 'fcn_resnet18_coco': None, 13 | 'fcn_resnet34_coco': None, 14 | 'fcn_resnet50_coco': None, 15 | 'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth', 16 | 'deeplabv3_resnet50_coco': None, 17 | 'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth', 18 | } 19 | 20 | 21 | def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True, export_onnx=False): 22 | 23 | if backbone_name == "resnet18" or backbone_name == "resnet34": 24 | replace_stride_with_dilation=[False, False, False] 25 | inplanes_scale_factor = 4 26 | else: 27 | replace_stride_with_dilation=[False, True, True] 28 | inplanes_scale_factor = 1 29 | 30 | backbone = resnet.__dict__[backbone_name]( 31 | pretrained=pretrained_backbone, 32 | replace_stride_with_dilation=replace_stride_with_dilation) 33 | 34 | return_layers = {'layer4': 'out'} 35 | if aux: 36 | return_layers['layer3'] = 'aux' 37 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) 38 | 39 | aux_classifier = None 40 | if aux: 41 | inplanes = 1024 / inplanes_scale_factor 42 | aux_classifier = FCNHead(inplanes, num_classes) 43 | 44 | model_map = { 45 | 'deeplab': (DeepLabHead, DeepLabV3), 46 | 'fcn': (FCNHead, FCN), 47 | } 48 | 49 | inplanes = 2048 / inplanes_scale_factor 50 | classifier = model_map[name][0](int(inplanes), int(num_classes)) 51 | base_model = model_map[name][1] 52 | 53 | model = base_model(backbone, classifier, aux_classifier, export_onnx) 54 | return model 55 | 56 | 57 | def fcn_resnet18(pretrained=False, progress=True, 58 | num_classes=21, aux_loss=None, **kwargs): 59 | """Constructs a Fully-Convolutional Network model with a ResNet-18 backbone. 60 | 61 | Args: 62 | pretrained (bool): If True, returns a model pre-trained on COCO train2017 which 63 | contains the same classes as Pascal VOC 64 | progress (bool): If True, displays a progress bar of the download to stderr 65 | """ 66 | print('torchvision.models.segmentation.fcn_resnet18()') 67 | 68 | if pretrained: 69 | aux_loss = True 70 | model = _segm_resnet("fcn", "resnet18", num_classes, aux_loss, **kwargs) 71 | if pretrained: 72 | arch = 'fcn_resnet18_coco' 73 | model_url = model_urls[arch] 74 | if model_url is None: 75 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) 76 | else: 77 | state_dict = load_state_dict_from_url(model_url, progress=progress) 78 | model.load_state_dict(state_dict) 79 | return model 80 | 81 | def fcn_resnet34(pretrained=False, progress=True, 82 | num_classes=21, aux_loss=None, **kwargs): 83 | """Constructs a Fully-Convolutional Network model with a ResNet-34 backbone. 84 | 85 | Args: 86 | pretrained (bool): If True, returns a model pre-trained on COCO train2017 which 87 | contains the same classes as Pascal VOC 88 | progress (bool): If True, displays a progress bar of the download to stderr 89 | """ 90 | print('torchvision.models.segmentation.fcn_resnet34()') 91 | 92 | if pretrained: 93 | aux_loss = True 94 | model = _segm_resnet("fcn", "resnet34", num_classes, aux_loss, **kwargs) 95 | if pretrained: 96 | arch = 'fcn_resnet34_coco' 97 | model_url = model_urls[arch] 98 | if model_url is None: 99 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) 100 | else: 101 | state_dict = load_state_dict_from_url(model_url, progress=progress) 102 | model.load_state_dict(state_dict) 103 | return model 104 | 105 | def fcn_resnet50(pretrained=False, progress=True, 106 | num_classes=21, aux_loss=None, **kwargs): 107 | """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. 108 | 109 | Args: 110 | pretrained (bool): If True, returns a model pre-trained on COCO train2017 which 111 | contains the same classes as Pascal VOC 112 | progress (bool): If True, displays a progress bar of the download to stderr 113 | """ 114 | print('torchvision.models.segmentation.fcn_resnet50()') 115 | 116 | if pretrained: 117 | aux_loss = True 118 | model = _segm_resnet("fcn", "resnet50", num_classes, aux_loss, **kwargs) 119 | if pretrained: 120 | arch = 'fcn_resnet50_coco' 121 | model_url = model_urls[arch] 122 | if model_url is None: 123 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) 124 | else: 125 | state_dict = load_state_dict_from_url(model_url, progress=progress) 126 | model.load_state_dict(state_dict) 127 | return model 128 | 129 | 130 | def fcn_resnet101(pretrained=False, progress=True, 131 | num_classes=21, aux_loss=None, **kwargs): 132 | """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone. 133 | 134 | Args: 135 | pretrained (bool): If True, returns a model pre-trained on COCO train2017 which 136 | contains the same classes as Pascal VOC 137 | progress (bool): If True, displays a progress bar of the download to stderr 138 | """ 139 | if pretrained: 140 | aux_loss = True 141 | model = _segm_resnet("fcn", "resnet101", num_classes, aux_loss, **kwargs) 142 | if pretrained: 143 | arch = 'fcn_resnet101_coco' 144 | model_url = model_urls[arch] 145 | if model_url is None: 146 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) 147 | else: 148 | state_dict = load_state_dict_from_url(model_url, progress=progress) 149 | model.load_state_dict(state_dict) 150 | return model 151 | 152 | 153 | def deeplabv3_resnet50(pretrained=False, progress=True, 154 | num_classes=21, aux_loss=None, **kwargs): 155 | """Constructs a DeepLabV3 model with a ResNet-50 backbone. 156 | 157 | Args: 158 | pretrained (bool): If True, returns a model pre-trained on COCO train2017 which 159 | contains the same classes as Pascal VOC 160 | progress (bool): If True, displays a progress bar of the download to stderr 161 | """ 162 | if pretrained: 163 | aux_loss = True 164 | model = _segm_resnet("deeplab", "resnet50", num_classes, aux_loss, **kwargs) 165 | if pretrained: 166 | arch = 'deeplabv3_resnet50_coco' 167 | model_url = model_urls[arch] 168 | if model_url is None: 169 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) 170 | else: 171 | state_dict = load_state_dict_from_url(model_url, progress=progress) 172 | model.load_state_dict(state_dict) 173 | return model 174 | 175 | 176 | def deeplabv3_resnet101(pretrained=False, progress=True, 177 | num_classes=21, aux_loss=None, **kwargs): 178 | """Constructs a DeepLabV3 model with a ResNet-101 backbone. 179 | 180 | Args: 181 | pretrained (bool): If True, returns a model pre-trained on COCO train2017 which 182 | contains the same classes as Pascal VOC 183 | progress (bool): If True, displays a progress bar of the download to stderr 184 | """ 185 | if pretrained: 186 | aux_loss = True 187 | model = _segm_resnet("deeplab", "resnet101", num_classes, aux_loss, **kwargs) 188 | if pretrained: 189 | arch = 'deeplabv3_resnet101_coco' 190 | model_url = model_urls[arch] 191 | if model_url is None: 192 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) 193 | else: 194 | state_dict = load_state_dict_from_url(model_url, progress=progress) 195 | model.load_state_dict(state_dict) 196 | return model 197 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch.hub import load_state_dict_from_url 3 | except ImportError: 4 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 5 | -------------------------------------------------------------------------------- /onnx_export.py: -------------------------------------------------------------------------------- 1 | # 2 | # converts a saved PyTorch model to ONNX format 3 | # 4 | import os 5 | import argparse 6 | 7 | import torch 8 | from models import segmentation 9 | 10 | 11 | # parse command line 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--input', type=str, default='model_best.pth', 14 | help="path to input PyTorch model (default: model_best.pth)") 15 | parser.add_argument('--output', type=str, default='', 16 | help="desired path of converted ONNX model (default: .onnx)") 17 | parser.add_argument('--model-dir', type=str, default='', 18 | help="directory to look for the input PyTorch model in, and export the converted ONNX model to (if --output doesn't specify a directory)") 19 | 20 | opt = parser.parse_args() 21 | print(opt) 22 | 23 | # format input model path 24 | if opt.model_dir: 25 | opt.model_dir = os.path.expanduser(opt.model_dir) 26 | opt.input = os.path.join(opt.model_dir, opt.input) 27 | 28 | # set the device 29 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 30 | print('running on device ' + str(device)) 31 | 32 | # load the model checkpoint 33 | print('loading checkpoint: ' + opt.input) 34 | checkpoint = torch.load(opt.input) 35 | 36 | arch = checkpoint['arch'] 37 | num_classes = checkpoint['num_classes'] 38 | 39 | print('checkpoint accuracy: {:.3f}% mean IoU, {:.3f}% accuracy'.format( 40 | checkpoint['mean_IoU'], checkpoint['accuracy'])) 41 | 42 | # create the model architecture 43 | print('using model: ' + arch) 44 | print('num classes: ' + str(num_classes)) 45 | 46 | model = segmentation.__dict__[arch](num_classes=num_classes, 47 | aux_loss=None, 48 | pretrained=False, 49 | export_onnx=True) 50 | 51 | # load the model weights 52 | model.load_state_dict(checkpoint['model']) 53 | 54 | model.to(device) 55 | model.eval() 56 | 57 | print(model) 58 | print('') 59 | 60 | # create example image data 61 | resolution = checkpoint['resolution'] 62 | input = torch.ones((1, 3, resolution[0], resolution[1])).cuda() 63 | print('input size: {:d}x{:d}'.format(resolution[1], resolution[0])) 64 | 65 | # format output model path 66 | if not opt.output: 67 | opt.output = arch + '.onnx' 68 | 69 | if opt.model_dir and opt.output.find('/') == -1 and opt.output.find('\\') == -1: 70 | opt.output = os.path.join(opt.model_dir, opt.output) 71 | 72 | # export the model 73 | input_names = ["input_0"] 74 | output_names = ["output_0"] 75 | 76 | print('exporting model to ONNX...') 77 | torch.onnx.export(model, input, opt.output, verbose=True, 78 | input_names=input_names, output_names=output_names) 79 | print('model exported to: {:s}'.format(opt.output)) 80 | -------------------------------------------------------------------------------- /onnx_validate.py: -------------------------------------------------------------------------------- 1 | # 2 | # Check that an ONNX model is valid and well-formed. 3 | # 4 | # Before running this script, install the following: 5 | # 6 | # $ sudo apt-get install protobuf-compiler libprotoc-dev 7 | # $ pip install onnx 8 | # 9 | import onnx 10 | import argparse 11 | import sys 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('model', type=str, default='resnet18.onnx', help='path to ONNX model to validate') 15 | args = parser.parse_args(sys.argv[1:]) 16 | 17 | # Load the ONNX model 18 | model = onnx.load(args.model) 19 | 20 | # Print a human readable representation of the graph 21 | print('Network Graph:') 22 | print(onnx.helper.printable_graph(model.graph)) 23 | print('') 24 | 25 | # Print model metadata 26 | print('ONNX version: ' + onnx.__version__) 27 | print('IR version: {:d}'.format(model.ir_version)) 28 | print('Producer name: ' + model.producer_name) 29 | print('Producer version: ' + model.producer_version) 30 | print('Model version: {:d}'.format(model.model_version)) 31 | print('') 32 | 33 | # Check that the IR is well formed 34 | print('Checking model IR...') 35 | onnx.checker.check_model(model) 36 | print('The model was checked successfully!') 37 | 38 | 39 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | certifi=2020.6.20=py36h9f0ad1d_0 4 | cffi=1.14.0=py36ha419a9e_0 5 | cudatoolkit=10.0.130=0 6 | cython=0.29.20=py36h003fed8_0 7 | intel-openmp=2020.1=216 8 | libblas=3.8.0=16_mkl 9 | libcblas=3.8.0=16_mkl 10 | liblapack=3.8.0=16_mkl 11 | libprotobuf=3.11.3=h1a1b453_0 12 | mkl=2020.1=216 13 | ninja=1.10.0=h1ad3211_0 14 | numpy=1.18.5=py36h4d86e3b_0 15 | pip=20.1.1=py_1 16 | protobuf=3.5.1=py36_vc14_3 17 | pycparser=2.20=pyh9f0ad1d_2 18 | python=3.6.10=he025d50_1009_cpython 19 | python_abi=3.6=1_cp36m 20 | setuptools=49.1.0=py36h9f0ad1d_0 21 | vc=14.1=h869be7e_1 22 | vs2015_runtime=14.16.27012=h30e32a0_2 23 | tqdm=4.48.2=pyh9f0ad1d_0 24 | wheel=0.34.2=py_1 25 | wincertstore=0.2=py36_1003 26 | zlib=1.2.11=h2fa13f4_1006 27 | -------------------------------------------------------------------------------- /split_custom.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import re 4 | from PIL import Image, ImageFile 5 | from tqdm import tqdm 6 | import numpy as np 7 | import argparse 8 | 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | 11 | ap = argparse.ArgumentParser() 12 | ap.add_argument("-i", "--images", type=str, required=True, 13 | help="path to images") 14 | 15 | ap.add_argument("-m", "--masks", type=str, required=True, 16 | help="path to your masks") 17 | 18 | ap.add_argument("-o", "--output", type=str, required=True, 19 | help="path to where the split dataset should be stored") 20 | 21 | ap.add_argument("--image-format", dest="image_format", type=str, default="jpg", 22 | help="image format, defaults to jpg") 23 | 24 | ap.add_argument("--mask-format", dest="mask_format", type=str, default="png", 25 | help="mask format, defaults to png") 26 | 27 | ap.add_argument("--keep-original", dest="keep_original", action="store_true", 28 | help="keep the original images after storing them into corresponding folders") 29 | 30 | 31 | args = vars(ap.parse_args()) 32 | 33 | 34 | # Variables (change if needed) 35 | 36 | INPUT_IMAGE_PATH = args["images"] 37 | INPUT_MASK_PATH = args["masks"] 38 | 39 | OUTPUT_DATA_PATH = args["output"] 40 | OUTPUT_IMAGE_PATH = OUTPUT_DATA_PATH+'/images' 41 | OUTPUT_MASK_PATH = OUTPUT_DATA_PATH+'/annotations' 42 | IMAGE_FORMAT = '.' + args["image_format"] 43 | MASK_FORMAT = '.' + args["mask_format"] 44 | 45 | # Remove duplicates after adding them to train/val folders 46 | keep_old_images = args["keep_original"] 47 | 48 | # Get all images and masks, sort them and shuffle them to generate data sets. 49 | 50 | all_masks = [os.path.splitext(x)[0] for x in os.listdir( 51 | INPUT_MASK_PATH) if MASK_FORMAT in os.path.splitext(x)[1]] 52 | 53 | 54 | all_images = [os.path.splitext(x)[0] for x in os.listdir( 55 | INPUT_IMAGE_PATH) if IMAGE_FORMAT in os.path.splitext(x)[1] and os.path.splitext(x)[0] in all_masks] 56 | 57 | 58 | all_images.sort(key=lambda var: [int(x) if x.isdigit() else x 59 | for x in re.findall(r'[^0-9]|[0-9]+', var)]) 60 | all_masks.sort(key=lambda var: [int(x) if x.isdigit() else x 61 | for x in re.findall(r'[^0-9]|[0-9]+', var)]) 62 | 63 | 64 | random.seed(230) 65 | random.shuffle(all_images) 66 | 67 | 68 | # Split images to train and val sets (80% : 20% ratio) 69 | 70 | train_split = int(0.8*len(all_images)) 71 | 72 | train_images = all_images[:train_split] 73 | val_images = all_images[train_split:] 74 | 75 | 76 | print( 77 | f'SPLIT: {len(train_images)} train and {len(val_images)} validation images!') 78 | print('-------------------------------------------------------------------------------') 79 | 80 | 81 | # Generate corresponding mask lists for masks 82 | 83 | 84 | train_masks = [f for f in all_masks if f in train_images] 85 | val_masks = [f for f in all_masks if f in val_images] 86 | 87 | 88 | # Generate required folders 89 | 90 | train_folder = 'training' 91 | val_folder = 'validation' 92 | 93 | folders = [train_folder, val_folder] 94 | 95 | for folder in folders: 96 | os.makedirs(os.path.join(OUTPUT_IMAGE_PATH, folder), exist_ok=True) 97 | os.makedirs(os.path.join(OUTPUT_MASK_PATH, folder), exist_ok=True) 98 | 99 | 100 | # Add train and val images and their masks to corresponding folders 101 | 102 | 103 | def add_images(dir_name, image): 104 | 105 | full_image_path = INPUT_IMAGE_PATH+'/'+image+IMAGE_FORMAT 106 | img = Image.open(full_image_path) 107 | img = img.convert("RGB") 108 | img.save(OUTPUT_IMAGE_PATH+'/{}'.format(dir_name)+'/'+image+IMAGE_FORMAT) 109 | 110 | if not keep_old_images: 111 | os.remove(full_image_path) 112 | 113 | 114 | def add_masks(dir_name, image): 115 | 116 | full_mask_path = INPUT_MASK_PATH+'/'+image+MASK_FORMAT 117 | img = Image.open(full_mask_path) 118 | 119 | img.save(OUTPUT_MASK_PATH+'/{}'.format(dir_name)+'/'+image+MASK_FORMAT) 120 | 121 | if not keep_old_images: 122 | os.remove(full_mask_path) 123 | 124 | 125 | image_folders = [(train_images, train_folder), (val_images, val_folder)] 126 | 127 | mask_folders = [(train_masks, train_folder), (val_masks, val_folder)] 128 | 129 | print( 130 | f'Writing images to the {image_folders[0][1]} and {image_folders[1][1]} folders...') 131 | 132 | for folder in image_folders: 133 | 134 | array = folder[0] 135 | name = [folder[1]] * len(array) 136 | 137 | list(map(add_images, tqdm(name), array)) 138 | 139 | print( 140 | f'Writing masks to the {mask_folders[0][1]} and {mask_folders[1][1]} folders...') 141 | 142 | for folder in mask_folders: 143 | 144 | array = folder[0] 145 | name = [folder[1]] * len(array) 146 | 147 | list(map(add_masks, tqdm(name), array)) 148 | 149 | print('Done!') 150 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # 2 | # Note -- this training script is tweaked from the original version at: 3 | # 4 | # https://github.com/pytorch/vision/tree/v0.3.0/references/segmentation 5 | # 6 | # 7 | import argparse 8 | import datetime 9 | import time 10 | import math 11 | import os 12 | import shutil 13 | 14 | import torch 15 | import torch.utils.data 16 | from torch import nn 17 | import torchvision 18 | from models import segmentation 19 | 20 | from datasets.coco_utils import get_coco 21 | from datasets.cityscapes_utils import get_cityscapes 22 | from datasets.deepscene import DeepSceneSegmentation 23 | from datasets.custom_dataset import CustomSegmentation 24 | from datasets.mhp import MHPSegmentation 25 | from datasets.nyu import NYUDepth 26 | from datasets.sun import SunRGBDSegmentation 27 | 28 | import transforms as T 29 | import utils 30 | 31 | model_names = sorted(name for name in segmentation.__dict__ 32 | if name.islower() and not name.startswith("__") 33 | and callable(segmentation.__dict__[name])) 34 | 35 | # 36 | # parse command-line arguments 37 | # 38 | def parse_args(): 39 | parser = argparse.ArgumentParser(description='PyTorch Segmentation Training') 40 | 41 | parser.add_argument('data', metavar='DIR', help='path to dataset') 42 | parser.add_argument('--dataset', default='voc', help='dataset type: voc, voc_aug, coco, cityscapes, deepscene, mhp, nyu, sun, custom (default: voc)') 43 | parser.add_argument('-a', '--arch', metavar='ARCH', default='fcn_resnet18', 44 | choices=model_names, 45 | help='model architecture: ' + 46 | ' | '.join(model_names) + 47 | ' (default: fcn_resnet18)') 48 | parser.add_argument('--classes', default=21, type=int, metavar='C', help='number of classes in your dataset (outputs)') 49 | parser.add_argument('--aux-loss', action='store_true', help='train with auxilliary loss') 50 | parser.add_argument('--resolution', default=320, type=int, metavar='N', 51 | help='NxN resolution used for scaling the training dataset (default: 320x320) ' 52 | 'to specify a non-square resolution, use the --width and --height options') 53 | parser.add_argument('--width', default=argparse.SUPPRESS, type=int, metavar='X', 54 | help='desired width of the training dataset. if this option is not set, --resolution will be used') 55 | parser.add_argument('--height', default=argparse.SUPPRESS, type=int, metavar='Y', 56 | help='desired height of the training dataset. if this option is not set, --resolution will be used') 57 | parser.add_argument('--device', default='cuda', help='device') 58 | parser.add_argument('-b', '--batch-size', default=4, type=int) 59 | parser.add_argument('--epochs', default=30, type=int, metavar='N', help='number of total epochs to run') 60 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 61 | help='number of data loading workers (default: 16)') 62 | parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate') 63 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 64 | help='momentum') 65 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 66 | metavar='W', help='weight decay (default: 1e-4)', 67 | dest='weight_decay') 68 | parser.add_argument('--print-freq', default=10, type=int, help='print frequency') 69 | parser.add_argument('--model-dir', default='.', help='path where to save output models') 70 | parser.add_argument('--resume', default='', help='resume from checkpoint') 71 | parser.add_argument("--test-only", dest="test_only", help="Only test the model", action="store_true") 72 | parser.add_argument("--pretrained", dest="pretrained", help="Use pre-trained models (only supported for fcn_resnet101)", action="store_true") 73 | 74 | # distributed training parameters 75 | parser.add_argument('--world-size', default=1, type=int, 76 | help='number of distributed processes') 77 | parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') 78 | 79 | args = parser.parse_args() 80 | return args 81 | 82 | 83 | # 84 | # load desired dataset 85 | # 86 | def get_dataset(name, path, image_set, transform, num_classes): 87 | def sbd(*args, **kwargs): 88 | return torchvision.datasets.SBDataset(*args, mode='segmentation', **kwargs) 89 | paths = { 90 | "voc": (path, torchvision.datasets.VOCSegmentation, num_classes), 91 | "voc_aug": (path, sbd, num_classes), 92 | "coco": (path, get_coco, num_classes), 93 | "cityscapes": (path, get_cityscapes, num_classes), 94 | "deepscene": (path, DeepSceneSegmentation, 5), 95 | "mhp": (path, MHPSegmentation, num_classes), 96 | "nyu": (path, NYUDepth, num_classes), 97 | "sun": (path, SunRGBDSegmentation, num_classes), 98 | "custom": (path, CustomSegmentation, num_classes) 99 | } 100 | p, ds_fn, num_classes = paths[name] 101 | 102 | ds = ds_fn(p, image_set=image_set, transforms=transform) 103 | return ds, num_classes 104 | 105 | 106 | # 107 | # create data transform 108 | # 109 | def get_transform(train, resolution): 110 | transforms = [] 111 | 112 | # if square resolution, perform some aspect cropping 113 | # otherwise, resize to the resolution as specified 114 | if resolution[0] == resolution[1]: 115 | base_size = resolution[0] + 32 #520 116 | crop_size = resolution[0] #480 117 | 118 | min_size = int((0.5 if train else 1.0) * base_size) 119 | max_size = int((2.0 if train else 1.0) * base_size) 120 | 121 | transforms.append(T.RandomResize(min_size, max_size)) 122 | 123 | # during training mode, perform some data randomization 124 | if train: 125 | transforms.append(T.RandomHorizontalFlip(0.5)) 126 | transforms.append(T.RandomCrop(crop_size)) 127 | else: 128 | transforms.append(T.Resize(resolution)) 129 | 130 | if train: 131 | transforms.append(T.RandomHorizontalFlip(0.5)) 132 | 133 | transforms.append(T.ToTensor()) 134 | transforms.append(T.Normalize(mean=[0.485, 0.456, 0.406], 135 | std=[0.229, 0.224, 0.225])) 136 | 137 | return T.Compose(transforms) 138 | 139 | 140 | # 141 | # define the loss functions 142 | # 143 | def criterion(inputs, target): 144 | losses = {} 145 | for name, x in inputs.items(): 146 | losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255) 147 | 148 | if len(losses) == 1: 149 | return losses['out'] 150 | 151 | return losses['out'] + 0.5 * losses['aux'] 152 | 153 | 154 | # 155 | # evaluate model IoU (intersection over union) 156 | # 157 | def evaluate(model, data_loader, device, num_classes): 158 | model.eval() 159 | confmat = utils.ConfusionMatrix(num_classes) 160 | metric_logger = utils.MetricLogger(delimiter=" ") 161 | header = 'Test:' 162 | with torch.no_grad(): 163 | for image, target in metric_logger.log_every(data_loader, 100, header): 164 | image, target = image.to(device), target.to(device) 165 | output = model(image) 166 | output = output['out'] 167 | 168 | confmat.update(target.flatten(), output.argmax(1).flatten()) 169 | 170 | confmat.reduce_from_all_processes() 171 | 172 | return confmat 173 | 174 | 175 | # 176 | # train for one epoch over the dataset 177 | # 178 | def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq): 179 | model.train() 180 | metric_logger = utils.MetricLogger(delimiter=" ") 181 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) 182 | header = 'Epoch: [{}]'.format(epoch) 183 | for image, target in metric_logger.log_every(data_loader, print_freq, header): 184 | image, target = image.to(device), target.to(device) 185 | output = model(image) 186 | loss = criterion(output, target) 187 | 188 | optimizer.zero_grad() 189 | loss.backward() 190 | optimizer.step() 191 | 192 | lr_scheduler.step() 193 | 194 | metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) 195 | 196 | 197 | # 198 | # main training function 199 | # 200 | def main(args): 201 | if args.model_dir: 202 | utils.mkdir(args.model_dir) 203 | 204 | utils.init_distributed_mode(args) 205 | print(args) 206 | 207 | device = torch.device(args.device) 208 | 209 | # determine the desired resolution 210 | resolution = (args.resolution, args.resolution) 211 | 212 | if "width" in args and "height" in args: 213 | resolution = (args.height, args.width) 214 | 215 | # load the train and val datasets 216 | dataset, num_classes = get_dataset(args.dataset, args.data, "train", get_transform(train=True, resolution=resolution), args.classes) 217 | dataset_test, _ = get_dataset(args.dataset, args.data, "val", get_transform(train=False, resolution=resolution), args.classes) 218 | 219 | if args.distributed: 220 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 221 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) 222 | else: 223 | train_sampler = torch.utils.data.RandomSampler(dataset) 224 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 225 | 226 | data_loader = torch.utils.data.DataLoader( 227 | dataset, batch_size=args.batch_size, 228 | sampler=train_sampler, num_workers=args.workers, 229 | collate_fn=utils.collate_fn, drop_last=True) 230 | 231 | data_loader_test = torch.utils.data.DataLoader( 232 | dataset_test, batch_size=1, 233 | sampler=test_sampler, num_workers=args.workers, 234 | collate_fn=utils.collate_fn) 235 | 236 | print("=> training with dataset: '{:s}' (train={:d}, val={:d})".format(args.dataset, len(dataset), len(dataset_test))) 237 | print("=> training with resolution: {:d}x{:d}, {:d} classes".format(resolution[1], resolution[0], num_classes)) 238 | print("=> training with model: {:s}".format(args.arch)) 239 | 240 | # create the segmentation model 241 | model = segmentation.__dict__[args.arch](num_classes=num_classes, 242 | aux_loss=args.aux_loss, 243 | pretrained=args.pretrained) 244 | model.to(device) 245 | 246 | if args.distributed: 247 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 248 | 249 | if args.resume: 250 | checkpoint = torch.load(args.resume, map_location='cpu') 251 | model.load_state_dict(checkpoint['model']) 252 | 253 | model_without_ddp = model 254 | 255 | if args.distributed: 256 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 257 | model_without_ddp = model.module 258 | 259 | # eval-only mode 260 | if args.test_only: 261 | confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) 262 | print(confmat) 263 | return 264 | 265 | # create the optimizer 266 | params_to_optimize = [ 267 | {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]}, 268 | {"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]}, 269 | ] 270 | 271 | if args.aux_loss: 272 | params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad] 273 | params_to_optimize.append({"params": params, "lr": args.lr * 10}) 274 | 275 | optimizer = torch.optim.SGD( 276 | params_to_optimize, 277 | lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 278 | 279 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR( 280 | optimizer, 281 | lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) 282 | 283 | # training loop 284 | start_time = time.time() 285 | best_IoU = 0.0 286 | 287 | for epoch in range(args.epochs): 288 | if args.distributed: 289 | train_sampler.set_epoch(epoch) 290 | 291 | # train the model over the next epoc 292 | train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq) 293 | 294 | # test the model on the val dataset 295 | confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) 296 | print(confmat) 297 | 298 | # save model checkpoint 299 | checkpoint_path = os.path.join(args.model_dir, 'model_{}.pth'.format(epoch)) 300 | 301 | utils.save_on_master( 302 | { 303 | 'model': model_without_ddp.state_dict(), 304 | 'optimizer': optimizer.state_dict(), 305 | 'epoch': epoch, 306 | 'args': args, 307 | 'arch': args.arch, 308 | 'dataset': args.dataset, 309 | 'num_classes': num_classes, 310 | 'resolution': resolution, 311 | 'accuracy': confmat.acc_global, 312 | 'mean_IoU': confmat.mean_IoU 313 | }, 314 | checkpoint_path) 315 | 316 | print('saved checkpoint to: {:s} ({:.3f}% mean IoU, {:.3f}% accuracy)'.format(checkpoint_path, confmat.mean_IoU, confmat.acc_global)) 317 | 318 | if confmat.mean_IoU > best_IoU: 319 | best_IoU = confmat.mean_IoU 320 | best_path = os.path.join(args.model_dir, 'model_best.pth') 321 | shutil.copyfile(checkpoint_path, best_path) 322 | print('saved best model to: {:s} ({:.3f}% mean IoU, {:.3f}% accuracy)'.format(best_path, best_IoU, confmat.acc_global)) 323 | 324 | total_time = time.time() - start_time 325 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 326 | print('Training time {}'.format(total_time_str)) 327 | 328 | 329 | if __name__ == "__main__": 330 | args = parse_args() 331 | main(args) 332 | 333 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import random 4 | 5 | import torch 6 | from torchvision import transforms as T 7 | from torchvision.transforms import functional as F 8 | 9 | 10 | def pad_if_smaller(img, size, fill=0): 11 | min_size = min(img.size) 12 | if min_size < size: 13 | ow, oh = img.size 14 | padh = size - oh if oh < size else 0 15 | padw = size - ow if ow < size else 0 16 | img = F.pad(img, (0, 0, padw, padh), fill=fill) 17 | return img 18 | 19 | 20 | class Compose(object): 21 | def __init__(self, transforms): 22 | self.transforms = transforms 23 | 24 | def __call__(self, image, target): 25 | for t in self.transforms: 26 | image, target = t(image, target) 27 | return image, target 28 | 29 | 30 | class Resize(object): 31 | def __init__(self, size): 32 | self.size = size 33 | 34 | def __call__(self, image, target): 35 | image = F.resize(image, self.size) 36 | target = F.resize(target, self.size, interpolation=Image.NEAREST) 37 | return image, target 38 | 39 | 40 | class RandomResize(object): 41 | def __init__(self, min_size, max_size=None): 42 | self.min_size = min_size 43 | if max_size is None: 44 | max_size = min_size 45 | self.max_size = max_size 46 | 47 | def __call__(self, image, target): 48 | size = random.randint(self.min_size, self.max_size) 49 | image = F.resize(image, size) 50 | target = F.resize(target, size, interpolation=Image.NEAREST) 51 | return image, target 52 | 53 | 54 | class RandomHorizontalFlip(object): 55 | def __init__(self, flip_prob): 56 | self.flip_prob = flip_prob 57 | 58 | def __call__(self, image, target): 59 | if random.random() < self.flip_prob: 60 | image = F.hflip(image) 61 | target = F.hflip(target) 62 | return image, target 63 | 64 | 65 | class RandomCrop(object): 66 | def __init__(self, size): 67 | self.size = size 68 | 69 | def __call__(self, image, target): 70 | image = pad_if_smaller(image, self.size) 71 | target = pad_if_smaller(target, self.size, fill=255) 72 | crop_params = T.RandomCrop.get_params(image, (self.size, self.size)) 73 | image = F.crop(image, *crop_params) 74 | target = F.crop(target, *crop_params) 75 | return image, target 76 | 77 | 78 | class CenterCrop(object): 79 | def __init__(self, size): 80 | self.size = size 81 | 82 | def __call__(self, image, target): 83 | image = F.center_crop(image, self.size) 84 | target = F.center_crop(target, self.size) 85 | return image, target 86 | 87 | 88 | class ToTensor(object): 89 | def __call__(self, image, target): 90 | image = F.to_tensor(image) 91 | target = torch.as_tensor(np.asarray(target), dtype=torch.int64) 92 | return image, target 93 | 94 | 95 | class Normalize(object): 96 | def __init__(self, mean, std): 97 | self.mean = mean 98 | self.std = std 99 | 100 | def __call__(self, image, target): 101 | image = F.normalize(image, mean=self.mean, std=self.std) 102 | return image, target 103 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from collections import defaultdict, deque 3 | import datetime 4 | import math 5 | import time 6 | import torch 7 | import torch.distributed as dist 8 | 9 | import errno 10 | import os 11 | 12 | 13 | class SmoothedValue(object): 14 | """Track a series of values and provide access to smoothed values over a 15 | window or the global series average. 16 | """ 17 | 18 | def __init__(self, window_size=20, fmt=None): 19 | if fmt is None: 20 | fmt = "{median:.4f} ({global_avg:.4f})" 21 | self.deque = deque(maxlen=window_size) 22 | self.total = 0.0 23 | self.count = 0 24 | self.fmt = fmt 25 | 26 | def update(self, value, n=1): 27 | self.deque.append(value) 28 | self.count += n 29 | self.total += value * n 30 | 31 | def synchronize_between_processes(self): 32 | """ 33 | Warning: does not synchronize the deque! 34 | """ 35 | if not is_dist_avail_and_initialized(): 36 | return 37 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 38 | dist.barrier() 39 | dist.all_reduce(t) 40 | t = t.tolist() 41 | self.count = int(t[0]) 42 | self.total = t[1] 43 | 44 | @property 45 | def median(self): 46 | d = torch.tensor(list(self.deque)) 47 | return d.median().item() 48 | 49 | @property 50 | def avg(self): 51 | d = torch.tensor(list(self.deque), dtype=torch.float32) 52 | return d.mean().item() 53 | 54 | @property 55 | def global_avg(self): 56 | return self.total / self.count 57 | 58 | @property 59 | def max(self): 60 | return max(self.deque) 61 | 62 | @property 63 | def value(self): 64 | return self.deque[-1] 65 | 66 | def __str__(self): 67 | return self.fmt.format( 68 | median=self.median, 69 | avg=self.avg, 70 | global_avg=self.global_avg, 71 | max=self.max, 72 | value=self.value) 73 | 74 | 75 | class ConfusionMatrix(object): 76 | def __init__(self, num_classes): 77 | self.num_classes = num_classes 78 | self.mat = None 79 | self.acc_global = 0.0 80 | self.mean_IoU = 0.0 81 | 82 | def update(self, a, b): 83 | n = self.num_classes 84 | if self.mat is None: 85 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device) 86 | with torch.no_grad(): 87 | k = (a >= 0) & (a < n) 88 | inds = n * a[k].to(torch.int64) + b[k] 89 | self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) 90 | 91 | def reset(self): 92 | self.mat.zero_() 93 | 94 | def compute(self): 95 | h = self.mat.float() 96 | 97 | # not all object classes may have data from the val category 98 | h_sum1 = h.sum(1) 99 | h_sum1[h_sum1 == 0] = 1 100 | 101 | if not torch.equal(h.sum(1), h_sum1): 102 | print('Test: Warning -- some classes may be missing validation examples') 103 | 104 | acc_global = torch.diag(h).sum() / h.sum() 105 | acc = torch.diag(h) / h.sum(1) 106 | iu = torch.diag(h) / (h_sum1 + h.sum(0) - torch.diag(h)) 107 | self.acc_global = acc_global.item() * 100 108 | self.mean_IoU = iu.mean().item() * 100 109 | return acc_global, acc, iu 110 | 111 | def reduce_from_all_processes(self): 112 | if not torch.distributed.is_available(): 113 | return 114 | if not torch.distributed.is_initialized(): 115 | return 116 | torch.distributed.barrier() 117 | torch.distributed.all_reduce(self.mat) 118 | 119 | def __str__(self): 120 | acc_global, acc, iu = self.compute() 121 | return ( 122 | 'global correct: {:.1f}\n' 123 | 'average row correct: {}\n' 124 | 'IoU: {}\n' 125 | 'mean IoU: {:.1f}').format( 126 | acc_global.item() * 100, 127 | ['{:.1f}'.format(i) for i in (acc * 100).tolist()], 128 | ['{:.1f}'.format(i) for i in (iu * 100).tolist()], 129 | iu.mean().item() * 100) 130 | 131 | 132 | class MetricLogger(object): 133 | def __init__(self, delimiter="\t"): 134 | self.meters = defaultdict(SmoothedValue) 135 | self.delimiter = delimiter 136 | 137 | def update(self, **kwargs): 138 | for k, v in kwargs.items(): 139 | if isinstance(v, torch.Tensor): 140 | v = v.item() 141 | assert isinstance(v, (float, int)) 142 | self.meters[k].update(v) 143 | 144 | def __getattr__(self, attr): 145 | if attr in self.meters: 146 | return self.meters[attr] 147 | if attr in self.__dict__: 148 | return self.__dict__[attr] 149 | raise AttributeError("'{}' object has no attribute '{}'".format( 150 | type(self).__name__, attr)) 151 | 152 | def __str__(self): 153 | loss_str = [] 154 | for name, meter in self.meters.items(): 155 | loss_str.append( 156 | "{}: {}".format(name, str(meter)) 157 | ) 158 | return self.delimiter.join(loss_str) 159 | 160 | def synchronize_between_processes(self): 161 | for meter in self.meters.values(): 162 | meter.synchronize_between_processes() 163 | 164 | def add_meter(self, name, meter): 165 | self.meters[name] = meter 166 | 167 | def log_every(self, iterable, print_freq, header=None): 168 | i = 0 169 | if not header: 170 | header = '' 171 | start_time = time.time() 172 | end = time.time() 173 | iter_time = SmoothedValue(fmt='{avg:.4f}') 174 | data_time = SmoothedValue(fmt='{avg:.4f}') 175 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 176 | log_msg = self.delimiter.join([ 177 | header, 178 | '[{0' + space_fmt + '}/{1}]', 179 | 'eta: {eta}', 180 | '{meters}', 181 | 'time: {time}', 182 | 'data: {data}', 183 | 'max mem: {memory:.0f}' 184 | ]) 185 | MB = 1024.0 * 1024.0 186 | for obj in iterable: 187 | data_time.update(time.time() - end) 188 | yield obj 189 | iter_time.update(time.time() - end) 190 | if i % print_freq == 0: 191 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 192 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 193 | print(log_msg.format( 194 | i, len(iterable), eta=eta_string, 195 | meters=str(self), 196 | time=str(iter_time), data=str(data_time), 197 | memory=torch.cuda.max_memory_allocated() / MB)) 198 | i += 1 199 | end = time.time() 200 | total_time = time.time() - start_time 201 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 202 | print('{} Total time: {}'.format(header, total_time_str)) 203 | 204 | 205 | def cat_list(images, fill_value=0): 206 | max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) 207 | batch_shape = (len(images),) + max_size 208 | batched_imgs = images[0].new(*batch_shape).fill_(fill_value) 209 | for img, pad_img in zip(images, batched_imgs): 210 | pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img) 211 | return batched_imgs 212 | 213 | 214 | def collate_fn(batch): 215 | images, targets = list(zip(*batch)) 216 | batched_imgs = cat_list(images, fill_value=0) 217 | batched_targets = cat_list(targets, fill_value=255) 218 | return batched_imgs, batched_targets 219 | 220 | 221 | def mkdir(path): 222 | try: 223 | os.makedirs(path) 224 | except OSError as e: 225 | if e.errno != errno.EEXIST: 226 | raise 227 | 228 | 229 | def setup_for_distributed(is_master): 230 | """ 231 | This function disables printing when not in master process 232 | """ 233 | import builtins as __builtin__ 234 | builtin_print = __builtin__.print 235 | 236 | def print(*args, **kwargs): 237 | force = kwargs.pop('force', False) 238 | if is_master or force: 239 | builtin_print(*args, **kwargs) 240 | 241 | __builtin__.print = print 242 | 243 | 244 | def is_dist_avail_and_initialized(): 245 | if not dist.is_available(): 246 | return False 247 | if not dist.is_initialized(): 248 | return False 249 | return True 250 | 251 | 252 | def get_world_size(): 253 | if not is_dist_avail_and_initialized(): 254 | return 1 255 | return dist.get_world_size() 256 | 257 | 258 | def get_rank(): 259 | if not is_dist_avail_and_initialized(): 260 | return 0 261 | return dist.get_rank() 262 | 263 | 264 | def is_main_process(): 265 | return get_rank() == 0 266 | 267 | 268 | def save_on_master(*args, **kwargs): 269 | if is_main_process(): 270 | torch.save(*args, **kwargs) 271 | 272 | 273 | def init_distributed_mode(args): 274 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 275 | args.rank = int(os.environ["RANK"]) 276 | args.world_size = int(os.environ['WORLD_SIZE']) 277 | args.gpu = int(os.environ['LOCAL_RANK']) 278 | elif 'SLURM_PROCID' in os.environ: 279 | args.rank = int(os.environ['SLURM_PROCID']) 280 | args.gpu = args.rank % torch.cuda.device_count() 281 | elif hasattr(args, "rank"): 282 | pass 283 | else: 284 | print('Not using distributed mode') 285 | args.distributed = False 286 | return 287 | 288 | args.distributed = True 289 | 290 | torch.cuda.set_device(args.gpu) 291 | args.dist_backend = 'nccl' 292 | print('| distributed init (rank {}): {}'.format( 293 | args.rank, args.dist_url), flush=True) 294 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 295 | world_size=args.world_size, rank=args.rank) 296 | setup_for_distributed(args.rank == 0) 297 | --------------------------------------------------------------------------------