├── trainer ├── VERSION ├── utils │ ├── __init__.py │ ├── cpu_memory.py │ ├── distributed.py │ └── cuda_memory.py ├── logger.py ├── __init__.py ├── analytics.py ├── TODO.txt ├── logging │ ├── __init__.py │ ├── dummy_logger.py │ ├── clearml_logger.py │ ├── base_dash_logger.py │ ├── tensorboard_logger.py │ ├── wandb_logger.py │ ├── console_logger.py │ ├── mlflow_logger.py │ └── aim_logger.py ├── distribute.py ├── torch.py ├── generic_utils.py ├── trainer_utils.py ├── model.py ├── callbacks.py └── io.py ├── requirements.test.txt ├── requirements.txt ├── requirements.dev.txt ├── setup.cfg ├── tests ├── __init__.py ├── test_train_mnist.py ├── test_train_batch_size_finder.py ├── utils │ ├── train_mnist.py │ └── mnist.py ├── test_continue_train.py ├── test_lr_schedulers.py ├── test_num_gpus.py └── test_train_gan.py ├── .github ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── feature_request.md │ └── bug_report.yaml └── workflows │ ├── style_check.yml │ ├── tests.yml │ └── pypi-release.yml ├── MANIFEST.in ├── pyproject.toml ├── Makefile ├── bin └── collect_env_info.py ├── .gitignore ├── examples ├── train_mnist.py └── train_simple_gan.py ├── CONTRIBUTING.md ├── setup.py ├── CODE_OF_CONDUCT.md ├── README.md └── .pylintrc /trainer/VERSION: -------------------------------------------------------------------------------- 1 | v0.0.36 2 | -------------------------------------------------------------------------------- /trainer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.test.txt: -------------------------------------------------------------------------------- 1 | torchvision -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7 2 | coqpit 3 | psutil 4 | fsspec 5 | tensorboard 6 | soundfile -------------------------------------------------------------------------------- /requirements.dev.txt: -------------------------------------------------------------------------------- 1 | black 2 | coverage 3 | isort 4 | pytest 5 | pylint 6 | accelerate # for testing 7 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [build_py] 2 | build-lib=temp_build 3 | 4 | [bdist_wheel] 5 | bdist-dir=temp_build 6 | 7 | [install_lib] 8 | build-dir=temp_build 9 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def run_cli(command): 5 | exit_status = os.system(command) 6 | assert exit_status == 0, f" [!] command `{command}` failed." 7 | -------------------------------------------------------------------------------- /trainer/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | handler = logging.StreamHandler() 4 | handler.setFormatter(logging.Formatter("")) 5 | logger = logging.getLogger("trainer") 6 | logger.addHandler(handler) 7 | logger.setLevel(logging.INFO) 8 | logger.propagate = False 9 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from trainer.model import * 4 | from trainer.trainer import * 5 | 6 | with open(os.path.join(os.path.dirname(__file__), "VERSION"), "r", encoding="utf-8") as f: 7 | version = f.read().strip() 8 | 9 | __version__ = version 10 | -------------------------------------------------------------------------------- /trainer/analytics.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import requests 4 | 5 | telemetry = os.environ.get("TRAINER_TELEMETRY") 6 | 7 | 8 | def ping_training_run(): 9 | if telemetry == "0": 10 | return 11 | URL = "https://coqui.gateway.scarf.sh/trainer/training_run" 12 | _ = requests.get(URL, timeout=5) 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: 👟 GitHub Discussions 4 | url: https://github.com/coqui-ai/Trainer/discussions 5 | about: Please ask and answer questions here. 6 | - name: Coqui Security issue disclosure 7 | url: mailto:info@coqui.ai 8 | about: Please report security vulnerabilities here. 9 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE.txt 3 | include requirements.*.txt 4 | include requirements.txt 5 | include trainer/VERSION 6 | recursive-include trainer *.json 7 | recursive-include trainer *.html 8 | recursive-include trainer *.png 9 | recursive-include trainer *.md 10 | recursive-include trainer *.py 11 | recursive-include trainer *.pyx 12 | recursive-include images *.png 13 | -------------------------------------------------------------------------------- /trainer/TODO.txt: -------------------------------------------------------------------------------- 1 | + Accumulate gradients b/w batches. 2 | + Abstract DashLogger 3 | + MLFlow logger 4 | + Profiler integration. 5 | + Moving `training_assets` to the model implementation. 6 | - Wrap model for not calling .module in DDP. 7 | - Overfitting to a batch. 8 | - TPU training 9 | - BaseTrainingConfig 10 | - Add Checkpoint manager 11 | - Use `logging` instead of `print` 12 | - Auto scaling the batch size and find the largest batch size for training. 13 | - Stochastic weight averaging 14 | - Deepspeed integration 15 | -------------------------------------------------------------------------------- /tests/test_train_mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from tests.utils.mnist import MnistModel, MnistModelConfig 6 | from trainer import Trainer, TrainerArgs 7 | 8 | is_cuda = torch.cuda.is_available() 9 | 10 | 11 | def test_train_mnist(): 12 | model = MnistModel() 13 | trainer = Trainer( 14 | TrainerArgs(), MnistModelConfig(), model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None 15 | ) 16 | 17 | trainer.fit() 18 | loss1 = trainer.keep_avg_train["avg_loss"] 19 | 20 | trainer.fit() 21 | loss2 = trainer.keep_avg_train["avg_loss"] 22 | 23 | assert loss1 > loss2 24 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | 4 | [flake8] 5 | max-line-length=120 6 | 7 | [tool.black] 8 | line-length = 120 9 | target-version = ['py38'] 10 | exclude = ''' 11 | 12 | ( 13 | /( 14 | \.eggs # exclude a few common directories in the 15 | | \.git # root of the project 16 | | \.hg 17 | | \.mypy_cache 18 | | \.tox 19 | | \.venv 20 | | _build 21 | | buck-out 22 | | build 23 | | dist 24 | )/ 25 | | foo.py # also separately exclude a file named foo.py in 26 | # the root of the project 27 | ) 28 | ''' 29 | 30 | [tool.isort] 31 | profile = "black" 32 | multi_line_output = 3 -------------------------------------------------------------------------------- /tests/test_train_batch_size_finder.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from tests.utils.mnist import MnistModel, MnistModelConfig 6 | from trainer import Trainer, TrainerArgs 7 | 8 | is_cuda = torch.cuda.is_available() 9 | 10 | 11 | def test_train_largest_batch_mnist(): 12 | model = MnistModel() 13 | trainer = Trainer( 14 | TrainerArgs(), MnistModelConfig(), model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None 15 | ) 16 | 17 | trainer.fit_with_largest_batch_size(starting_batch_size=2048) 18 | loss1 = trainer.keep_avg_train["avg_loss"] 19 | 20 | trainer.fit_with_largest_batch_size(starting_batch_size=2048) 21 | loss2 = trainer.keep_avg_train["avg_loss"] 22 | 23 | assert loss1 > loss2 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🚀 Feature request 3 | about: Suggest a feature or an idea for this project 4 | title: '[Feature request] ' 5 | labels: feature request 6 | assignees: '' 7 | 8 | --- 9 | 11 | **🚀 Feature Description** 12 | 13 | 14 | 15 | **Solution** 16 | 17 | 18 | 19 | **Alternative Solutions** 20 | 21 | 22 | 23 | **Additional context** 24 | 25 | 26 | -------------------------------------------------------------------------------- /tests/utils/train_mnist.py: -------------------------------------------------------------------------------- 1 | from distutils.command.config import config 2 | 3 | from mnist import MnistModel, MnistModelConfig 4 | 5 | from trainer import Trainer, TrainerArgs 6 | 7 | 8 | def main(): 9 | """Run `MNIST` model training from scratch or from previous checkpoint.""" 10 | # init args and config 11 | train_args = TrainerArgs() 12 | config = MnistModelConfig() 13 | 14 | # init the model from config 15 | model = MnistModel() 16 | 17 | # init the trainer and 🚀 18 | trainer = Trainer( 19 | train_args, 20 | config, 21 | config.output_path, 22 | model=model, 23 | train_samples=model.get_data_loader(config, None, False, None, None, None), 24 | eval_samples=model.get_data_loader(config, None, True, None, None, None), 25 | parse_command_line_args=True, 26 | ) 27 | trainer.fit() 28 | 29 | 30 | if __name__ == "__main__": 31 | main() 32 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_GOAL := help 2 | .PHONY: test system-deps dev-deps deps style lint install help docs 3 | 4 | help: 5 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 6 | 7 | target_dirs := tests trainer 8 | 9 | test_all: ## run tests and don't stop on an error. 10 | coverage run -m pytest trainer tests 11 | 12 | test: ## run tests. 13 | coverage run -m pytest -x trainer tests 14 | 15 | test_failed: ## only run tests failed the last time. 16 | coverage run -m pytest --ff trainer tests 17 | 18 | style: ## update code style. 19 | black ${target_dirs} 20 | isort ${target_dirs} 21 | 22 | lint: ## run pylint linter. 23 | pylint ${target_dirs} 24 | 25 | dev-deps: ## install development deps 26 | pip install -r requirements.dev.txt 27 | 28 | doc-deps: ## install docs dependencies 29 | pip install -r docs/requirements.txt 30 | 31 | build-docs: ## build the docs 32 | cd docs && make clean && make build 33 | 34 | deps: ## install 🐸 requirements. 35 | pip install -r requirements.txt 36 | 37 | install: ## install 🐸 Trainer for development. 38 | pip install -e .[all] 39 | 40 | docs: ## build the docs 41 | $(MAKE) -C docs clean && $(MAKE) -C docs html 42 | -------------------------------------------------------------------------------- /tests/test_continue_train.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import shutil 4 | 5 | from tests import run_cli 6 | 7 | 8 | def test_continue_train(): 9 | output_path = "output/" 10 | 11 | command_train = "python tests/utils/train_mnist.py" 12 | run_cli(command_train) 13 | 14 | continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) 15 | number_of_checkpoints = len(glob.glob(os.path.join(continue_path, "*.pth"))) 16 | 17 | # Continue training from the best model 18 | command_continue = f"python tests/utils/train_mnist.py --continue_path {continue_path} --coqpit.run_eval_steps=1" 19 | run_cli(command_continue) 20 | 21 | assert number_of_checkpoints < len(glob.glob(os.path.join(continue_path, "*.pth"))) 22 | 23 | # Continue training from the last checkpoint 24 | for best in glob.glob(os.path.join(continue_path, "best_model*")): 25 | os.remove(best) 26 | run_cli(command_continue) 27 | 28 | # Continue training from a specific checkpoint 29 | restore_path = os.path.join(continue_path, "checkpoint_5.pth") 30 | command_continue = f"python tests/utils/train_mnist.py --restore_path {restore_path}" 31 | run_cli(command_continue) 32 | shutil.rmtree(continue_path) 33 | -------------------------------------------------------------------------------- /bin/collect_env_info.py: -------------------------------------------------------------------------------- 1 | """Get detailed info about the working environment.""" 2 | import os 3 | import platform 4 | import sys 5 | 6 | import numpy 7 | import torch 8 | 9 | sys.path += [os.path.abspath(".."), os.path.abspath(".")] 10 | import json 11 | 12 | import trainer 13 | 14 | 15 | def system_info(): 16 | return { 17 | "OS": platform.system(), 18 | "architecture": platform.architecture(), 19 | "version": platform.version(), 20 | "processor": platform.processor(), 21 | "python": platform.python_version(), 22 | } 23 | 24 | 25 | def cuda_info(): 26 | return { 27 | "GPU": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())], 28 | "available": torch.cuda.is_available(), 29 | "version": torch.version.cuda, 30 | } 31 | 32 | 33 | def package_info(): 34 | return { 35 | "numpy": numpy.__version__, 36 | "PyTorch_version": torch.__version__, 37 | "PyTorch_debug": torch.version.debug, 38 | "Trainer": trainer.__version__, 39 | } 40 | 41 | 42 | def main(): 43 | details = {"System": system_info(), "CUDA": cuda_info(), "Packages": package_info()} 44 | print(json.dumps(details, indent=4, sort_keys=True)) 45 | 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /tests/test_lr_schedulers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import torch 5 | 6 | from tests.utils.mnist import MnistModel, MnistModelConfig 7 | from trainer import Trainer, TrainerArgs 8 | from trainer.generic_utils import KeepAverage 9 | 10 | is_cuda = torch.cuda.is_available() 11 | 12 | 13 | def test_train_mnist(): 14 | model = MnistModel() 15 | # Test StepwiseGradualLR 16 | config = MnistModelConfig( 17 | lr_scheduler="StepwiseGradualLR", 18 | lr_scheduler_params={ 19 | "gradual_learning_rates": [ 20 | [0, 1e-3], 21 | [2, 1e-4], 22 | ] 23 | }, 24 | scheduler_after_epoch=False, 25 | ) 26 | trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None) 27 | trainer.train_loader = trainer.get_train_dataloader( 28 | trainer.training_assets, 29 | trainer.train_samples, 30 | verbose=True, 31 | ) 32 | trainer.keep_avg_train = KeepAverage() 33 | 34 | lr_0 = trainer.scheduler.get_lr() 35 | trainer.train_step(next(iter(trainer.train_loader)), len(trainer.train_loader), 0, time.time()) 36 | lr_1 = trainer.scheduler.get_lr() 37 | trainer.train_step(next(iter(trainer.train_loader)), len(trainer.train_loader), 1, time.time()) 38 | lr_2 = trainer.scheduler.get_lr() 39 | assert lr_0 == 1e-3 40 | assert lr_1 == 1e-3 41 | assert lr_2 == 1e-4 42 | -------------------------------------------------------------------------------- /.github/workflows/style_check.yml: -------------------------------------------------------------------------------- 1 | name: style-check 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | types: [opened, synchronize, reopened] 9 | jobs: 10 | check_skip: 11 | runs-on: ubuntu-latest 12 | if: "! contains(github.event.head_commit.message, '[ci skip]')" 13 | steps: 14 | - run: echo "${{ github.event.head_commit.message }}" 15 | 16 | test: 17 | runs-on: ubuntu-latest 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | python-version: [3.9] 22 | experimental: [false] 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v4 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | architecture: x64 30 | cache: 'pip' 31 | cache-dependency-path: 'requirements*' 32 | - name: check OS 33 | run: cat /etc/os-release 34 | - name: Install dependencies 35 | run: | 36 | sudo apt-get update 37 | sudo apt-get install -y git make gcc 38 | make system-deps 39 | - name: Install/upgrade Python setup deps 40 | run: python3 -m pip install --upgrade pip setuptools wheel 41 | - name: Install Trainer 42 | run: | 43 | python3 -m pip install .[all] 44 | python3 setup.py egg_info 45 | - name: Lint check 46 | run: | 47 | make lint -------------------------------------------------------------------------------- /trainer/utils/cpu_memory.py: -------------------------------------------------------------------------------- 1 | def get_available_cpu_memory(): 2 | import psutil # pylint: disable=import-outside-toplevel 3 | 4 | this_process = psutil.Process() 5 | available_memory = psutil.virtual_memory().available 6 | 7 | try: 8 | import resource # pylint: disable=import-outside-toplevel 9 | 10 | _, hard_mem_limit = resource.getrlimit(resource.RLIMIT_AS) # pylint: disable=unused-variable 11 | if hard_mem_limit != resource.RLIM_INFINITY: 12 | used_memory = this_process.memory_info().vms 13 | available_memory = min(hard_mem_limit - used_memory, available_memory) 14 | except ImportError: 15 | pass 16 | 17 | return available_memory 18 | 19 | 20 | def set_cpu_memory_limit(num_gigabytes): 21 | try: 22 | import resource # pylint: disable=import-outside-toplevel 23 | 24 | num_bytes = int(num_gigabytes * 2**30) 25 | _, hard_limit = resource.getrlimit(resource.RLIMIT_AS) 26 | if hard_limit != resource.RLIM_INFINITY: 27 | hard_limit = min(num_bytes, hard_limit) 28 | else: 29 | hard_limit = num_bytes 30 | resource.setrlimit(resource.RLIMIT_AS, (hard_limit, hard_limit)) 31 | except ImportError: 32 | pass 33 | 34 | 35 | def is_out_of_cpu_memory(exception): 36 | return ( 37 | isinstance(exception, RuntimeError) 38 | and len(exception.args) == 1 39 | and "DefaultCPUAllocator: can't allocate memory" in exception.args[0] 40 | ) 41 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | types: [opened, synchronize, reopened] 9 | jobs: 10 | check_skip: 11 | runs-on: ubuntu-latest 12 | if: "! contains(github.event.head_commit.message, '[ci skip]')" 13 | steps: 14 | - run: echo "${{ github.event.head_commit.message }}" 15 | 16 | test: 17 | runs-on: ubuntu-latest 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | python-version: [3.8, 3.9, "3.10", "3.11"] 22 | experimental: [false] 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v4 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | architecture: x64 30 | cache: 'pip' 31 | cache-dependency-path: 'requirements*' 32 | - name: check OS 33 | run: cat /etc/os-release 34 | - name: Telemetry off 35 | run: | 36 | export TRAINER_TELEMETRY=0 37 | - name: Install dependencies 38 | run: | 39 | sudo apt-get update 40 | sudo apt-get install -y --no-install-recommends git make gcc 41 | make system-deps 42 | - name: Install/upgrade Python setup deps 43 | run: python3 -m pip install --upgrade pip setuptools wheel 44 | - name: Install Trainer 45 | run: | 46 | python3 -m pip install .[all] 47 | python3 setup.py egg_info 48 | - name: Unit tests 49 | run: make test_all 50 | -------------------------------------------------------------------------------- /trainer/utils/distributed.py: -------------------------------------------------------------------------------- 1 | # edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py 2 | import os 3 | from functools import wraps 4 | from typing import Any, Callable, Optional 5 | 6 | import torch 7 | import torch.distributed as dist 8 | 9 | 10 | def is_dist_avail_and_initialized(): 11 | if not dist.is_available(): 12 | return False 13 | if not dist.is_initialized(): 14 | return False 15 | return True 16 | 17 | 18 | def get_rank(): 19 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") 20 | for key in rank_keys: 21 | rank = os.environ.get(key) 22 | if rank is not None: 23 | return int(rank) 24 | return 0 25 | 26 | 27 | def is_main_process(): 28 | return get_rank() == 0 29 | 30 | 31 | def rank_zero_only(fn: Callable) -> Callable: 32 | @wraps(fn) 33 | def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: 34 | if is_main_process(): 35 | return fn(*args, **kwargs) 36 | return None 37 | 38 | return wrapped_fn 39 | 40 | 41 | @rank_zero_only 42 | def rank_zero_print(message: str, *args, **kwargs) -> None: # pylint: disable=unused-argument 43 | print(message) 44 | 45 | 46 | @rank_zero_only 47 | def rank_zero_logger_info(message: str, logger: "Logger", *args, **kwargs) -> None: # pylint: disable=unused-argument 48 | logger.info(message) 49 | 50 | 51 | def reduce_tensor(tensor, num_gpus): 52 | rt = tensor.clone() 53 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 54 | rt /= num_gpus 55 | return rt 56 | 57 | 58 | def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url): 59 | assert torch.cuda.is_available(), "Distributed mode requires CUDA." 60 | 61 | # Set cuda device so everything is done on the right GPU. 62 | torch.cuda.set_device(rank % torch.cuda.device_count()) 63 | 64 | # Initialize distributed communication 65 | dist.init_process_group( 66 | dist_backend, 67 | init_method=dist_url, 68 | world_size=num_gpus, 69 | rank=rank, 70 | group_name=group_name, 71 | ) 72 | -------------------------------------------------------------------------------- /tests/utils/mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.utils.data import DataLoader 8 | from torchvision import transforms 9 | from torchvision.datasets import MNIST 10 | 11 | from trainer import TrainerConfig, TrainerModel 12 | 13 | 14 | @dataclass 15 | class MnistModelConfig(TrainerConfig): 16 | optimizer: str = "Adam" 17 | lr: float = 0.001 18 | epochs: int = 1 19 | print_step: int = 1 20 | save_step: int = 5 21 | plot_step: int = 5 22 | dashboard_logger: str = "tensorboard" 23 | 24 | 25 | class MnistModel(TrainerModel): 26 | def __init__(self): 27 | super().__init__() 28 | 29 | # mnist images are (1, 28, 28) (channels, height, width) 30 | self.layer_1 = nn.Linear(28 * 28, 128) 31 | self.layer_2 = nn.Linear(128, 256) 32 | self.layer_3 = nn.Linear(256, 10) 33 | 34 | def forward(self, x): 35 | batch_size, _, _, _ = x.size() 36 | 37 | # (b, 1, 28, 28) -> (b, 1*28*28) 38 | x = x.view(batch_size, -1) 39 | x = self.layer_1(x) 40 | x = F.relu(x) 41 | x = self.layer_2(x) 42 | x = F.relu(x) 43 | x = self.layer_3(x) 44 | 45 | x = F.log_softmax(x, dim=1) 46 | return x 47 | 48 | def train_step(self, batch, criterion): 49 | x, y = batch 50 | logits = self(x) 51 | loss = criterion(logits, y) 52 | return {"model_outputs": logits}, {"loss": loss} 53 | 54 | def eval_step(self, batch, criterion): 55 | x, y = batch 56 | logits = self(x) 57 | loss = criterion(logits, y) 58 | return {"model_outputs": logits}, {"loss": loss} 59 | 60 | @staticmethod 61 | def get_criterion(): 62 | return torch.nn.NLLLoss() 63 | 64 | def get_data_loader( 65 | self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 66 | ): # pylint: disable=unused-argument 67 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 68 | dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) 69 | dataset.data = dataset.data[:256] 70 | dataset.targets = dataset.targets[:256] 71 | dataloader = DataLoader(dataset, batch_size=config.batch_size) 72 | return dataloader 73 | -------------------------------------------------------------------------------- /trainer/logging/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from trainer.logging.console_logger import ConsoleLogger 5 | from trainer.logging.dummy_logger import DummyLogger 6 | 7 | # pylint: disable=import-outside-toplevel 8 | 9 | 10 | logger = logging.getLogger("trainer") 11 | 12 | 13 | def get_mlflow_tracking_url(): 14 | if "MLFLOW_TRACKING_URI" in os.environ: 15 | return os.environ["MLFLOW_TRACKING_URI"] 16 | return None 17 | 18 | 19 | def get_ai_repo_url(): 20 | if "AIM_TRACKING_URI" in os.environ: 21 | return os.environ["AIM_TRACKING_URI"] 22 | return None 23 | 24 | 25 | def logger_factory(config, output_path): 26 | run_name = config.run_name 27 | project_name = config.project_name 28 | log_uri = config.logger_uri if config.logger_uri else output_path 29 | 30 | if config.dashboard_logger == "tensorboard": 31 | from trainer.logging.tensorboard_logger import TensorboardLogger 32 | 33 | model_name = f"{project_name}@{run_name}" if project_name else run_name 34 | dashboard_logger = TensorboardLogger(log_uri, model_name=model_name) 35 | 36 | logger.info(" > Start Tensorboard: tensorboard --logdir=%s", log_uri) 37 | 38 | elif config.dashboard_logger == "wandb": 39 | from trainer.logging.wandb_logger import WandbLogger 40 | 41 | dashboard_logger = WandbLogger( # pylint: disable=abstract-class-instantiated 42 | project=project_name, 43 | name=run_name, 44 | config=config, 45 | entity=config.wandb_entity, 46 | ) 47 | 48 | elif config.dashboard_logger == "mlflow": 49 | from trainer.logging.mlflow_logger import MLFlowLogger 50 | 51 | dashboard_logger = MLFlowLogger(log_uri=log_uri, model_name=project_name) 52 | 53 | elif config.dashboard_logger == "aim": 54 | from trainer.logging.aim_logger import AimLogger 55 | 56 | dashboard_logger = AimLogger(repo=log_uri, model_name=project_name) 57 | 58 | elif config.dashboard_logger == "clearml": 59 | from trainer.logging.clearml_logger import ClearMLLogger 60 | 61 | dashboard_logger = ClearMLLogger( 62 | output_uri=log_uri, local_path=output_path, project_name=project_name, task_name=run_name 63 | ) 64 | 65 | else: 66 | raise ValueError(f"Unknown dashboard logger: {config.dashboard_logger}") 67 | 68 | return dashboard_logger 69 | -------------------------------------------------------------------------------- /trainer/distribute.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import pathlib 6 | import subprocess 7 | import time 8 | 9 | from trainer import TrainerArgs, logger 10 | 11 | 12 | def distribute(): 13 | """ 14 | Call 👟Trainer training script in DDP mode. 15 | """ 16 | parser = TrainerArgs().init_argparse(arg_prefix="") 17 | parser.add_argument("--script", type=str, help="Target training script to distibute.") 18 | parser.add_argument( 19 | "--gpus", 20 | type=str, 21 | help='GPU IDs to be used for distributed training in the format ```"0,1"```. Used if ```CUDA_VISIBLE_DEVICES``` is not set.', 22 | ) 23 | args, unargs = parser.parse_known_args() 24 | 25 | gpus = get_gpus(args) 26 | 27 | group_id = time.strftime("%Y_%m_%d-%H%M%S") 28 | 29 | # set arguments for train.py 30 | folder_path = pathlib.Path(__file__).parent.absolute() 31 | if os.path.exists(os.path.join(folder_path, args.script)): 32 | command = [os.path.join(folder_path, args.script)] 33 | else: 34 | command = [args.script] 35 | 36 | # Pass all the TrainerArgs fields 37 | command.append(f"--continue_path={args.continue_path}") 38 | command.append(f"--restore_path={args.restore_path}") 39 | command.append(f"--group_id=group_{group_id}") 40 | command.append("--use_ddp=true") 41 | command += unargs 42 | command.append("") 43 | 44 | # run processes 45 | processes = [] 46 | for rank, local_gpu_id in enumerate(gpus): 47 | my_env = os.environ.copy() 48 | my_env["PYTHON_EGG_CACHE"] = f"/tmp/tmp{local_gpu_id}" 49 | my_env["RANK"] = f"{rank}" 50 | my_env["CUDA_VISIBLE_DEVICES"] = f"{','.join(gpus)}" 51 | command[-1] = f"--rank={rank}" 52 | # prevent stdout for processes with rank != 0 53 | stdout = None 54 | p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with 55 | processes.append(p) 56 | logger.info(command) 57 | 58 | for p in processes: 59 | p.wait() 60 | 61 | 62 | def get_gpus(args): 63 | # set active gpus from CUDA_VISIBLE_DEVICES or --gpus 64 | if "CUDA_VISIBLE_DEVICES" in os.environ and os.environ["CUDA_VISIBLE_DEVICES"] != "": 65 | gpus = os.environ["CUDA_VISIBLE_DEVICES"] 66 | else: 67 | gpus = args.gpus 68 | gpus = list(map(str.strip, gpus.split(","))) 69 | return gpus 70 | 71 | 72 | if __name__ == "__main__": 73 | distribute() 74 | -------------------------------------------------------------------------------- /tests/test_num_gpus.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from argparse import Namespace 4 | from unittest import TestCase, mock 5 | 6 | from trainer import TrainerArgs 7 | from trainer.distribute import get_gpus 8 | 9 | 10 | class TestGpusStringParsingMethods(TestCase): 11 | @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) 12 | def test_parse_gpus_set_in_env_var_and_args(self): 13 | args = Namespace(gpus="0,1") 14 | gpus = get_gpus(args) 15 | expected_value = ["0"] 16 | self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value)) 17 | 18 | @mock.patch.dict(os.environ, {}) 19 | def test_parse_gpus_set_in_args(self): 20 | _old = None 21 | # this is to handle the case when CUDA_VISIBLE_DEVICES is set while running the tests 22 | if "CUDA_VISIBLE_DEVICES" in os.environ: 23 | _old = os.environ["CUDA_VISIBLE_DEVICES"] 24 | del os.environ["CUDA_VISIBLE_DEVICES"] 25 | args = Namespace(gpus="0,1") 26 | gpus = get_gpus(args) 27 | expected_value = ["0", "1"] 28 | self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value)) 29 | if _old is not None: 30 | os.environ["CUDA_VISIBLE_DEVICES"] = _old 31 | 32 | @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) 33 | def test_parse_gpus_set_in_env_var(self): 34 | args = Namespace() 35 | gpus = get_gpus(args) 36 | expected_value = ["0", "1"] 37 | self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value)) 38 | 39 | @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0, 1 "}) 40 | def test_parse_gpus_set_in_env_var_with_spaces(self): 41 | args = Namespace() 42 | gpus = get_gpus(args) 43 | expected_value = ["0", "1"] 44 | self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value)) 45 | 46 | @mock.patch.dict(os.environ, {}) 47 | def test_parse_gpus_set_in_args_with_spaces(self): 48 | args = Namespace(gpus="0, 1, 2, 3 ") 49 | gpus = get_gpus(args) 50 | expected_value = ["0", "1", "2", "3"] 51 | self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value)) 52 | 53 | 54 | def msg_for_test_failure(expected_value): 55 | return "GPU Values are expected to be " + str(expected_value) 56 | 57 | 58 | def create_args_parser(): 59 | parser = TrainerArgs().init_argparse(arg_prefix="") 60 | parser.add_argument("--gpus", type=str) 61 | return parser 62 | 63 | 64 | if __name__ == "__main__": 65 | unittest.main() 66 | -------------------------------------------------------------------------------- /trainer/logging/dummy_logger.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | from trainer.logging.base_dash_logger import BaseDashboardLogger 4 | 5 | 6 | class DummyLogger(BaseDashboardLogger): 7 | """DummyLogger that implements the API but does nothing""" 8 | 9 | def add_scalar(self, title: str, value: float, step: int) -> None: 10 | pass 11 | 12 | def add_figure( 13 | self, 14 | title: str, 15 | figure: Union["matplotlib.figure.Figure", "plotly.graph_objects.Figure"], 16 | step: int, 17 | ) -> None: 18 | pass 19 | 20 | def add_config(self, config): 21 | pass 22 | 23 | def add_audio(self, title: str, audio: "np.ndarray", step: int, sample_rate: int) -> None: 24 | pass 25 | 26 | def add_text(self, title: str, text: str, step: int) -> None: 27 | pass 28 | 29 | def add_artifact(self, file_or_dir: str, name: str, artifact_type: str, aliases=None): 30 | pass 31 | 32 | def add_scalars(self, scope_name: str, scalars: Dict, step: int): 33 | pass 34 | 35 | def add_figures(self, scope_name: str, figures: Dict, step: int): 36 | pass 37 | 38 | def add_audios(self, scope_name: str, audios: Dict, step: int, sample_rate: int): 39 | pass 40 | 41 | def flush(self): 42 | pass 43 | 44 | def finish(self): 45 | pass 46 | 47 | def train_step_stats(self, step, stats): 48 | self.add_scalars(scope_name="TrainIterStats", scalars=stats, step=step) 49 | 50 | def train_epoch_stats(self, step, stats): 51 | self.add_scalars(scope_name="TrainEpochStats", scalars=stats, step=step) 52 | 53 | def train_figures(self, step, figures): 54 | self.add_figures(scope_name="TrainFigures", figures=figures, step=step) 55 | 56 | def train_audios(self, step, audios, sample_rate): 57 | self.add_audios(scope_name="TrainAudios", audios=audios, step=step, sample_rate=sample_rate) 58 | 59 | def eval_stats(self, step, stats): 60 | self.add_scalars(scope_name="EvalStats", scalars=stats, step=step) 61 | 62 | def eval_figures(self, step, figures): 63 | self.add_figures(scope_name="EvalFigures", figures=figures, step=step) 64 | 65 | def eval_audios(self, step, audios, sample_rate): 66 | self.add_audios(scope_name="EvalAudios", audios=audios, step=step, sample_rate=sample_rate) 67 | 68 | def test_audios(self, step, audios, sample_rate): 69 | self.add_audios(scope_name="TestAudios", audios=audios, step=step, sample_rate=sample_rate) 70 | 71 | def test_figures(self, step, figures): 72 | self.add_figures(scope_name="TestFigures", figures=figures, step=step) 73 | -------------------------------------------------------------------------------- /trainer/logging/clearml_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any 3 | 4 | import torch 5 | 6 | from trainer.logging.tensorboard_logger import TensorboardLogger 7 | from trainer.trainer_utils import is_clearml_available 8 | from trainer.utils.distributed import rank_zero_only 9 | 10 | if is_clearml_available(): 11 | from clearml import Task # pylint: disable=import-error 12 | else: 13 | raise ImportError("ClearML is not installed. Please install it with `pip install clearml`") 14 | 15 | 16 | class ClearMLLogger(TensorboardLogger): 17 | """ClearML Logger using TensorBoard in the background. 18 | 19 | TODO: 20 | - Add hyperparameter handling 21 | - Use ClearML logger for plots 22 | - Handle continuing training 23 | 24 | Args: 25 | output_uri (str): URI of the ClearML repository. 26 | local_path (str): Path to the local directory where the model is saved. 27 | project_name (str): Name of the ClearML project. 28 | task_name (str): Name of the ClearML task. 29 | tags (str): Comma separated list of tags to add to the ClearML task. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | output_uri: str, 35 | local_path: str, 36 | project_name: str, 37 | task_name: str, 38 | tags: str = None, 39 | ): 40 | self._context = None 41 | self.local_path = local_path 42 | self.task_name = task_name 43 | self.tags = tags.split(",") if tags else [] 44 | self.run = Task.init(project_name=project_name, task_name=task_name, tags=self.tags, output_uri=output_uri) 45 | 46 | if tags: 47 | for tag in tags.split(","): 48 | self.run.add_tag(tag) 49 | 50 | super().__init__("run", None) 51 | 52 | @rank_zero_only 53 | def add_config(self, config): 54 | """Upload config file(s) to ClearML.""" 55 | self.add_text("run_config", f"{config.to_json()}", 0) 56 | self.run.connect_configuration(name="model_config", configuration=config.to_dict()) 57 | self.run.set_comment(config.run_description) 58 | self.run.upload_artifact("model_config", config.to_dict()) 59 | self.run.upload_artifact("configs", artifact_object=os.path.join(self.local_path, "*.json")) 60 | 61 | @rank_zero_only 62 | def add_artifact(self, file_or_dir, name, **kwargs): # pylint: disable=unused-argument, arguments-differ 63 | """Upload artifact to ClearML.""" 64 | self.run.upload_artifact(name, artifact_object=file_or_dir) 65 | 66 | @staticmethod 67 | @rank_zero_only 68 | def save_model(state: Any, path: str): 69 | torch.save(state, path) 70 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: "🐛 Bug report" 2 | description: Create a bug report to help 👟 improve 3 | title: '[Bug] ' 4 | labels: [ "bug" ] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Welcome to the 👟! Thanks for taking the time to fill out this bug report! 10 | 11 | - type: textarea 12 | id: bug-description 13 | attributes: 14 | label: Describe the bug 15 | description: A clear and concise description of what the bug is. If you intend to submit a PR for this issue, tell us in the description. Thanks! 16 | placeholder: Bug description 17 | validations: 18 | required: true 19 | 20 | - type: textarea 21 | id: reproduction 22 | attributes: 23 | label: To Reproduce 24 | description: | 25 | Please share your code to reproduce the error. 26 | 27 | Issues are fixed faster if you can provide a working example. 28 | 29 | The best place for sharing code is colab. https://colab.research.google.com/ 30 | So we can directly run your code and reproduce the issue. 31 | 32 | In the worse case, provide steps to reproduce the behavior. 33 | 34 | 1. Run the following command '...' 35 | 2. ... 36 | 3. See error 37 | placeholder: Reproduction 38 | validations: 39 | required: true 40 | 41 | - type: textarea 42 | id: expected-behavior 43 | attributes: 44 | label: Expected behavior 45 | description: "Write down what the expected behaviour" 46 | 47 | - type: textarea 48 | id: logs 49 | attributes: 50 | label: Logs 51 | description: "Please include the relevant logs if you can." 52 | render: shell 53 | 54 | - type: textarea 55 | id: system-info 56 | attributes: 57 | label: Environment 58 | description: | 59 | You can either run `trainer/bin/collect_env_info.py` 60 | 61 | ```bash 62 | wget https://raw.githubusercontent.com/coqui-ai/Trainer/main/bin/collect_env_info.py 63 | python collect_env_info.py 64 | ``` 65 | 66 | or fill in the fields below manually. 67 | render: shell 68 | placeholder: | 69 | - 👟 Version (e.g., 1.3.0): 70 | - PyTorch Version (e.g., 1.8) 71 | - Python version: 72 | - OS (e.g., Linux): 73 | - CUDA/cuDNN version: 74 | - GPU models and configuration: 75 | - How you installed PyTorch (`conda`, `pip`, source): 76 | - Any other relevant information: 77 | validations: 78 | required: true 79 | - type: textarea 80 | id: context 81 | attributes: 82 | label: Additional context 83 | description: Add any other context about the problem here. 84 | validations: 85 | required: false 86 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .noseids 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # pytype static type analyzer 136 | .pytype/ 137 | 138 | # Cython debug symbols 139 | cython_debug/ 140 | 141 | # custom list 142 | MNIST/ 143 | tests_local/ 144 | output/ 145 | -------------------------------------------------------------------------------- /.github/workflows/pypi-release.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distributions 📦 to PyPI 2 | on: 3 | release: 4 | types: [published] 5 | defaults: 6 | run: 7 | shell: 8 | bash 9 | jobs: 10 | build-sdist: 11 | runs-on: ubuntu-20.04 12 | steps: 13 | - uses: actions/checkout@v3 14 | - name: Verify tag matches version 15 | run: | 16 | set -ex 17 | version=$(cat trainer/VERSION) 18 | tag="${GITHUB_REF/refs\/tags\/}" 19 | if [[ "$version" != "$tag" ]]; then 20 | exit 1 21 | fi 22 | - uses: actions/checkout@v3 23 | with: 24 | python-version: 3.9 25 | - run: | 26 | python -m pip install -U pip setuptools wheel build 27 | - run: | 28 | python -m build 29 | - run: | 30 | pip install dist/*.tar.gz 31 | - uses: actions/upload-artifact@v2 32 | with: 33 | name: sdist 34 | path: dist/*.tar.gz 35 | build-wheels: 36 | runs-on: ubuntu-20.04 37 | strategy: 38 | matrix: 39 | python-version: ["3.8", "3.9", "3.10", "3.11"] 40 | steps: 41 | - uses: actions/checkout@v3 42 | - uses: actions/setup-python@v4 43 | with: 44 | python-version: ${{ matrix.python-version }} 45 | - run: | 46 | python -m pip install -U pip setuptools wheel build 47 | - run: | 48 | python -m build 49 | - run: | 50 | python -m pip install dist/*.whl 51 | - uses: actions/upload-artifact@v2 52 | with: 53 | name: wheel-${{ matrix.python-version }} 54 | path: dist/*.whl 55 | publish-artifacts: 56 | runs-on: ubuntu-20.04 57 | needs: [build-sdist, build-wheels] 58 | steps: 59 | - run: | 60 | mkdir dist 61 | - uses: actions/download-artifact@v2 62 | with: 63 | name: "sdist" 64 | path: "dist/" 65 | - uses: actions/download-artifact@v2 66 | with: 67 | name: "wheel-3.8" 68 | path: "dist/" 69 | - uses: actions/download-artifact@v2 70 | with: 71 | name: "wheel-3.9" 72 | path: "dist/" 73 | - uses: actions/download-artifact@v2 74 | with: 75 | name: "wheel-3.10" 76 | path: "dist/" 77 | - uses: actions/download-artifact@v2 78 | with: 79 | name: "wheel-3.11" 80 | path: "dist/" 81 | - run: | 82 | ls -lh dist/ 83 | - name: Setup PyPI config 84 | run: | 85 | cat << EOF > ~/.pypirc 86 | [pypi] 87 | username=__token__ 88 | password=${{ secrets.PYPI_TOKEN }} 89 | EOF 90 | - uses: actions/setup-python@v2 91 | with: 92 | python-version: 3.9 93 | - run: | 94 | python -m pip install twine 95 | - run: | 96 | twine upload --repository pypi dist/* 97 | -------------------------------------------------------------------------------- /trainer/utils/cuda_memory.py: -------------------------------------------------------------------------------- 1 | """ 2 | credit: https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py 3 | 4 | Helper to free Torch cuda memory and determine when a Torch exception might be 5 | because of OOM conditions. 6 | """ 7 | from __future__ import print_function 8 | 9 | import gc 10 | 11 | import torch 12 | 13 | from trainer.utils.cpu_memory import is_out_of_cpu_memory 14 | 15 | 16 | def gc_cuda(): 17 | """Gargage collect Torch (CUDA) memory.""" 18 | gc.collect() 19 | if torch.cuda.is_available(): 20 | torch.cuda.empty_cache() 21 | 22 | 23 | def get_cuda_total_memory(): 24 | if torch.cuda.is_available(): 25 | return torch.cuda.get_device_properties(0).total_memory 26 | return 0 27 | 28 | 29 | def get_cuda_assumed_available_memory(): 30 | if torch.cuda.is_available(): 31 | return get_cuda_total_memory() - torch.cuda.memory_reserved() 32 | return 0 33 | 34 | 35 | def get_cuda_available_memory(): 36 | # Always allow for 1 GB overhead. 37 | if torch.cuda.is_available(): 38 | return get_cuda_assumed_available_memory() - get_cuda_blocked_memory() 39 | return 0 40 | 41 | 42 | def get_cuda_blocked_memory(): 43 | if not torch.cuda.is_available(): 44 | return 0 45 | 46 | available_memory = get_cuda_assumed_available_memory() 47 | current_block = available_memory - 2**28 # 256 MB steps 48 | while True: 49 | try: 50 | _ = torch.empty((current_block,), dtype=torch.uint8, device="cuda") 51 | break 52 | except RuntimeError as exception: 53 | if is_cuda_out_of_memory(exception): 54 | current_block -= 2**30 55 | if current_block <= 0: 56 | return available_memory 57 | else: 58 | raise 59 | _ = None 60 | gc_cuda() 61 | return available_memory - current_block 62 | 63 | 64 | def is_cuda_out_of_memory(exception): 65 | return ( 66 | isinstance(exception, (RuntimeError, torch.cuda.OutOfMemoryError)) 67 | and len(exception.args) == 1 68 | and "CUDA out of memory." in exception.args[0] 69 | ) 70 | 71 | 72 | def is_cudnn_snafu(exception): 73 | # For/because of https://github.com/pytorch/pytorch/issues/4107 74 | return ( 75 | isinstance(exception, RuntimeError) 76 | and len(exception.args) == 1 77 | and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0] 78 | ) 79 | 80 | 81 | def cuda_meminfo(): 82 | if not torch.cuda.is_available(): 83 | return 84 | 85 | print( 86 | "Total:", torch.cuda.memory_allocated() / 2**30, " GB Cached: ", torch.cuda.memory_reserved() / 2**30, "GB" 87 | ) 88 | print( 89 | "Max Total:", 90 | torch.cuda.max_memory_allocated() / 2**30, 91 | " GB Max Cached: ", 92 | torch.cuda.max_memory_reserved() / 2**30, 93 | "GB", 94 | ) 95 | 96 | 97 | def should_reduce_batch_size(exception): 98 | return is_cuda_out_of_memory(exception) or is_cudnn_snafu(exception) or is_out_of_cpu_memory(exception) 99 | -------------------------------------------------------------------------------- /trainer/logging/base_dash_logger.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Union 3 | 4 | from trainer.io import save_fsspec 5 | from trainer.utils.distributed import rank_zero_only 6 | 7 | 8 | # pylint: disable=too-many-public-methods 9 | class BaseDashboardLogger(ABC): 10 | @abstractmethod 11 | def add_scalar(self, title: str, value: float, step: int) -> None: 12 | pass 13 | 14 | @abstractmethod 15 | def add_figure( 16 | self, 17 | title: str, 18 | figure: Union["matplotlib.figure.Figure", "plotly.graph_objects.Figure"], 19 | step: int, 20 | ) -> None: 21 | pass 22 | 23 | @abstractmethod 24 | def add_config(self, config): 25 | pass 26 | 27 | @abstractmethod 28 | def add_audio(self, title: str, audio: "np.ndarray", step: int, sample_rate: int) -> None: 29 | pass 30 | 31 | @abstractmethod 32 | def add_text(self, title: str, text: str, step: int) -> None: 33 | pass 34 | 35 | @abstractmethod 36 | def add_artifact(self, file_or_dir: str, name: str, artifact_type: str, aliases=None): 37 | pass 38 | 39 | @abstractmethod 40 | def add_scalars(self, scope_name: str, scalars: Dict, step: int): 41 | pass 42 | 43 | @abstractmethod 44 | def add_figures(self, scope_name: str, figures: Dict, step: int): 45 | pass 46 | 47 | @abstractmethod 48 | def add_audios(self, scope_name: str, audios: Dict, step: int, sample_rate: int): 49 | pass 50 | 51 | @abstractmethod 52 | def flush(self): 53 | pass 54 | 55 | @abstractmethod 56 | def finish(self): 57 | pass 58 | 59 | @staticmethod 60 | @rank_zero_only 61 | def save_model(state: Dict, path: str): 62 | save_fsspec(state, path) 63 | 64 | def train_step_stats(self, step, stats): 65 | self.add_scalars(scope_name="TrainIterStats", scalars=stats, step=step) 66 | 67 | def train_epoch_stats(self, step, stats): 68 | self.add_scalars(scope_name="TrainEpochStats", scalars=stats, step=step) 69 | 70 | def train_figures(self, step, figures): 71 | self.add_figures(scope_name="TrainFigures", figures=figures, step=step) 72 | 73 | def train_audios(self, step, audios, sample_rate): 74 | self.add_audios(scope_name="TrainAudios", audios=audios, step=step, sample_rate=sample_rate) 75 | 76 | def eval_stats(self, step, stats): 77 | self.add_scalars(scope_name="EvalStats", scalars=stats, step=step) 78 | 79 | def eval_figures(self, step, figures): 80 | self.add_figures(scope_name="EvalFigures", figures=figures, step=step) 81 | 82 | def eval_audios(self, step, audios, sample_rate): 83 | self.add_audios(scope_name="EvalAudios", audios=audios, step=step, sample_rate=sample_rate) 84 | 85 | def test_audios(self, step, audios, sample_rate): 86 | self.add_audios(scope_name="TestAudios", audios=audios, step=step, sample_rate=sample_rate) 87 | 88 | def test_figures(self, step, figures): 89 | self.add_figures(scope_name="TestFigures", figures=figures, step=step) 90 | -------------------------------------------------------------------------------- /trainer/logging/tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | 3 | from torch.utils.tensorboard import SummaryWriter 4 | 5 | from trainer.logging.base_dash_logger import BaseDashboardLogger 6 | 7 | 8 | class TensorboardLogger(BaseDashboardLogger): 9 | def __init__(self, log_dir, model_name): 10 | self.model_name = model_name 11 | self.writer = SummaryWriter(log_dir) 12 | 13 | def model_weights(self, model, step): 14 | layer_num = 1 15 | for name, param in model.named_parameters(): 16 | if param.numel() == 1: 17 | self.writer.add_scalar("layer{}-{}/value".format(layer_num, name), param.max(), step) 18 | else: 19 | self.writer.add_scalar("layer{}-{}/max".format(layer_num, name), param.max(), step) 20 | self.writer.add_scalar("layer{}-{}/min".format(layer_num, name), param.min(), step) 21 | self.writer.add_scalar("layer{}-{}/mean".format(layer_num, name), param.mean(), step) 22 | self.writer.add_scalar("layer{}-{}/std".format(layer_num, name), param.std(), step) 23 | self.writer.add_histogram("layer{}-{}/param".format(layer_num, name), param, step) 24 | self.writer.add_histogram("layer{}-{}/grad".format(layer_num, name), param.grad, step) 25 | layer_num += 1 26 | 27 | def add_config(self, config): 28 | self.add_text("model-config", f"
{config.to_json()}", 0)
29 |
30 | def add_scalar(self, title: str, value: float, step: int) -> None:
31 | self.writer.add_scalar(title, value, step)
32 |
33 | def add_audio(self, title, audio, step, sample_rate):
34 | self.writer.add_audio(title, audio, step, sample_rate=sample_rate)
35 |
36 | def add_text(self, title, text, step):
37 | self.writer.add_text(title, text, step)
38 |
39 | def add_figure(self, title, figure, step):
40 | self.writer.add_figure(title, figure, step)
41 |
42 | def add_artifact(self, file_or_dir, name, artifact_type, aliases=None): # pylint: disable=W0613
43 | yield
44 |
45 | def add_scalars(self, scope_name, scalars, step):
46 | for key, value in scalars.items():
47 | self.add_scalar("{}/{}".format(scope_name, key), value, step)
48 |
49 | def add_figures(self, scope_name, figures, step):
50 | for key, value in figures.items():
51 | self.writer.add_figure("{}/{}".format(scope_name, key), value, step)
52 |
53 | def add_audios(self, scope_name, audios, step, sample_rate):
54 | for key, value in audios.items():
55 | if value.dtype == "float16":
56 | value = value.astype("float32")
57 | try:
58 | self.add_audio(
59 | "{}/{}".format(scope_name, key),
60 | value,
61 | step,
62 | sample_rate=sample_rate,
63 | )
64 | except RuntimeError:
65 | traceback.print_exc()
66 |
67 | def flush(self):
68 | self.writer.flush()
69 |
70 | def finish(self):
71 | self.writer.close()
72 |
--------------------------------------------------------------------------------
/examples/train_mnist.py:
--------------------------------------------------------------------------------
1 | """
2 | This example shows training of a simple Conv model with MNIST dataset using Auto Training mode of 👟.
3 | """
4 |
5 | import os
6 | from dataclasses import dataclass
7 |
8 | import torch
9 | from torch import nn
10 | from torch.nn import functional as F
11 | from torch.utils.data import DataLoader
12 | from torchvision import transforms
13 | from torchvision.datasets import MNIST
14 |
15 | from trainer import TrainerConfig, TrainerModel, Trainer, TrainerArgs
16 |
17 |
18 | @dataclass
19 | class MnistModelConfig(TrainerConfig):
20 | optimizer: str = "Adam"
21 | lr: float = 0.001
22 | epochs: int = 1
23 | print_step: int = 1
24 | save_step: int = 5
25 | plot_step: int = 5
26 | dashboard_logger: str = "tensorboard"
27 |
28 |
29 | class MnistModel(TrainerModel):
30 | def __init__(self):
31 | super().__init__()
32 |
33 | # mnist images are (1, 28, 28) (channels, height, width)
34 | self.layer_1 = nn.Linear(28 * 28, 128)
35 | self.layer_2 = nn.Linear(128, 256)
36 | self.layer_3 = nn.Linear(256, 10)
37 |
38 | def forward(self, x):
39 | batch_size, _, _, _ = x.size()
40 |
41 | # (b, 1, 28, 28) -> (b, 1*28*28)
42 | x = x.view(batch_size, -1)
43 | x = self.layer_1(x)
44 | x = F.relu(x)
45 | x = self.layer_2(x)
46 | x = F.relu(x)
47 | x = self.layer_3(x)
48 |
49 | x = F.log_softmax(x, dim=1)
50 | return x
51 |
52 | def train_step(self, batch, criterion):
53 | x, y = batch
54 | logits = self(x)
55 | loss = criterion(logits, y)
56 | return {"model_outputs": logits}, {"loss": loss}
57 |
58 | def eval_step(self, batch, criterion):
59 | x, y = batch
60 | logits = self(x)
61 | loss = criterion(logits, y)
62 | return {"model_outputs": logits}, {"loss": loss}
63 |
64 | @staticmethod
65 | def get_criterion():
66 | return torch.nn.NLLLoss()
67 |
68 | def get_data_loader(
69 | self, config, assets, is_eval, samples, verbose, num_gpus, rank=0
70 | ): # pylint: disable=unused-argument
71 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
72 | dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform)
73 | dataset.data = dataset.data[:256]
74 | dataset.targets = dataset.targets[:256]
75 | dataloader = DataLoader(dataset, batch_size=config.batch_size)
76 | return dataloader
77 |
78 |
79 | def main():
80 | """Run `MNIST` model training from scratch or from previous checkpoint."""
81 | # init args and config
82 | train_args = TrainerArgs()
83 | config = MnistModelConfig()
84 |
85 | # init the model from config
86 | model = MnistModel()
87 |
88 | # init the trainer and 🚀
89 | trainer = Trainer(
90 | train_args,
91 | config,
92 | config.output_path,
93 | model=model,
94 | train_samples=model.get_data_loader(config, None, False, None, None, None),
95 | eval_samples=model.get_data_loader(config, None, True, None, None, None),
96 | parse_command_line_args=True,
97 | )
98 | trainer.fit()
99 |
100 |
101 | if __name__ == "__main__":
102 | main()
103 |
--------------------------------------------------------------------------------
/trainer/logging/wandb_logger.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=W0613
2 |
3 | import traceback
4 | from collections import defaultdict
5 | from pathlib import Path
6 | from typing import Union
7 |
8 | from trainer.logging.base_dash_logger import BaseDashboardLogger
9 | from trainer.trainer_utils import is_wandb_available
10 | from trainer.utils.distributed import rank_zero_only
11 |
12 | if is_wandb_available():
13 | import wandb # pylint: disable=import-error
14 |
15 |
16 | class WandbLogger(BaseDashboardLogger):
17 | def __init__(self, **kwargs):
18 | if not wandb:
19 | raise RuntimeError("install wandb using `pip install wandb` to use WandbLogger")
20 |
21 | self.run = None
22 | self.run = wandb.init(**kwargs) if not wandb.run else wandb.run
23 | self.model_name = self.run.config.model
24 | # dictionary of dictionaries - record stats per step
25 | self.log_dict = defaultdict(dict)
26 |
27 | def model_weights(self, model, step):
28 | layer_num = 1
29 | for name, param in model.named_parameters():
30 | if param.numel() == 1:
31 | self.add_scalars("weights", {"layer{}-{}/value".format(layer_num, name): param.max()}, step)
32 | else:
33 | self.add_scalars("weights", {"layer{}-{}/max".format(layer_num, name): param.max()}, step)
34 | self.add_scalars("weights", {"layer{}-{}/min".format(layer_num, name): param.min()}, step)
35 | self.add_scalars("weights", {"layer{}-{}/mean".format(layer_num, name): param.mean()}, step)
36 | self.add_scalars("weights", {"layer{}-{}/std".format(layer_num, name): param.std()}, step)
37 | self.log_dict[step]["weights/layer{}-{}/param".format(layer_num, name)] = wandb.Histogram(param)
38 | self.log_dict[step]["weights/layer{}-{}/grad".format(layer_num, name)] = wandb.Histogram(param.grad)
39 | layer_num += 1
40 |
41 | def add_scalars(self, scope_name, scalars, step):
42 | for key, value in scalars.items():
43 | self.log_dict[step]["{}/{}".format(scope_name, key)] = value
44 |
45 | def add_figures(self, scope_name, figures, step):
46 | for key, value in figures.items():
47 | self.log_dict[step]["{}/{}".format(scope_name, key)] = wandb.Image(value)
48 |
49 | def add_audios(self, scope_name, audios, step, sample_rate):
50 | for key, value in audios.items():
51 | if value.dtype == "float16":
52 | value = value.astype("float32")
53 | try:
54 | self.log_dict[step]["{}/{}".format(scope_name, key)] = wandb.Audio(value, sample_rate=sample_rate)
55 | except RuntimeError:
56 | traceback.print_exc()
57 |
58 | def add_text(self, title, text, step):
59 | pass
60 |
61 | def add_scalar(self, title: str, value: float, step: int) -> None:
62 | pass
63 |
64 | def add_figure(
65 | self,
66 | title: str,
67 | figure: Union["matplotlib.figure.Figure", "plotly.graph_objects.Figure"],
68 | step: int,
69 | ) -> None:
70 | pass
71 |
72 | def add_audio(self, title: str, audio: "np.ndarray", step: int, sample_rate: int) -> None:
73 | pass
74 |
75 | @rank_zero_only
76 | def add_config(self, config):
77 | pass
78 |
79 | def flush(self):
80 | if self.run:
81 | for step in sorted(self.log_dict.keys()):
82 | wandb.log(self.log_dict[step], step)
83 | self.log_dict.clear()
84 |
85 | def finish(self):
86 | if self.run:
87 | self.run.finish()
88 |
89 | def add_artifact(self, file_or_dir, name, artifact_type, aliases=None):
90 | if not self.run:
91 | return
92 | name = "_".join([self.run.id, name])
93 | artifact = wandb.Artifact(name, type=artifact_type)
94 | data_path = Path(file_or_dir)
95 | if data_path.is_dir():
96 | artifact.add_dir(str(data_path))
97 | elif data_path.is_file():
98 | artifact.add_file(str(data_path))
99 |
100 | self.run.log_artifact(artifact, aliases=aliases)
101 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution guidelines
2 |
3 | Welcome to the 👟!
4 |
5 | This repository is governed by [the Contributor Covenant Code of Conduct](https://github.com/coqui-ai/Trainer/blob/main/CODE_OF_CONDUCT.md).
6 |
7 | ## Where to start.
8 | We welcome everyone who likes to contribute to 👟.
9 |
10 | You can contribute not only with code but with bug reports, comments, questions, answers, or just a simple tweet to spread the word.
11 |
12 | If you like to contribute code, squash a bug but if you don't know where to start, here are some pointers.
13 |
14 | - [Github Issues Tracker](https://github.com/coqui-ai/Trainer/issues)
15 |
16 | This is a place to find feature requests, bugs.
17 |
18 | Issues with the ```good first issue``` tag are good place for beginners to take on.
19 |
20 | - ✨**PR**✨ [pages](https://github.com/coqui-ai/Trainer/pulls) with the ```🚀new version``` tag.
21 |
22 | We list all the target improvements for the next version. You can pick one of them and start contributing.
23 |
24 | - Also feel free to suggest new features. We're always open for new things.
25 |
26 | ## Sending a ✨**PR**✨
27 |
28 | If you have a new feature or a bug to squash, go ahead and send a ✨**PR**✨.
29 | Please use the following steps for a ✨**PR**✨.
30 | Let us know if you encounter a problem along the way.
31 |
32 | The following steps are tested on an Ubuntu system.
33 |
34 | 1. Fork 👟[https://github.com/coqui-ai/Trainer] by clicking the fork button at the top right corner of the project page.
35 |
36 | 2. Clone 👟 and add the main repo as a new remote named ```upsteam```.
37 |
38 | ```bash
39 | $ git clone git@github.com:{config.to_json()}", 0)
59 |
60 | def add_scalar(self, title, value, step):
61 | self.client.log_metric(self.run_id, title, value, step)
62 |
63 | def add_text(self, title, text, step):
64 | self.client.log_text(self.run_id, text, "{}/{}.txt".format(title, step))
65 |
66 | def add_figure(self, title, figure, step):
67 | self.client.log_figure(figure, "{}/{}.png".format(title, step))
68 |
69 | def add_artifact(self, file_or_dir, name, artifact_type, aliases=None): # pylint: disable=W0613, R0201
70 | self.client.log_artifacts(self.run_id, file_or_dir)
71 |
72 | def add_audio(self, title, audio, step, sample_rate):
73 | self.client.log_audio(self.run_id, audio, "{}/{}.wav".format(title, step), sample_rate)
74 |
75 | @rank_zero_only
76 | def add_scalars(self, scope_name, stats, step):
77 | for key, value in stats.items():
78 | if torch.is_tensor(value):
79 | value = value.item()
80 | self.client.log_metric(self.run_id, "{}-{}".format(scope_name, key), value, step)
81 |
82 | @rank_zero_only
83 | def add_figures(self, scope_name, figures, step):
84 | for key, value in figures.items():
85 | self.client.log_figure(self.run_id, value, "{}/{}/{}.png".format(scope_name, key, step))
86 |
87 | @rank_zero_only
88 | def add_audios(self, scope_name, audios, step, sample_rate):
89 | for key, value in audios.items():
90 | if value.dtype == "float16":
91 | value = value.astype("float32")
92 | try:
93 | tmp_audio_path = tempfile.NamedTemporaryFile(suffix=".wav")
94 | sf.write(tmp_audio_path, value, sample_rate)
95 | self.client.log_artifact(
96 | self.run_id,
97 | tmp_audio_path,
98 | "{}/{}/{}.wav".format(scope_name, key, step),
99 | )
100 | shutil.rmtree(tmp_audio_path)
101 | except RuntimeError:
102 | traceback.print_exc()
103 |
104 | def train_step_stats(self, step, stats):
105 | self.client.set_tag(self.run_id, "Mode", "training")
106 | super().train_step_stats(step, stats)
107 |
108 | def train_epoch_stats(self, step, stats):
109 | self.client.set_tag(self.run_id, "Mode", "training")
110 | super().train_epoch_stats(step, stats)
111 |
112 | def train_figures(self, step, figures):
113 | self.client.set_tag(self.run_id, "Mode", "training")
114 | super().train_figures(step, figures)
115 |
116 | def train_audios(self, step, audios, sample_rate):
117 | self.client.set_tag(self.run_id, "Mode", "training")
118 | super().train_audios(step, audios, sample_rate)
119 |
120 | def eval_stats(self, step, stats):
121 | self.client.set_tag(self.run_id, "Mode", "evaluation")
122 | super().eval_stats(step, stats)
123 |
124 | def eval_figures(self, step, figures):
125 | self.client.set_tag(self.run_id, "Mode", "evaluation")
126 | super().eval_figures(step, figures)
127 |
128 | def eval_audios(self, step, audios, sample_rate):
129 | self.client.set_tag(self.run_id, "Mode", "evaluation")
130 | super().eval_audios(step, audios, sample_rate)
131 |
132 | def test_audios(self, step, audios, sample_rate):
133 | self.client.set_tag(self.run_id, "Mode", "test")
134 | super().test_audios(step, audios, sample_rate)
135 |
136 | def test_figures(self, step, figures):
137 | self.client.set_tag(self.run_id, "Mode", "test")
138 | super().test_figures(step, figures)
139 |
140 | def flush(self):
141 | pass
142 |
143 | @rank_zero_only
144 | def finish(self):
145 | super().finalize(status)
146 | status = "FINISHED" if status == "success" else status
147 | if self.client.get_run(self.run_id):
148 | self.client.set_terminated(self.run_id, status)
149 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
{config.to_json()}", 0)
64 |
65 | def add_scalar(self, title, value, step):
66 | self.run.track(value, name=title, step=step, context=self.context)
67 |
68 | def add_text(self, title, text, step):
69 | self.run.track(
70 | Text(text), # Pass a string you want to track
71 | name=title, # The name of distributions
72 | step=step, # Step index (optional)
73 | context=self.context,
74 | )
75 |
76 | def add_figure(self, title, figure, step):
77 | self.run.track(
78 | Image(figure, title), # Pass image data and/or caption
79 | name=title, # The name of image set
80 | step=step, # Step index (optional)
81 | context=self.context,
82 | )
83 |
84 | def add_artifact(self, file_or_dir, name, artifact_type, aliases=None): # pylint: disable=W0613
85 | # AIM does not support artifacts
86 | ...
87 |
88 | def add_audio(self, title, audio, step, sample_rate):
89 | self.run.track(
90 | Audio(audio), # Pass audio file or numpy array
91 | name=title, # The name of distributions
92 | step=step, # Step index (optional)
93 | context=self.context,
94 | )
95 |
96 | @rank_zero_only
97 | def add_scalars(self, scope_name, scalars, step):
98 | for key, value in scalars.items():
99 | if torch.is_tensor(value):
100 | value = value.item()
101 | self.run.track(value, name="{}-{}".format(scope_name, key), step=step, context=self.context)
102 |
103 | @rank_zero_only
104 | def add_figures(self, scope_name, figures, step):
105 | for key, value in figures.items():
106 | title = "{}/{}/{}.png".format(scope_name, key, step)
107 | self.run.track(
108 | Image(value, title), # Pass image data and/or caption
109 | name=title, # The name of image set
110 | step=step, # Step index (optional)
111 | context=self.context,
112 | )
113 |
114 | @rank_zero_only
115 | def add_audios(self, scope_name, audios, step, sample_rate):
116 | for key, value in audios.items():
117 | title = "{}/{}/{}.wav".format(scope_name, key, step)
118 | self.run.track(
119 | Audio(value), # Pass audio file or numpy array
120 | name=title, # The name of distributions
121 | step=step, # Step index (optional)
122 | context=self.context,
123 | )
124 |
125 | def train_step_stats(self, step, stats):
126 | self.context = {"subset": "train"}
127 | super().train_step_stats(step, stats)
128 |
129 | def train_epoch_stats(self, step, stats):
130 | self.context = {"subset": "train"}
131 | super().train_epoch_stats(step, stats)
132 |
133 | def train_figures(self, step, figures):
134 | self.context = {"subset": "train"}
135 | super().train_figures(step, figures)
136 |
137 | def train_audios(self, step, audios, sample_rate):
138 | self.context = {"subset": "train"}
139 | super().train_audios(step, audios, sample_rate)
140 |
141 | def eval_stats(self, step, stats):
142 | self.context = {"subset": "eval"}
143 | super().eval_stats(step, stats)
144 |
145 | def eval_figures(self, step, figures):
146 | self.context = {"subset": "eval"}
147 | super().eval_figures(step, figures)
148 |
149 | def eval_audios(self, step, audios, sample_rate):
150 | self.context = {"subset": "eval"}
151 | super().eval_audios(step, audios, sample_rate)
152 |
153 | def test_audios(self, step, audios, sample_rate):
154 | self.context = {"subset": "test"}
155 | super().test_audios(step, audios, sample_rate)
156 |
157 | def test_figures(self, step, figures):
158 | self.context = {"subset": "test"}
159 | super().test_figures(step, figures)
160 |
161 | def flush(self):
162 | pass
163 |
164 | @rank_zero_only
165 | def finish(self):
166 | super().close()
167 |
--------------------------------------------------------------------------------
/trainer/trainer_utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | import random
4 | from typing import Dict, List, Tuple
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from trainer.logger import logger
10 | from trainer.torch import NoamLR, StepwiseGradualLR
11 | from trainer.utils.distributed import rank_zero_logger_info
12 |
13 |
14 | def is_apex_available():
15 | return importlib.util.find_spec("apex") is not None
16 |
17 |
18 | def is_mlflow_available():
19 | return importlib.util.find_spec("mlflow") is not None
20 |
21 |
22 | def is_aim_available():
23 | return importlib.util.find_spec("aim") is not None
24 |
25 |
26 | def is_wandb_available():
27 | return importlib.util.find_spec("wandb") is not None
28 |
29 |
30 | def is_clearml_available():
31 | return importlib.util.find_spec("clearml") is not None
32 |
33 |
34 | def print_training_env(args, config):
35 | """Print training environment."""
36 | rank_zero_logger_info(" > Training Environment:", logger)
37 |
38 | if args.use_accelerate:
39 | rank_zero_logger_info(" | > Backend: Accelerate", logger)
40 | else:
41 | rank_zero_logger_info(" | > Backend: Torch", logger)
42 |
43 | if config.mixed_precision:
44 | rank_zero_logger_info(" | > Mixed precision: True", logger)
45 | rank_zero_logger_info(f" | > Precision: {config.precision}", logger)
46 | else:
47 | rank_zero_logger_info(" | > Mixed precision: False", logger)
48 | rank_zero_logger_info(" | > Precision: float32", logger)
49 |
50 | if torch.cuda.is_available() and torch.cuda.device_count() > 0:
51 | rank_zero_logger_info(f" | > Current device: {torch.cuda.current_device()}", logger)
52 | rank_zero_logger_info(f" | > Num. of GPUs: {torch.cuda.device_count()}", logger)
53 |
54 | rank_zero_logger_info(f" | > Num. of CPUs: {os.cpu_count()}", logger)
55 | rank_zero_logger_info(f" | > Num. of Torch Threads: {torch.get_num_threads()}", logger)
56 | rank_zero_logger_info(f" | > Torch seed: {torch.initial_seed()}", logger)
57 | rank_zero_logger_info(f" | > Torch CUDNN: {torch.backends.cudnn.enabled}", logger)
58 | rank_zero_logger_info(f" | > Torch CUDNN deterministic: {torch.backends.cudnn.deterministic}", logger)
59 | rank_zero_logger_info(f" | > Torch CUDNN benchmark: {torch.backends.cudnn.benchmark}", logger)
60 | rank_zero_logger_info(f" | > Torch TF32 MatMul: {torch.backends.cuda.matmul.allow_tf32}", logger)
61 |
62 |
63 | def setup_torch_training_env(
64 | args: "TrainerArgs",
65 | cudnn_enable: bool,
66 | cudnn_benchmark: bool,
67 | cudnn_deterministic: bool,
68 | use_ddp: bool = False,
69 | training_seed=54321,
70 | allow_tf32: bool = False,
71 | gpu=None,
72 | ) -> Tuple[bool, int]:
73 | """Setup PyTorch environment for training.
74 |
75 | Args:
76 | cudnn_enable (bool): Enable/disable CUDNN.
77 | cudnn_benchmark (bool): Enable/disable CUDNN benchmarking. Better to set to False if input sequence length is
78 | variable between batches.
79 | cudnn_deterministic (bool): Enable/disable CUDNN deterministic mode.
80 | use_ddp (bool): DDP flag. True if DDP is enabled, False otherwise.
81 | allow_tf32 (bool): Enable/disable TF32. TF32 is only available on Ampere GPUs.
82 | torch_seed (int): Seed for torch random number generator.
83 |
84 | Returns:
85 | Tuple[bool, int]: is cuda on or off and number of GPUs in the environment.
86 | """
87 | # clear cache before training
88 | torch.cuda.empty_cache()
89 |
90 | # set_nvidia_flags
91 | # set the correct cuda visible devices (using pci order)
92 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
93 | if "CUDA_VISIBLE_DEVICES" not in os.environ and gpu is not None:
94 | torch.cuda.set_device(int(gpu))
95 | num_gpus = 1
96 | else:
97 | num_gpus = torch.cuda.device_count()
98 |
99 | if num_gpus > 1 and (not use_ddp and not args.use_accelerate):
100 | raise RuntimeError(
101 | f" [!] {num_gpus} active GPUs. Define the target GPU by `CUDA_VISIBLE_DEVICES`. For multi-gpu training use `TTS/bin/distribute.py`."
102 | )
103 |
104 | random.seed(training_seed)
105 | os.environ["PYTHONHASHSEED"] = str(training_seed)
106 | np.random.seed(training_seed)
107 | torch.manual_seed(training_seed)
108 | torch.cuda.manual_seed(training_seed)
109 |
110 | # set torch backend flags.
111 | # set them true if they are already set true
112 | torch.backends.cudnn.deterministic = cudnn_deterministic or torch.backends.cudnn.deterministic
113 | torch.backends.cudnn.enabled = cudnn_enable or torch.backends.cudnn.enabled
114 | torch.backends.cudnn.benchmark = cudnn_benchmark or torch.backends.cudnn.benchmark
115 | torch.backends.cuda.matmul.allow_tf32 = allow_tf32 or torch.backends.cuda.matmul.allow_tf32
116 |
117 | use_cuda = torch.cuda.is_available()
118 | return use_cuda, num_gpus
119 |
120 |
121 | def get_scheduler(
122 | lr_scheduler: str, lr_scheduler_params: Dict, optimizer: torch.optim.Optimizer
123 | ) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access
124 | """Find, initialize and return a Torch scheduler.
125 |
126 | Args:
127 | lr_scheduler (str): Scheduler name.
128 | lr_scheduler_params (Dict): Scheduler parameters.
129 | optimizer (torch.optim.Optimizer): Optimizer to pass to the scheduler.
130 |
131 | Returns:
132 | torch.optim.lr_scheduler._LRScheduler: Functional scheduler.
133 | """
134 | if lr_scheduler is None:
135 | return None
136 | if lr_scheduler.lower() == "noamlr":
137 | scheduler = NoamLR
138 | elif lr_scheduler.lower() == "stepwisegraduallr":
139 | scheduler = StepwiseGradualLR
140 | else:
141 | scheduler = getattr(torch.optim.lr_scheduler, lr_scheduler)
142 | return scheduler(optimizer, **lr_scheduler_params)
143 |
144 |
145 | def get_optimizer(
146 | optimizer_name: str,
147 | optimizer_params: dict,
148 | lr: float,
149 | model: torch.nn.Module = None,
150 | parameters: List = None,
151 | ) -> torch.optim.Optimizer:
152 | """Find, initialize and return a Torch optimizer.
153 |
154 | Args:
155 | optimizer_name (str): Optimizer name.
156 | optimizer_params (dict): Optimizer parameters.
157 | lr (float): Initial learning rate.
158 | model (torch.nn.Module): Model to pass to the optimizer.
159 |
160 | Returns:
161 | torch.optim.Optimizer: Functional optimizer.
162 | """
163 | if optimizer_name.lower() == "radam":
164 | module = importlib.import_module("TTS.utils.radam")
165 | optimizer = getattr(module, "RAdam")
166 | else:
167 | optimizer = getattr(torch.optim, optimizer_name)
168 | if model is not None:
169 | parameters = model.parameters()
170 | return optimizer(parameters, lr=lr, **optimizer_params)
171 |
--------------------------------------------------------------------------------
/trainer/model.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any, Dict, List, Tuple, Union
3 |
4 | import torch
5 | from coqpit import Coqpit
6 | from torch import nn
7 |
8 | from trainer.trainer_utils import is_apex_available
9 |
10 | if is_apex_available():
11 | from apex import amp
12 |
13 |
14 | # pylint: skip-file
15 |
16 |
17 | class TrainerModel(ABC, nn.Module):
18 | """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this."""
19 |
20 | @abstractmethod
21 | def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict:
22 | """Forward ... for the model mainly used in training.
23 |
24 | You can be flexible here and use different number of arguments and argument names since it is intended to be
25 | used by `train_step()` without exposing it out of the model.
26 |
27 | Args:
28 | input (torch.Tensor): Input tensor.
29 | aux_input (Dict): Auxiliary model inputs like embeddings, durations or any other sorts of inputs.
30 |
31 | Returns:
32 | Dict: Model outputs. Main model output must be named as "model_outputs".
33 | """
34 | outputs_dict = {"model_outputs": None}
35 | ...
36 | return outputs_dict
37 |
38 | def format_batch(self, batch: Dict) -> Dict:
39 | """Format batch returned by the data loader before sending it to the model.
40 |
41 | If not implemented, model uses the batch as is.
42 | Can be used for data augmentation, feature ectraction, etc.
43 | """
44 | return batch
45 |
46 | def format_batch_on_device(self, batch: Dict) -> Dict:
47 | """Format batch on device before sending it to the model.
48 |
49 | If not implemented, model uses the batch as is.
50 | Can be used for data augmentation, feature ectraction, etc.`
51 | """
52 | return batch
53 |
54 | def train_step(self, *args: Any, **kwargs: Any) -> Tuple[Dict, Dict]:
55 | """Perform a single training step. Run the model forward ... and compute losses.
56 |
57 | Args:
58 | batch (Dict): Input tensors.
59 | criterion (nn.Module): Loss layer designed for the model.
60 | optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
61 |
62 | Returns:
63 | Tuple[Dict, Dict]: Model ouputs and computed losses.
64 | """
65 | ...
66 | raise NotImplementedError(" [!] `train_step()` is not implemented.")
67 |
68 | def train_log(self, *args: Any, **kwargs: Any) -> None:
69 | """Create visualizations and waveform examples for training.
70 |
71 | For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
72 | be projected onto Tensorboard.
73 |
74 | Args:
75 | batch (Dict): Model inputs used at the previous training step.
76 | outputs (Dict): Model outputs generated at the previoud training step.
77 | logger (Logger): Logger instance to log training plots.
78 | assets (Dict): Assets to be used for logging from the trainer's closure.
79 | steps (int): Number of training steps taken so far.
80 |
81 | Returns:
82 | Tuple[Dict, np.ndarray]: training plots and output waveform.
83 | """
84 | ...
85 | raise NotImplementedError(" [!] `train_log()` is not implemented.")
86 |
87 | @torch.no_grad()
88 | def eval_step(self, *args: Any, **kwargs: Any):
89 | """Perform a single evaluation step. Run the model forward ... and compute losses. In most cases, you can
90 | call `train_step()` with no changes.
91 |
92 | Args:
93 | batch (Dict): Input tensors.
94 | criterion (nn.Module): Loss layer designed for the model.
95 | optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
96 |
97 | Returns:
98 | Tuple[Dict, Dict]: Model ouputs and computed losses.
99 | """
100 | raise NotImplementedError(" [!] `eval_step()` is not implemented.")
101 |
102 | def eval_log(self, *args: Any, **kwargs: Any) -> None:
103 | """The same as `train_log()`"""
104 | ...
105 | raise NotImplementedError(" [!] `eval_log()` is not implemented.")
106 |
107 | @abstractmethod
108 | def get_data_loader(*args: Any, **kwargs: Any) -> torch.utils.data.DataLoader:
109 | """Get data loader for the model.
110 |
111 | Args:
112 | config (Coqpit): Configuration object.
113 | assets (Dict): Additional assets to be used for data loading.
114 | is_eval (bool): If True, returns evaluation data loader.
115 | samples (Union[List[Dict], List[List]]): List of samples to be used for data loading.
116 | verbose (bool): If True, prints data loading information.
117 | num_gpus (int): Number of GPUs used for training.
118 | rank (int): Rank of the current GPU.
119 |
120 | Returns:
121 | torch.utils.data.DataLoader: Data loader for the model.
122 | """
123 |
124 | ...
125 | raise NotImplementedError(" [!] `get_data_loader()` is not implemented.")
126 |
127 | def init_for_training(self) -> None:
128 | """Initialize model for training."""
129 | ...
130 |
131 | def optimize(self, *args: Any, **kwargs: Any) -> Tuple[Dict, Dict, float]:
132 | """Model specific optimization step that must perform the following steps:
133 | 1. Forward pass
134 | 2. Compute loss
135 | 3. Backward pass
136 | 4. Update weights
137 |
138 | Use `self.scaled_backward()` instead of `loss.backward()` to be able to use Mixed Precision Training.
139 |
140 | Args:
141 | batch (Dict): Input tensors.
142 | trainer (Trainer): Trainer instance to be able to access the training closure.
143 |
144 | Returns:
145 | Tuple[Dict, Dict, float]: Model outputs, loss dictionary and grad_norm value.
146 | """
147 | ...
148 | raise NotImplementedError(" [!] `optimize()` is not implemented.")
149 |
150 | def scaled_backward(
151 | self, loss: torch.Tensor, trainer: "Trainer", optimizer: "Optimizer", *args: Any, **kwargs: Any
152 | ) -> Tuple[float, bool]:
153 | """Backward pass with gradient scaling for custom `optimize` calls.
154 |
155 | Args:
156 | loss (torch.Tensor): Loss to be backpropagated.
157 | trainer (Trainer): Trainer instance to be able to access the training closure.
158 | optimizer (Optimizer): Optimizer for APEX AMP based scaled `backward` calls.
159 | """
160 | if trainer.use_amp_scaler:
161 | if trainer.use_apex:
162 | # https://nvidia.github.io/apex/advanced.html?highlight=accumulate#backward-passes-with-multiple-optimizers
163 | with amp.scale_loss(loss, optimizer) as scaled_loss:
164 | scaled_loss.backward()
165 | else:
166 | # model optimizer step in mixed precision mode
167 | trainer.scaler.scale(loss).backward()
168 | else:
169 | # main model optimizer step
170 | loss.backward()
171 |
172 | # def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]:
173 | # """Setup an return optimizer or optimizers."""
174 | # ...
175 |
176 | # def get_lr(self) -> Union[float, List[float]]:
177 | # """Return learning rate(s).
178 |
179 | # Returns:
180 | # Union[float, List[float]]: Model's initial learning rates.
181 | # """
182 | # ...
183 |
184 | # def get_scheduler(self, optimizer: torch.optim.Optimizer):
185 | # ...
186 |
187 | # def get_criterion(self):
188 | # ...
189 |
--------------------------------------------------------------------------------
/trainer/callbacks.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Dict
2 |
3 |
4 | class TrainerCallback:
5 | def __init__(self) -> None:
6 | self.callbacks_on_init_start = []
7 | self.callbacks_on_init_end = []
8 | self.callbacks_on_epoch_start = []
9 | self.callbacks_on_epoch_end = []
10 | self.callbacks_on_train_epoch_start = []
11 | self.callbacks_on_train_epoch_end = []
12 | self.callbacks_on_train_step_start = []
13 | self.callbacks_on_train_step_end = []
14 | self.callbacks_on_keyboard_interrupt = []
15 |
16 | def parse_callbacks_dict(self, callbacks_dict: Dict[str, Callable]) -> None:
17 | for key, value in callbacks_dict.items():
18 | if key == "on_init_start":
19 | self.callbacks_on_init_start.append(value)
20 | elif key == "on_init_end":
21 | self.callbacks_on_init_end.append(value)
22 | elif key == "on_epoch_start":
23 | self.callbacks_on_epoch_start.append(value)
24 | elif key == "on_epoch_end":
25 | self.callbacks_on_epoch_end.append(value)
26 | elif key == "on_train_epoch_start":
27 | self.callbacks_on_train_epoch_start.append(value)
28 | elif key == "on_train_epoch_end":
29 | self.callbacks_on_train_epoch_end.append(value)
30 | elif key == "on_train_step_start":
31 | self.callbacks_on_train_step_start.append(value)
32 | elif key == "on_train_step_end":
33 | self.callbacks_on_train_step_end.append(value)
34 | elif key == "on_keyboard_interrupt":
35 | self.callbacks_on_keyboard_interrupt.append(value)
36 | else:
37 | raise ValueError(f"Invalid callback key: {key}")
38 |
39 | def on_init_start(self, trainer) -> None:
40 | if hasattr(trainer.model, "module"):
41 | if hasattr(trainer.model.module, "on_init_start"):
42 | trainer.model.module.on_init_start(trainer)
43 | else:
44 | if hasattr(trainer.model, "on_init_start"):
45 | trainer.model.on_init_start(trainer)
46 |
47 | if hasattr(trainer.criterion, "on_init_start"):
48 | trainer.criterion.on_init_start(trainer)
49 |
50 | if hasattr(trainer.optimizer, "on_init_start"):
51 | trainer.optimizer.on_init_start(trainer)
52 |
53 | if self.callbacks_on_init_start:
54 | for callback in self.callbacks_on_init_start:
55 | callback(trainer)
56 |
57 | def on_init_end(self, trainer) -> None:
58 | if hasattr(trainer.model, "module"):
59 | if hasattr(trainer.model.module, "on_init_end"):
60 | trainer.model.module.on_init_end(trainer)
61 | else:
62 | if hasattr(trainer.model, "on_init_end"):
63 | trainer.model.on_init_end(trainer)
64 |
65 | if hasattr(trainer.criterion, "on_init_end"):
66 | trainer.criterion.on_init_end(trainer)
67 |
68 | if hasattr(trainer.optimizer, "on_init_end"):
69 | trainer.optimizer.on_init_end(trainer)
70 |
71 | if len(self.callbacks_on_init_end) > 0:
72 | for callback in self.callbacks_on_init_end:
73 | callback(trainer)
74 |
75 | def on_epoch_start(self, trainer) -> None:
76 | if hasattr(trainer.model, "module"):
77 | if hasattr(trainer.model.module, "on_epoch_start"):
78 | trainer.model.module.on_epoch_start(trainer)
79 | else:
80 | if hasattr(trainer.model, "on_epoch_start"):
81 | trainer.model.on_epoch_start(trainer)
82 |
83 | if hasattr(trainer.criterion, "on_epoch_start"):
84 | trainer.criterion.on_epoch_start(trainer)
85 |
86 | if hasattr(trainer.optimizer, "on_epoch_start"):
87 | trainer.optimizer.on_epoch_start(trainer)
88 |
89 | if self.callbacks_on_epoch_start:
90 | for callback in self.callbacks_on_epoch_start:
91 | callback(trainer)
92 |
93 | def on_epoch_end(self, trainer) -> None:
94 | if hasattr(trainer.model, "module"):
95 | if hasattr(trainer.model.module, "on_epoch_end"):
96 | trainer.model.module.on_epoch_end(trainer)
97 | else:
98 | if hasattr(trainer.model, "on_epoch_end"):
99 | trainer.model.on_epoch_end(trainer)
100 |
101 | if hasattr(trainer.criterion, "on_epoch_end"):
102 | trainer.criterion.on_epoch_end(trainer)
103 |
104 | if hasattr(trainer.optimizer, "on_epoch_end"):
105 | trainer.optimizer.on_epoch_end(trainer)
106 |
107 | if self.callbacks_on_epoch_end:
108 | for callback in self.callbacks_on_epoch_end:
109 | callback(trainer)
110 |
111 | def on_train_epoch_start(self, trainer) -> None:
112 | if hasattr(trainer.model, "module"):
113 | if hasattr(trainer.model.module, "on_train_epoch_start"):
114 | trainer.model.module.on_train_epoch_start(trainer)
115 | else:
116 | if hasattr(trainer.model, "on_train_epoch_start"):
117 | trainer.model.on_train_epoch_start(trainer)
118 |
119 | if hasattr(trainer.criterion, "on_train_epoch_start"):
120 | trainer.criterion.on_train_epoch_start(trainer)
121 |
122 | if hasattr(trainer.optimizer, "on_train_epoch_start"):
123 | trainer.optimizer.on_train_epoch_start(trainer)
124 |
125 | if self.callbacks_on_train_epoch_start:
126 | for callback in self.callbacks_on_train_epoch_start:
127 | callback(trainer)
128 |
129 | def on_train_epoch_end(self, trainer) -> None:
130 | if hasattr(trainer.model, "module"):
131 | if hasattr(trainer.model.module, "on_train_epoch_end"):
132 | trainer.model.module.on_train_epoch_end(trainer)
133 | else:
134 | if hasattr(trainer.model, "on_train_epoch_end"):
135 | trainer.model.on_train_epoch_end(trainer)
136 |
137 | if hasattr(trainer.criterion, "on_train_epoch_end"):
138 | trainer.criterion.on_train_epoch_end(trainer)
139 |
140 | if hasattr(trainer.optimizer, "on_train_epoch_end"):
141 | trainer.optimizer.on_train_epoch_end(trainer)
142 |
143 | if self.callbacks_on_train_epoch_end:
144 | for callback in self.callbacks_on_train_epoch_end:
145 | callback(trainer)
146 |
147 | @staticmethod
148 | def before_backward_pass(trainer, loss_dict) -> None:
149 | if hasattr(trainer.model, "module"):
150 | if hasattr(trainer.model.module, "before_backward_pass"):
151 | trainer.model.module.before_backward_pass(loss_dict, trainer.optimizer)
152 | else:
153 | if hasattr(trainer.model, "before_backward_pass"):
154 | trainer.model.before_backward_pass(loss_dict, trainer.optimizer)
155 |
156 | @staticmethod
157 | def before_gradient_clipping(trainer) -> None:
158 | if hasattr(trainer.model, "module"):
159 | if hasattr(trainer.model.module, "before_gradient_clipping"):
160 | trainer.model.module.before_gradient_clipping()
161 | else:
162 | if hasattr(trainer.model, "before_gradient_clipping"):
163 | trainer.model.before_gradient_clipping()
164 |
165 | def on_train_step_start(self, trainer) -> None:
166 | if hasattr(trainer.model, "module"):
167 | if hasattr(trainer.model.module, "on_train_step_start"):
168 | trainer.model.module.on_train_step_start(trainer)
169 | else:
170 | if hasattr(trainer.model, "on_train_step_start"):
171 | trainer.model.on_train_step_start(trainer)
172 |
173 | if hasattr(trainer.criterion, "on_train_step_start"):
174 | trainer.criterion.on_train_step_start(trainer)
175 |
176 | if hasattr(trainer.optimizer, "on_train_step_start"):
177 | trainer.optimizer.on_train_step_start(trainer)
178 |
179 | if self.callbacks_on_train_step_start:
180 | for callback in self.callbacks_on_train_step_start:
181 | callback(trainer)
182 |
183 | def on_train_step_end(self, trainer) -> None:
184 | if hasattr(trainer.model, "module"):
185 | if hasattr(trainer.model.module, "on_train_step_end"):
186 | trainer.model.module.on_train_step_end(trainer)
187 | else:
188 | if hasattr(trainer.model, "on_train_step_end"):
189 | trainer.model.on_train_step_end(trainer)
190 |
191 | if hasattr(trainer.criterion, "on_train_step_end"):
192 | trainer.criterion.on_train_step_end(trainer)
193 |
194 | if hasattr(trainer.optimizer, "on_train_step_end"):
195 | trainer.optimizer.on_train_step_end(trainer)
196 |
197 | if self.callbacks_on_train_step_end:
198 | for callback in self.callbacks_on_train_step_end:
199 | callback(trainer)
200 |
201 | def on_keyboard_interrupt(self, trainer) -> None:
202 | if hasattr(trainer.model, "module"):
203 | if hasattr(trainer.model.module, "on_keyboard_interrupt"):
204 | trainer.model.module.on_keyboard_interrupt(trainer)
205 | else:
206 | if hasattr(trainer.model, "on_keyboard_interrupt"):
207 | trainer.model.on_keyboard_interrupt(trainer)
208 |
209 | if hasattr(trainer.criterion, "on_keyboard_interrupt"):
210 | trainer.criterion.on_keyboard_interrupt(trainer)
211 |
212 | if hasattr(trainer.optimizer, "on_keyboard_interrupt"):
213 | trainer.optimizer.on_keyboard_interrupt(trainer)
214 |
215 | if self.callbacks_on_keyboard_interrupt:
216 | for callback in self.callbacks_on_keyboard_interrupt:
217 | callback(trainer)
218 |
--------------------------------------------------------------------------------
/trainer/io.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import json
3 | import os
4 | import re
5 | import sys
6 | from pathlib import Path
7 | from typing import Any, Callable, Dict, List, Tuple, Union
8 | from urllib.parse import urlparse
9 |
10 | import fsspec
11 | import torch
12 | from coqpit import Coqpit
13 |
14 | from trainer.logger import logger
15 |
16 |
17 | def get_user_data_dir(appname):
18 | if sys.platform == "win32":
19 | import winreg # pylint: disable=import-outside-toplevel, import-error
20 |
21 | key = winreg.OpenKey(
22 | winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
23 | )
24 | dir_, _ = winreg.QueryValueEx(key, "Local AppData")
25 | ans = Path(dir_).resolve(strict=False)
26 | elif sys.platform == "darwin":
27 | ans = Path("~/Library/Application Support/").expanduser()
28 | else:
29 | ans = Path.home().joinpath(".local/share")
30 | return ans.joinpath(appname)
31 |
32 |
33 | def copy_model_files(config: Coqpit, out_path, new_fields):
34 | """Copy config.json and other model files to training folder and add
35 | new fields.
36 |
37 | Args:
38 | config (Coqpit): Coqpit config defining the training run.
39 | out_path (str): output path to copy the file.
40 | new_fields (dict): new fileds to be added or edited
41 | in the config file.
42 | """
43 | copy_config_path = os.path.join(out_path, "config.json")
44 | # add extra information fields
45 | new_config = {**config.to_dict(), **new_fields}
46 | # TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
47 | with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
48 | json.dump(new_config, f, indent=4)
49 |
50 |
51 | def load_fsspec(
52 | path: str,
53 | map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
54 | cache: bool = True,
55 | **kwargs,
56 | ) -> Any:
57 | """Like torch.load but can load from other locations (e.g. s3:// , gs://).
58 | Args:
59 | path: Any path or url supported by fsspec.
60 | map_location: torch.device or str.
61 | cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/trainer_cache`. Defaults to True.
62 | **kwargs: Keyword arguments forwarded to torch.load.
63 | Returns:
64 | Object stored in path.
65 | """
66 | is_local = os.path.isdir(path) or os.path.isfile(path)
67 | if cache and not is_local:
68 | with fsspec.open(
69 | f"filecache::{path}",
70 | filecache={"cache_storage": str(get_user_data_dir("tts_cache"))},
71 | mode="rb",
72 | ) as f:
73 | return torch.load(f, map_location=map_location, **kwargs)
74 | else:
75 | with fsspec.open(path, "rb") as f:
76 | return torch.load(f, map_location=map_location, **kwargs)
77 |
78 |
79 | def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
80 | state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
81 | model.load_state_dict(state["model"])
82 | if use_cuda:
83 | model.cuda()
84 | if eval:
85 | model.eval()
86 | return model, state
87 |
88 |
89 | def save_fsspec(state: Any, path: str, **kwargs):
90 | """Like torch.save but can save to other locations (e.g. s3:// , gs://).
91 |
92 | Args:
93 | state: State object to save
94 | path: Any path or url supported by fsspec.
95 | **kwargs: Keyword arguments forwarded to torch.save.
96 | """
97 | with fsspec.open(path, "wb") as f:
98 | torch.save(state, f, **kwargs)
99 |
100 |
101 | def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, save_func, **kwargs):
102 | if hasattr(model, "module"):
103 | model_state = model.module.state_dict()
104 | else:
105 | model_state = model.state_dict()
106 | if isinstance(optimizer, list):
107 | optimizer_state = [optim.state_dict() for optim in optimizer]
108 | elif isinstance(optimizer, dict):
109 | optimizer_state = {k: v.state_dict() for k, v in optimizer.items()}
110 | else:
111 | optimizer_state = optimizer.state_dict() if optimizer is not None else None
112 |
113 | if isinstance(scaler, list):
114 | scaler_state = [s.state_dict() for s in scaler]
115 | else:
116 | scaler_state = scaler.state_dict() if scaler is not None else None
117 |
118 | if isinstance(config, Coqpit):
119 | config = config.to_dict()
120 |
121 | state = {
122 | "config": config,
123 | "model": model_state,
124 | "optimizer": optimizer_state,
125 | "scaler": scaler_state,
126 | "step": current_step,
127 | "epoch": epoch,
128 | "date": datetime.date.today().strftime("%B %d, %Y"),
129 | }
130 | state.update(kwargs)
131 | if save_func:
132 | save_func(state, output_path)
133 | else:
134 | save_fsspec(state, output_path)
135 |
136 |
137 | def save_checkpoint(
138 | config,
139 | model,
140 | optimizer,
141 | scaler,
142 | current_step,
143 | epoch,
144 | output_folder,
145 | save_n_checkpoints=None,
146 | save_func=None,
147 | **kwargs,
148 | ):
149 | file_name = f"checkpoint_{current_step}.pth"
150 | checkpoint_path = os.path.join(output_folder, file_name)
151 |
152 | logger.info("\n > CHECKPOINT : %s", checkpoint_path)
153 | save_model(
154 | config,
155 | model,
156 | optimizer,
157 | scaler,
158 | current_step,
159 | epoch,
160 | checkpoint_path,
161 | save_func=save_func,
162 | **kwargs,
163 | )
164 | if save_n_checkpoints is not None:
165 | keep_n_checkpoints(output_folder, save_n_checkpoints)
166 |
167 |
168 | def save_best_model(
169 | current_loss,
170 | best_loss,
171 | config,
172 | model,
173 | optimizer,
174 | scaler,
175 | current_step,
176 | epoch,
177 | out_path,
178 | keep_all_best=False,
179 | keep_after=0,
180 | save_func=None,
181 | **kwargs,
182 | ):
183 | if isinstance(current_loss, dict):
184 | use_eval_loss = current_loss["eval_loss"] is not None and best_loss["eval_loss"] is not None
185 | is_save_model = (use_eval_loss and current_loss["eval_loss"] < best_loss["eval_loss"]) or (
186 | not use_eval_loss and current_loss["train_loss"] < best_loss["train_loss"]
187 | )
188 | else:
189 | is_save_model = current_loss < best_loss
190 |
191 | if isinstance(keep_after, (int, float)):
192 | keep_after = int(keep_after)
193 | is_save_model = is_save_model and current_step > keep_after
194 |
195 | if is_save_model:
196 | best_model_name = f"best_model_{current_step}.pth"
197 | checkpoint_path = os.path.join(out_path, best_model_name)
198 | logger.info(" > BEST MODEL : %s", checkpoint_path)
199 | save_model(
200 | config,
201 | model,
202 | optimizer,
203 | scaler,
204 | current_step,
205 | epoch,
206 | checkpoint_path,
207 | model_loss=current_loss,
208 | save_func=save_func,
209 | **kwargs,
210 | )
211 | fs = fsspec.get_mapper(out_path).fs
212 | # only delete previous if current is saved successfully
213 | if not keep_all_best or (current_step < keep_after):
214 | model_names = fs.glob(os.path.join(out_path, "best_model*.pth"))
215 | for model_name in model_names:
216 | if os.path.basename(model_name) != best_model_name:
217 | fs.rm(model_name)
218 | # create a shortcut which always points to the currently best model
219 | shortcut_name = "best_model.pth"
220 | shortcut_path = os.path.join(out_path, shortcut_name)
221 | fs.copy(checkpoint_path, shortcut_path)
222 | best_loss = current_loss
223 | return best_loss
224 |
225 |
226 | def get_last_checkpoint(path: str) -> Tuple[str, str]:
227 | """Get latest checkpoint or/and best model in path.
228 |
229 | It is based on globbing for `*.pth` and the RegEx
230 | `(checkpoint|best_model)_([0-9]+)`.
231 |
232 | Args:
233 | path: Path to files to be compared.
234 |
235 | Raises:
236 | ValueError: If no checkpoint or best_model files are found.
237 |
238 | Returns:
239 | Path to the last checkpoint
240 | Path to best checkpoint
241 | """
242 | fs = fsspec.get_mapper(path).fs
243 | file_names = fs.glob(os.path.join(path, "*.pth"))
244 | scheme = urlparse(path).scheme
245 | if scheme and path.startswith(scheme + "://"):
246 | # scheme is not preserved in fs.glob, add it
247 | # back if it exists on the path
248 | file_names = [scheme + "://" + file_name for file_name in file_names]
249 | last_models = {}
250 | last_model_nums = {}
251 | for key in ["checkpoint", "best_model"]:
252 | last_model_num = None
253 | last_model = None
254 | # pass all the checkpoint files and find
255 | # the one with the largest model number suffix.
256 | for file_name in file_names:
257 | match = re.search(f"{key}_([0-9]+)", file_name)
258 | if match is not None:
259 | model_num = int(match.groups()[0])
260 | if last_model_num is None or model_num > last_model_num:
261 | last_model_num = model_num
262 | last_model = file_name
263 |
264 | # if there is no checkpoint found above
265 | # find the checkpoint with the latest
266 | # modification date.
267 | key_file_names = [fn for fn in file_names if key in fn]
268 | if last_model is None and len(key_file_names) > 0:
269 | last_model = max(key_file_names, key=os.path.getctime)
270 | last_model_num = load_fsspec(last_model)["step"]
271 |
272 | if last_model is not None:
273 | last_models[key] = last_model
274 | last_model_nums[key] = last_model_num
275 |
276 | # check what models were found
277 | if not last_models:
278 | raise ValueError(f"No models found in continue path {path}!")
279 | if "checkpoint" not in last_models: # no checkpoint just best model
280 | last_models["checkpoint"] = last_models["best_model"]
281 | elif "best_model" not in last_models: # no best model
282 | # this shouldn't happen, but let's handle it just in case
283 | last_models["best_model"] = last_models["checkpoint"]
284 | # finally check if last best model is more recent than checkpoint
285 | elif last_model_nums["best_model"] > last_model_nums["checkpoint"]:
286 | last_models["checkpoint"] = last_models["best_model"]
287 |
288 | return last_models["checkpoint"], last_models["best_model"]
289 |
290 |
291 | def keep_n_checkpoints(path: str, n: int) -> None:
292 | """Keep only the last n checkpoints in path.
293 |
294 | Args:
295 | path: Path to files to be compared.
296 | n: Number of checkpoints to keep.
297 | """
298 | fs = fsspec.get_mapper(path).fs
299 | file_names = sort_checkpoints(path, "checkpoint")
300 | if len(file_names) > n:
301 | for file_name in file_names[:-n]:
302 | fs.rm(file_name)
303 |
304 |
305 | def sort_checkpoints(output_path: str, checkpoint_prefix: str, use_mtime: bool = False) -> List[str]:
306 | """Sort checkpoint paths based on the checkpoint step number.
307 |
308 | Args:
309 | output_path (str): Path to directory containing checkpoints.
310 | checkpoint_prefix (str): Prefix of the checkpoint files.
311 | use_mtime (bool): If True, use modification dates to determine checkpoint order.
312 | """
313 | ordering_and_checkpoint_path = []
314 |
315 | glob_checkpoints = [str(x) for x in Path(output_path).glob(f"{checkpoint_prefix}_*")]
316 |
317 | for path in glob_checkpoints:
318 | if use_mtime:
319 | ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
320 | else:
321 | regex_match = re.match(f".*{checkpoint_prefix}_([0-9]+)", path)
322 | if regex_match is not None and regex_match.groups() is not None:
323 | ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
324 |
325 | checkpoints_sorted = sorted(ordering_and_checkpoint_path)
326 | checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
327 | return checkpoints_sorted
328 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MAIN]
2 |
3 | # Analyse import fallback blocks. This can be used to support both Python 2 and
4 | # 3 compatible code, which means that the block might have code that exists
5 | # only in one or another interpreter, leading to false positives when analysed.
6 | analyse-fallback-blocks=no
7 |
8 | # Clear in-memory caches upon conclusion of linting. Useful if running pylint
9 | # in a server-like mode.
10 | clear-cache-post-run=no
11 |
12 | # Load and enable all available extensions. Use --list-extensions to see a list
13 | # all available extensions.
14 | #enable-all-extensions=
15 |
16 | # In error mode, messages with a category besides ERROR or FATAL are
17 | # suppressed, and no reports are done by default. Error mode is compatible with
18 | # disabling specific errors.
19 | #errors-only=
20 |
21 | # Always return a 0 (non-error) status code, even if lint errors are found.
22 | # This is primarily useful in continuous integration scripts.
23 | #exit-zero=
24 |
25 | # A comma-separated list of package or module names from where C extensions may
26 | # be loaded. Extensions are loading into the active Python interpreter and may
27 | # run arbitrary code.
28 | extension-pkg-allow-list=
29 |
30 | # A comma-separated list of package or module names from where C extensions may
31 | # be loaded. Extensions are loading into the active Python interpreter and may
32 | # run arbitrary code. (This is an alternative name to extension-pkg-allow-list
33 | # for backward compatibility.)
34 | extension-pkg-whitelist=
35 |
36 | # Return non-zero exit code if any of these messages/categories are detected,
37 | # even if score is above --fail-under value. Syntax same as enable. Messages
38 | # specified are enabled, while categories only check already-enabled messages.
39 | fail-on=
40 |
41 | # Specify a score threshold under which the program will exit with error.
42 | fail-under=10
43 |
44 | # Interpret the stdin as a python script, whose filename needs to be passed as
45 | # the module_or_package argument.
46 | #from-stdin=
47 |
48 | # Files or directories to be skipped. They should be base names, not paths.
49 | ignore=CVS
50 |
51 | # Add files or directories matching the regular expressions patterns to the
52 | # ignore-list. The regex matches against paths and can be in Posix or Windows
53 | # format. Because '\\' represents the directory delimiter on Windows systems,
54 | # it can't be used as an escape character.
55 | ignore-paths=
56 |
57 | # Files or directories matching the regular expression patterns are skipped.
58 | # The regex matches against base names, not paths. The default value ignores
59 | # Emacs file locks
60 | ignore-patterns=^\.#
61 |
62 | # List of module names for which member attributes should not be checked
63 | # (useful for modules/projects where namespaces are manipulated during runtime
64 | # and thus existing member attributes cannot be deduced by static analysis). It
65 | # supports qualified module names, as well as Unix pattern matching.
66 | ignored-modules=
67 |
68 | # Python code to execute, usually for sys.path manipulation such as
69 | # pygtk.require().
70 | #init-hook=
71 |
72 | # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
73 | # number of processors available to use, and will cap the count on Windows to
74 | # avoid hangs.
75 | jobs=1
76 |
77 | # Control the amount of potential inferred values when inferring a single
78 | # object. This can help the performance when dealing with large functions or
79 | # complex, nested conditions.
80 | limit-inference-results=100
81 |
82 | # List of plugins (as comma separated values of python module names) to load,
83 | # usually to register additional checkers.
84 | load-plugins=
85 |
86 | # Pickle collected data for later comparisons.
87 | persistent=yes
88 |
89 | # Minimum Python version to use for version dependent checks. Will default to
90 | # the version used to run pylint.
91 | py-version=3.11
92 |
93 | # Discover python modules and packages in the file system subtree.
94 | recursive=no
95 |
96 | # Add paths to the list of the source roots. Supports globbing patterns. The
97 | # source root is an absolute path or a path relative to the current working
98 | # directory used to determine a package namespace for modules located under the
99 | # source root.
100 | source-roots=
101 |
102 | # When enabled, pylint would attempt to guess common misconfiguration and emit
103 | # user-friendly hints instead of false-positive error messages.
104 | suggestion-mode=yes
105 |
106 | # Allow loading of arbitrary C extensions. Extensions are imported into the
107 | # active Python interpreter and may run arbitrary code.
108 | unsafe-load-any-extension=no
109 |
110 | # In verbose mode, extra non-checker-related info will be displayed.
111 | #verbose=
112 |
113 |
114 | [BASIC]
115 |
116 | # Naming style matching correct argument names.
117 | argument-naming-style=snake_case
118 |
119 | # Regular expression matching correct argument names. Overrides argument-
120 | # naming-style. If left empty, argument names will be checked with the set
121 | # naming style.
122 | #argument-rgx=
123 |
124 | # Naming style matching correct attribute names.
125 | attr-naming-style=snake_case
126 |
127 | # Regular expression matching correct attribute names. Overrides attr-naming-
128 | # style. If left empty, attribute names will be checked with the set naming
129 | # style.
130 | #attr-rgx=
131 |
132 | # Bad variable names which should always be refused, separated by a comma.
133 | bad-names=foo,
134 | bar,
135 | baz,
136 | toto,
137 | tutu,
138 | tata
139 |
140 | # Bad variable names regexes, separated by a comma. If names match any regex,
141 | # they will always be refused
142 | bad-names-rgxs=
143 |
144 | # Naming style matching correct class attribute names.
145 | class-attribute-naming-style=any
146 |
147 | # Regular expression matching correct class attribute names. Overrides class-
148 | # attribute-naming-style. If left empty, class attribute names will be checked
149 | # with the set naming style.
150 | #class-attribute-rgx=
151 |
152 | # Naming style matching correct class constant names.
153 | class-const-naming-style=UPPER_CASE
154 |
155 | # Regular expression matching correct class constant names. Overrides class-
156 | # const-naming-style. If left empty, class constant names will be checked with
157 | # the set naming style.
158 | #class-const-rgx=
159 |
160 | # Naming style matching correct class names.
161 | class-naming-style=PascalCase
162 |
163 | # Regular expression matching correct class names. Overrides class-naming-
164 | # style. If left empty, class names will be checked with the set naming style.
165 | #class-rgx=
166 |
167 | # Naming style matching correct constant names.
168 | const-naming-style=UPPER_CASE
169 |
170 | # Regular expression matching correct constant names. Overrides const-naming-
171 | # style. If left empty, constant names will be checked with the set naming
172 | # style.
173 | #const-rgx=
174 |
175 | # Minimum line length for functions/classes that require docstrings, shorter
176 | # ones are exempt.
177 | docstring-min-length=-1
178 |
179 | # Naming style matching correct function names.
180 | function-naming-style=snake_case
181 |
182 | # Regular expression matching correct function names. Overrides function-
183 | # naming-style. If left empty, function names will be checked with the set
184 | # naming style.
185 | #function-rgx=
186 |
187 | # Good variable names which should always be accepted, separated by a comma.
188 | good-names=i,
189 | j,
190 | k,
191 | ex,
192 | Run,
193 | _
194 |
195 | # Good variable names regexes, separated by a comma. If names match any regex,
196 | # they will always be accepted
197 | good-names-rgxs=
198 |
199 | # Include a hint for the correct naming format with invalid-name.
200 | include-naming-hint=no
201 |
202 | # Naming style matching correct inline iteration names.
203 | inlinevar-naming-style=any
204 |
205 | # Regular expression matching correct inline iteration names. Overrides
206 | # inlinevar-naming-style. If left empty, inline iteration names will be checked
207 | # with the set naming style.
208 | #inlinevar-rgx=
209 |
210 | # Naming style matching correct method names.
211 | method-naming-style=snake_case
212 |
213 | # Regular expression matching correct method names. Overrides method-naming-
214 | # style. If left empty, method names will be checked with the set naming style.
215 | #method-rgx=
216 |
217 | # Naming style matching correct module names.
218 | module-naming-style=snake_case
219 |
220 | # Regular expression matching correct module names. Overrides module-naming-
221 | # style. If left empty, module names will be checked with the set naming style.
222 | #module-rgx=
223 |
224 | # Colon-delimited sets of names that determine each other's naming style when
225 | # the name regexes allow several styles.
226 | name-group=
227 |
228 | # Regular expression which should only match function or class names that do
229 | # not require a docstring.
230 | no-docstring-rgx=^_
231 |
232 | # List of decorators that produce properties, such as abc.abstractproperty. Add
233 | # to this list to register other decorators that produce valid properties.
234 | # These decorators are taken in consideration only for invalid-name.
235 | property-classes=abc.abstractproperty
236 |
237 | # Regular expression matching correct type alias names. If left empty, type
238 | # alias names will be checked with the set naming style.
239 | #typealias-rgx=
240 |
241 | # Regular expression matching correct type variable names. If left empty, type
242 | # variable names will be checked with the set naming style.
243 | #typevar-rgx=
244 |
245 | # Naming style matching correct variable names.
246 | variable-naming-style=snake_case
247 |
248 | # Regular expression matching correct variable names. Overrides variable-
249 | # naming-style. If left empty, variable names will be checked with the set
250 | # naming style.
251 | #variable-rgx=
252 |
253 |
254 | [CLASSES]
255 |
256 | # Warn about protected attribute access inside special methods
257 | check-protected-access-in-special-methods=no
258 |
259 | # List of method names used to declare (i.e. assign) instance attributes.
260 | defining-attr-methods=__init__,
261 | __new__,
262 | setUp,
263 | asyncSetUp,
264 | __post_init__
265 |
266 | # List of member names, which should be excluded from the protected access
267 | # warning.
268 | exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit
269 |
270 | # List of valid names for the first argument in a class method.
271 | valid-classmethod-first-arg=cls
272 |
273 | # List of valid names for the first argument in a metaclass class method.
274 | valid-metaclass-classmethod-first-arg=mcs
275 |
276 |
277 | [DESIGN]
278 |
279 | # List of regular expressions of class ancestor names to ignore when counting
280 | # public methods (see R0903)
281 | exclude-too-few-public-methods=
282 |
283 | # List of qualified class names to ignore when counting class parents (see
284 | # R0901)
285 | ignored-parents=
286 |
287 | # Maximum number of arguments for function / method.
288 | max-args=5
289 |
290 | # Maximum number of attributes for a class (see R0902).
291 | max-attributes=7
292 |
293 | # Maximum number of boolean expressions in an if statement (see R0916).
294 | max-bool-expr=5
295 |
296 | # Maximum number of branch for function / method body.
297 | max-branches=12
298 |
299 | # Maximum number of locals for function / method body.
300 | max-locals=15
301 |
302 | # Maximum number of parents for a class (see R0901).
303 | max-parents=7
304 |
305 | # Maximum number of public methods for a class (see R0904).
306 | max-public-methods=20
307 |
308 | # Maximum number of return / yield for function / method body.
309 | max-returns=6
310 |
311 | # Maximum number of statements in function / method body.
312 | max-statements=50
313 |
314 | # Minimum number of public methods for a class (see R0903).
315 | min-public-methods=2
316 |
317 |
318 | [EXCEPTIONS]
319 |
320 | # Exceptions that will emit a warning when caught.
321 | overgeneral-exceptions=builtins.BaseException,builtins.Exception
322 |
323 |
324 | [FORMAT]
325 |
326 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
327 | expected-line-ending-format=
328 |
329 | # Regexp for a line that is allowed to be longer than the limit.
330 | ignore-long-lines=^\s*(# )??$
331 |
332 | # Number of spaces of indent required inside a hanging or continued line.
333 | indent-after-paren=4
334 |
335 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
336 | # tab).
337 | indent-string=' '
338 |
339 | # Maximum number of characters on a single line.
340 | max-line-length=100
341 |
342 | # Maximum number of lines in a module.
343 | max-module-lines=1000
344 |
345 | # Allow the body of a class to be on the same line as the declaration if body
346 | # contains single statement.
347 | single-line-class-stmt=no
348 |
349 | # Allow the body of an if to be on the same line as the test if there is no
350 | # else.
351 | single-line-if-stmt=no
352 |
353 |
354 | [IMPORTS]
355 |
356 | # List of modules that can be imported at any level, not just the top level
357 | # one.
358 | allow-any-import-level=
359 |
360 | # Allow explicit reexports by alias from a package __init__.
361 | allow-reexport-from-package=no
362 |
363 | # Allow wildcard imports from modules that define __all__.
364 | allow-wildcard-with-all=no
365 |
366 | # Deprecated modules which should not be used, separated by a comma.
367 | deprecated-modules=
368 |
369 | # Output a graph (.gv or any supported image format) of external dependencies
370 | # to the given file (report RP0402 must not be disabled).
371 | ext-import-graph=
372 |
373 | # Output a graph (.gv or any supported image format) of all (i.e. internal and
374 | # external) dependencies to the given file (report RP0402 must not be
375 | # disabled).
376 | import-graph=
377 |
378 | # Output a graph (.gv or any supported image format) of internal dependencies
379 | # to the given file (report RP0402 must not be disabled).
380 | int-import-graph=
381 |
382 | # Force import order to recognize a module as part of the standard
383 | # compatibility libraries.
384 | known-standard-library=
385 |
386 | # Force import order to recognize a module as part of a third party library.
387 | known-third-party=enchant
388 |
389 | # Couples of modules and preferred modules, separated by a comma.
390 | preferred-modules=
391 |
392 |
393 | [LOGGING]
394 |
395 | # The type of string formatting that logging methods do. `old` means using %
396 | # formatting, `new` is for `{}` formatting.
397 | logging-format-style=old
398 |
399 | # Logging modules to check that the string format arguments are in logging
400 | # function parameter format.
401 | logging-modules=logging
402 |
403 |
404 | [MESSAGES CONTROL]
405 |
406 | # Only show warnings with the listed confidence levels. Leave empty to show
407 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE,
408 | # UNDEFINED.
409 | confidence=HIGH,
410 | INFERENCE,
411 | INFERENCE_FAILURE,
412 | UNDEFINED
413 |
414 | # Disable the message, report, category or checker with the given id(s). You
415 | # can either give multiple identifiers separated by comma (,) or put this
416 | # option multiple times (only on the command line, not in the configuration
417 | # file where it should appear only once). You can also use "--disable=all" to
418 | # disable everything first and then re-enable specific checks. For example, if
419 | # you want to run only the similarities checker, you can use "--disable=all
420 | # --enable=similarities". If you want to run only the classes checker, but have
421 | # no Warning level messages displayed, use "--disable=all --enable=classes
422 | # --disable=W".
423 | disable=raw-checker-failed,
424 | bad-inline-option,
425 | locally-disabled,
426 | file-ignored,
427 | suppressed-message,
428 | useless-suppression,
429 | deprecated-pragma,
430 | use-symbolic-message-instead,
431 | line-too-long,
432 | missing-function-docstring,
433 | missing-module-docstring,
434 | missing-class-docstring,
435 | invalid-name,
436 | consider-using-f-string,
437 | too-many-instance-attributes,
438 | no-member,
439 | too-many-locals,
440 | too-many-branches,
441 | too-many-arguments,
442 | fixme,
443 | too-many-lines,
444 | too-many-statements,
445 | too-many-public-methods,
446 | duplicate-code,
447 |
448 |
449 | # Enable the message, report, category or checker with the given id(s). You can
450 | # either give multiple identifier separated by comma (,) or put this option
451 | # multiple time (only on the command line, not in the configuration file where
452 | # it should appear only once). See also the "--disable" option for examples.
453 | enable=c-extension-no-member
454 |
455 |
456 | [METHOD_ARGS]
457 |
458 | # List of qualified names (i.e., library.method) which require a timeout
459 | # parameter e.g. 'requests.api.get,requests.api.post'
460 | timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
461 |
462 |
463 | [MISCELLANEOUS]
464 |
465 | # List of note tags to take in consideration, separated by a comma.
466 | notes=FIXME,
467 | XXX,
468 | TODO
469 |
470 | # Regular expression of note tags to take in consideration.
471 | notes-rgx=
472 |
473 |
474 | [REFACTORING]
475 |
476 | # Maximum number of nested blocks for function / method body
477 | max-nested-blocks=5
478 |
479 | # Complete name of functions that never returns. When checking for
480 | # inconsistent-return-statements if a never returning function is called then
481 | # it will be considered as an explicit return statement and no message will be
482 | # printed.
483 | never-returning-functions=sys.exit,argparse.parse_error
484 |
485 |
486 | [REPORTS]
487 |
488 | # Python expression which should return a score less than or equal to 10. You
489 | # have access to the variables 'fatal', 'error', 'warning', 'refactor',
490 | # 'convention', and 'info' which contain the number of messages in each
491 | # category, as well as 'statement' which is the total number of statements
492 | # analyzed. This score is used by the global evaluation report (RP0004).
493 | evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
494 |
495 | # Template used to display messages. This is a python new-style format string
496 | # used to format the message information. See doc for all details.
497 | msg-template=
498 |
499 | # Set the output format. Available formats are text, parseable, colorized, json
500 | # and msvs (visual studio). You can also give a reporter class, e.g.
501 | # mypackage.mymodule.MyReporterClass.
502 | #output-format=
503 |
504 | # Tells whether to display a full report or only the messages.
505 | reports=no
506 |
507 | # Activate the evaluation score.
508 | score=yes
509 |
510 |
511 | [SIMILARITIES]
512 |
513 | # Comments are removed from the similarity computation
514 | ignore-comments=yes
515 |
516 | # Docstrings are removed from the similarity computation
517 | ignore-docstrings=yes
518 |
519 | # Imports are removed from the similarity computation
520 | ignore-imports=yes
521 |
522 | # Signatures are removed from the similarity computation
523 | ignore-signatures=yes
524 |
525 | # Minimum lines number of a similarity.
526 | min-similarity-lines=4
527 |
528 |
529 | [SPELLING]
530 |
531 | # Limits count of emitted suggestions for spelling mistakes.
532 | max-spelling-suggestions=4
533 |
534 | # Spelling dictionary name. No available dictionaries : You need to install
535 | # both the python package and the system dependency for enchant to work..
536 | spelling-dict=
537 |
538 | # List of comma separated words that should be considered directives if they
539 | # appear at the beginning of a comment and should not be checked.
540 | spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
541 |
542 | # List of comma separated words that should not be checked.
543 | spelling-ignore-words=
544 |
545 | # A path to a file that contains the private dictionary; one word per line.
546 | spelling-private-dict-file=
547 |
548 | # Tells whether to store unknown words to the private dictionary (see the
549 | # --spelling-private-dict-file option) instead of raising a message.
550 | spelling-store-unknown-words=no
551 |
552 |
553 | [STRING]
554 |
555 | # This flag controls whether inconsistent-quotes generates a warning when the
556 | # character used as a quote delimiter is used inconsistently within a module.
557 | check-quote-consistency=no
558 |
559 | # This flag controls whether the implicit-str-concat should generate a warning
560 | # on implicit string concatenation in sequences defined over several lines.
561 | check-str-concat-over-line-jumps=no
562 |
563 |
564 | [TYPECHECK]
565 |
566 | # List of decorators that produce context managers, such as
567 | # contextlib.contextmanager. Add to this list to register other decorators that
568 | # produce valid context managers.
569 | contextmanager-decorators=contextlib.contextmanager
570 |
571 | # List of members which are set dynamically and missed by pylint inference
572 | # system, and so shouldn't trigger E1101 when accessed. Python regular
573 | # expressions are accepted.
574 | generated-members=
575 |
576 | # Tells whether to warn about missing members when the owner of the attribute
577 | # is inferred to be None.
578 | ignore-none=yes
579 |
580 | # This flag controls whether pylint should warn about no-member and similar
581 | # checks whenever an opaque object is returned when inferring. The inference
582 | # can return multiple potential results while evaluating a Python object, but
583 | # some branches might not be evaluated, which results in partial inference. In
584 | # that case, it might be useful to still emit no-member and other checks for
585 | # the rest of the inferred objects.
586 | ignore-on-opaque-inference=yes
587 |
588 | # List of symbolic message names to ignore for Mixin members.
589 | ignored-checks-for-mixins=no-member,
590 | not-async-context-manager,
591 | not-context-manager,
592 | attribute-defined-outside-init
593 |
594 | # List of class names for which member attributes should not be checked (useful
595 | # for classes with dynamically set attributes). This supports the use of
596 | # qualified names.
597 | ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
598 |
599 | # Show a hint with possible names when a member name was not found. The aspect
600 | # of finding the hint is based on edit distance.
601 | missing-member-hint=yes
602 |
603 | # The minimum edit distance a name should have in order to be considered a
604 | # similar match for a missing member name.
605 | missing-member-hint-distance=1
606 |
607 | # The total number of similar names that should be taken in consideration when
608 | # showing a hint for a missing member.
609 | missing-member-max-choices=1
610 |
611 | # Regex pattern to define which classes are considered mixins.
612 | mixin-class-rgx=.*[Mm]ixin
613 |
614 | # List of decorators that change the signature of a decorated function.
615 | signature-mutators=
616 |
617 |
618 | [VARIABLES]
619 |
620 | # List of additional names supposed to be defined in builtins. Remember that
621 | # you should avoid defining new builtins when possible.
622 | additional-builtins=
623 |
624 | # Tells whether unused global variables should be treated as a violation.
625 | allow-global-unused-variables=yes
626 |
627 | # List of names allowed to shadow builtins
628 | allowed-redefined-builtins=
629 |
630 | # List of strings which can identify a callback function by name. A callback
631 | # name must start or end with one of those strings.
632 | callbacks=cb_,
633 | _cb
634 |
635 | # A regular expression matching the name of dummy variables (i.e. expected to
636 | # not be used).
637 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
638 |
639 | # Argument names that match this expression will be ignored.
640 | ignored-argument-names=_.*|^ignored_|^unused_
641 |
642 | # Tells whether we should check for unused import in __init__ files.
643 | init-import=no
644 |
645 | # List of qualified module names which can have objects that can redefine
646 | # builtins.
647 | redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
648 |
--------------------------------------------------------------------------------
/tests/test_train_gan.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import dataclass
3 | from typing import Any, Dict, Tuple
4 |
5 | import numpy as np
6 | import torch
7 | from torch import nn
8 | from torch.utils.data import DataLoader
9 | from torchvision import transforms
10 | from torchvision.datasets import MNIST
11 |
12 | from trainer import TrainerConfig, TrainerModel
13 | from trainer.trainer import Trainer, TrainerArgs
14 |
15 | is_cuda = torch.cuda.is_available()
16 |
17 |
18 | # pylint: skip-file
19 |
20 |
21 | class Generator(nn.Module):
22 | def __init__(self, latent_dim, img_shape):
23 | super().__init__()
24 | self.img_shape = img_shape
25 |
26 | def block(in_feat, out_feat, normalize=True):
27 | layers = [nn.Linear(in_feat, out_feat)]
28 | if normalize:
29 | layers.append(nn.BatchNorm1d(out_feat, 0.8))
30 | layers.append(nn.LeakyReLU(0.2, inplace=True))
31 | return layers
32 |
33 | self.model = nn.Sequential(
34 | *block(latent_dim, 128, normalize=False),
35 | *block(128, 256),
36 | *block(256, 512),
37 | *block(512, 1024),
38 | nn.Linear(1024, int(np.prod(img_shape))),
39 | nn.Tanh(),
40 | )
41 |
42 | def forward(self, z):
43 | img = self.model(z)
44 | img = img.view(img.size(0), *self.img_shape)
45 | return img
46 |
47 |
48 | class Discriminator(nn.Module):
49 | def __init__(self, img_shape):
50 | super().__init__()
51 |
52 | self.model = nn.Sequential(
53 | nn.Linear(int(np.prod(img_shape)), 512),
54 | nn.LeakyReLU(0.2, inplace=True),
55 | nn.Linear(512, 256),
56 | nn.LeakyReLU(0.2, inplace=True),
57 | nn.Linear(256, 1),
58 | nn.Sigmoid(),
59 | )
60 |
61 | def forward(self, img):
62 | img_flat = img.view(img.size(0), -1)
63 | validity = self.model(img_flat)
64 |
65 | return validity
66 |
67 |
68 | def test_overfit_mnist_simple_gan():
69 | @dataclass
70 | class GANModelConfig(TrainerConfig):
71 | epochs: int = 1
72 | print_step: int = 2
73 | training_seed: int = 666
74 |
75 | class GANModel(TrainerModel):
76 | def __init__(self):
77 | super().__init__()
78 | data_shape = (1, 28, 28)
79 | self.generator = Generator(latent_dim=100, img_shape=data_shape)
80 | self.discriminator = Discriminator(img_shape=data_shape)
81 |
82 | def forward(self, x):
83 | ...
84 |
85 | def train_step(self, batch, criterion, optimizer_idx):
86 | imgs, _ = batch
87 |
88 | # sample noise
89 | z = torch.randn(imgs.shape[0], 100)
90 | z = z.type_as(imgs)
91 |
92 | # train discriminator
93 | if optimizer_idx == 0:
94 | imgs_gen = self.generator(z)
95 | logits = self.discriminator(imgs_gen.detach())
96 | fake = torch.zeros(imgs.size(0), 1)
97 | fake = fake.type_as(imgs)
98 | loss_fake = criterion(logits, fake)
99 |
100 | valid = torch.ones(imgs.size(0), 1)
101 | valid = valid.type_as(imgs)
102 | logits = self.discriminator(imgs)
103 | loss_real = loss = criterion(logits, valid)
104 | loss = (loss_real + loss_fake) / 2
105 | return {"model_outputs": logits}, {"loss": loss}
106 |
107 | # train generator
108 | if optimizer_idx == 1:
109 | imgs_gen = self.generator(z)
110 |
111 | valid = torch.ones(imgs.size(0), 1)
112 | valid = valid.type_as(imgs)
113 |
114 | logits = self.discriminator(imgs_gen)
115 | loss_real = criterion(logits, valid)
116 | return {"model_outputs": logits}, {"loss": loss_real}
117 |
118 | @torch.no_grad()
119 | def eval_step(self, batch, criterion, optimizer_idx):
120 | return self.train_step(batch, criterion, optimizer_idx)
121 |
122 | def get_optimizer(self):
123 | discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
124 | generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999))
125 | return [discriminator_optimizer, generator_optimizer]
126 |
127 | def get_criterion(self):
128 | return nn.BCELoss()
129 |
130 | def get_data_loader(
131 | self, config, assets, is_eval, samples, verbose, num_gpus, rank=0
132 | ): # pylint: disable=unused-argument
133 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
134 | dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform)
135 | dataset.data = dataset.data[:64]
136 | dataset.targets = dataset.targets[:64]
137 | dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=True)
138 | return dataloader
139 |
140 | config = GANModelConfig()
141 | config.batch_size = 64
142 | config.grad_clip = None
143 |
144 | model = GANModel()
145 | trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None)
146 |
147 | trainer.config.epochs = 1
148 | trainer.fit()
149 | loss_d1 = trainer.keep_avg_train["avg_loss_0"]
150 | loss_g1 = trainer.keep_avg_train["avg_loss_1"]
151 |
152 | trainer.config.epochs = 5
153 | trainer.fit()
154 | loss_d2 = trainer.keep_avg_train["avg_loss_0"]
155 | loss_g2 = trainer.keep_avg_train["avg_loss_1"]
156 |
157 | print(f"loss_d1: {loss_d1}, loss_d2: {loss_d2}")
158 | print(f"loss_g1: {loss_g1}, loss_g2: {loss_g2}")
159 | assert loss_d1 > loss_d2, f"Discriminator loss should decrease. {loss_d1} > {loss_d2}"
160 | assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}"
161 |
162 |
163 | def test_overfit_accelerate_mnist_simple_gan():
164 | @dataclass
165 | class GANModelConfig(TrainerConfig):
166 | epochs: int = 1
167 | print_step: int = 2
168 | training_seed: int = 666
169 |
170 | class GANModel(TrainerModel):
171 | def __init__(self):
172 | super().__init__()
173 | data_shape = (1, 28, 28)
174 | self.generator = Generator(latent_dim=100, img_shape=data_shape)
175 | self.discriminator = Discriminator(img_shape=data_shape)
176 |
177 | def forward(self, x):
178 | ...
179 |
180 | def train_step(self, batch, criterion, optimizer_idx):
181 | imgs, _ = batch
182 |
183 | # sample noise
184 | z = torch.randn(imgs.shape[0], 100)
185 | z = z.type_as(imgs)
186 |
187 | # train discriminator
188 | if optimizer_idx == 0:
189 | imgs_gen = self.generator(z)
190 | logits = self.discriminator(imgs_gen.detach())
191 | fake = torch.zeros(imgs.size(0), 1)
192 | fake = fake.type_as(imgs)
193 | loss_fake = criterion(logits, fake)
194 |
195 | valid = torch.ones(imgs.size(0), 1)
196 | valid = valid.type_as(imgs)
197 | logits = self.discriminator(imgs)
198 | loss_real = loss = criterion(logits, valid)
199 | loss = (loss_real + loss_fake) / 2
200 | return {"model_outputs": logits}, {"loss": loss}
201 |
202 | # train generator
203 | if optimizer_idx == 1:
204 | imgs_gen = self.generator(z)
205 |
206 | valid = torch.ones(imgs.size(0), 1)
207 | valid = valid.type_as(imgs)
208 |
209 | logits = self.discriminator(imgs_gen)
210 | loss_real = criterion(logits, valid)
211 | return {"model_outputs": logits}, {"loss": loss_real}
212 |
213 | @torch.no_grad()
214 | def eval_step(self, batch, criterion, optimizer_idx):
215 | return self.train_step(batch, criterion, optimizer_idx)
216 |
217 | def get_optimizer(self):
218 | discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
219 | generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999))
220 | return [discriminator_optimizer, generator_optimizer]
221 |
222 | def get_criterion(self):
223 | return nn.BCELoss()
224 |
225 | def get_data_loader(
226 | self, config, assets, is_eval, samples, verbose, num_gpus, rank=0
227 | ): # pylint: disable=unused-argument
228 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
229 | dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform)
230 | dataset.data = dataset.data[:64]
231 | dataset.targets = dataset.targets[:64]
232 | dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=False)
233 | return dataloader
234 |
235 | config = GANModelConfig()
236 | config.batch_size = 64
237 | config.grad_clip = None
238 | config.training_seed = 333
239 |
240 | model = GANModel()
241 | trainer = Trainer(
242 | TrainerArgs(use_accelerate=True), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None
243 | )
244 |
245 | trainer.eval_epoch()
246 | loss_d1 = trainer.keep_avg_eval["avg_loss_0"]
247 | loss_g1 = trainer.keep_avg_eval["avg_loss_1"]
248 |
249 | trainer.config.epochs = 5
250 | trainer.fit()
251 | loss_d2 = trainer.keep_avg_train["avg_loss_0"]
252 | loss_g2 = trainer.keep_avg_train["avg_loss_1"]
253 |
254 | print(f"loss_d1: {loss_d1}, loss_d2: {loss_d2}")
255 | print(f"loss_g1: {loss_g1}, loss_g2: {loss_g2}")
256 | assert loss_d1 > loss_d2, f"Discriminator loss should decrease. {loss_d1} > {loss_d2}"
257 | assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}"
258 |
259 |
260 | def test_overfit_manual_optimize_mnist_simple_gan():
261 | @dataclass
262 | class GANModelConfig(TrainerConfig):
263 | epochs: int = 1
264 | print_step: int = 2
265 | training_seed: int = 666
266 |
267 | class GANModel(TrainerModel):
268 | def __init__(self):
269 | super().__init__()
270 | data_shape = (1, 28, 28)
271 | self.generator = Generator(latent_dim=100, img_shape=data_shape)
272 | self.discriminator = Discriminator(img_shape=data_shape)
273 |
274 | def forward(self, x):
275 | ...
276 |
277 | def optimize(self, batch, trainer):
278 | imgs, _ = batch
279 |
280 | # sample noise
281 | z = torch.randn(imgs.shape[0], 100)
282 | z = z.type_as(imgs)
283 |
284 | # train discriminator
285 | imgs_gen = self.generator(z)
286 | logits = self.discriminator(imgs_gen.detach())
287 | fake = torch.zeros(imgs.size(0), 1)
288 | fake = fake.type_as(imgs)
289 | loss_fake = trainer.criterion(logits, fake)
290 |
291 | valid = torch.ones(imgs.size(0), 1)
292 | valid = valid.type_as(imgs)
293 | logits = self.discriminator(imgs)
294 | loss_real = trainer.criterion(logits, valid)
295 | loss_disc = (loss_real + loss_fake) / 2
296 |
297 | # step dicriminator
298 | trainer.optimizer[0].zero_grad()
299 | self.scaled_backward(loss_disc, trainer, trainer.optimizer[0])
300 | trainer.optimizer[0].step()
301 |
302 | # train generator
303 | imgs_gen = self.generator(z)
304 |
305 | valid = torch.ones(imgs.size(0), 1)
306 | valid = valid.type_as(imgs)
307 |
308 | logits = self.discriminator(imgs_gen)
309 | loss_gen = trainer.criterion(logits, valid)
310 |
311 | # step generator
312 | trainer.optimizer[1].zero_grad()
313 | self.scaled_backward(loss_gen, trainer, trainer.optimizer[1])
314 | trainer.optimizer[1].step()
315 | return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc}
316 |
317 | @torch.no_grad()
318 | def eval_step(self, batch, trainer):
319 | imgs, _ = batch
320 |
321 | # sample noise
322 | z = torch.randn(imgs.shape[0], 100)
323 | z = z.type_as(imgs)
324 |
325 | imgs_gen = self.generator(z)
326 | valid = torch.ones(imgs.size(0), 1)
327 | valid = valid.type_as(imgs)
328 |
329 | logits = self.discriminator(imgs_gen)
330 | loss_gen = trainer.criterion(logits, valid)
331 | return {"model_outputs": logits}, {"loss_gen": loss_gen}
332 |
333 | def get_optimizer(self):
334 | discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
335 | generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999))
336 | return [discriminator_optimizer, generator_optimizer]
337 |
338 | def get_criterion(self):
339 | return nn.BCELoss()
340 |
341 | def get_data_loader(
342 | self, config, assets, is_eval, samples, verbose, num_gpus, rank=0
343 | ): # pylint: disable=unused-argument
344 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
345 | dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform)
346 | dataset.data = dataset.data[:64]
347 | dataset.targets = dataset.targets[:64]
348 | dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=True)
349 | return dataloader
350 |
351 | config = GANModelConfig()
352 | config.batch_size = 64
353 | config.grad_clip = None
354 |
355 | model = GANModel()
356 | trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None)
357 |
358 | trainer.config.epochs = 1
359 | trainer.fit()
360 | loss_d1 = trainer.keep_avg_train["avg_loss_disc"]
361 | loss_g1 = trainer.keep_avg_train["avg_loss_gen"]
362 |
363 | trainer.config.epochs = 5
364 | trainer.fit()
365 | loss_d2 = trainer.keep_avg_train["avg_loss_disc"]
366 | loss_g2 = trainer.keep_avg_train["avg_loss_gen"]
367 |
368 | print(f"loss_d1: {loss_d1}, loss_d2: {loss_d2}")
369 | print(f"loss_g1: {loss_g1}, loss_g2: {loss_g2}")
370 | assert loss_d1 > loss_d2, f"Discriminator loss should decrease. {loss_d1} > {loss_d2}"
371 | assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}"
372 |
373 |
374 | def test_overfit_manual_optimize_grad_accum_mnist_simple_gan():
375 | @dataclass
376 | class GANModelConfig(TrainerConfig):
377 | epochs: int = 1
378 | print_step: int = 2
379 | training_seed: int = 666
380 |
381 | class GANModel(TrainerModel):
382 | def __init__(self):
383 | super().__init__()
384 | data_shape = (1, 28, 28)
385 | self.generator = Generator(latent_dim=100, img_shape=data_shape)
386 | self.discriminator = Discriminator(img_shape=data_shape)
387 |
388 | def forward(self, x):
389 | ...
390 |
391 | def optimize(self, batch, trainer):
392 | imgs, _ = batch
393 |
394 | # sample noise
395 | z = torch.randn(imgs.shape[0], 100)
396 | z = z.type_as(imgs)
397 |
398 | # train discriminator
399 | imgs_gen = self.generator(z)
400 | logits = self.discriminator(imgs_gen.detach())
401 | fake = torch.zeros(imgs.size(0), 1)
402 | fake = fake.type_as(imgs)
403 | loss_fake = trainer.criterion(logits, fake)
404 |
405 | valid = torch.ones(imgs.size(0), 1)
406 | valid = valid.type_as(imgs)
407 | logits = self.discriminator(imgs)
408 | loss_real = trainer.criterion(logits, valid)
409 | loss_disc = (loss_real + loss_fake) / 2
410 |
411 | # step dicriminator
412 | self.scaled_backward(loss_disc, trainer, trainer.optimizer[0])
413 |
414 | if trainer.total_steps_done % trainer.grad_accum_steps == 0:
415 | trainer.optimizer[0].step()
416 | trainer.optimizer[0].zero_grad()
417 |
418 | # train generator
419 | imgs_gen = self.generator(z)
420 |
421 | valid = torch.ones(imgs.size(0), 1)
422 | valid = valid.type_as(imgs)
423 |
424 | logits = self.discriminator(imgs_gen)
425 | loss_gen = trainer.criterion(logits, valid)
426 |
427 | # step generator
428 | self.scaled_backward(loss_gen, trainer, trainer.optimizer[1])
429 | if trainer.total_steps_done % trainer.grad_accum_steps == 0:
430 | trainer.optimizer[1].step()
431 | trainer.optimizer[1].zero_grad()
432 | return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc}
433 |
434 | @torch.no_grad()
435 | def eval_step(self, batch, criterion):
436 | imgs, _ = batch
437 |
438 | # sample noise
439 | z = torch.randn(imgs.shape[0], 100)
440 | z = z.type_as(imgs)
441 |
442 | imgs_gen = self.generator(z)
443 | valid = torch.ones(imgs.size(0), 1)
444 | valid = valid.type_as(imgs)
445 |
446 | logits = self.discriminator(imgs_gen)
447 | loss_gen = trainer.criterion(logits, valid)
448 | return {"model_outputs": logits}, {"loss_gen": loss_gen}
449 |
450 | def get_optimizer(self):
451 | discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
452 | generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999))
453 | return [discriminator_optimizer, generator_optimizer]
454 |
455 | def get_criterion(self):
456 | return nn.BCELoss()
457 |
458 | def get_data_loader(
459 | self, config, assets, is_eval, samples, verbose, num_gpus, rank=0
460 | ): # pylint: disable=unused-argument
461 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
462 | dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform)
463 | dataset.data = dataset.data[:64]
464 | dataset.targets = dataset.targets[:64]
465 | dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=True)
466 | return dataloader
467 |
468 | config = GANModelConfig()
469 | config.batch_size = 64
470 | config.grad_clip = None
471 |
472 | model = GANModel()
473 | trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None)
474 |
475 | trainer.config.epochs = 1
476 | trainer.fit()
477 | loss_d1 = trainer.keep_avg_train["avg_loss_disc"]
478 | loss_g1 = trainer.keep_avg_train["avg_loss_gen"]
479 |
480 | trainer.config.epochs = 5
481 | trainer.fit()
482 | loss_d2 = trainer.keep_avg_train["avg_loss_disc"]
483 | loss_g2 = trainer.keep_avg_train["avg_loss_gen"]
484 |
485 | print(f"loss_d1: {loss_d1}, loss_d2: {loss_d2}")
486 | print(f"loss_g1: {loss_g1}, loss_g2: {loss_g2}")
487 | assert loss_d1 > loss_d2, f"Discriminator loss should decrease. {loss_d1} > {loss_d2}"
488 | assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}"
489 |
490 |
491 | def test_overfit_manual_accelerate_optimize_grad_accum_mnist_simple_gan():
492 | @dataclass
493 | class GANModelConfig(TrainerConfig):
494 | epochs: int = 1
495 | print_step: int = 2
496 | training_seed: int = 666
497 |
498 | class GANModel(TrainerModel):
499 | def __init__(self):
500 | super().__init__()
501 | data_shape = (1, 28, 28)
502 | self.generator = Generator(latent_dim=100, img_shape=data_shape)
503 | self.discriminator = Discriminator(img_shape=data_shape)
504 |
505 | def train_step():
506 | ...
507 |
508 | def forward(self, x):
509 | ...
510 |
511 | def optimize(self, batch, trainer):
512 | imgs, _ = batch
513 |
514 | # sample noise
515 | z = torch.randn(imgs.shape[0], 100)
516 | z = z.type_as(imgs)
517 |
518 | # train discriminator
519 | imgs_gen = self.generator(z)
520 | logits = self.discriminator(imgs_gen.detach())
521 | fake = torch.zeros(imgs.size(0), 1)
522 | fake = fake.type_as(imgs)
523 | loss_fake = trainer.criterion(logits, fake)
524 |
525 | valid = torch.ones(imgs.size(0), 1)
526 | valid = valid.type_as(imgs)
527 | logits = self.discriminator(imgs)
528 | loss_real = trainer.criterion(logits, valid)
529 | loss_disc = (loss_real + loss_fake) / 2
530 |
531 | # step dicriminator
532 | self.scaled_backward(loss_disc, trainer, trainer.optimizer[0])
533 |
534 | if trainer.total_steps_done % trainer.grad_accum_steps == 0:
535 | trainer.optimizer[0].step()
536 | trainer.optimizer[0].zero_grad()
537 |
538 | # train generator
539 | imgs_gen = self.generator(z)
540 |
541 | valid = torch.ones(imgs.size(0), 1)
542 | valid = valid.type_as(imgs)
543 |
544 | logits = self.discriminator(imgs_gen)
545 | loss_gen = trainer.criterion(logits, valid)
546 |
547 | # step generator
548 | self.scaled_backward(loss_gen, trainer, trainer.optimizer[1])
549 | if trainer.total_steps_done % trainer.grad_accum_steps == 0:
550 | trainer.optimizer[1].step()
551 | trainer.optimizer[1].zero_grad()
552 | return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc}
553 |
554 | @torch.no_grad()
555 | def eval_step(self, batch, criterion):
556 | imgs, _ = batch
557 |
558 | # sample noise
559 | z = torch.randn(imgs.shape[0], 100)
560 | z = z.type_as(imgs)
561 |
562 | imgs_gen = self.generator(z)
563 | valid = torch.ones(imgs.size(0), 1)
564 | valid = valid.type_as(imgs)
565 |
566 | logits = self.discriminator(imgs_gen)
567 | loss_gen = trainer.criterion(logits, valid)
568 | return {"model_outputs": logits}, {"loss_gen": loss_gen}
569 |
570 | def get_optimizer(self):
571 | discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
572 | generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999))
573 | return [discriminator_optimizer, generator_optimizer]
574 |
575 | def get_criterion(self):
576 | return nn.BCELoss()
577 |
578 | def get_data_loader(
579 | self, config, assets, is_eval, samples, verbose, num_gpus, rank=0
580 | ): # pylint: disable=unused-argument
581 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
582 | dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform)
583 | dataset.data = dataset.data[:64]
584 | dataset.targets = dataset.targets[:64]
585 | dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=True)
586 | return dataloader
587 |
588 | config = GANModelConfig()
589 | config.batch_size = 64
590 | config.grad_clip = None
591 |
592 | model = GANModel()
593 | trainer = Trainer(
594 | TrainerArgs(use_accelerate=True), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None
595 | )
596 |
597 | trainer.config.epochs = 1
598 | trainer.fit()
599 | loss_d1 = trainer.keep_avg_train["avg_loss_disc"]
600 | loss_g1 = trainer.keep_avg_train["avg_loss_gen"]
601 |
602 | trainer.config.epochs = 5
603 | trainer.fit()
604 | loss_d2 = trainer.keep_avg_train["avg_loss_disc"]
605 | loss_g2 = trainer.keep_avg_train["avg_loss_gen"]
606 |
607 | print(f"loss_d1: {loss_d1}, loss_d2: {loss_d2}")
608 | print(f"loss_g1: {loss_g1}, loss_g2: {loss_g2}")
609 | assert loss_d1 > loss_d2, f"Discriminator loss should decrease. {loss_d1} > {loss_d2}"
610 | assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}"
611 |
612 |
613 | if __name__ == "__main__":
614 | test_overfit_mnist_simple_gan()
615 | test_overfit_accelerate_mnist_simple_gan()
616 | test_overfit_manual_optimize_mnist_simple_gan()
617 | test_overfit_manual_optimize_grad_accum_mnist_simple_gan()
618 | test_overfit_manual_accelerate_optimize_grad_accum_mnist_simple_gan()
619 |
--------------------------------------------------------------------------------