├── .python-version ├── .gitignore ├── example.png ├── example_pred.png ├── requirements.txt ├── LICENSE ├── README.md ├── train_lit_model.py ├── retriever.py └── LitModel.py /.python-version: -------------------------------------------------------------------------------- 1 | 3.6.9 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | scrap* 2 | __pycache__/ 3 | .ipynb_checkpoints/ -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YassineYousfi/comma10k-baseline/HEAD/example.png -------------------------------------------------------------------------------- /example_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YassineYousfi/comma10k-baseline/HEAD/example_pred.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1 2 | tensorboard==2.2.0 3 | pytorch-lightning==0.9.0 4 | albumentations==0.5.2 5 | segmentation_models_pytorch==0.1.2 6 | pandas==1.1.5 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yassine Yousfi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🚗 comma10k-baseline 2 | 3 | A semantic segmentation baseline using [@comma.ai](https://github.com/commaai)'s [comma10k dataset](https://github.com/commaai/comma10k). 4 | 5 | Using U-Net with efficientnet encoder, this baseline reaches 0.044 validation loss. 6 | 7 | ## Visualize 8 | Here is an example (randomly from the validation set, no cherry picking) 9 | #### Ground truth 10 | ![Ground truth](example.png) 11 | #### Predicted 12 | ![Prediction](example_pred.png) 13 | 14 | ## Info 15 | 16 | The comma10k dataset is currently being labeled, stay tuned for: 17 | - A retrained model when the dataset is released 18 | - More features to use the model 19 | 20 | 21 | ## How to use 22 | This baseline uses two stages (i) 437x582 (ii) 874x1164 (full resolution) 23 | ``` 24 | python3 train_lit_model.py --backbone efficientnet-b4 --version first-stage --gpus 2 --batch-size 28 --epochs 100 --height 437 --width 582 25 | python3 train_lit_model.py --backbone efficientnet-b4 --version second-stage --gpus 2 --batch-size 7 --learning-rate 5e-5 --epochs 30 --height 874 --width 1164 --augmentation-level hard --seed-from-checkpoint .../efficientnet-b4/first-stage/checkpoints/last.ckpt 26 | ``` 27 | 28 | ## WIP and ideas of contributions! 29 | - Update to pytorch lightning 1.0 30 | - Try more image augmentations 31 | - Pretrain on a larger driving dataset (make sure license is permissive) 32 | - Try over sampling images with small or far objects 33 | 34 | 35 | ## Dependecies 36 | Python 3.5+, pytorch 1.6+ and dependencies listed in requirements.txt. -------------------------------------------------------------------------------- /train_lit_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Runs a model on a single node across multiple gpus. 3 | """ 4 | import warnings 5 | warnings.simplefilter(action='ignore', category=FutureWarning) 6 | import os 7 | from pathlib import Path 8 | from argparse import ArgumentParser 9 | from LitModel import * 10 | import torch 11 | from pytorch_lightning import Trainer, seed_everything 12 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint 13 | from pytorch_lightning.loggers import TensorBoardLogger 14 | from pytorch_lightning.callbacks import LearningRateLogger 15 | from pytorch_lightning.utilities.distributed import rank_zero_only 16 | from pytorch_lightning.callbacks import Callback 17 | 18 | seed_everything(1994) 19 | 20 | def setup_callbacks_loggers(args): 21 | 22 | log_path = Path('/home/yyousfi1/LogFiles/comma/') 23 | name = args.backbone 24 | version = args.version 25 | tb_logger = TensorBoardLogger(log_path, name=name, version=version) 26 | lr_logger = LearningRateLogger(logging_interval='epoch') 27 | ckpt_callback = ModelCheckpoint(filepath=Path(tb_logger.log_dir)/'checkpoints/{epoch:02d}_{val_loss:.4f}', 28 | save_top_k=10, save_last=True) 29 | 30 | return ckpt_callback, tb_logger, lr_logger 31 | 32 | 33 | def main(args): 34 | """ Main training routine specific for this project. """ 35 | 36 | if args.seed_from_checkpoint: 37 | print('model seeded') 38 | model = LitModel.load_from_checkpoint(args.seed_from_checkpoint, **vars(args)) 39 | else: 40 | model = LitModel(**vars(args)) 41 | 42 | ckpt_callback, tb_logger, lr_logger = setup_callbacks_loggers(args) 43 | 44 | trainer = Trainer(checkpoint_callback=ckpt_callback, 45 | logger=tb_logger, 46 | callbacks=[lr_logger], 47 | gpus=args.gpus, 48 | min_epochs=args.epochs, 49 | max_epochs=args.epochs, 50 | precision=16, 51 | amp_backend='native', 52 | row_log_interval=100, 53 | log_save_interval=100, 54 | distributed_backend='ddp', 55 | benchmark=True, 56 | sync_batchnorm=True, 57 | resume_from_checkpoint=args.resume_from_checkpoint) 58 | 59 | 60 | trainer.logger.log_hyperparams(model.hparams) 61 | 62 | trainer.fit(model) 63 | 64 | 65 | def run_cli(): 66 | root_dir = os.path.dirname(os.path.realpath(__file__)) 67 | 68 | parent_parser = ArgumentParser(add_help=False) 69 | 70 | parser = LitModel.add_model_specific_args(parent_parser) 71 | 72 | parser.add_argument('--version', 73 | default=None, 74 | type=str, 75 | metavar='V', 76 | help='version or id of the net') 77 | parser.add_argument('--resume-from-checkpoint', 78 | default=None, 79 | type=str, 80 | metavar='RFC', 81 | help='path to checkpoint') 82 | parser.add_argument('--seed-from-checkpoint', 83 | default=None, 84 | type=str, 85 | metavar='SFC', 86 | help='path to checkpoint seed') 87 | 88 | args = parser.parse_args() 89 | 90 | main(args) 91 | 92 | 93 | if __name__ == '__main__': 94 | run_cli() -------------------------------------------------------------------------------- /retriever.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.simplefilter(action='ignore', category=FutureWarning) 3 | import pandas as pd 4 | import numpy as np 5 | import pickle 6 | import cv2 7 | import albumentations as A 8 | from albumentations.core.composition import Compose 9 | from typing import Callable, List 10 | from pathlib import Path 11 | import os 12 | from torch.utils.data import Dataset 13 | import torch 14 | import sys 15 | 16 | def pad_to_multiple(x, k=32): 17 | return int(k*(np.ceil(x/k))) 18 | 19 | def get_train_transforms(height: int = 437, 20 | width: int = 582, 21 | level: str = 'hard'): 22 | if level == 'light': 23 | return A.Compose([ 24 | A.HorizontalFlip(p=0.5), 25 | A.IAAAdditiveGaussianNoise(p=0.2), 26 | A.OneOf( 27 | [A.CLAHE(p=1.0), 28 | A.RandomBrightness(p=1.0), 29 | A.RandomGamma(p=1.0), 30 | ],p=0.5), 31 | A.OneOf( 32 | [A.IAASharpen(p=1.0), 33 | A.Blur(blur_limit=3, p=1.0), 34 | A.MotionBlur(blur_limit=3, p=1.0), 35 | ],p=0.5), 36 | A.OneOf( 37 | [A.RandomContrast(p=1.0), 38 | A.HueSaturationValue(p=1.0), 39 | ],p=0.5), 40 | A.Resize(height=height, width=width, p=1.0), 41 | A.PadIfNeeded(pad_to_multiple(height), 42 | pad_to_multiple(width), 43 | border_mode=cv2.BORDER_CONSTANT, 44 | value=0, 45 | mask_value=0) 46 | ], p=1.0) 47 | 48 | elif level == 'hard': 49 | return A.Compose([ 50 | A.HorizontalFlip(p=0.5), 51 | A.IAAAdditiveGaussianNoise(p=0.2), 52 | A.OneOf( 53 | [A.GridDistortion(border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0, p=1.0), 54 | A.ElasticTransform(alpha_affine=10, border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0, p=1.0), 55 | A.ShiftScaleRotate( 56 | shift_limit=0, 57 | scale_limit=0, 58 | rotate_limit=10, 59 | border_mode=cv2.BORDER_CONSTANT, 60 | value=0, 61 | mask_value=0, 62 | p=1.0 63 | ), 64 | A.OpticalDistortion(border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0, p=1.0), 65 | ],p=0.5), 66 | A.OneOf( 67 | [A.CLAHE(p=1.0), 68 | A.RandomBrightness(p=1.0), 69 | A.RandomGamma(p=1.0), 70 | A.ISONoise(p=1.0) 71 | ],p=0.5), 72 | A.OneOf( 73 | [A.IAASharpen(p=1.0), 74 | A.Blur(blur_limit=3, p=1.0), 75 | A.MotionBlur(blur_limit=3, p=1.0), 76 | ],p=0.5), 77 | A.OneOf( 78 | [A.RandomContrast(p=1.0), 79 | A.HueSaturationValue(p=1.0), 80 | ],p=0.5), 81 | A.Resize(height=height, width=width, p=1.0), 82 | A.Cutout(p=0.3), 83 | A.PadIfNeeded(pad_to_multiple(height), 84 | pad_to_multiple(width), 85 | border_mode=cv2.BORDER_CONSTANT, 86 | value=0, 87 | mask_value=0) 88 | ], p=1.0) 89 | elif level == 'hard_weather': 90 | return A.Compose([ 91 | A.HorizontalFlip(p=0.5), 92 | A.IAAAdditiveGaussianNoise(p=0.2), 93 | A.OneOf( 94 | [A.GridDistortion(border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0, p=1.0), 95 | A.ElasticTransform(alpha_affine=10, border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0, p=1.0), 96 | A.ShiftScaleRotate( 97 | shift_limit=0, 98 | scale_limit=0, 99 | rotate_limit=10, 100 | border_mode=cv2.BORDER_CONSTANT, 101 | value=0, 102 | mask_value=0, 103 | p=1.0 104 | ), 105 | A.OpticalDistortion(border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0, p=1.0), 106 | ],p=0.5), 107 | A.OneOf( 108 | [A.CLAHE(p=1.0), 109 | A.RandomBrightness(p=1.0), 110 | A.RandomGamma(p=1.0), 111 | A.ISONoise(p=1.0) 112 | ],p=0.5), 113 | A.OneOf( 114 | [A.IAASharpen(p=1.0), 115 | A.Blur(blur_limit=3, p=1.0), 116 | A.MotionBlur(blur_limit=3, p=1.0), 117 | ],p=0.5), 118 | A.OneOf( 119 | [A.RandomContrast(p=1.0), 120 | A.HueSaturationValue(p=1.0), 121 | ],p=0.5), 122 | A.OneOf( 123 | [A.RandomFog(fog_coef_upper=0.8, p=1.0), 124 | A.RandomRain(p=1.0), 125 | A.RandomSnow(p=1.0), 126 | A.RandomSunFlare(src_radius=100, p=1.0) 127 | ],p=0.4), 128 | A.Resize(height=height, width=width, p=1.0), 129 | A.Cutout(p=0.3), 130 | A.PadIfNeeded(pad_to_multiple(height), 131 | pad_to_multiple(width), 132 | border_mode=cv2.BORDER_CONSTANT, 133 | value=0, 134 | mask_value=0) 135 | ], p=1.0) 136 | 137 | def get_valid_transforms(height: int = 437, 138 | width: int = 582): 139 | return A.Compose([ 140 | A.Resize(height=height, width=width, p=1.0), 141 | A.PadIfNeeded(pad_to_multiple(height), 142 | pad_to_multiple(width), 143 | border_mode=cv2.BORDER_CONSTANT, 144 | value=0, 145 | mask_value=0) 146 | ], p=1.0) 147 | 148 | def to_tensor(x, **kwargs): 149 | return x.transpose(2, 0, 1).astype('float32') 150 | 151 | def get_preprocessing(preprocessing_fn: Callable): 152 | _transform = [ 153 | A.Lambda(image=preprocessing_fn), 154 | A.Lambda(image=to_tensor, mask=to_tensor), 155 | ] 156 | return A.Compose(_transform) 157 | 158 | class TrainRetriever(Dataset): 159 | 160 | def __init__(self, 161 | data_path: Path, 162 | image_names: List[str], 163 | preprocess_fn: Callable, 164 | transforms: Compose, 165 | class_values: List[int]): 166 | super().__init__() 167 | 168 | self.data_path = data_path 169 | self.image_names = image_names 170 | self.transforms = transforms 171 | self.preprocess = get_preprocessing(preprocess_fn) 172 | self.class_values = class_values 173 | self.images_folder = 'imgs' 174 | self.masks_folder = 'masks' 175 | 176 | def __getitem__(self, index: int): 177 | 178 | image_name = self.image_names[index] 179 | 180 | image = cv2.imread(str(self.data_path/self.images_folder/image_name)) 181 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 182 | 183 | mask = cv2.imread(str(self.data_path/self.masks_folder/image_name), 0).astype('uint8') 184 | 185 | if self.transforms: 186 | sample = self.transforms(image=image, mask=mask) 187 | image = sample['image'] 188 | mask = sample['mask'] 189 | 190 | mask = np.stack([(mask == v) for v in self.class_values], axis=-1).astype('uint8') 191 | 192 | if self.preprocess: 193 | sample = self.preprocess(image=image, mask=mask) 194 | image = sample['image'] 195 | mask = sample['mask'] 196 | 197 | return image, mask 198 | 199 | def __len__(self) -> int: 200 | return len(self.image_names) 201 | 202 | 203 | -------------------------------------------------------------------------------- /LitModel.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.simplefilter(action='ignore', category=FutureWarning) 3 | import pandas as pd 4 | import numpy as np 5 | import pickle 6 | import argparse 7 | from collections import OrderedDict 8 | from pathlib import Path 9 | from tempfile import TemporaryDirectory 10 | from typing import Optional, Generator, Union 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import optim 14 | from torch.nn import Module 15 | from torch.utils.data import DataLoader 16 | import pytorch_lightning as pl 17 | from pytorch_lightning import _logger as log 18 | import random 19 | from retriever import * 20 | from pytorch_lightning.metrics.converters import _sync_ddp_if_available 21 | import segmentation_models_pytorch as smp 22 | 23 | class LitModel(pl.LightningModule): 24 | """Transfer Learning 25 | """ 26 | def __init__(self, 27 | data_path: Union[str, Path], 28 | backbone: str = 'efficientnet-b0', 29 | augmentation_level: str = 'light', 30 | batch_size: int = 32, 31 | lr: float = 1e-4, 32 | eps: float = 1e-7, 33 | height: int = 14*32, 34 | width: int = 18*32, 35 | num_workers: int = 6, 36 | epochs: int = 50, 37 | gpus: int = 1, 38 | weight_decay: float = 1e-3, 39 | class_values: List[int] = [41, 76, 90, 124, 161, 0] # 0 added for padding 40 | ,**kwargs) -> None: 41 | 42 | super().__init__() 43 | self.data_path = Path(data_path) 44 | self.epochs = epochs 45 | self.backbone = backbone 46 | self.batch_size = batch_size 47 | self.lr = lr 48 | self.height = height 49 | self.width = width 50 | self.num_workers = num_workers 51 | self.gpus = gpus 52 | self.weight_decay = weight_decay 53 | self.eps = eps 54 | self.class_values = class_values 55 | self.augmentation_level = augmentation_level 56 | 57 | self.save_hyperparameters() 58 | 59 | self.train_custom_metrics = {'train_acc': smp.utils.metrics.Accuracy(activation='softmax2d')} 60 | self.validation_custom_metrics = {'val_acc': smp.utils.metrics.Accuracy(activation='softmax2d')} 61 | 62 | self.preprocess_fn = smp.encoders.get_preprocessing_fn(self.backbone, pretrained='imagenet') 63 | 64 | self.__build_model() 65 | 66 | def __build_model(self): 67 | """Define model layers & loss.""" 68 | 69 | # 1. net: 70 | 71 | self.net = smp.Unet(self.backbone, classes=len(self.class_values), 72 | activation=None, encoder_weights='imagenet') 73 | 74 | # 2. Loss: 75 | self.loss_func = lambda x, y: torch.nn.CrossEntropyLoss()(x, torch.argmax(y,axis=1)) 76 | 77 | def forward(self, x): 78 | """Forward pass. Returns logits.""" 79 | 80 | x = self.net(x) 81 | 82 | return x 83 | 84 | def loss(self, logits, labels): 85 | """Use the loss_func""" 86 | return self.loss_func(logits, labels) 87 | 88 | def training_step(self, batch, batch_idx): 89 | 90 | # 1. Forward pass: 91 | x, y = batch 92 | y_logits = self.forward(x) 93 | 94 | # 2. Compute loss & accuracy: 95 | train_loss = self.loss(y_logits, y) 96 | 97 | metrics = {} 98 | for metric_name in self.train_custom_metrics.keys(): 99 | metrics[metric_name] = self.train_custom_metrics[metric_name](y_logits, y) 100 | 101 | # 3. Outputs: 102 | output = OrderedDict({'loss': train_loss, 103 | 'log': metrics, 104 | 'progress_bar': metrics}) 105 | 106 | return output 107 | 108 | def validation_step(self, batch, batch_idx): 109 | 110 | # 1. Forward pass: 111 | x, y = batch 112 | y_logits = self.forward(x) 113 | 114 | # 2. Compute loss & accuracy: 115 | val_loss = self.loss(y_logits, y) 116 | 117 | metrics = {'val_loss': val_loss} 118 | 119 | for metric_name in self.validation_custom_metrics.keys(): 120 | metrics[metric_name] = self.validation_custom_metrics[metric_name](y_logits, y) 121 | 122 | return metrics 123 | 124 | def validation_epoch_end(self, outputs): 125 | """Compute and log training loss and accuracy at the epoch level. 126 | Average statistics accross GPUs in case of DDP 127 | """ 128 | keys = outputs[0].keys() 129 | metrics = {} 130 | for metric_name in keys: 131 | metrics[metric_name] = _sync_ddp_if_available(torch.stack([output[metric_name] for output in outputs]).mean(), reduce_op='avg') 132 | 133 | metrics['step'] = self.current_epoch 134 | 135 | return {'log': metrics} 136 | 137 | 138 | def configure_optimizers(self): 139 | 140 | optimizer = torch.optim.Adam 141 | optimizer_kwargs = {'eps': self.eps} 142 | 143 | optimizer = optimizer(self.parameters(), 144 | lr=self.lr, 145 | weight_decay=self.weight_decay, 146 | **optimizer_kwargs) 147 | 148 | scheduler_kwargs = {'T_max': self.epochs*len(self.train_dataset)//self.gpus//self.batch_size, 149 | 'eta_min':self.lr/50} 150 | 151 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR 152 | interval = 'step' 153 | scheduler = scheduler(optimizer, **scheduler_kwargs) 154 | 155 | return [optimizer], [{'scheduler':scheduler, 'interval': interval, 'name': 'lr'}] 156 | 157 | 158 | def prepare_data(self): 159 | """Data download is not part of this script 160 | Get the data from https://github.com/commaai/comma10k 161 | """ 162 | assert (self.data_path/'imgs').is_dir(), 'Images not found' 163 | assert (self.data_path/'masks').is_dir(), 'Masks not found' 164 | assert (self.data_path/'files_trainable').exists(), 'Files trainable file not found' 165 | 166 | print('data ready') 167 | 168 | def setup(self, stage: str): 169 | 170 | image_names = np.loadtxt(self.data_path/'files_trainable', dtype='str').tolist() 171 | 172 | random.shuffle(image_names) 173 | 174 | self.train_dataset = TrainRetriever( 175 | data_path=self.data_path, 176 | image_names=[x.split('masks/')[-1] for x in image_names if not x.endswith('9.png')], 177 | preprocess_fn=self.preprocess_fn, 178 | transforms=get_train_transforms(self.height, self.width, self.augmentation_level), 179 | class_values=self.class_values 180 | ) 181 | 182 | self.valid_dataset = TrainRetriever( 183 | data_path=self.data_path, 184 | image_names=[x.split('masks/')[-1] for x in image_names if x.endswith('9.png')], 185 | preprocess_fn=self.preprocess_fn, 186 | transforms=get_valid_transforms(self.height, self.width), 187 | class_values=self.class_values 188 | ) 189 | 190 | 191 | def __dataloader(self, train): 192 | """Train/validation loaders.""" 193 | 194 | _dataset = self.train_dataset if train else self.valid_dataset 195 | loader = DataLoader(dataset=_dataset, 196 | batch_size=self.batch_size, 197 | num_workers=self.num_workers, 198 | shuffle=True if train else False) 199 | 200 | return loader 201 | 202 | def train_dataloader(self): 203 | log.info('Training data loaded.') 204 | return self.__dataloader(train=True) 205 | 206 | def val_dataloader(self): 207 | log.info('Validation data loaded.') 208 | return self.__dataloader(train=False) 209 | 210 | 211 | @staticmethod 212 | def add_model_specific_args(parent_parser): 213 | parser = argparse.ArgumentParser(parents=[parent_parser]) 214 | parser.add_argument('--backbone', 215 | default='efficientnet-b0', 216 | type=str, 217 | metavar='BK', 218 | help='Name as in segmentation_models_pytorch') 219 | parser.add_argument('--augmentation-level', 220 | default='light', 221 | type=str, 222 | help='Training augmentation level c.f. retiriever') 223 | parser.add_argument('--data-path', 224 | default='/home/yyousfi1/commaai/comma10k', 225 | type=str, 226 | metavar='dp', 227 | help='data_path') 228 | parser.add_argument('--epochs', 229 | default=30, 230 | type=int, 231 | metavar='N', 232 | help='total number of epochs') 233 | parser.add_argument('--batch-size', 234 | default=32, 235 | type=int, 236 | metavar='B', 237 | help='batch size', 238 | dest='batch_size') 239 | parser.add_argument('--gpus', 240 | type=int, 241 | default=1, 242 | help='number of gpus to use') 243 | parser.add_argument('--lr', 244 | '--learning-rate', 245 | default=1e-4, 246 | type=float, 247 | metavar='LR', 248 | help='initial learning rate', 249 | dest='lr') 250 | parser.add_argument('--eps', 251 | default=1e-7, 252 | type=float, 253 | help='eps for adaptive optimizers', 254 | dest='eps') 255 | parser.add_argument('--height', 256 | default=14*32, 257 | type=int, 258 | help='image height') 259 | parser.add_argument('--width', 260 | default=18*32, 261 | type=int, 262 | help='image width') 263 | parser.add_argument('--num-workers', 264 | default=6, 265 | type=int, 266 | metavar='W', 267 | help='number of CPU workers', 268 | dest='num_workers') 269 | parser.add_argument('--weight-decay', 270 | default=1e-3, 271 | type=float, 272 | metavar='wd', 273 | help='Optimizer weight decay') 274 | 275 | return parser 276 | --------------------------------------------------------------------------------