├── vulcanai ├── tests │ ├── models │ │ ├── __init__.py │ │ ├── test_basenetwork.py │ │ ├── test_data │ │ │ └── sensitivity_analysis_test_truth.csv │ │ ├── test_device.py │ │ ├── test_utils.py │ │ ├── test_ensemble.py │ │ ├── conftest.py │ │ ├── test_cnn.py │ │ ├── test_layers.py │ │ └── test_dnn.py │ ├── datasets │ │ ├── test_data │ │ │ ├── birthweight_reduced2.csv │ │ │ ├── birthweight_reduced_merged.csv │ │ │ └── birthweight_reduced.csv │ │ └── test_tabulardataset.py │ └── plotters │ │ └── test_visualization.py ├── datasets │ ├── __init__.py │ ├── utils.py │ ├── multidataset.py │ └── fashion.py ├── plotters │ ├── __init__.py │ ├── utils.py │ └── visualization.py ├── logging.conf ├── models │ ├── __init__.py │ ├── dnn.py │ ├── ensemble.py │ ├── layers.py │ ├── cnn.py │ └── utils.py └── __init__.py ├── readthedocs.yml ├── docs ├── source │ ├── modules.rst │ ├── vulcanai.rst │ ├── vulcanai.plotters.rst │ ├── vulcanai.datasets.rst │ ├── vulcanai.models.rst │ ├── index.rst │ └── conf.py ├── requirements.txt ├── Makefile └── make.bat ├── MANIFEST.in ├── requirements.txt ├── .gitignore ├── environment.yml ├── .travis.yml ├── examples ├── sensitivity_analysis_example.py ├── predict_single_value_example.py ├── fashion_conv_dense.py └── fashion_multi_input_network.py ├── setup.py ├── README.md └── CODE_OF_CONDUCT.md /vulcanai/tests/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | build: 2 | image: latest 3 | 4 | python: 5 | version: 3.6 6 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | vulcanai 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | vulcanai 8 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include vulcanai *.py 2 | include README.md 3 | include LICENSE 4 | include requirements.txt 5 | include vulcanai/logging.conf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16.0 2 | setuptools>=40.6.3 3 | scipy>=1.1.0 4 | matplotlib>=2.0.2 5 | scikit-learn>=0.18.0 6 | sphinx>=1.8.3 7 | pandas>=0.23.4 8 | pydash>=4.7.4 9 | tqdm>=4.25.0 10 | seaborn>=0.9.0 11 | torch>=1.0.0 12 | torchvision>=0.2.1 13 | pytest>=3.8.0 14 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.16.0 2 | setuptools==40.6.3 3 | scipy==1.1.0 4 | matplotlib==2.0.2 5 | scikit-learn==0.18.0 6 | sphinx==1.8.3 7 | pandas==0.23.4 8 | pydash==4.7.4 9 | tqdm==4.25.0 10 | seaborn==0.9.0 11 | https://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-linux_x86_64.whl 12 | torchvision==0.2.1 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb_checkpoints 2 | *.pyc 3 | *.DS_Store 4 | *.__pycache__ 5 | *.pytest_cache 6 | data 7 | runs 8 | examples/saved_models 9 | 10 | /*.egg-info 11 | *.png 12 | *.jpg 13 | *.pdf 14 | *.csv 15 | *.npz 16 | *.pickle 17 | *.gz 18 | *.npy 19 | *.network 20 | *.json 21 | *.zip 22 | *.cache 23 | *.log 24 | docs/build 25 | -------------------------------------------------------------------------------- /docs/source/vulcanai.rst: -------------------------------------------------------------------------------- 1 | vulcanai package 2 | ================ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | vulcanai.datasets 10 | vulcanai.models 11 | vulcanai.plotters 12 | 13 | Module contents 14 | --------------- 15 | 16 | .. automodule:: vulcanai 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: Vulcan 2 | channels: 3 | - defaults 4 | dependencies: 5 | - libgfortran=3.0.0=1 6 | - mkl 7 | - numpy=1.13.1=py36_0 8 | - openssl=1.0.2l=0 9 | - pip=10.0.1 10 | - python>=3.6 11 | - readline=6.2=2 12 | - scikit-learn>=0.19.0 13 | - scipy>=0.19.1 14 | - setuptools=27.2.0=py36_0 15 | - sqlite=3.13.0=0 16 | - tk=8.5.18=0 17 | - wheel=0.32.2 18 | - xz=5.2.2=1 19 | - zlib=1.2.8=3 20 | - tqdm=4.14.0 21 | - pip: 22 | - tabulate==0.7.7 -------------------------------------------------------------------------------- /docs/source/vulcanai.plotters.rst: -------------------------------------------------------------------------------- 1 | vulcanai.plotters package 2 | ========================= 3 | 4 | vulcanai.plotters.utils module 5 | ------------------------------ 6 | 7 | .. automodule:: vulcanai.plotters.utils 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | vulcanai.plotters.visualization module 13 | -------------------------------------- 14 | 15 | .. automodule:: vulcanai.plotters.visualization 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | -------------------------------------------------------------------------------- /vulcanai/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Imports dataset classes so they can be used directly 3 | 4 | Submodules 5 | ========== 6 | 7 | .. autosummary:: 8 | :toctree: _autosummary 9 | 10 | fashion 11 | tabulardataset 12 | multidataset 13 | """ 14 | 15 | from .fashion import FashionData 16 | from .multidataset import MultiDataset 17 | 18 | __all__ = [ 19 | 'fashion', 20 | 'tabular_data_utils.py', 21 | 'utils', 22 | 'FashionData', 23 | 'MultiDataset' 24 | ] -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /vulcanai/plotters/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import matplotlib 3 | matplotlib.use('agg') 4 | import matplotlib.pyplot as plt 5 | 6 | from .visualization import ( 7 | compute_saliency_map, 8 | display_saliency_overlay, 9 | display_pca, 10 | display_tsne, 11 | display_confusion_matrix, 12 | display_record, 13 | display_receptive_fields 14 | ) 15 | 16 | __all__ = [ 17 | 'utils', 18 | 'visualization', 19 | 'compute_saliency_map', 20 | 'display_saliency_overlay', 21 | 'display_pca', 22 | 'display_tsne', 23 | 'display_confusion_matrix', 24 | 'display_record', 25 | 'display_receptive_fields', 26 | ] 27 | -------------------------------------------------------------------------------- /vulcanai/logging.conf: -------------------------------------------------------------------------------- 1 | [loggers] 2 | keys=root 3 | 4 | [handlers] 5 | keys=consoleHandler,fileHandler 6 | 7 | [formatters] 8 | keys=fileFormatter,consoleFormatter 9 | 10 | [logger_root] 11 | level=DEBUG 12 | handlers=consoleHandler,fileHandler 13 | 14 | [handler_consoleHandler] 15 | class=StreamHandler 16 | level=WARNING 17 | formatter=consoleFormatter 18 | args=(sys.stdout,) 19 | 20 | [handler_fileHandler] 21 | class=FileHandler 22 | level=DEBUG 23 | formatter=fileFormatter 24 | args=('%(logfilename)s',) 25 | 26 | [formatter_fileFormatter] 27 | format=%(asctime)s - %(name)s - %(levelname)s - %(message)s 28 | datefmt=%Y/%m/%d %H:%M:%S 29 | 30 | [formatter_consoleFormatter] 31 | format=%(name)s - %(levelname)s - %(message)s 32 | datefmt=%Y/%m/%d %H:%M:%S 33 | -------------------------------------------------------------------------------- /docs/source/vulcanai.datasets.rst: -------------------------------------------------------------------------------- 1 | vulcanai.datasets package 2 | ========================= 3 | 4 | vulcanai.datasets.fashion module 5 | -------------------------------- 6 | 7 | .. automodule:: vulcanai.datasets.fashion 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | vulcanai.datasets.multidataset module 13 | ------------------------------------- 14 | 15 | .. automodule:: vulcanai.datasets.multidataset 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | vulcanai.datasets.tabulardataset module 21 | --------------------------------------- 22 | 23 | .. automodule:: vulcanai.datasets.tabulardataset 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | vulcanai.datasets.utils module 29 | ------------------------------ 30 | 31 | .. automodule:: vulcanai.datasets.utils 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - '3.6' 4 | install: 5 | - pip install --upgrade pip 6 | - pip install --progress-bar off -r requirements.txt 7 | - pip install -e . 8 | script: 9 | - pytest 10 | branches: 11 | only: 12 | - master 13 | - /^\d+\.\d+\.\d+$/ 14 | deploy: 15 | provider: pypi 16 | user: "rfratila" 17 | password: 18 | secure: i1SMEd951TYKtVhAAOFduJaRraqrhyyresjOK2V19xRy0c1AiMcoUSZlKS3xr1ET4p27zDkZdGDiu+bQono22UDHFKg039EXz8zGk5TybmLkL/v+KsStWSU9Y7wUm/GGPPyACR0UPYJoMjGVyJFaGhJWGNX3k2PGDZdaAenGaeJUdsutMlHquBUTNgJg3Ib0z34bDTf6KSssXlbMSwx0fvcztbKRCbst0QrorZ/osp9udMWyAcyXXJEl+rBmSd1zTkcfbku/O6y+y9HQWBWEuQS9iXUW8Gc8HUm8zlVQafJkvEyvSR8KBMUl99sxG8UAkdsrP3CkiRnS0OzJkHQR/aq8ClfN1Kz6gid5SMt06px8k/J1zZ8E7EUCLK6ra+QpfngdT6z391ZVga7dLOri3kP964LJ6gSdeAx5XFAKY7fN3xSUfZ3je39AZ2dkiFmogOmnzAkIDBhbWdLG1KruLmq9XRZri/fJ/WZZ1Cxw7RcT+7SCaFEhiJdhzUWuoor30WwglzX4lOUgnBJw+9uH7FrH9WrLD6RRwAN8m3P868crSLERKf0Vja73ZkUOS9zVcHkw1gc55VqsOu+I3RU8FpPNU5IpaKOwAvEfCIliPLAE4Lnh+3r4b+ZBqk5V+TD/hgtyLQuupFdhTCXRZM4xo4GlA9V1hiWz68G0VjJkNTA= 19 | on: 20 | tags: true 21 | -------------------------------------------------------------------------------- /vulcanai/models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ Imports network classes so they can be used directly""" 3 | from .utils import ( 4 | round_list, 5 | get_one_hot, 6 | pad, 7 | network_summary, 8 | print_model_structure, 9 | selu_weight_init_, 10 | selu_bias_init_, 11 | set_tensor_device, 12 | master_device_setter 13 | ) 14 | 15 | from .basenetwork import BaseNetwork 16 | from .cnn import ConvNet, ConvNetConfig 17 | from .dnn import DenseNet, DenseNetConfig 18 | from .ensemble import SnapshotNet 19 | from .metrics import Metrics 20 | from .layers import BaseUnit, ConvUnit, DenseUnit, FlattenUnit 21 | 22 | __all__ = [ 23 | 'basenetwork', 24 | 'cnn', 25 | 'dnn', 26 | 'layers', 27 | 'ensemble', 28 | 'metrics', 29 | 'utils', 30 | 'BaseNetwork', 31 | 'ConvNet', 32 | 'ConvNetConfig', 33 | 'DenseNet', 34 | 'DenseNetConfig', 35 | 'SnapshotNet', 36 | 'Metrics', 37 | 'ConvUnit', 38 | 'DenseUnit', 39 | 'FlattenUnit', 40 | 'BaseUnit' 41 | ] 42 | -------------------------------------------------------------------------------- /vulcanai/__init__.py: -------------------------------------------------------------------------------- 1 | import logging.config 2 | import os 3 | import torch 4 | import random 5 | import numpy as np 6 | 7 | 8 | DEFAULT_RANDOM_SEED = 42 9 | 10 | log_config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 11 | 'logging.conf') 12 | 13 | if not os.path.isfile(log_config_path): 14 | raise IOError 15 | 16 | log_output_path = "logfile.log" 17 | 18 | logging.config.fileConfig(log_config_path, 19 | disable_existing_loggers=False, 20 | defaults={'logfilename': log_output_path}) 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def set_global_seed(seed_value): 26 | """ 27 | Sets all the random seeds, including for torch, GPU, numpy and python. 28 | 29 | Parameters: 30 | seed_value : int 31 | The random seed value. 32 | 33 | """ 34 | random.seed(seed_value) 35 | np.random.seed(seed_value) # cpu vars 36 | torch.manual_seed(seed_value) # cpu vars 37 | if torch.cuda.is_available(): 38 | torch.cuda.manual_seed_all(seed_value) 39 | 40 | 41 | set_global_seed(DEFAULT_RANDOM_SEED) 42 | 43 | -------------------------------------------------------------------------------- /examples/sensitivity_analysis_example.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, Subset, TensorDataset 2 | from vulcanai.models import ConvNet, DenseNet 3 | import sys 4 | sys.path.append('.') 5 | from vulcanai.models.metrics import Metrics 6 | import torch 7 | 8 | def dataloader(): 9 | """Create a dataset by importing from the test csv""" 10 | test_input = torch.rand(size=[13, 15]) 11 | test_dataloader = DataLoader( 12 | TensorDataset(test_input, torch.tensor([0, 0, 0, 1, 2, 0, 2, 0, 0, 0, 0, 13 | 0, 0]))) 14 | return test_dataloader 15 | 16 | def dnn(): 17 | """DenseNet with prediction layer.""" 18 | return DenseNet( 19 | name='dnn_class', 20 | in_dim=(15), 21 | config={ 22 | 'dense_units': [100, 50], 23 | 'dropout': 0.5, 24 | }, 25 | num_classes=4 26 | ) 27 | 28 | if __name__ == '__main__': 29 | net = dnn() 30 | dl = dataloader() 31 | 32 | m = Metrics() 33 | 34 | # with col names given 35 | m.conduct_sensitivity_analysis(net, dl, 'test_sensitivity_analysis_1', 36 | ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 37 | 'i', 'j', 'k', 'l']) 38 | 39 | # without col names given 40 | m.conduct_sensitivity_analysis(net, dl, 'test_sensitivity_analysis_2') 41 | -------------------------------------------------------------------------------- /docs/source/vulcanai.models.rst: -------------------------------------------------------------------------------- 1 | vulcanai.models package 2 | ======================= 3 | 4 | vulcanai.models.basenetwork module 5 | ---------------------------------- 6 | 7 | .. automodule:: vulcanai.models.basenetwork 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | vulcanai.models.cnn module 13 | -------------------------- 14 | 15 | .. automodule:: vulcanai.models.cnn 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | vulcanai.models.dnn module 21 | -------------------------- 22 | 23 | .. automodule:: vulcanai.models.dnn 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | vulcanai.models.ensemble module 29 | ------------------------------- 30 | 31 | .. automodule:: vulcanai.models.ensemble 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | vulcanai.models.layers module 37 | ----------------------------- 38 | 39 | .. automodule:: vulcanai.models.layers 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | vulcanai.models.metrics module 45 | ------------------------------ 46 | 47 | .. automodule:: vulcanai.models.metrics 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: 51 | 52 | vulcanai.models.utils module 53 | ---------------------------- 54 | 55 | .. automodule:: vulcanai.models.utils 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | 60 | -------------------------------------------------------------------------------- /vulcanai/tests/models/test_basenetwork.py: -------------------------------------------------------------------------------- 1 | """Test BaseNetwork functionality.""" 2 | import pytest 3 | from vulcanai.models.basenetwork import BaseNetwork 4 | 5 | import torch 6 | 7 | 8 | class TestBaseNetwork: 9 | """Define BaseNetwork test class.""" 10 | 11 | @pytest.fixture 12 | def basenet(self): 13 | """Create a test BaseNetwork.""" 14 | return BaseNetwork( 15 | name='Test_BaseNet', 16 | in_dim=(None, 10), 17 | config={} 18 | ) 19 | 20 | def test_init(self, basenet): 21 | """Initialization Test of a BaseNetwork object.""" 22 | assert isinstance(basenet, BaseNetwork) 23 | assert isinstance(basenet, torch.nn.Module) 24 | assert hasattr(basenet, 'network') 25 | assert hasattr(basenet, 'in_dim') 26 | assert hasattr(basenet, 'record') 27 | 28 | def test_name(self, basenet): 29 | """Test changing names.""" 30 | basenet.name = 'New_Name' 31 | assert basenet.name is 'New_Name' 32 | 33 | def test_learning_rate(self, basenet): 34 | """Test learning rate change.""" 35 | basenet.learning_rate = 0.1 36 | assert basenet.learning_rate is 0.1 37 | 38 | def test_default_criter_spec(self, basenet): 39 | """Test default value behaviour for criter spec""" 40 | assert isinstance(basenet.criter_spec, torch.nn.CrossEntropyLoss) 41 | assert isinstance(basenet._final_transform, torch.nn.Softmax) 42 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Vulcan documentation master file, created by 2 | sphinx-quickstart on Thu Jan 17 14:46:09 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Vulcan's documentation! 7 | ================================== 8 | 9 | Vulcan is Aifred Health's framework for rapid deep learning model prototyping and analysis. 10 | 11 | Vulcan provides the tools for: 12 | 13 | 1. Rapid-yet-flexible data preprocessing 14 | 15 | 2. Rapid creation of modular neural networks. Among the usual we also include capability for: 16 | 17 | - snapshot ensembles 18 | - multi-modal networks with complex architectures 19 | - state of the art activations 20 | - training and saving models across multiple machines 21 | 22 | 3. Model Evaluation 23 | 24 | 4. Visualization for data and network interpretability. Among the usual we also include: 25 | 26 | - t-SNE 27 | - Saliency maps using guided backpropagation 28 | 29 | 30 | Vulcan is built on Pytorch. We think Pytorch is great, so our framework was built with the goal of facilitating but not impeding access to all of Pytorch. Want to do things the easy way? Great, create a network using our simple configuration dict. Need something a little more complicated? Extend our classes or write your own Pytorch module to use within the rest of our framework. 31 | 32 | Contents 33 | ================== 34 | .. toctree:: 35 | :maxdepth: 2 36 | 37 | vulcanai.datasets 38 | vulcanai.models 39 | vulcanai.plotters 40 | 41 | 42 | Indices and tables 43 | ================== 44 | 45 | * :ref:`genindex` 46 | * :ref:`modindex` 47 | * :ref:`search` 48 | -------------------------------------------------------------------------------- /examples/predict_single_value_example.py: -------------------------------------------------------------------------------- 1 | from vulcanai import datasets 2 | from vulcanai.models import ConvNet, DenseNet 3 | import torch 4 | from torch.utils.data import DataLoader, Subset, TensorDataset 5 | import numpy as np 6 | 7 | conv_2D_config = { 8 | 'conv_units': [ 9 | dict( 10 | in_channels=1, 11 | out_channels=16, 12 | kernel_size=(5, 5), 13 | stride=2, 14 | dropout=0.1 15 | ), 16 | dict( 17 | in_channels=16, 18 | out_channels=32, 19 | kernel_size=(5, 5), 20 | dropout=0.1 21 | ), 22 | dict( 23 | in_channels=32, 24 | out_channels=64, 25 | kernel_size=(5, 5), 26 | pool_size=2, 27 | dropout=0.1 28 | ) 29 | ], 30 | } 31 | 32 | 33 | conv_2D = ConvNet( 34 | name='conv_2D', 35 | in_dim=(1, 28, 28), 36 | config=conv_2D_config, 37 | num_classes=1, 38 | criter_spec=torch.nn.MSELoss() 39 | ) 40 | 41 | test_input = torch.rand(size=[10, 1, 28, 28]) 42 | test_output = torch.rand(size=[10, 1]) 43 | 44 | test_dataloader = DataLoader(TensorDataset(test_input, test_output)) 45 | 46 | conv_2D.fit( 47 | test_dataloader, 48 | test_dataloader, 49 | epochs=3, 50 | plot=False, 51 | save_path="." 52 | ) 53 | 54 | conv_2D.run_test(test_dataloader) 55 | res = conv_2D.forward_pass(test_dataloader, convert_to_class=True, transform_callable=np.round, decimals=3) 56 | 57 | print(res) 58 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup script for uploading package to PyPI servers.""" 2 | from setuptools import setup, find_packages 3 | 4 | tests_require = [ 5 | 'pytest', 6 | 'numpydoc' 7 | ] 8 | 9 | docs_require = [ 10 | 'Sphinx', # TODO: maybe numpydoc? 11 | ] 12 | 13 | with open('requirements.txt') as f: 14 | install_requires = [l.strip() for l in f] 15 | 16 | 17 | setup( 18 | name='vulcanai', 19 | version='1.0.8', 20 | description='A high-level framework built on top of Pytorch' 21 | ' using added functionality from Scikit-learn to provide ' 22 | 'all of the tools needed for visualizing and processing ' 23 | 'high-dimensional data, modular neural networks, ' 24 | 'and model evaluation', 25 | author='Robert Fratila, Priyatharsan Rajasekar, Caitrin Armstrong, ' 26 | 'Joseph Mehltretter, Sneha Desai', 27 | author_email='robertfratila10@gmail.com', 28 | url='https://github.com/Aifred-Health/Vulcan', 29 | install_requires=install_requires, 30 | extras_require={ 31 | 'testing': tests_require, 32 | 'docs': docs_require, 33 | }, 34 | packages=find_packages(), 35 | include_package_data=True, 36 | classifiers=['Development Status :: 3 - Alpha', 37 | 'Intended Audience :: Developers', 38 | 'Intended Audience :: Science/Research', 39 | 'Intended Audience :: Education', 40 | 'Topic :: Software Development :: Build Tools', 41 | 'Programming Language :: Python :: 3.6', 42 | 'Operating System :: Unix', 43 | 'Operating System :: POSIX :: Linux', 44 | 'Topic :: Scientific/Engineering :: Artificial Intelligence'], 45 | keywords='deep learning machine learning development framework' 46 | ) 47 | -------------------------------------------------------------------------------- /vulcanai/tests/datasets/test_data/birthweight_reduced2.csv: -------------------------------------------------------------------------------- 1 | id,headcirumference,length,Birthweight,Gestation,smoker,motherage,mnocig,mheight,mppwt,fage,fedyrs 2 | 1313,12,17,5.8,33,0,24,0,58,99,26,16 3 | 431,12,19,4.2,33,1,20,7,63,109,20,10 4 | 808,13,19,6.4,34,0,26,0,65,140,25,12 5 | 300,12,18,4.5,35,1,41,7,65,125,37,14 6 | 516,13,18,5.8,35,1,20,35,67,125,23,12 7 | 321,13,19,6.8,37,0,28,0,62,118,39,10 8 | 1363,12,19,5.2,37,1,20,7,64,104,20,10 9 | 575,12,19,6.1,37,1,19,7,65,132,20,14 10 | 822,13,19,7.5,38,0,20,0,62,103,22,14 11 | 1081,14,21,8,38,0,18,0,67,109,20,12 12 | 1636,14,20,8.6,38,0,29,0,64,135,31,16 13 | 1107,14,20,7.1,38,0,31,0,64,125,35,16 14 | 1023,13,20,6.6,38,1,30,12,64,140,38,14 15 | 1369,13,19,7,38,1,31,25,63,124,32,16 16 | 697,13,19,6.6,39,0,27,0,63,135,27,14 17 | 1600,13,21,6.3,39,0,19,0,64,125,23,14 18 | 57,14,20,7.3,39,1,23,17,62,104,32,12 19 | 272,14,20,8.5,39,1,30,25,67,170,40,16 20 | 569,13,19,5.5,39,1,22,7,62,115,23,14 21 | 619,13,20,7.5,39,1,23,25,71,152,23,16 22 | 1522,13,19,6,39,1,21,17,61,115,24,12 23 | 820,13,20,8.3,40,0,24,0,62,110,31,16 24 | 1016,14,21,9.5,40,0,19,0,67,135,19,12 25 | 1058,13,20,6.9,40,0,29,0,65,130,30,16 26 | 1088,14,20,7.2,40,0,24,0,66,117,29,16 27 | 365,14,20,7.7,40,1,26,25,67,137,30,10 28 | 532,13,21,7.9,40,1,31,12,64,107,41,12 29 | 752,14,19,7.3,40,1,27,12,60,105,37,12 30 | 792,14,21,8,40,1,20,2,66,130,24,12 31 | 1272,12,20,6,40,1,37,50,66,135,31,16 32 | 462,15,22,9,41,0,35,0,67,127,31,16 33 | 755,13,21,7,41,0,21,0,61,120,25,14 34 | 1683,13,21,7.3,41,0,27,0,64,135,37,14 35 | 27,14,20,7.8,41,1,37,25,63,145,46,16 36 | 1262,13,21,7,41,1,27,35,64,110,31,16 37 | 1388,13,20,6.9,41,1,22,7,63,117,24,16 38 | 1764,15,22,10,41,1,32,12,68,154,38,14 39 | 553,14,21,8.6,42,0,24,0,69,143,30,12 40 | 1191,13,21,8,42,0,21,0,65,132,21,10 41 | 1360,13,22,10,44,0,20,0,63,125,23,10 42 | 223,13,19,8.5,45,1,28,25,64,118,30,16 43 | 1187,14,20,8.9,44,0,20,0,68,150,26,14 44 | -------------------------------------------------------------------------------- /vulcanai/tests/datasets/test_data/birthweight_reduced_merged.csv: -------------------------------------------------------------------------------- 1 | id,headcirumference,length,Birthweight,Gestation,smoker,motherage,mnocig,mheight,mppwt,fage,fedyrs,LowBirthWeight 2 | 1313,12,17,5.8,33,0,24,0,58,99,26,16,Low 3 | 431,12,19,4.2,33,1,20,7,63,109,20,10,Low 4 | 808,13,19,6.4,34,0,26,0,65,140,25,12,Normal 5 | 300,12,18,4.5,35,1,41,7,65,125,37,14,Low 6 | 516,13,18,5.8,35,1,20,35,67,125,23,12,Low 7 | 321,13,19,6.8,37,0,28,0,62,118,39,10,Normal 8 | 1363,12,19,5.2,37,1,20,7,64,104,20,10,Low 9 | 575,12,19,6.1,37,1,19,7,65,132,20,14,Normal 10 | 822,13,19,7.5,38,0,20,0,62,103,22,14,Normal 11 | 1081,14,21,8,38,0,18,0,67,109,20,12,Normal 12 | 1636,14,20,8.6,38,0,29,0,64,135,31,16,Normal 13 | 1107,14,20,7.1,38,0,31,0,64,125,35,16,Normal 14 | 1023,13,20,6.6,38,1,30,12,64,140,38,14,Normal 15 | 1369,13,19,7,38,1,31,25,63,124,32,16,Normal 16 | 697,13,19,6.6,39,0,27,0,63,135,27,14,Normal 17 | 1600,13,21,6.3,39,0,19,0,64,125,23,14,Normal 18 | 57,14,20,7.3,39,1,23,17,62,104,32,12,Normal 19 | 272,14,20,8.5,39,1,30,25,67,170,40,16,Normal 20 | 569,13,19,5.5,39,1,22,7,62,115,23,14,Low 21 | 619,13,20,7.5,39,1,23,25,71,152,23,16,Normal 22 | 1522,13,19,6,39,1,21,17,61,115,24,12,Normal 23 | 820,13,20,8.3,40,0,24,0,62,110,31,16,Normal 24 | 1016,14,21,9.5,40,0,19,0,67,135,19,12,Normal 25 | 1058,13,20,6.9,40,0,29,0,65,130,30,16,Normal 26 | 1088,14,20,7.2,40,0,24,0,66,117,29,16,Normal 27 | 365,14,20,7.7,40,1,26,25,67,137,30,10,Normal 28 | 532,13,21,7.9,40,1,31,12,64,107,41,12,Normal 29 | 752,14,19,7.3,40,1,27,12,60,105,37,12,Normal 30 | 792,14,21,8,40,1,20,2,66,130,24,12,Normal 31 | 1272,12,20,6,40,1,37,50,66,135,31,16,Normal 32 | 462,15,22,9,41,0,35,0,67,127,31,16,Normal 33 | 755,13,21,7,41,0,21,0,61,120,25,14,Normal 34 | 1683,13,21,7.3,41,0,27,0,64,135,37,14,Normal 35 | 27,14,20,7.8,41,1,37,25,63,145,46,16,Normal 36 | 1262,13,21,7,41,1,27,35,64,110,31,16,Normal 37 | 1388,13,20,6.9,41,1,22,7,63,117,24,16,Normal 38 | 1764,15,22,10,41,1,32,12,68,154,38,14,Normal 39 | 553,14,21,8.6,42,0,24,0,69,143,30,12,Normal 40 | 1191,13,21,8,42,0,21,0,65,132,21,10,Normal 41 | 1360,13,22,10,44,0,20,0,63,125,23,10,Normal 42 | 223,13,19,8.5,45,1,28,25,64,118,30,16,Normal 43 | 1187,14,20,8.9,44,0,20,0,68,150,26,14,Normal 44 | -------------------------------------------------------------------------------- /vulcanai/tests/datasets/test_data/birthweight_reduced.csv: -------------------------------------------------------------------------------- 1 | id,headcirumference,length,Birthweight,Gestation,smoker,motherage,mnocig,mheight,mppwt,fage,fedyrs,LowBirthWeight 2 | 1313,Nan,17,5.8,33,0,24,0,58,99,26,16,Low 3 | 431,Nan,19,4.2,33,1,20,7,63,109,20,10,Low 4 | 808,13,19,6.4,34,0,26,0,65,140,25,Nan,Normal 5 | 300,Nan,18,4.5,35,1,41,7,65,125,37,14,Low 6 | 516,13,18,5.8,35,1,20,35,67,125,23,Nan,Low 7 | 321,13,19,6.8,37,0,28,0,62,118,39,10,Normal 8 | 1363,Nan,19,5.2,37,1,20,7,64,104,20,10,Low 9 | 575,Nan,19,6.1,37,1,19,7,65,132,20,14,Normal 10 | 822,13,19,7.5,38,0,20,0,62,103,22,14,Normal 11 | 1081,14,21,8,38,0,18,0,67,109,20,Nan,Normal 12 | 1636,14,20,8.6,38,0,29,0,64,135,31,16,Normal 13 | 1107,14,20,7.1,38,0,31,0,64,125,35,16,Normal 14 | 1023,13,20,6.6,38,1,30,Nan,64,140,38,14,Normal 15 | 1369,13,19,7,38,1,31,25,63,124,32,16,Normal 16 | 697,13,19,6.6,39,0,27,0,63,135,27,14,Normal 17 | 1600,13,21,6.3,39,0,19,0,64,125,23,14,Normal 18 | 57,14,20,7.3,39,1,23,17,62,104,32,Nan,Normal 19 | 272,14,20,8.5,39,1,30,25,67,170,40,16,Normal 20 | 569,13,19,5.5,39,1,22,7,62,115,23,14,Low 21 | 619,13,20,7.5,39,1,23,25,71,152,23,16,Normal 22 | 1522,13,19,6,39,1,21,17,61,115,24,Nan,Normal 23 | 820,13,20,8.3,40,0,24,0,62,110,31,16,Normal 24 | 1016,14,21,9.5,40,0,19,0,67,135,19,Nan,Normal 25 | 1058,13,20,6.9,40,0,29,0,65,130,30,16,Normal 26 | 1088,14,20,7.2,40,0,24,0,66,117,29,16,Normal 27 | 365,14,20,7.7,40,1,26,25,67,137,30,10,Normal 28 | 532,13,21,7.9,40,1,31,Nan,64,107,41,Nan,Normal 29 | 752,14,19,7.3,40,1,27,Nan,60,105,37,Nan,Normal 30 | 792,14,21,8,40,1,20,2,66,130,24,Nan,Normal 31 | 1272,Nan,20,6,40,1,37,50,66,135,31,16,Normal 32 | 462,15,22,9,41,0,35,0,67,127,31,16,Normal 33 | 755,13,21,7,41,0,21,0,61,120,25,14,Normal 34 | 1683,13,21,7.3,41,0,27,0,64,135,37,14,Normal 35 | 27,14,20,7.8,41,1,37,25,63,145,46,16,Normal 36 | 1262,13,21,7,41,1,27,35,64,110,31,16,Normal 37 | 1388,13,20,6.9,41,1,22,7,63,117,24,16,Normal 38 | 1764,15,22,10,41,1,32,Nan,68,154,38,14,Normal 39 | 553,14,21,8.6,42,0,24,0,69,143,30,Nan,Normal 40 | 1191,13,21,8,42,0,21,0,65,132,21,10,Normal 41 | 1360,13,22,10,44,0,20,0,63,125,23,10,Normal 42 | 223,13,19,8.5,45,1,28,25,64,118,30,16,Normal 43 | 1187,14,20,8.9,44,0,20,0,68,150,26,14,Normal -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vulcan 2 | [![Build Status](https://travis-ci.com/Aifred-Health/Vulcan.svg?branch=master)](https://travis-ci.com/Aifred-Health/Vulcan) 3 | [![Documentation Status](https://readthedocs.org/projects/vulcanai/badge/?version=latest)](https://vulcanai.readthedocs.io/en/latest/?badge=latest) 4 | 5 | 6 | Vulcan is Aifred Health's framework for rapid deep learning model prototyping and analysis. 7 | 8 | Vulcan provides the tools for: 9 | 1. Rapid-yet-flexible data preprocessing 10 | 2. Rapid creation of modular neural networks. Among the usual we also include capability for: 11 | * snapshot ensembles 12 | * multi-modal networks with complex architectures 13 | * state of the art activations 14 | * training and saving models across multiple machines 15 | 3. Model Evaluation 16 | 4. Visualization for data and network interpretability. Among the usual we also include: 17 | * t-SNE 18 | * [Saliency maps using guided backpropagation](https://arxiv.org/abs/1412.6806) 19 | 20 | Vulcan is built on Pytorch. We think Pytorch is great, so our framework was built with the goal of facilitating but not impeding access to all of Pytorch. Want to do things the easy way? Great, create a network using our simple configuration dict. Need something a little more complicated? Extend our classes or write your own Pytorch module to use within the rest of our framework. 21 | 22 | For a more detailed runthrough on how to use the tools, please look at the [documentation](https://vulcanai.readthedocs.io/en/latest/). 23 | 24 | ## Installation 25 | [Pytorch](https://pytorch.org) must be installed separately as per your devices requirements (e.g. GPU/CPU). Afterwards, Vulcan can be installed using PyPI: 26 | ``` 27 | pip install vulcanai 28 | ``` 29 | or you can install from source after cloning the repository: 30 | ``` 31 | git clone https://github.com/Aifred-Health/Vulcan.git 32 | cd Vulcan 33 | pip install -e vulcanai 34 | ``` 35 | 36 | ## Releases 37 | The current stable release is 1.0.8 38 | 39 | ## Contributions 40 | We welcome contributions, particularily to tabular_data_utils, additional processing methods, and to generalized and generalizable network architectures. Please create an issue before embarking on a solution, however, as we may already have something similar in the works! 41 | -------------------------------------------------------------------------------- /vulcanai/tests/models/test_data/sensitivity_analysis_test_truth.csv: -------------------------------------------------------------------------------- 1 | Feature,Value,Number of examples classified as class 0,Number of examples classified as class 1,Number of examples classified as class 2 2 | a,0.5349225401878357,0.0,2.0,3.0 3 | a,0.28683900833129883,0.0,2.0,3.0 4 | a,0.7368624210357666,0.0,1.0,4.0 5 | a,0.5209300518035889,0.0,2.0,3.0 6 | a,0.38874197006225586,0.0,2.0,3.0 7 | b,0.19880318641662598,0.0,1.0,4.0 8 | b,0.2063193917274475,0.0,1.0,4.0 9 | b,0.03311049938201904,0.0,2.0,3.0 10 | b,0.593231201171875,0.0,1.0,4.0 11 | b,0.2214464545249939,0.0,1.0,4.0 12 | c,0.6592116951942444,0.0,2.0,3.0 13 | c,0.4450901746749878,0.0,2.0,3.0 14 | c,0.09137779474258423,0.0,1.0,4.0 15 | c,0.8797041773796082,0.0,2.0,3.0 16 | c,0.37420207262039185,0.0,2.0,3.0 17 | d,0.6568902730941772,0.0,2.0,3.0 18 | d,0.35928595066070557,0.0,0.0,5.0 19 | d,0.899403989315033,0.0,3.0,2.0 20 | d,0.6285890936851501,0.0,2.0,3.0 21 | d,0.19525814056396484,0.0,0.0,5.0 22 | e,0.23276156187057495,0.0,2.0,3.0 23 | e,0.720380425453186,0.0,2.0,3.0 24 | e,0.9936200976371765,0.0,1.0,4.0 25 | e,0.7652736902236938,0.0,2.0,3.0 26 | e,0.7405239939689636,0.0,2.0,3.0 27 | f,0.4250614047050476,0.0,1.0,4.0 28 | f,0.07305192947387695,0.0,1.0,4.0 29 | f,0.4702875018119812,0.0,1.0,4.0 30 | f,0.11322057247161865,0.0,1.0,4.0 31 | f,0.25287991762161255,0.0,1.0,4.0 32 | g,0.20708602666854858,0.0,2.0,3.0 33 | g,0.9699196219444275,0.0,2.0,3.0 34 | g,0.10492372512817383,0.0,1.0,4.0 35 | g,0.8559429049491882,0.0,2.0,3.0 36 | g,0.23315048217773438,0.0,2.0,3.0 37 | h,0.6297363638877869,0.0,1.0,4.0 38 | h,0.10778653621673584,0.0,1.0,4.0 39 | h,0.5136615633964539,0.0,1.0,4.0 40 | h,0.6720845699310303,0.0,1.0,4.0 41 | h,0.9314137697219849,0.0,2.0,3.0 42 | i,0.36531615257263184,0.0,2.0,3.0 43 | i,0.8828775882720947,0.0,1.0,4.0 44 | i,0.26739025115966797,0.0,2.0,3.0 45 | i,0.6266505718231201,0.0,1.0,4.0 46 | i,0.9575381278991699,0.0,1.0,4.0 47 | j,0.8512683510780334,0.0,1.0,4.0 48 | j,0.41317272186279297,0.0,1.0,4.0 49 | j,0.4990387558937073,0.0,1.0,4.0 50 | j,0.5690726637840271,0.0,1.0,4.0 51 | j,0.5575070977210999,0.0,1.0,4.0 52 | k,0.8549431562423706,0.0,2.0,3.0 53 | k,0.757194459438324,0.0,1.0,4.0 54 | k,0.7447319030761719,0.0,1.0,4.0 55 | k,0.743732750415802,0.0,1.0,4.0 56 | k,0.41341865062713623,0.0,1.0,4.0 57 | l,0.5509352087974548,0.0,2.0,3.0 58 | l,0.6948475241661072,0.0,1.0,4.0 59 | l,0.7213373184204102,0.0,1.0,4.0 60 | l,0.9592165946960449,0.0,1.0,4.0 61 | l,0.43545758724212646,0.0,2.0,3.0 62 | -------------------------------------------------------------------------------- /examples/fashion_conv_dense.py: -------------------------------------------------------------------------------- 1 | """Simple Convolution and fully connected blocks example.""" 2 | from vulcanai import datasets 3 | from vulcanai.models import ConvNet, DenseNet 4 | 5 | import torchvision.transforms as transforms 6 | from torch.utils.data import DataLoader 7 | 8 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 9 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 10 | 11 | transform = transforms.Compose([transforms.ToTensor(), 12 | normalize]) 13 | 14 | 15 | data_path = "../data" 16 | train_dataset = datasets.FashionData(root=data_path, 17 | train=True, 18 | transform=transform, 19 | download=True) 20 | 21 | val_dataset = datasets.FashionData(root=data_path, 22 | train=False, 23 | transform=transform) 24 | 25 | 26 | batch_size = 100 27 | 28 | train_loader = DataLoader(dataset=train_dataset, 29 | batch_size=batch_size, 30 | shuffle=True) 31 | 32 | val_loader = DataLoader(dataset=val_dataset, 33 | batch_size=batch_size, 34 | shuffle=False) 35 | 36 | conv_2D_config = { 37 | 'conv_units': [ 38 | dict( 39 | in_channels=1, 40 | out_channels=16, 41 | kernel_size=(5, 5), 42 | stride=2, 43 | dropout=0.1 44 | ), 45 | dict( 46 | in_channels=16, 47 | out_channels=32, 48 | kernel_size=(5, 5), 49 | dropout=0.1 50 | ), 51 | dict( 52 | in_channels=32, 53 | out_channels=64, 54 | kernel_size=(5, 5), 55 | pool_size=2, 56 | dropout=0.1 57 | ) 58 | ], 59 | } 60 | 61 | dense_config = { 62 | 'dense_units': [100, 50], 63 | 'dropout': 0.5, # Single value or List 64 | } 65 | 66 | conv_2D = ConvNet( 67 | name='conv_2D', 68 | in_dim=(1, 28, 28), 69 | config=conv_2D_config 70 | ) 71 | 72 | dense_model = DenseNet( 73 | name='dense_model', 74 | input_networks=conv_2D, 75 | config=dense_config, 76 | num_classes=10, 77 | early_stopping="best_validation_error", 78 | early_stopping_patience=2 79 | ) 80 | 81 | dense_model.fit( 82 | train_loader, 83 | val_loader, 84 | epochs=40, 85 | plot=True, 86 | save_path=".", 87 | valid_interv=1 88 | ) 89 | dense_model.run_test(val_loader, plot=True, save_path=".") 90 | dense_model.save_model() 91 | -------------------------------------------------------------------------------- /vulcanai/tests/models/test_device.py: -------------------------------------------------------------------------------- 1 | """Test device switching for networks.""" 2 | import pytest 3 | 4 | import torch 5 | from torch.utils.data import DataLoader, Subset 6 | 7 | from vulcanai.models import ConvNet, DenseNet 8 | from vulcanai.models.utils import master_device_setter 9 | 10 | TEST_CUDA = torch.cuda.is_available() 11 | TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2 12 | DEVICE_COUNT = 0 13 | 14 | if TEST_CUDA: 15 | DEVICE_COUNT = torch.cuda.device_count() 16 | 17 | 18 | class TestDevice: 19 | """Test multi-input GPU device switching.""" 20 | 21 | @pytest.mark.skipif(not TEST_CUDA, reason="No CUDA" 22 | " supported devices available") 23 | def test_master_net_device_set_to_cuda(self, multi_input_cnn): 24 | """Test if the network as whole gets switched to cuda.""" 25 | assert hasattr(multi_input_cnn, 'device') 26 | master_device_setter(multi_input_cnn, 'cuda:0') 27 | assert multi_input_cnn.device == torch.device(type='cuda', index=0) 28 | assert multi_input_cnn.input_networks['conv3D_net']\ 29 | .device == torch.device(type='cuda', index=0) 30 | assert multi_input_cnn.input_networks['multi_input_dnn']\ 31 | .device == torch.device(type='cuda', index=0) 32 | assert multi_input_cnn.input_networks['multi_input_dnn'].\ 33 | input_networks['conv1D_net'].\ 34 | device == torch.device(type='cuda', index=0) 35 | assert multi_input_cnn.input_networks['multi_input_dnn'].\ 36 | input_networks['conv2D_net'].\ 37 | device == torch.device(type='cuda', index=0) 38 | 39 | @pytest.mark.skipif(not TEST_CUDA, reason="No CUDA" 40 | " supported devices available") 41 | def test_fail_mixed_devices(self, multi_input_cnn, conv3D_net, 42 | multi_input_dnn, conv1D_net, 43 | multi_input_dnn_data, 44 | multi_input_cnn_data): 45 | """Test training throws ValueError when network has mixed devices.""" 46 | assert hasattr(conv1D_net, 'device') 47 | assert hasattr(conv3D_net, 'device') 48 | assert hasattr(multi_input_dnn, 'device') 49 | assert hasattr(multi_input_cnn, 'device') 50 | 51 | master_device_setter(multi_input_cnn, 'cuda:0') 52 | assert conv3D_net == multi_input_cnn.input_networks['conv3D_net'] 53 | assert multi_input_dnn == multi_input_cnn.input_networks['multi_input_dnn'] 54 | 55 | data_len = len(multi_input_cnn_data) 56 | train_loader = DataLoader( 57 | Subset(multi_input_cnn_data, range(data_len//2))) 58 | valid_loader = DataLoader( 59 | Subset(multi_input_cnn_data, range(data_len//2, data_len))) 60 | 61 | multi_input_cnn.fit( 62 | train_loader=train_loader, 63 | val_loader=valid_loader, 64 | epochs=1, 65 | plot=False) 66 | 67 | with pytest.raises(ValueError) as e_info: 68 | multi_input_cnn.input_networks['conv3D_net'].device = 'cpu' 69 | multi_input_cnn.fit( 70 | train_loader=train_loader, 71 | val_loader=valid_loader, 72 | epochs=1, 73 | plot=False) 74 | 75 | assert str(e_info.value).endswith("{'conv3D_net': device(type='cpu')}") 76 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at info@aifredhealth.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /vulcanai/datasets/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | This file contains utility methods that many be useful to several dataset 4 | classes. 5 | check_split_ration, stratify, rationed_split, randomshuffler 6 | were all copy-pasted from torchtext because torchtext is not yet packaged 7 | for anaconda and is therefore not yet a reasonable dependency. 8 | See https://github.com/pytorch/text/blob/master/torchtext/data/dataset.py 9 | """ 10 | import pandas as pd 11 | import logging 12 | import copy 13 | import numpy as np 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | # TODO: implement 19 | def clean_dataframe(df): 20 | """ 21 | Goes through and ensures that all nonsensical values are encoded as NaNs 22 | :param df: 23 | :return: 24 | """ 25 | 26 | return df 27 | 28 | 29 | def check_split_ratio(split_ratio): 30 | """ 31 | Check that the split ratio argument is not malformed 32 | 33 | Parameters: 34 | 35 | split_ratio: desired split ratio, either a list of length 2 or 3 36 | depending if the validation set is desired. 37 | 38 | Returns: 39 | split ratio as tuple 40 | 41 | """ 42 | valid_ratio = 0. 43 | if isinstance(split_ratio, float): 44 | # Only the train set relative ratio is provided 45 | # Assert in bounds, validation size is zero 46 | assert 0. < split_ratio < 1., ( 47 | "Split ratio {} not between 0 and 1".format(split_ratio)) 48 | 49 | test_ratio = 1. - split_ratio 50 | return split_ratio, test_ratio, valid_ratio 51 | elif isinstance(split_ratio, list): 52 | # A list of relative ratios is provided 53 | length = len(split_ratio) 54 | assert length == 2 or length == 3, ( 55 | "Length of split ratio list should be 2 or 3, got {}".format( 56 | split_ratio)) 57 | 58 | # Normalize if necessary 59 | ratio_sum = sum(split_ratio) 60 | if not ratio_sum == 1.: 61 | split_ratio = [float(ratio) / ratio_sum for ratio in split_ratio] 62 | 63 | if length == 2: 64 | return tuple(split_ratio + [valid_ratio]) 65 | return tuple(split_ratio) 66 | else: 67 | raise ValueError('Split ratio must be float or a list, got {}' 68 | .format(type(split_ratio))) 69 | 70 | 71 | def rationed_split(df, train_ratio, test_ratio, validation_ratio): 72 | """ 73 | Function to split a dataset given ratios. Assumes the ratios given 74 | are valid (checked using check_split_ratio). 75 | 76 | Parameters: 77 | df: Dataframe 78 | The dataframe you want to split 79 | train_ratio: int 80 | proportion of the dataset that will go to the train split. 81 | between 0 and 1 82 | test_ratio: int 83 | proportion of the dataset that will go to the test split. 84 | between 0 and 1 85 | validation_ratio: int 86 | proportion of the dataset that will go to the val split. 87 | between 0 and 1 88 | 89 | Returns: 90 | indices: tuple of list of indices. 91 | """ 92 | n = len(df.index) 93 | perm = np.random.permutation(df.index) 94 | train_len = int(round(train_ratio * n)) 95 | 96 | # Due to possible rounding problems 97 | if not validation_ratio: 98 | test_len = n - train_len 99 | else: 100 | test_len = int(round(test_ratio * n)) 101 | 102 | indices = (perm[:train_len], # Train 103 | perm[train_len:train_len + test_len], # Test 104 | perm[train_len + test_len:]) # Validation 105 | 106 | return indices 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /vulcanai/datasets/multidataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ Defines the MultiDataset Class""" 3 | from torch.utils.data import Dataset 4 | import logging 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class MultiDataset(Dataset): 10 | """ 11 | Define a dataset for multi input networks. 12 | 13 | Takes in a list of datasets, and whether or not their input_data 14 | and target data should be output. 15 | 16 | Parameters: 17 | dataset_tuples : list of tuples 18 | Each tuple being (Dataset, use_data_boolean, use_target_boolean). 19 | A list of tuples, wherein each tuple should have the Dataset in the 20 | zero index, a boolean of whether to include the input_data in 21 | the first index, and a boolean of whether to include the target 22 | data in the second index. You can only specificy one target at a 23 | time throughout all incoming datasets. 24 | 25 | Returns: 26 | multi_dataset : torch.utils.data.Dataset 27 | 28 | """ 29 | 30 | def __init__(self, dataset_tuples): 31 | """Initialize a dataset for multi input networks.""" 32 | def _get_total_targets(multi_datasets): 33 | num_targets = 0 34 | for tup in multi_datasets: 35 | if isinstance(tup, MultiDataset): 36 | num_targets += _get_total_targets(tup._dataset_tuples) 37 | else: 38 | num_targets += int(tup[2]) 39 | return num_targets 40 | 41 | self._dataset_tuples = dataset_tuples 42 | # must always have exactly one target. 43 | total_num_targets = _get_total_targets(self._dataset_tuples) 44 | if total_num_targets > 1: 45 | raise ValueError( 46 | "You may specify at most one target." 47 | " {} specified".format(total_num_targets)) 48 | 49 | def __len__(self): 50 | """ 51 | Denotes the total number of samples. 52 | 53 | Will look for the dataset with the smallest number of samples and 54 | default the length to that so as to avoid getting a sample that doesn't 55 | exist in another dataset. 56 | 57 | Returns: 58 | length : int 59 | 60 | """ 61 | logger.warning("Defaulting to the length of the smallest dataset") 62 | 63 | def _get_min_length(multi_datasets): 64 | min_length = float('inf') 65 | for tup in multi_datasets: 66 | if isinstance(tup, MultiDataset): 67 | length = _get_min_length(tup._dataset_tuples) 68 | else: 69 | length = len(tup[0]) 70 | if length < min_length: 71 | min_length = length 72 | return min_length 73 | 74 | return _get_min_length(self._dataset_tuples) 75 | 76 | def __getitem__(self, idx): 77 | """ 78 | Override getitem used by DataLoader required by torch Dataset. 79 | 80 | Parameters: 81 | idx : index 82 | Index of sample to extract 83 | 84 | Returns: 85 | (input_data, targets) : (torch.Tensor, torch.Tensor) 86 | Tuple of input_data and target at that index as specific in 87 | config. Target data is always last. 88 | 89 | """ 90 | input_data_items = [] 91 | target_item = None 92 | 93 | for tup in self._dataset_tuples: 94 | 95 | include_data = tup[1] 96 | include_target = tup[2] 97 | 98 | if isinstance(tup, MultiDataset): 99 | ds = tup 100 | else: 101 | ds = tup[0] 102 | 103 | if include_data: 104 | input_data_items.append(ds.__getitem__(idx)[0]) 105 | 106 | if include_target: 107 | target_item = ds.__getitem__(idx)[1] 108 | 109 | return input_data_items, target_item 110 | -------------------------------------------------------------------------------- /vulcanai/tests/models/test_utils.py: -------------------------------------------------------------------------------- 1 | """Define the model utilitiy tests.""" 2 | import pytest 3 | import numpy as np 4 | import torch 5 | from vulcanai.models.utils import (round_list, 6 | get_one_hot, 7 | pad, 8 | set_tensor_device, 9 | master_device_setter) 10 | 11 | TEST_CUDA = torch.cuda.is_available() 12 | 13 | 14 | def test_round_list(): 15 | """Test if the list is rounded to the desired decimals.""" 16 | test_list = [0.83754245, 0.13249807] 17 | out = round_list(test_list, decimals=3) 18 | assert len(str(out[0]).split(".")[1]) == 3 19 | 20 | 21 | def test_get_one_hot(): 22 | """Test for get_one_hot.""" 23 | test_inp = np.array([0, 1, 2, 3, 4, 3, 2, 1, 0]) 24 | assert np.all(get_one_hot(test_inp) == np.array([ 25 | [1., 0., 0., 0., 0.], 26 | [0., 1., 0., 0., 0.], 27 | [0., 0., 1., 0., 0.], 28 | [0., 0., 0., 1., 0.], 29 | [0., 0., 0., 0., 1.], 30 | [0., 0., 0., 1., 0.], 31 | [0., 0., 1., 0., 0.], 32 | [0., 1., 0., 0., 0.], 33 | [1., 0., 0., 0., 0.] 34 | ])) 35 | 36 | 37 | def test_pad(): 38 | """Test if input_tensor is padded to the desired padded shape.""" 39 | n_channels = 2 40 | n_features = 50 41 | pad_shape = 250 42 | batch = 1 43 | test_inp = torch.randn([batch, n_channels, n_features]) 44 | padded_tensor = pad(tensor=test_inp, target_shape=[pad_shape]) 45 | assert padded_tensor.shape == (batch, n_channels, pad_shape) 46 | assert padded_tensor.nonzero().size(0) == n_channels * n_features 47 | # 2D padding 48 | pad_shape1 = 100 49 | test_inp = torch.randn([batch, n_channels, n_features, n_features]) 50 | padded_tensor = pad( 51 | tensor=test_inp, 52 | target_shape=[pad_shape, pad_shape1]) 53 | assert padded_tensor.shape == (batch, n_channels, pad_shape, pad_shape1) 54 | assert padded_tensor.nonzero().size(0) == \ 55 | n_channels * n_features * n_features 56 | # 3D padding 57 | pad_shape2 = 50 58 | test_inp = torch.randn([batch, n_channels, n_features, n_features, 59 | n_features]) 60 | padded_tensor = pad( 61 | tensor=test_inp, 62 | target_shape=[pad_shape, pad_shape1, pad_shape2]) 63 | assert padded_tensor.shape == \ 64 | (batch, n_channels, pad_shape, pad_shape1, pad_shape2) 65 | assert padded_tensor.nonzero().size(0) == \ 66 | n_channels * n_features * n_features * n_features 67 | 68 | 69 | @pytest.mark.skipif(not TEST_CUDA, reason="No CUDA" 70 | " supported devices available") 71 | def test_set_tensor_device(): 72 | """If CUDA is available, test set_tensor_device usage.""" 73 | # Tensor 74 | test_tensor_cpu = torch.randn([5, 5]).cpu() 75 | test_tensor_cuda = set_tensor_device(test_tensor_cpu, device='cuda:0') 76 | assert str(test_tensor_cuda.device) == 'cuda:0' 77 | 78 | # List of Tensors 79 | test_tensorlist_cpu = [torch.randn([1, 2]).cpu(), 80 | torch.randn([2, 3]).cpu(), 81 | torch.randn([3, 4]).cpu()] 82 | test_tensorlist_cuda = set_tensor_device( 83 | test_tensorlist_cpu, device='cuda:0') 84 | for t in test_tensorlist_cuda: 85 | assert str(t.device) == 'cuda:0' 86 | 87 | 88 | @pytest.mark.skipif(not TEST_CUDA, reason="No CUDA" 89 | " supported devices available") 90 | def test_master_device_setter(multi_input_cnn): 91 | """If CUDA is available, test master_device_setter usage.""" 92 | # Make sure the network is in cpu first 93 | assert str(multi_input_cnn.device) == 'cpu' 94 | master_device_setter(multi_input_cnn, device='cuda:0') 95 | assert str(multi_input_cnn.device) == 'cuda:0' 96 | assert str(list(multi_input_cnn.input_networks.values())[0] == 'cuda:0') 97 | assert str(list(list(multi_input_cnn.input_networks. 98 | values())[2].input_networks.values())[0] == 'cuda:0') 99 | -------------------------------------------------------------------------------- /vulcanai/tests/models/test_ensemble.py: -------------------------------------------------------------------------------- 1 | """Define tests for ensemble networks such as Snapshot ensembling.""" 2 | import pytest 3 | import numpy as np 4 | import torch 5 | from vulcanai.models.cnn import ConvNet 6 | from vulcanai.models.dnn import DenseNet 7 | from vulcanai.models.ensemble import SnapshotNet 8 | from torch.utils.data import TensorDataset, DataLoader 9 | 10 | torch.manual_seed(1234) 11 | 12 | 13 | class TestSnapshotNet: 14 | """Test SnapshotNet functionality.""" 15 | 16 | @pytest.fixture 17 | def cnn_noclass(self): 18 | """Create intermediate conv module.""" 19 | return ConvNet( 20 | name='Test_ConvNet_noclass', 21 | in_dim=(1, 28, 28), 22 | config={ 23 | 'conv_units': [ 24 | { 25 | "in_channels": 1, 26 | "out_channels": 16, 27 | "kernel_size": (5, 5), 28 | "stride": 2 29 | }, 30 | { 31 | "in_channels": 16, 32 | "out_channels": 1, 33 | "kernel_size": (5, 5), 34 | "stride": 1, 35 | "padding": 2 36 | }] 37 | }, 38 | ) 39 | 40 | @pytest.fixture 41 | def dnn_class(self, cnn_noclass): 42 | """Create dnn module prediction leaf node.""" 43 | return DenseNet( 44 | name='Test_DenseNet_class', 45 | input_networks=cnn_noclass, 46 | config={ 47 | 'dense_units': [100, 50], 48 | 'initializer': None, 49 | 'bias_init': None, 50 | 'norm': None, 51 | 'dropout': 0.5, # Single value or List 52 | }, 53 | num_classes=3 54 | ) 55 | 56 | @pytest.fixture 57 | def dnn_class_two(self, cnn_noclass): 58 | """Create dnn module prediction leaf node.""" 59 | return DenseNet( 60 | name='Test_DenseNet_class', 61 | input_networks=cnn_noclass, 62 | optim_spec={'name': 'Adam', 'lr': 0.001}, 63 | config={ 64 | 'dense_units': [100, 50], 65 | 'initializer': None, 66 | 'bias_init': None, 67 | 'norm': None, 68 | 'dropout': 0.5, # Single value or List 69 | }, 70 | num_classes=3 71 | ) 72 | 73 | def test_snapshot_structure(self, cnn_noclass, dnn_class): 74 | """Confirm Snapshot structure is generated properly.""" 75 | test_input = torch.randint(0, 10, size=[3, *cnn_noclass.in_dim]).float() 76 | test_target = torch.LongTensor([0, 2, 1]) 77 | test_dataloader = DataLoader(TensorDataset(test_input, test_target)) 78 | test_snap = SnapshotNet( 79 | name='test_snap', 80 | template_network=dnn_class, 81 | n_snapshots=3 82 | ) 83 | test_snap.fit( 84 | train_loader=test_dataloader, 85 | val_loader=test_dataloader, 86 | epochs=3, 87 | plot=False 88 | ) 89 | assert isinstance( 90 | test_snap.template_network.lr_scheduler, 91 | torch.optim.lr_scheduler.CosineAnnealingLR) 92 | # Check correct number of generated snapshots 93 | assert len(test_snap.network) == 3 94 | # Check snapshots are not identical 95 | assert test_snap.network[0] is not \ 96 | test_snap.network[1] is not \ 97 | test_snap.network[2] 98 | output = test_snap.forward_pass( 99 | data_loader=test_dataloader, 100 | transform_outputs=False) 101 | assert output.shape == (3, test_snap.num_classes) 102 | assert np.any(~np.isnan(output)) 103 | 104 | def test_snapshot_lr(self, cnn_noclass, dnn_class_two): 105 | """Confirm Snapshot structure is generated properly.""" 106 | test_input = torch.randint(0, 10, size=[3, *cnn_noclass.in_dim]).float() 107 | test_target = torch.LongTensor([0, 2, 1]) 108 | test_dataloader = DataLoader(TensorDataset(test_input, test_target)) 109 | test_snap = SnapshotNet( 110 | name='test_snap', 111 | template_network=dnn_class_two, 112 | n_snapshots=3, 113 | ) 114 | test_snap.fit( 115 | train_loader=test_dataloader, 116 | val_loader=test_dataloader, 117 | epochs=15, 118 | plot=False 119 | ) 120 | assert test_snap.template_network.lr_scheduler.get_lr()[0] < 0.001 121 | -------------------------------------------------------------------------------- /vulcanai/plotters/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """Contains all visualization utilities.""" 3 | import torch 4 | from torch.nn import ReLU, SELU, ModuleList 5 | import logging 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def get_notable_indices(feature_importances, top_k=5): 11 | """ 12 | Return dict of top k and bottom k features useful from matrix. 13 | 14 | Parameters: 15 | feature_importances: numpy.ndarray 16 | 1D numpy array to extract the top of bottom indices. 17 | top_k : int 18 | How many features from top and bottom to extract. 19 | Defaults to 5. 20 | 21 | Returns: 22 | notable_indices : dict 23 | Indices of the top most important features. 24 | Indices of the bottom mos unimportant features. 25 | 26 | """ 27 | important_features = feature_importances.argsort()[-top_k:][::-1] 28 | unimportant_features = feature_importances.argsort()[:-1][:top_k] 29 | return {'important_indices': important_features, 30 | 'unimportant_indices': unimportant_features} 31 | 32 | 33 | class GuidedBackprop(object): 34 | """ 35 | Generate gradients with guided back propagation w.r.t given input. 36 | 37 | Modified from https://github.com/utkuozbulak/pytorch-cnn-visualizations 38 | Insert backward hooks for activations to propagate positive gradients. 39 | 40 | Parameters: 41 | network : BaseNetwork 42 | Network to conduct guided backprop on. 43 | 44 | Returns: 45 | gradients : list of numpy.ndarray 46 | Gradients of top most layer w.r.t the input sample. 47 | 48 | """ 49 | 50 | def __init__(self, network): 51 | """Set up hooks for activations and gradient retrieval.""" 52 | if network.__class__.__bases__[0].__name__ != "BaseNetwork": 53 | raise ValueError("Network type must be a subclass of BaseNetwork") 54 | self.network = network 55 | self.gradients = [] 56 | self.hooks = [] 57 | # Put model in evaluation mode 58 | self.network.eval() 59 | self._crop_negative_gradients() 60 | 61 | def _crop_negative_gradients(self): 62 | """Update relu/selu activations to return positive gradients.""" 63 | def activation_hook_function(module, grad_in, grad_out): 64 | """If there is a negative gradient, changes it to zero.""" 65 | if isinstance(module, ReLU) or isinstance(module, SELU): 66 | return (torch.clamp(grad_in[0], min=0.0),) 67 | else: 68 | raise NotImplementedError("Only ReLU and SELU supported.") 69 | 70 | # noinspection PyProtectedMember 71 | def _hook_all_networks(network): 72 | self.hooks.append( 73 | network.network[0]._activation. 74 | register_backward_hook(activation_hook_function)) 75 | logging.info("Cropping gradients in {}.".format(network.name)) 76 | if network.input_networks: 77 | for in_net in network.input_networks.values(): 78 | _hook_all_networks(in_net) 79 | # For Snapshot Networks 80 | if isinstance(self.network.network, ModuleList): 81 | for net in self.network.network: 82 | _hook_all_networks(net) 83 | else: 84 | _hook_all_networks(self.network) 85 | 86 | def _remove_hooks(self): 87 | """Remove all previously placed hooks from model.""" 88 | for h in self.hooks: 89 | h.remove() 90 | 91 | # noinspection PyProtectedMember 92 | def generate_gradients(self, input_data, targets): 93 | """ 94 | Compute guided backprop gradients and returns top layer gradients. 95 | 96 | Parameters: 97 | input_data : numpy.ndarray or torch.Tensor 98 | 2D for DenseNet, 4D (for 2D images) or 5D (for 3D images) 99 | Tensor. 100 | targets : numpy.ndarray or torch.LongTensor 101 | 1D list of class labels 102 | 103 | Returns: 104 | gradients : list of numpy.ndarray 105 | Gradient list of numpy array with same shape as inputs. 106 | 107 | """ 108 | assert isinstance(targets, torch.LongTensor) 109 | 110 | def _requires_grad_multidataset(data_list): 111 | for d in data_list: 112 | if isinstance(d, list): 113 | _requires_grad_multidataset(d) 114 | else: 115 | assert isinstance(d, torch.Tensor) 116 | d.requires_grad_() 117 | 118 | if isinstance(input_data, list): 119 | _requires_grad_multidataset(input_data) 120 | else: 121 | assert isinstance(input_data, torch.Tensor) 122 | input_data.requires_grad_() 123 | 124 | # Forward pass 125 | network_output = self.network(input_data) 126 | # Zero gradients 127 | self.network.zero_grad() 128 | # Target for backprop 129 | one_hot_zeros = torch.zeros( 130 | network_output.size()[0], 131 | self.network.num_classes) 132 | one_hot_output = one_hot_zeros.scatter_(1, targets.unsqueeze(dim=1), 1) 133 | one_hot_output = one_hot_output.to(self.network.device) 134 | # Backward pass 135 | network_output.backward(gradient=one_hot_output) 136 | 137 | # noinspection PyShadowingNames 138 | def _extract_input_gradients_multidataset(input_data): 139 | for data in input_data: 140 | if isinstance(data, list): 141 | _extract_input_gradients_multidataset(data) 142 | else: 143 | self.gradients.append(data.grad.detach().cpu().numpy()) 144 | 145 | if isinstance(input_data, list): 146 | _extract_input_gradients_multidataset(input_data) 147 | else: 148 | # noinspection PyUnresolvedReferences 149 | self.gradients.append(input_data.grad.detach().cpu().numpy()) 150 | 151 | self._remove_hooks() 152 | return self.gradients 153 | -------------------------------------------------------------------------------- /examples/fashion_multi_input_network.py: -------------------------------------------------------------------------------- 1 | """Example for a complex multi-modality (i.e. multi-input) network.""" 2 | from vulcanai import datasets 3 | from vulcanai.models import ConvNet, DenseNet 4 | from vulcanai.datasets import MultiDataset 5 | 6 | import torch 7 | 8 | import torchvision.transforms as transforms 9 | from torch.utils.data import DataLoader, TensorDataset 10 | 11 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 12 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 13 | 14 | transform = transforms.Compose([transforms.ToTensor(), 15 | normalize]) 16 | 17 | 18 | data_path = "../data" 19 | train_dataset = datasets.FashionData(root=data_path, 20 | train=True, 21 | transform=transform, 22 | download=True) 23 | 24 | train_dataset = torch.utils.data.Subset(train_dataset, range(0, 1000)) 25 | 26 | val_dataset = datasets.FashionData(root=data_path, 27 | train=False, 28 | transform=transform) 29 | 30 | 31 | batch_size = 100 32 | 33 | train_loader = DataLoader(dataset=train_dataset, 34 | batch_size=batch_size, 35 | shuffle=True) 36 | 37 | val_loader = DataLoader(dataset=val_dataset, 38 | batch_size=batch_size, 39 | shuffle=False) 40 | 41 | 42 | conv_1D_config = { 43 | 'conv_units': [ 44 | dict( 45 | in_channels=1, 46 | out_channels=16, 47 | kernel_size=(5), 48 | stride=2, 49 | dropout=0.1 50 | ), 51 | dict( 52 | in_channels=16, 53 | out_channels=32, 54 | kernel_size=(5), 55 | padding=0, 56 | dropout=0.1 57 | ), 58 | dict( 59 | in_channels=32, 60 | out_channels=64, 61 | kernel_size=(5), 62 | pool_size=2, 63 | dropout=0.1 64 | ) 65 | ], 66 | } 67 | conv_2D_config = { 68 | 'conv_units': [ 69 | dict( 70 | in_channels=1, 71 | out_channels=16, 72 | kernel_size=(5, 5), 73 | stride=2, 74 | dropout=0.1 75 | ), 76 | dict( 77 | in_channels=16, 78 | out_channels=32, 79 | kernel_size=(5, 5), 80 | dropout=0.1 81 | ), 82 | dict( 83 | in_channels=32, 84 | out_channels=64, 85 | kernel_size=(5, 5), 86 | pool_size=2, 87 | dropout=0.1 88 | ) 89 | ], 90 | } 91 | conv_3D_config = { 92 | 'conv_units': [ 93 | dict( 94 | in_channels=1, 95 | out_channels=16, 96 | kernel_size=(5, 5, 5), 97 | stride=2, 98 | dropout=0.1 99 | ), 100 | dict( 101 | in_channels=16, 102 | out_channels=16, 103 | kernel_size=(5, 5, 5), 104 | stride=1, 105 | dropout=0.1 106 | ), 107 | dict( 108 | in_channels=16, 109 | out_channels=64, 110 | kernel_size=(5, 5, 5), 111 | dropout=0.1 112 | ), 113 | ], 114 | } 115 | 116 | multi_input_conv_3D_config = { 117 | 'conv_units': [ 118 | dict( 119 | in_channels=1, 120 | out_channels=16, 121 | kernel_size=(3, 3, 3), 122 | stride=2, 123 | dropout=0.1 124 | ), 125 | ], 126 | } 127 | dense_config = { 128 | 'dense_units': [100, 50], 129 | 'weight_init': None, 130 | 'bias_init': None, 131 | 'norm': None, 132 | 'dropout': 0.5, # Single value or List 133 | } 134 | 135 | conv_1D = ConvNet( 136 | name='conv_1D', 137 | input_networks=None, 138 | in_dim=(1, 28), 139 | config=conv_1D_config 140 | ) 141 | conv_2D = ConvNet( 142 | name='conv_2D', 143 | input_networks=None, 144 | in_dim=(1, 28, 28), 145 | config=conv_2D_config 146 | ) 147 | conv_3D = ConvNet( 148 | name='conv_3D', 149 | input_networks=None, 150 | in_dim=(1, 28, 28, 28), 151 | config=conv_3D_config 152 | ) 153 | 154 | dense_model = DenseNet( 155 | name='dense_model', 156 | input_networks=[conv_2D, conv_1D], 157 | config=dense_config 158 | ) 159 | 160 | multi_input_conv_3D = ConvNet( 161 | name='multi_input_conv_3D', 162 | input_networks=[conv_1D, dense_model, conv_2D, conv_3D], 163 | config=multi_input_conv_3D_config, 164 | num_classes=10, 165 | device="cuda:0" 166 | ) 167 | 168 | 169 | multi_dense = [ 170 | (val_loader.dataset, True, False), 171 | (TensorDataset(torch.ones([10000, *conv_1D.in_dim])), True, False) 172 | ] 173 | 174 | m = MultiDataset(multi_dense) 175 | 176 | x = [ 177 | (TensorDataset(torch.ones([10000, *conv_1D.in_dim])), True, False), 178 | m, 179 | (val_loader.dataset, True, True), 180 | (TensorDataset(torch.ones([10000, *conv_3D.in_dim])), True, False), 181 | ] 182 | 183 | multi_dataset = MultiDataset(x) 184 | 185 | train_multi = torch.utils.data.Subset( 186 | multi_dataset, range(len(multi_dataset)//2)) 187 | val_multi = torch.utils.data.Subset( 188 | multi_dataset, range(len(multi_dataset)//2, len(multi_dataset))) 189 | 190 | train_loader_multi = DataLoader(train_multi, batch_size=100) 191 | val_loader_multi = DataLoader(val_multi, batch_size=100) 192 | 193 | multi_input_conv_3D.fit( 194 | train_loader_multi, 195 | val_loader_multi, 196 | epochs=3, 197 | plot=True, 198 | save_path="." 199 | ) 200 | multi_input_conv_3D.run_test(val_loader_multi, plot=True, save_path=".") 201 | multi_input_conv_3D.save_model() 202 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | 18 | sys.path.insert(0, os.path.abspath('../../')) 19 | sys.path.insert(0, os.path.abspath('../../vulcanai/')) 20 | sys.path.insert(0, os.path.abspath('../../vulcanai/datasets/')) 21 | sys.path.insert(0, os.path.abspath('../../vulcanai/models/')) 22 | sys.path.insert(0, os.path.abspath('../../vulcanai/plotters/')) 23 | 24 | import torch 25 | import torchvision 26 | # -- Project information ----------------------------------------------------- 27 | 28 | project = 'Vulcan' 29 | copyright = '2019, Robert Fratila, Priyatharsan Rajasekar, Caitrin Armstrong, Joseph Mehltretter' 30 | author = 'Robert Fratila, Priyatharsan Rajasekar, Caitrin Armstrong, Joseph Mehltretter' 31 | 32 | # The short X.Y version 33 | version = '' 34 | # The full version, including alpha/beta/rc tags 35 | release = '1.0' 36 | 37 | 38 | # -- General configuration --------------------------------------------------- 39 | 40 | # If your documentation needs a minimal Sphinx version, state it here. 41 | # 42 | # needs_sphinx = '1.0' 43 | 44 | # Add any Sphinx extension module names here, as strings. They can be 45 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 46 | # ones. 47 | extensions = [ 48 | 'sphinx.ext.autodoc', 49 | 'sphinx.ext.ifconfig', 50 | ] 51 | 52 | autodoc_default_options = { 53 | 'members': None, 54 | 'member-order': 'bysource', 55 | 'special-members': '__init__', 56 | 'undoc-members': None, 57 | 'exclude-members': None 58 | } 59 | # Add any paths that contain templates here, relative to this directory. 60 | templates_path = ['_templates'] 61 | 62 | # The suffix(es) of source filenames. 63 | # You can specify multiple suffix as a list of string: 64 | # 65 | # source_suffix = ['.rst', '.md'] 66 | source_suffix = '.rst' 67 | 68 | # The master toctree document. 69 | master_doc = 'index' 70 | 71 | # The language for content autogenerated by Sphinx. Refer to documentation 72 | # for a list of supported languages. 73 | # 74 | # This is also used if you do content translation via gettext catalogs. 75 | # Usually you set "language" from the command line for these cases. 76 | language = None 77 | 78 | # List of patterns, relative to source directory, that match files and 79 | # directories to ignore when looking for source files. 80 | # This pattern also affects html_static_path and html_extra_path. 81 | exclude_patterns = [] 82 | 83 | # The name of the Pygments (syntax highlighting) style to use. 84 | pygments_style = None 85 | 86 | 87 | # -- Options for HTML output ------------------------------------------------- 88 | 89 | # The theme to use for HTML and HTML Help pages. See the documentation for 90 | # a list of builtin themes. 91 | # 92 | html_theme = 'classic' 93 | html_theme_options = { 94 | "stickysidebar": True, 95 | "sidebarbgcolor": "#344054", 96 | "sidebartextcolor": "#5df6c6", 97 | "sidebarlinkcolor": "white", 98 | "headtextcolor": "#344054", 99 | "textcolor": "#344054", 100 | "linkcolor": "#344054" 101 | } 102 | 103 | # Theme options are theme-specific and customize the look and feel of a theme 104 | # further. For a list of options available for each theme, see the 105 | # documentation. 106 | # 107 | # html_theme_options = {} 108 | 109 | # Add any paths that contain custom static files (such as style sheets) here, 110 | # relative to this directory. They are copied after the builtin static files, 111 | # so a file named "default.css" will overwrite the builtin "default.css". 112 | html_static_path = [] 113 | 114 | # Custom sidebar templates, must be a dictionary that maps document names 115 | # to template names. 116 | # 117 | # The default sidebars (for documents that don't match any pattern) are 118 | # defined by theme itself. Builtin themes are using these templates by 119 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 120 | # 'searchbox.html']``. 121 | # 122 | html_sidebars = { '**': ['globaltoc.html', 'relations.html', 'sourcelink.html', 'searchbox.html'] } 123 | 124 | 125 | # -- Options for HTMLHelp output --------------------------------------------- 126 | 127 | # Output file base name for HTML help builder. 128 | htmlhelp_basename = 'Vulcandoc' 129 | 130 | 131 | # -- Options for LaTeX output ------------------------------------------------ 132 | 133 | latex_elements = { 134 | # The paper size ('letterpaper' or 'a4paper'). 135 | # 136 | # 'papersize': 'letterpaper', 137 | 138 | # The font size ('10pt', '11pt' or '12pt'). 139 | # 140 | # 'pointsize': '10pt', 141 | 142 | # Additional stuff for the LaTeX preamble. 143 | # 144 | # 'preamble': '', 145 | 146 | # Latex figure (float) alignment 147 | # 148 | # 'figure_align': 'htbp', 149 | } 150 | 151 | # Grouping the document tree into LaTeX files. List of tuples 152 | # (source start file, target name, title, 153 | # author, documentclass [howto, manual, or own class]). 154 | latex_documents = [ 155 | (master_doc, 'Vulcan.tex', 'Vulcan Documentation', 156 | 'Robert Fratila, Priyatharsan Rajasekar, Caitrin Armstrong, Joseph Mehltretter', 'manual'), 157 | ] 158 | 159 | 160 | # -- Options for manual page output ------------------------------------------ 161 | 162 | # One entry per manual page. List of tuples 163 | # (source start file, name, description, authors, manual section). 164 | man_pages = [ 165 | (master_doc, 'vulcan', 'Vulcan Documentation', 166 | [author], 1) 167 | ] 168 | 169 | 170 | # -- Options for Texinfo output ---------------------------------------------- 171 | 172 | # Grouping the document tree into Texinfo files. List of tuples 173 | # (source start file, target name, title, author, 174 | # dir menu entry, description, category) 175 | texinfo_documents = [ 176 | (master_doc, 'Vulcan', 'Vulcan Documentation', 177 | author, 'Vulcan', 'One line description of project.', 178 | 'Miscellaneous'), 179 | ] 180 | 181 | 182 | # -- Options for Epub output ------------------------------------------------- 183 | 184 | # Bibliographic Dublin Core info. 185 | epub_title = project 186 | 187 | # The unique identifier of the text. This can be a ISBN number 188 | # or the project homepage. 189 | # 190 | # epub_identifier = '' 191 | 192 | # A unique identification for the text. 193 | # 194 | # epub_uid = '' 195 | 196 | # A list of files that should not be packed into the epub file. 197 | epub_exclude_files = ['search.html'] 198 | 199 | 200 | # -- Extension configuration ------------------------------------------------- 201 | -------------------------------------------------------------------------------- /vulcanai/models/dnn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """Defines the DenseNet class.""" 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .basenetwork import BaseNetwork 7 | from .layers import DenseUnit, FlattenUnit 8 | 9 | import logging 10 | from inspect import getfullargspec 11 | 12 | from collections import OrderedDict 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class DenseNetConfig: 18 | """Defines the necessary configuration for a DenseNet.""" 19 | 20 | def __init__(self, raw_config): 21 | """ 22 | Take in user config dict and clean it up. 23 | 24 | Cleaned units is stored in self.units 25 | 26 | Parameters: 27 | raw_config : dict of dict 28 | User specified dict 29 | 30 | """ 31 | if 'dense_units' not in raw_config: 32 | raise KeyError("dense_units must be specified.") 33 | 34 | if not isinstance(raw_config['dense_units'], list): 35 | raise ValueError("dense_units must be of type list.") 36 | 37 | dense_unit_arg_spec = getfullargspec(DenseUnit) 38 | dense_unit_arg_spec.args.remove('self') 39 | 40 | # Find the index for where the defaulted values begin 41 | default_arg_start_index = len(dense_unit_arg_spec.args) - \ 42 | len(dense_unit_arg_spec.defaults) 43 | default_args = dense_unit_arg_spec.args[default_arg_start_index:] 44 | 45 | # Only look at args that were specified to be overwritten 46 | override_args = set(default_args).intersection(set(raw_config.keys())) 47 | # If arguments are specified in lists, 48 | # check that length corresponds to dense_units 49 | for arg in override_args: 50 | if isinstance(raw_config[arg], list): 51 | if len(raw_config[arg]) != len(raw_config['dense_units']): 52 | raise ValueError( 53 | "{} list not same length as dense_units.".format(arg)) 54 | else: 55 | # If a single value is specified, apply to all layers 56 | raw_config[arg] = [raw_config[arg]] * \ 57 | len(raw_config['dense_units']) 58 | 59 | # TODO: Think about moving dimension to config file 60 | _units_per_layer = list([None] + raw_config['dense_units']) 61 | _unit_pairs = list(zip(_units_per_layer[:-1], _units_per_layer[1:])) 62 | 63 | self.units = [] 64 | for i, (in_feature, out_feature) in enumerate(_unit_pairs): 65 | temp_unit = { 66 | 'in_features': in_feature, 67 | 'out_features': out_feature 68 | } 69 | for arg in override_args: 70 | temp_unit[arg] = raw_config[arg][i] 71 | self.units.append(temp_unit) 72 | 73 | 74 | # noinspection PyDefaultArgument,PyTypeChecker 75 | class DenseNet(BaseNetwork): 76 | """ 77 | Subclass of BaseNetwork defining a DenseNet. 78 | 79 | Parameters: 80 | name : str 81 | The name of the network. Used when saving the file. 82 | config : dict 83 | The configuration of the network module, as a dict. 84 | in_dim : tuple 85 | The input dimensions of the network. Not required to specify when 86 | the network has input_networks. 87 | save_path : str 88 | The name of the file to which you would like to save this network. 89 | input_networks : list of BaseNetwork 90 | A network object provided as input. 91 | num_classes : int or None 92 | The number of classes to predict. 93 | activation : torch.nn.Module 94 | The desired activation function for use in the network. 95 | pred_activation : torch.nn.Module 96 | The desired activation function for use in the prediction layer. 97 | optim_spec : dict 98 | A dictionary of parameters for the desired optimizer. 99 | lr_scheduler : torch.optim.lr_scheduler 100 | A callable torch.optim.lr_scheduler 101 | early_stopping : str or None 102 | So far just 'best_validation_error' is implemented. 103 | early_stopping_patience: integer 104 | Number of validation iterations of decreasing loss 105 | (note -not necessarily every epoch! 106 | before early stopping is applied. 107 | early_stopping_metric: string 108 | Either "loss" or "accuracy" are implemented. 109 | criter_spec : dict 110 | criterion specification with name and all its parameters. 111 | 112 | Returns: 113 | network : DenseNet 114 | A network of type BaseNetwork. 115 | 116 | """ 117 | 118 | def __init__(self, name, config, in_dim=None, save_path=None, 119 | input_networks=None, num_classes=None, 120 | activation=nn.ReLU(), pred_activation=None, 121 | optim_spec={'name': 'Adam', 'lr': 0.001}, 122 | lr_scheduler=None, early_stopping=None, 123 | early_stopping_patience=None, 124 | early_stopping_metric="accuracy", 125 | criter_spec=nn.CrossEntropyLoss(), 126 | device="cuda:0"): 127 | """Define the DenseNet object.""" 128 | super(DenseNet, self).__init__( 129 | name, DenseNetConfig(config), in_dim, save_path, input_networks, 130 | num_classes, activation, pred_activation, optim_spec, 131 | lr_scheduler, early_stopping, early_stopping_patience, 132 | early_stopping_metric, 133 | criter_spec, device) 134 | 135 | def _create_network(self, **kwargs): 136 | """ 137 | Build the layers of the network into a nn.Sequential object. 138 | 139 | Parameters: 140 | dense_hid_layers : DenseNetConfig.units (list of dict) 141 | The hidden layers specification 142 | activation : torch.nn.Module 143 | the non-linear activation to apply to each layer 144 | 145 | Returns: 146 | output : torch.nn.Sequential 147 | the dense network as a nn.Sequential object 148 | 149 | """ 150 | dense_hid_layers = self._config.units 151 | 152 | if self.input_networks: 153 | self.in_dim = self._get_in_dim() 154 | 155 | # Build network 156 | # Specify incoming feature size for the first dense hidden layer 157 | dense_hid_layers[0]['in_features'] = self.in_dim[0] 158 | dense_layers = OrderedDict() 159 | for idx, dense_layer_config in enumerate(dense_hid_layers): 160 | dense_layer_config['activation'] = kwargs['activation'] 161 | layer_name = 'dense_{}'.format(idx) 162 | dense_layers[layer_name] = DenseUnit(**dense_layer_config) 163 | self.network = nn.Sequential(dense_layers) 164 | 165 | if self.num_classes: 166 | self.network.add_module( 167 | 'classify', DenseUnit( 168 | in_features=self._get_out_dim()[0], 169 | out_features=self.num_classes, 170 | activation=kwargs['pred_activation'])) 171 | 172 | def _merge_input_network_outputs(self, tensors): 173 | output_tensors = [FlattenUnit()(t) for t in tensors] 174 | return torch.cat(output_tensors, dim=1) 175 | 176 | def __str__(self): 177 | if self.optim is not None: 178 | return super(DenseNet, self).__str__() + '\noptim: {}'\ 179 | .format(self.optim) 180 | else: 181 | return super(DenseNet, self).__str__() 182 | -------------------------------------------------------------------------------- /vulcanai/models/ensemble.py: -------------------------------------------------------------------------------- 1 | """Contains all ensemble models.""" 2 | from copy import deepcopy 3 | import logging 4 | from datetime import datetime 5 | import pickle 6 | import os 7 | 8 | import torch 9 | from torch import nn 10 | from torch.optim.lr_scheduler import CosineAnnealingLR 11 | 12 | from .basenetwork import BaseNetwork 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class SnapshotNet(BaseNetwork): 18 | """ 19 | Initialize snapshot ensemble given a template network. 20 | 21 | A wrapper class for any Network inheriting from BaseNetwork to 22 | train the template network using Snapshot Ensembling. 23 | 24 | Parameters: 25 | name : str 26 | String of snapshot ensemble name. 27 | template_network : BaseNetwork 28 | Network object which you want to ensemble. 29 | n_snapshots : int 30 | Number of snapshots in ensemble. 31 | 32 | Returns: 33 | network : SnapshotNet 34 | 35 | """ 36 | 37 | # noinspection PyProtectedMember 38 | def __init__(self, name, template_network, n_snapshots=3): 39 | """Use Network to build model snapshots.""" 40 | # TODO: Should these be defaulted to the values of template_network? 41 | super(SnapshotNet, self).__init__( 42 | name=name, 43 | config=None, # template_network._config 44 | in_dim=template_network.in_dim, 45 | save_path=None, # template_network.save_path 46 | input_networks=None, # template_network.input_networks 47 | num_classes=template_network.num_classes, 48 | activation=None, # template_network.network[0]._activation 49 | pred_activation=None, # pred_activation 50 | optim_spec=None, # template_network._optim_spec 51 | lr_scheduler=None, # template_network.lr_scheduler 52 | early_stopping=None, # template_network.early_stopping 53 | criter_spec=None # template_network._criter_spec 54 | ) 55 | 56 | if not isinstance(template_network, BaseNetwork): 57 | raise ValueError( 58 | "template_network type must inherit from BaseNetwork.") 59 | 60 | self.template_network = deepcopy(template_network) 61 | self.network = nn.ModuleList() 62 | self.out_dim = self.template_network.out_dim 63 | if n_snapshots <= 0: 64 | raise ValueError("n_snapshots must be >=1.") 65 | self.n_snapshots = n_snapshots 66 | 67 | def fit(self, train_loader, val_loader, epochs, 68 | retain_graph=None, valid_interv=4, plot=False): 69 | """ 70 | Train each model for T/M epochs and controls network learning rate. 71 | 72 | Collects each model in a class variable self.network 73 | 74 | Parameters: 75 | train_loader : DataLoader 76 | Input data and targets to train against 77 | val_loader : DataLoader 78 | Input data and targets to validate against 79 | epochs : int 80 | Total number of epochs (evenly distributed between snapshots) 81 | 82 | Returns: 83 | None 84 | 85 | """ 86 | # There must be at least one train epoch for each snapshot 87 | if epochs < self.n_snapshots: 88 | logger.warn( 89 | 'Number of epochs to small for number of Snapshots. ' 90 | 'Setting epochs to {}.'.format(self.n_snapshots)) 91 | epochs = self.n_snapshots 92 | 93 | T = epochs 94 | # How many epochs each singular network should train for 95 | network_epochs = T // self.n_snapshots 96 | 97 | # Temporary but check if it first has an optimizer, 98 | # if not it will make one 99 | if self.template_network.optim is None: 100 | self.template_network._init_trainer() 101 | 102 | self.template_network.lr_scheduler = CosineAnnealingLR( 103 | optimizer=self.template_network.optim, 104 | T_max=network_epochs 105 | ) 106 | 107 | for index in range(self.n_snapshots): 108 | self.template_network.fit( 109 | train_loader=train_loader, 110 | val_loader=val_loader, 111 | epochs=network_epochs, 112 | valid_interv=valid_interv, 113 | plot=plot 114 | ) 115 | # Save instance of snapshot in a nn.ModuleList 116 | temp_network = deepcopy(self.template_network) 117 | self._update_network_name_stack( 118 | network=temp_network, 119 | append_str=index) 120 | self.network.append(temp_network) 121 | 122 | def _update_network_name_stack(self, network, append_str): 123 | """ 124 | Given a network, append a string to the name of all networks in stack. 125 | 126 | Recursively traverse each input network to update names 127 | with the appended string. 128 | 129 | Parameters: 130 | network : BaseNetwork 131 | Network stack to update names of with new append_str. 132 | append_str : int, str 133 | The characters to append at the end of BaseNetwork stack of 134 | names. 135 | 136 | """ 137 | if network.input_networks: 138 | for in_net in network.input_networks.values(): 139 | self._update_network_name_stack(in_net, append_str) 140 | network.name = "{}_{}".format(network.name, append_str) 141 | 142 | def forward(self, inputs, **kwargs): 143 | """ 144 | Snapshot forward function. 145 | 146 | Collect outputs of all internal networks and average outputs. 147 | 148 | Parameters: 149 | inputs : torch.Tensor 150 | Input tensor to pass through self. 151 | 152 | Returns: 153 | output : torch.Tensor 154 | 155 | """ 156 | if len(self.network) == 0: 157 | raise ValueError("SnapshotNet needs to be trained.") 158 | 159 | pred_collector = [] 160 | for net in self.network: 161 | pred_collector.append(net(inputs)) 162 | # Stack outputs along a new 0 dimension to be averaged 163 | # Can take a list despite type checker complaining? 164 | # noinspection PyTypeChecker 165 | pred_collector = torch.stack(pred_collector) 166 | 167 | return torch.mean(input=pred_collector, dim=0) 168 | 169 | def save_model(self, save_path=None): 170 | """ 171 | Save all ensembled network in a folder with ensemble name. 172 | 173 | Parameters: 174 | save_path : str 175 | The folder path to save models in. 176 | 177 | Returns: 178 | None 179 | 180 | """ 181 | if not save_path: 182 | save_path = r"saved_models/" 183 | 184 | if not save_path.endswith("/"): 185 | save_path = save_path + "/" 186 | 187 | save_path = save_path + "{}_{}/".format( 188 | self.name, datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) 189 | logger.info("No save path provided, saving to {}".format(save_path)) 190 | 191 | for network in self.network: 192 | logger.info("Saving network {}".format(network.name)) 193 | network.save_model(save_path=save_path) 194 | 195 | if not os.path.exists(save_path): 196 | os.makedirs(save_path) 197 | 198 | model_save_path = save_path + "model.pkl" 199 | self.save_path = save_path 200 | pickle.dump(self, open(model_save_path, "wb"), 2) 201 | -------------------------------------------------------------------------------- /vulcanai/datasets/fashion.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function 4 | import torch.utils.data as data 5 | from PIL import Image 6 | import os 7 | import os.path 8 | import errno 9 | import torch 10 | import codecs 11 | import torchvision.transforms as transforms 12 | import urllib 13 | 14 | 15 | # FROM https://raw.githubusercontent.com/mayurbhangale/ 16 | # fashion-mnist-pytorch/master/fashion.py 17 | class FashionData(data.Dataset): 18 | """'MNIST `_ Dataset. 19 | 20 | Parameters: 21 | root (string): Root directory of dataset where 22 | ``processed/training.pt`` and ``processed/test.pt`` exist. 23 | train (bool, optional): If True, creates dataset from ``training.pt``, 24 | otherwise from ``test.pt``. 25 | download (bool, optional): If true, downloads the dataset from the 26 | internet and puts it in root directory. If dataset is already 27 | downloaded, it is not downloaded again. 28 | transform (callable, optional): A function/transform that takes in an 29 | PIL image and returns a transformed version. E.g, 30 | ``transforms.RandomCrop`` 31 | target_transform (callable, optional): A function/transform that 32 | takes in the target and transforms it. 33 | """ 34 | urls = [ 35 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz', 36 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz', 37 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz', 38 | 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz', 39 | ] 40 | raw_folder = 'raw' 41 | processed_folder = 'processed' 42 | training_file = 'training.pt' 43 | test_file = 'test.pt' 44 | 45 | def __init__(self, root, train=True, transform=None, target_transform=None, 46 | download=False): 47 | self.root = os.path.expanduser(root) 48 | self.transform = transform 49 | self.target_transform = target_transform 50 | self.train = train # training set or test set 51 | 52 | self.transform = transforms.Compose([transforms.ToTensor(), 53 | transforms.Normalize((0.1307,), 54 | (0.3081,))]) 55 | 56 | if download: 57 | self.download() 58 | 59 | if not self._check_exists(): 60 | print("downloading") 61 | self.download() 62 | 63 | if self.train: 64 | self.train_data, self.train_labels = torch.load( 65 | os.path.join(root, self.processed_folder, self.training_file)) 66 | else: 67 | self.test_data, self.test_labels = torch.load(os.path.join( 68 | root, self.processed_folder, self.test_file)) 69 | 70 | def __getitem__(self, index): 71 | """ 72 | Args: 73 | index (int): Index 74 | 75 | Returns: 76 | tuple: (image, target) where target is index of the target class. 77 | """ 78 | if self.train: 79 | img, target = self.train_data[index], self.train_labels[index] 80 | else: 81 | img, target = self.test_data[index], self.test_labels[index] 82 | 83 | # doing this so that it is consistent with all other datasets 84 | # to return a PIL Image 85 | img = Image.fromarray(img.numpy(), mode='L') 86 | 87 | if self.transform is not None: 88 | img = self.transform(img) 89 | 90 | if self.target_transform is not None: 91 | target = self.target_transform(target) 92 | 93 | return img, target 94 | 95 | def __len__(self): 96 | if self.train: 97 | return len(self.train_data) 98 | else: 99 | return len(self.test_data) 100 | 101 | def _check_exists(self): 102 | return os.path.exists(os.path.join(self.root, self.processed_folder, 103 | self.training_file)) and \ 104 | os.path.exists(os.path.join(self.root, self.processed_folder, 105 | self.test_file)) 106 | 107 | # TODO: Need to fix. File not found error before it downloads 108 | def download(self): 109 | """Download the MNIST data if it doesn't exist in processed_folder 110 | already.""" 111 | import gzip 112 | 113 | if self._check_exists(): 114 | return 115 | 116 | # download files 117 | try: 118 | os.makedirs(os.path.join(self.root, self.raw_folder)) 119 | os.makedirs(os.path.join(self.root, self.processed_folder)) 120 | except OSError as e: 121 | if e.errno == errno.EEXIST: 122 | pass 123 | else: 124 | raise 125 | 126 | for url in self.urls: 127 | print('Downloading ' + url) 128 | response = urllib.request.urlopen(url) 129 | #compressed_file = io.BytesIO(response.read()) 130 | filename = url.rpartition('/')[2] 131 | file_path = os.path.join(self.root, self.raw_folder, filename) 132 | with open(file_path.replace('.gz', ''), 'wb') as out_f, \ 133 | gzip.GzipFile(fileobj=response) as zip_f: 134 | out_f.write(zip_f.read()) 135 | #os.unlink(file_path) 136 | 137 | # process and save as torch files 138 | print('Processing...') 139 | 140 | training_set = ( 141 | read_image_file(os.path.join(self.root, self.raw_folder, 142 | 'train-images-idx3-ubyte')), 143 | read_label_file(os.path.join(self.root, self.raw_folder, 144 | 'train-labels-idx1-ubyte')) 145 | ) 146 | test_set = ( 147 | read_image_file(os.path.join(self.root, self.raw_folder, 148 | 't10k-images-idx3-ubyte')), 149 | read_label_file(os.path.join(self.root, self.raw_folder, 150 | 't10k-labels-idx1-ubyte')) 151 | ) 152 | with open(os.path.join(self.root, self.processed_folder, 153 | self.training_file), 'wb') as f: 154 | torch.save(training_set, f) 155 | with open(os.path.join(self.root, self.processed_folder, 156 | self.test_file), 'wb') as f: 157 | torch.save(test_set, f) 158 | 159 | print('Done!') 160 | 161 | def get_int(b): 162 | return int(codecs.encode(b, 'hex'), 16) 163 | 164 | def parse_byte(b): 165 | if isinstance(b, str): 166 | return ord(b) 167 | return b 168 | 169 | def read_label_file(path): 170 | with open(path, 'rb') as f: 171 | data_in = f.read() 172 | assert get_int(data_in[:4]) == 2049 173 | length = get_int(data_in[4:8]) 174 | labels = [parse_byte(b) for b in data_in[8:]] 175 | assert len(labels) == length 176 | return torch.LongTensor(labels) 177 | 178 | def read_image_file(path): 179 | with open(path, 'rb') as f: 180 | data_in = f.read() 181 | assert get_int(data_in[:4]) == 2051 182 | length = get_int(data_in[4:8]) 183 | num_rows = get_int(data_in[8:12]) 184 | num_cols = get_int(data_in[12:16]) 185 | images = [] 186 | idx = 16 187 | for l in range(length): 188 | img = [] 189 | images.append(img) 190 | for r in range(num_rows): 191 | row = [] 192 | img.append(row) 193 | for c in range(num_cols): 194 | row.append(parse_byte(data_in[idx])) 195 | idx += 1 196 | assert len(images) == length 197 | return torch.ByteTensor(images).view(-1, 28, 28) 198 | -------------------------------------------------------------------------------- /vulcanai/tests/plotters/test_visualization.py: -------------------------------------------------------------------------------- 1 | """Will test visualization functions.""" 2 | import pytest 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | from torch.utils.data import TensorDataset, DataLoader 7 | from vulcanai.datasets import tabular_data_utils 8 | import os 9 | from copy import deepcopy 10 | 11 | from sklearn.datasets import load_digits 12 | from sklearn.metrics import confusion_matrix 13 | from vulcanai.plotters.visualization import compute_saliency_map, \ 14 | display_pca, \ 15 | display_tsne, \ 16 | display_receptive_fields, \ 17 | display_saliency_overlay, \ 18 | display_confusion_matrix 19 | 20 | 21 | class TestVisualization: 22 | """Test all visualization functionality.""" 23 | 24 | @pytest.fixture 25 | def cnn_class(self): 26 | """Create ConvNet with classes fixture.""" 27 | from vulcanai.models.cnn import ConvNet 28 | return ConvNet( 29 | name='Test_ConvNet_class', 30 | in_dim=(1, 28, 28), 31 | config={ 32 | 'conv_units': [ 33 | { 34 | "in_channels": 1, 35 | "out_channels": 16, 36 | "kernel_size": (5, 5), 37 | "stride": 1, 38 | "padding": 2 39 | }, 40 | { 41 | "in_channels": 16, 42 | "out_channels": 1, 43 | "kernel_size": (5, 5), 44 | "stride": 1, 45 | "padding": 2 46 | }] 47 | }, 48 | num_classes=3 49 | ) 50 | 51 | @pytest.fixture 52 | def dnn_class(self): 53 | """Create DenseNet with classes fixture.""" 54 | from vulcanai.models.dnn import DenseNet 55 | return DenseNet( 56 | name='Test_DenseNet_class', 57 | in_dim=(200), 58 | config={ 59 | 'dense_units': [100, 50], 60 | 'dropouts': 0.3, 61 | }, 62 | num_classes=3 63 | ) 64 | 65 | @pytest.fixture 66 | def dnn_class_two(self): 67 | """Create DenseNet with no prediction layer.""" 68 | from vulcanai.models.dnn import DenseNet 69 | return DenseNet( 70 | name='Test_DenseNet_class', 71 | in_dim=(200), 72 | activation=torch.nn.SELU(), 73 | num_classes=2, 74 | input_networks=None, 75 | config={ 76 | 'dense_units': [100], 77 | 'dropout': [0.3], 78 | 'initializer': None, 79 | 'bias_init': None, 80 | 'norm': None 81 | }, 82 | optim_spec={'name': 'Adam', 'lr': 0.001} 83 | ) 84 | 85 | def test_compute_saliency_map_cnn(self, cnn_class): 86 | """Confirm hooks are removed, and gradient shape.""" 87 | test_input_1B = torch.ones([1, *cnn_class.in_dim]) 88 | test_input_5B = torch.ones([5, *cnn_class.in_dim]) 89 | 90 | model_copy = deepcopy(cnn_class) 91 | # Test shape conservation 92 | cnn_class.freeze(apply_inputs=False) 93 | sal_map_1B = compute_saliency_map( 94 | cnn_class, 95 | test_input_1B, torch.LongTensor([2])) 96 | for sal_map, test_input in zip(sal_map_1B, [test_input_1B]): 97 | assert sal_map.shape == test_input.shape 98 | 99 | cnn_class.unfreeze(apply_inputs=False) 100 | sal_map_5B = compute_saliency_map( 101 | cnn_class, test_input_5B, 102 | torch.LongTensor([0, 2, 1, 1, 0])) 103 | for sal_map, test_input in zip(sal_map_5B, [test_input_5B]): 104 | assert sal_map.shape == test_input.shape 105 | 106 | # Check that all gradients are not 0 107 | assert ~np.all(sal_map_5B == 0.) 108 | 109 | # Test hook removal 110 | assert cnn_class._backward_hooks == model_copy._backward_hooks 111 | 112 | def test_compute_saliency_map_dnn(self, dnn_class): 113 | """Confirm hooks are removed, and gradient shape.""" 114 | test_input_1B = torch.ones([1, *dnn_class.in_dim]) 115 | test_input_5B = torch.ones([5, *dnn_class.in_dim]) 116 | model_copy = deepcopy(dnn_class) 117 | # Test shape conservation 118 | dnn_class.freeze(apply_inputs=False) 119 | sal_map_1B = compute_saliency_map( 120 | dnn_class, 121 | test_input_1B, torch.LongTensor([2])) 122 | for sal_map, test_input in zip(sal_map_1B, [test_input_1B]): 123 | assert sal_map.shape == test_input.shape 124 | 125 | dnn_class.unfreeze(apply_inputs=False) 126 | sal_map_5B = compute_saliency_map( 127 | dnn_class, test_input_5B, 128 | torch.LongTensor([0, 2, 1, 1, 0])) 129 | for sal_map, test_input in zip(sal_map_5B, [test_input_5B]): 130 | assert sal_map.shape == test_input.shape 131 | 132 | # Check that all gradients are not 0 133 | assert ~np.all(sal_map_5B == 0.) 134 | 135 | # Test hook removal 136 | assert dnn_class._backward_hooks == model_copy._backward_hooks 137 | 138 | def test_display_saliency_overlay(self, dnn_class): 139 | """Test saliency overlay displays and saves.""" 140 | curr_path = str(os.path.dirname(__file__)) + '/' 141 | img = np.zeros((25,25)) 142 | test_input_1B = torch.ones([1, *dnn_class.in_dim]) 143 | sal_map_1B = np.array(compute_saliency_map( 144 | dnn_class, 145 | test_input_1B, torch.LongTensor([2]))) 146 | display_saliency_overlay(img, sal_map_1B, shape=(25,25), save_path=curr_path) 147 | file_created = False 148 | for file in os.listdir(curr_path): 149 | if file.startswith('saliency'): 150 | file_created = True 151 | file_path = curr_path + file 152 | os.remove(file_path) 153 | assert file_created 154 | 155 | def test_display_pca(self, dnn_class): 156 | """Test PCA displays and saves.""" 157 | curr_path = str(os.path.dirname(__file__)) + '/' 158 | digits = load_digits() 159 | display_pca( 160 | digits.data[0:10], digits.target[0:10], save_path=curr_path) 161 | file_created = False 162 | for file in os.listdir(curr_path): 163 | if file.startswith('PCA'): 164 | file_created = True 165 | file_path = curr_path + file 166 | os.remove(file_path) 167 | assert file_created 168 | 169 | def test_display_tsne(self, dnn_class): 170 | """Test t-SNE displays and saves.""" 171 | curr_path = str(os.path.dirname(__file__)) + '/' 172 | digits = load_digits() 173 | display_tsne( 174 | digits.data[0:10], digits.target[0:10], save_path=curr_path) 175 | file_created = False 176 | for file in os.listdir(curr_path): 177 | if file.startswith('t-SNE'): 178 | file_created = True 179 | file_path = curr_path + file 180 | os.remove(file_path) 181 | assert file_created 182 | 183 | def test_display_confusion_matrix(self): 184 | """Test confusion matrix displays and saves.""" 185 | curr_path = str(os.path.dirname(__file__)) + '/' 186 | cm = confusion_matrix(y_true=[0, 1, 1, 0, 0], y_pred=[1, 1, 0, 1, 0]) 187 | display_confusion_matrix(cm, class_list=[0, 1], save_path=curr_path) 188 | file_created = False 189 | for file in os.listdir(curr_path): 190 | if file.startswith('confusion'): 191 | file_created = True 192 | file_path = curr_path + file 193 | os.remove(file_path) 194 | assert file_created 195 | 196 | def test_receptive_field(self, dnn_class_two): 197 | """Test receptive field visualization gets created and can be saved.""" 198 | curr_path = str(os.path.dirname(__file__)) + '/' 199 | test_input = torch.ones([10, *dnn_class_two.in_dim]).float() 200 | test_target = torch.LongTensor([0, 1, 1, 0, 0, 1, 0, 1, 0, 1]) 201 | test_dataloader = DataLoader(TensorDataset(test_input, test_target)) 202 | 203 | dnn_class_two.fit(test_dataloader, test_dataloader, 5) 204 | display_receptive_fields(dnn_class_two, save_path=curr_path) 205 | file_created = False 206 | for file in os.listdir(curr_path): 207 | if file.startswith('feature_importance'): 208 | file_created = True 209 | file_path = curr_path + file 210 | os.remove(file_path) 211 | assert file_created 212 | -------------------------------------------------------------------------------- /vulcanai/tests/models/conftest.py: -------------------------------------------------------------------------------- 1 | """Specify dummy networks to test vulcan functionality.""" 2 | import pytest 3 | 4 | import torch 5 | from torch.utils.data import TensorDataset 6 | from vulcanai.datasets import MultiDataset 7 | from vulcanai.models import ConvNet, DenseNet 8 | from torch.utils.data import DataLoader, Subset 9 | 10 | 11 | @pytest.fixture(scope="module") 12 | def conv1D_net(): 13 | """conv1D fixture.""" 14 | return ConvNet( 15 | name='conv1D_net', 16 | in_dim=(1, 28), 17 | config={ 18 | 'conv_units': [ 19 | dict( 20 | in_channels=1, 21 | out_channels=24, 22 | kernel_size=(5), 23 | stride=2, 24 | pool_size=2, 25 | dropout=0.1 26 | ), 27 | dict( 28 | in_channels=24, 29 | out_channels=64, 30 | kernel_size=(5), 31 | pool_size=2, 32 | dropout=0.1 33 | ) 34 | ], 35 | }, 36 | device='cpu' 37 | ) 38 | 39 | 40 | @pytest.fixture(scope="module") 41 | def conv2D_net(): 42 | """conv2D fixture.""" 43 | return ConvNet( 44 | name='conv2D_net', 45 | in_dim=(1, 28, 28), 46 | config={ 47 | 'conv_units': [ 48 | dict( 49 | in_channels=1, 50 | out_channels=24, 51 | kernel_size=(5, 5), 52 | stride=2, 53 | pool_size=2, 54 | dropout=0.1 55 | ), 56 | dict( 57 | in_channels=24, 58 | out_channels=64, 59 | kernel_size=(5, 5), 60 | pool_size=2, 61 | dropout=0.1 62 | ) 63 | ], 64 | }, 65 | device='cpu' 66 | ) 67 | 68 | 69 | @pytest.fixture(scope="module") 70 | def conv3D_net(): 71 | """conv3D fixture.""" 72 | return ConvNet( 73 | name='conv3D_net', 74 | in_dim=(1, 28, 28, 28), 75 | config={ 76 | 'conv_units': [ 77 | dict( 78 | in_channels=1, 79 | out_channels=16, 80 | kernel_size=(5, 5, 5), 81 | stride=2, 82 | dropout=0.1 83 | ), 84 | dict( 85 | in_channels=16, 86 | out_channels=64, 87 | kernel_size=(5, 5, 5), 88 | dropout=0.1 89 | ) 90 | ], 91 | }, 92 | device='cpu' 93 | ) 94 | 95 | 96 | @pytest.fixture(scope="module") 97 | def conv3D_net_class(): 98 | """conv3D fixture.""" 99 | return ConvNet( 100 | name='conv3D_net', 101 | in_dim=(1, 28, 28, 28), 102 | num_classes=10, 103 | config={ 104 | 'conv_units': [ 105 | dict( 106 | in_channels=1, 107 | out_channels=16, 108 | kernel_size=(5, 5, 5), 109 | stride=2, 110 | dropout=0.1 111 | ), 112 | dict( 113 | in_channels=16, 114 | out_channels=64, 115 | kernel_size=(5, 5, 5), 116 | dropout=0.1 117 | ) 118 | ], 119 | }, 120 | device='cpu' 121 | ) 122 | 123 | @pytest.fixture(scope="module") 124 | def conv3D_net_class_early_stopping(): 125 | """conv3D fixture.""" 126 | return ConvNet( 127 | name='conv3D_net', 128 | in_dim=(1, 28, 28, 28), 129 | num_classes=10, 130 | early_stopping="best_validation_error", 131 | early_stopping_patience=2, 132 | config={ 133 | 'conv_units': [ 134 | dict( 135 | in_channels=1, 136 | out_channels=16, 137 | kernel_size=(5, 5, 5), 138 | stride=2, 139 | dropout=0.1 140 | ), 141 | dict( 142 | in_channels=16, 143 | out_channels=64, 144 | kernel_size=(5, 5, 5), 145 | dropout=0.1 146 | ) 147 | ], 148 | }, 149 | device='cpu' 150 | ) 151 | 152 | @pytest.fixture(scope="module") 153 | def conv3D_net_class_single_value(): 154 | """conv3D fixture.""" 155 | return ConvNet( 156 | name='conv3D_net', 157 | in_dim=(1, 28, 28, 28), 158 | num_classes=1, 159 | config={ 160 | 'conv_units': [ 161 | dict( 162 | in_channels=1, 163 | out_channels=16, 164 | kernel_size=(5, 5, 5), 165 | stride=2, 166 | dropout=0.1 167 | ), 168 | dict( 169 | in_channels=16, 170 | out_channels=64, 171 | kernel_size=(5, 5, 5), 172 | dropout=0.1 173 | ) 174 | ], 175 | }, 176 | device='cpu' 177 | ) 178 | 179 | 180 | @pytest.fixture(scope="module") 181 | def dnn_noclass(): 182 | """DenseNet fixture.""" 183 | return DenseNet( 184 | name='dnn_noclass', 185 | in_dim=(200), 186 | config={ 187 | 'dense_units': [100, 50], 188 | 'dropout': [0.3, 0.5], 189 | } 190 | ) 191 | 192 | 193 | @pytest.fixture(scope="module") 194 | def dnn_class(): 195 | """DenseNet with prediction layer.""" 196 | return DenseNet( 197 | name='dnn_class', 198 | in_dim=(200), 199 | config={ 200 | 'dense_units': [100, 50], 201 | 'dropout': 0.5, 202 | }, 203 | num_classes=3 204 | ) 205 | 206 | 207 | @pytest.fixture(scope="module") 208 | def dnn_class_early_stopping(): 209 | """DenseNet with prediction layer.""" 210 | return DenseNet( 211 | name='dnn_class', 212 | in_dim=(200), 213 | early_stopping="best_validation_error", 214 | early_stopping_patience=2, 215 | config={ 216 | 'dense_units': [100, 50], 217 | 'dropout': 0.5, 218 | }, 219 | num_classes=3 220 | ) 221 | 222 | 223 | @pytest.fixture(scope="module") 224 | def dnn_class_single_value(): 225 | """DenseNet with prediction layer.""" 226 | return DenseNet( 227 | name='dnn_class', 228 | in_dim=(200), 229 | config={ 230 | 'dense_units': [100, 50], 231 | 'dropout': 0.5, 232 | }, 233 | num_classes=1 234 | ) 235 | 236 | @pytest.fixture(scope="module") 237 | def multi_input_dnn(conv1D_net, conv2D_net): 238 | """Dense network fixture with two inputs.""" 239 | return DenseNet( 240 | name='multi_input_dnn', 241 | input_networks=[conv1D_net, conv2D_net], 242 | config={ 243 | 'dense_units': [100, 50], 244 | 'initializer': None, 245 | 'bias_init': None, 246 | 'norm': None, 247 | 'dropout': 0.5, # Single value or List 248 | }, 249 | device='cpu' 250 | ) 251 | 252 | 253 | @pytest.fixture(scope="module") 254 | def multi_input_cnn(conv2D_net, conv3D_net, multi_input_dnn): 255 | """Bottom multi-input network fixture.""" 256 | return ConvNet( 257 | name='multi_input_cnn', 258 | input_networks=[conv2D_net, conv3D_net, multi_input_dnn], 259 | num_classes=10, 260 | config={ 261 | 'conv_units': [ 262 | dict( 263 | in_channels=1, 264 | out_channels=16, 265 | kernel_size=(3, 3, 3), 266 | stride=2, 267 | dropout=0.1 268 | ), 269 | ], 270 | }, 271 | device='cpu' 272 | ) 273 | 274 | 275 | @pytest.fixture(scope="module") 276 | def multi_input_dnn_class(conv1D_net, conv2D_net): 277 | """Dense network fixture with two inputs.""" 278 | return DenseNet( 279 | name='multi_input_dnn_class', 280 | input_networks=[conv1D_net, conv2D_net], 281 | num_classes=10, 282 | config={ 283 | 'dense_units': [100, 50], 284 | 'initializer': None, 285 | 'bias_init': None, 286 | 'norm': None, 287 | 'dropout': 0.5, # Single value or List 288 | }, 289 | device='cpu' 290 | ) 291 | 292 | 293 | @pytest.fixture(scope="module") 294 | def multi_input_dnn_data(conv1D_net, conv2D_net, 295 | multi_input_dnn): 296 | return MultiDataset([ 297 | ( 298 | TensorDataset( 299 | torch.rand(size=[10, *conv1D_net.in_dim]), 300 | torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).long()), 301 | True, True), 302 | ( 303 | TensorDataset( 304 | torch.rand( 305 | size=[ 306 | 10, 307 | *multi_input_dnn.input_networks['conv2D_net']. 308 | in_dim])), 309 | True, False) 310 | ]) 311 | 312 | 313 | @pytest.fixture(scope="module") 314 | def multi_input_cnn_data(conv2D_net, conv3D_net, multi_input_dnn_data): 315 | return MultiDataset([ 316 | (TensorDataset(torch.rand(size=[10, *conv2D_net.in_dim])), True, False), 317 | (TensorDataset(torch.rand(size=[10, *conv3D_net.in_dim])), True, False), 318 | multi_input_dnn_data 319 | ]) 320 | 321 | 322 | @pytest.fixture(scope="module") 323 | def dnn_class_multi_value(): 324 | """DenseNet with prediction layer and mulitple classes.""" 325 | return DenseNet( 326 | name='dnn_class', 327 | in_dim=(12), 328 | config={ 329 | 'dense_units': [100, 50], 330 | 'dropout': 0.5, 331 | }, 332 | num_classes=3 333 | ) 334 | 335 | 336 | @pytest.fixture(scope="module") 337 | def sensitivity_data_loader(): 338 | torch.manual_seed(7) 339 | test_input = torch.rand(size=[5, 12]) 340 | test_dataloader = DataLoader(TensorDataset(test_input, torch.tensor( 341 | [0, 1, 2, 0, 1]))) 342 | return test_dataloader 343 | -------------------------------------------------------------------------------- /vulcanai/models/layers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """Define the ConvUnit and DenseUnit.""" 3 | import torch 4 | import torch.nn as nn 5 | import logging 6 | from .utils import selu_bias_init_, selu_weight_init_ 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class BaseUnit(nn.Sequential): 11 | """ 12 | The base class for all layers. 13 | 14 | Parameters: 15 | weight_init : torch.nn.init 16 | Torch initialization function for weights. 17 | bias_init : int or float 18 | A constant int or float to initialize biases with. 19 | norm : str 20 | 'batch' for batch norm of 'instance' for instance norm. 21 | dropout : float between 0-1 22 | The probability of dropping out a feature during training. 23 | 24 | Returns: 25 | dense_unit : torch.nn.Sequential 26 | A single fully connected layer. 27 | 28 | """ 29 | 30 | def __init__(self, weight_init=None, bias_init=None, 31 | norm=None, dropout=None): 32 | """Initialize a base unit.""" 33 | super(BaseUnit, self).__init__() 34 | 35 | self.weight_init = weight_init 36 | self.bias_init = bias_init 37 | self.norm = norm 38 | self.dropout = dropout 39 | 40 | self.in_shape = None # [self.batch_size, *in_shape] 41 | self.out_shape = None # [self.batch_size, *out_shape] 42 | self.in_bound_layers = [] 43 | self.out_bound_layers = [] 44 | 45 | self._kernel = None 46 | 47 | def _init_weights(self): 48 | """ 49 | Initialize the weights. 50 | 51 | If self.weight_init is None, then pytorch default weight 52 | will be assigned to the kernel. 53 | """ 54 | if self.weight_init: 55 | self.weight_init(self._kernel.weight) 56 | 57 | def _init_bias(self): 58 | """ 59 | Initialize the bias. 60 | 61 | If self.bias_init is None, then pytorch default bias 62 | will be assigned to the kernel. 63 | """ 64 | if self.bias_init: 65 | self.bias_init(self._kernel.bias) 66 | 67 | 68 | class FlattenUnit(BaseUnit): 69 | """ 70 | Layer to flatten the input. 71 | 72 | Returns: 73 | flatten_unit : torch.Sequential 74 | A flatten layer. 75 | 76 | """ 77 | 78 | def __init__(self): 79 | """Initialize flatten layer.""" 80 | super(FlattenUnit, self).__init__() 81 | 82 | def forward(self, x): 83 | """Maintain batch size but flatten all remaining dimensions.""" 84 | return x.view(x.shape[0], -1) 85 | 86 | 87 | # noinspection PyUnresolvedReferences 88 | class DenseUnit(BaseUnit): 89 | """ 90 | Define the DenseUnit object. 91 | 92 | Parameters: 93 | in_features : int 94 | The incoming feature size of a sample. 95 | out_features : int 96 | The number of hidden Linear units for this layer. 97 | weight_init : torch.nn.init 98 | Torch initialization function. 99 | bias_init : int or float 100 | A constant int or float to initialize biases with. 101 | norm : str 102 | 'batch' for batch norm of 'instance' for instance norm. 103 | activation : torch.nn.Module 104 | An activation function to apply after Linear unit. 105 | dropout : float between 0-1 106 | The probability of dropping out a feature during training. 107 | 108 | Returns: 109 | dense_unit : torch.nn.Sequential 110 | A single fully connected layer. 111 | 112 | """ 113 | 114 | def __init__(self, in_features, out_features, 115 | weight_init=None, bias_init=None, 116 | norm=None, activation=None, dropout=None): 117 | """Initialize a single DenseUnit (i.e. a dense layer).""" 118 | super(DenseUnit, self).__init__(weight_init, bias_init, 119 | norm, dropout) 120 | self.in_features = in_features 121 | self.out_features = out_features 122 | # Main layer 123 | self._kernel = nn.Linear( 124 | in_features=self.in_features, 125 | out_features=self.out_features) 126 | self.add_module('_kernel', self._kernel) 127 | 128 | # Norm 129 | if self.norm is not None: 130 | if self.norm == 'batch': 131 | self.add_module( 132 | '_norm', 133 | torch.nn.BatchNorm1d(self.out_features)) 134 | elif self.norm == 'instance': 135 | self.add_module( 136 | '_norm', 137 | torch.nn.InstanceNorm1d(self.out_features)) 138 | 139 | # Activation/Non-Linearity 140 | if activation is not None: 141 | self.add_module('_activation', activation) 142 | if isinstance(activation, nn.SELU): 143 | self.weight_init = selu_weight_init_ 144 | self.bias_init = selu_bias_init_ 145 | 146 | # Dropout 147 | if self.dropout is not None: 148 | if isinstance(activation, nn.SELU): 149 | self.add_module( 150 | '_dropout', nn.AlphaDropout(self.dropout)) 151 | else: 152 | self.add_module( 153 | '_dropout', nn.Dropout(self.dropout)) 154 | self._init_weights() 155 | self._init_bias() 156 | 157 | 158 | # TODO: Automatically calculate padding to be the same as input shape. 159 | class ConvUnit(BaseUnit): 160 | """ 161 | Define the ConvUnit object. 162 | 163 | Parameters: 164 | conv_dim : int 165 | 1, 2, or 3 representing spatial dimensional inputs. 166 | in_channels : int 167 | The number of incoming channels. 168 | out_channels : int 169 | The number of convolution kernels for this layer. 170 | kernel_size : int or tuple 171 | The size of the 1, 2, or 3 dimensional convolution kernel. 172 | weight_init : torch.nn.init 173 | Torch initialization function. 174 | bias_init : int or float 175 | A constant int or float to initialize biases with. 176 | stride : int or tuple 177 | The stride of the 1, 2, or 3 dimensional convolution kernel. 178 | padding : int 179 | Number of zero-padding on both sides per dimension. 180 | norm : str 181 | 'batch' for batch norm of 'instance' for instance norm. 182 | activation : torch.nn.Module 183 | An activation function to apply after Linear unit. 184 | pool_size : int 185 | Max pooling by a factor of pool_size in each dimension. 186 | dropout : float between 0-1 187 | The probability of dropping out a feature during training. 188 | 189 | Returns: 190 | conv_unit : torch.nn.Sequential 191 | A single convolution layer. 192 | 193 | """ 194 | 195 | def __init__(self, conv_dim, in_channels, out_channels, kernel_size, 196 | weight_init=None, bias_init=None, 197 | stride=1, padding=0, norm=None, 198 | activation=None, pool_size=None, dropout=None): 199 | """Initialize a single ConvUnit (i.e. a conv layer).""" 200 | super(ConvUnit, self).__init__(weight_init, bias_init, 201 | norm, dropout) 202 | self.conv_dim = conv_dim 203 | self._init_layers() 204 | 205 | self.in_channels = in_channels 206 | self.out_channels = out_channels 207 | self.kernel_size = kernel_size 208 | 209 | # Main layer 210 | self._kernel = self.conv_layer( 211 | in_channels=self.in_channels, 212 | kernel_size=self.kernel_size, 213 | out_channels=self.out_channels, 214 | stride=stride, 215 | padding=padding, 216 | bias=True 217 | ) 218 | self.add_module('_kernel', self._kernel) 219 | 220 | # Norm 221 | if self.norm: 222 | if self.norm == 'batch': 223 | self.add_module( 224 | '_norm', 225 | self.batch_norm(num_features=self.out_channels)) 226 | elif self.norm == 'instance': 227 | self.add_module( 228 | '_norm', 229 | self.instance_norm(num_features=self.out_channels)) 230 | 231 | # Activation/Non-Linearity 232 | if activation is not None: 233 | self.add_module('_activation', activation) 234 | if isinstance(activation, nn.SELU): 235 | self.weight_init = selu_weight_init_ 236 | self.bias_init = selu_bias_init_ 237 | # Pool 238 | if pool_size is not None: 239 | self.add_module( 240 | '_pool', self.pool_layer(kernel_size=pool_size)) 241 | 242 | # Dropout 243 | if self.dropout is not None: 244 | if isinstance(activation, nn.SELU): 245 | self.add_module( 246 | '_dropout', nn.AlphaDropout(self.dropout)) 247 | else: 248 | self.add_module( 249 | '_dropout', self.dropout_layer(self.dropout)) 250 | self._init_weights() 251 | self._init_bias() 252 | 253 | def _init_layers(self): 254 | if self.conv_dim == 1: 255 | self.conv_layer = nn.Conv1d 256 | self.batch_norm = nn.BatchNorm1d 257 | self.instance_norm = nn.InstanceNorm1d 258 | self.pool_layer = nn.MaxPool1d 259 | self.dropout_layer = nn.Dropout 260 | elif self.conv_dim == 2: 261 | self.conv_layer = nn.Conv2d 262 | self.batch_norm = nn.BatchNorm2d 263 | self.instance_norm = nn.InstanceNorm2d 264 | self.pool_layer = nn.MaxPool2d 265 | self.dropout_layer = nn.Dropout2d 266 | elif self.conv_dim == 3: 267 | self.conv_layer = nn.Conv3d 268 | self.batch_norm = nn.BatchNorm3d 269 | self.instance_norm = nn.InstanceNorm3d 270 | self.pool_layer = nn.MaxPool3d 271 | self.dropout_layer = nn.Dropout3d 272 | else: 273 | self.conv_layer = None 274 | self.batch_norm = None 275 | self.instance_norm = None 276 | self.pool_layer = None 277 | self.dropout_layer = None 278 | raise ValueError( 279 | "Convolution is only supported for" 280 | " one of the first three dimensions.") 281 | 282 | def get_conv_output_size(self): 283 | """Calculate the size of the flattened features after conv.""" 284 | with torch.no_grad(): 285 | x = torch.ones(1, *self.in_dim) 286 | x = self.conv_model(x) 287 | return x.numel() -------------------------------------------------------------------------------- /vulcanai/tests/datasets/test_tabulardataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ Defines test cases for tabular dataset """ 3 | import pytest 4 | from vulcanai.datasets import tabular_data_utils 5 | import os 6 | import pandas as pd 7 | import torch 8 | import numpy as np 9 | 10 | # noinspection PyMissingOrEmptyDocstring 11 | class TestTabularDataset: 12 | @pytest.fixture 13 | def split_data(self): 14 | pass 15 | 16 | @pytest.fixture 17 | def my_test_dataset(self): 18 | """Create a dataset by importing from the test csv""" 19 | fpath = str(os.path.dirname(__file__)) + \ 20 | "/test_data/birthweight_reduced.csv" 21 | # Nan just an artifact of this dataset. 22 | return pd.read_csv(fpath, na_values='Nan') 23 | 24 | def test_convert_to_tensor_dataset(self, my_test_dataset): 25 | 26 | only_numeric = my_test_dataset.drop("LowBirthWeight", axis=1) 27 | res1 = tabular_data_utils.convert_to_tensor_datasets(only_numeric, 28 | target_vars="mnocig" 29 | ) 30 | res2 = tabular_data_utils.convert_to_tensor_datasets(only_numeric, 31 | target_vars="mnocig", 32 | continuous_target=True) 33 | 34 | assert isinstance(res1, torch.utils.data.TensorDataset) 35 | assert isinstance(res2, torch.utils.data.TensorDataset) 36 | 37 | def test_create_label_encoding(self, my_test_dataset): 38 | tabular_data_utils.create_label_encoding(my_test_dataset, 39 | "LowBirthWeight", 40 | {"Low": 0, "Normal": 1}) 41 | assert set(list(my_test_dataset["LowBirthWeight"].unique())) \ 42 | == {0, 1} 43 | 44 | def test_create_one_hot_encoding(self, my_test_dataset): 45 | res = tabular_data_utils.create_one_hot_encoding(my_test_dataset, 46 | "LowBirthWeight") 47 | assert "LowBirthWeight@Low" in set(list(res.columns)) 48 | 49 | def test_reverse_create_all_one_hot_encodings(self, my_test_dataset): 50 | tabular_data_utils.create_one_hot_encoding(my_test_dataset, 51 | "LowBirthWeight") 52 | res = tabular_data_utils.reverse_create_one_hot_encoding(my_test_dataset, 53 | prefix_sep="@") 54 | assert "LowBirthWeight@Low" not in set(list(res.columns)) 55 | 56 | def test_identify_null(self, my_test_dataset): 57 | num_threshold = 0.2 58 | res = tabular_data_utils.identify_null(my_test_dataset, num_threshold) 59 | assert {'fedyrs'} == set(res) 60 | 61 | def test_identify_unique(self, my_test_dataset): 62 | res = tabular_data_utils.identify_unique(my_test_dataset, 5) 63 | assert set(res) == {'headcirumference', 'smoker', 'fedyrs', 64 | 'LowBirthWeight'} 65 | 66 | def test_identify_unbalanced_columns(self, my_test_dataset): 67 | res = tabular_data_utils.identify_unbalanced_columns(my_test_dataset, 68 | 0.5) 69 | assert set(res) == {'headcirumference', 'smoker', 'mnocig', 70 | 'LowBirthWeight'} 71 | 72 | def test_identify_highly_correlated(self, my_test_dataset): 73 | res = tabular_data_utils.identify_highly_correlated(my_test_dataset, 74 | 0.2) 75 | assert (('fage', 'motherage'), 0.8065844173531495) in res 76 | 77 | def test_identify_low_variance(self, my_test_dataset): 78 | res = tabular_data_utils.identify_low_variance(my_test_dataset, 0.05 79 | ) 80 | assert 'Gestation' in res 81 | 82 | 83 | # This is breaking lint for now because it aids in the clarity of the data 84 | # TODO: refactor 85 | class TestStitchDataset: 86 | @pytest.fixture 87 | def my_test_dataset_one(self): 88 | dct_dfs = {'df_test_one': pd.DataFrame({'A': ['A0', 'A1', 'A2', 'A3'], 'B': ['B0', 'B1', 'B2', 'B3'], 89 | 'C': ['C0', 'C1', 'C2', 'C3'], 'D': ['D0', 'D1', 'D2', 'D3']}, 90 | index=[0, 1, 2, 3]), 91 | 'df_test_two': pd.DataFrame({'A': ['A4', 'A5', 'A6', 'A7'], 'B': ['B4', 'B5', 'B6', 'B7'], 92 | 'C': ['C4', 'C5', 'C6', 'C7'], 'D': ['D4', 'D5', 'D6', 'D7']}, 93 | index=[4, 5, 6, 7])} 94 | return dct_dfs 95 | 96 | @pytest.fixture 97 | def my_test_dataset_two(self): 98 | dct_dfs = {'df_test_one': pd.DataFrame({'name': ['Jane', 'John', 'Jesse', 'Jane', 'John', 'Jesse'], 99 | 'age': [23, 25, 26, np.nan, np.nan, np.nan]}, 100 | index=[0, 1, 2, 3, 4, 5]), 101 | 'df_test_two': pd.DataFrame({'name': ['Jane', 'John', 'Jesse'], 102 | 'state': ['CA', 'WA', 'OR']}, 103 | index=[0, 1, 2])} 104 | return dct_dfs 105 | 106 | @pytest.fixture 107 | def my_test_dataset_three(self): 108 | dct_dfs = {'df_test_one': pd.DataFrame({'name': ['Jane', 'John', 'Jesse', 'Jane', 'John', 'Jesse', 'Jane'], 109 | 'age': [23, 25, 26, np.nan, np.nan, np.nan, 23], 110 | 'dob': ['09-18-1995', '10-18-1993', '06-18-1992', np.nan, np.nan, 111 | np.nan, 112 | '05-23-1995']}, 113 | index=[0, 1, 2, 3, 4, 5, 6]), 114 | 'df_test_two': pd.DataFrame({'name': ['Jane', 'John', 'Jesse', 'Jane'], 115 | 'dob': ['09-18-1995', '10-18-1993', '06-18-1992', '05-23-1995'], 116 | 'state': ['CA', 'WA', 'OR', 'AZ']}, 117 | index=[0, 1, 2, 3])} 118 | return dct_dfs 119 | 120 | @pytest.fixture 121 | def my_test_dataset_four(self): 122 | dct_dfs = {'df_test_one': pd.DataFrame({'name': ['Jane', 'John', 'Jesse', 'Jane'], 123 | 'age': [23, 25, 26, 23], 124 | 'dob': ['09-18-1995', '10-18-1993', '06-18-1992', '05-23-1995'], 125 | 'visit_date': ['11/29/2018', '11/29/2018', '11/29/2018', '11/29/2018'], 126 | 'visit_location': ['HI', 'HI', 'HI', 'HI']}, 127 | index=[0, 1, 2, 3]), 128 | 'df_test_two': pd.DataFrame({'name': ['John', 'Jesse'], 129 | 'age': [25, 26], 130 | 'dob': ['10-18-1993', '06-18-1992'], 131 | 'visit_date': ['09/12/2018', '12/20/2017'], 132 | 'visit_location': ['CA', 'AZ']}, 133 | index=[0, 1])} 134 | return dct_dfs 135 | 136 | def test_no_merge_on_columns(self, my_test_dataset_one): 137 | # MOC (merge on columns) 138 | df_no_moc_results = pd.DataFrame({'A': ['A0', 'A1', 'A2', 'A3', 'A4', 'A5', 'A6', 'A7'], 139 | 'B': ['B0', 'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7'], 140 | 'C': ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7'], 141 | 'D': ['D0', 'D1', 'D2', 'D3', 'D4', 'D5', 'D6', 'D7']}, 142 | index=[0, 1, 2, 3, 4, 5, 6, 7]) 143 | 144 | stitch_dataset_results = tabular_data_utils.stitch_datasets(merge_on_columns=None, **my_test_dataset_one) 145 | pd.testing.assert_frame_equal(stitch_dataset_results, df_no_moc_results) 146 | 147 | def test_single_merge_on_columns(self, my_test_dataset_two): 148 | df_single_moc_results = pd.DataFrame({'name': ['Jane', 'John', 'Jesse'], 149 | 'age': [23, 25, 26], 150 | 'state': ['CA', 'WA', 'OR']}, 151 | index=[0, 1, 2]) 152 | stitch_dataset_results = tabular_data_utils.stitch_datasets(merge_on_columns=['name'], **my_test_dataset_two) 153 | 154 | # Assert_frame_equal checks order of columns; therefore, sort_index by columns when checking. If dataframes are 155 | # same, they should sort the same way. 156 | pd.testing.assert_frame_equal(stitch_dataset_results.sort_index(axis=1), df_single_moc_results.sort_index(axis=1), 157 | check_dtype=False) 158 | 159 | def test_two_merge_on_columns(self, my_test_dataset_three): 160 | df_two_moc_results = pd.DataFrame({'name': ['Jane', 'John', 'Jesse', 'Jane'], 161 | 'age': [23, 25, 26, 23], 162 | 'dob': ['09-18-1995', '10-18-1993', '06-18-1992', '05-23-1995'], 163 | 'state': ['CA', 'WA', 'OR', 'AZ']}, 164 | index=[0, 1, 2, 6]) 165 | stitch_dataset_results = tabular_data_utils.stitch_datasets(merge_on_columns=['name', 'dob'], **my_test_dataset_three) 166 | 167 | # Assert_frame_equal checks order of columns; therefore, sort_index by columns when checking. If dataframes are 168 | # same, they should sort the same way. 169 | pd.testing.assert_frame_equal(stitch_dataset_results.sort_index(axis=1), df_two_moc_results.sort_index(axis=1), 170 | check_dtype=False) 171 | 172 | def test_three_merge_on_columns(self, my_test_dataset_four): 173 | df_three_moc_results = pd.DataFrame({'name': ['Jane', 'John', 'Jesse', 'Jane', 'John', 'Jesse'], 174 | 'age': [23, 25, 26, 23, 25, 26], 175 | 'dob': ['09-18-1995', '10-18-1993', '06-18-1992', '05-23-1995', 176 | '10-18-1993', '06-18-1992'], 177 | 'visit_date': ['11/29/2018', '11/29/2018', '11/29/2018', '11/29/2018', 178 | '09/12/2018', '12/20/2017'], 179 | 'visit_location': ['HI', 'HI', 'HI', 'HI', 'CA', 'AZ']}, 180 | index=[0, 1, 2, 3, 4, 5]) 181 | 182 | stitch_dataset_results = tabular_data_utils.stitch_datasets(merge_on_columns=['name', 'dob', 'visit_date'], **my_test_dataset_four) 183 | 184 | # Assert_frame_equal checks order of columns; therefore, sort_index by columns when checking. If dataframes are 185 | # same, they should sort the same way. 186 | pd.testing.assert_frame_equal(stitch_dataset_results.sort_index(axis=1), df_three_moc_results.sort_index(axis=1), check_dtype=False) 187 | -------------------------------------------------------------------------------- /vulcanai/models/cnn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """Defines the ConvNet class.""" 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | from .basenetwork import BaseNetwork 8 | from .layers import DenseUnit, ConvUnit, FlattenUnit 9 | from .utils import pad 10 | 11 | import logging 12 | from inspect import getfullargspec 13 | from math import ceil 14 | 15 | from collections import OrderedDict 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | # TODO: use setters to enforce types/formats/values! 21 | # TODO: make this a base class? 22 | # TODO: add additional constraints in the future 23 | class ConvNetConfig: 24 | """Define the necessary configuration for a ConvNet.""" 25 | 26 | def __init__(self, raw_config): 27 | """ 28 | Take in user config dict and clean it up. 29 | 30 | Cleaned units is stored in self.units 31 | 32 | Parameters: 33 | raw_config : dict of dict 34 | User specified dict 35 | 36 | """ 37 | if 'conv_units' not in raw_config: 38 | raise KeyError("conv_units must be specified") 39 | 40 | # Confirm all passed units conform to Unit required arguments 41 | conv_unit_arg_spec = getfullargspec(ConvUnit) 42 | conv_unit_arg_spec.args.remove('self') 43 | # Deal with dim inference when cleaning unit 44 | conv_unit_arg_spec.args.remove('conv_dim') 45 | 46 | # Find the index for where the defaulted values begin 47 | default_arg_start_index = len(conv_unit_arg_spec.args) - \ 48 | len(conv_unit_arg_spec.defaults) 49 | self.required_args = conv_unit_arg_spec.args[:default_arg_start_index] 50 | self.units = [] 51 | for u in raw_config['conv_units']: 52 | unit = self._clean_unit(raw_unit=u) 53 | self.units.append(unit) 54 | 55 | def _clean_unit(self, raw_unit): 56 | """ 57 | Use this to catch mistakes in each user-specified unit. 58 | 59 | Infer dimension of Conv using the kernel shape. 60 | 61 | Parameters: 62 | raw_unit : dict 63 | 64 | Returns: 65 | unit : dict 66 | Cleaned unit config. 67 | 68 | """ 69 | unit = raw_unit 70 | for key in self.required_args: 71 | if key not in unit.keys(): 72 | raise ValueError( 73 | "{} needs to be specified in your config.".format(key)) 74 | if not isinstance(unit['kernel_size'], tuple): 75 | if isinstance(unit['kernel_size'], int): 76 | unit['kernel_size'] = (unit['kernel_size'],) 77 | unit['kernel_size'] = tuple(unit['kernel_size']) 78 | unit['conv_dim'] = len(unit['kernel_size']) 79 | return unit 80 | 81 | 82 | # noinspection PyDefaultArgument,PyTypeChecker 83 | class ConvNet(BaseNetwork): 84 | """ 85 | Subclass of BaseNetwork defining a ConvNet. 86 | 87 | Parameters: 88 | name : str 89 | The name of the network. Used when saving the file. 90 | config : dict 91 | The configuration of the network module, as a dict. 92 | in_dim : tuple 93 | The input dimensions of the network. Not required to specify when 94 | the network has input_networks. 95 | save_path : str 96 | The name of the file to which you would like to save this network. 97 | input_networks : list of BaseNetwork 98 | A network object provided as input. 99 | num_classes : int or None 100 | The number of classes to predict. 101 | activation : torch.nn.Module 102 | The desired activation function for use in the network. 103 | pred_activation : torch.nn.Module 104 | The desired activation function for use in the prediction layer. 105 | optim_spec : dict 106 | A dictionary of parameters for the desired optimizer. 107 | lr_scheduler : torch.optim.lr_scheduler 108 | A callable torch.optim.lr_scheduler 109 | early_stopping : str or None 110 | So far just 'best_validation_error' is implemented. 111 | early_stopping_patience: integer 112 | Number of validation iterations of decreasing loss 113 | (note -not necessarily every epoch! 114 | before early stopping is applied. 115 | early_stopping_metric: string 116 | Either "loss" or "accuracy" are implemented. 117 | criter_spec : dict 118 | criterion specification with name and all its parameters. 119 | 120 | Returns: 121 | network : ConvNet 122 | A network of type BaseNetwork. 123 | 124 | """ 125 | 126 | def __init__(self, name, config, in_dim=None, save_path=None, 127 | input_networks=None, num_classes=None, 128 | activation=nn.ReLU(), pred_activation=None, 129 | optim_spec={'name': 'Adam', 'lr': 0.001}, 130 | lr_scheduler=None, early_stopping=None, 131 | early_stopping_patience=None, 132 | early_stopping_metric="accuracy", 133 | criter_spec=nn.CrossEntropyLoss(), 134 | device="cuda:0"): 135 | """Define the ConvNet object.""" 136 | super(ConvNet, self).__init__( 137 | name, ConvNetConfig(config), in_dim, save_path, input_networks, 138 | num_classes, activation, pred_activation, optim_spec, 139 | lr_scheduler, early_stopping, early_stopping_patience, 140 | early_stopping_metric, 141 | criter_spec, device) 142 | 143 | def _create_network(self, **kwargs): 144 | """ 145 | Build the layers of the network into a nn.Sequential object. 146 | 147 | Parameters: 148 | conv_hid_layers : ConvNetConfig.units (list of dict) 149 | The hidden layers specification 150 | activation : torch.nn.Module 151 | the non-linear activation to apply to each layer 152 | 153 | Returns: 154 | output : torch.nn.Sequential 155 | the conv network as a nn.Sequential object 156 | 157 | """ 158 | conv_hid_layers = self._config.units 159 | 160 | if self.input_networks: 161 | self.in_dim = self._get_in_dim() 162 | 163 | conv_hid_layers[0]['in_channels'] = self.in_dim[0] 164 | conv_layers = OrderedDict() 165 | for idx, conv_layer_config in enumerate(conv_hid_layers): 166 | conv_layer_config['activation'] = kwargs['activation'] 167 | layer_name = 'conv_{}'.format(idx) 168 | conv_layers[layer_name] = ConvUnit(**conv_layer_config) 169 | self.network = nn.Sequential(conv_layers) 170 | 171 | if self.num_classes: 172 | self.network.add_module( 173 | 'flatten', FlattenUnit()) 174 | self.network.add_module( 175 | 'classify', DenseUnit( 176 | in_features=self._get_out_dim()[0], 177 | out_features=self.num_classes, 178 | activation=kwargs['pred_activation'])) 179 | 180 | def _merge_input_network_outputs(self, tensors): 181 | """Calculate converged in_dim for the MultiInput ConvNet.""" 182 | reshaped_tensors = [] 183 | # Determine what shape to cast to without losing any information. 184 | max_conv_tensor_size = self._get_max_incoming_spatial_dims() 185 | for t in tensors: 186 | if t.dim() == 2: 187 | # Cast Linear output to largest Conv output shape 188 | t = self._cast_linear_to_shape( 189 | tensor=t, 190 | cast_shape=max_conv_tensor_size) 191 | elif t.dim() > 2: 192 | # Cast Conv output to largest Conv output shape 193 | t = self._cast_conv_to_shape( 194 | tensor=t, 195 | cast_shape=max_conv_tensor_size) 196 | reshaped_tensors.append(t) 197 | return torch.cat(reshaped_tensors, dim=1) 198 | 199 | def _get_max_incoming_spatial_dims(self): 200 | """Return the max spatial dimensions of the input networks.""" 201 | # Ignoring the channels 202 | spatial_inputs = [] 203 | for in_net in self.input_networks.values(): 204 | if isinstance(in_net, ConvNet): 205 | spatial_inputs.append(list(in_net.out_dim[1:])) 206 | max_spatial_dim = len(max(spatial_inputs, key=len)) 207 | 208 | # Fill with zeros in missing dim to compare 209 | # max size later for each dim. 210 | for in_spatial_dim in spatial_inputs: 211 | while len(in_spatial_dim) < max_spatial_dim: 212 | in_spatial_dim.insert(0, 0) 213 | 214 | # All spatial dimensions 215 | # Take the max size in each dimension. 216 | max_conv_tensor_size = np.array(spatial_inputs).transpose().max(axis=1) 217 | return np.array(max_conv_tensor_size) 218 | 219 | @staticmethod 220 | def _cast_linear_to_shape(tensor, cast_shape): 221 | """ 222 | Convert Linear outputs into Conv outputs. 223 | 224 | Parameters: 225 | tensor : torch.Tensor 226 | The Linear tensor to reshape of shape [out_features] 227 | cast_shape : numpy.ndarray 228 | The Conv shape to cast linear to of shape 229 | [batch, num_channels, *spatial_dimensions]. 230 | 231 | Returns: 232 | tensor : torch.Tensor 233 | Tensor of shape [batch, num_channels, *spatial_dimensions] 234 | 235 | """ 236 | # Equivalent to calculating tensor.numel() in pytorch. 237 | sequence_length = cast_shape.prod() 238 | # How many channels to spread the information into 239 | # Ignore batch from linear 240 | n_channels = ceil(tensor[-1].numel() / sequence_length) 241 | # How much pad to add to either sides to reshape the linear tensor 242 | # into cast_shape spatial dimensions. 243 | pad_shape = sequence_length * n_channels 244 | tensor = pad(tensor=tensor, target_shape=[pad_shape]) 245 | return tensor.view(-1, n_channels, *cast_shape) 246 | 247 | @staticmethod 248 | def _cast_conv_to_shape(tensor, cast_shape): 249 | """ 250 | Convert Conv outputs into Conv outputs. 251 | 252 | Parameters: 253 | tensor : torch.Tensor 254 | The Conv tensor to reshape of shape 255 | [batch, num_channels, *spatial_dimensions] 256 | cast_shape : numpy.ndarray 257 | The Conv shape to cast incoming Conv to shape 258 | [batch, num_channels, *spatial_dimensions]. 259 | 260 | Returns: 261 | tensor : torch.Tensor 262 | Tensor of shape [batch, num_channels, *spatial_dimensions] 263 | 264 | """ 265 | # Extract only the spatial dimensions by ignoring the batch and channel 266 | spatial_dim_idx_start = 2 267 | if len(tensor.shape[spatial_dim_idx_start:]) < len(cast_shape): 268 | # TODO: https://github.com/pytorch/pytorch/issues/9410 269 | # Ignore batch for incoming tensor 270 | # For each missing dim, add dims until it 271 | # is equivalient to the max dim 272 | n_unsqueezes = len(cast_shape) - \ 273 | len(tensor.shape[spatial_dim_idx_start:]) 274 | for _ in range(n_unsqueezes): 275 | tensor = tensor.unsqueeze(dim=spatial_dim_idx_start) 276 | return pad(tensor=tensor, target_shape=cast_shape) 277 | 278 | def __str__(self): 279 | """Specify how to print network.""" 280 | if self.optim is not None: 281 | return super(ConvNet, self).__str__() + '\noptim: {}'\ 282 | .format(self.optim) 283 | else: 284 | return super(ConvNet, self).__str__() 285 | -------------------------------------------------------------------------------- /vulcanai/tests/models/test_cnn.py: -------------------------------------------------------------------------------- 1 | """Test all ConvNet capabilities.""" 2 | import pytest 3 | import numpy as np 4 | import copy 5 | import logging 6 | import os 7 | import shutil 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils.data import DataLoader, Subset, TensorDataset 12 | 13 | import vulcanai 14 | from vulcanai.models import BaseNetwork 15 | from vulcanai.models.cnn import ConvNet, ConvNetConfig 16 | from vulcanai.models.utils import master_device_setter 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class TestConvNet: 22 | """Define ConvNet test class.""" 23 | 24 | @pytest.fixture 25 | def multi_input_cnn_train_loader(self, multi_input_cnn_data): 26 | """Synthetic test data pytorch dataloader object.""" 27 | test_train = Subset(multi_input_cnn_data, 28 | range(len(multi_input_cnn_data)//2)) 29 | return DataLoader(test_train, batch_size=2) 30 | 31 | @pytest.fixture 32 | def multi_input_cnn_test_loader(self, multi_input_cnn_data): 33 | """Synthetic train data pytorch dataloader object.""" 34 | test_val = Subset(multi_input_cnn_data, 35 | range(len(multi_input_cnn_data)//2, 36 | len(multi_input_cnn_data))) 37 | return DataLoader(test_val, batch_size=2) 38 | 39 | def test_init(self, conv1D_net): 40 | """Initialization Test of a ConvNet object.""" 41 | assert isinstance(conv1D_net, BaseNetwork) 42 | assert isinstance(conv1D_net, nn.Module) 43 | assert hasattr(conv1D_net, 'network') 44 | assert hasattr(conv1D_net, 'in_dim') 45 | assert hasattr(conv1D_net, 'record') 46 | assert hasattr(conv1D_net, 'device') 47 | 48 | assert conv1D_net._name is not None 49 | assert isinstance(conv1D_net._config, ConvNetConfig) 50 | 51 | assert conv1D_net.input_networks is None 52 | assert conv1D_net.epoch == 0 53 | assert conv1D_net.optim is None 54 | assert conv1D_net.criterion is None 55 | 56 | assert not hasattr(conv1D_net, 'metrics') 57 | 58 | def test_function_multi_input(self, conv1D_net, multi_input_cnn): 59 | """Test functions wrt multi_input_cnn.""" 60 | tensor_size = [10, 1, 28, 28] 61 | assert isinstance(multi_input_cnn.input_networks, nn.ModuleDict) 62 | assert len(list(multi_input_cnn.input_networks)) == 3 63 | assert all(multi_input_cnn._get_max_incoming_spatial_dims() == 64 | (8, 8, 8)) 65 | # TODO: make more elegant 66 | assert multi_input_cnn._merge_input_network_outputs([ 67 | torch.rand(size=tensor_size), 68 | torch.rand(size=tensor_size + [28]), 69 | torch.rand(size=[10] + 70 | list(multi_input_cnn.input_networks 71 | ['multi_input_dnn'].out_dim)) 72 | ]).shape == (10, 3, 8, 8, 8) 73 | test_net = copy.deepcopy(multi_input_cnn) 74 | test_net._add_input_network(conv1D_net) 75 | assert len(list(test_net.input_networks)) == 4 76 | assert 'conv1D_net' in test_net.input_networks 77 | assert isinstance(test_net.input_networks['conv1D_net'], ConvNet) 78 | 79 | def test_forward(self, conv1D_net): 80 | """Test Forward of ConvNet.""" 81 | out = conv1D_net(torch.rand(size=[10, *conv1D_net.in_dim])) 82 | assert out.shape == (10, 64, 1) 83 | 84 | def test_forward_multi_input(self, multi_input_cnn): 85 | """Test Forward of Multi Input ConvNet.""" 86 | master_device_setter(multi_input_cnn, 'cpu') 87 | input_tensor = [ 88 | torch.rand(size=[10, 1, 28, 28]), 89 | torch.rand(size=[10, 1, 28, 28, 28]), 90 | [torch.rand(size=[10, 1, 28]), 91 | torch.rand(size=[10, 1, 28, 28])] 92 | ] 93 | out = multi_input_cnn(input_tensor) 94 | assert out.shape == (10, 10) 95 | 96 | def test_forward_pass_not_nan(self, conv3D_net): 97 | """Confirm out is non nan.""" 98 | test_input = torch.rand(size=[1, *conv3D_net.in_dim]) 99 | test_dataloader = DataLoader(TensorDataset(test_input, test_input)) 100 | output = conv3D_net.forward_pass( 101 | data_loader=test_dataloader, 102 | transform_outputs=False) 103 | assert np.any(~np.isnan(output)) 104 | 105 | def test_forward_pass_class_not_nan(self, conv3D_net_class): 106 | """Confirm out is non nan.""" 107 | test_input = torch.rand(size=[1, *conv3D_net_class.in_dim]) 108 | test_dataloader = DataLoader(TensorDataset(test_input, test_input)) 109 | raw_output = conv3D_net_class.forward_pass( 110 | data_loader=test_dataloader, 111 | transform_outputs=False) 112 | class_output = conv3D_net_class.metrics.transform_outputs( 113 | in_matrix=raw_output) 114 | assert np.any(~np.isnan(class_output)) 115 | assert np.any(~np.isnan(raw_output)) 116 | 117 | def test_freeze_class(self, conv3D_net_class): 118 | """Test class network freezing.""" 119 | conv3D_net_class.freeze(apply_inputs=False) 120 | for params in conv3D_net_class.network.parameters(): 121 | assert params.requires_grad is False 122 | 123 | def test_unfreeze_class(self, conv3D_net_class): 124 | """Test class network unfreezing.""" 125 | conv3D_net_class.freeze(apply_inputs=False) 126 | conv3D_net_class.unfreeze(apply_inputs=False) 127 | for params in conv3D_net_class.network.parameters(): 128 | assert params.requires_grad is True 129 | 130 | def test_freeze_noclass(self, conv3D_net): 131 | """Test intermediate network freezing.""" 132 | conv3D_net.freeze(apply_inputs=False) 133 | for params in conv3D_net.network.parameters(): 134 | assert params.requires_grad is False 135 | 136 | def test_unfreeze_noclass(self, conv3D_net): 137 | """Test intermediate network unfreezing.""" 138 | conv3D_net.freeze(apply_inputs=False) 139 | conv3D_net.unfreeze(apply_inputs=False) 140 | for params in conv3D_net.network.parameters(): 141 | assert params.requires_grad is True 142 | 143 | def test_fit_multi_input(self, multi_input_cnn, 144 | multi_input_cnn_train_loader, 145 | multi_input_cnn_test_loader): 146 | """Test for fit function.""" 147 | init_weights = copy.deepcopy(multi_input_cnn.network[0]. 148 | _kernel.weight.detach()) 149 | multi_input_cnn_no_fit = copy.deepcopy(multi_input_cnn) 150 | parameters1 = multi_input_cnn_no_fit.parameters() 151 | try: 152 | multi_input_cnn.fit( 153 | multi_input_cnn_train_loader, 154 | multi_input_cnn_test_loader, 155 | 2) 156 | except RuntimeError: 157 | logger.error("The network multi_input_cnn failed to train.") 158 | finally: 159 | parameters2 = multi_input_cnn.parameters() 160 | trained_weights = multi_input_cnn.network[0]._kernel.weight.detach() 161 | 162 | # Sanity check if the network parameters are training 163 | assert not (torch.equal(init_weights.cpu(), trained_weights.cpu())) 164 | compare_params = [not torch.allclose(param1, param2) 165 | for param1, param2 in zip(parameters1, 166 | parameters2)] 167 | assert all(compare_params) 168 | 169 | def test_params_multi_input(self, multi_input_cnn, 170 | multi_input_cnn_train_loader, 171 | multi_input_cnn_test_loader): 172 | """Test for change in network params/specifications.""" 173 | test_net = copy.deepcopy(multi_input_cnn) 174 | # Check the parameters are copying properly 175 | copy_params = [torch.allclose(param1, param2) 176 | for param1, param2 in zip(multi_input_cnn.parameters(), 177 | test_net.parameters())] 178 | assert all(copy_params) 179 | 180 | # Check the parameters change after copy and fit 181 | test_net.fit( 182 | multi_input_cnn_train_loader, 183 | multi_input_cnn_test_loader, 184 | 2) 185 | close_params = [not torch.allclose(param1, param2) 186 | for param1, param2 in zip(multi_input_cnn.parameters(), 187 | test_net.parameters())] 188 | assert all(close_params) 189 | 190 | # Check the network params and optimizer params point to 191 | # the same memory 192 | if test_net.optim: 193 | assert isinstance(test_net.optim, torch.optim.Adam) 194 | assert isinstance(test_net.criterion, torch.nn.CrossEntropyLoss) 195 | for param, opt_param in zip(test_net.parameters(), 196 | test_net.optim.param_groups[0]['params']): 197 | assert param is opt_param 198 | 199 | # Check the params after saving loaading 200 | test_net.save_model() 201 | save_path = test_net.save_path 202 | abs_save_path = os.path.dirname(os.path.abspath(save_path)) 203 | loaded_test_net = BaseNetwork.load_model(load_path=save_path) 204 | load_params = [torch.allclose(param1, param2) 205 | for param1, param2 in zip(test_net.parameters(), 206 | loaded_test_net.parameters())] 207 | shutil.rmtree(abs_save_path) 208 | assert all(load_params) 209 | 210 | def test_forward_pass_class_not_nan_single_value(self, 211 | conv3D_net_class_single_value): 212 | """Confirm out is non nan.""" 213 | test_input = torch.rand(size=[10, *conv3D_net_class_single_value.in_dim]) 214 | test_output = torch.rand(size=[10, 1]) 215 | test_dataloader = DataLoader(TensorDataset(test_input, test_output)) 216 | raw_output = conv3D_net_class_single_value.forward_pass( 217 | data_loader=test_dataloader, 218 | transform_outputs=False) 219 | assert np.any(~np.isnan(raw_output)) 220 | 221 | def test_early_stopping(self, conv3D_net_class_early_stopping, 222 | conv3D_net_class): 223 | """ Test that their final params are different: aka 224 | that the early stopping did something. 225 | Constant seed resetting due to worker init function of dataloader 226 | https://github.com/pytorch/pytorch/issues/7068 """ 227 | 228 | vulcanai.set_global_seed(42) 229 | 230 | ds = DataLoader(TensorDataset( 231 | torch.rand(size=[10, *conv3D_net_class_early_stopping.in_dim]), 232 | torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).long()), 233 | num_workers=0) 234 | 235 | vulcanai.set_global_seed(42) 236 | 237 | ds2 = copy.deepcopy(ds) 238 | 239 | vulcanai.set_global_seed(42) 240 | 241 | ds3 = copy.deepcopy(ds) 242 | 243 | conv3D_net_class_copy = copy.deepcopy(conv3D_net_class) 244 | 245 | conv3D_net_class_copy.fit( 246 | train_loader=ds2, 247 | val_loader=ds2, 248 | epochs=5) 249 | vulcanai.set_global_seed(42) 250 | 251 | conv3D_net_class.fit( 252 | train_loader=ds, 253 | val_loader=ds, 254 | epochs=5) 255 | 256 | vulcanai.set_global_seed(42) 257 | 258 | conv3D_net_class_early_stopping.fit( 259 | train_loader=ds3, 260 | val_loader=ds3, 261 | epochs=5) 262 | 263 | stopping_params = list(conv3D_net_class_early_stopping.parameters())[0][0][0].data 264 | non_stopping_params = list(conv3D_net_class.parameters())[0][0][0][0][0].data 265 | non_stopping_params_copy = list(conv3D_net_class_copy.parameters())[0][0][0][0][0].data 266 | 267 | assert not torch.eq(stopping_params, non_stopping_params).all() 268 | assert torch.eq(non_stopping_params, non_stopping_params_copy).all() 269 | 270 | -------------------------------------------------------------------------------- /vulcanai/tests/models/test_layers.py: -------------------------------------------------------------------------------- 1 | """The script to test all layers and SELU activation properties hold.""" 2 | import pytest 3 | import numpy as np 4 | import math 5 | import copy 6 | from functools import reduce 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import modules 10 | from vulcanai.models.layers import BaseUnit, ConvUnit, DenseUnit 11 | from vulcanai.models.utils import selu_weight_init_, selu_bias_init_ 12 | from vulcanai.models.dnn import DenseNet 13 | from vulcanai.models.cnn import ConvNet 14 | from torch.utils.data import TensorDataset, DataLoader 15 | 16 | torch.manual_seed(1234) 17 | 18 | class TestBaseUnit: 19 | """To test BaseUnit layer.""" 20 | 21 | @pytest.fixture 22 | def baseunit(self): 23 | """Make Base unit fixture.""" 24 | return BaseUnit() 25 | 26 | def test_init(self, baseunit): 27 | """Initialization Test of a BaseUnit object.""" 28 | assert isinstance(baseunit, nn.Sequential) 29 | assert hasattr(baseunit, 'weight_init') 30 | assert hasattr(baseunit, 'bias_init') 31 | assert hasattr(baseunit, 'norm') 32 | assert hasattr(baseunit, 'dropout') 33 | assert hasattr(baseunit, 'in_shape') 34 | assert hasattr(baseunit, 'out_shape') 35 | assert hasattr(baseunit, '_kernel') 36 | 37 | 38 | class TestDenseUnit: 39 | """To test DenseUnit layer.""" 40 | 41 | @pytest.fixture 42 | def denseunit(self): 43 | """Dense unit fixture.""" 44 | return DenseUnit( 45 | in_features=10, 46 | out_features=10, 47 | weight_init=nn.init.xavier_uniform_, 48 | bias_init=nn.init.zeros_, 49 | norm='batch', 50 | activation=nn.ReLU(), 51 | dropout=0.5 52 | ) 53 | 54 | def test_init(self, denseunit): 55 | """Initialization Test of a DenseUnit object.""" 56 | assert hasattr(denseunit, 'in_features') 57 | assert hasattr(denseunit, 'out_features') 58 | for unit in denseunit.named_children(): 59 | assert any(unit[0] == i for i in ['_kernel', '_norm', 60 | '_activation', '_dropout']) 61 | assert isinstance(unit[1], (nn.Linear, 62 | nn.modules.batchnorm._BatchNorm, 63 | nn.ReLU, modules.dropout._DropoutNd)) 64 | assert callable(unit[1]) 65 | 66 | @pytest.fixture 67 | def test_denseunit_parameters(self): 68 | """Create a dictionary with incorrect DenseUnit parameters.""" 69 | return dict( 70 | in_channels=10, 71 | out_features=10, 72 | weight_init=nn.init.xavier_uniform_, 73 | bias_init=nn.init.zeros_, 74 | norm='batch', 75 | activation=nn.ReLU(), 76 | dropout=0.5 77 | ) 78 | 79 | def test_create_denseunit(self, test_denseunit_parameters): 80 | """Check if passing wrong parameters raises TypeError.""" 81 | with pytest.raises(TypeError) as e: 82 | DenseUnit(**test_denseunit_parameters) 83 | assert 'in_channels' in str(e.value) 84 | 85 | def test_forward(self, denseunit): 86 | """Confirm size is expected after forward.""" 87 | test_input = torch.rand(size=[10, denseunit.in_features]) 88 | output = denseunit.forward(test_input) 89 | assert output.shape == torch.rand(size=[10, denseunit.out_features]).shape 90 | 91 | 92 | class TestConvUnit: 93 | """To test ConvUnit layers.""" 94 | 95 | @pytest.fixture 96 | def convunit(self): 97 | """Create ConvUnit fixture.""" 98 | return ConvUnit( 99 | conv_dim=2, 100 | in_channels=10, 101 | out_channels=10, 102 | kernel_size=(5, 5), 103 | weight_init=nn.init.xavier_uniform_, 104 | bias_init=nn.init.zeros_, 105 | norm='batch', 106 | activation=nn.ReLU(), 107 | dropout=0.5 108 | ) 109 | 110 | @pytest.fixture 111 | def test_convunit_parameters(self): 112 | """Create dictionary with incorrect ConvUnit parameters.""" 113 | return dict( 114 | conv_dim=2, 115 | in_channels=10, 116 | out_features=10, 117 | kernel_size=(5, 5), 118 | weight_init=nn.init.xavier_uniform_, 119 | bias_init=nn.init.zeros_, 120 | norm='batch', 121 | activation=nn.ReLU(), 122 | dropout=0.5 123 | ) 124 | 125 | def test_create_convunit(self, test_convunit_parameters): 126 | """Check if passing wrong parameters raises TypeError.""" 127 | with pytest.raises(TypeError) as e: 128 | ConvUnit(**test_convunit_parameters) 129 | assert 'out_features' in str(e.value) 130 | 131 | def test_init(self, convunit): 132 | """Initialization Test of a ConvUnit object.""" 133 | assert hasattr(convunit, 'in_channels') 134 | assert hasattr(convunit, 'out_channels') 135 | assert hasattr(convunit, 'kernel_size') 136 | for unit in convunit.named_children(): 137 | assert any(unit[0] == i for i in 138 | ['_kernel', '_norm', '_activation', '_dropout']) 139 | assert isinstance(unit[1], (modules.conv._ConvNd, 140 | nn.modules.batchnorm._BatchNorm, 141 | nn.ReLU, modules.dropout._DropoutNd)) 142 | assert callable(unit[1]) 143 | 144 | def test_forward(self, convunit): 145 | """Confirm size is expected after forward.""" 146 | test_input = torch.rand(size=[10, convunit.in_channels, 28, 28]) 147 | output = convunit.forward(test_input) 148 | # No padding with 2 5x5 kernels leads from 28x28 -> 24x24 149 | assert output.shape == \ 150 | torch.rand(size=[10, convunit.out_channels, 24, 24]).shape 151 | 152 | 153 | class TestSeluInit: 154 | """To test selu initialized layer properties.""" 155 | 156 | @pytest.fixture 157 | def dense_unit(self): 158 | """Create a dense unit fixture.""" 159 | return DenseUnit( 160 | in_features=10, 161 | out_features=10 162 | ) 163 | 164 | @pytest.fixture 165 | def conv_unit(self): 166 | """Create a conv unit fixture.""" 167 | return ConvUnit( 168 | conv_dim=2, 169 | in_channels=10, 170 | out_channels=10, 171 | kernel_size=(5, 5) 172 | ) 173 | 174 | def test_dense_selu_weight_change(self, dense_unit): 175 | """Confirm SELU weight init properties hold for dense net.""" 176 | starting_weight = copy.deepcopy(dense_unit._kernel.weight) 177 | fan_in = dense_unit.in_features 178 | std = round(math.sqrt(1. / fan_in), 1) 179 | 180 | dense_unit.weight_init = selu_weight_init_ 181 | dense_unit._init_weights() 182 | new_weight = dense_unit._kernel.weight 183 | assert (torch.equal(starting_weight, new_weight) is False) 184 | assert (math.isclose(new_weight.std().item(), math.sqrt(1. / fan_in), rel_tol=std) is True) 185 | assert (int(new_weight.mean().item()) == 0.0) 186 | 187 | def test_conv_selu_weight_change(self, conv_unit): 188 | """Confirm SELU weight init properties hold for conv net.""" 189 | starting_weight = copy.deepcopy(conv_unit._kernel.weight) 190 | fan_in = conv_unit._kernel.in_channels * \ 191 | reduce(lambda k1, k2: k1 * k2, conv_unit._kernel.kernel_size) 192 | std = math.sqrt(1. / fan_in) 193 | conv_unit.weight_init = selu_weight_init_ 194 | conv_unit._init_weights() 195 | new_weight = conv_unit._kernel.weight 196 | assert (torch.equal(starting_weight, new_weight) is False) 197 | assert (math.isclose(new_weight.std().item(), math.sqrt(1./fan_in), rel_tol=std) is True) 198 | assert (int(new_weight.mean().item()) == 0) 199 | 200 | def test_dense_selu_bias_change(self, dense_unit): 201 | """Confirm SELU bias init properties hold for dense net.""" 202 | starting_bias = copy.deepcopy(dense_unit._kernel.bias) 203 | 204 | dense_unit.bias_init = selu_bias_init_ 205 | dense_unit._init_bias() 206 | new_bias = dense_unit._kernel.bias 207 | assert (torch.equal(starting_bias, new_bias) is False) 208 | assert (round(new_bias.std().item(), 1) == 0.0) 209 | assert (int(new_bias.mean().item()) == 0) 210 | 211 | def test_conv_selu_bias_change(self, conv_unit): 212 | """Confirm SELU bias init properties hold for conv net.""" 213 | starting_bias = copy.deepcopy(conv_unit._kernel.bias) 214 | 215 | conv_unit.bias_init = selu_bias_init_ 216 | conv_unit._init_bias() 217 | new_bias = conv_unit._kernel.bias 218 | assert (torch.equal(starting_bias, new_bias) is False) 219 | assert (round(new_bias.std().item(), 1) == 0.0) 220 | assert (int(new_bias.mean().item()) == 0) 221 | 222 | 223 | class TestSeluInitTrain: 224 | """To test selu initialization properties hold during training.""" 225 | 226 | @pytest.fixture 227 | def dnn_class(self): 228 | """Create DenseNet with no prediction layer.""" 229 | return DenseNet( 230 | name='Test_DenseNet_class', 231 | in_dim=(200), 232 | activation=torch.nn.SELU(), 233 | num_classes=10, 234 | config={ 235 | 'dense_units': [100], 236 | 'dropout': [0.3], 237 | }, 238 | optim_spec={'name': 'Adam', 'lr': 0.001} 239 | ) 240 | 241 | @pytest.fixture 242 | def cnn_class(self): 243 | """Create ConvNet with prediction layer.""" 244 | return ConvNet( 245 | name='Test_ConvNet_class', 246 | in_dim=(1, 28, 28), 247 | activation=torch.nn.SELU(), 248 | config={ 249 | 'conv_units': [ 250 | { 251 | "in_channels": 1, 252 | "out_channels": 16, 253 | "kernel_size": (5, 5), 254 | "stride": 2 255 | }, 256 | { 257 | "in_channels": 16, 258 | "out_channels": 1, 259 | "kernel_size": (5, 5), 260 | "stride": 1, 261 | "padding": 2 262 | }] 263 | }, 264 | num_classes=10 265 | ) 266 | 267 | def test_selu_trained_dense(self, dnn_class): 268 | """Confirm SELU weight and bias properties hold for a dense net.""" 269 | fan_in = dnn_class.in_dim[0] 270 | std = round(math.sqrt(1. / fan_in), 1) 271 | test_input = torch.rand(size=[10, *dnn_class.in_dim]).float() 272 | test_target = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 273 | test_dataloader = DataLoader(TensorDataset(test_input, test_target)) 274 | dnn_class.fit(test_dataloader, test_dataloader, 5) 275 | trained = dnn_class.network.dense_0._kernel 276 | assert (round(trained.weight.std().item(), 1) == std) 277 | assert (int(trained.weight.mean().item()) == 0.0) 278 | assert (round(trained.bias.std().item(), 1) == 0.0) 279 | assert (int(trained.bias.mean().item()) == 0.0) 280 | 281 | def test_selu_trained_conv(self, cnn_class): 282 | """Confirm SELU weight and bias properties hold for a conv net.""" 283 | fan_in = cnn_class.network[0].in_channels * \ 284 | reduce(lambda k1, k2: k1 * k2, cnn_class.network[0].kernel_size) 285 | std = round(math.sqrt(1. / fan_in), 1) 286 | test_input = torch.rand(size=[10, *cnn_class.in_dim]).float() 287 | test_target = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 288 | test_dataloader = DataLoader(TensorDataset(test_input, test_target)) 289 | cnn_class.fit(test_dataloader, test_dataloader, 5) 290 | trained = cnn_class.network.conv_0._kernel 291 | assert (round(trained.weight.std().item(), 1) == std) 292 | assert (int(trained.weight.mean().item()) == 0.0) 293 | assert (round(trained.bias.std().item(), 1) == 0.0) 294 | assert (int(trained.bias.mean().item()) == 0.0) 295 | -------------------------------------------------------------------------------- /vulcanai/tests/models/test_dnn.py: -------------------------------------------------------------------------------- 1 | """Test all DenseNet capabilities.""" 2 | import pytest 3 | import numpy as np 4 | import copy 5 | import logging 6 | import os 7 | import shutil 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils.data import DataLoader, Subset, TensorDataset 12 | 13 | import vulcanai 14 | from vulcanai.models import BaseNetwork 15 | from vulcanai.models.dnn import DenseNet, DenseNetConfig 16 | from vulcanai.models.utils import master_device_setter 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class TestDenseNet: 22 | """Define DenseNet test class.""" 23 | 24 | @pytest.fixture 25 | def multi_input_dnn_train_loader(self, multi_input_dnn_data): 26 | """Synthetic train data pytorch dataloader object.""" 27 | test_train = Subset(multi_input_dnn_data, 28 | range(len(multi_input_dnn_data)//2)) 29 | return DataLoader(test_train, batch_size=2) 30 | 31 | @pytest.fixture 32 | def multi_input_dnn_test_loader(self, multi_input_dnn_data): 33 | """Synthetic test data pytorch dataloader object.""" 34 | test_val = Subset(multi_input_dnn_data, 35 | range(len(multi_input_dnn_data)//2, 36 | len(multi_input_dnn_data))) 37 | return DataLoader(test_val, batch_size=2) 38 | 39 | def test_init(self, dnn_noclass): 40 | """Initialization Test of a DenseNet object.""" 41 | assert isinstance(dnn_noclass, BaseNetwork) 42 | assert isinstance(dnn_noclass, nn.Module) 43 | assert hasattr(dnn_noclass, 'network') 44 | assert hasattr(dnn_noclass, 'in_dim') 45 | assert hasattr(dnn_noclass, 'record') 46 | assert hasattr(dnn_noclass, 'device') 47 | 48 | assert dnn_noclass._name is not None 49 | assert isinstance(dnn_noclass._config, DenseNetConfig) 50 | 51 | assert dnn_noclass.input_networks is None 52 | assert dnn_noclass.epoch == 0 53 | assert dnn_noclass.optim is None 54 | assert dnn_noclass.criterion is None 55 | 56 | assert not hasattr(dnn_noclass, 'metrics') 57 | 58 | def test_function_multi_input(self, dnn_noclass, 59 | multi_input_dnn): 60 | """Test functions wrt multi_input_dnn.""" 61 | assert isinstance(multi_input_dnn.input_networks, nn.ModuleDict) 62 | assert len(list(multi_input_dnn.input_networks)) == 2 63 | assert multi_input_dnn._merge_input_network_outputs([ 64 | torch.rand(size=[10, 1, 28]), 65 | torch.rand(size=[10, 1, 28, 28]) 66 | ]).shape == (10, 812) 67 | test_net = copy.deepcopy(multi_input_dnn) 68 | test_net._add_input_network(dnn_noclass) 69 | assert len(list(test_net.input_networks)) == 3 70 | assert 'dnn_noclass' in test_net.input_networks 71 | assert isinstance(test_net.input_networks['dnn_noclass'], DenseNet) 72 | 73 | def test_forward(self, dnn_class): 74 | """Test Forward of DenseNet.""" 75 | out = dnn_class(torch.rand(size=[10, *dnn_class.in_dim])) 76 | assert out.shape == (10, 3) 77 | 78 | def test_forward_multi_input(self, multi_input_dnn): 79 | """Test Forward of Multi Input ConvNet.""" 80 | master_device_setter(multi_input_dnn, 'cpu') 81 | input_tensor = [torch.rand(size=[10, 1, 28]), 82 | torch.rand(size=[10, 1, 28, 28])] 83 | out = multi_input_dnn(input_tensor) 84 | assert out.shape == (10, 50) 85 | 86 | def test_forward_pass_not_nan(self, dnn_noclass): 87 | """Confirm out is non nan.""" 88 | test_input = torch.rand(size=[5, *dnn_noclass.in_dim]) 89 | test_dataloader = DataLoader(TensorDataset(test_input, test_input)) 90 | output = dnn_noclass.forward_pass( 91 | data_loader=test_dataloader, 92 | transform_outputs=False) 93 | assert np.any(~np.isnan(output)) 94 | 95 | def test_forward_pass_class_not_nan(self, dnn_class): 96 | """Confirm out is non nan.""" 97 | test_input = torch.rand(size=[5, *dnn_class.in_dim]) 98 | test_dataloader = DataLoader(TensorDataset(test_input, test_input)) 99 | raw_output = dnn_class.forward_pass( 100 | data_loader=test_dataloader, 101 | transform_outputs=False) 102 | class_output = dnn_class.metrics.transform_outputs( 103 | in_matrix=test_input) 104 | assert np.any(~np.isnan(raw_output)) 105 | assert np.any(~np.isnan(class_output)) 106 | 107 | def test_freeze_class(self, dnn_class): 108 | """Test class network freezing.""" 109 | dnn_class.freeze(apply_inputs=False) 110 | for params in dnn_class.network.parameters(): 111 | assert params.requires_grad is False 112 | 113 | def test_unfreeze_class(self, dnn_class): 114 | """Test class network unfreezing.""" 115 | dnn_class.freeze(apply_inputs=False) 116 | dnn_class.unfreeze(apply_inputs=False) 117 | for params in dnn_class.network.parameters(): 118 | assert params.requires_grad is True 119 | 120 | def test_freeze_noclass(self, dnn_noclass): 121 | """Test intermediate network freezing.""" 122 | dnn_noclass.freeze(apply_inputs=False) 123 | for params in dnn_noclass.network.parameters(): 124 | assert params.requires_grad is False 125 | 126 | def test_unfreeze_noclass(self, dnn_noclass): 127 | """Test intermediate network unfreezing.""" 128 | dnn_noclass.freeze(apply_inputs=False) 129 | dnn_noclass.unfreeze(apply_inputs=False) 130 | for params in dnn_noclass.network.parameters(): 131 | assert params.requires_grad is True 132 | 133 | def test_fit_multi_input(self, multi_input_dnn_class, 134 | multi_input_dnn_train_loader, 135 | multi_input_dnn_test_loader): 136 | """Test for fit function.""" 137 | init_weights = copy.deepcopy(multi_input_dnn_class.network[0]. 138 | _kernel.weight.detach()) 139 | multi_input_dnn_class_no_fit = copy.deepcopy(multi_input_dnn_class) 140 | parameters1 = multi_input_dnn_class_no_fit.parameters() 141 | try: 142 | multi_input_dnn_class.fit( 143 | multi_input_dnn_train_loader, 144 | multi_input_dnn_test_loader, 145 | 2) 146 | except RuntimeError: 147 | logger.error("The network multi_input_dnn_class failed to train.") 148 | finally: 149 | parameters2 = multi_input_dnn_class.parameters() 150 | trained_weights = multi_input_dnn_class.network[0]._kernel.weight.detach() 151 | 152 | # Sanity check if the network parameters are training 153 | # We want to be sure weights are different. 154 | # Hacked so that we can be sure we're properly comparing floats 155 | # There is no negation of np.testing.assert_almost_equal. 156 | # Thus we throw an error if an error doesn't occur from checking eq 157 | weights_same = True 158 | try: 159 | np.testing.assert_almost_equal(init_weights.cpu().numpy(), 160 | trained_weights.cpu().numpy()) 161 | except AssertionError: 162 | weights_same = False 163 | 164 | if weights_same: 165 | raise AssertionError 166 | 167 | #assert not (torch.equal(init_weights.cpu(), trained_weights.cpu())) 168 | compare_params = [not torch.allclose(param1, param2) 169 | for param1, param2 in zip(parameters1, 170 | parameters2)] 171 | assert all(compare_params) 172 | 173 | def test_params_multi_input(self, multi_input_dnn_class, 174 | multi_input_dnn_train_loader, 175 | multi_input_dnn_test_loader): 176 | """Test for change in network params/specifications.""" 177 | test_net = copy.deepcopy(multi_input_dnn_class) 178 | # Check the parameters are copying properly 179 | copy_params = [torch.allclose(param1, param2) 180 | for param1, param2 in zip(multi_input_dnn_class.parameters(), 181 | test_net.parameters())] 182 | assert all(copy_params) 183 | 184 | # Check the parameters change after copy and fit 185 | test_net.fit( 186 | multi_input_dnn_train_loader, 187 | multi_input_dnn_test_loader, 188 | 2) 189 | close_params = [not torch.allclose(param1, param2) 190 | for param1, param2 in zip(multi_input_dnn_class.parameters(), 191 | test_net.parameters())] 192 | assert all(close_params) 193 | 194 | # Check the network params and optimizer params point to 195 | # the same memory 196 | if test_net.optim: 197 | assert isinstance(test_net.optim, torch.optim.Adam) 198 | assert isinstance(test_net.criterion, torch.nn.CrossEntropyLoss) 199 | for param, opt_param in zip(test_net.parameters(), 200 | test_net.optim.param_groups[0]['params']): 201 | assert param is opt_param 202 | 203 | # Check the params after saving loaading 204 | test_net.save_model() 205 | save_path = test_net.save_path 206 | abs_save_path = os.path.dirname(os.path.abspath(save_path)) 207 | loaded_test_net = BaseNetwork.load_model(load_path=save_path) 208 | load_params = [torch.allclose(param1, param2) 209 | for param1, param2 in zip(test_net.parameters(), 210 | loaded_test_net.parameters())] 211 | shutil.rmtree(abs_save_path) 212 | assert all(load_params) 213 | 214 | 215 | def test_forward_pass_class_not_nan_single_value(self, 216 | dnn_class_single_value): 217 | """Confirm out is non nan.""" 218 | test_input = torch.rand(size=[10, *dnn_class_single_value.in_dim]) 219 | test_output = torch.rand(size=[10, 1]) 220 | test_dataloader = DataLoader(TensorDataset(test_input, test_output)) 221 | raw_output = dnn_class_single_value.forward_pass( 222 | data_loader=test_dataloader, 223 | transform_outputs=False) 224 | assert np.any(~np.isnan(raw_output)) 225 | 226 | def test_early_stopping(self, dnn_class_early_stopping, 227 | dnn_class): 228 | """ Test that their final params are different: aka 229 | that the early stopping did something 230 | Constant seed resetting due to worker init function of dataloader 231 | https://github.com/pytorch/pytorch/issues/7068 """ 232 | 233 | ds = DataLoader(TensorDataset( 234 | torch.rand(size=[3, *dnn_class_early_stopping.in_dim]), 235 | torch.tensor([0, 1, 2]).long())) 236 | 237 | vulcanai.set_global_seed(42) 238 | 239 | ds2 = copy.deepcopy(ds) 240 | 241 | vulcanai.set_global_seed(42) 242 | 243 | ds3 = copy.deepcopy(ds) 244 | 245 | dnn_class_copy = copy.deepcopy(dnn_class) 246 | 247 | dnn_class_copy.fit( 248 | train_loader=ds2, 249 | val_loader=ds2, 250 | epochs=5) 251 | 252 | vulcanai.set_global_seed(42) 253 | 254 | dnn_class.fit( 255 | train_loader=ds, 256 | val_loader=ds, 257 | epochs=5) 258 | 259 | vulcanai.set_global_seed(42) 260 | 261 | dnn_class_early_stopping.fit( 262 | train_loader=ds3, 263 | val_loader=ds3, 264 | epochs=5) 265 | 266 | stopping_params = list(dnn_class_early_stopping.parameters())[0][0].data 267 | non_stopping_params = list(dnn_class.parameters())[0][0].data 268 | non_stopping_params_copy = list(dnn_class_copy.parameters())[0][0].data 269 | 270 | assert not torch.eq(stopping_params, non_stopping_params).all() 271 | assert torch.eq(non_stopping_params, non_stopping_params_copy).all() 272 | 273 | -------------------------------------------------------------------------------- /vulcanai/models/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """Define utilities for all networks.""" 3 | from math import ceil, floor 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as f 7 | from torch.utils.data import DataLoader, TensorDataset 8 | 9 | import numpy as np 10 | import math 11 | from sklearn.preprocessing import LabelBinarizer 12 | from collections import OrderedDict 13 | from collections import defaultdict 14 | 15 | 16 | def _get_probs(network, loader, index_to_iter, ls_feat_vals): 17 | """Returns probability for each object within loader based on output 18 | from training neural network 19 | 20 | Parameters: 21 | network : vulcan.model 22 | training vulcan network 23 | loader : torch.dataloader 24 | dataloader containing validation set 25 | index_to_iter : string 26 | feature to iterate through adjusting values 27 | ls_feat_vals : list 28 | values to iterate through for feature in index_to_iter 29 | 30 | Returns: 31 | dct_scores : dictionary 32 | dictionary of scores 33 | """ 34 | 35 | dct_scores = defaultdict() 36 | for index in range(len(loader)): 37 | dct_scores[index] = {} 38 | for index in range(len(loader)): 39 | # Extract specific index from loader and create a DataLoader instance 40 | # to send to forward_pass 41 | input_loader = DataLoader(TensorDataset(loader.dataset[index][0] 42 | .unsqueeze(0), 43 | loader.dataset[index][1] 44 | .unsqueeze(0))) 45 | 46 | subj_prob = network.forward_pass(data_loader=input_loader, 47 | transform_outputs=False) 48 | # Standardize probability of positive label 49 | subj_prob = subj_prob[0][1] * 100 50 | subj_prob = round(subj_prob, 2) 51 | # Add probability to scores dictionary where keys are the index 52 | # and value the probability belongs to. 53 | dct_scores[index][loader.dataset[index][0][index_to_iter].item()] = \ 54 | subj_prob 55 | # Iterate through other possible values and find probability of 56 | # positive label and add to dictionary 57 | # for current index. 58 | for new_val in ls_feat_vals: 59 | if new_val != loader.dataset[index][0][index_to_iter].item(): 60 | loader.dataset[index][0][index_to_iter] = new_val 61 | input_loader = DataLoader(TensorDataset(loader 62 | .dataset[index][0] 63 | .unsqueeze(0), 64 | loader 65 | .dataset[index][1] 66 | .unsqueeze(0))) 67 | subj_prob = network.forward_pass(data_loader=input_loader, 68 | ransform_outputs=False) 69 | subj_prob = subj_prob[0][1] * 100 70 | subj_prob = round(subj_prob, 2) 71 | dct_scores[index][new_val] = subj_prob 72 | return dct_scores 73 | 74 | 75 | def _filter_matched_subj(dct_scores, loader, index_to_iter): 76 | """ 77 | Returns dictionary of filtered keys based on predicted value and value 78 | truly assigned in actual set. 79 | 80 | Parameters: 81 | dct_scores : dictionary 82 | dictionary of probability scores produced in get_scores function 83 | loader : torch.dataloader 84 | dataloader containing validation set 85 | index_to_iter : string 86 | feature to iterate through adjusting values 87 | Returns: 88 | dct_filtered : dictionary 89 | dictionary of filtered entities 90 | """ 91 | dct_filtered = {} 92 | for subj in list(dct_scores): 93 | highest_prob = max(dct_scores[subj].values()) 94 | highest_val = [val for val, prob in dct_scores[subj].items() 95 | if prob == highest_prob] 96 | if loader.dataset.dataset[subj][0][index_to_iter].item() \ 97 | in highest_val: 98 | dct_filtered[subj] = highest_prob 99 | return dct_filtered 100 | 101 | 102 | def round_list(raw_list, decimals=4): 103 | """ 104 | Return the same list with each item rounded off. 105 | 106 | Parameters: 107 | raw_list : float list 108 | float list to round. 109 | decimals : int 110 | How many decimal points to round to. 111 | 112 | Returns: 113 | rounded_list : float list 114 | The rounded list in the same shape as raw_list. 115 | 116 | """ 117 | return [round(item, decimals) for item in raw_list] 118 | 119 | 120 | def get_one_hot(in_matrix): 121 | """ 122 | Reformat truth matrix to same size as the output of the dense network. 123 | 124 | Parameters: 125 | in_matrix : numpy.ndarray 126 | The categorized 1D matrix 127 | 128 | Returns: 129 | one_hot : numpy.ndarray 130 | A one-hot matrix representing the categorized matrix 131 | 132 | """ 133 | if in_matrix.dtype.name == 'category': 134 | custom_array = in_matrix.cat.codes 135 | 136 | elif isinstance(in_matrix, np.ndarray): 137 | custom_array = in_matrix 138 | 139 | else: 140 | raise ValueError("Input matrix cannot be converted.") 141 | 142 | lb = LabelBinarizer() 143 | return np.array(lb.fit_transform(custom_array), dtype='float32') 144 | 145 | 146 | def pad(tensor, target_shape): 147 | """ 148 | Pad incoming tensor to the size of target_shape. 149 | 150 | tensor must have same spatial dimenison as the target_shape. 151 | Useful for combining various conv dimension outputs and to implement 152 | 'same' padding for conv operations. 153 | 154 | Parameters: 155 | tensor : torch.Tensor 156 | Tensor to be padded 157 | target_shape : np.array 158 | Final padded tensor shape [*spatial_dimensions] 159 | 160 | Returns: 161 | tensor : torch.Tensor 162 | zero padded tensor with spatial dimension as target_shape 163 | 164 | """ 165 | # Ignore channels and batch and focus on spatial dimensions 166 | # from incoming tensor 167 | if not isinstance(target_shape, np.ndarray): 168 | target_shape = np.array(target_shape) 169 | n_dim = len(target_shape) 170 | # Calculate, element-wise, how much needs to be padded for each dim. 171 | dims_size_diff = target_shape - np.array(tensor.shape[-n_dim:]) 172 | # TODO: Use torch.nn.ConstantPadding? 173 | padding_needed = [] 174 | for dim_diff in reversed(dims_size_diff): 175 | dim_zero_padding = ceil(dim_diff/2) 176 | dim_one_padding = floor(dim_diff/2) 177 | padding_needed.append(dim_zero_padding) 178 | padding_needed.append(dim_one_padding) 179 | return f.pad(tensor, padding_needed) 180 | 181 | 182 | # noinspection PyTypeChecker 183 | def network_summary(network, input_size=None): 184 | """ 185 | Returns the summary of shapes of all layers in the network 186 | :return: OrderedDict of shape of each layer in the network 187 | """ 188 | if not input_size: 189 | input_size = [] 190 | for net in network.input_networks.values(): 191 | input_size.append(net.in_dim) 192 | # input_size must be a list 193 | if isinstance(input_size, (tuple, int)): 194 | input_size = [input_size] 195 | 196 | def get_size(summary_dict, output): 197 | """ 198 | Helper function for the BaseNetwork's get_output_shapes 199 | """ 200 | if isinstance(output, tuple): 201 | for i in range(len(output)): 202 | summary_dict[i] = OrderedDict() 203 | summary_dict[i] = get_size(summary_dict[i], output[i]) 204 | else: 205 | summary_dict['output_shape'] = tuple(output.size()) 206 | return summary_dict 207 | 208 | # noinspection PyUnresolvedReferences,PyUnresolvedReferences 209 | def register_hook(module): 210 | """ 211 | Registers a backward hook 212 | For more info: https://pytorch.org/docs/stable/_modules/torch/ 213 | tensor.html#Tensor.register_hook 214 | 215 | """ 216 | 217 | # noinspection PyShadowingNames, PyUnresolvedReferences 218 | def hook(module, input, output): 219 | """ 220 | https://github.com/pytorch/tutorials/blob/8afce8a213cb3712aa7de 221 | 1e1cf158da765f029a7/beginner_source/former_torchies/nn_tutorial 222 | .py#L146 223 | """ 224 | class_name = str(module.__class__).split('.')[-1].split("'")[0] 225 | module_idx = len(summary) 226 | # Test 227 | m_key = '%s-%i' % (class_name, module_idx + 1) 228 | summary[m_key] = OrderedDict() 229 | summary[m_key]['input_shape'] = tuple(input[0].size()) 230 | summary[m_key] = get_size(summary[m_key], output) 231 | # Test 232 | params = 0 233 | if hasattr(module, 'weight'): 234 | params += torch.prod(torch.LongTensor(tuple( 235 | module.weight.size()))) 236 | if module.weight.requires_grad: 237 | summary[m_key]['trainable'] = True 238 | else: 239 | summary[m_key]['trainable'] = False 240 | if hasattr(module, 'bias'): 241 | params += torch.prod(torch.LongTensor(tuple( 242 | module.bias.size()))) 243 | # Test 244 | summary[m_key]['nb_params'] = params 245 | if not isinstance(module, torch.nn.Sequential) and \ 246 | not isinstance(module, torch.nn.ModuleList) and \ 247 | not (module == network): 248 | hooks.append(module.register_forward_hook(hook)) 249 | 250 | x = [] 251 | for in_size in input_size: 252 | x.append(torch.empty(1, *in_size)) 253 | 254 | if len(x) == 1: 255 | x = torch.cat(x, dim=1) 256 | 257 | # create properties 258 | summary = OrderedDict() 259 | hooks = [] 260 | 261 | # register hook 262 | network.apply(register_hook) 263 | # make a forward pass 264 | network.cpu()(x) 265 | 266 | # remove these hooks 267 | for h in hooks: 268 | h.remove() 269 | 270 | return summary 271 | 272 | 273 | def print_model_structure(network): 274 | """Print the entire model structure.""" 275 | shapes = network_summary(network) 276 | for k, v in shapes.items(): 277 | print('{}:'.format(k)) 278 | if isinstance(v, OrderedDict): 279 | for k2, v2 in v.items(): 280 | print('\t {}: {}'.format(k2, v2)) 281 | 282 | 283 | # noinspection PyProtectedMember 284 | def selu_weight_init_(tensor, mean=0.0): 285 | """ 286 | SELU layer weight initialization function. 287 | 288 | Function assigned to variable that will be called within 289 | _init_weights function to assign weights for selu. 290 | 291 | Parameters: 292 | tensor : torch.tensor 293 | Weight tensor to be adjusted 294 | mean : float 295 | Mean value for the normal distribution 296 | 297 | Returns: 298 | torch.tensor 299 | weight tensor with normal distribution 300 | 301 | """ 302 | with torch.no_grad(): 303 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(tensor) 304 | std = math.sqrt(1. / fan_in) 305 | return nn.init.normal_(tensor, mean, std) 306 | 307 | 308 | def selu_bias_init_(tensor, const=0.0): 309 | """ 310 | SELU layer bias initialization function. 311 | 312 | Function assigned to variable that will be called within 313 | _init_bias function to assign bias for selu. 314 | 315 | Parameters: 316 | tensor : torch.tensor 317 | Bias tensor to be adjusted 318 | const : float 319 | Constant value to be assigned to tensor. 320 | 321 | Returns: 322 | torch.tensor 323 | bias tensor with constant values. 324 | 325 | """ 326 | with torch.no_grad(): 327 | return nn.init.constant_(tensor, const) 328 | 329 | 330 | def set_tensor_device(data, device=None): 331 | """ 332 | Convert list of data tensors to specified device. 333 | 334 | Parameters: 335 | data : torch.tensor or list 336 | data to be converted to the specified device. 337 | device : str or torch.device 338 | the desired device 339 | 340 | Returns: 341 | data : torch.tensor or list 342 | data converted to the specified device 343 | 344 | """ 345 | if not isinstance(data, (list, tuple)): 346 | data = data.to(device=device, non_blocking=True) 347 | 348 | else: 349 | for idx, d in enumerate(data): 350 | data[idx] = set_tensor_device(d, device=device) 351 | return data 352 | 353 | 354 | def master_device_setter(network, device=None): 355 | """ 356 | Convert network and input_networks to specified device. 357 | 358 | Parameters: 359 | network : BaseNetwork 360 | network to be converted to the specified device. 361 | device : str or torch.device 362 | the desired device 363 | 364 | """ 365 | network.device = device 366 | if network.input_networks: 367 | for net in network.input_networks.values(): 368 | master_device_setter(net, device) 369 | -------------------------------------------------------------------------------- /vulcanai/plotters/visualization.py: -------------------------------------------------------------------------------- 1 | """Contains all visualization methods.""" 2 | import os 3 | 4 | import numpy as np 5 | 6 | from math import sqrt, ceil, floor 7 | import pickle 8 | from datetime import datetime 9 | from .utils import GuidedBackprop, get_notable_indices 10 | 11 | from sklearn.manifold import TSNE 12 | from sklearn.decomposition import PCA 13 | 14 | import matplotlib 15 | from mpl_toolkits.axes_grid1 import make_axes_locatable 16 | import matplotlib.pyplot as plt 17 | import seaborn as sns 18 | 19 | import itertools 20 | import logging 21 | logger = logging.getLogger(__name__) 22 | 23 | DISPLAY_AVAILABLE = True if os.environ.get("DISPLAY") else False 24 | 25 | 26 | def save_visualization(plot, path=None): 27 | """ 28 | Save plot at designated path. 29 | 30 | Parameters: 31 | plot : matplotlib 32 | Matplotlib variable with savefig ability. 33 | path : string 34 | String that designates the path to save the given figure to. 35 | 36 | Returns: 37 | None 38 | 39 | """ 40 | plot.savefig(path) 41 | logger.info("Saved visualization at %s", path) 42 | 43 | 44 | def get_save_path(path, vis_type): 45 | """Return a save_path string.""" 46 | path = "{}{}_{date:%Y-%m-%d_%H-%M-%S}.png".format( 47 | path, vis_type, date=datetime.now()) 48 | return path 49 | 50 | 51 | def display_record(record=None, save_path=None, interactive=True): 52 | """ 53 | Display the training curve for a network training session. 54 | 55 | Parameters: 56 | record : dict 57 | the network record dictionary for dynamic graphs during training. 58 | save_path : String 59 | String that designates the path to save figure to be produced. 60 | Save_path must be a proper path that ends with a filename with an 61 | image filetype. 62 | interactive : boolean 63 | To display during training or afterwards. 64 | 65 | Returns: 66 | None 67 | 68 | """ 69 | title = 'Training curve' 70 | if record is None or not isinstance(record, dict): 71 | raise ValueError('No record exists and cannot be displayed.') 72 | 73 | plt.subplot(1, 2, 1) 74 | plt.title("{}: Error".format(title)) 75 | train_error, = plt.plot( 76 | record['epoch'], 77 | record['train_error'], 78 | '-mo', 79 | label='Train Error' 80 | ) 81 | # val_error = \ 82 | # [i if ~np.isnan(i) else None for i in record['validation_error']] 83 | validation_error, = plt.plot( 84 | record['epoch'], 85 | record['validation_error'], 86 | '-ro', 87 | label='Validation Error' 88 | ) 89 | plt.xlabel("Epoch") 90 | plt.ylabel("Cross entropy error") 91 | plt.legend(handles=[train_error, 92 | validation_error], 93 | loc=0) 94 | 95 | plt.subplot(1, 2, 2) 96 | plt.title("{}: Accuracy".format(title)) 97 | train_accuracy, = plt.plot( 98 | record['epoch'], 99 | record['train_accuracy'], 100 | '-go', 101 | label='Train Accuracy' 102 | ) 103 | validation_accuracy, = plt.plot( 104 | record['epoch'], 105 | record['validation_accuracy'], 106 | '-bo', 107 | label='Validation Accuracy' 108 | ) 109 | plt.xlabel("Epoch") 110 | plt.ylabel("Accuracy") 111 | plt.ylim(0, 1) 112 | 113 | plt.legend(handles=[train_accuracy, 114 | validation_accuracy], 115 | loc=0) 116 | 117 | if save_path: 118 | save_visualization(plt, save_path) 119 | 120 | if not DISPLAY_AVAILABLE and save_path is None: 121 | raise RuntimeError( 122 | "No display environment found. " 123 | "Display environment needed to plot, " 124 | "or set save_path=path/to/dir") 125 | 126 | elif interactive is True: 127 | plt.draw() 128 | plt.pause(1e-17) 129 | 130 | 131 | def display_pca(input_data, targets, label_map=None, save_path=None): 132 | """ 133 | Calculate pca reduction and plot it. 134 | 135 | Parameters: 136 | input_data : numpy.dnarray 137 | Input data to reduce in dimensions. 138 | targets : numpy.ndarray 139 | size (batch, labels) for samples. 140 | label_map : dict 141 | labelled {str(int), string} key, value pairs. 142 | save_path : String 143 | String that designates the path to save figure to be produced. 144 | 145 | """ 146 | pca = PCA(n_components=2, random_state=0) 147 | x_transform = pca.fit_transform(input_data) 148 | _plot_reduction( 149 | x_transform, 150 | targets, 151 | label_map=label_map, 152 | title='PCA Visualization', 153 | save_path=save_path) 154 | 155 | 156 | def display_tsne(input_data, targets, label_map=None, save_path=None): 157 | """ 158 | t-distributed Stochastic Neighbor Embedding (t-SNE) visualization [1]. 159 | 160 | [1]: Maaten, L., Hinton, G. (2008). Visualizing Data using t-SNE. 161 | JMLR 9(Nov):2579--2605. 162 | 163 | Parameters: 164 | input_data : numpy.dnarray 165 | Input data to reduce in dimensions. 166 | targets : numpy.ndarray 167 | size (batch, labels) for samples. 168 | label_map : dict 169 | labelled {str(int), string} key, value pairs. 170 | save_path : String 171 | String that designates the path to save figure to be produced. 172 | 173 | """ 174 | tsne = TSNE(n_components=2, random_state=0) 175 | x_transform = tsne.fit_transform(input_data) 176 | _plot_reduction( 177 | x_transform, 178 | targets, 179 | label_map=label_map, 180 | title='t-SNE Visualization', 181 | save_path=save_path) 182 | 183 | 184 | def _plot_reduction(x_transform, targets, label_map, title, save_path=None): 185 | """Once PCA and t-SNE has been calculated, this is used to plot.""" 186 | y_unique = np.unique(targets) 187 | if label_map is None: 188 | label_map = {str(i): str(i) for i in y_unique} 189 | elif not isinstance(label_map, dict): 190 | raise ValueError('label_map most be a dict of a key' 191 | ' mapping to its true label') 192 | colours = np.array(sns.color_palette("hls", len(y_unique))) 193 | plt.figure() 194 | for index, cl in enumerate(y_unique): 195 | plt.scatter(x=x_transform[targets == cl, 0], 196 | y=x_transform[targets == cl, 1], 197 | s=100, 198 | c=colours[index], 199 | alpha=0.5, 200 | marker='o', 201 | edgecolors='none', 202 | label=label_map[str(cl)]) 203 | plt.xlabel('X') 204 | plt.ylabel('Y') 205 | plt.legend(loc='upper right') 206 | plt.title(title) 207 | 208 | if save_path: 209 | save_path = get_save_path(save_path, vis_type=title) 210 | save_visualization(plt, save_path) 211 | 212 | if not DISPLAY_AVAILABLE and save_path is None: 213 | raise RuntimeError( 214 | "No display environment found. " 215 | "Display environment needed to plot, " 216 | "or set save_path=path/to/dir") 217 | else: 218 | plt.show(False) 219 | 220 | 221 | def display_confusion_matrix(cm, class_list=None, save_path=None): 222 | """ 223 | Print and plot the confusion matrix. 224 | 225 | inspired from: https://github.com/zaidalyafeai/Machine-Learning 226 | 227 | Parameters: 228 | cm : numpy.ndarray 229 | 2D confustion_matrix obtained using utils.get_confusion_matrix 230 | class_list : list 231 | Actual class labels (e.g.: MNIST - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 232 | save_path : String 233 | String that designates the path to save figure to be produced. 234 | 235 | """ 236 | if class_list is None: 237 | class_list = list(range(cm.shape[0])) 238 | if not isinstance(class_list, list): 239 | raise ValueError("class_list must be of type list.") 240 | plt.figure() 241 | ax = plt.gca() 242 | im = ax.imshow(cm, interpolation='nearest', cmap='Blues', origin='lower') 243 | 244 | plt.title('Confusion matrix') 245 | tick_marks = np.arange(len(class_list)) 246 | plt.xticks(tick_marks, class_list, rotation=45) 247 | plt.yticks(tick_marks, class_list) 248 | # Plot number overlay 249 | thresh = cm.max() / 2. 250 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 251 | plt.text(j, i, format(cm[i, j], 'd'), 252 | horizontalalignment="center", 253 | color="white" if cm[i, j] > thresh else "black") 254 | # Plot labels 255 | plt.tight_layout() 256 | plt.ylabel('True label') 257 | plt.xlabel('Predicted label') 258 | divider = make_axes_locatable(ax) 259 | cax = divider.append_axes("right", size="4%", pad=0.03) 260 | plt.colorbar(im, cax=cax) 261 | 262 | if save_path: 263 | save_path = get_save_path(save_path, vis_type='confusion_matrix') 264 | save_visualization(plt, save_path) 265 | 266 | if not DISPLAY_AVAILABLE and save_path is None: 267 | raise RuntimeError( 268 | "No display environment found. " 269 | "Display environment needed to plot, " 270 | "or set save_path=path/to/dir") 271 | else: 272 | plt.show(False) 273 | 274 | 275 | def compute_saliency_map(network, input_data, targets): 276 | """ 277 | Return the saliency map using the guided backpropagation method [1]. 278 | 279 | [1]: Springgenberg, J.T., Dosovitskiy, A., Brox, T., Riedmiller, M. (2015). 280 | Striving for Simplicity: The All Convolutional Net. ICLR 2015 281 | (https://arxiv.org/pdf/1412.6806.pdf) 282 | 283 | Parameters: 284 | network : BaseNetwork 285 | A network to get saliency maps on. 286 | input_data : numpy.ndarray 287 | Input array of shape (batch, channel, width, height) or 288 | (batch, channel, width, height, depth). 289 | targets : numpy.ndarray 290 | 1D array with class targets of size [batch]. 291 | 292 | Returns: 293 | saliency_map : list of numpy.ndarray 294 | Top layer gradients of the same shape as input data. 295 | 296 | """ 297 | guided_backprop = GuidedBackprop(network) 298 | saliency_map = guided_backprop.generate_gradients(input_data, targets) 299 | return saliency_map 300 | 301 | 302 | def display_saliency_overlay(image, saliency_map, shape=(28, 28), save_path=None): 303 | """ 304 | Plot overlay saliency map over image. 305 | 306 | Parameters: 307 | image : numpy.ndarray 308 | (1D, 2D, 3D) for single image or linear output. 309 | saliency_map: numpy.ndarray 310 | (1D, 2D, 3D) for single image or linear output. 311 | shape : tuple, list 312 | The dimensions of the image. Defaults to mnist. 313 | save_path : String 314 | String that designates the path to save figure to be produced. 315 | 316 | """ 317 | # Handle different colour channels and shapes for image input 318 | if len(image.shape) == 3: 319 | if image.shape[0] == 1: 320 | # For 1 colour channel, remove it 321 | image = image[0] 322 | elif image.shape[0] == 3 or image.shape[0] == 4: 323 | # For 3 or 4 colour channels, move to end for plotting 324 | image = np.moveaxis(image, 0, -1) 325 | else: 326 | raise ValueError("Invalid number of colour channels in input.") 327 | elif len(image.shape) == 1: 328 | image = np.reshape(image, shape) 329 | 330 | # Handle different colour channels and shapes for saliency map 331 | if len(saliency_map.shape) == 3: 332 | if saliency_map.shape[0] == 1: 333 | # For 1 colour channel, remove it 334 | saliency_map = saliency_map[0] 335 | elif saliency_map.shape[0] == 3 or saliency_map.shape[0] == 4: 336 | # For 3 or 4 colour channels, move to end for plotting 337 | saliency_map = np.moveaxis(saliency_map, 0, -1) 338 | else: 339 | raise ValueError("Invalid number of channels in saliency map.") 340 | elif len(saliency_map.shape) == 1: 341 | saliency_map = np.reshape(saliency_map, shape) 342 | 343 | fig = plt.figure() 344 | fig.suptitle("Saliency Map") 345 | # Plot original image 346 | fig.add_subplot(1, 2, 1) 347 | plt.imshow(image, cmap='gray') 348 | # Plot original with saliency overlay 349 | fig.add_subplot(1, 2, 2) 350 | plt.imshow(image, cmap='binary') 351 | 352 | ax = plt.gca() 353 | im = ax.imshow(saliency_map, cmap='Blues', alpha=0.7) 354 | divider = make_axes_locatable(ax) 355 | cax = divider.append_axes("right", size="4%", pad=0.03) 356 | plt.colorbar(im, cax=cax, format='%.0e') 357 | 358 | if save_path: 359 | save_path = get_save_path(save_path, vis_type='saliency_map') 360 | save_visualization(plt, save_path) 361 | 362 | if not DISPLAY_AVAILABLE and save_path is None: 363 | raise RuntimeError( 364 | "No display environment found. " 365 | "Display environment needed to plot, " 366 | "or set save_path=path/to/dir") 367 | else: 368 | plt.show(False) 369 | 370 | 371 | def display_receptive_fields(network, top_k=5, save_path=None): 372 | """ 373 | Display receptive fields of layers from a network [1]. 374 | 375 | [1]: Luo, W., Li, Y., Urtason, R., Zemel, R. (2016). 376 | Understanding the Effective Receptive Field in Deep 377 | Convolutional Neural Networks. Advances in Neural Information 378 | Processing Systems, 29 (NIPS 2016) 379 | 380 | Parameters: 381 | network : BaseNetwork 382 | Network to get receptive fields of. 383 | top_k : int 384 | To return the most and least k important features from field 385 | save_path : String 386 | String that designates the path to save figure to be produced. 387 | 388 | Returns: 389 | k_features: dict 390 | A dict of the top k and bottom k important features. 391 | 392 | """ 393 | if type(network).__name__ == "ConvNet": 394 | raise NotImplementedError 395 | elif '_input_network' in network._modules: 396 | if type(network._modules['_input_network']).__name__ == "ConvNet": 397 | raise NotImplementedError 398 | 399 | feature_importance = {} 400 | fig = plt.figure() 401 | fig.suptitle("Feature importance") 402 | num_layers = len(network._modules['network']) 403 | for i, layer in enumerate(network._modules['network']): 404 | raw_field = layer._kernel._parameters['weight'].cpu().detach().numpy() 405 | field = np.average(raw_field, axis=0) # average all outgoing 406 | field_shape = [ 407 | floor(sqrt(field.shape[0])), 408 | ceil(sqrt(field.shape[0])) 409 | ] 410 | fig.add_subplot( 411 | floor(sqrt(num_layers)), 412 | ceil(sqrt(num_layers)), 413 | i + 1) 414 | field = abs(field) 415 | feats = get_notable_indices(field, top_k=top_k) 416 | unit_type = type(layer).__name__ 417 | layer_name = '{}_{}'.format(unit_type, i) 418 | feature_importance.update({layer_name: feats}) 419 | plt.title(layer_name) 420 | plt.imshow(np.resize(field, field_shape), cmap='Blues') 421 | plt.colorbar() 422 | 423 | if save_path: 424 | save_path = get_save_path(save_path, vis_type='feature_importance') 425 | save_visualization(plt, save_path) 426 | 427 | if not DISPLAY_AVAILABLE and save_path is None: 428 | raise RuntimeError( 429 | "No display environment found. " 430 | "Display environment needed to plot, " 431 | "or set save_path=path/to/dir") 432 | else: 433 | plt.show(False) 434 | 435 | return feature_importance 436 | --------------------------------------------------------------------------------