├── .gitignore ├── .travis.yml ├── LICENSE ├── Makefile ├── README.md ├── cortex ├── __init__.py ├── _lib │ ├── __init__.py │ ├── config.py │ ├── data │ │ ├── __init__.py │ │ ├── data_handler.py │ │ └── noise.py │ ├── exp.py │ ├── handlers.py │ ├── log_utils.py │ ├── models.py │ ├── optimizer.py │ ├── parsing.py │ ├── reg.py │ ├── train.py │ ├── utils.py │ ├── viz.py │ └── viz_utils.py ├── built_ins │ ├── __init__.py │ ├── datasets │ │ ├── CelebA.py │ │ ├── __init__.py │ │ ├── dSprites.py │ │ ├── imagenet.py │ │ ├── nii_dataload.py │ │ ├── torchvision_datasets.py │ │ ├── toysets.py │ │ └── utils.py │ ├── models │ │ ├── __init__.py │ │ ├── adversarial_autoencoder.py │ │ ├── ae.py │ │ ├── ali.py │ │ ├── classifier.py │ │ ├── discrete_gan.py │ │ ├── gan.py │ │ ├── image_coders.py │ │ ├── mine.py │ │ ├── utils.py │ │ └── vae.py │ ├── networks │ │ ├── SpectralNormLayer.py │ │ ├── __init__.py │ │ ├── ae_network.py │ │ ├── base_network.py │ │ ├── conv_decoders.py │ │ ├── convnets.py │ │ ├── fully_connected.py │ │ ├── modules.py │ │ ├── resnets.py │ │ ├── tv_models_wrapper.py │ │ └── utils.py │ └── transforms │ │ ├── __init__.py │ │ └── sobel.py ├── main.py └── plugins.py ├── demos ├── demo_classifier.py └── demo_custom_ae.py ├── docs.py ├── docs ├── .nojekyll ├── html │ ├── .buildinfo │ ├── .nojekyll │ ├── _modules │ │ ├── cortex │ │ │ ├── built_ins │ │ │ │ ├── datasets │ │ │ │ │ ├── CelebA.html │ │ │ │ │ ├── imagenet.html │ │ │ │ │ ├── nii_dataload.html │ │ │ │ │ ├── torchvision_datasets.html │ │ │ │ │ ├── toysets.html │ │ │ │ │ └── utils.html │ │ │ │ ├── models │ │ │ │ │ ├── adversarial_autoencoder.html │ │ │ │ │ ├── ae.html │ │ │ │ │ ├── ali.html │ │ │ │ │ ├── classifier.html │ │ │ │ │ ├── gan.html │ │ │ │ │ ├── image_coders.html │ │ │ │ │ ├── mine.html │ │ │ │ │ ├── utils.html │ │ │ │ │ └── vae.html │ │ │ │ ├── networks │ │ │ │ │ ├── SpectralNormLayer.html │ │ │ │ │ ├── ae_network.html │ │ │ │ │ ├── base_network.html │ │ │ │ │ ├── conv_decoders.html │ │ │ │ │ ├── convnets.html │ │ │ │ │ ├── fully_connected.html │ │ │ │ │ ├── modules.html │ │ │ │ │ ├── resnets.html │ │ │ │ │ ├── tv_models_wrapper.html │ │ │ │ │ └── utils.html │ │ │ │ └── transforms │ │ │ │ │ └── sobel.html │ │ │ ├── config.html │ │ │ ├── main.html │ │ │ └── plugins.html │ │ └── index.html │ ├── _static │ │ ├── ajax-loader.gif │ │ ├── basic.css │ │ ├── comment-bright.png │ │ ├── comment-close.png │ │ ├── comment.png │ │ ├── css │ │ │ ├── badge_only.css │ │ │ └── theme.css │ │ ├── doctools.js │ │ ├── documentation_options.js │ │ ├── down-pressed.png │ │ ├── down.png │ │ ├── file.png │ │ ├── fonts │ │ │ ├── Lato │ │ │ │ ├── lato-bold.eot │ │ │ │ ├── lato-bold.ttf │ │ │ │ ├── lato-bold.woff │ │ │ │ ├── lato-bold.woff2 │ │ │ │ ├── lato-bolditalic.eot │ │ │ │ ├── lato-bolditalic.ttf │ │ │ │ ├── lato-bolditalic.woff │ │ │ │ ├── lato-bolditalic.woff2 │ │ │ │ ├── lato-italic.eot │ │ │ │ ├── lato-italic.ttf │ │ │ │ ├── lato-italic.woff │ │ │ │ ├── lato-italic.woff2 │ │ │ │ ├── lato-regular.eot │ │ │ │ ├── lato-regular.ttf │ │ │ │ ├── lato-regular.woff │ │ │ │ └── lato-regular.woff2 │ │ │ ├── RobotoSlab │ │ │ │ ├── roboto-slab-v7-bold.eot │ │ │ │ ├── roboto-slab-v7-bold.ttf │ │ │ │ ├── roboto-slab-v7-bold.woff │ │ │ │ ├── roboto-slab-v7-bold.woff2 │ │ │ │ ├── roboto-slab-v7-regular.eot │ │ │ │ ├── roboto-slab-v7-regular.ttf │ │ │ │ ├── roboto-slab-v7-regular.woff │ │ │ │ └── roboto-slab-v7-regular.woff2 │ │ │ ├── fontawesome-webfont.eot │ │ │ ├── fontawesome-webfont.svg │ │ │ ├── fontawesome-webfont.ttf │ │ │ ├── fontawesome-webfont.woff │ │ │ └── fontawesome-webfont.woff2 │ │ ├── jquery-3.2.1.js │ │ ├── jquery.js │ │ ├── js │ │ │ ├── modernizr.min.js │ │ │ └── theme.js │ │ ├── minus.png │ │ ├── plus.png │ │ ├── pygments.css │ │ ├── searchtools.js │ │ ├── underscore-1.3.1.js │ │ ├── underscore.js │ │ ├── up-pressed.png │ │ ├── up.png │ │ └── websupport.js │ ├── build.html │ ├── cortex.built_ins.datasets.html │ ├── cortex.built_ins.html │ ├── cortex.built_ins.models.html │ ├── cortex.built_ins.networks.html │ ├── cortex.built_ins.transforms.html │ ├── cortex.html │ ├── develop.html │ ├── genindex.html │ ├── getting_started.html │ ├── index.html │ ├── install.html │ ├── modules.html │ ├── py-modindex.html │ ├── search.html │ └── searchindex.js ├── index.html └── source │ ├── build.rst │ ├── conf.py │ ├── cortex.rst │ ├── develop.rst │ ├── getting_started.rst │ ├── index.rst │ ├── install.rst │ └── modules.rst ├── make.bat ├── scripts └── nearest-neighbor.py ├── setup.py ├── tests ├── built_ins │ ├── models │ │ ├── test_build.py │ │ ├── test_loop.py │ │ └── test_routine.py │ └── networks │ │ ├── test_base_network.py │ │ ├── test_convnets.py │ │ ├── test_fully_connected.py │ │ └── test_network_utils.py ├── conftest.py ├── test_argparsing.py ├── test_handlers.py └── test_optimizer.py ├── tox.ini ├── travis-config.py └── travis-output.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | config.yaml 3 | 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .pytest_cache/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *,cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # IPython Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # dotenv 81 | .env 82 | 83 | # virtualenv 84 | venv/ 85 | ENV/ 86 | 87 | # Spyder project settings 88 | .spyderproject 89 | 90 | # Rope project settings 91 | .ropeproject 92 | .idea/ 93 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.6" 4 | before_install: 5 | - "pip install pyyaml" 6 | - "pip install sphinxcontrib-napoleon" 7 | - "python travis-config.py" 8 | install: 9 | - "bash travis-output.sh" 10 | - "pip install flake8" 11 | script: 12 | - "flake8 --ignore=C901" 13 | - "pytest" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, Montreal Institute for Learning Algorithms (MILA) 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = Cortex 8 | SOURCEDIR = docs/source 9 | BUILDDIR = docs 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /cortex/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | '''Init file for cortex. 3 | 4 | ''' 5 | 6 | from cortex.built_ins.datasets import * 7 | from cortex.built_ins.models import * 8 | -------------------------------------------------------------------------------- /cortex/_lib/__init__.py: -------------------------------------------------------------------------------- 1 | '''Cortex setup 2 | 3 | ''' 4 | 5 | import copy 6 | import glob 7 | import logging 8 | import os 9 | import pprint 10 | 11 | from . import config, exp, log_utils, models 12 | from .parsing import default_args, parse_args, update_args 13 | from .viz import init as viz_init 14 | 15 | __author__ = 'R Devon Hjelm' 16 | __author_email__ = 'erroneus@gmail.com' 17 | 18 | logger = logging.getLogger('cortex.init') 19 | 20 | 21 | def setup_cortex(model=None): 22 | '''Sets up cortex 23 | 24 | Finds all the models in cortex, parses the command line, and sets the 25 | logger. 26 | 27 | Returns: 28 | TODO 29 | 30 | ''' 31 | args = parse_args(models.MODEL_PLUGINS, model=model) 32 | 33 | log_utils.set_stream_logger(args.verbosity) 34 | 35 | return args 36 | 37 | 38 | def find_autoreload(out_path, global_out_path, name): 39 | out_path = out_path or global_out_path 40 | out_path = os.path.join(out_path, name) 41 | binary_dir = os.path.join(out_path, 'binaries') 42 | binaries = glob.glob(os.path.join(binary_dir, '*.t7')) 43 | binaries.sort(key=os.path.getmtime) 44 | 45 | if len(binaries) > 0: 46 | return binaries[-1] 47 | else: 48 | logger.warning('No model found to auto-reload') 49 | return None 50 | 51 | 52 | def setup_experiment(args, model=None, testmode=False): 53 | '''Sets up the experiment 54 | 55 | Args: 56 | args: TODO 57 | 58 | ''' 59 | 60 | def update_nested_dicts(from_d, to_d): 61 | for k, v in from_d.items(): 62 | if (k in to_d) and isinstance(to_d[k], dict): 63 | if not isinstance(v, dict): 64 | raise ValueError('Updating dict entry with non-dict.') 65 | update_nested_dicts(v, to_d[k]) 66 | else: 67 | to_d[k] = v 68 | 69 | exp.setup_device(args.device) 70 | 71 | if model is None: 72 | model_name = args.command 73 | model = models.get_model(model_name) 74 | else: 75 | model_name = model.__class__.__name__ 76 | 77 | experiment_args = copy.deepcopy(default_args) 78 | update_args(experiment_args, exp.ARGS) 79 | 80 | if not testmode: 81 | viz_init(config.CONFIG.viz) 82 | 83 | for k, v in vars(args).items(): 84 | if v is not None: 85 | if '.' in k: 86 | head, tail = k.split('.') 87 | elif k in model.kwargs: 88 | head = 'model' 89 | tail = k 90 | else: 91 | continue 92 | exp.ARGS[head][tail] = v 93 | 94 | reload_nets = None 95 | 96 | if args.autoreload: 97 | reload_path = find_autoreload(args.out_path, config.CONFIG.out_path, 98 | args.name or model_name) 99 | elif args.reload: 100 | reload_path = args.reload 101 | else: 102 | reload_path = None 103 | 104 | if reload_path: 105 | d = exp.reload_model(reload_path) 106 | exp.INFO.update(**d['info']) 107 | exp.NAME = exp.INFO['name'] 108 | exp.SUMMARY.update(**d['summary']) 109 | update_nested_dicts(d['args'], exp.ARGS) 110 | 111 | if args.name: 112 | exp.INFO['name'] = exp.NAME 113 | if args.out_path or args.name: 114 | exp.setup_out_dir(args.out_path, config.CONFIG.out_path, exp.NAME, 115 | clean=args.clean) 116 | else: 117 | exp.OUT_DIRS.update(**d['out_dirs']) 118 | 119 | reload_nets = d['nets'] 120 | else: 121 | if args.load_networks: 122 | d = exp.reload_model(args.load_networks) 123 | keys = args.networks_to_reload or d['nets'] 124 | for key in keys: 125 | if key not in d['nets']: 126 | raise KeyError('Model {} has no network called {}' 127 | .format(args.load_networks, key)) 128 | reload_nets = dict((k, d['nets'][k]) for k in keys) 129 | 130 | exp.NAME = args.name or model_name 131 | exp.INFO['name'] = exp.NAME 132 | exp.setup_out_dir(args.out_path, config.CONFIG.out_path, exp.NAME, 133 | clean=args.clean) 134 | 135 | update_nested_dicts(exp.ARGS['model'], model.kwargs) 136 | exp.ARGS['model'].update(**model.kwargs) 137 | 138 | exp.configure_from_yaml(config_file=args.config_file) 139 | 140 | for k, v in exp.ARGS.items(): 141 | logger.info('Ultimate {} arguments: \n{}' 142 | .format(k, pprint.pformat(v))) 143 | 144 | return model, reload_nets 145 | -------------------------------------------------------------------------------- /cortex/_lib/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Data module""" 2 | 3 | import logging 4 | 5 | from .data_handler import DataHandler 6 | 7 | __author__ = 'R Devon Hjelm' 8 | __author_email__ = 'erroneus@gmail.com' 9 | 10 | logger = logging.getLogger('cortex.data') 11 | 12 | DATA_HANDLER = DataHandler() 13 | _PLUGINS = {} 14 | 15 | 16 | def setup(source: str=None, batch_size=64, n_workers: int=4, 17 | skip_last_batch: bool=False, inputs=dict(), 18 | copy_to_local: bool=False, data_args={}, shuffle: bool=True): 19 | """ 20 | Dataset entrypoint. 21 | 22 | Args: 23 | source: Dataset source or list of sources. 24 | batch_size: Batch size or dict of batch sizes. 25 | noise_variables: Dict of noise variables. 26 | n_workers: Number of workers for DataLoader class. 27 | skip_last_batch: Whether to skip the last batch if the size 28 | is smaller than batch_size. 29 | inputs: Dictionary of input mappings. 30 | copy_to_local: Copy the data to a local path. 31 | data_args: Arguments for dataset plugin. 32 | shuffle: Shuffle the dataset. 33 | 34 | """ 35 | global DATA_HANDLER 36 | 37 | if source and not isinstance(source, (list, tuple)): 38 | sources = [source] 39 | else: 40 | sources = source 41 | 42 | DATA_HANDLER.set_batch_size(batch_size, skip_last_batch=skip_last_batch) 43 | DATA_HANDLER.set_inputs(**inputs) 44 | 45 | if sources: 46 | for source in sources: 47 | # TODO: Hardcoded for testing purpose. 48 | if not isinstance(source, str): 49 | source = 'CIFAR10' 50 | plugin = _PLUGINS.get(source, None) 51 | if plugin is None: 52 | raise KeyError('Dataset plugin for `{}` not found.' 53 | ' Available: {}' 54 | .format(source, tuple(_PLUGINS.keys()))) 55 | 56 | plugin.handle(source, copy_to_local=copy_to_local, **data_args) 57 | DATA_HANDLER.add_dataset(source, plugin, n_workers=n_workers, 58 | shuffle=shuffle) 59 | else: 60 | raise ValueError('No source provided. Use `--d.source`') 61 | 62 | 63 | def register(plugin): 64 | global _PLUGINS 65 | plugin = plugin() 66 | 67 | for k in plugin.sources: 68 | if k in _PLUGINS: 69 | raise KeyError('`{}` already registered in a plugin. ' 70 | 'Try using a different name.'.format(k)) 71 | _PLUGINS[k] = plugin 72 | 73 | 74 | class DatasetPluginBase: 75 | def __init__(self): 76 | if len(self.sources) == 0: 77 | raise ValueError('No sources found for dataset entry point.') 78 | 79 | self._datasets = {} 80 | self._dims = {} 81 | self._input_names = None 82 | self._scale = None 83 | self._dataloader_class = None 84 | -------------------------------------------------------------------------------- /cortex/_lib/data/noise.py: -------------------------------------------------------------------------------- 1 | '''Module for handling noise. 2 | 3 | ''' 4 | 5 | import torch 6 | import torch.distributions as tdist 7 | 8 | 9 | _dist_dict = dict( 10 | bernoulli=tdist.bernoulli.Bernoulli, 11 | beta=tdist.beta.Beta, 12 | binomial=tdist.binomial.Binomial, 13 | categorical=tdist.categorical.Categorical, 14 | cauchy=tdist.cauchy.Cauchy, 15 | chi2=tdist.chi2.Chi2, 16 | dirichlet=tdist.dirichlet.Dirichlet, 17 | exponential=tdist.exponential.Exponential, 18 | fishersnedecor=tdist.fishersnedecor.FisherSnedecor, 19 | gamma=tdist.gamma.Gamma, 20 | geometric=tdist.geometric.Geometric, 21 | gumbel=tdist.gumbel.Gumbel, 22 | laplace=tdist.laplace.Laplace, 23 | log_normal=tdist.log_normal.LogNormal, 24 | multinomial=tdist.multinomial.Multinomial, 25 | multivariate_normal=tdist.multivariate_normal.MultivariateNormal, 26 | normal=tdist.normal.Normal, 27 | one_hot_categorical=tdist.one_hot_categorical.OneHotCategorical, 28 | pareto=tdist.pareto.Pareto, 29 | poisson=tdist.poisson.Poisson, 30 | relaxed_bernoulli=tdist.relaxed_bernoulli.RelaxedBernoulli, 31 | relaxed_categorical=tdist.relaxed_categorical.RelaxedOneHotCategorical, 32 | studentT=tdist.studentT.StudentT, 33 | uniform=tdist.uniform.Uniform 34 | ) 35 | 36 | 37 | def get_noise_var(dist, size, **kwargs): 38 | 39 | def expand(*args): 40 | expanded = tuple() 41 | for arg in args: 42 | zeros = torch.zeros(size) 43 | if isinstance(arg, list): 44 | arg_tensor = torch.tensor(arg) 45 | expanded += (zeros + arg_tensor,) 46 | else: 47 | expanded += (zeros + arg,) 48 | return expanded 49 | 50 | Dist = _dist_dict.get(dist) 51 | if Dist is None: 52 | raise NotImplementedError(dist) 53 | 54 | if dist == 'dirichlet': 55 | conc = kwargs.pop('concentration', 1.) 56 | conc = expand(conc) 57 | var = Dist(conc, **kwargs) 58 | elif dist in ('cachy', 'gumbel', 'laplace', 'log_normal', 'normal'): 59 | loc = kwargs.pop('loc', 0.) 60 | scale = kwargs.pop('scale', 1.) 61 | loc, scale = expand(loc, scale) 62 | var = Dist(loc, scale, **kwargs) 63 | elif dist == 'uniform': 64 | low = kwargs.pop('low', 0.) 65 | high = kwargs.pop('high', 1.) 66 | low, high = expand(low, high) 67 | var = Dist(low, high, **kwargs) 68 | else: 69 | raise NotImplementedError('`{}` distribution not found'.format(dist)) 70 | 71 | return var 72 | -------------------------------------------------------------------------------- /cortex/_lib/exp.py: -------------------------------------------------------------------------------- 1 | '''Experiment module. 2 | 3 | Used for saving, loading, summarizing, etc 4 | 5 | ''' 6 | 7 | import logging 8 | import os 9 | from os import path 10 | from shutil import copyfile, rmtree 11 | import yaml 12 | 13 | import torch 14 | 15 | from .log_utils import set_file_logger 16 | 17 | __author__ = 'R Devon Hjelm' 18 | __author_email__ = 'erroneus@gmail.com' 19 | 20 | logger = logging.getLogger('cortex.exp') 21 | 22 | # Experiment info 23 | NAME = 'X' 24 | SUMMARY = {'train': {}, 'test': {}} 25 | OUT_DIRS = {} 26 | ARGS = dict(data=dict(), model=dict(), optimizer=dict(), train=dict()) 27 | INFO = {'name': NAME, 'epoch': 0} 28 | DEVICE = torch.device('cpu') 29 | 30 | 31 | def _file_string(prefix=''): 32 | if prefix == '': 33 | return NAME 34 | return '{}_{}'.format(NAME, prefix) 35 | 36 | 37 | def configure_from_yaml(config_file=None): 38 | '''Loads arguments into a yaml file. 39 | 40 | ''' 41 | global ARGS 42 | 43 | if config_file is not None: 44 | with open(config_file, 'r') as f: 45 | d = yaml.load(f) 46 | logger.info('Loading config {}'.format(d)) 47 | ARGS.model.update(**d.get('builds', {})) 48 | ARGS.optimizer.update(**d.get('optimizer', {})) 49 | ARGS.train.update(**d.get('train', {})) 50 | ARGS.data.update(**d.get('data', {})) 51 | 52 | 53 | def reload_model(model_to_reload): 54 | if not path.isfile(model_to_reload): 55 | raise ValueError('Cannot find {}'.format(model_to_reload)) 56 | 57 | logger.info('Reloading from {} and creating backup'.format(model_to_reload)) 58 | copyfile(model_to_reload, model_to_reload + '.bak') 59 | 60 | return torch.load(model_to_reload, map_location='cpu') 61 | 62 | 63 | def save(model, prefix=''): 64 | '''Saves a model. 65 | 66 | Args: 67 | model: Model to save. 68 | prefix: Prefix for the save file. 69 | 70 | ''' 71 | prefix = _file_string(prefix) 72 | binary_dir = OUT_DIRS.get('binary_dir', None) 73 | if binary_dir is None: 74 | return 75 | 76 | def strip_Nones(d): 77 | d_ = {} 78 | for k, v in d.items(): 79 | if isinstance(v, dict): 80 | d_[k] = strip_Nones(v) 81 | elif v is not None: 82 | d_[k] = v 83 | return d_ 84 | 85 | for net in model.nets.values(): 86 | if hasattr(net, 'states'): 87 | net.states.clear() 88 | 89 | state = dict( 90 | nets=dict(model.nets), 91 | info=INFO, 92 | args=ARGS, 93 | out_dirs=OUT_DIRS, 94 | summary=SUMMARY 95 | ) 96 | 97 | file_path = path.join(binary_dir, '{}.t7'.format(prefix)) 98 | logger.info('Saving checkpoint {}'.format(file_path)) 99 | torch.save(state, file_path) 100 | 101 | 102 | def setup_out_dir(out_path, global_out_path, name=None, clean=False): 103 | '''Sets up the output directory of an experiment. 104 | 105 | ''' 106 | global OUT_DIRS 107 | 108 | if out_path is None: 109 | if name is None: 110 | raise ValueError('If `out_path` (-o) argument is not set, you ' 111 | 'must set the `name` (-n)') 112 | out_path = global_out_path 113 | if out_path is None: 114 | raise ValueError('If `--out_path` (`-o`) argument is not set, you ' 115 | 'must set both the name argument and configure ' 116 | 'the out_path entry in `config.yaml`') 117 | 118 | if name is not None: 119 | out_path = path.join(out_path, name) 120 | 121 | if not path.isdir(out_path): 122 | logger.info('Creating out path `{}`'.format(out_path)) 123 | os.mkdir(out_path) 124 | 125 | binary_dir = path.join(out_path, 'binaries') 126 | image_dir = path.join(out_path, 'images') 127 | 128 | if clean: 129 | logger.warning('Cleaning directory (cannot be undone)') 130 | if path.isdir(binary_dir): 131 | rmtree(binary_dir) 132 | if path.isdir(image_dir): 133 | rmtree(image_dir) 134 | 135 | if not path.isdir(binary_dir): 136 | os.mkdir(binary_dir) 137 | if not path.isdir(image_dir): 138 | os.mkdir(image_dir) 139 | 140 | logger.info('Setting out path to `{}`'.format(out_path)) 141 | logger.info('Logging to `{}`'.format(path.join(out_path, 'out.log'))) 142 | set_file_logger(path.join(out_path, 'out.log')) 143 | 144 | OUT_DIRS.update(binary_dir=binary_dir, image_dir=image_dir) 145 | 146 | 147 | def setup_device(device): 148 | global DEVICE 149 | if torch.cuda.is_available() and device != 'cpu': 150 | if device < torch.cuda.device_count(): 151 | logger.info('Using GPU {}'.format(device)) 152 | DEVICE = torch.device(device) 153 | else: 154 | logger.info('GPU {} doesn\'t exists. Using CPU'.format(device)) 155 | else: 156 | logger.info('Using CPU') 157 | -------------------------------------------------------------------------------- /cortex/_lib/log_utils.py: -------------------------------------------------------------------------------- 1 | '''Module for logging 2 | 3 | ''' 4 | 5 | import logging 6 | 7 | __author__ = 'R Devon Hjelm' 8 | __author_email__ = 'erroneus@gmail.com' 9 | 10 | logging.basicConfig() 11 | logger = logging.getLogger('cortex') 12 | logger.propagate = False 13 | 14 | file_formatter = logging.Formatter( 15 | '%(asctime)s:%(name)s[%(levelname)s]: %(message)s\n') 16 | stream_formatter = logging.Formatter( 17 | '[%(levelname)s:%(name)s]: %(message)s' + ' ' * 40 + '\n') 18 | 19 | 20 | def set_stream_logger(verbosity): 21 | global logger 22 | 23 | if verbosity == 0: 24 | level = logging.WARNING 25 | lstr = 'WARNING' 26 | elif verbosity == 1: 27 | level = logging.INFO 28 | lstr = 'INFO' 29 | elif verbosity == 2: 30 | level = logging.DEBUG 31 | lstr = 'DEBUG' 32 | else: 33 | level = logging.INFO 34 | lstr = 'INFO' 35 | logger.setLevel(level) 36 | ch = logging.StreamHandler() 37 | ch.terminator = '' 38 | ch.setLevel(level) 39 | ch.setFormatter(stream_formatter) 40 | logger.addHandler(ch) 41 | logger.info('Setting logging to %s' % lstr) 42 | 43 | 44 | def set_file_logger(file_path): 45 | global logger 46 | fh = logging.FileHandler(file_path) 47 | fh.setLevel(logging.DEBUG) 48 | fh.setFormatter(file_formatter) 49 | logger.addHandler(fh) 50 | fh.terminator = '' 51 | -------------------------------------------------------------------------------- /cortex/_lib/optimizer.py: -------------------------------------------------------------------------------- 1 | '''Module for setting up the optimizer. 2 | 3 | ''' 4 | 5 | from collections import defaultdict 6 | import logging 7 | 8 | import torch 9 | import torch.optim as optim 10 | import torch.backends.cudnn as cudnn 11 | 12 | from . import exp 13 | 14 | 15 | __author__ = 'R Devon Hjelm' 16 | __author_email__ = 'erroneus@gmail.com' 17 | 18 | logger = logging.getLogger('cortex.optimizer') 19 | OPTIMIZERS = {} 20 | 21 | _optimizer_defaults = dict( 22 | SGD=dict(), 23 | Adam=dict(betas=(0.5, 0.999)) 24 | ) 25 | 26 | 27 | def wrap_optimizer(C): 28 | class Op(C): 29 | def __init__(self, params, clipping=None, **kwargs): 30 | super().__init__(params, **kwargs) 31 | 32 | if clipping is not None and clipping < 0.0: 33 | raise ValueError( 34 | "Invalid clipping value: {}".format(clipping)) 35 | 36 | self.defaults.update(clipping=clipping) 37 | 38 | self.state = defaultdict(dict) 39 | self.param_groups = [] 40 | 41 | param_groups = list(params) 42 | if len(param_groups) == 0: 43 | raise ValueError("optimizer got an empty parameter list") 44 | if not isinstance(param_groups[0], dict): 45 | param_groups = [{'params': param_groups}] 46 | 47 | for param_group in param_groups: 48 | self.add_param_group(param_group) 49 | 50 | def step(self, closure=None): 51 | """Performs a single optimization step. 52 | 53 | Arguments: 54 | closure (callable, optional): A closure that reevaluates the model 55 | and returns the loss. 56 | """ 57 | loss = super().step(closure=closure) 58 | 59 | for group in self.param_groups: 60 | bound = group['clipping'] 61 | if bound: 62 | for p in group['params']: 63 | p.data.clamp_(-bound, bound) 64 | return loss 65 | 66 | return Op 67 | 68 | 69 | def setup(model, optimizer='Adam', learning_rate=1.e-4, 70 | weight_decay={}, clipping={}, optimizer_options={}, 71 | model_optimizer_options={}): 72 | '''Optimizer entrypoint. 73 | 74 | Args: 75 | optimizer: Optimizer type. See `torch.optim` for supported optimizers. 76 | learning_rate: Learning rate. 77 | updates_per_routine: Updates per routine. 78 | clipping: If set, this is the clipping for each model. 79 | weight_decay: If set, this is the weight decay for specified model. 80 | optimizer_options: Optimizer options. 81 | model_optimizer_options: Optimizer options for specified model. 82 | 83 | ''' 84 | 85 | OPTIMIZERS.clear() 86 | model_optimizer_options = model_optimizer_options or {} 87 | weight_decay = weight_decay or {} 88 | clipping = clipping or {} 89 | 90 | # Set the optimizer options 91 | if len(optimizer_options) == 0: 92 | optimizer_options = 'default' 93 | if not isinstance(optimizer, str): 94 | optimizer = 'Adam' 95 | if optimizer_options == 'default'\ 96 | and optimizer in _optimizer_defaults.keys(): 97 | optimizer_options = _optimizer_defaults[optimizer] 98 | elif optimizer_options == 'default': 99 | raise ValueError( 100 | 'Default optimizer options for' 101 | ' `{}` not available.'.format(optimizer)) 102 | 103 | # Set the optimizers 104 | if callable(optimizer): 105 | op = optimizer 106 | elif hasattr(optim, optimizer): 107 | op = getattr(optim, optimizer) 108 | else: 109 | raise NotImplementedError( 110 | 'Optimizer not supported `{}`'.format(optimizer)) 111 | 112 | for network_key, network in model.nets.items(): 113 | # Set model parameters to cpu or gpu 114 | network.to(exp.DEVICE) 115 | # TODO(Devon): is the next line really doing anything? 116 | if str(exp.DEVICE) == 'cpu': 117 | pass 118 | else: 119 | torch.nn.DataParallel( 120 | network, device_ids=range( 121 | torch.cuda.device_count())) 122 | 123 | model._reset_epoch() 124 | model.data.reset(make_pbar=False, mode='test') 125 | model.train_step(_init=True) 126 | model.visualize(auto_input=True) 127 | 128 | training_nets = model._get_training_nets() 129 | 130 | logger.info('Setting up optimizers for {}'.format(set(training_nets))) 131 | 132 | for network_key in set(training_nets): 133 | logger.debug('Building optimizer for {}'.format(network_key)) 134 | network = model.nets[network_key] 135 | 136 | if isinstance(network, (tuple, list)): 137 | params = [] 138 | for net in network: 139 | params += list(net.parameters()) 140 | else: 141 | params = list(network.parameters()) 142 | 143 | # Needed for reloading. 144 | for p in params: 145 | p.requires_grad = True 146 | 147 | # Learning rates 148 | if isinstance(learning_rate, dict): 149 | eta = learning_rate[network_key] 150 | else: 151 | eta = learning_rate 152 | 153 | # Weight decay 154 | if isinstance(weight_decay, dict): 155 | wd = weight_decay.get(network_key, 0) 156 | else: 157 | wd = weight_decay 158 | 159 | if isinstance(clipping, dict): 160 | cl = clipping.get(network_key, None) 161 | else: 162 | cl = clipping 163 | 164 | # Update the optimizer options 165 | optimizer_options_ = dict((k, v) for k, v in optimizer_options.items()) 166 | optimizer_options_.update(weight_decay=wd, clipping=cl, lr=eta) 167 | 168 | if network_key in model_optimizer_options.keys(): 169 | optimizer_options_.update(**model_optimizer_options) 170 | 171 | # Create the optimizer 172 | op = wrap_optimizer(op) 173 | 174 | optimizer = op(params, **optimizer_options_) 175 | OPTIMIZERS[network_key] = optimizer 176 | 177 | logger.debug( 178 | 'Training {} routine with {}'.format( 179 | network_key, optimizer)) 180 | 181 | if not exp.DEVICE == torch.device('cpu'): 182 | cudnn.benchmark = True 183 | -------------------------------------------------------------------------------- /cortex/_lib/reg.py: -------------------------------------------------------------------------------- 1 | from . import models 2 | 3 | __author__ = 'Bradley Baker' 4 | __author_email__ = 'bbaker@mrn.org' 5 | 6 | ''' CLIPPING is a global dictionary of clipping boundaries, 7 | keyed by model name''' 8 | CLIPPING = {} 9 | 10 | ''' REGULARIZER is a global dictionary of floats, the 11 | lambda scaling factor for L1 regularization 12 | keyed by model name''' 13 | L1_DECAY = {} 14 | 15 | 16 | def init(clipping=None, weight_decay=None): 17 | '''called in setup.py, initialize clipping and 18 | weight_decay dicts''' 19 | global CLIPPING, L1_DECAY 20 | clipping = clipping or {} 21 | weight_decay = weight_decay or {} 22 | CLIPPING.update(**clipping) 23 | L1_DECAY.update(**weight_decay) 24 | 25 | 26 | def clip(key): 27 | ''' 28 | called in train.py, clip weights 29 | ''' 30 | if key not in CLIPPING: 31 | return 32 | bound = CLIPPING[key] 33 | if key in models.MODEL_HANDLER: 34 | model = models.MODEL_HANDLER[key] 35 | if isinstance(model, (list, tuple)): 36 | for net in model: 37 | for p in net.parameters(): 38 | p.data.clamp_(-bound, bound) 39 | else: 40 | for p in model.parameters(): 41 | p.data.clamp_(-bound, bound) 42 | 43 | 44 | def l1_decay(key): 45 | ''' 46 | called in train.py, do L1 regularization 47 | ''' 48 | if key not in L1_DECAY: 49 | return 50 | factor = L1_DECAY[key] 51 | if key in models.MODEL_HANDLER: 52 | model = models.MODEL_HANDLER[key] 53 | if isinstance(model, (list, tuple)): 54 | for net in model: 55 | for p in net.parameters(): 56 | p.add(factor * (p / p.norm(1))) 57 | else: 58 | for p in model.parameters(): 59 | p.add(factor * (p / p.norm(1))) 60 | -------------------------------------------------------------------------------- /cortex/_lib/utils.py: -------------------------------------------------------------------------------- 1 | '''Utility methods 2 | 3 | ''' 4 | 5 | import logging 6 | import os 7 | 8 | import numpy as np 9 | import torch 10 | 11 | __author__ = 'R Devon Hjelm' 12 | __author_email__ = 'erroneus@gmail.com' 13 | 14 | logger = logging.getLogger('cortex.util') 15 | 16 | try: 17 | _, _columns = os.popen('stty size', 'r').read().split() 18 | _columns = int(_columns) 19 | except ValueError: 20 | _columns = 1 21 | 22 | 23 | def print_section(s): 24 | '''For printing sections to scripts nicely. 25 | Args: 26 | s (str): string of section 27 | ''' 28 | h = s + ('-' * (_columns - len(s))) 29 | print(h) 30 | 31 | 32 | def update_dict_of_lists(d_to_update, **d): 33 | '''Updates a dict of list with kwargs. 34 | 35 | Args: 36 | d_to_update (dict): dictionary of lists. 37 | **d: keyword arguments to append. 38 | 39 | ''' 40 | for k, v in d.items(): 41 | if isinstance(v, dict): 42 | if k not in d_to_update.keys(): 43 | d_to_update[k] = {} 44 | update_dict_of_lists(d_to_update[k], **v) 45 | elif k in d_to_update.keys(): 46 | d_to_update[k].append(v) 47 | else: 48 | d_to_update[k] = [v] 49 | 50 | 51 | def bad_values(d): 52 | failed = {} 53 | for k, v in d.items(): 54 | if isinstance(v, dict): 55 | v_ = bad_values(v) 56 | if v_: 57 | failed[k] = v_ 58 | else: 59 | if isinstance(v, (list, tuple)): 60 | v_ = [] 61 | for v__ in v: 62 | if isinstance(v__, torch.Tensor): 63 | v_.append(v__.item()) 64 | else: 65 | v_.append(v__) 66 | v_ = np.array(v_).sum() 67 | elif isinstance(v, torch.Tensor): 68 | v_ = v.item() 69 | else: 70 | v_ = v 71 | if np.isnan(v_) or np.isinf(v_): 72 | failed[k] = v_ 73 | 74 | if len(failed) == 0: 75 | return False 76 | return failed 77 | 78 | 79 | def convert_to_numpy(o): 80 | if isinstance(o, torch.Tensor): 81 | o = o.data.cpu().numpy() 82 | if len(o.shape) == 1 and o.shape[0] == 1: 83 | o = o[0] 84 | elif isinstance(o, (torch.cuda.FloatTensor, torch.cuda.LongTensor)): 85 | o = o.cpu().numpy() 86 | elif isinstance(o, list): 87 | for i in range(len(o)): 88 | o[i] = convert_to_numpy(o[i]) 89 | elif isinstance(o, tuple): 90 | o_ = tuple() 91 | for i in range(len(o)): 92 | o_ = o_ + (convert_to_numpy(o[i]),) 93 | o = o_ 94 | elif isinstance(o, dict): 95 | for k in o.keys(): 96 | o[k] = convert_to_numpy(o[k]) 97 | return o 98 | 99 | 100 | def compute_tsne(X, perplexity=40, n_iter=300, init='pca'): 101 | from sklearn.manifold import TSNE 102 | 103 | tsne = TSNE(2, perplexity=perplexity, n_iter=n_iter, init=init) 104 | points = X.tolist() 105 | return tsne.fit_transform(points) 106 | -------------------------------------------------------------------------------- /cortex/_lib/viz_utils.py: -------------------------------------------------------------------------------- 1 | ''' This file contains different utility functions that are not connected 2 | in anyway to the networks presented in the tutorials, but rather help in 3 | processing the outputs into a more understandable way. 4 | For example ``tile_raster_images`` helps in generating a easy to grasp 5 | image from a set of samples or weights. 6 | ''' 7 | 8 | import numpy 9 | 10 | 11 | def scale_to_unit_interval(ndar, eps=1e-6): 12 | ''' Scales all values in the ndarray ndar to be between 0 and 1 ''' 13 | ndar = ndar.copy() 14 | ndar -= ndar.min() 15 | try: 16 | ndar *= 1.0 / (ndar.max() + eps) 17 | except FloatingPointError: 18 | pass 19 | return ndar 20 | 21 | 22 | def tile_raster_images(X, img_shape, tile_shape, tile_spacing=(0, 0), 23 | scale_rows_to_unit_interval=True, 24 | output_pixel_vals=True, 25 | bottom_margin=0, right_margin=0): 26 | ''' 27 | Transform an array with one flattened image per row, into an array in 28 | which images are reshaped and layed out like tiles on a floor. 29 | This function is useful for visualizing datasets whose rows are images, 30 | and also columns of matrices for transforming those rows 31 | (such as the first layer of a neural net). 32 | :type X: a 2-D ndarray or a tuple of 4 channels, elements of which can 33 | be 2-D ndarrays or None; 34 | :param X: a 2-D array in which every row is a flattened image. 35 | :type img_shape: tuple; (height, width) 36 | :param img_shape: the original shape of each image 37 | :type tile_shape: tuple; (rows, cols) 38 | :param tile_shape: the number of images to tile (rows, cols) 39 | :param output_pixel_vals: if output should be pixel values (i.e. int8 40 | values) or floats 41 | :param scale_rows_to_unit_interval: if the values need to be scaled before 42 | being plotted to [0,1] or not 43 | :returns: array suitable for viewing as an image. 44 | (See:`PIL.Image.fromarray`.) 45 | :rtype: a 2-d array with same dtype as X. 46 | ''' 47 | 48 | assert len(img_shape) == 2 49 | assert len(tile_shape) == 2 50 | assert len(tile_spacing) == 2 51 | 52 | # The expression below can be re-written in a more C style as 53 | # follows : 54 | # 55 | # out_shape = [0,0] 56 | # out_shape[0] = (img_shape[0]+tile_spacing[0])*tile_shape[0] - 57 | # tile_spacing[0] 58 | # out_shape[1] = (img_shape[1]+tile_spacing[1])*tile_shape[1] - 59 | # tile_spacing[1] 60 | out_shape = [(ishp + tsp) * tshp - tsp for ishp, tshp, tsp 61 | in zip(img_shape, tile_shape, tile_spacing)] 62 | 63 | out_shape = (out_shape[0] + bottom_margin, out_shape[1] + right_margin) 64 | 65 | if isinstance(X, tuple): 66 | assert len(X) == 4 67 | # Create an output numpy ndarray to store the image 68 | if output_pixel_vals: 69 | out_array = numpy.zeros((out_shape[0], out_shape[1], 4), 70 | dtype='uint8') 71 | else: 72 | out_array = numpy.zeros((out_shape[0], 73 | out_shape[1], 4), 74 | dtype=X.dtype) 75 | 76 | # colors default to 0, alpha defaults to 1 (opaque) 77 | if output_pixel_vals: 78 | channel_defaults = [0, 0, 0, 255] 79 | else: 80 | channel_defaults = [0., 0., 0., 1.] 81 | 82 | for i in range(4): 83 | if X[i] is None: 84 | # if channel is None, fill it with zeros of the correct 85 | # dtype 86 | dt = out_array.dtype 87 | if output_pixel_vals: 88 | dt = 'uint8' 89 | out_array[:, :, i] = numpy.zeros( 90 | out_shape, dtype=dt) + channel_defaults[i] 91 | else: 92 | # use a recurrent call to compute the channel and store it 93 | # in the output 94 | out_array[:, :, i] = tile_raster_images( 95 | X[i], img_shape, tile_shape, tile_spacing, 96 | scale_rows_to_unit_interval, output_pixel_vals) 97 | return out_array 98 | 99 | else: 100 | # if we are dealing with only one channel 101 | H, W = img_shape 102 | Hs, Ws = tile_spacing 103 | 104 | # generate a matrix to store the output 105 | dt = X.dtype 106 | out_array = numpy.zeros(out_shape, dtype='uint8') 107 | 108 | for tile_row in range(tile_shape[0]): 109 | for tile_col in range(tile_shape[1]): 110 | if tile_row * tile_shape[1] + tile_col < X.shape[0]: 111 | this_x = X[tile_row * tile_shape[1] + tile_col] 112 | 113 | out_array[ 114 | tile_row * (H + Hs): tile_row * (H + Hs) + H, 115 | tile_col * (W + Ws): tile_col * (W + Ws) + W 116 | ] = this_x 117 | 118 | return out_array 119 | -------------------------------------------------------------------------------- /cortex/built_ins/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['datasets', 'models'] 2 | -------------------------------------------------------------------------------- /cortex/built_ins/datasets/CelebA.py: -------------------------------------------------------------------------------- 1 | """ 2 | Handler for CelebA. 3 | """ 4 | 5 | import csv 6 | import os 7 | 8 | import numpy as np 9 | import torchvision 10 | from torchvision.transforms import transforms 11 | 12 | from cortex.plugins import DatasetPlugin, register_plugin 13 | from cortex.built_ins.datasets.utils import build_transforms 14 | 15 | 16 | class CelebAPlugin(DatasetPlugin): 17 | sources = ['CelebA'] 18 | 19 | def handle(self, source, copy_to_local=False, normalize=True, 20 | split=None, classification_mode=False, **transform_args): 21 | """ 22 | 23 | Args: 24 | source: 25 | copy_to_local: 26 | normalize: 27 | **transform_args: 28 | 29 | Returns: 30 | 31 | """ 32 | Dataset = self.make_indexing(CelebA) 33 | data_path = self.get_path(source) 34 | 35 | if copy_to_local: 36 | data_path = self.copy_to_local_path(data_path) 37 | 38 | if normalize and isinstance(normalize, bool): 39 | normalize = [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)] 40 | 41 | if classification_mode: 42 | train_transform = transforms.Compose([ 43 | transforms.RandomResizedCrop(64), 44 | transforms.RandomHorizontalFlip(), 45 | transforms.ToTensor(), 46 | transforms.Normalize(*normalize), 47 | ]) 48 | test_transform = transforms.Compose([ 49 | transforms.Resize(64), 50 | transforms.CenterCrop(64), 51 | transforms.ToTensor(), 52 | transforms.Normalize(*normalize), 53 | ]) 54 | else: 55 | train_transform = build_transforms(normalize=normalize, 56 | **transform_args) 57 | test_transform = train_transform 58 | 59 | if split is None: 60 | train_set = Dataset(root=data_path, transform=train_transform, 61 | download=True) 62 | test_set = Dataset(root=data_path, transform=test_transform) 63 | else: 64 | train_set, test_set = self.make_split( 65 | data_path, split, Dataset, train_transform, test_transform) 66 | input_names = ['images', 'labels', 'attributes'] 67 | 68 | dim_c, dim_x, dim_y = train_set[0][0].size() 69 | dim_l = len(train_set.classes) 70 | dim_a = train_set.attributes[0].shape[0] 71 | 72 | dims = dict(x=dim_x, y=dim_y, c=dim_c, labels=dim_l, attributes=dim_a) 73 | self.add_dataset('train', train_set) 74 | self.add_dataset('test', test_set) 75 | self.set_input_names(input_names) 76 | self.set_dims(**dims) 77 | 78 | self.set_scale((-1, 1)) 79 | 80 | def make_split(self, data_path, split, Dataset, train_transform, 81 | test_transform): 82 | train_set = Dataset(root=data_path, transform=train_transform, 83 | download=True, split=split) 84 | test_set = Dataset(root=data_path, transform=test_transform, 85 | split=split - 1) 86 | return train_set, test_set 87 | 88 | 89 | register_plugin(CelebAPlugin) 90 | 91 | 92 | class CelebA(torchvision.datasets.ImageFolder): 93 | 94 | url = ('https://www.dropbox.com/sh/8oqt9vytwxb3s4r/' 95 | 'AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=1') 96 | attr_url = ('https://www.dropbox.com/s/auexdy98c6g7y25/' 97 | 'list_attr_celeba.zip?dl=1') 98 | filename = 'img_align_celeba.zip' 99 | attr_filename = 'list_attr_celeba.zip' 100 | 101 | def __init__( 102 | self, 103 | root, 104 | transform=None, 105 | target_transform=None, 106 | download=False, 107 | split=None): 108 | self.root = os.path.expanduser(root) 109 | self.transform = transform 110 | self.target_transform = target_transform 111 | 112 | if download: 113 | self.download() 114 | 115 | self.attributes = [] 116 | 117 | attr_fpath = os.path.join(root, 'attributes', 'list_attr_celeba.txt') 118 | reader = csv.reader(open(attr_fpath), delimiter=' ', 119 | skipinitialspace=True) 120 | for i, line in enumerate(reader): 121 | if i == 0: 122 | pass 123 | elif i == 1: 124 | self.attribute_names = line 125 | else: 126 | attr = ((np.array(line[1:]).astype('int8') + 1) / 2).astype('float32') 127 | self.attributes.append(attr) 128 | 129 | super(CelebA, self).__init__(root, transform, target_transform) 130 | if split: 131 | if split > 0: 132 | index = int(split * len(self)) 133 | self.imgs = self.imgs[:index] 134 | self.attributes = self.attributes[:index] 135 | self.samples = self.samples[:index] 136 | else: 137 | index = int(split * len(self)) - 1 138 | self.imgs = self.imgs[index:] 139 | self.attributes = self.attributes[index:] 140 | self.samples = self.samples[index:] 141 | 142 | def __len__(self): 143 | return len(self.imgs) 144 | 145 | def download(self): 146 | """ 147 | 148 | Returns: 149 | 150 | """ 151 | import errno 152 | import zipfile 153 | from six.moves import urllib 154 | 155 | root = self.root 156 | url = self.url 157 | attr_url = self.attr_url 158 | 159 | root = os.path.expanduser(root) 160 | fpath = os.path.join(root, self.filename) 161 | attr_fpath = os.path.join(root, self.attr_filename) 162 | image_dir = os.path.join(root, 'images') 163 | attribute_dir = os.path.join(root, 'attributes') 164 | 165 | def get_data(data_path, zip_path, url): 166 | try: 167 | os.makedirs(data_path) 168 | except OSError as e: 169 | if e.errno == errno.EEXIST: 170 | return 171 | else: 172 | raise 173 | 174 | if os.path.isfile(zip_path): 175 | print('Using downloaded file: {}'.format(zip_path)) 176 | else: 177 | try: 178 | print('Downloading ' + url + ' to ' + zip_path) 179 | urllib.request.urlretrieve(url, zip_path) 180 | except Exception: 181 | if url[:5] == 'https': 182 | url = url.replace('https:', 'http:') 183 | print('Failed download. Trying https -> http instead.' 184 | ' Downloading ' + url + ' to ' + zip_path) 185 | urllib.request.urlretrieve(url, zip_path) 186 | else: 187 | raise 188 | print('Unzipping {}'.format(zip_path)) 189 | 190 | zip_ref = zipfile.ZipFile(zip_path, 'r') 191 | zip_ref.extractall(data_path) 192 | zip_ref.close() 193 | 194 | get_data(image_dir, fpath, url) 195 | get_data(attribute_dir, attr_fpath, attr_url) 196 | 197 | def __getitem__(self, index): 198 | output = super().__getitem__(index) 199 | return output + (self.attributes[index],) 200 | -------------------------------------------------------------------------------- /cortex/built_ins/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['CelebA', 'imagenet', 'torchvision_datasets'] 2 | -------------------------------------------------------------------------------- /cortex/built_ins/datasets/dSprites.py: -------------------------------------------------------------------------------- 1 | '''dShapes dataset 2 | 3 | Taken and adapted from https://github.com/Near32/PYTORCH_VAE 4 | 5 | ''' 6 | 7 | 8 | from os import path 9 | import urllib.request 10 | 11 | from torch.utils.data import Dataset 12 | import numpy as np 13 | from PIL import Image 14 | 15 | from . import logger 16 | 17 | 18 | DATASETS = ['dSprites'] 19 | 20 | 21 | class dSprites(Dataset): 22 | _url = ('https://github.com/deepmind/dsprites-dataset/blob/master/' 23 | 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz?raw=true') 24 | 25 | def __init__(self, root, download=True, transform=None, shuffle=False): 26 | if not root: 27 | raise ValueError('Dataset path not provided') 28 | self.root = root 29 | self.transform = transform 30 | 31 | if download: 32 | if path.isfile(root): 33 | logger.warning('File already in path, ignoring download.') 34 | else: 35 | urllib.request.urlretrieve(self._url, root) 36 | 37 | # Load dataset 38 | dataset_zip = np.load(self.root) 39 | logger.debug('Keys in the dataset:', dataset_zip.keys()) 40 | self.imgs = dataset_zip['imgs'] 41 | self.latents_values = dataset_zip['latents_values'] 42 | self.latents_classes = dataset_zip['latents_classes'] 43 | logger.info('Dataset loaded : OK.') 44 | 45 | if shuffle: 46 | self.idx = np.random.permutation(len(self)) 47 | self.imgs = self.imgs[self.idx] 48 | self.latents_classes = self.latents_classes[self.idx] 49 | self.latents_values = self.latents_values[self.idx] 50 | 51 | def __len__(self): 52 | return len(self.imgs) 53 | 54 | def __getitem__(self, idx): 55 | image = Image.fromarray(self.imgs[idx]) 56 | latent = self.latents_values[idx] 57 | 58 | if self.transform is not None: 59 | image = self.transform(image) 60 | 61 | sample = (image, latent) 62 | 63 | return sample 64 | -------------------------------------------------------------------------------- /cortex/built_ins/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | '''Handler for imagenet datasets. 2 | 3 | ''' 4 | 5 | from os import path 6 | 7 | import torchvision 8 | from torchvision.transforms import transforms 9 | 10 | from cortex.plugins import DatasetPlugin, register_data 11 | from cortex.built_ins.datasets.utils import build_transforms 12 | 13 | 14 | class ImageFolder(DatasetPlugin): 15 | sources = ['tiny-imagenet-200', 'imagenet'] 16 | 17 | def handle(self, source, copy_to_local=False, normalize=True, 18 | tanh_normalization=False, **transform_args): 19 | 20 | Dataset = self.make_indexing(torchvision.datasets.ImageFolder) 21 | data_path = self.get_path(source) 22 | 23 | if isinstance(data_path, dict): 24 | if 'train' not in data_path.keys() and 'valid' in data_path.keys(): 25 | raise ValueError('Imagenet data path must have `train` and ' 26 | '`valid` paths specified') 27 | train_path = data_path['train'] 28 | test_path = data_path['valid'] 29 | else: 30 | train_path = path.join(data_path, 'train') 31 | test_path = path.join(data_path, 'val') 32 | 33 | if copy_to_local: 34 | train_path = self.copy_to_local_path(train_path) 35 | test_path = self.copy_to_local_path(test_path) 36 | 37 | if normalize and isinstance(normalize, bool): 38 | if tanh_normalization: 39 | normalize = transforms.Normalize((0.5, 0.5, 0.5), 40 | (0.5, 0.5, 0.5)) 41 | else: 42 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 43 | std=[0.229, 0.224, 0.225]) 44 | 45 | if source == 'imagenet': 46 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 47 | std=[0.229, 0.224, 0.225]) 48 | train_transform = transforms.Compose([ 49 | transforms.RandomResizedCrop(224), 50 | transforms.RandomHorizontalFlip(), 51 | transforms.ToTensor(), 52 | normalize, 53 | ]) 54 | test_transform = transforms.Compose([ 55 | transforms.Resize(256), 56 | transforms.CenterCrop(224), 57 | transforms.ToTensor(), 58 | normalize, 59 | ]) 60 | else: 61 | train_transform = build_transforms( 62 | normalize=normalize, **transform_args) 63 | test_transform = build_transforms(normalize=normalize) 64 | train_set = Dataset(root=train_path, transform=train_transform) 65 | test_set = Dataset(root=test_path, transform=test_transform) 66 | input_names = ['images', 'targets', 'index'] 67 | 68 | dim_c, dim_x, dim_y = train_set[0][0].size() 69 | 70 | print('Computing min / max...') 71 | 72 | img_min = 1000 73 | img_max = -1000 74 | for i in range(1000): 75 | img = train_set[i][0] 76 | img_min = min(img.min(), img_min) 77 | img_max = max(img.max(), img_max) 78 | 79 | dim_l = len(train_set.classes) 80 | 81 | dims = dict(x=dim_x, y=dim_y, c=dim_c, labels=dim_l) 82 | 83 | self.add_dataset('train', train_set) 84 | self.add_dataset('test', test_set) 85 | self.set_input_names(input_names) 86 | self.set_dims(**dims) 87 | 88 | self.set_scale((img_min, img_max)) 89 | print('Finished loading dataset') 90 | 91 | 92 | register_data(ImageFolder) 93 | -------------------------------------------------------------------------------- /cortex/built_ins/datasets/nii_dataload.py: -------------------------------------------------------------------------------- 1 | '''Module for handling neuroimaging data 2 | We build an "ImageFolder" object and we can iterate/index through 3 | it. The class is initialized with a folder location, a loader (the only one 4 | we have now is for nii files), and (optionally) a list of regex patterns. 5 | 6 | The user can also provide a 3D binary mask (same size as data) to vectorize 7 | the space/voxel dimension. Can handle 3D and 3D+time (4D) datasets So, it can 8 | be built one of two ways: 9 | 1: a path to one directory with many images, and the classes are based on 10 | regex patterns. 11 | example 1a: "/home/user/some_data_path" has files *_H_*.nii and *_S_*.nii 12 | files 13 | patterned_images = ImageFolder("/home/user/some_data_path", 14 | patterns=['*_H_*','*_S_*'] , loader=nii_loader) 15 | example 1b: "/home/user/some_data_path" has files *_H_*.nii and *_S_*.nii 16 | files, and user specifies a mask to vectorize space 17 | patterned_images_mask = ImageFolder("/home/user/some_data_path", 18 | patterns=['*_H_*','*_S_*'] , loader=nii_loader, 19 | mask="/home/user/maskImage.nii") 20 | 21 | 2: a path to a top level directory with sub directories denoting the classes. 22 | example 2a: "/home/user/some_data_path" has subfolders 0 and 1 with nifti 23 | files corresponding to class 0 and class 1 respectively 24 | foldered_images = ImageFolder("/home/user/some_data_path",loader=nii_loader) 25 | example 2b: Same as above but with a mask 26 | foldered_images = ImageFolder("/home/user/some_data_path",loader=nii_loader, 27 | mask="/home/user/maskImage.nii") 28 | 29 | 30 | The final output (when we call __getitem__) is a tuple of: (image,label) 31 | ''' 32 | 33 | import torch.utils.data as data 34 | 35 | import os 36 | import os.path 37 | import numpy as np 38 | import nibabel as nib 39 | from glob import glob 40 | 41 | IMG_EXTENSIONS = ['.nii', '.nii.gz', '.img', '.hdr', '.img.gz', '.hdr.gz'] 42 | 43 | 44 | def make_dataset(dir, patterns=None): 45 | """ 46 | 47 | Args: 48 | dir: 49 | patterns: 50 | 51 | Returns: 52 | 53 | """ 54 | images = [] 55 | 56 | dir = os.path.expanduser(dir) 57 | 58 | file_list = [] 59 | 60 | all_items = [os.path.join(dir, i) for i in os.listdir(dir)] 61 | directories = [os.path.join(dir, d) for d in all_items if os.path.isdir(d)] 62 | if patterns is not None: 63 | for i, pattern in enumerate(patterns): 64 | files = [(f, i) for f in glob(os.path.join(dir, pattern))] 65 | file_list.append(files) 66 | else: 67 | file_list = [[(os.path.join(p, f), i) 68 | for f in os.listdir(p) 69 | if os.path.isfile(os.path.join(p, f))] 70 | for i, p in enumerate(directories)] 71 | 72 | for i, target in enumerate(file_list): 73 | for item in target: 74 | images.append(item) 75 | 76 | return images 77 | 78 | 79 | def nii_loader(path): 80 | """ 81 | 82 | Args: 83 | path: 84 | 85 | Returns: 86 | 87 | """ 88 | img = nib.load(path) 89 | data = img.get_data() 90 | # hdr = img.header 91 | 92 | return data 93 | 94 | 95 | class ImageFolder(data.Dataset): 96 | ''' 97 | Args: 98 | root (string): Root directory path. 99 | patterns (list): list of regex patterns 100 | loader (callable, optional): A function to load an image given its 101 | path. 102 | 103 | Attributes: 104 | imgs (list): List of (image path, class_index) tuples 105 | ''' 106 | 107 | def __init__(self, root, loader=nii_loader, patterns=None, mask=None): 108 | imgs = make_dataset(root, patterns) 109 | 110 | if len(imgs) == 0: 111 | raise ( 112 | RuntimeError( 113 | "Found 0 images in subfolders of: " + 114 | root + 115 | "\n" 116 | "Supported image extensions are: " + 117 | ",".join(IMG_EXTENSIONS))) 118 | 119 | self.root = root 120 | self.imgs = imgs 121 | 122 | self.loader = loader 123 | self.mask = mask 124 | 125 | def maskData(self, data): 126 | """ 127 | 128 | Args: 129 | data: 130 | 131 | Returns: 132 | 133 | """ 134 | 135 | msk = nib.load(self.mask) 136 | mskD = msk.get_data() 137 | if not np.all(np.bitwise_or(mskD == 0, mskD == 1)): 138 | raise ValueError("Mask has incorrect values.") 139 | # nVox = np.sum(mskD.flatten()) 140 | if data.shape[0:3] != mskD.shape: 141 | raise ValueError((data.shape, mskD.shape)) 142 | 143 | msk_f = mskD.flatten() 144 | msk_idx = np.where(msk_f == 1)[0] 145 | 146 | if len(data.shape) == 3: 147 | data_masked = data.flatten()[msk_idx] 148 | 149 | if len(data.shape) == 4: 150 | data = np.transpose(data, (3, 0, 1, 2)) 151 | data_masked = np.zeros((data.shape[0], int(mskD.sum()))) 152 | for i, x in enumerate(data): 153 | data_masked[i] = x.flatten()[msk_idx] 154 | 155 | img = data_masked 156 | 157 | return np.array(img) 158 | 159 | ''' 160 | Gives us a tuple from the array at (index) of: (image, label) 161 | ''' 162 | 163 | def __getitem__(self, index): 164 | """ 165 | Args: 166 | index (int): Index 167 | 168 | Returns: 169 | tuple: (image, target) where target is class_index of the target 170 | class. 171 | """ 172 | path, label = self.imgs[index] 173 | img = self.loader(path) 174 | if self.mask: 175 | img = self.maskData(img) 176 | 177 | return np.array(img), label 178 | 179 | def __len__(self): 180 | return len(self.imgs) 181 | -------------------------------------------------------------------------------- /cortex/built_ins/datasets/torchvision_datasets.py: -------------------------------------------------------------------------------- 1 | '''Entrypoint for torchvision datasets. 2 | 3 | ''' 4 | 5 | import os 6 | 7 | import numpy as np 8 | import torchvision 9 | from torchvision.transforms import transforms 10 | 11 | from cortex.plugins import DatasetPlugin, register_plugin 12 | from .utils import build_transforms 13 | 14 | 15 | class TorchvisionDatasetPlugin(DatasetPlugin): 16 | sources = [ 17 | 'CIFAR10', 'CIFAR100', 'CocoCaptions', 'CocoDetection', 'FakeData', 18 | 'FashionMNIST', 'ImageFolder', 'LSUN', 'LSUNClass', 'MNIST', 19 | 'PhotoTour', 'SEMEION', 'STL10', 'SVHN' 20 | ] 21 | 22 | def _handle_LSUN(self, Dataset, data_path, transform=None, **kwargs): 23 | train_set = Dataset( 24 | data_path, classes=['bedroom_train'], transform=transform) 25 | test_set = Dataset( 26 | data_path, classes=['bedroom_test'], transform=transform) 27 | return train_set, test_set 28 | 29 | def _handle_SVHN(self, Dataset, data_path, transform=None, **kwargs): 30 | train_set = Dataset( 31 | data_path, split='train', transform=transform, download=True) 32 | test_set = Dataset( 33 | data_path, split='test', transform=transform, download=True) 34 | return train_set, test_set 35 | 36 | def _handle_STL(self, Dataset, data_path, transform=None, 37 | labeled_only=False, stl_center_crop=False, 38 | stl_resize_only=False, stl_no_resize=False): 39 | normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 40 | 41 | if stl_no_resize: 42 | train_transform = transforms.Compose([ 43 | transforms.RandomHorizontalFlip(), 44 | transforms.ToTensor(), 45 | normalize, 46 | ]) 47 | test_transform = transforms.Compose([ 48 | transforms.ToTensor(), 49 | normalize, 50 | ]) 51 | else: 52 | if stl_center_crop: 53 | tr_trans = transforms.CenterCrop(64) 54 | te_trans = transforms.CenterCrop(64) 55 | elif stl_resize_only: 56 | tr_trans = transforms.Resize(64) 57 | te_trans = transforms.Resize(64) 58 | elif stl_no_resize: 59 | pass 60 | else: 61 | tr_trans = transforms.RandomResizedCrop(64) 62 | te_trans = transforms.Resize(64) 63 | 64 | train_transform = transforms.Compose([ 65 | tr_trans, 66 | transforms.RandomHorizontalFlip(), 67 | transforms.ToTensor(), 68 | normalize, 69 | ]) 70 | test_transform = transforms.Compose([ 71 | te_trans, 72 | transforms.ToTensor(), 73 | normalize, 74 | ]) 75 | if labeled_only: 76 | split = 'train' 77 | else: 78 | split = 'train+unlabeled' 79 | train_set = Dataset( 80 | data_path, split=split, transform=train_transform, download=True) 81 | test_set = Dataset( 82 | data_path, split='test', transform=test_transform, download=True) 83 | return train_set, test_set 84 | 85 | def _handle(self, Dataset, data_path, transform=None, **kwargs): 86 | train_set = Dataset( 87 | data_path, train=True, transform=transform, download=True) 88 | test_set = Dataset( 89 | data_path, train=False, transform=transform, download=True) 90 | return train_set, test_set 91 | 92 | def handle(self, source, copy_to_local=False, normalize=True, 93 | train_samples=None, test_samples=None, 94 | labeled_only=False, stl_center_crop=False, 95 | stl_resize_only=False, stl_no_resize=False, **transform_args): 96 | 97 | Dataset = getattr(torchvision.datasets, source) 98 | Dataset = self.make_indexing(Dataset) 99 | 100 | torchvision_path = self.get_path('torchvision') 101 | if not os.path.isdir(torchvision_path): 102 | os.mkdir(torchvision_path) 103 | 104 | data_path = os.path.join(torchvision_path, source) 105 | 106 | if copy_to_local: 107 | data_path = self.copy_to_local_path(data_path) 108 | 109 | if normalize and isinstance(normalize, bool): 110 | if source in [ 111 | 'MNIST', 'dSprites', 'Fashion-MNIST', 'EMNIST', 'PhotoTour' 112 | ]: 113 | normalize = [(0.5,), (0.5,)] 114 | scale = (0, 1) 115 | else: 116 | normalize = [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)] 117 | scale = (-1, 1) 118 | 119 | else: 120 | scale = None 121 | 122 | transform = build_transforms(normalize=normalize, **transform_args) 123 | 124 | if source == 'LSUN': 125 | handler = self._handle_LSUN 126 | elif source == 'SVHN': 127 | handler = self._handle_SVHN 128 | elif source == 'STL10': 129 | handler = self._handle_STL 130 | else: 131 | handler = self._handle 132 | 133 | train_set, test_set = handler(Dataset, data_path, transform=transform, 134 | labeled_only=labeled_only, 135 | stl_center_crop=stl_center_crop, 136 | stl_resize_only=stl_resize_only, 137 | stl_no_resize=stl_no_resize) 138 | if train_samples is not None: 139 | train_set.train_data = train_set.train_data[:train_samples] 140 | train_set.train_labels = train_set.train_labels[:train_samples] 141 | if test_samples is not None: 142 | test_set.test_data = test_set.test_data[:test_samples] 143 | test_set.test_labels = test_set.test_labels[:test_samples] 144 | 145 | if source in ('SVHN', 'STL10'): 146 | dim_c, dim_x, dim_y = train_set[0][0].size() 147 | uniques = np.unique(train_set.labels).tolist() 148 | try: 149 | uniques.remove(-1) 150 | except ValueError: 151 | pass 152 | dim_l = len(uniques) 153 | else: 154 | dim_c, dim_x, dim_y = train_set[0][0].size() 155 | 156 | labels = train_set.train_labels 157 | if not isinstance(labels, list): 158 | labels = labels.numpy() 159 | dim_l = len(np.unique(labels)) 160 | 161 | dims = dict(x=dim_x, y=dim_y, c=dim_c, labels=dim_l) 162 | input_names = ['images', 'targets', 'index'] 163 | 164 | self.add_dataset('train', train_set) 165 | self.add_dataset('test', test_set) 166 | self.set_input_names(input_names) 167 | self.set_dims(**dims) 168 | 169 | if scale is not None: 170 | self.set_scale(scale) 171 | 172 | 173 | register_plugin(TorchvisionDatasetPlugin) 174 | -------------------------------------------------------------------------------- /cortex/built_ins/datasets/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extra functions for build-in datasets 3 | """ 4 | 5 | import torchvision.transforms as transforms 6 | 7 | 8 | def build_transforms(normalize=True, center_crop=None, image_size=None, 9 | random_crop=None, flip=None, random_resize_crop=None, 10 | random_sized_crop=None, use_sobel=False): 11 | """ 12 | 13 | Args: 14 | normalize: 15 | center_crop: 16 | image_size: 17 | random_crop: 18 | flip: 19 | random_resize_crop: 20 | random_sized_crop: 21 | use_sobel: 22 | 23 | Returns: 24 | 25 | """ 26 | transform_ = [] 27 | 28 | if random_resize_crop: 29 | transform_.append(transforms.RandomResizedCrop(random_resize_crop)) 30 | elif random_crop: 31 | transform_.append(transforms.RandomCrop(random_crop)) 32 | elif center_crop: 33 | transform_.append(transforms.CenterCrop(center_crop)) 34 | elif random_sized_crop: 35 | transform_.append(transforms.RandomSizedCrop(random_sized_crop)) 36 | 37 | if image_size: 38 | if isinstance(image_size, int): 39 | image_size = (image_size, image_size) 40 | transform_.append(transforms.Resize(image_size)) 41 | 42 | if flip: 43 | transform_.append(transforms.RandomHorizontalFlip()) 44 | 45 | transform_.append(transforms.ToTensor()) 46 | 47 | if normalize: 48 | if isinstance(normalize, transforms.Normalize): 49 | transform_.append(normalize) 50 | else: 51 | transform_.append(transforms.Normalize(*normalize)) 52 | transform = transforms.Compose(transform_) 53 | return transform 54 | -------------------------------------------------------------------------------- /cortex/built_ins/models/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['adversarial_autoencoder', 'ae', 'ali', 'classifier', 'gan', 2 | 'mine', 'vae'] 3 | -------------------------------------------------------------------------------- /cortex/built_ins/models/adversarial_autoencoder.py: -------------------------------------------------------------------------------- 1 | '''Adversarial autoencoder 2 | 3 | ''' 4 | 5 | from cortex.plugins import ModelPlugin, register_model 6 | from cortex.built_ins.models.gan import (SimpleDiscriminator, GradientPenalty, 7 | generator_loss) 8 | from cortex.built_ins.models.vae import ImageDecoder, ImageEncoder 9 | 10 | 11 | class AdversarialAutoencoder(ModelPlugin): 12 | '''Adversarial Autoencoder 13 | 14 | Autoencoder with a GAN loss on the latent space. 15 | 16 | ''' 17 | 18 | defaults = dict( 19 | data=dict(batch_size=dict(train=64, test=640), 20 | inputs=dict(inputs='images')), 21 | optimizer=dict(optimizer='Adam', learning_rate=1e-4), 22 | train=dict(epochs=500, archive_every=10) 23 | ) 24 | 25 | def __init__(self): 26 | super().__init__() 27 | encoder_contract = dict(kwargs=dict(dim_out='dim_z')) 28 | decoder_contract = dict(kwargs=dict(dim_in='dim_z')) 29 | disc_contract = dict(kwargs=dict(dim_in='dim_z')) 30 | penalty_contract = dict(nets=dict(network='discriminator')) 31 | 32 | self.encoder = ImageEncoder(contract=encoder_contract) 33 | self.decoder = ImageDecoder(contract=decoder_contract) 34 | self.discriminator = SimpleDiscriminator(contract=disc_contract) 35 | self.penalty = GradientPenalty(contract=penalty_contract) 36 | 37 | def build(self, noise_type='normal', dim_z=64): 38 | ''' 39 | 40 | Args: 41 | noise_type: Prior noise distribution. 42 | dim_z: Dimensionality of latent space. 43 | 44 | ''' 45 | 46 | self.add_noise('Z', dist=noise_type, size=dim_z) 47 | self.encoder.build() 48 | self.decoder.build() 49 | self.discriminator.build() 50 | 51 | def routine(self, inputs, Z, encoder_loss_type='non-saturating', 52 | measure=None, beta=1.0): 53 | ''' 54 | 55 | Args: 56 | encoder_loss_type: Adversarial loss type for the encoder. 57 | beta: Amount of adversarial loss for the encoder. 58 | 59 | ''' 60 | 61 | Z_Q = self.encoder.encode(inputs) 62 | self.decoder.routine(inputs, Z_Q) 63 | 64 | E_pos, E_neg, P_samples, Q_samples = self.discriminator.score( 65 | Z, Z_Q, measure) 66 | 67 | adversarial_loss = generator_loss( 68 | Q_samples, measure, loss_type=encoder_loss_type) 69 | 70 | self.losses.encoder = self.losses.decoder + beta * adversarial_loss 71 | self.results.adversarial_loss = adversarial_loss.item() 72 | 73 | def train_step(self, n_discriminator_updates=1): 74 | ''' 75 | 76 | Args: 77 | n_discriminator_updates: Number of discriminator updates per step. 78 | 79 | ''' 80 | for _ in range(n_discriminator_updates): 81 | self.data.next() 82 | inputs, Z = self.inputs('inputs', 'Z') 83 | Z_Q = self.encoder.encode(inputs) 84 | self.discriminator.routine(Z, Z_Q) 85 | self.optimizer_step() 86 | self.penalty.routine(Z) 87 | self.optimizer_step() 88 | 89 | self.routine(auto_input=True) 90 | self.optimizer_step() 91 | 92 | def eval_step(self): 93 | self.data.next() 94 | inputs, Z = self.inputs('inputs', 'Z') 95 | Z_Q = self.encoder.encode(inputs) 96 | self.discriminator.routine(Z, Z_Q) 97 | self.penalty.routine(Z) 98 | 99 | self.routine(auto_input=True) 100 | 101 | def visualize(self, inputs, Z, targets): 102 | self.decoder.visualize(Z) 103 | self.encoder.visualize(inputs, targets) 104 | 105 | Z_Q = self.encoder.encode(inputs) 106 | R = self.decoder.decode(Z_Q) 107 | self.add_image(inputs, name='ground truth') 108 | self.add_image(R, name='reconstructed') 109 | 110 | 111 | register_model(AdversarialAutoencoder) 112 | -------------------------------------------------------------------------------- /cortex/built_ins/models/ae.py: -------------------------------------------------------------------------------- 1 | '''Module for autoencoder model. 2 | 3 | ''' 4 | 5 | import torch.nn.functional as F 6 | 7 | from cortex.plugins import ModelPlugin, register_plugin 8 | from cortex.built_ins.models.image_coders import ImageEncoder, ImageDecoder 9 | from cortex.built_ins.networks.ae_network import AENetwork 10 | 11 | 12 | class Autoencoder(ModelPlugin): 13 | '''Simple autoencder model. 14 | 15 | Trains a noiseless autoencoder of image data. 16 | 17 | ''' 18 | defaults = dict( 19 | data=dict( 20 | batch_size=dict(train=64, test=64), inputs=dict(inputs='images')), 21 | optimizer=dict(optimizer='Adam', learning_rate=1e-4), 22 | train=dict(save_on_lowest='losses.ae')) 23 | 24 | def __init__(self, Encoder=None, Decoder=None): 25 | super().__init__() 26 | if Encoder is None: 27 | Encoder = ImageEncoder 28 | if Decoder is None: 29 | Decoder = ImageDecoder 30 | self.encoder = Encoder() 31 | self.decoder = Decoder() 32 | 33 | def build(self, dim_z=64): 34 | self.encoder.build(dim_out=dim_z) 35 | self.decoder.build(dim_in=dim_z) 36 | 37 | encoder = self.nets.encoder 38 | decoder = self.nets.decoder 39 | ae = AENetwork(encoder, decoder) 40 | self.nets.ae = ae 41 | 42 | def routine(self, inputs, targets, ae_criterion=F.mse_loss): 43 | ''' 44 | 45 | Args: 46 | ae_criterion: Criterion for the autoencoder. 47 | 48 | ''' 49 | ae = self.nets.ae 50 | outputs = ae(inputs) 51 | r_loss = ae_criterion( 52 | outputs, inputs, size_average=False) / inputs.size(0) 53 | self.losses.ae = r_loss 54 | 55 | def visualize(self, inputs, targets): 56 | ae = self.nets.ae 57 | outputs = ae(inputs) 58 | self.add_image(outputs, name='reconstruction') 59 | self.add_image(inputs, name='ground truth') 60 | 61 | 62 | register_plugin(Autoencoder) 63 | -------------------------------------------------------------------------------- /cortex/built_ins/models/classifier.py: -------------------------------------------------------------------------------- 1 | '''Simple classifier model 2 | 3 | ''' 4 | 5 | 6 | from cortex.plugins import (register_plugin, ModelPlugin) 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from cortex.built_ins.networks.fully_connected import FullyConnectedNet 12 | from .utils import update_encoder_args 13 | 14 | 15 | class SimpleClassifier(ModelPlugin): 16 | '''Build a simple feed-forward classifier. 17 | 18 | ''' 19 | defaults = dict( 20 | data=dict(batch_size=128), 21 | optimizer=dict(optimizer='Adam', learning_rate=1e-4), 22 | train=dict(epochs=200, save_on_best='losses.classifier'), 23 | classifier_args=dict(dropout=0.2)) 24 | 25 | def build(self, dim_in: int=None, classifier_args=dict(dim_h=[200, 200])): 26 | ''' 27 | 28 | Args: 29 | dim_in (int): Input size 30 | classifier_args: Extra arguments for building the classifier 31 | 32 | ''' 33 | dim_l = self.get_dims('labels') 34 | classifier = FullyConnectedNet(dim_in, dim_out=dim_l, **classifier_args) 35 | self.nets.classifier = classifier 36 | 37 | def routine(self, inputs, targets, 38 | criterion=nn.CrossEntropyLoss(reduce=False)): 39 | ''' 40 | 41 | Args: 42 | criterion: Classifier criterion. 43 | 44 | ''' 45 | classifier = self.nets.classifier 46 | 47 | outputs = classifier(inputs) 48 | predicted = torch.max(F.log_softmax(outputs, dim=1).data, 1)[1] 49 | 50 | unlabeled = targets.eq(-1).long() 51 | losses = criterion(outputs, (1 - unlabeled) * targets) 52 | labeled = 1. - unlabeled.float() 53 | loss = (losses * labeled).sum() / labeled.sum() 54 | 55 | if labeled.sum() > 0: 56 | correct = 100. * (labeled * predicted.eq( 57 | targets.data).float()).cpu().sum() / labeled.cpu().sum() 58 | self.results.accuracy = correct 59 | self.losses.classifier = loss 60 | 61 | self.results.perc_labeled = labeled.mean() 62 | 63 | def predict(self, inputs): 64 | classifier = self.nets.classifier 65 | 66 | outputs = classifier(inputs) 67 | predicted = torch.max(F.log_softmax(outputs, dim=1).data, 1)[1] 68 | 69 | return predicted 70 | 71 | def visualize(self, images, inputs, targets): 72 | predicted = self.predict(inputs) 73 | self.add_image(images.data, labels=(targets.data, predicted.data), 74 | name='gt_pred') 75 | 76 | 77 | class SimpleAttributeClassifier(SimpleClassifier): 78 | '''Build a simple feed-forward classifier. 79 | 80 | ''' 81 | 82 | defaults = dict( 83 | data=dict(batch_size=128), 84 | optimizer=dict(optimizer='Adam', learning_rate=1e-4), 85 | train=dict(epochs=200, save_on_best='losses.classifier')) 86 | 87 | def build(self, dim_in: int = None, classifier_args=dict(dim_h=[200, 200])): 88 | ''' 89 | 90 | Args: 91 | dim_in (int): Input size 92 | dim_out (int): Output size 93 | dim_h (:obj:`list` of :obj:`int`): Hidden layer sizes 94 | classifier_args: Extra arguments for building the classifier 95 | 96 | ''' 97 | dim_a = self.get_dims('attributes') 98 | classifier = FullyConnectedNet(dim_in, dim_out=dim_a, **classifier_args) 99 | self.nets.classifier = classifier 100 | 101 | def routine(self, inputs, attributes): 102 | classifier = self.nets.classifier 103 | outputs = classifier(inputs, nonlinearity='sigmoid') 104 | loss = torch.nn.BCELoss()(outputs, attributes) 105 | 106 | predicted = (outputs.data >= 0.5).float() 107 | correct = 100. * predicted.eq(attributes.data).cpu().sum(0) / attributes.size(0) 108 | 109 | self.losses.classifier = loss 110 | self.results.accuracy = dict(mean=correct.float().mean(), 111 | max=correct.max(), 112 | min=correct.min()) 113 | 114 | def predict(self, inputs): 115 | classifier = self.nets.classifier 116 | outputs = classifier(inputs) 117 | predicted = (F.sigmoid(outputs).data >= 0.5).float() 118 | 119 | return predicted 120 | 121 | def visualize(self, images, inputs): 122 | self.add_image(images.data, name='gt_pred') 123 | 124 | 125 | class ImageClassification(SimpleClassifier): 126 | '''Basic image classifier. 127 | 128 | Classifies images using standard convnets. 129 | 130 | ''' 131 | 132 | defaults = dict( 133 | data=dict(batch_size=128, inputs=dict(inputs='images')), 134 | optimizer=dict(optimizer='Adam', learning_rate=1e-3), 135 | train=dict(epochs=200, save_on_best='losses.classifier')) 136 | 137 | def build(self, classifier_type='convnet', 138 | classifier_args=dict(dropout=0.2), Encoder=None): 139 | '''Builds a simple image classifier. 140 | 141 | Args: 142 | classifier_type (str): Network type for the classifier. 143 | classifier_args: Classifier arguments. Can include dropout, 144 | batch_norm, layer_norm, etc. 145 | 146 | ''' 147 | classifier_args = classifier_args or {} 148 | 149 | shape = self.get_dims('x', 'y', 'c') 150 | dim_l = self.get_dims('labels') 151 | 152 | Encoder_, args = update_encoder_args( 153 | shape, model_type=classifier_type, encoder_args=classifier_args) 154 | Encoder = Encoder or Encoder_ 155 | 156 | args.update(**classifier_args) 157 | 158 | classifier = Encoder(shape, dim_out=dim_l, **args) 159 | self.nets.classifier = classifier 160 | 161 | 162 | class ImageAttributeClassification(SimpleAttributeClassifier): 163 | '''Basic image classifier. 164 | 165 | Classifies images using standard convnets. 166 | 167 | ''' 168 | 169 | defaults = dict( 170 | data=dict(batch_size=128, inputs=dict(inputs='images')), 171 | optimizer=dict(optimizer='Adam', learning_rate=1e-3), 172 | train=dict(epochs=200, save_on_best='losses.classifier')) 173 | 174 | def build(self, classifier_type='convnet', 175 | classifier_args=dict(dropout=0.2), Encoder=None): 176 | '''Builds a simple image classifier. 177 | 178 | Args: 179 | classifier_type (str): Network type for the classifier. 180 | classifier_args: Classifier arguments. Can include dropout, 181 | batch_norm, layer_norm, etc. 182 | 183 | ''' 184 | 185 | classifier_args = classifier_args or {} 186 | 187 | shape = self.get_dims('x', 'y', 'c') 188 | dim_a = self.get_dims('attributes') 189 | 190 | Encoder_, args = update_encoder_args( 191 | shape, model_type=classifier_type, encoder_args=classifier_args) 192 | Encoder = Encoder or Encoder_ 193 | 194 | args.update(**classifier_args) 195 | 196 | classifier = Encoder(shape, dim_out=dim_a, **args) 197 | self.nets.classifier = classifier 198 | 199 | 200 | register_plugin(ImageClassification) 201 | register_plugin(ImageAttributeClassification) 202 | -------------------------------------------------------------------------------- /cortex/built_ins/models/image_coders.py: -------------------------------------------------------------------------------- 1 | from cortex.plugins import ModelPlugin 2 | from cortex.built_ins.models.utils import update_encoder_args, update_decoder_args, ms_ssim 3 | import torch.nn.functional as F 4 | 5 | 6 | class ImageEncoder(ModelPlugin): 7 | 8 | def build(self, 9 | dim_out=None, 10 | encoder_type: str = 'convnet', 11 | encoder_args=dict(fully_connected_layers=1028), 12 | Encoder=None): 13 | x_shape = self.get_dims('x', 'y', 'c') 14 | Encoder_, encoder_args = update_encoder_args( 15 | x_shape, model_type=encoder_type, encoder_args=encoder_args) 16 | Encoder = Encoder or Encoder_ 17 | encoder = Encoder(x_shape, dim_out=dim_out, **encoder_args) 18 | self.nets.encoder = encoder 19 | 20 | def encode(self, inputs, **kwargs): 21 | return self.nets.encoder(inputs, **kwargs) 22 | 23 | def visualize(self, inputs, targets): 24 | Z = self.encode(inputs) 25 | if targets is not None: 26 | targets = targets.data 27 | self.add_scatter(Z.data, labels=targets, name='latent values') 28 | 29 | 30 | class ImageDecoder(ModelPlugin): 31 | 32 | def build(self, 33 | dim_in=None, 34 | decoder_type: str = 'convnet', 35 | decoder_args=dict(output_nonlinearity='tanh'), 36 | Decoder=None): 37 | x_shape = self.get_dims('x', 'y', 'c') 38 | Decoder_, decoder_args = update_decoder_args( 39 | x_shape, model_type=decoder_type, decoder_args=decoder_args) 40 | Decoder = Decoder or Decoder_ 41 | decoder = Decoder(x_shape, dim_in=dim_in, **decoder_args) 42 | self.nets.decoder = decoder 43 | 44 | def routine(self, inputs, Z, decoder_crit=F.mse_loss): 45 | X = self.decode(Z) 46 | self.losses.decoder = decoder_crit(X, inputs) / inputs.size(0) 47 | msssim = ms_ssim(inputs, X) 48 | self.results.ms_ssim = msssim.item() 49 | 50 | def decode(self, Z): 51 | return self.nets.decoder(Z) 52 | 53 | def visualize(self, Z): 54 | gen = self.decode(Z) 55 | self.add_image(gen, name='generated') 56 | -------------------------------------------------------------------------------- /cortex/built_ins/models/mine.py: -------------------------------------------------------------------------------- 1 | '''Mutual information neural estimation 2 | 3 | ''' 4 | 5 | 6 | from cortex.plugins import register_model 7 | from cortex.built_ins.models.ali import ALIDiscriminator 8 | from cortex.built_ins.models.gan import GAN, GradientPenalty 9 | from cortex.built_ins.models.vae import ImageEncoder 10 | 11 | 12 | class MINE(ALIDiscriminator): 13 | '''Mutual information neural estimation (MINE). 14 | 15 | Estimates mutual information of two random variables. 16 | 17 | ''' 18 | 19 | def __init__(self): 20 | super().__init__(contract=dict(nets=dict(discriminator='mine'))) 21 | contract = dict(nets=dict(network='mine'), 22 | kwargs=dict(penalty_type='mine_penalty_type', 23 | penalty_amount='mine_penalty_amount')) 24 | self.penalty = GradientPenalty(contract=contract) 25 | 26 | def routine(self, X, X_m, Z, Z_m, mine_measure='JSD'): 27 | ''' 28 | 29 | Args: 30 | mine_measure: MINE measure. 31 | {GAN, JSD, KL, RKL (reverse KL), X2 (Chi^2), H2 32 | (squared Hellinger), DV (Donsker Varahdan KL), W1 (IPM)} 33 | 34 | ''' 35 | 36 | super().routine(X, X_m, Z, Z_m, measure=mine_measure) 37 | 38 | 39 | class GAN_MINE(GAN): 40 | '''GAN + MINE. 41 | 42 | A generative adversarial network trained with MI maximization. 43 | 44 | ''' 45 | def __init__(self): 46 | super().__init__() 47 | self.mine = MINE() 48 | 49 | encoder_contract = dict(nets=dict(encoder='x_encoder'), 50 | kwargs=dict(dim_out='dim_int')) 51 | self.encoder = ImageEncoder(contract=encoder_contract) 52 | 53 | def build(self, noise_type='normal', dim_z=64): 54 | super().build(noise_type=noise_type, dim_z=dim_z) 55 | self.encoder.build() 56 | self.mine.build() 57 | 58 | def routine(self, Z, Z_m, mine_measure=None, beta=1.0): 59 | ''' 60 | 61 | Args: 62 | beta: Factor for mutual information maximization for generator. 63 | 64 | ''' 65 | self.generator.routine(Z) 66 | X = self.generator.generate(Z) 67 | E_pos, E_neg, _, _ = self.mine.score(X, X, Z, Z_m, mine_measure) 68 | 69 | self.losses.generator += (E_neg - E_pos) 70 | 71 | def train_step(self, mine_updates=1, discriminator_updates=1): 72 | ''' 73 | 74 | Args: 75 | mine_updates: Number of MINE updates per step. 76 | discriminator_updates: Number of discriminator updates per step. 77 | 78 | ''' 79 | 80 | for _ in range(discriminator_updates): 81 | self.data.next() 82 | inputs, Z = self.inputs('inputs', 'Z') 83 | generated = self.generator.generate(Z) 84 | self.discriminator.routine(inputs, generated.detach()) 85 | self.optimizer_step() 86 | self.penalty.train_step() 87 | 88 | Z_P = Z 89 | 90 | for _ in range(mine_updates): 91 | self.data.next() 92 | Z = self.inputs('Z') 93 | generated = self.generator.generate(Z) 94 | self.mine.routine(generated, Z, generated, Z_P) 95 | self.optimizer_step() 96 | 97 | self.routine(Z, Z_P) 98 | self.optimizer_step() 99 | 100 | def eval_step(self): 101 | self.data.next() 102 | inputs, Z = self.inputs('inputs', 'Z') 103 | generated = self.generator.generate(Z) 104 | self.discriminator.routine(inputs, generated.detach()) 105 | 106 | Z_P = Z 107 | 108 | self.data.next() 109 | Z = self.inputs('Z') 110 | generated = self.generator.generate(Z) 111 | self.mine.routine(generated, Z, generated, Z_P) 112 | 113 | self.routine(Z, Z_P) 114 | 115 | def visualize(self, images, Z, targets): 116 | self.add_image(images, name='ground truth') 117 | generated = self.generator.generate(Z) 118 | self.discriminator.visualize(images, generated) 119 | self.generator.visualize(Z) 120 | self.data.next() 121 | Z_N = self.inputs('Z') 122 | 123 | self.mine.visualize(generated, generated, Z, Z_N, targets) 124 | 125 | 126 | register_model(GAN_MINE) 127 | -------------------------------------------------------------------------------- /cortex/built_ins/models/utils.py: -------------------------------------------------------------------------------- 1 | '''Model misc utilities. 2 | 3 | ''' 4 | 5 | import logging 6 | import math 7 | 8 | from sklearn import svm 9 | import torch 10 | 11 | 12 | logger = logging.getLogger('cortex.arch' + __name__) 13 | 14 | 15 | def log_sum_exp(x, axis=None): 16 | x_max = torch.max(x, axis)[0] 17 | y = torch.log((torch.exp(x - x_max)).sum(axis)) + x_max 18 | return y 19 | 20 | 21 | def cross_correlation(X, remove_diagonal=False): 22 | X_s = X / X.std(0) 23 | X_m = X_s - X_s.mean(0) 24 | b, dim = X_m.size() 25 | correlations = (X_m.unsqueeze(2).expand(b, dim, dim) * 26 | X_m.unsqueeze(1).expand(b, dim, dim)).sum(0) / float(b) 27 | if remove_diagonal: 28 | Id = torch.eye(dim) 29 | Id = torch.autograd.Variable(Id.cuda(), requires_grad=False) 30 | correlations -= Id 31 | 32 | return correlations 33 | 34 | 35 | def perform_svc(X, Y, clf=None): 36 | if clf is None: 37 | clf = svm.LinearSVC() 38 | clf.fit(X, Y) 39 | 40 | Y_hat = clf.predict(X) 41 | 42 | return clf, Y_hat 43 | 44 | 45 | def ms_ssim(X_a, X_b, window_size=11, size_average=True, C1=0.01**2, C2=0.03**2): 46 | ''' 47 | Taken from Po-Hsun-Su/pytorch-ssim 48 | ''' 49 | 50 | channel = X_a.size(1) 51 | 52 | def gaussian(sigma=1.5): 53 | gauss = torch.Tensor( 54 | [math.exp(-(x - window_size // 2) ** 55 | 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 56 | return gauss / gauss.sum() 57 | 58 | def create_window(): 59 | _1D_window = gaussian(window_size).unsqueeze(1) 60 | _2D_window = _1D_window.mm( 61 | _1D_window.t()).float().unsqueeze(0).unsqueeze(0) 62 | window = torch.Tensor( 63 | _2D_window.expand(channel, 1, window_size, 64 | window_size).contiguous()) 65 | return window.cuda() 66 | 67 | window = create_window() 68 | 69 | mu1 = torch.nn.functional.conv2d(X_a, window, 70 | padding=window_size // 2, groups=channel) 71 | mu2 = torch.nn.functional.conv2d(X_b, window, 72 | padding=window_size // 2, groups=channel) 73 | 74 | mu1_sq = mu1.pow(2) 75 | mu2_sq = mu2.pow(2) 76 | mu1_mu2 = mu1 * mu2 77 | 78 | sigma1_sq = torch.nn.functional.conv2d( 79 | X_a * X_a, window, padding=window_size // 2, groups=channel) - mu1_sq 80 | sigma2_sq = torch.nn.functional.conv2d( 81 | X_b * X_b, window, padding=window_size // 2, groups=channel) - mu2_sq 82 | sigma12 = torch.nn.functional.conv2d( 83 | X_a * X_b, window, padding=window_size // 2, groups=channel) - mu1_mu2 84 | 85 | ssim_map = (((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / 86 | ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))) 87 | 88 | if size_average: 89 | return ssim_map.mean() 90 | else: 91 | return ssim_map.mean(1).mean(1).mean(1) 92 | 93 | 94 | resnet_encoder_args_ = dict(dim_h=64, batch_norm=True, f_size=3, n_steps=3) 95 | mnist_encoder_args_ = dict(dim_h=64, batch_norm=True, f_size=5, 96 | pad=2, stride=2, min_dim=7) 97 | convnet_encoder_args_ = dict(dim_h=64, batch_norm=True, n_steps=3) 98 | 99 | 100 | def update_encoder_args(x_shape, model_type='convnet', encoder_args=None): 101 | encoder_args = encoder_args or {} 102 | if model_type == 'resnet': 103 | from cortex.built_ins.networks.resnets import ResEncoder as Encoder 104 | encoder_args_ = {k: v for k, v in resnet_encoder_args_.items()} 105 | elif model_type == 'convnet': 106 | from cortex.built_ins.networks.convnets import SimpleConvEncoder as Encoder 107 | encoder_args_ = {k: v for k, v in convnet_encoder_args_.items()} 108 | elif model_type == 'mnist': 109 | from cortex.built_ins.networks.convnets import SimpleConvEncoder as Encoder 110 | encoder_args_ = {k: v for k, v in mnist_encoder_args_.items()} 111 | elif model_type.split('.')[0] == 'tv': 112 | from cortex.built_ins.networks.torchvision import models 113 | model_attributes = model_type.split('.') 114 | if len(model_attributes) != 2: 115 | raise ValueError('`tvr` model type should be in form `tv.`') 116 | model_key = model_attributes[1] 117 | 118 | try: 119 | tv_model = getattr(models, model_key) 120 | except AttributeError: 121 | raise NotImplementedError(model_attributes[1]) 122 | 123 | # TODO This lambda function is necessary because Encoder takes shape 124 | # and dim_out. 125 | Encoder = (lambda shape, dim_out=None, n_steps=None, 126 | **kwargs: tv_model(num_classes=dim_out, **kwargs)) 127 | encoder_args_ = {} 128 | elif model_type.split('.')[0] == 'tv-wrapper': 129 | from cortex.built_ins.networks import tv_models_wrapper as models 130 | model_attributes = model_type.split('.') 131 | 132 | if len(model_attributes) != 2: 133 | raise ValueError( 134 | '`tv-wrapper` model type should be in form' 135 | ' `tv-wrapper.`') 136 | model_key = model_attributes[1] 137 | 138 | try: 139 | Encoder = getattr(models, model_key) 140 | except AttributeError: 141 | raise NotImplementedError(model_attributes[1]) 142 | encoder_args_ = {} 143 | else: 144 | raise NotImplementedError(model_type) 145 | 146 | encoder_args_.update(**encoder_args) 147 | if x_shape[0] == 64: 148 | encoder_args_['n_steps'] = 4 149 | elif x_shape[0] == 128: 150 | encoder_args_['n_steps'] = 5 151 | 152 | return Encoder, encoder_args_ 153 | 154 | 155 | resnet_decoder_args_ = dict(dim_h=64, batch_norm=True, f_size=3, n_steps=3) 156 | mnist_decoder_args_ = dict(dim_h=64, batch_norm=True, f_size=4, 157 | pad=1, stride=2, n_steps=2) 158 | convnet_decoder_args_ = dict(dim_h=64, batch_norm=True, n_steps=3) 159 | 160 | 161 | def update_decoder_args(x_shape, model_type='convnet', decoder_args=None): 162 | decoder_args = decoder_args or {} 163 | 164 | if model_type == 'resnet': 165 | from cortex.built_ins.networks.resnets import ResDecoder as Decoder 166 | decoder_args_ = {k: v for k, v in resnet_decoder_args_.items()} 167 | elif model_type == 'convnet': 168 | from cortex.built_ins.networks.conv_decoders import ( 169 | SimpleConvDecoder as Decoder) 170 | decoder_args_ = {k: v for k, v in convnet_decoder_args_.items()} 171 | elif model_type == 'mnist': 172 | from cortex.built_ins.networks.conv_decoders import ( 173 | SimpleConvDecoder as Decoder) 174 | decoder_args_ = {k: v for k, v in mnist_decoder_args_.items()} 175 | else: 176 | raise NotImplementedError(model_type) 177 | 178 | decoder_args_.update(**decoder_args) 179 | if x_shape[0] >= 64: 180 | decoder_args_['n_steps'] = 4 181 | elif x_shape[0] == 128: 182 | decoder_args_['n_steps'] = 5 183 | 184 | return Decoder, decoder_args_ 185 | 186 | 187 | def to_one_hot(y, K): 188 | y_ = torch.unsqueeze(y, 1).long() 189 | 190 | one_hot = torch.zeros(y.size(0), K).cuda() 191 | one_hot.scatter_(1, y_.data.cuda(), 1) 192 | return torch.tensor(one_hot) 193 | -------------------------------------------------------------------------------- /cortex/built_ins/models/vae.py: -------------------------------------------------------------------------------- 1 | '''Simple Variational Autoencoder model. 2 | ''' 3 | 4 | import logging 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | 11 | from cortex.built_ins.models.utils import ms_ssim 12 | from cortex.built_ins.models.image_coders import ImageDecoder, ImageEncoder 13 | from cortex.plugins import ModelPlugin, register_plugin 14 | 15 | 16 | __author__ = 'R Devon Hjelm and Samuel Lavoie' 17 | __author_email__ = 'erroneus@gmail.com' 18 | 19 | logger = logging.getLogger('cortex.vae') 20 | 21 | 22 | class VAENetwork(nn.Module): 23 | '''VAE model. 24 | 25 | Attributes: 26 | encoder: Encoder network. 27 | mu_net: Single layer network for caculating mean. 28 | logvar_net: Single layer network for calculating log variance. 29 | decoder: Decoder network. 30 | mu: The mean after encoding. 31 | logvar: The log variance after encoding. 32 | latent: The latent state (Z). 33 | 34 | ''' 35 | 36 | def __init__(self, encoder, decoder, dim_out=None, dim_z=None): 37 | super(VAENetwork, self).__init__() 38 | self.encoder = encoder 39 | self.mu_net = nn.Linear(dim_out, dim_z) 40 | self.logvar_net = nn.Linear(dim_out, dim_z) 41 | self.decoder = decoder 42 | self.mu = None 43 | self.logvar = None 44 | self.latent = None 45 | 46 | def encode(self, inputs, **kwargs): 47 | encoded = self.encoder(inputs, **kwargs) 48 | encoded = F.relu(encoded) 49 | return self.mu_net(encoded) 50 | 51 | def reparametrize(self, mu, std): 52 | if self.training: 53 | esp = Variable( 54 | std.data.new(std.size()).normal_(), requires_grad=False).cuda() 55 | return mu + std * esp 56 | else: 57 | return mu 58 | 59 | def forward(self, x, nonlinearity=None): 60 | encoded = self.encoder(x) 61 | encoded = F.relu(encoded) 62 | self.mu = self.mu_net(encoded) 63 | self.std = self.logvar_net(encoded).exp_() 64 | self.latent = self.reparametrize(self.mu, self.std) 65 | return self.decoder(self.latent, nonlinearity=nonlinearity) 66 | 67 | 68 | class VAE(ModelPlugin): 69 | '''Variational autoencder. 70 | 71 | A generative model trained using the variational lower-bound to the 72 | log-likelihood. 73 | See: Kingma, Diederik P., and Max Welling. 74 | "Auto-encoding variational bayes." arXiv preprint arXiv:1312.6114 (2013). 75 | 76 | ''' 77 | 78 | defaults = dict( 79 | data=dict( 80 | batch_size=dict(train=64, test=640), inputs=dict(inputs='images')), 81 | optimizer=dict(optimizer='Adam', learning_rate=1e-4), 82 | train=dict(save_on_lowest='losses.vae')) 83 | 84 | def __init__(self): 85 | super().__init__() 86 | 87 | self.encoder = ImageEncoder(contract=dict( 88 | kwargs=dict(dim_out='dim_encoder_out'))) 89 | decoder_contract = dict(kwargs=dict(dim_in='dim_z')) 90 | self.decoder = ImageDecoder(contract=decoder_contract) 91 | 92 | def build(self, dim_z=64, dim_encoder_out=1024): 93 | ''' 94 | 95 | Args: 96 | dim_z: Latent dimension. 97 | dim_encoder_out: Dimension of the final layer of the decoder before 98 | decoding to mu and log sigma. 99 | 100 | ''' 101 | self.encoder.build() 102 | self.decoder.build() 103 | 104 | self.add_noise('Z', dist='normal', size=dim_z) 105 | encoder = self.nets.encoder 106 | decoder = self.nets.decoder 107 | vae = VAENetwork(encoder, decoder, dim_out=dim_encoder_out, dim_z=dim_z) 108 | self.nets.vae = vae 109 | 110 | def routine(self, inputs, targets, Z, vae_criterion=F.mse_loss, 111 | beta_kld=1.): 112 | ''' 113 | 114 | Args: 115 | vae_criterion: Reconstruction criterion. 116 | beta_kld: Beta scaling for KL term in lower-bound. 117 | 118 | ''' 119 | 120 | vae = self.nets.vae 121 | outputs = vae(inputs) 122 | 123 | try: 124 | r_loss = vae_criterion( 125 | outputs, inputs, size_average=False) / inputs.size(0) 126 | except RuntimeError as e: 127 | logger.error('Runtime error. This could possibly be due to using ' 128 | 'the wrong encoder / decoder for this dataset. ' 129 | 'If you are using MNIST, for example, use the ' 130 | 'arguments `--encoder_type mnist --decoder_type ' 131 | 'mnist`') 132 | raise e 133 | 134 | kl = (0.5 * (vae.std**2 + vae.mu**2 - 2. * torch.log(vae.std) - 135 | 1.).sum(1).mean()) 136 | 137 | msssim = ms_ssim(inputs, outputs) 138 | 139 | self.losses.vae = (r_loss + beta_kld * kl) 140 | self.results.update(KL_divergence=kl.item(), ms_ssim=msssim.item()) 141 | 142 | def visualize(self, inputs, targets, Z): 143 | vae = self.nets.vae 144 | 145 | outputs = vae(inputs) 146 | 147 | self.add_image(outputs, name='reconstruction') 148 | self.add_image(inputs, name='ground truth') 149 | self.add_scatter(vae.mu.data, labels=targets.data, name='latent values') 150 | self.decoder.visualize(Z) 151 | 152 | 153 | register_plugin(VAE) 154 | -------------------------------------------------------------------------------- /cortex/built_ins/networks/SpectralNormLayer.py: -------------------------------------------------------------------------------- 1 | # Implementation based on original paper: 2 | # https://github.com/pfnet-research/sngan_projection 3 | 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import torch 7 | 8 | 9 | def l2normalize(v, esp=1e-8): 10 | return v / (v.norm() + esp) 11 | 12 | 13 | def sn_weight(weight, u, height, n_power_iterations): 14 | weight.requires_grad_(False) 15 | for _ in range(n_power_iterations): 16 | v = l2normalize(torch.mv(weight.view(height, -1).t(), u)) 17 | u = l2normalize(torch.mv(weight.view(height, -1), v)) 18 | 19 | weight.requires_grad_(True) 20 | sigma = u.dot(weight.view(height, -1).mv(v)) 21 | return torch.div(weight, sigma), u 22 | 23 | 24 | class SNConv2d(nn.Conv2d): 25 | def __init__(self, *args, n_power_iterations=1, **kwargs): 26 | super(SNConv2d, self).__init__(*args, **kwargs) 27 | self.n_power_iterations = n_power_iterations 28 | self.height = self.weight.shape[0] 29 | self.register_buffer( 30 | 'u', l2normalize(self.weight.new_empty(self.height).normal_(0, 1))) 31 | 32 | def forward(self, input): 33 | w_sn, self.u = sn_weight(self.weight, self.u, self.height, 34 | self.n_power_iterations) 35 | return F.conv2d(input, w_sn, self.bias, self.stride, 36 | self.padding, self.dilation, self.groups) 37 | 38 | 39 | class SNLinear(nn.Linear): 40 | def __init__(self, *args, n_power_iterations=1, **kwargs): 41 | super(SNLinear, self).__init__(*args, **kwargs) 42 | self.n_power_iterations = n_power_iterations 43 | self.height = self.weight.shape[0] 44 | self.register_buffer( 45 | 'u', l2normalize(self.weight.new(self.height).normal_(0, 1))) 46 | 47 | def forward(self, input): 48 | w_sn, self.u = sn_weight( 49 | self.weight, self.u, self.height, self.n_power_iterations) 50 | return F.linear(input, w_sn, self.bias) 51 | -------------------------------------------------------------------------------- /cortex/built_ins/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/cortex/built_ins/networks/__init__.py -------------------------------------------------------------------------------- /cortex/built_ins/networks/ae_network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class AENetwork(nn.Module): 5 | def __init__(self, encoder, decoder): 6 | super(AENetwork, self).__init__() 7 | self.encoder = encoder 8 | self.decoder = decoder 9 | 10 | def forward(self, x, nonlinearity=None): 11 | encoded = self.encoder(x) 12 | decoded = self.decoder(encoded) 13 | return decoded 14 | -------------------------------------------------------------------------------- /cortex/built_ins/networks/base_network.py: -------------------------------------------------------------------------------- 1 | '''A base network for handling common arguments in cortex models. 2 | 3 | This is not necessary to use cortex: these are just convenience networks. 4 | 5 | ''' 6 | 7 | import torch.nn as nn 8 | import torch 9 | 10 | from .utils import apply_nonlinearity, get_nonlinearity, finish_layer_1d 11 | 12 | 13 | class BaseNet(nn.Module): 14 | '''Basic convenience network for cortex. 15 | 16 | Attributes: 17 | models: A sequence of 18 | 19 | ''' 20 | 21 | def __init__(self, nonlinearity='ReLU', output_nonlinearity=None): 22 | super(BaseNet, self).__init__() 23 | 24 | self.models = nn.Sequential() 25 | 26 | self.output_nonlinearity = output_nonlinearity 27 | self.layer_nonlinearity = get_nonlinearity(nonlinearity) 28 | 29 | def forward(self, 30 | x: torch.Tensor, 31 | nonlinearity: str = None, 32 | **nonlinearity_args: dict) -> torch.Tensor: 33 | self.states = [] 34 | if nonlinearity is None: 35 | nonlinearity = self.output_nonlinearity 36 | elif not nonlinearity: 37 | nonlinearity = None 38 | 39 | for model in self.models: 40 | x = model(x) 41 | self.states.append(x) 42 | x = apply_nonlinearity(x, nonlinearity, **nonlinearity_args) 43 | return x 44 | 45 | def get_h(self, dim_h, n_levels=None): 46 | if isinstance(dim_h, (list, tuple)): 47 | pass 48 | elif n_levels: 49 | dim_h = [dim_h for _ in range(n_levels)] 50 | else: 51 | dim_h = [dim_h] 52 | 53 | return dim_h 54 | 55 | def add_linear_layers(self, 56 | dim_in, 57 | dim_h, 58 | dim_ex=None, 59 | Linear=None, 60 | **layer_args): 61 | Linear = Linear or nn.Linear 62 | 63 | if dim_h is None or len(dim_h) == 0: 64 | return dim_in 65 | 66 | for dim_out in dim_h: 67 | name = 'linear_({}/{})'.format(dim_in, dim_out) 68 | self.models.add_module(name, Linear(dim_in, dim_out)) 69 | finish_layer_1d( 70 | self.models, 71 | name, 72 | dim_out, 73 | nonlinearity=self.layer_nonlinearity, 74 | **layer_args) 75 | dim_in = dim_out 76 | if dim_ex is not None: 77 | dim_in += dim_ex 78 | 79 | return dim_out 80 | 81 | def add_output_layer(self, dim_in, dim_out, Linear=None): 82 | 83 | Linear = Linear or nn.Linear 84 | if dim_out is not None: 85 | name = 'linear_({}/{})_{}'.format(dim_in, dim_out, 'out') 86 | self.models.add_module(name, Linear(dim_in, dim_out)) 87 | 88 | 89 | def make_subnet(from_network, n_layers): 90 | '''Makes a subnet out of another net. 91 | 92 | Shares parameters with original network. 93 | 94 | Args: 95 | from_network: Network to derive subnet from. 96 | n_layers: Number of layers from network to use. 97 | 98 | Returns: 99 | A Subnet for the original network. 100 | 101 | ''' 102 | to_network = BaseNet() 103 | to_network.models = from_network.models[:n_layers] 104 | return to_network 105 | -------------------------------------------------------------------------------- /cortex/built_ins/networks/conv_decoders.py: -------------------------------------------------------------------------------- 1 | '''Convolutional decoders 2 | 3 | ''' 4 | 5 | import logging 6 | 7 | import torch.nn as nn 8 | 9 | from .modules import View 10 | from .base_network import BaseNet 11 | from .utils import finish_layer_2d 12 | 13 | 14 | logger = logging.getLogger('cortex.models' + __name__) 15 | 16 | 17 | def infer_conv_size(w, k, s, p): 18 | x = (w - k + 2 * p) // s + 1 19 | return x 20 | 21 | 22 | class SimpleConvDecoder(BaseNet): 23 | def __init__(self, shape, dim_in=None, initial_layer=None, dim_h=64, 24 | nonlinearity='ReLU', output_nonlinearity=None, 25 | f_size=4, stride=2, pad=1, n_steps=3, **layer_args): 26 | super(SimpleConvDecoder, self).__init__( 27 | nonlinearity=nonlinearity, output_nonlinearity=output_nonlinearity) 28 | 29 | dim_h_ = dim_h 30 | logger.debug('Input shape: {}'.format(shape)) 31 | dim_x_, dim_y_, dim_out_ = shape 32 | 33 | dim_x = dim_x_ 34 | dim_y = dim_y_ 35 | dim_h = dim_h_ 36 | 37 | saved_spatial_dimensions = [(dim_x, dim_y)] 38 | 39 | for n in range(n_steps): 40 | dim_x, dim_y = self.next_size(dim_x, dim_y, f_size, stride, pad) 41 | saved_spatial_dimensions.append((dim_x, dim_y)) 42 | if n < n_steps - 1: 43 | dim_h *= 2 44 | 45 | dim_out = dim_x * dim_y * dim_h 46 | 47 | if initial_layer is not None: 48 | dim_h_ = [initial_layer, dim_out] 49 | else: 50 | dim_h_ = [dim_out] 51 | 52 | self.add_linear_layers(dim_in, dim_h_, **layer_args) 53 | 54 | name = 'reshape' 55 | self.models.add_module(name, View(-1, dim_h, dim_x, dim_y)) 56 | 57 | finish_layer_2d(self.models, name, dim_x, dim_y, dim_h, 58 | nonlinearity=self.layer_nonlinearity, **layer_args) 59 | dim_out = dim_h 60 | 61 | for i in range(n_steps): 62 | dim_in = dim_out 63 | 64 | if i == n_steps - 1: 65 | pass 66 | else: 67 | dim_out = dim_in // 2 68 | 69 | name = 'tconv_({}/{})_{}'.format(dim_in, dim_out, i + 1) 70 | self.models.add_module( 71 | name, nn.ConvTranspose2d(dim_in, dim_out, f_size, stride, pad, 72 | bias=False)) 73 | 74 | finish_layer_2d(self.models, name, dim_x, dim_y, dim_out, 75 | nonlinearity=self.layer_nonlinearity, **layer_args) 76 | 77 | self.models.add_module(name + 'f', nn.Conv2d( 78 | dim_out, dim_out_, 3, 1, 1, bias=False)) 79 | 80 | def next_size(self, dim_x, dim_y, k, s, p): 81 | if isinstance(k, int): 82 | kx, ky = (k, k) 83 | else: 84 | kx, ky = k 85 | 86 | if isinstance(s, int): 87 | sx, sy = (s, s) 88 | else: 89 | sx, sy = s 90 | 91 | if isinstance(p, int): 92 | px, py = (p, p) 93 | else: 94 | px, py = p 95 | 96 | return infer_conv_size( 97 | dim_x, kx, sx, px), infer_conv_size(dim_y, ky, sy, py) 98 | -------------------------------------------------------------------------------- /cortex/built_ins/networks/convnets.py: -------------------------------------------------------------------------------- 1 | '''Convoluational encoders 2 | 3 | ''' 4 | 5 | import logging 6 | 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from .SpectralNormLayer import SNConv2d, SNLinear 10 | 11 | from .modules import View 12 | from .base_network import BaseNet 13 | from .utils import finish_layer_2d 14 | 15 | logger = logging.getLogger('cortex.arch' + __name__) 16 | 17 | 18 | def infer_conv_size(w, k, s, p): 19 | x = (w - k + 2 * p) // s + 1 20 | return x 21 | 22 | 23 | class SimpleNet(nn.Module): 24 | 25 | def __init__(self): 26 | super(SimpleNet, self).__init__() 27 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 28 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 29 | self.conv2_drop = nn.Dropout2d() 30 | self.fc1 = nn.Linear(320, 50) 31 | self.fc2 = nn.Linear(50, 10) 32 | 33 | def forward(self, x): 34 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 35 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 36 | x = x.view(-1, 320) 37 | x = F.relu(self.fc1(x)) 38 | x = F.dropout(x, training=self.training) 39 | x = self.fc2(x) 40 | return F.log_softmax(x) 41 | 42 | 43 | class SimpleConvEncoder(BaseNet): 44 | def __init__(self, shape, dim_out=None, dim_h=64, 45 | fully_connected_layers=None, nonlinearity='ReLU', 46 | output_nonlinearity=None, f_size=4, 47 | stride=2, pad=1, min_dim=4, n_steps=None, normalize_input=False, 48 | spectral_norm=False, last_conv_nonlinearity=True, 49 | last_batchnorm=True, **layer_args): 50 | super(SimpleConvEncoder, self).__init__( 51 | nonlinearity=nonlinearity, output_nonlinearity=output_nonlinearity) 52 | 53 | Conv2d = SNConv2d if spectral_norm else nn.Conv2d 54 | Linear = SNLinear if spectral_norm else nn.Linear 55 | 56 | dim_out_ = dim_out 57 | fully_connected_layers = fully_connected_layers or [] 58 | if isinstance(fully_connected_layers, int): 59 | fully_connected_layers = [fully_connected_layers] 60 | 61 | logger.debug('Input shape: {}'.format(shape)) 62 | dim_x, dim_y, dim_in = shape 63 | 64 | if isinstance(dim_h, list): 65 | n_steps = len(dim_h) 66 | 67 | if normalize_input: 68 | self.models.add_module('initial_bn', nn.BatchNorm2d(dim_in)) 69 | 70 | i = 0 71 | logger.debug('Input size: {},{}'.format(dim_x, dim_y)) 72 | while ((dim_x >= min_dim and dim_y >= min_dim) and 73 | (i < n_steps if n_steps else True)): 74 | if i == 0: 75 | if isinstance(dim_h, list): 76 | dim_out = dim_h[0] 77 | else: 78 | dim_out = dim_h 79 | else: 80 | dim_in = dim_out 81 | if isinstance(dim_h, list): 82 | dim_out = dim_h[i] 83 | else: 84 | dim_out = dim_in * 2 85 | conv_args = dict((k, v) for k, v in layer_args.items()) 86 | name = 'conv_({}/{})_{}'.format(dim_in, dim_out, i + 1) 87 | 88 | self.models.add_module( 89 | name, Conv2d(dim_in, dim_out, f_size, stride, pad, bias=False)) 90 | dim_x, dim_y = self.next_size(dim_x, dim_y, f_size, stride, pad) 91 | 92 | is_last_layer = not((dim_x >= min_dim and dim_y >= min_dim) and 93 | (i < n_steps if n_steps else True)) 94 | 95 | if is_last_layer: 96 | if not(last_conv_nonlinearity): 97 | nonlinearity = None 98 | else: 99 | nonlinearity = self.layer_nonlinearity 100 | 101 | if not(last_batchnorm): 102 | conv_args['batch_norm'] = False 103 | 104 | finish_layer_2d( 105 | self.models, name, dim_x, dim_y, dim_out, 106 | nonlinearity=nonlinearity, **conv_args) 107 | logger.debug('Output size: {},{}'.format(dim_x, dim_y)) 108 | i += 1 109 | 110 | if len(fully_connected_layers) == 0 and dim_out_ is None: 111 | return 112 | 113 | dim_out__ = dim_out 114 | dim_out = dim_x * dim_y * dim_out 115 | 116 | self.models.add_module('final_reshape_{}x{}x{}to{}' 117 | .format(dim_x, dim_y, dim_out__, dim_out), 118 | View(-1, dim_out)) 119 | 120 | dim_out = self.add_linear_layers(dim_out, fully_connected_layers, 121 | Linear=Linear, **layer_args) 122 | self.add_output_layer(dim_out, dim_out_, Linear=Linear) 123 | 124 | def next_size(self, dim_x, dim_y, k, s, p): 125 | if isinstance(k, int): 126 | kx, ky = (k, k) 127 | else: 128 | kx, ky = k 129 | 130 | if isinstance(s, int): 131 | sx, sy = (s, s) 132 | else: 133 | sx, sy = s 134 | 135 | if isinstance(p, int): 136 | px, py = (p, p) 137 | else: 138 | px, py = p 139 | return infer_conv_size( 140 | dim_x, kx, sx, px), infer_conv_size(dim_y, ky, sy, py) 141 | -------------------------------------------------------------------------------- /cortex/built_ins/networks/fully_connected.py: -------------------------------------------------------------------------------- 1 | '''Simple dense network encoders 2 | 3 | ''' 4 | 5 | import logging 6 | 7 | import torch.nn as nn 8 | 9 | from .base_network import BaseNet 10 | 11 | 12 | logger = logging.getLogger('cortex.arch' + __name__) 13 | 14 | 15 | class FullyConnectedNet(BaseNet): 16 | 17 | def __init__(self, dim_in, dim_out=None, dim_h=64, dim_ex=None, 18 | nonlinearity='ReLU', n_levels=None, output_nonlinearity=None, 19 | normalize_input=False, **layer_args): 20 | super(FullyConnectedNet, self).__init__( 21 | nonlinearity=nonlinearity, output_nonlinearity=output_nonlinearity) 22 | 23 | dim_h = self.get_h(dim_h, n_levels=n_levels) 24 | 25 | if normalize_input: 26 | self.models.add_module('initial_bn', nn.BatchNorm1d(dim_in)) 27 | 28 | dim_in = self.add_linear_layers(dim_in, dim_h, dim_ex=dim_ex, 29 | **layer_args) 30 | self.add_output_layer(dim_in, dim_out) 31 | -------------------------------------------------------------------------------- /cortex/built_ins/networks/modules.py: -------------------------------------------------------------------------------- 1 | '''General purpose modules 2 | 3 | ''' 4 | 5 | import torch.nn as nn 6 | 7 | 8 | class View(nn.Module): 9 | def __init__(self, *shape): 10 | super(View, self).__init__() 11 | self.shape = shape 12 | 13 | def forward(self, input): 14 | return input.view(*self.shape) 15 | 16 | 17 | class Pipeline(nn.Module): 18 | def __init__(self, networks): 19 | super(Pipeline, self).__init__() 20 | self.networks = networks 21 | 22 | def forward(self, input): 23 | output = input 24 | for network in self.networks: 25 | output = network(output) 26 | return output 27 | -------------------------------------------------------------------------------- /cortex/built_ins/networks/tv_models_wrapper.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchvision import models 3 | from .utils import finish_layer_1d, get_nonlinearity 4 | 5 | 6 | class AlexNet(models.AlexNet): 7 | def __init__(self, shape, dim_out=None, fully_connected_layers=None, 8 | nonlinearity='ReLU', n_steps=None, 9 | **layer_args): 10 | super(AlexNet, self).__init__() 11 | fully_connected_layers = fully_connected_layers or [] 12 | self.fc = nn.Sequential() 13 | dim_out_ = (256 * ((shape[0] + 4 - 10) // 32) * 14 | ((shape[1] + 4 - 10) // 32)) 15 | nonlinearity = get_nonlinearity(nonlinearity) 16 | for dim_h in fully_connected_layers: 17 | dim_in = dim_out_ 18 | dim_out_ = dim_h 19 | name = 'linear_%s_%s' % (dim_in, dim_out_) 20 | self.fc.add_module(name, nn.Linear(dim_in, dim_out_)) 21 | finish_layer_1d(self.fc, name, dim_out_, 22 | nonlinearity=nonlinearity, **layer_args) 23 | 24 | if dim_out: 25 | name = 'dim_out' 26 | self.fc.add_module(name, nn.Linear(dim_out_, dim_out)) 27 | 28 | def forward(self, x): 29 | x = self.features(x) 30 | x = x.view(x.size()[0], -1) 31 | return self.fc(x) 32 | -------------------------------------------------------------------------------- /cortex/built_ins/networks/utils.py: -------------------------------------------------------------------------------- 1 | '''Utils for networks 2 | 3 | ''' 4 | 5 | import logging 6 | 7 | from torch import nn 8 | 9 | logger = logging.getLogger('cortex.arch.modules' + __name__) 10 | 11 | 12 | def get_nonlinearity(nonlinearity=None): 13 | if not nonlinearity: 14 | pass 15 | elif callable(nonlinearity): 16 | if nonlinearity == nn.LeakyReLU: 17 | nonlinearity = nonlinearity(0.02, inplace=True) 18 | elif hasattr(nn, nonlinearity): 19 | nonlinearity = getattr(nn, nonlinearity) 20 | if nonlinearity == 'LeakyReLU': 21 | nonlinearity = nonlinearity(0.02, inplace=True) 22 | else: 23 | nonlinearity = nonlinearity() 24 | elif hasattr(nn.functional, nonlinearity): 25 | nonlinearity = getattr(nn.functional, nonlinearity) 26 | else: 27 | raise ValueError(nonlinearity) 28 | return nonlinearity 29 | 30 | 31 | def finish_layer_2d(models, name, dim_x, dim_y, dim_out, 32 | dropout=False, layer_norm=False, batch_norm=False, 33 | nonlinearity=None): 34 | if layer_norm and batch_norm: 35 | logger.warning('Ignoring layer_norm because batch_norm is True') 36 | 37 | if dropout: 38 | models.add_module(name + '_do', nn.Dropout2d(p=dropout)) 39 | 40 | if layer_norm: 41 | models.add_module(name + '_ln', nn.LayerNorm((dim_out, dim_x, dim_y))) 42 | elif batch_norm: 43 | models.add_module(name + '_bn', nn.BatchNorm2d(dim_out)) 44 | 45 | if nonlinearity: 46 | nonlinearity = get_nonlinearity(nonlinearity) 47 | models.add_module( 48 | '{}_{}'.format(name, nonlinearity.__class__.__name__), 49 | nonlinearity) 50 | 51 | 52 | def finish_layer_1d(models, name, dim_out, 53 | dropout=False, layer_norm=False, batch_norm=False, 54 | nonlinearity=None): 55 | if layer_norm and batch_norm: 56 | logger.warning('Ignoring layer_norm because batch_norm is True') 57 | 58 | if dropout: 59 | models.add_module(name + '_do', nn.Dropout(p=dropout)) 60 | 61 | if layer_norm: 62 | models.add_module(name + '_ln', nn.LayerNorm(dim_out)) 63 | elif batch_norm: 64 | models.add_module(name + '_bn', nn.BatchNorm1d(dim_out)) 65 | 66 | if nonlinearity: 67 | nonlinearity = get_nonlinearity(nonlinearity) 68 | models.add_module( 69 | '{}_{}'.format(name, nonlinearity.__class__.__name__), 70 | nonlinearity) 71 | 72 | 73 | def apply_nonlinearity(x, nonlinearity, **nonlinearity_args): 74 | if nonlinearity: 75 | if isinstance(nonlinearity, str): 76 | nonlinearity = get_nonlinearity(nonlinearity) 77 | if callable(nonlinearity): 78 | if isinstance(nonlinearity, nn.PReLU): 79 | nonlinearity.to(x.device) 80 | try: 81 | x = nonlinearity(x, **nonlinearity_args) 82 | except RuntimeError: 83 | nonlinearity.to('cpu') 84 | x = nonlinearity(x, **nonlinearity_args) 85 | else: 86 | raise ValueError(nonlinearity, type(nonlinearity)) 87 | return x 88 | -------------------------------------------------------------------------------- /cortex/built_ins/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | ''' 4 | -------------------------------------------------------------------------------- /cortex/built_ins/transforms/sobel.py: -------------------------------------------------------------------------------- 1 | '''Module for Sobel transformation 2 | 3 | ''' 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | class Sobel(object): 10 | def __init__(self): 11 | self.kernel_g_x = torch.FloatTensor( 12 | [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).unsqueeze(0).unsqueeze(0) 13 | self.kernel_g_y = torch.FloatTensor( 14 | [[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).unsqueeze(0).unsqueeze(0) 15 | 16 | def _apply_sobel(self, channel): 17 | g_x = F.conv2d(channel, self.kernel_g_x, stride=1, padding=1) 18 | g_y = F.conv2d(channel, self.kernel_g_y, stride=1, padding=1) 19 | return torch.sqrt(torch.pow(g_x, 2) + torch.pow(g_y, 2)) 20 | 21 | def __call__(self, img): 22 | a = torch.cat([self._apply_sobel( 23 | channel.unsqueeze(0).unsqueeze(0)) for channel in img]).squeeze(1) 24 | return a 25 | 26 | def __repr__(self): 27 | return self.__class__.__name__ + '()' 28 | -------------------------------------------------------------------------------- /cortex/main.py: -------------------------------------------------------------------------------- 1 | '''Main file for running experiments. 2 | 3 | ''' 4 | 5 | 6 | import logging 7 | 8 | from cortex._lib import (config, data, exp, optimizer, setup_cortex, 9 | setup_experiment, train) 10 | from cortex._lib.utils import print_section 11 | 12 | __author__ = 'R Devon Hjelm' 13 | __author_email__ = 'erroneus@gmail.com' 14 | 15 | 16 | logger = logging.getLogger('cortex') 17 | 18 | 19 | def run(model=None): 20 | '''Main function. 21 | 22 | ''' 23 | # Parse the command-line arguments 24 | 25 | try: 26 | args = setup_cortex(model=model) 27 | if args.command == 'setup': 28 | # Performs setup only. 29 | config.setup() 30 | exit(0) 31 | else: 32 | config.set_config() 33 | print_section('EXPERIMENT') 34 | model, reload_nets = setup_experiment(args, model=model) 35 | print_section('DATA') 36 | data.setup(**exp.ARGS['data']) 37 | print_section('MODEL') 38 | model.reload_nets(reload_nets) 39 | model.build() 40 | print_section('OPTIMIZER') 41 | optimizer.setup(model, **exp.ARGS['optimizer']) 42 | 43 | except KeyboardInterrupt: 44 | print('Cancelled') 45 | exit(0) 46 | 47 | print_section('RUNNING') 48 | train.main_loop(model, **exp.ARGS['train']) 49 | -------------------------------------------------------------------------------- /demos/demo_classifier.py: -------------------------------------------------------------------------------- 1 | '''Simple classifier model 2 | 3 | ''' 4 | 5 | from cortex.main import run 6 | from cortex.plugins import ModelPlugin 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from cortex.built_ins.models.utils import update_encoder_args 13 | 14 | 15 | class MyClassifier(ModelPlugin): 16 | '''Basic image classifier. 17 | 18 | Classifies images using standard convnets. 19 | 20 | ''' 21 | 22 | defaults = dict( 23 | data=dict(batch_size=128, inputs=dict(inputs='images')), 24 | optimizer=dict(optimizer='Adam', learning_rate=1e-3), 25 | train=dict(epochs=200, save_on_best='losses.classifier')) 26 | 27 | def build(self, classifier_type='convnet', 28 | classifier_args=dict(dropout=0.2)): 29 | '''Builds a simple image classifier. 30 | 31 | Args: 32 | classifier_type (str): Network type for the classifier. 33 | classifier_args: Classifier arguments. Can include dropout, 34 | batch_norm, layer_norm, etc. 35 | 36 | ''' 37 | 38 | classifier_args = classifier_args or {} 39 | 40 | shape = self.get_dims('x', 'y', 'c') 41 | dim_l = self.get_dims('labels') 42 | 43 | Encoder, args = update_encoder_args( 44 | shape, model_type=classifier_type, encoder_args=classifier_args) 45 | 46 | args.update(**classifier_args) 47 | 48 | classifier = Encoder(shape, dim_out=dim_l, **args) 49 | self.nets.classifier = classifier 50 | 51 | def routine(self, inputs, targets, criterion=nn.CrossEntropyLoss()): 52 | ''' 53 | 54 | Args: 55 | criterion: Classifier criterion. 56 | 57 | ''' 58 | 59 | classifier = self.nets.classifier 60 | 61 | outputs = classifier(inputs) 62 | predicted = torch.max(F.log_softmax(outputs, dim=1).data, 1)[1] 63 | 64 | loss = criterion(outputs, targets) 65 | correct = 100. * predicted.eq( 66 | targets.data).cpu().sum() / targets.size(0) 67 | 68 | self.losses.classifier = loss 69 | self.results.accuracy = correct 70 | 71 | def predict(self, inputs): 72 | classifier = self.nets.classifier 73 | 74 | outputs = classifier(inputs) 75 | predicted = torch.max(F.log_softmax(outputs, dim=1).data, 1)[1] 76 | 77 | return predicted 78 | 79 | def visualize(self, images, inputs, targets): 80 | predicted = self.predict(inputs) 81 | self.add_image(images.data, labels=(targets.data, predicted.data), 82 | name='gt_pred') 83 | 84 | 85 | if __name__ == '__main__': 86 | classifier = MyClassifier() 87 | run(model=classifier) 88 | -------------------------------------------------------------------------------- /demos/demo_custom_ae.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch import nn 3 | 4 | from cortex.plugins import ModelPlugin 5 | from cortex.main import run 6 | 7 | 8 | class Autoencoder(nn.Module): 9 | """ 10 | Encapsulation of an encoder and a decoder. 11 | """ 12 | def __init__(self, encoder, decoder): 13 | super(Autoencoder, self).__init__() 14 | self.encoder = encoder 15 | self.decoder = decoder 16 | 17 | def forward(self, x, nonlinearity=None): 18 | encoded = self.encoder(x) 19 | decoded = self.decoder(encoded) 20 | return decoded 21 | 22 | 23 | class AE(ModelPlugin): 24 | """ 25 | Autoencoder designed with Pytorch components. 26 | """ 27 | defaults = dict( 28 | data=dict( 29 | batch_size=dict(train=64, test=64), inputs=dict(inputs='images')), 30 | optimizer=dict(optimizer='Adam', learning_rate=1e-4), 31 | train=dict(save_on_lowest='losses.ae')) 32 | 33 | def __init__(self): 34 | super().__init__() 35 | 36 | def build(self, dim_z=64, dim_encoder_out=64): 37 | """ 38 | Build AE with an encoder and decoder and attribute 39 | an Autoencoder to self.nets. 40 | Args: 41 | dim_z: int 42 | dim_encoder_out: in 43 | 44 | Returns: None 45 | 46 | """ 47 | encoder = nn.Sequential( 48 | nn.Linear(28, 256), 49 | nn.ReLU(True), 50 | nn.Linear(256, 28), 51 | nn.ReLU(True)) 52 | decoder = nn.Sequential( 53 | nn.Linear(28, 256), 54 | nn.ReLU(True), 55 | nn.Linear(256, 28), 56 | nn.Sigmoid()) 57 | self.nets.ae = Autoencoder(encoder, decoder) 58 | 59 | def routine(self, inputs, ae_criterion=F.mse_loss): 60 | """ 61 | Training routine and loss computing. 62 | Args: 63 | inputs: torch.Tensor 64 | ae_criterion: function 65 | 66 | Returns: None 67 | 68 | """ 69 | encoded = self.nets.ae.encoder(inputs) 70 | outputs = self.nets.ae.decoder(encoded) 71 | r_loss = ae_criterion( 72 | outputs, inputs, size_average=False) / inputs.size(0) 73 | self.losses.ae = r_loss 74 | 75 | def visualize(self, inputs): 76 | """ 77 | Adding generated images and base images to 78 | visualization. 79 | Args: 80 | inputs: torch.Tensor 81 | Returns: None 82 | 83 | """ 84 | encoded = self.nets.ae.encoder(inputs) 85 | outputs = self.nets.ae.decoder(encoded) 86 | self.add_image(outputs, name='reconstruction') 87 | self.add_image(inputs, name='ground truth') 88 | 89 | 90 | if __name__ == '__main__': 91 | autoencoder = AE() 92 | run(model=autoencoder) 93 | -------------------------------------------------------------------------------- /docs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for documentation generation and cleaning. 3 | """ 4 | 5 | import glob 6 | import subprocess 7 | import os 8 | 9 | files_before_generation = glob.glob("./docs/source/*.rst") 10 | subprocess.call(["sphinx-apidoc", "-f", "-o", "docs/source", "./cortex"]) 11 | subprocess.call(["make", "html"]) 12 | files_after_generation = glob.glob("./docs/source/*.rst") 13 | 14 | for filename in files_before_generation: 15 | if filename in files_after_generation: 16 | files_after_generation.remove(filename) 17 | 18 | for filename in files_after_generation: 19 | subprocess.call(["rm", "-f", filename]) 20 | 21 | folders = [ 22 | folder for folder in os.listdir('./docs/html') 23 | if os.path.isdir(os.path.join('./docs/html', folder)) 24 | ] 25 | 26 | folders_to_remove = [] 27 | 28 | for folder in folders: 29 | if folder[0] != "_": 30 | folders_to_remove.append(folder) 31 | 32 | for folder_to_remove in folders_to_remove: 33 | folder_to_remove = "./docs/html/" + folder_to_remove 34 | subprocess.call(["rm", "-rf", folder_to_remove]) 35 | 36 | doctrees = glob.glob("./docs/doctrees") 37 | subprocess.call(["rm", "-rf", doctrees[0]]) 38 | 39 | sources = glob.glob("./docs/html/_sources") 40 | subprocess.call(["rm", "-rf", sources[0]]) 41 | 42 | objects = glob.glob("./docs/html/objects.inv") 43 | subprocess.call(["rm", "-f", objects[0]]) 44 | -------------------------------------------------------------------------------- /docs/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/.nojekyll -------------------------------------------------------------------------------- /docs/html/.buildinfo: -------------------------------------------------------------------------------- 1 | # Sphinx build info version 1 2 | # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. 3 | config: 5c7d95945a42b27b9dc9e58391223d61 4 | tags: 645f666f9bcd5a90fca523b33c5a78b7 5 | -------------------------------------------------------------------------------- /docs/html/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/.nojekyll -------------------------------------------------------------------------------- /docs/html/_static/ajax-loader.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/ajax-loader.gif -------------------------------------------------------------------------------- /docs/html/_static/comment-bright.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/comment-bright.png -------------------------------------------------------------------------------- /docs/html/_static/comment-close.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/comment-close.png -------------------------------------------------------------------------------- /docs/html/_static/comment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/comment.png -------------------------------------------------------------------------------- /docs/html/_static/css/badge_only.css: -------------------------------------------------------------------------------- 1 | .fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:before,.clearfix:after{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-weight:normal;font-style:normal;src:url("../fonts/fontawesome-webfont.eot");src:url("../fonts/fontawesome-webfont.eot?#iefix") format("embedded-opentype"),url("../fonts/fontawesome-webfont.woff") format("woff"),url("../fonts/fontawesome-webfont.ttf") format("truetype"),url("../fonts/fontawesome-webfont.svg#FontAwesome") format("svg")}.fa:before{display:inline-block;font-family:FontAwesome;font-style:normal;font-weight:normal;line-height:1;text-decoration:inherit}a .fa{display:inline-block;text-decoration:inherit}li .fa{display:inline-block}li .fa-large:before,li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-0.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before,ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before{content:""}.icon-book:before{content:""}.fa-caret-down:before{content:""}.icon-caret-down:before{content:""}.fa-caret-up:before{content:""}.icon-caret-up:before{content:""}.fa-caret-left:before{content:""}.icon-caret-left:before{content:""}.fa-caret-right:before{content:""}.icon-caret-right:before{content:""}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:"Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;z-index:400}.rst-versions a{color:#2980B9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27AE60;*zoom:1}.rst-versions .rst-current-version:before,.rst-versions .rst-current-version:after{display:table;content:""}.rst-versions .rst-current-version:after{clear:both}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book{float:left}.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#E74C3C;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#F1C40F;color:#000}.rst-versions.shift-up{height:auto;max-height:100%}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:gray;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:solid 1px #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px}.rst-versions.rst-badge .icon-book{float:none}.rst-versions.rst-badge .fa-book{float:none}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book{float:left}.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge .rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width: 768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} 2 | -------------------------------------------------------------------------------- /docs/html/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | var DOCUMENTATION_OPTIONS = { 2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), 3 | VERSION: '', 4 | LANGUAGE: 'None', 5 | COLLAPSE_INDEX: false, 6 | FILE_SUFFIX: '.html', 7 | HAS_SOURCE: true, 8 | SOURCELINK_SUFFIX: '.txt' 9 | }; -------------------------------------------------------------------------------- /docs/html/_static/down-pressed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/down-pressed.png -------------------------------------------------------------------------------- /docs/html/_static/down.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/down.png -------------------------------------------------------------------------------- /docs/html/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/file.png -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/Lato/lato-bold.eot -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/Lato/lato-bold.ttf -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/Lato/lato-bold.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/Lato/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-bolditalic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/Lato/lato-bolditalic.eot -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-bolditalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/Lato/lato-bolditalic.ttf -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-bolditalic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/Lato/lato-bolditalic.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-bolditalic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/Lato/lato-bolditalic.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-italic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/Lato/lato-italic.eot -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/Lato/lato-italic.ttf -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/Lato/lato-italic.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/Lato/lato-italic.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/Lato/lato-regular.eot -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/Lato/lato-regular.ttf -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/Lato/lato-regular.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/Lato/lato-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/Lato/lato-regular.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot -------------------------------------------------------------------------------- /docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf -------------------------------------------------------------------------------- /docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot -------------------------------------------------------------------------------- /docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf -------------------------------------------------------------------------------- /docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2 -------------------------------------------------------------------------------- /docs/html/_static/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/html/_static/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/html/_static/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/html/_static/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdevon/cortex/2837b220f9fb73279df3815bb18b274106412c08/docs/html/_static/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/html/_static/js/theme.js: -------------------------------------------------------------------------------- 1 | /* sphinx_rtd_theme version 0.4.0 | MIT license */ 2 | /* Built 20180606 11:06 */ 3 | require=function n(e,i,t){function o(s,a){if(!i[s]){if(!e[s]){var l="function"==typeof require&&require;if(!a&&l)return l(s,!0);if(r)return r(s,!0);var c=new Error("Cannot find module '"+s+"'");throw c.code="MODULE_NOT_FOUND",c}var u=i[s]={exports:{}};e[s][0].call(u.exports,function(n){var i=e[s][1][n];return o(i||n)},u,u.exports,n,e,i,t)}return i[s].exports}for(var r="function"==typeof require&&require,s=0;s"),n("table.docutils.footnote").wrap("
"),n("table.docutils.citation").wrap("
"),n(".wy-menu-vertical ul").not(".simple").siblings("a").each(function(){var i=n(this);expand=n(''),expand.on("click",function(n){return e.toggleCurrent(i),n.stopPropagation(),!1}),i.prepend(expand)})},reset:function(){var n=encodeURI(window.location.hash)||"#";try{var e=$(".wy-menu-vertical"),i=e.find('[href="'+n+'"]');if(0===i.length){var t=$('.document [id="'+n.substring(1)+'"]').closest("div.section");0===(i=e.find('[href="#'+t.attr("id")+'"]')).length&&(i=e.find('[href="#"]'))}i.length>0&&($(".wy-menu-vertical .current").removeClass("current"),i.addClass("current"),i.closest("li.toctree-l1").addClass("current"),i.closest("li.toctree-l1").parent().addClass("current"),i.closest("li.toctree-l1").addClass("current"),i.closest("li.toctree-l2").addClass("current"),i.closest("li.toctree-l3").addClass("current"),i.closest("li.toctree-l4").addClass("current"))}catch(o){console.log("Error expanding nav for anchor",o)}},onScroll:function(){this.winScroll=!1;var n=this.win.scrollTop(),e=n+this.winHeight,i=this.navBar.scrollTop()+(n-this.winPosition);n<0||e>this.docHeight||(this.navBar.scrollTop(i),this.winPosition=n)},onResize:function(){this.winResize=!1,this.winHeight=this.win.height(),this.docHeight=$(document).height()},hashChange:function(){this.linkScroll=!0,this.win.one("hashchange",function(){this.linkScroll=!1})},toggleCurrent:function(n){var e=n.closest("li");e.siblings("li.current").removeClass("current"),e.siblings().find("li.current").removeClass("current"),e.find("> ul li.current").removeClass("current"),e.toggleClass("current")}},"undefined"!=typeof window&&(window.SphinxRtdTheme={Navigation:e.exports.ThemeNav,StickyNav:e.exports.ThemeNav}),function(){for(var n=0,e=["ms","moz","webkit","o"],i=0;i 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | Welcome to Cortex2.0 — Cortex2.0 documentation 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 |
42 | 43 | 44 | 98 | 99 |
100 | 101 | 102 | 108 | 109 | 110 |
111 | 112 |
113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 |
131 | 132 |
    133 | 134 |
  • Docs »
  • 135 | 136 |
  • Welcome to Cortex2.0
  • 137 | 138 | 139 |
  • 140 | 141 | 142 | View page source 143 | 144 | 145 |
  • 146 | 147 |
