├── torchtracer ├── utils │ ├── __init__.py │ ├── progress.py │ └── storeman.py ├── data │ ├── __init__.py │ ├── model.py │ └── config.py ├── __init__.py └── tracer.py ├── requirements.txt ├── .github └── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md ├── setup.py ├── .travis.yml ├── LICENSE ├── .gitignore ├── demo.py └── README.md /torchtracer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from torchtracer.utils.storeman import StoreMan 2 | -------------------------------------------------------------------------------- /torchtracer/data/__init__.py: -------------------------------------------------------------------------------- 1 | from torchtracer.data.config import Config 2 | from torchtracer.data.model import Model 3 | -------------------------------------------------------------------------------- /torchtracer/__init__.py: -------------------------------------------------------------------------------- 1 | __name__ = 'torchtracer' 2 | __version__ = '0.2.0' 3 | __author__ = 'OIdiotLin' 4 | __what__ = 'lmmnb!!' 5 | 6 | from torchtracer.tracer import Tracer 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cycler>=0.10.0 2 | kiwisolver>=1.0.1 3 | matplotlib>=3.0.2 4 | numpy>=1.15.4 5 | Pillow>=5.3.0 6 | tqdm==4.26.0 7 | pyparsing>=2.3.0 8 | python-dateutil>=2.7.5 9 | six>=1.11.0 10 | torch>=0.4.1 11 | torchvision>=0.2.1 12 | -------------------------------------------------------------------------------- /torchtracer/data/model.py: -------------------------------------------------------------------------------- 1 | class Model(object): 2 | 3 | def __init__(self, model) -> None: 4 | super().__init__() 5 | # print(model) 6 | self.state_dict = model.state_dict() 7 | self.name = model.__class__.__name__ 8 | self.architecture = str(model) 9 | 10 | def __str__(self): 11 | return '{0.name}\n{0.architecture}'.format(self) 12 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | 5 | --- 6 | 7 | Chinese is also welcomed. 8 | 中文亦可。 9 | 10 | --- 11 | 12 | **Is your feature request related to a problem? Please describe.** 13 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 14 | 15 | **Describe the solution you'd like** 16 | A clear and concise description of what you want to happen. 17 | 18 | **Describe alternatives you've considered** 19 | A clear and concise description of any alternative solutions or features you've considered. 20 | 21 | **Additional context** 22 | Add any other context or screenshots about the feature request here. 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | 5 | --- 6 | 7 | Chinese is also welcomed. 8 | 中文亦可。 9 | 10 | --- 11 | 12 | **Describe the bug** 13 | A clear and concise description of what the bug is. 14 | 15 | **To Reproduce** 16 | Steps to reproduce the behavior: 17 | 1. Go to '...' 18 | 2. Click on '....' 19 | 3. Scroll down to '....' 20 | 4. See error 21 | 22 | **Expected behavior** 23 | A clear and concise description of what you expected to happen. 24 | 25 | **Screenshots** 26 | If applicable, add screenshots to help explain your problem. 27 | 28 | **Environment:** 29 | - OS: [e.g. ArchLinux x86_64] 30 | - Python Version: [e.g. 3.7.0] 31 | - PyTorch Version: [e.g. 3.7.] 32 | - etc. (you can add something else) 33 | 34 | **Additional context** 35 | Add any other context about the problem here. 36 | -------------------------------------------------------------------------------- /torchtracer/utils/progress.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | 4 | class ProgressBar(tqdm): 5 | def __init__(self, iterable=None, desc=None, total=None, leave=True, file=None, ncols=None, mininterval=0.1, 6 | maxinterval=10.0, miniters=None, ascii=None, disable=False, unit='it', unit_scale=False, 7 | dynamic_ncols=False, smoothing=0.3, bar_format=None, initial=0, position=None, postfix=None, 8 | unit_divisor=1000, gui=False, **kwargs): 9 | super().__init__(iterable, desc, total, leave, file, ncols, mininterval, maxinterval, miniters, ascii, disable, 10 | unit, unit_scale, dynamic_ncols, smoothing, bar_format, initial, position, postfix, 11 | unit_divisor, gui, **kwargs) 12 | 13 | def update(self, n=1, **params): 14 | self.set_postfix(**params) 15 | super().update(n) 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | import torchtracer 3 | 4 | with open('README.md', 'r') as f: 5 | long_description = f.read() 6 | 7 | setuptools.setup( 8 | name=torchtracer.__name__, 9 | version=torchtracer.__version__, 10 | author='OIdiotLin', 11 | author_email='oidiotlin@gmail.com', 12 | maintainer='OIdiotLin', 13 | maintainer_email='oidiotlin@gmail.com', 14 | description='A python package for visualization and storage management in a pytorch AI task.', 15 | license='MIT License', 16 | long_description=long_description, 17 | long_description_content_type='text/markdown', 18 | url='https://github.com/OIdiotLin/torchtracer', 19 | packages=setuptools.find_packages(), 20 | classifiers=[ 21 | 'Programming Language :: Python :: 3', 22 | 'License :: OSI Approved :: MIT License', 23 | 'Operating System :: OS Independent', 24 | ], 25 | install_requires=[ 26 | 'matplotlib', 27 | 'numpy', 28 | ] 29 | ) 30 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.5" 4 | 5 | sudo: false 6 | dist: trusty 7 | cache: pip 8 | 9 | install: 10 | - pip install -r ./requirements.txt -q 11 | 12 | script: 13 | - mkdir ./checkpoints 14 | - python demo.py 15 | 16 | deploy: 17 | provider: pypi 18 | user: OIdiotLin 19 | password: 20 | secure: WtDDpKNpwFQo3uhz9HDO7GpsySEXeeFhZyCwHnGtdSaSM+Ev/YDuPJTuFGGbbLsDZJGF0lFfl53oLVSg8aytAeT8Yu9Uw67TwoGp+IFjTEMEc4rufQMsVoO9BY+C0fy8H9msEzh2ntQsp2HuMjF69r0IMxUcSZ6f+hoCoHcfTOhlD2PiKijF6T7Qh/u7NMdMHN7qlSjUNSdQOFmIf6qQAg6Oq+MoaL9lxJZzdmh5WPd7WAL8ejVVxZWpCspgKWkkrExegvE9tEs91oqhu/DzduS1Gdf7MpF5vSYe26FJbMcfSyXbjDKxxuJAuIGd3JOWZJGI2DzB52YsXjePFJu9FV7kWwZvJTcExBjhTa3jz3M0wyeOUraUi7ncvKKtXSksLs9JLk2K3baHvz2JzJR1Don7euXhKkg7FJQajITDNNoZquknfuYvWctYBNjwiES1QwCrncC7HD8WALrnMNQhz2KVxJ5JD4VAsX58tM6ArPLBsh/qJFCD9DJ/TIMKI3mYmIlO03LQqAHqfte9lc+rSel+4io1Mz5jx1Fk4ZvvwlC9nkN40ZWz5yj/u0cajtCA/hop4tL77XbipLFmCFNN+SERRDQv65hwy69OtMk7uMmnnV1OYQMQOb1t/k0krKqGYPzjbdXqQ6wzqnmstRzLypRrU1paQHm1FvK9TL//i44= 21 | on: 22 | tags: true 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 OIdiotLin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /torchtracer/tracer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | from tqdm import tqdm 5 | 6 | from torchtracer.utils import StoreMan 7 | from torchtracer.utils.progress import ProgressBar 8 | 9 | 10 | class Tracer: 11 | def __init__(self, root): 12 | self.root = root 13 | self.storage = None 14 | self.epoch_bar = None 15 | print('Tracer start at {}'.format(os.path.abspath(self.root))) 16 | 17 | def attach(self, task_id=None): 18 | if task_id is None: 19 | task_id = datetime.now().isoformat(sep='T', timespec='minutes') 20 | self.storage = StoreMan(self.root, task_id) 21 | print('Tracer attached with task: {}'.format(task_id)) 22 | return self 23 | 24 | def detach(self): 25 | self.storage.close() 26 | 27 | def store(self, item, file=None): 28 | if self.storage is None: 29 | raise Exception('You should attach with task id first.') 30 | self.storage.store(item, file) 31 | 32 | def log(self, msg, file=None): 33 | if self.storage is None: 34 | raise Exception('You should attach with task id first.') 35 | self.storage.log(msg, file) 36 | 37 | def epoch_bar_init(self, epoch_num): 38 | self.epoch_bar = ProgressBar(total=epoch_num, desc='Epoch') 39 | 40 | @staticmethod 41 | def print(msg): 42 | tqdm.write(msg) 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Jetbrains 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | # test env 110 | checkpoints/ 111 | -------------------------------------------------------------------------------- /torchtracer/data/config.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import json 3 | 4 | import torch 5 | 6 | 7 | class Config(object): 8 | def __init__(self, cfg) -> None: 9 | super().__init__() 10 | if isinstance(cfg, configparser.ConfigParser): 11 | self.content = Config.from_cfg(cfg) 12 | elif isinstance(cfg, str): 13 | self.content = Config.from_ini(cfg) 14 | elif isinstance(cfg, dict): 15 | self.content = json.dumps(Config.from_dict(cfg), indent=2) 16 | 17 | @staticmethod 18 | def from_ini(ini): 19 | config = configparser.ConfigParser() 20 | config.read_string(ini) 21 | return Config.from_cfg(config) 22 | 23 | @staticmethod 24 | def from_cfg(cfg): 25 | dic = {} 26 | sections = cfg.sections() 27 | for section in sections: 28 | dic_section = {} 29 | options = cfg.options(section) 30 | for option in options: 31 | dic_section[option] = cfg.get(section, option) 32 | dic[section] = dic_section 33 | return dic 34 | 35 | @staticmethod 36 | def from_dict(dic): 37 | res = {} 38 | # only loss function name reserved. 39 | if isinstance(dic, torch.nn.modules.loss._Loss): 40 | return dic._get_name() 41 | # 42 | if isinstance(dic, torch.optim.Optimizer): 43 | sub = dic.param_groups[0].copy() 44 | sub.pop('params') 45 | sub['name'] = dic.__class__.__name__ 46 | return Config.from_dict(sub) 47 | for k in dic.keys(): 48 | if type(dic[k]) in [int, float, bool, str, list]: 49 | res[k] = dic[k] 50 | elif isinstance(dic[k], (torch.optim.Optimizer, 51 | torch.nn.modules.loss._Loss)): 52 | res[k] = Config.from_dict(dic[k]) 53 | return res 54 | 55 | def __str__(self): 56 | return self.content 57 | -------------------------------------------------------------------------------- /torchtracer/utils/storeman.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from matplotlib.figure import Figure 5 | 6 | from torchtracer.data import Config, Model 7 | 8 | 9 | class StoreMan(object): 10 | """ 11 | Storage module. 12 | """ 13 | CONFIG_FILENAME = 'config.json' 14 | IMG_DIR = 'images' 15 | MODEL_DESCRIPTION_FILENAME = 'model.txt' 16 | MODEL_PARAMETERS_FILENAME = 'model.pth' 17 | LOG_FILENAME = 'log' 18 | 19 | def __init__(self, root, task_id) -> None: 20 | super().__init__() 21 | self.active_log = None 22 | self.root = root 23 | self.regist(task_id) 24 | 25 | def regist(self, task_id) -> None: 26 | dir_path = os.path.abspath(os.path.join(self.root, task_id)) 27 | if os.path.isdir(dir_path): 28 | raise FileExistsError('{} exists, you should rename the task id.'.format(task_id)) 29 | self.root = self.mkdir(dir_path) 30 | 31 | def close(self): 32 | self.active_log.close() 33 | 34 | @staticmethod 35 | def mkdir(path) -> os.path: 36 | # create task directory 37 | if not os.path.isdir(path): 38 | os.mkdir(path) 39 | # create image directory 40 | img_dir = os.path.join(path, StoreMan.IMG_DIR) 41 | if not os.path.isdir(img_dir): 42 | os.mkdir(img_dir) 43 | return path 44 | 45 | def store(self, item, file): 46 | if isinstance(item, Config): 47 | self.store_config(item) 48 | elif isinstance(item, Model): 49 | self.store_model(item, file) 50 | elif isinstance(item, Figure): 51 | self.store_image(item, file) 52 | 53 | def store_config(self, cfg): 54 | path = os.path.join(self.root, self.CONFIG_FILENAME) 55 | with open(path, 'w') as f: 56 | f.write(str(cfg)) 57 | 58 | def store_model(self, model, file=None): 59 | description = str(model) 60 | parameters = model.state_dict 61 | 62 | description_file = os.path.join(self.root, 63 | '{}.txt'.format(file) if file else self.MODEL_DESCRIPTION_FILENAME) 64 | parameters_file = os.path.join(self.root, 65 | '{}.pth'.format(file) if file else self.MODEL_PARAMETERS_FILENAME) 66 | 67 | with open(description_file, 'w', encoding='utf-8') as f: 68 | f.write(description) 69 | torch.save(parameters, parameters_file) 70 | 71 | def store_image(self, image, file): 72 | img_file = os.path.join(self.root, self.IMG_DIR, file) 73 | image.savefig(img_file) 74 | 75 | def log(self, msg, file=None): 76 | logfile = os.path.join(self.root, 77 | '{}.{}'.format(file, self.LOG_FILENAME) if file else self.LOG_FILENAME) 78 | if self.active_log: 79 | if not os.path.basename(self.active_log.name) == logfile: 80 | self.active_log.close() 81 | self.active_log = open(logfile, 'a', encoding='utf-8') 82 | else: 83 | self.active_log = open(logfile, 'a', encoding='utf-8') 84 | self.active_log.write(msg + '\n') 85 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | 3 | matplotlib.use('agg') 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torchvision 9 | from torch.utils.data import DataLoader 10 | 11 | from torchtracer import Tracer 12 | from torchtracer.data import Config, Model 13 | 14 | 15 | def f(x): 16 | return np.sqrt(x) 17 | 18 | 19 | def make_batch(batch_size, valid=False): 20 | if not valid: 21 | x = np.random.uniform(0, 1000, batch_size) 22 | else: 23 | x = np.random.uniform(1000, 2000, batch_size) 24 | x = x.reshape((batch_size, 1)) 25 | x = torch.Tensor(x) 26 | y = f(x) 27 | return x, y 28 | 29 | 30 | def evaluate(model, **kwargs): 31 | batch_size = kwargs['batch_size'] 32 | criterion = kwargs['criterion'] 33 | 34 | loss_sum = 0 35 | for step in range(30): 36 | x, y = make_batch(batch_size=batch_size) 37 | y_ = model(x) 38 | loss = criterion(y, y_) 39 | loss_sum += loss.item() 40 | 41 | return loss_sum / (30 * batch_size) 42 | 43 | 44 | def train(model, tracer=None, **kwargs): 45 | cfg = Config(kwargs) 46 | tracer.store(cfg) 47 | 48 | epoch_n = kwargs['epoch_n'] 49 | batch_size = kwargs['batch_size'] 50 | criterion = kwargs['criterion'] 51 | optimizer = kwargs['optimizer'] 52 | 53 | train_losses = [] 54 | valid_losses = [] 55 | 56 | tracer.epoch_bar_init(epoch_n) 57 | 58 | for epoch in range(epoch_n): 59 | loss_sum = 0 60 | for step in range(30): 61 | x, y = make_batch(batch_size) 62 | y_ = model(x) 63 | loss = criterion(y, y_) 64 | loss_sum += loss.item() 65 | 66 | optimizer.zero_grad() 67 | loss.backward() 68 | optimizer.step() 69 | 70 | train_loss = loss_sum / (30 * batch_size) 71 | valid_loss = evaluate(model, **kwargs) 72 | tracer.epoch_bar.update(n=1, train_loss=train_loss, valid_loss=train_loss) 73 | 74 | tracer.log('Epoch #{:03d}\ttrain_loss: {:.4f}\tvalid_loss: {:.4f}'.format(epoch, train_loss, valid_loss)) 75 | 76 | train_losses.append(train_loss) 77 | valid_losses.append(valid_loss) 78 | tracer.epoch_bar.close() 79 | tracer.store(Model(model)) 80 | 81 | plt.plot(train_losses, label='train loss', c='b') 82 | plt.plot(valid_losses, label='valid loss', c='r') 83 | plt.title('Demo Learning on SQRT') 84 | plt.legend() 85 | tracer.store(plt.gcf(), 'losses.png') 86 | 87 | 88 | if __name__ == '__main__': 89 | net = nn.Sequential(nn.Linear(1, 6, True), 90 | nn.ReLU(), 91 | nn.Linear(6, 12, True), 92 | nn.ReLU(), 93 | nn.Linear(12, 12, True), 94 | nn.ReLU(), 95 | nn.Linear(12, 1, True)) 96 | args = {'epoch_n': 120, 97 | 'batch_size': 50, 98 | 'criterion': nn.MSELoss(), 99 | 'optimizer': torch.optim.RMSprop(net.parameters(), lr=1e-3), 100 | 'dataloader': DataLoader(dataset=torchvision.datasets.fakedata)} 101 | tracer = Tracer('checkpoints').attach('rabbit') 102 | train(net, tracer, **args) 103 | tracer.detach() 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torchtracer 2 | 3 | [![Build Status](https://travis-ci.com/OIdiotLin/torchtracer.svg?branch=master)](https://travis-ci.com/OIdiotLin/torchtracer) 4 | ![](https://img.shields.io/badge/python-3.6-blue.svg) 5 | ![](https://img.shields.io/badge/pytorch-0.4.1-orange.svg) 6 | 7 | `torchtracer` is a tool package for visualization and storage management in pytorch AI task. 8 | 9 | ## Getting Started 10 | 11 | ### PyTorch Required 12 | 13 | This tool is developed for PyTorch AI task. Thus, PyTorch is needed of course. 14 | 15 | ### Installing 16 | 17 | You can use `pip` to install `torchtracer`. 18 | 19 | ```bash 20 | pip install torchtracer 21 | ``` 22 | 23 | ## How to use? 24 | 25 | ### Import `torchtracer` 26 | 27 | ```python 28 | from torchtracer import Tracer 29 | ``` 30 | 31 | ### Create an instance of `Tracer` 32 | 33 | Assume that the root is `./checkpoints` and current task id is `lmmnb`. 34 | 35 | ***Avoiding messing working directory, you should make root directory manually.*** 36 | 37 | ```python 38 | tracer = Tracer('checkpoints').attach('lmmnb') 39 | ``` 40 | 41 | This step will create a directory `checkpoints` inside which is a directory `lmmnb` for current AI task. 42 | 43 | Also, you could call `.attach()` without task id. **Datetime will be used as task id.** 44 | 45 | ```python 46 | tracer = Tracer('checkpoints').attach() 47 | ``` 48 | 49 | ### Saving config 50 | 51 | Raw config should be a `dict` like this: 52 | 53 | ```python 54 | # `net` is a defined nn.Module 55 | args = {'epoch_n': 120, 56 | 'batch_size': 10, 57 | 'criterion': nn.MSELoss(), 58 | 'optimizer': torch.optim.RMSprop(net.parameters(), lr=1e-3)} 59 | ``` 60 | 61 | The config dict should be wrapped with `torchtracer.data.Config` 62 | 63 | ```python 64 | cfg = Config(args) 65 | tracer.store(cfg) 66 | ``` 67 | 68 | This step will create `config.json` in `./checkpoints/lmmnb/`, which contains JSON information like this: 69 | 70 | ```json 71 | { 72 | "epoch_n": 120, 73 | "batch_size": 10, 74 | "criterion": "MSELoss", 75 | "optimizer": { 76 | "lr": 0.001, 77 | "momentum": 0, 78 | "alpha": 0.99, 79 | "eps": 1e-08, 80 | "centered": false, 81 | "weight_decay": 0, 82 | "name": "RMSprop" 83 | } 84 | } 85 | ``` 86 | 87 | ### Logging 88 | 89 | During the training iteration, you could print any information you want by using `Tracer.log(msg, file)`. 90 | 91 | If `file` not specified, it will output `msg` to `./checkpoints/lmmnb/log`. Otherwise, it will be `./checkpoints/lmmnb/something.log`. 92 | 93 | ```python 94 | tracer.log(msg='Epoch #{:03d}\ttrain_loss: {:.4f}\tvalid_loss: {:.4f}'.format(epoch, train_loss, valid_loss), 95 | file='losses') 96 | ``` 97 | 98 | This step will create a log file `losses.log` in `./checkpoints/lmmnb/`, which contains logs like: 99 | 100 | ```text 101 | Epoch #001 train_loss: 18.6356 valid_loss: 21.3882 102 | Epoch #002 train_loss: 19.1731 valid_loss: 17.8482 103 | Epoch #003 train_loss: 19.6756 valid_loss: 19.1418 104 | Epoch #004 train_loss: 20.0638 valid_loss: 18.3875 105 | Epoch #005 train_loss: 18.4679 valid_loss: 19.6304 106 | ... 107 | ``` 108 | 109 | ### Saving model 110 | 111 | The model object should be wrapped with `torchtracer.data.Model` 112 | 113 | If `file` not specified, it will generates model files `model.txt`. Otherwise, it will be `somename.txt` 114 | 115 | ```python 116 | tracer.store(Model(model), file='somename') 117 | ``` 118 | 119 | This step will create 2 files: 120 | 121 | - **description**: `somename.txt` 122 | 123 | ```text 124 | Sequential 125 | Sequential( 126 | (0): Linear(in_features=1, out_features=6, bias=True) 127 | (1): ReLU() 128 | (2): Linear(in_features=6, out_features=12, bias=True) 129 | (3): ReLU() 130 | (4): Linear(in_features=12, out_features=12, bias=True) 131 | (5): ReLU() 132 | (6): Linear(in_features=12, out_features=1, bias=True) 133 | ) 134 | ``` 135 | 136 | - **parameters**: `somename.pth` 137 | 138 | ### Saving matplotlib images 139 | 140 | Use `tracer.store(figure, file)` to save matplotlib figure in `images/` 141 | 142 | ```python 143 | # assume that `train_losses` and `valid_losses` are lists of losses. 144 | # create figure manually. 145 | plt.plot(train_losses, label='train loss', c='b') 146 | plt.plot(valid_losses, label='valid loss', c='r') 147 | plt.title('Demo Learning on SQRT') 148 | plt.legend() 149 | # save figure. remember to call `plt.gcf()` 150 | tracer.store(plt.gcf(), 'losses.png') 151 | ``` 152 | 153 | This step will save a png file `losses.png` representing losses curves. 154 | 155 | ### Progress bar for epochs 156 | 157 | Use `tracer.epoch_bar_init(total)` to initialize a progress bar. 158 | 159 | ```python 160 | tracer.epoch_bar_init(epoch_n) 161 | ``` 162 | 163 | Use `tracer.epoch_bar.update(n=1, **params)` to update postfix of the progress bar. 164 | 165 | ```python 166 | tracer.epoch_bar.update(train_loss=train_loss, valid_loss=train_loss) 167 | ``` 168 | 169 | ```plain 170 | (THIS IS A DEMO) 171 | Tracer start at /home/oidiotlin/projects/torchtracer/checkpoints 172 | Tracer attached with task: rabbit 173 | Epoch: 100%|█████████| 120/120 [00:02<00:00, 41.75it/s, train_loss=0.417, valid_loss=0.417] 174 | ``` 175 | 176 | **DO NOT FORGET TO CALL** `tracer.epoch_bar.close()` to finish the bar. 177 | 178 | ## Contribute 179 | 180 | If you like this project, welcome to pull request & create issues. 181 | --------------------------------------------------------------------------------