├── .gitignore ├── README.md ├── assets └── figure.png ├── configs ├── ddep.yaml └── dep.yaml ├── license ├── requirements.txt ├── scripts └── extract_model_weights.py ├── src ├── data.py ├── model.py └── pl_utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | output/ 3 | weights/ 4 | notes.txt 5 | 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/#use-with-ide 116 | .pdm.toml 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ 167 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Decoder Denoising Pretraining for Semantic Segmentation 2 | 3 | PyTorch reimplementation of ["Decoder Denoising Pretraining for Semantic Segmentation"](https://arxiv.org/abs/2205.11423). 4 | 5 |

6 | 7 |

8 | 9 | ## Requirements 10 | - Python 3.8+ 11 | - `pip install -r requirements` 12 | 13 | ## Usage 14 | To perform decoder denoising pretraining on a U-Net with a ResNet-50 encoder run: 15 | ``` 16 | python train.py --gpus 1 --max_epochs 100 --data.root path/to/data/ --model.arch unet --model.encoder resnet50 17 | ``` 18 | 19 | - `--model.arch` can be one of `unet, unetplusplus, manet, linknet, fpn, pspnet, deeplabv3, deeplabv3plus, pan`. 20 | - `--model.encoder` can be any from the list [here](https://smp.readthedocs.io/en/latest/encoders.html). 21 | - `configs/` contains example configuration files which can be run with `python train.py --config path/to/config`. 22 | - Run `python train.py --help` to get descriptions for all the options. 23 | 24 | ### Using a Pretrained Model 25 | Model weights can be extracted from a pretraining checkpoint file by running: 26 | ``` 27 | python scripts/extract_model_weights.py -c path/to/checkpoint/file 28 | ``` 29 | You can then initialize a segmentation model with these weights with the following (example for U-Net with ResNet-50 encoder): 30 | ```python 31 | import segmentation_models_pytorch as smp 32 | import torch 33 | import torch.nn as nn 34 | 35 | weights = torch.load("weights.pt") 36 | 37 | model = smp.create_model( 38 | "unet", 39 | encoder_name="resnet50", 40 | in_channels=3, 41 | classes=3, # Same number used during pretraining for now 42 | encoder_weights=None, 43 | ) 44 | 45 | model.load_state_dict(weights, strict=True) 46 | 47 | # Replace segmentation head for fine-tuning 48 | in_channels = model.segmentation_head[0].in_channels 49 | num_classes = 10 50 | model.segmentation_head[0] = nn.Conv2d(in_channels, num_classes, kernel_size=3, padding=1) 51 | ``` 52 | 53 | ## Citation 54 | ```bibtex 55 | @inproceedings{brempong2022denoising, 56 | title={Denoising Pretraining for Semantic Segmentation}, 57 | author={Brempong, Emmanuel Asiedu and Kornblith, Simon and Chen, Ting and Parmar, Niki and Minderer, Matthias and Norouzi, Mohammad}, 58 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 59 | pages={4175--4186}, 60 | year={2022} 61 | } 62 | ``` 63 | -------------------------------------------------------------------------------- /assets/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bwconrad/decoder-denoising/e3b0d68512d1ec68cd06b1151d6e8edb2ad0cf3b/assets/figure.png -------------------------------------------------------------------------------- /configs/ddep.yaml: -------------------------------------------------------------------------------- 1 | name: ddep 2 | gpus: 1 3 | max_epochs: 100 4 | precision: 16 5 | data: 6 | root: path/to/data/ 7 | size: 256 8 | crop: 224 9 | num_val: 1000 10 | batch_size: 256 11 | workers: 8 12 | model: 13 | lr: 0.0001 14 | optimizer: adam 15 | betas: 16 | - 0.9 17 | - 0.999 18 | weight_decay: 0.0 19 | momentum: 0.9 20 | arch: unet 21 | encoder: resnet50 22 | in_channels: 3 23 | mode: decoder 24 | noise_type: scaled 25 | noise_std: 0.22 26 | loss_type: l2 27 | channel_last: true 28 | -------------------------------------------------------------------------------- /configs/dep.yaml: -------------------------------------------------------------------------------- 1 | name: dep 2 | gpus: 1 3 | max_epochs: 100 4 | precision: 16 5 | data: 6 | root: path/to/data/ 7 | size: 256 8 | crop: 224 9 | num_val: 1000 10 | batch_size: 256 11 | workers: 8 12 | model: 13 | lr: 0.0001 14 | optimizer: adam 15 | betas: 16 | - 0.9 17 | - 0.999 18 | weight_decay: 0.0 19 | momentum: 0.9 20 | arch: unet 21 | encoder: resnet50 22 | in_channels: 3 23 | mode: encoder+decoder 24 | noise_type: scaled 25 | noise_std: 0.22 26 | loss_type: l2 27 | channel_last: true 28 | -------------------------------------------------------------------------------- /license: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ben Conrad 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.21.5 2 | Pillow==9.2.0 3 | pytorch_lightning[extra]==1.6.4 4 | segmentation_models_pytorch==0.2.1 5 | torch==1.12.0 6 | torchvision==0.13.0 7 | wandb==0.12.15 8 | -------------------------------------------------------------------------------- /scripts/extract_model_weights.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to extract the network's state_dict from a checkpoint file 3 | """ 4 | 5 | from argparse import ArgumentParser 6 | 7 | import torch 8 | 9 | if __name__ == "__main__": 10 | parser = ArgumentParser() 11 | parser.add_argument("--checkpoint", "-c", type=str, required=True) 12 | parser.add_argument("--output", "-o", type=str, default="weights.pt") 13 | parser.add_argument("--prefix", "-p", type=str, default="net") 14 | 15 | args = parser.parse_args() 16 | 17 | checkpoint = torch.load(args.checkpoint, map_location="cpu") 18 | checkpoint = checkpoint["state_dict"] 19 | 20 | newmodel = {} 21 | for k, v in checkpoint.items(): 22 | if not k.startswith(args.prefix): 23 | continue 24 | 25 | k = k.replace(args.prefix + ".", "") 26 | newmodel[k] = v 27 | 28 | with open(args.output, "wb") as f: 29 | torch.save(newmodel, f) 30 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from typing import Callable 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | import torch.utils.data as data 8 | from PIL import Image 9 | from torch.utils.data import DataLoader 10 | from torchvision.transforms import (Compose, Lambda, RandomCrop, 11 | RandomHorizontalFlip, Resize, ToTensor) 12 | 13 | 14 | class SimpleDataModule(pl.LightningDataModule): 15 | def __init__( 16 | self, 17 | root: str, 18 | size: int = 256, 19 | crop: int = 224, 20 | num_val: int = 1000, 21 | batch_size: int = 32, 22 | workers: int = 4, 23 | ): 24 | """Basic data module 25 | 26 | Args: 27 | root: Path to image directory 28 | size: Size of resized image 29 | crop: Size of image crop 30 | num_val: Number of validation samples 31 | batch_size: Number of batch samples 32 | workers: Number of data workers 33 | """ 34 | super().__init__() 35 | self.root = root 36 | self.num_val = num_val 37 | self.batch_size = batch_size 38 | self.workers = workers 39 | 40 | self.transforms = Compose( 41 | [ 42 | Resize(size), 43 | RandomCrop(crop), 44 | RandomHorizontalFlip(), 45 | ToTensor(), 46 | Lambda(lambda t: (t * 2) - 1), # Scale to [-1, 1] 47 | ] 48 | ) 49 | 50 | def setup(self, stage="fit"): 51 | if stage == "fit": 52 | dataset = SimpleDataset(self.root, self.transforms) 53 | 54 | self.train_dataset, self.val_dataset = data.random_split( 55 | dataset, 56 | [len(dataset) - self.num_val, self.num_val], 57 | generator=torch.Generator().manual_seed(42), 58 | ) 59 | 60 | def train_dataloader(self): 61 | return DataLoader( 62 | self.train_dataset, 63 | batch_size=self.batch_size, 64 | shuffle=True, 65 | num_workers=self.workers, 66 | pin_memory=True, 67 | drop_last=True, 68 | ) 69 | 70 | def val_dataloader(self): 71 | return DataLoader( 72 | self.val_dataset, 73 | batch_size=self.batch_size, 74 | shuffle=False, 75 | num_workers=self.workers, 76 | pin_memory=True, 77 | drop_last=False, 78 | ) 79 | 80 | 81 | class SimpleDataset(data.Dataset): 82 | def __init__(self, root: str, transforms: Callable): 83 | """Image dataset from directory 84 | 85 | Args: 86 | root: Path to directory 87 | transforms: Image augmentations 88 | """ 89 | super().__init__() 90 | self.root = root 91 | self.paths = [ 92 | f for f in glob(f"{root}/**/*", recursive=True) if os.path.isfile(f) 93 | ] 94 | self.transforms = transforms 95 | 96 | print(f"Loaded {len(self.paths)} images from {root}") 97 | 98 | def __getitem__(self, index): 99 | img = Image.open(self.paths[index]).convert("RGB") 100 | img = self.transforms(img) 101 | return img 102 | 103 | def __len__(self): 104 | return len(self.paths) 105 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import segmentation_models_pytorch as smp 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.optim import SGD, Adam, AdamW 6 | from torch.optim.lr_scheduler import CosineAnnealingLR 7 | 8 | 9 | class DecoderDenoisingModel(pl.LightningModule): 10 | def __init__( 11 | self, 12 | lr: float = 1e-4, 13 | optimizer: str = "adam", 14 | betas: tuple[float, float] = (0.9, 0.999), 15 | weight_decay: float = 0.0, 16 | momentum: float = 0.9, 17 | arch: str = "unet", 18 | encoder: str = "resnet18", 19 | in_channels: int = 3, 20 | mode: str = "decoder", 21 | noise_type: str = "scaled", 22 | noise_std: float = 0.22, 23 | loss_type: str = "l2", 24 | channel_last: bool = False, 25 | ): 26 | """Decoder Denoising Pretraining Model 27 | 28 | Args: 29 | lr: Learning rate 30 | optimizer: Name of optimizer (adam | adamw | sgd) 31 | betas: Adam beta parameters 32 | weight_decay: Optimizer weight decay 33 | momentum: SGD momentum parameter 34 | arch: Segmentation model architecture 35 | encoder: Segmentation model encoder architecture 36 | in_channels: Number of channels of input image 37 | mode: Denoising pretraining mode (encoder | encoder+decoder) 38 | noise_type: Type of noising process (scaled | simple) 39 | noise_std: Standard deviation/magnitude of gaussian noise 40 | loss_type: Loss function type (l1 | l2 | huber) 41 | channel_last: Change to channel last memory format for possible training speed up 42 | """ 43 | super().__init__() 44 | self.save_hyperparameters() 45 | self.lr = lr 46 | self.optimizer = optimizer 47 | self.betas = betas 48 | self.weight_decay = weight_decay 49 | self.momentum = momentum 50 | self.noise_type = noise_type 51 | self.noise_std = noise_std 52 | self.channel_last = channel_last 53 | 54 | # Initialize loss function 55 | self.loss_fn = self.get_loss_fn(loss_type) 56 | 57 | # Initalize network 58 | self.net = smp.create_model( 59 | arch, 60 | encoder_name=encoder, 61 | in_channels=in_channels, 62 | classes=in_channels, 63 | encoder_weights="imagenet" if mode == "decoder" else None, 64 | ) 65 | 66 | # Freeze encoder when doing decoder only pretraining 67 | if mode == "decoder": 68 | for child in self.net.encoder.children(): # type:ignore 69 | for param in child.parameters(): 70 | param.requires_grad = False 71 | elif mode != "encoder+decoder": 72 | raise ValueError( 73 | f"{mode} is not an available training mode. Should be one of ['decoder', 'encoder+decoder']" 74 | ) 75 | 76 | # Change to channel last memory format 77 | # https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html 78 | if self.channel_last: 79 | self = self.to(memory_format=torch.channels_last) 80 | 81 | @staticmethod 82 | def get_loss_fn(loss_type: str): 83 | if loss_type == "l1": 84 | return F.l1_loss 85 | elif loss_type == "l2": 86 | return F.mse_loss 87 | elif loss_type == "huber": 88 | return F.smooth_l1_loss 89 | else: 90 | raise ValueError( 91 | f"{loss_type} is not an available loss function. Should be one of ['l1', 'l2', 'huber']" 92 | ) 93 | 94 | @torch.no_grad() 95 | def add_noise(self, x): 96 | # Sample noise 97 | noise = torch.randn_like(x) 98 | 99 | # Add noise to x 100 | if self.noise_type == "simple": 101 | x_noise = x + noise * self.noise_std 102 | elif self.noise_type == "scaled": 103 | x_noise = ((1 + self.noise_std**2) ** -0.5) * (x + noise * self.noise_std) 104 | else: 105 | raise ValueError( 106 | f"{self.noise_type} is not an available noise type. Should be one of ['simple', 'scaled']" 107 | ) 108 | 109 | return x_noise, noise 110 | 111 | def denoise_step(self, x, mode="train"): 112 | if self.channel_last: 113 | x = x.to(memory_format=torch.channels_last) 114 | 115 | # Add noise to x 116 | x_noise, noise = self.add_noise(x) 117 | 118 | # Predict noise 119 | pred_noise = self.net(x_noise) 120 | 121 | # Calculate loss 122 | loss = self.loss_fn(pred_noise, noise) 123 | 124 | # Log 125 | self.log(f"{mode}_loss", loss) 126 | 127 | return loss 128 | 129 | def training_step(self, x, _): 130 | self.log( 131 | "lr", 132 | self.trainer.optimizers[0].param_groups[0]["lr"], # type:ignore 133 | prog_bar=True, 134 | ) 135 | return self.denoise_step(x, mode="train") 136 | 137 | def validation_step(self, x, _): 138 | return self.denoise_step(x, mode="val") 139 | 140 | def configure_optimizers(self): 141 | if self.optimizer == "adam": 142 | optimizer = Adam( 143 | self.net.parameters(), 144 | lr=self.lr, 145 | betas=self.betas, 146 | weight_decay=self.weight_decay, 147 | ) 148 | elif self.optimizer == "adamw": 149 | optimizer = AdamW( 150 | self.net.parameters(), 151 | lr=self.lr, 152 | betas=self.betas, 153 | weight_decay=self.weight_decay, 154 | ) 155 | elif self.optimizer == "sgd": 156 | optimizer = SGD( 157 | self.net.parameters(), 158 | lr=self.lr, 159 | momentum=self.momentum, 160 | weight_decay=self.weight_decay, 161 | ) 162 | else: 163 | raise ValueError( 164 | f"{self.optimizer} is not an available optimizer. Should be one of ['adam', 'adamw', 'sgd']" 165 | ) 166 | 167 | scheduler = CosineAnnealingLR( 168 | optimizer, T_max=self.trainer.estimated_stepping_batches # type:ignore 169 | ) 170 | 171 | return { 172 | "optimizer": optimizer, 173 | "lr_scheduler": { 174 | "scheduler": scheduler, 175 | "interval": "step", 176 | }, 177 | } 178 | -------------------------------------------------------------------------------- /src/pl_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from pytorch_lightning.loggers import (CSVLogger, LightningLoggerBase, 4 | TensorBoardLogger, WandbLogger) 5 | from pytorch_lightning.utilities.cli import LightningArgumentParser 6 | 7 | 8 | class MyLightningArgumentParser(LightningArgumentParser): 9 | def __init__(self, *args: Any, **kwargs: Any) -> None: 10 | super().__init__(*args, **kwargs) 11 | self.add_logger_args() 12 | 13 | def add_logger_args(self) -> None: 14 | # Common args 15 | self.add_argument( 16 | "--logger_type", 17 | type=str, 18 | help="Name of logger", 19 | default="csv", 20 | choices=["csv", "wandb", "tensorboard"], 21 | ) 22 | self.add_argument( 23 | "--save_path", 24 | type=str, 25 | help="Save path of outputs", 26 | default="output/", 27 | ) 28 | self.add_argument("--name", type=str, help="name of run", default="default") 29 | 30 | # Wandb 31 | self.add_argument( 32 | "--project", type=str, help="Name of wandb project", default="default" 33 | ) 34 | 35 | 36 | def init_logger(args: dict) -> LightningLoggerBase: # type:ignore 37 | """Initialize logger from arguments 38 | 39 | Args: 40 | args: parsed argument dictionary 41 | 42 | Returns: 43 | logger: initialized logger object 44 | """ 45 | if args["logger_type"] == "wandb": 46 | return WandbLogger( 47 | project=args["project"], 48 | name=args["name"], 49 | save_dir=args["save_path"], 50 | ) 51 | elif args["logger_type"] == "tensorboard": 52 | return TensorBoardLogger(name=args["name"], save_dir=args["save_path"]) 53 | elif args["logger_type"] == "csv": 54 | return CSVLogger(name=args["name"], save_dir=args["save_path"]) 55 | else: 56 | ValueError( 57 | f"{args['logger_type']} is not an available logger. Should be one of ['cvs', 'wandb', 'tensorboard']" 58 | ) 59 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | import pytorch_lightning as pl 4 | 5 | from src.data import SimpleDataModule 6 | from src.model import DecoderDenoisingModel 7 | from src.pl_utils import MyLightningArgumentParser, init_logger 8 | 9 | model_class = DecoderDenoisingModel 10 | dm_class = SimpleDataModule 11 | 12 | # Parse arguments 13 | parser = MyLightningArgumentParser() 14 | parser.add_lightning_class_args(pl.Trainer, None) # type:ignore 15 | parser.add_lightning_class_args(dm_class, "data") 16 | parser.add_lightning_class_args(model_class, "model") 17 | 18 | args = parser.parse_args() 19 | 20 | # Setup trainer 21 | logger = init_logger(args) 22 | dm = dm_class(**args["data"]) 23 | model = model_class(**args["model"]) 24 | trainer = pl.Trainer.from_argparse_args(Namespace(**args), logger=logger) 25 | 26 | # Train 27 | trainer.tune(model, dm) 28 | trainer.fit(model, dm) 29 | --------------------------------------------------------------------------------