148 | 149 | 150 |
151 |
152 |
153 |
154 | 155 |
156 |

Welcome to Cortex2.0

157 |
158 |

User Documentation

159 | 170 |
171 |
172 | 173 | 174 |
175 | 176 |
177 |
178 | 179 | 185 | 186 | 187 |
188 | 189 |
190 |

191 | © Copyright 2018, MILA. 192 | 193 |

194 |
195 | Built with Sphinx using a theme provided by Read the Docs. 196 | 197 |
198 | 199 |
200 |
201 | 202 |
203 | 204 |
205 | 206 | 207 | 208 | 209 | 210 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 235 | 236 | 237 | -------------------------------------------------------------------------------- /docs/html/search.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | Search — Cortex2.0 documentation 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 |
41 | 42 | 43 | 97 | 98 |
99 | 100 | 101 | 107 | 108 | 109 |
110 | 111 |
112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 |
130 | 131 |
    132 | 133 |
  • Docs »
  • 134 | 135 |
  • Search
  • 136 | 137 | 138 |
  • 139 | 140 | 141 | 142 |
  • 143 | 144 |
145 | 146 | 147 |
148 |
149 |
150 |
151 | 152 | 160 | 161 | 162 |
163 | 164 |
165 | 166 |
167 | 168 |
169 |
170 | 171 | 172 |
173 | 174 |
175 |

