├── CNAME ├── unittests ├── __init__.py ├── test_featureVisualization.py ├── test_loger.py ├── test_optimizer.py ├── test_instences.py ├── test_supParallelTrainer.py └── test_model.py ├── _config.yml ├── docs ├── _config.yml ├── requirements.txt ├── source │ ├── model.rst │ ├── assessment.rst │ ├── optimizer.rst │ ├── parallel.rst │ ├── dataset.rst │ ├── trainer.rst │ ├── index.rst │ ├── quick start.rst │ ├── conf.py │ └── Build your own trainer.rst ├── Makefile └── make.bat ├── jdit ├── .pytest_cache │ └── v │ │ └── cache │ │ ├── lastfailed │ │ └── nodeids ├── assessment │ ├── __init__.py │ ├── inception.py │ └── fid.py ├── parallel │ └── __init__.py ├── __init__.py ├── trainer │ ├── gan │ │ ├── __init__.py │ │ ├── generate.py │ │ ├── pix2pix.py │ │ └── sup_gan.py │ ├── single │ │ ├── __init__.py │ │ ├── classification.py │ │ ├── autoencoder.py │ │ └── sup_single.py │ ├── __init__.py │ ├── instances │ │ ├── __init__.py │ │ ├── fashionAutoencoder.py │ │ ├── fashionClassification.py │ │ ├── fashionClassParallelTrainer.py │ │ ├── fashionGenerateGan.py │ │ └── cifarPix2pixGan.py │ └── abandon_classification.py ├── perspective.py ├── optimizer.py └── dataset.py ├── MANIFEST.in ├── resources ├── logo.png ├── class_log.jpg ├── class_net.png ├── class_opt.png ├── class_train.png ├── class_valid.png ├── tb_graphs.png ├── tb_scalars.png ├── class_dataset.png └── tb_projector.png ├── readthedocs.yml ├── requirements.txt ├── .gitignore ├── .travis.yml ├── .github └── workflows │ └── python-package.yml ├── setup.py ├── README.md └── LICENSE /CNAME: -------------------------------------------------------------------------------- 1 | dingguanglei.com -------------------------------------------------------------------------------- /unittests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /jdit/.pytest_cache/v/cache/lastfailed: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /jdit/assessment/__init__.py: -------------------------------------------------------------------------------- 1 | from .fid import FID_score -------------------------------------------------------------------------------- /jdit/.pytest_cache/v/cache/nodeids: -------------------------------------------------------------------------------- 1 | [ 2 | "model.py::test" 3 | ] -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include docs/source/index.rst 2 | include requirements.txt 3 | include LICENSE -------------------------------------------------------------------------------- /resources/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingguanglei/jdit/HEAD/resources/logo.png -------------------------------------------------------------------------------- /resources/class_log.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingguanglei/jdit/HEAD/resources/class_log.jpg -------------------------------------------------------------------------------- /resources/class_net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingguanglei/jdit/HEAD/resources/class_net.png -------------------------------------------------------------------------------- /resources/class_opt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingguanglei/jdit/HEAD/resources/class_opt.png -------------------------------------------------------------------------------- /resources/class_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingguanglei/jdit/HEAD/resources/class_train.png -------------------------------------------------------------------------------- /resources/class_valid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingguanglei/jdit/HEAD/resources/class_valid.png -------------------------------------------------------------------------------- /resources/tb_graphs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingguanglei/jdit/HEAD/resources/tb_graphs.png -------------------------------------------------------------------------------- /resources/tb_scalars.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingguanglei/jdit/HEAD/resources/tb_scalars.png -------------------------------------------------------------------------------- /resources/class_dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingguanglei/jdit/HEAD/resources/class_dataset.png -------------------------------------------------------------------------------- /resources/tb_projector.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingguanglei/jdit/HEAD/resources/tb_projector.png -------------------------------------------------------------------------------- /jdit/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .parallel_trainer import SupParallelTrainer 2 | 3 | __all__ = ["SupParallelTrainer"] -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | 3 | build: 4 | image: latest 5 | 6 | python: 7 | version: 3.6 8 | setup_py_install: true -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | pip 2 | mock 3 | future 4 | numpy 5 | tqdm 6 | pyyaml 7 | setuptools 8 | six 9 | typing 10 | psutil 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /jdit/__init__.py: -------------------------------------------------------------------------------- 1 | from jdit.model import Model 2 | from jdit.optimizer import Optimizer 3 | 4 | import jdit.dataset 5 | import jdit.trainer 6 | import jdit.assessment 7 | import jdit.parallel 8 | 9 | 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | tensorflow 4 | tensorboard 5 | tqdm 6 | imageio 7 | Pillow 8 | psutil 9 | numpy 10 | scipy 11 | setuptools 12 | mock 13 | nvidia_ml_py3 14 | sphinx_rtd_theme 15 | 16 | -------------------------------------------------------------------------------- /jdit/trainer/gan/__init__.py: -------------------------------------------------------------------------------- 1 | from .sup_gan import SupGanTrainer 2 | from .pix2pix import Pix2pixGanTrainer 3 | from .generate import GenerateGanTrainer 4 | 5 | __all__ = ['SupGanTrainer', 'Pix2pixGanTrainer','GenerateGanTrainer'] -------------------------------------------------------------------------------- /docs/source/model.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | jdit.model 5 | ========== 6 | 7 | .. automodule:: jdit 8 | .. currentmodule:: jdit 9 | 10 | Model 11 | ----- 12 | 13 | .. autoclass:: jdit.Model 14 | :members: 15 | 16 | -------------------------------------------------------------------------------- /jdit/trainer/single/__init__.py: -------------------------------------------------------------------------------- 1 | from .sup_single import SupSingleModelTrainer 2 | from .classification import ClassificationTrainer 3 | from .autoencoder import AutoEncoderTrainer 4 | 5 | __all__ = ['SupSingleModelTrainer', 'ClassificationTrainer', "AutoEncoderTrainer"] 6 | -------------------------------------------------------------------------------- /docs/source/assessment.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | jdit.assessment 5 | =============== 6 | 7 | .. automodule:: jdit.assessment 8 | .. currentmodule:: jdit.assessment 9 | 10 | FID 11 | --- 12 | .. autofunction:: FID_score 13 | 14 | 15 | -------------------------------------------------------------------------------- /docs/source/optimizer.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | jdit.optimizer 5 | ============== 6 | 7 | .. automodule:: jdit 8 | .. currentmodule:: jdit 9 | 10 | Optimizer 11 | --------- 12 | 13 | .. autoclass:: jdit.Optimizer 14 | :members: 15 | 16 | -------------------------------------------------------------------------------- /docs/source/parallel.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | jdit.parallel 5 | ============= 6 | 7 | .. automodule:: jdit.parallel 8 | .. currentmodule:: jdit.parallel 9 | 10 | SupParallelTrainer 11 | -------------------- 12 | 13 | .. autoclass:: SupParallelTrainer 14 | :members: 15 | -------------------------------------------------------------------------------- /unittests/test_featureVisualization.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | 4 | class TestFeatureVisualization(TestCase): 5 | def test__hook(self): 6 | pass 7 | 8 | def test__register_forward_hook(self): 9 | pass 10 | 11 | def test_trace_activation(self): 12 | pass 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | #jdit/metric/Cifar10_M.csv 2 | #jdit/metric/Cifar10_S.csv 3 | mypackage 4 | 5 | log 6 | jdit/data 7 | .idea 8 | datasets 9 | 10 | # Windows: 11 | Thumbs.db 12 | ehthumbs.db 13 | Desktop.ini 14 | 15 | # Python: 16 | *.py[cod] 17 | *.so 18 | *.egg 19 | *.egg-info 20 | *.pypirc 21 | dist 22 | build 23 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.6" 4 | - "3.7-dev" # 3.7 development branch 5 | # command to install dependencies 6 | install: 7 | - pip install -r docs/requirements.txt 8 | # command to run tests 9 | script: 10 | - ls 11 | 12 | #cache: bundler 13 | cache: 14 | directories: 15 | - $HOME/.m2 -------------------------------------------------------------------------------- /jdit/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .super import SupTrainer 2 | from .gan import * 3 | from .single import * 4 | from jdit.trainer import instances 5 | 6 | # from .instances import * 7 | 8 | # from jdit.trainer import single 9 | # 10 | # __all__ = ['SupTrainer', 11 | # 'ClassificationTrainer', 12 | # 'instances', 13 | # 'single', 14 | # 'SupGanTrainer', 15 | # 'Pix2pixGanTrainer', 16 | # 'GenerateGanTrainer', 17 | # 'SupSingleModelTrainer', 18 | # 'ClassificationTrainer', 19 | # 'AutoEncoderTrainer'] 20 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = jdit 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /jdit/trainer/instances/__init__.py: -------------------------------------------------------------------------------- 1 | from .fashionClassification import FashionClassTrainer, start_fashionClassTrainer 2 | from .fashionGenerateGan import FashionGenerateGanTrainer, start_fashionGenerateGanTrainer 3 | from .cifarPix2pixGan import start_cifarPix2pixGanTrainer 4 | from .fashionClassParallelTrainer import start_fashionClassPrarallelTrainer 5 | from .fashionAutoencoder import FashionAutoEncoderTrainer, start_fashionAutoencoderTrainer 6 | __all__ = ['FashionClassTrainer', 'start_fashionClassTrainer', 7 | 'FashionGenerateGanTrainer', 'start_fashionGenerateGanTrainer', 8 | 'cifarPix2pixGan', 'start_cifarPix2pixGanTrainer', 'start_fashionClassPrarallelTrainer', 9 | 'start_fashionAutoencoderTrainer', 'FashionAutoEncoderTrainer'] 10 | -------------------------------------------------------------------------------- /unittests/test_loger.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | import torch 3 | from jdit.optimizer import Optimizer 4 | from jdit.trainer.super import Loger 5 | from jdit.model import Model 6 | 7 | 8 | class TestLoger(TestCase): 9 | def test_regist_config(self): 10 | log = Loger() 11 | param = torch.nn.Linear(10, 1).parameters() 12 | opt = Optimizer(param, "RMSprop", lr_decay=0.5,decay_position= 2, position_type="step", lr=0.999) 13 | log.regist_config(opt) 14 | self.assertEqual(log.regist_dict["Optimizer"], 15 | {'opt_name': 'RMSprop', 'lr': 0.999, 'momentum': 0, 'alpha': 0.99, 'eps': 1e-08, 16 | 'centered': False, 'weight_decay': 0, 'lr_decay': '0.5', 17 | 'decay_decay_typeposition': 'step', 'decay_position': '2'}) 18 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | set SPHINXPROJ=jdit 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/source/dataset.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | jdit.dataset 5 | ============ 6 | 7 | .. automodule:: jdit.dataset 8 | .. currentmodule:: jdit.dataset 9 | 10 | Dataloaders_factory 11 | ------------------- 12 | 13 | .. autoclass:: DataLoadersFactory 14 | :members: 15 | 16 | HandMNIST 17 | ---------- 18 | 19 | .. autoclass:: HandMNIST 20 | :members: 21 | 22 | FashionMNIST 23 | ------------- 24 | 25 | .. autoclass:: FashionMNIST 26 | :members: 27 | 28 | Cifar10 29 | ------- 30 | 31 | .. autoclass:: Cifar10 32 | :members: 33 | 34 | Lsun 35 | ---- 36 | 37 | .. autoclass:: Lsun 38 | :members: 39 | 40 | get_mnist_dataloaders 41 | --------------------- 42 | .. autofunction:: get_mnist_dataloaders 43 | 44 | get_fashion_mnist_dataloaders 45 | ----------------------------- 46 | .. autofunction:: get_fashion_mnist_dataloaders 47 | 48 | get_lsun_dataloader 49 | ------------------- 50 | .. autofunction:: get_lsun_dataloader -------------------------------------------------------------------------------- /unittests/test_optimizer.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | import torch 3 | from jdit import Optimizer 4 | 5 | 6 | class TestOptimizer(TestCase): 7 | def setUp(self): 8 | param = torch.nn.Linear(10, 1).parameters() 9 | self.opt = Optimizer(param, "RMSprop", lr_decay=0.5, decay_position=3, position_type="step", lr=2) 10 | 11 | def test_do_lr_decay(self): 12 | self.opt.do_lr_decay(reset_lr=2, reset_lr_decay=0.3) 13 | self.assertEqual(self.opt.lr, 2) 14 | self.assertEqual(self.opt.lr_decay, 0.3) 15 | self.opt.do_lr_decay() 16 | self.assertEqual(self.opt.lr, 0.6) 17 | 18 | def test_is_lrdecay(self): 19 | self.assert_(not self.opt.is_decay_lr(2)) 20 | self.assert_(self.opt.is_decay_lr(3)) 21 | 22 | def test_configure(self): 23 | self.opt.do_lr_decay(reset_lr=2, reset_lr_decay=0.3) 24 | self.assertEqual(self.opt.configure['lr'], 2) 25 | self.assertEqual(self.opt.configure['lr_decay'], '0.3') 26 | -------------------------------------------------------------------------------- /unittests/test_instences.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | from jdit.trainer.instances import start_fashionClassTrainer, start_fashionGenerateGanTrainer, \ 3 | start_cifarPix2pixGanTrainer, start_fashionAutoencoderTrainer, start_fashionClassPrarallelTrainer 4 | import shutil 5 | import os 6 | 7 | 8 | class TestInstances(TestCase): 9 | def test_start_fashionClassTrainer(self): 10 | start_fashionClassTrainer(run_type="debug") 11 | 12 | def test_start_fashionGenerateGanTrainer(self): 13 | start_fashionGenerateGanTrainer(run_type="debug") 14 | 15 | def test_start_cifarPix2pixGanTrainer(self): 16 | start_cifarPix2pixGanTrainer(run_type="debug") 17 | 18 | def test_start_fashionAotoencoderTrainer(self): 19 | start_fashionAutoencoderTrainer(run_type="debug") 20 | 21 | def test_start_fashionClassPrarallelTrainer(self): 22 | start_fashionClassPrarallelTrainer() 23 | 24 | def setUp(self): 25 | dir = "log_debug/" 26 | if os._exists(dir): 27 | shutil.rmtree(dir) 28 | 29 | def tearDown(self): 30 | dir = "log_debug/" 31 | if os._exists(dir): 32 | shutil.rmtree(dir) 33 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.7, 3.8, 3.9] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install pytest 30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 31 | - name: Lint with flake8 32 | run: | 33 | # stop the build if there are Python syntax errors or undefined names 34 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 35 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 36 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 37 | - name: Test with pytest 38 | run: | 39 | pytest 40 | -------------------------------------------------------------------------------- /docs/source/trainer.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | jdit.trainer 5 | ============ 6 | 7 | .. automodule:: jdit.trainer 8 | .. currentmodule:: jdit.trainer 9 | 10 | SupTrainer 11 | ---------- 12 | 13 | .. autoclass:: SupTrainer 14 | :members: 15 | 16 | Single Model Trainer 17 | -------------------- 18 | 19 | .. automodule:: jdit.trainer 20 | .. currentmodule:: jdit.trainer 21 | 22 | SupSingleModelTrainer 23 | >>>>>>>>>>>>>>>>>>>>> 24 | 25 | .. autoclass:: SupSingleModelTrainer 26 | :members: 27 | 28 | ClassificationTrainer 29 | >>>>>>>>>>>>>>>>>>>>> 30 | 31 | .. autoclass:: ClassificationTrainer 32 | :members: 33 | 34 | AutoEncoderTrainer 35 | >>>>>>>>>>>>>>>>>>>>> 36 | 37 | .. autoclass:: AutoEncoderTrainer 38 | :members: 39 | 40 | Generative Adversarial Networks Trainer 41 | --------------------------------------- 42 | 43 | .. automodule:: jdit.trainer 44 | .. currentmodule:: jdit.trainer 45 | 46 | SupGanTrainer 47 | >>>>>>>>>>>>>>>>>>>>>>>>>>>>> 48 | 49 | .. autoclass:: SupGanTrainer 50 | :members: 51 | 52 | Pix2pixGanTrainer 53 | >>>>>>>>>>>>>>>>>>>>>>>>>>>>> 54 | 55 | .. autoclass:: Pix2pixGanTrainer 56 | :members: 57 | 58 | GenerateGanTrainer 59 | >>>>>>>>>>>>>>>>>>>>>>>>>>>>> 60 | 61 | .. autoclass:: GenerateGanTrainer 62 | :members: 63 | 64 | instances 65 | --------- 66 | 67 | .. automodule:: jdit.trainer.instances 68 | .. currentmodule:: jdit.trainer.instances 69 | 70 | instances.FashionClassTrainer 71 | >>>>>>>>>>>>>>>>>>>>>>>>>>>>> 72 | 73 | .. autofunction:: start_fashionClassTrainer 74 | 75 | -------------------------------------------------------------------------------- /unittests/test_supParallelTrainer.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | from jdit.parallel import SupParallelTrainer 3 | import os 4 | import shutil 5 | 6 | class test_print(): 7 | def __init__(self, params): 8 | self.params = params 9 | def train(self, *args, **kwargs): 10 | print(self.params) 11 | 12 | 13 | class TestSupParallelTrainer(TestCase): 14 | 15 | def tearDown(self): 16 | logdir = "log_debug/" 17 | if os._exists(logdir): 18 | shutil.rmtree(logdir) 19 | 20 | def setUp(self): 21 | logdir = "log_debug/" 22 | if os._exists(logdir): 23 | shutil.rmtree(logdir) 24 | 25 | unfixed_params = [{'task_id': 2, 'depth': 1, 'gpu_ids_abs': []}, 26 | {'task_id': 1, 'depth': 2, 'gpu_ids_abs': [1, 2]}, 27 | {'task_id': 1, 'depth': 3, 'gpu_ids_abs': [1, 2]}, 28 | {'task_id': 2, 'depth': 4, 'gpu_ids_abs': [3, 4]} 29 | ] 30 | 31 | 32 | def test_train(self): 33 | unfixed_params = [{'task_id': 2, 'depth': 1, 'gpu_ids_abs': []}, 34 | {'task_id': 1, 'depth': 2, 'gpu_ids_abs': []}, ] 35 | pt = SupParallelTrainer(unfixed_params, test_print) 36 | pt.train() 37 | 38 | def test__add_logdirs_to_unfixed_params(self): 39 | unfixed_params = [ 40 | {'depth': 1, 'gpu_ids_abs': []}, 41 | {'depth': 2, 'gpu_ids_abs': [1, 2]} 42 | ] 43 | final_unfixed_params = [ 44 | {'depth': 1, 'gpu_ids_abs': [], 'logdir': 'plog/depth=1,gpu=[]'}, 45 | {'depth': 2, 'gpu_ids_abs': [1, 2], 'logdir': 'plog/depth=2,gpu=[1, 2]'} 46 | ] 47 | pt = SupParallelTrainer([{'task_id': 1, "logdir": "log"}], test_print) 48 | test_final_unfixed_params_list = pt._add_logdirs_to_unfixed_params(unfixed_params) 49 | self.assertEqual(final_unfixed_params, test_final_unfixed_params_list) 50 | 51 | def test__convert_to_dirname(self): 52 | pt = SupParallelTrainer([{'task_id': 1, 'logdir': "log"}], test_print) 53 | self.assertEqual(pt._convert_to_dirname("abc"), "abc") 54 | self.assertEqual(pt._convert_to_dirname("123_abc_abc****"), "123_abc_abc") 55 | self.assertEqual(pt._convert_to_dirname("*<>,/\\:?|abc"), "smallergreater,__%$-abc") 56 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | from setuptools import setup 3 | 4 | setup( 5 | name="jdit", # pypi中的名称,pip或者easy_install安装时使用的名称,或生成egg文件的名称 6 | version="0.1.5", 7 | author="Guanglei Ding", 8 | author_email="dingguanglei.bupt@qq.com", 9 | maintainer='Guanglei Ding', 10 | maintainer_email='dingguanglei.bupt@qq.com', 11 | description=("Make it easy to do research on pytorch"), 12 | # long_description=open('docs/source/index.rst').read(), 13 | long_description=open('README.md').read(), 14 | long_description_content_type="text/markdown", 15 | license="Apache License 2.0", 16 | keywords="pytorch research framework", 17 | platforms=["all"], 18 | url="https://github.com/dingguanglei/jdit", 19 | packages=['jdit', 20 | 'jdit/trainer', 21 | 'jdit/trainer/gan', 22 | 'jdit/trainer/single', 23 | 'jdit/trainer/instances', 24 | 'jdit/assessment', 25 | 'jdit/parallel'], 26 | # 'mypackage','mypackage/model','mypackage/metric','mypackage/model/shared'], # 需要打包的目录列表 27 | 28 | # 需要安装的依赖 29 | install_requires=[ 30 | "nvidia_ml_py3>=7.352.0", 31 | 'imageio', 32 | ], 33 | 34 | # # 添加这个选项,在windows下Python目录的scripts下生成exe文件 35 | # # 注意:模块与函数之间是冒号: 36 | # entry_points={'console_scripts': [ 37 | # 'redis_run = RedisRun.redis_run:main', 38 | # ]}, 39 | 40 | # long_description=read('README.md'), 41 | classifiers=[ # 程序的所属分类列表 42 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 43 | "License :: OSI Approved :: Apache Software License", 44 | # 'Development Status :: 2 - Pre-Alpha', 45 | # 'Development Status :: 3 - Alpha', 46 | 'Development Status :: 4 - Beta', 47 | # 'Development Status :: 5 - Production/Stable', 48 | 'Operating System :: OS Independent', 49 | 'Intended Audience :: Developers', 50 | "Programming Language :: Python :: 3 :: Only", 51 | 'Programming Language :: Python :: Implementation', 52 | 'Programming Language :: Python :: 3.6', 53 | 'Programming Language :: Python :: 3.7', 54 | 'Programming Language :: Python :: Implementation :: CPython', 55 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 56 | 'Topic :: Scientific/Engineering :: Visualization ' 57 | ], 58 | # 此项需要,否则卸载时报windows error 59 | zip_safe=False 60 | ) 61 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to jdit documentation! 2 | ================================ 3 | .. toctree:: 4 | :glob: 5 | :maxdepth: 1 6 | :caption: Notes 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | :caption: User Guide 11 | 12 | quick start 13 | Build your own trainer 14 | 15 | .. toctree:: 16 | :maxdepth: 1 17 | :caption: Package Reference 18 | 19 | dataset 20 | model 21 | optimizer 22 | trainer 23 | assessment 24 | parallel 25 | 26 | 27 | 28 | **Jdit** is a research processing oriented framework based on pytorch. Only care about your ideas. 29 | You don't need to build a long boring code to run a deep learning project to verify your ideas. 30 | 31 | You only need to implement you ideas and 32 | don't do anything with training framework, multiply-gpus, checkpoint, process visualization, performance evaluation and so on. 33 | 34 | Quick start 35 | ----------- 36 | After building and installing jdit package, you can make a new directory for a quick test. 37 | Assuming that you get a new directory `example`. 38 | run this code in `ipython` cmd.(Create a `main.py` file is also acceptable.) 39 | 40 | .. code-block:: python 41 | 42 | from jdit.trainer.instances.fashionClassification 43 | import start_fashionClassTrainer 44 | start_fashionClassTrainer() 45 | 46 | Then you will see something like this as following. 47 | 48 | .. code-block:: python 49 | 50 | ===> Build dataset 51 | use 8 thread 52 | Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz 53 | Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz 54 | Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz 55 | Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz 56 | Processing... 57 | Done 58 | ===> Building model 59 | ResNet Total number of parameters: 2776522 60 | ResNet model use CPU 61 | apply kaiming weight init 62 | ===> Building optimizer 63 | ===> Training 64 | using `tensorboard --logdir=log` to see learning curves and net structure. 65 | training and valid_epoch data, configures info and checkpoint were save in `log` directory. 66 | 0%| | 0/10 [00:00<.., ..epoch/s] 67 | 0step [00:00, step/s] 68 | 69 | * It will search a fashion mnist dataset. 70 | * Then build a resnet18 for classification. 71 | * For training process, you can find learning curves in `tensorboard`. 72 | * It will create a `log` directory in `example/`, which saves training processing data and configures. 73 | 74 | Although it is just an example, you still can build your own project easily by using jdit framework. 75 | Jdit framework can deal with 76 | * Data visualization. (learning curves, images in pilot process) 77 | * CPU, GPU or GPUs. (Training your model on specify devices) 78 | * Intermediate data storage. (Saving training data into a csv file) 79 | * Model checkpoint automatically. 80 | * Flexible templates can be used to integrate and custom overrides. 81 | So, let's see what is **jdit**. 82 | 83 | 84 | Indices and tables 85 | ================== 86 | 87 | * :ref:`genindex` 88 | * :ref:`modindex` 89 | * :ref:`search` -------------------------------------------------------------------------------- /jdit/trainer/instances/fashionAutoencoder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from jdit.trainer.single.autoencoder import AutoEncoderTrainer 6 | from jdit import Model 7 | from jdit.optimizer import Optimizer 8 | from jdit.dataset import FashionMNIST 9 | 10 | 11 | class SimpleModel(nn.Module): 12 | def __init__(self, depth=32): 13 | super(SimpleModel, self).__init__() 14 | self.layer1 = nn.Conv2d(1, depth, 3, 1, 1) 15 | self.layer2 = nn.Conv2d(depth, depth * 2, 4, 2, 1) 16 | self.layer3 = nn.Conv2d(depth * 2, depth * 4, 4, 2, 1) 17 | self.layer4 = nn.ConvTranspose2d(depth * 4, depth * 2, 4, 2, 1) 18 | self.layer5 = nn.ConvTranspose2d(depth * 2, 1, 4, 2, 1) 19 | 20 | def forward(self, x): 21 | out = F.relu(self.layer1(x)) 22 | out = F.relu(self.layer2(out)) 23 | out = F.relu(self.layer3(out)) 24 | out = F.relu(self.layer4(out)) 25 | out = self.layer5(out) 26 | return out 27 | 28 | 29 | class FashionAutoEncoderTrainer(AutoEncoderTrainer): 30 | def __init__(self, logdir, nepochs, gpu_ids, net, opt, datasets): 31 | super(FashionAutoEncoderTrainer, self).__init__(logdir, nepochs, gpu_ids, net, opt, datasets) 32 | data, label = self.datasets.samples_train 33 | self.watcher.embedding(data, data, label, 1) 34 | 35 | def get_data_from_batch(self, batch_data, device): 36 | input = ground_truth = batch_data[0] 37 | return input, ground_truth 38 | 39 | def compute_loss(self): 40 | var_dic = {} 41 | var_dic["CEP"] = loss = nn.MSELoss(reduction="mean")(self.output, self.ground_truth) 42 | return loss, var_dic 43 | 44 | def compute_valid(self): 45 | _, var_dic = self.compute_loss() 46 | return var_dic 47 | 48 | 49 | def start_fashionAutoencoderTrainer(gpus=(), nepochs=10, run_type="train"): 50 | """" An example of fashion-mnist classification 51 | 52 | """ 53 | depth = 32 54 | gpus = gpus 55 | batch_size = 16 56 | nepochs = nepochs 57 | opt_hpm = {"optimizer": "Adam", 58 | "lr_decay": 0.94, 59 | "decay_position": 10, 60 | "position_type": "epoch", 61 | "lr_reset": {2: 5e-4, 3: 1e-3}, 62 | "lr": 1e-3, 63 | "weight_decay": 2e-5, 64 | "betas": (0.9, 0.99)} 65 | 66 | print('===> Build dataset') 67 | mnist = FashionMNIST(batch_size=batch_size) 68 | torch.backends.cudnn.benchmark = True 69 | print('===> Building model') 70 | net = Model(SimpleModel(depth=depth), gpu_ids_abs=gpus, init_method="kaiming", check_point_pos=1) 71 | print('===> Building optimizer') 72 | opt = Optimizer(net.parameters(), **opt_hpm) 73 | print('===> Training') 74 | print("using `tensorboard --logdir=log` to see learning curves and net structure." 75 | "training and valid_epoch data, configures info and checkpoint were save in `log` directory.") 76 | Trainer = FashionAutoEncoderTrainer("log/fashion_classify", nepochs, gpus, net, opt, mnist) 77 | if run_type == "train": 78 | Trainer.train() 79 | elif run_type == "debug": 80 | Trainer.debug() 81 | 82 | 83 | if __name__ == '__main__': 84 | start_fashionAutoencoderTrainer() 85 | -------------------------------------------------------------------------------- /docs/source/quick start.rst: -------------------------------------------------------------------------------- 1 | Quick Start 2 | =========== 3 | You can get a quick start by following these setps. 4 | After building and installing jdit package, you can make a new directory for a quick test. 5 | Assuming that you get a new directory ``example``. 6 | run this code in ``ipython`` .(Create a ``main.py`` file is also acceptable.) 7 | 8 | 9 | Fashion-mnist Classification 10 | ---------------------------- 11 | To start a simple classification task. 12 | 13 | .. code:: python 14 | 15 | from jdit.trainer.instances.fashionClassification import start_fashionClassTrainer 16 | start_fashionClassTrainer() 17 | 18 | Then you will see something like this as following. 19 | 20 | .. code:: 21 | 22 | ===> Build dataset 23 | use 8 thread! 24 | Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz 25 | Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz 26 | Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz 27 | Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz 28 | Processing... 29 | Done! 30 | ===> Building model 31 | SimpleModel Total number of parameters: 2776522 32 | ResNet model use CPU! 33 | apply kaiming weight init! 34 | ===> Building optimizer 35 | ===> Training 36 | using `tensorboard --logdir=log` to see learning curves and net structure. 37 | training and valid_epoch data, configures info and checkpoint were save in `log` directory. 38 | 0%| | 0/10 [00:00 Build dataset 63 | use 2 thread! 64 | ===> Building model 65 | Discriminator Total number of parameters: 100865 66 | Discriminator model use GPU(0)! 67 | apply kaiming weight init! 68 | Generator Total number of parameters: 951361 69 | Generator model use GPU(0)! 70 | apply kaiming weight init! 71 | ===> Building optimizer 72 | ===> Training 73 | 0%| | 0/200 [00:001 0010=>2 47 | total = predict.size(0) 48 | correct = predict.eq(labels).cpu().sum().float() 49 | acc = correct / total 50 | var_dic["ACC"] = acc 51 | return var_dic 52 | 53 | 54 | def start_fashionClassTrainer(gpus=(), nepochs=10, run_type="train"): 55 | """" An example of fashion-mnist classification 56 | 57 | """ 58 | num_class = 10 59 | depth = 32 60 | gpus = gpus 61 | batch_size = 4 62 | nepochs = nepochs 63 | opt_hpm = {"optimizer": "Adam", 64 | "lr_decay": 0.94, 65 | "decay_position": 10, 66 | "position_type": "epoch", 67 | "lr_reset": {2: 5e-4, 3: 1e-3}, 68 | "lr": 1e-4, 69 | "weight_decay": 2e-5, 70 | "betas": (0.9, 0.99)} 71 | 72 | print('===> Build dataset') 73 | mnist = FashionMNIST(batch_size=batch_size) 74 | # mnist.dataset_train = mnist.dataset_test 75 | torch.backends.cudnn.benchmark = True 76 | print('===> Building model') 77 | net = Model(SimpleModel(depth=depth), gpu_ids_abs=gpus, init_method="kaiming", check_point_pos=1) 78 | print('===> Building optimizer') 79 | opt = Optimizer(net.parameters(), **opt_hpm) 80 | print('===> Training') 81 | print("using `tensorboard --logdir=log` to see learning curves and net structure." 82 | "training and valid_epoch data, configures info and checkpoint were save in `log` directory.") 83 | Trainer = FashionClassTrainer("log/fashion_classify", nepochs, gpus, net, opt, mnist, num_class) 84 | if run_type == "train": 85 | Trainer.train() 86 | elif run_type == "debug": 87 | Trainer.debug() 88 | 89 | 90 | 91 | if __name__ == '__main__': 92 | start_fashionClassTrainer() 93 | -------------------------------------------------------------------------------- /jdit/trainer/instances/fashionClassParallelTrainer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from jdit.trainer.single.classification import ClassificationTrainer 6 | from jdit.model import Model 7 | from jdit.optimizer import Optimizer 8 | from jdit.dataset import FashionMNIST 9 | from jdit.parallel import SupParallelTrainer 10 | 11 | 12 | class SimpleModel(nn.Module): 13 | def __init__(self, depth=64, num_class=10): 14 | super(SimpleModel, self).__init__() 15 | self.num_class = num_class 16 | self.layer1 = nn.Conv2d(1, depth, 3, 1, 1) 17 | self.layer2 = nn.Conv2d(depth, depth * 2, 4, 2, 1) 18 | self.layer3 = nn.Conv2d(depth * 2, depth * 4, 4, 2, 1) 19 | self.layer4 = nn.Conv2d(depth * 4, depth * 8, 4, 2, 1) 20 | self.layer5 = nn.Conv2d(depth * 8, num_class, 4, 1, 0) 21 | 22 | def forward(self, x): 23 | out = F.relu(self.layer1(x)) 24 | out = F.relu(self.layer2(out)) 25 | out = F.relu(self.layer3(out)) 26 | out = F.relu(self.layer4(out)) 27 | out = self.layer5(out) 28 | out = out.view(-1, self.num_class) 29 | return out 30 | 31 | 32 | class FashionClassTrainer(ClassificationTrainer): 33 | def __init__(self, logdir, nepochs, gpu_ids, net, opt, dataset, num_class): 34 | super(FashionClassTrainer, self).__init__(logdir, nepochs, gpu_ids, net, opt, dataset, num_class) 35 | 36 | def compute_loss(self): 37 | var_dic = {} 38 | var_dic["CEP"] = loss = nn.CrossEntropyLoss()(self.output, self.ground_truth.squeeze().long()) 39 | 40 | _, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2 41 | total = predict.size(0) * 1.0 42 | labels = self.ground_truth.squeeze().long() 43 | correct = predict.eq(labels).cpu().sum().float() 44 | acc = correct / total 45 | var_dic["ACC"] = acc 46 | return loss, var_dic 47 | 48 | def compute_valid(self): 49 | _,var_dic = self.compute_loss() 50 | return var_dic 51 | 52 | 53 | def build_task_trainer(unfixed_params): 54 | """build a task just like FashionClassTrainer. 55 | 56 | :param unfixed_params: 57 | :return: 58 | """ 59 | logdir = unfixed_params['logdir'] 60 | gpu_ids_abs = unfixed_params["gpu_ids_abs"] 61 | depth = unfixed_params["depth"] 62 | lr = unfixed_params["lr"] 63 | 64 | batch_size = 32 65 | opt_name = "RMSprop" 66 | lr_decay = 0.94 67 | decay_position= 1 68 | position_type = "epoch" 69 | weight_decay = 2e-5 70 | momentum = 0 71 | nepochs = 100 72 | num_class = 10 73 | torch.backends.cudnn.benchmark = True 74 | mnist = FashionMNIST(root="datasets/fashion_data", batch_size=batch_size, num_workers=2) 75 | net = Model(SimpleModel(depth), gpu_ids_abs=gpu_ids_abs, init_method="kaiming", verbose=False) 76 | opt = Optimizer(net.parameters(), opt_name, lr_decay, decay_position, position_type=position_type, 77 | lr=lr, weight_decay=weight_decay, momentum=momentum) 78 | Trainer = FashionClassTrainer(logdir, nepochs, gpu_ids_abs, net, opt, mnist, num_class) 79 | return Trainer 80 | 81 | 82 | def trainerParallel(): 83 | unfixed_params = [ 84 | {'task_id': 1, 'gpu_ids_abs': [], 85 | 'depth': 4, 'lr': 1e-3, 86 | }, 87 | {'task_id': 1, 'gpu_ids_abs': [], 88 | 'depth': 8, 'lr': 1e-2, 89 | }, 90 | 91 | {'task_id': 2, 'gpu_ids_abs': [], 92 | 'depth': 4, 'lr': 1e-2, 93 | }, 94 | {'task_id': 2, 'gpu_ids_abs': [], 95 | 'depth': 8, 'lr': 1e-3, 96 | }, 97 | ] 98 | tp = SupParallelTrainer(unfixed_params, build_task_trainer) 99 | return tp 100 | 101 | 102 | def start_fashionClassPrarallelTrainer(run_type="debug"): 103 | tp = trainerParallel() 104 | tp.train() 105 | 106 | if __name__ == '__main__': 107 | start_fashionClassPrarallelTrainer() 108 | -------------------------------------------------------------------------------- /unittests/test_model.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | import torch 3 | from torch.nn import init, Conv2d, Sequential 4 | from jdit.model import Model 5 | import shutil 6 | 7 | 8 | class TestModel(TestCase): 9 | def setUp(self): 10 | self.mode = Sequential(Conv2d(10, 1, 3, 1, 0)) 11 | self.epoch = 32 12 | 13 | def test_define(self): 14 | net = Model(self.mode, [], "kaiming", show_structure=False) 15 | if torch.cuda.device_count() == 0: 16 | net = Model(self.mode, [], "kaiming", show_structure=False) 17 | elif torch.cuda.device_count() == 1: 18 | net = Model(self.mode, [0], "kaiming", show_structure=False) 19 | else : 20 | net = Model(self.mode, [0, 1], "kaiming", show_structure=False) 21 | 22 | def test_save_load_weights(self): 23 | print(self.mode) 24 | if torch.cuda.device_count() == 0: 25 | net = Model(self.mode, [], "kaiming", show_structure=False) 26 | elif torch.cuda.device_count() == 1: 27 | net = Model(self.mode, [0], "kaiming", show_structure=False) 28 | else : 29 | net = Model(self.mode, [0, 1], "kaiming", show_structure=False) 30 | net.check_point("tm", self.epoch, "test_model") 31 | net.load_weights("test_model/checkpoint/Weights_tm_%d.pth" % self.epoch) 32 | dir = "test_model/" 33 | shutil.rmtree(dir) 34 | 35 | def test_load_point(self): 36 | if torch.cuda.device_count() == 0: 37 | net = Model(self.mode, [], "kaiming", show_structure=False) 38 | elif torch.cuda.device_count() == 1: 39 | net = Model(self.mode, [0], "kaiming", show_structure=False) 40 | else : 41 | net = Model(self.mode, [0, 1], "kaiming", show_structure=False) 42 | 43 | net.check_point("tm", self.epoch, "test_model") 44 | net.load_point("tm", self.epoch, "test_model") 45 | dir = "test_model/" 46 | shutil.rmtree(dir) 47 | 48 | def test_check_point(self): 49 | if torch.cuda.is_available(): 50 | net = Model(self.mode, [0], "kaiming", show_structure=False) 51 | elif torch.cuda.device_count() > 1: 52 | net = Model(self.mode, [0, 1], "kaiming", show_structure=False) 53 | elif torch.cuda.device_count() > 2: 54 | net = Model(self.mode, [2, 3], "kaiming", show_structure=False) 55 | else: 56 | net = Model(self.mode, [], "kaiming", show_structure=False) 57 | net.check_point("tm", self.epoch, "test_model") 58 | dir = "test_model/" 59 | shutil.rmtree(dir) 60 | 61 | def test__weight_init(self): 62 | if torch.cuda.is_available(): 63 | net = Model(self.mode, [0], "kaiming", show_structure=False) 64 | elif torch.cuda.device_count() > 1: 65 | net = Model(self.mode, [0, 1], "kaiming", show_structure=False) 66 | elif torch.cuda.device_count() > 2: 67 | net = Model(self.mode, [2, 3], "kaiming", show_structure=False) 68 | else: 69 | net = Model(self.mode, [], "kaiming", show_structure=False) 70 | net.init_fc = init.kaiming_normal_ 71 | self.mode.apply(net._weight_init) 72 | 73 | def test__fix_weights(self): 74 | pass 75 | 76 | def test__set_device(self): 77 | pass 78 | 79 | def test_configure(self): 80 | if torch.cuda.device_count() == 0: 81 | net = Model(self.mode, [], "kaiming", show_structure=False) 82 | elif torch.cuda.device_count() == 1: 83 | net = Model(self.mode, [0], "kaiming", show_structure=False) 84 | else : 85 | net = Model(self.mode, [0, 1], "kaiming", show_structure=False) 86 | self.assertEqual(net.configure, 87 | {'model_name': 'Sequential', 'init_method': 'kaiming', 'total_params': 91, 88 | 'structure': 'Sequential(\n (0): Conv2d(10, 1, kernel_size=(3, 3), ' 89 | 'stride=(1, 1))\n)'}) 90 | print(net.configure) 91 | for k, v in net.configure.items(): 92 | assert k is not None 93 | assert v is not None 94 | -------------------------------------------------------------------------------- /jdit/trainer/single/classification.py: -------------------------------------------------------------------------------- 1 | # from torch.nn import CrossEntropyLoss 2 | from .sup_single import * 3 | from abc import abstractmethod 4 | 5 | 6 | 7 | class ClassificationTrainer(SupSingleModelTrainer): 8 | """this is a classification trainer. 9 | 10 | """ 11 | 12 | def __init__(self, logdir, nepochs, gpu_ids, net, opt, datasets, num_class): 13 | super(ClassificationTrainer, self).__init__(logdir, nepochs, gpu_ids, net, opt, datasets) 14 | self.net = net 15 | self.opt = opt 16 | self.datasets = datasets 17 | self.num_class = num_class 18 | 19 | 20 | @abstractmethod 21 | def compute_loss(self): 22 | """Compute the main loss and observed values. 23 | 24 | Compute the loss and other values shown in tensorboard scalars visualization. 25 | You should return a main loss for doing backward propagation. 26 | 27 | So, if you want some values visualized. Make a ``dict()`` with key name is the variable's name. 28 | The training logic is : 29 | self.input, self.ground_truth = self.get_data_from_batch(batch, self.device) 30 | self.output = self.net(self.input) 31 | self._train_iteration(self.opt, self.compute_loss, csv_filename="Train") 32 | So, you have `self.net`, `self.input`, `self.output`, `self.ground_truth` to compute your own loss here. 33 | 34 | .. note:: 35 | 36 | Only the main loss will do backward propagation, which is the first returned variable. 37 | If you have the joint loss, please add them up and return one main loss. 38 | 39 | .. note:: 40 | 41 | All of your variables in returned ``dict()`` will never do backward propagation with ``model.train()``. 42 | However, It still compute grads, without using ``with torch.autograd.no_grad()``. 43 | So, you can compute any grads variables for visualization. 44 | 45 | Example:: 46 | 47 | var_dic = {} 48 | labels = self.ground_truth.squeeze().long() 49 | var_dic["MSE"] = loss = nn.MSELoss()(self.output, labels) 50 | return loss, var_dic 51 | 52 | """ 53 | 54 | @abstractmethod 55 | def compute_valid(self): 56 | """Compute the valid_epoch variables for visualization. 57 | 58 | Compute the validations. 59 | For the validations will only be used in tensorboard scalars visualization. 60 | So, if you want some variables visualized. Make a ``dict()`` with key name is the variable's name. 61 | You have `self.net`, `self.input`, `self.output`, `self.ground_truth` to compute your own validations here. 62 | 63 | .. note:: 64 | 65 | All of your variables in returned ``dict()`` will never do backward propagation with ``model.eval()``. 66 | However, It still compute grads, without using ``with torch.autograd.no_grad()``. 67 | So, you can compute some grads variables for visualization. 68 | 69 | Example:: 70 | var_dic = {} 71 | labels = self.ground_truth.squeeze().long() 72 | var_dic["CEP"] = nn.CrossEntropyLoss()(self.output, labels) 73 | return var_dic 74 | 75 | """ 76 | 77 | def valid_epoch(self): 78 | avg_dic = dict() 79 | self.net.eval() 80 | for iteration, batch in enumerate(self.datasets.loader_valid, 1): 81 | self.input, self.ground_truth = self.get_data_from_batch(batch, self.device) 82 | self.output = self.net(self.input).detach() 83 | dic = self.compute_valid() 84 | if avg_dic == {}: 85 | avg_dic = dic 86 | else: 87 | # sum up 88 | for key in dic.keys(): 89 | avg_dic[key] += dic[key] 90 | 91 | for key in avg_dic.keys(): 92 | avg_dic[key] = avg_dic[key] / self.datasets.nsteps_valid 93 | 94 | self.watcher.scalars(var_dict=avg_dic, global_step=self.step, tag="Valid") 95 | self.loger.write(self.step, self.current_epoch, avg_dic, "Valid", header=self.current_epoch <= 1) 96 | self.net.train() 97 | 98 | def get_data_from_batch(self, batch_data, device): 99 | """If you have different behavior. You need to rewrite thisd method and the method `sllf.train_epoch()` 100 | 101 | :param batch_data: A Tensor loads from dataset 102 | :param device: compute device 103 | :return: Tensors, 104 | """ 105 | input_tensor, labels_tensor = batch_data[0], batch_data[1] 106 | return input_tensor, labels_tensor 107 | 108 | def _watch_images(self, tag: str, grid_size: tuple = (3, 3), shuffle=False, save_file=True): 109 | pass 110 | 111 | @property 112 | def configure(self): 113 | config_dic = super(ClassificationTrainer, self).configure 114 | config_dic["num_class"] = self.num_class 115 | return config_dic 116 | -------------------------------------------------------------------------------- /jdit/trainer/single/autoencoder.py: -------------------------------------------------------------------------------- 1 | from .sup_single import * 2 | from abc import abstractmethod 3 | 4 | 5 | class AutoEncoderTrainer(SupSingleModelTrainer): 6 | """this is a autoencoder-decoder trainer. Image to Image 7 | 8 | """ 9 | 10 | def __init__(self, logdir, nepochs, gpu_ids, net, opt, datasets): 11 | super(AutoEncoderTrainer, self).__init__(logdir, nepochs, gpu_ids, net, opt, datasets) 12 | self.net = net 13 | self.opt = opt 14 | self.datasets = datasets 15 | 16 | @abstractmethod 17 | def compute_loss(self): 18 | """Compute the main loss and observed values. 19 | 20 | Compute the loss and other values shown in tensorboard scalars visualization. 21 | You should return a main loss for doing backward propagation. 22 | 23 | So, if you want some values visualized. Make a ``dict()`` with key name is the variable's name. 24 | The training logic is : 25 | self.input, self.ground_truth = self.get_data_from_batch(batch, self.device) 26 | self.output = self.net(self.input) 27 | self._train_iteration(self.opt, self.compute_loss, csv_filename="Train") 28 | So, you have `self.net`, `self.input`, `self.output`, `self.ground_truth` to compute your own loss here. 29 | 30 | .. note:: 31 | 32 | Only the main loss will do backward propagation, which is the first returned variable. 33 | If you have the joint loss, please add them up and return one main loss. 34 | 35 | .. note:: 36 | 37 | All of your variables in returned ``dict()`` will never do backward propagation with ``model.train()``. 38 | However, It still compute grads, without using ``with torch.autograd.no_grad()``. 39 | So, you can compute any grads variables for visualization. 40 | 41 | Example:: 42 | 43 | var_dic = {} 44 | var_dic["CEP"] = loss = nn.MSELoss(reduction="mean")(self.output, self.ground_truth) 45 | return loss, var_dic 46 | 47 | """ 48 | 49 | @abstractmethod 50 | def compute_valid(self): 51 | """Compute the valid_epoch variables for visualization. 52 | 53 | Compute the caring variables. 54 | For the caring variables will only be used in tensorboard scalars visualization. 55 | So, if you want some variables visualized. Make a ``dict()`` with key name is the variable's name. 56 | 57 | .. note:: 58 | 59 | All of your variables in returned ``dict()`` will never do backward propagation with ``model.eval()``. 60 | However, It still compute grads, without using ``with torch.autograd.no_grad()``. 61 | So, you can compute some grads variables for visualization. 62 | 63 | Example:: 64 | var_dic = {} 65 | var_dic["CEP"] = loss = nn.MSELoss(reduction="mean")(self.output, self.ground_truth) 66 | return var_dic 67 | 68 | """ 69 | 70 | def valid_epoch(self): 71 | avg_dic = dict() 72 | self.net.eval() 73 | for iteration, batch in enumerate(self.datasets.loader_valid, 1): 74 | self.input, self.labels = self.get_data_from_batch(batch, self.device) 75 | self.output = self.net(self.input).detach() 76 | dic = self.compute_valid() 77 | if avg_dic == {}: 78 | avg_dic = dic 79 | else: 80 | for key in dic.keys(): 81 | avg_dic[key] += dic[key] 82 | 83 | for key in avg_dic.keys(): 84 | avg_dic[key] = avg_dic[key] / self.datasets.nsteps_valid 85 | 86 | self.watcher.scalars(var_dict=avg_dic, global_step=self.step, tag="Valid") 87 | self.loger.write(self.step, self.current_epoch, avg_dic, "Valid", header=self.current_epoch <= 1) 88 | self.net.train() 89 | 90 | def get_data_from_batch(self, batch_data, device): 91 | """If you have different behavior. You need to rewrite thisd method and the method `sllf.train_epoch()` 92 | 93 | :param batch_data: A Tensor loads from dataset 94 | :param device: compute device 95 | :return: Tensors, 96 | """ 97 | input_tensor, ground_gruth_tensor = batch_data[0], batch_data[1] 98 | return input_tensor, ground_gruth_tensor 99 | 100 | def _watch_images(self, tag: str, grid_size: tuple = (3, 3), shuffle=False, save_file=True): 101 | self.watcher.image(self.input, 102 | self.current_epoch, 103 | tag="%s/input" % tag, 104 | grid_size=grid_size, 105 | shuffle=shuffle, 106 | save_file=save_file) 107 | self.watcher.image(self.output, 108 | self.current_epoch, 109 | tag="%s/output" % tag, 110 | grid_size=grid_size, 111 | shuffle=shuffle, 112 | save_file=save_file) 113 | self.watcher.image(self.ground_truth, 114 | self.current_epoch, 115 | tag="%s/ground_truth" % tag, 116 | grid_size=grid_size, 117 | shuffle=shuffle, 118 | save_file=save_file) 119 | -------------------------------------------------------------------------------- /jdit/trainer/gan/generate.py: -------------------------------------------------------------------------------- 1 | from .sup_gan import SupGanTrainer 2 | from abc import abstractmethod 3 | from torch.autograd import Variable 4 | import torch 5 | 6 | 7 | class GenerateGanTrainer(SupGanTrainer): 8 | d_turn = 1 9 | """The training times of Discriminator every ones Generator training. 10 | """ 11 | 12 | def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets, latent_shape): 13 | """ a gan super class 14 | 15 | :param logdir:Path of log 16 | :param nepochs:Amount of epochs. 17 | :param gpu_ids_abs: he id of gpus which t obe used. If use CPU, set ``[]``. 18 | :param netG:Generator model. 19 | :param netD:Discrimiator model 20 | :param optG:Optimizer of Generator. 21 | :param optD:Optimizer of Discrimiator. 22 | :param datasets:Datasets. 23 | :param latent_shape:The shape of input noise. 24 | """ 25 | super(GenerateGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets) 26 | self.latent_shape = latent_shape 27 | self.fixed_input = torch.randn((self.datasets.batch_size, *self.latent_shape)).to(self.device) 28 | # self.metric = FID(self.gpu_ids) 29 | 30 | def get_data_from_batch(self, batch_data: list, device: torch.device): 31 | ground_truth_tensor = batch_data[0] 32 | input_tensor = Variable(torch.randn((len(ground_truth_tensor), *self.latent_shape))) 33 | return input_tensor, ground_truth_tensor 34 | 35 | def valid_epoch(self): 36 | super(GenerateGanTrainer, self).valid_epoch() 37 | self.netG.eval() 38 | # watching the variation during training by a fixed input 39 | with torch.no_grad(): 40 | fake = self.netG(self.fixed_input).detach() 41 | self.watcher.image(fake, self.current_epoch, tag="Valid/Fixed_fake", grid_size=(4, 4), shuffle=False) 42 | # saving training processes to build a .gif. 43 | self.watcher.set_training_progress_images(fake, grid_size=(4, 4)) 44 | self.netG.train() 45 | 46 | @abstractmethod 47 | def compute_d_loss(self): 48 | """ Rewrite this method to compute your own loss Discriminator. 49 | 50 | You should return a **loss** for the first position. 51 | You can return a ``dict`` of loss that you want to visualize on the second position.like 52 | The train logic is : 53 | self.input, self.ground_truth = self.get_data_from_batch(batch, self.device) 54 | self.fake = self.netG(self.input) 55 | self._train_iteration(self.optD, self.compute_d_loss, csv_filename="Train_D") 56 | if (self.step % self.d_turn) == 0: 57 | self._train_iteration(self.optG, self.compute_g_loss, csv_filename="Train_G") 58 | So, you use `self.input` , `self.ground_truth`, `self.fake`, `self.netG`, `self.optD` to compute loss. 59 | Example:: 60 | 61 | d_fake = self.netD(self.fake.detach()) 62 | d_real = self.netD(self.ground_truth) 63 | var_dic = {} 64 | var_dic["LS_LOSSD"] = loss_d = 0.5 * (torch.mean((d_real - 1) ** 2) + torch.mean(d_fake ** 2)) 65 | return loss_d, var_dic 66 | 67 | """ 68 | loss_d = None 69 | var_dic = {} 70 | 71 | return loss_d, var_dic 72 | 73 | @abstractmethod 74 | def compute_g_loss(self): 75 | """Rewrite this method to compute your own loss of Generator. 76 | 77 | You should return a **loss** for the first position. 78 | You can return a ``dict`` of loss that you want to visualize on the second position.like 79 | The train logic is : 80 | self.input, self.ground_truth = self.get_data_from_batch(batch, self.device) 81 | self.fake = self.netG(self.input) 82 | self._train_iteration(self.optD, self.compute_d_loss, csv_filename="Train_D") 83 | if (self.step % self.d_turn) == 0: 84 | self._train_iteration(self.optG, self.compute_g_loss, csv_filename="Train_G") 85 | So, you use `self.input` , `self.ground_truth`, `self.fake`, `self.netG`, `self.optD` to compute loss. 86 | Example:: 87 | 88 | d_fake = self.netD(self.fake, self.input) 89 | var_dic = {} 90 | var_dic["LS_LOSSG"] = loss_g = 0.5 * torch.mean((d_fake - 1) ** 2) 91 | return loss_g, var_dic 92 | 93 | """ 94 | loss_g = None 95 | var_dic = {} 96 | return loss_g, var_dic 97 | 98 | @abstractmethod 99 | def compute_valid(self): 100 | """ 101 | The train logic is : 102 | self.input, self.ground_truth = self.get_data_from_batch(batch, self.device) 103 | self.fake = self.netG(self.input) 104 | self._train_iteration(self.optD, self.compute_d_loss, csv_filename="Train_D") 105 | if (self.step % self.d_turn) == 0: 106 | self._train_iteration(self.optG, self.compute_g_loss, csv_filename="Train_G") 107 | So, you use `self.input` , `self.ground_truth`, `self.fake`, `self.netG`, `self.optD` to compute validations. 108 | 109 | :return: 110 | """ 111 | _, d_var_dic = self.compute_g_loss() 112 | _, g_var_dic = self.compute_d_loss() 113 | var_dic = dict(d_var_dic, **g_var_dic) 114 | return var_dic 115 | 116 | def test(self): 117 | self.input = Variable(torch.randn((self.datasets.batch_size, *self.latent_shape))).to(self.device) 118 | self.netG.eval() 119 | with torch.no_grad(): 120 | fake = self.netG(self.input).detach() 121 | self.watcher.image(fake, self.current_epoch, tag="Test/fake", grid_size=(4, 4), shuffle=False) 122 | self.netG.train() 123 | 124 | @property 125 | def configure(self): 126 | config_dic = super(GenerateGanTrainer, self).configure 127 | config_dic["latent_shape"] = str(self.latent_shape) 128 | return config_dic 129 | -------------------------------------------------------------------------------- /jdit/trainer/instances/fashionGenerateGan.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | import torch.nn as nn 4 | from jdit.trainer import GenerateGanTrainer 5 | from jdit.model import Model 6 | from jdit.optimizer import Optimizer 7 | from jdit.dataset import FashionMNIST 8 | 9 | 10 | class Discriminator(nn.Module): 11 | def __init__(self, input_nc=1, output_nc=1, depth=64): 12 | super(Discriminator, self).__init__() 13 | # 32 x 32 14 | self.layer1 = nn.Sequential( 15 | nn.utils.spectral_norm(nn.Conv2d(input_nc, depth * 1, kernel_size=3, stride=1, padding=1)), 16 | nn.LeakyReLU(0.1)) 17 | # 32 x 32 18 | self.layer2 = nn.Sequential( 19 | nn.utils.spectral_norm(nn.Conv2d(depth * 1, depth * 2, kernel_size=3, stride=1, padding=1)), 20 | nn.LeakyReLU(0.1), 21 | nn.MaxPool2d(2, 2)) 22 | # 16 x 16 23 | self.layer3 = nn.Sequential( 24 | nn.utils.spectral_norm(nn.Conv2d(depth * 2, depth * 4, kernel_size=3, stride=1, padding=1)), 25 | nn.LeakyReLU(0.1), 26 | nn.MaxPool2d(2, 2)) 27 | # 8 x 8 28 | self.layer4 = nn.Sequential(nn.Conv2d(depth * 4, output_nc, kernel_size=8, stride=1, padding=0)) 29 | # 1 x 1 30 | 31 | def forward(self, x): 32 | out = self.layer1(x) 33 | out = self.layer2(out) 34 | out = self.layer3(out) 35 | out = self.layer4(out) 36 | return out 37 | 38 | 39 | class Generator(nn.Module): 40 | def __init__(self, input_nc=256, output_nc=1, depth=64): 41 | super(Generator, self).__init__() 42 | self.encoder = nn.Sequential( 43 | nn.ConvTranspose2d(input_nc, 4 * depth, 4, 1, 0), # 256,1,1 => 256,4,4 44 | nn.ReLU()) 45 | self.decoder = nn.Sequential( 46 | nn.ConvTranspose2d(4 * depth, 4 * depth, 4, 2, 1), # 256,4,4 => 256,8,8 47 | nn.ReLU(), 48 | nn.BatchNorm2d(4 * depth), 49 | nn.ConvTranspose2d(4 * depth, 2 * depth, 4, 2, 1), # 256,8,8 => 128,16,16 50 | nn.ReLU(), 51 | nn.BatchNorm2d(2 * depth), 52 | nn.ConvTranspose2d(2 * depth, depth, 4, 2, 1), # 128,16,16 => 64,32,32 53 | nn.ReLU(), 54 | nn.BatchNorm2d(depth), 55 | nn.ConvTranspose2d(depth, output_nc, 3, 1, 1), # 64,32,32 => 1,32,32 56 | ) 57 | 58 | def forward(self, x): 59 | out = self.encoder(x) 60 | out = self.decoder(out) 61 | return out 62 | 63 | 64 | class FashionGenerateGanTrainer(GenerateGanTrainer): 65 | d_turn = 1 66 | 67 | def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, dataset, latent_shape): 68 | super(FashionGenerateGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, 69 | dataset, 70 | latent_shape=latent_shape) 71 | 72 | data, label = self.datasets.samples_train 73 | self.watcher.embedding(data, data, label, global_step=1) 74 | 75 | def compute_d_loss(self): 76 | d_fake = self.netD(self.fake.detach()) 77 | d_real = self.netD(self.ground_truth) 78 | var_dic = dict() 79 | var_dic["LS_LOSSD"] = loss_d = 0.5 * (torch.mean((d_real - 1) ** 2) + torch.mean(d_fake ** 2)) 80 | return loss_d, var_dic 81 | 82 | def compute_g_loss(self): 83 | d_fake = self.netD(self.fake) 84 | var_dic = dict() 85 | var_dic["LS_LOSSG"] = loss_g = 0.5 * torch.mean((d_fake - 1) ** 2) 86 | return loss_g, var_dic 87 | 88 | def compute_valid(self): 89 | _, d_var_dic = self.compute_g_loss() 90 | _, g_var_dic = self.compute_d_loss() 91 | var_dic = dict(d_var_dic, **g_var_dic) 92 | return var_dic 93 | 94 | 95 | def start_fashionGenerateGanTrainer(gpus=(), nepochs=50, lr=1e-3, depth_G=32, depth_D=32, latent_shape=(256, 1, 1), 96 | run_type="train"): 97 | gpus = gpus # set `gpus = []` to use cpu 98 | batch_size = 64 99 | image_channel = 1 100 | nepochs = nepochs 101 | 102 | depth_G = depth_G 103 | depth_D = depth_D 104 | 105 | G_hprams = {"optimizer": "Adam", "lr_decay": 0.94, 106 | "decay_position": 2, "position_type": "epoch", 107 | "lr": lr, "weight_decay": 2e-5, 108 | "betas": (0.9, 0.99) 109 | } 110 | D_hprams = {"optimizer": "RMSprop", "lr_decay": 0.94, 111 | "decay_position": 2, "position_type": "epoch", 112 | "lr": lr, "weight_decay": 2e-5, 113 | "momentum": 0 114 | } 115 | 116 | # the input shape of Generator 117 | latent_shape = latent_shape 118 | print('===> Build dataset') 119 | mnist = FashionMNIST(batch_size=batch_size) 120 | torch.backends.cudnn.benchmark = True 121 | print('===> Building model') 122 | D_net = Discriminator(input_nc=image_channel, depth=depth_D) 123 | D = Model(D_net, gpu_ids_abs=gpus, init_method="kaiming", check_point_pos=10) 124 | # ----------------------------------- 125 | G_net = Generator(input_nc=latent_shape[0], output_nc=image_channel, depth=depth_G) 126 | G = Model(G_net, gpu_ids_abs=gpus, init_method="kaiming", check_point_pos=10) 127 | print('===> Building optimizer') 128 | opt_D = Optimizer(D.parameters(), **D_hprams) 129 | opt_G = Optimizer(G.parameters(), **G_hprams) 130 | print('===> Training') 131 | print("using `tensorboard --logdir=log` to see learning curves and net structure." 132 | "training and valid_epoch data, configures info and checkpoint were save in `log` directory.") 133 | Trainer = FashionGenerateGanTrainer("log/fashion_generate", nepochs, gpus, G, D, opt_G, opt_D, mnist, 134 | latent_shape) 135 | if run_type == "train": 136 | Trainer.train() 137 | elif run_type == "debug": 138 | Trainer.debug() 139 | 140 | 141 | if __name__ == '__main__': 142 | start_fashionGenerateGanTrainer() 143 | -------------------------------------------------------------------------------- /jdit/trainer/abandon_classification.py: -------------------------------------------------------------------------------- 1 | # from torch.nn import CrossEntropyLoss 2 | from .super import * 3 | from abc import abstractmethod 4 | from tqdm import * 5 | 6 | 7 | class ClassificationTrainer(SupTrainer): 8 | """this is a classification trainer. 9 | 10 | """ 11 | 12 | def __init__(self, logdir, nepochs, gpu_ids, net, opt, datasets, num_class): 13 | super(ClassificationTrainer, self).__init__(nepochs, logdir, gpu_ids) 14 | self.net = net 15 | self.opt = opt 16 | self.datasets = datasets 17 | self.num_class = num_class 18 | self.labels = None 19 | self.output = None 20 | 21 | def train_epoch(self, subbar_disable=False): 22 | for iteration, batch in tqdm(enumerate(self.datasets.loader_train, 1), unit="step", disable=subbar_disable): 23 | self.step += 1 24 | 25 | self.input, self.labels = self.get_data_from_batch(batch, self.device) 26 | self.output = self.net(self.input) 27 | self._train_iteration(self.opt, self.compute_loss, csv_filename="Train") 28 | 29 | if iteration == 1: 30 | self._watch_images(show_imgs_num=3, tag="Train") 31 | 32 | @abstractmethod 33 | def compute_loss(self): 34 | """Compute the main loss and observed variables. 35 | 36 | Compute the loss and other caring variables. 37 | You should return a main loss for doing backward propagation. 38 | 39 | For the caring variables will only be used in tensorboard scalars visualization. 40 | So, if you want some variables visualized. Make a ``dict()`` with key name is the variable's name. 41 | 42 | 43 | .. note:: 44 | 45 | Only the main loss will do backward propagation, which is the first returned variable. 46 | If you have the joint loss, please add them up and return one main loss. 47 | 48 | .. note:: 49 | 50 | All of your variables in returned ``dict()`` will never do backward propagation with ``model.train()``. 51 | However, It still compute grads, without using ``with torch.autograd.no_grad()``. 52 | So, you can compute any grads variables for visualization. 53 | 54 | Example:: 55 | 56 | var_dic = {} 57 | # visualize the value of CrossEntropyLoss. 58 | var_dic["CEP"] = loss = CrossEntropyLoss()(self.output, self.labels.squeeze().long()) 59 | 60 | _, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2 61 | total = predict.size(0) * 1.0 62 | labels = self.labels.squeeze().long() 63 | correct = predict.eq(labels).cpu().sum().float() 64 | acc = correct / total 65 | # visualize the value of accuracy. 66 | var_dic["ACC"] = acc 67 | # using CrossEntropyLoss as the main loss for backward, and return by visualized ``dict`` 68 | return loss, var_dic 69 | 70 | """ 71 | 72 | @abstractmethod 73 | def compute_valid(self): 74 | """Compute the valid_epoch variables for visualization. 75 | 76 | Compute the caring variables. 77 | For the caring variables will only be used in tensorboard scalars visualization. 78 | So, if you want some variables visualized. Make a ``dict()`` with key name is the variable's name. 79 | 80 | .. note:: 81 | 82 | All of your variables in returned ``dict()`` will never do backward propagation with ``model.eval()``. 83 | However, It still compute grads, without using ``with torch.autograd.no_grad()``. 84 | So, you can compute some grads variables for visualization. 85 | 86 | Example:: 87 | 88 | var_dic = {} 89 | # visualize the valid_epoch curve of CrossEntropyLoss 90 | var_dic["CEP"] = loss = CrossEntropyLoss()(self.output, self.labels.squeeze().long()) 91 | 92 | _, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2 93 | total = predict.size(0) * 1.0 94 | labels = self.labels.squeeze().long() 95 | correct = predict.eq(labels).cpu().sum().float() 96 | acc = correct / total 97 | # visualize the valid_epoch curve of accuracy 98 | var_dic["ACC"] = acc 99 | return var_dic 100 | 101 | """ 102 | 103 | def valid_epoch(self): 104 | avg_dic = dict() 105 | self.net.eval() 106 | for iteration, batch in enumerate(self.datasets.loader_valid, 1): 107 | self.input, self.labels = self.get_data_from_batch(batch, self.device) 108 | self.output = self.net(self.input).detach() 109 | dic = self.compute_valid() 110 | if avg_dic == {}: 111 | avg_dic = dic 112 | else: 113 | # 求和 114 | for key in dic.keys(): 115 | avg_dic[key] += dic[key] 116 | 117 | for key in avg_dic.keys(): 118 | avg_dic[key] = avg_dic[key] / self.datasets.nsteps_valid 119 | 120 | self.watcher.scalars(var_dict=avg_dic, global_step=self.step, tag="Valid") 121 | self.loger.write(self.step, self.current_epoch, avg_dic, "Valid", header=self.current_epoch <= 1) 122 | self.net.train() 123 | 124 | def get_data_from_batch(self, batch_data, device): 125 | input_tensor, labels_tensor = batch_data[0], batch_data[1] 126 | # if use_onehot: 127 | # # label => onehot 128 | # y_onehot = torch.zeros(labels.size(0), self.num_class) 129 | # if labels.size() != (labels.size(0), self.num_class): 130 | # labels = labels.reshape((labels.size(0), 1)) 131 | # ground_truth_tensor = y_onehot.scatter_(1, labels, 1).long() # labels => [[],[],[]] batchsize, 132 | # num_class 133 | # labels_tensor = labels 134 | # else: 135 | # # onehot => label 136 | # ground_truth_tensor = labels 137 | # labels_tensor = torch.max(self.labels.detach(), 1) 138 | 139 | return input_tensor, labels_tensor 140 | 141 | def _watch_images(self, show_imgs_num=4, tag="Train"): 142 | pass 143 | 144 | @property 145 | def configure(self): 146 | config_dic = super(ClassificationTrainer, self).configure 147 | config_dic["num_class"] = self.num_class 148 | return config_dic 149 | -------------------------------------------------------------------------------- /jdit/trainer/instances/cifarPix2pixGan.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | import torch.nn as nn 4 | from jdit.trainer import Pix2pixGanTrainer 5 | from jdit.model import Model 6 | from jdit.optimizer import Optimizer 7 | from jdit.dataset import Cifar10 8 | 9 | 10 | class Discriminator(nn.Module): 11 | def __init__(self, input_nc=3, output_nc=1, depth=64): 12 | super(Discriminator, self).__init__() 13 | # 32 x 32 14 | self.layer1 = nn.Sequential( 15 | nn.utils.spectral_norm(nn.Conv2d(input_nc, depth * 1, kernel_size=3, stride=1, padding=1)), 16 | nn.LeakyReLU(0.1)) 17 | # 32 x 32 18 | self.layer2 = nn.Sequential( 19 | nn.utils.spectral_norm(nn.Conv2d(depth * 1, depth * 2, kernel_size=3, stride=1, padding=1)), 20 | nn.LeakyReLU(0.1), 21 | nn.MaxPool2d(2, 2)) 22 | # 16 x 16 23 | self.layer3 = nn.Sequential( 24 | nn.utils.spectral_norm(nn.Conv2d(depth * 2, depth * 4, kernel_size=3, stride=1, padding=1)), 25 | nn.LeakyReLU(0.1), 26 | nn.MaxPool2d(2, 2)) 27 | # 8 x 8 28 | self.layer4 = nn.Sequential(nn.Conv2d(depth * 4, output_nc, kernel_size=8, stride=1, padding=0)) 29 | # 1 x 1 30 | 31 | def forward(self, x): 32 | out = self.layer1(x) 33 | out = self.layer2(out) 34 | out = self.layer3(out) 35 | out = self.layer4(out) 36 | return out 37 | 38 | 39 | class Generator(nn.Module): 40 | def __init__(self, input_nc=1, output_nc=3, depth=32): 41 | super(Generator, self).__init__() 42 | 43 | self.latent_to_features = nn.Sequential( 44 | nn.Conv2d(input_nc, 1 * depth, 3, 1, 1), # 1,32,32 => d,32,32 45 | nn.ReLU(), 46 | nn.BatchNorm2d(1 * depth), 47 | nn.Conv2d(1 * depth, 2 * depth, 4, 2, 1), # d,32,32 => 2d,16,16 48 | nn.ReLU(), 49 | nn.BatchNorm2d(2 * depth), 50 | nn.Conv2d(2 * depth, 4 * depth, 4, 2, 1), # 2d,16,16 => 4d,8,8 51 | nn.ReLU(), 52 | nn.BatchNorm2d(4 * depth), 53 | nn.Conv2d(4 * depth, 4 * depth, 4, 2, 1), # 4d,8,8 => 4d,4,4 54 | nn.ReLU(), 55 | nn.BatchNorm2d(4 * depth) 56 | ) 57 | self.features_to_image = nn.Sequential( 58 | nn.ConvTranspose2d(4 * depth, 4 * depth, 4, 2, 1), # 4d,4,4 => 4d,8,8 59 | nn.ReLU(), 60 | nn.BatchNorm2d(4 * depth), 61 | nn.ConvTranspose2d(4 * depth, 2 * depth, 4, 2, 1), # 4d,8,8 => 2d,16,16 62 | nn.ReLU(), 63 | nn.BatchNorm2d(2 * depth), 64 | nn.ConvTranspose2d(2 * depth, depth, 4, 2, 1), # 2d,16,16 => d,32,32 65 | nn.ReLU(), 66 | nn.BatchNorm2d(depth), 67 | nn.ConvTranspose2d(depth, output_nc, 3, 1, 1), # d,32,32 => 3,32,32 68 | ) 69 | 70 | def forward(self, x): 71 | out = self.latent_to_features(x) 72 | out = self.features_to_image(out) 73 | return out 74 | 75 | 76 | class CifarPix2pixGanTrainer(Pix2pixGanTrainer): 77 | d_turn = 5 78 | 79 | def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets): 80 | super(CifarPix2pixGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, 81 | datasets) 82 | 83 | def get_data_from_batch(self, batch_data, device): 84 | ground_truth_cpu, _ = batch_data[0], batch_data[1] 85 | input_cpu = ground_truth_cpu[:, 0, :, :].unsqueeze(1) # only use one channel [?,3,32,32] =>[?,1,32,32] 86 | return input_cpu, ground_truth_cpu 87 | 88 | def compute_d_loss(self): 89 | d_fake = self.netD(self.fake.detach()) 90 | d_real = self.netD(self.ground_truth) 91 | var_dic = {} 92 | var_dic["LS_LOSSD"] = loss_d = 0.5 * (torch.mean((d_real - 1) ** 2) + torch.mean(d_fake ** 2)) 93 | return loss_d, var_dic 94 | 95 | def compute_g_loss(self): 96 | d_fake = self.netD(self.fake) 97 | var_dic = {} 98 | var_dic["LS_LOSSG"] = loss_g = 0.5 * torch.mean((d_fake - 1) ** 2) 99 | 100 | return loss_g, var_dic 101 | 102 | def compute_valid(self): 103 | g_loss, _ = self.compute_g_loss() 104 | d_loss, _ = self.compute_d_loss() 105 | mse = ((self.fake.detach() - self.ground_truth) ** 2).mean() 106 | var_dic = {"LOSS_D": d_loss, "LOSS_G": g_loss, "MSE": mse} 107 | return var_dic 108 | 109 | 110 | def start_cifarPix2pixGanTrainer(gpus=(), nepochs=200, lr=1e-3, depth_G=32, depth_D=32, run_type="train"): 111 | gpus = gpus # set `gpus = []` to use cpu 112 | batch_size = 32 113 | image_channel = 3 114 | nepochs = nepochs 115 | depth_G = depth_G 116 | depth_D = depth_D 117 | 118 | G_hprams = {"optimizer": "Adam", "lr_decay": 0.9, 119 | "decay_position": 10, "position_type": "epoch", 120 | "lr": lr, "weight_decay": 2e-5, 121 | "betas": (0.9, 0.99) 122 | } 123 | D_hprams = {"optimizer": "RMSprop", "lr_decay": 0.9, 124 | "decay_position": 10, "position_type": "epoch", 125 | "lr": lr, "weight_decay": 2e-5, 126 | "momentum": 0 127 | } 128 | 129 | print('===> Build dataset') 130 | cifar10 = Cifar10(root="datasets/cifar10", batch_size=batch_size) 131 | torch.backends.cudnn.benchmark = True 132 | print('===> Building model') 133 | D_net = Discriminator(input_nc=image_channel, depth=depth_D) 134 | D = Model(D_net, gpu_ids_abs=gpus, init_method="kaiming", check_point_pos=50) 135 | # ----------------------------------- 136 | G_net = Generator(input_nc=1, output_nc=image_channel, depth=depth_G) 137 | G = Model(G_net, gpu_ids_abs=gpus, init_method="kaiming", check_point_pos=50) 138 | print('===> Building optimizer') 139 | opt_D = Optimizer(D.parameters(), **D_hprams) 140 | opt_G = Optimizer(G.parameters(), **G_hprams) 141 | print('===> Training') 142 | Trainer = CifarPix2pixGanTrainer("log/cifar_p2p", nepochs, gpus, G, D, opt_G, opt_D, cifar10) 143 | if run_type == "train": 144 | Trainer.train() 145 | elif run_type == "debug": 146 | Trainer.debug() 147 | 148 | 149 | if __name__ == '__main__': 150 | start_cifarPix2pixGanTrainer() 151 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | 16 | import os 17 | import sys 18 | import mock 19 | 20 | MOCK_MODULES = ['numpy', 'scipy', 21 | 'pandas', 22 | 'matplotlib', 'matplotlib.pyplot', 23 | 'scipy.interpolate', 24 | # 'tensorboardX','tensorboardX.SummaryWriter', 25 | 'tensorboard', 26 | 'torch', 27 | 'torch.optim', 28 | 'torch.nn', 29 | 'torch.nn.functional', 'torch.nn.utils', "torch.nn.parallel", 30 | 'torch.utils', 31 | 'torch.utils.model_zoo', 'torch.utils.tensorboard', 32 | 'torch.utils.data', "torch.utils.data.distributed", 33 | 'torch.autograd', 'torch.autograd.Variable', 34 | 'torchvision', 35 | 'torchvision.transforms', 'torchvision.utils', 36 | 'torchvision.transforms.Resize', 'torchvision.transforms.ToTensor', 37 | 'torchvision.transforms.Normalize', 38 | 'torchvision.utils.make_grid', 'torchvision.datasets', 39 | 'torchvision.datasets.MNIST', 'torchvision.datasets.FashionMNIST', 40 | 'torchvision.datasets.CIFAR10', 'torchvision.datasets.LSUN', 41 | ] 42 | 43 | for mod_name in MOCK_MODULES: 44 | sys.modules[mod_name] = mock.Mock() 45 | 46 | sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) 47 | # sys.path.insert(0, os.path.abspath('../../')) 48 | # import jdit 49 | 50 | 51 | # -- Project information ----------------------------------------------------- 52 | project = 'jdit' 53 | copyright = '2018, dingguanglei' 54 | author = 'dingguanglei' 55 | 56 | # The short X.Y version 57 | version = '' 58 | # The full version, including alpha/beta/rc tags 59 | release = '0.1.5' 60 | 61 | # -- General configuration --------------------------------------------------- 62 | 63 | # If your documentation needs a minimal Sphinx version, state it here. 64 | # 65 | # needs_sphinx = '1.0' 66 | 67 | # Add any Sphinx extension module names here, as strings. They can be 68 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 69 | # ones. 70 | extensions = [ 71 | 'sphinx.ext.autodoc', 72 | 'sphinx.ext.autosummary', 73 | 'sphinx.ext.doctest', 74 | 'sphinx.ext.intersphinx', 75 | 'sphinx.ext.viewcode', 76 | ] 77 | 78 | # Add any paths that contain templates here, relative to this directory. 79 | templates_path = ['ntemplates'] 80 | 81 | # The suffix(es) of source filenames. 82 | # You can specify multiple suffix as a list of string: 83 | # 84 | # source_suffix = ['.rst', '.md'] 85 | source_suffix = '.rst' 86 | 87 | # The master toctree document. 88 | master_doc = 'index' 89 | 90 | # The language for content autogenerated by Sphinx. Refer to documentation 91 | # for a list of supported languages. 92 | # 93 | # This is also used if you do content translation via gettext catalogs. 94 | # Usually you set "language" from the command line for these cases. 95 | language = None 96 | 97 | # List of patterns, relative to source directory, that match files and 98 | # directories to ignore when looking for source files. 99 | # This pattern also affects html_static_path and html_extra_path . 100 | exclude_patterns = [] 101 | 102 | # The name of the Pygments (syntax highlighting) style to use. 103 | pygments_style = 'sphinx' 104 | 105 | # -- Options for HTML output ------------------------------------------------- 106 | 107 | # The theme to use for HTML and HTML Help pages. See the documentation for 108 | # a list of builtin themes. 109 | # sphinx_rtd_theme,alabaster 110 | import sphinx_rtd_theme 111 | 112 | html_theme = 'sphinx_rtd_theme' 113 | 114 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 115 | # Theme options are theme-specific and customize the look and feel of a theme 116 | # further. For a list of options available for each theme, see the 117 | # documentation. 118 | # 119 | html_theme_options = {} 120 | 121 | # Add any paths that contain custom static files (such as style sheets) here, 122 | # relative to this directory. They are copied after the builtin static files, 123 | # so a file named "default.css" will overwrite the builtin "default.css". 124 | # sphinx_rtd_theme ,nstatic 125 | html_static_path = ['nstatic'] 126 | 127 | # Custom sidebar templates, must be a dictionary that maps document names 128 | # to template names. 129 | # 130 | # The default sidebars (for documents that don't match any pattern) are 131 | # defined by theme itself. Builtin themes are using these templates by 132 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 133 | # 'searchbox.html']``. 134 | # 135 | # html_sidebars = {} 136 | 137 | 138 | # -- Options for HTMLHelp output --------------------------------------------- 139 | 140 | # Output file base name for HTML help builder. 141 | htmlhelp_basename = 'jditdoc' 142 | 143 | # -- Options for LaTeX output ------------------------------------------------ 144 | 145 | latex_elements = { 146 | # The paper size ('letterpaper' or 'a4paper'). 147 | # 148 | # 'papersize': 'letterpaper', 149 | 150 | # The font size ('10pt', '11pt' or '12pt'). 151 | # 152 | # 'pointsize': '10pt', 153 | 154 | # Additional stuff for the LaTeX preamble. 155 | # 156 | # 'preamble': '', 157 | 158 | # Latex figure (float) alignment 159 | # 160 | # 'figure_align': 'htbp', 161 | } 162 | 163 | # Grouping the document tree into LaTeX files. List of tuples 164 | # (source start file, target name, title, 165 | # author, documentclass [howto, manual, or own class]). 166 | latex_documents = [ 167 | (master_doc, 'jdit.tex', 'jdit Documentation', 168 | 'dingguanglei', 'manual'), 169 | ] 170 | 171 | # -- Options for manual page output ------------------------------------------ 172 | 173 | # One entry per manual page. List of tuples 174 | # (source start file, name, description, authors, manual section). 175 | man_pages = [ 176 | (master_doc, 'jdit', 'jdit Documentation', 177 | [author], 1) 178 | ] 179 | 180 | # -- Options for Texinfo output ---------------------------------------------- 181 | 182 | # Grouping the document tree into Texinfo files. List of tuples 183 | # (source start file, target name, title, author, 184 | # dir menu entry, description, category) 185 | texinfo_documents = [ 186 | (master_doc, 'jdit', 'jdit Documentation', 187 | author, 'jdit', 'One line description of project.', 188 | 'Miscellaneous'), 189 | ] 190 | 191 | # -- Extension configuration ------------------------------------------------- 192 | -------------------------------------------------------------------------------- /jdit/perspective.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union, Dict, Tuple 3 | from jdit.model import Model 4 | from jdit.trainer.super import Watcher 5 | from torch.nn import Module 6 | import torch 7 | from torch import Tensor 8 | 9 | 10 | class FeatureVisualization(object): 11 | """ Visualize the activations of model 12 | 13 | This class can visualize the activations of each layer in model. 14 | :param model:model network. It can be ``jdit.Model`` or ``torch.nn.Module`` 15 | :param log:The log directory and ``--logdir=`` of tensorboard will be run in this dir. 16 | """ 17 | 18 | def __init__(self, model: Union[Model, Module], logdir: str): 19 | model.eval() 20 | self.model = model 21 | self.activations: Dict[str, Tensor] = {} 22 | self.watcher = Watcher(logdir=logdir) 23 | 24 | def _hook(self, name: str): 25 | def hook(model, input, output): 26 | self.activations[name] = output.detach() 27 | 28 | return hook 29 | 30 | @property 31 | def layers_names(self): 32 | """Get model's layers names. 33 | 34 | :return: 35 | """ 36 | layer_names = [] 37 | for layer in self.model.named_modules(): 38 | layer_names.append(layer[0]) 39 | return layer_names 40 | 41 | def _register_forward_hook(self, block_name: str): 42 | """Set a forward hook in model. 43 | 44 | :param block_name: block name of model 45 | :return: 46 | """ 47 | if not hasattr(self.model, block_name): 48 | raise AttributeError( 49 | "model doesn't have `%s` layer. These layers available: %s" % ( 50 | block_name, str.join("\n", self.layers_names))) 51 | 52 | getattr(self.model, block_name).register_forward_hook(self._hook(block_name)) 53 | 54 | def trace_activation(self, input_tensor: Union[Tensor, Tuple[Tensor]], block_names: list, 55 | use_gpu=False) -> \ 56 | Dict[str, Tensor]: 57 | """Trace the activations of model during a forward propagation. 58 | 59 | :param input_tensor:Input of model. Tensor or tuple of Tensors 60 | :param block_names:A list of layer's key names which you want to visualize. 61 | :param use_gpu:If use GPU. Default: False 62 | :return:The activations dic with block_names as the keys. 63 | """ 64 | if len(input_tensor.shape) == 4 and input_tensor.shape[0] != 1: 65 | raise ValueError( 66 | "You can only pass one sample to do feature visualization, but %d was given" % 67 | input_tensor.shape[0]) 68 | 69 | for name in block_names: 70 | self._register_forward_hook(name) 71 | with torch.no_grad(): 72 | if use_gpu: 73 | self.model = self.model.cuda() 74 | if isinstance(input_tensor, Tensor): 75 | input_tensor = input_tensor.cuda() 76 | self.model(input_tensor) 77 | elif isinstance(input_tensor, (tuple, list)): 78 | input_tensor = [item.cuda() for item in input_tensor] 79 | self.model(*input_tensor) 80 | else: 81 | raise TypeError( 82 | "`input_tensor` should be Tensor or tuple of Tensors, but %s was given" % type( 83 | input_tensor)) 84 | else: 85 | if isinstance(input_tensor, Tensor): 86 | self.model(input_tensor) 87 | elif isinstance(input_tensor, (tuple, list)): 88 | self.model(*input_tensor) 89 | else: 90 | raise TypeError( 91 | "`input_tensor` should be Tensor or tuple of Tensors, but %s was given" % type( 92 | input_tensor)) 93 | 94 | return self.activations 95 | 96 | def show_feature_trace(self, input_tensor: Union[Tensor, Tuple[Tensor]], block_names: list, global_step, 97 | grid_size=(4, 4), show_struct=True, use_gpu=False, shuffle=False): 98 | """Trace the activations of model during a forward propagation. 99 | 100 | :param input_tensor: Input of model. Tensor or tuple of Tensors 101 | :param block_names: A list of layer's key names which you want to visualize. 102 | :param global_step:step flag. 103 | :param grid_size:Amount and size of actives images which you want to show in tensorboard. 104 | :param show_struct:If show structure of model. Default: True 105 | :param use_gpu:If use GPU. Default: False 106 | :param shuffle:If shuffle the activations channels. Default: False 107 | :return: 108 | 109 | Example: 110 | 111 | .. code-block:: python 112 | 113 | from torch.nn import Conv2d 114 | class CL(Module): 115 | def __init__(self): 116 | super(CL, self).__init__() 117 | self.layer_1 = Conv2d(1, 2, 3, 1, 1) 118 | self.layer_2 = Conv2d(2, 4, 3, 1, 1) 119 | self.layer_3 = Conv2d(4, 8, 3, 1, 1) 120 | self.layer_4 = Conv2d(8, 16, 3, 1, 1) 121 | def forward(self, input): 122 | out = self.layer_1(input) 123 | out = self.layer_2(out) 124 | out = self.layer_3(out) 125 | out = self.layer_4(out) 126 | return out 127 | 128 | model = CL() 129 | fv = FeatureVisualization(model, "incep_test") 130 | input = torch.randn(1, 1, 10, 10) 131 | fv.show_feature_trace(input, ["layer_1", "layer_2", "layer_3"], 1) 132 | fv.watcher.close() 133 | 134 | """ 135 | self.trace_activation(input_tensor, block_names) 136 | if show_struct: 137 | self.watcher.graph(self.model, self.model.__class__.__name__, use_gpu, input_tensor.size()) 138 | 139 | for layer_name, tensor in self.activations.items(): 140 | if len(tensor.shape) == 4: 141 | tensor = tensor[0] # 1,channel,H,W =>channel,H,W 142 | tensor = tensor.unsqueeze(1) # channel,1,H,W 143 | 144 | self.watcher.image(tensor, grid_size=grid_size, tag=layer_name, global_step=global_step, shuffle=shuffle) 145 | 146 | @staticmethod 147 | def _build_dir(dirs: str): 148 | if not os.path.exists(dirs): 149 | os.makedirs(dirs) 150 | 151 | 152 | if __name__ == '__main__': 153 | from torch.nn import Conv2d 154 | 155 | 156 | class CL(Module): 157 | def __init__(self): 158 | super(CL, self).__init__() 159 | self.unet_block_1 = Conv2d(1, 2, 3, 1, 1) 160 | self.unet_block_2 = Conv2d(2, 4, 3, 1, 1) 161 | self.unet_block_3 = Conv2d(4, 8, 3, 1, 1) 162 | self.unet_block_4 = Conv2d(8, 16, 3, 1, 1) 163 | 164 | def forward(self, input): 165 | out = self.unet_block_1(input) 166 | out = self.unet_block_2(out) 167 | out = self.unet_block_3(out) 168 | out = self.unet_block_4(out) 169 | return out 170 | 171 | 172 | _model = CL() 173 | fv = FeatureVisualization(_model, "incep_test") 174 | _input = torch.randn(1, 1, 10, 10) 175 | fv.show_feature_trace(_input, ["unet_block_1", "unet_block_2", "unet_block_3"], 1) 176 | fv.watcher.close() 177 | -------------------------------------------------------------------------------- /jdit/trainer/gan/pix2pix.py: -------------------------------------------------------------------------------- 1 | from .sup_gan import SupGanTrainer 2 | from abc import abstractmethod 3 | import torch 4 | 5 | 6 | class Pix2pixGanTrainer(SupGanTrainer): 7 | d_turn = 1 8 | 9 | def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets): 10 | """ A pixel to pixel gan trainer 11 | 12 | :param logdir:Path of log 13 | :param nepochs:Amount of epochs. 14 | :param gpu_ids_abs: he id of gpus which t obe used. If use CPU, set ``[]``. 15 | :param netG:Generator model. 16 | :param netD:Discrimiator model 17 | :param optG:Optimizer of Generator. 18 | :param optD:Optimizer of Discrimiator. 19 | :param datasets:Datasets. 20 | """ 21 | super(Pix2pixGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets) 22 | 23 | def get_data_from_batch(self, batch_data: list, device: torch.device): 24 | input_tensor, ground_truth_tensor = batch_data[0], batch_data[1] 25 | return input_tensor, ground_truth_tensor 26 | 27 | def _watch_images(self, tag, grid_size=(3, 3), shuffle=False, save_file=True): 28 | self.watcher.image(self.input, 29 | self.current_epoch, 30 | tag="%s/input" % tag, 31 | grid_size=grid_size, 32 | shuffle=shuffle, 33 | save_file=save_file) 34 | self.watcher.image(self.fake, 35 | self.current_epoch, 36 | tag="%s/fake" % tag, 37 | grid_size=grid_size, 38 | shuffle=shuffle, 39 | save_file=save_file) 40 | self.watcher.image(self.ground_truth, 41 | self.current_epoch, 42 | tag="%s/real" % tag, 43 | grid_size=grid_size, 44 | shuffle=shuffle, 45 | save_file=save_file) 46 | 47 | @abstractmethod 48 | def compute_d_loss(self): 49 | """ Rewrite this method to compute your own loss Discriminator. 50 | 51 | You should return a **loss** for the first position. 52 | You can return a ``dict`` of loss that you want to visualize on the second position.like 53 | The training logic is : 54 | self.input, self.ground_truth = self.get_data_from_batch(batch, self.device) 55 | self.fake = self.netG(self.input) 56 | self._train_iteration(self.optD, self.compute_d_loss, csv_filename="Train_D") 57 | if (self.step % self.d_turn) == 0: 58 | self._train_iteration(self.optG, self.compute_g_loss, csv_filename="Train_G") 59 | So, you use `self.input` , `self.ground_truth`, `self.fake`, `self.netG`, `self.optD` to compute loss. 60 | Example:: 61 | 62 | d_fake = self.netD(self.fake.detach()) 63 | d_real = self.netD(self.ground_truth) 64 | var_dic = {} 65 | var_dic["LS_LOSSD"] = loss_d = 0.5 * (torch.mean((d_real - 1) ** 2) + torch.mean(d_fake ** 2)) 66 | return loss_d, var_dic 67 | 68 | """ 69 | loss_d = None 70 | var_dic = {} 71 | 72 | return loss_d, var_dic 73 | 74 | @abstractmethod 75 | def compute_g_loss(self): 76 | """Rewrite this method to compute your own loss of Generator. 77 | 78 | You should return a **loss** for the first position. 79 | You can return a ``dict`` of loss that you want to visualize on the second position.like 80 | The training logic is : 81 | self.input, self.ground_truth = self.get_data_from_batch(batch, self.device) 82 | self.fake = self.netG(self.input) 83 | self._train_iteration(self.optD, self.compute_d_loss, csv_filename="Train_D") 84 | if (self.step % self.d_turn) == 0: 85 | self._train_iteration(self.optG, self.compute_g_loss, csv_filename="Train_G") 86 | So, you use `self.input` , `self.ground_truth`, `self.fake`, `self.netG`, `self.optD` to compute loss. 87 | Example:: 88 | 89 | d_fake = self.netD(self.fake, self.input) 90 | var_dic = {} 91 | var_dic["LS_LOSSG"] = loss_g = 0.5 * torch.mean((d_fake - 1) ** 2) 92 | return loss_g, var_dic 93 | 94 | """ 95 | loss_g = None 96 | var_dic = {} 97 | return loss_g, var_dic 98 | 99 | @abstractmethod 100 | def compute_valid(self): 101 | """Rewrite this method to compute valid_epoch values. 102 | 103 | You can return a ``dict`` of values that you want to visualize. 104 | 105 | .. note:: 106 | 107 | This method is under ``torch.no_grad():``. So, it will never compute grad. 108 | If you want to compute grad, please use ``torch.enable_grad():`` to wrap your operations. 109 | 110 | Example:: 111 | 112 | d_fake = self.netD(self.fake.detach()) 113 | d_real = self.netD(self.ground_truth) 114 | var_dic = {} 115 | var_dic["WD"] = w_distance = (d_real.mean() - d_fake.mean()).detach() 116 | return var_dic 117 | 118 | """ 119 | _, d_var_dic = self.compute_g_loss() 120 | _, g_var_dic = self.compute_d_loss() 121 | var_dic = dict(d_var_dic, **g_var_dic) 122 | return var_dic 123 | 124 | def valid_epoch(self): 125 | super(Pix2pixGanTrainer, self).valid_epoch() 126 | self.netG.eval() 127 | self.netD.eval() 128 | if self.fixed_input is None: 129 | for batch in self.datasets.loader_test: 130 | if isinstance(batch, (list, tuple)): 131 | self.fixed_input, fixed_ground_truth = self.get_data_from_batch(batch, self.device) 132 | self.watcher.image(fixed_ground_truth, self.current_epoch, tag="Fixed/groundtruth", 133 | grid_size=(6, 6), 134 | shuffle=False) 135 | else: 136 | self.fixed_input = batch.to(self.device) 137 | self.watcher.image(self.fixed_input, self.current_epoch, tag="Fixed/input", 138 | grid_size=(6, 6), 139 | shuffle=False) 140 | break 141 | 142 | # watching the variation during training by a fixed input 143 | with torch.no_grad(): 144 | fake = self.netG(self.fixed_input).detach() 145 | self.watcher.image(fake, self.current_epoch, tag="Fixed/fake", grid_size=(6, 6), shuffle=False) 146 | 147 | # saving training processes to build a .gif. 148 | self.watcher.set_training_progress_images(fake, grid_size=(6, 6)) 149 | 150 | self.netG.train() 151 | self.netD.train() 152 | 153 | def test(self): 154 | """ Test your model when you finish all epochs. 155 | 156 | This method will call when all epochs finish. 157 | 158 | Example:: 159 | 160 | for index, batch in enumerate(self.datasets.loader_test, 1): 161 | # For test only have input without groundtruth 162 | input = batch.to(self.device) 163 | self.netG.eval() 164 | with torch.no_grad(): 165 | fake = self.netG(input) 166 | self.watcher.image(fake, self.current_epoch, tag="Test/fake", grid_size=(4, 4), shuffle=False) 167 | self.netG.train() 168 | """ 169 | for batch in self.datasets.loader_test: 170 | self.input, _ = self.get_data_from_batch(batch, self.device) 171 | self.netG.eval() 172 | with torch.no_grad(): 173 | fake = self.netG(self.input).detach() 174 | self.watcher.image(fake, self.current_epoch, tag="Test/fake", grid_size=(7, 7), shuffle=False) 175 | self.netG.train() 176 | -------------------------------------------------------------------------------- /jdit/trainer/single/sup_single.py: -------------------------------------------------------------------------------- 1 | from ..super import SupTrainer 2 | from tqdm import tqdm 3 | import torch 4 | from jdit.optimizer import Optimizer 5 | from jdit.model import Model 6 | from jdit.dataset import DataLoadersFactory 7 | 8 | 9 | class SupSingleModelTrainer(SupTrainer): 10 | """ This is a Single Model Trainer. 11 | It means you only have one model. 12 | input, gound_truth 13 | output = model(input) 14 | loss(output, gound_truth) 15 | 16 | """ 17 | 18 | def __init__(self, logdir, nepochs, gpu_ids_abs, net: Model, opt: Optimizer, datasets: DataLoadersFactory): 19 | 20 | super(SupSingleModelTrainer, self).__init__(nepochs, logdir, gpu_ids_abs=gpu_ids_abs) 21 | self.net = net 22 | self.opt = opt 23 | self.datasets = datasets 24 | self.fixed_input = None 25 | self.input = None 26 | self.output = None 27 | self.ground_truth = None 28 | 29 | def train_epoch(self, subbar_disable=False): 30 | for iteration, batch in tqdm(enumerate(self.datasets.loader_train, 1), unit="step", disable=subbar_disable): 31 | self.step += 1 32 | self.input, self.ground_truth = self.get_data_from_batch(batch, self.device) 33 | self.output = self.net(self.input) 34 | self._train_iteration(self.opt, self.compute_loss, csv_filename="Train") 35 | if iteration == 1: 36 | self._watch_images("Train") 37 | 38 | def get_data_from_batch(self, batch_data: list, device: torch.device): 39 | """ Load and wrap data from the data lodaer. 40 | 41 | Split your one batch data to specify variable. 42 | 43 | Example:: 44 | 45 | # batch_data like this [input_Data, ground_truth_Data] 46 | input_cpu, ground_truth_cpu = batch_data[0], batch_data[1] 47 | # then move them to device and return them 48 | return input_cpu.to(self.device), ground_truth_cpu.to(self.device) 49 | 50 | :param batch_data: one batch data load from ``DataLoader`` 51 | :param device: A device variable. ``torch.device`` 52 | :return: input Tensor, ground_truth Tensor 53 | """ 54 | input_tensor, ground_truth_tensor = batch_data[0], batch_data[1] 55 | return input_tensor, ground_truth_tensor 56 | 57 | def _watch_images(self, tag: str, grid_size: tuple = (3, 3), shuffle=False, save_file=True): 58 | """ Show images in tensorboard 59 | 60 | To show images in tensorboad. If want to show fixed input and it's output, 61 | please use ``shuffle=False`` to fix the visualized data. 62 | Otherwise, it will sample and visualize the data randomly. 63 | 64 | Example:: 65 | 66 | # show fake data 67 | self.watcher.image(self.output, 68 | self.current_epoch, 69 | tag="%s/output" % tag, 70 | grid_size=grid_size, 71 | shuffle=shuffle, 72 | save_file=save_file) 73 | 74 | # show ground_truth 75 | self.watcher.image(self.ground_truth, 76 | self.current_epoch, 77 | tag="%s/ground_truth" % tag, 78 | grid_size=grid_size, 79 | shuffle=shuffle, 80 | save_file=save_file) 81 | 82 | # show input 83 | self.watcher.image(self.input, 84 | self.current_epoch, 85 | tag="%s/input" % tag, 86 | grid_size=grid_size, 87 | shuffle=shuffle, 88 | save_file=save_file) 89 | 90 | 91 | :param tag: tensorboard tag 92 | :param grid_size: A tuple for grad size which data you want to visualize 93 | :param shuffle: If shuffle the data. 94 | :param save_file: If save this images. 95 | :return: 96 | """ 97 | self.watcher.image(self.output, 98 | self.current_epoch, 99 | tag="%s/output" % tag, 100 | grid_size=grid_size, 101 | shuffle=shuffle, 102 | save_file=save_file) 103 | self.watcher.image(self.ground_truth, 104 | self.current_epoch, 105 | tag="%s/ground_truth" % tag, 106 | grid_size=grid_size, 107 | shuffle=shuffle, 108 | save_file=save_file) 109 | 110 | def compute_loss(self) -> (torch.Tensor, dict): 111 | """ Rewrite this method to compute your own loss Discriminator. 112 | Use self.input, self.output and self.ground_truth to compute loss. 113 | You should return a **loss** for the first position. 114 | You can return a ``dict`` of loss that you want to visualize on the second position.like 115 | 116 | Example:: 117 | 118 | var_dic = {} 119 | var_dic["LOSS"] = loss_d = (self.output ** 2 - self.groundtruth ** 2) ** 0.5 120 | return: loss, var_dic 121 | 122 | """ 123 | loss: torch.Tensor 124 | var_dic = {} 125 | 126 | return loss, var_dic 127 | 128 | def compute_valid(self) -> dict: 129 | """ Rewrite this method to compute your validation values. 130 | Use self.input, self.output and self.ground_truth to compute valid loss. 131 | You can return a ``dict`` of validation values that you want to visualize. 132 | 133 | Example:: 134 | 135 | # It will do the same thing as ``compute_loss()`` 136 | var_dic, _ = self.compute_loss() 137 | return var_dic 138 | 139 | """ 140 | # It will do the same thing as ``compute_loss()`` 141 | var_dic, _ = self.compute_loss() 142 | return var_dic 143 | 144 | def valid_epoch(self): 145 | """Validate model each epoch. 146 | 147 | It will be called each epoch, when training finish. 148 | So, do same verification here. 149 | 150 | Example:: 151 | 152 | avg_dic: dict = {} 153 | self.net.eval() 154 | for iteration, batch in enumerate(self.datasets.loader_valid, 1): 155 | self.input, self.ground_truth = self.get_data_from_batch(batch, self.device) 156 | with torch.no_grad(): 157 | self.output = self.net(self.input) 158 | dic: dict = self.compute_valid() 159 | if avg_dic == {}: 160 | avg_dic: dict = dic 161 | else: 162 | for key in dic.keys(): 163 | avg_dic[key] += dic[key] 164 | 165 | for key in avg_dic.keys(): 166 | avg_dic[key] = avg_dic[key] / self.datasets.nsteps_valid 167 | 168 | self.watcher.scalars(avg_dic, self.step, tag="Valid") 169 | self.loger.write(self.step, self.current_epoch, avg_dic, "Valid", header=self.step <= 1) 170 | self._watch_images(tag="Valid") 171 | self.net.train() 172 | 173 | """ 174 | avg_dic: dict = {} 175 | self.net.eval() 176 | for iteration, batch in enumerate(self.datasets.loader_valid, 1): 177 | self.input, self.ground_truth = self.get_data_from_batch(batch, self.device) 178 | with torch.no_grad(): 179 | self.output = self.net(self.input) 180 | dic: dict = self.compute_valid() 181 | if avg_dic == {}: 182 | avg_dic: dict = dic 183 | else: 184 | # 求和 185 | for key in dic.keys(): 186 | avg_dic[key] += dic[key] 187 | 188 | for key in avg_dic.keys(): 189 | avg_dic[key] = avg_dic[key] / self.datasets.nsteps_valid 190 | 191 | self.watcher.scalars(avg_dic, self.step, tag="Valid") 192 | self.loger.write(self.step, self.current_epoch, avg_dic, "Valid", header=self.current_epoch <= 1) 193 | self._watch_images(tag="Valid") 194 | self.net.train() 195 | 196 | def test(self): 197 | pass 198 | -------------------------------------------------------------------------------- /jdit/trainer/gan/sup_gan.py: -------------------------------------------------------------------------------- 1 | from ..super import SupTrainer 2 | from tqdm import tqdm 3 | import torch 4 | from jdit.optimizer import Optimizer 5 | from jdit.model import Model 6 | from jdit.dataset import DataLoadersFactory 7 | 8 | 9 | class SupGanTrainer(SupTrainer): 10 | d_turn = 1 11 | """The training times of Discriminator every ones Generator training. 12 | """ 13 | 14 | def __init__(self, logdir, nepochs, gpu_ids_abs, netG: Model, netD: Model, optG: Optimizer, optD: Optimizer, 15 | datasets: DataLoadersFactory): 16 | 17 | super(SupGanTrainer, self).__init__(nepochs, logdir, gpu_ids_abs=gpu_ids_abs) 18 | self.netG = netG 19 | self.netD = netD 20 | self.optG = optG 21 | self.optD = optD 22 | self.datasets = datasets 23 | self.fake = None 24 | self.fixed_input = None 25 | 26 | def train_epoch(self, subbar_disable=False): 27 | for iteration, batch in tqdm(enumerate(self.datasets.loader_train, 1), unit="step", disable=subbar_disable): 28 | self.step += 1 29 | self.input, self.ground_truth = self.get_data_from_batch(batch, self.device) 30 | self.fake = self.netG(self.input) 31 | self._train_iteration(self.optD, self.compute_d_loss, csv_filename="Train_D") 32 | if (self.step % self.d_turn) == 0: 33 | self._train_iteration(self.optG, self.compute_g_loss, csv_filename="Train_G") 34 | if iteration == 1: 35 | self._watch_images("Train") 36 | 37 | def get_data_from_batch(self, batch_data: list, device: torch.device): 38 | """ Load and wrap data from the data lodaer. 39 | 40 | Split your one batch data to specify variable. 41 | 42 | Example:: 43 | 44 | # batch_data like this [input_Data, ground_truth_Data] 45 | input_cpu, ground_truth_cpu = batch_data[0], batch_data[1] 46 | # then move them to device and return them 47 | return input_cpu.to(self.device), ground_truth_cpu.to(self.device) 48 | 49 | :param batch_data: one batch data load from ``DataLoader`` 50 | :param device: A device variable. ``torch.device`` 51 | :return: input Tensor, ground_truth Tensor 52 | """ 53 | input_tensor, ground_truth_tensor = batch_data[0], batch_data[1] 54 | return input_tensor, ground_truth_tensor 55 | 56 | def _watch_images(self, tag: str, grid_size: tuple = (3, 3), shuffle=False, save_file=True): 57 | """ Show images in tensorboard 58 | 59 | To show images in tensorboad. If want to show fixed input and it's output, 60 | please use ``shuffle=False`` to fix the visualized data. 61 | Otherwise, it will sample and visualize the data randomly. 62 | 63 | Example:: 64 | 65 | # show fake data 66 | self.watcher.image(self.fake, 67 | self.current_epoch, 68 | tag="%s/fake" % tag, 69 | grid_size=grid_size, 70 | shuffle=shuffle, 71 | save_file=save_file) 72 | 73 | # show ground_truth 74 | self.watcher.image(self.ground_truth, 75 | self.current_epoch, 76 | tag="%s/real" % tag, 77 | grid_size=grid_size, 78 | shuffle=shuffle, 79 | save_file=save_file) 80 | 81 | # show input 82 | self.watcher.image(self.input, 83 | self.current_epoch, 84 | tag="%s/input" % tag, 85 | grid_size=grid_size, 86 | shuffle=shuffle, 87 | save_file=save_file) 88 | 89 | 90 | :param tag: tensorboard tag 91 | :param grid_size: A tuple for grad size which data you want to visualize 92 | :param shuffle: If shuffle the data. 93 | :param save_file: If save this images. 94 | :return: 95 | """ 96 | self.watcher.image(self.fake, 97 | self.current_epoch, 98 | tag="%s/fake" % tag, 99 | grid_size=grid_size, 100 | shuffle=shuffle, 101 | save_file=save_file) 102 | self.watcher.image(self.ground_truth, 103 | self.current_epoch, 104 | tag="%s/real" % tag, 105 | grid_size=grid_size, 106 | shuffle=shuffle, 107 | save_file=save_file) 108 | 109 | def compute_d_loss(self) -> (torch.Tensor, dict): 110 | """ Rewrite this method to compute your own loss Discriminator. 111 | 112 | You should return a **loss** for the first position. 113 | You can return a ``dict`` of loss that you want to visualize on the second position.like 114 | 115 | Example:: 116 | 117 | d_fake = self.netD(self.fake.detach()) 118 | d_real = self.netD(self.ground_truth) 119 | var_dic = {} 120 | var_dic["GP"] = gp = gradPenalty(self.netD, self.ground_truth, self.fake, input=self.input, 121 | use_gpu=self.use_gpu) 122 | var_dic["WD"] = w_distance = (d_real.mean() - d_fake.mean()).detach() 123 | var_dic["LOSS_D"] = loss_d = d_fake.mean() - d_real.mean() + gp + sgp 124 | return: loss_d, var_dic 125 | 126 | """ 127 | loss_d: torch.Tensor 128 | var_dic = {} 129 | 130 | return loss_d, var_dic 131 | 132 | def compute_g_loss(self) -> (torch.Tensor, dict): 133 | """Rewrite this method to compute your own loss of Generator. 134 | 135 | You should return a **loss** for the first position. 136 | You can return a ``dict`` of loss that you want to visualize on the second position.like 137 | 138 | Example:: 139 | 140 | d_fake = self.netD(self.fake) 141 | var_dic = {} 142 | var_dic["JC"] = jc = jcbClamp(self.netG, self.input, use_gpu=self.use_gpu) 143 | var_dic["LOSS_D"] = loss_g = -d_fake.mean() + jc 144 | return: loss_g, var_dic 145 | 146 | """ 147 | loss_g: torch.Tensor 148 | var_dic = {} 149 | return loss_g, var_dic 150 | 151 | def compute_valid(self) -> dict: 152 | """ Rewrite this method to compute your validation values. 153 | 154 | You can return a ``dict`` of validation values that you want to visualize. 155 | 156 | Example:: 157 | 158 | # It will do the same thing as ``compute_g_loss()`` and ``self.compute_d_loss()`` 159 | g_loss, _ = self.compute_g_loss() 160 | d_loss, _ = self.compute_d_loss() 161 | var_dic = {"LOSS_D": d_loss, "LOSS_G": g_loss} 162 | return var_dic 163 | 164 | """ 165 | g_loss, _ = self.compute_g_loss() 166 | d_loss, _ = self.compute_d_loss() 167 | var_dic = {"LOSS_D": d_loss, "LOSS_G": g_loss} 168 | return var_dic 169 | 170 | def valid_epoch(self): 171 | """Validate model each epoch. 172 | 173 | It will be called each epoch, when training finish. 174 | So, do same verification here. 175 | 176 | Example:: 177 | 178 | avg_dic: dict = {} 179 | self.netG.eval() 180 | self.netD.eval() 181 | # Load data from loader_valid. 182 | for iteration, batch in enumerate(self.datasets.loader_valid, 1): 183 | self.input, self.ground_truth = self.get_data_from_batch(batch) 184 | with torch.no_grad(): 185 | self.fake = self.netG(self.input) 186 | # You can write this function to apply your computation. 187 | dic: dict = self.compute_valid() 188 | if avg_dic == {}: 189 | avg_dic: dict = dic 190 | else: 191 | for key in dic.keys(): 192 | avg_dic[key] += dic[key] 193 | 194 | for key in avg_dic.keys(): 195 | avg_dic[key] = avg_dic[key] / self.datasets.nsteps_valid 196 | 197 | self.watcher.scalars(avg_dic, self.step, tag="Valid") 198 | self._watch_images(tag="Valid") 199 | self.netG.train() 200 | self.netD.train() 201 | 202 | """ 203 | avg_dic: dict = {} 204 | self.netG.eval() 205 | self.netD.eval() 206 | for iteration, batch in enumerate(self.datasets.loader_valid, 1): 207 | self.input, self.ground_truth = self.get_data_from_batch(batch, self.device) 208 | with torch.no_grad(): 209 | self.fake = self.netG(self.input) 210 | dic: dict = self.compute_valid() 211 | if avg_dic == {}: 212 | avg_dic: dict = dic 213 | else: 214 | # 求和 215 | for key in dic.keys(): 216 | avg_dic[key] += dic[key] 217 | 218 | for key in avg_dic.keys(): 219 | avg_dic[key] = avg_dic[key] / self.datasets.nsteps_valid 220 | 221 | self.watcher.scalars(avg_dic, self.step, tag="Valid") 222 | self.loger.write(self.step, self.current_epoch, avg_dic, "Valid", header=self.current_epoch <= 1) 223 | self._watch_images(tag="Valid") 224 | self.netG.train() 225 | self.netD.train() 226 | 227 | def test(self): 228 | pass 229 | 230 | @property 231 | def configure(self): 232 | config_dic = super(SupGanTrainer, self).configure 233 | config_dic["d_turn"] = str(self.d_turn) 234 | return config_dic 235 | -------------------------------------------------------------------------------- /jdit/optimizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from typing import Optional, Union, Dict 3 | import torch.optim as optim 4 | from inspect import signature 5 | 6 | 7 | class Optimizer(object): 8 | """This is a wrapper of ``optimizer`` class in pytorch. 9 | 10 | We add something new features in order to feather control the optimizer. 11 | 12 | * :attr:`params` is the parameters of model which need to be updated. 13 | It will use a filter to get all the parameters that required grad automatically. 14 | Like this 15 | 16 | ``filter(lambda p: p.requires_grad, params)`` 17 | 18 | So, you can passing ``model.all_params()`` without any filters. 19 | 20 | * :attr:`learning rate decay` When calling ``do_lr_decay()``, 21 | it will do a learning rate decay. like: 22 | 23 | .. math:: 24 | 25 | lr = lr * decay 26 | 27 | * :attr:`learning rate reset` . Reset learning rate, it can change learning rate and decay directly. 28 | 29 | 30 | :param params: parameters of model, which need to be updated. 31 | :param optimizer: An optimizer classin pytorch, such as ``torch.optim.Adam``. 32 | :param lr_decay: learning rate decay. Default: 0.92. 33 | :param decay_at_epoch: The position of applying lr decay. Default: None. 34 | :param decay_at_step: learning rate decay. Default: None 35 | :param kwargs: pass hyper-parameters to optimizer, such as ``lr`` , ``betas`` , ``weight_decay`` . 36 | :return: 37 | 38 | Args: 39 | 40 | params (dict): parameters of model, which need to be updated. 41 | 42 | optimizer (torch.optim.Optimizer): An optimizer classin pytorch, such as ``torch.optim.Adam`` 43 | 44 | lr_decay (float, optional): learning rate decay. Default: 0.92 45 | 46 | decay_position (int, list, optional): The decaly position of lr. Default: None 47 | 48 | lr_reset (Dict[position(int), lr(float)] ): Reset learning at a certain position. Default: None 49 | 50 | position_type ('epoch','step'): Position type. Default: None 51 | 52 | **kwargs : pass hyper-parameters to optimizer, such as ``lr`` , ``betas`` , ``weight_decay`` . 53 | 54 | Example:: 55 | 56 | >>> from torch.nn import Sequential, Conv3d 57 | >>> from torch.optim import Adam 58 | >>> module = Sequential(Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))) 59 | >>> opt = Optimizer(module.parameters() ,"Adam", 0.5, 10, {4:0.99},"epoch", lr=1.0, betas=(0.9, 0.999), 60 | weight_decay=1e-5) 61 | >>> print(opt) 62 | (Adam ( 63 | Parameter Group 0 64 | amsgrad: False 65 | betas: (0.9, 0.999) 66 | eps: 1e-08 67 | lr: 1.0 68 | weight_decay: 1e-05 69 | ) 70 | lr_decay:0.5 71 | decay_position:10 72 | lr_reset:{4: 0.99} 73 | position_type:epoch 74 | )) 75 | >>> opt.lr 76 | 1.0 77 | >>> opt.lr_decay 78 | 0.5 79 | >>> opt.do_lr_decay() 80 | >>> opt.lr 81 | 0.5 82 | >>> opt.do_lr_decay(reset_lr=1) 83 | >>> opt.lr 84 | 1 85 | >>> opt.opt 86 | Adam ( 87 | Parameter Group 0 88 | amsgrad: False 89 | betas: (0.9, 0.999) 90 | eps: 1e-08 91 | lr: 1 92 | weight_decay: 1e-05 93 | ) 94 | >>> opt.is_decay_lr(1) 95 | False 96 | >>> opt.is_decay_lr(10) 97 | True 98 | >>> opt.is_decay_lr(20) 99 | True 100 | >>> opt.is_reset_lr(4) 101 | 0.99 102 | >>> opt.is_reset_lr(5) 103 | False 104 | 105 | """ 106 | 107 | def __init__(self, params: "parameters of model", 108 | optimizer: "[Adam,RMSprop,SGD...]", 109 | lr_decay: float = 1.0, 110 | decay_position: Union[int, tuple, list] = -1, 111 | lr_reset: Dict[int, float] = None, 112 | position_type: "('epoch','step')" = "epoch", 113 | **kwargs): 114 | if not isinstance(decay_position, (int, tuple, list)): 115 | raise TypeError("`decay_position` should be int or tuple/list, get %s instead" % type( 116 | decay_position)) 117 | if position_type not in ('epoch', 'step'): 118 | raise AttributeError("You need to set `position_type` 'step' or 'epoch', get %s instead" % position_type) 119 | if lr_reset and any(lr_reset.values()) <= 0: 120 | raise AttributeError("The learning rate in `lr_reset={position:lr,}` should be grater than 0!") 121 | self.lr_decay = lr_decay 122 | self.decay_position = decay_position 123 | self.position_type = position_type 124 | self.lr_reset = lr_reset 125 | self.opt_name = optimizer 126 | 127 | try: 128 | Optim = getattr(optim, optimizer) 129 | self.opt = Optim(filter(lambda p: p.requires_grad, params), **kwargs) 130 | except TypeError as e: 131 | raise TypeError( 132 | "%s\n`%s` parameters are:\n %s\n Got %s instead." % (e, optimizer, signature(self.opt), kwargs)) 133 | except AttributeError as e: 134 | opts = [i for i in dir(optim) if not i.endswith("__") and i not in ['lr_scheduler', 'Optimizer']] 135 | raise AttributeError( 136 | "%s\n`%s` is not an optimizer in torch.optim. Availible optims are:\n%s" % (e, optimizer, opts)) 137 | 138 | for param_group in self.opt.param_groups: 139 | self.lr = param_group["lr"] 140 | 141 | def __repr__(self): 142 | string = "(" + str(self.opt) + "\n %s:%s\n" % ("lr_decay", self.lr_decay) 143 | string = string + " %s:%s\n" % ("decay_position", self.decay_position) 144 | string = string + " %s:%s\n" % ("lr_reset", self.lr_reset) 145 | string = string + " %s:%s\n)" % ("position_type", self.position_type) + ")" 146 | return string 147 | 148 | def __getattr__(self, name): 149 | return getattr(self.opt, name) 150 | 151 | def is_decay_lr(self, position: Optional[int]) -> bool: 152 | """Judge if use learning decay on this position. 153 | 154 | :param position: (int) A position of step or epoch. 155 | :return: bool 156 | """ 157 | if not self.decay_position: 158 | return False 159 | 160 | if isinstance(self.decay_position, int): 161 | is_change_lr = position > 0 and (position % self.decay_position) == 0 162 | else: 163 | is_change_lr = position in self.decay_position 164 | return is_change_lr 165 | 166 | def is_reset_lr(self, position: Optional[int]) -> bool: 167 | """Judge if use learning decay on this position. 168 | 169 | :param position: (int) A position of step or epoch. 170 | :return: bool 171 | """ 172 | if not self.lr_reset: 173 | return False 174 | if isinstance(self.lr_reset, (tuple, list)): 175 | reset_lr = position > 0 and (position % self.decay_position) == 0 176 | else: 177 | reset_lr = self.lr_reset.get(position, False) 178 | return reset_lr 179 | 180 | def do_lr_decay(self, reset_lr_decay: float = None, reset_lr: float = None): 181 | """Do learning rate decay, or reset them. 182 | 183 | Passing parameters both None: 184 | Do a learning rate decay by ``self.lr = self.lr * self.lr_decay`` . 185 | 186 | Passing parameters reset_lr_decay or reset_lr: 187 | Do a learning rate or decay reset. by 188 | ``self.lr = reset_lr`` 189 | ``self.lr_decay = reset_lr_decay`` 190 | 191 | :param reset_lr_decay: if not None, use this value to reset `self.lr_decay`. Default: None. 192 | :param reset_lr: if not None, use this value to reset `self.lr`. Default: None. 193 | :return: 194 | """ 195 | 196 | self.lr = self.lr * self.lr_decay 197 | if reset_lr_decay is not None: 198 | self.lr_decay = reset_lr_decay 199 | if reset_lr is not None: 200 | self.lr = reset_lr 201 | for param_group in self.opt.param_groups: 202 | param_group["lr"] = self.lr 203 | 204 | @property 205 | def configure(self): 206 | config_dic = dict() 207 | config_dic["opt_name"] = self.opt_name 208 | opt_config = dict(self.opt.param_groups[0]) 209 | opt_config.pop("params") 210 | config_dic.update(opt_config) 211 | config_dic["lr_decay"] = str(self.lr_decay) 212 | config_dic["decay_position"] = str(self.decay_position) 213 | config_dic["decay_decay_typeposition"] = self.position_type 214 | return config_dic 215 | 216 | 217 | if __name__ == '__main__': 218 | import torch 219 | from torch.optim import Adam, RMSprop, SGD 220 | 221 | adam, rmsprop, sgd = Adam, RMSprop, SGD 222 | param = torch.nn.Linear(10, 1).parameters() 223 | opt = Optimizer(param, "Adam", 0.1, 10, {2: 0.01, 4: 0.1}, "step", lr=0.9, betas=(0.9, 0.999), weight_decay=1e-5) 224 | print(opt) 225 | print(opt.configure['lr']) 226 | opt.do_lr_decay() 227 | print(opt.configure['lr']) 228 | opt.do_lr_decay(reset_lr=0.232, reset_lr_decay=0.3) 229 | print(opt.configure['lr_decay']) 230 | opt.do_lr_decay(reset_lr=0.2) 231 | print(opt.configure) 232 | print(opt.is_decay_lr(1)) 233 | print(opt.is_decay_lr(2)) 234 | print(opt.is_decay_lr(40)) 235 | print(opt.is_decay_lr(10)) 236 | print(opt.is_reset_lr(2)) 237 | print(opt.is_reset_lr(3)) 238 | print(opt.is_reset_lr(4)) 239 | param = torch.nn.Linear(10, 1).parameters() 240 | hpd = {"optimizer": "Adam", "lr_decay": 0.1, "decay_position": [1, 3, 5], "position_type": "epoch", 241 | "lr": 0.9, "betas": (0.9, 0.999), "weight_decay": 1e-5} 242 | opt = Optimizer(param, **hpd) 243 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Jdit](https://raw.githubusercontent.com/dingguanglei/jdit/master/resources/logo.png) 2 | 3 | --- 4 | 5 | [![](http://img.shields.io/travis/dingguanglei/jdit.svg)](https://github.com/dingguanglei/jdit) 6 | [![Documentation Status](https://readthedocs.org/projects/jdit/badge/?version=latest)](https://jdit.readthedocs.io/en/latest/?badge=latest) 7 | [![Codacy Badge](https://api.codacy.com/project/badge/Grade/1a5aa0721e2f44ebb61b0b73627bbc90)](https://www.codacy.com/app/dingguanglei/jdit?utm_source=github.com&utm_medium=referral&utm_content=dingguanglei/jdit&utm_campaign=Badge_Grade) 8 | ![Packagist](https://img.shields.io/hexpm/l/plug.svg) 9 | ![pypi download](https://img.shields.io/pypi/dm/jdit.svg) 10 | 11 | ### [Chinese guide is here!(https://dingguanglei.com/jdit/)](https://dingguanglei.com/jdit/) 12 | **Jdit** is a research processing oriented framework based on pytorch. 13 | Only care about your ideas. You don't need to build a long boring code 14 | to run a deep learning project to verify your ideas. 15 | 16 | You only need to implement you ideas and don't do anything with training 17 | framework, multiply-gpus, checkpoint, process visualization, performance 18 | evaluation and so on. 19 | 20 | Guide: [https://dingguanglei.com/tag/jdit](https://dingguanglei.com/tag/jdit) 21 | 22 | Docs: [https://jdit.readthedocs.io/en/latest/index.html](https://jdit.readthedocs.io/en/latest/index.html) 23 | 24 | If you have any problems, or you find bugs you can contact the author. 25 | 26 | E-mail: dingguanglei.bupt@qq.com 27 | 28 | ## Install 29 | Requires: 30 | ``` {.sourceCode .bash} 31 | tensorboard >= 1.14.0 32 | pytorch >= 1.1.0 33 | ``` 34 | Install requirement. 35 | ``` {.sourceCode .bash} 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | ### From pip 40 | ``` {.sourceCode .bash} 41 | pip install jdit 42 | ``` 43 | 44 | ### From source 45 | This method is recommended, because you can keep the newest version. 46 | 1. Clone from github 47 | ``` {.sourceCode .bash} 48 | git clone https://github.com/dingguanglei/jdit 49 | ``` 50 | 51 | 2. Setup 52 | By using `setup.py` to install the package. 53 | ``` {.sourceCode .bash} 54 | python setup.py bdist_wheel 55 | ``` 56 | 57 | 3. Install 58 | You will find packages in `jdit/dist/`. Use pip to install. 59 | ``` {.sourceCode .bash} 60 | pip install dist/jdit-x.y.z-py3-none-any.whl 61 | ``` 62 | 63 | ## Quick start 64 | 65 | After building and installing jdit package, you can make a new directory 66 | for a quick test. Assuming that you get a new directory example. run 67 | this code in ipython cmd.(Create a main.py file is also acceptable.) 68 | 69 | ```python 70 | from jdit.trainer.instances.fashionClassification import start_fashionClassTrainer 71 | if __name__ == '__main__': 72 | start_fashionClassTrainer() 73 | ``` 74 | The following is the accomplishment of ``start_fashionClassTrainer()`` 75 | 76 | ```python 77 | # coding=utf-8 78 | import torch 79 | import torch.nn as nn 80 | import torch.nn.functional as F 81 | from jdit.trainer.single.classification import ClassificationTrainer 82 | from jdit import Model 83 | from jdit.optimizer import Optimizer 84 | from jdit.dataset import FashionMNIST 85 | 86 | class SimpleModel(nn.Module): 87 | def __init__(self, depth=64, num_class=10): 88 | super(SimpleModel, self).__init__() 89 | self.num_class = num_class 90 | self.layer1 = nn.Conv2d(1, depth, 3, 1, 1) 91 | self.layer2 = nn.Conv2d(depth, depth * 2, 4, 2, 1) 92 | self.layer3 = nn.Conv2d(depth * 2, depth * 4, 4, 2, 1) 93 | self.layer4 = nn.Conv2d(depth * 4, depth * 8, 4, 2, 1) 94 | self.layer5 = nn.Conv2d(depth * 8, num_class, 4, 1, 0) 95 | 96 | def forward(self, input): 97 | out = F.relu(self.layer1(input)) 98 | out = F.relu(self.layer2(out)) 99 | out = F.relu(self.layer3(out)) 100 | out = F.relu(self.layer4(out)) 101 | out = self.layer5(out) 102 | out = out.view(-1, self.num_class) 103 | return out 104 | 105 | 106 | class FashionClassTrainer(ClassificationTrainer): 107 | def __init__(self, logdir, nepochs, gpu_ids, net, opt, datasets, num_class): 108 | super(FashionClassTrainer, self).__init__(logdir, nepochs, gpu_ids, net, opt, datasets, num_class) 109 | data, label = self.datasets.samples_train 110 | self.watcher.embedding(data, data, label, 1) 111 | 112 | def compute_loss(self): 113 | var_dic = {} 114 | labels = self.ground_truth.squeeze().long() 115 | var_dic["CEP"] = loss = nn.CrossEntropyLoss()(self.output, labels) 116 | return loss, var_dic 117 | 118 | def compute_valid(self): 119 | _, var_dic = self.compute_loss() 120 | labels = self.ground_truth.squeeze().long() 121 | _, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2 122 | total = predict.size(0) 123 | correct = predict.eq(labels).cpu().sum().float() 124 | acc = correct / total 125 | var_dic["ACC"] = acc 126 | return var_dic 127 | 128 | def start_fashionClassTrainer(gpus=(), nepochs=10, run_type="train"): 129 | """" An example of fashing-mnist classification 130 | 131 | """ 132 | num_class = 10 133 | depth = 32 134 | gpus = gpus 135 | batch_size = 4 136 | nepochs = nepochs 137 | opt_hpm = {"optimizer": "Adam", 138 | "lr_decay": 0.94, 139 | "decay_position": 10, 140 | "position_type": "epoch", 141 | "lr_reset": {2: 5e-4, 3: 1e-3}, 142 | "lr": 1e-4, 143 | "weight_decay": 2e-5, 144 | "betas": (0.9, 0.99)} 145 | 146 | print('===> Build dataset') 147 | mnist = FashionMNIST(batch_size=batch_size) 148 | # mnist.dataset_train = mnist.dataset_test 149 | torch.backends.cudnn.benchmark = True 150 | print('===> Building model') 151 | net = Model(SimpleModel(depth=depth), gpu_ids_abs=gpus, init_method="kaiming", check_point_pos=1) 152 | print('===> Building optimizer') 153 | opt = Optimizer(net.parameters(), **opt_hpm) 154 | print('===> Training') 155 | print("using `tensorboard --logdir=log` to see learning curves and net structure." 156 | "training and valid_epoch data, configures info and checkpoint were save in `log` directory.") 157 | Trainer = FashionClassTrainer("log/fashion_classify", nepochs, gpus, net, opt, mnist, num_class) 158 | if run_type == "train": 159 | Trainer.train() 160 | elif run_type == "debug": 161 | Trainer.debug() 162 | 163 | if __name__ == '__main__': 164 | start_fashionClassTrainer() 165 | 166 | ``` 167 | 168 | Then you will see something like this as following. 169 | 170 | ``` 171 | ===> Build dataset 172 | use 8 thread 173 | Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz 174 | Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz 175 | Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz 176 | Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz 177 | Processing... 178 | Done 179 | ===> Building model 180 | ResNet Total number of parameters: 2776522 181 | ResNet model use CPU 182 | apply kaiming weight init 183 | ===> Building optimizer 184 | ===> Training 185 | using `tensorboard --logdir=log` to see learning curves and net structure. 186 | training and valid_epoch data, configures info and checkpoint were save in `log` directory. 187 | 0%| | 0/10 [00:00= 1: 78 | # block1 = [ 79 | # inception.Conv2d_3b_1x1, 80 | # inception.Conv2d_4a_3x3, 81 | # nn.MaxPool2d(kernel_size=3, stride=2) 82 | # ] 83 | # self.blocks.append(nn.Sequential(*block1)) 84 | # 85 | # # Block 2: maxpool2 to aux classifier 86 | # if self.last_needed_block >= 2: 87 | # block2 = [ 88 | # inception.Mixed_5b, 89 | # inception.Mixed_5c, 90 | # inception.Mixed_5d, 91 | # inception.Mixed_6a, 92 | # inception.Mixed_6b, 93 | # inception.Mixed_6c, 94 | # inception.Mixed_6d, 95 | # inception.Mixed_6e, 96 | # ] 97 | # self.blocks.append(nn.Sequential(*block2)) 98 | # 99 | # # Block 3: aux classifier to final avgpool 100 | # if self.last_needed_block >= 3: 101 | # block3 = [ 102 | # inception.Mixed_7a, 103 | # inception.Mixed_7b, 104 | # inception.Mixed_7c, 105 | # nn.AdaptiveAvgPool2d(output_size=(1, 1)) 106 | # ] 107 | # self.blocks.append(nn.Sequential(*block3)) 108 | # 109 | # for param in self.parameters(): 110 | # param.requires_grad = requires_grad 111 | # 112 | # def forward(self, inp): 113 | # """Get Inception feature maps 114 | # 115 | # Parameters 116 | # ---------- 117 | # inp : torch.autograd.Variable 118 | # Input tensor of shape Bx3xHxW. Values are expected to be in 119 | # range (0, 1) 120 | # 121 | # Returns 122 | # ------- 123 | # List of torch.autograd.Variable, corresponding to the selected output 124 | # block, sorted ascending by index 125 | # """ 126 | # outp = [] 127 | # x = inp 128 | # 129 | # if self.resize_input: 130 | # # x = F.upsample(x, size=(299, 299), mode='bilinear') 131 | # x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False) 132 | # if self.normalize_input: 133 | # x = x.clone() 134 | # x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 135 | # x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 136 | # x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 137 | # 138 | # for idx, block in enumerate(self.blocks): 139 | # x = block(x) 140 | # if idx in self.output_blocks: 141 | # outp.append(x) 142 | # 143 | # if idx == self.last_needed_block: 144 | # break 145 | # 146 | # return outp 147 | 148 | 149 | class FID(object): 150 | def __init__(self, gpu_ids, dim=2048, ): 151 | self.inception = None 152 | self.dim = dim 153 | self.gpu_ids = gpu_ids 154 | 155 | self.mu = None 156 | self.sigma = None 157 | 158 | def _get_cifar10_mu_sigma(self): 159 | if self.mu is None: 160 | self.mu = pd.read_csv(os.path.join(os.getcwd(), "jdit/metric/Cifar10_M.csv"), header=None).values 161 | self.mu = self.mu.reshape(-1) 162 | if self.sigma is None: 163 | self.sigma = pd.read_csv(os.path.join(os.getcwd(), "jdit/metric/Cifar10_S.csv"), header=None).values 164 | return self.mu, self.sigma 165 | 166 | def _getInception(self): 167 | if self.inception is None: 168 | self.inception = InceptionV3([InceptionV3.BLOCK_INDEX_BY_DIM[self.dim]]) 169 | self.inception = self.inception.cuda() if len(self.gpu_ids) > 0 else self.inception 170 | return self.inception 171 | 172 | # def cifar10_FID_fromloader(self, dataloader): 173 | # 174 | # fake_mu, fake_sigma = self.compute_act_statistics_from_loader(dataloader, self._getInception(), self.gpu_ids) 175 | # cifar10_mu, cifar10_sigma = self._get_cifar10_mu_sigma() 176 | # fid = self.FID(cifar10_mu, cifar10_sigma, fake_mu, fake_sigma) 177 | # return fid 178 | # 179 | # def compute_act_statistics_from_loader(self, dataloader, model, gpu_ids): 180 | # model.eval() 181 | # pred_arr = None 182 | # placeholder = Variable().cuda() if len(gpu_ids) > 0 else Variable() 183 | # model = model.cuda() if len(gpu_ids) > 0 else model 184 | # for iteration, batch in enumerate(dataloader, 1): 185 | # pred = self.compute_act_batch(placeholder, batch) 186 | # if pred_arr is None: 187 | # pred_arr = pred 188 | # else: 189 | # pred_arr = torch.cat((pred_arr, pred)) # [?, 2048, 1, 1] 190 | # 191 | # pred_arr = pred_arr.cpu().numpy().reshape(pred_arr.size()[0], -1) # [?, 2048] 192 | # mu = np.mean(pred_arr, axis=0) 193 | # sigma = np.cov(pred_arr, rowvar=False) 194 | # return mu, sigma 195 | 196 | def compute_act_batch(self, batch): 197 | with torch.autograd.no_grad(): 198 | pred = self.inception(batch)[0] # [batchsize, 1024,1,1] 199 | if pred.shape[2] != 1 or pred.shape[3] != 1: 200 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 201 | return pred 202 | 203 | def evaluate_model_fid(self, model, noise_shape=(64, 3, 32, 32), amount=10000, dataset="cifar10"): 204 | model.eval() 205 | noise = Variable().cuda() if len(self.gpu_ids) > 0 else Variable() 206 | # iteration = amount // noise_shape[0] 207 | iteration = amount 208 | pred_arr = None 209 | self._getInception() 210 | with torch.autograd.no_grad(): 211 | for i in range(iteration): 212 | noise_cpu = Variable(torch.randn(noise_shape)) 213 | noise.data.resize_(noise_cpu.size()).copy_(noise_cpu) 214 | fake = model(noise) 215 | pred = self.inception(fake)[0] # [batchsize, 1024,1,1] 216 | if pred.shape[2] != 1 or pred.shape[3] != 1: 217 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 218 | if pred_arr is None: 219 | pred_arr = pred 220 | else: 221 | pred_arr = torch.cat((pred_arr, pred)) # [?, 2048, 1, 1] 222 | pred_arr = pred_arr.cpu().numpy().reshape(pred_arr.size()[0], -1) # [?, 2048] 223 | fake_mu = np.mean(pred_arr, axis=0) 224 | fake_sigma = np.cov(pred_arr, rowvar=False) 225 | if dataset == "cifar10": 226 | self._get_cifar10_mu_sigma() 227 | # print(fake_mu.shape, fake_sigma.shape) 228 | # print(self.mu.shape, self.sigma.shape) 229 | fid_score = self.FID(fake_mu, fake_sigma, self.mu, self.sigma) 230 | model.train() 231 | return fid_score 232 | 233 | @staticmethod 234 | def FID(mu1, sigma1, mu2, sigma2, eps=1e-6): 235 | """Numpy implementation of the Frechet Distance. 236 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 237 | and X_2 ~ N(mu_2, C_2) is 238 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 239 | 240 | Stable version by Dougal J. Sutherland. 241 | 242 | Params: 243 | -- mu1 : Numpy array containing the activations of a layer of the 244 | inception net (like returned by the function 'get_predictions') 245 | for generated samples. 246 | -- mu2 : The sample mean over activations, precalculated on an 247 | representive data set. 248 | -- sigma1: The covariance matrix over activations for generated samples. 249 | -- sigma2: The covariance matrix over activations, precalculated on an 250 | representive data set. 251 | 252 | Returns: 253 | -- : The Frechet Distance. 254 | """ 255 | 256 | mu1 = np.atleast_1d(mu1) 257 | mu2 = np.atleast_1d(mu2) 258 | 259 | sigma1 = np.atleast_2d(sigma1) 260 | sigma2 = np.atleast_2d(sigma2) 261 | 262 | if mu1.shape != mu2.shape: 263 | raise ValueError("Training and test mean vectors have different lengths") 264 | if sigma1.shape != sigma2.shape: 265 | raise ValueError("Training and test covariances have different dimensions") 266 | 267 | diff = mu1 - mu2 268 | 269 | # Product might be almost singular 270 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 271 | if not np.isfinite(covmean).all(): 272 | msg = ('fid calculation produces singular product; ' 273 | 'adding %s to diagonal of cov estimates') % eps 274 | print(msg) 275 | offset = np.eye(sigma1.shape[0]) * eps 276 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 277 | 278 | # Numerical error might give slight imaginary component 279 | if np.iscomplexobj(covmean): 280 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 281 | m = np.max(np.abs(covmean.imag)) 282 | raise ValueError('Imaginary component {}'.format(m)) 283 | covmean = covmean.real 284 | 285 | tr_covmean = np.trace(covmean) 286 | 287 | return (diff.dot(diff) + np.trace(sigma1) + 288 | np.trace(sigma2) - 2 * tr_covmean) 289 | -------------------------------------------------------------------------------- /docs/source/Build your own trainer.rst: -------------------------------------------------------------------------------- 1 | Build your own trainer 2 | ====================== 3 | 4 | To build your own trainer, you need prepare these sections: 5 | 6 | * ``dataset`` This is the datasets which you want to use. 7 | * ``Model`` This is a wrapper of your own pytorch ``module`` . 8 | * ``Optimizer`` This is a wrapper of pytorch ``opt`` . 9 | * ``trainer`` This is a training pipeline which assemble the sections above. 10 | 11 | jdit.dataset 12 | ------------ 13 | 14 | In this section, you should build your own dataset that you want to use following. 15 | 16 | Common dataset 17 | >>>>>>>>>>>>>> 18 | 19 | For some reasons, many opening dataset are common. So, you can easily build a standard common dataaset. 20 | such as : 21 | 22 | * Fashion mnist 23 | * Cifar10 24 | * Lsun 25 | 26 | Only one parameters you need to set is ``batch_shize`` . 27 | For these common datasets, you only need to reset the batch size. 28 | 29 | .. code-block:: python 30 | 31 | >>> from jdit.dataset import FashionMNIST 32 | >>> fashion_data = FashionMNIST(batch_shize=64) # now you get a ``dataset`` 33 | 34 | Custom dataset 35 | >>>>>>>>>>>>>> 36 | 37 | If you want to build a dataset by your own data, you need to inherit the class 38 | 39 | ``jdit.dataset.Dataloaders_factory`` 40 | 41 | and rewrite it's ``build_transforms()`` and ``build_datasets()`` 42 | (If you want to use default set, rewrite this is not necessary.) 43 | 44 | Following these setps: 45 | 46 | * Rewrite your own transforms to ``self.train_transform_list`` and ``self.valid_transform_list``. (Not necessary) 47 | * Register your training dataset to ``self.dataset_train`` by using ``self.train_transform_list`` 48 | * Register your valid_epoch dataset to ``self.dataset_valid`` by using ``self.valid_transform_list`` 49 | 50 | Example:: 51 | 52 | class FashionMNIST(DataLoadersFactory): 53 | def __init__(self, root=r'.\datasets\fashion_data', batch_size=128, num_workers=-1): 54 | super(FashionMNIST, self).__init__(root, batch_size, num_workers) 55 | 56 | def build_transforms(self, resize=32): 57 | # This is a default set, you can rewrite it. 58 | self.train_transform_list = self.valid_transform_list = [ 59 | transforms.Resize(resize), 60 | transforms.ToTensor(), 61 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])] 62 | 63 | def build_datasets(self): 64 | self.dataset_train = datasets.CIFAR10(root, train=True, download=True, 65 | transform=transforms.Compose(self.train_transform_list)) 66 | self.dataset_valid = datasets.CIFAR10(root, train=False, download=True, 67 | transform=transforms.Compose(self.valid_transform_list)) 68 | 69 | For now, you get your own dataset. 70 | 71 | Model 72 | ----- 73 | 74 | In this section, you should build your own network. 75 | 76 | First, you need to build a pytorch ``module`` like this: 77 | 78 | .. code-block:: python 79 | 80 | >>> class SimpleModel(nn.Module): 81 | ... def __init__(self): 82 | ... super(SimpleModel, self).__init__() 83 | ... self.layer1 = nn.Linear(32, 64) 84 | ... self.layer2 = nn.Linear(64, 1) 85 | ... 86 | ... def forward(self, input): 87 | ... out = self.layer1(input) 88 | ... out = self.layer2(out) 89 | ... return out 90 | >>> network = SimpleModel() 91 | 92 | .. note:: 93 | 94 | You don't need to convert it to gpu or using data parallel. 95 | The ``jdit.Model`` will do this for you. 96 | 97 | Second, wrap your model by using ``jdit.Model`` . 98 | Set which gpus you want to use and the weights init method. 99 | 100 | .. note:: 101 | 102 | For some reasons, the gpu id in pytorch still start from 0. 103 | For this model, it will handel this problem. 104 | If you have gpu ``[0,1,2,3]`` , and you only want to use 2,3. 105 | Just set ``gpu_ids_abs=[2, 3]`` . 106 | 107 | .. code-block:: python 108 | 109 | >>> from jdit import Model 110 | >>> network = SimpleModel() 111 | >>> jdit_model = Model(network, gpu_ids_abs=[2,3], init_method="kaiming") 112 | SimpleModel Total number of parameters: 2177 113 | SimpleModel dataParallel use GPUs[2, 3]! 114 | apply kaiming weight init! 115 | 116 | For now, you get your own dataset. 117 | 118 | Optimizer 119 | --------- 120 | In this section, you should build your an optimizer. 121 | 122 | Compare with the optimizer in pytorch. This extend a easy function 123 | that can do a learning rate decay and reset. 124 | 125 | However, ``do_lr_decay()`` will be called every epoch or on certain epoch 126 | at the end automatically. 127 | Actually, you don' need to do anything to apply learning rate decay. 128 | If you don't want to decay. Just set ``lr_decay = 1.`` or set a decay epoch larger than training epoch. 129 | I will show you how it works and you can implement something special strategies. 130 | 131 | .. code-block:: python 132 | 133 | >>> from jdit import Optimizer 134 | >>> from torch.nn import Linear 135 | >>> network = Linear(10, 1) 136 | >>> #set params 137 | >>> #`optimizer` is equal to pytorch class name (torch.optim.RMSprop). 138 | >>> hparams = { 139 | ... "optimizer" = "RMSprop" , 140 | ... "lr" = 0.001, 141 | ... "lr_decay" = 0.5, 142 | ... "weight_decay" = 2e-5, 143 | ... "momentum" = 0} 144 | >>> #define optimizer 145 | >>> opt = Optimizer(network.parameters(),**hparams) 146 | >>> opt.lr 147 | 0.001 148 | >>> opt.do_lr_decay() 149 | >>> opt.lr 150 | 0.0005 151 | >>> opt.do_lr_decay(reset_lr = 1) 152 | >>> opt.lr 153 | 1 154 | 155 | You can pass a certain name to use it,such "Adam" ,"RMSprop", "SGD". 156 | 157 | .. note:: 158 | 159 | As for spectrum normalization, the optimizer will filter out the differentiable weights. 160 | So, you don't need write something like this 161 | ``filter(lambda p: p.requires_grad, params)`` 162 | Merely pass the ``model.parameters()`` 163 | is enough. 164 | 165 | 166 | For now, you get an Optimizer. 167 | 168 | trainer 169 | ------- 170 | 171 | For the final section it is a little complex. 172 | It supplies some templates such as ``SupTrainer`` ``GanTrainer`` ``ClassificationTrainer`` and ``instances`` . 173 | 174 | The inherit relation shape is following: 175 | 176 | | ``SupTrainer`` 177 | 178 | | ``ClassificationTrainer`` 179 | 180 | | ``instances.FashionClassTrainer`` 181 | 182 | | ``SupGanTrainer`` 183 | 184 | | ``Pix2pixGanTrainer`` 185 | 186 | | ``instances.CifarPix2pixGanTrainer`` 187 | 188 | | ``GenerateGanTrainer`` 189 | 190 | | ``instances.FashionGenerateGanTrainer`` 191 | 192 | Top level ``SupTrainer`` 193 | >>>>>>>>>>>>>>>>>>>>>>>> 194 | ``SupTrainer`` is the top class of these templates. 195 | 196 | It defines some tools to record the log, data visualization and so on. 197 | Besides, it contain a big loop of epoch, 198 | which can be inherited by the second level templates to 199 | fill the contents in each opch training. 200 | 201 | Something like this:: 202 | 203 | def train(): 204 | for epoch in range(nepochs): 205 | self._record_configs() # record info 206 | self.train_epoch() 207 | self.valid_epoch() 208 | # do learning rate decay 209 | self._change_lr() 210 | # save model check point 211 | self._check_point() 212 | self.test() 213 | 214 | Every method will be rewrite by the second level templates. It only defines a rough framework. 215 | 216 | Second level ``ClassificationTrainer`` 217 | >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 218 | On this level, the task becomes more clear, a classification task. 219 | We get one ``model``, one ``optimizer`` and one ``dataset`` 220 | and the data structure is images and labels. 221 | So, to init a ClassificationTrainer. 222 | 223 | .. code-block:: python 224 | 225 | class ClassificationTrainer(SupTrainer): 226 | def __init__(self, logdir, nepochs, gpu_ids, net, opt, datasets, num_class): 227 | super(ClassificationTrainer, self).__init__(nepochs, logdir, gpu_ids_abs) 228 | self.net = net 229 | self.opt = opt 230 | self.datasets = datasets 231 | self.num_class = num_class 232 | self.labels = None 233 | self.output = None 234 | 235 | For the next, build a training loop for one epoch. 236 | You must using ``self.step`` to record the training step. 237 | 238 | .. code-block:: python 239 | 240 | def train_epoch(self, subbar_disable=False): 241 | # display training images every epoch 242 | self._watch_images(show_imgs_num=3, tag="Train") 243 | for iteration, batch in tqdm(enumerate(self.datasets.loader_train, 1), unit="step", disable=subbar_disable): 244 | self.step += 1 # necessary! 245 | # unzip data from one batch and move to certain device 246 | self.input, self.ground_truth, self.labels = self.get_data_from_batch(batch, self.device) 247 | self.output = self.net(self.input) 248 | # this is defined in SupTrainer. 249 | # using `self.compute_loss` and `self.opt` to do a backward. 250 | self._train_iteration(self.opt, self.compute_loss, tag="Train") 251 | 252 | @abstractmethod 253 | def compute_loss(self): 254 | """Compute the main loss and observed variables. 255 | Rewrite by the next templates. 256 | """ 257 | 258 | @abstractmethod 259 | def compute_valid(self): 260 | """Compute the valid_epoch variables for visualization. 261 | Rewrite by the next templates. 262 | """ 263 | 264 | The ``compute_loss()`` and ``compute_valid`` should be rewrite in the next template. 265 | 266 | Third level ``FashionClassTrainer`` 267 | >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 268 | 269 | Up to this level every this is clear. So, inherit the ``ClassificationTrainer`` 270 | and fill the specify methods. 271 | 272 | .. code-block:: python 273 | 274 | class FashionClassTrainer(ClassificationTrainer): 275 | def __init__(self, logdir, nepochs, gpu_ids, net, opt, dataset): 276 | super(FashionClassTrainer, self).__init__(logdir, nepochs, gpu_ids, net, opt, dataset) 277 | data, label = self.datasets.samples_train 278 | # show dataset in tensorboard 279 | self.watcher.embedding(data, data, label, 1) 280 | 281 | def compute_loss(self): 282 | var_dic = {} 283 | var_dic["CEP"] = loss = nn.CrossEntropyLoss()(self.output, self.labels.squeeze().long()) 284 | return loss, var_dic 285 | 286 | def compute_valid(self): 287 | var_dic = {} 288 | var_dic["CEP"] = cep = nn.CrossEntropyLoss()(self.output, self.labels.squeeze().long()) 289 | 290 | _, predict = torch.max(self.output.detach(), 1) # 0100=>1 0010=>2 291 | total = predict.size(0) * 1.0 292 | labels = self.labels.squeeze().long() 293 | correct = predict.eq(labels).cpu().sum().float() 294 | acc = correct / total 295 | var_dic["ACC"] = acc 296 | return var_dic 297 | 298 | ``compute_loss()`` will be called every training step of backward. It returns two values. 299 | 300 | * The first one, ``loss`` , is **main loss** which will be implemented ``loss.backward()`` to update model weights. 301 | 302 | * The second one, ``var_dic`` , is a **value dictionary** which will be visualized on tensorboard and depicted as a curve. 303 | 304 | In this example, for ``compute_loss()`` it will use ``loss = nn.CrossEntropyLoss()`` 305 | to do a backward propagation and visualize it on tensorboard named ``"CEP"``. 306 | 307 | ``compute_loss()`` will be called every validation step. It returns one value. 308 | 309 | * The ``var_dic`` , is the same thing like ``var_dic`` in ``compute_loss()`` . 310 | 311 | .. note:: 312 | 313 | ``compute_loss()`` will be called under ``torch.no_grad()`` . 314 | So, grads will not be computed in this method. But if you need to get grads, 315 | please use ``torch.enable_grad()`` to make grads computation available. 316 | 317 | Finally, you get a trainer. 318 | 319 | You have got everything. Put them together and train it! 320 | 321 | .. code-block:: python 322 | 323 | >>> mnist = FashionMNIST(batch_size) 324 | >>> net = Model(SimpleModel(depth=depth), gpu_ids_abs=gpus, init_method="kaiming") 325 | >>> opt = Optimizer(net.parameters(), **hparams) 326 | >>> Trainer = FashionClassTrainer("log", nepochs, gpus, net, opt, mnist, 10) 327 | >>> Trainer.train() 328 | 329 | 330 | -------------------------------------------------------------------------------- /jdit/assessment/fid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 3 | 4 | The FID metric calculates the distance between two distributions of images. 5 | Typically, we have summary statistics (mean & covariance matrix) of one 6 | of these distributions, while the 2nd distribution is given by a GAN. 7 | 8 | When run as a stand-alone program, it compares the distribution of 9 | images that are stored as PNG/JPEG at a specified location with a 10 | distribution given by summary statistics (in pickle format). 11 | 12 | The FID is calculated by assuming that X_1 and X_2 are the activations of 13 | the pool_3 layer of the inception net for generated samples and real world 14 | samples respectivly. 15 | 16 | See --help to see further details. 17 | 18 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 19 | of Tensorflow 20 | 21 | Copyright 2018 Institute of Bioinformatics, JKU Linz 22 | 23 | Licensed under the Apache License, Version 2.0 (the "License"); 24 | you may not use this file except in compliance with the License. 25 | You may obtain a copy of the License at 26 | 27 | http://www.apache.org/licenses/LICENSE-2.0 28 | 29 | Unless required by applicable law or agreed to in writing, software 30 | distributed under the License is distributed on an "AS IS" BASIS, 31 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 32 | See the License for the specific language governing permissions and 33 | limitations under the License. 34 | """ 35 | 36 | from tqdm import * 37 | import math 38 | from torch import Tensor 39 | from torch.utils.data import DataLoader 40 | import torch.nn as nn 41 | import torch.nn.functional as F 42 | from torchvision import models 43 | import torch 44 | import numpy as np 45 | from scipy import linalg 46 | from torch.autograd import Variable 47 | from torch.nn.functional import adaptive_avg_pool2d 48 | 49 | 50 | class InceptionV3(nn.Module): 51 | """Pretrained _InceptionV3 network returning feature maps""" 52 | 53 | # Index of default block of inception to return, 54 | # corresponds to output of final average pooling 55 | DEFAULT_BLOCK_INDEX = 3 56 | 57 | # Maps feature dimensionality to their output blocks indices 58 | BLOCK_INDEX_BY_DIM = { 59 | 64: 0, # First max pooling features 60 | 192: 1, # Second max pooling featurs 61 | 768: 2, # Pre-aux classifier features 62 | 2048: 3 # Final average pooling features 63 | } 64 | 65 | def __init__(self, 66 | output_blocks=(DEFAULT_BLOCK_INDEX,), 67 | resize_input=True, 68 | normalize_input=True, 69 | requires_grad=False): 70 | """Build pretrained _InceptionV3 71 | 72 | Parameters 73 | ---------- 74 | output_blocks : list of int 75 | Indices of blocks to return features of. Possible values are: 76 | - 0: corresponds to output of first max pooling 77 | - 1: corresponds to output of second max pooling 78 | - 2: corresponds to output which is fed to aux classifier 79 | - 3: corresponds to output of final average pooling 80 | resize_input : bool 81 | If true, bilinearly resizes input to width and height 299 before 82 | feeding input to model. As the network without fully connected 83 | layers is fully convolutional, it should be able to handle inputs 84 | of arbitrary size, so resizing might not be strictly needed 85 | normalize_input : bool 86 | If true, normalizes the input to the statistics the pretrained 87 | Inception network expects 88 | requires_grad : bool 89 | If true, parameters of the model require gradient. Possibly useful 90 | for finetuning the network 91 | """ 92 | super(InceptionV3, self).__init__() 93 | 94 | self.resize_input = resize_input 95 | self.normalize_input = normalize_input 96 | self.output_blocks = sorted(output_blocks) 97 | self.last_needed_block = max(output_blocks) 98 | 99 | assert self.last_needed_block <= 3, \ 100 | 'Last possible output block index is 3' 101 | 102 | self.blocks = nn.ModuleList() 103 | 104 | inception = models.inception_v3(pretrained=True) 105 | 106 | # Block 0: input to maxpool1 107 | block0 = [ 108 | inception.Conv2d_1a_3x3, 109 | inception.Conv2d_2a_3x3, 110 | inception.Conv2d_2b_3x3, 111 | nn.MaxPool2d(kernel_size=3, stride=2) 112 | ] 113 | self.blocks.append(nn.Sequential(*block0)) 114 | 115 | # Block 1: maxpool1 to maxpool2 116 | if self.last_needed_block >= 1: 117 | block1 = [ 118 | inception.Conv2d_3b_1x1, 119 | inception.Conv2d_4a_3x3, 120 | nn.MaxPool2d(kernel_size=3, stride=2) 121 | ] 122 | self.blocks.append(nn.Sequential(*block1)) 123 | 124 | # Block 2: maxpool2 to aux classifier 125 | if self.last_needed_block >= 2: 126 | block2 = [ 127 | inception.Mixed_5b, 128 | inception.Mixed_5c, 129 | inception.Mixed_5d, 130 | inception.Mixed_6a, 131 | inception.Mixed_6b, 132 | inception.Mixed_6c, 133 | inception.Mixed_6d, 134 | inception.Mixed_6e, 135 | ] 136 | self.blocks.append(nn.Sequential(*block2)) 137 | 138 | # Block 3: aux classifier to final avgpool 139 | if self.last_needed_block >= 3: 140 | block3 = [ 141 | inception.Mixed_7a, 142 | inception.Mixed_7b, 143 | inception.Mixed_7c, 144 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 145 | ] 146 | self.blocks.append(nn.Sequential(*block3)) 147 | 148 | for param in self.parameters(): 149 | param.requires_grad = requires_grad 150 | 151 | def forward(self, inp): 152 | """Get Inception feature maps 153 | 154 | Parameters 155 | ---------- 156 | inp : torch.autograd.Variable 157 | Input tensor of shape Bx3xHxW. Values are expected to be in 158 | range (0, 1) 159 | 160 | Returns 161 | ------- 162 | List of torch.autograd.Variable, corresponding to the selected output 163 | block, sorted ascending by index 164 | """ 165 | outp = [] 166 | x = inp 167 | 168 | if self.resize_input: 169 | # x = F.upsample(x, size=(299, 299), mode='bilinear') 170 | x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False) 171 | if self.normalize_input: 172 | x = x.clone() 173 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 174 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 175 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 176 | 177 | for idx, block in enumerate(self.blocks): 178 | x = block(x) 179 | if idx in self.output_blocks: 180 | outp.append(x) 181 | 182 | if idx == self.last_needed_block: 183 | break 184 | 185 | return outp 186 | 187 | 188 | # ______________________________________________________________ 189 | 190 | def compute_act_statistics_from_loader(dataloader: DataLoader, model, gpu_ids): 191 | """ 192 | 193 | :param dataloader: 194 | :param model: 195 | :param gpu_ids: 196 | :return: 197 | """ 198 | model.eval() 199 | pred_arr = None 200 | image = Variable().cuda() if len(gpu_ids) > 0 else Variable() 201 | model = model.cuda() if len(gpu_ids) > 0 else model 202 | for iteration, batch in tqdm(enumerate(dataloader, 1)): 203 | 204 | image.data.resize_(batch[0].size()).copy_(batch[0]) 205 | with torch.autograd.no_grad(): 206 | pred = model(image)[0] # [batchsize, 1024,1,1] 207 | if pred.shape[2] != 1 or pred.shape[3] != 1: 208 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 209 | if pred_arr is None: 210 | pred_arr = pred 211 | else: 212 | pred_arr = torch.cat((pred_arr, pred)) # [?, 2048, 1, 1] 213 | 214 | pred_arr = pred_arr.cpu().numpy().reshape(pred_arr.size()[0], -1) # [?, 2048] 215 | mu = np.mean(pred_arr, axis=0) 216 | sigma = np.cov(pred_arr, rowvar=False) 217 | return mu, sigma 218 | 219 | 220 | def compute_act_statistics_from_torch(tensor, model, gpu_ids): 221 | model.eval() 222 | model = model.cuda() if len(gpu_ids) > 0 else model 223 | 224 | image = Variable().cuda() if len(gpu_ids) > 0 else Variable() 225 | image.data.resize_(tensor.size()).copy_(tensor) # [Total, C, H, W] 226 | with torch.autograd.no_grad(): 227 | pred = model(image)[0] # [batchsize, 1024,1,1] 228 | if pred.shape[2] != 1 or pred.shape[3] != 1: 229 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 230 | 231 | pred_np = pred.cpu().numpy().reshape(len(pred), -1) # [?, 2048] 232 | mu = np.mean(pred_np, axis=0) 233 | sigma = np.cov(pred_np, rowvar=False) 234 | return mu, sigma 235 | 236 | 237 | def compute_act_statistics_from_path(path, model, gpu_ids, dims, batch_size=32, verbose=True): 238 | if path.endswith('.npz'): 239 | f = np.load(path) 240 | mu, sigma = f['mu'][:], f['sigma'][:] 241 | f.close() 242 | else: 243 | import pathlib 244 | from scipy.misc import imread 245 | path = pathlib.Path(path) 246 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 247 | 248 | images = np.array([imread(str(fn)).astype(np.float32) for fn in files]) 249 | # Bring images to shape (B, 3, H, W) 250 | images = images.transpose((0, 3, 1, 2)) 251 | # Rescale images to be between 0 and 1 252 | images /= 255 253 | 254 | # act = get_activations(images, model, batch_size, dims, cuda, verbose) 255 | model.eval() 256 | model = model.cuda() if len(gpu_ids) > 0 else model 257 | d0 = images.shape[0] 258 | if batch_size > d0: 259 | print(('Warning: batch size is bigger than the data size. ' 260 | 'Setting batch size to data size')) 261 | batch_size = d0 262 | n_batches = d0 // batch_size 263 | n_used_imgs = n_batches * batch_size 264 | pred_arr = np.empty((n_used_imgs, dims)) 265 | for i in range(n_batches): 266 | if verbose: 267 | print('\rPropagating batch %d/%d' % (i + 1, n_batches), 268 | end='', flush=True) 269 | start = i * batch_size 270 | end = start + batch_size 271 | 272 | batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor) 273 | batch = Variable(batch, volatile=True) 274 | if len(gpu_ids) > 0: 275 | batch = batch.cuda() 276 | pred = model(batch)[0] 277 | # If model output is not scalar, apply global spatial average pooling. 278 | # This happens if you choose a dimensionality not equal 2048. 279 | if pred.shape[2] != 1 or pred.shape[3] != 1: 280 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 281 | 282 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1) 283 | 284 | mu = np.mean(pred_arr, axis=0) 285 | sigma = np.cov(pred_arr, rowvar=False) 286 | return mu, sigma 287 | 288 | 289 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 290 | """Numpy implementation of the Frechet Distance. 291 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 292 | and X_2 ~ N(mu_2, C_2) is 293 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 294 | 295 | Stable version by Dougal J. Sutherland. 296 | 297 | Params: 298 | -- mu1 : Numpy array containing the activations of a layer of the 299 | inception net (like returned by the function 'get_predictions') 300 | for generated samples. 301 | -- mu2 : The sample mean over activations, precalculated on an 302 | representive data set. 303 | -- sigma1: The covariance matrix over activations for generated samples. 304 | -- sigma2: The covariance matrix over activations, precalculated on an 305 | representive data set. 306 | 307 | Returns: 308 | -- : The Frechet Distance. 309 | """ 310 | 311 | mu1 = np.atleast_1d(mu1) 312 | mu2 = np.atleast_1d(mu2) 313 | 314 | sigma1 = np.atleast_2d(sigma1) 315 | sigma2 = np.atleast_2d(sigma2) 316 | 317 | assert mu1.shape == mu2.shape, \ 318 | 'Training and test mean vectors have different lengths' 319 | assert sigma1.shape == sigma2.shape, \ 320 | 'Training and test covariances have different dimensions' 321 | 322 | diff = mu1 - mu2 323 | 324 | # Product might be almost singular 325 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 326 | if not np.isfinite(covmean).all(): 327 | msg = ('fid calculation produces singular product; ' 328 | 'adding %s to diagonal of cov estimates') % eps 329 | print(msg) 330 | offset = np.eye(sigma1.shape[0]) * eps 331 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 332 | 333 | # Numerical error might give slight imaginary component 334 | if np.iscomplexobj(covmean): 335 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 336 | m = np.max(np.abs(covmean.imag)) 337 | raise ValueError('Imaginary component {}'.format(m)) 338 | covmean = covmean.real 339 | 340 | tr_covmean = np.trace(covmean) 341 | 342 | return (diff.dot(diff) + np.trace(sigma1) + 343 | np.trace(sigma2) - 2 * tr_covmean) 344 | 345 | 346 | BLOCK_INDEX_BY_DIM = { 347 | 64: 0, # First max pooling features 348 | 192: 1, # Second max pooling featurs 349 | 768: 2, # Pre-aux classifier features 350 | 2048: 3 # Final average pooling features 351 | } 352 | 353 | 354 | def FID_score(source, target, sample_prop=1.0, gpu_ids=(), dim=2048, batchsize=128, verbose=True): 355 | """ Compute FID score from ``Tensor``, ``DataLoader`` or a directory``path``. 356 | 357 | 358 | :param source: source data. 359 | :param target: target data. 360 | :param sample_prop: If passing a ``Tensor`` source, set this rate to sample a part of data from source. 361 | :param gpu_ids: gpu ids. 362 | :param dim: The number of features. Three options available. 363 | 364 | * 64: The first max pooling features of Inception. 365 | 366 | * 192: The Second max pooling features of Inception. 367 | 368 | * 768: The Pre-aux classifier features of Inception. 369 | 370 | * 2048: The Final average pooling features of Inception. 371 | 372 | Default: 2048. 373 | :param batchsize: Only using for passing paths of source and target. 374 | :param verbose: If show processing log. 375 | :return: fid score 376 | 377 | .. attention :: 378 | 379 | If you are passing ``Tensor`` as source and target. 380 | Make sure you have enough memory to load these data in _InceptionV3. 381 | Otherwise, please passing ``path`` of ``DataLoader`` to compute them step by step. 382 | 383 | Example:: 384 | 385 | >>> from jdit.dataset import Cifar10 386 | >>> loader = Cifar10(root=r"../../datasets/cifar10", batch_shape=(32, 3, 32, 32)) 387 | >>> target_tensor = loader.samples_train[0] 388 | >>> source_tensor = loader.samples_valid[0] 389 | >>> # using Tensor to compute FID score 390 | >>> fid_value = FID_score(source_tensor, target_tensor, sample_prop=0.01, depth=768) 391 | >>> print('FID: ', fid_value) 392 | >>> # using DataLoader to compute FID score 393 | >>> fid_value = FID_score(loader.loader_test, loader.loader_valid, depth=768) 394 | >>> print('FID: ', fid_value) 395 | 396 | """ 397 | assert sample_prop <= 1 and sample_prop > 0, "sample_prop must between 0 and 1, but %s got" % sample_prop 398 | model = InceptionV3([InceptionV3.BLOCK_INDEX_BY_DIM[dim]]) 399 | if isinstance(source, Tensor) and isinstance(target, Tensor): 400 | # source[?,C,H,W] target[?,C,H,W] 401 | s_length = len(source) 402 | t_length = len(target) 403 | nums_sample = math.floor(min(s_length, t_length) * sample_prop) 404 | s_mu, s_sigma = compute_act_statistics_from_torch(source[0:nums_sample], model, gpu_ids) 405 | t_mu, t_sigma = compute_act_statistics_from_torch(target[0:nums_sample], model, gpu_ids) 406 | 407 | elif isinstance(source, DataLoader) and isinstance(target, DataLoader): 408 | s_mu, s_sigma = compute_act_statistics_from_loader(source, model, gpu_ids) 409 | t_mu, t_sigma = compute_act_statistics_from_loader(target, model, gpu_ids) 410 | elif isinstance(source, str) and isinstance(target, str): 411 | s_mu, s_sigma = compute_act_statistics_from_path(source, model, gpu_ids, dim, batchsize, verbose) 412 | t_mu, t_sigma = compute_act_statistics_from_path(target, model, gpu_ids, dim, batchsize, verbose) 413 | else: 414 | raise Exception 415 | 416 | fid_value = calculate_frechet_distance(s_mu, s_sigma, t_mu, t_sigma) 417 | return fid_value 418 | -------------------------------------------------------------------------------- /jdit/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, random_split 2 | from torchvision import datasets, transforms 3 | import psutil 4 | #from typing import Union 5 | from abc import ABCMeta, abstractmethod 6 | from torch.utils.data.distributed import DistributedSampler 7 | 8 | 9 | class DataLoadersFactory(metaclass=ABCMeta): 10 | """This is a super class of dataloader. 11 | 12 | It defines same basic attributes and methods. 13 | 14 | * For training data: ``train_dataset``, ``loader_train``, ``nsteps_train`` . 15 | Others such as ``valid_epoch`` and ``test`` have the same naming format. 16 | * For transform, you can define your own transforms. 17 | * If you don't have test set, it will be replaced by valid_epoch dataset. 18 | 19 | It will build dataset following these setps: 20 | 21 | #. ``build_transforms()`` To build transforms for training dataset and valid_epoch. 22 | You can rewrite this method for your own transform. It will be used in ``build_datasets()`` 23 | #. ``build_datasets()`` You must rewrite this method to load your own dataset 24 | by passing datasets to ``self.dataset_train`` and ``self.dataset_valid`` . 25 | ``self.dataset_test`` is optional. If you don't pass a test dataset, 26 | it will be replaced by ``self.dataset_valid`` . 27 | 28 | Example:: 29 | 30 | def build_transforms(self, resize=32): 31 | self.train_transform_list = self.valid_transform_list = [ 32 | transforms.Resize(resize), 33 | transforms.ToTensor(), 34 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])] 35 | # Inherit this class and write this method. 36 | def build_datasets(self): 37 | self.dataset_train = datasets.CIFAR10(root, train=True, download=True, 38 | transform=transforms.Compose(self.train_transform_list)) 39 | self.dataset_valid = datasets.CIFAR10(root, train=False, download=True, 40 | transform=transforms.Compose(self.valid_transform_list)) 41 | 42 | #. ``build_loaders()`` It will use dataset, and passed parameters to 43 | build dataloaders for ``self.loader_train``, ``self.loader_valid`` and ``self.loader_test``. 44 | 45 | 46 | * :attr:`root` is the root path of datasets. 47 | 48 | * :attr:`batch_shape` is the size of data loader. shape is ``(Batchsize, Channel, Height, Width)`` 49 | 50 | * :attr:`num_workers` is the number of threads, using to load data. 51 | If you pass -1, it will use the max number of threads, according to your cpu. Default: -1 52 | 53 | * :attr:`shuffle` is whether shuffle the data. Default: ``True`` 54 | 55 | """ 56 | 57 | def __init__(self, root: str, batch_size: int, num_workers=-1, shuffle=True, subdata_size=1): 58 | """ Build data loaders. 59 | 60 | :param root: root path of datasets. 61 | :param batch_size: shape of data. ``(Batchsize, Channel, Height, Width)`` 62 | :param num_workers: the number of threads. Default: -1 63 | :param shuffle: whether shuffle the data. Default: ``True`` 64 | """ 65 | self.batch_size = batch_size 66 | self.shuffle = shuffle 67 | self.root = root 68 | if num_workers == -1: 69 | print("use %d thread!" % psutil.cpu_count()) 70 | self.num_workers = psutil.cpu_count() 71 | else: 72 | self.num_workers = num_workers 73 | 74 | self.dataset_train: datasets = None 75 | self.dataset_valid: datasets = None 76 | self.dataset_test: datasets = None 77 | 78 | self.loader_train: DataLoader = None 79 | self.loader_valid: DataLoader = None 80 | self.loader_test: DataLoader = None 81 | 82 | self.nsteps_train: int = None 83 | self.nsteps_valid: int = None 84 | self.nsteps_test: int = None 85 | 86 | self.sample_dataset_size = subdata_size 87 | 88 | self.build_transforms() 89 | self.build_datasets() 90 | self.build_loaders() 91 | 92 | def build_transforms(self, resize: int = 32): 93 | """ This will build transforms for training and valid_epoch. 94 | 95 | You can rewrite this method to build your own transforms. 96 | Don't forget to register your transforms to ``self.train_transform_list`` and ``self.valid_transform_list`` 97 | 98 | The following is the default set. 99 | 100 | .. code:: 101 | 102 | self.train_transform_list = self.valid_transform_list = [ 103 | transforms.Resize(resize), 104 | transforms.ToTensor(), 105 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])] 106 | 107 | """ 108 | self.train_transform_list = self.valid_transform_list = [ 109 | transforms.Resize(resize), 110 | transforms.ToTensor(), 111 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])] 112 | 113 | @abstractmethod 114 | def build_datasets(self): 115 | """ You must to rewrite this method to load your own datasets. 116 | 117 | * :attr:`self.dataset_train` . Assign a training ``dataset`` to this. 118 | * :attr:`self.dataset_valid` . Assign a valid_epoch ``dataset`` to this. 119 | * :attr:`self.dataset_test` is optional. Assign a test ``dataset`` to this. 120 | If not, it will be replaced by ``self.dataset_valid`` . 121 | 122 | Example:: 123 | 124 | self.dataset_train = datasets.CIFAR10(root, train=True, download=True, 125 | transform=transforms.Compose(self.train_transform_list)) 126 | self.dataset_valid = datasets.CIFAR10(root, train=False, download=True, 127 | transform=transforms.Compose(self.valid_transform_list)) 128 | """ 129 | pass 130 | 131 | def build_loaders(self): 132 | r""" Build datasets 133 | The previous function ``self.build_datasets()`` has created datasets. 134 | Use these datasets to build their's dataloaders 135 | """ 136 | 137 | if self.dataset_train is None: 138 | raise ValueError( 139 | "`self.dataset_train` can't be `None`! Rewrite `build_datasets` method and pass your own dataset " 140 | "to " 141 | "`self.dataset_train`") 142 | if self.dataset_train is None: 143 | raise ValueError( 144 | "`self.dataset_valid` can't be `None`! Rewrite `build_datasets` method and pass your own dataset " 145 | "to " 146 | "`self.dataset_valid`") 147 | 148 | # Create dataloaders 149 | self.loader_train = DataLoader(self.dataset_train, batch_size=self.batch_size, shuffle=self.shuffle) 150 | self.loader_valid = DataLoader(self.dataset_valid, batch_size=self.batch_size, shuffle=self.shuffle) 151 | self.nsteps_train = len(self.loader_train) 152 | self.nsteps_valid = len(self.loader_valid) 153 | 154 | if self.dataset_test is None: 155 | self.dataset_test = self.dataset_valid 156 | 157 | self.loader_test = DataLoader(self.dataset_test, batch_size=self.batch_size, shuffle=self.shuffle) 158 | self.nsteps_test = len(self.loader_test) 159 | 160 | def convert_to_distributed(self, which_dataset=None, num_replicas=None, rank=None): 161 | samplers = {} 162 | if which_dataset is None: 163 | samplers["train"] = DistributedSampler(self.dataset_train, num_replicas=None, rank=None) 164 | self.loader_train = DataLoader(self.dataset_train, self.batch_size, False, sampler=samplers["train"]) 165 | 166 | else: 167 | if which_dataset == "train": 168 | samplers["train"] = DistributedSampler(self.dataset_train, num_replicas=num_replicas, rank=rank) 169 | self.loader_train = DataLoader(self.dataset_train, self.batch_size, False, 170 | sampler=samplers["train"]) 171 | elif which_dataset == "valid": 172 | samplers["valid"] = DistributedSampler(self.dataset_valid, num_replicas=num_replicas, rank=rank) 173 | self.loader_valid = DataLoader(self.dataset_valid, self.batch_size, False, 174 | sampler=samplers["valid"]) 175 | elif which_dataset == "test": 176 | self.loader_test.sampler = samplers["test"] 177 | self.loader_test = DataLoader(self.dataset_test, self.batch_size, False, 178 | sampler=samplers["test"]) 179 | else: 180 | ValueError( 181 | "param `which_dataset` can only be set 'train, valid and test'. Got %s instead" % which_dataset) 182 | return samplers 183 | 184 | @property 185 | def samples_train(self): 186 | return self._get_samples(self.dataset_train, self.sample_dataset_size) 187 | 188 | @property 189 | def samples_valid(self): 190 | return self._get_samples(self.dataset_train, self.sample_dataset_size) 191 | 192 | @property 193 | def samples_test(self): 194 | return self._get_samples(self.dataset_train, self.sample_dataset_size) 195 | 196 | @staticmethod 197 | def _get_samples(dataset, sample_dataset_size=1): 198 | import math 199 | if int(len(dataset) * sample_dataset_size) <= 0: 200 | raise ValueError( 201 | "Dataset is %d too small. `sample_dataset_size` is %f" % (len(dataset), sample_dataset_size)) 202 | size_is_prop = isinstance(sample_dataset_size, float) 203 | size_is_amount = isinstance(sample_dataset_size, int) 204 | if size_is_prop: 205 | if not (0 < sample_dataset_size <= 1): 206 | raise ValueError("sample_dataset_size proportion should between 0. and 1.") 207 | subdata_size = math.floor(sample_dataset_size * len(dataset)) 208 | elif size_is_amount: 209 | if not (sample_dataset_size < len(dataset)): 210 | raise ValueError("sample_dataset_size amount should be smaller than length of dataset") 211 | subdata_size = sample_dataset_size 212 | else: 213 | raise Exception("sample_dataset_size should be float or int." 214 | "%s was given" % str(sample_dataset_size)) 215 | sample_dataset, _ = random_split(dataset, [subdata_size, len(dataset) - subdata_size]) 216 | sample_loader = DataLoader(sample_dataset, batch_size=subdata_size, shuffle=True) 217 | [samples_data] = list(sample_loader) 218 | return samples_data 219 | 220 | @property 221 | def configure(self): 222 | configs = dict() 223 | configs["dataset_name"] = str(self.dataset_train.__class__.__name__) 224 | configs["batch_size"] = str(self.batch_size) 225 | configs["shuffle"] = str(self.shuffle) 226 | configs["root"] = str(self.root) 227 | configs["num_workers"] = str(self.num_workers) 228 | configs["sample_dataset_size"] = str(self.sample_dataset_size) 229 | configs["nsteps_train"] = str(self.nsteps_train) 230 | configs["nsteps_valid"] = str(self.nsteps_valid) 231 | configs["nsteps_test"] = str(self.nsteps_test) 232 | configs["dataset_train"] = str(self.dataset_train) 233 | configs["dataset_valid"] = str(self.dataset_valid) 234 | configs["dataset_test"] = str(self.dataset_test) 235 | return configs 236 | 237 | 238 | class HandMNIST(DataLoadersFactory): 239 | """ Hand writing mnist dataset. 240 | 241 | Example:: 242 | 243 | >>> data = HandMNIST(r"../datasets/mnist") 244 | use 8 thread! 245 | Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz 246 | Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz 247 | Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz 248 | Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz 249 | Processing... 250 | Done! 251 | >>> data.dataset_train 252 | Dataset MNIST 253 | Number of datapoints: 60000 254 | Split: train 255 | Root Location: data 256 | Transforms (if any): Compose( 257 | Resize(size=32, interpolation=PIL.Image.BILINEAR) 258 | ToTensor() 259 | Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 260 | ) 261 | Target Transforms (if any): None 262 | >>> # We don't set test dataset, so they are the same. 263 | >>> data.dataset_valid is data.dataset_test 264 | True 265 | >>> # Number of steps at batch size 128. 266 | >>> data.nsteps_train 267 | 469 268 | >>> # Total samples of training datset. 269 | >>> len(data.dataset_train) 270 | 60000 271 | >>> # The batch size of sample load is 1. So, we get length of loader is equal to samples amount. 272 | >>> len(data.samples_train) 273 | 6000 274 | 275 | """ 276 | 277 | def __init__(self, root="datasets/hand_data", batch_size=64, num_workers=-1): 278 | super(HandMNIST, self).__init__(root, batch_size, num_workers) 279 | 280 | def build_datasets(self): 281 | """Build datasets by using ``datasets.MNIST`` in pytorch 282 | 283 | 284 | 285 | """ 286 | self.dataset_train = datasets.MNIST(self.root, train=True, download=True, 287 | transform=transforms.Compose(self.train_transform_list)) 288 | self.dataset_valid = datasets.MNIST(self.root, train=False, download=True, 289 | transform=transforms.Compose(self.valid_transform_list)) 290 | 291 | def build_transforms(self, resize: int = 32): 292 | self.train_transform_list = self.valid_transform_list = [ 293 | transforms.Resize(resize), 294 | transforms.ToTensor(), 295 | transforms.Normalize([0.5], [0.5])] 296 | 297 | 298 | class FashionMNIST(DataLoadersFactory): 299 | def __init__(self, root="datasets/fashion_data", batch_size=64, num_workers=-1): 300 | super(FashionMNIST, self).__init__(root, batch_size, num_workers) 301 | 302 | def build_datasets(self): 303 | self.dataset_train = datasets.FashionMNIST(self.root, train=True, download=True, 304 | transform=transforms.Compose(self.train_transform_list)) 305 | self.dataset_valid = datasets.FashionMNIST(self.root, train=False, download=True, 306 | transform=transforms.Compose(self.valid_transform_list)) 307 | 308 | def build_transforms(self, resize: int = 32): 309 | self.train_transform_list = self.valid_transform_list = [ 310 | transforms.Resize(resize), 311 | transforms.ToTensor(), 312 | transforms.Normalize([0.5], [0.5])] 313 | 314 | 315 | class Cifar10(DataLoadersFactory): 316 | def __init__(self, root="datasets/cifar10", batch_size=32, num_workers=-1): 317 | super(Cifar10, self).__init__(root, batch_size, num_workers) 318 | 319 | def build_datasets(self): 320 | self.dataset_train = datasets.CIFAR10(self.root, train=True, download=True, 321 | transform=transforms.Compose(self.train_transform_list)) 322 | self.dataset_valid = datasets.CIFAR10(self.root, train=False, download=True, 323 | transform=transforms.Compose(self.valid_transform_list)) 324 | 325 | 326 | class Lsun(DataLoadersFactory): 327 | def __init__(self, root, batch_size=32, num_workers=-1): 328 | super(Lsun, self).__init__(root, batch_size, num_workers) 329 | 330 | def build_datasets(self): 331 | self.dataset_train = datasets.CIFAR10(self.root, train=True, download=True, 332 | transform=transforms.Compose(self.train_transform_list)) 333 | self.dataset_valid = datasets.CIFAR10(self.root, train=False, download=True, 334 | transform=transforms.Compose(self.valid_transform_list)) 335 | 336 | def build_transforms(self, resize: int = 32): 337 | super(Lsun, self).build_transforms(resize) 338 | 339 | 340 | def get_mnist_dataloaders(root=r'..\data', batch_size=128): 341 | """MNIST dataloader with (32, 32) sized images.""" 342 | # Resize images so they are a power of 2 343 | all_transforms = transforms.Compose([ 344 | transforms.Resize(28), 345 | transforms.ToTensor(), 346 | # transforms.Normalize([0.5],[0.5]) 347 | ]) 348 | # Get train and test data 349 | train_data = datasets.MNIST(root, train=True, download=True, 350 | transform=all_transforms) 351 | test_data = datasets.MNIST(root, train=False, 352 | transform=all_transforms) 353 | # Create dataloaders 354 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) 355 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) 356 | return train_loader, test_loader 357 | 358 | 359 | def get_fashion_mnist_dataloaders(root=r'.\dataset\fashion_data', batch_size=128, resize=32, transform_list=None, 360 | num_workers=-1): 361 | """Fashion MNIST dataloader with (32, 32) sized images.""" 362 | # Resize images so they are a power of 2 363 | if num_workers == -1: 364 | print("use %d thread!" % psutil.cpu_count()) 365 | num_workers = psutil.cpu_count() 366 | if transform_list is None: 367 | transform_list = [ 368 | transforms.Resize(resize), 369 | transforms.ToTensor(), 370 | transforms.Normalize([0.5], [0.5]) 371 | ] 372 | all_transforms = transforms.Compose(transform_list) 373 | # Get train and test data 374 | train_data = datasets.FashionMNIST(root, train=True, download=True, 375 | transform=all_transforms) 376 | test_data = datasets.FashionMNIST(root, train=False, 377 | transform=all_transforms) 378 | # Create dataloaders 379 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers) 380 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers) 381 | return train_loader, test_loader 382 | 383 | 384 | def get_lsun_dataloader(path_to_data='/data/dgl/LSUN', dataset='bedroom_train', 385 | batch_size=64): 386 | """LSUN dataloader with (128, 128) sized images. 387 | 388 | path_to_data : str 389 | One of 'bedroom_val' or 'bedroom_train' 390 | """ 391 | # Compose transforms 392 | transform = transforms.Compose([ 393 | transforms.Resize(128), 394 | transforms.CenterCrop(128), 395 | transforms.ToTensor() 396 | ]) 397 | 398 | # Get dataset 399 | lsun_dset = datasets.LSUN(root=path_to_data, classes=[dataset], 400 | transform=transform) 401 | 402 | # Create dataloader 403 | return DataLoader(lsun_dset, batch_size=batch_size, shuffle=True) 404 | --------------------------------------------------------------------------------