├── trainer ├── VERSION ├── utils │ ├── __init__.py │ ├── cpu_memory.py │ ├── distributed.py │ └── cuda_memory.py ├── logger.py ├── __init__.py ├── analytics.py ├── TODO.txt ├── logging │ ├── __init__.py │ ├── dummy_logger.py │ ├── clearml_logger.py │ ├── base_dash_logger.py │ ├── tensorboard_logger.py │ ├── wandb_logger.py │ ├── console_logger.py │ ├── mlflow_logger.py │ └── aim_logger.py ├── distribute.py ├── torch.py ├── generic_utils.py ├── trainer_utils.py ├── model.py ├── callbacks.py └── io.py ├── requirements.test.txt ├── requirements.txt ├── requirements.dev.txt ├── setup.cfg ├── tests ├── __init__.py ├── test_train_mnist.py ├── test_train_batch_size_finder.py ├── utils │ ├── train_mnist.py │ └── mnist.py ├── test_continue_train.py ├── test_lr_schedulers.py ├── test_num_gpus.py └── test_train_gan.py ├── .github ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── feature_request.md │ └── bug_report.yaml └── workflows │ ├── style_check.yml │ ├── tests.yml │ └── pypi-release.yml ├── MANIFEST.in ├── pyproject.toml ├── Makefile ├── bin └── collect_env_info.py ├── .gitignore ├── examples ├── train_mnist.py └── train_simple_gan.py ├── CONTRIBUTING.md ├── setup.py ├── CODE_OF_CONDUCT.md ├── README.md └── .pylintrc /trainer/VERSION: -------------------------------------------------------------------------------- 1 | v0.0.36 2 | -------------------------------------------------------------------------------- /trainer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.test.txt: -------------------------------------------------------------------------------- 1 | torchvision -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7 2 | coqpit 3 | psutil 4 | fsspec 5 | tensorboard 6 | soundfile -------------------------------------------------------------------------------- /requirements.dev.txt: -------------------------------------------------------------------------------- 1 | black 2 | coverage 3 | isort 4 | pytest 5 | pylint 6 | accelerate # for testing 7 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [build_py] 2 | build-lib=temp_build 3 | 4 | [bdist_wheel] 5 | bdist-dir=temp_build 6 | 7 | [install_lib] 8 | build-dir=temp_build 9 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def run_cli(command): 5 | exit_status = os.system(command) 6 | assert exit_status == 0, f" [!] command `{command}` failed." 7 | -------------------------------------------------------------------------------- /trainer/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | handler = logging.StreamHandler() 4 | handler.setFormatter(logging.Formatter("")) 5 | logger = logging.getLogger("trainer") 6 | logger.addHandler(handler) 7 | logger.setLevel(logging.INFO) 8 | logger.propagate = False 9 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from trainer.model import * 4 | from trainer.trainer import * 5 | 6 | with open(os.path.join(os.path.dirname(__file__), "VERSION"), "r", encoding="utf-8") as f: 7 | version = f.read().strip() 8 | 9 | __version__ = version 10 | -------------------------------------------------------------------------------- /trainer/analytics.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import requests 4 | 5 | telemetry = os.environ.get("TRAINER_TELEMETRY") 6 | 7 | 8 | def ping_training_run(): 9 | if telemetry == "0": 10 | return 11 | URL = "https://coqui.gateway.scarf.sh/trainer/training_run" 12 | _ = requests.get(URL, timeout=5) 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: 👟 GitHub Discussions 4 | url: https://github.com/coqui-ai/Trainer/discussions 5 | about: Please ask and answer questions here. 6 | - name: Coqui Security issue disclosure 7 | url: mailto:info@coqui.ai 8 | about: Please report security vulnerabilities here. 9 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE.txt 3 | include requirements.*.txt 4 | include requirements.txt 5 | include trainer/VERSION 6 | recursive-include trainer *.json 7 | recursive-include trainer *.html 8 | recursive-include trainer *.png 9 | recursive-include trainer *.md 10 | recursive-include trainer *.py 11 | recursive-include trainer *.pyx 12 | recursive-include images *.png 13 | -------------------------------------------------------------------------------- /trainer/TODO.txt: -------------------------------------------------------------------------------- 1 | + Accumulate gradients b/w batches. 2 | + Abstract DashLogger 3 | + MLFlow logger 4 | + Profiler integration. 5 | + Moving `training_assets` to the model implementation. 6 | - Wrap model for not calling .module in DDP. 7 | - Overfitting to a batch. 8 | - TPU training 9 | - BaseTrainingConfig 10 | - Add Checkpoint manager 11 | - Use `logging` instead of `print` 12 | - Auto scaling the batch size and find the largest batch size for training. 13 | - Stochastic weight averaging 14 | - Deepspeed integration 15 | -------------------------------------------------------------------------------- /tests/test_train_mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from tests.utils.mnist import MnistModel, MnistModelConfig 6 | from trainer import Trainer, TrainerArgs 7 | 8 | is_cuda = torch.cuda.is_available() 9 | 10 | 11 | def test_train_mnist(): 12 | model = MnistModel() 13 | trainer = Trainer( 14 | TrainerArgs(), MnistModelConfig(), model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None 15 | ) 16 | 17 | trainer.fit() 18 | loss1 = trainer.keep_avg_train["avg_loss"] 19 | 20 | trainer.fit() 21 | loss2 = trainer.keep_avg_train["avg_loss"] 22 | 23 | assert loss1 > loss2 24 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | 4 | [flake8] 5 | max-line-length=120 6 | 7 | [tool.black] 8 | line-length = 120 9 | target-version = ['py38'] 10 | exclude = ''' 11 | 12 | ( 13 | /( 14 | \.eggs # exclude a few common directories in the 15 | | \.git # root of the project 16 | | \.hg 17 | | \.mypy_cache 18 | | \.tox 19 | | \.venv 20 | | _build 21 | | buck-out 22 | | build 23 | | dist 24 | )/ 25 | | foo.py # also separately exclude a file named foo.py in 26 | # the root of the project 27 | ) 28 | ''' 29 | 30 | [tool.isort] 31 | profile = "black" 32 | multi_line_output = 3 -------------------------------------------------------------------------------- /tests/test_train_batch_size_finder.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from tests.utils.mnist import MnistModel, MnistModelConfig 6 | from trainer import Trainer, TrainerArgs 7 | 8 | is_cuda = torch.cuda.is_available() 9 | 10 | 11 | def test_train_largest_batch_mnist(): 12 | model = MnistModel() 13 | trainer = Trainer( 14 | TrainerArgs(), MnistModelConfig(), model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None 15 | ) 16 | 17 | trainer.fit_with_largest_batch_size(starting_batch_size=2048) 18 | loss1 = trainer.keep_avg_train["avg_loss"] 19 | 20 | trainer.fit_with_largest_batch_size(starting_batch_size=2048) 21 | loss2 = trainer.keep_avg_train["avg_loss"] 22 | 23 | assert loss1 > loss2 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🚀 Feature request 3 | about: Suggest a feature or an idea for this project 4 | title: '[Feature request] ' 5 | labels: feature request 6 | assignees: '' 7 | 8 | --- 9 | 11 | **🚀 Feature Description** 12 | 13 | 14 | 15 | **Solution** 16 | 17 | 18 | 19 | **Alternative Solutions** 20 | 21 | 22 | 23 | **Additional context** 24 | 25 | 26 | -------------------------------------------------------------------------------- /tests/utils/train_mnist.py: -------------------------------------------------------------------------------- 1 | from distutils.command.config import config 2 | 3 | from mnist import MnistModel, MnistModelConfig 4 | 5 | from trainer import Trainer, TrainerArgs 6 | 7 | 8 | def main(): 9 | """Run `MNIST` model training from scratch or from previous checkpoint.""" 10 | # init args and config 11 | train_args = TrainerArgs() 12 | config = MnistModelConfig() 13 | 14 | # init the model from config 15 | model = MnistModel() 16 | 17 | # init the trainer and 🚀 18 | trainer = Trainer( 19 | train_args, 20 | config, 21 | config.output_path, 22 | model=model, 23 | train_samples=model.get_data_loader(config, None, False, None, None, None), 24 | eval_samples=model.get_data_loader(config, None, True, None, None, None), 25 | parse_command_line_args=True, 26 | ) 27 | trainer.fit() 28 | 29 | 30 | if __name__ == "__main__": 31 | main() 32 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_GOAL := help 2 | .PHONY: test system-deps dev-deps deps style lint install help docs 3 | 4 | help: 5 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 6 | 7 | target_dirs := tests trainer 8 | 9 | test_all: ## run tests and don't stop on an error. 10 | coverage run -m pytest trainer tests 11 | 12 | test: ## run tests. 13 | coverage run -m pytest -x trainer tests 14 | 15 | test_failed: ## only run tests failed the last time. 16 | coverage run -m pytest --ff trainer tests 17 | 18 | style: ## update code style. 19 | black ${target_dirs} 20 | isort ${target_dirs} 21 | 22 | lint: ## run pylint linter. 23 | pylint ${target_dirs} 24 | 25 | dev-deps: ## install development deps 26 | pip install -r requirements.dev.txt 27 | 28 | doc-deps: ## install docs dependencies 29 | pip install -r docs/requirements.txt 30 | 31 | build-docs: ## build the docs 32 | cd docs && make clean && make build 33 | 34 | deps: ## install 🐸 requirements. 35 | pip install -r requirements.txt 36 | 37 | install: ## install 🐸 Trainer for development. 38 | pip install -e .[all] 39 | 40 | docs: ## build the docs 41 | $(MAKE) -C docs clean && $(MAKE) -C docs html 42 | -------------------------------------------------------------------------------- /tests/test_continue_train.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import shutil 4 | 5 | from tests import run_cli 6 | 7 | 8 | def test_continue_train(): 9 | output_path = "output/" 10 | 11 | command_train = "python tests/utils/train_mnist.py" 12 | run_cli(command_train) 13 | 14 | continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) 15 | number_of_checkpoints = len(glob.glob(os.path.join(continue_path, "*.pth"))) 16 | 17 | # Continue training from the best model 18 | command_continue = f"python tests/utils/train_mnist.py --continue_path {continue_path} --coqpit.run_eval_steps=1" 19 | run_cli(command_continue) 20 | 21 | assert number_of_checkpoints < len(glob.glob(os.path.join(continue_path, "*.pth"))) 22 | 23 | # Continue training from the last checkpoint 24 | for best in glob.glob(os.path.join(continue_path, "best_model*")): 25 | os.remove(best) 26 | run_cli(command_continue) 27 | 28 | # Continue training from a specific checkpoint 29 | restore_path = os.path.join(continue_path, "checkpoint_5.pth") 30 | command_continue = f"python tests/utils/train_mnist.py --restore_path {restore_path}" 31 | run_cli(command_continue) 32 | shutil.rmtree(continue_path) 33 | -------------------------------------------------------------------------------- /bin/collect_env_info.py: -------------------------------------------------------------------------------- 1 | """Get detailed info about the working environment.""" 2 | import os 3 | import platform 4 | import sys 5 | 6 | import numpy 7 | import torch 8 | 9 | sys.path += [os.path.abspath(".."), os.path.abspath(".")] 10 | import json 11 | 12 | import trainer 13 | 14 | 15 | def system_info(): 16 | return { 17 | "OS": platform.system(), 18 | "architecture": platform.architecture(), 19 | "version": platform.version(), 20 | "processor": platform.processor(), 21 | "python": platform.python_version(), 22 | } 23 | 24 | 25 | def cuda_info(): 26 | return { 27 | "GPU": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())], 28 | "available": torch.cuda.is_available(), 29 | "version": torch.version.cuda, 30 | } 31 | 32 | 33 | def package_info(): 34 | return { 35 | "numpy": numpy.__version__, 36 | "PyTorch_version": torch.__version__, 37 | "PyTorch_debug": torch.version.debug, 38 | "Trainer": trainer.__version__, 39 | } 40 | 41 | 42 | def main(): 43 | details = {"System": system_info(), "CUDA": cuda_info(), "Packages": package_info()} 44 | print(json.dumps(details, indent=4, sort_keys=True)) 45 | 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /tests/test_lr_schedulers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import torch 5 | 6 | from tests.utils.mnist import MnistModel, MnistModelConfig 7 | from trainer import Trainer, TrainerArgs 8 | from trainer.generic_utils import KeepAverage 9 | 10 | is_cuda = torch.cuda.is_available() 11 | 12 | 13 | def test_train_mnist(): 14 | model = MnistModel() 15 | # Test StepwiseGradualLR 16 | config = MnistModelConfig( 17 | lr_scheduler="StepwiseGradualLR", 18 | lr_scheduler_params={ 19 | "gradual_learning_rates": [ 20 | [0, 1e-3], 21 | [2, 1e-4], 22 | ] 23 | }, 24 | scheduler_after_epoch=False, 25 | ) 26 | trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None) 27 | trainer.train_loader = trainer.get_train_dataloader( 28 | trainer.training_assets, 29 | trainer.train_samples, 30 | verbose=True, 31 | ) 32 | trainer.keep_avg_train = KeepAverage() 33 | 34 | lr_0 = trainer.scheduler.get_lr() 35 | trainer.train_step(next(iter(trainer.train_loader)), len(trainer.train_loader), 0, time.time()) 36 | lr_1 = trainer.scheduler.get_lr() 37 | trainer.train_step(next(iter(trainer.train_loader)), len(trainer.train_loader), 1, time.time()) 38 | lr_2 = trainer.scheduler.get_lr() 39 | assert lr_0 == 1e-3 40 | assert lr_1 == 1e-3 41 | assert lr_2 == 1e-4 42 | -------------------------------------------------------------------------------- /.github/workflows/style_check.yml: -------------------------------------------------------------------------------- 1 | name: style-check 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | types: [opened, synchronize, reopened] 9 | jobs: 10 | check_skip: 11 | runs-on: ubuntu-latest 12 | if: "! contains(github.event.head_commit.message, '[ci skip]')" 13 | steps: 14 | - run: echo "${{ github.event.head_commit.message }}" 15 | 16 | test: 17 | runs-on: ubuntu-latest 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | python-version: [3.9] 22 | experimental: [false] 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v4 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | architecture: x64 30 | cache: 'pip' 31 | cache-dependency-path: 'requirements*' 32 | - name: check OS 33 | run: cat /etc/os-release 34 | - name: Install dependencies 35 | run: | 36 | sudo apt-get update 37 | sudo apt-get install -y git make gcc 38 | make system-deps 39 | - name: Install/upgrade Python setup deps 40 | run: python3 -m pip install --upgrade pip setuptools wheel 41 | - name: Install Trainer 42 | run: | 43 | python3 -m pip install .[all] 44 | python3 setup.py egg_info 45 | - name: Lint check 46 | run: | 47 | make lint -------------------------------------------------------------------------------- /trainer/utils/cpu_memory.py: -------------------------------------------------------------------------------- 1 | def get_available_cpu_memory(): 2 | import psutil # pylint: disable=import-outside-toplevel 3 | 4 | this_process = psutil.Process() 5 | available_memory = psutil.virtual_memory().available 6 | 7 | try: 8 | import resource # pylint: disable=import-outside-toplevel 9 | 10 | _, hard_mem_limit = resource.getrlimit(resource.RLIMIT_AS) # pylint: disable=unused-variable 11 | if hard_mem_limit != resource.RLIM_INFINITY: 12 | used_memory = this_process.memory_info().vms 13 | available_memory = min(hard_mem_limit - used_memory, available_memory) 14 | except ImportError: 15 | pass 16 | 17 | return available_memory 18 | 19 | 20 | def set_cpu_memory_limit(num_gigabytes): 21 | try: 22 | import resource # pylint: disable=import-outside-toplevel 23 | 24 | num_bytes = int(num_gigabytes * 2**30) 25 | _, hard_limit = resource.getrlimit(resource.RLIMIT_AS) 26 | if hard_limit != resource.RLIM_INFINITY: 27 | hard_limit = min(num_bytes, hard_limit) 28 | else: 29 | hard_limit = num_bytes 30 | resource.setrlimit(resource.RLIMIT_AS, (hard_limit, hard_limit)) 31 | except ImportError: 32 | pass 33 | 34 | 35 | def is_out_of_cpu_memory(exception): 36 | return ( 37 | isinstance(exception, RuntimeError) 38 | and len(exception.args) == 1 39 | and "DefaultCPUAllocator: can't allocate memory" in exception.args[0] 40 | ) 41 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | types: [opened, synchronize, reopened] 9 | jobs: 10 | check_skip: 11 | runs-on: ubuntu-latest 12 | if: "! contains(github.event.head_commit.message, '[ci skip]')" 13 | steps: 14 | - run: echo "${{ github.event.head_commit.message }}" 15 | 16 | test: 17 | runs-on: ubuntu-latest 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | python-version: [3.8, 3.9, "3.10", "3.11"] 22 | experimental: [false] 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v4 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | architecture: x64 30 | cache: 'pip' 31 | cache-dependency-path: 'requirements*' 32 | - name: check OS 33 | run: cat /etc/os-release 34 | - name: Telemetry off 35 | run: | 36 | export TRAINER_TELEMETRY=0 37 | - name: Install dependencies 38 | run: | 39 | sudo apt-get update 40 | sudo apt-get install -y --no-install-recommends git make gcc 41 | make system-deps 42 | - name: Install/upgrade Python setup deps 43 | run: python3 -m pip install --upgrade pip setuptools wheel 44 | - name: Install Trainer 45 | run: | 46 | python3 -m pip install .[all] 47 | python3 setup.py egg_info 48 | - name: Unit tests 49 | run: make test_all 50 | -------------------------------------------------------------------------------- /trainer/utils/distributed.py: -------------------------------------------------------------------------------- 1 | # edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py 2 | import os 3 | from functools import wraps 4 | from typing import Any, Callable, Optional 5 | 6 | import torch 7 | import torch.distributed as dist 8 | 9 | 10 | def is_dist_avail_and_initialized(): 11 | if not dist.is_available(): 12 | return False 13 | if not dist.is_initialized(): 14 | return False 15 | return True 16 | 17 | 18 | def get_rank(): 19 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") 20 | for key in rank_keys: 21 | rank = os.environ.get(key) 22 | if rank is not None: 23 | return int(rank) 24 | return 0 25 | 26 | 27 | def is_main_process(): 28 | return get_rank() == 0 29 | 30 | 31 | def rank_zero_only(fn: Callable) -> Callable: 32 | @wraps(fn) 33 | def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: 34 | if is_main_process(): 35 | return fn(*args, **kwargs) 36 | return None 37 | 38 | return wrapped_fn 39 | 40 | 41 | @rank_zero_only 42 | def rank_zero_print(message: str, *args, **kwargs) -> None: # pylint: disable=unused-argument 43 | print(message) 44 | 45 | 46 | @rank_zero_only 47 | def rank_zero_logger_info(message: str, logger: "Logger", *args, **kwargs) -> None: # pylint: disable=unused-argument 48 | logger.info(message) 49 | 50 | 51 | def reduce_tensor(tensor, num_gpus): 52 | rt = tensor.clone() 53 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 54 | rt /= num_gpus 55 | return rt 56 | 57 | 58 | def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url): 59 | assert torch.cuda.is_available(), "Distributed mode requires CUDA." 60 | 61 | # Set cuda device so everything is done on the right GPU. 62 | torch.cuda.set_device(rank % torch.cuda.device_count()) 63 | 64 | # Initialize distributed communication 65 | dist.init_process_group( 66 | dist_backend, 67 | init_method=dist_url, 68 | world_size=num_gpus, 69 | rank=rank, 70 | group_name=group_name, 71 | ) 72 | -------------------------------------------------------------------------------- /tests/utils/mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.utils.data import DataLoader 8 | from torchvision import transforms 9 | from torchvision.datasets import MNIST 10 | 11 | from trainer import TrainerConfig, TrainerModel 12 | 13 | 14 | @dataclass 15 | class MnistModelConfig(TrainerConfig): 16 | optimizer: str = "Adam" 17 | lr: float = 0.001 18 | epochs: int = 1 19 | print_step: int = 1 20 | save_step: int = 5 21 | plot_step: int = 5 22 | dashboard_logger: str = "tensorboard" 23 | 24 | 25 | class MnistModel(TrainerModel): 26 | def __init__(self): 27 | super().__init__() 28 | 29 | # mnist images are (1, 28, 28) (channels, height, width) 30 | self.layer_1 = nn.Linear(28 * 28, 128) 31 | self.layer_2 = nn.Linear(128, 256) 32 | self.layer_3 = nn.Linear(256, 10) 33 | 34 | def forward(self, x): 35 | batch_size, _, _, _ = x.size() 36 | 37 | # (b, 1, 28, 28) -> (b, 1*28*28) 38 | x = x.view(batch_size, -1) 39 | x = self.layer_1(x) 40 | x = F.relu(x) 41 | x = self.layer_2(x) 42 | x = F.relu(x) 43 | x = self.layer_3(x) 44 | 45 | x = F.log_softmax(x, dim=1) 46 | return x 47 | 48 | def train_step(self, batch, criterion): 49 | x, y = batch 50 | logits = self(x) 51 | loss = criterion(logits, y) 52 | return {"model_outputs": logits}, {"loss": loss} 53 | 54 | def eval_step(self, batch, criterion): 55 | x, y = batch 56 | logits = self(x) 57 | loss = criterion(logits, y) 58 | return {"model_outputs": logits}, {"loss": loss} 59 | 60 | @staticmethod 61 | def get_criterion(): 62 | return torch.nn.NLLLoss() 63 | 64 | def get_data_loader( 65 | self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 66 | ): # pylint: disable=unused-argument 67 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 68 | dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) 69 | dataset.data = dataset.data[:256] 70 | dataset.targets = dataset.targets[:256] 71 | dataloader = DataLoader(dataset, batch_size=config.batch_size) 72 | return dataloader 73 | -------------------------------------------------------------------------------- /trainer/logging/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from trainer.logging.console_logger import ConsoleLogger 5 | from trainer.logging.dummy_logger import DummyLogger 6 | 7 | # pylint: disable=import-outside-toplevel 8 | 9 | 10 | logger = logging.getLogger("trainer") 11 | 12 | 13 | def get_mlflow_tracking_url(): 14 | if "MLFLOW_TRACKING_URI" in os.environ: 15 | return os.environ["MLFLOW_TRACKING_URI"] 16 | return None 17 | 18 | 19 | def get_ai_repo_url(): 20 | if "AIM_TRACKING_URI" in os.environ: 21 | return os.environ["AIM_TRACKING_URI"] 22 | return None 23 | 24 | 25 | def logger_factory(config, output_path): 26 | run_name = config.run_name 27 | project_name = config.project_name 28 | log_uri = config.logger_uri if config.logger_uri else output_path 29 | 30 | if config.dashboard_logger == "tensorboard": 31 | from trainer.logging.tensorboard_logger import TensorboardLogger 32 | 33 | model_name = f"{project_name}@{run_name}" if project_name else run_name 34 | dashboard_logger = TensorboardLogger(log_uri, model_name=model_name) 35 | 36 | logger.info(" > Start Tensorboard: tensorboard --logdir=%s", log_uri) 37 | 38 | elif config.dashboard_logger == "wandb": 39 | from trainer.logging.wandb_logger import WandbLogger 40 | 41 | dashboard_logger = WandbLogger( # pylint: disable=abstract-class-instantiated 42 | project=project_name, 43 | name=run_name, 44 | config=config, 45 | entity=config.wandb_entity, 46 | ) 47 | 48 | elif config.dashboard_logger == "mlflow": 49 | from trainer.logging.mlflow_logger import MLFlowLogger 50 | 51 | dashboard_logger = MLFlowLogger(log_uri=log_uri, model_name=project_name) 52 | 53 | elif config.dashboard_logger == "aim": 54 | from trainer.logging.aim_logger import AimLogger 55 | 56 | dashboard_logger = AimLogger(repo=log_uri, model_name=project_name) 57 | 58 | elif config.dashboard_logger == "clearml": 59 | from trainer.logging.clearml_logger import ClearMLLogger 60 | 61 | dashboard_logger = ClearMLLogger( 62 | output_uri=log_uri, local_path=output_path, project_name=project_name, task_name=run_name 63 | ) 64 | 65 | else: 66 | raise ValueError(f"Unknown dashboard logger: {config.dashboard_logger}") 67 | 68 | return dashboard_logger 69 | -------------------------------------------------------------------------------- /trainer/distribute.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import pathlib 6 | import subprocess 7 | import time 8 | 9 | from trainer import TrainerArgs, logger 10 | 11 | 12 | def distribute(): 13 | """ 14 | Call 👟Trainer training script in DDP mode. 15 | """ 16 | parser = TrainerArgs().init_argparse(arg_prefix="") 17 | parser.add_argument("--script", type=str, help="Target training script to distibute.") 18 | parser.add_argument( 19 | "--gpus", 20 | type=str, 21 | help='GPU IDs to be used for distributed training in the format ```"0,1"```. Used if ```CUDA_VISIBLE_DEVICES``` is not set.', 22 | ) 23 | args, unargs = parser.parse_known_args() 24 | 25 | gpus = get_gpus(args) 26 | 27 | group_id = time.strftime("%Y_%m_%d-%H%M%S") 28 | 29 | # set arguments for train.py 30 | folder_path = pathlib.Path(__file__).parent.absolute() 31 | if os.path.exists(os.path.join(folder_path, args.script)): 32 | command = [os.path.join(folder_path, args.script)] 33 | else: 34 | command = [args.script] 35 | 36 | # Pass all the TrainerArgs fields 37 | command.append(f"--continue_path={args.continue_path}") 38 | command.append(f"--restore_path={args.restore_path}") 39 | command.append(f"--group_id=group_{group_id}") 40 | command.append("--use_ddp=true") 41 | command += unargs 42 | command.append("") 43 | 44 | # run processes 45 | processes = [] 46 | for rank, local_gpu_id in enumerate(gpus): 47 | my_env = os.environ.copy() 48 | my_env["PYTHON_EGG_CACHE"] = f"/tmp/tmp{local_gpu_id}" 49 | my_env["RANK"] = f"{rank}" 50 | my_env["CUDA_VISIBLE_DEVICES"] = f"{','.join(gpus)}" 51 | command[-1] = f"--rank={rank}" 52 | # prevent stdout for processes with rank != 0 53 | stdout = None 54 | p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with 55 | processes.append(p) 56 | logger.info(command) 57 | 58 | for p in processes: 59 | p.wait() 60 | 61 | 62 | def get_gpus(args): 63 | # set active gpus from CUDA_VISIBLE_DEVICES or --gpus 64 | if "CUDA_VISIBLE_DEVICES" in os.environ and os.environ["CUDA_VISIBLE_DEVICES"] != "": 65 | gpus = os.environ["CUDA_VISIBLE_DEVICES"] 66 | else: 67 | gpus = args.gpus 68 | gpus = list(map(str.strip, gpus.split(","))) 69 | return gpus 70 | 71 | 72 | if __name__ == "__main__": 73 | distribute() 74 | -------------------------------------------------------------------------------- /tests/test_num_gpus.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from argparse import Namespace 4 | from unittest import TestCase, mock 5 | 6 | from trainer import TrainerArgs 7 | from trainer.distribute import get_gpus 8 | 9 | 10 | class TestGpusStringParsingMethods(TestCase): 11 | @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) 12 | def test_parse_gpus_set_in_env_var_and_args(self): 13 | args = Namespace(gpus="0,1") 14 | gpus = get_gpus(args) 15 | expected_value = ["0"] 16 | self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value)) 17 | 18 | @mock.patch.dict(os.environ, {}) 19 | def test_parse_gpus_set_in_args(self): 20 | _old = None 21 | # this is to handle the case when CUDA_VISIBLE_DEVICES is set while running the tests 22 | if "CUDA_VISIBLE_DEVICES" in os.environ: 23 | _old = os.environ["CUDA_VISIBLE_DEVICES"] 24 | del os.environ["CUDA_VISIBLE_DEVICES"] 25 | args = Namespace(gpus="0,1") 26 | gpus = get_gpus(args) 27 | expected_value = ["0", "1"] 28 | self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value)) 29 | if _old is not None: 30 | os.environ["CUDA_VISIBLE_DEVICES"] = _old 31 | 32 | @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) 33 | def test_parse_gpus_set_in_env_var(self): 34 | args = Namespace() 35 | gpus = get_gpus(args) 36 | expected_value = ["0", "1"] 37 | self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value)) 38 | 39 | @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0, 1 "}) 40 | def test_parse_gpus_set_in_env_var_with_spaces(self): 41 | args = Namespace() 42 | gpus = get_gpus(args) 43 | expected_value = ["0", "1"] 44 | self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value)) 45 | 46 | @mock.patch.dict(os.environ, {}) 47 | def test_parse_gpus_set_in_args_with_spaces(self): 48 | args = Namespace(gpus="0, 1, 2, 3 ") 49 | gpus = get_gpus(args) 50 | expected_value = ["0", "1", "2", "3"] 51 | self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value)) 52 | 53 | 54 | def msg_for_test_failure(expected_value): 55 | return "GPU Values are expected to be " + str(expected_value) 56 | 57 | 58 | def create_args_parser(): 59 | parser = TrainerArgs().init_argparse(arg_prefix="") 60 | parser.add_argument("--gpus", type=str) 61 | return parser 62 | 63 | 64 | if __name__ == "__main__": 65 | unittest.main() 66 | -------------------------------------------------------------------------------- /trainer/logging/dummy_logger.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | from trainer.logging.base_dash_logger import BaseDashboardLogger 4 | 5 | 6 | class DummyLogger(BaseDashboardLogger): 7 | """DummyLogger that implements the API but does nothing""" 8 | 9 | def add_scalar(self, title: str, value: float, step: int) -> None: 10 | pass 11 | 12 | def add_figure( 13 | self, 14 | title: str, 15 | figure: Union["matplotlib.figure.Figure", "plotly.graph_objects.Figure"], 16 | step: int, 17 | ) -> None: 18 | pass 19 | 20 | def add_config(self, config): 21 | pass 22 | 23 | def add_audio(self, title: str, audio: "np.ndarray", step: int, sample_rate: int) -> None: 24 | pass 25 | 26 | def add_text(self, title: str, text: str, step: int) -> None: 27 | pass 28 | 29 | def add_artifact(self, file_or_dir: str, name: str, artifact_type: str, aliases=None): 30 | pass 31 | 32 | def add_scalars(self, scope_name: str, scalars: Dict, step: int): 33 | pass 34 | 35 | def add_figures(self, scope_name: str, figures: Dict, step: int): 36 | pass 37 | 38 | def add_audios(self, scope_name: str, audios: Dict, step: int, sample_rate: int): 39 | pass 40 | 41 | def flush(self): 42 | pass 43 | 44 | def finish(self): 45 | pass 46 | 47 | def train_step_stats(self, step, stats): 48 | self.add_scalars(scope_name="TrainIterStats", scalars=stats, step=step) 49 | 50 | def train_epoch_stats(self, step, stats): 51 | self.add_scalars(scope_name="TrainEpochStats", scalars=stats, step=step) 52 | 53 | def train_figures(self, step, figures): 54 | self.add_figures(scope_name="TrainFigures", figures=figures, step=step) 55 | 56 | def train_audios(self, step, audios, sample_rate): 57 | self.add_audios(scope_name="TrainAudios", audios=audios, step=step, sample_rate=sample_rate) 58 | 59 | def eval_stats(self, step, stats): 60 | self.add_scalars(scope_name="EvalStats", scalars=stats, step=step) 61 | 62 | def eval_figures(self, step, figures): 63 | self.add_figures(scope_name="EvalFigures", figures=figures, step=step) 64 | 65 | def eval_audios(self, step, audios, sample_rate): 66 | self.add_audios(scope_name="EvalAudios", audios=audios, step=step, sample_rate=sample_rate) 67 | 68 | def test_audios(self, step, audios, sample_rate): 69 | self.add_audios(scope_name="TestAudios", audios=audios, step=step, sample_rate=sample_rate) 70 | 71 | def test_figures(self, step, figures): 72 | self.add_figures(scope_name="TestFigures", figures=figures, step=step) 73 | -------------------------------------------------------------------------------- /trainer/logging/clearml_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any 3 | 4 | import torch 5 | 6 | from trainer.logging.tensorboard_logger import TensorboardLogger 7 | from trainer.trainer_utils import is_clearml_available 8 | from trainer.utils.distributed import rank_zero_only 9 | 10 | if is_clearml_available(): 11 | from clearml import Task # pylint: disable=import-error 12 | else: 13 | raise ImportError("ClearML is not installed. Please install it with `pip install clearml`") 14 | 15 | 16 | class ClearMLLogger(TensorboardLogger): 17 | """ClearML Logger using TensorBoard in the background. 18 | 19 | TODO: 20 | - Add hyperparameter handling 21 | - Use ClearML logger for plots 22 | - Handle continuing training 23 | 24 | Args: 25 | output_uri (str): URI of the ClearML repository. 26 | local_path (str): Path to the local directory where the model is saved. 27 | project_name (str): Name of the ClearML project. 28 | task_name (str): Name of the ClearML task. 29 | tags (str): Comma separated list of tags to add to the ClearML task. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | output_uri: str, 35 | local_path: str, 36 | project_name: str, 37 | task_name: str, 38 | tags: str = None, 39 | ): 40 | self._context = None 41 | self.local_path = local_path 42 | self.task_name = task_name 43 | self.tags = tags.split(",") if tags else [] 44 | self.run = Task.init(project_name=project_name, task_name=task_name, tags=self.tags, output_uri=output_uri) 45 | 46 | if tags: 47 | for tag in tags.split(","): 48 | self.run.add_tag(tag) 49 | 50 | super().__init__("run", None) 51 | 52 | @rank_zero_only 53 | def add_config(self, config): 54 | """Upload config file(s) to ClearML.""" 55 | self.add_text("run_config", f"{config.to_json()}", 0) 56 | self.run.connect_configuration(name="model_config", configuration=config.to_dict()) 57 | self.run.set_comment(config.run_description) 58 | self.run.upload_artifact("model_config", config.to_dict()) 59 | self.run.upload_artifact("configs", artifact_object=os.path.join(self.local_path, "*.json")) 60 | 61 | @rank_zero_only 62 | def add_artifact(self, file_or_dir, name, **kwargs): # pylint: disable=unused-argument, arguments-differ 63 | """Upload artifact to ClearML.""" 64 | self.run.upload_artifact(name, artifact_object=file_or_dir) 65 | 66 | @staticmethod 67 | @rank_zero_only 68 | def save_model(state: Any, path: str): 69 | torch.save(state, path) 70 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: "🐛 Bug report" 2 | description: Create a bug report to help 👟 improve 3 | title: '[Bug] ' 4 | labels: [ "bug" ] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Welcome to the 👟! Thanks for taking the time to fill out this bug report! 10 | 11 | - type: textarea 12 | id: bug-description 13 | attributes: 14 | label: Describe the bug 15 | description: A clear and concise description of what the bug is. If you intend to submit a PR for this issue, tell us in the description. Thanks! 16 | placeholder: Bug description 17 | validations: 18 | required: true 19 | 20 | - type: textarea 21 | id: reproduction 22 | attributes: 23 | label: To Reproduce 24 | description: | 25 | Please share your code to reproduce the error. 26 | 27 | Issues are fixed faster if you can provide a working example. 28 | 29 | The best place for sharing code is colab. https://colab.research.google.com/ 30 | So we can directly run your code and reproduce the issue. 31 | 32 | In the worse case, provide steps to reproduce the behavior. 33 | 34 | 1. Run the following command '...' 35 | 2. ... 36 | 3. See error 37 | placeholder: Reproduction 38 | validations: 39 | required: true 40 | 41 | - type: textarea 42 | id: expected-behavior 43 | attributes: 44 | label: Expected behavior 45 | description: "Write down what the expected behaviour" 46 | 47 | - type: textarea 48 | id: logs 49 | attributes: 50 | label: Logs 51 | description: "Please include the relevant logs if you can." 52 | render: shell 53 | 54 | - type: textarea 55 | id: system-info 56 | attributes: 57 | label: Environment 58 | description: | 59 | You can either run `trainer/bin/collect_env_info.py` 60 | 61 | ```bash 62 | wget https://raw.githubusercontent.com/coqui-ai/Trainer/main/bin/collect_env_info.py 63 | python collect_env_info.py 64 | ``` 65 | 66 | or fill in the fields below manually. 67 | render: shell 68 | placeholder: | 69 | - 👟 Version (e.g., 1.3.0): 70 | - PyTorch Version (e.g., 1.8) 71 | - Python version: 72 | - OS (e.g., Linux): 73 | - CUDA/cuDNN version: 74 | - GPU models and configuration: 75 | - How you installed PyTorch (`conda`, `pip`, source): 76 | - Any other relevant information: 77 | validations: 78 | required: true 79 | - type: textarea 80 | id: context 81 | attributes: 82 | label: Additional context 83 | description: Add any other context about the problem here. 84 | validations: 85 | required: false 86 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .noseids 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # pytype static type analyzer 136 | .pytype/ 137 | 138 | # Cython debug symbols 139 | cython_debug/ 140 | 141 | # custom list 142 | MNIST/ 143 | tests_local/ 144 | output/ 145 | -------------------------------------------------------------------------------- /.github/workflows/pypi-release.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distributions 📦 to PyPI 2 | on: 3 | release: 4 | types: [published] 5 | defaults: 6 | run: 7 | shell: 8 | bash 9 | jobs: 10 | build-sdist: 11 | runs-on: ubuntu-20.04 12 | steps: 13 | - uses: actions/checkout@v3 14 | - name: Verify tag matches version 15 | run: | 16 | set -ex 17 | version=$(cat trainer/VERSION) 18 | tag="${GITHUB_REF/refs\/tags\/}" 19 | if [[ "$version" != "$tag" ]]; then 20 | exit 1 21 | fi 22 | - uses: actions/checkout@v3 23 | with: 24 | python-version: 3.9 25 | - run: | 26 | python -m pip install -U pip setuptools wheel build 27 | - run: | 28 | python -m build 29 | - run: | 30 | pip install dist/*.tar.gz 31 | - uses: actions/upload-artifact@v2 32 | with: 33 | name: sdist 34 | path: dist/*.tar.gz 35 | build-wheels: 36 | runs-on: ubuntu-20.04 37 | strategy: 38 | matrix: 39 | python-version: ["3.8", "3.9", "3.10", "3.11"] 40 | steps: 41 | - uses: actions/checkout@v3 42 | - uses: actions/setup-python@v4 43 | with: 44 | python-version: ${{ matrix.python-version }} 45 | - run: | 46 | python -m pip install -U pip setuptools wheel build 47 | - run: | 48 | python -m build 49 | - run: | 50 | python -m pip install dist/*.whl 51 | - uses: actions/upload-artifact@v2 52 | with: 53 | name: wheel-${{ matrix.python-version }} 54 | path: dist/*.whl 55 | publish-artifacts: 56 | runs-on: ubuntu-20.04 57 | needs: [build-sdist, build-wheels] 58 | steps: 59 | - run: | 60 | mkdir dist 61 | - uses: actions/download-artifact@v2 62 | with: 63 | name: "sdist" 64 | path: "dist/" 65 | - uses: actions/download-artifact@v2 66 | with: 67 | name: "wheel-3.8" 68 | path: "dist/" 69 | - uses: actions/download-artifact@v2 70 | with: 71 | name: "wheel-3.9" 72 | path: "dist/" 73 | - uses: actions/download-artifact@v2 74 | with: 75 | name: "wheel-3.10" 76 | path: "dist/" 77 | - uses: actions/download-artifact@v2 78 | with: 79 | name: "wheel-3.11" 80 | path: "dist/" 81 | - run: | 82 | ls -lh dist/ 83 | - name: Setup PyPI config 84 | run: | 85 | cat << EOF > ~/.pypirc 86 | [pypi] 87 | username=__token__ 88 | password=${{ secrets.PYPI_TOKEN }} 89 | EOF 90 | - uses: actions/setup-python@v2 91 | with: 92 | python-version: 3.9 93 | - run: | 94 | python -m pip install twine 95 | - run: | 96 | twine upload --repository pypi dist/* 97 | -------------------------------------------------------------------------------- /trainer/utils/cuda_memory.py: -------------------------------------------------------------------------------- 1 | """ 2 | credit: https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py 3 | 4 | Helper to free Torch cuda memory and determine when a Torch exception might be 5 | because of OOM conditions. 6 | """ 7 | from __future__ import print_function 8 | 9 | import gc 10 | 11 | import torch 12 | 13 | from trainer.utils.cpu_memory import is_out_of_cpu_memory 14 | 15 | 16 | def gc_cuda(): 17 | """Gargage collect Torch (CUDA) memory.""" 18 | gc.collect() 19 | if torch.cuda.is_available(): 20 | torch.cuda.empty_cache() 21 | 22 | 23 | def get_cuda_total_memory(): 24 | if torch.cuda.is_available(): 25 | return torch.cuda.get_device_properties(0).total_memory 26 | return 0 27 | 28 | 29 | def get_cuda_assumed_available_memory(): 30 | if torch.cuda.is_available(): 31 | return get_cuda_total_memory() - torch.cuda.memory_reserved() 32 | return 0 33 | 34 | 35 | def get_cuda_available_memory(): 36 | # Always allow for 1 GB overhead. 37 | if torch.cuda.is_available(): 38 | return get_cuda_assumed_available_memory() - get_cuda_blocked_memory() 39 | return 0 40 | 41 | 42 | def get_cuda_blocked_memory(): 43 | if not torch.cuda.is_available(): 44 | return 0 45 | 46 | available_memory = get_cuda_assumed_available_memory() 47 | current_block = available_memory - 2**28 # 256 MB steps 48 | while True: 49 | try: 50 | _ = torch.empty((current_block,), dtype=torch.uint8, device="cuda") 51 | break 52 | except RuntimeError as exception: 53 | if is_cuda_out_of_memory(exception): 54 | current_block -= 2**30 55 | if current_block <= 0: 56 | return available_memory 57 | else: 58 | raise 59 | _ = None 60 | gc_cuda() 61 | return available_memory - current_block 62 | 63 | 64 | def is_cuda_out_of_memory(exception): 65 | return ( 66 | isinstance(exception, (RuntimeError, torch.cuda.OutOfMemoryError)) 67 | and len(exception.args) == 1 68 | and "CUDA out of memory." in exception.args[0] 69 | ) 70 | 71 | 72 | def is_cudnn_snafu(exception): 73 | # For/because of https://github.com/pytorch/pytorch/issues/4107 74 | return ( 75 | isinstance(exception, RuntimeError) 76 | and len(exception.args) == 1 77 | and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0] 78 | ) 79 | 80 | 81 | def cuda_meminfo(): 82 | if not torch.cuda.is_available(): 83 | return 84 | 85 | print( 86 | "Total:", torch.cuda.memory_allocated() / 2**30, " GB Cached: ", torch.cuda.memory_reserved() / 2**30, "GB" 87 | ) 88 | print( 89 | "Max Total:", 90 | torch.cuda.max_memory_allocated() / 2**30, 91 | " GB Max Cached: ", 92 | torch.cuda.max_memory_reserved() / 2**30, 93 | "GB", 94 | ) 95 | 96 | 97 | def should_reduce_batch_size(exception): 98 | return is_cuda_out_of_memory(exception) or is_cudnn_snafu(exception) or is_out_of_cpu_memory(exception) 99 | -------------------------------------------------------------------------------- /trainer/logging/base_dash_logger.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Union 3 | 4 | from trainer.io import save_fsspec 5 | from trainer.utils.distributed import rank_zero_only 6 | 7 | 8 | # pylint: disable=too-many-public-methods 9 | class BaseDashboardLogger(ABC): 10 | @abstractmethod 11 | def add_scalar(self, title: str, value: float, step: int) -> None: 12 | pass 13 | 14 | @abstractmethod 15 | def add_figure( 16 | self, 17 | title: str, 18 | figure: Union["matplotlib.figure.Figure", "plotly.graph_objects.Figure"], 19 | step: int, 20 | ) -> None: 21 | pass 22 | 23 | @abstractmethod 24 | def add_config(self, config): 25 | pass 26 | 27 | @abstractmethod 28 | def add_audio(self, title: str, audio: "np.ndarray", step: int, sample_rate: int) -> None: 29 | pass 30 | 31 | @abstractmethod 32 | def add_text(self, title: str, text: str, step: int) -> None: 33 | pass 34 | 35 | @abstractmethod 36 | def add_artifact(self, file_or_dir: str, name: str, artifact_type: str, aliases=None): 37 | pass 38 | 39 | @abstractmethod 40 | def add_scalars(self, scope_name: str, scalars: Dict, step: int): 41 | pass 42 | 43 | @abstractmethod 44 | def add_figures(self, scope_name: str, figures: Dict, step: int): 45 | pass 46 | 47 | @abstractmethod 48 | def add_audios(self, scope_name: str, audios: Dict, step: int, sample_rate: int): 49 | pass 50 | 51 | @abstractmethod 52 | def flush(self): 53 | pass 54 | 55 | @abstractmethod 56 | def finish(self): 57 | pass 58 | 59 | @staticmethod 60 | @rank_zero_only 61 | def save_model(state: Dict, path: str): 62 | save_fsspec(state, path) 63 | 64 | def train_step_stats(self, step, stats): 65 | self.add_scalars(scope_name="TrainIterStats", scalars=stats, step=step) 66 | 67 | def train_epoch_stats(self, step, stats): 68 | self.add_scalars(scope_name="TrainEpochStats", scalars=stats, step=step) 69 | 70 | def train_figures(self, step, figures): 71 | self.add_figures(scope_name="TrainFigures", figures=figures, step=step) 72 | 73 | def train_audios(self, step, audios, sample_rate): 74 | self.add_audios(scope_name="TrainAudios", audios=audios, step=step, sample_rate=sample_rate) 75 | 76 | def eval_stats(self, step, stats): 77 | self.add_scalars(scope_name="EvalStats", scalars=stats, step=step) 78 | 79 | def eval_figures(self, step, figures): 80 | self.add_figures(scope_name="EvalFigures", figures=figures, step=step) 81 | 82 | def eval_audios(self, step, audios, sample_rate): 83 | self.add_audios(scope_name="EvalAudios", audios=audios, step=step, sample_rate=sample_rate) 84 | 85 | def test_audios(self, step, audios, sample_rate): 86 | self.add_audios(scope_name="TestAudios", audios=audios, step=step, sample_rate=sample_rate) 87 | 88 | def test_figures(self, step, figures): 89 | self.add_figures(scope_name="TestFigures", figures=figures, step=step) 90 | -------------------------------------------------------------------------------- /trainer/logging/tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | 3 | from torch.utils.tensorboard import SummaryWriter 4 | 5 | from trainer.logging.base_dash_logger import BaseDashboardLogger 6 | 7 | 8 | class TensorboardLogger(BaseDashboardLogger): 9 | def __init__(self, log_dir, model_name): 10 | self.model_name = model_name 11 | self.writer = SummaryWriter(log_dir) 12 | 13 | def model_weights(self, model, step): 14 | layer_num = 1 15 | for name, param in model.named_parameters(): 16 | if param.numel() == 1: 17 | self.writer.add_scalar("layer{}-{}/value".format(layer_num, name), param.max(), step) 18 | else: 19 | self.writer.add_scalar("layer{}-{}/max".format(layer_num, name), param.max(), step) 20 | self.writer.add_scalar("layer{}-{}/min".format(layer_num, name), param.min(), step) 21 | self.writer.add_scalar("layer{}-{}/mean".format(layer_num, name), param.mean(), step) 22 | self.writer.add_scalar("layer{}-{}/std".format(layer_num, name), param.std(), step) 23 | self.writer.add_histogram("layer{}-{}/param".format(layer_num, name), param, step) 24 | self.writer.add_histogram("layer{}-{}/grad".format(layer_num, name), param.grad, step) 25 | layer_num += 1 26 | 27 | def add_config(self, config): 28 | self.add_text("model-config", f"
{config.to_json()}
", 0) 29 | 30 | def add_scalar(self, title: str, value: float, step: int) -> None: 31 | self.writer.add_scalar(title, value, step) 32 | 33 | def add_audio(self, title, audio, step, sample_rate): 34 | self.writer.add_audio(title, audio, step, sample_rate=sample_rate) 35 | 36 | def add_text(self, title, text, step): 37 | self.writer.add_text(title, text, step) 38 | 39 | def add_figure(self, title, figure, step): 40 | self.writer.add_figure(title, figure, step) 41 | 42 | def add_artifact(self, file_or_dir, name, artifact_type, aliases=None): # pylint: disable=W0613 43 | yield 44 | 45 | def add_scalars(self, scope_name, scalars, step): 46 | for key, value in scalars.items(): 47 | self.add_scalar("{}/{}".format(scope_name, key), value, step) 48 | 49 | def add_figures(self, scope_name, figures, step): 50 | for key, value in figures.items(): 51 | self.writer.add_figure("{}/{}".format(scope_name, key), value, step) 52 | 53 | def add_audios(self, scope_name, audios, step, sample_rate): 54 | for key, value in audios.items(): 55 | if value.dtype == "float16": 56 | value = value.astype("float32") 57 | try: 58 | self.add_audio( 59 | "{}/{}".format(scope_name, key), 60 | value, 61 | step, 62 | sample_rate=sample_rate, 63 | ) 64 | except RuntimeError: 65 | traceback.print_exc() 66 | 67 | def flush(self): 68 | self.writer.flush() 69 | 70 | def finish(self): 71 | self.writer.close() 72 | -------------------------------------------------------------------------------- /examples/train_mnist.py: -------------------------------------------------------------------------------- 1 | """ 2 | This example shows training of a simple Conv model with MNIST dataset using Auto Training mode of 👟. 3 | """ 4 | 5 | import os 6 | from dataclasses import dataclass 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from torch.utils.data import DataLoader 12 | from torchvision import transforms 13 | from torchvision.datasets import MNIST 14 | 15 | from trainer import TrainerConfig, TrainerModel, Trainer, TrainerArgs 16 | 17 | 18 | @dataclass 19 | class MnistModelConfig(TrainerConfig): 20 | optimizer: str = "Adam" 21 | lr: float = 0.001 22 | epochs: int = 1 23 | print_step: int = 1 24 | save_step: int = 5 25 | plot_step: int = 5 26 | dashboard_logger: str = "tensorboard" 27 | 28 | 29 | class MnistModel(TrainerModel): 30 | def __init__(self): 31 | super().__init__() 32 | 33 | # mnist images are (1, 28, 28) (channels, height, width) 34 | self.layer_1 = nn.Linear(28 * 28, 128) 35 | self.layer_2 = nn.Linear(128, 256) 36 | self.layer_3 = nn.Linear(256, 10) 37 | 38 | def forward(self, x): 39 | batch_size, _, _, _ = x.size() 40 | 41 | # (b, 1, 28, 28) -> (b, 1*28*28) 42 | x = x.view(batch_size, -1) 43 | x = self.layer_1(x) 44 | x = F.relu(x) 45 | x = self.layer_2(x) 46 | x = F.relu(x) 47 | x = self.layer_3(x) 48 | 49 | x = F.log_softmax(x, dim=1) 50 | return x 51 | 52 | def train_step(self, batch, criterion): 53 | x, y = batch 54 | logits = self(x) 55 | loss = criterion(logits, y) 56 | return {"model_outputs": logits}, {"loss": loss} 57 | 58 | def eval_step(self, batch, criterion): 59 | x, y = batch 60 | logits = self(x) 61 | loss = criterion(logits, y) 62 | return {"model_outputs": logits}, {"loss": loss} 63 | 64 | @staticmethod 65 | def get_criterion(): 66 | return torch.nn.NLLLoss() 67 | 68 | def get_data_loader( 69 | self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 70 | ): # pylint: disable=unused-argument 71 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 72 | dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) 73 | dataset.data = dataset.data[:256] 74 | dataset.targets = dataset.targets[:256] 75 | dataloader = DataLoader(dataset, batch_size=config.batch_size) 76 | return dataloader 77 | 78 | 79 | def main(): 80 | """Run `MNIST` model training from scratch or from previous checkpoint.""" 81 | # init args and config 82 | train_args = TrainerArgs() 83 | config = MnistModelConfig() 84 | 85 | # init the model from config 86 | model = MnistModel() 87 | 88 | # init the trainer and 🚀 89 | trainer = Trainer( 90 | train_args, 91 | config, 92 | config.output_path, 93 | model=model, 94 | train_samples=model.get_data_loader(config, None, False, None, None, None), 95 | eval_samples=model.get_data_loader(config, None, True, None, None, None), 96 | parse_command_line_args=True, 97 | ) 98 | trainer.fit() 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /trainer/logging/wandb_logger.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=W0613 2 | 3 | import traceback 4 | from collections import defaultdict 5 | from pathlib import Path 6 | from typing import Union 7 | 8 | from trainer.logging.base_dash_logger import BaseDashboardLogger 9 | from trainer.trainer_utils import is_wandb_available 10 | from trainer.utils.distributed import rank_zero_only 11 | 12 | if is_wandb_available(): 13 | import wandb # pylint: disable=import-error 14 | 15 | 16 | class WandbLogger(BaseDashboardLogger): 17 | def __init__(self, **kwargs): 18 | if not wandb: 19 | raise RuntimeError("install wandb using `pip install wandb` to use WandbLogger") 20 | 21 | self.run = None 22 | self.run = wandb.init(**kwargs) if not wandb.run else wandb.run 23 | self.model_name = self.run.config.model 24 | # dictionary of dictionaries - record stats per step 25 | self.log_dict = defaultdict(dict) 26 | 27 | def model_weights(self, model, step): 28 | layer_num = 1 29 | for name, param in model.named_parameters(): 30 | if param.numel() == 1: 31 | self.add_scalars("weights", {"layer{}-{}/value".format(layer_num, name): param.max()}, step) 32 | else: 33 | self.add_scalars("weights", {"layer{}-{}/max".format(layer_num, name): param.max()}, step) 34 | self.add_scalars("weights", {"layer{}-{}/min".format(layer_num, name): param.min()}, step) 35 | self.add_scalars("weights", {"layer{}-{}/mean".format(layer_num, name): param.mean()}, step) 36 | self.add_scalars("weights", {"layer{}-{}/std".format(layer_num, name): param.std()}, step) 37 | self.log_dict[step]["weights/layer{}-{}/param".format(layer_num, name)] = wandb.Histogram(param) 38 | self.log_dict[step]["weights/layer{}-{}/grad".format(layer_num, name)] = wandb.Histogram(param.grad) 39 | layer_num += 1 40 | 41 | def add_scalars(self, scope_name, scalars, step): 42 | for key, value in scalars.items(): 43 | self.log_dict[step]["{}/{}".format(scope_name, key)] = value 44 | 45 | def add_figures(self, scope_name, figures, step): 46 | for key, value in figures.items(): 47 | self.log_dict[step]["{}/{}".format(scope_name, key)] = wandb.Image(value) 48 | 49 | def add_audios(self, scope_name, audios, step, sample_rate): 50 | for key, value in audios.items(): 51 | if value.dtype == "float16": 52 | value = value.astype("float32") 53 | try: 54 | self.log_dict[step]["{}/{}".format(scope_name, key)] = wandb.Audio(value, sample_rate=sample_rate) 55 | except RuntimeError: 56 | traceback.print_exc() 57 | 58 | def add_text(self, title, text, step): 59 | pass 60 | 61 | def add_scalar(self, title: str, value: float, step: int) -> None: 62 | pass 63 | 64 | def add_figure( 65 | self, 66 | title: str, 67 | figure: Union["matplotlib.figure.Figure", "plotly.graph_objects.Figure"], 68 | step: int, 69 | ) -> None: 70 | pass 71 | 72 | def add_audio(self, title: str, audio: "np.ndarray", step: int, sample_rate: int) -> None: 73 | pass 74 | 75 | @rank_zero_only 76 | def add_config(self, config): 77 | pass 78 | 79 | def flush(self): 80 | if self.run: 81 | for step in sorted(self.log_dict.keys()): 82 | wandb.log(self.log_dict[step], step) 83 | self.log_dict.clear() 84 | 85 | def finish(self): 86 | if self.run: 87 | self.run.finish() 88 | 89 | def add_artifact(self, file_or_dir, name, artifact_type, aliases=None): 90 | if not self.run: 91 | return 92 | name = "_".join([self.run.id, name]) 93 | artifact = wandb.Artifact(name, type=artifact_type) 94 | data_path = Path(file_or_dir) 95 | if data_path.is_dir(): 96 | artifact.add_dir(str(data_path)) 97 | elif data_path.is_file(): 98 | artifact.add_file(str(data_path)) 99 | 100 | self.run.log_artifact(artifact, aliases=aliases) 101 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution guidelines 2 | 3 | Welcome to the 👟! 4 | 5 | This repository is governed by [the Contributor Covenant Code of Conduct](https://github.com/coqui-ai/Trainer/blob/main/CODE_OF_CONDUCT.md). 6 | 7 | ## Where to start. 8 | We welcome everyone who likes to contribute to 👟. 9 | 10 | You can contribute not only with code but with bug reports, comments, questions, answers, or just a simple tweet to spread the word. 11 | 12 | If you like to contribute code, squash a bug but if you don't know where to start, here are some pointers. 13 | 14 | - [Github Issues Tracker](https://github.com/coqui-ai/Trainer/issues) 15 | 16 | This is a place to find feature requests, bugs. 17 | 18 | Issues with the ```good first issue``` tag are good place for beginners to take on. 19 | 20 | - ✨**PR**✨ [pages](https://github.com/coqui-ai/Trainer/pulls) with the ```🚀new version``` tag. 21 | 22 | We list all the target improvements for the next version. You can pick one of them and start contributing. 23 | 24 | - Also feel free to suggest new features. We're always open for new things. 25 | 26 | ## Sending a ✨**PR**✨ 27 | 28 | If you have a new feature or a bug to squash, go ahead and send a ✨**PR**✨. 29 | Please use the following steps for a ✨**PR**✨. 30 | Let us know if you encounter a problem along the way. 31 | 32 | The following steps are tested on an Ubuntu system. 33 | 34 | 1. Fork 👟[https://github.com/coqui-ai/Trainer] by clicking the fork button at the top right corner of the project page. 35 | 36 | 2. Clone 👟 and add the main repo as a new remote named ```upsteam```. 37 | 38 | ```bash 39 | $ git clone git@github.com:/Trainer.git 40 | $ cd Trainer 41 | $ git remote add upstream https://github.com/coqui-ai/Trainer.git 42 | ``` 43 | 44 | 3. Install 👟 for development. 45 | 46 | ```bash 47 | $ make install 48 | ``` 49 | 50 | 4. Create a new branch with an informative name for your goal. 51 | 52 | ```bash 53 | $ git checkout -b an_informative_name_for_my_branch 54 | ``` 55 | 56 | 5. Implement your changes on your new branch. 57 | 58 | 6. Explain your code using [Google Style](https://google.github.io/styleguide/pyguide.html#381-docstrings) docstrings. 59 | 60 | 7. Add your tests to our test suite under ```tests``` folder. It is important to show that your code works, edge cases are considered, and inform others about the intended use. 61 | 62 | 8. Run the tests to see how your updates work with the rest of the project. You can repeat this step multiple times as you implement your changes to make sure you are on the right direction. 63 | 64 | ```bash 65 | $ make test # stop at the first error 66 | $ make test_all # run all the tests, report all the errors 67 | ``` 68 | 69 | 9. Format your code. We use ```black``` for code and ```isort``` for ```import``` formatting. 70 | 71 | ```bash 72 | $ make style 73 | ``` 74 | 75 | 10. Run the linter and correct the issues raised. We use ```pylint``` for linting. It helps to enforce a coding standard, offers simple refactoring suggestions. 76 | 77 | ```bash 78 | $ make lint 79 | ``` 80 | 81 | 11. When things are good, add new files and commit your changes. 82 | 83 | ```bash 84 | $ git add my_file1.py my_file2.py ... 85 | $ git commit 86 | ``` 87 | 88 | It's a good practice to regularly sync your local copy of the project with the upstream code to keep up with the recent updates. 89 | 90 | ```bash 91 | $ git fetch upstream 92 | $ git rebase upstream/master 93 | # or for the development version 94 | $ git rebase upstream/dev 95 | ``` 96 | 97 | 12. Send a PR to ```dev``` branch. 98 | 99 | Push your branch to your fork. 100 | 101 | ```bash 102 | $ git push -u origin an_informative_name_for_my_branch 103 | ``` 104 | 105 | Then go to your fork's Github page and click on 'Pull request' to send your ✨**PR**✨. 106 | 107 | Please set ✨**PR**✨'s target branch to ```dev``` as we use ```dev``` to work on the next version. 108 | 109 | 13. Let's discuss until it is perfect. 💪 110 | 111 | We might ask you for certain changes that would appear in the ✨**PR**✨'s page under 👟[https://github.com/coqui-ai/Trainer/pulls]. 112 | 113 | 14. Once things look perfect, We merge it to the ```dev``` branch and make it ready for the next version. 114 | 115 | Feel free to ping us at any step you need help using our communication channels. 116 | 117 | If you are new to Github or open-source contribution, These are good resources. 118 | 119 | - [Github Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/proposing-changes-to-your-work-with-pull-requests) 120 | - [First-Contribution](https://github.com/firstcontributions/first-contributions) 121 | -------------------------------------------------------------------------------- /trainer/logging/console_logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | from dataclasses import dataclass 4 | 5 | from trainer.utils.distributed import rank_zero_only 6 | 7 | logger = logging.getLogger("trainer") 8 | 9 | 10 | @dataclass(frozen=True) 11 | class tcolors: 12 | OKBLUE: str = "\033[94m" 13 | HEADER: str = "\033[95m" 14 | OKGREEN: str = "\033[92m" 15 | WARNING: str = "\033[93m" 16 | FAIL: str = "\033[91m" 17 | ENDC: str = "\033[0m" 18 | BOLD: str = "\033[1m" 19 | UNDERLINE: str = "\033[4m" 20 | 21 | 22 | class ConsoleLogger: 23 | def __init__(self): 24 | # TODO: color code for value changes 25 | # use these to compare values between iterations 26 | self.old_train_loss_dict = None 27 | self.old_epoch_loss_dict = None 28 | self.old_eval_loss_dict = None 29 | 30 | @staticmethod 31 | def log_with_flush(msg: str): 32 | if logger is not None: 33 | logger.info(msg) 34 | for handler in logger.handlers: 35 | handler.flush() 36 | else: 37 | print(msg, flush=True) 38 | 39 | @staticmethod 40 | def get_time(): 41 | now = datetime.datetime.now() 42 | return now.strftime("%Y-%m-%d %H:%M:%S") 43 | 44 | @rank_zero_only 45 | def print_epoch_start(self, epoch, max_epoch, output_path=None): 46 | self.log_with_flush( 47 | "\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC), 48 | ) 49 | if output_path is not None: 50 | self.log_with_flush(f" --> {output_path}") 51 | 52 | @rank_zero_only 53 | def print_train_start(self): 54 | self.log_with_flush(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}") 55 | 56 | @rank_zero_only 57 | def print_train_step(self, batch_steps, step, global_step, loss_dict, avg_loss_dict): 58 | indent = " | > " 59 | self.log_with_flush("") 60 | log_text = "{} --> TIME: {} -- STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format( 61 | tcolors.BOLD, self.get_time(), step, batch_steps, global_step, tcolors.ENDC 62 | ) 63 | for key, value in loss_dict.items(): 64 | # print the avg value if given 65 | if f"avg_{key}" in avg_loss_dict.keys(): 66 | log_text += "{}{}: {} ({})\n".format(indent, key, str(value), str(avg_loss_dict[f"avg_{key}"])) 67 | else: 68 | log_text += "{}{}: {} \n".format(indent, key, str(value)) 69 | self.log_with_flush(log_text) 70 | 71 | # pylint: disable=unused-argument 72 | @rank_zero_only 73 | def print_train_epoch_end(self, global_step, epoch, epoch_time, print_dict): 74 | indent = " | > " 75 | log_text = f"\n{tcolors.BOLD} --> TRAIN PERFORMACE -- EPOCH TIME: {epoch_time:.2f} sec -- GLOBAL_STEP: {global_step}{tcolors.ENDC}\n" 76 | for key, value in print_dict.items(): 77 | log_text += "{}{}: {}\n".format(indent, key, str(value)) 78 | self.log_with_flush(log_text) 79 | 80 | @rank_zero_only 81 | def print_eval_start(self): 82 | self.log_with_flush(f"\n{tcolors.BOLD} > EVALUATION {tcolors.ENDC}\n") 83 | 84 | @rank_zero_only 85 | def print_eval_step(self, step, loss_dict, avg_loss_dict): 86 | indent = " | > " 87 | log_text = f"{tcolors.BOLD} --> STEP: {step}{tcolors.ENDC}\n" 88 | for key, value in loss_dict.items(): 89 | # print the avg value if given 90 | if f"avg_{key}" in avg_loss_dict.keys(): 91 | log_text += "{}{}: {} ({})\n".format(indent, key, str(value), str(avg_loss_dict[f"avg_{key}"])) 92 | else: 93 | log_text += "{}{}: {} \n".format(indent, key, str(value)) 94 | self.log_with_flush(log_text) 95 | 96 | @rank_zero_only 97 | def print_epoch_end(self, epoch, avg_loss_dict): 98 | indent = " | > " 99 | log_text = "\n {}--> EVAL PERFORMANCE{}\n".format(tcolors.BOLD, tcolors.ENDC) 100 | for key, value in avg_loss_dict.items(): 101 | # print the avg value if given 102 | color = "" 103 | sign = "+" 104 | diff = 0 105 | if self.old_eval_loss_dict is not None and key in self.old_eval_loss_dict: 106 | diff = value - self.old_eval_loss_dict[key] 107 | if diff < 0: 108 | color = tcolors.OKGREEN 109 | sign = "" 110 | elif diff > 0: 111 | color = tcolors.FAIL 112 | sign = "+" 113 | log_text += "{}{}:{} {} {}({}{})\n".format(indent, key, color, str(value), tcolors.ENDC, sign, str(diff)) 114 | self.old_eval_loss_dict = avg_loss_dict 115 | self.log_with_flush(log_text) 116 | -------------------------------------------------------------------------------- /trainer/torch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data.distributed import DistributedSampler 4 | 5 | 6 | class DistributedSamplerWrapper(DistributedSampler): 7 | """Wrapper over Sampler for distributed training. It allows you to use any sampler in distributed mode. 8 | It is especially useful in conjunction with torch.nn.parallel.DistributedDataParallel. In such a case, each 9 | process can pass a torch.utils.data.DistributedSampler instance as a torch.utils.data.DataLoader sampler, 10 | and load a subset of the original dataset that is exclusive to it. 11 | 12 | .. note: 13 | Dataset is assumed to be of constant size. 14 | 15 | Args: 16 | sampler: Sampler used for subsampling. 17 | num_replicas (int, optional): Number of processes participating in distributed training. By default, 18 | world_size is retrieved from the current distributed group. 19 | rank (int, optional): Rank of the current process within num_replicas. By default, rank is retrieved 20 | from the current distributed group. 21 | shuffle (bool, optional): If True, sampler will shuffle the indices. Default: True. 22 | seed (int, optional): random seed used to shuffle the sampler if shuffle=True. This number should be 23 | identical across all processes in the distributed group. Default: 0. 24 | 25 | Reference: https://github.com/pytorch/pytorch/issues/23430 26 | 27 | """ 28 | 29 | def __init__( 30 | self, 31 | sampler, 32 | num_replicas: int = None, 33 | rank: int = None, 34 | shuffle: bool = True, 35 | seed: int = 0, 36 | ): 37 | super().__init__( 38 | sampler, 39 | num_replicas=num_replicas, 40 | rank=rank, 41 | shuffle=shuffle, 42 | seed=seed, 43 | ) 44 | 45 | def __iter__(self): 46 | indices = list(self.dataset)[: self.total_size] 47 | 48 | # Add extra samples to make it evenly divisible 49 | indices += indices[: (self.total_size - len(indices))] 50 | assert len(indices) == self.total_size, f"{len(indices)} != {self.total_size}" 51 | 52 | # Subsample 53 | offset = self.num_samples * self.rank 54 | indices = indices[offset : offset + self.num_samples] 55 | assert len(indices) == self.num_samples, f"{len(indices)} != {self.num_samples}" 56 | 57 | return iter(indices) 58 | 59 | def set_epoch(self, epoch): 60 | super().set_epoch(epoch) 61 | if hasattr(self.dataset, "set_epoch"): 62 | self.dataset.set_epoch(epoch) 63 | elif hasattr(self.dataset, "generator"): 64 | self.dataset.generator = torch.Generator().manual_seed(self.seed + epoch) 65 | 66 | def state_dict(self): 67 | return self.dataset.state_dict() 68 | 69 | def load_state_dict(self, state_dict): 70 | self.dataset.load_state_dict(state_dict) 71 | 72 | 73 | # pylint: disable=protected-access 74 | class NoamLR(torch.optim.lr_scheduler._LRScheduler): 75 | def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1): 76 | self.warmup_steps = float(warmup_steps) 77 | super().__init__(optimizer, last_epoch) 78 | 79 | def get_lr(self): 80 | step = max(self.last_epoch, 1) 81 | return [ 82 | base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5) 83 | for base_lr in self.base_lrs 84 | ] 85 | 86 | 87 | # pylint: disable=protected-access 88 | class StepwiseGradualLR(torch.optim.lr_scheduler._LRScheduler): 89 | """Hardcoded step-wise learning rate scheduling. 90 | Necessary for CapacitronVAE""" 91 | 92 | def __init__(self, optimizer, gradual_learning_rates, last_epoch=-1): 93 | self.gradual_learning_rates = gradual_learning_rates 94 | super().__init__(optimizer, last_epoch) 95 | 96 | def get_lr(self): 97 | step = max(self.last_epoch, 1) 98 | step_thresholds = [] 99 | rates = [] 100 | for values in self.gradual_learning_rates: 101 | step_thresholds.append(values[0]) 102 | rates.append(values[1]) 103 | 104 | boolean_indeces = np.less_equal(step_thresholds, step) 105 | try: 106 | last_true = np.where(boolean_indeces == True)[0][-1] # pylint: disable=singleton-comparison 107 | except IndexError: 108 | # For the steps larger than the last step in the list 109 | pass 110 | lr = rates[np.max(last_true, 0)] 111 | 112 | # Return last lr if step is above the set threshold 113 | lr = rates[-1] if step > step_thresholds[-1] else lr 114 | # Return first lr if step is below the second threshold - first is initial lr 115 | lr = rates[0] if step < step_thresholds[1] else lr 116 | 117 | return np.tile(lr, len(self.base_lrs)) # hack? 118 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # ,*++++++*, ,*++++++*, 3 | # *++. .+++ *++. .++* 4 | # *+* ,++++* *+* *+* ,++++, *+* 5 | # ,+, .++++++++++* ,++,,,,*+, ,++++++++++. *+, 6 | # *+. .++++++++++++..++ *+.,++++++++++++. .+* 7 | # .+* ++++++++++++.*+, .+*.++++++++++++ *+, 8 | # .++ *++++++++* ++, .++.*++++++++* ++, 9 | # ,+++*. . .*++, ,++*. .*+++* 10 | # *+, .,*++**. .**++**. ,+* 11 | # .+* *+, 12 | # *+. Coqui .+* 13 | # *+* +++ Trainer +++ *+* 14 | # .+++*. . . *+++. 15 | # ,+* *+++*... ...*+++* *+, 16 | # .++. .""""+++++++****+++++++"""". ++. 17 | # ,++. **** .++, 18 | # .++* *++. 19 | # *+++, ,+++* 20 | # .,*++++::::::++++*,. 21 | # 22 | 23 | 24 | import os 25 | import subprocess 26 | import sys 27 | from distutils.version import LooseVersion 28 | 29 | import setuptools.command.build_py 30 | import setuptools.command.develop 31 | from setuptools import find_packages, setup 32 | 33 | if LooseVersion(sys.version) < LooseVersion("3.6") or LooseVersion( 34 | sys.version 35 | ) > LooseVersion("3.12"): 36 | raise RuntimeError( 37 | "Trainer requires python >= 3.6 and <=3.12 " 38 | "but your Python version is {}".format(sys.version) 39 | ) 40 | 41 | 42 | cwd = os.path.dirname(os.path.abspath(__file__)) 43 | 44 | cwd = os.path.dirname(os.path.abspath(__file__)) 45 | with open(os.path.join(cwd, "trainer", "VERSION")) as fin: 46 | version = fin.read().strip() 47 | 48 | 49 | class build_py( 50 | setuptools.command.build_py.build_py 51 | ): # pylint: disable=too-many-ancestors 52 | def run(self): 53 | setuptools.command.build_py.build_py.run(self) 54 | 55 | 56 | class develop(setuptools.command.develop.develop): 57 | def run(self): 58 | setuptools.command.develop.develop.run(self) 59 | 60 | 61 | def pip_install(package_name): 62 | subprocess.call([sys.executable, "-m", "pip", "install", package_name]) 63 | 64 | requirements = open(os.path.join(cwd, "requirements.txt"), "r").readlines() 65 | with open(os.path.join(cwd, "requirements.dev.txt"), "r") as f: 66 | requirements_dev = f.readlines() 67 | with open(os.path.join(cwd, "requirements.test.txt"), "r") as f: 68 | requirements_test = f.readlines() 69 | requirements_all = requirements + requirements_dev + requirements_test 70 | 71 | with open("README.md", "r", encoding="utf-8") as readme_file: 72 | README = readme_file.read() 73 | 74 | setup( 75 | name="trainer", 76 | version=version, 77 | url="https://github.com/coqui-ai/Trainer", 78 | author="Eren Gölge", 79 | author_email="egolge@coqui.ai", 80 | description="General purpose model trainer for PyTorch that is more flexible than it should be, by 🐸Coqui.", 81 | long_description=README, 82 | long_description_content_type="text/markdown", 83 | license="Apache2", 84 | # package 85 | include_package_data=True, 86 | packages=find_packages(include=["trainer"]), 87 | package_data={ 88 | "trainer": [ 89 | "VERSION", 90 | ] 91 | }, 92 | project_urls={ 93 | "Documentation": "https://github.com/coqui-ai/Trainer/", 94 | "Tracker": "https://github.com/coqui-ai/Trainer/issues", 95 | "Repository": "https://github.com/coqui-ai/Trainer", 96 | "Discussions": "https://github.com/coqui-ai/Trainer/discussions", 97 | }, 98 | cmdclass={ 99 | "build_py": build_py, 100 | "develop": develop, 101 | }, 102 | install_requires=requirements, 103 | extras_require={ 104 | "dev": requirements_dev, 105 | "test": requirements_test, 106 | "all": requirements_all 107 | }, 108 | python_requires=">=3.6.0, <3.12", 109 | classifiers=[ 110 | "Environment :: Console", 111 | "Natural Language :: English", 112 | # How mature is this project? Common values are 113 | # 3 - Alpha, 4 - Beta, 5 - Production/Stable 114 | "Development Status :: 3 - Alpha", 115 | # Indicate who your project is intended for 116 | "Intended Audience :: Developers", 117 | # Pick your license as you wish 118 | "License :: OSI Approved :: Apache Software License", 119 | "Operating System :: OS Independent", 120 | # Specify the Python versions you support here. In particular, ensure 121 | # that you indicate whether you support Python 2, Python 3 or both. 122 | "Programming Language :: Python :: 3.8", 123 | "Programming Language :: Python :: 3.9", 124 | "Programming Language :: Python :: 3.10", 125 | "Programming Language :: Python :: 3.11", 126 | ], 127 | zip_safe=False, 128 | ) 129 | -------------------------------------------------------------------------------- /trainer/generic_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import datetime 3 | import os 4 | import subprocess 5 | 6 | import fsspec 7 | import torch 8 | 9 | from trainer.logger import logger 10 | 11 | 12 | def isimplemented(obj, method_name): 13 | """Check if a method is implemented in a class.""" 14 | if method_name in dir(obj) and callable(getattr(obj, method_name)): 15 | try: 16 | obj.__getattribute__(method_name)() # pylint: disable=bad-option-value, unnecessary-dunder-call 17 | except NotImplementedError: 18 | return False 19 | except: # pylint: disable=bare-except 20 | return True 21 | return True 22 | return False 23 | 24 | 25 | def to_cuda(x: torch.Tensor) -> torch.Tensor: 26 | if x is None: 27 | return None 28 | if torch.is_tensor(x): 29 | x = x.contiguous() 30 | if torch.cuda.is_available(): 31 | x = x.cuda(non_blocking=True) 32 | return x 33 | 34 | 35 | def get_cuda(): 36 | use_cuda = torch.cuda.is_available() 37 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 38 | return use_cuda, device 39 | 40 | 41 | def get_git_branch(): 42 | try: 43 | out = subprocess.check_output(["git", "branch"]).decode("utf8") 44 | current = next(line for line in out.split("\n") if line.startswith("*")) 45 | current.replace("* ", "") 46 | except subprocess.CalledProcessError: 47 | current = "inside_docker" 48 | except FileNotFoundError: 49 | current = "unknown" 50 | return current 51 | 52 | 53 | def get_commit_hash(): 54 | """https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script""" 55 | try: 56 | commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip() 57 | # Not copying .git folder into docker container 58 | except (subprocess.CalledProcessError, FileNotFoundError): 59 | commit = "0000000" 60 | return commit 61 | 62 | 63 | def get_experiment_folder_path(root_path, model_name): 64 | """Get an experiment folder path with the current date and time""" 65 | date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") 66 | commit_hash = get_commit_hash() 67 | output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash) 68 | return output_folder 69 | 70 | 71 | def remove_experiment_folder(experiment_path): 72 | """Check folder if there is a checkpoint, otherwise remove the folder""" 73 | fs = fsspec.get_mapper(experiment_path).fs 74 | checkpoint_files = fs.glob(experiment_path + "/*.pth") 75 | if not checkpoint_files: 76 | if fs.exists(experiment_path): 77 | fs.rm(experiment_path, recursive=True) 78 | logger.info(" ! Run is removed from %s", experiment_path) 79 | else: 80 | logger.info(" ! Run is kept in %s", experiment_path) 81 | 82 | 83 | def count_parameters(model): 84 | r"""Count number of trainable parameters in a network""" 85 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 86 | 87 | 88 | def set_partial_state_dict(model_dict, checkpoint_state, c): 89 | # Partial initialization: if there is a mismatch with new and old layer, it is skipped. 90 | for k, v in checkpoint_state.items(): 91 | if k not in model_dict: 92 | logger.info(" | > Layer missing in the model definition: %s", k) 93 | for k in model_dict: 94 | if k not in checkpoint_state: 95 | logger.info(" | > Layer missing in the checkpoint: %s", k) 96 | for k, v in checkpoint_state.items(): 97 | if k in model_dict and v.numel() != model_dict[k].numel(): 98 | logger.info(" | > Layer dimention missmatch between model definition and checkpoint: %s", k) 99 | # 1. filter out unnecessary keys 100 | pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} 101 | # 2. filter out different size layers 102 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()} 103 | # 3. skip reinit layers 104 | if c.has("reinit_layers") and c.reinit_layers is not None: 105 | for reinit_layer_name in c.reinit_layers: 106 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} 107 | # 4. overwrite entries in the existing state dict 108 | model_dict.update(pretrained_dict) 109 | logger.info(" | > %i / %i layers are restored.", len(pretrained_dict), len(model_dict)) 110 | return model_dict 111 | 112 | 113 | class KeepAverage: 114 | def __init__(self): 115 | self.avg_values = {} 116 | self.iters = {} 117 | 118 | def __getitem__(self, key): 119 | return self.avg_values[key] 120 | 121 | def items(self): 122 | return self.avg_values.items() 123 | 124 | def add_value(self, name, init_val=0, init_iter=0): 125 | self.avg_values[name] = init_val 126 | self.iters[name] = init_iter 127 | 128 | def update_value(self, name, value, weighted_avg=False): 129 | if name not in self.avg_values: 130 | # add value if not exist before 131 | self.add_value(name, init_val=value) 132 | else: 133 | # else update existing value 134 | if weighted_avg: 135 | self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value 136 | self.iters[name] += 1 137 | else: 138 | self.avg_values[name] = self.avg_values[name] * self.iters[name] + value 139 | self.iters[name] += 1 140 | self.avg_values[name] /= self.iters[name] 141 | 142 | def add_values(self, name_dict): 143 | for key, value in name_dict.items(): 144 | self.add_value(key, init_val=value) 145 | 146 | def update_values(self, value_dict): 147 | for key, value in value_dict.items(): 148 | self.update_value(key, value) 149 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | 2 | # Contributor Covenant Code of Conduct 3 | 4 | ## Our Pledge 5 | 6 | We as members, contributors, and leaders pledge to make participation in our 7 | community a harassment-free experience for everyone, regardless of age, body 8 | size, visible or invisible disability, ethnicity, sex characteristics, gender 9 | identity and expression, level of experience, education, socio-economic status, 10 | nationality, personal appearance, race, caste, color, religion, or sexual identity 11 | and orientation. 12 | 13 | We pledge to act and interact in ways that contribute to an open, welcoming, 14 | diverse, inclusive, and healthy community. 15 | 16 | ## Our Standards 17 | 18 | Examples of behavior that contributes to a positive environment for our 19 | community include: 20 | 21 | * Demonstrating empathy and kindness toward other people 22 | * Being respectful of differing opinions, viewpoints, and experiences 23 | * Giving and gracefully accepting constructive feedback 24 | * Accepting responsibility and apologizing to those affected by our mistakes, 25 | and learning from the experience 26 | * Focusing on what is best not just for us as individuals, but for the 27 | overall community 28 | 29 | Examples of unacceptable behavior include: 30 | 31 | * The use of sexualized language or imagery, and sexual attention or 32 | advances of any kind 33 | * Trolling, insulting or derogatory comments, and personal or political attacks 34 | * Public or private harassment 35 | * Publishing others' private information, such as a physical or email 36 | address, without their explicit permission 37 | * Other conduct which could reasonably be considered inappropriate in a 38 | professional setting 39 | 40 | ## Enforcement Responsibilities 41 | 42 | Community leaders are responsible for clarifying and enforcing our standards of 43 | acceptable behavior and will take appropriate and fair corrective action in 44 | response to any behavior that they deem inappropriate, threatening, offensive, 45 | or harmful. 46 | 47 | Community leaders have the right and responsibility to remove, edit, or reject 48 | comments, commits, code, wiki edits, issues, and other contributions that are 49 | not aligned to this Code of Conduct, and will communicate reasons for moderation 50 | decisions when appropriate. 51 | 52 | ## Scope 53 | 54 | This Code of Conduct applies within all community spaces, and also applies when 55 | an individual is officially representing the community in public spaces. 56 | Examples of representing our community include using an official e-mail address, 57 | posting via an official social media account, or acting as an appointed 58 | representative at an online or offline event. 59 | 60 | ## Enforcement 61 | 62 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 63 | reported to the community leaders responsible for enforcement at 64 | coc-report@coqui.ai. 65 | All complaints will be reviewed and investigated promptly and fairly. 66 | 67 | All community leaders are obligated to respect the privacy and security of the 68 | reporter of any incident. 69 | 70 | ## Enforcement Guidelines 71 | 72 | Community leaders will follow these Community Impact Guidelines in determining 73 | the consequences for any action they deem in violation of this Code of Conduct: 74 | 75 | ### 1. Correction 76 | 77 | **Community Impact**: Use of inappropriate language or other behavior deemed 78 | unprofessional or unwelcome in the community. 79 | 80 | **Consequence**: A private, written warning from community leaders, providing 81 | clarity around the nature of the violation and an explanation of why the 82 | behavior was inappropriate. A public apology may be requested. 83 | 84 | ### 2. Warning 85 | 86 | **Community Impact**: A violation through a single incident or series 87 | of actions. 88 | 89 | **Consequence**: A warning with consequences for continued behavior. No 90 | interaction with the people involved, including unsolicited interaction with 91 | those enforcing the Code of Conduct, for a specified period of time. This 92 | includes avoiding interactions in community spaces as well as external channels 93 | like social media. Violating these terms may lead to a temporary or 94 | permanent ban. 95 | 96 | ### 3. Temporary Ban 97 | 98 | **Community Impact**: A serious violation of community standards, including 99 | sustained inappropriate behavior. 100 | 101 | **Consequence**: A temporary ban from any sort of interaction or public 102 | communication with the community for a specified period of time. No public or 103 | private interaction with the people involved, including unsolicited interaction 104 | with those enforcing the Code of Conduct, is allowed during this period. 105 | Violating these terms may lead to a permanent ban. 106 | 107 | ### 4. Permanent Ban 108 | 109 | **Community Impact**: Demonstrating a pattern of violation of community 110 | standards, including sustained inappropriate behavior, harassment of an 111 | individual, or aggression toward or disparagement of classes of individuals. 112 | 113 | **Consequence**: A permanent ban from any sort of public interaction within 114 | the community. 115 | 116 | ## Attribution 117 | 118 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 119 | version 2.0, available at 120 | [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0]. 121 | 122 | Community Impact Guidelines were inspired by 123 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 124 | 125 | For answers to common questions about this code of conduct, see the FAQ at 126 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available 127 | at [https://www.contributor-covenant.org/translations][translations]. 128 | 129 | [homepage]: https://www.contributor-covenant.org 130 | [v2.0]: https://www.contributor-covenant.org/version/2/0/code_of_conduct.html 131 | [Mozilla CoC]: https://github.com/mozilla/diversity 132 | [FAQ]: https://www.contributor-covenant.org/faq 133 | [translations]: https://www.contributor-covenant.org/translations 134 | -------------------------------------------------------------------------------- /examples/train_simple_gan.py: -------------------------------------------------------------------------------- 1 | """ 2 | This example shows training of a simple GAN model with MNIST dataset using Gradient Accumulation and Advanced 3 | Optimization where you call optimizer steps manually. 4 | """ 5 | 6 | import os 7 | from dataclasses import dataclass 8 | 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | from torch.utils.data import DataLoader 13 | from torchvision import transforms 14 | from torchvision.datasets import MNIST 15 | 16 | from trainer import Trainer, TrainerConfig, TrainerModel 17 | from trainer.trainer import TrainerArgs 18 | 19 | is_cuda = torch.cuda.is_available() 20 | 21 | 22 | # pylint: skip-file 23 | 24 | 25 | class Generator(nn.Module): 26 | def __init__(self, latent_dim, img_shape): 27 | super().__init__() 28 | self.img_shape = img_shape 29 | 30 | def block(in_feat, out_feat, normalize=True): 31 | layers = [nn.Linear(in_feat, out_feat)] 32 | if normalize: 33 | layers.append(nn.BatchNorm1d(out_feat, 0.8)) 34 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 35 | return layers 36 | 37 | self.model = nn.Sequential( 38 | *block(latent_dim, 128, normalize=False), 39 | *block(128, 256), 40 | *block(256, 512), 41 | *block(512, 1024), 42 | nn.Linear(1024, int(np.prod(img_shape))), 43 | nn.Tanh(), 44 | ) 45 | 46 | def forward(self, z): 47 | img = self.model(z) 48 | img = img.view(img.size(0), *self.img_shape) 49 | return img 50 | 51 | 52 | class Discriminator(nn.Module): 53 | def __init__(self, img_shape): 54 | super().__init__() 55 | 56 | self.model = nn.Sequential( 57 | nn.Linear(int(np.prod(img_shape)), 512), 58 | nn.LeakyReLU(0.2, inplace=True), 59 | nn.Linear(512, 256), 60 | nn.LeakyReLU(0.2, inplace=True), 61 | nn.Linear(256, 1), 62 | nn.Sigmoid(), 63 | ) 64 | 65 | def forward(self, img): 66 | img_flat = img.view(img.size(0), -1) 67 | validity = self.model(img_flat) 68 | 69 | return validity 70 | 71 | 72 | @dataclass 73 | class GANModelConfig(TrainerConfig): 74 | epochs: int = 1 75 | print_step: int = 2 76 | training_seed: int = 666 77 | 78 | 79 | class GANModel(TrainerModel): 80 | def __init__(self): 81 | super().__init__() 82 | data_shape = (1, 28, 28) 83 | self.generator = Generator(latent_dim=100, img_shape=data_shape) 84 | self.discriminator = Discriminator(img_shape=data_shape) 85 | 86 | def forward(self, x): 87 | ... 88 | 89 | def optimize(self, batch, trainer): 90 | imgs, _ = batch 91 | 92 | # sample noise 93 | z = torch.randn(imgs.shape[0], 100) 94 | z = z.type_as(imgs) 95 | 96 | # train discriminator 97 | imgs_gen = self.generator(z) 98 | logits = self.discriminator(imgs_gen.detach()) 99 | fake = torch.zeros(imgs.size(0), 1) 100 | fake = fake.type_as(imgs) 101 | loss_fake = trainer.criterion(logits, fake) 102 | 103 | valid = torch.ones(imgs.size(0), 1) 104 | valid = valid.type_as(imgs) 105 | logits = self.discriminator(imgs) 106 | loss_real = trainer.criterion(logits, valid) 107 | loss_disc = (loss_real + loss_fake) / 2 108 | 109 | # step dicriminator 110 | _, _ = self.scaled_backward(loss_disc, None, trainer, trainer.optimizer[0]) 111 | 112 | if trainer.total_steps_done % trainer.grad_accum_steps == 0: 113 | trainer.optimizer[0].step() 114 | trainer.optimizer[0].zero_grad() 115 | 116 | # train generator 117 | imgs_gen = self.generator(z) 118 | 119 | valid = torch.ones(imgs.size(0), 1) 120 | valid = valid.type_as(imgs) 121 | 122 | logits = self.discriminator(imgs_gen) 123 | loss_gen = trainer.criterion(logits, valid) 124 | 125 | # step generator 126 | _, _ = self.scaled_backward(loss_gen, None, trainer, trainer.optimizer[1]) 127 | if trainer.total_steps_done % trainer.grad_accum_steps == 0: 128 | trainer.optimizer[1].step() 129 | trainer.optimizer[1].zero_grad() 130 | return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc} 131 | 132 | @torch.no_grad() 133 | def eval_step(self, batch, criterion): 134 | imgs, _ = batch 135 | 136 | # sample noise 137 | z = torch.randn(imgs.shape[0], 100) 138 | z = z.type_as(imgs) 139 | 140 | imgs_gen = self.generator(z) 141 | valid = torch.ones(imgs.size(0), 1) 142 | valid = valid.type_as(imgs) 143 | 144 | logits = self.discriminator(imgs_gen) 145 | loss_gen = trainer.criterion(logits, valid) 146 | return {"model_outputs": logits}, {"loss_gen": loss_gen} 147 | 148 | def get_optimizer(self): 149 | discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) 150 | generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999)) 151 | return [discriminator_optimizer, generator_optimizer] 152 | 153 | def get_criterion(self): 154 | return nn.BCELoss() 155 | 156 | def get_data_loader( 157 | self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 158 | ): # pylint: disable=unused-argument 159 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 160 | dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) 161 | dataset.data = dataset.data[:64] 162 | dataset.targets = dataset.targets[:64] 163 | dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=True) 164 | return dataloader 165 | 166 | 167 | if __name__ == "__main__": 168 | 169 | config = GANModelConfig() 170 | config.batch_size = 64 171 | config.grad_clip = None 172 | 173 | model = GANModel() 174 | trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None) 175 | trainer.config.epochs = 10 176 | trainer.fit() 177 | -------------------------------------------------------------------------------- /trainer/logging/mlflow_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | import traceback 5 | 6 | import soundfile as sf 7 | import torch 8 | 9 | from trainer.logging.base_dash_logger import BaseDashboardLogger 10 | from trainer.trainer_utils import is_mlflow_available 11 | from trainer.utils.distributed import rank_zero_only 12 | 13 | if is_mlflow_available(): 14 | from mlflow.tracking import MlflowClient 15 | from mlflow.tracking.context.registry import resolve_tags 16 | from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME 17 | 18 | # pylint: skip-file 19 | 20 | 21 | class MLFlowLogger(BaseDashboardLogger): 22 | def __init__( 23 | self, 24 | log_uri: str, 25 | model_name: str, 26 | tags: str = None, 27 | ): 28 | self.model_name = model_name 29 | self.client = MlflowClient(tracking_uri=os.path.join(log_uri)) 30 | 31 | experiment = self.client.get_experiment_by_name(model_name) 32 | if experiment is None: 33 | self.experiment_id = self.client.create_experiment(name=model_name) 34 | else: 35 | self.experiment_id = experiment.experiment_id 36 | 37 | if tags is not None: 38 | self.client.set_experiment_tag(self.experiment_id, MLFLOW_RUN_NAME, tags) 39 | run = self.client.create_run(experiment_id=self.experiment_id, tags=resolve_tags(tags)) 40 | self.run_id = run.info.run_id 41 | 42 | def model_weights(self, model, step): 43 | layer_num = 1 44 | for name, param in model.named_parameters(): 45 | if param.numel() == 1: 46 | self.client.log_metric("layer{}-{}/value".format(layer_num, name), param.max(), step) 47 | else: 48 | self.client.log_metric("layer{}-{}/max".format(layer_num, name), param.max(), step) 49 | self.client.log_metric("layer{}-{}/min".format(layer_num, name), param.min(), step) 50 | self.client.log_metric("layer{}-{}/mean".format(layer_num, name), param.mean(), step) 51 | self.client.log_metric("layer{}-{}/std".format(layer_num, name), param.std(), step) 52 | # MlFlow does not support histograms 53 | # self.client.add_histogram("layer{}-{}/param".format(layer_num, name), param, step) 54 | # self.client.add_histogram("layer{}-{}/grad".format(layer_num, name), param.grad, step) 55 | layer_num += 1 56 | 57 | def add_config(self, config): 58 | self.add_text("model-config", f"
{config.to_json()}
", 0) 59 | 60 | def add_scalar(self, title, value, step): 61 | self.client.log_metric(self.run_id, title, value, step) 62 | 63 | def add_text(self, title, text, step): 64 | self.client.log_text(self.run_id, text, "{}/{}.txt".format(title, step)) 65 | 66 | def add_figure(self, title, figure, step): 67 | self.client.log_figure(figure, "{}/{}.png".format(title, step)) 68 | 69 | def add_artifact(self, file_or_dir, name, artifact_type, aliases=None): # pylint: disable=W0613, R0201 70 | self.client.log_artifacts(self.run_id, file_or_dir) 71 | 72 | def add_audio(self, title, audio, step, sample_rate): 73 | self.client.log_audio(self.run_id, audio, "{}/{}.wav".format(title, step), sample_rate) 74 | 75 | @rank_zero_only 76 | def add_scalars(self, scope_name, stats, step): 77 | for key, value in stats.items(): 78 | if torch.is_tensor(value): 79 | value = value.item() 80 | self.client.log_metric(self.run_id, "{}-{}".format(scope_name, key), value, step) 81 | 82 | @rank_zero_only 83 | def add_figures(self, scope_name, figures, step): 84 | for key, value in figures.items(): 85 | self.client.log_figure(self.run_id, value, "{}/{}/{}.png".format(scope_name, key, step)) 86 | 87 | @rank_zero_only 88 | def add_audios(self, scope_name, audios, step, sample_rate): 89 | for key, value in audios.items(): 90 | if value.dtype == "float16": 91 | value = value.astype("float32") 92 | try: 93 | tmp_audio_path = tempfile.NamedTemporaryFile(suffix=".wav") 94 | sf.write(tmp_audio_path, value, sample_rate) 95 | self.client.log_artifact( 96 | self.run_id, 97 | tmp_audio_path, 98 | "{}/{}/{}.wav".format(scope_name, key, step), 99 | ) 100 | shutil.rmtree(tmp_audio_path) 101 | except RuntimeError: 102 | traceback.print_exc() 103 | 104 | def train_step_stats(self, step, stats): 105 | self.client.set_tag(self.run_id, "Mode", "training") 106 | super().train_step_stats(step, stats) 107 | 108 | def train_epoch_stats(self, step, stats): 109 | self.client.set_tag(self.run_id, "Mode", "training") 110 | super().train_epoch_stats(step, stats) 111 | 112 | def train_figures(self, step, figures): 113 | self.client.set_tag(self.run_id, "Mode", "training") 114 | super().train_figures(step, figures) 115 | 116 | def train_audios(self, step, audios, sample_rate): 117 | self.client.set_tag(self.run_id, "Mode", "training") 118 | super().train_audios(step, audios, sample_rate) 119 | 120 | def eval_stats(self, step, stats): 121 | self.client.set_tag(self.run_id, "Mode", "evaluation") 122 | super().eval_stats(step, stats) 123 | 124 | def eval_figures(self, step, figures): 125 | self.client.set_tag(self.run_id, "Mode", "evaluation") 126 | super().eval_figures(step, figures) 127 | 128 | def eval_audios(self, step, audios, sample_rate): 129 | self.client.set_tag(self.run_id, "Mode", "evaluation") 130 | super().eval_audios(step, audios, sample_rate) 131 | 132 | def test_audios(self, step, audios, sample_rate): 133 | self.client.set_tag(self.run_id, "Mode", "test") 134 | super().test_audios(step, audios, sample_rate) 135 | 136 | def test_figures(self, step, figures): 137 | self.client.set_tag(self.run_id, "Mode", "test") 138 | super().test_figures(step, figures) 139 | 140 | def flush(self): 141 | pass 142 | 143 | @rank_zero_only 144 | def finish(self): 145 | super().finalize(status) 146 | status = "FINISHED" if status == "success" else status 147 | if self.client.get_run(self.run_id): 148 | self.client.set_terminated(self.run_id, status) 149 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | # 👟 Trainer 4 | An opinionated general purpose model trainer on PyTorch with a simple code base. 5 | 6 | ## Installation 7 | 8 | From Github: 9 | 10 | ```console 11 | git clone https://github.com/coqui-ai/Trainer 12 | cd Trainer 13 | make install 14 | ``` 15 | 16 | From PyPI: 17 | 18 | ```console 19 | pip install trainer 20 | ``` 21 | 22 | Prefer installing from Github as it is more stable. 23 | 24 | ## Implementing a model 25 | Subclass and overload the functions in the [```TrainerModel()```](trainer/model.py) 26 | 27 | 28 | ## Training a model with auto-optimization 29 | See the [MNIST example](examples/train_mnist.py). 30 | 31 | 32 | ## Training a model with advanced optimization 33 | With 👟 you can define the whole optimization cycle as you want as the in GAN example below. It enables more 34 | under-the-hood control and flexibility for more advanced training loops. 35 | 36 | You just have to use the ```scaled_backward()``` function to handle mixed precision training. 37 | 38 | ```python 39 | ... 40 | 41 | def optimize(self, batch, trainer): 42 | imgs, _ = batch 43 | 44 | # sample noise 45 | z = torch.randn(imgs.shape[0], 100) 46 | z = z.type_as(imgs) 47 | 48 | # train discriminator 49 | imgs_gen = self.generator(z) 50 | logits = self.discriminator(imgs_gen.detach()) 51 | fake = torch.zeros(imgs.size(0), 1) 52 | fake = fake.type_as(imgs) 53 | loss_fake = trainer.criterion(logits, fake) 54 | 55 | valid = torch.ones(imgs.size(0), 1) 56 | valid = valid.type_as(imgs) 57 | logits = self.discriminator(imgs) 58 | loss_real = trainer.criterion(logits, valid) 59 | loss_disc = (loss_real + loss_fake) / 2 60 | 61 | # step dicriminator 62 | _, _ = self.scaled_backward(loss_disc, None, trainer, trainer.optimizer[0]) 63 | 64 | if trainer.total_steps_done % trainer.grad_accum_steps == 0: 65 | trainer.optimizer[0].step() 66 | trainer.optimizer[0].zero_grad() 67 | 68 | # train generator 69 | imgs_gen = self.generator(z) 70 | 71 | valid = torch.ones(imgs.size(0), 1) 72 | valid = valid.type_as(imgs) 73 | 74 | logits = self.discriminator(imgs_gen) 75 | loss_gen = trainer.criterion(logits, valid) 76 | 77 | # step generator 78 | _, _ = self.scaled_backward(loss_gen, None, trainer, trainer.optimizer[1]) 79 | if trainer.total_steps_done % trainer.grad_accum_steps == 0: 80 | trainer.optimizer[1].step() 81 | trainer.optimizer[1].zero_grad() 82 | return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc} 83 | 84 | ... 85 | ``` 86 | 87 | See the [GAN training example](examples/train_simple_gan.py) with Gradient Accumulation 88 | 89 | 90 | ## Training with Batch Size Finder 91 | see the test script [here](tests/test_train_batch_size_finder.py) for training with batch size finder. 92 | 93 | 94 | The batch size finder starts at a default BS(defaults to 2048 but can also be user defined) and searches for the largest batch size that can fit on your hardware. you should expect for it to run multiple trainings until it finds it. to use it instead of calling ```trainer.fit()``` youll call ```trainer.fit_with_largest_batch_size(starting_batch_size=2048)``` with ```starting_batch_size``` being the batch the size you want to start the search with. very useful if you are wanting to use as much gpu mem as possible. 95 | 96 | ## Training with DDP 97 | 98 | ```console 99 | $ python -m trainer.distribute --script path/to/your/train.py --gpus "0,1" 100 | ``` 101 | 102 | We don't use ```.spawn()``` to initiate multi-gpu training since it causes certain limitations. 103 | 104 | - Everything must the pickable. 105 | - ```.spawn()``` trains the model in subprocesses and the model in the main process is not updated. 106 | - DataLoader with N processes gets really slow when the N is large. 107 | 108 | ## Training with [Accelerate](https://huggingface.co/docs/accelerate/index) 109 | 110 | Setting `use_accelerate` in `TrainingArgs` to `True` will enable training with Accelerate. 111 | 112 | You can also use it for multi-gpu or distributed training. 113 | 114 | ```console 115 | CUDA_VISIBLE_DEVICES="0,1,2" accelerate launch --multi_gpu --num_processes 3 train_recipe_autoregressive_prompt.py 116 | ``` 117 | 118 | See the [Accelerate docs](https://huggingface.co/docs/accelerate/basic_tutorials/launch). 119 | 120 | ## Adding a callback 121 | 👟 Supports callbacks to customize your runs. You can either set callbacks in your model implementations or give them 122 | explicitly to the Trainer. 123 | 124 | Please check `trainer.utils.callbacks` to see available callbacks. 125 | 126 | Here is how you provide an explicit call back to a 👟Trainer object for weight reinitialization. 127 | 128 | ```python 129 | def my_callback(trainer): 130 | print(" > My callback was called.") 131 | 132 | trainer = Trainer(..., callbacks={"on_init_end": my_callback}) 133 | trainer.fit() 134 | ``` 135 | 136 | ## Profiling example 137 | 138 | - Create the torch profiler as you like and pass it to the trainer. 139 | ```python 140 | import torch 141 | profiler = torch.profiler.profile( 142 | activities=[ 143 | torch.profiler.ProfilerActivity.CPU, 144 | torch.profiler.ProfilerActivity.CUDA, 145 | ], 146 | schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), 147 | on_trace_ready=torch.profiler.tensorboard_trace_handler("./profiler/"), 148 | record_shapes=True, 149 | profile_memory=True, 150 | with_stack=True, 151 | ) 152 | prof = trainer.profile_fit(profiler, epochs=1, small_run=64) 153 | then run Tensorboard 154 | ``` 155 | - Run the tensorboard. 156 | ```console 157 | tensorboard --logdir="./profiler/" 158 | ``` 159 | 160 | ## Supported Experiment Loggers 161 | - [Tensorboard](https://www.tensorflow.org/tensorboard) - actively maintained 162 | - [ClearML](https://clear.ml/) - actively maintained 163 | - [MLFlow](https://mlflow.org/) 164 | - [Aim](https://aimstack.io/) 165 | - [WandDB](https://wandb.ai/) 166 | 167 | To add a new logger, you must subclass [BaseDashboardLogger](trainer/logging/base_dash_logger.py) and overload its functions. 168 | 169 | ## Anonymized Telemetry 170 | We constantly seek to improve 🐸 for the community. To understand the community's needs better and address them accordingly, we collect stripped-down anonymized usage stats when you run the trainer. 171 | 172 | Of course, if you don't want, you can opt out by setting the environment variable `TRAINER_TELEMETRY=0`. 173 | -------------------------------------------------------------------------------- /trainer/logging/aim_logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from trainer.logging.base_dash_logger import BaseDashboardLogger 4 | from trainer.trainer_utils import is_aim_available 5 | from trainer.utils.distributed import rank_zero_only 6 | 7 | if is_aim_available(): 8 | from aim import Audio, Image, Repo, Text # pylint: disable=import-error 9 | from aim.sdk.run import Run # pylint: disable=import-error 10 | 11 | 12 | # pylint: disable=too-many-public-methods 13 | class AimLogger(BaseDashboardLogger): 14 | def __init__( 15 | self, 16 | repo: str, 17 | model_name: str, 18 | tags: str = None, 19 | ): 20 | self._context = None 21 | self.model_name = model_name 22 | self.run = Run(repo=repo, experiment=model_name) 23 | self.repo = Repo(repo) 24 | 25 | # query = f"runs.name == '{model_name}'" 26 | # runs = self.repo.query_runs(query=query) 27 | 28 | if tags: 29 | for tag in tags.split(","): 30 | self.run.add_tag(tag) 31 | 32 | # @staticmethod 33 | # def __fig_to_pil(image): 34 | # """Convert Matplotlib figure to PIL image.""" 35 | # return PIL.Image.frombytes("RGB", image.canvas.get_width_height(), image.canvas.tostring_rgb()) 36 | 37 | @property 38 | def context(self): 39 | return self._context 40 | 41 | @context.setter 42 | def context(self, context): 43 | self._context = context 44 | 45 | def model_weights(self, model, step): 46 | layer_num = 1 47 | for name, param in model.named_parameters(): 48 | if param.numel() == 1: 49 | self.run.log_metric("layer{}-{}/value".format(layer_num, name), param.max(), step) 50 | else: 51 | self.run.log_metric("layer{}-{}/max".format(layer_num, name), param.max(), step) 52 | self.run.log_metric("layer{}-{}/min".format(layer_num, name), param.min(), step) 53 | self.run.log_metric("layer{}-{}/mean".format(layer_num, name), param.mean(), step) 54 | self.run.log_metric("layer{}-{}/std".format(layer_num, name), param.std(), step) 55 | # MlFlow does not support histograms 56 | # self.client.addå_histogram("layer{}-{}/param".format(layer_num, name), param, step) 57 | # self.client.add_histogram("layer{}-{}/grad".format(layer_num, name), param.grad, step) 58 | layer_num += 1 59 | 60 | def add_config(self, config): 61 | """TODO: Add config to AIM""" 62 | # self.run['hparams'] = config.to_dict() 63 | self.add_text("model-config", f"
{config.to_json()}
", 0) 64 | 65 | def add_scalar(self, title, value, step): 66 | self.run.track(value, name=title, step=step, context=self.context) 67 | 68 | def add_text(self, title, text, step): 69 | self.run.track( 70 | Text(text), # Pass a string you want to track 71 | name=title, # The name of distributions 72 | step=step, # Step index (optional) 73 | context=self.context, 74 | ) 75 | 76 | def add_figure(self, title, figure, step): 77 | self.run.track( 78 | Image(figure, title), # Pass image data and/or caption 79 | name=title, # The name of image set 80 | step=step, # Step index (optional) 81 | context=self.context, 82 | ) 83 | 84 | def add_artifact(self, file_or_dir, name, artifact_type, aliases=None): # pylint: disable=W0613 85 | # AIM does not support artifacts 86 | ... 87 | 88 | def add_audio(self, title, audio, step, sample_rate): 89 | self.run.track( 90 | Audio(audio), # Pass audio file or numpy array 91 | name=title, # The name of distributions 92 | step=step, # Step index (optional) 93 | context=self.context, 94 | ) 95 | 96 | @rank_zero_only 97 | def add_scalars(self, scope_name, scalars, step): 98 | for key, value in scalars.items(): 99 | if torch.is_tensor(value): 100 | value = value.item() 101 | self.run.track(value, name="{}-{}".format(scope_name, key), step=step, context=self.context) 102 | 103 | @rank_zero_only 104 | def add_figures(self, scope_name, figures, step): 105 | for key, value in figures.items(): 106 | title = "{}/{}/{}.png".format(scope_name, key, step) 107 | self.run.track( 108 | Image(value, title), # Pass image data and/or caption 109 | name=title, # The name of image set 110 | step=step, # Step index (optional) 111 | context=self.context, 112 | ) 113 | 114 | @rank_zero_only 115 | def add_audios(self, scope_name, audios, step, sample_rate): 116 | for key, value in audios.items(): 117 | title = "{}/{}/{}.wav".format(scope_name, key, step) 118 | self.run.track( 119 | Audio(value), # Pass audio file or numpy array 120 | name=title, # The name of distributions 121 | step=step, # Step index (optional) 122 | context=self.context, 123 | ) 124 | 125 | def train_step_stats(self, step, stats): 126 | self.context = {"subset": "train"} 127 | super().train_step_stats(step, stats) 128 | 129 | def train_epoch_stats(self, step, stats): 130 | self.context = {"subset": "train"} 131 | super().train_epoch_stats(step, stats) 132 | 133 | def train_figures(self, step, figures): 134 | self.context = {"subset": "train"} 135 | super().train_figures(step, figures) 136 | 137 | def train_audios(self, step, audios, sample_rate): 138 | self.context = {"subset": "train"} 139 | super().train_audios(step, audios, sample_rate) 140 | 141 | def eval_stats(self, step, stats): 142 | self.context = {"subset": "eval"} 143 | super().eval_stats(step, stats) 144 | 145 | def eval_figures(self, step, figures): 146 | self.context = {"subset": "eval"} 147 | super().eval_figures(step, figures) 148 | 149 | def eval_audios(self, step, audios, sample_rate): 150 | self.context = {"subset": "eval"} 151 | super().eval_audios(step, audios, sample_rate) 152 | 153 | def test_audios(self, step, audios, sample_rate): 154 | self.context = {"subset": "test"} 155 | super().test_audios(step, audios, sample_rate) 156 | 157 | def test_figures(self, step, figures): 158 | self.context = {"subset": "test"} 159 | super().test_figures(step, figures) 160 | 161 | def flush(self): 162 | pass 163 | 164 | @rank_zero_only 165 | def finish(self): 166 | super().close() 167 | -------------------------------------------------------------------------------- /trainer/trainer_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import random 4 | from typing import Dict, List, Tuple 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from trainer.logger import logger 10 | from trainer.torch import NoamLR, StepwiseGradualLR 11 | from trainer.utils.distributed import rank_zero_logger_info 12 | 13 | 14 | def is_apex_available(): 15 | return importlib.util.find_spec("apex") is not None 16 | 17 | 18 | def is_mlflow_available(): 19 | return importlib.util.find_spec("mlflow") is not None 20 | 21 | 22 | def is_aim_available(): 23 | return importlib.util.find_spec("aim") is not None 24 | 25 | 26 | def is_wandb_available(): 27 | return importlib.util.find_spec("wandb") is not None 28 | 29 | 30 | def is_clearml_available(): 31 | return importlib.util.find_spec("clearml") is not None 32 | 33 | 34 | def print_training_env(args, config): 35 | """Print training environment.""" 36 | rank_zero_logger_info(" > Training Environment:", logger) 37 | 38 | if args.use_accelerate: 39 | rank_zero_logger_info(" | > Backend: Accelerate", logger) 40 | else: 41 | rank_zero_logger_info(" | > Backend: Torch", logger) 42 | 43 | if config.mixed_precision: 44 | rank_zero_logger_info(" | > Mixed precision: True", logger) 45 | rank_zero_logger_info(f" | > Precision: {config.precision}", logger) 46 | else: 47 | rank_zero_logger_info(" | > Mixed precision: False", logger) 48 | rank_zero_logger_info(" | > Precision: float32", logger) 49 | 50 | if torch.cuda.is_available() and torch.cuda.device_count() > 0: 51 | rank_zero_logger_info(f" | > Current device: {torch.cuda.current_device()}", logger) 52 | rank_zero_logger_info(f" | > Num. of GPUs: {torch.cuda.device_count()}", logger) 53 | 54 | rank_zero_logger_info(f" | > Num. of CPUs: {os.cpu_count()}", logger) 55 | rank_zero_logger_info(f" | > Num. of Torch Threads: {torch.get_num_threads()}", logger) 56 | rank_zero_logger_info(f" | > Torch seed: {torch.initial_seed()}", logger) 57 | rank_zero_logger_info(f" | > Torch CUDNN: {torch.backends.cudnn.enabled}", logger) 58 | rank_zero_logger_info(f" | > Torch CUDNN deterministic: {torch.backends.cudnn.deterministic}", logger) 59 | rank_zero_logger_info(f" | > Torch CUDNN benchmark: {torch.backends.cudnn.benchmark}", logger) 60 | rank_zero_logger_info(f" | > Torch TF32 MatMul: {torch.backends.cuda.matmul.allow_tf32}", logger) 61 | 62 | 63 | def setup_torch_training_env( 64 | args: "TrainerArgs", 65 | cudnn_enable: bool, 66 | cudnn_benchmark: bool, 67 | cudnn_deterministic: bool, 68 | use_ddp: bool = False, 69 | training_seed=54321, 70 | allow_tf32: bool = False, 71 | gpu=None, 72 | ) -> Tuple[bool, int]: 73 | """Setup PyTorch environment for training. 74 | 75 | Args: 76 | cudnn_enable (bool): Enable/disable CUDNN. 77 | cudnn_benchmark (bool): Enable/disable CUDNN benchmarking. Better to set to False if input sequence length is 78 | variable between batches. 79 | cudnn_deterministic (bool): Enable/disable CUDNN deterministic mode. 80 | use_ddp (bool): DDP flag. True if DDP is enabled, False otherwise. 81 | allow_tf32 (bool): Enable/disable TF32. TF32 is only available on Ampere GPUs. 82 | torch_seed (int): Seed for torch random number generator. 83 | 84 | Returns: 85 | Tuple[bool, int]: is cuda on or off and number of GPUs in the environment. 86 | """ 87 | # clear cache before training 88 | torch.cuda.empty_cache() 89 | 90 | # set_nvidia_flags 91 | # set the correct cuda visible devices (using pci order) 92 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 93 | if "CUDA_VISIBLE_DEVICES" not in os.environ and gpu is not None: 94 | torch.cuda.set_device(int(gpu)) 95 | num_gpus = 1 96 | else: 97 | num_gpus = torch.cuda.device_count() 98 | 99 | if num_gpus > 1 and (not use_ddp and not args.use_accelerate): 100 | raise RuntimeError( 101 | f" [!] {num_gpus} active GPUs. Define the target GPU by `CUDA_VISIBLE_DEVICES`. For multi-gpu training use `TTS/bin/distribute.py`." 102 | ) 103 | 104 | random.seed(training_seed) 105 | os.environ["PYTHONHASHSEED"] = str(training_seed) 106 | np.random.seed(training_seed) 107 | torch.manual_seed(training_seed) 108 | torch.cuda.manual_seed(training_seed) 109 | 110 | # set torch backend flags. 111 | # set them true if they are already set true 112 | torch.backends.cudnn.deterministic = cudnn_deterministic or torch.backends.cudnn.deterministic 113 | torch.backends.cudnn.enabled = cudnn_enable or torch.backends.cudnn.enabled 114 | torch.backends.cudnn.benchmark = cudnn_benchmark or torch.backends.cudnn.benchmark 115 | torch.backends.cuda.matmul.allow_tf32 = allow_tf32 or torch.backends.cuda.matmul.allow_tf32 116 | 117 | use_cuda = torch.cuda.is_available() 118 | return use_cuda, num_gpus 119 | 120 | 121 | def get_scheduler( 122 | lr_scheduler: str, lr_scheduler_params: Dict, optimizer: torch.optim.Optimizer 123 | ) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access 124 | """Find, initialize and return a Torch scheduler. 125 | 126 | Args: 127 | lr_scheduler (str): Scheduler name. 128 | lr_scheduler_params (Dict): Scheduler parameters. 129 | optimizer (torch.optim.Optimizer): Optimizer to pass to the scheduler. 130 | 131 | Returns: 132 | torch.optim.lr_scheduler._LRScheduler: Functional scheduler. 133 | """ 134 | if lr_scheduler is None: 135 | return None 136 | if lr_scheduler.lower() == "noamlr": 137 | scheduler = NoamLR 138 | elif lr_scheduler.lower() == "stepwisegraduallr": 139 | scheduler = StepwiseGradualLR 140 | else: 141 | scheduler = getattr(torch.optim.lr_scheduler, lr_scheduler) 142 | return scheduler(optimizer, **lr_scheduler_params) 143 | 144 | 145 | def get_optimizer( 146 | optimizer_name: str, 147 | optimizer_params: dict, 148 | lr: float, 149 | model: torch.nn.Module = None, 150 | parameters: List = None, 151 | ) -> torch.optim.Optimizer: 152 | """Find, initialize and return a Torch optimizer. 153 | 154 | Args: 155 | optimizer_name (str): Optimizer name. 156 | optimizer_params (dict): Optimizer parameters. 157 | lr (float): Initial learning rate. 158 | model (torch.nn.Module): Model to pass to the optimizer. 159 | 160 | Returns: 161 | torch.optim.Optimizer: Functional optimizer. 162 | """ 163 | if optimizer_name.lower() == "radam": 164 | module = importlib.import_module("TTS.utils.radam") 165 | optimizer = getattr(module, "RAdam") 166 | else: 167 | optimizer = getattr(torch.optim, optimizer_name) 168 | if model is not None: 169 | parameters = model.parameters() 170 | return optimizer(parameters, lr=lr, **optimizer_params) 171 | -------------------------------------------------------------------------------- /trainer/model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict, List, Tuple, Union 3 | 4 | import torch 5 | from coqpit import Coqpit 6 | from torch import nn 7 | 8 | from trainer.trainer_utils import is_apex_available 9 | 10 | if is_apex_available(): 11 | from apex import amp 12 | 13 | 14 | # pylint: skip-file 15 | 16 | 17 | class TrainerModel(ABC, nn.Module): 18 | """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.""" 19 | 20 | @abstractmethod 21 | def forward(self, input: torch.Tensor, *args, aux_input={}, **kwargs) -> Dict: 22 | """Forward ... for the model mainly used in training. 23 | 24 | You can be flexible here and use different number of arguments and argument names since it is intended to be 25 | used by `train_step()` without exposing it out of the model. 26 | 27 | Args: 28 | input (torch.Tensor): Input tensor. 29 | aux_input (Dict): Auxiliary model inputs like embeddings, durations or any other sorts of inputs. 30 | 31 | Returns: 32 | Dict: Model outputs. Main model output must be named as "model_outputs". 33 | """ 34 | outputs_dict = {"model_outputs": None} 35 | ... 36 | return outputs_dict 37 | 38 | def format_batch(self, batch: Dict) -> Dict: 39 | """Format batch returned by the data loader before sending it to the model. 40 | 41 | If not implemented, model uses the batch as is. 42 | Can be used for data augmentation, feature ectraction, etc. 43 | """ 44 | return batch 45 | 46 | def format_batch_on_device(self, batch: Dict) -> Dict: 47 | """Format batch on device before sending it to the model. 48 | 49 | If not implemented, model uses the batch as is. 50 | Can be used for data augmentation, feature ectraction, etc.` 51 | """ 52 | return batch 53 | 54 | def train_step(self, *args: Any, **kwargs: Any) -> Tuple[Dict, Dict]: 55 | """Perform a single training step. Run the model forward ... and compute losses. 56 | 57 | Args: 58 | batch (Dict): Input tensors. 59 | criterion (nn.Module): Loss layer designed for the model. 60 | optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks. 61 | 62 | Returns: 63 | Tuple[Dict, Dict]: Model ouputs and computed losses. 64 | """ 65 | ... 66 | raise NotImplementedError(" [!] `train_step()` is not implemented.") 67 | 68 | def train_log(self, *args: Any, **kwargs: Any) -> None: 69 | """Create visualizations and waveform examples for training. 70 | 71 | For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to 72 | be projected onto Tensorboard. 73 | 74 | Args: 75 | batch (Dict): Model inputs used at the previous training step. 76 | outputs (Dict): Model outputs generated at the previoud training step. 77 | logger (Logger): Logger instance to log training plots. 78 | assets (Dict): Assets to be used for logging from the trainer's closure. 79 | steps (int): Number of training steps taken so far. 80 | 81 | Returns: 82 | Tuple[Dict, np.ndarray]: training plots and output waveform. 83 | """ 84 | ... 85 | raise NotImplementedError(" [!] `train_log()` is not implemented.") 86 | 87 | @torch.no_grad() 88 | def eval_step(self, *args: Any, **kwargs: Any): 89 | """Perform a single evaluation step. Run the model forward ... and compute losses. In most cases, you can 90 | call `train_step()` with no changes. 91 | 92 | Args: 93 | batch (Dict): Input tensors. 94 | criterion (nn.Module): Loss layer designed for the model. 95 | optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks. 96 | 97 | Returns: 98 | Tuple[Dict, Dict]: Model ouputs and computed losses. 99 | """ 100 | raise NotImplementedError(" [!] `eval_step()` is not implemented.") 101 | 102 | def eval_log(self, *args: Any, **kwargs: Any) -> None: 103 | """The same as `train_log()`""" 104 | ... 105 | raise NotImplementedError(" [!] `eval_log()` is not implemented.") 106 | 107 | @abstractmethod 108 | def get_data_loader(*args: Any, **kwargs: Any) -> torch.utils.data.DataLoader: 109 | """Get data loader for the model. 110 | 111 | Args: 112 | config (Coqpit): Configuration object. 113 | assets (Dict): Additional assets to be used for data loading. 114 | is_eval (bool): If True, returns evaluation data loader. 115 | samples (Union[List[Dict], List[List]]): List of samples to be used for data loading. 116 | verbose (bool): If True, prints data loading information. 117 | num_gpus (int): Number of GPUs used for training. 118 | rank (int): Rank of the current GPU. 119 | 120 | Returns: 121 | torch.utils.data.DataLoader: Data loader for the model. 122 | """ 123 | 124 | ... 125 | raise NotImplementedError(" [!] `get_data_loader()` is not implemented.") 126 | 127 | def init_for_training(self) -> None: 128 | """Initialize model for training.""" 129 | ... 130 | 131 | def optimize(self, *args: Any, **kwargs: Any) -> Tuple[Dict, Dict, float]: 132 | """Model specific optimization step that must perform the following steps: 133 | 1. Forward pass 134 | 2. Compute loss 135 | 3. Backward pass 136 | 4. Update weights 137 | 138 | Use `self.scaled_backward()` instead of `loss.backward()` to be able to use Mixed Precision Training. 139 | 140 | Args: 141 | batch (Dict): Input tensors. 142 | trainer (Trainer): Trainer instance to be able to access the training closure. 143 | 144 | Returns: 145 | Tuple[Dict, Dict, float]: Model outputs, loss dictionary and grad_norm value. 146 | """ 147 | ... 148 | raise NotImplementedError(" [!] `optimize()` is not implemented.") 149 | 150 | def scaled_backward( 151 | self, loss: torch.Tensor, trainer: "Trainer", optimizer: "Optimizer", *args: Any, **kwargs: Any 152 | ) -> Tuple[float, bool]: 153 | """Backward pass with gradient scaling for custom `optimize` calls. 154 | 155 | Args: 156 | loss (torch.Tensor): Loss to be backpropagated. 157 | trainer (Trainer): Trainer instance to be able to access the training closure. 158 | optimizer (Optimizer): Optimizer for APEX AMP based scaled `backward` calls. 159 | """ 160 | if trainer.use_amp_scaler: 161 | if trainer.use_apex: 162 | # https://nvidia.github.io/apex/advanced.html?highlight=accumulate#backward-passes-with-multiple-optimizers 163 | with amp.scale_loss(loss, optimizer) as scaled_loss: 164 | scaled_loss.backward() 165 | else: 166 | # model optimizer step in mixed precision mode 167 | trainer.scaler.scale(loss).backward() 168 | else: 169 | # main model optimizer step 170 | loss.backward() 171 | 172 | # def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]: 173 | # """Setup an return optimizer or optimizers.""" 174 | # ... 175 | 176 | # def get_lr(self) -> Union[float, List[float]]: 177 | # """Return learning rate(s). 178 | 179 | # Returns: 180 | # Union[float, List[float]]: Model's initial learning rates. 181 | # """ 182 | # ... 183 | 184 | # def get_scheduler(self, optimizer: torch.optim.Optimizer): 185 | # ... 186 | 187 | # def get_criterion(self): 188 | # ... 189 | -------------------------------------------------------------------------------- /trainer/callbacks.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict 2 | 3 | 4 | class TrainerCallback: 5 | def __init__(self) -> None: 6 | self.callbacks_on_init_start = [] 7 | self.callbacks_on_init_end = [] 8 | self.callbacks_on_epoch_start = [] 9 | self.callbacks_on_epoch_end = [] 10 | self.callbacks_on_train_epoch_start = [] 11 | self.callbacks_on_train_epoch_end = [] 12 | self.callbacks_on_train_step_start = [] 13 | self.callbacks_on_train_step_end = [] 14 | self.callbacks_on_keyboard_interrupt = [] 15 | 16 | def parse_callbacks_dict(self, callbacks_dict: Dict[str, Callable]) -> None: 17 | for key, value in callbacks_dict.items(): 18 | if key == "on_init_start": 19 | self.callbacks_on_init_start.append(value) 20 | elif key == "on_init_end": 21 | self.callbacks_on_init_end.append(value) 22 | elif key == "on_epoch_start": 23 | self.callbacks_on_epoch_start.append(value) 24 | elif key == "on_epoch_end": 25 | self.callbacks_on_epoch_end.append(value) 26 | elif key == "on_train_epoch_start": 27 | self.callbacks_on_train_epoch_start.append(value) 28 | elif key == "on_train_epoch_end": 29 | self.callbacks_on_train_epoch_end.append(value) 30 | elif key == "on_train_step_start": 31 | self.callbacks_on_train_step_start.append(value) 32 | elif key == "on_train_step_end": 33 | self.callbacks_on_train_step_end.append(value) 34 | elif key == "on_keyboard_interrupt": 35 | self.callbacks_on_keyboard_interrupt.append(value) 36 | else: 37 | raise ValueError(f"Invalid callback key: {key}") 38 | 39 | def on_init_start(self, trainer) -> None: 40 | if hasattr(trainer.model, "module"): 41 | if hasattr(trainer.model.module, "on_init_start"): 42 | trainer.model.module.on_init_start(trainer) 43 | else: 44 | if hasattr(trainer.model, "on_init_start"): 45 | trainer.model.on_init_start(trainer) 46 | 47 | if hasattr(trainer.criterion, "on_init_start"): 48 | trainer.criterion.on_init_start(trainer) 49 | 50 | if hasattr(trainer.optimizer, "on_init_start"): 51 | trainer.optimizer.on_init_start(trainer) 52 | 53 | if self.callbacks_on_init_start: 54 | for callback in self.callbacks_on_init_start: 55 | callback(trainer) 56 | 57 | def on_init_end(self, trainer) -> None: 58 | if hasattr(trainer.model, "module"): 59 | if hasattr(trainer.model.module, "on_init_end"): 60 | trainer.model.module.on_init_end(trainer) 61 | else: 62 | if hasattr(trainer.model, "on_init_end"): 63 | trainer.model.on_init_end(trainer) 64 | 65 | if hasattr(trainer.criterion, "on_init_end"): 66 | trainer.criterion.on_init_end(trainer) 67 | 68 | if hasattr(trainer.optimizer, "on_init_end"): 69 | trainer.optimizer.on_init_end(trainer) 70 | 71 | if len(self.callbacks_on_init_end) > 0: 72 | for callback in self.callbacks_on_init_end: 73 | callback(trainer) 74 | 75 | def on_epoch_start(self, trainer) -> None: 76 | if hasattr(trainer.model, "module"): 77 | if hasattr(trainer.model.module, "on_epoch_start"): 78 | trainer.model.module.on_epoch_start(trainer) 79 | else: 80 | if hasattr(trainer.model, "on_epoch_start"): 81 | trainer.model.on_epoch_start(trainer) 82 | 83 | if hasattr(trainer.criterion, "on_epoch_start"): 84 | trainer.criterion.on_epoch_start(trainer) 85 | 86 | if hasattr(trainer.optimizer, "on_epoch_start"): 87 | trainer.optimizer.on_epoch_start(trainer) 88 | 89 | if self.callbacks_on_epoch_start: 90 | for callback in self.callbacks_on_epoch_start: 91 | callback(trainer) 92 | 93 | def on_epoch_end(self, trainer) -> None: 94 | if hasattr(trainer.model, "module"): 95 | if hasattr(trainer.model.module, "on_epoch_end"): 96 | trainer.model.module.on_epoch_end(trainer) 97 | else: 98 | if hasattr(trainer.model, "on_epoch_end"): 99 | trainer.model.on_epoch_end(trainer) 100 | 101 | if hasattr(trainer.criterion, "on_epoch_end"): 102 | trainer.criterion.on_epoch_end(trainer) 103 | 104 | if hasattr(trainer.optimizer, "on_epoch_end"): 105 | trainer.optimizer.on_epoch_end(trainer) 106 | 107 | if self.callbacks_on_epoch_end: 108 | for callback in self.callbacks_on_epoch_end: 109 | callback(trainer) 110 | 111 | def on_train_epoch_start(self, trainer) -> None: 112 | if hasattr(trainer.model, "module"): 113 | if hasattr(trainer.model.module, "on_train_epoch_start"): 114 | trainer.model.module.on_train_epoch_start(trainer) 115 | else: 116 | if hasattr(trainer.model, "on_train_epoch_start"): 117 | trainer.model.on_train_epoch_start(trainer) 118 | 119 | if hasattr(trainer.criterion, "on_train_epoch_start"): 120 | trainer.criterion.on_train_epoch_start(trainer) 121 | 122 | if hasattr(trainer.optimizer, "on_train_epoch_start"): 123 | trainer.optimizer.on_train_epoch_start(trainer) 124 | 125 | if self.callbacks_on_train_epoch_start: 126 | for callback in self.callbacks_on_train_epoch_start: 127 | callback(trainer) 128 | 129 | def on_train_epoch_end(self, trainer) -> None: 130 | if hasattr(trainer.model, "module"): 131 | if hasattr(trainer.model.module, "on_train_epoch_end"): 132 | trainer.model.module.on_train_epoch_end(trainer) 133 | else: 134 | if hasattr(trainer.model, "on_train_epoch_end"): 135 | trainer.model.on_train_epoch_end(trainer) 136 | 137 | if hasattr(trainer.criterion, "on_train_epoch_end"): 138 | trainer.criterion.on_train_epoch_end(trainer) 139 | 140 | if hasattr(trainer.optimizer, "on_train_epoch_end"): 141 | trainer.optimizer.on_train_epoch_end(trainer) 142 | 143 | if self.callbacks_on_train_epoch_end: 144 | for callback in self.callbacks_on_train_epoch_end: 145 | callback(trainer) 146 | 147 | @staticmethod 148 | def before_backward_pass(trainer, loss_dict) -> None: 149 | if hasattr(trainer.model, "module"): 150 | if hasattr(trainer.model.module, "before_backward_pass"): 151 | trainer.model.module.before_backward_pass(loss_dict, trainer.optimizer) 152 | else: 153 | if hasattr(trainer.model, "before_backward_pass"): 154 | trainer.model.before_backward_pass(loss_dict, trainer.optimizer) 155 | 156 | @staticmethod 157 | def before_gradient_clipping(trainer) -> None: 158 | if hasattr(trainer.model, "module"): 159 | if hasattr(trainer.model.module, "before_gradient_clipping"): 160 | trainer.model.module.before_gradient_clipping() 161 | else: 162 | if hasattr(trainer.model, "before_gradient_clipping"): 163 | trainer.model.before_gradient_clipping() 164 | 165 | def on_train_step_start(self, trainer) -> None: 166 | if hasattr(trainer.model, "module"): 167 | if hasattr(trainer.model.module, "on_train_step_start"): 168 | trainer.model.module.on_train_step_start(trainer) 169 | else: 170 | if hasattr(trainer.model, "on_train_step_start"): 171 | trainer.model.on_train_step_start(trainer) 172 | 173 | if hasattr(trainer.criterion, "on_train_step_start"): 174 | trainer.criterion.on_train_step_start(trainer) 175 | 176 | if hasattr(trainer.optimizer, "on_train_step_start"): 177 | trainer.optimizer.on_train_step_start(trainer) 178 | 179 | if self.callbacks_on_train_step_start: 180 | for callback in self.callbacks_on_train_step_start: 181 | callback(trainer) 182 | 183 | def on_train_step_end(self, trainer) -> None: 184 | if hasattr(trainer.model, "module"): 185 | if hasattr(trainer.model.module, "on_train_step_end"): 186 | trainer.model.module.on_train_step_end(trainer) 187 | else: 188 | if hasattr(trainer.model, "on_train_step_end"): 189 | trainer.model.on_train_step_end(trainer) 190 | 191 | if hasattr(trainer.criterion, "on_train_step_end"): 192 | trainer.criterion.on_train_step_end(trainer) 193 | 194 | if hasattr(trainer.optimizer, "on_train_step_end"): 195 | trainer.optimizer.on_train_step_end(trainer) 196 | 197 | if self.callbacks_on_train_step_end: 198 | for callback in self.callbacks_on_train_step_end: 199 | callback(trainer) 200 | 201 | def on_keyboard_interrupt(self, trainer) -> None: 202 | if hasattr(trainer.model, "module"): 203 | if hasattr(trainer.model.module, "on_keyboard_interrupt"): 204 | trainer.model.module.on_keyboard_interrupt(trainer) 205 | else: 206 | if hasattr(trainer.model, "on_keyboard_interrupt"): 207 | trainer.model.on_keyboard_interrupt(trainer) 208 | 209 | if hasattr(trainer.criterion, "on_keyboard_interrupt"): 210 | trainer.criterion.on_keyboard_interrupt(trainer) 211 | 212 | if hasattr(trainer.optimizer, "on_keyboard_interrupt"): 213 | trainer.optimizer.on_keyboard_interrupt(trainer) 214 | 215 | if self.callbacks_on_keyboard_interrupt: 216 | for callback in self.callbacks_on_keyboard_interrupt: 217 | callback(trainer) 218 | -------------------------------------------------------------------------------- /trainer/io.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import os 4 | import re 5 | import sys 6 | from pathlib import Path 7 | from typing import Any, Callable, Dict, List, Tuple, Union 8 | from urllib.parse import urlparse 9 | 10 | import fsspec 11 | import torch 12 | from coqpit import Coqpit 13 | 14 | from trainer.logger import logger 15 | 16 | 17 | def get_user_data_dir(appname): 18 | if sys.platform == "win32": 19 | import winreg # pylint: disable=import-outside-toplevel, import-error 20 | 21 | key = winreg.OpenKey( 22 | winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders" 23 | ) 24 | dir_, _ = winreg.QueryValueEx(key, "Local AppData") 25 | ans = Path(dir_).resolve(strict=False) 26 | elif sys.platform == "darwin": 27 | ans = Path("~/Library/Application Support/").expanduser() 28 | else: 29 | ans = Path.home().joinpath(".local/share") 30 | return ans.joinpath(appname) 31 | 32 | 33 | def copy_model_files(config: Coqpit, out_path, new_fields): 34 | """Copy config.json and other model files to training folder and add 35 | new fields. 36 | 37 | Args: 38 | config (Coqpit): Coqpit config defining the training run. 39 | out_path (str): output path to copy the file. 40 | new_fields (dict): new fileds to be added or edited 41 | in the config file. 42 | """ 43 | copy_config_path = os.path.join(out_path, "config.json") 44 | # add extra information fields 45 | new_config = {**config.to_dict(), **new_fields} 46 | # TODO: Revert to config.save_json() once Coqpit supports arbitrary paths. 47 | with fsspec.open(copy_config_path, "w", encoding="utf8") as f: 48 | json.dump(new_config, f, indent=4) 49 | 50 | 51 | def load_fsspec( 52 | path: str, 53 | map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, 54 | cache: bool = True, 55 | **kwargs, 56 | ) -> Any: 57 | """Like torch.load but can load from other locations (e.g. s3:// , gs://). 58 | Args: 59 | path: Any path or url supported by fsspec. 60 | map_location: torch.device or str. 61 | cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/trainer_cache`. Defaults to True. 62 | **kwargs: Keyword arguments forwarded to torch.load. 63 | Returns: 64 | Object stored in path. 65 | """ 66 | is_local = os.path.isdir(path) or os.path.isfile(path) 67 | if cache and not is_local: 68 | with fsspec.open( 69 | f"filecache::{path}", 70 | filecache={"cache_storage": str(get_user_data_dir("tts_cache"))}, 71 | mode="rb", 72 | ) as f: 73 | return torch.load(f, map_location=map_location, **kwargs) 74 | else: 75 | with fsspec.open(path, "rb") as f: 76 | return torch.load(f, map_location=map_location, **kwargs) 77 | 78 | 79 | def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin 80 | state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) 81 | model.load_state_dict(state["model"]) 82 | if use_cuda: 83 | model.cuda() 84 | if eval: 85 | model.eval() 86 | return model, state 87 | 88 | 89 | def save_fsspec(state: Any, path: str, **kwargs): 90 | """Like torch.save but can save to other locations (e.g. s3:// , gs://). 91 | 92 | Args: 93 | state: State object to save 94 | path: Any path or url supported by fsspec. 95 | **kwargs: Keyword arguments forwarded to torch.save. 96 | """ 97 | with fsspec.open(path, "wb") as f: 98 | torch.save(state, f, **kwargs) 99 | 100 | 101 | def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, save_func, **kwargs): 102 | if hasattr(model, "module"): 103 | model_state = model.module.state_dict() 104 | else: 105 | model_state = model.state_dict() 106 | if isinstance(optimizer, list): 107 | optimizer_state = [optim.state_dict() for optim in optimizer] 108 | elif isinstance(optimizer, dict): 109 | optimizer_state = {k: v.state_dict() for k, v in optimizer.items()} 110 | else: 111 | optimizer_state = optimizer.state_dict() if optimizer is not None else None 112 | 113 | if isinstance(scaler, list): 114 | scaler_state = [s.state_dict() for s in scaler] 115 | else: 116 | scaler_state = scaler.state_dict() if scaler is not None else None 117 | 118 | if isinstance(config, Coqpit): 119 | config = config.to_dict() 120 | 121 | state = { 122 | "config": config, 123 | "model": model_state, 124 | "optimizer": optimizer_state, 125 | "scaler": scaler_state, 126 | "step": current_step, 127 | "epoch": epoch, 128 | "date": datetime.date.today().strftime("%B %d, %Y"), 129 | } 130 | state.update(kwargs) 131 | if save_func: 132 | save_func(state, output_path) 133 | else: 134 | save_fsspec(state, output_path) 135 | 136 | 137 | def save_checkpoint( 138 | config, 139 | model, 140 | optimizer, 141 | scaler, 142 | current_step, 143 | epoch, 144 | output_folder, 145 | save_n_checkpoints=None, 146 | save_func=None, 147 | **kwargs, 148 | ): 149 | file_name = f"checkpoint_{current_step}.pth" 150 | checkpoint_path = os.path.join(output_folder, file_name) 151 | 152 | logger.info("\n > CHECKPOINT : %s", checkpoint_path) 153 | save_model( 154 | config, 155 | model, 156 | optimizer, 157 | scaler, 158 | current_step, 159 | epoch, 160 | checkpoint_path, 161 | save_func=save_func, 162 | **kwargs, 163 | ) 164 | if save_n_checkpoints is not None: 165 | keep_n_checkpoints(output_folder, save_n_checkpoints) 166 | 167 | 168 | def save_best_model( 169 | current_loss, 170 | best_loss, 171 | config, 172 | model, 173 | optimizer, 174 | scaler, 175 | current_step, 176 | epoch, 177 | out_path, 178 | keep_all_best=False, 179 | keep_after=0, 180 | save_func=None, 181 | **kwargs, 182 | ): 183 | if isinstance(current_loss, dict): 184 | use_eval_loss = current_loss["eval_loss"] is not None and best_loss["eval_loss"] is not None 185 | is_save_model = (use_eval_loss and current_loss["eval_loss"] < best_loss["eval_loss"]) or ( 186 | not use_eval_loss and current_loss["train_loss"] < best_loss["train_loss"] 187 | ) 188 | else: 189 | is_save_model = current_loss < best_loss 190 | 191 | if isinstance(keep_after, (int, float)): 192 | keep_after = int(keep_after) 193 | is_save_model = is_save_model and current_step > keep_after 194 | 195 | if is_save_model: 196 | best_model_name = f"best_model_{current_step}.pth" 197 | checkpoint_path = os.path.join(out_path, best_model_name) 198 | logger.info(" > BEST MODEL : %s", checkpoint_path) 199 | save_model( 200 | config, 201 | model, 202 | optimizer, 203 | scaler, 204 | current_step, 205 | epoch, 206 | checkpoint_path, 207 | model_loss=current_loss, 208 | save_func=save_func, 209 | **kwargs, 210 | ) 211 | fs = fsspec.get_mapper(out_path).fs 212 | # only delete previous if current is saved successfully 213 | if not keep_all_best or (current_step < keep_after): 214 | model_names = fs.glob(os.path.join(out_path, "best_model*.pth")) 215 | for model_name in model_names: 216 | if os.path.basename(model_name) != best_model_name: 217 | fs.rm(model_name) 218 | # create a shortcut which always points to the currently best model 219 | shortcut_name = "best_model.pth" 220 | shortcut_path = os.path.join(out_path, shortcut_name) 221 | fs.copy(checkpoint_path, shortcut_path) 222 | best_loss = current_loss 223 | return best_loss 224 | 225 | 226 | def get_last_checkpoint(path: str) -> Tuple[str, str]: 227 | """Get latest checkpoint or/and best model in path. 228 | 229 | It is based on globbing for `*.pth` and the RegEx 230 | `(checkpoint|best_model)_([0-9]+)`. 231 | 232 | Args: 233 | path: Path to files to be compared. 234 | 235 | Raises: 236 | ValueError: If no checkpoint or best_model files are found. 237 | 238 | Returns: 239 | Path to the last checkpoint 240 | Path to best checkpoint 241 | """ 242 | fs = fsspec.get_mapper(path).fs 243 | file_names = fs.glob(os.path.join(path, "*.pth")) 244 | scheme = urlparse(path).scheme 245 | if scheme and path.startswith(scheme + "://"): 246 | # scheme is not preserved in fs.glob, add it 247 | # back if it exists on the path 248 | file_names = [scheme + "://" + file_name for file_name in file_names] 249 | last_models = {} 250 | last_model_nums = {} 251 | for key in ["checkpoint", "best_model"]: 252 | last_model_num = None 253 | last_model = None 254 | # pass all the checkpoint files and find 255 | # the one with the largest model number suffix. 256 | for file_name in file_names: 257 | match = re.search(f"{key}_([0-9]+)", file_name) 258 | if match is not None: 259 | model_num = int(match.groups()[0]) 260 | if last_model_num is None or model_num > last_model_num: 261 | last_model_num = model_num 262 | last_model = file_name 263 | 264 | # if there is no checkpoint found above 265 | # find the checkpoint with the latest 266 | # modification date. 267 | key_file_names = [fn for fn in file_names if key in fn] 268 | if last_model is None and len(key_file_names) > 0: 269 | last_model = max(key_file_names, key=os.path.getctime) 270 | last_model_num = load_fsspec(last_model)["step"] 271 | 272 | if last_model is not None: 273 | last_models[key] = last_model 274 | last_model_nums[key] = last_model_num 275 | 276 | # check what models were found 277 | if not last_models: 278 | raise ValueError(f"No models found in continue path {path}!") 279 | if "checkpoint" not in last_models: # no checkpoint just best model 280 | last_models["checkpoint"] = last_models["best_model"] 281 | elif "best_model" not in last_models: # no best model 282 | # this shouldn't happen, but let's handle it just in case 283 | last_models["best_model"] = last_models["checkpoint"] 284 | # finally check if last best model is more recent than checkpoint 285 | elif last_model_nums["best_model"] > last_model_nums["checkpoint"]: 286 | last_models["checkpoint"] = last_models["best_model"] 287 | 288 | return last_models["checkpoint"], last_models["best_model"] 289 | 290 | 291 | def keep_n_checkpoints(path: str, n: int) -> None: 292 | """Keep only the last n checkpoints in path. 293 | 294 | Args: 295 | path: Path to files to be compared. 296 | n: Number of checkpoints to keep. 297 | """ 298 | fs = fsspec.get_mapper(path).fs 299 | file_names = sort_checkpoints(path, "checkpoint") 300 | if len(file_names) > n: 301 | for file_name in file_names[:-n]: 302 | fs.rm(file_name) 303 | 304 | 305 | def sort_checkpoints(output_path: str, checkpoint_prefix: str, use_mtime: bool = False) -> List[str]: 306 | """Sort checkpoint paths based on the checkpoint step number. 307 | 308 | Args: 309 | output_path (str): Path to directory containing checkpoints. 310 | checkpoint_prefix (str): Prefix of the checkpoint files. 311 | use_mtime (bool): If True, use modification dates to determine checkpoint order. 312 | """ 313 | ordering_and_checkpoint_path = [] 314 | 315 | glob_checkpoints = [str(x) for x in Path(output_path).glob(f"{checkpoint_prefix}_*")] 316 | 317 | for path in glob_checkpoints: 318 | if use_mtime: 319 | ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) 320 | else: 321 | regex_match = re.match(f".*{checkpoint_prefix}_([0-9]+)", path) 322 | if regex_match is not None and regex_match.groups() is not None: 323 | ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) 324 | 325 | checkpoints_sorted = sorted(ordering_and_checkpoint_path) 326 | checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] 327 | return checkpoints_sorted 328 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MAIN] 2 | 3 | # Analyse import fallback blocks. This can be used to support both Python 2 and 4 | # 3 compatible code, which means that the block might have code that exists 5 | # only in one or another interpreter, leading to false positives when analysed. 6 | analyse-fallback-blocks=no 7 | 8 | # Clear in-memory caches upon conclusion of linting. Useful if running pylint 9 | # in a server-like mode. 10 | clear-cache-post-run=no 11 | 12 | # Load and enable all available extensions. Use --list-extensions to see a list 13 | # all available extensions. 14 | #enable-all-extensions= 15 | 16 | # In error mode, messages with a category besides ERROR or FATAL are 17 | # suppressed, and no reports are done by default. Error mode is compatible with 18 | # disabling specific errors. 19 | #errors-only= 20 | 21 | # Always return a 0 (non-error) status code, even if lint errors are found. 22 | # This is primarily useful in continuous integration scripts. 23 | #exit-zero= 24 | 25 | # A comma-separated list of package or module names from where C extensions may 26 | # be loaded. Extensions are loading into the active Python interpreter and may 27 | # run arbitrary code. 28 | extension-pkg-allow-list= 29 | 30 | # A comma-separated list of package or module names from where C extensions may 31 | # be loaded. Extensions are loading into the active Python interpreter and may 32 | # run arbitrary code. (This is an alternative name to extension-pkg-allow-list 33 | # for backward compatibility.) 34 | extension-pkg-whitelist= 35 | 36 | # Return non-zero exit code if any of these messages/categories are detected, 37 | # even if score is above --fail-under value. Syntax same as enable. Messages 38 | # specified are enabled, while categories only check already-enabled messages. 39 | fail-on= 40 | 41 | # Specify a score threshold under which the program will exit with error. 42 | fail-under=10 43 | 44 | # Interpret the stdin as a python script, whose filename needs to be passed as 45 | # the module_or_package argument. 46 | #from-stdin= 47 | 48 | # Files or directories to be skipped. They should be base names, not paths. 49 | ignore=CVS 50 | 51 | # Add files or directories matching the regular expressions patterns to the 52 | # ignore-list. The regex matches against paths and can be in Posix or Windows 53 | # format. Because '\\' represents the directory delimiter on Windows systems, 54 | # it can't be used as an escape character. 55 | ignore-paths= 56 | 57 | # Files or directories matching the regular expression patterns are skipped. 58 | # The regex matches against base names, not paths. The default value ignores 59 | # Emacs file locks 60 | ignore-patterns=^\.# 61 | 62 | # List of module names for which member attributes should not be checked 63 | # (useful for modules/projects where namespaces are manipulated during runtime 64 | # and thus existing member attributes cannot be deduced by static analysis). It 65 | # supports qualified module names, as well as Unix pattern matching. 66 | ignored-modules= 67 | 68 | # Python code to execute, usually for sys.path manipulation such as 69 | # pygtk.require(). 70 | #init-hook= 71 | 72 | # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the 73 | # number of processors available to use, and will cap the count on Windows to 74 | # avoid hangs. 75 | jobs=1 76 | 77 | # Control the amount of potential inferred values when inferring a single 78 | # object. This can help the performance when dealing with large functions or 79 | # complex, nested conditions. 80 | limit-inference-results=100 81 | 82 | # List of plugins (as comma separated values of python module names) to load, 83 | # usually to register additional checkers. 84 | load-plugins= 85 | 86 | # Pickle collected data for later comparisons. 87 | persistent=yes 88 | 89 | # Minimum Python version to use for version dependent checks. Will default to 90 | # the version used to run pylint. 91 | py-version=3.11 92 | 93 | # Discover python modules and packages in the file system subtree. 94 | recursive=no 95 | 96 | # Add paths to the list of the source roots. Supports globbing patterns. The 97 | # source root is an absolute path or a path relative to the current working 98 | # directory used to determine a package namespace for modules located under the 99 | # source root. 100 | source-roots= 101 | 102 | # When enabled, pylint would attempt to guess common misconfiguration and emit 103 | # user-friendly hints instead of false-positive error messages. 104 | suggestion-mode=yes 105 | 106 | # Allow loading of arbitrary C extensions. Extensions are imported into the 107 | # active Python interpreter and may run arbitrary code. 108 | unsafe-load-any-extension=no 109 | 110 | # In verbose mode, extra non-checker-related info will be displayed. 111 | #verbose= 112 | 113 | 114 | [BASIC] 115 | 116 | # Naming style matching correct argument names. 117 | argument-naming-style=snake_case 118 | 119 | # Regular expression matching correct argument names. Overrides argument- 120 | # naming-style. If left empty, argument names will be checked with the set 121 | # naming style. 122 | #argument-rgx= 123 | 124 | # Naming style matching correct attribute names. 125 | attr-naming-style=snake_case 126 | 127 | # Regular expression matching correct attribute names. Overrides attr-naming- 128 | # style. If left empty, attribute names will be checked with the set naming 129 | # style. 130 | #attr-rgx= 131 | 132 | # Bad variable names which should always be refused, separated by a comma. 133 | bad-names=foo, 134 | bar, 135 | baz, 136 | toto, 137 | tutu, 138 | tata 139 | 140 | # Bad variable names regexes, separated by a comma. If names match any regex, 141 | # they will always be refused 142 | bad-names-rgxs= 143 | 144 | # Naming style matching correct class attribute names. 145 | class-attribute-naming-style=any 146 | 147 | # Regular expression matching correct class attribute names. Overrides class- 148 | # attribute-naming-style. If left empty, class attribute names will be checked 149 | # with the set naming style. 150 | #class-attribute-rgx= 151 | 152 | # Naming style matching correct class constant names. 153 | class-const-naming-style=UPPER_CASE 154 | 155 | # Regular expression matching correct class constant names. Overrides class- 156 | # const-naming-style. If left empty, class constant names will be checked with 157 | # the set naming style. 158 | #class-const-rgx= 159 | 160 | # Naming style matching correct class names. 161 | class-naming-style=PascalCase 162 | 163 | # Regular expression matching correct class names. Overrides class-naming- 164 | # style. If left empty, class names will be checked with the set naming style. 165 | #class-rgx= 166 | 167 | # Naming style matching correct constant names. 168 | const-naming-style=UPPER_CASE 169 | 170 | # Regular expression matching correct constant names. Overrides const-naming- 171 | # style. If left empty, constant names will be checked with the set naming 172 | # style. 173 | #const-rgx= 174 | 175 | # Minimum line length for functions/classes that require docstrings, shorter 176 | # ones are exempt. 177 | docstring-min-length=-1 178 | 179 | # Naming style matching correct function names. 180 | function-naming-style=snake_case 181 | 182 | # Regular expression matching correct function names. Overrides function- 183 | # naming-style. If left empty, function names will be checked with the set 184 | # naming style. 185 | #function-rgx= 186 | 187 | # Good variable names which should always be accepted, separated by a comma. 188 | good-names=i, 189 | j, 190 | k, 191 | ex, 192 | Run, 193 | _ 194 | 195 | # Good variable names regexes, separated by a comma. If names match any regex, 196 | # they will always be accepted 197 | good-names-rgxs= 198 | 199 | # Include a hint for the correct naming format with invalid-name. 200 | include-naming-hint=no 201 | 202 | # Naming style matching correct inline iteration names. 203 | inlinevar-naming-style=any 204 | 205 | # Regular expression matching correct inline iteration names. Overrides 206 | # inlinevar-naming-style. If left empty, inline iteration names will be checked 207 | # with the set naming style. 208 | #inlinevar-rgx= 209 | 210 | # Naming style matching correct method names. 211 | method-naming-style=snake_case 212 | 213 | # Regular expression matching correct method names. Overrides method-naming- 214 | # style. If left empty, method names will be checked with the set naming style. 215 | #method-rgx= 216 | 217 | # Naming style matching correct module names. 218 | module-naming-style=snake_case 219 | 220 | # Regular expression matching correct module names. Overrides module-naming- 221 | # style. If left empty, module names will be checked with the set naming style. 222 | #module-rgx= 223 | 224 | # Colon-delimited sets of names that determine each other's naming style when 225 | # the name regexes allow several styles. 226 | name-group= 227 | 228 | # Regular expression which should only match function or class names that do 229 | # not require a docstring. 230 | no-docstring-rgx=^_ 231 | 232 | # List of decorators that produce properties, such as abc.abstractproperty. Add 233 | # to this list to register other decorators that produce valid properties. 234 | # These decorators are taken in consideration only for invalid-name. 235 | property-classes=abc.abstractproperty 236 | 237 | # Regular expression matching correct type alias names. If left empty, type 238 | # alias names will be checked with the set naming style. 239 | #typealias-rgx= 240 | 241 | # Regular expression matching correct type variable names. If left empty, type 242 | # variable names will be checked with the set naming style. 243 | #typevar-rgx= 244 | 245 | # Naming style matching correct variable names. 246 | variable-naming-style=snake_case 247 | 248 | # Regular expression matching correct variable names. Overrides variable- 249 | # naming-style. If left empty, variable names will be checked with the set 250 | # naming style. 251 | #variable-rgx= 252 | 253 | 254 | [CLASSES] 255 | 256 | # Warn about protected attribute access inside special methods 257 | check-protected-access-in-special-methods=no 258 | 259 | # List of method names used to declare (i.e. assign) instance attributes. 260 | defining-attr-methods=__init__, 261 | __new__, 262 | setUp, 263 | asyncSetUp, 264 | __post_init__ 265 | 266 | # List of member names, which should be excluded from the protected access 267 | # warning. 268 | exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit 269 | 270 | # List of valid names for the first argument in a class method. 271 | valid-classmethod-first-arg=cls 272 | 273 | # List of valid names for the first argument in a metaclass class method. 274 | valid-metaclass-classmethod-first-arg=mcs 275 | 276 | 277 | [DESIGN] 278 | 279 | # List of regular expressions of class ancestor names to ignore when counting 280 | # public methods (see R0903) 281 | exclude-too-few-public-methods= 282 | 283 | # List of qualified class names to ignore when counting class parents (see 284 | # R0901) 285 | ignored-parents= 286 | 287 | # Maximum number of arguments for function / method. 288 | max-args=5 289 | 290 | # Maximum number of attributes for a class (see R0902). 291 | max-attributes=7 292 | 293 | # Maximum number of boolean expressions in an if statement (see R0916). 294 | max-bool-expr=5 295 | 296 | # Maximum number of branch for function / method body. 297 | max-branches=12 298 | 299 | # Maximum number of locals for function / method body. 300 | max-locals=15 301 | 302 | # Maximum number of parents for a class (see R0901). 303 | max-parents=7 304 | 305 | # Maximum number of public methods for a class (see R0904). 306 | max-public-methods=20 307 | 308 | # Maximum number of return / yield for function / method body. 309 | max-returns=6 310 | 311 | # Maximum number of statements in function / method body. 312 | max-statements=50 313 | 314 | # Minimum number of public methods for a class (see R0903). 315 | min-public-methods=2 316 | 317 | 318 | [EXCEPTIONS] 319 | 320 | # Exceptions that will emit a warning when caught. 321 | overgeneral-exceptions=builtins.BaseException,builtins.Exception 322 | 323 | 324 | [FORMAT] 325 | 326 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 327 | expected-line-ending-format= 328 | 329 | # Regexp for a line that is allowed to be longer than the limit. 330 | ignore-long-lines=^\s*(# )??$ 331 | 332 | # Number of spaces of indent required inside a hanging or continued line. 333 | indent-after-paren=4 334 | 335 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 336 | # tab). 337 | indent-string=' ' 338 | 339 | # Maximum number of characters on a single line. 340 | max-line-length=100 341 | 342 | # Maximum number of lines in a module. 343 | max-module-lines=1000 344 | 345 | # Allow the body of a class to be on the same line as the declaration if body 346 | # contains single statement. 347 | single-line-class-stmt=no 348 | 349 | # Allow the body of an if to be on the same line as the test if there is no 350 | # else. 351 | single-line-if-stmt=no 352 | 353 | 354 | [IMPORTS] 355 | 356 | # List of modules that can be imported at any level, not just the top level 357 | # one. 358 | allow-any-import-level= 359 | 360 | # Allow explicit reexports by alias from a package __init__. 361 | allow-reexport-from-package=no 362 | 363 | # Allow wildcard imports from modules that define __all__. 364 | allow-wildcard-with-all=no 365 | 366 | # Deprecated modules which should not be used, separated by a comma. 367 | deprecated-modules= 368 | 369 | # Output a graph (.gv or any supported image format) of external dependencies 370 | # to the given file (report RP0402 must not be disabled). 371 | ext-import-graph= 372 | 373 | # Output a graph (.gv or any supported image format) of all (i.e. internal and 374 | # external) dependencies to the given file (report RP0402 must not be 375 | # disabled). 376 | import-graph= 377 | 378 | # Output a graph (.gv or any supported image format) of internal dependencies 379 | # to the given file (report RP0402 must not be disabled). 380 | int-import-graph= 381 | 382 | # Force import order to recognize a module as part of the standard 383 | # compatibility libraries. 384 | known-standard-library= 385 | 386 | # Force import order to recognize a module as part of a third party library. 387 | known-third-party=enchant 388 | 389 | # Couples of modules and preferred modules, separated by a comma. 390 | preferred-modules= 391 | 392 | 393 | [LOGGING] 394 | 395 | # The type of string formatting that logging methods do. `old` means using % 396 | # formatting, `new` is for `{}` formatting. 397 | logging-format-style=old 398 | 399 | # Logging modules to check that the string format arguments are in logging 400 | # function parameter format. 401 | logging-modules=logging 402 | 403 | 404 | [MESSAGES CONTROL] 405 | 406 | # Only show warnings with the listed confidence levels. Leave empty to show 407 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, 408 | # UNDEFINED. 409 | confidence=HIGH, 410 | INFERENCE, 411 | INFERENCE_FAILURE, 412 | UNDEFINED 413 | 414 | # Disable the message, report, category or checker with the given id(s). You 415 | # can either give multiple identifiers separated by comma (,) or put this 416 | # option multiple times (only on the command line, not in the configuration 417 | # file where it should appear only once). You can also use "--disable=all" to 418 | # disable everything first and then re-enable specific checks. For example, if 419 | # you want to run only the similarities checker, you can use "--disable=all 420 | # --enable=similarities". If you want to run only the classes checker, but have 421 | # no Warning level messages displayed, use "--disable=all --enable=classes 422 | # --disable=W". 423 | disable=raw-checker-failed, 424 | bad-inline-option, 425 | locally-disabled, 426 | file-ignored, 427 | suppressed-message, 428 | useless-suppression, 429 | deprecated-pragma, 430 | use-symbolic-message-instead, 431 | line-too-long, 432 | missing-function-docstring, 433 | missing-module-docstring, 434 | missing-class-docstring, 435 | invalid-name, 436 | consider-using-f-string, 437 | too-many-instance-attributes, 438 | no-member, 439 | too-many-locals, 440 | too-many-branches, 441 | too-many-arguments, 442 | fixme, 443 | too-many-lines, 444 | too-many-statements, 445 | too-many-public-methods, 446 | duplicate-code, 447 | 448 | 449 | # Enable the message, report, category or checker with the given id(s). You can 450 | # either give multiple identifier separated by comma (,) or put this option 451 | # multiple time (only on the command line, not in the configuration file where 452 | # it should appear only once). See also the "--disable" option for examples. 453 | enable=c-extension-no-member 454 | 455 | 456 | [METHOD_ARGS] 457 | 458 | # List of qualified names (i.e., library.method) which require a timeout 459 | # parameter e.g. 'requests.api.get,requests.api.post' 460 | timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request 461 | 462 | 463 | [MISCELLANEOUS] 464 | 465 | # List of note tags to take in consideration, separated by a comma. 466 | notes=FIXME, 467 | XXX, 468 | TODO 469 | 470 | # Regular expression of note tags to take in consideration. 471 | notes-rgx= 472 | 473 | 474 | [REFACTORING] 475 | 476 | # Maximum number of nested blocks for function / method body 477 | max-nested-blocks=5 478 | 479 | # Complete name of functions that never returns. When checking for 480 | # inconsistent-return-statements if a never returning function is called then 481 | # it will be considered as an explicit return statement and no message will be 482 | # printed. 483 | never-returning-functions=sys.exit,argparse.parse_error 484 | 485 | 486 | [REPORTS] 487 | 488 | # Python expression which should return a score less than or equal to 10. You 489 | # have access to the variables 'fatal', 'error', 'warning', 'refactor', 490 | # 'convention', and 'info' which contain the number of messages in each 491 | # category, as well as 'statement' which is the total number of statements 492 | # analyzed. This score is used by the global evaluation report (RP0004). 493 | evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) 494 | 495 | # Template used to display messages. This is a python new-style format string 496 | # used to format the message information. See doc for all details. 497 | msg-template= 498 | 499 | # Set the output format. Available formats are text, parseable, colorized, json 500 | # and msvs (visual studio). You can also give a reporter class, e.g. 501 | # mypackage.mymodule.MyReporterClass. 502 | #output-format= 503 | 504 | # Tells whether to display a full report or only the messages. 505 | reports=no 506 | 507 | # Activate the evaluation score. 508 | score=yes 509 | 510 | 511 | [SIMILARITIES] 512 | 513 | # Comments are removed from the similarity computation 514 | ignore-comments=yes 515 | 516 | # Docstrings are removed from the similarity computation 517 | ignore-docstrings=yes 518 | 519 | # Imports are removed from the similarity computation 520 | ignore-imports=yes 521 | 522 | # Signatures are removed from the similarity computation 523 | ignore-signatures=yes 524 | 525 | # Minimum lines number of a similarity. 526 | min-similarity-lines=4 527 | 528 | 529 | [SPELLING] 530 | 531 | # Limits count of emitted suggestions for spelling mistakes. 532 | max-spelling-suggestions=4 533 | 534 | # Spelling dictionary name. No available dictionaries : You need to install 535 | # both the python package and the system dependency for enchant to work.. 536 | spelling-dict= 537 | 538 | # List of comma separated words that should be considered directives if they 539 | # appear at the beginning of a comment and should not be checked. 540 | spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: 541 | 542 | # List of comma separated words that should not be checked. 543 | spelling-ignore-words= 544 | 545 | # A path to a file that contains the private dictionary; one word per line. 546 | spelling-private-dict-file= 547 | 548 | # Tells whether to store unknown words to the private dictionary (see the 549 | # --spelling-private-dict-file option) instead of raising a message. 550 | spelling-store-unknown-words=no 551 | 552 | 553 | [STRING] 554 | 555 | # This flag controls whether inconsistent-quotes generates a warning when the 556 | # character used as a quote delimiter is used inconsistently within a module. 557 | check-quote-consistency=no 558 | 559 | # This flag controls whether the implicit-str-concat should generate a warning 560 | # on implicit string concatenation in sequences defined over several lines. 561 | check-str-concat-over-line-jumps=no 562 | 563 | 564 | [TYPECHECK] 565 | 566 | # List of decorators that produce context managers, such as 567 | # contextlib.contextmanager. Add to this list to register other decorators that 568 | # produce valid context managers. 569 | contextmanager-decorators=contextlib.contextmanager 570 | 571 | # List of members which are set dynamically and missed by pylint inference 572 | # system, and so shouldn't trigger E1101 when accessed. Python regular 573 | # expressions are accepted. 574 | generated-members= 575 | 576 | # Tells whether to warn about missing members when the owner of the attribute 577 | # is inferred to be None. 578 | ignore-none=yes 579 | 580 | # This flag controls whether pylint should warn about no-member and similar 581 | # checks whenever an opaque object is returned when inferring. The inference 582 | # can return multiple potential results while evaluating a Python object, but 583 | # some branches might not be evaluated, which results in partial inference. In 584 | # that case, it might be useful to still emit no-member and other checks for 585 | # the rest of the inferred objects. 586 | ignore-on-opaque-inference=yes 587 | 588 | # List of symbolic message names to ignore for Mixin members. 589 | ignored-checks-for-mixins=no-member, 590 | not-async-context-manager, 591 | not-context-manager, 592 | attribute-defined-outside-init 593 | 594 | # List of class names for which member attributes should not be checked (useful 595 | # for classes with dynamically set attributes). This supports the use of 596 | # qualified names. 597 | ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace 598 | 599 | # Show a hint with possible names when a member name was not found. The aspect 600 | # of finding the hint is based on edit distance. 601 | missing-member-hint=yes 602 | 603 | # The minimum edit distance a name should have in order to be considered a 604 | # similar match for a missing member name. 605 | missing-member-hint-distance=1 606 | 607 | # The total number of similar names that should be taken in consideration when 608 | # showing a hint for a missing member. 609 | missing-member-max-choices=1 610 | 611 | # Regex pattern to define which classes are considered mixins. 612 | mixin-class-rgx=.*[Mm]ixin 613 | 614 | # List of decorators that change the signature of a decorated function. 615 | signature-mutators= 616 | 617 | 618 | [VARIABLES] 619 | 620 | # List of additional names supposed to be defined in builtins. Remember that 621 | # you should avoid defining new builtins when possible. 622 | additional-builtins= 623 | 624 | # Tells whether unused global variables should be treated as a violation. 625 | allow-global-unused-variables=yes 626 | 627 | # List of names allowed to shadow builtins 628 | allowed-redefined-builtins= 629 | 630 | # List of strings which can identify a callback function by name. A callback 631 | # name must start or end with one of those strings. 632 | callbacks=cb_, 633 | _cb 634 | 635 | # A regular expression matching the name of dummy variables (i.e. expected to 636 | # not be used). 637 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ 638 | 639 | # Argument names that match this expression will be ignored. 640 | ignored-argument-names=_.*|^ignored_|^unused_ 641 | 642 | # Tells whether we should check for unused import in __init__ files. 643 | init-import=no 644 | 645 | # List of qualified module names which can have objects that can redefine 646 | # builtins. 647 | redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io 648 | -------------------------------------------------------------------------------- /tests/test_train_gan.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | from torchvision.datasets import MNIST 11 | 12 | from trainer import TrainerConfig, TrainerModel 13 | from trainer.trainer import Trainer, TrainerArgs 14 | 15 | is_cuda = torch.cuda.is_available() 16 | 17 | 18 | # pylint: skip-file 19 | 20 | 21 | class Generator(nn.Module): 22 | def __init__(self, latent_dim, img_shape): 23 | super().__init__() 24 | self.img_shape = img_shape 25 | 26 | def block(in_feat, out_feat, normalize=True): 27 | layers = [nn.Linear(in_feat, out_feat)] 28 | if normalize: 29 | layers.append(nn.BatchNorm1d(out_feat, 0.8)) 30 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 31 | return layers 32 | 33 | self.model = nn.Sequential( 34 | *block(latent_dim, 128, normalize=False), 35 | *block(128, 256), 36 | *block(256, 512), 37 | *block(512, 1024), 38 | nn.Linear(1024, int(np.prod(img_shape))), 39 | nn.Tanh(), 40 | ) 41 | 42 | def forward(self, z): 43 | img = self.model(z) 44 | img = img.view(img.size(0), *self.img_shape) 45 | return img 46 | 47 | 48 | class Discriminator(nn.Module): 49 | def __init__(self, img_shape): 50 | super().__init__() 51 | 52 | self.model = nn.Sequential( 53 | nn.Linear(int(np.prod(img_shape)), 512), 54 | nn.LeakyReLU(0.2, inplace=True), 55 | nn.Linear(512, 256), 56 | nn.LeakyReLU(0.2, inplace=True), 57 | nn.Linear(256, 1), 58 | nn.Sigmoid(), 59 | ) 60 | 61 | def forward(self, img): 62 | img_flat = img.view(img.size(0), -1) 63 | validity = self.model(img_flat) 64 | 65 | return validity 66 | 67 | 68 | def test_overfit_mnist_simple_gan(): 69 | @dataclass 70 | class GANModelConfig(TrainerConfig): 71 | epochs: int = 1 72 | print_step: int = 2 73 | training_seed: int = 666 74 | 75 | class GANModel(TrainerModel): 76 | def __init__(self): 77 | super().__init__() 78 | data_shape = (1, 28, 28) 79 | self.generator = Generator(latent_dim=100, img_shape=data_shape) 80 | self.discriminator = Discriminator(img_shape=data_shape) 81 | 82 | def forward(self, x): 83 | ... 84 | 85 | def train_step(self, batch, criterion, optimizer_idx): 86 | imgs, _ = batch 87 | 88 | # sample noise 89 | z = torch.randn(imgs.shape[0], 100) 90 | z = z.type_as(imgs) 91 | 92 | # train discriminator 93 | if optimizer_idx == 0: 94 | imgs_gen = self.generator(z) 95 | logits = self.discriminator(imgs_gen.detach()) 96 | fake = torch.zeros(imgs.size(0), 1) 97 | fake = fake.type_as(imgs) 98 | loss_fake = criterion(logits, fake) 99 | 100 | valid = torch.ones(imgs.size(0), 1) 101 | valid = valid.type_as(imgs) 102 | logits = self.discriminator(imgs) 103 | loss_real = loss = criterion(logits, valid) 104 | loss = (loss_real + loss_fake) / 2 105 | return {"model_outputs": logits}, {"loss": loss} 106 | 107 | # train generator 108 | if optimizer_idx == 1: 109 | imgs_gen = self.generator(z) 110 | 111 | valid = torch.ones(imgs.size(0), 1) 112 | valid = valid.type_as(imgs) 113 | 114 | logits = self.discriminator(imgs_gen) 115 | loss_real = criterion(logits, valid) 116 | return {"model_outputs": logits}, {"loss": loss_real} 117 | 118 | @torch.no_grad() 119 | def eval_step(self, batch, criterion, optimizer_idx): 120 | return self.train_step(batch, criterion, optimizer_idx) 121 | 122 | def get_optimizer(self): 123 | discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) 124 | generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999)) 125 | return [discriminator_optimizer, generator_optimizer] 126 | 127 | def get_criterion(self): 128 | return nn.BCELoss() 129 | 130 | def get_data_loader( 131 | self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 132 | ): # pylint: disable=unused-argument 133 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 134 | dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) 135 | dataset.data = dataset.data[:64] 136 | dataset.targets = dataset.targets[:64] 137 | dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=True) 138 | return dataloader 139 | 140 | config = GANModelConfig() 141 | config.batch_size = 64 142 | config.grad_clip = None 143 | 144 | model = GANModel() 145 | trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None) 146 | 147 | trainer.config.epochs = 1 148 | trainer.fit() 149 | loss_d1 = trainer.keep_avg_train["avg_loss_0"] 150 | loss_g1 = trainer.keep_avg_train["avg_loss_1"] 151 | 152 | trainer.config.epochs = 5 153 | trainer.fit() 154 | loss_d2 = trainer.keep_avg_train["avg_loss_0"] 155 | loss_g2 = trainer.keep_avg_train["avg_loss_1"] 156 | 157 | print(f"loss_d1: {loss_d1}, loss_d2: {loss_d2}") 158 | print(f"loss_g1: {loss_g1}, loss_g2: {loss_g2}") 159 | assert loss_d1 > loss_d2, f"Discriminator loss should decrease. {loss_d1} > {loss_d2}" 160 | assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}" 161 | 162 | 163 | def test_overfit_accelerate_mnist_simple_gan(): 164 | @dataclass 165 | class GANModelConfig(TrainerConfig): 166 | epochs: int = 1 167 | print_step: int = 2 168 | training_seed: int = 666 169 | 170 | class GANModel(TrainerModel): 171 | def __init__(self): 172 | super().__init__() 173 | data_shape = (1, 28, 28) 174 | self.generator = Generator(latent_dim=100, img_shape=data_shape) 175 | self.discriminator = Discriminator(img_shape=data_shape) 176 | 177 | def forward(self, x): 178 | ... 179 | 180 | def train_step(self, batch, criterion, optimizer_idx): 181 | imgs, _ = batch 182 | 183 | # sample noise 184 | z = torch.randn(imgs.shape[0], 100) 185 | z = z.type_as(imgs) 186 | 187 | # train discriminator 188 | if optimizer_idx == 0: 189 | imgs_gen = self.generator(z) 190 | logits = self.discriminator(imgs_gen.detach()) 191 | fake = torch.zeros(imgs.size(0), 1) 192 | fake = fake.type_as(imgs) 193 | loss_fake = criterion(logits, fake) 194 | 195 | valid = torch.ones(imgs.size(0), 1) 196 | valid = valid.type_as(imgs) 197 | logits = self.discriminator(imgs) 198 | loss_real = loss = criterion(logits, valid) 199 | loss = (loss_real + loss_fake) / 2 200 | return {"model_outputs": logits}, {"loss": loss} 201 | 202 | # train generator 203 | if optimizer_idx == 1: 204 | imgs_gen = self.generator(z) 205 | 206 | valid = torch.ones(imgs.size(0), 1) 207 | valid = valid.type_as(imgs) 208 | 209 | logits = self.discriminator(imgs_gen) 210 | loss_real = criterion(logits, valid) 211 | return {"model_outputs": logits}, {"loss": loss_real} 212 | 213 | @torch.no_grad() 214 | def eval_step(self, batch, criterion, optimizer_idx): 215 | return self.train_step(batch, criterion, optimizer_idx) 216 | 217 | def get_optimizer(self): 218 | discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) 219 | generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999)) 220 | return [discriminator_optimizer, generator_optimizer] 221 | 222 | def get_criterion(self): 223 | return nn.BCELoss() 224 | 225 | def get_data_loader( 226 | self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 227 | ): # pylint: disable=unused-argument 228 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 229 | dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) 230 | dataset.data = dataset.data[:64] 231 | dataset.targets = dataset.targets[:64] 232 | dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=False) 233 | return dataloader 234 | 235 | config = GANModelConfig() 236 | config.batch_size = 64 237 | config.grad_clip = None 238 | config.training_seed = 333 239 | 240 | model = GANModel() 241 | trainer = Trainer( 242 | TrainerArgs(use_accelerate=True), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None 243 | ) 244 | 245 | trainer.eval_epoch() 246 | loss_d1 = trainer.keep_avg_eval["avg_loss_0"] 247 | loss_g1 = trainer.keep_avg_eval["avg_loss_1"] 248 | 249 | trainer.config.epochs = 5 250 | trainer.fit() 251 | loss_d2 = trainer.keep_avg_train["avg_loss_0"] 252 | loss_g2 = trainer.keep_avg_train["avg_loss_1"] 253 | 254 | print(f"loss_d1: {loss_d1}, loss_d2: {loss_d2}") 255 | print(f"loss_g1: {loss_g1}, loss_g2: {loss_g2}") 256 | assert loss_d1 > loss_d2, f"Discriminator loss should decrease. {loss_d1} > {loss_d2}" 257 | assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}" 258 | 259 | 260 | def test_overfit_manual_optimize_mnist_simple_gan(): 261 | @dataclass 262 | class GANModelConfig(TrainerConfig): 263 | epochs: int = 1 264 | print_step: int = 2 265 | training_seed: int = 666 266 | 267 | class GANModel(TrainerModel): 268 | def __init__(self): 269 | super().__init__() 270 | data_shape = (1, 28, 28) 271 | self.generator = Generator(latent_dim=100, img_shape=data_shape) 272 | self.discriminator = Discriminator(img_shape=data_shape) 273 | 274 | def forward(self, x): 275 | ... 276 | 277 | def optimize(self, batch, trainer): 278 | imgs, _ = batch 279 | 280 | # sample noise 281 | z = torch.randn(imgs.shape[0], 100) 282 | z = z.type_as(imgs) 283 | 284 | # train discriminator 285 | imgs_gen = self.generator(z) 286 | logits = self.discriminator(imgs_gen.detach()) 287 | fake = torch.zeros(imgs.size(0), 1) 288 | fake = fake.type_as(imgs) 289 | loss_fake = trainer.criterion(logits, fake) 290 | 291 | valid = torch.ones(imgs.size(0), 1) 292 | valid = valid.type_as(imgs) 293 | logits = self.discriminator(imgs) 294 | loss_real = trainer.criterion(logits, valid) 295 | loss_disc = (loss_real + loss_fake) / 2 296 | 297 | # step dicriminator 298 | trainer.optimizer[0].zero_grad() 299 | self.scaled_backward(loss_disc, trainer, trainer.optimizer[0]) 300 | trainer.optimizer[0].step() 301 | 302 | # train generator 303 | imgs_gen = self.generator(z) 304 | 305 | valid = torch.ones(imgs.size(0), 1) 306 | valid = valid.type_as(imgs) 307 | 308 | logits = self.discriminator(imgs_gen) 309 | loss_gen = trainer.criterion(logits, valid) 310 | 311 | # step generator 312 | trainer.optimizer[1].zero_grad() 313 | self.scaled_backward(loss_gen, trainer, trainer.optimizer[1]) 314 | trainer.optimizer[1].step() 315 | return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc} 316 | 317 | @torch.no_grad() 318 | def eval_step(self, batch, trainer): 319 | imgs, _ = batch 320 | 321 | # sample noise 322 | z = torch.randn(imgs.shape[0], 100) 323 | z = z.type_as(imgs) 324 | 325 | imgs_gen = self.generator(z) 326 | valid = torch.ones(imgs.size(0), 1) 327 | valid = valid.type_as(imgs) 328 | 329 | logits = self.discriminator(imgs_gen) 330 | loss_gen = trainer.criterion(logits, valid) 331 | return {"model_outputs": logits}, {"loss_gen": loss_gen} 332 | 333 | def get_optimizer(self): 334 | discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) 335 | generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999)) 336 | return [discriminator_optimizer, generator_optimizer] 337 | 338 | def get_criterion(self): 339 | return nn.BCELoss() 340 | 341 | def get_data_loader( 342 | self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 343 | ): # pylint: disable=unused-argument 344 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 345 | dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) 346 | dataset.data = dataset.data[:64] 347 | dataset.targets = dataset.targets[:64] 348 | dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=True) 349 | return dataloader 350 | 351 | config = GANModelConfig() 352 | config.batch_size = 64 353 | config.grad_clip = None 354 | 355 | model = GANModel() 356 | trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None) 357 | 358 | trainer.config.epochs = 1 359 | trainer.fit() 360 | loss_d1 = trainer.keep_avg_train["avg_loss_disc"] 361 | loss_g1 = trainer.keep_avg_train["avg_loss_gen"] 362 | 363 | trainer.config.epochs = 5 364 | trainer.fit() 365 | loss_d2 = trainer.keep_avg_train["avg_loss_disc"] 366 | loss_g2 = trainer.keep_avg_train["avg_loss_gen"] 367 | 368 | print(f"loss_d1: {loss_d1}, loss_d2: {loss_d2}") 369 | print(f"loss_g1: {loss_g1}, loss_g2: {loss_g2}") 370 | assert loss_d1 > loss_d2, f"Discriminator loss should decrease. {loss_d1} > {loss_d2}" 371 | assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}" 372 | 373 | 374 | def test_overfit_manual_optimize_grad_accum_mnist_simple_gan(): 375 | @dataclass 376 | class GANModelConfig(TrainerConfig): 377 | epochs: int = 1 378 | print_step: int = 2 379 | training_seed: int = 666 380 | 381 | class GANModel(TrainerModel): 382 | def __init__(self): 383 | super().__init__() 384 | data_shape = (1, 28, 28) 385 | self.generator = Generator(latent_dim=100, img_shape=data_shape) 386 | self.discriminator = Discriminator(img_shape=data_shape) 387 | 388 | def forward(self, x): 389 | ... 390 | 391 | def optimize(self, batch, trainer): 392 | imgs, _ = batch 393 | 394 | # sample noise 395 | z = torch.randn(imgs.shape[0], 100) 396 | z = z.type_as(imgs) 397 | 398 | # train discriminator 399 | imgs_gen = self.generator(z) 400 | logits = self.discriminator(imgs_gen.detach()) 401 | fake = torch.zeros(imgs.size(0), 1) 402 | fake = fake.type_as(imgs) 403 | loss_fake = trainer.criterion(logits, fake) 404 | 405 | valid = torch.ones(imgs.size(0), 1) 406 | valid = valid.type_as(imgs) 407 | logits = self.discriminator(imgs) 408 | loss_real = trainer.criterion(logits, valid) 409 | loss_disc = (loss_real + loss_fake) / 2 410 | 411 | # step dicriminator 412 | self.scaled_backward(loss_disc, trainer, trainer.optimizer[0]) 413 | 414 | if trainer.total_steps_done % trainer.grad_accum_steps == 0: 415 | trainer.optimizer[0].step() 416 | trainer.optimizer[0].zero_grad() 417 | 418 | # train generator 419 | imgs_gen = self.generator(z) 420 | 421 | valid = torch.ones(imgs.size(0), 1) 422 | valid = valid.type_as(imgs) 423 | 424 | logits = self.discriminator(imgs_gen) 425 | loss_gen = trainer.criterion(logits, valid) 426 | 427 | # step generator 428 | self.scaled_backward(loss_gen, trainer, trainer.optimizer[1]) 429 | if trainer.total_steps_done % trainer.grad_accum_steps == 0: 430 | trainer.optimizer[1].step() 431 | trainer.optimizer[1].zero_grad() 432 | return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc} 433 | 434 | @torch.no_grad() 435 | def eval_step(self, batch, criterion): 436 | imgs, _ = batch 437 | 438 | # sample noise 439 | z = torch.randn(imgs.shape[0], 100) 440 | z = z.type_as(imgs) 441 | 442 | imgs_gen = self.generator(z) 443 | valid = torch.ones(imgs.size(0), 1) 444 | valid = valid.type_as(imgs) 445 | 446 | logits = self.discriminator(imgs_gen) 447 | loss_gen = trainer.criterion(logits, valid) 448 | return {"model_outputs": logits}, {"loss_gen": loss_gen} 449 | 450 | def get_optimizer(self): 451 | discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) 452 | generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999)) 453 | return [discriminator_optimizer, generator_optimizer] 454 | 455 | def get_criterion(self): 456 | return nn.BCELoss() 457 | 458 | def get_data_loader( 459 | self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 460 | ): # pylint: disable=unused-argument 461 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 462 | dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) 463 | dataset.data = dataset.data[:64] 464 | dataset.targets = dataset.targets[:64] 465 | dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=True) 466 | return dataloader 467 | 468 | config = GANModelConfig() 469 | config.batch_size = 64 470 | config.grad_clip = None 471 | 472 | model = GANModel() 473 | trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None) 474 | 475 | trainer.config.epochs = 1 476 | trainer.fit() 477 | loss_d1 = trainer.keep_avg_train["avg_loss_disc"] 478 | loss_g1 = trainer.keep_avg_train["avg_loss_gen"] 479 | 480 | trainer.config.epochs = 5 481 | trainer.fit() 482 | loss_d2 = trainer.keep_avg_train["avg_loss_disc"] 483 | loss_g2 = trainer.keep_avg_train["avg_loss_gen"] 484 | 485 | print(f"loss_d1: {loss_d1}, loss_d2: {loss_d2}") 486 | print(f"loss_g1: {loss_g1}, loss_g2: {loss_g2}") 487 | assert loss_d1 > loss_d2, f"Discriminator loss should decrease. {loss_d1} > {loss_d2}" 488 | assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}" 489 | 490 | 491 | def test_overfit_manual_accelerate_optimize_grad_accum_mnist_simple_gan(): 492 | @dataclass 493 | class GANModelConfig(TrainerConfig): 494 | epochs: int = 1 495 | print_step: int = 2 496 | training_seed: int = 666 497 | 498 | class GANModel(TrainerModel): 499 | def __init__(self): 500 | super().__init__() 501 | data_shape = (1, 28, 28) 502 | self.generator = Generator(latent_dim=100, img_shape=data_shape) 503 | self.discriminator = Discriminator(img_shape=data_shape) 504 | 505 | def train_step(): 506 | ... 507 | 508 | def forward(self, x): 509 | ... 510 | 511 | def optimize(self, batch, trainer): 512 | imgs, _ = batch 513 | 514 | # sample noise 515 | z = torch.randn(imgs.shape[0], 100) 516 | z = z.type_as(imgs) 517 | 518 | # train discriminator 519 | imgs_gen = self.generator(z) 520 | logits = self.discriminator(imgs_gen.detach()) 521 | fake = torch.zeros(imgs.size(0), 1) 522 | fake = fake.type_as(imgs) 523 | loss_fake = trainer.criterion(logits, fake) 524 | 525 | valid = torch.ones(imgs.size(0), 1) 526 | valid = valid.type_as(imgs) 527 | logits = self.discriminator(imgs) 528 | loss_real = trainer.criterion(logits, valid) 529 | loss_disc = (loss_real + loss_fake) / 2 530 | 531 | # step dicriminator 532 | self.scaled_backward(loss_disc, trainer, trainer.optimizer[0]) 533 | 534 | if trainer.total_steps_done % trainer.grad_accum_steps == 0: 535 | trainer.optimizer[0].step() 536 | trainer.optimizer[0].zero_grad() 537 | 538 | # train generator 539 | imgs_gen = self.generator(z) 540 | 541 | valid = torch.ones(imgs.size(0), 1) 542 | valid = valid.type_as(imgs) 543 | 544 | logits = self.discriminator(imgs_gen) 545 | loss_gen = trainer.criterion(logits, valid) 546 | 547 | # step generator 548 | self.scaled_backward(loss_gen, trainer, trainer.optimizer[1]) 549 | if trainer.total_steps_done % trainer.grad_accum_steps == 0: 550 | trainer.optimizer[1].step() 551 | trainer.optimizer[1].zero_grad() 552 | return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc} 553 | 554 | @torch.no_grad() 555 | def eval_step(self, batch, criterion): 556 | imgs, _ = batch 557 | 558 | # sample noise 559 | z = torch.randn(imgs.shape[0], 100) 560 | z = z.type_as(imgs) 561 | 562 | imgs_gen = self.generator(z) 563 | valid = torch.ones(imgs.size(0), 1) 564 | valid = valid.type_as(imgs) 565 | 566 | logits = self.discriminator(imgs_gen) 567 | loss_gen = trainer.criterion(logits, valid) 568 | return {"model_outputs": logits}, {"loss_gen": loss_gen} 569 | 570 | def get_optimizer(self): 571 | discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999)) 572 | generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.001, betas=(0.5, 0.999)) 573 | return [discriminator_optimizer, generator_optimizer] 574 | 575 | def get_criterion(self): 576 | return nn.BCELoss() 577 | 578 | def get_data_loader( 579 | self, config, assets, is_eval, samples, verbose, num_gpus, rank=0 580 | ): # pylint: disable=unused-argument 581 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 582 | dataset = MNIST(os.getcwd(), train=not is_eval, download=True, transform=transform) 583 | dataset.data = dataset.data[:64] 584 | dataset.targets = dataset.targets[:64] 585 | dataloader = DataLoader(dataset, batch_size=config.batch_size, drop_last=True, shuffle=True) 586 | return dataloader 587 | 588 | config = GANModelConfig() 589 | config.batch_size = 64 590 | config.grad_clip = None 591 | 592 | model = GANModel() 593 | trainer = Trainer( 594 | TrainerArgs(use_accelerate=True), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None 595 | ) 596 | 597 | trainer.config.epochs = 1 598 | trainer.fit() 599 | loss_d1 = trainer.keep_avg_train["avg_loss_disc"] 600 | loss_g1 = trainer.keep_avg_train["avg_loss_gen"] 601 | 602 | trainer.config.epochs = 5 603 | trainer.fit() 604 | loss_d2 = trainer.keep_avg_train["avg_loss_disc"] 605 | loss_g2 = trainer.keep_avg_train["avg_loss_gen"] 606 | 607 | print(f"loss_d1: {loss_d1}, loss_d2: {loss_d2}") 608 | print(f"loss_g1: {loss_g1}, loss_g2: {loss_g2}") 609 | assert loss_d1 > loss_d2, f"Discriminator loss should decrease. {loss_d1} > {loss_d2}" 610 | assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}" 611 | 612 | 613 | if __name__ == "__main__": 614 | test_overfit_mnist_simple_gan() 615 | test_overfit_accelerate_mnist_simple_gan() 616 | test_overfit_manual_optimize_mnist_simple_gan() 617 | test_overfit_manual_optimize_grad_accum_mnist_simple_gan() 618 | test_overfit_manual_accelerate_optimize_grad_accum_mnist_simple_gan() 619 | --------------------------------------------------------------------------------