176 | © Copyright 2018, MILA. 177 | 178 |

179 |
180 | Built with Sphinx using a theme provided by Read the Docs. 181 | 182 |
183 | 184 |
185 |
186 | 187 |
188 | 189 |
190 | 191 | 192 | 193 | 194 | 195 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 221 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/build.rst: -------------------------------------------------------------------------------- 1 | Custom demos 2 | ~~~~~~~~~~~~ 3 | 4 | While cortex has built-in functionality, but it is meant to meant to be 5 | used with your own modules. An example of making a model that works with 6 | cortex can be found at: 7 | https://github.com/rdevon/cortex/blob/master/demos/demo_classifier.py 8 | and https://github.com/rdevon/cortex/blob/master/demos/demo_custom_ae.py 9 | 10 | Documentation on the API can be found here: 11 | https://github.com/rdevon/cortex/blob/master/cortex/plugins.py 12 | 13 | For instance, the demo autoencoder can be used as: 14 | 15 | :: 16 | 17 | python cortex/demos/demo_custom_ae.py --help 18 | 19 | A walkthrough a custom classifier: 20 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 21 | 22 | Let’s look a little more closely at the autoencoder demo above to see 23 | what’s going on. cortex relies on using and overriding methods of 24 | plugins classes. 25 | 26 | First, let’s look at the methods, ``build``, ``routine``, and 27 | ``visualize``. These are special methods for the plugin that can be 28 | overridden to change the behavior of your model for your needs. 29 | 30 | The signature of these functions look like: 31 | 32 | :: 33 | 34 | def build(self, dim_z=64, dim_encoder_out=64): 35 | ... 36 | 37 | def routine(self, inputs, targets, ae_criterion=F.mse_loss): 38 | ... 39 | 40 | def visualize(self, inputs, targets): 41 | ... 42 | 43 | Each of these functions have arguments and keyword arguments. Note that 44 | the keyword arguments showed up in the help in the above example. This 45 | is part of the functionality of cortex: it manages your hyperparameters 46 | to these functions, organizes them, and provides command line control 47 | automatically. Even the docstrings are used in the command line, so 48 | other users can get the usage docs directly from there. 49 | 50 | The arguments are *data*, which are to be manipulated as needed in those 51 | methods. These are for the most part handled automatically, but all of 52 | these methods can be used as normal functions as well. 53 | 54 | Building models 55 | ^^^^^^^^^^^^^^^ 56 | 57 | The ``build`` function takes the hyperparameters and sets networks. 58 | 59 | :: 60 | 61 | 62 | class Autoencoder(nn.Module): 63 | def __init__(self, encoder, decoder): 64 | super(Autoencoder, self).__init__() 65 | self.encoder = encoder 66 | self.decoder = decoder 67 | 68 | def forward(self, x, nonlinearity=None): 69 | encoded = self.encoder(x) 70 | decoded = self.decoder(encoded) 71 | return decoded 72 | 73 | ... 74 | 75 | def build(self, dim_z=64, dim_encoder_out=64): 76 | encoder = nn.Sequential( 77 | nn.Linear(28, 256), 78 | nn.ReLU(True), 79 | nn.Linear(256, 28), 80 | nn.ReLU(True)) 81 | decoder = nn.Sequential( 82 | nn.Linear(28, 256), 83 | nn.ReLU(True), 84 | nn.Linear(256, 28), 85 | nn.Sigmoid()) 86 | self.nets.ae = Autoencoder(encoder, decoder) 87 | 88 | All that’s being done here is the hyperparameters are being used to 89 | create an instance of an ``nn.Module`` subclass, which is being added to 90 | the set of “nets”. Note that they keyword ``ae`` is very important, as 91 | this is going to be how you retrieve your nets and define their losses 92 | farther down. 93 | 94 | Also note that cortex *only* currently supports ``nn.Module`` subclasses 95 | from Pytorch. 96 | 97 | Defining losses and results 98 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 99 | 100 | Adding losses and results from your model is easy, just compute your 101 | graph given you models and data, then add the losses and results by 102 | setting those members: 103 | 104 | :: 105 | 106 | def routine(self, inputs, targets, ae_criterion=F.mse_loss): 107 | encoded = self.nets.ae.encoder(inputs) 108 | outputs = self.nets.ae.decoder(encoded) 109 | r_loss = ae_criterion( 110 | outputs, inputs, size_average=False) / inputs.size(0) 111 | self.losses.ae = r_loss 112 | 113 | Additional results can be added similarly. For instance, in the demo 114 | classifier: 115 | 116 | :: 117 | 118 | def routine(self, inputs, targets, criterion=nn.CrossEntropyLoss()): 119 | ... 120 | classifier = self.nets.classifier 121 | 122 | outputs = classifier(inputs) 123 | predicted = torch.max(F.log_softmax(outputs, dim=1).data, 1)[1] 124 | 125 | loss = criterion(outputs, targets) 126 | correct = 100. * predicted.eq( 127 | targets.data).cpu().sum() / targets.size(0) 128 | 129 | self.losses.classifier = loss 130 | self.results.accuracy = correct 131 | 132 | Visualization 133 | ~~~~~~~~~~~~~ 134 | 135 | Cortex allows for visualization using visdom, and this can be defined in 136 | a similar way as above: 137 | 138 | :: 139 | 140 | def visualize(self, images, inputs, targets): 141 | predicted = self.predict(inputs) 142 | self.add_image(images.data, labels=(targets.data, predicted.data), 143 | name='gt_pred') 144 | 145 | See the ModelPlugin API for more more details. 146 | 147 | Putting it together 148 | ~~~~~~~~~~~~~~~~~~~ 149 | 150 | Finally, we can specify default arguments: 151 | 152 | :: 153 | 154 | defaults = dict( 155 | data=dict( 156 | batch_size=dict(train=64, test=64), inputs=dict(inputs='images')), 157 | optimizer=dict(optimizer='Adam', learning_rate=1e-4), 158 | train=dict(save_on_lowest='losses.ae')) 159 | 160 | and then add ``cortex.main.run`` to ``__main__``: 161 | 162 | :: 163 | 164 | if __name__ == '__main__': 165 | autoencoder = AE() 166 | run(model=autoencoder) 167 | 168 | And that’s it. cortex also allows for lower-level functions to be 169 | overridden (e.g., train_step, eval_step, train_loop, etc) with more 170 | customizability coming soon. For more examples of usage, see the 171 | built-in models: 172 | https://github.com/rdevon/cortex/tree/master/cortex/built_ins/models -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | sys.path.insert(0, os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..", "cortex"))) 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = 'Cortex2.0' 22 | copyright = '2018, MILA' 23 | author = 'Devon Hjelm' 24 | 25 | # The short X.Y version 26 | version = '' 27 | # The full version, including alpha/beta/rc tags 28 | release = '' 29 | 30 | 31 | # -- General configuration --------------------------------------------------- 32 | 33 | # If your documentation needs a minimal Sphinx version, state it here. 34 | # 35 | # needs_sphinx = '1.0' 36 | 37 | # Add any Sphinx extension module names here, as strings. They can be 38 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 39 | # ones. 40 | extensions = [ 41 | 'sphinx.ext.autodoc', 42 | 'sphinx.ext.doctest', 43 | 'sphinx.ext.intersphinx', 44 | 'sphinx.ext.todo', 45 | 'sphinx.ext.coverage', 46 | 'sphinx.ext.mathjax', 47 | 'sphinx.ext.ifconfig', 48 | 'sphinx.ext.viewcode', 49 | 'sphinx.ext.githubpages', 50 | 'sphinx.ext.napoleon' 51 | ] 52 | 53 | napoleon_google_docstring = True 54 | 55 | # Add any paths that contain templates here, relative to this directory. 56 | templates_path = ['templates'] 57 | 58 | # The suffix(es) of source filenames. 59 | # You can specify multiple suffix as a list of string: 60 | # 61 | # source_suffix = ['.rst', '.md'] 62 | source_suffix = '.rst' 63 | 64 | # The master toctree document. 65 | master_doc = 'index' 66 | 67 | # The language for content autogenerated by Sphinx. Refer to documentation 68 | # for a list of supported languages. 69 | # 70 | # This is also used if you do content translation via gettext catalogs. 71 | # Usually you set "language" from the command line for these cases. 72 | language = None 73 | 74 | # List of patterns, relative to source directory, that match files and 75 | # directories to ignore when looking for source files. 76 | # This pattern also affects html_static_path and html_extra_path . 77 | exclude_patterns = [] 78 | 79 | # The name of the Pygments (syntax highlighting) style to use. 80 | pygments_style = 'sphinx' 81 | 82 | 83 | # -- Options for HTML output ------------------------------------------------- 84 | 85 | # The theme to use for HTML and HTML Help pages. See the documentation for 86 | # a list of builtin themes. 87 | # 88 | html_theme = 'sphinx_rtd_theme' 89 | 90 | # Theme options are theme-specific and customize the look and feel of a theme 91 | # further. For a list of options available for each theme, see the 92 | # documentation. 93 | # 94 | # html_theme_options = {} 95 | 96 | # Add any paths that contain custom static files (such as style sheets) here, 97 | # relative to this directory. They are copied after the builtin static files, 98 | # so a file named "default.css" will overwrite the builtin "default.css". 99 | html_static_path = ['static'] 100 | 101 | # Custom sidebar templates, must be a dictionary that maps document names 102 | # to template names. 103 | # 104 | # The default sidebars (for documents that don't match any pattern) are 105 | # defined by theme itself. Builtin themes are using these templates by 106 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 107 | # 'searchbox.html']``. 108 | # 109 | # html_sidebars = {} 110 | 111 | 112 | # -- Options for HTMLHelp output --------------------------------------------- 113 | 114 | # Output file base name for HTML help builder. 115 | htmlhelp_basename = 'Cortexdoc' 116 | 117 | 118 | # -- Options for LaTeX output ------------------------------------------------ 119 | 120 | latex_elements = { 121 | # The paper size ('letterpaper' or 'a4paper'). 122 | # 123 | # 'papersize': 'letterpaper', 124 | 125 | # The font size ('10pt', '11pt' or '12pt'). 126 | # 127 | # 'pointsize': '10pt', 128 | 129 | # Additional stuff for the LaTeX preamble. 130 | # 131 | # 'preamble': '', 132 | 133 | # Latex figure (float) alignment 134 | # 135 | # 'figure_align': 'htbp', 136 | } 137 | 138 | # Grouping the document tree into LaTeX files. List of tuples 139 | # (source start file, target name, title, 140 | # author, documentclass [howto, manual, or own class]). 141 | latex_documents = [ 142 | (master_doc, 'Cortex.tex', 'Cortex Documentation', 143 | 'Devon Hjelm', 'manual'), 144 | ] 145 | 146 | 147 | # -- Options for manual page output ------------------------------------------ 148 | 149 | # One entry per manual page. List of tuples 150 | # (source start file, name, description, authors, manual section). 151 | man_pages = [ 152 | (master_doc, 'cortex', 'Cortex Documentation', 153 | [author], 1) 154 | ] 155 | 156 | 157 | # -- Options for Texinfo output ---------------------------------------------- 158 | 159 | # Grouping the document tree into Texinfo files. List of tuples 160 | # (source start file, target name, title, author, 161 | # dir menu entry, description, category) 162 | texinfo_documents = [ 163 | (master_doc, 'Cortex', 'Cortex Documentation', 164 | author, 'Cortex', 'One line description of project.', 165 | 'Miscellaneous'), 166 | ] 167 | 168 | 169 | # -- Options for Epub output ------------------------------------------------- 170 | 171 | # Bibliographic Dublin Core info. 172 | epub_title = project 173 | epub_author = author 174 | epub_publisher = author 175 | epub_copyright = copyright 176 | 177 | # The unique identifier of the text. This can be a ISBN number 178 | # or the project homepage. 179 | # 180 | # epub_identifier = '' 181 | 182 | # A unique identification for the text. 183 | # 184 | # epub_uid = '' 185 | 186 | # A list of files that should not be packed into the epub file. 187 | epub_exclude_files = ['search.html'] 188 | 189 | 190 | # -- Extension configuration ------------------------------------------------- 191 | 192 | # -- Options for intersphinx extension --------------------------------------- 193 | 194 | # Example configuration for intersphinx: refer to the Python standard library. 195 | intersphinx_mapping = {'https://docs.python.org/': None} 196 | 197 | # -- Options for todo extension ---------------------------------------------- 198 | 199 | # If true, `todo` and `todoList` produce output, else they produce nothing. 200 | todo_include_todos = True 201 | 202 | -------------------------------------------------------------------------------- /docs/source/cortex.rst: -------------------------------------------------------------------------------- 1 | cortex package 2 | ============== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | cortex.built_ins 10 | 11 | Submodules 12 | ---------- 13 | 14 | cortex.main module 15 | ------------------ 16 | 17 | .. automodule:: cortex.main 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | cortex.plugins module 23 | --------------------- 24 | 25 | .. automodule:: cortex.plugins 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: cortex 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/source/develop.rst: -------------------------------------------------------------------------------- 1 | Develop 2 | =============== 3 | 4 | Documentation 5 | ~~~~~~~~~~~~~ 6 | Make sure that the cortex package is installed and configured. For development purpose, if you are 7 | making changes to documentation, for example modifications inside docstrings or changes in some .rst 8 | files 9 | 10 | Building Documentation 11 | ^^^^^^^^^^^^^^^^^^^^^^ 12 | To build the documentation, the docs.py script under the root of the project is facilitating the process. 13 | Before making a Pull Request to the remote repository, you should run the script. 14 | :: 15 | 16 | $ python docs.py 17 | 18 | Serving Documentation Locally 19 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 20 | If you want to have a look at your changes before making a Pull Request on GitHub, it is possible to 21 | serve locally the generated html files. 22 | :: 23 | 24 | $ cd docs/build/html 25 | $ python -m http.server 8000 --bind 127.0.0.1 26 | 27 | 28 | -------------------------------------------------------------------------------- /docs/source/getting_started.rst: -------------------------------------------------------------------------------- 1 | Getting Started 2 | =============== 3 | 4 | Configuration 5 | ~~~~~~~~~~~~~ 6 | 7 | The first thing to do is to set up the config.yaml. This file is 8 | user-specific (it got tracked at some point, so I need to fix this), and 9 | will tell cortex everything user-specific regarding data locations, 10 | visualation, and outputs. 11 | 12 | :: 13 | 14 | $ rm -rf ~/.cortex.yml 15 | $ cortex setup 16 | 17 | Configuration File Example 18 | '''''''''''''''''''''''''' 19 | 20 | Located at ``~/.cortex.yml`` 21 | 22 | .. code:: python 23 | 24 | torchvision_data_path: /data/milatmp1/hjelmdev/data/ 25 | data_paths: { 26 | Imagenet-12: /data/lisa/data/ImageNet2012_jpeg, CelebA: /tmp/hjelmdev/CelebA}viz: { 27 | font: /usr/share/fonts/truetype/liberation/LiberationSerif-Regular.ttf, server: 'http://132.204.26.180'} 28 | out_path: /data/milatmp1/hjelmdev/outs/ 29 | 30 | These are as follows: 31 | 32 | - torchvision\_data\_path: the path to all torchvision-specific 33 | datasets (details can be found in torchvision.datasets) 34 | - data\_paths: user-specified custom datasets. Currently, only support 35 | is for image folders (a la imagenet), but other dataset types (e.g., 36 | text) are planned in the near-future. 37 | - vis: visdom specific arguments. 38 | - | out\_path: Out path for experiment outputs 39 | 40 | 41 | Usage 42 | ''''' 43 | 44 | cortex --help 45 | 46 | Built-ins 47 | ''''''''' 48 | 49 | :setup: 50 | Setup cortex configuration. 51 | 52 | :GAN: Generative adversarial network. 53 | :VAE: Variational autoencoder. 54 | :AdversarialAutoencoder: Adversarial Autoencoder. 55 | :ALI: Adversarially learned inference. 56 | :ImageClassification: Basic image classifier. 57 | :GAN_MINE: GAN + MINE. 58 | 59 | Options 60 | ''''''' 61 | -h, --help show this help message and exit 62 | -o OUT_PATH, --out_path OUT_PATH Output path directory. All model results will go 63 | here. If a new directory, a new one will be 64 | created, as long as parent exists. 65 | -n NAME, --name NAME Name of the experiment. If given, base name of 66 | output directory will be `--name`. If not given, 67 | name will be the base name of the `--out_path` 68 | -r RELOAD, --reload RELOAD Path to model to reload. 69 | 70 | -M LOAD_MODELS, --load_models LOAD_MODELS Path to model to reload. Does not load args, info, 71 | etc 72 | 73 | -m META, --meta META TODO 74 | 75 | -c CONFIG_FILE, --config_file CONFIG_FILE Configuration yaml file. See `exps/` for examples 76 | 77 | -k, --clean Cleans the output directory. This cannot be undone! 78 | 79 | -v VERBOSITY, --verbosity VERBOSITY Verbosity of the logging. (0, 1, 2) 80 | 81 | -d DEVICE, --device DEVICE TODO 82 | 83 | 84 | Usage Example 85 | ''''''''''''' 86 | 87 | To run an experiment. 88 | 89 | :: 90 | 91 | cortex GAN --d.source CIFAR10 --d.copy_to_local 92 | 93 | Custom models 94 | ''''''''''''' 95 | 96 | It is possible to run experiments with custom models made with Pytorch under the Cortex framework. For doing so, the model has to 97 | be added to the demos folder under the root of the project. You can have a look to the given demo autoencoder and classifier already 98 | implemented. The main difference is that, rather than registering the plugins, the run function of main.py has to be called. For example, 99 | 100 | :: 101 | 102 | if __name__ == '__main__': 103 | classifier = MyClassifier() 104 | run(model=classifier) 105 | 106 | To run an experiment with a custom model. 107 | 108 | :: 109 | 110 | python my_model.py --d.source --d.copy_to_local 111 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to Cortex2.0 2 | ======================== 3 | 4 | .. _user-docs: 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | :caption: User Documentation 9 | 10 | install 11 | getting_started 12 | modules 13 | develop 14 | developer_start_guide 15 | build 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /docs/source/install.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ------------ 3 | 4 | Prerequisites 5 | ~~~~~~~~~~~~~ 6 | 7 | Visdom 8 | '''''' 9 | 10 | :: 11 | 12 | $pip install visdom 13 | 14 | From Source 15 | ~~~~~~~~~~~ 16 | 17 | :: 18 | 19 | $git clone https://github.com/rdevon/cortex2.0.git 20 | $cd cortex2.0 21 | $pip install . 22 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | cortex 2 | ====== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | cortex 8 | -------------------------------------------------------------------------------- /make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | set SPHINXPROJ=Cortex 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /scripts/nearest-neighbor.py: -------------------------------------------------------------------------------- 1 | from sklearn.neighbors import NearestNeighbors 2 | from PIL import Image 3 | import numpy as np 4 | import os 5 | from operator import mul 6 | import functools 7 | 8 | 9 | def get_neighbors(samples, dataset, n_neighbors): 10 | size = functools.reduce(mul, dataset[0].shape, 1) 11 | nbrs = NearestNeighbors( 12 | n_neighbors=n_neighbors, metric='l2', algorithm='brute').fit( 13 | dataset.reshape(-1, size)) 14 | _, samples_idxs = nbrs.kneighbors(samples.reshape(-1, size)) 15 | return np.array([[dataset[idx] for idx in idxs] for idxs in samples_idxs]) 16 | 17 | 18 | if __name__ == '__main__': 19 | import torchvision 20 | import torch 21 | root = os.argv[1] 22 | images = np.concatenate([ 23 | np.array([np.array(Image.open(os.path.join( 24 | root, 'n01443537', 'images', file))) 25 | for file in os.listdir(os.path.join( 26 | root, 'n01443537', 'images'))]) / 255., 27 | np.array([np.array(Image.open(os.path.join( 28 | root, 'n09193705', 'images', file))) 29 | for file in os.listdir(os.path.join( 30 | root, 'n09193705', 'images'))]) / 255., 31 | np.array([np.array(Image.open(os.path.join( 32 | root, 'n01742172', 'images', file))) 33 | for file in os.listdir(os.path.join( 34 | root, 'n01742172', 'images'))]) / 255., 35 | np.array([np.array(Image.open(os.path.join( 36 | root, 'n02058221', 'images', file))) 37 | for file in os.listdir(os.path.join( 38 | root, 'n02058221', 'images'))]) / 255., 39 | np.array([np.array(Image.open(os.path.join( 40 | root, 'n02094433', 'images', file))) 41 | for file in os.listdir(os.path.join( 42 | root, 'n02094433', 'images'))]) / 255.], 43 | axis=0) 44 | nbhrs = np.concatenate((get_neighbors(images[[0, -1]], images, 5)), axis=0) 45 | torchvision.utils.save_image(torch.from_numpy(nbhrs.transpose(0, 3, 1, 2)), 46 | 'test.png', nrow=5) 47 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | packages = [ 4 | 'cortex', 'cortex._lib', 'cortex.built_ins', 'cortex._lib.data', 5 | 'cortex.built_ins.datasets', 'cortex.built_ins.models', 6 | 'cortex.built_ins.networks', 'cortex.built_ins.transforms'] 7 | 8 | 9 | install_requirements = [ 10 | 'imageio', 'matplotlib', 'progressbar2', 'scipy', 'sklearn', 'visdom', 11 | 'pyyaml', 'pathlib', 'sphinxcontrib-napoleon', 'nibabel', 'torch', 'torchvision' 12 | ] 13 | 14 | setup(name='cortex', 15 | version='0.11', 16 | description='A library for wrapping your pytorch code', 17 | author='R Devon Hjelm', 18 | author_email='erroneus@gmail.com', 19 | packages=packages, 20 | install_requires=[ 21 | 'imageio', 'matplotlib', 'progressbar2', 'scipy', 'sklearn', 22 | 'torchvision', 'visdom', 'pyyaml', 'sphinxcontrib-napoleon'], 23 | entry_points={ 24 | 'console_scripts': [ 25 | 'cortex=cortex.main:run'] 26 | }, 27 | dependency_links={'git+https://github.com/facebookresearch/visdom.git'}, 28 | zip_safe=False) 29 | -------------------------------------------------------------------------------- /tests/built_ins/models/test_build.py: -------------------------------------------------------------------------------- 1 | '''Testing for building models. 2 | 3 | ''' 4 | 5 | from cortex._lib.models import MODEL_PLUGINS 6 | from cortex.built_ins.networks.fully_connected import FullyConnectedNet 7 | from cortex.plugins import ModelPlugin, register_model 8 | 9 | 10 | def test_class(model_class, arguments): 11 | '''Tests simple class attributions. 12 | 13 | Args: 14 | model_class: ModulePlugin subclass. 15 | arguments: Arguments for the class. 16 | 17 | ''' 18 | arg1 = arguments['arg1'] 19 | arg2 = arguments['arg2'] 20 | arg1_help = arguments['arg1_help'] 21 | arg2_help = arguments['arg2_help'] 22 | 23 | assert model_class._help[arg1] == arg1_help, model_class.help[arg1] 24 | assert model_class._help[arg2] == arg2_help, model_class.help[arg2] 25 | assert model_class._kwargs[arg1] == 17, model_class.kwargs[arg1] 26 | assert model_class._kwargs[arg2] == 19, model_class.kwargs[arg2] 27 | 28 | 29 | def test_register(model_class): 30 | '''Tests registration of a model. 31 | 32 | Args: 33 | model_class: ModelPlugin subclass. 34 | 35 | ''' 36 | 37 | MODEL_PLUGINS.clear() 38 | register_model(model_class) 39 | assert isinstance(list(MODEL_PLUGINS.values())[0], model_class) 40 | MODEL_PLUGINS.clear() 41 | 42 | 43 | def test_build(model_class, arguments): 44 | '''Tests building the model. 45 | 46 | Args: 47 | model_class: ModulePlugin subclass. 48 | arguments: Arguments for the class. 49 | 50 | ''' 51 | ModelPlugin._all_nets.clear() 52 | kwargs = {arguments['arg1']: 11, arguments['arg2']: 13} 53 | 54 | model = model_class() 55 | model.kwargs.update(**kwargs) 56 | 57 | model.build() 58 | 59 | print('Model networks:', model.nets) 60 | assert isinstance(model.nets.net, FullyConnectedNet) 61 | 62 | parameters = list(model.nets.net.parameters()) 63 | 64 | print('Parameter sizes:', [p.size() for p in parameters]) 65 | assert parameters[0].size(1) == 11 66 | assert parameters[2].size(0) == parameters[3].size(0) == 13 67 | 68 | 69 | def test_subplugin(model_class_with_submodel): 70 | '''Tests a model with a model inside. 71 | 72 | Args: 73 | model_class_with_submodel: ModulePlugin subclass. 74 | 75 | ''' 76 | 77 | contract = dict( 78 | kwargs=dict(b='c'), 79 | nets=dict(net='net') 80 | ) 81 | 82 | model = model_class_with_submodel(sub_contract=contract) 83 | 84 | try: 85 | model.build() 86 | assert 0 87 | except KeyError as e: 88 | print('build failed ({}). This is expected.'.format(e)) 89 | 90 | ModelPlugin._all_nets.clear() 91 | 92 | sub_contract = dict( 93 | kwargs=dict(a='c'), 94 | nets=dict(net='net2') 95 | ) 96 | model = model_class_with_submodel(sub_contract=sub_contract) 97 | 98 | model.build() 99 | -------------------------------------------------------------------------------- /tests/built_ins/models/test_loop.py: -------------------------------------------------------------------------------- 1 | '''Tests the training loop. 2 | 3 | ''' 4 | 5 | from cortex._lib import optimizer 6 | 7 | 8 | def test_loop(model_with_submodel): 9 | model = model_with_submodel 10 | model.build() 11 | optimizer.setup(model) 12 | 13 | model.train_loop(0) 14 | 15 | results = model._all_epoch_results 16 | 17 | print(results) 18 | 19 | rlen = len(results['TestModel2_output']) 20 | 21 | assert len(results['TestModel2_output']) == \ 22 | len(results.losses['net']) == \ 23 | len(results.times['TestModel2']) 24 | 25 | model.train_loop(0) 26 | 27 | results = model._all_epoch_results 28 | 29 | print(results) 30 | 31 | assert len(results['TestModel2_output']) == \ 32 | len(results.losses['net']) == \ 33 | len(results.times['TestModel2']) 34 | 35 | assert rlen == len(results['TestModel2_output']) 36 | -------------------------------------------------------------------------------- /tests/built_ins/models/test_routine.py: -------------------------------------------------------------------------------- 1 | '''Module for testing model routines. 2 | 3 | ''' 4 | 5 | import torch.optim as optim 6 | 7 | from cortex.plugins import ModelPlugin 8 | 9 | 10 | def test_routine(model_class, arguments, data_class): 11 | ModelPlugin._reset_class() 12 | 13 | kwargs = {arguments['arg1']: 11, arguments['arg2']: 13} 14 | data = data_class(11) 15 | 16 | model = model_class(contract=dict(inputs=dict(A='test'))) 17 | model._data = data 18 | model.kwargs.update(**kwargs) 19 | 20 | model.build() 21 | 22 | model.eval_step() 23 | print('Training nets: ', model._training_nets) 24 | assert 'net' in list(model._training_nets.values())[0] 25 | 26 | params = list(model.nets.net.parameters()) 27 | op = optim.SGD(params, lr=0.0001) 28 | model._optimizers = dict(net=op) 29 | 30 | A = model.inputs('A') 31 | model.routine(A) 32 | 33 | model._reset_epoch() 34 | 35 | model.train_step() 36 | model.train_step() 37 | model.train_step() 38 | 39 | print('Results:', model._all_epoch_results) 40 | print('Losses:', model._all_epoch_losses) 41 | print('Times:', model._all_epoch_times) 42 | 43 | assert len(list(model._all_epoch_results.values())[0]) == 3 44 | assert len(list(model._all_epoch_losses.values())[0]) == 3 45 | assert len(list(model._all_epoch_times.values())[0]) == 3 46 | 47 | 48 | def test_routine_with_submodels(model_with_submodel): 49 | model = model_with_submodel 50 | model.build() 51 | 52 | params = list(model.nets.net.parameters()) 53 | op = optim.SGD(params, lr=0.0001) 54 | 55 | params2 = list(model.nets.net2.parameters()) 56 | op2 = optim.SGD(params2, lr=0.001) 57 | 58 | model._optimizers = dict(net=op, net2=op2) 59 | model.submodel._optimizers = dict(net=op, net2=op2) 60 | 61 | assert model._get_training_nets() == [] 62 | model.train_step() 63 | assert model._get_training_nets() == ['net', 'net2'] 64 | model.train_step() 65 | model.train_step() 66 | 67 | 68 | def test_routine_with_submodels_2(model_class_with_submodel_2, data_class): 69 | ModelPlugin._reset_class() 70 | 71 | kwargs = {'d': 11, 'c': 13} 72 | data = data_class(11) 73 | 74 | contract = dict(inputs=dict(B='test')) 75 | sub_contract = dict( 76 | kwargs=dict(a='d'), 77 | nets=dict(net='net2'), 78 | inputs=dict(A='test') 79 | ) 80 | 81 | sub_contract2 = dict( 82 | kwargs=dict(a='d'), 83 | nets=dict(net='net3'), 84 | inputs=dict(A='test') 85 | ) 86 | 87 | model = model_class_with_submodel_2(sub_contract1=sub_contract, 88 | sub_contract2=sub_contract2, 89 | contract=contract) 90 | model._data = data 91 | model.submodel1._data = data 92 | model.submodel2._data = data 93 | model.kwargs.update(**kwargs) 94 | 95 | model.build() 96 | 97 | params = list(model.nets.net.parameters()) 98 | op = optim.SGD(params, lr=0.0001) 99 | 100 | params2 = list(model.nets.net2.parameters()) 101 | op2 = optim.SGD(params2, lr=0.001) 102 | 103 | params3 = list(model.nets.net3.parameters()) 104 | op3 = optim.SGD(params3, lr=0.001) 105 | 106 | model._optimizers = dict(net=op, net2=op2, net3=op3) 107 | model.submodel1._optimizers = dict(net=op, net2=op2, net3=op3) 108 | model.submodel2._optimizers = dict(net=op, net2=op2, net3=op3) 109 | 110 | model.train_step() 111 | model.train_step() 112 | model.train_step() 113 | -------------------------------------------------------------------------------- /tests/built_ins/networks/test_base_network.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def test_base_net(base_net_model): 5 | """ 6 | 7 | Args: 8 | base_net_model(@pytest.fixture): BaseNet 9 | 10 | Asserts: True if BaseNet has an empty nn.Sequential models attribute 11 | && a default nn.ReLU activation function. 12 | 13 | """ 14 | assert isinstance(base_net_model.models, nn.Sequential) 15 | assert base_net_model.output_nonlinearity is None 16 | assert isinstance(base_net_model.layer_nonlinearity, nn.ReLU) 17 | 18 | 19 | def test_forward_base_net(base_net_model, simple_tensor): 20 | 21 | """ 22 | 23 | Args: 24 | base_net_model(@pytest.fixture): BaseNet 25 | simple_tensor(@pytest.fixture): torch.Tensor 26 | 27 | Asserts: True if the dimension of the output equals the dimension of 28 | the input. 29 | 30 | """ 31 | base_dimension = simple_tensor.dim() 32 | output = base_net_model.forward(simple_tensor) 33 | assert output.dim() == base_dimension 34 | 35 | 36 | def test_add_linear_layers(base_net_model): 37 | """ 38 | 39 | Args: 40 | base_net_model(@pytest.fixture): BaseNet 41 | 42 | Asserts: True if giving no hidden layers, it returns the dimension 43 | of the input (ImageClassification). 44 | 45 | """ 46 | # Test settings based on ImageClassification. 47 | dim_in = 4096 48 | dim_h = [] 49 | dim_ex = None 50 | Linear = None 51 | layer_args = dict(batch_norm=True, dropout=0.2) 52 | output = base_net_model.add_linear_layers(dim_in, dim_h, dim_ex, Linear, **layer_args) 53 | assert output == dim_in 54 | 55 | 56 | def test_add_output_layer(base_net_model): 57 | """ 58 | 59 | Args: 60 | base_net_model(@pytest.fixture): BasetNet 61 | 62 | Asserts: True if model's models contains an output layer of 63 | a Linear module. 64 | 65 | """ 66 | dim_in = 4096 67 | dim_out = 10 68 | 69 | expected_name = 'linear_({}/{})_{}'.format(dim_in, dim_out, 'out') 70 | 71 | base_net_model.add_output_layer(dim_in, dim_out) 72 | layers = list(base_net_model.models._modules.items()) 73 | 74 | assert layers[0][0] == expected_name 75 | assert isinstance(layers[0][1], nn.modules.linear.Linear) 76 | assert layers[0][1].in_features == dim_in and layers[0][1].out_features == dim_out 77 | -------------------------------------------------------------------------------- /tests/built_ins/networks/test_fully_connected.py: -------------------------------------------------------------------------------- 1 | from cortex.built_ins.networks.fully_connected import FullyConnectedNet 2 | from torch import nn 3 | 4 | 5 | def test_fully_connected_build(): 6 | """ 7 | 8 | Asserts: True if a the FullyConnectedNet has correct layers and 9 | attributes. 10 | 11 | """ 12 | dim_in = 4096 13 | dim_out = 10 14 | dim_h = 64 15 | dim_ex = None 16 | nonlinearity = 'ReLU' 17 | n_levels = None 18 | output_nonlinearity = None 19 | layer_args = {} 20 | 21 | expected_name_linear = 'linear_({}/{})'.format(dim_in, dim_h) 22 | expected_name_relu = 'linear_({}/{})_{}'.format(dim_in, dim_h, 'ReLU') 23 | expected_name_out = 'linear_({}/{})_{}'.format(dim_h, dim_out, 'out') 24 | 25 | fully_connected_net = FullyConnectedNet(dim_in, dim_out, dim_h, dim_ex, 26 | nonlinearity, n_levels, 27 | output_nonlinearity, **layer_args) 28 | layers = list(fully_connected_net.models._modules.items()) 29 | 30 | assert layers[0][0] == expected_name_linear 31 | assert layers[1][0] == expected_name_relu 32 | assert layers[2][0] == expected_name_out 33 | assert isinstance(layers[0][1], nn.modules.linear.Linear) 34 | assert isinstance(layers[1][1], nn.modules.activation.ReLU) 35 | assert isinstance(layers[2][1], nn.modules.linear.Linear) 36 | assert layers[0][1].in_features == dim_in 37 | assert layers[0][1].out_features == dim_h 38 | assert layers[2][1].in_features == dim_h 39 | assert layers[2][1].out_features == dim_out 40 | -------------------------------------------------------------------------------- /tests/built_ins/networks/test_network_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from cortex.built_ins.networks.utils import get_nonlinearity, \ 3 | apply_nonlinearity, finish_layer_1d, finish_layer_2d 4 | from torch import nn 5 | import torch 6 | from cortex.built_ins.networks.modules import View 7 | 8 | 9 | def test_get_nonlinearity(nonlinearity): 10 | 11 | """ 12 | 13 | Args: 14 | nonlinearity(@pytest.fixture): dict 15 | 16 | Asserts: True if right instance of activation function is returned. 17 | 18 | """ 19 | 20 | relu = get_nonlinearity(nonlinearity['relu']) 21 | tanh = get_nonlinearity(nonlinearity['tanh']) 22 | leakyrelu = get_nonlinearity(nonlinearity['leakyrelu']) 23 | sigmoid = get_nonlinearity(nonlinearity['sigmoid']) 24 | 25 | assert callable(sigmoid) 26 | assert callable(tanh) 27 | assert isinstance(relu, nn.modules.activation.ReLU) 28 | assert isinstance(leakyrelu, nn.modules.activation.LeakyReLU) 29 | 30 | 31 | def test_apply_nonlinearity(simple_tensor): 32 | 33 | """ 34 | 35 | Args: 36 | simple_tensor(@pytest.fixture): torch.Tensor 37 | 38 | Asserts: True if the right Pytorch function is called. 39 | 40 | """ 41 | 42 | nonlinearity_args = {} 43 | nonlinear = 'tanh' 44 | 45 | expected_output = torch.nn.functional.tanh(simple_tensor) 46 | applied_nonlinearity = apply_nonlinearity(simple_tensor, nonlinear, 47 | **nonlinearity_args) 48 | 49 | assert torch.equal(expected_output, applied_nonlinearity) 50 | 51 | 52 | def test_finish_layer_1d(nonlinearity): 53 | 54 | """ 55 | 56 | Args: 57 | nonlinearity(@pytest.fixture): dict 58 | 59 | Asserts: True if right GAN layers are being added. 60 | 61 | """ 62 | 63 | # Test settings for a GAN 64 | layer_norm = False 65 | batch_norm = True 66 | dropout = False 67 | name = 'linear_(64/4096)' 68 | dim_out = 4096 69 | 70 | model = nn.Sequential( 71 | OrderedDict([('linear_(64/4096)', 72 | nn.Linear(in_features=64, out_features=4096, 73 | bias=True))])) 74 | 75 | finish_layer_1d(model, name, dim_out, dropout, layer_norm, batch_norm, 76 | nonlinearity['relu']) 77 | 78 | assert isinstance(model[0], torch.nn.modules.linear.Linear) 79 | assert isinstance(model[1], torch.nn.modules.batchnorm.BatchNorm1d) 80 | assert isinstance(model[2], torch.nn.modules.activation.ReLU) 81 | assert model[0].in_features == 64 and model[0].out_features == 4096 82 | 83 | 84 | def test_finish_layer_2d(): 85 | 86 | """ 87 | 88 | Asserts: True if right GAN layers are being added. 89 | 90 | """ 91 | 92 | # Test settings for a GAN 93 | dim_x = 4 94 | dim_y = 4 95 | dim_out = 256 96 | dropout = False 97 | layer_norm = False 98 | batch_norm = True 99 | nonlinearity = nn.ReLU() 100 | name = 'reshape' 101 | batch_norm_1d_layer = nn.BatchNorm1d(4096) 102 | nonlinear_relu_layer = nn.ReLU() 103 | view = View() 104 | 105 | model = nn.Sequential( 106 | OrderedDict([('linear_(64/4096)', 107 | nn.Linear(in_features=64, out_features=4096, bias=True)), 108 | ('linear_(64/4096)_bn', batch_norm_1d_layer), 109 | ('linear_(64/4096)_ReLU', 110 | nonlinear_relu_layer), ('reshape', view)])) 111 | 112 | finish_layer_2d(model, name, dim_x, dim_y, dim_out, dropout, layer_norm, 113 | batch_norm, nonlinearity) 114 | 115 | assert isinstance(model[4], torch.nn.modules.batchnorm.BatchNorm2d) 116 | assert isinstance(model[5], torch.nn.modules.activation.ReLU) 117 | -------------------------------------------------------------------------------- /tests/test_argparsing.py: -------------------------------------------------------------------------------- 1 | from cortex._lib import (config, setup_experiment, exp) 2 | from cortex.built_ins.models.classifier import ImageClassification 3 | from cortex._lib.parsing import update_args 4 | 5 | 6 | def update_nested_dicts(from_d, to_d): 7 | """ 8 | Copied from _lib/__init__::setup_experiment 9 | 10 | """ 11 | for k, v in from_d.items(): 12 | if (k in to_d) and isinstance(to_d[k], dict): 13 | if not isinstance(v, dict): 14 | raise ValueError('Updating dict entry with non-dict.') 15 | update_nested_dicts(v, to_d[k]) 16 | else: 17 | to_d[k] = v 18 | 19 | 20 | def test_command_override_static(args): 21 | """ 22 | 23 | Args: 24 | args(@pytest.fixture): Namespace 25 | 26 | Asserts: True if passing a command line arg, the exp.ARGS is 27 | changing the value from default for the command line 28 | one. 29 | 30 | """ 31 | expected_type = 'resnet' 32 | args.__dict__['classifier_type'] = expected_type 33 | classifier_defaults = ImageClassification() 34 | config.set_config() 35 | # NOTE: exp.ARGS is being populated inside setup_experiment() call 36 | classifier_defaults = setup_experiment( 37 | args, model=classifier_defaults, testmode=True) 38 | assert exp.ARGS['model']['classifier_type'] == expected_type 39 | 40 | 41 | def test_static_override_parameters(args, classifier_modified): 42 | """ 43 | 44 | Args: 45 | args(@pytest.fixture): Namespace 46 | classifier_modified(@pytest.fixture): ClassifierModified 47 | 48 | Asserts: True if default attribute is overriding 49 | parameters values. 50 | 51 | """ 52 | expected_type = 'convnet' 53 | config.set_config() 54 | classifier_modified = setup_experiment( 55 | args, model=classifier_modified, testmode=True) 56 | assert exp.ARGS['model']['classifier_type'] == expected_type 57 | 58 | 59 | def test_update_nested_dicts(args, classifier_modified): 60 | """ 61 | 62 | Args: 63 | args(@pytest.fixture): Namespace 64 | classifier_modified(@pytest.fixture): ClassifierModified 65 | 66 | Asserts: True if a dict. arg. is being updated to a 67 | nested dict. (not overridden). 68 | 69 | """ 70 | expected_classifier_args_before_update = {'dropout': 0.2} 71 | expected_classifier_args_after_update = {'dropout': 0.2, 'dim_h': 100} 72 | args_for_update = { 73 | 'data': { 74 | 'batch_size': 122 75 | }, 76 | 'model': { 77 | 'classifier_args': { 78 | 'dropout': 0.2, 79 | 'dim_h': 100 80 | } 81 | } 82 | } 83 | config.set_config() 84 | classifier_modified = setup_experiment( 85 | args, model=classifier_modified, testmode=True) 86 | assert exp.ARGS['model'][ 87 | 'classifier_args'] == expected_classifier_args_before_update 88 | update_nested_dicts(args_for_update, exp.ARGS) 89 | assert exp.ARGS['model'][ 90 | 'classifier_args'] == expected_classifier_args_after_update 91 | assert exp.ARGS['data']['batch_size'] == 122 92 | 93 | 94 | def test_update_args(args, classifier_modified): 95 | """ 96 | 97 | Args: 98 | args(@pytest.fixture): Namespace 99 | classifier_modified(@pytest.fixture): ClassifierModified 100 | 101 | Asserts: True if exp.ARGS is updated adequately. 102 | 103 | """ 104 | expected_classifier_args_before_update = {'dropout': 0.2} 105 | expected_classifier_args_after_update = {'dropout': 0.1} 106 | args_for_update = { 107 | 'data': { 108 | 'batch_size': 128 109 | }, 110 | 'model': { 111 | 'classifier_args': { 112 | 'dropout': 0.1 113 | } 114 | } 115 | } 116 | config.set_config() 117 | classifier_modified = setup_experiment( 118 | args, model=classifier_modified, testmode=True) 119 | assert exp.ARGS['model'][ 120 | 'classifier_args'] == expected_classifier_args_before_update 121 | update_args(args_for_update, exp.ARGS) 122 | assert exp.ARGS['model'][ 123 | 'classifier_args'] == expected_classifier_args_after_update 124 | assert exp.ARGS['data']['batch_size'] == 128 125 | -------------------------------------------------------------------------------- /tests/test_handlers.py: -------------------------------------------------------------------------------- 1 | '''Module for testing handlers. 2 | 3 | ''' 4 | 5 | import torch 6 | 7 | from cortex._lib.handlers import (AliasedHandler, PrefixedAliasedHandler, 8 | Handler, NetworkHandler) 9 | from cortex.built_ins.networks.fully_connected import FullyConnectedNet 10 | 11 | 12 | def test_basic_handler(): 13 | h = Handler() 14 | h.a = 1 15 | assert h['a'] == 1 16 | 17 | h.a = 2 18 | 19 | h.lock() 20 | assert h._locked 21 | 22 | try: 23 | h.b = 10 24 | assert 0 25 | except KeyError: 26 | pass 27 | 28 | h = Handler(allow_overwrite=False) 29 | h.a = 1 30 | 31 | try: 32 | h.a = 2 33 | assert 0 34 | except KeyError: 35 | pass 36 | 37 | for kv in h.items(): 38 | pass 39 | 40 | for k in h: 41 | pass 42 | 43 | 44 | def test_network_handler(): 45 | h = NetworkHandler() 46 | 47 | try: 48 | h.a = 1 49 | assert 0 50 | except TypeError: 51 | pass 52 | 53 | h.a = FullyConnectedNet(1, 2) 54 | 55 | assert isinstance(h.a, torch.nn.Module) 56 | 57 | h.b = FullyConnectedNet(2, 3) 58 | 59 | for k, m in h.items(): 60 | assert isinstance(m, torch.nn.Module), k 61 | 62 | 63 | def test_aliased_handler(): 64 | h = Handler() 65 | 66 | aliases = dict(a='A', b='B', c='C', d='D') 67 | 68 | ah = AliasedHandler(h, aliases=aliases) 69 | 70 | ah.a = 13 71 | ah.b = 12 72 | 73 | try: 74 | ah.C = 22 75 | assert 0, 'Name in the set of aliases values cannot be set.' 76 | except KeyError: 77 | pass 78 | 79 | assert ah.a == 13 80 | assert h.A == 13 81 | assert ah.b == 12 82 | assert h.B == 12 83 | 84 | for k, v in ah.items(): 85 | if k == 'a': 86 | pass 87 | elif k == 'b': 88 | pass 89 | else: 90 | assert False, k 91 | 92 | for k in ah: 93 | if k == 'a': 94 | pass 95 | elif k == 'b': 96 | pass 97 | else: 98 | assert False, k 99 | 100 | ah.pop('a') 101 | try: 102 | h.A 103 | assert 0 104 | except AttributeError: 105 | pass 106 | 107 | 108 | def test_prefixed_handler(): 109 | 110 | h = Handler() 111 | ah = PrefixedAliasedHandler(h, prefix='test') 112 | 113 | ah.a = 13 114 | ah.b = 12 115 | 116 | assert ah.a == 13 117 | assert h.test_a == 13 118 | assert ah.b == 12 119 | assert h.test_b == 12 120 | 121 | ah.pop('a') 122 | 123 | try: 124 | h.test_a 125 | assert 0 126 | except AttributeError: 127 | pass 128 | -------------------------------------------------------------------------------- /tests/test_optimizer.py: -------------------------------------------------------------------------------- 1 | '''Tests the optimizer functionality. 2 | 3 | ''' 4 | 5 | import copy 6 | 7 | import numpy as np 8 | 9 | from cortex._lib import optimizer 10 | 11 | 12 | def test_optimizer(model_with_submodel): 13 | model = model_with_submodel 14 | model.build() 15 | 16 | optimizer.setup(model) 17 | 18 | assert set(optimizer.OPTIMIZERS.keys()) == set(['net', 'net2']) 19 | 20 | optimizer.OPTIMIZERS['net'].step() 21 | 22 | 23 | def test_clipping(model_with_submodel, clip=0.0001): 24 | model = model_with_submodel 25 | model.build() 26 | 27 | optimizer.setup(model, clipping=clip) 28 | 29 | model.train_step() 30 | 31 | optimizer.OPTIMIZERS['net'].step() 32 | 33 | params = model.nets.net.parameters() 34 | 35 | for p in params: 36 | print(p.min(), p.max()) 37 | assert p.max() <= clip 38 | assert -p.min() <= clip 39 | 40 | 41 | def test_gradient(model_with_submodel): 42 | model = model_with_submodel 43 | model.build() 44 | 45 | optimizer.setup(model, learning_rate=1.0, optimizer='SGD') 46 | 47 | model.routine(auto_input=True) 48 | 49 | net_loss = model.losses['net'] 50 | parameters = copy.deepcopy(list(model.nets.net.parameters())) 51 | print(parameters[0][0]) 52 | print(net_loss) 53 | 54 | print('Stepping') 55 | net_loss.backward() 56 | grad = [p.grad for p in model.nets.net.parameters()] 57 | 58 | print('grad', grad[0][0]) 59 | model._optimizers['net'].step() 60 | 61 | grad = [p.grad for p in model.nets.net.parameters()] 62 | 63 | for p1, p2, g in zip(list(model.nets.net.parameters()), parameters, grad): 64 | print(p1[0], p2[0], g[0]) 65 | print('diff', p1[0] - p2[0]) 66 | p1 = p1.data.numpy() 67 | p2 = p2.data.numpy() 68 | g = g.data.numpy() 69 | assert np.allclose(p1, p2 - g) 70 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | count = True 3 | show-source = True 4 | doctests = True 5 | # select = E, F, W, C90, I, D, B, B902 6 | ignore = 7 | # blank-line after doc summaries (annoying for modules' doc) 8 | D205 9 | # conflicts with D211: No blank lines allowed before class docstring 10 | D203 11 | # do not enforce first-line-period at module docs 12 | D400 13 | # conflicts with E133: closing bracket is missing indentation 14 | E123 15 | exclude = 16 | .tox, 17 | .git, 18 | __pycache__, 19 | docs, 20 | build, 21 | dist, 22 | *.pyc, 23 | *.egg-info, 24 | .cache, 25 | .eggs, 26 | max-line-length = 100 27 | # McCabe complexity checker 28 | max-complexity = 20 29 | # flake8-import-order: style 30 | import-order-style = google 31 | # flake8-import-order: local module name checker 32 | application-import-names = cortex 33 | 34 | [testenv:flake8] 35 | description = Use flake8 linter to impose standards on the project 36 | basepython = python3.6 37 | skip_install = true 38 | deps = 39 | flake8 == 3.5.0 40 | flake8-import-order == 0.15 41 | flake8-docstrings == 1.1.0 42 | flake8-bugbear == 17.4.0 43 | commands = 44 | flake8 cortex setup.py 45 | 46 | [testenv:pylint] 47 | description = Perform static analysis and output code metrics 48 | basepython = python3.6 49 | skip_install = false 50 | deps = 51 | pylint == 1.8.1 52 | commands = 53 | pylint cortex cortex/_lib cortex/built_ins 54 | 55 | [testenv:autopep8] 56 | deps = 57 | autopep8 58 | commands = 59 | autopep8 -a -i -r cortex 60 | 61 | 62 | [testenv:docs] 63 | basepython = python 64 | deps = sphinx 65 | sphinx_rtd_theme 66 | commands = 67 | sphinx-apidoc -f -o docs/source ./cortex 68 | make html 69 | 70 | [doc8] 71 | max-line-length = 100 72 | file-encoding = utf-8 73 | 74 | [testenv:doc8] 75 | description = Impose standards on *.rst documentation files 76 | basepython = python3.6 77 | skip_install = true 78 | deps = 79 | -rdocs/requirements.txt 80 | doc8 == 0.8.0 81 | commands = 82 | doc8 -v docs/source/ -------------------------------------------------------------------------------- /travis-config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from os import path 3 | 4 | config_name = '.cortex.yml' 5 | home = path.expanduser('~') 6 | config_file_path = path.join(home, config_name) 7 | 8 | configs = dict( 9 | datapaths=dict(local=home, torchvision=home), 10 | out_path=home, 11 | viz=dict(port=8097, server='http://localhost')) 12 | 13 | with open(config_file_path, 'w') as config_file: 14 | yaml.dump(configs, config_file) 15 | -------------------------------------------------------------------------------- /travis-output.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Abort on Error 3 | set -e 4 | 5 | export PING_SLEEP=30s 6 | export WORKDIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 7 | export BUILD_OUTPUT=$WORKDIR/build.out 8 | 9 | touch $BUILD_OUTPUT 10 | 11 | dump_output() { 12 | echo Tailing the last 500 lines of output: 13 | tail -500 $BUILD_OUTPUT 14 | } 15 | error_handler() { 16 | echo ERROR: An error was encountered with the build. 17 | dump_output 18 | exit 1 19 | } 20 | # If an error occurs, run our error handler to output a tail of the build 21 | trap 'error_handler' ERR 22 | 23 | # Set up a repeating loop to send some output to Travis. 24 | 25 | bash -c "while true; do echo \$(date) - building ...; sleep $PING_SLEEP; done" & 26 | PING_LOOP_PID=$! 27 | 28 | # My build is using maven, but you could build anything with this, E.g. 29 | # your_build_command_1 >> $BUILD_OUTPUT 2>&1 30 | # your_build_command_2 >> $BUILD_OUTPUT 2>&1 31 | pip install . >> $BUILD_OUTPUT 2>&1 32 | 33 | # The build finished without returning an error so dump a tail of the output 34 | dump_output 35 | 36 | # nicely terminate the ping output loop 37 | kill $PING_LOOP_PID --------------------------------------------------------------------------------