├── teaser.png ├── LMSCNet ├── common │ ├── seed.py │ ├── time.py │ ├── io_tools.py │ ├── logger.py │ ├── dataset.py │ ├── optimizer.py │ ├── model.py │ ├── checkpoint.py │ ├── config.py │ └── metrics.py ├── test.py ├── data │ ├── labels_downscale.py │ ├── semantic-kitti.yaml │ ├── io_data.py │ └── SemanticKITTI.py ├── validate.py ├── models │ ├── SSCNet_full.py │ ├── SSCNet.py │ ├── LMSCNet_SS.py │ └── LMSCNet.py └── train.py ├── SSC_configs ├── examples │ ├── SSCNet.yaml │ ├── LMSCNet.yaml │ ├── SSCNet_full.yaml │ └── LMSCNet_SS.yaml └── config_routine.py ├── .gitignore ├── README.md └── LICENSE /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/astra-vision/LMSCNet/HEAD/teaser.png -------------------------------------------------------------------------------- /LMSCNet/common/seed.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | import os 5 | 6 | 7 | def seed_all(seed): 8 | ''' 9 | Set seeds for training reproducibility 10 | ''' 11 | random.seed(seed) 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | np.random.seed(seed) 15 | os.environ['PYTHONHASHSEED'] = str(seed) -------------------------------------------------------------------------------- /LMSCNet/common/time.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | 4 | def get_date_sting(): 5 | ''' 6 | To retrieve time in nice format for string printing and naming 7 | :return: 8 | ''' 9 | _now = datetime.datetime.now() 10 | _date = ('%.2i' % _now.month) + ('%.2i' % _now.day) # ('%.4i' % _now.year) + 11 | _time = ('%.2i' % _now.hour) + ('%.2i' % _now.minute) + ('%.2i' % _now.second) 12 | return (_date + '_' + _time) -------------------------------------------------------------------------------- /SSC_configs/examples/SSCNet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | NUM_WORKERS: 4 3 | DATASET: 4 | AUGMENTATION: 5 | FLIPS: false 6 | MODALITIES: 7 | 3D_LABEL: true 8 | 3D_OCCLUDED: true 9 | 3D_OCCUPANCY: true 10 | ROOT_DIR: /datasets_local/datasets_lroldaoj/semantic_kitti_v1.0/ 11 | TYPE: SemanticKITTI 12 | MODEL: 13 | TYPE: SSCNet 14 | OPTIMIZER: 15 | BASE_LR: 0.001 16 | BETA1: 0.9 17 | BETA2: 0.999 18 | MOMENTUM: NA 19 | TYPE: Adam 20 | WEIGHT_DECAY: NA 21 | OUTPUT: 22 | OUT_ROOT: ../SSC_out/ 23 | SCHEDULER: 24 | FREQUENCY: epoch 25 | LR_POWER: NA 26 | TYPE: constant 27 | STATUS: 28 | RESUME: false 29 | TRAIN: 30 | BATCH_SIZE: 4 31 | CHECKPOINT_PERIOD: 15 32 | EPOCHS: 80 33 | SUMMARY_PERIOD: 50 34 | VAL: 35 | BATCH_SIZE: 8 36 | SUMMARY_PERIOD: 20 37 | -------------------------------------------------------------------------------- /SSC_configs/examples/LMSCNet.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | NUM_WORKERS: 4 3 | DATASET: 4 | AUGMENTATION: 5 | FLIPS: true 6 | MODALITIES: 7 | 3D_LABEL: true 8 | 3D_OCCLUDED: true 9 | 3D_OCCUPANCY: true 10 | ROOT_DIR: /datasets_local/datasets_lroldaoj/semantic_kitti_v1.0/ 11 | TYPE: SemanticKITTI 12 | MODEL: 13 | TYPE: LMSCNet 14 | OPTIMIZER: 15 | BASE_LR: 0.001 16 | BETA1: 0.9 17 | BETA2: 0.999 18 | MOMENTUM: NA 19 | TYPE: Adam 20 | WEIGHT_DECAY: NA 21 | OUTPUT: 22 | OUT_ROOT: ../SSC_out/ 23 | SCHEDULER: 24 | FREQUENCY: epoch 25 | LR_POWER: 0.98 26 | TYPE: power_iteration 27 | STATUS: 28 | RESUME: false 29 | TRAIN: 30 | BATCH_SIZE: 4 31 | CHECKPOINT_PERIOD: 15 32 | EPOCHS: 80 33 | SUMMARY_PERIOD: 50 34 | VAL: 35 | BATCH_SIZE: 8 36 | SUMMARY_PERIOD: 20 37 | -------------------------------------------------------------------------------- /SSC_configs/examples/SSCNet_full.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | NUM_WORKERS: 4 3 | DATASET: 4 | AUGMENTATION: 5 | FLIPS: false 6 | MODALITIES: 7 | 3D_LABEL: true 8 | 3D_OCCLUDED: true 9 | 3D_OCCUPANCY: true 10 | ROOT_DIR: /datasets_local/datasets_lroldaoj/semantic_kitti_v1.0/ 11 | TYPE: SemanticKITTI 12 | MODEL: 13 | TYPE: SSCNet_full 14 | OPTIMIZER: 15 | BASE_LR: 0.001 16 | BETA1: 0.9 17 | BETA2: 0.999 18 | MOMENTUM: NA 19 | TYPE: Adam 20 | WEIGHT_DECAY: NA 21 | OUTPUT: 22 | OUT_ROOT: ../SSC_out/ 23 | SCHEDULER: 24 | FREQUENCY: epoch 25 | LR_POWER: NA 26 | TYPE: constant 27 | STATUS: 28 | RESUME: false 29 | TRAIN: 30 | BATCH_SIZE: 4 31 | CHECKPOINT_PERIOD: 15 32 | EPOCHS: 80 33 | SUMMARY_PERIOD: 50 34 | VAL: 35 | BATCH_SIZE: 8 36 | SUMMARY_PERIOD: 20 37 | -------------------------------------------------------------------------------- /SSC_configs/examples/LMSCNet_SS.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | NUM_WORKERS: 4 3 | DATASET: 4 | AUGMENTATION: 5 | FLIPS: true 6 | MODALITIES: 7 | 3D_LABEL: true 8 | 3D_OCCLUDED: true 9 | 3D_OCCUPANCY: true 10 | ROOT_DIR: /datasets_local/datasets_lroldaoj/semantic_kitti_v1.0/ 11 | TYPE: SemanticKITTI 12 | MODEL: 13 | TYPE: LMSCNet_SS 14 | OPTIMIZER: 15 | BASE_LR: 0.001 16 | BETA1: 0.9 17 | BETA2: 0.999 18 | MOMENTUM: NA 19 | TYPE: Adam 20 | WEIGHT_DECAY: NA 21 | OUTPUT: 22 | OUT_ROOT: ../SSC_out/ 23 | SCHEDULER: 24 | FREQUENCY: epoch 25 | LR_POWER: 0.98 26 | TYPE: power_iteration 27 | STATUS: 28 | RESUME: false 29 | TRAIN: 30 | BATCH_SIZE: 4 31 | CHECKPOINT_PERIOD: 15 32 | EPOCHS: 80 33 | SUMMARY_PERIOD: 50 34 | VAL: 35 | BATCH_SIZE: 8 36 | SUMMARY_PERIOD: 20 37 | -------------------------------------------------------------------------------- /LMSCNet/common/io_tools.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | 4 | 5 | def get_md5(filename): 6 | ''' 7 | 8 | ''' 9 | hash_obj = hashlib.md5() 10 | with open(filename, 'rb') as f: 11 | hash_obj.update(f.read()) 12 | return hash_obj.hexdigest() 13 | 14 | 15 | def dict_to(_dict, device, dtype): 16 | ''' 17 | 18 | ''' 19 | for key, value in _dict.items(): 20 | if type(_dict[key]) is dict: 21 | _dict[key] = dict_to(_dict[key], device, dtype) 22 | else: 23 | _dict[key] = _dict[key].to(device=device, dtype=dtype) 24 | 25 | return _dict 26 | 27 | 28 | def _remove_recursively(folder_path): 29 | ''' 30 | Remove directory recursively 31 | ''' 32 | if os.path.isdir(folder_path): 33 | filelist = [f for f in os.listdir(folder_path)] 34 | for f in filelist: 35 | os.remove(os.path.join(folder_path, f)) 36 | return 37 | 38 | 39 | def _create_directory(directory): 40 | ''' 41 | Create directory if doesn't exists 42 | ''' 43 | if not os.path.exists(directory): 44 | os.makedirs(directory) 45 | return -------------------------------------------------------------------------------- /LMSCNet/common/logger.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import os 3 | import logging 4 | 5 | 6 | def get_logger(path, filename): 7 | 8 | # Create the folder where the training information is to be saved if it doesn't exist 9 | if not os.path.exists(path): 10 | try: 11 | os.makedirs(path) 12 | except OSError as exc: # Guard against race condition 13 | if exc.errno != errno.EEXIST: 14 | raise 15 | 16 | # Create the logger 17 | logger = logging.getLogger() 18 | logger.setLevel(logging.INFO) # In order to store logs of level INFO and above 19 | # create file handler which logs even debug messages into logs file 20 | fh = logging.FileHandler(os.path.join(path, filename)) 21 | fh.setLevel(logging.INFO) 22 | # create console handler 23 | ch = logging.StreamHandler() 24 | ch.setLevel(logging.INFO) 25 | # create formatter and add it to the handlers 26 | formatter = logging.Formatter('%(asctime)s -- %(message)s') 27 | fh.setFormatter(formatter) 28 | ch.setFormatter(formatter) 29 | # add the handlers to the logger 30 | logger.addHandler(fh) 31 | logger.addHandler(ch) 32 | 33 | return logger -------------------------------------------------------------------------------- /LMSCNet/common/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from LMSCNet.data.SemanticKITTI import SemanticKITTI_dataloader 4 | 5 | 6 | def get_dataset(_cfg): 7 | 8 | if _cfg._dict['DATASET']['TYPE'] == 'SemanticKITTI': 9 | ds_train = SemanticKITTI_dataloader(_cfg._dict['DATASET'], 'train') 10 | ds_val = SemanticKITTI_dataloader(_cfg._dict['DATASET'], 'val') 11 | ds_test = SemanticKITTI_dataloader(_cfg._dict['DATASET'], 'test') 12 | 13 | _cfg._dict['DATASET']['SPLIT'] = {'TRAIN': len(ds_train), 'VAL': len(ds_val), 'TEST': len(ds_test)} 14 | 15 | dataset = {} 16 | 17 | train_batch_size = _cfg._dict['TRAIN']['BATCH_SIZE'] 18 | val_batch_size = _cfg._dict['VAL']['BATCH_SIZE'] 19 | num_workers = _cfg._dict['DATALOADER']['NUM_WORKERS'] 20 | 21 | dataset['train'] = DataLoader(ds_train, batch_size=train_batch_size, num_workers=num_workers, shuffle=True) 22 | dataset['val'] = DataLoader(ds_val, batch_size=val_batch_size, num_workers=num_workers, shuffle=False) 23 | dataset['test'] = DataLoader(ds_test, batch_size=val_batch_size, num_workers=num_workers, shuffle=False) 24 | 25 | return dataset -------------------------------------------------------------------------------- /LMSCNet/common/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | 4 | def build_optimizer(_cfg, model): 5 | 6 | opt = _cfg._dict['OPTIMIZER']['TYPE'] 7 | lr = _cfg._dict['OPTIMIZER']['BASE_LR'] 8 | if 'MOMENTUM' in _cfg._dict['OPTIMIZER']: momentum = _cfg._dict['OPTIMIZER']['MOMENTUM'] 9 | if 'WEIGHT_DECAY' in _cfg._dict['OPTIMIZER']: weight_decay = _cfg._dict['OPTIMIZER']['WEIGHT_DECAY'] 10 | 11 | if opt == 'Adam': optimizer = optim.Adam(model.get_parameters(), 12 | lr=lr, 13 | betas=(0.9, 0.999)) 14 | 15 | elif opt == 'SGD': optimizer = optim.SGD(model.get_parameters(), 16 | lr=lr, 17 | momentum=momentum, 18 | weight_decay=weight_decay) 19 | 20 | return optimizer 21 | 22 | 23 | def build_scheduler(_cfg, optimizer): 24 | 25 | # Constant learning rate 26 | if _cfg._dict['SCHEDULER']['TYPE'] == 'constant': 27 | lambda1 = lambda epoch: 1 28 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 29 | 30 | # Learning rate scaled by 0.98^(epoch) 31 | if _cfg._dict['SCHEDULER']['TYPE'] == 'power_iteration': 32 | lambda1 = lambda epoch: (0.98) ** (epoch) 33 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 34 | 35 | 36 | return scheduler -------------------------------------------------------------------------------- /LMSCNet/common/model.py: -------------------------------------------------------------------------------- 1 | from LMSCNet.models.LMSCNet import LMSCNet 2 | from LMSCNet.models.LMSCNet_SS import LMSCNet_SS 3 | from LMSCNet.models.SSCNet_full import SSCNet_full 4 | from LMSCNet.models.SSCNet import SSCNet 5 | 6 | 7 | def get_model(_cfg, dataset): 8 | 9 | nbr_classes = dataset.nbr_classes 10 | grid_dimensions = dataset.grid_dimensions 11 | class_frequencies = dataset.class_frequencies 12 | 13 | selected_model = _cfg._dict['MODEL']['TYPE'] 14 | 15 | # LMSCNet ---------------------------------------------------------------------------------------------------------- 16 | if selected_model == 'LMSCNet': 17 | model = LMSCNet(class_num=nbr_classes, input_dimensions=grid_dimensions, class_frequencies=class_frequencies) 18 | # ------------------------------------------------------------------------------------------------------------------ 19 | 20 | # LMSCNet_SS ------------------------------------------------------------------------------------------------------- 21 | elif selected_model == 'LMSCNet_SS': 22 | model = LMSCNet_SS(class_num=nbr_classes, input_dimensions=grid_dimensions, class_frequencies=class_frequencies) 23 | # ------------------------------------------------------------------------------------------------------------------ 24 | 25 | # SSCNet_full ------------------------------------------------------------------------------------------------------ 26 | elif selected_model == 'SSCNet_full': 27 | model = SSCNet_full(class_num=nbr_classes) 28 | # ------------------------------------------------------------------------------------------------------------------ 29 | 30 | # SSCNet ----------------------------------------------------------------------------------------------------------- 31 | elif selected_model == 'SSCNet': 32 | model = SSCNet(class_num=nbr_classes) 33 | # ------------------------------------------------------------------------------------------------------------------ 34 | 35 | else: 36 | assert False, 'Wrong model selected' 37 | 38 | return model -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | 123 | # Pyre type checker 124 | .pyre/ 125 | 126 | # Project Folders 127 | LMSCNet/scratch 128 | SSC_configs/routines 129 | SSC_out/ 130 | weights/ 131 | 132 | # Python 133 | .idea/ 134 | -------------------------------------------------------------------------------- /LMSCNet/common/checkpoint.py: -------------------------------------------------------------------------------- 1 | from torch.nn.parallel import DataParallel, DistributedDataParallel 2 | import torch 3 | import os 4 | from glob import glob 5 | 6 | from LMSCNet.common.io_tools import _remove_recursively, _create_directory 7 | 8 | 9 | def load(model, optimizer, scheduler, resume, path, logger): 10 | ''' 11 | Load checkpoint file 12 | ''' 13 | 14 | # If not resume, initialize model and return everything as it is 15 | if not resume: 16 | logger.info('=> No checkpoint. Initializing model from scratch') 17 | model.weights_init() 18 | epoch = 1 19 | return model, optimizer, scheduler, epoch 20 | 21 | # If resume, check that path exists and load everything to return 22 | else: 23 | file_path = glob(os.path.join(path, '*.pth'))[0] 24 | assert os.path.isfile(file_path), '=> No checkpoint found at {}'.format(path) 25 | checkpoint = torch.load(file_path) 26 | epoch = checkpoint.pop('startEpoch') 27 | if isinstance(model, (DataParallel, DistributedDataParallel)): 28 | model.module.load_state_dict(checkpoint.pop('model')) 29 | else: 30 | model.load_state_dict(checkpoint.pop('model')) 31 | optimizer.load_state_dict(checkpoint.pop('optimizer')) 32 | scheduler.load_state_dict(checkpoint.pop('scheduler')) 33 | logger.info('=> Continuing training routine. Checkpoint loaded at {}'.format(file_path)) 34 | return model, optimizer, scheduler, epoch 35 | 36 | 37 | def load_model(model, filepath, logger): 38 | ''' 39 | Load checkpoint file 40 | ''' 41 | 42 | # check that path exists and load everything to return 43 | assert os.path.isfile(filepath), '=> No file found at {}' 44 | checkpoint = torch.load(filepath) 45 | 46 | if isinstance(model, (DataParallel, DistributedDataParallel)): 47 | model.module.load_state_dict(checkpoint.pop('model')) 48 | else: 49 | model.load_state_dict(checkpoint.pop('model')) 50 | logger.info('=> Model loaded at {}'.format(filepath)) 51 | return model 52 | 53 | 54 | def save(path, model, optimizer, scheduler, epoch, config): 55 | ''' 56 | Save checkpoint file 57 | ''' 58 | 59 | # Remove recursively if epoch_last folder exists and create new one 60 | _remove_recursively(path) 61 | _create_directory(path) 62 | 63 | weights_fpath = os.path.join(path, 'weights_epoch_{}.pth'.format(str(epoch).zfill(3))) 64 | 65 | torch.save({ 66 | 'startEpoch': epoch+1, # To start on next epoch when loading the dict... 67 | 'model': model.state_dict(), 68 | 'optimizer': optimizer.state_dict(), 69 | 'scheduler': scheduler.state_dict(), 70 | 'config_dict': config 71 | }, weights_fpath) 72 | 73 | return weights_fpath -------------------------------------------------------------------------------- /LMSCNet/common/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | 4 | from LMSCNet.common.time import get_date_sting 5 | 6 | 7 | class CFG: 8 | 9 | def __init__(self): 10 | ''' 11 | Class constructor 12 | :param config_path: 13 | ''' 14 | 15 | # Initializing dict... 16 | self._dict = {} 17 | return 18 | 19 | def from_config_yaml(self, config_path): 20 | ''' 21 | Class constructor 22 | :param config_path: 23 | ''' 24 | 25 | # Reading config file 26 | self._dict = yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader) 27 | 28 | self._dict['STATUS']['CONFIG'] = config_path 29 | 30 | if not 'OUTPUT_PATH' in self._dict['OUTPUT'].keys(): 31 | self.set_output_filename() 32 | self.init_stats() 33 | self.update_config() 34 | 35 | return 36 | 37 | def from_dict(self, config_dict): 38 | ''' 39 | Class constructor 40 | :param config_path: 41 | ''' 42 | 43 | # Reading config file 44 | self._dict = config_dict 45 | return 46 | 47 | def set_output_filename(self): 48 | ''' 49 | Set output path in the form Model_Dataset_DDYY_HHMMSS 50 | ''' 51 | datetime = get_date_sting() 52 | model = self._dict['MODEL']['TYPE'] 53 | dataset = self._dict['DATASET']['TYPE'] 54 | OUT_PATH = os.path.join(self._dict['OUTPUT']['OUT_ROOT'], model + '_' + dataset + '_' + datetime) 55 | self._dict['OUTPUT']['OUTPUT_PATH'] = OUT_PATH 56 | return 57 | 58 | def update_config(self, resume=False): 59 | ''' 60 | Save config file 61 | ''' 62 | if resume: 63 | self.set_resume() 64 | yaml.dump(self._dict, open(self._dict['STATUS']['CONFIG'], 'w')) 65 | return 66 | 67 | def init_stats(self): 68 | ''' 69 | Initialize training stats (i.e. epoch mean time, best loss, best metrics) 70 | ''' 71 | self._dict['OUTPUT']['BEST_LOSS'] = 999999999999 72 | self._dict['OUTPUT']['BEST_METRIC'] = -999999999999 73 | self._dict['STATUS']['LAST'] = '' 74 | return 75 | 76 | def set_resume(self): 77 | ''' 78 | Update resume status dict file 79 | ''' 80 | if not self._dict['STATUS']['RESUME']: 81 | self._dict['STATUS']['RESUME'] = True 82 | return 83 | 84 | def finish_config(self): 85 | self.move_config(os.path.join(self._dict['OUTPUT']['OUTPUT_PATH'], 'config.yaml')) 86 | return 87 | 88 | def move_config(self, path): 89 | # Remove from original path 90 | os.remove(self._dict['STATUS']['CONFIG']) 91 | # Change ['STATUS']['CONFIG'] to new path 92 | self._dict['STATUS']['CONFIG'] = path 93 | # Save to routine output folder 94 | yaml.dump(self._dict, open(path, 'w')) 95 | 96 | return 97 | -------------------------------------------------------------------------------- /SSC_configs/config_routine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import sys 4 | 5 | # Append root directory to system path for imports 6 | repo_path, _ = os.path.split(os.path.realpath(__file__)) 7 | repo_path, _ = os.path.split(repo_path) 8 | sys.path.append(repo_path) 9 | 10 | from LMSCNet.common.io_tools import _create_directory 11 | 12 | config_dict = {} 13 | 14 | output_root = '' 15 | output_folder = 'routines' 16 | output_filename = 'LMSCNet.yaml' 17 | out_path = os.path.join(output_root, output_folder, output_filename) 18 | 19 | # ------------------------------------------------------------- 20 | config_dict['DATALOADER'] = {} 21 | config_dict['DATALOADER']['NUM_WORKERS'] = 4 22 | # ------------------------------------------------------------- 23 | 24 | # ------------------------------------------------------------- 25 | config_dict['DATASET'] = {} 26 | config_dict['DATASET']['TYPE'] = 'SemanticKITTI' # SemanticKITTI, other datasets might be added... 27 | config_dict['DATASET']['MODALITIES'] = {} 28 | # More modalities might be added 29 | config_dict['DATASET']['MODALITIES']['3D_LABEL'] = True 30 | config_dict['DATASET']['MODALITIES']['3D_OCCUPANCY'] = True 31 | config_dict['DATASET']['MODALITIES']['3D_OCCLUDED'] = True 32 | config_dict['DATASET']['ROOT_DIR'] = '/datasets_local/datasets_lroldaoj/semantic_kitti_v1.0/' 33 | config_dict['DATASET']['AUGMENTATION'] = {} 34 | config_dict['DATASET']['AUGMENTATION']['FLIPS'] = True # More data augmentation can be added in dataloader 35 | # ------------------------------------------------------------- 36 | 37 | # ------------------------------------------------------------- 38 | config_dict['MODEL'] = {} 39 | config_dict['MODEL']['TYPE'] = 'LMSCNet' # [LMSCNet, LMSCNet_SS, SSCNet, SSCNet_full] 40 | # ------------------------------------------------------------- 41 | 42 | # ------------------------------------------------------------- 43 | config_dict['OPTIMIZER'] = {} 44 | config_dict['OPTIMIZER']['BASE_LR'] = 0.001 45 | config_dict['OPTIMIZER']['TYPE'] = 'Adam' # [SGD, Adam] 46 | # For SGD Optimizer 47 | config_dict['OPTIMIZER']['MOMENTUM'] = 'NA' 48 | config_dict['OPTIMIZER']['WEIGHT_DECAY'] = 'NA' 49 | # For Adam Optimizer 50 | config_dict['OPTIMIZER']['BETA1'] = 0.9 51 | config_dict['OPTIMIZER']['BETA2'] = 0.999 52 | # ------------------------------------------------------------- 53 | 54 | # ------------------------------------------------------------- 55 | config_dict['OUTPUT'] = {} 56 | config_dict['OUTPUT']['OUT_ROOT'] = '../SSC_out/' 57 | # ------------------------------------------------------------- 58 | 59 | # ------------------------------------------------------------- 60 | config_dict['SCHEDULER'] = {} 61 | config_dict['SCHEDULER']['TYPE'] = 'power_iteration' # ['constant', 'power_iteration'] 62 | config_dict['SCHEDULER']['FREQUENCY'] = 'epoch' 63 | config_dict['SCHEDULER']['LR_POWER'] = 0.98 # ['NA', 0.98] 64 | # ------------------------------------------------------------- 65 | 66 | # ------------------------------------------------------------- 67 | config_dict['STATUS'] = {} 68 | config_dict['STATUS']['RESUME'] = False 69 | # ------------------------------------------------------------- 70 | 71 | # ------------------------------------------------------------- 72 | config_dict['TRAIN'] = {} 73 | config_dict['TRAIN']['BATCH_SIZE'] = 4 74 | config_dict['TRAIN']['CHECKPOINT_PERIOD'] = 15 75 | config_dict['TRAIN']['EPOCHS'] = 80 76 | config_dict['TRAIN']['SUMMARY_PERIOD'] = 50 77 | # ------------------------------------------------------------- 78 | 79 | # ------------------------------------------------------------- 80 | config_dict['VAL'] = {} 81 | config_dict['VAL']['BATCH_SIZE'] = 8 82 | config_dict['VAL']['SUMMARY_PERIOD'] = 20 83 | # ------------------------------------------------------------- 84 | 85 | _create_directory(os.path.dirname(out_path)) 86 | yaml.dump(config_dict, open(out_path, 'w')) 87 | 88 | print('Config routine file {} saved...'.format(out_path)) 89 | -------------------------------------------------------------------------------- /LMSCNet/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import sys 6 | import numpy as np 7 | 8 | # Append root directory to system path for imports 9 | repo_path, _ = os.path.split(os.path.realpath(__file__)) 10 | repo_path, _ = os.path.split(repo_path) 11 | sys.path.append(repo_path) 12 | 13 | from LMSCNet.common.seed import seed_all 14 | from LMSCNet.common.config import CFG 15 | from LMSCNet.common.dataset import get_dataset 16 | from LMSCNet.common.model import get_model 17 | from LMSCNet.common.logger import get_logger 18 | from LMSCNet.common.io_tools import dict_to, _create_directory 19 | import LMSCNet.common.checkpoint as checkpoint 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description='LMSCNet validating') 24 | parser.add_argument( 25 | '--weights', 26 | dest='weights_file', 27 | default='', 28 | metavar='FILE', 29 | help='path to folder where model.pth file is', 30 | type=str, 31 | ) 32 | parser.add_argument( 33 | '--dset_root', 34 | dest='dataset_root', 35 | default='', 36 | metavar='DATASET', 37 | help='path to dataset root folder', 38 | type=str, 39 | ) 40 | parser.add_argument( 41 | '--out_path', 42 | dest='output_path', 43 | default='', 44 | metavar='OUT_PATH', 45 | help='path to folder where predictions will be saved', 46 | type=str, 47 | ) 48 | args = parser.parse_args() 49 | return args 50 | 51 | 52 | def test(model, dset, _cfg, logger, out_path_root): 53 | 54 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 55 | dtype = torch.float32 # Tensor type to be used 56 | # Moving optimizer and model to used device 57 | model = model.to(device=device) 58 | logger.info('=> Passing the network on the test set...') 59 | model.eval() 60 | inv_remap_lut = dset.dataset.get_inv_remap_lut() 61 | 62 | with torch.no_grad(): 63 | 64 | for t, (data, indices) in enumerate(dset): 65 | 66 | data = dict_to(data, device, dtype) 67 | scores = model(data) 68 | for key in scores: 69 | scores[key] = torch.argmax(scores[key], dim=1).data.cpu().numpy() 70 | 71 | curr_index = 0 72 | for score in scores['pred_semantic_1_1']: 73 | score = np.moveaxis(score, [0, 1, 2], [0, 2, 1]).reshape(-1).astype(np.uint16) 74 | score = inv_remap_lut[score].astype(np.uint16) 75 | input_filename = dset.dataset.filepaths['3D_OCCUPANCY'][indices[curr_index]] 76 | filename, extension = os.path.splitext(os.path.basename(input_filename)) 77 | sequence = os.path.dirname(input_filename).split('/')[-2] 78 | out_filename = os.path.join(out_path_root, 'sequences', sequence, 'predicitons', filename + '.label') 79 | _create_directory(os.path.dirname(out_filename)) 80 | score.tofile(out_filename) 81 | logger.info('=> Sequence {} - File {} saved'.format(sequence, os.path.basename(out_filename))) 82 | curr_index += 1 83 | 84 | return 85 | 86 | 87 | def main(): 88 | 89 | # https://github.com/pytorch/pytorch/issues/27588 90 | torch.backends.cudnn.enabled = False 91 | 92 | seed_all(0) 93 | 94 | args = parse_args() 95 | 96 | weights_f = args.weights_file 97 | dataset_f = args.dataset_root 98 | out_path_root = args.output_path 99 | 100 | assert os.path.isfile(weights_f), '=> No file found at {}' 101 | 102 | checkpoint_path = torch.load(weights_f) 103 | config_dict = checkpoint_path.pop('config_dict') 104 | config_dict['DATASET']['ROOT_DIR'] = dataset_f 105 | 106 | # Read train configuration file 107 | _cfg = CFG() 108 | _cfg.from_dict(config_dict) 109 | # Setting the logger to print statements and also save them into logs file 110 | logger = get_logger(out_path_root, 'logs_test.log') 111 | 112 | logger.info('============ Test weights: "%s" ============\n' % weights_f) 113 | dataset = get_dataset(_cfg)['test'] 114 | 115 | logger.info('=> Loading network architecture...') 116 | model = get_model(_cfg, dataset.dataset) 117 | if torch.cuda.device_count() > 1: 118 | model = nn.DataParallel(model) 119 | model = model.module 120 | 121 | logger.info('=> Loading network weights...') 122 | model = checkpoint.load_model(model, weights_f, logger) 123 | 124 | test(model, dataset, _cfg, logger, out_path_root) 125 | 126 | logger.info('=> ============ Network Test Done ============') 127 | 128 | exit() 129 | 130 | 131 | if __name__ == '__main__': 132 | main() -------------------------------------------------------------------------------- /LMSCNet/data/labels_downscale.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import os 3 | import numpy as np 4 | import yaml 5 | import time 6 | import argparse 7 | import sys 8 | 9 | # Append root directory to system path for imports 10 | repo_path, _ = os.path.split(os.path.realpath(__file__)) 11 | repo_path, _ = os.path.split(repo_path) 12 | repo_path, _ = os.path.split(repo_path) 13 | sys.path.append(repo_path) 14 | 15 | import LMSCNet.data.io_data as SemanticKittiIO 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description='LMSCNet labels lower scales creation') 20 | parser.add_argument( 21 | '--dset_root', 22 | dest='dataset_root', 23 | default='', 24 | metavar='DATASET', 25 | help='path to dataset root folder', 26 | type=str, 27 | ) 28 | args = parser.parse_args() 29 | return args 30 | 31 | 32 | def majority_pooling(grid, k_size=2): 33 | result = np.zeros((grid.shape[0] // k_size, grid.shape[1] // k_size, grid.shape[2] // k_size)) 34 | for xx in range(0, int(np.floor(grid.shape[0]/k_size))): 35 | for yy in range(0, int(np.floor(grid.shape[1]/k_size))): 36 | for zz in range(0, int(np.floor(grid.shape[2]/k_size))): 37 | 38 | sub_m = grid[(xx*k_size):(xx*k_size)+k_size, (yy*k_size):(yy*k_size)+k_size, (zz*k_size):(zz*k_size)+k_size] 39 | unique, counts = np.unique(sub_m, return_counts=True) 40 | if True in ((unique != 0) & (unique != 255)): 41 | # Remove counts with 0 and 255 42 | counts = counts[((unique != 0) & (unique != 255))] 43 | unique = unique[((unique != 0) & (unique != 255))] 44 | else: 45 | if True in (unique == 0): 46 | counts = counts[(unique != 255)] 47 | unique = unique[(unique != 255)] 48 | value = unique[np.argmax(counts)] 49 | result[xx, yy, zz] = value 50 | return result 51 | 52 | 53 | def downscale_data(LABEL, downscaling): 54 | # Majority pooling labels downscaled in 3D 55 | LABEL = majority_pooling(LABEL, k_size=downscaling) 56 | # Reshape to 1D 57 | LABEL = np.moveaxis(LABEL, [0, 1, 2], [0, 2, 1]).reshape(-1) 58 | # Invalid file downscaled 59 | INVALID = np.zeros_like(LABEL) 60 | INVALID[np.isclose(LABEL, 255)] = 1 61 | return LABEL, INVALID 62 | 63 | 64 | def main(): 65 | 66 | args = parse_args() 67 | 68 | dset_root = args.dataset_root 69 | yaml_path, _ = os.path.split(os.path.realpath(__file__)) 70 | remap_lut = SemanticKittiIO.get_remap_lut(os.path.join(yaml_path, 'semantic-kitti.yaml')) 71 | dataset_config = yaml.safe_load(open(os.path.join(yaml_path, 'semantic-kitti.yaml'), 'r')) 72 | sequences = sorted(glob(os.path.join(dset_root, 'dataset', 'sequences', '*'))) 73 | # Selecting training/validation set sequences only (labels unavailable for test set) 74 | sequences = sequences[:11] 75 | grid_dimensions = dataset_config['grid_dims'] # [W, H, D] 76 | 77 | assert len(sequences) > 0, 'Error, no sequences on selected dataset root path' 78 | 79 | for sequence in sequences: 80 | 81 | label_paths = sorted(glob(os.path.join(sequence, 'voxels', '*.label'))) 82 | invalid_paths = sorted(glob(os.path.join(sequence, 'voxels', '*.invalid'))) 83 | out_dir = os.path.join(sequence, 'voxels') 84 | downscaling = {'1_2': 2, '1_4': 4, '1_8': 8} 85 | 86 | for i in range(len(label_paths)): 87 | 88 | filename, extension = os.path.splitext(os.path.basename(label_paths[i])) 89 | 90 | LABEL = SemanticKittiIO._read_label_SemKITTI(label_paths[i]) 91 | INVALID = SemanticKittiIO._read_invalid_SemKITTI(invalid_paths[i]) 92 | LABEL = remap_lut[LABEL.astype(np.uint16)].astype(np.float32) # Remap 20 classes semanticKITTI SSC 93 | LABEL[np.isclose(INVALID, 1)] = 255 # Setting to unknown all voxels marked on invalid mask... 94 | LABEL = np.moveaxis(LABEL.reshape([grid_dimensions[0], grid_dimensions[2], grid_dimensions[1]]), 95 | [0, 1, 2], [0, 2, 1]) # [256, 32, 256] 96 | 97 | for scale in downscaling: 98 | 99 | label_filename = os.path.join(out_dir, filename + '.label_' + scale) 100 | invalid_filename = os.path.join(out_dir, filename + '.invalid_' + scale) 101 | # If files have not been created... 102 | if not (os.path.isfile(label_filename) & os.path.isfile(invalid_filename)): 103 | LABEL_ds, INVALID_ds = downscale_data(LABEL, downscaling[scale]) 104 | SemanticKittiIO.pack(INVALID_ds.astype(dtype=np.uint8)).tofile(invalid_filename) 105 | print(time.strftime('%x %X') + ' -- => File {} - Sequence {} saved...'.format(filename + '.label_' + scale, os.path.basename(sequence))) 106 | LABEL_ds.astype(np.uint16).tofile(label_filename) 107 | print(time.strftime('%x %X') + ' -- => File {} - Sequence {} saved...'.format(filename + '.invalid_' + scale, os.path.basename(sequence))) 108 | 109 | print(time.strftime('%x %X') + ' -- => All files saved for Sequence {}'.format(os.path.basename(sequence))) 110 | 111 | print(time.strftime('%x %X') + ' -- => All files saved') 112 | 113 | exit() 114 | 115 | 116 | if __name__ == '__main__': 117 | main() -------------------------------------------------------------------------------- /LMSCNet/validate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import sys 6 | 7 | # Append root directory to system path for imports 8 | repo_path, _ = os.path.split(os.path.realpath(__file__)) 9 | repo_path, _ = os.path.split(repo_path) 10 | sys.path.append(repo_path) 11 | 12 | from LMSCNet.common.seed import seed_all 13 | from LMSCNet.common.config import CFG 14 | from LMSCNet.common.dataset import get_dataset 15 | from LMSCNet.common.model import get_model 16 | from LMSCNet.common.logger import get_logger 17 | from LMSCNet.common.io_tools import dict_to 18 | from LMSCNet.common.metrics import Metrics 19 | import LMSCNet.common.checkpoint as checkpoint 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description='LMSCNet validating') 24 | parser.add_argument( 25 | '--weights', 26 | dest='weights_file', 27 | default='', 28 | metavar='FILE', 29 | help='path to folder where model.pth file is', 30 | type=str, 31 | ) 32 | parser.add_argument( 33 | '--dset_root', 34 | dest='dataset_root', 35 | default='', 36 | metavar='DATASET', 37 | help='path to dataset root folder', 38 | type=str, 39 | ) 40 | args = parser.parse_args() 41 | return args 42 | 43 | 44 | def validate(model, dset, _cfg, logger, metrics): 45 | 46 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 47 | dtype = torch.float32 # Tensor type to be used 48 | 49 | # Moving optimizer and model to used device 50 | model = model.to(device=device) 51 | 52 | logger.info('=> Passing the network on the validation set...') 53 | 54 | model.eval() 55 | 56 | with torch.no_grad(): 57 | 58 | for t, (data, indices) in enumerate(dset): 59 | 60 | data = dict_to(data, device, dtype) 61 | 62 | scores = model(data) 63 | 64 | loss = model.compute_loss(scores, data) 65 | 66 | # Updating batch losses to then get mean for epoch loss 67 | metrics.losses_track.update_validaiton_losses(loss) 68 | 69 | if (t + 1) % _cfg._dict['VAL']['SUMMARY_PERIOD'] == 0: 70 | loss_print = '=> Iteration [{}/{}], Train Losses: '.format(t+1, len(dset)) 71 | for key in loss.keys(): loss_print += '{} = {:.6f}, '.format(key, loss[key]) 72 | logger.info(loss_print[:-3]) 73 | 74 | metrics.add_batch(prediction=scores, target=model.get_target(data)) 75 | 76 | epoch_loss = metrics.losses_track.validation_losses['total']/metrics.losses_track.validation_iteration_counts 77 | 78 | logger.info('=> [Total Validation Loss = {}]'.format(epoch_loss)) 79 | for scale in metrics.evaluator.keys(): 80 | loss_scale = metrics.losses_track.validation_losses['semantic_{}'.format(scale)].item()/metrics.losses_track.validation_iteration_counts 81 | logger.info('=> [Scale {}: Loss = {:.6f} - mIoU = {:.6f} - IoU = {:.6f} ' 82 | '- P = {:.6f} - R = {:.6f} - F1 = {:.6f}]' 83 | .format(scale, loss_scale, 84 | metrics.get_semantics_mIoU(scale).item(), 85 | metrics.get_occupancy_IoU(scale).item(), 86 | metrics.get_occupancy_Precision(scale).item(), 87 | metrics.get_occupancy_Recall(scale).item(), 88 | metrics.get_occupancy_F1(scale).item())) 89 | 90 | logger.info('=> Training set class-wise IoU:') 91 | for i in range(1, metrics.nbr_classes): 92 | class_name = dset.dataset.dataset_config['labels'][dset.dataset.dataset_config['learning_map_inv'][i]] 93 | class_score = metrics.evaluator['1_1'].getIoU()[1][i] 94 | logger.info(' => IoU {}: {:.6f}'.format(class_name, class_score)) 95 | 96 | return 97 | 98 | 99 | def main(): 100 | 101 | # https://github.com/pytorch/pytorch/issues/27588 102 | torch.backends.cudnn.enabled = False 103 | 104 | seed_all(0) 105 | 106 | args = parse_args() 107 | 108 | weights_f = args.weights_file 109 | dataset_f = args.dataset_root 110 | 111 | assert os.path.isfile(weights_f), '=> No file found at {}' 112 | 113 | checkpoint_path = torch.load(weights_f) 114 | config_dict = checkpoint_path.pop('config_dict') 115 | config_dict['DATASET']['ROOT_DIR'] = dataset_f 116 | 117 | # Read train configuration file 118 | _cfg = CFG() 119 | _cfg.from_dict(config_dict) 120 | # Setting the logger to print statements and also save them into logs file 121 | logger = get_logger(_cfg._dict['OUTPUT']['OUTPUT_PATH'], 'logs_val.log') 122 | 123 | logger.info('============ Validation weights: "%s" ============\n' % weights_f) 124 | dataset = get_dataset(_cfg) 125 | 126 | logger.info('=> Loading network architecture...') 127 | model = get_model(_cfg, dataset['train'].dataset) 128 | if torch.cuda.device_count() > 1: 129 | model = nn.DataParallel(model) 130 | model = model.module 131 | 132 | logger.info('=> Loading network weights...') 133 | model = checkpoint.load_model(model, weights_f, logger) 134 | 135 | nbr_iterations = len(dataset['val']) 136 | metrics = Metrics(dataset['val'].dataset.nbr_classes, nbr_iterations, model.get_scales()) 137 | metrics.reset_evaluator() 138 | metrics.losses_track.set_validation_losses(model.get_validation_loss_keys()) 139 | metrics.losses_track.set_train_losses(model.get_train_loss_keys()) 140 | 141 | validate(model, dataset['val'], _cfg, logger, metrics) 142 | 143 | logger.info('=> ============ Network Validation Done ============') 144 | 145 | exit() 146 | 147 | 148 | if __name__ == '__main__': 149 | main() -------------------------------------------------------------------------------- /LMSCNet/data/semantic-kitti.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | nbr_classes: 20 3 | grid_dims: [256, 32, 256] # (W, H, D) 4 | labels: 5 | 0 : "unlabeled" 6 | 1 : "outlier" 7 | 10: "car" 8 | 11: "bicycle" 9 | 13: "bus" 10 | 15: "motorcycle" 11 | 16: "on-rails" 12 | 18: "truck" 13 | 20: "other-vehicle" 14 | 30: "person" 15 | 31: "bicyclist" 16 | 32: "motorcyclist" 17 | 40: "road" 18 | 44: "parking" 19 | 48: "sidewalk" 20 | 49: "other-ground" 21 | 50: "building" 22 | 51: "fence" 23 | 52: "other-structure" 24 | 60: "lane-marking" 25 | 70: "vegetation" 26 | 71: "trunk" 27 | 72: "terrain" 28 | 80: "pole" 29 | 81: "traffic-sign" 30 | 99: "other-object" 31 | 252: "moving-car" 32 | 253: "moving-bicyclist" 33 | 254: "moving-person" 34 | 255: "moving-motorcyclist" 35 | 256: "moving-on-rails" 36 | 257: "moving-bus" 37 | 258: "moving-truck" 38 | 259: "moving-other-vehicle" 39 | color_map: # bgr 40 | 0 : [0, 0, 0] 41 | 1 : [0, 0, 255] 42 | 10: [245, 150, 100] 43 | 11: [245, 230, 100] 44 | 13: [250, 80, 100] 45 | 15: [150, 60, 30] 46 | 16: [255, 0, 0] 47 | 18: [180, 30, 80] 48 | 20: [255, 0, 0] 49 | 30: [30, 30, 255] 50 | 31: [200, 40, 255] 51 | 32: [90, 30, 150] 52 | 40: [255, 0, 255] 53 | 44: [255, 150, 255] 54 | 48: [75, 0, 75] 55 | 49: [75, 0, 175] 56 | 50: [0, 200, 255] 57 | 51: [50, 120, 255] 58 | 52: [0, 150, 255] 59 | 60: [170, 255, 150] 60 | 70: [0, 175, 0] 61 | 71: [0, 60, 135] 62 | 72: [80, 240, 150] 63 | 80: [150, 240, 255] 64 | 81: [0, 0, 255] 65 | 99: [255, 255, 50] 66 | 252: [245, 150, 100] 67 | 256: [255, 0, 0] 68 | 253: [200, 40, 255] 69 | 254: [30, 30, 255] 70 | 255: [90, 30, 150] 71 | 257: [250, 80, 100] 72 | 258: [180, 30, 80] 73 | 259: [255, 0, 0] 74 | content: # as a ratio with the total number of points 75 | 0: 0.018889854628292943 76 | 1: 0.0002937197336781505 77 | 10: 0.040818519255974316 78 | 11: 0.00016609538710764618 79 | 13: 2.7879693665067774e-05 80 | 15: 0.00039838616015114444 81 | 16: 0.0 82 | 18: 0.0020633612104619787 83 | 20: 0.0016218197275284021 84 | 30: 0.00017698551338515307 85 | 31: 1.1065903904919655e-08 86 | 32: 5.532951952459828e-09 87 | 40: 0.1987493871255525 88 | 44: 0.014717169549888214 89 | 48: 0.14392298360372 90 | 49: 0.0039048553037472045 91 | 50: 0.1326861944777486 92 | 51: 0.0723592229456223 93 | 52: 0.002395131480328884 94 | 60: 4.7084144280367186e-05 95 | 70: 0.26681502148037506 96 | 71: 0.006035012012626033 97 | 72: 0.07814222006271769 98 | 80: 0.002855498193863172 99 | 81: 0.0006155958086189918 100 | 99: 0.009923127583046915 101 | 252: 0.001789309418528068 102 | 253: 0.00012709999297008662 103 | 254: 0.00016059776092534436 104 | 255: 3.745553104802113e-05 105 | 256: 0.0 106 | 257: 0.00011351574470342043 107 | 258: 0.00010157861367183268 108 | 259: 4.3840131989471124e-05 109 | # classes that are indistinguishable from single scan or inconsistent in 110 | # ground truth are mapped to their closest equivalent 111 | learning_map: 112 | 0 : 0 # "unlabeled" 113 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped 114 | 10: 1 # "car" 115 | 11: 2 # "bicycle" 116 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped 117 | 15: 3 # "motorcycle" 118 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped 119 | 18: 4 # "truck" 120 | 20: 5 # "other-vehicle" 121 | 30: 6 # "person" 122 | 31: 7 # "bicyclist" 123 | 32: 8 # "motorcyclist" 124 | 40: 9 # "road" 125 | 44: 10 # "parking" 126 | 48: 11 # "sidewalk" 127 | 49: 12 # "other-ground" 128 | 50: 13 # "building" 129 | 51: 14 # "fence" 130 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped 131 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped 132 | 70: 15 # "vegetation" 133 | 71: 16 # "trunk" 134 | 72: 17 # "terrain" 135 | 80: 18 # "pole" 136 | 81: 19 # "traffic-sign" 137 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped 138 | 252: 1 # "moving-car" to "car" ------------------------------------mapped 139 | 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped 140 | 254: 6 # "moving-person" to "person" ------------------------------mapped 141 | 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped 142 | 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped 143 | 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped 144 | 258: 4 # "moving-truck" to "truck" --------------------------------mapped 145 | 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped 146 | learning_map_inv: # inverse of previous map 147 | 0: 0 # "unlabeled", and others ignored 148 | 1: 10 # "car" 149 | 2: 11 # "bicycle" 150 | 3: 15 # "motorcycle" 151 | 4: 18 # "truck" 152 | 5: 20 # "other-vehicle" 153 | 6: 30 # "person" 154 | 7: 31 # "bicyclist" 155 | 8: 32 # "motorcyclist" 156 | 9: 40 # "road" 157 | 10: 44 # "parking" 158 | 11: 48 # "sidewalk" 159 | 12: 49 # "other-ground" 160 | 13: 50 # "building" 161 | 14: 51 # "fence" 162 | 15: 70 # "vegetation" 163 | 16: 71 # "trunk" 164 | 17: 72 # "terrain" 165 | 18: 80 # "pole" 166 | 19: 81 # "traffic-sign" 167 | learning_ignore: # Ignore classes 168 | 0: True # "unlabeled", and others ignored 169 | 1: False # "car" 170 | 2: False # "bicycle" 171 | 3: False # "motorcycle" 172 | 4: False # "truck" 173 | 5: False # "other-vehicle" 174 | 6: False # "person" 175 | 7: False # "bicyclist" 176 | 8: False # "motorcyclist" 177 | 9: False # "road" 178 | 10: False # "parking" 179 | 11: False # "sidewalk" 180 | 12: False # "other-ground" 181 | 13: False # "building" 182 | 14: False # "fence" 183 | 15: False # "vegetation" 184 | 16: False # "trunk" 185 | 17: False # "terrain" 186 | 18: False # "pole" 187 | 19: False # "traffic-sign" 188 | split: # sequence numbers 189 | train: 190 | - 0 191 | - 1 192 | - 2 193 | - 3 194 | - 4 195 | - 5 196 | - 6 197 | - 7 198 | - 9 199 | - 10 200 | valid: 201 | - 8 202 | test: 203 | - 11 204 | - 12 205 | - 13 206 | - 14 207 | - 15 208 | - 16 209 | - 17 210 | - 18 211 | - 19 212 | - 20 213 | - 21 214 | -------------------------------------------------------------------------------- /LMSCNet/common/metrics.py: -------------------------------------------------------------------------------- 1 | # Some sections of this code reused code from SemanticKITTI development kit 2 | # https://github.com/PRBonn/semantic-kitti-api 3 | 4 | import numpy as np 5 | import torch 6 | import copy 7 | 8 | 9 | class iouEval: 10 | def __init__(self, n_classes, ignore=None): 11 | # classes 12 | self.n_classes = n_classes 13 | 14 | # What to include and ignore from the means 15 | self.ignore = np.array(ignore, dtype=np.int64) 16 | self.include = np.array( 17 | [n for n in range(self.n_classes) if n not in self.ignore], dtype=np.int64) 18 | 19 | # reset the class counters 20 | self.reset() 21 | 22 | def num_classes(self): 23 | return self.n_classes 24 | 25 | def reset(self): 26 | self.conf_matrix = np.zeros((self.n_classes, 27 | self.n_classes), 28 | dtype=np.int64) 29 | 30 | def addBatch(self, x, y): # x=preds, y=targets 31 | 32 | assert x.shape == y.shape 33 | 34 | # sizes should be matching 35 | x_row = x.reshape(-1) # de-batchify 36 | y_row = y.reshape(-1) # de-batchify 37 | 38 | # check 39 | assert(x_row.shape == x_row.shape) 40 | 41 | # create indexes 42 | idxs = tuple(np.stack((x_row, y_row), axis=0)) 43 | 44 | # make confusion matrix (cols = gt, rows = pred) 45 | np.add.at(self.conf_matrix, idxs, 1) 46 | 47 | def getStats(self): 48 | # remove fp from confusion on the ignore classes cols 49 | conf = self.conf_matrix.copy() 50 | conf[:, self.ignore] = 0 51 | 52 | # get the clean stats 53 | tp = np.diag(conf) 54 | fp = conf.sum(axis=1) - tp 55 | fn = conf.sum(axis=0) - tp 56 | return tp, fp, fn 57 | 58 | def getIoU(self): 59 | tp, fp, fn = self.getStats() 60 | intersection = tp 61 | union = tp + fp + fn + 1e-15 62 | iou = intersection / union 63 | iou_mean = (intersection[self.include] / union[self.include]).mean() 64 | return iou_mean, iou # returns "iou mean", "iou per class" ALL CLASSES 65 | 66 | def getacc(self): 67 | tp, fp, fn = self.getStats() 68 | total_tp = tp.sum() 69 | total = tp[self.include].sum() + fp[self.include].sum() + 1e-15 70 | acc_mean = total_tp / total 71 | return acc_mean # returns "acc mean" 72 | 73 | def get_confusion(self): 74 | return self.conf_matrix.copy() 75 | 76 | 77 | class LossesTrackEpoch: 78 | def __init__(self, num_iterations): 79 | # classes 80 | self.num_iterations = num_iterations 81 | self.validation_losses = {} 82 | self.train_losses = {} 83 | self.train_iteration_counts = 0 84 | self.validation_iteration_counts = 0 85 | 86 | def set_validation_losses(self, keys): 87 | for key in keys: 88 | self.validation_losses[key] = 0 89 | return 90 | 91 | def set_train_losses(self, keys): 92 | for key in keys: 93 | self.train_losses[key] = 0 94 | return 95 | 96 | def update_train_losses(self, loss): 97 | for key in loss: 98 | self.train_losses[key] += loss[key] 99 | self.train_iteration_counts += 1 100 | return 101 | 102 | def update_validaiton_losses(self, loss): 103 | for key in loss: 104 | self.validation_losses[key] += loss[key] 105 | self.validation_iteration_counts += 1 106 | return 107 | 108 | def restart_train_losses(self): 109 | for key in self.train_losses.keys(): 110 | self.train_losses[key] = 0 111 | self.train_iteration_counts = 0 112 | return 113 | 114 | def restart_validation_losses(self): 115 | for key in self.validation_losses.keys(): 116 | self.validation_losses[key] = 0 117 | self.validation_iteration_counts = 0 118 | return 119 | 120 | 121 | class Metrics: 122 | 123 | def __init__(self, nbr_classes, num_iterations_epoch, scales): 124 | 125 | self.nbr_classes = nbr_classes 126 | self.evaluator = {} 127 | for scale in scales: 128 | self.evaluator[scale] = iouEval(self.nbr_classes, []) 129 | # self.evaluator = iouEval(self.nbr_classes, []) 130 | self.losses_track = LossesTrackEpoch(num_iterations_epoch) 131 | self.best_metric_record = {'mIoU': 0, 'IoU':0, 'epoch': 0, 'loss': 99999999} 132 | 133 | return 134 | 135 | def add_batch(self, prediction, target): 136 | 137 | # passing to cpu 138 | for key in prediction: 139 | prediction[key] = torch.argmax(prediction[key], dim=1).data.cpu().numpy() 140 | for key in target: 141 | target[key] = target[key].data.cpu().numpy() 142 | 143 | for key in target: 144 | prediction['pred_semantic_' + key] = prediction['pred_semantic_' + key].reshape(-1).astype('int64') 145 | target[key] = target[key].reshape(-1).astype('int64') 146 | lidar_mask = self.get_eval_mask_Lidar(target[key]) 147 | self.evaluator[key].addBatch(prediction['pred_semantic_' + key][lidar_mask], target[key][lidar_mask]) 148 | 149 | return 150 | 151 | def get_eval_mask_Lidar(self, target): 152 | ''' 153 | eval_mask_lidar is only to ingore unknown voxels in groundtruth 154 | ''' 155 | mask = (target != 255) 156 | return mask 157 | 158 | def get_occupancy_IoU(self, scale): 159 | conf = self.evaluator[scale].get_confusion() 160 | tp_occupancy = np.sum(conf[1:, 1:]) 161 | fp_occupancy = np.sum(conf[1:, 0]) 162 | fn_occupancy = np.sum(conf[0, 1:]) 163 | intersection = tp_occupancy 164 | union = tp_occupancy + fp_occupancy + fn_occupancy + 1e-15 165 | iou_occupancy = intersection / union 166 | return iou_occupancy # returns iou occupancy 167 | 168 | def get_occupancy_Precision(self, scale): 169 | conf = self.evaluator[scale].get_confusion() 170 | tp_occupancy = np.sum(conf[1:, 1:]) 171 | fp_occupancy = np.sum(conf[1:, 0]) 172 | precision = tp_occupancy / (tp_occupancy + fp_occupancy + 1e-15) 173 | return precision # returns precision occupancy 174 | 175 | def get_occupancy_Recall(self, scale): 176 | conf = self.evaluator[scale].get_confusion() 177 | tp_occupancy = np.sum(conf[1:, 1:]) 178 | fn_occupancy = np.sum(conf[0, 1:]) 179 | recall = tp_occupancy/(tp_occupancy + fn_occupancy + 1e-15) 180 | return recall # returns recall occupancy 181 | 182 | def get_occupancy_F1(self, scale): 183 | conf = self.evaluator[scale].get_confusion() 184 | tp_occupancy = np.sum(conf[1:, 1:]) 185 | fn_occupancy = np.sum(conf[0, 1:]) 186 | fp_occupancy = np.sum(conf[1:, 0]) 187 | precision = tp_occupancy/(tp_occupancy + fp_occupancy + 1e-15) 188 | recall = tp_occupancy/(tp_occupancy + fn_occupancy + 1e-15) 189 | F1 = 2 * (precision * recall) / (precision + recall + 1e-15) 190 | return F1 # returns recall occupancy 191 | 192 | def get_semantics_mIoU(self, scale): 193 | _, class_jaccard = self.evaluator[scale].getIoU() 194 | mIoU_semantics = class_jaccard[1:].mean() # Ignore on free voxels (0 excluded) 195 | return mIoU_semantics # returns mIoU semantics 196 | 197 | def reset_evaluator(self): 198 | for key in self.evaluator: 199 | self.evaluator[key].reset() 200 | 201 | def update_best_metric_record(self, mIoU, IoU, loss, epoch): 202 | self.best_metric_record['mIoU'] = mIoU 203 | self.best_metric_record['IoU'] = IoU 204 | self.best_metric_record['loss'] = loss 205 | self.best_metric_record['epoch'] = epoch 206 | return 207 | 208 | -------------------------------------------------------------------------------- /LMSCNet/models/SSCNet_full.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | class SSCNet_full(nn.Module): 7 | ''' 8 | # Class coded from caffe model https://github.com/shurans/sscnet/blob/master/test/demo.txt 9 | ''' 10 | 11 | def __init__(self, class_num): 12 | ''' 13 | SSCNet architecture 14 | :param N: number of classes to be predicted (i.e. 12 for NYUv2) 15 | ''' 16 | super().__init__() 17 | 18 | self.nbr_classes = class_num 19 | 20 | self.conv1_1 = nn.Conv3d(1, 16, kernel_size=7, padding=3, stride=2, dilation=1) # conv(16, 7, 2, 1) 21 | 22 | self.reduction2_1 = nn.Conv3d(16, 32, kernel_size=1, padding=0, stride=1, dilation=1) # conv(32, 1, 1, 1) 23 | 24 | self.conv2_1 = nn.Conv3d(16, 32, kernel_size=3, padding=1, stride=1, dilation=1) # conv(32, 3, 1, 1) 25 | self.conv2_2 = nn.Conv3d(32, 32, kernel_size=3, padding=1, stride=1, dilation=1) # conv(32, 3, 1, 1) 26 | 27 | self.pool2 = nn.MaxPool3d(2) # pooling 28 | 29 | self.reduction3_1 = nn.Conv3d(64, 64, kernel_size=1, padding=0, stride=1, dilation=1) # conv(64, 1, 1, 1) 30 | 31 | self.conv3_1 = nn.Conv3d(32, 64, kernel_size=3, padding=1, stride=1, dilation=1) # conv(64, 3, 1, 1) 32 | self.conv3_2 = nn.Conv3d(64, 64, kernel_size=3, padding=1, stride=1, dilation=1) # conv(64, 3, 1, 1) 33 | 34 | self.conv3_3 = nn.Conv3d(64, 64, kernel_size=3, padding=1, stride=1, dilation=1) # conv(64, 3, 1, 1) 35 | self.conv3_4 = nn.Conv3d(64, 64, kernel_size=3, padding=1, stride=1, dilation=1) # conv(64, 3, 1, 1) 36 | 37 | self.conv3_5 = nn.Conv3d(64, 64, kernel_size=3, padding=2, stride=1, dilation=2) # dilated(64, 3, 1, 2) 38 | self.conv3_6 = nn.Conv3d(64, 64, kernel_size=3, padding=2, stride=1, dilation=2) # dilated(64, 3, 1, 2) 39 | 40 | self.conv3_7 = nn.Conv3d(64, 64, kernel_size=3, padding=2, stride=1, dilation=2) # dilated(64, 3, 1, 2) 41 | self.conv3_8 = nn.Conv3d(64, 64, kernel_size=3, padding=2, stride=1, dilation=2) # dilated(64, 3, 1, 2) 42 | 43 | self.conv4_1 = nn.Conv3d(192, 128, kernel_size=1, padding=0, stride=1, dilation=1) # conv(128, 1, 1, 1) 44 | self.conv4_2 = nn.Conv3d(128, 128, kernel_size=1, padding=0, stride=1, dilation=1) # conv(128, 1, 1, 1) 45 | 46 | self.deconv_classes = nn.ConvTranspose3d(128, self.nbr_classes, kernel_size=4, padding=0, stride=4) 47 | 48 | return 49 | 50 | def forward(self, x): 51 | 52 | input = x['3D_OCCUPANCY'].permute(0, 1, 3, 2, 4) # Reshaping [bs, H, W, D] 53 | 54 | out = F.relu(self.conv1_1(input)) 55 | out_add_1 = self.reduction2_1(out) 56 | out = F.relu((self.conv2_1(out))) 57 | out = F.relu(out_add_1 + self.conv2_2(out)) 58 | 59 | out = self.pool2(out) 60 | 61 | out = F.relu(self.conv3_1(out)) 62 | out_add_2 = self.reduction3_1(out) 63 | out = F.relu(out_add_2 + self.conv3_2(out)) 64 | 65 | out_add_3 = self.conv3_3(out) 66 | out = self.conv3_4(F.relu(out_add_3)) 67 | out_res_1 = F.relu(out_add_3 + out) 68 | 69 | out_add_4 = self.conv3_5(out_res_1) 70 | out = self.conv3_6(F.relu(out_add_4)) 71 | out_res_2 = F.relu(out_add_4 + out) 72 | 73 | out_add_5 = self.conv3_7(out_res_2) 74 | out = self.conv3_8(F.relu(out_add_5)) 75 | out_res_3 = F.relu(out_add_5 + out) 76 | 77 | out = torch.cat((out_res_3, out_res_2, out_res_1), 1) 78 | 79 | out = F.relu(self.conv4_1(out)) 80 | out = F.relu(self.conv4_2(out)) 81 | 82 | out = self.deconv_classes(out) 83 | 84 | out = out.permute(0, 1, 3, 2, 4) # [bs, C, H, W, D] -> [bs, C, W, H, D] 85 | 86 | scores = {'pred_semantic_1_1': out} 87 | 88 | return scores 89 | 90 | def weights_initializer(self, m): 91 | if isinstance(m, nn.Conv2d): 92 | nn.init.kaiming_uniform_(m.weight) 93 | nn.init.zeros_(m.bias) 94 | 95 | def weights_init(self): 96 | self.apply(self.weights_initializer) 97 | 98 | def get_parameters(self): 99 | return self.parameters() 100 | 101 | def compute_loss(self, scores, data): 102 | ''' 103 | :param: prediction: the predicted tensor, must be [BS, C, W, H, D] 104 | ''' 105 | 106 | target = data['3D_LABEL']['1_1'] # [bs, C, W, H, D] 107 | device, dtype = target.device, target.dtype 108 | class_weights = torch.ones(self.nbr_classes).to(device=device, dtype=dtype) 109 | 110 | criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=255, reduction='none').to(device=device) 111 | 112 | # Reduction is none to be able to apply the 2N data balancing after. The mean will be calculated then... 113 | loss_1_1 = criterion(scores['pred_semantic_1_1'], data['3D_LABEL']['1_1'].long()) 114 | # F.cross_entropy(prediction, target.long(), weight=class_weights, ignore_index=255, reduction='none') 115 | 116 | # For SSCNet all classes have same weight and their weight is done by their 2N Data Balancing 117 | weight_db = self.get_data_balance_2N(data) 118 | # Calculate loss weighted by 2N data balancing 119 | # Remember target == 255 is ignored for the loss, this has to be considered for the mean..! 120 | # Also we are considering loss on only 2N free/occluded voxels, which is given by weight_db mask. 121 | # We do not consider 2N in occluded voxels only since is Lidar data, all scene needs to be completed. 122 | # Including outside FoV 123 | loss_1_1 = torch.sum(loss_1_1*weight_db) / torch.sum((weight_db != 1) & (target != 255)) 124 | 125 | loss = {'total': loss_1_1, 'semantic_1_1': loss_1_1} 126 | 127 | return loss 128 | 129 | def get_data_balance_2N(self, data): 130 | ''' 131 | Get a weight tensor for the loss computing. The weight tensor will ignore unknown voxels on target tensor 132 | (label==255). A random under sampling on free voxels with a relation 2:1 between free:occupied is obtained. 133 | The subsampling is done by considering only free occluded voxels. Explanation in SSCNet article 134 | (https://arxiv.org/abs/1611.08974) 135 | 136 | There is a discrepancy between data balancing explained on article and data balancing implemented on code 137 | https://github.com/shurans/sscnet/issues/33 138 | 139 | The subsampling will be done in all free voxels.. Not occluded only.. As Martin Gabarde did on TS3D.. There is 140 | a problem on what is explained for data balancing on SSCNet 141 | ''' 142 | 143 | batch_target = data['3D_LABEL']['1_1'] 144 | weight = torch.zeros_like(batch_target) 145 | for i, target in enumerate(batch_target): 146 | nbr_occupied = torch.sum((target > 0) & (target < 255)) 147 | nbr_free = torch.sum(target == 0) 148 | free_indices = torch.where(target == 0) # Indices of free voxels on target 149 | subsampling = torch.randint(nbr_free, (2 * nbr_occupied,)) # Random subsampling 2*nbr_occupied in range nbr_free 150 | mask = (free_indices[0][subsampling], free_indices[1][subsampling], free_indices[2][subsampling]) # New mask 151 | weight[i][mask] = 1 # Subsampled free voxels to be considered (2N) 152 | weight[i][(target > 0) & (target < 255)] = 1 # Occupied voxels 153 | 154 | # Returning weight that has N occupied voxels and 2N free voxels... 155 | return weight 156 | 157 | def get_target(self, data): 158 | ''' 159 | Return the target to use for evaluation of the model 160 | ''' 161 | return {'1_1': data['3D_LABEL']['1_1']} 162 | 163 | def get_scales(self): 164 | ''' 165 | Return scales needed to train the model 166 | ''' 167 | scales = ['1_1'] 168 | return scales 169 | 170 | def get_validation_loss_keys(self): 171 | return ['total', 'semantic_1_1'] 172 | 173 | def get_train_loss_keys(self): 174 | return ['total', 'semantic_1_1'] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LMSCNet: Lightweight Multiscale 3D Semantic Completion 2 | Official repository. 3 | 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/lmscnet-lightweight-multiscale-3d-semantic/3d-semantic-scene-completion-on-semantickitti)](https://paperswithcode.com/sota/3d-semantic-scene-completion-on-semantickitti?p=lmscnet-lightweight-multiscale-3d-semantic) 5 | 6 | ## Paper 7 | ![alt text](teaser.png "LMSCNet") 8 | 9 | [LMSCNet: Lightweight Multiscale 3D Semantic Completion](https://arxiv.org/abs/2008.10559) \ 10 | [Luis Roldão](https://team.inria.fr/rits/membres/luis-roldao-jimenez/), [Raoul de Charette](https://team.inria.fr/rits/membres/raoul-de-charette/), [Anne Verroust-Blondet](https://team.inria.fr/rits/membres/anne-verroust/) 11 | Inria, Akka Research. 3DV 2020 (oral) \ 12 | [[Demo Video]](https://www.youtube.com/watch?v=J6dYoWx4Xqw&feature=youtu.be) 13 | 14 | 15 | If you find our work useful, please cite: 16 | ``` 17 | @inproceedings{roldao2020lmscnet, 18 | title={LMSCNet: Lightweight Multiscale 3D Semantic Completion}, 19 | author={Rold{\~a}o, Luis and de Charette, Raoul and Verroust-Blondet, Anne}, 20 | booktitle={International Conference on 3D Vision (3DV)}, 21 | year={2020} 22 | } 23 | ``` 24 | 25 | ## Preparation 26 | ### Prerequisites 27 | Tested with 28 | * PyTorch 1.3.1 29 | * CUDA 10.2 30 | * Python 3.7.5 31 | * Numpy 1.17.4 32 | 33 | ### Setup 34 | We advise to create a new conda environment for installation. 35 | 36 | ``` 37 | $ conda create --name lmscnet_ssc python=3.7.5 numpy tqdm scipy scikit-learn pyyaml imageio tensorboard -y 38 | $ conda activate lmscnet_ssc 39 | $ conda install pytorch torchvision cudatoolkit=10.0 -c pytorch 40 | ``` 41 | 42 | Then clone this repository in desired location 43 | ``` 44 | $ git clone https://github.com/cv-rits/LMSCNet 45 | ``` 46 | 47 | ### Dataset 48 | 49 | Please download the Full Semantic Scene Completion dataset (v1.1) from the [SemanticKITII website](http://www.semantic-kitti.org/dataset.html) and extract it. 50 | 51 | You need to preprocess the data to generate lower scale labels for LMSCNet first. 52 | The preprocessing performs majority pooling over high-resolution original scale label 53 | grids (1:1) in order to obtain ground-truth data at lower resolutions (1:2, 1:4 and 1:8). 54 | It also generates validity masks as such resolutions to consider the loss on known voxels only, 55 | as in original scale data. All information will be stored in the same format and respective location 56 | than the semanticKITTI provided data with new file extensions (`file.label_1_X` and `file.invalid_1_X`). 57 | 58 | If you are using **v1.1** of the dataset, you can download the data directly from [here](https://www.rocq.inria.fr/rits_files/download.php?file=computer-vision/lmscnet/semanticKITTI_v1.1_dscale.zip). 59 | Please extract the data into the semanticKITTI root folder. 60 | 61 | Otherwise, you need to generate the data by running the `LMSCNet/data/labels_downscale.py` as follows: 62 | ``` 63 | $ cd 64 | $ python LMSCNet/data/labels_downscale.py --dset_root 65 | ``` 66 | * `dset_root` should point to the root directory of the SemanticKITTI dataset (containing `dataset` folder) 67 | 68 | ## Training 69 | 70 | All training settings can be edited by using the yaml file generator in `SSC_configs/config_routine.py`. We provide training 71 | routine examples in the `SSC_configs/examples/` folder. Make sure to change the dataset path to your extracted dataset location in such files if you 72 | want to use them for training. Additionally, you can change the folder where the performance and stats will be stored, this folder has been defined as `SSC_out` by default. 73 | * `config_dict['DATASET']['ROOT_DIR']` should be changed to the root directory of the SemanticKITTI dataset (containing `dataset` folder) 74 | * `config_dict['OUTPUT']['OUT_ROOT'] ` to be changed to desired output folder. 75 | 76 | ### LMSCNet & LMSCNet-SS 77 | You can run the training with 78 | ``` 79 | $ cd 80 | $ python LMSCNet/train.py --cfg SSC_configs/examples/LMSCNet.yaml --dset_root 81 | ``` 82 | 83 | We also provide single scale version of our network which can achieve slightly better performance at 84 | the cost of losing multiscale capacity: 85 | ``` 86 | $ cd 87 | $ python LMSCNet/train.py --cfg SSC_configs/examples/LMSCNet_SS.yaml --dset_root 88 | ``` 89 | 90 | ### Baselines 91 | Train coded baselines with: 92 | ``` 93 | $ python LMSCNet/train.py --cfg SSC_configs/examples/SSCNet.yaml --dset_root 94 | $ python LMSCNet/train.py --cfg SSC_configs/examples/SSCNet_full.yaml --dset_root 95 | ``` 96 | 97 | In all previous examples you need to provide your path to the dataset folder, if not provided, the path 98 | set in the `.yaml` file will be used by default. 99 | 100 | ## Validating & Testing 101 | 102 | Validation passes are done during training routine. Additional pass in the validation set with saved model 103 | to check performance can be done by using the `LMSCNet/validate.py` file. You need to provide the path to the saved model and the 104 | dataset root directory. 105 | 106 | ``` 107 | $ cd 108 | $ python LMSCNet/validate.py --weights --dset_root 109 | ``` 110 | 111 | Since SemantiKITTI contains a hidden test set, we provide test routine to save predicted output in same 112 | format of SemantiKITTI, which can be compressed and uploaded to the [SemanticKITTI Semantic Scene Completion Benchmark](http://www.semantic-kitti.org/tasks.html#semseg). 113 | We recommend to pass compressed data through official checking script provided in the [SemanticKITTI Development Kit](http://www.semantic-kitti.org/resources.html#devkit) to avoid any issue. 114 | You can provide which checkpoints you want to use for testing. We used the ones that performed best on the validation set during training. 115 | You need to provide the path to the saved model, the 116 | dataset root directory and the output path to where the predictions will be stored. 117 | 118 | ``` 119 | $ cd 120 | $ python LMSCNet/test.py --weights --dset_root --out_path 121 | ``` 122 | 123 | ## Ablation 124 | 125 | We test the robustness of our network against sparsity by retrieving the original 64-layers KITTI scans used in SemanticKITTI 126 | and subsampling 8/16/32 layers LiDARs with layers subsampling. 127 | 128 | _coming up soon..._ 129 | 130 | 134 | 135 | ## Model Zoo 136 | 137 | You can download the models with the scores below from this [Google drive folder](https://drive.google.com/drive/folders/12A46LE3BO6tQ8Y5OFbR4ImP7nzM8_3wb?usp=sharing). 138 | 139 | | Method | SC IoU | SSC mIoU | 140 | | ------------------------- | -------------------- | -------------------- | 141 | | SSCNet-full | 49.98* | 16.14* | 142 | | LMSCNet | 55.32* | 17.01* | 143 | | LMSCNet-SS | 56.72* | 17.62* | 144 | 145 | 146 | * Results reported correspond to SemanticKITTI hidden test set V1.0. 147 | SemanticKITTI benchmark has recently changed 148 | to v1.1 due to grid shift issue ([link](https://github.com/PRBonn/semantic-kitti-api/issues/49)). This should bring also slight differences if re-uploaded to test benchmark. 149 | 150 | ## License 151 | LMSCNet is released under the [Apache 2.0 license](./LICENSE). 152 | -------------------------------------------------------------------------------- /LMSCNet/models/SSCNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | class SSCNet(nn.Module): 7 | ''' 8 | # Class coded from caffe model https://github.com/shurans/sscnet/blob/master/test/demo.txt 9 | ''' 10 | 11 | def __init__(self, class_num): 12 | ''' 13 | SSCNet architecture 14 | :param N: number of classes to be predicted (i.e. 12 for NYUv2) 15 | ''' 16 | super().__init__() 17 | 18 | self.nbr_classes = class_num 19 | 20 | self.conv1_1 = nn.Conv3d(1, 16, kernel_size=7, padding=3, stride=2, dilation=1) # conv(16, 7, 2, 1) 21 | 22 | self.reduction2_1 = nn.Conv3d(16, 32, kernel_size=1, padding=0, stride=1, dilation=1) # conv(32, 1, 1, 1) 23 | 24 | self.conv2_1 = nn.Conv3d(16, 32, kernel_size=3, padding=1, stride=1, dilation=1) # conv(32, 3, 1, 1) 25 | self.conv2_2 = nn.Conv3d(32, 32, kernel_size=3, padding=1, stride=1, dilation=1) # conv(32, 3, 1, 1) 26 | 27 | self.pool2 = nn.MaxPool3d(2) # pooling 28 | 29 | self.reduction3_1 = nn.Conv3d(64, 64, kernel_size=1, padding=0, stride=1, dilation=1) # conv(64, 1, 1, 1) 30 | 31 | self.conv3_1 = nn.Conv3d(32, 64, kernel_size=3, padding=1, stride=1, dilation=1) # conv(64, 3, 1, 1) 32 | self.conv3_2 = nn.Conv3d(64, 64, kernel_size=3, padding=1, stride=1, dilation=1) # conv(64, 3, 1, 1) 33 | 34 | self.conv3_3 = nn.Conv3d(64, 64, kernel_size=3, padding=1, stride=1, dilation=1) # conv(64, 3, 1, 1) 35 | self.conv3_4 = nn.Conv3d(64, 64, kernel_size=3, padding=1, stride=1, dilation=1) # conv(64, 3, 1, 1) 36 | 37 | self.conv3_5 = nn.Conv3d(64, 64, kernel_size=3, padding=2, stride=1, dilation=2) # dilated(64, 3, 1, 2) 38 | self.conv3_6 = nn.Conv3d(64, 64, kernel_size=3, padding=2, stride=1, dilation=2) # dilated(64, 3, 1, 2) 39 | 40 | self.conv3_7 = nn.Conv3d(64, 64, kernel_size=3, padding=2, stride=1, dilation=2) # dilated(64, 3, 1, 2) 41 | self.conv3_8 = nn.Conv3d(64, 64, kernel_size=3, padding=2, stride=1, dilation=2) # dilated(64, 3, 1, 2) 42 | 43 | self.conv4_1 = nn.Conv3d(192, 128, kernel_size=1, padding=0, stride=1, dilation=1) # conv(128, 1, 1, 1) 44 | self.conv4_2 = nn.Conv3d(128, 128, kernel_size=1, padding=0, stride=1, dilation=1) # conv(128, 1, 1, 1) 45 | self.conv_classes = nn.Conv3d(128, self.nbr_classes, kernel_size=1, padding=0, stride=1, dilation=1) # conv(nbr_classes, 1, 1, 1) 46 | 47 | self.four_upsample = nn.Upsample(scale_factor=4, mode='nearest') 48 | 49 | return 50 | 51 | def forward(self, x): 52 | 53 | input = x['3D_OCCUPANCY'].permute(0, 1, 3, 2, 4) # Reshaping [bs, H, W, D] 54 | 55 | out = F.relu(self.conv1_1(input)) 56 | out_add_1 = self.reduction2_1(out) 57 | out = F.relu((self.conv2_1(out))) 58 | out = F.relu(out_add_1 + self.conv2_2(out)) 59 | 60 | out = self.pool2(out) 61 | 62 | out = F.relu(self.conv3_1(out)) 63 | out_add_2 = self.reduction3_1(out) 64 | out = F.relu(out_add_2 + self.conv3_2(out)) 65 | 66 | out_add_3 = self.conv3_3(out) 67 | out = self.conv3_4(F.relu(out_add_3)) 68 | out_res_1 = F.relu(out_add_3 + out) 69 | 70 | out_add_4 = self.conv3_5(out_res_1) 71 | out = self.conv3_6(F.relu(out_add_4)) 72 | out_res_2 = F.relu(out_add_4 + out) 73 | 74 | out_add_5 = self.conv3_7(out_res_2) 75 | out = self.conv3_8(F.relu(out_add_5)) 76 | out_res_3 = F.relu(out_add_5 + out) 77 | 78 | out = torch.cat((out_res_3, out_res_2, out_res_1), 1) 79 | 80 | out = F.relu(self.conv4_1(out)) 81 | out = F.relu(self.conv4_2(out)) 82 | out = self.four_upsample(self.conv_classes(out)) 83 | 84 | out = out.permute(0, 1, 3, 2, 4) # [bs, C, H, W, D] -> [bs, C, W, H, D] 85 | 86 | scores = {'pred_semantic_1_1': out} 87 | 88 | return scores 89 | 90 | def weights_initializer(self, m): 91 | if isinstance(m, nn.Conv2d): 92 | nn.init.kaiming_uniform_(m.weight) 93 | nn.init.zeros_(m.bias) 94 | 95 | def weights_init(self): 96 | self.apply(self.weights_initializer) 97 | 98 | def get_parameters(self): 99 | return self.parameters() 100 | 101 | def compute_loss(self, scores, data): 102 | ''' 103 | :param: prediction: the predicted tensor, must be [BS, C, W, H, D] 104 | ''' 105 | 106 | target = data['3D_LABEL']['1_1'] # [bs, C, W, H, D] 107 | device, dtype = target.device, target.dtype 108 | class_weights = torch.ones(self.nbr_classes).to(device=device, dtype=dtype) 109 | 110 | criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=255, reduction='none').to(device=device) 111 | 112 | # Reduction is none to be able to apply the 2N data balancing after. The mean will be calculated then... 113 | loss_1_1 = criterion(scores['pred_semantic_1_1'], data['3D_LABEL']['1_1'].long()) 114 | # F.cross_entropy(prediction, target.long(), weight=class_weights, ignore_index=255, reduction='none') 115 | 116 | # For SSCNet all classes have same weight and their weight is done by their 2N Data Balancing 117 | weight_db = self.get_data_balance_2N(data) 118 | # Calculate loss weighted by 2N data balancing 119 | # Remember target == 255 is ignored for the loss, this has to be considered for the mean..! 120 | # Also we are considering loss on only 2N free/occluded voxels, which is given by weight_db mask. 121 | # We do not consider 2N in occluded voxels only since is Lidar data, all scene needs to be completed. 122 | # Including outside FoV 123 | loss_1_1 = torch.sum(loss_1_1*weight_db) / torch.sum((weight_db != 1) & (target != 255)) 124 | 125 | loss = {'total': loss_1_1, 'semantic_1_1': loss_1_1} 126 | 127 | return loss 128 | 129 | def get_data_balance_2N(self, data): 130 | ''' 131 | Get a weight tensor for the loss computing. The weight tensor will ignore unknown voxels on target tensor 132 | (label==255). A random under sampling on free voxels with a relation 2:1 between free:occupied is obtained. 133 | The subsampling is done by considering only free occluded voxels. Explanation in SSCNet article 134 | (https://arxiv.org/abs/1611.08974) 135 | 136 | There is a discrepancy between data balancing explained on article and data balancing implemented on code 137 | https://github.com/shurans/sscnet/issues/33 138 | 139 | The subsampling will be done in all free voxels.. Not occluded only.. As Martin Gabarde did on TS3D.. There is 140 | a problem on what is explained for data balancing on SSCNet 141 | ''' 142 | 143 | batch_target = data['3D_LABEL']['1_1'] 144 | weight = torch.zeros_like(batch_target) 145 | for i, target in enumerate(batch_target): 146 | nbr_occupied = torch.sum((target > 0) & (target < 255)) 147 | nbr_free = torch.sum(target == 0) 148 | free_indices = torch.where(target == 0) # Indices of free voxels on target 149 | subsampling = torch.randint(nbr_free, (2 * nbr_occupied,)) # Random subsampling 2*nbr_occupied in range nbr_free 150 | mask = (free_indices[0][subsampling], free_indices[1][subsampling], free_indices[2][subsampling]) # New mask 151 | weight[i][mask] = 1 # Subsampled free voxels to be considered (2N) 152 | weight[i][(target > 0) & (target < 255)] = 1 # Occupied voxels 153 | 154 | # Returning weight that has N occupied voxels and 2N free voxels... 155 | return weight 156 | 157 | def get_target(self, data): 158 | ''' 159 | Return the target to use for evaluation of the model 160 | ''' 161 | return {'1_1': data['3D_LABEL']['1_1']} 162 | 163 | def get_scales(self): 164 | ''' 165 | Return scales needed to train the model 166 | ''' 167 | scales = ['1_1'] 168 | return scales 169 | 170 | def get_validation_loss_keys(self): 171 | return ['total', 'semantic_1_1'] 172 | 173 | def get_train_loss_keys(self): 174 | return ['total', 'semantic_1_1'] -------------------------------------------------------------------------------- /LMSCNet/data/io_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import yaml 3 | import imageio 4 | 5 | 6 | def unpack(compressed): 7 | ''' given a bit encoded voxel grid, make a normal voxel grid out of it. ''' 8 | uncompressed = np.zeros(compressed.shape[0] * 8, dtype=np.uint8) 9 | uncompressed[::8] = compressed[:] >> 7 & 1 10 | uncompressed[1::8] = compressed[:] >> 6 & 1 11 | uncompressed[2::8] = compressed[:] >> 5 & 1 12 | uncompressed[3::8] = compressed[:] >> 4 & 1 13 | uncompressed[4::8] = compressed[:] >> 3 & 1 14 | uncompressed[5::8] = compressed[:] >> 2 & 1 15 | uncompressed[6::8] = compressed[:] >> 1 & 1 16 | uncompressed[7::8] = compressed[:] & 1 17 | 18 | return uncompressed 19 | 20 | 21 | def img_normalize(img, mean, std): 22 | img = img.astype(np.float32) / 255.0 23 | img = img - mean 24 | img = img / std 25 | 26 | return img 27 | 28 | 29 | def pack(array): 30 | """ convert a boolean array into a bitwise array. """ 31 | array = array.reshape((-1)) 32 | 33 | #compressing bit flags. 34 | # yapf: disable 35 | compressed = array[::8] << 7 | array[1::8] << 6 | array[2::8] << 5 | array[3::8] << 4 | array[4::8] << 3 | array[5::8] << 2 | array[6::8] << 1 | array[7::8] 36 | # yapf: enable 37 | 38 | return np.array(compressed, dtype=np.uint8) 39 | 40 | 41 | def get_grid_coords(dims, resolution): 42 | ''' 43 | :param dims: the dimensions of the grid [x, y, z] (i.e. [256, 256, 32]) 44 | :return coords_grid: is the center coords of voxels in the grid 45 | ''' 46 | 47 | # The sensor in centered in X (we go to dims/2 + 1 for the histogramdd) 48 | g_xx = np.arange(-dims[0]/2, dims[0]/2 + 1) 49 | # The sensor is in Y=0 (we go to dims + 1 for the histogramdd) 50 | g_yy = np.arange(0, dims[1] + 1) 51 | # The sensor is in Z=1.73. I observed that the ground was to voxel levels above the grid bottom, so Z pose is at 10 52 | # if bottom voxel is 0. If we want the sensor to be at (0, 0, 0), then the bottom in z is -10, top is 22 53 | # (we go to 22 + 1 for the histogramdd) 54 | # ATTENTION.. Is 11 for old grids.. 10 for new grids (v1.1) (https://github.com/PRBonn/semantic-kitti-api/issues/49) 55 | sensor_pose = 10 56 | g_zz = np.arange(0 - sensor_pose, dims[2] - sensor_pose + 1) 57 | 58 | # Obtaining the grid with coords... 59 | xx, yy, zz = np.meshgrid(g_xx[:-1], g_yy[:-1], g_zz[:-1]) 60 | coords_grid = np.array([xx.flatten(), yy.flatten(), zz.flatten()]).T 61 | coords_grid = coords_grid.astype(np.float) 62 | 63 | coords_grid = (coords_grid * resolution) + resolution/2 64 | 65 | temp = np.copy(coords_grid) 66 | temp[:, 0] = coords_grid[:, 1] 67 | temp[:, 1] = coords_grid[:, 0] 68 | coords_grid = np.copy(temp) 69 | 70 | return coords_grid, g_xx, g_yy, g_zz 71 | 72 | 73 | def _get_remap_lut(config_path): 74 | ''' 75 | remap_lut to remap classes of semantic kitti for training... 76 | :return: 77 | ''' 78 | 79 | dataset_config = yaml.safe_load(open(config_path, 'r')) 80 | # make lookup table for mapping 81 | maxkey = max(dataset_config['learning_map'].keys()) 82 | 83 | # +100 hack making lut bigger just in case there are unknown labels 84 | remap_lut = np.zeros((maxkey + 100), dtype=np.int32) 85 | remap_lut[list(dataset_config['learning_map'].keys())] = list(dataset_config['learning_map'].values()) 86 | 87 | # in completion we have to distinguish empty and invalid voxels. 88 | # Important: For voxels 0 corresponds to "empty" and not "unlabeled". 89 | remap_lut[remap_lut == 0] = 255 # map 0 to 'invalid' 90 | remap_lut[0] = 0 # only 'empty' stays 'empty'. 91 | 92 | return remap_lut 93 | 94 | 95 | def _read_SemKITTI(path, dtype, do_unpack): 96 | bin = np.fromfile(path, dtype=dtype) # Flattened array 97 | if do_unpack: 98 | bin = unpack(bin) 99 | return bin 100 | 101 | 102 | def _read_label_SemKITTI(path): 103 | label = _read_SemKITTI(path, dtype=np.uint16, do_unpack=False).astype(np.float32) 104 | return label 105 | 106 | 107 | def _read_invalid_SemKITTI(path): 108 | invalid = _read_SemKITTI(path, dtype=np.uint8, do_unpack=True) 109 | return invalid 110 | 111 | 112 | def _read_occluded_SemKITTI(path): 113 | occluded = _read_SemKITTI(path, dtype=np.uint8, do_unpack=True) 114 | return occluded 115 | 116 | 117 | def _read_occupancy_SemKITTI(path): 118 | occupancy = _read_SemKITTI(path, dtype=np.uint8, do_unpack=True).astype(np.float32) 119 | return occupancy 120 | 121 | 122 | def _read_rgb_SemKITTI(path): 123 | rgb = np.asarray(imageio.imread(path)) 124 | return rgb 125 | 126 | 127 | def _read_pointcloud_SemKITTI(path): 128 | 'Return pointcloud semantic kitti with remissions (x, y, z, intensity)' 129 | pointcloud = _read_SemKITTI(path, dtype=np.float32, do_unpack=False) 130 | pointcloud = pointcloud.reshape((-1, 4)) 131 | return pointcloud 132 | 133 | 134 | def _read_calib_SemKITTI(calib_path): 135 | """ 136 | :param calib_path: Path to a calibration text file. 137 | :return: dict with calibration matrices. 138 | """ 139 | calib_all = {} 140 | with open(calib_path, 'r') as f: 141 | for line in f.readlines(): 142 | if line == '\n': 143 | break 144 | key, value = line.split(':', 1) 145 | calib_all[key] = np.array([float(x) for x in value.split()]) 146 | 147 | # reshape matrices 148 | calib_out = {} 149 | calib_out['P2'] = calib_all['P2'].reshape(3, 4) # 3x4 projection matrix for left camera 150 | calib_out['Tr'] = np.identity(4) # 4x4 matrix 151 | calib_out['Tr'][:3, :4] = calib_all['Tr'].reshape(3, 4) 152 | return calib_out 153 | 154 | 155 | def get_remap_lut(path): 156 | ''' 157 | remap_lut to remap classes of semantic kitti for training... 158 | :return: 159 | ''' 160 | 161 | dataset_config = yaml.safe_load(open(path, 'r')) 162 | 163 | # make lookup table for mapping 164 | maxkey = max(dataset_config['learning_map'].keys()) 165 | 166 | # +100 hack making lut bigger just in case there are unknown labels 167 | remap_lut = np.zeros((maxkey + 100), dtype=np.int32) 168 | remap_lut[list(dataset_config['learning_map'].keys())] = list(dataset_config['learning_map'].values()) 169 | 170 | # in completion we have to distinguish empty and invalid voxels. 171 | # Important: For voxels 0 corresponds to "empty" and not "unlabeled". 172 | remap_lut[remap_lut == 0] = 255 # map 0 to 'invalid' 173 | remap_lut[0] = 0 # only 'empty' stays 'empty'. 174 | 175 | return remap_lut 176 | 177 | 178 | def data_augmentation_3Dflips(flip, data): 179 | # The .copy() is done to avoid negative strides of the numpy array caused by the way numpy manages the data 180 | # into memory. This gives errors when trying to pass the array to torch sensors.. Solution seen in: 181 | # https://discuss.pytorch.org/t/torch-from-numpy-not-support-negative-strides/3663 182 | # Dims -> {XZY} 183 | # Flipping around the X axis... 184 | if np.isclose(flip, 1): 185 | data = np.flip(data, axis=0).copy() 186 | 187 | # Flipping around the Y axis... 188 | if np.isclose(flip, 2): 189 | data = np.flip(data, 2).copy() 190 | 191 | # Flipping around the X and the Y axis... 192 | if np.isclose(flip, 3): 193 | data = np.flip(np.flip(data, axis=0), axis=2).copy() 194 | 195 | return data 196 | 197 | 198 | def get_cmap_semanticKITTI20(): 199 | colors = np.array([ 200 | # [0 , 0 , 0, 255], 201 | [100, 150, 245, 255], 202 | [100, 230, 245, 255], 203 | [30, 60, 150, 255], 204 | [80, 30, 180, 255], 205 | [100, 80, 250, 255], 206 | [255, 30, 30, 255], 207 | [255, 40, 200, 255], 208 | [150, 30, 90, 255], 209 | [255, 0, 255, 255], 210 | [255, 150, 255, 255], 211 | [75, 0, 75, 255], 212 | [175, 0, 75, 255], 213 | [255, 200, 0, 255], 214 | [255, 120, 50, 255], 215 | [0, 175, 0, 255], 216 | [135, 60, 0, 255], 217 | [150, 240, 80, 255], 218 | [255, 240, 150, 255], 219 | [255, 0, 0, 255]]).astype(np.uint8) 220 | 221 | return colors -------------------------------------------------------------------------------- /LMSCNet/models/LMSCNet_SS.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class SegmentationHead(nn.Module): 8 | ''' 9 | 3D Segmentation heads to retrieve semantic segmentation at each scale. 10 | Formed by Dim expansion, Conv3D, ASPP block, Conv3D. 11 | ''' 12 | def __init__(self, inplanes, planes, nbr_classes, dilations_conv_list): 13 | super().__init__() 14 | 15 | # First convolution 16 | self.conv0 = nn.Conv3d(inplanes, planes, kernel_size=3, padding=1, stride=1) 17 | 18 | # ASPP Block 19 | self.conv_list = dilations_conv_list 20 | self.conv1 = nn.ModuleList( 21 | [nn.Conv3d(planes, planes, kernel_size=3, padding=dil, dilation=dil, bias=False) for dil in dilations_conv_list]) 22 | self.bn1 = nn.ModuleList([nn.BatchNorm3d(planes) for dil in dilations_conv_list]) 23 | self.conv2 = nn.ModuleList( 24 | [nn.Conv3d(planes, planes, kernel_size=3, padding=dil, dilation=dil, bias=False) for dil in dilations_conv_list]) 25 | self.bn2 = nn.ModuleList([nn.BatchNorm3d(planes) for dil in dilations_conv_list]) 26 | self.relu = nn.ReLU(inplace=True) 27 | 28 | # Convolution for output 29 | self.conv_classes = nn.Conv3d(planes, nbr_classes, kernel_size=3, padding=1, stride=1) 30 | 31 | def forward(self, x_in): 32 | 33 | # Dimension exapension 34 | x_in = x_in[:, None, :, :, :] 35 | 36 | # Convolution to go from inplanes to planes features... 37 | x_in = self.relu(self.conv0(x_in)) 38 | 39 | y = self.bn2[0](self.conv2[0](self.relu(self.bn1[0](self.conv1[0](x_in))))) 40 | for i in range(1, len(self.conv_list)): 41 | y += self.bn2[i](self.conv2[i](self.relu(self.bn1[i](self.conv1[i](x_in))))) 42 | x_in = self.relu(y + x_in) # modified 43 | 44 | x_in = self.conv_classes(x_in) 45 | 46 | return x_in 47 | 48 | 49 | class LMSCNet_SS(nn.Module): 50 | 51 | def __init__(self, class_num, input_dimensions, class_frequencies): 52 | ''' 53 | SSCNet architecture 54 | :param N: number of classes to be predicted (i.e. 12 for NYUv2) 55 | ''' 56 | 57 | super().__init__() 58 | self.nbr_classes = class_num 59 | self.input_dimensions = input_dimensions # Grid dimensions should be (W, H, D).. z or height being axis 1 60 | self.class_frequencies = class_frequencies 61 | f = self.input_dimensions[1] 62 | 63 | self.pool = nn.MaxPool2d(2) # [F=2; S=2; P=0; D=1] 64 | 65 | self.Encoder_block1 = nn.Sequential( 66 | nn.Conv2d(f, f, kernel_size=3, padding=1, stride=1), 67 | nn.ReLU(), 68 | nn.Conv2d(f, f, kernel_size=3, padding=1, stride=1), 69 | nn.ReLU() 70 | ) 71 | 72 | self.Encoder_block2 = nn.Sequential( 73 | nn.MaxPool2d(2), 74 | nn.Conv2d(f, int(f*1.5), kernel_size=3, padding=1, stride=1), 75 | nn.ReLU(), 76 | nn.Conv2d(int(f*1.5), int(f*1.5), kernel_size=3, padding=1, stride=1), 77 | nn.ReLU() 78 | ) 79 | 80 | self.Encoder_block3 = nn.Sequential( 81 | nn.MaxPool2d(2), 82 | nn.Conv2d(int(f*1.5), int(f*2), kernel_size=3, padding=1, stride=1), 83 | nn.ReLU(), 84 | nn.Conv2d(int(f*2), int(f*2), kernel_size=3, padding=1, stride=1), 85 | nn.ReLU() 86 | ) 87 | 88 | self.Encoder_block4 = nn.Sequential( 89 | nn.MaxPool2d(2), 90 | nn.Conv2d(int(f*2), int(f*2.5), kernel_size=3, padding=1, stride=1), 91 | nn.ReLU(), 92 | nn.Conv2d(int(f*2.5), int(f*2.5), kernel_size=3, padding=1, stride=1), 93 | nn.ReLU() 94 | ) 95 | 96 | # Treatment output 1:8 97 | self.conv_out_scale_1_8 = nn.Conv2d(int(f*2.5), int(f/8), kernel_size=3, padding=1, stride=1) 98 | self.deconv_1_8__1_2 = nn.ConvTranspose2d(int(f/8), int(f/8), kernel_size=4, padding=0, stride=4) 99 | self.deconv_1_8__1_1 = nn.ConvTranspose2d(int(f/8), int(f/8), kernel_size=8, padding=0, stride=8) 100 | 101 | # Treatment output 1:4 102 | self.deconv1_8 = nn.ConvTranspose2d(int(f/8), int(f/8), kernel_size=6, padding=2, stride=2) 103 | self.conv1_4 = nn.Conv2d(int(f*2) + int(f/8), int(f*2), kernel_size=3, padding=1, stride=1) 104 | self.conv_out_scale_1_4 = nn.Conv2d(int(f*2), int(f/4), kernel_size=3, padding=1, stride=1) 105 | self.deconv_1_4__1_1 = nn.ConvTranspose2d(int(f/4), int(f/4), kernel_size=4, padding=0, stride=4) 106 | 107 | # Treatment output 1:2 108 | self.deconv1_4 = nn.ConvTranspose2d(int(f/4), int(f/4), kernel_size=6, padding=2, stride=2) 109 | self.conv1_2 = nn.Conv2d(int(f*1.5) + int(f/4) + int(f/8), int(f*1.5), kernel_size=3, padding=1, stride=1) 110 | self.conv_out_scale_1_2 = nn.Conv2d(int(f*1.5), int(f/2), kernel_size=3, padding=1, stride=1) 111 | 112 | # Treatment output 1:1 113 | self.deconv1_2 = nn.ConvTranspose2d(int(f/2), int(f/2), kernel_size=6, padding=2, stride=2) 114 | self.conv1_1 = nn.Conv2d(int(f/8) + int(f/4) + int(f/2) + int(f), f, kernel_size=3, padding=1, stride=1) 115 | self.seg_head_1_1 = SegmentationHead(1, 8, self.nbr_classes, [1, 2, 3]) 116 | 117 | def forward(self, x): 118 | 119 | input = x['3D_OCCUPANCY'] # Input to LMSCNet model is 3D occupancy big scale (1:1) [bs, 1, W, H, D] 120 | input = torch.squeeze(input, dim=1).permute(0, 2, 1, 3) # Reshaping to the right way for 2D convs [bs, H, W, D] 121 | 122 | # Encoder block 123 | _skip_1_1 = self.Encoder_block1(input) 124 | _skip_1_2 = self.Encoder_block2(_skip_1_1) 125 | _skip_1_4 = self.Encoder_block3(_skip_1_2) 126 | _skip_1_8 = self.Encoder_block4(_skip_1_4) 127 | 128 | # Out 1_8 129 | out_scale_1_8__2D = self.conv_out_scale_1_8(_skip_1_8) 130 | 131 | # Out 1_4 132 | out = self.deconv1_8(out_scale_1_8__2D) 133 | out = torch.cat((out, _skip_1_4), 1) 134 | out = F.relu(self.conv1_4(out)) 135 | out_scale_1_4__2D = self.conv_out_scale_1_4(out) 136 | 137 | # Out 1_2 138 | out = self.deconv1_4(out_scale_1_4__2D) 139 | out = torch.cat((out, _skip_1_2, self.deconv_1_8__1_2(out_scale_1_8__2D)), 1) 140 | out = F.relu(self.conv1_2(out)) 141 | out_scale_1_2__2D = self.conv_out_scale_1_2(out) 142 | 143 | # Out 1_1 144 | out = self.deconv1_2(out_scale_1_2__2D) 145 | out = torch.cat((out, _skip_1_1, self.deconv_1_4__1_1(out_scale_1_4__2D), self.deconv_1_8__1_1(out_scale_1_8__2D)), 1) 146 | out_scale_1_1__2D = F.relu(self.conv1_1(out)) 147 | out_scale_1_1__3D = self.seg_head_1_1(out_scale_1_1__2D) 148 | 149 | # Take back to [W, H, D] axis order 150 | out_scale_1_1__3D = out_scale_1_1__3D.permute(0, 1, 3, 2, 4) # [bs, C, H, W, D] -> [bs, C, W, H, D] 151 | 152 | scores = {'pred_semantic_1_1': out_scale_1_1__3D} 153 | 154 | return scores 155 | 156 | def weights_initializer(self, m): 157 | if isinstance(m, nn.Conv2d): 158 | nn.init.kaiming_uniform_(m.weight) 159 | nn.init.zeros_(m.bias) 160 | 161 | def weights_init(self): 162 | self.apply(self.weights_initializer) 163 | 164 | def get_parameters(self): 165 | return self.parameters() 166 | 167 | def compute_loss(self, scores, data): 168 | ''' 169 | :param: prediction: the predicted tensor, must be [BS, C, H, W, D] 170 | ''' 171 | 172 | target = data['3D_LABEL']['1_1'] 173 | device, dtype = target.device, target.dtype 174 | class_weights = self.get_class_weights().to(device=target.device, dtype=target.dtype) 175 | 176 | criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=255, reduction='mean').to(device=device) 177 | 178 | loss_1_1 = criterion(scores['pred_semantic_1_1'], data['3D_LABEL']['1_1'].long()) 179 | 180 | loss = {'total': loss_1_1, 'semantic_1_1': loss_1_1} 181 | 182 | return loss 183 | 184 | def get_class_weights(self): 185 | ''' 186 | Cless weights being 1/log(fc) (https://arxiv.org/pdf/2008.10559.pdf) 187 | ''' 188 | epsilon_w = 0.001 # eps to avoid zero division 189 | weights = torch.from_numpy(1 / np.log(self.class_frequencies + epsilon_w)) 190 | 191 | return weights 192 | 193 | def get_target(self, data): 194 | ''' 195 | Return the target to use for evaluation of the model 196 | ''' 197 | return {'1_1': data['3D_LABEL']['1_1']} 198 | # return data['3D_LABEL']['1_1'] #.permute(0, 2, 1, 3) 199 | 200 | def get_scales(self): 201 | ''' 202 | Return scales needed to train the model 203 | ''' 204 | scales = ['1_1'] 205 | return scales 206 | 207 | def get_validation_loss_keys(self): 208 | return ['total', 'semantic_1_1'] 209 | 210 | def get_train_loss_keys(self): 211 | return ['total', 'semantic_1_1'] -------------------------------------------------------------------------------- /LMSCNet/models/LMSCNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class SegmentationHead(nn.Module): 8 | ''' 9 | 3D Segmentation heads to retrieve semantic segmentation at each scale. 10 | Formed by Dim expansion, Conv3D, ASPP block, Conv3D. 11 | ''' 12 | def __init__(self, inplanes, planes, nbr_classes, dilations_conv_list): 13 | super().__init__() 14 | 15 | # First convolution 16 | self.conv0 = nn.Conv3d(inplanes, planes, kernel_size=3, padding=1, stride=1) 17 | 18 | # ASPP Block 19 | self.conv_list = dilations_conv_list 20 | self.conv1 = nn.ModuleList( 21 | [nn.Conv3d(planes, planes, kernel_size=3, padding=dil, dilation=dil, bias=False) for dil in dilations_conv_list]) 22 | self.bn1 = nn.ModuleList([nn.BatchNorm3d(planes) for dil in dilations_conv_list]) 23 | self.conv2 = nn.ModuleList( 24 | [nn.Conv3d(planes, planes, kernel_size=3, padding=dil, dilation=dil, bias=False) for dil in dilations_conv_list]) 25 | self.bn2 = nn.ModuleList([nn.BatchNorm3d(planes) for dil in dilations_conv_list]) 26 | self.relu = nn.ReLU(inplace=True) 27 | 28 | # Convolution for output 29 | self.conv_classes = nn.Conv3d(planes, nbr_classes, kernel_size=3, padding=1, stride=1) 30 | 31 | def forward(self, x_in): 32 | 33 | # Dimension exapension 34 | x_in = x_in[:, None, :, :, :] 35 | 36 | # Convolution to go from inplanes to planes features... 37 | x_in = self.relu(self.conv0(x_in)) 38 | 39 | y = self.bn2[0](self.conv2[0](self.relu(self.bn1[0](self.conv1[0](x_in))))) 40 | for i in range(1, len(self.conv_list)): 41 | y += self.bn2[i](self.conv2[i](self.relu(self.bn1[i](self.conv1[i](x_in))))) 42 | x_in = self.relu(y + x_in) # modified 43 | 44 | x_in = self.conv_classes(x_in) 45 | 46 | return x_in 47 | 48 | 49 | class LMSCNet(nn.Module): 50 | 51 | def __init__(self, class_num, input_dimensions, class_frequencies): 52 | ''' 53 | SSCNet architecture 54 | :param N: number of classes to be predicted (i.e. 12 for NYUv2) 55 | ''' 56 | 57 | super().__init__() 58 | self.nbr_classes = class_num 59 | self.input_dimensions = input_dimensions # Grid dimensions should be (W, H, D).. z or height being axis 1 60 | self.class_frequencies = class_frequencies 61 | f = self.input_dimensions[1] 62 | 63 | self.pool = nn.MaxPool2d(2) # [F=2; S=2; P=0; D=1] 64 | 65 | self.Encoder_block1 = nn.Sequential( 66 | nn.Conv2d(f, f, kernel_size=3, padding=1, stride=1), 67 | nn.ReLU(), 68 | nn.Conv2d(f, f, kernel_size=3, padding=1, stride=1), 69 | nn.ReLU() 70 | ) 71 | 72 | self.Encoder_block2 = nn.Sequential( 73 | nn.MaxPool2d(2), 74 | nn.Conv2d(f, int(f*1.5), kernel_size=3, padding=1, stride=1), 75 | nn.ReLU(), 76 | nn.Conv2d(int(f*1.5), int(f*1.5), kernel_size=3, padding=1, stride=1), 77 | nn.ReLU() 78 | ) 79 | 80 | self.Encoder_block3 = nn.Sequential( 81 | nn.MaxPool2d(2), 82 | nn.Conv2d(int(f*1.5), int(f*2), kernel_size=3, padding=1, stride=1), 83 | nn.ReLU(), 84 | nn.Conv2d(int(f*2), int(f*2), kernel_size=3, padding=1, stride=1), 85 | nn.ReLU() 86 | ) 87 | 88 | self.Encoder_block4 = nn.Sequential( 89 | nn.MaxPool2d(2), 90 | nn.Conv2d(int(f*2), int(f*2.5), kernel_size=3, padding=1, stride=1), 91 | nn.ReLU(), 92 | nn.Conv2d(int(f*2.5), int(f*2.5), kernel_size=3, padding=1, stride=1), 93 | nn.ReLU() 94 | ) 95 | 96 | # Treatment output 1:8 97 | self.conv_out_scale_1_8 = nn.Conv2d(int(f*2.5), int(f/8), kernel_size=3, padding=1, stride=1) 98 | self.seg_head_1_8 = SegmentationHead(1, 8, self.nbr_classes, [1, 2, 3]) 99 | self.deconv_1_8__1_2 = nn.ConvTranspose2d(int(f/8), int(f/8), kernel_size=4, padding=0, stride=4) 100 | self.deconv_1_8__1_1 = nn.ConvTranspose2d(int(f/8), int(f/8), kernel_size=8, padding=0, stride=8) 101 | 102 | # Treatment output 1:4 103 | self.deconv1_8 = nn.ConvTranspose2d(int(f/8), int(f/8), kernel_size=6, padding=2, stride=2) 104 | self.conv1_4 = nn.Conv2d(int(f*2) + int(f/8), int(f*2), kernel_size=3, padding=1, stride=1) 105 | self.conv_out_scale_1_4 = nn.Conv2d(int(f*2), int(f/4), kernel_size=3, padding=1, stride=1) 106 | self.seg_head_1_4 = SegmentationHead(1, 8, self.nbr_classes, [1, 2, 3]) 107 | self.deconv_1_4__1_1 = nn.ConvTranspose2d(int(f/4), int(f/4), kernel_size=4, padding=0, stride=4) 108 | 109 | # Treatment output 1:2 110 | self.deconv1_4 = nn.ConvTranspose2d(int(f/4), int(f/4), kernel_size=6, padding=2, stride=2) 111 | self.conv1_2 = nn.Conv2d(int(f*1.5) + int(f/4) + int(f/8), int(f*1.5), kernel_size=3, padding=1, stride=1) 112 | self.conv_out_scale_1_2 = nn.Conv2d(int(f*1.5), int(f/2), kernel_size=3, padding=1, stride=1) 113 | self.seg_head_1_2 = SegmentationHead(1, 8, self.nbr_classes, [1, 2, 3]) 114 | 115 | # Treatment output 1:1 116 | self.deconv1_2 = nn.ConvTranspose2d(int(f/2), int(f/2), kernel_size=6, padding=2, stride=2) 117 | self.conv1_1 = nn.Conv2d(int(f/8) + int(f/4) + int(f/2) + int(f), f, kernel_size=3, padding=1, stride=1) 118 | self.seg_head_1_1 = SegmentationHead(1, 8, self.nbr_classes, [1, 2, 3]) 119 | 120 | def forward(self, x): 121 | 122 | input = x['3D_OCCUPANCY'] # Input to LMSCNet model is 3D occupancy big scale (1:1) [bs, 1, W, H, D] 123 | input = torch.squeeze(input, dim=1).permute(0, 2, 1, 3) # Reshaping to the right way for 2D convs [bs, H, W, D] 124 | 125 | # Encoder block 126 | _skip_1_1 = self.Encoder_block1(input) 127 | _skip_1_2 = self.Encoder_block2(_skip_1_1) 128 | _skip_1_4 = self.Encoder_block3(_skip_1_2) 129 | _skip_1_8 = self.Encoder_block4(_skip_1_4) 130 | 131 | # Out 1_8 132 | out_scale_1_8__2D = self.conv_out_scale_1_8(_skip_1_8) 133 | out_scale_1_8__3D = self.seg_head_1_8(out_scale_1_8__2D) 134 | 135 | # Out 1_4 136 | out = self.deconv1_8(out_scale_1_8__2D) 137 | out = torch.cat((out, _skip_1_4), 1) 138 | out = F.relu(self.conv1_4(out)) 139 | out_scale_1_4__2D = self.conv_out_scale_1_4(out) 140 | out_scale_1_4__3D = self.seg_head_1_4(out_scale_1_4__2D) 141 | 142 | # Out 1_2 143 | out = self.deconv1_4(out_scale_1_4__2D) 144 | out = torch.cat((out, _skip_1_2, self.deconv_1_8__1_2(out_scale_1_8__2D)), 1) 145 | out = F.relu(self.conv1_2(out)) 146 | out_scale_1_2__2D = self.conv_out_scale_1_2(out) 147 | out_scale_1_2__3D = self.seg_head_1_2(out_scale_1_2__2D) 148 | 149 | # Out 1_1 150 | out = self.deconv1_2(out_scale_1_2__2D) 151 | out = torch.cat((out, _skip_1_1, self.deconv_1_4__1_1(out_scale_1_4__2D), self.deconv_1_8__1_1(out_scale_1_8__2D)), 1) 152 | out_scale_1_1__2D = F.relu(self.conv1_1(out)) 153 | out_scale_1_1__3D = self.seg_head_1_1(out_scale_1_1__2D) 154 | 155 | # Take back to [W, H, D] axis order 156 | out_scale_1_8__3D = out_scale_1_8__3D.permute(0, 1, 3, 2, 4) # [bs, C, H, W, D] -> [bs, C, W, H, D] 157 | out_scale_1_4__3D = out_scale_1_4__3D.permute(0, 1, 3, 2, 4) # [bs, C, H, W, D] -> [bs, C, W, H, D] 158 | out_scale_1_2__3D = out_scale_1_2__3D.permute(0, 1, 3, 2, 4) # [bs, C, H, W, D] -> [bs, C, W, H, D] 159 | out_scale_1_1__3D = out_scale_1_1__3D.permute(0, 1, 3, 2, 4) # [bs, C, H, W, D] -> [bs, C, W, H, D] 160 | 161 | scores = {'pred_semantic_1_1': out_scale_1_1__3D, 'pred_semantic_1_2': out_scale_1_2__3D, 162 | 'pred_semantic_1_4': out_scale_1_4__3D, 'pred_semantic_1_8': out_scale_1_8__3D} 163 | 164 | return scores 165 | 166 | def weights_initializer(self, m): 167 | if isinstance(m, nn.Conv2d): 168 | nn.init.kaiming_uniform_(m.weight) 169 | nn.init.zeros_(m.bias) 170 | 171 | def weights_init(self): 172 | self.apply(self.weights_initializer) 173 | 174 | def get_parameters(self): 175 | return self.parameters() 176 | 177 | def compute_loss(self, scores, data): 178 | ''' 179 | :param: prediction: the predicted tensor, must be [BS, C, H, W, D] 180 | ''' 181 | 182 | target = data['3D_LABEL']['1_1'] 183 | device, dtype = target.device, target.dtype 184 | class_weights = self.get_class_weights().to(device=target.device, dtype=target.dtype) 185 | 186 | criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=255, reduction='mean').to(device=device) 187 | 188 | loss_1_1 = criterion(scores['pred_semantic_1_1'], data['3D_LABEL']['1_1'].long()) 189 | loss_1_2 = criterion(scores['pred_semantic_1_2'], data['3D_LABEL']['1_2'].long()) 190 | loss_1_4 = criterion(scores['pred_semantic_1_4'], data['3D_LABEL']['1_4'].long()) 191 | loss_1_8 = criterion(scores['pred_semantic_1_8'], data['3D_LABEL']['1_8'].long()) 192 | 193 | loss_total = (loss_1_1 + loss_1_2 + loss_1_4 + loss_1_8) / 4 194 | 195 | loss = {'total': loss_total, 'semantic_1_1': loss_1_1, 'semantic_1_2': loss_1_2, 'semantic_1_4': loss_1_4, 196 | 'semantic_1_8': loss_1_8} 197 | 198 | return loss 199 | 200 | def get_class_weights(self): 201 | ''' 202 | Cless weights being 1/log(fc) (https://arxiv.org/pdf/2008.10559.pdf) 203 | ''' 204 | epsilon_w = 0.001 # eps to avoid zero division 205 | weights = torch.from_numpy(1 / np.log(self.class_frequencies + epsilon_w)) 206 | 207 | return weights 208 | 209 | def get_target(self, data): 210 | ''' 211 | Return the target to use for evaluation of the model 212 | ''' 213 | return {'1_1': data['3D_LABEL']['1_1'], '1_2': data['3D_LABEL']['1_2'], 214 | '1_4': data['3D_LABEL']['1_4'], '1_8': data['3D_LABEL']['1_8']} 215 | # return data['3D_LABEL']['1_1'] #.permute(0, 2, 1, 3) 216 | 217 | def get_scales(self): 218 | ''' 219 | Return scales needed to train the model 220 | ''' 221 | scales = ['1_1', '1_2', '1_4', '1_8'] 222 | return scales 223 | 224 | def get_validation_loss_keys(self): 225 | return ['total', 'semantic_1_1','semantic_1_2', 'semantic_1_4', 'semantic_1_8'] 226 | 227 | def get_train_loss_keys(self): 228 | return ['total', 'semantic_1_1','semantic_1_2', 'semantic_1_4', 'semantic_1_8'] -------------------------------------------------------------------------------- /LMSCNet/data/SemanticKITTI.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from glob import glob 3 | import os 4 | import numpy as np 5 | import yaml 6 | import random 7 | import sys 8 | 9 | import LMSCNet.data.io_data as SemanticKittiIO 10 | 11 | 12 | class SemanticKITTI_dataloader(Dataset): 13 | 14 | def __init__(self, dataset, phase): 15 | ''' 16 | 17 | :param dataset: The dataset configuration (data augmentation, input encoding, etc) 18 | :param phase_tag: To differentiate between training, validation and test phase 19 | ''' 20 | 21 | yaml_path, _ = os.path.split(os.path.realpath(__file__)) 22 | self.dataset_config = yaml.safe_load(open(os.path.join(yaml_path, 'semantic-kitti.yaml'), 'r')) 23 | self.nbr_classes = self.dataset_config['nbr_classes'] 24 | self.grid_dimensions = self.dataset_config['grid_dims'] # [W, H, D] 25 | self.remap_lut = self.get_remap_lut() 26 | self.rgb_mean = np.array([0.34749558, 0.36745213, 0.36123651]) # images mean: [88.61137282 93.70029365 92.11530949] 27 | self.rgb_std = np.array([0.30599035, 0.3129534 , 0.31933814]) # images std: [78.02753826 79.80311686 81.43122464] 28 | self.root_dir = dataset['ROOT_DIR'] 29 | self.modalities = dataset['MODALITIES'] 30 | self.extensions = {'3D_OCCUPANCY': '.bin', '3D_LABEL': '.label', '3D_OCCLUDED': '.occluded', 31 | '3D_INVALID': '.invalid'} 32 | self.data_augmentation = {'FLIPS': dataset['AUGMENTATION']['FLIPS']} 33 | 34 | self.filepaths = {} 35 | self.phase = phase 36 | self.class_frequencies = np.array([5.41773033e+09, 1.57835390e+07, 1.25136000e+05, 1.18809000e+05, 37 | 6.46799000e+05, 8.21951000e+05, 2.62978000e+05, 2.83696000e+05, 38 | 2.04750000e+05, 6.16887030e+07, 4.50296100e+06, 4.48836500e+07, 39 | 2.26992300e+06, 5.68402180e+07, 1.57196520e+07, 1.58442623e+08, 40 | 2.06162300e+06, 3.69705220e+07, 1.15198800e+06, 3.34146000e+05]) 41 | 42 | self.split = {'train': [0, 1, 2, 3, 4, 5, 6, 7, 9, 10], 'val': [8], 43 | 'test': [11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]} 44 | 45 | for modality in self.modalities: 46 | if self.modalities[modality]: 47 | self.get_filepaths(modality) 48 | 49 | # if self.phase != 'test': 50 | # self.check_same_nbr_files() 51 | 52 | self.nbr_files = len(self.filepaths['3D_OCCUPANCY']) # TODO: Pass to something generic 53 | 54 | return 55 | 56 | def get_filepaths(self, modality): 57 | ''' 58 | Set modality filepaths with split according to phase (train, val, test) 59 | ''' 60 | 61 | sequences = list(sorted(glob(os.path.join(self.root_dir, 'dataset', 'sequences', '*')))[i] for i in self.split[self.phase]) 62 | 63 | if self.phase != 'test': 64 | 65 | if modality == '3D_LABEL': 66 | self.filepaths['3D_LABEL'] = {'1_1': [], '1_2': [], '1_4': [], '1_8': []} 67 | self.filepaths['3D_INVALID'] = {'1_1': [], '1_2': [], '1_4': [], '1_8': []} 68 | for sequence in sequences: 69 | assert len(os.listdir(sequence)) > 0, 'Error, No files in sequence: {}'.format(sequence) 70 | # Scale 1:1 71 | self.filepaths['3D_LABEL']['1_1'] += sorted(glob(os.path.join(sequence, 'voxels', '*.label'))) 72 | self.filepaths['3D_INVALID']['1_1'] += sorted(glob(os.path.join(sequence, 'voxels', '*.invalid'))) 73 | # Scale 1:2 74 | self.filepaths['3D_LABEL']['1_2'] += sorted(glob(os.path.join(sequence, 'voxels', '*.label_1_2'))) 75 | self.filepaths['3D_INVALID']['1_2'] += sorted(glob(os.path.join(sequence, 'voxels', '*.invalid_1_2'))) 76 | # Scale 1:4 77 | self.filepaths['3D_LABEL']['1_4'] += sorted(glob(os.path.join(sequence, 'voxels', '*.label_1_4'))) 78 | self.filepaths['3D_INVALID']['1_4'] += sorted(glob(os.path.join(sequence, 'voxels', '*.invalid_1_4'))) 79 | # Scale 1:8 80 | self.filepaths['3D_LABEL']['1_8'] += sorted(glob(os.path.join(sequence, 'voxels', '*.label_1_8'))) 81 | self.filepaths['3D_INVALID']['1_8'] += sorted(glob(os.path.join(sequence, 'voxels', '*.invalid_1_8'))) 82 | 83 | if modality == '3D_OCCLUDED': 84 | self.filepaths['3D_OCCLUDED'] = [] 85 | for sequence in sequences: 86 | assert len(os.listdir(sequence)) > 0, 'Error, No files in sequence: {}'.format(sequence) 87 | self.filepaths['3D_OCCLUDED'] += sorted(glob(os.path.join(sequence, 'voxels', '*.occluded'))) 88 | 89 | if modality == '3D_OCCUPANCY': 90 | self.filepaths['3D_OCCUPANCY'] = [] 91 | for sequence in sequences: 92 | assert len(os.listdir(sequence)) > 0, 'Error, No files in sequence: {}'.format(sequence) 93 | self.filepaths['3D_OCCUPANCY'] += sorted(glob(os.path.join(sequence, 'voxels', '*.bin'))) 94 | 95 | # if modality == '2D_RGB': 96 | # self.filepaths['2D_RGB'] = [] 97 | # for sequence in sequences: 98 | # assert len(os.listdir(sequence)) > 0, 'Error, No files in sequence: {}'.format(sequence) 99 | # self.filepaths['2D_RGB'] += sorted(glob(os.path.join(sequence, 'image_2', '*.png')))[::5] 100 | 101 | return 102 | 103 | def check_same_nbr_files(self): 104 | ''' 105 | Set modality filepaths with split according to phase (train, val, test) 106 | ''' 107 | 108 | # TODO: Modify for nested dictionaries... 109 | for i in range(len(self.filepaths.keys()) - 1): 110 | length1 = len(self.filepaths[list(self.filepaths.keys())[i]]) 111 | length2 = len(self.filepaths[list(self.filepaths.keys())[i+1]]) 112 | assert length1 == length2, 'Error: {} and {} not same number of files'.format(list(self.filepaths.keys())[i], 113 | list(self.filepaths.keys())[i+1]) 114 | return 115 | 116 | def __getitem__(self, idx): 117 | ''' 118 | 119 | ''' 120 | 121 | data = {} 122 | 123 | do_flip = 0 124 | if self.data_augmentation['FLIPS'] and self.phase == 'train': 125 | do_flip = random.randint(0, 3) 126 | 127 | for modality in self.modalities: 128 | if (self.modalities[modality]) and (modality in self.filepaths): 129 | data[modality] = self.get_data_modality(modality, idx, do_flip) 130 | 131 | return data, idx 132 | 133 | def get_data_modality(self, modality, idx, flip): 134 | 135 | if modality == '3D_OCCUPANCY': 136 | OCCUPANCY = SemanticKittiIO._read_occupancy_SemKITTI(self.filepaths[modality][idx]) 137 | OCCUPANCY = np.moveaxis(OCCUPANCY.reshape([self.grid_dimensions[0], 138 | self.grid_dimensions[2], 139 | self.grid_dimensions[1]]), [0, 1, 2], [0, 2, 1]) 140 | OCCUPANCY = SemanticKittiIO.data_augmentation_3Dflips(flip, OCCUPANCY) 141 | return OCCUPANCY[None, :, :, :] 142 | 143 | elif modality == '3D_LABEL': 144 | LABEL_1_1 = SemanticKittiIO.data_augmentation_3Dflips(flip, self.get_label_at_scale('1_1', idx)) 145 | LABEL_1_2 = SemanticKittiIO.data_augmentation_3Dflips(flip, self.get_label_at_scale('1_2', idx)) 146 | LABEL_1_4 = SemanticKittiIO.data_augmentation_3Dflips(flip, self.get_label_at_scale('1_4', idx)) 147 | LABEL_1_8 = SemanticKittiIO.data_augmentation_3Dflips(flip, self.get_label_at_scale('1_8', idx)) 148 | return {'1_1': LABEL_1_1, '1_2': LABEL_1_2, '1_4': LABEL_1_4, '1_8': LABEL_1_8} 149 | 150 | elif modality == '3D_OCCLUDED': 151 | OCCLUDED = SemanticKittiIO._read_occluded_SemKITTI(self.filepaths[modality][idx]) 152 | OCCLUDED = np.moveaxis(OCCLUDED.reshape([self.grid_dimensions[0], 153 | self.grid_dimensions[2], 154 | self.grid_dimensions[1]]), [0, 1, 2], [0, 2, 1]) 155 | OCCLUDED = SemanticKittiIO.data_augmentation_3Dflips(flip, OCCLUDED) 156 | return OCCLUDED 157 | 158 | # elif modality == '2D_RGB': 159 | # RGB = SemanticKittiIO._read_rgb_SemKITTI(self.filepaths[modality][idx]) 160 | # # TODO Standarize, Normalize 161 | # RGB = SemanticKittiIO.img_normalize(RGB, self.rgb_mean, self.rgb_std) 162 | # RGB = np.moveaxis(RGB, (0, 1, 2), (1, 2, 0)).astype(dtype='float32') # reshaping [3xHxW] 163 | # # There is a problem on the RGB images.. They are not all the same size and I used those to calculate the mapping 164 | # # for the sketch... I need images all te same size.. 165 | # return RGB 166 | 167 | else: 168 | assert False, 'Specified modality not found' 169 | 170 | def get_label_at_scale(self, scale, idx): 171 | 172 | scale_divide = int(scale[-1]) 173 | INVALID = SemanticKittiIO._read_invalid_SemKITTI(self.filepaths['3D_INVALID'][scale][idx]) 174 | LABEL = SemanticKittiIO._read_label_SemKITTI(self.filepaths['3D_LABEL'][scale][idx]) 175 | if scale == '1_1': 176 | LABEL = self.remap_lut[LABEL.astype(np.uint16)].astype(np.float32) # Remap 20 classes semanticKITTI SSC 177 | LABEL[np.isclose(INVALID, 1)] = 255 # Setting to unknown all voxels marked on invalid mask... 178 | LABEL = np.moveaxis(LABEL.reshape([int(self.grid_dimensions[0] / scale_divide), 179 | int(self.grid_dimensions[2] / scale_divide), 180 | int(self.grid_dimensions[1] / scale_divide)]), [0, 1, 2], [0, 2, 1]) 181 | 182 | return LABEL 183 | 184 | def read_semantics_config(self, data_path): 185 | 186 | # get number of interest classes, and the label mappings 187 | DATA = yaml.safe_load(open(data_path, 'r')) 188 | self.class_strings = DATA["labels"] 189 | self.class_remap = DATA["learning_map"] 190 | self.class_inv_remap = DATA["learning_map_inv"] 191 | self.class_ignore = DATA["learning_ignore"] 192 | self.n_classes = len(self.class_inv_remap) 193 | 194 | return 195 | 196 | def get_inv_remap_lut(self): 197 | ''' 198 | remap_lut to remap classes of semantic kitti for training... 199 | :return: 200 | ''' 201 | 202 | # make lookup table for mapping 203 | maxkey = max(self.dataset_config['learning_map_inv'].keys()) 204 | 205 | # +100 hack making lut bigger just in case there are unknown labels 206 | remap_lut = np.zeros((maxkey + 1), dtype=np.int32) 207 | remap_lut[list(self.dataset_config['learning_map_inv'].keys())] = list(self.dataset_config['learning_map_inv'].values()) 208 | 209 | return remap_lut 210 | 211 | def get_remap_lut(self): 212 | ''' 213 | remap_lut to remap classes of semantic kitti for training... 214 | :return: 215 | ''' 216 | 217 | # make lookup table for mapping 218 | maxkey = max(self.dataset_config['learning_map'].keys()) 219 | 220 | # +100 hack making lut bigger just in case there are unknown labels 221 | remap_lut = np.zeros((maxkey + 100), dtype=np.int32) 222 | remap_lut[list(self.dataset_config['learning_map'].keys())] = list(self.dataset_config['learning_map'].values()) 223 | 224 | # in completion we have to distinguish empty and invalid voxels. 225 | # Important: For voxels 0 corresponds to "empty" and not "unlabeled". 226 | remap_lut[remap_lut == 0] = 255 # map 0 to 'invalid' 227 | remap_lut[0] = 0 # only 'empty' stays 'empty'. 228 | 229 | return remap_lut 230 | 231 | def __len__(self): 232 | """ 233 | Returns the length of the dataset 234 | """ 235 | # Return the number of elements in the dataset 236 | return self.nbr_files 237 | 238 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | LMSCNet: Lightweight Multiscale 3D Semantic Completion (L. Roldão et al., 3DV 2020) 2 | 3 | Copyright 2020 Inria and AKKA Technologies 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | https://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | 17 | 18 | 19 | Apache License 20 | Version 2.0, January 2004 21 | https://www.apache.org/licenses/ 22 | 23 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 24 | 25 | 1. Definitions. 26 | 27 | "License" shall mean the terms and conditions for use, reproduction, 28 | and distribution as defined by Sections 1 through 9 of this document. 29 | 30 | "Licensor" shall mean the copyright owner or entity authorized by 31 | the copyright owner that is granting the License. 32 | 33 | "Legal Entity" shall mean the union of the acting entity and all 34 | other entities that control, are controlled by, or are under common 35 | control with that entity. For the purposes of this definition, 36 | "control" means (i) the power, direct or indirect, to cause the 37 | direction or management of such entity, whether by contract or 38 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 39 | outstanding shares, or (iii) beneficial ownership of such entity. 40 | 41 | "You" (or "Your") shall mean an individual or Legal Entity 42 | exercising permissions granted by this License. 43 | 44 | "Source" form shall mean the preferred form for making modifications, 45 | including but not limited to software source code, documentation 46 | source, and configuration files. 47 | 48 | "Object" form shall mean any form resulting from mechanical 49 | transformation or translation of a Source form, including but 50 | not limited to compiled object code, generated documentation, 51 | and conversions to other media types. 52 | 53 | "Work" shall mean the work of authorship, whether in Source or 54 | Object form, made available under the License, as indicated by a 55 | copyright notice that is included in or attached to the work 56 | (an example is provided in the Appendix below). 57 | 58 | "Derivative Works" shall mean any work, whether in Source or Object 59 | form, that is based on (or derived from) the Work and for which the 60 | editorial revisions, annotations, elaborations, or other modifications 61 | represent, as a whole, an original work of authorship. For the purposes 62 | of this License, Derivative Works shall not include works that remain 63 | separable from, or merely link (or bind by name) to the interfaces of, 64 | the Work and Derivative Works thereof. 65 | 66 | "Contribution" shall mean any work of authorship, including 67 | the original version of the Work and any modifications or additions 68 | to that Work or Derivative Works thereof, that is intentionally 69 | submitted to Licensor for inclusion in the Work by the copyright owner 70 | or by an individual or Legal Entity authorized to submit on behalf of 71 | the copyright owner. For the purposes of this definition, "submitted" 72 | means any form of electronic, verbal, or written communication sent 73 | to the Licensor or its representatives, including but not limited to 74 | communication on electronic mailing lists, source code control systems, 75 | and issue tracking systems that are managed by, or on behalf of, the 76 | Licensor for the purpose of discussing and improving the Work, but 77 | excluding communication that is conspicuously marked or otherwise 78 | designated in writing by the copyright owner as "Not a Contribution." 79 | 80 | "Contributor" shall mean Licensor and any individual or Legal Entity 81 | on behalf of whom a Contribution has been received by Licensor and 82 | subsequently incorporated within the Work. 83 | 84 | 2. Grant of Copyright License. Subject to the terms and conditions of 85 | this License, each Contributor hereby grants to You a perpetual, 86 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 87 | copyright license to reproduce, prepare Derivative Works of, 88 | publicly display, publicly perform, sublicense, and distribute the 89 | Work and such Derivative Works in Source or Object form. 90 | 91 | 3. Grant of Patent License. Subject to the terms and conditions of 92 | this License, each Contributor hereby grants to You a perpetual, 93 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 94 | (except as stated in this section) patent license to make, have made, 95 | use, offer to sell, sell, import, and otherwise transfer the Work, 96 | where such license applies only to those patent claims licensable 97 | by such Contributor that are necessarily infringed by their 98 | Contribution(s) alone or by combination of their Contribution(s) 99 | with the Work to which such Contribution(s) was submitted. If You 100 | institute patent litigation against any entity (including a 101 | cross-claim or counterclaim in a lawsuit) alleging that the Work 102 | or a Contribution incorporated within the Work constitutes direct 103 | or contributory patent infringement, then any patent licenses 104 | granted to You under this License for that Work shall terminate 105 | as of the date such litigation is filed. 106 | 107 | 4. Redistribution. You may reproduce and distribute copies of the 108 | Work or Derivative Works thereof in any medium, with or without 109 | modifications, and in Source or Object form, provided that You 110 | meet the following conditions: 111 | 112 | (a) You must give any other recipients of the Work or 113 | Derivative Works a copy of this License; and 114 | 115 | (b) You must cause any modified files to carry prominent notices 116 | stating that You changed the files; and 117 | 118 | (c) You must retain, in the Source form of any Derivative Works 119 | that You distribute, all copyright, patent, trademark, and 120 | attribution notices from the Source form of the Work, 121 | excluding those notices that do not pertain to any part of 122 | the Derivative Works; and 123 | 124 | (d) If the Work includes a "NOTICE" text file as part of its 125 | distribution, then any Derivative Works that You distribute must 126 | include a readable copy of the attribution notices contained 127 | within such NOTICE file, excluding those notices that do not 128 | pertain to any part of the Derivative Works, in at least one 129 | of the following places: within a NOTICE text file distributed 130 | as part of the Derivative Works; within the Source form or 131 | documentation, if provided along with the Derivative Works; or, 132 | within a display generated by the Derivative Works, if and 133 | wherever such third-party notices normally appear. The contents 134 | of the NOTICE file are for informational purposes only and 135 | do not modify the License. You may add Your own attribution 136 | notices within Derivative Works that You distribute, alongside 137 | or as an addendum to the NOTICE text from the Work, provided 138 | that such additional attribution notices cannot be construed 139 | as modifying the License. 140 | 141 | You may add Your own copyright statement to Your modifications and 142 | may provide additional or different license terms and conditions 143 | for use, reproduction, or distribution of Your modifications, or 144 | for any such Derivative Works as a whole, provided Your use, 145 | reproduction, and distribution of the Work otherwise complies with 146 | the conditions stated in this License. 147 | 148 | 5. Submission of Contributions. Unless You explicitly state otherwise, 149 | any Contribution intentionally submitted for inclusion in the Work 150 | by You to the Licensor shall be under the terms and conditions of 151 | this License, without any additional terms or conditions. 152 | Notwithstanding the above, nothing herein shall supersede or modify 153 | the terms of any separate license agreement you may have executed 154 | with Licensor regarding such Contributions. 155 | 156 | 6. Trademarks. This License does not grant permission to use the trade 157 | names, trademarks, service marks, or product names of the Licensor, 158 | except as required for reasonable and customary use in describing the 159 | origin of the Work and reproducing the content of the NOTICE file. 160 | 161 | 7. Disclaimer of Warranty. Unless required by applicable law or 162 | agreed to in writing, Licensor provides the Work (and each 163 | Contributor provides its Contributions) on an "AS IS" BASIS, 164 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 165 | implied, including, without limitation, any warranties or conditions 166 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 167 | PARTICULAR PURPOSE. You are solely responsible for determining the 168 | appropriateness of using or redistributing the Work and assume any 169 | risks associated with Your exercise of permissions under this License. 170 | 171 | 8. Limitation of Liability. In no event and under no legal theory, 172 | whether in tort (including negligence), contract, or otherwise, 173 | unless required by applicable law (such as deliberate and grossly 174 | negligent acts) or agreed to in writing, shall any Contributor be 175 | liable to You for damages, including any direct, indirect, special, 176 | incidental, or consequential damages of any character arising as a 177 | result of this License or out of the use or inability to use the 178 | Work (including but not limited to damages for loss of goodwill, 179 | work stoppage, computer failure or malfunction, or any and all 180 | other commercial damages or losses), even if such Contributor 181 | has been advised of the possibility of such damages. 182 | 183 | 9. Accepting Warranty or Additional Liability. While redistributing 184 | the Work or Derivative Works thereof, You may choose to offer, 185 | and charge a fee for, acceptance of support, warranty, indemnity, 186 | or other liability obligations and/or rights consistent with this 187 | License. However, in accepting such obligations, You may act only 188 | on Your own behalf and on Your sole responsibility, not on behalf 189 | of any other Contributor, and only if You agree to indemnify, 190 | defend, and hold each Contributor harmless for any liability 191 | incurred by, or claims asserted against, such Contributor by reason 192 | of your accepting any such warranty or additional liability. 193 | 194 | END OF TERMS AND CONDITIONS 195 | -------------------------------------------------------------------------------- /LMSCNet/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.tensorboard import SummaryWriter 6 | import sys 7 | 8 | # Append root directory to system path for imports 9 | repo_path, _ = os.path.split(os.path.realpath(__file__)) 10 | repo_path, _ = os.path.split(repo_path) 11 | sys.path.append(repo_path) 12 | 13 | from LMSCNet.common.seed import seed_all 14 | from LMSCNet.common.config import CFG 15 | from LMSCNet.common.dataset import get_dataset 16 | from LMSCNet.common.model import get_model 17 | from LMSCNet.common.logger import get_logger 18 | from LMSCNet.common.optimizer import build_optimizer, build_scheduler 19 | from LMSCNet.common.io_tools import dict_to 20 | from LMSCNet.common.metrics import Metrics 21 | import LMSCNet.common.checkpoint as checkpoint 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser(description='LMSCNet training') 26 | parser.add_argument( 27 | '--cfg', 28 | dest='config_file', 29 | default='', 30 | metavar='FILE', 31 | help='path to config file', 32 | type=str, 33 | ) 34 | parser.add_argument( 35 | '--dset_root', 36 | dest='dataset_root', 37 | default=None, 38 | metavar='DATASET', 39 | help='path to dataset root folder', 40 | type=str, 41 | ) 42 | args = parser.parse_args() 43 | return args 44 | 45 | 46 | def train(model, optimizer, scheduler, dataset, _cfg, start_epoch, logger, tbwriter): 47 | """ 48 | Train a model using the PyTorch Module API. 49 | Inputs: 50 | - model: A PyTorch Module giving the model to train. 51 | - optimizer: An Optimizer object we will use to train the model 52 | - scheduler: Scheduler for learning rate decay if used 53 | - dataset: The dataset to load files 54 | - _cfg: The configuration dictionary read from config file 55 | - start_epoch: The epoch at which start the training (checkpoint) 56 | - logger: The logger to save info 57 | - tbwriter: The tensorboard writer to save plots 58 | Returns: Nothing, but prints model accuracies during training. 59 | """ 60 | 61 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 62 | dtype = torch.float32 # Tensor type to be used 63 | 64 | # Moving optimizer and model to used device 65 | model = model.to(device=device) 66 | for state in optimizer.state.values(): 67 | for k, v in state.items(): 68 | if isinstance(v, torch.Tensor): 69 | state[k] = v.to(device) 70 | 71 | dset = dataset['train'] 72 | 73 | nbr_epochs = _cfg._dict['TRAIN']['EPOCHS'] 74 | nbr_iterations = len(dset) # number of iterations depends on batchs size 75 | 76 | # Defining metrics class and initializing them.. 77 | metrics = Metrics(dset.dataset.nbr_classes, nbr_iterations, model.get_scales()) 78 | metrics.reset_evaluator() 79 | metrics.losses_track.set_validation_losses(model.get_validation_loss_keys()) 80 | metrics.losses_track.set_train_losses(model.get_train_loss_keys()) 81 | 82 | for epoch in range(start_epoch, nbr_epochs+1): 83 | 84 | logger.info('=> =============== Epoch [{}/{}] ==============='.format(epoch, nbr_epochs)) 85 | logger.info('=> Reminder - Output of routine on {}'.format(_cfg._dict['OUTPUT']['OUTPUT_PATH'])) 86 | 87 | # Print learning rate 88 | # for param_group in optimizer.param_groups: 89 | logger.info('=> Learning rate: {}'.format(scheduler.get_lr()[0])) 90 | 91 | model.train() # put model to training mode 92 | 93 | # for t, (data, indices) in enumerate(dataset['train']): 94 | for t, (data, indices) in enumerate(dset): 95 | 96 | data = dict_to(data, device, dtype) 97 | 98 | scores = model(data) 99 | 100 | loss = model.compute_loss(scores, data) 101 | 102 | # Zero out the gradients. 103 | optimizer.zero_grad() 104 | # Backward pass: gradient of loss wr. each model parameter. 105 | loss['total'].backward() 106 | # update parameters of model by gradients. 107 | optimizer.step() 108 | 109 | if _cfg._dict['SCHEDULER']['FREQUENCY'] == 'iteration': 110 | scheduler.step() 111 | 112 | for l_key in loss: 113 | tbwriter.add_scalar('train_loss_batch/{}'.format(l_key), loss[l_key].item(), len(dset) * (epoch-1) + t) 114 | # Updating batch losses to then get mean for epoch loss 115 | metrics.losses_track.update_train_losses(loss) 116 | 117 | if (t + 1) % _cfg._dict['TRAIN']['SUMMARY_PERIOD'] == 0: 118 | loss_print = '=> Epoch [{}/{}], Iteration [{}/{}], Learn Rate: {}, Train Losses: '\ 119 | .format(epoch, nbr_epochs, t+1, len(dset), scheduler.get_lr()[0]) 120 | for key in loss.keys(): loss_print += '{} = {:.6f}, '.format(key, loss[key]) 121 | logger.info(loss_print[:-3]) 122 | 123 | metrics.add_batch(prediction=scores, target=model.get_target(data)) 124 | 125 | for l_key in metrics.losses_track.train_losses: 126 | tbwriter.add_scalar('train_loss_epoch/{}'.format(l_key), 127 | metrics.losses_track.train_losses[l_key].item()/metrics.losses_track.train_iteration_counts, 128 | epoch - 1) 129 | tbwriter.add_scalar('lr/lr', scheduler.get_lr()[0], epoch - 1) 130 | 131 | epoch_loss = metrics.losses_track.train_losses['total']/metrics.losses_track.train_iteration_counts 132 | 133 | for scale in metrics.evaluator.keys(): 134 | tbwriter.add_scalar('train_performance/{}/mIoU'.format(scale), metrics.get_semantics_mIoU(scale).item(), epoch-1) 135 | tbwriter.add_scalar('train_performance/{}/IoU'.format(scale), metrics.get_occupancy_IoU(scale).item(), epoch-1) 136 | # tbwriter.add_scalar('train_performance/{}/Precision'.format(scale), metrics.get_occupancy_Precision(scale).item(), epoch-1) 137 | # tbwriter.add_scalar('train_performance/{}/Recall'.format(scale), metrics.get_occupancy_Recall(scale).item(), epoch-1) 138 | # tbwriter.add_scalar('train_performance/{}/F1'.format(scale), metrics.get_occupancy_F1(scale).item(), epoch-1) 139 | 140 | logger.info('=> [Epoch {} - Total Train Loss = {}]'.format(epoch, epoch_loss)) 141 | for scale in metrics.evaluator.keys(): 142 | loss_scale = metrics.losses_track.train_losses['semantic_{}'.format(scale)].item()/metrics.losses_track.train_iteration_counts 143 | logger.info('=> [Epoch {} - Scale {}: Loss = {:.6f} - mIoU = {:.6f} - IoU = {:.6f} ' 144 | '- P = {:.6f} - R = {:.6f} - F1 = {:.6f}]' 145 | .format(epoch, scale, loss_scale, 146 | metrics.get_semantics_mIoU(scale).item(), 147 | metrics.get_occupancy_IoU(scale).item(), 148 | metrics.get_occupancy_Precision(scale).item(), 149 | metrics.get_occupancy_Recall(scale).item(), 150 | metrics.get_occupancy_F1(scale).item())) 151 | 152 | logger.info('=> Epoch {} - Training set class-wise IoU:'.format(epoch)) 153 | for i in range(1, metrics.nbr_classes): 154 | class_name = dset.dataset.dataset_config['labels'][dset.dataset.dataset_config['learning_map_inv'][i]] 155 | class_score = metrics.evaluator['1_1'].getIoU()[1][i] 156 | logger.info(' => IoU {}: {:.6f}'.format(class_name, class_score)) 157 | 158 | # Reset evaluator for validation... 159 | metrics.reset_evaluator() 160 | 161 | checkpoint_info = validate(model, dataset['val'], _cfg, epoch, logger, tbwriter, metrics) 162 | 163 | # Reset evaluator and losses for next epoch... 164 | metrics.reset_evaluator() 165 | metrics.losses_track.restart_train_losses() 166 | metrics.losses_track.restart_validation_losses() 167 | 168 | if _cfg._dict['SCHEDULER']['FREQUENCY'] == 'epoch': 169 | scheduler.step() 170 | 171 | # Save checkpoints 172 | for k in checkpoint_info.keys(): 173 | checkpoint_path = os.path.join(_cfg._dict['OUTPUT']['OUTPUT_PATH'], 'chkpt', k) 174 | _cfg._dict['STATUS'][checkpoint_info[k]] = checkpoint_path 175 | checkpoint.save(checkpoint_path, model, optimizer, scheduler, epoch, _cfg._dict) 176 | 177 | # Save checkpoint if current epoch matches checkpoint period 178 | if epoch % _cfg._dict['TRAIN']['CHECKPOINT_PERIOD'] == 0: 179 | checkpoint_path = os.path.join(_cfg._dict['OUTPUT']['OUTPUT_PATH'], 'chkpt', str(epoch).zfill(2)) 180 | checkpoint.save(checkpoint_path, model, optimizer, scheduler, epoch, _cfg._dict) 181 | 182 | # Update config file 183 | _cfg.update_config(resume=True) 184 | 185 | return metrics.best_metric_record 186 | 187 | 188 | def validate(model, dset, _cfg, epoch, logger, tbwriter, metrics): 189 | 190 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 191 | dtype = torch.float32 # Tensor type to be used 192 | 193 | nbr_epochs = _cfg._dict['TRAIN']['EPOCHS'] 194 | 195 | logger.info('=> Passing the network on the validation set...') 196 | 197 | model.eval() 198 | 199 | with torch.no_grad(): 200 | 201 | for t, (data, indices) in enumerate(dset): 202 | 203 | data = dict_to(data, device, dtype) 204 | 205 | scores = model(data) 206 | 207 | loss = model.compute_loss(scores, data) 208 | 209 | for l_key in loss: 210 | tbwriter.add_scalar('validation_loss_batch/{}'.format(l_key), loss[l_key].item(), len(dset) * (epoch-1) + t) 211 | # Updating batch losses to then get mean for epoch loss 212 | metrics.losses_track.update_validaiton_losses(loss) 213 | 214 | if (t + 1) % _cfg._dict['VAL']['SUMMARY_PERIOD'] == 0: 215 | loss_print = '=> Epoch [{}/{}], Iteration [{}/{}], Train Losses: '.format(epoch, nbr_epochs, t+1, len(dset)) 216 | for key in loss.keys(): loss_print += '{} = {:.6f}, '.format(key, loss[key]) 217 | logger.info(loss_print[:-3]) 218 | 219 | metrics.add_batch(prediction=scores, target=model.get_target(data)) 220 | 221 | for l_key in metrics.losses_track.validation_losses: 222 | tbwriter.add_scalar('validation_loss_epoch/{}'.format(l_key), 223 | metrics.losses_track.validation_losses[l_key].item()/metrics.losses_track.validation_iteration_counts, 224 | epoch - 1) 225 | 226 | epoch_loss = metrics.losses_track.validation_losses['total']/metrics.losses_track.validation_iteration_counts 227 | 228 | for scale in metrics.evaluator.keys(): 229 | tbwriter.add_scalar('validation_performance/{}/mIoU'.format(scale), metrics.get_semantics_mIoU(scale).item(), epoch-1) 230 | tbwriter.add_scalar('validation_performance/{}/IoU'.format(scale), metrics.get_occupancy_IoU(scale).item(), epoch-1) 231 | # tbwriter.add_scalar('validation_performance/{}/Precision'.format(scale), metrics.get_occupancy_Precision(scale).item(), epoch-1) 232 | # tbwriter.add_scalar('validation_performance/{}/Recall'.format(scale), metrics.get_occupancy_Recall(scale).item(), epoch-1) 233 | # tbwriter.add_scalar('validation_performance/{}/F1'.format(scale), metrics.get_occupancy_F1(scale).item(), epoch-1) 234 | 235 | logger.info('=> [Epoch {} - Total Validation Loss = {}]'.format(epoch, epoch_loss)) 236 | for scale in metrics.evaluator.keys(): 237 | loss_scale = metrics.losses_track.validation_losses['semantic_{}'.format(scale)].item()/metrics.losses_track.train_iteration_counts 238 | logger.info('=> [Epoch {} - Scale {}: Loss = {:.6f} - mIoU = {:.6f} - IoU = {:.6f} ' 239 | '- P = {:.6f} - R = {:.6f} - F1 = {:.6f}]' 240 | .format(epoch, scale, loss_scale, 241 | metrics.get_semantics_mIoU(scale).item(), 242 | metrics.get_occupancy_IoU(scale).item(), 243 | metrics.get_occupancy_Precision(scale).item(), 244 | metrics.get_occupancy_Recall(scale).item(), 245 | metrics.get_occupancy_F1(scale).item())) 246 | 247 | logger.info('=> Epoch {} - Validation set class-wise IoU:'.format(epoch)) 248 | for i in range(1, metrics.nbr_classes): 249 | class_name = dset.dataset.dataset_config['labels'][dset.dataset.dataset_config['learning_map_inv'][i]] 250 | class_score = metrics.evaluator['1_1'].getIoU()[1][i] 251 | logger.info(' => {}: {:.6f}'.format(class_name, class_score)) 252 | 253 | checkpoint_info = {} 254 | 255 | if epoch_loss < _cfg._dict['OUTPUT']['BEST_LOSS']: 256 | logger.info('=> Best loss on validation set encountered: ({} < {})'. 257 | format(epoch_loss, _cfg._dict['OUTPUT']['BEST_LOSS'])) 258 | _cfg._dict['OUTPUT']['BEST_LOSS'] = epoch_loss.item() 259 | checkpoint_info['best-loss'] = 'BEST_LOSS' 260 | 261 | mIoU_1_1 = metrics.get_semantics_mIoU('1_1') 262 | IoU_1_1 = metrics.get_occupancy_IoU('1_1') 263 | if mIoU_1_1 > _cfg._dict['OUTPUT']['BEST_METRIC']: 264 | logger.info('=> Best metric on validation set encountered: ({} > {})'. 265 | format(mIoU_1_1, _cfg._dict['OUTPUT']['BEST_METRIC'])) 266 | _cfg._dict['OUTPUT']['BEST_METRIC'] = mIoU_1_1.item() 267 | checkpoint_info['best-metric'] = 'BEST_METRIC' 268 | metrics.update_best_metric_record(mIoU_1_1, IoU_1_1, epoch_loss.item(), epoch) 269 | 270 | checkpoint_info['last'] = 'LAST' 271 | 272 | return checkpoint_info 273 | 274 | 275 | def main(): 276 | 277 | # https://github.com/pytorch/pytorch/issues/27588 278 | torch.backends.cudnn.enabled = False 279 | 280 | seed_all(0) 281 | 282 | args = parse_args() 283 | 284 | train_f = args.config_file 285 | dataset_f = args.dataset_root 286 | 287 | # Read train configuration file 288 | _cfg = CFG() 289 | _cfg.from_config_yaml(train_f) 290 | 291 | # Replace dataset path in config file by the one passed by argument 292 | if dataset_f is not None: 293 | _cfg._dict['DATASET']['ROOT_DIR'] = dataset_f 294 | 295 | # Create writer for Tensorboard 296 | tbwriter = SummaryWriter(log_dir=os.path.join(_cfg._dict['OUTPUT']['OUTPUT_PATH'], 'metrics')) 297 | 298 | # Setting the logger to print statements and also save them into logs file 299 | logger = get_logger(_cfg._dict['OUTPUT']['OUTPUT_PATH'], 'logs_train.log') 300 | 301 | logger.info('============ Training routine: "%s" ============\n' % train_f) 302 | dataset = get_dataset(_cfg) 303 | 304 | logger.info('=> Loading network architecture...') 305 | model = get_model(_cfg, dataset['train'].dataset) 306 | if torch.cuda.device_count() > 1: 307 | model = nn.DataParallel(model) 308 | model = model.module 309 | 310 | logger.info('=> Loading optimizer...') 311 | optimizer = build_optimizer(_cfg, model) 312 | scheduler = build_scheduler(_cfg, optimizer) 313 | 314 | model, optimizer, scheduler, epoch = checkpoint.load(model, optimizer, scheduler, _cfg._dict['STATUS']['RESUME'], 315 | _cfg._dict['STATUS']['LAST'], logger) 316 | 317 | best_record = train(model, optimizer, scheduler, dataset, _cfg, epoch, logger, tbwriter) 318 | 319 | logger.info('=> ============ Network trained - all epochs passed... ============') 320 | 321 | logger.info('=> [Best performance: Epoch {} - mIoU = {} - IoU {}]'.format(best_record['epoch'], best_record['mIoU'], best_record['IoU'])) 322 | 323 | logger.info('=> Writing config file in output folder - deleting from config files folder') 324 | _cfg.finish_config() 325 | logger.info('=> Training routine completed...') 326 | 327 | exit() 328 | 329 | 330 | if __name__ == '__main__': 331 | main() --------------------------------------------------------------------------------