├── .flake8 ├── .github └── workflows │ └── main.yml ├── .gitignore ├── .gitlab-ci.yml ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── LICENSE.txt ├── README.md ├── deepstruct ├── __init__.py ├── __version__.py ├── constants.py ├── dataset.py ├── flexible_transform.py ├── graph.py ├── models.py ├── node_map_strategies.py ├── pruning │ ├── __init__.py │ ├── engine.py │ ├── strategy.py │ └── util.py ├── recurrent.py ├── scalable.py ├── sparse.py ├── transform.py ├── traverse_strategies.py └── util.py ├── development.md ├── docs ├── artificial-landscape-approximation.png ├── index.md ├── logo-wide.png ├── logo.png ├── masked-deep-cell-dan.png ├── masked-deep-dan.png ├── masked-deep-ffn.png ├── methods-pruning-growing.graphml ├── methods-pruning-growing.png ├── sparse-network.graphml └── sparse-network.png ├── mkdocs.yml ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py ├── graph ├── __init__.py ├── test_CachedLayeredGraph.py ├── test_LabeledDAG.py └── test_LayerIndex.py ├── large ├── __init__.py ├── test_GraphTransform.py └── test_flexible_transform.py ├── models.py ├── models ├── __init__.py └── test_DeepGraphModule.py ├── pruning ├── __init__.py └── test_engine_basic.py ├── scalable └── test_scalable_dan.py ├── sparse ├── __init__.py ├── test_DeepCellDAN.py ├── test_MaskedDeepDAN.py ├── test_MaskedDeepFFN.py └── test_MaskedLinearLayer.py ├── test_MaskableModule.py ├── test_flexible_transform.py ├── test_low_level_rep.py ├── test_recurrent.py ├── training ├── __init__.py └── test_MaskedLinearLayer.py ├── transform ├── __init__.py ├── test_Conv2dLayerFunctor.py ├── test_GraphTransform.py └── test_LinearLayerFunctor.py └── utils.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | ignore = E203, E501, W503 4 | per-file-ignores = __init__.py:F401 5 | exclude = 6 | .git 7 | __pycache__ 8 | setup.py 9 | build 10 | dist 11 | releases 12 | .venv 13 | .tox 14 | .mypy_cache 15 | .pytest_cache 16 | .vscode 17 | .github 18 | res 19 | tests/cache/ 20 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | Linting: 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - uses: actions/checkout@v3 11 | - name: Set up Python 3.9 12 | uses: actions/setup-python@v4 13 | with: 14 | python-version: 3.9 15 | - name: Get full Python version 16 | id: full-python-version 17 | run: echo version=$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))") >> $GITHUB_OUTPUT 18 | - name: Linting 19 | run: | 20 | pip install pre-commit 21 | pre-commit run --all-files 22 | 23 | Tests: 24 | needs: Linting 25 | name: ${{ matrix.os }} (${{ matrix.python_version }}) 26 | runs-on: ${{ matrix.os }}-latest 27 | strategy: 28 | matrix: 29 | os: [Ubuntu, MacOS] 30 | python_version: ["3.9", "3.10", "3.11"] 31 | steps: 32 | - uses: actions/checkout@v3 33 | 34 | - name: Set up Python ${{ matrix.python_version }} 35 | uses: actions/setup-python@v4 36 | with: 37 | python-version: ${{ matrix.python_version }} 38 | 39 | - name: Get full python version 40 | id: full-python-version 41 | shell: bash 42 | run: echo ::set-output name=version::$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))") 43 | 44 | - name: Python Poetry Action 45 | uses: abatilo/actions-poetry@v2 46 | 47 | - name: Configure poetry 48 | shell: bash 49 | run: poetry config virtualenvs.in-project true 50 | 51 | - name: Set up cache 52 | uses: actions/cache@v2 53 | id: cache 54 | with: 55 | path: .venv 56 | key: venv-${{ runner.os }}-${{ steps.full-python-version.outputs.version }}-${{ hashFiles('**/poetry.lock') }} 57 | 58 | - name: Ensure cache is healthy 59 | if: steps.cache.outputs.cache-hit == 'true' 60 | shell: bash 61 | run: timeout 5s poetry run pip --version || rm -rf .venv 62 | 63 | - name: Install dependencies 64 | shell: bash 65 | run: poetry install 66 | 67 | - name: Run pytest 68 | shell: bash 69 | run: poetry run pytest -q tests --ignore=tests/large/ 70 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Temporary files 2 | __pycache__/ 3 | *~ 4 | tests/cache/ 5 | 6 | # Build files and directories 7 | .venv/ 8 | src/ 9 | build/ 10 | *.c 11 | *.so 12 | dist/ 13 | deepstruct.egg* 14 | 15 | # Log-directories 16 | tensorboard/ 17 | 18 | # IDE-specific files 19 | .idea/ 20 | *.iml 21 | .ipynb_checkpoints/ 22 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | cache: 2 | key: "deepstruct-${CI_JOB_NAME}" 3 | paths: 4 | - .cache/pip 5 | - .venv 6 | 7 | stages: 8 | - stage-quality 9 | - stage-tests 10 | 11 | .install-deps-template: &install-deps 12 | before_script: 13 | - pip install poetry 14 | - poetry --version 15 | - poetry config virtualenvs.in-project true 16 | - poetry install -vv 17 | 18 | .test-template: &test 19 | <<: *install-deps 20 | stage: stage-tests 21 | coverage: '/TOTAL.*\s(\d+\.\d+\%)/' 22 | script: 23 | - poetry run pytest -q tests --ignore=tests/large/ 24 | artifacts: 25 | paths: 26 | - tests/logs 27 | when: always 28 | expire_in: 1 week 29 | 30 | # Test Jobs 31 | test-python3.8: 32 | <<: *test 33 | image: python:3.8 34 | 35 | test-python3.9: 36 | <<: *test 37 | image: python:3.9 38 | 39 | test-python3.10: 40 | <<: *test 41 | image: python:3.10 42 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 24.1.1 4 | hooks: 5 | - id: black 6 | 7 | - repo: https://github.com/pycqa/flake8 8 | rev: 7.0.0 9 | hooks: 10 | - id: flake8 11 | 12 | - repo: https://github.com/timothycrosley/isort 13 | rev: 5.13.2 14 | hooks: 15 | - id: isort 16 | additional_dependencies: [toml] 17 | 18 | - repo: https://github.com/pre-commit/pre-commit-hooks 19 | rev: v4.5.0 20 | hooks: 21 | - id: trailing-whitespace 22 | exclude: ^tests/.*/fixtures/.* 23 | - id: end-of-file-fixer 24 | exclude: ^tests/.*/fixtures/.* 25 | - id: debug-statements 26 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog for deepstruct 2 | 3 | ## 0.10 4 | * bug-fix for 'RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation' caused by in-place operation in model (thanks to Mohammad Alahmad) 5 | * new methods for scalable models from a underlying computational theme structure 6 | * more explicit methods to store and re-load a cached layered graph 7 | * some documentation on base module *MaskedLinearLayer* 8 | * dependency updates 9 | 10 | ## 0.9 11 | * re-introduced saliency as an optional additional property on MaskedLinearLayers for communicating saliency measures on weight-level to decide on further pruning 12 | * fixed some of the simpler pruning functions such as prune_network_by_saliency() and prune_layer_by_saliency() from deepstruct.pruning 13 | * masks up to now do not consider bias vectors which might be unexpected behaviour 14 | 15 | ## 0.8 16 | * deprecation of learning utilities 17 | * integrated additional normalization layers 18 | * masks on maskable layers are parameterizable to investigate on structural regularization ideas 19 | * functional dataset can now be easily stored in a pickle file 20 | 21 | ## 0.7 22 | * new minimal version requirement is python 3.7 23 | * introduced interface for "functors" which transform a nn.Module into a directed acyclic graph 24 | * created a first functor for Linear and MaskedLinear layers 25 | * a graph transform class passes a random input through a generic module and can transform it into a graph given that it consists of linear or conv2d layers (first tests added) 26 | * added mkdocs to provide an initial documentation skeleton 27 | 28 | ## 0.6 29 | * introduced *BaseRecurrentLayer*, *MaskedRecurrentLayer*, *MaskedGRULayer*, *MaskedLSTMLayer* 30 | * introduced *deepstruct.recurrent.MaskedDeepRNN* for sparse recurrent models 31 | 32 | ## 0.5 33 | * new feature: concept of scalable families which is a first notion of *graph themes* analysis 34 | * file restructuring for better semantics 35 | * pypaddle will be renamed to deepstruct 36 | 37 | ## 0.4 38 | * switched to poetry for dependency and build management 39 | * added integration tests 40 | * switched to pytest instead of unittest 41 | 42 | ## 0.3 43 | * added support to define input shape for MaskedDeepFFN and MaskedDeepDAN 44 | * changed parameter for recompute_mask(epsilon) to recompute_mask(theta) as it should denote a threshold 45 | * implemented a first running version of a randomly wired cell network, more general than RandWireNN and in spirit of analysing graph theoretic properties 46 | * bugfixes on generating structures from masks 47 | * added/modified data loader utilities for mnist/cifar (probably no official part and concern of this library tools) 48 | * fixed PyPi setup and tested installation routine 49 | * defined networkx and torch as dependencies in setup.py. Next will be to check if it can be shadowed by pytorch packages from conda channels 50 | * added a DeepCellDAN() which builds directed, acyclic networks with customized cells given a certain structure 51 | 52 | ## 0.2 53 | * introduced LayeredGraph as a wrapper for directed graphs which provides access to its layered ordering 54 | * central provided modules are MaskedLinearLayer, MaskedDeepFFN and MaskedDeepDAN 55 | * provided first functionality to generate structures from masked modules 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # deepstruct - neural network structure tool [![PyPI version](https://badge.fury.io/py/deepstruct.svg)](https://badge.fury.io/py/deepstruct) ![Tests](https://github.com/innvariant/deepstruct/workflows/Tests/badge.svg) [![Documentation Status](https://readthedocs.org/projects/deepstruct/badge/?version=latest)](https://deepstruct.readthedocs.io/en/latest/?badge=latest) [![Downloads](https://pepy.tech/badge/deepstruct)](https://pepy.tech/project/deepstruct) [![Python 3.8](https://img.shields.io/badge/python-3.8-blue.svg)](https://www.python.org/downloads/release/python-380/) 2 | ![deepstruct neural network structure tool](docs/logo-wide.png) 3 | 4 | Create deep neural networks based on very different kinds of graphs or use *deepstruct* to extract the structure of your deep neural network. 5 | Deepstruct can automatically create a deep neural network models based on graphs and for purposes of visualization, analysis or transformations it also supports graph extraction from a given model. 6 | 7 | Interested in neural network visualizations, pruning, neural architecture search or neural structure in general? 8 | 9 | See [examples](#examples) below or [read the docs](https://deepstruct.readthedocs.io). 10 | 11 | 12 | ## Installation 13 | - With **pip** from PyPi: ``pip install deepstruct`` 14 | - With **poetry** (recommended for *projects*) using PyPi: ``poetry add deepstruct`` 15 | - With **conda** in your *environment.yml* (recommended for reproducible experiments): 16 | ```yaml 17 | name: exp01 18 | channels: 19 | - defaults 20 | dependencies: 21 | - pip>=20 22 | - pip: 23 | - deepstruct 24 | ``` 25 | - From public GitHub: ``pip install --upgrade git+ssh://git@github.com:innvariant/deepstruct.git`` 26 | 27 | 28 | ## Quick usages 29 | *deepstruct* provides two major tool approaches to pytorch models: 30 | 31 | 1. Based on a graph structure, it allows you to automatically **construct** a deep neural network model. 32 | 2. Given a (pre-trained) model, *deepstruct* provides options to **extract** different notions of a graph from it as to further visualize or analyze it. 33 | 34 | 35 | ### Constructing Models 36 | 37 | #### Multi-layered feed-forward neural network on MNIST 38 | The simplest implementation is one which provides multiple layers with binary masks for each weight matrix. 39 | It doesn't consider any skip-layer connections. 40 | Each layer is then connected to only the following one. 41 | ```python 42 | import deepstruct.sparse 43 | 44 | mnist_model = deepstruct.sparse.MaskedDeepFFN((1, 28, 28), 10, [100]*10, use_layer_norm=True) 45 | ``` 46 | This is a ready-to-use pytorch module which has ten layers of each one hundred neurons and applies layer normalization before each activation. 47 | Training it on any dataset will work out of the box like every other pytorch module. 48 | Have a look on [pytorch ignite](https://pytorch.org/ignite/) or [pytorch lightning](https://github.com/Lightning-AI/lightning/) for designing your training loops. 49 | You can set masks on the model via 50 | ```python 51 | import deepstruct.sparse 52 | for layer in deepstruct.sparse.maskable_layers(mnist_model): 53 | layer.mask[:, :] = True 54 | ``` 55 | and if you disable some of these mask elements you have defined your first sparse model. 56 | 57 | 58 | #### Random Graphs as Structural Priors 59 | Specify structures by prior design, e.g. random social networks transformed into directed acyclic graphs: 60 | ```python 61 | import networkx as nx 62 | import deepstruct.sparse 63 | 64 | # Use networkx to generate a random graph based on the Watts-Strogatz model 65 | random_graph = nx.newman_watts_strogatz_graph(100, 4, 0.5) 66 | structure = deepstruct.graph.CachedLayeredGraph() 67 | structure.add_edges_from(random_graph.edges) 68 | structure.add_nodes_from(random_graph.nodes) 69 | 70 | # Build a neural network classifier with 784 input and 10 output neurons and the given structure 71 | model = deepstruct.sparse.MaskedDeepDAN(784, 10, structure) 72 | model.apply_mask() # Apply the mask on the weights (hard, not undoable) 73 | model.recompute_mask() # Use weight magnitude to recompute the mask from the network 74 | pruned_structure = model.generate_structure() # Get the structure -- a networkx graph -- based on the current mask 75 | 76 | new_model = deepstruct.sparse.MaskedDeepDAN(784, 10, pruned_structure) 77 | ``` 78 | 79 | 80 | #### Recurrent Neural Networks with sparsity 81 | ```python 82 | import torch 83 | import deepstruct.recurrent 84 | import numpy as np 85 | 86 | # A sequence of size 15 with one-dimensional elements which could e.g. be labelled 87 | # BatchSize x [(1,), (2,), (3,), (4,), (5,), (0,), (0,), (0,)] --> [ label1, label2, ..] 88 | batch_size = 100 89 | seq_size = 15 90 | input_size = 1 91 | model = deepstruct.recurrent.MaskedDeepRNN( 92 | input_size, 93 | hidden_layers=[100, 100, 1], 94 | batch_first=True, 95 | build_recurrent_layer=deepstruct.recurrent.MaskedLSTMLayer, 96 | ) 97 | random_input = torch.tensor( 98 | np.random.random((batch_size, seq_size, input_size)), 99 | dtype=torch.float32, 100 | requires_grad=False, 101 | ) 102 | model.forward(random_input) 103 | ``` 104 | 105 | 106 | 107 | ### Graph Extraction 108 | Define a feed-forward neural network (with no skip-layer connections) and obtain its structure as a graph: 109 | ```python 110 | import torch 111 | import deepstruct.sparse 112 | from deepstruct.transform import GraphTransform 113 | 114 | model = deepstruct.sparse.MaskedDeepFFN(784, 10, [100, 100]) 115 | # .. train the model or load a pre-trained 116 | 117 | # Define a random input tensor which is required to analyse the structure 118 | input_random = torch.randn((1, 20)) 119 | functor = GraphTransform(input_random) 120 | result = functor.transform(model) 121 | ``` 122 | 123 | 124 | 125 | 126 | 127 | ## Sparse Neural Network implementations 128 | ![Sparse Network Connectivity on zeroth order with a masked deep feed-forward neural network](docs/masked-deep-ffn.png) 129 | ![Sparse Network Connectivity on zeroth order with a masked deep neural network with skip-layer connections](docs/masked-deep-dan.png) 130 | ![Sparse Network Connectivity on second order with a masked deep cell-based neural network](docs/masked-deep-cell-dan.png) 131 | 132 | **What's contained in deepstruct?** 133 | - ready-to-use models in pytorch for learning instances on common (supervised/unsupervised) datasets from which a structural analysis is possible 134 | - model-to-graph transformations for studying models from a graph-theoretic perspective 135 | 136 | **Models:** 137 | - *deepstruct.sparse.MaskableModule*: pytorch modules that contain explicit masks to enforce (mostly zero-ordered) structure 138 | - *deepstruct.sparse.MaskedLinearLayer*: pytorch module with a simple linear layer extended with masking capability. 139 | Suitable if you want to have linear-layers on which to enforce masks which could be obtained through pruning, regularization or other other search techniques. 140 | - *deepstruct.sparse.MaskedDeepFFN*: feed-forward neural network with any width and depth and easy-to-use masks. 141 | Suitable for simple and canonical pruning research on zero-ordered structure 142 | - *deepstruct.sparse.MaskedDeepDAN*: feed-forward neural network with skip-layer connections based on any directed acyclic network. 143 | Suitable for arbitrary structures on zero-order and on that level most flexible but also computationally expensive. 144 | - *deepstruct.sparse.DeepCellDAN*: complex module based on a directed acyclic network and custom cells on third-order structures. 145 | Suitable for large-scale neural architecture search 146 | - *deepstruct.recurrent.MaskedDeepRNN*: multi-layered network with recurrent layers which can be masked 147 | 148 | ## What is the orders of structure? 149 | - zero-th order: weight-level 150 | - first order: kernel-level (filter, channel, blocks, cells) 151 | - second order: layers 152 | 153 | There is various evidence across empirical machine learning studies that the way artificial neural networks are structurally connected has a (minor?) influence on performance metrics such as the accuracy or probably even on more complex concepts such as adversarial robustness. 154 | What do we mean by "structure"? 155 | We define structure over graph theoretic properties given a computational graph with very restricted non-linearities. 156 | This includes all major neural network definitions and lets us study them from the perspective of their *representation* and their *structure*. 157 | In a probabilistic sense, one can interprete structure as a prior to the model and despite single-layered wide networks are universal function approximators we follow the hypothesis that given certain structural priors we can find models with better properties. 158 | 159 | Before considering implementations, one should have a look on possible representations of Sparse Neural Networks. 160 | In case of feed-forward neural networks (FFNs) the network can be represented as a list of weight matrices. 161 | Each weight matrix represents the connections from one layer to the next. 162 | Having a network without some connections then means setting entries in those matrices to zero. 163 | Removing a particular neuron means setting all entries representing its incoming connections to zero. 164 | 165 | However, sparsity can be employed on various levels of a general artificial neural network. 166 | Zero order sparsity would remove single weights (representing connections) from the network. 167 | First order sparsity removes groups of weights within one dimension of a matrix from the network. 168 | Sparsity can be employed on connection-, weight-, block-, channel-, cell-level and so on. 169 | Implementations respecting the areas for sparsification can have drastical differences. 170 | Thus there are various ways for implementing Sparse Neural Networks. 171 | 172 | 173 | # Artificial PyTorch Datasets 174 | ![A custom artificial landscape Stier2020B for testing function approximation](docs/artificial-landscape-approximation.png) 175 | We provide some simple utilities for artificial function approximation. 176 | Like polynomials, neural networks are universal function approximators on bounded intervals of compact spaces. 177 | To test, you can easily define a function of any finite dimension, e.g. $f: \mathbb{R}^2\rightarrow\mathbb{R}, (x,y)\mapsto 20 + x - 1.8*(y-5) + 3 * np.sin(x + 2 * y) * y + (x / 4) ** 4 + (y / 4) ** 4$: 178 | 179 | ```python 180 | import numpy as np 181 | import torch.utils.data 182 | from dataset import FuncDataset 183 | from deepstruct.sparse import MaskedDeepFFN 184 | 185 | # Our artificial landscape: f: R^2 -> R 186 | # Have a look at https://github.com/innvariant/eddy for some visual examples 187 | # You could easily define arbitrary functions from R^a to R^b 188 | stier2020B1d = lambda x, y: 20 + x - 1.8 * (y - 5) + 3 * np.sin(x + 2 * y) * y + (x / 4) ** 4 + (y / 4) ** 4 189 | ds_input_shape = ( 190 | 2,) # specify the number of input dimensions (usually a one-sized tensor if no further structures are used) 191 | # Explicitly define the target function for the dataset which returns a numpy array of our above function 192 | # By above definition x is two-dimensional, so you have access to x[0] and x[1] 193 | fn_target = lambda x: np.array([stier2020B1d(x[0], x[1])]) 194 | # Define a sampling strategy for the dataset, e.g. uniform sampling the space 195 | fn_sampler = lambda: np.random.uniform(-2, 2, size=ds_input_shape) 196 | # Define the dataset given the target function and your sampling strategy 197 | # This simply wraps your function into a pytorch dataset and provides you with discrete observations 198 | # Your model will later only know those observations to come up with an approximate solution of your target 199 | ds_train = FuncDataset(fn_target, shape_input=ds_input_shape, size=500) 200 | 201 | # Calculate the output shape given our target function .. usually simply a (1,)-dimensional output 202 | ds_output_shape = fn_target(fn_sampler()).shape 203 | 204 | # As usual in pytorch, you can simply wrap your dataset with a loading strategy .. 205 | # This ensures e.g. that you do not iterate over your observations in the exact same manner 206 | # In case you sample first 100 examples of a binary classification dataset with label 1 and then another 207 | # 100 with label 2 it might impact your training .. so this ensures you have an e.g. random sampling strategy over the dataset 208 | batch_size = 100 209 | train_sampler = torch.utils.data.SubsetRandomSampler(np.arange(len(ds_train), dtype=np.int64)) 210 | train_loader = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, sampler=train_sampler, num_workers=2) 211 | 212 | # Define a model for which we can later extract its structure or impose sparsity constraints 213 | model = MaskedDeepFFN(2, 1, [50, 20]) 214 | 215 | # Iterate over your training set 216 | for feat, target in train_loader: 217 | print(feat, target) 218 | 219 | # feed it into a model to learn 220 | prediction = model.forward(feat) 221 | 222 | # compute a loss based on the expected target and the models prediction 223 | # .. 224 | ``` 225 | 226 | 227 | # References 228 | Have a look into our positioning paper on [arxiv](https://arxiv.org/abs/2111.06679) and related articles. 229 | We're glad if you cite our work 230 | ```bibtex 231 | @article{stier2022deepstruct, 232 | title={deepstruct -- linking deep learning and graph theory}, 233 | author={Stier, Julian and Granitzer, Michael}, 234 | journal={Software Impacts}, 235 | volume={11}, 236 | year={2022}, 237 | publisher={Elsevier} 238 | } 239 | @article{stier2019structural, 240 | title={Structural analysis of sparse neural networks}, 241 | author={Stier, Julian and Granitzer, Michael}, 242 | journal={Procedia Computer Science}, 243 | volume={159}, 244 | pages={107--116}, 245 | year={2019}, 246 | publisher={Elsevier} 247 | } 248 | @phdthesis{stier2024structures, 249 | author = "Julian Stier", 250 | title = "Structures of Artificial Neural Networks - Empirical Investigations", 251 | school = "University of Passau", 252 | year = "2024" 253 | } 254 | ``` 255 | -------------------------------------------------------------------------------- /deepstruct/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innvariant/deepstruct/0ffbaf73425c063485cb173f5c0fc3677368b522/deepstruct/__init__.py -------------------------------------------------------------------------------- /deepstruct/__version__.py: -------------------------------------------------------------------------------- 1 | from importlib_metadata import version 2 | 3 | 4 | __version__ = version("deepstruct") 5 | -------------------------------------------------------------------------------- /deepstruct/constants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | DEFAULT_OPERATIONS = [ 5 | (torch, "add"), 6 | (torch.Tensor, "add"), 7 | (torch, "cos"), 8 | (torch.nn.modules.conv.Conv2d, "forward"), 9 | (torch.nn.modules.Linear, "forward"), 10 | (torch.nn.modules.MaxPool2d, "forward"), 11 | (torch.nn.modules.Flatten, "forward"), 12 | (torch.nn.modules.BatchNorm2d, "forward"), 13 | (torch.nn.functional, "relu"), 14 | (torch.Tensor, "view"), 15 | (torch.Tensor, "size"), 16 | ] 17 | -------------------------------------------------------------------------------- /deepstruct/dataset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import deprecated 4 | import numpy as np 5 | import torch 6 | 7 | from torch.utils.data import Dataset 8 | from tqdm import tqdm 9 | 10 | 11 | class FunctionalDataset(Dataset): 12 | def __init__(self, fn, size: int, shape_input: tuple = None, sampler=None): 13 | assert fn is not None 14 | assert callable(fn) 15 | assert shape_input is not None 16 | assert size is not None and size > 0 17 | assert callable(sampler) or sampler is None 18 | 19 | self._size = size 20 | self._sampler = sampler 21 | self._fn = fn 22 | self._shape_input = shape_input 23 | self._data = None 24 | 25 | def generate(self, without_progress: bool = True): 26 | self._data = [] 27 | has_length = hasattr(self._fn, "__len__") 28 | fn_name = str(self._fn) 29 | with tqdm( 30 | total=self._size, desc="Generating samples", disable=without_progress 31 | ) as pbar: 32 | for ix in range(self._size): 33 | self._data.append(self.sample()) 34 | if has_length: 35 | pbar.set_postfix(**{fn_name: len(self._fn)}) 36 | pbar.update() 37 | 38 | @property 39 | def sampler(self): 40 | return self._sampler 41 | 42 | @sampler.setter 43 | def sampler(self, sampler): 44 | assert sampler is not None and callable(sampler) 45 | self._sampler = sampler 46 | 47 | def sample(self): 48 | assert self._sampler is not None 49 | preimage_sample = torch.tensor(self._sampler()).view(self._shape_input) 50 | # print(preimage_sample.shape) 51 | codomain = self._fn(preimage_sample) 52 | return preimage_sample, codomain 53 | 54 | def save(self, path_file): 55 | with open(path_file, "wb") as handle: 56 | pickle.dump( 57 | { 58 | "size": self._size, 59 | "shape_input": self._shape_input, 60 | "data": self._data, 61 | }, 62 | handle, 63 | ) 64 | 65 | @staticmethod 66 | def load(path_file): 67 | with open(path_file, "rb") as handle: 68 | meta = pickle.load(handle) 69 | dataset = FuncDataset( 70 | lambda: None, meta["size"], meta["shape_input"], lambda: None 71 | ) 72 | dataset._data = meta["data"] 73 | return dataset 74 | 75 | def __len__(self): 76 | return self._size 77 | 78 | def __getitem__(self, idx: int): 79 | if self._data is None: 80 | self.generate() 81 | assert self._data is not None 82 | 83 | idx = idx + int(idx < 0) * len(self) 84 | 85 | if idx > len(self): 86 | raise StopIteration("Given index <%s> exceeds dataset." % idx) 87 | 88 | return self._data[idx] 89 | 90 | 91 | @deprecated.deprecated(reason="FuncDataset gets a more explicit name", version="0.11.0") 92 | def FuncDataset(*args, **kwargs): 93 | return FunctionalDataset(*args, **kwargs) 94 | 95 | 96 | if __name__ == "__main__": 97 | import matplotlib.pyplot as plt 98 | 99 | from deepstruct.sparse import MaskedDeepFFN 100 | 101 | batch_size = 100 102 | ds_input_shape = (2,) 103 | # fn_target = lambda x: np.array([4+x[0]**2-3*x[1]]) 104 | # fn_target = lambda x: np.array([4 + x[0] ** 2 - 3 * x[1]]) 105 | stier2020B1d = ( 106 | lambda x, y: 20 107 | + x 108 | - 1.8 * (y - 5) 109 | + 3 * np.sin(x + 2 * y) * y 110 | + (x / 4) ** 4 111 | + (y / 4) ** 4 112 | ) 113 | 114 | def fn_target(x): 115 | return np.array([stier2020B1d(x[0], x[1])]) 116 | 117 | # Training 118 | ds_train = FuncDataset(fn_target, shape_input=ds_input_shape, size=500) 119 | ds_train.sampler = lambda: np.random.uniform(-2, 2, size=ds_input_shape) 120 | 121 | ds_output_shape = fn_target(ds_train.sampler()).shape 122 | print("f: R^(%s) --> R^(%s)" % (ds_input_shape, ds_output_shape)) 123 | 124 | train_sampler = torch.utils.data.SubsetRandomSampler( 125 | np.arange(len(ds_train), dtype=np.int64) 126 | ) 127 | train_loader = torch.utils.data.DataLoader( 128 | ds_train, batch_size=batch_size, sampler=train_sampler, num_workers=2 129 | ) 130 | 131 | model = MaskedDeepFFN(ds_input_shape, 1, [100, 100]) 132 | 133 | fn_loss = torch.nn.MSELoss() 134 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) 135 | 136 | model.train() 137 | errors_train = [] 138 | for epoch in range(100): 139 | for feat, target in train_loader: 140 | target = target.reshape(-1, 1) 141 | optimizer.zero_grad() 142 | pred = model(feat) 143 | error = fn_loss(target, pred) 144 | errors_train.append(error.detach().numpy()) 145 | error.backward() 146 | optimizer.step() 147 | if epoch % 10 == 0: 148 | print("Cur Err [-1]:", errors_train[-1]) 149 | print("Avg Err [-5]:", np.mean(errors_train[-5:])) 150 | 151 | # Testing 152 | ds_test = FuncDataset(fn_target, shape_input=ds_input_shape, size=5000) 153 | ds_test.sampler = lambda: np.random.uniform(-3, 3, size=ds_input_shape) 154 | 155 | test_sampler = torch.utils.data.SubsetRandomSampler( 156 | np.arange(len(ds_test), dtype=np.int64) 157 | ) 158 | test_loader = torch.utils.data.DataLoader( 159 | ds_test, batch_size=batch_size, sampler=test_sampler, num_workers=2 160 | ) 161 | 162 | model.eval() 163 | errors_test = [] 164 | xs = np.array([]).reshape((-1,) + ds_input_shape) 165 | ys = np.array([]).reshape((-1,) + ds_output_shape) 166 | ms = np.array([]).reshape((-1,) + ds_output_shape) 167 | for feat, target in test_loader: 168 | target = target.reshape(-1, 1) 169 | pred = model(feat) 170 | error = fn_loss(target, pred) 171 | errors_test.append(error.detach().numpy()) 172 | 173 | xs = np.vstack([xs, feat.detach().numpy()]) 174 | ys = np.vstack([ys, target.detach().numpy()]) 175 | ms = np.vstack([ms, pred.detach().numpy()]) 176 | 177 | print(errors_test) 178 | print(np.mean(errors_test)) 179 | 180 | # print(xs.shape) 181 | # print(ys.shape) 182 | # print(ms.shape) 183 | 184 | fig = plt.figure() 185 | ax = fig.add_subplot(111, projection="3d") 186 | ax.scatter(xs[:, 0], xs[:, 1], ys, marker=".", color="blue") 187 | ax.scatter(xs[:, 0], xs[:, 1], ms, marker=".", color="orange") 188 | # plt.plot(xs, ys) 189 | # plt.plot(xs, ms) 190 | ax.set_zlim(0, 50) 191 | plt.show() 192 | -------------------------------------------------------------------------------- /deepstruct/flexible_transform.py: -------------------------------------------------------------------------------- 1 | import networkx 2 | import torch 3 | 4 | from deepstruct.node_map_strategies import CustomNodeMap 5 | from deepstruct.node_map_strategies import HighLevelNodeMap 6 | from deepstruct.traverse_strategies import FXTraversal 7 | from deepstruct.traverse_strategies import TraversalStrategy 8 | 9 | 10 | class GraphTransform: 11 | def __init__( 12 | self, 13 | random_input, 14 | traversal_strategy: TraversalStrategy = FXTraversal(), 15 | node_map_strategy: CustomNodeMap = HighLevelNodeMap(), 16 | ): 17 | self.random_input = random_input 18 | self.traversal_strategy = traversal_strategy 19 | self.node_map_strategy = node_map_strategy 20 | 21 | def transform(self, model: torch.nn.Module): 22 | try: 23 | self.traversal_strategy.init(self.node_map_strategy) 24 | self.traversal_strategy.traverse(self.random_input, model) 25 | finally: 26 | self.traversal_strategy.restore_traversal() 27 | 28 | def get_graph(self) -> networkx.DiGraph: 29 | return self.traversal_strategy.get_graph() 30 | -------------------------------------------------------------------------------- /deepstruct/graph.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import List 4 | 5 | import networkx 6 | import networkx as nx 7 | import numpy as np 8 | 9 | 10 | class LayeredGraph(nx.DiGraph): 11 | def save(self, path): 12 | struct = nx.relabel_nodes( 13 | self, {name: ix for ix, name in enumerate(self.nodes)} 14 | ) 15 | nx.write_graphml(struct, path) 16 | 17 | @staticmethod 18 | def load(path): 19 | graph = nx.read_graphml(path) 20 | return LayeredGraph.load_from(graph) 21 | 22 | @staticmethod 23 | def load_from(graph: nx.DiGraph) -> LayeredGraph: 24 | structure = CachedLayeredGraph() 25 | map_old2new = {name: ix for ix, name in enumerate(graph.nodes)} 26 | structure.add_nodes_from([map_old2new[v] for v in graph.nodes]) 27 | structure.add_edges_from( 28 | [(map_old2new[s], map_old2new[t]) for (s, t) in graph.edges] 29 | ) 30 | return structure 31 | 32 | @property 33 | def first_layer(self): 34 | """ 35 | :rtype: int 36 | """ 37 | return NotImplementedError() 38 | 39 | @property 40 | def last_layer(self): 41 | """ 42 | :rtype: int 43 | """ 44 | return NotImplementedError() 45 | 46 | @property 47 | def num_layers(self): 48 | """ 49 | :rtype: int 50 | """ 51 | return NotImplementedError() 52 | 53 | @property 54 | def first_layer_size(self): 55 | """ 56 | :rtype: int 57 | """ 58 | return NotImplementedError() 59 | 60 | @property 61 | def last_layer_size(self): 62 | """ 63 | :rtype: int 64 | """ 65 | return NotImplementedError() 66 | 67 | @property 68 | def layers(self): 69 | """ 70 | :rtype: list[int] 71 | """ 72 | raise NotImplementedError() 73 | 74 | def get_layer(self, vertex: int): 75 | """ 76 | :rtype: int 77 | """ 78 | raise NotImplementedError() 79 | 80 | def get_vertices(self, layer: int): 81 | """ 82 | :rtype: list[int] 83 | """ 84 | raise NotImplementedError() 85 | 86 | def get_layer_size(self, layer: int): 87 | """ 88 | :rtype: int 89 | """ 90 | raise NotImplementedError() 91 | 92 | def layer_connected(self, layer_index1: int, layer_index2: int): 93 | """ 94 | :rtype: bool 95 | """ 96 | raise NotImplementedError() 97 | 98 | def layer_connection_size(self, layer_index1: int, layer_index2: int): 99 | """ 100 | :rtype: int 101 | """ 102 | raise NotImplementedError() 103 | 104 | 105 | class CachedLayeredGraph(LayeredGraph): 106 | def __init__(self, **attr): 107 | super(CachedLayeredGraph, self).__init__(**attr) 108 | self._has_changed = True 109 | self._layer_index = None 110 | self._vertex_by_layer = None 111 | 112 | def add_cycle(self, nodes, **attr): 113 | super(LayeredGraph, self).add_cycle(nodes, **attr) 114 | self._has_changed = True 115 | 116 | def add_edge(self, u_of_edge, v_of_edge, **attr): 117 | super(LayeredGraph, self).add_edge(u_of_edge, v_of_edge, **attr) 118 | self._has_changed = True 119 | 120 | def add_edges_from(self, ebunch_to_add, **attr): 121 | super(LayeredGraph, self).add_edges_from(ebunch_to_add, **attr) 122 | self._has_changed = True 123 | 124 | def add_node(self, node_for_adding, **attr): 125 | super(LayeredGraph, self).add_node(node_for_adding, **attr) 126 | self._has_changed = True 127 | 128 | def add_nodes_from(self, nodes_for_adding, **attr): 129 | super(LayeredGraph, self).add_nodes_from(nodes_for_adding, **attr) 130 | self._has_changed = True 131 | 132 | def add_path(self, nodes, **attr): 133 | super(LayeredGraph, self).add_path(nodes, **attr) 134 | self._has_changed = True 135 | 136 | def add_star(self, nodes, **attr): 137 | super(LayeredGraph, self).add_star(nodes, **attr) 138 | self._has_changed = True 139 | 140 | def add_weighted_edges_from(self, ebunch_to_add, weight="weight", **attr): 141 | super(LayeredGraph, self).add_weighted_edges_from( 142 | ebunch_to_add, weight="weight", **attr 143 | ) 144 | self._has_changed = True 145 | 146 | def _get_layer_index(self): 147 | if ( 148 | self._has_changed 149 | or self._layer_index is None 150 | or self._vertex_by_layer is None 151 | ): 152 | self._build_layer_index() 153 | self._has_changed = False 154 | 155 | return self._layer_index, self._vertex_by_layer 156 | 157 | def _layer_by_vertex(self, vertex: int): 158 | return self._get_layer_index()[0][vertex] 159 | 160 | def _vertices_by_layer(self, layer: int): 161 | return self._get_layer_index()[1][layer] 162 | 163 | def _build_layer_index(self): 164 | self._layer_index, self._vertex_by_layer = build_layer_index(self) 165 | 166 | @property 167 | def first_layer(self): 168 | """ 169 | :rtype: int 170 | """ 171 | return self.layers[0] 172 | 173 | @property 174 | def last_layer(self): 175 | """ 176 | :rtype: int 177 | """ 178 | return self.layers[-1] 179 | 180 | @property 181 | def num_layers(self): 182 | return len(self.layers) 183 | 184 | @property 185 | def first_layer_size(self): 186 | return self.get_layer_size(self.layers[0]) 187 | 188 | @property 189 | def last_layer_size(self): 190 | return self.get_layer_size(self.layers[-1]) 191 | 192 | @property 193 | def layers(self): 194 | return [layer for layer in self._get_layer_index()[1]] 195 | 196 | def get_layer(self, vertex: int): 197 | return self._layer_by_vertex(vertex) 198 | 199 | def get_vertices(self, layer: int): 200 | return self._vertices_by_layer(layer) 201 | 202 | def get_layer_size(self, layer: int): 203 | return len(self._vertices_by_layer(layer)) 204 | 205 | def layer_connected(self, layer_index1: int, layer_index2: int): 206 | """ 207 | :rtype: bool 208 | """ 209 | if layer_index1 is layer_index2: 210 | raise ValueError( 211 | "Same layer does not have interconnections, it would be split up." 212 | ) 213 | if layer_index1 > layer_index2: 214 | tmp = layer_index2 215 | layer_index2 = layer_index1 216 | layer_index1 = tmp 217 | 218 | for source_vertex in self.get_vertices(layer_index1): 219 | for target_vertex in self.get_vertices(layer_index2): 220 | if self.has_edge(source_vertex, target_vertex): 221 | return True 222 | return False 223 | 224 | def layer_connection_size(self, layer_index1: int, layer_index2: int): 225 | """ 226 | :rtype: int 227 | """ 228 | if layer_index1 is layer_index2: 229 | raise ValueError( 230 | "Same layer does not have interconnections, it would be split up." 231 | ) 232 | if layer_index1 > layer_index2: 233 | tmp = layer_index2 234 | layer_index2 = layer_index1 235 | layer_index1 = tmp 236 | 237 | size = 0 238 | for source_vertex in self.get_vertices(layer_index1): 239 | for target_vertex in self.get_vertices(layer_index2): 240 | if self.has_edge(source_vertex, target_vertex): 241 | size += 1 242 | return size 243 | 244 | 245 | class LabeledDAG(LayeredGraph): 246 | """ 247 | Directed acyclic graph in which the order of vertices matters as they are enumerated. 248 | The implementation makes sure you add no cycles. 249 | """ 250 | 251 | def __init__(self, **attr): 252 | super(LabeledDAG, self).__init__(**attr) 253 | self._layer_index = {} 254 | self._vertex_by_layer = {} 255 | self._update_indices() 256 | self._has_changed = False 257 | 258 | def _update_indices(self): 259 | self._layer_index, self._vertex_by_layer = build_layer_index( 260 | self, self._layer_index 261 | ) 262 | 263 | def _get_layer_index(self): 264 | if ( 265 | self._has_changed 266 | or self._layer_index is None 267 | or self._vertex_by_layer is None 268 | ): 269 | self._update_indices() 270 | self._has_changed = False 271 | 272 | return self._layer_index, self._vertex_by_layer 273 | 274 | def _layer_by_vertex(self, vertex: int): 275 | return self._get_layer_index()[0][vertex] 276 | 277 | def _vertices_by_layer(self, layer: int): 278 | return self._get_layer_index()[1][layer] 279 | 280 | def index_in_layer(self, vertex): 281 | layer_vertices = self.get_vertices(self.get_layer(vertex)) 282 | return layer_vertices.index(vertex) 283 | """match = np.where(self.get_vertices(self.get_layer(vertex)) == vertex) 284 | return match[0][0] if len(match) > 0 else None""" 285 | 286 | def add_vertex(self, layer: int = 0, **kwargs): 287 | assert layer >= 0 288 | 289 | new_node = len(self.nodes) 290 | if layer not in self._vertex_by_layer: 291 | self._vertex_by_layer[layer] = [] 292 | self._vertex_by_layer[layer].append(new_node) 293 | self._layer_index[new_node] = layer 294 | super().add_node(new_node, **kwargs) 295 | return new_node 296 | 297 | def add_vertices(self, num_vertices: int, layer: int = 0): 298 | assert num_vertices > 0 299 | assert layer >= 0 300 | 301 | new_nodes = np.arange(len(self), len(self) + num_vertices) 302 | if layer not in self._vertex_by_layer: 303 | self._vertex_by_layer[layer] = [] 304 | self._vertex_by_layer[layer].extend(new_nodes) 305 | self._layer_index.update(dict.fromkeys(new_nodes, layer)) 306 | super().add_nodes_from(new_nodes) 307 | return new_nodes 308 | 309 | def append(self, other: LabeledDAG): 310 | assert self.last_layer is not None 311 | assert other is not None 312 | assert other.first_layer is not None 313 | assert other.first_layer != other.last_layer 314 | assert self.last_layer_size == other.first_layer_size 315 | 316 | offset_layer = self.last_layer 317 | 318 | for layer in other.layers[1:]: 319 | own_layer_target = offset_layer + layer 320 | self.add_vertices(other.get_layer_size(layer), own_layer_target) 321 | for oth_v_idx, oth_v in enumerate(other.get_vertices(layer)): 322 | own_v = self.get_vertices(own_layer_target)[oth_v_idx] 323 | for oth_u, _ in other.in_edges(oth_v): 324 | oth_u_idx = other.index_in_layer(oth_u) 325 | own_layer_source = offset_layer + other.get_layer(oth_u) 326 | own_u = self.get_vertices(own_layer_source)[oth_u_idx] 327 | self.add_edge(own_u, own_v) 328 | 329 | def add_edge(self, u_of_edge, v_of_edge, **attr): 330 | new_layer_source = ( 331 | 0 if "source_layer" not in attr else int(attr["source_layer"]) 332 | ) 333 | source = ( 334 | u_of_edge 335 | if u_of_edge in self.nodes 336 | else self.add_vertex(layer=new_layer_source) 337 | ) 338 | layer_source = self.get_layer(source) 339 | new_layer_target = ( 340 | max(1, layer_source + 1) 341 | if "target_layer" not in attr 342 | else int(attr["target_layer"]) 343 | ) 344 | target = ( 345 | v_of_edge 346 | if v_of_edge in self.nodes 347 | else self.add_vertex(layer=new_layer_target) 348 | ) 349 | layer_target = self.get_layer(target) 350 | assert ( 351 | layer_source < layer_target 352 | ), "Can only add edges from lower layers numbers to higher layer numbers. We found L({source})={slayer} >= L({target})={tlayer}".format( 353 | source=source, slayer=layer_source, target=target, tlayer=layer_target 354 | ) 355 | super().add_edge(source, target, **attr) 356 | 357 | def add_nodes_from(self, nodes_for_adding, **attr): 358 | return self.add_vertices(len(nodes_for_adding)) 359 | 360 | def add_node(self, node_for_adding, **attr): 361 | return self.add_vertex(layer=0 if "layer" not in attr else attr["layer"]) 362 | 363 | def _get_next_layer_or_param(self, layer_source: int, **attr): 364 | return ( 365 | max(1, layer_source + 1) 366 | if "target_layer" not in attr 367 | else int(attr["target_layer"]) 368 | ) 369 | 370 | def add_edges_from(self, ebunch_to_add, **attr): 371 | new_layer_source = ( 372 | 0 if "source_layer" not in attr else int(attr["source_layer"]) 373 | ) 374 | source_map = {} 375 | target_map = {} 376 | edges = [] 377 | for s, t in ebunch_to_add: 378 | if s not in source_map: 379 | source_map[s] = ( 380 | s if s in self.nodes else self.add_vertex(new_layer_source) 381 | ) 382 | if t not in target_map: 383 | target_map[t] = ( 384 | t 385 | if t in self.nodes 386 | else self.add_vertex( 387 | self._get_next_layer_or_param( 388 | self.get_layer(source_map[s]), **attr 389 | ) 390 | ) 391 | ) 392 | edges.append((source_map[s], target_map[t])) 393 | super().add_edges_from(edges) 394 | 395 | @property 396 | def first_layer(self): 397 | """ 398 | :rtype: int 399 | """ 400 | return self.layers[0] if self.num_layers > 0 else None 401 | 402 | @property 403 | def last_layer(self): 404 | """ 405 | :rtype: int 406 | """ 407 | return self.layers[-1] if self.num_layers > 0 else None 408 | 409 | @property 410 | def num_layers(self): 411 | return len(self.layers) 412 | 413 | @property 414 | def first_layer_size(self): 415 | return self.get_layer_size(self.layers[0]) 416 | 417 | @property 418 | def last_layer_size(self): 419 | return self.get_layer_size(self.layers[-1]) 420 | 421 | @property 422 | def layers(self): 423 | return [layer for layer in self._get_layer_index()[1]] 424 | 425 | def get_layer(self, vertex: int): 426 | return self._layer_by_vertex(vertex) 427 | 428 | def get_vertices(self, layer: int): 429 | return self._vertices_by_layer(layer) 430 | 431 | def get_layer_size(self, layer: int): 432 | return len(self._vertices_by_layer(layer)) 433 | 434 | def layer_connected(self, layer_index1: int, layer_index2: int): 435 | """ 436 | :rtype: bool 437 | """ 438 | if layer_index1 is layer_index2: 439 | raise ValueError( 440 | "Same layer does not have interconnections, it would be split up." 441 | ) 442 | if layer_index1 > layer_index2: 443 | tmp = layer_index2 444 | layer_index2 = layer_index1 445 | layer_index1 = tmp 446 | 447 | for source_vertex in self.get_vertices(layer_index1): 448 | for target_vertex in self.get_vertices(layer_index2): 449 | if self.has_edge(source_vertex, target_vertex): 450 | return True 451 | return False 452 | 453 | def layer_connection_size(self, layer_index1: int, layer_index2: int): 454 | """ 455 | :rtype: int 456 | """ 457 | if layer_index1 is layer_index2: 458 | raise ValueError( 459 | "Same layer does not have interconnections, it would be split up." 460 | ) 461 | if layer_index1 > layer_index2: 462 | tmp = layer_index2 463 | layer_index2 = layer_index1 464 | layer_index1 = tmp 465 | 466 | size = 0 467 | for source_vertex in self.get_vertices(layer_index1): 468 | for target_vertex in self.get_vertices(layer_index2): 469 | if self.has_edge(source_vertex, target_vertex): 470 | size += 1 471 | return size 472 | 473 | 474 | class MarkableDAG(LabeledDAG): 475 | def add_connection(self, mark: str, s_idx: int, t_idx: int): 476 | pass 477 | 478 | 479 | def uniform_proportions(graph: nx.Graph): 480 | """ 481 | Samples weights from a dirichlet distribution for each vertex of a given graph such that all of them add up to one 482 | and are almost uniformly distributed. 483 | 484 | :param graph: 485 | :return: 486 | """ 487 | return { 488 | v: p 489 | for v, p in zip( 490 | graph.nodes, np.random.dirichlet(np.ones(len(graph.nodes)) * 100) 491 | ) 492 | } 493 | 494 | 495 | def build_layer_index(graph: nx.DiGraph, layer_index=None): 496 | """ 497 | 498 | :param graph: 499 | :type graph igraph.Graph 500 | :param layer_index: 501 | :return: 502 | """ 503 | if layer_index is None: 504 | layer_index = {} 505 | 506 | def get_layer_index(vertex, graph: nx.DiGraph): 507 | assert vertex is not None, "Given vertex was none." 508 | try: 509 | vertex = int(vertex) 510 | except TypeError: 511 | raise ValueError("You have to pass vertex indices to this function.") 512 | 513 | if vertex not in layer_index: 514 | # Recursively calling itself 515 | layer_index[vertex] = ( 516 | max( 517 | [ 518 | get_layer_index(v, graph) 519 | for v in nx.algorithms.dag.ancestors(graph, vertex) 520 | ] 521 | + [-1] 522 | ) 523 | + 1 524 | ) 525 | return layer_index[vertex] 526 | 527 | for v in graph: 528 | get_layer_index(v, graph) 529 | 530 | vertices_by_layer = {} 531 | for v in layer_index: 532 | idx = layer_index[v] 533 | if idx not in vertices_by_layer: 534 | vertices_by_layer[idx] = [] 535 | vertices_by_layer[idx].append(v) 536 | 537 | return layer_index, vertices_by_layer 538 | 539 | 540 | class LayeredFXGraph(networkx.DiGraph): 541 | 542 | def __init__(self, **attr): 543 | super(LayeredFXGraph, self).__init__(**attr) 544 | self._node_name_data = ( 545 | {} 546 | ) # information for a name -> [layer_number, [indices], output_layer_size] 547 | self._mask_for_name = {} 548 | self.edges_for_name = {} 549 | self.ignored_nodes = [] 550 | 551 | def get_next_layer_index(self): 552 | return len(self._node_name_data) 553 | 554 | def get_output_layer_len(self, node_name): 555 | data = self._node_name_data.get(node_name, None) 556 | return data[2] if data is not None else 0 557 | 558 | def get_layer_number(self, node_name): 559 | data = self._node_name_data.get(node_name, None) 560 | return data[0] if data is not None else None 561 | 562 | def get_indices_for_name(self, node_name): 563 | data = self._node_name_data.get(node_name, None) 564 | return data[1] if data is not None else None 565 | 566 | def add_vertices(self, count: int, name, output_layer_size=0, layer=None, **kwargs): 567 | node_data = [] 568 | node_indices = [] 569 | mask = kwargs.pop("mask", None) 570 | if mask is not None: 571 | self._mask_for_name[name] = mask 572 | if layer is None: 573 | layer = self.get_next_layer_index() 574 | node_data.append(layer) 575 | for _ in range(count): 576 | node_indices.append(self._add_vertex(name, **kwargs)) 577 | node_data.append(node_indices) 578 | node_data.append(output_layer_size) 579 | self._node_name_data[name] = node_data 580 | return node_indices 581 | 582 | def _add_vertex(self, node_name, **kwargs): 583 | next_node_index = len(self.nodes) 584 | super().add_node(next_node_index, name=node_name, **kwargs) 585 | return next_node_index 586 | 587 | def add_edges(self, source_node_names: List, target_node_name): 588 | target_indices = self.get_indices_for_name(target_node_name) 589 | source_node_names = self._flatten_args(source_node_names) 590 | for source_node_name in source_node_names: 591 | s_n = str(source_node_name) 592 | edges = self.edges_for_name.pop(s_n, None) 593 | source_indices = self._determine_source_indices(s_n) 594 | if ( 595 | source_indices is not None 596 | ): # ignore nodes that were not added to the graph before e.g. constants 597 | self._add_edges( 598 | source_indices, 599 | target_indices, 600 | self._mask_for_name.pop(s_n, None), 601 | edges, 602 | ) 603 | 604 | def _flatten_args(self, nested_list): 605 | flat_list = [] 606 | for element in nested_list: 607 | if isinstance(element, (list, tuple)): 608 | flat_list.extend(self._flatten_args(element)) 609 | else: 610 | flat_list.append(element) 611 | return flat_list 612 | 613 | def _determine_source_indices(self, source_node_name): 614 | if source_node_name in self.ignored_nodes: 615 | values = list(self._node_name_data.values()) 616 | assert len(values) > 1 617 | penultimate = values[-2] 618 | return penultimate[1] # the layer that was added before the current node 619 | else: 620 | return self.get_indices_for_name(source_node_name) 621 | 622 | def _add_edges(self, source_indices, target_indices, mask, edges): 623 | if edges is not None: 624 | super().add_edges_from(edges) 625 | elif mask is None: 626 | for s_i in source_indices: 627 | for t_i in target_indices: 628 | super().add_edge(s_i, t_i) 629 | else: 630 | target_counter = 0 631 | for t_i in target_indices: 632 | mask_slice = mask[target_counter] 633 | target_counter += 1 634 | source_counter = 0 635 | for s_i in source_indices: 636 | if mask_slice.numel() > 0 and mask_slice[source_counter]: 637 | super().add_edge(s_i, t_i) 638 | source_counter += 1 639 | -------------------------------------------------------------------------------- /deepstruct/models.py: -------------------------------------------------------------------------------- 1 | import deepstruct.sparse 2 | 3 | 4 | class DeepGraphModule(deepstruct.sparse.MaskableModule): 5 | pass 6 | 7 | 8 | class DeepOperationModule(deepstruct.sparse.MaskableModule): 9 | pass 10 | -------------------------------------------------------------------------------- /deepstruct/node_map_strategies.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import torch.nn 4 | 5 | from deepstruct.graph import LayeredFXGraph 6 | 7 | 8 | class NodeMapper: 9 | 10 | @abstractmethod 11 | def add_node(self, graph, predecessors, **kwargs): 12 | pass 13 | 14 | 15 | class CustomNodeMap: 16 | 17 | def __init__(self, node_mappers: dict, default_mapper: NodeMapper): 18 | self.node_mappers = node_mappers 19 | self.default_mapper = default_mapper 20 | 21 | def map_node(self, graph, module, predecessors, **kwargs): 22 | kwargs["module"] = module 23 | self.node_mappers.get(module, self.default_mapper).add_node( 24 | graph, predecessors, **kwargs 25 | ) 26 | 27 | 28 | class HighLevelNodeMap(CustomNodeMap): 29 | 30 | def __init__(self): 31 | super().__init__({}, All2VertexNodeMapper()) 32 | 33 | 34 | class LowLevelNodeMap(CustomNodeMap): 35 | 36 | def __init__(self, threshold=-1): 37 | super().__init__( 38 | { 39 | torch.nn.Linear: Linear2LayerMapper(threshold), 40 | torch.nn.Conv2d: Conv2LayerMapper(threshold), 41 | }, 42 | All2VertexNodeMapper(), 43 | ) 44 | 45 | 46 | class All2VertexNodeMapper(NodeMapper): 47 | 48 | def add_node(self, graph: LayeredFXGraph, predecessors, **kwargs): 49 | node_name = kwargs.pop("name") 50 | output_layer_len = 0 51 | for pred in predecessors: 52 | output_layer_len += graph.get_output_layer_len(str(pred)) 53 | if output_layer_len == 0: 54 | graph.add_vertices(1, node_name, **kwargs) 55 | else: 56 | graph.add_vertices(output_layer_len, node_name, **kwargs) 57 | graph.add_edges(predecessors, node_name) 58 | 59 | 60 | class Linear2LayerMapper(NodeMapper): 61 | 62 | def __init__(self, threshold): 63 | self.threshold = threshold 64 | 65 | def add_node(self, graph: LayeredFXGraph, predecessors, **kwargs): 66 | model = kwargs.get("origin_module") 67 | in_features = model.in_features 68 | out_features = model.out_features 69 | mask = torch.ones((out_features, in_features), dtype=torch.bool) 70 | mask[torch.where(abs(model.weight) < self.threshold)] = False # L1-Pruning 71 | kwargs["mask"] = mask 72 | node_name = kwargs.pop("name") 73 | graph.add_vertices( 74 | in_features, node_name, output_layer_size=out_features, **kwargs 75 | ) 76 | graph.add_edges(predecessors, node_name) 77 | 78 | 79 | class Conv2LayerMapper(NodeMapper): 80 | 81 | def __init__(self, threshold): 82 | self.threshold = threshold 83 | 84 | def add_node(self, graph: LayeredFXGraph, predecessors, **kwargs): 85 | model = kwargs.get("origin_module") 86 | shape = kwargs.get("shape") 87 | node_name = kwargs.pop("name") 88 | channels_in = model.in_channels 89 | channels_out = model.out_channels 90 | size_kernel = model.kernel_size 91 | stride = model.stride 92 | padding = model.padding 93 | 94 | def input_shape(size, dim): 95 | return int((size - 1) * stride[dim] - 2 * padding[dim] + size_kernel[dim]) 96 | 97 | output_width = shape[-1] 98 | output_height = shape[-2] 99 | input_height = input_shape(output_height, 0) 100 | input_width = input_shape(output_width, 1) 101 | input_neurons_count = channels_in * input_width * input_height 102 | output_neurons_count = channels_out * output_height * output_width 103 | input_neurons = graph.add_vertices( 104 | input_neurons_count, 105 | node_name, 106 | output_layer_size=output_neurons_count, 107 | **kwargs 108 | ) 109 | output_neurons = [ 110 | input_neurons[-1] + i + 1 for i in range(output_neurons_count) 111 | ] 112 | 113 | def get_input_neuron(channel: int, row: int, col: int): 114 | return int( 115 | input_neurons[ 116 | int( 117 | (col * input_height + row) 118 | + (channel * input_width * input_height) 119 | ) 120 | ] 121 | ) 122 | 123 | def get_output_neuron(channel_out: int, row: int, col: int): 124 | return int( 125 | output_neurons[ 126 | int( 127 | (col * output_height + row) 128 | + (channel_out * output_width * output_height) 129 | ) 130 | ] 131 | ) 132 | 133 | edges = [] 134 | for idx_channel_out in range(channels_out): 135 | for idx_channel_in in range(channels_in): 136 | out_col = 0 137 | offset_height = -padding[0] 138 | while offset_height + size_kernel[0] <= input_height: 139 | out_row = 0 140 | offset_width = -padding[1] 141 | while offset_width + size_kernel[1] <= input_width: 142 | target = get_output_neuron(idx_channel_out, out_row, out_col) 143 | for col in range( 144 | max(0, offset_height), 145 | min(offset_height + size_kernel[0], input_height), 146 | ): 147 | for row in range( 148 | max(0, offset_width), 149 | min(offset_width + size_kernel[1], input_width), 150 | ): 151 | source = get_input_neuron(idx_channel_in, row, col) 152 | edges.append((source, target)) 153 | offset_width += stride[1] 154 | out_row += 1 155 | offset_height += stride[0] 156 | out_col += 1 157 | graph.edges_for_name[node_name] = edges 158 | graph.add_edges(predecessors, node_name) 159 | -------------------------------------------------------------------------------- /deepstruct/pruning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innvariant/deepstruct/0ffbaf73425c063485cb173f5c0fc3677368b522/deepstruct/pruning/__init__.py -------------------------------------------------------------------------------- /deepstruct/pruning/engine.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | 4 | class Engine(ABC): 5 | def __init__(self, fn_prune): 6 | pass 7 | 8 | def run(self): 9 | pass 10 | -------------------------------------------------------------------------------- /deepstruct/pruning/strategy.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | 4 | class PruningStrategy(ABC): 5 | def __init__(self, fn_prune): 6 | pass 7 | 8 | def run(self): 9 | pass 10 | 11 | 12 | class AbsoluteThresholdStrategy(PruningStrategy): 13 | def run(self): 14 | pass 15 | 16 | 17 | class RelativeThresholdStrategy(PruningStrategy): 18 | def run(self): 19 | pass 20 | 21 | 22 | class BucketFillStrategy(PruningStrategy): 23 | def run(self): 24 | pass 25 | -------------------------------------------------------------------------------- /deepstruct/pruning/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import deepstruct.sparse 4 | 5 | 6 | def set_random_saliency(model: torch.nn.Module): 7 | # set saliency to random values 8 | for layer in deepstruct.sparse.maskable_layers(model): 9 | layer.saliency = torch.rand_like(layer.weight) * layer.mask 10 | 11 | 12 | def set_random_masks(module: torch.nn.Module): 13 | if isinstance(module, deepstruct.sparse.MaskedLinearLayer): 14 | module.mask = torch.round(torch.rand_like(module.weight)) 15 | 16 | 17 | def set_distributed_saliency(module: torch.nn.Module): 18 | # prune from each layer the according number of elements 19 | for layer in deepstruct.sparse.maskable_layers(module): 20 | # calculate standard deviation for the layer 21 | w = layer.weight.data 22 | st_v = 1 / w.std() 23 | # set the saliency in the layer = weight/st.deviation 24 | layer.saliency = st_v * w.abs() 25 | 26 | 27 | def reset_pruned_network(module: torch.nn.Module): 28 | for layer in deepstruct.sparse.maskable_layers(module): 29 | layer.reset_parameters(keep_mask=True) 30 | 31 | 32 | def keep_input_layerwise(module: torch.nn.Module): 33 | for layer in deepstruct.sparse.maskable_layers(module): 34 | layer.keep_layer_input = True 35 | 36 | 37 | def get_network_weight_count(module: torch.nn.Module): 38 | total_weights = 0 39 | for layer in deepstruct.sparse.maskable_layers(module): 40 | total_weights += layer.get_weight_count() 41 | return total_weights 42 | -------------------------------------------------------------------------------- /deepstruct/recurrent.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torch.autograd import Variable 7 | from torch.nn import Parameter 8 | from torch.nn import init 9 | 10 | 11 | class BaseRecurrentLayer(nn.Module): 12 | """ 13 | Base class for recurrent layers which can be masked and have an additional unfold() operation. 14 | 15 | Args: 16 | input_size: The number of expected features in the input 17 | hidden_size: The number of features in the hidden state 18 | batch_first: If True, then the input and output tensors are provided 19 | as (batch, seq, feature). Default: False 20 | """ 21 | 22 | def __init__( 23 | self, input_size: int, hidden_size: int, batch_first: bool = False, **kwargs 24 | ): 25 | super().__init__() 26 | assert input_size > 0 27 | assert hidden_size > 0 28 | 29 | self._input_size = input_size 30 | self._hidden_size = hidden_size 31 | self._batch_first = True if batch_first else False 32 | 33 | self._initialize_parameters() 34 | self.reset_parameters() 35 | 36 | def _initialize_parameters(self): 37 | input_size = self._input_size 38 | hidden_size = self._hidden_size 39 | 40 | self._weight_ih = Parameter(torch.randn(hidden_size, input_size)) 41 | self._weight_hh = Parameter(torch.randn(hidden_size, hidden_size)) 42 | self._bias_ih = Parameter(torch.randn(hidden_size)) 43 | self._bias_hh = Parameter(torch.randn(hidden_size)) 44 | 45 | self.register_buffer( 46 | "_mask_i2h", torch.ones((hidden_size, input_size), dtype=torch.bool) 47 | ) 48 | self.register_buffer( 49 | "_mask_h2h", torch.ones((hidden_size, hidden_size), dtype=torch.bool) 50 | ) 51 | 52 | def reset_parameters(self, keep_mask=False): 53 | # TODO should weight initialization be done here? 54 | stdv = 1.0 / math.sqrt(self._hidden_size) 55 | for weight in self.parameters(): 56 | init.uniform_(weight, -stdv, stdv) 57 | 58 | def set_i2h_mask(self, mask): 59 | self._mask_i2h = Variable(mask) 60 | 61 | def set_h2h_mask(self, mask): 62 | self._mask_h2h = Variable(mask) 63 | 64 | def unfold(self, input, hx): 65 | in_dim = 1 if self._batch_first else 0 66 | n_seq = input.size(in_dim) 67 | outputs = [] 68 | 69 | for i in range(n_seq): 70 | seq = input[:, i, :] if self._batch_first else input[i] 71 | hx = self.forward(seq, hx) 72 | outputs.append(hx.unsqueeze(in_dim)) 73 | 74 | return torch.cat(outputs, dim=in_dim) 75 | 76 | def extra_repr(self): 77 | s = "in_features={_input_size}, out_features={_hidden_size}" 78 | 79 | if self._batch_first: 80 | s += ", batch_first={_batch_first}" 81 | 82 | return s.format(**self.__dict__) 83 | 84 | 85 | class MaskedRecurrentLayer(BaseRecurrentLayer): 86 | """ 87 | Base class for layer initialization for Vanilla RNN. 88 | 89 | Args: 90 | nonlinearity: Can be a usual torch.nn.ReLU() or torch.nn.Tanh() or torch.nn.LogSigmoid() .. 91 | """ 92 | 93 | def __init__( 94 | self, 95 | input_size: int, 96 | hidden_size: int, 97 | batch_first: bool = False, 98 | nonlinearity=torch.nn.Tanh(), 99 | ): 100 | super().__init__(input_size, hidden_size, batch_first) 101 | assert callable(nonlinearity) 102 | self._nonlinearity = nonlinearity 103 | 104 | def forward(self, input, hx): 105 | igate = torch.mm(input, (self._weight_ih * self._mask_i2h).t()) + self._bias_ih 106 | hgate = torch.mm(hx, (self._weight_hh * self._mask_h2h).t()) + self._bias_hh 107 | 108 | return self._nonlinearity(igate + hgate) 109 | 110 | 111 | class MaskedGRULayer(BaseRecurrentLayer): 112 | """ 113 | Base class for layer initialization for GRU. 114 | """ 115 | 116 | def _initialize_parameters(self): 117 | input_size = self._input_size 118 | hidden_size = self._hidden_size 119 | 120 | gate_size = 3 * hidden_size 121 | 122 | self._weight_ih = Parameter(torch.randn(gate_size, input_size)) 123 | self._weight_hh = Parameter(torch.randn(gate_size, hidden_size)) 124 | self._bias_ih = Parameter(torch.randn(gate_size)) 125 | self._bias_hh = Parameter(torch.randn(gate_size)) 126 | 127 | self.register_buffer( 128 | "_mask_i2h", torch.ones((gate_size, input_size), dtype=torch.bool) 129 | ) 130 | self.register_buffer( 131 | "_mask_h2h", torch.ones((gate_size, hidden_size), dtype=torch.bool) 132 | ) 133 | 134 | def forward(self, input, hx): 135 | igate = torch.mm(input, (self._weight_ih * self._mask_i2h).t()) + self._bias_ih 136 | hgate = torch.mm(hx, (self._weight_hh * self._mask_h2h).t()) + self._bias_hh 137 | 138 | i_reset, i_input, i_new = igate.chunk(3, 1) 139 | h_reset, h_input, h_new = hgate.chunk(3, 1) 140 | 141 | reset_gate = torch.sigmoid(i_reset + h_reset) 142 | input_gate = torch.sigmoid(i_input + h_input) 143 | new_gate = torch.tanh(i_new + reset_gate * h_new) 144 | 145 | hx = new_gate + input_gate * (hx - new_gate) 146 | return hx 147 | 148 | 149 | class MaskedLSTMLayer(BaseRecurrentLayer): 150 | """ 151 | Base class for layer initialization for LSTM. 152 | """ 153 | 154 | def _initialize_parameters(self): 155 | input_size = self._input_size 156 | hidden_size = self._hidden_size 157 | 158 | gate_size = 4 * hidden_size 159 | 160 | self._weight_ih = Parameter(torch.randn(gate_size, input_size)) 161 | self._weight_hh = Parameter(torch.randn(gate_size, hidden_size)) 162 | self._bias_ih = Parameter(torch.randn(gate_size)) 163 | self._bias_hh = Parameter(torch.randn(gate_size)) 164 | self.reset_parameters() 165 | 166 | self.register_buffer( 167 | "_mask_i2h", torch.ones((gate_size, input_size), dtype=torch.bool) 168 | ) 169 | self.register_buffer( 170 | "_mask_h2h", torch.ones((gate_size, hidden_size), dtype=torch.bool) 171 | ) 172 | 173 | def forward(self, input, hx): 174 | hx, cx = hx 175 | igate = torch.mm(input, (self._weight_ih * self._mask_i2h).t()) + self._bias_ih 176 | hgate = torch.mm(hx, (self._weight_hh * self._mask_h2h).t()) + self._bias_hh 177 | 178 | gates = igate + hgate 179 | 180 | input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1) 181 | 182 | input_gate = torch.sigmoid(input_gate) 183 | forget_gate = torch.sigmoid(forget_gate) 184 | cell_gate = torch.tanh(cell_gate) 185 | out_gate = torch.sigmoid(out_gate) 186 | 187 | cx = (forget_gate * cx) + (input_gate * cell_gate) 188 | hx = out_gate * torch.tanh(cx) 189 | return hx, cx 190 | 191 | def unfold(self, input, hx): 192 | in_dim = 1 if self._batch_first else 0 193 | n_seq = input.size(in_dim) 194 | outputs = [] 195 | cx = hx.clone() 196 | 197 | for i in range(n_seq): 198 | seq = input[:, i, :] if self._batch_first else input[i] 199 | hx, cx = self.forward(seq, (hx, cx)) 200 | outputs.append(hx.unsqueeze(in_dim)) 201 | 202 | return torch.cat(outputs, dim=in_dim) 203 | 204 | 205 | class BaseMaskModule(nn.Module): 206 | def apply_mask(self, threshold=0.0, i2h=False, h2h=False): 207 | """ 208 | :param threshold: Amount of pruning to apply. Default '0.0' 209 | :param i2h: If True, then Input-to-Hidden layers will be masked. Default 'False' 210 | :param h2h: If True, then Hidden-to-Hidden layers will be masked. Default 'False' 211 | :type threshold: float 212 | :type i2h: bool 213 | :type h2h: bool 214 | """ 215 | if not i2h and not h2h: 216 | return 217 | 218 | masks = self.__get_masks(threshold, i2h, h2h) 219 | for lay_idx, layer in enumerate(self._recurrent_layers): 220 | if i2h: 221 | layer.set_i2h_mask(masks[lay_idx][0]) 222 | if h2h: 223 | layer.set_h2h_mask(masks[lay_idx][-1]) 224 | 225 | def __get_masks(self, threshold, i2h, h2h): 226 | key = "" if i2h and h2h else "ih" if i2h else "hh" if h2h else None 227 | 228 | masks = {} 229 | for lay_idx, layer in enumerate(self._recurrent_layers): 230 | masks[lay_idx] = [] 231 | for param, data in layer.named_parameters(): 232 | if "bias" not in param and key in param: 233 | mask = torch.ones(data.shape, dtype=torch.bool, device=data.device) 234 | mask[torch.where(abs(data) < threshold)] = False 235 | masks[lay_idx].append(mask) 236 | 237 | return masks 238 | 239 | 240 | class MaskedDeepRNN(BaseMaskModule): 241 | """ 242 | A deep Vanilla RNN model with maskable layers to allow for sparsity. 243 | 244 | Args: 245 | input_size: The number of expected features in the input 246 | hidden_layers: A list specifying number of expected features in each hidden layer 247 | (E.g, hidden_layers=[50, 50] specifies model consisting two hidden layers with 50 features each) 248 | nonlinearity: Can be a usual non-linearity such as torch.nn.ReLU() or torch.nn.Tanh() 249 | batch_first: If True, then the input and output tensors are provided 250 | as (batch, seq, feature). Default: False 251 | """ 252 | 253 | def __init__( 254 | self, 255 | input_size, 256 | hidden_layers: list, 257 | build_recurrent_layer=MaskedRecurrentLayer, 258 | nonlinearity=torch.nn.Tanh(), 259 | batch_first=False, 260 | ): 261 | super(MaskedDeepRNN, self).__init__() 262 | 263 | assert callable(nonlinearity) 264 | 265 | self._input_size = input_size 266 | self._hidden_layers = hidden_layers 267 | self._batch_first = batch_first 268 | 269 | layer_list = [] 270 | for lay, hidden_size in enumerate(hidden_layers): 271 | input_size = input_size if lay == 0 else hidden_layers[lay - 1] 272 | layer_list.append( 273 | build_recurrent_layer( 274 | input_size, hidden_size, batch_first, nonlinearity=nonlinearity 275 | ) 276 | ) 277 | # MaskedRecurrentLayer(input_size, hidden_size, nonlinearity, batch_first) 278 | self._recurrent_layers = nn.ModuleList(layer_list) 279 | 280 | def forward(self, input): 281 | batch_size = input.size(0) if self._batch_first else input.size(1) 282 | 283 | for layer, hidden_size in zip(self._recurrent_layers, self._hidden_layers): 284 | # TODO initialization of first hidden states should be configurable from outside 285 | hx = torch.zeros( 286 | batch_size, hidden_size, dtype=input.dtype, device=input.device 287 | ) 288 | input = layer.unfold(input, hx) 289 | 290 | output = input[:, -1, :] if self._batch_first else input[-1] 291 | return output 292 | -------------------------------------------------------------------------------- /deepstruct/scalable.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | from typing import Tuple 4 | from typing import Union 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | from deepstruct.graph import CachedLayeredGraph 11 | from deepstruct.graph import LayeredGraph 12 | from deepstruct.sparse import MaskedDeepDAN 13 | from deepstruct.sparse import MaskedDeepFFN 14 | from deepstruct.util import kullback_leibler 15 | 16 | 17 | class ScalableDeepFFN(object): 18 | _precision: int = 2 19 | _epsilon: float = 0.001 20 | _entropy_equivalence_epsilon: float = 0.1 21 | 22 | def __init__(self, proportions: np.ndarray): 23 | assert proportions is not None 24 | assert len(proportions) > 0 25 | assert np.abs(np.sum(proportions) - 1) < self._epsilon 26 | 27 | self._proportions = proportions 28 | 29 | @property 30 | def entropy_similarity_epsilon(self) -> float: 31 | return self._entropy_equivalence_epsilon 32 | 33 | @entropy_similarity_epsilon.setter 34 | def entropy_similarity_epsilon(self, epsilon): 35 | assert epsilon > 0 36 | self._entropy_equivalence_epsilon = epsilon 37 | 38 | @property 39 | def precision(self) -> int: 40 | return self._precision 41 | 42 | @precision.setter 43 | def precision(self, precision: int): 44 | assert precision > 0 45 | self._precision = precision 46 | 47 | @property 48 | def proportions(self): 49 | return np.round(self._proportions, self.precision) 50 | 51 | def draw(self, scale: int) -> list: 52 | return [round(int(np.maximum(1, size))) for size in scale * self.proportions] 53 | 54 | def build(self, input_shape, output_shape, scale: int) -> nn.Module: 55 | assert scale > 0 56 | 57 | layers = self.draw(scale) 58 | print("l", layers, np.sum(layers)) 59 | return MaskedDeepFFN(input_shape, output_shape, layers) 60 | 61 | def __eq__(self, other): 62 | if not isinstance(other, ScalableDeepFFN): 63 | return False 64 | 65 | return self.entropy_similarity(other) < self.entropy_similarity_epsilon 66 | 67 | def entropy_similarity(self, other) -> float: 68 | assert isinstance(other, ScalableDeepFFN) 69 | max_length = np.maximum(len(self._proportions), len(other._proportions)) 70 | return kullback_leibler( 71 | np.pad(self._proportions, (0, max_length - len(self._proportions))), 72 | np.pad(other._proportions, (0, max_length - len(other._proportions))), 73 | ) 74 | 75 | def __str__(self): 76 | return str(self.draw(np.power(10, self.precision))) 77 | 78 | 79 | class ScalableDAN(object): 80 | _cached_scaled_structure: LayeredGraph = None 81 | _vertex_correspondences: dict = None 82 | 83 | _structure: LayeredGraph 84 | _proportion_map: dict 85 | 86 | def __init__(self, structure: LayeredGraph, proportion_map: dict): 87 | if proportion_map is None: 88 | proportion_map = {v: 1 / len(structure.nodes) for v in structure.nodes} 89 | self.structure = structure 90 | self.proportions = proportion_map 91 | 92 | def reset(self): 93 | self._cached_scaled_structure = None 94 | self._vertex_correspondences = None 95 | 96 | @property 97 | def structure(self): 98 | return self._structure 99 | 100 | @structure.setter 101 | def structure(self, structure): 102 | assert structure is not None 103 | assert len(structure.nodes) > 0 104 | 105 | self._structure = structure 106 | 107 | @property 108 | def proportions(self): 109 | return self._proportion_map 110 | 111 | @proportions.setter 112 | def proportions(self, map): 113 | assert len(map) == len(self.structure.nodes) 114 | assert np.isclose(sum(map.values()), 1) 115 | self._proportion_map = map 116 | 117 | def grow(self, scale: int) -> LayeredGraph: 118 | assert scale > 0 119 | 120 | for layer in self.structure.layers: 121 | for v in self.structure.get_vertices(layer): 122 | self.scale(v, int(np.round(self.proportions[v] * scale))) 123 | 124 | return self._cached_scaled_structure 125 | 126 | def scale(self, vertex, size: int): 127 | graph_scaled = ( 128 | self._cached_scaled_structure 129 | if self._cached_scaled_structure is not None 130 | else CachedLayeredGraph() 131 | ) 132 | vertex_correspondences = ( 133 | self._vertex_correspondences 134 | if self._vertex_correspondences is not None 135 | else {} 136 | ) 137 | nodes_offset = len(graph_scaled.nodes) 138 | 139 | if vertex not in vertex_correspondences: 140 | vertex_correspondences[vertex] = np.array([]) 141 | 142 | vertex_correspondences[vertex] = np.concatenate( 143 | [ 144 | vertex_correspondences[vertex], 145 | np.arange(nodes_offset, nodes_offset + size), 146 | ] 147 | ) 148 | graph_scaled.add_nodes_from( 149 | vertex_correspondences[vertex] 150 | ) # ['v%s_%s' % (v, idx) for idx in range(size)]) 151 | nodes_offset += size 152 | 153 | for source_vertex, _ in self.structure.in_edges(vertex): 154 | graph_scaled.add_edges_from( 155 | itertools.product( 156 | vertex_correspondences[source_vertex], 157 | vertex_correspondences[vertex], 158 | ) 159 | ) 160 | 161 | self._cached_scaled_structure = graph_scaled 162 | self._vertex_correspondences = vertex_correspondences 163 | 164 | return graph_scaled 165 | 166 | def build( 167 | self, 168 | input_shape, 169 | output_shape, 170 | scale: int, 171 | use_layer_norm: bool = True, 172 | return_graph: bool = False, 173 | ) -> Union[nn.Module, Tuple[nn.Module, LayeredGraph]]: 174 | """ 175 | Grows the underlying base graph with the given scale parameter and the specified proportions per vertex and 176 | returns a neural network based on MaskedDeepDAN. 177 | 178 | :param input_shape: Input dimensions for the model to build, e.g. 784 or (28, 28) for MNIST. 179 | :param output_shape: Output dimensions for the model to build, e.g. 10 for MNIST or CIFAR10. 180 | :param scale: Number of neurons to distribute across the blocks of the layer graph. 181 | :param use_layer_norm: Specifies whether the model should use layer normalizations after single blocks 182 | :param return_graph: Specifies whether to return the scaled graph; introduced in 0.10 183 | :return: 184 | """ 185 | assert scale > 0 186 | graph_scaled = self.grow(scale) 187 | self.reset() 188 | 189 | model = ScalableDAN.model( 190 | input_shape, output_shape, graph_scaled, use_layer_norm 191 | ) 192 | return (model, graph_scaled) if return_graph else model 193 | 194 | @staticmethod 195 | def model( 196 | input_shape, output_shape, structure: LayeredGraph, use_layer_norm: bool = True 197 | ): 198 | return MaskedDeepDAN( 199 | input_shape, output_shape, structure, use_layer_norm=use_layer_norm 200 | ) 201 | 202 | 203 | if __name__ == "__main__": 204 | g1 = CachedLayeredGraph() 205 | g1.add_edge(0, 3) 206 | g1.add_edge(0, 5) 207 | g1.add_edge(1, 3) 208 | g1.add_edge(1, 4) 209 | g1.add_edge(1, 5) 210 | g1.add_edge(2, 3) 211 | g1.add_edge(2, 4) 212 | 213 | g1.add_edge(3, 5) 214 | g1.add_edge(3, 6) 215 | g1.add_edge(4, 5) 216 | g1.add_edge(4, 7) 217 | 218 | g1.add_edge(5, 7) 219 | 220 | props = { 221 | v: p 222 | for v, p in zip(g1.nodes, np.random.dirichlet(np.ones(len(g1.nodes)) * 100)) 223 | } 224 | print(props) 225 | fam = ScalableDAN(g1, props) 226 | model = fam.build(8, 4, 1000) 227 | print(model) 228 | for lay in model.layers_main_hidden: 229 | print(lay) 230 | print(torch.sum(lay.mask) / np.prod(lay.mask.shape)) 231 | 232 | for lay in model.layers_skip_hidden: 233 | print(lay) 234 | print(torch.sum(lay.mask) / np.prod(lay.mask.shape)) 235 | -------------------------------------------------------------------------------- /deepstruct/transform.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import warnings 3 | 4 | from functools import partial 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from deepstruct.graph import LabeledDAG 10 | from deepstruct.graph import LayeredGraph 11 | from deepstruct.sparse import MaskedLinearLayer 12 | 13 | 14 | def transform_mask_into_graph(graph: LayeredGraph, mask: torch.Tensor): 15 | assert mask.dtype == torch.bool 16 | assert graph is not None 17 | 18 | 19 | class ModuleVisitor: 20 | def __init__(self, graph: LabeledDAG): 21 | self._graph = graph 22 | 23 | def applies(self, model: torch.nn.Module): 24 | return model is not None 25 | 26 | def visited(self, model): 27 | if not hasattr(model, "_deepstruct_visitors"): 28 | return False 29 | return self in model._deepstruct_visitors 30 | 31 | def mark_visited(self, model): 32 | if self.visited(model): 33 | return 34 | 35 | if not hasattr(model, "_deepstruct_visitors"): 36 | model._deepstruct_visitors = [] 37 | model._deepstruct_visitors.append(self) 38 | 39 | def visit(self, model): 40 | self.mark_visited(model) 41 | 42 | 43 | class ForgetfulFunctor: 44 | def transform(self, model: torch.nn.Module) -> LabeledDAG: 45 | raise NotImplementedError("Abstract method needs to be implemented") 46 | 47 | def applies(self, model: torch.nn.Module): 48 | return True 49 | 50 | 51 | class LinearLayerFunctor(ForgetfulFunctor): 52 | def __init__(self, threshold: float = None): 53 | self._threshold = threshold 54 | 55 | def transform_masked(self, model: MaskedLinearLayer): 56 | if self._threshold is not None: 57 | model.recompute_mask() 58 | 59 | return self.transform_mask(model.mask) 60 | 61 | def transform_linear(self, model: torch.nn.Linear): 62 | assert ( 63 | self._threshold is not None 64 | ), "For transforming a linear layer you need to specify which threshold to use for pruning edges." 65 | 66 | in_features = model.in_features 67 | out_features = model.out_features 68 | mask = torch.ones((out_features, in_features), dtype=torch.bool) 69 | # TODO maybe also allow for non-L1-pruning methods? 70 | mask[torch.where(abs(model.weight) < self._threshold)] = False 71 | 72 | return self.transform_mask(mask) 73 | 74 | def transform_mask(self, mask: torch.tensor): 75 | assert mask is not None 76 | assert mask.dtype == torch.bool 77 | assert len(mask.shape) == 2 78 | 79 | dim_input = mask.shape[1] 80 | dim_output = mask.shape[0] 81 | 82 | graph = LabeledDAG() 83 | 84 | sources = graph.add_vertices(dim_input, layer=0) 85 | targets = graph.add_vertices(dim_output, layer=1) 86 | graph.add_edges_from( 87 | [ 88 | (sources[s], targets[t]) 89 | for (s, t) in itertools.product( 90 | np.arange(dim_input), np.arange(dim_output) 91 | ) 92 | if mask[t, s] 93 | ] 94 | ) 95 | 96 | return graph 97 | 98 | def transform(self, model: torch.nn.Module): 99 | return ( 100 | self.transform_masked(model) 101 | if isinstance(model, MaskedLinearLayer) 102 | else self.transform_linear(model) 103 | ) 104 | 105 | def applies(self, model: torch.nn.Module): 106 | return isinstance(model, torch.nn.Linear) or isinstance( 107 | model, MaskedLinearLayer 108 | ) 109 | 110 | 111 | class Conv2dLayerFunctor(ForgetfulFunctor): 112 | def __init__( 113 | self, input_width: int = 0, input_height: int = 0, threshold: float = None 114 | ): 115 | self._input_width = input_width 116 | self._input_height = input_height 117 | self._threshold = threshold 118 | 119 | @property 120 | def width(self): 121 | return self._input_width 122 | 123 | @property 124 | def height(self): 125 | return self._input_height 126 | 127 | @width.setter 128 | def width(self, width: int): 129 | assert width > 0 130 | self._input_width = width 131 | 132 | @height.setter 133 | def height(self, height: int): 134 | assert height > 0 135 | self._input_height = height 136 | 137 | def transform(self, model: torch.nn.Module) -> LabeledDAG: 138 | # TODO does not respect sparsity in kernel currently 139 | assert isinstance(model, torch.nn.Conv2d) 140 | assert model.dilation == ( 141 | 1, 142 | 1, 143 | ), "Currently dilation is not considered in this implementation" 144 | assert ( 145 | self.width > 0 and self.height > 0 146 | ), "You need to specify input width and height for this functor." 147 | 148 | channels_in = model.in_channels 149 | channels_out = model.out_channels 150 | size_kernel = model.kernel_size 151 | stride = model.stride 152 | padding = model.padding 153 | 154 | assert ( 155 | stride[0] > 0 and stride[1] > 0 156 | ), "Stride must be a natural number and at least be one" 157 | 158 | graph = LabeledDAG() 159 | input_neurons = graph.add_vertices( 160 | channels_in * self._input_width * self._input_height, layer=0 161 | ) 162 | 163 | def output_shape(size, dim): 164 | return int( 165 | np.floor((size - size_kernel[dim] + 2 * padding[dim]) / stride[dim]) + 1 166 | ) 167 | 168 | output_height = output_shape(self._input_height, 0) 169 | output_width = output_shape(self._input_width, 1) 170 | # print("output height", output_height) 171 | # print("output width", output_width) 172 | output_neurons = graph.add_vertices( 173 | channels_out * output_height * output_width, 174 | layer=1, 175 | ) 176 | 177 | # print(len(input_neurons)) 178 | # print(len(output_neurons)) 179 | 180 | def get_input_neuron(channel: int, row: int, col: int): 181 | return int( 182 | input_neurons[ 183 | int( 184 | (col * self._input_height + row) 185 | + (channel * self._input_width * self._input_height) 186 | ) 187 | ] 188 | ) 189 | 190 | def get_output_neuron(channel_out: int, row: int, col: int): 191 | return int( 192 | output_neurons[ 193 | int( 194 | (col * output_height + row) 195 | + (channel_out * output_width * output_height) 196 | ) 197 | ] 198 | ) 199 | 200 | # print("input width", self._input_width) 201 | # print("input height", self._input_height) 202 | # print("stride", stride) 203 | # print("kernel", size_kernel) 204 | 205 | for idx_channel_out in range(channels_out): 206 | # print() 207 | # print("Channel", idx_channel_out) 208 | for idx_channel_in in range(channels_in): 209 | out_col = 0 210 | offset_height = -padding[0] 211 | # print("Offset height initial:", offset_height) 212 | while offset_height + size_kernel[0] <= self._input_height: 213 | # print("Offset height", offset_height) 214 | out_row = 0 215 | offset_width = -padding[1] 216 | # print("Offset width initial:", offset_width) 217 | while offset_width + size_kernel[1] <= self._input_width: 218 | # print("Offset width", offset_width) 219 | # print("out:", (out_row, out_col), end="") 220 | target = get_output_neuron(idx_channel_out, out_row, out_col) 221 | # print("", target) 222 | # tmp_ll = [] 223 | # print(list(range(max(0, offset_height), min(offset_height+size_kernel[0], self._input_height)))) 224 | # print(list(range(max(0, offset_width), min(offset_width+size_kernel[1], self._input_width)))) 225 | # print("colrange(", max(0, offset_height), min(offset_height + size_kernel[0], self._input_height), ")") 226 | edges = [] 227 | for col in range( 228 | max(0, offset_height), 229 | min(offset_height + size_kernel[0], self._input_height), 230 | ): 231 | # print("rowrange(", max(0, offset_width), min(offset_width + size_kernel[1], self._input_width), ")") 232 | for row in range( 233 | max(0, offset_width), 234 | min(offset_width + size_kernel[1], self._input_width), 235 | ): 236 | source = get_input_neuron(idx_channel_in, row, col) 237 | edges.append((source, target)) 238 | # tmp_ll.append((source, target)) 239 | graph.add_edges_from(edges) 240 | # print(tmp_ll) 241 | # if len(tmp_ll) != size_kernel[0]*size_kernel[1]: 242 | # print("---------------- len(tmp_ll) !=", size_kernel[0]*size_kernel[1]) 243 | offset_width += stride[1] 244 | out_row += 1 245 | # print("o1: ", offset_width) 246 | # print("o2: ", self._input_width + padding[1]) 247 | offset_height += stride[0] 248 | out_col += 1 249 | # print() 250 | 251 | return graph 252 | 253 | def applies(self, model: torch.nn.Module): 254 | return isinstance(model, torch.nn.Conv2d) 255 | 256 | 257 | def forward_capture_shape(obj, orig_forward, input, **kwargs): 258 | # print("forward_capture_shape()") 259 | # print("input", input) 260 | # print(input.shape) 261 | obj._deepstruct_input_shape = input.shape 262 | return orig_forward(input) 263 | 264 | 265 | class GraphTransform(ForgetfulFunctor): 266 | """ 267 | Standard zeroth-order transformation from neural networks to graphs. 268 | """ 269 | 270 | def __init__(self, random_input: torch.Tensor): 271 | self.random_input = random_input 272 | self._pointwise_ops = [ 273 | torch.nn.Threshold, 274 | torch.nn.ReLU, 275 | torch.nn.RReLU, 276 | torch.nn.Hardtanh, 277 | torch.nn.ReLU6, 278 | torch.nn.Sigmoid, 279 | torch.nn.Tanh, 280 | torch.nn.ELU, 281 | torch.nn.CELU, 282 | torch.nn.SELU, 283 | torch.nn.GLU, 284 | torch.nn.GELU, 285 | torch.nn.Hardshrink, 286 | torch.nn.LeakyReLU, 287 | torch.nn.LogSigmoid, 288 | torch.nn.Softplus, 289 | torch.nn.Softshrink, 290 | torch.nn.MultiheadAttention, 291 | torch.nn.PReLU, 292 | torch.nn.Softsign, 293 | torch.nn.Tanhshrink, 294 | torch.nn.Softmin, 295 | torch.nn.Softmax, 296 | torch.nn.Softmax2d, 297 | torch.nn.LogSoftmax, 298 | torch.nn.BatchNorm1d, 299 | torch.nn.BatchNorm2d, 300 | torch.nn.BatchNorm3d, 301 | ] 302 | 303 | @property 304 | def random_input(self): 305 | return self._random_input 306 | 307 | @random_input.setter 308 | def random_input(self, random_input: torch.Tensor): 309 | assert random_input is not None 310 | assert hasattr(random_input, "shape") 311 | self._random_input = random_input 312 | 313 | def _punch(self, module: torch.nn.Module): 314 | for child in module.children(): 315 | self._punch(child) 316 | setattr( 317 | module, "forward", partial(forward_capture_shape, module, module.forward) 318 | ) 319 | return module 320 | 321 | def _transform_partial(self, module: torch.nn.Module, graph: LabeledDAG): 322 | assert graph is not None 323 | functor_conv = Conv2dLayerFunctor() 324 | functor_linear = LinearLayerFunctor(threshold=0.01) 325 | 326 | partial = None 327 | if isinstance(module, torch.nn.ModuleList) or isinstance( 328 | module, torch.nn.Sequential 329 | ): 330 | # partial = self.transform(module) 331 | # graph.append(partial) 332 | for child in module: 333 | graph = self._transform_partial(child, graph) 334 | elif isinstance(module, torch.nn.Linear): 335 | partial = functor_linear.transform(module) 336 | graph.append(partial) 337 | elif isinstance(module, torch.nn.Conv2d): 338 | width = module._deepstruct_input_shape[-1] 339 | height = module._deepstruct_input_shape[-2] 340 | functor_conv.width = width 341 | functor_conv.height = height 342 | partial = functor_conv.transform(module) 343 | graph.append(partial) 344 | elif isinstance(module, torch.nn.Dropout): 345 | # Dropout behaves structurally like a linear-layer and we ignore the fact for now that some edges 346 | # are ignored probabilistically 347 | pass 348 | elif isinstance(module, torch.nn.AdaptiveAvgPool2d): 349 | # TODO pooling needs to be transformed; most pooling results in structural singularities 350 | pass 351 | elif any(isinstance(module, op) for op in self._pointwise_ops): 352 | # Point-wise operations (mostly activation functions) do not change the structure 353 | # except for applying non-linear transformations on the input 354 | pass 355 | else: 356 | warnings.warn(f"Warning: ignoring sub-module of type {type(module)}") 357 | 358 | return graph 359 | 360 | def transform(self, model: torch.nn.Module): 361 | graph = LabeledDAG() 362 | 363 | self._punch(model) 364 | model.forward(self.random_input) 365 | graph.add_vertices(np.prod(self.random_input.shape), layer=0) 366 | 367 | for module in model.children(): 368 | graph = self._transform_partial(module, graph) 369 | 370 | return graph 371 | 372 | def applies(self, model: torch.nn.Module): 373 | return all( 374 | isinstance(c, torch.nn.Linear) or isinstance(c, torch.nn.Conv2d) 375 | for c in model.children() 376 | ) 377 | -------------------------------------------------------------------------------- /deepstruct/traverse_strategies.py: -------------------------------------------------------------------------------- 1 | import math 2 | import operator 3 | 4 | from abc import ABC 5 | from abc import abstractmethod 6 | from types import ModuleType 7 | from typing import Any 8 | from typing import Callable 9 | from typing import Dict 10 | from typing import Optional 11 | from typing import Tuple 12 | from typing import Union 13 | 14 | import numpy as np 15 | import numpy.random 16 | import torch 17 | import torch.fx 18 | 19 | from torch.fx.node import Node 20 | from tqdm import tqdm 21 | 22 | from deepstruct.graph import LayeredFXGraph 23 | from deepstruct.node_map_strategies import CustomNodeMap 24 | 25 | 26 | class TraversalStrategy(ABC): 27 | 28 | @abstractmethod 29 | def init(self, node_map_strategy: CustomNodeMap): 30 | pass 31 | 32 | @abstractmethod 33 | def traverse(self, input_tensor: torch.Tensor, model: torch.nn.Module): 34 | pass 35 | 36 | @abstractmethod 37 | def restore_traversal(self): 38 | pass 39 | 40 | @abstractmethod 41 | def get_graph(self): 42 | pass 43 | 44 | 45 | class FXTraversal(TraversalStrategy): 46 | 47 | def __init__( 48 | self, 49 | distribution_fn=np.random.normal, 50 | include_fn=None, 51 | include_modules=None, 52 | exclude_fn=None, 53 | exclude_modules=None, 54 | fold_modules=None, 55 | unfold_modules=None, 56 | ): 57 | self.traced_model = None 58 | self.distribution_fn = distribution_fn 59 | self.include_fn = include_fn 60 | self.include_modules = include_modules 61 | self.exclude_fn = exclude_fn 62 | self.exclude_modules = exclude_modules 63 | self.node_map_strategy = None 64 | self.layered_graph = LayeredFXGraph() 65 | self.fold_modules = fold_modules 66 | self.unfold_modules = unfold_modules 67 | 68 | def init(self, node_map_strategy: CustomNodeMap): 69 | self.node_map_strategy = node_map_strategy 70 | 71 | def traverse(self, input_tensor: torch.Tensor, model: torch.nn.Module): 72 | dist_fn = self.distribution_fn 73 | unfold = self.unfold_modules if self.unfold_modules else [] 74 | fold = self.fold_modules if self.fold_modules else [] 75 | 76 | class CustomTracer(torch.fx.Tracer): 77 | 78 | def __init__( 79 | self, 80 | autowrap_modules: Tuple[ModuleType] = (math,), 81 | autowrap_functions: Tuple[Callable, ...] = (), 82 | param_shapes_constant: bool = False, 83 | ): 84 | super().__init__( 85 | autowrap_modules, autowrap_functions, param_shapes_constant 86 | ) 87 | self.orig_mod = None 88 | 89 | def create_proxy( 90 | self, kind, target, args, kwargs, name=None, type_expr=None, *_, **__ 91 | ): 92 | operators = [ 93 | operator.gt, 94 | operator.ge, 95 | operator.lt, 96 | operator.le, 97 | operator.eq, 98 | operator.ne, 99 | ] 100 | if target and target in operators: 101 | return dist_fn(0, 1) > 0.5 102 | return super().create_proxy(kind, target, args, kwargs, name, type_expr) 103 | 104 | def create_node( 105 | self, 106 | kind: str, 107 | target: Union[str, Callable], 108 | args: Tuple[Any], 109 | kwargs: Dict[str, Any], 110 | name: Optional[str] = None, 111 | type_expr: Optional[Any] = None, 112 | ) -> Node | None: 113 | n = super().create_node(kind, target, args, kwargs, name) 114 | if self.orig_mod is None: 115 | n.orig_mod = target 116 | else: 117 | n.orig_mod = self.orig_mod 118 | self.orig_mod = None 119 | return n 120 | 121 | def call_module( 122 | self, 123 | m: torch.nn.Module, 124 | forward: Callable[..., Any], 125 | args: Tuple[Any, ...], 126 | kwargs: Dict[str, Any], 127 | ) -> Any: 128 | self.orig_mod = m 129 | return super().call_module(m, forward, args, kwargs) 130 | 131 | def is_leaf_module( 132 | self, m: torch.nn.Module, module_qualified_name: str 133 | ) -> bool: 134 | if any(isinstance(m, fm) for fm in fold): 135 | return True 136 | elif any(isinstance(m, um) for um in unfold): 137 | return False 138 | else: 139 | return super().is_leaf_module(m, module_qualified_name) 140 | 141 | traced_graph = CustomTracer().trace(model) 142 | traced = torch.fx.GraphModule(model, traced_graph) 143 | traced_modules = dict(traced.named_modules()) 144 | from torch.fx.passes.shape_prop import ShapeProp 145 | 146 | ShapeProp(traced).propagate(input_tensor) 147 | self.traced_model = traced 148 | self._add_nodes_to_graph(traced, traced_modules) 149 | 150 | def _add_nodes_to_graph(self, traced, traced_modules): 151 | 152 | class EmptyShape: 153 | def __init__(self): 154 | self.shape = None 155 | 156 | for node in tqdm(traced.graph.nodes, desc="Tracing Nodes"): 157 | if node.op == "get_attr": 158 | continue 159 | if self._should_be_included(node): 160 | module_instance = traced_modules.get(node.target) 161 | shape = getattr( 162 | node.meta.get("tensor_meta", EmptyShape()), "shape", None 163 | ) 164 | self.node_map_strategy.map_node( 165 | self.layered_graph, 166 | type(module_instance), 167 | node.args, 168 | name=node.name, 169 | shape=shape, 170 | origin_module=node.orig_mod, 171 | ) 172 | else: 173 | self.layered_graph.ignored_nodes.append(node.name) 174 | 175 | def _should_be_included(self, node): 176 | if node.op == "placeholder" or node.op == "output": 177 | return True 178 | else: 179 | return self._is_in_include(node) and not self._is_in_exclude(node) 180 | 181 | def _is_in_include(self, node): 182 | include_fn = self.include_fn if self.include_fn else [] 183 | include_modules = self.include_modules if self.include_modules else [] 184 | if len(include_fn) == 0 and len(include_modules) == 0: 185 | return True 186 | if node.op == "call_module" and len(include_modules) > 0: 187 | return any(isinstance(node.orig_mod, m) for m in include_modules) 188 | elif len(include_fn) > 0: 189 | return any( 190 | node.orig_mod == f or node.orig_mod == getattr(f, "__name__", None) 191 | for f in include_fn 192 | ) 193 | else: 194 | return True 195 | 196 | def _is_in_exclude(self, node): 197 | exclude_fn = self.exclude_fn if self.exclude_fn else [] 198 | exclude_modules = self.exclude_modules if self.exclude_modules else [] 199 | if len(exclude_modules) == 0 and len(exclude_fn) == 0: 200 | return False 201 | if node.op == "call_module": 202 | return any(isinstance(node.orig_mod, m) for m in exclude_modules) 203 | else: 204 | node_name = getattr(node.orig_mod, "__name__", "name not found in orig") 205 | if node_name == "name not found in orig": 206 | node_name = str(node.orig_mod) 207 | return any( 208 | node.orig_mod == f 209 | or node_name == getattr(f, "__name__", "name not found in fn") 210 | for f in exclude_fn 211 | ) 212 | 213 | def restore_traversal(self): 214 | pass 215 | 216 | def get_graph(self): 217 | return self.layered_graph 218 | -------------------------------------------------------------------------------- /deepstruct/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import numpy as np 5 | 6 | import deepstruct.sparse 7 | 8 | 9 | def generate_hessian_inverse_fc(layer, hessian_inverse_path, layer_input_train_dir): 10 | """ 11 | This function calculate hessian inverse for a fully-connect layer 12 | :param hessian_inverse_path: the hessian inverse path you store 13 | :param layer: the layer weights 14 | :param layer_input_train_dir: layer inputs in the dir 15 | :return: 16 | """ 17 | 18 | w_layer = layer.get_weight().data.numpy().T 19 | n_hidden_1 = w_layer.shape[0] 20 | 21 | # Here we use a recursive way to calculate hessian inverse 22 | hessian_inverse = 1000000 * np.eye(n_hidden_1) 23 | 24 | dataset_size = 0 25 | for input_index, input_file in enumerate(os.listdir(layer_input_train_dir)): 26 | layer2_input_train = np.load(layer_input_train_dir + "/" + input_file) 27 | 28 | if input_index == 0: 29 | dataset_size = layer2_input_train.shape[0] * len( 30 | os.listdir(layer_input_train_dir) 31 | ) 32 | 33 | for i in range(layer2_input_train.shape[0]): 34 | # vect_w_b = np.vstack((np.array([layer2_input_train[i]]).T, np.array([[1.0]]))) 35 | vect_w = np.array([layer2_input_train[i]]).T 36 | denominator = dataset_size + np.dot( 37 | np.dot(vect_w.T, hessian_inverse), vect_w 38 | ) 39 | numerator = np.dot( 40 | np.dot(hessian_inverse, vect_w), np.dot(vect_w.T, hessian_inverse) 41 | ) 42 | hessian_inverse = hessian_inverse - numerator * (1.00 / denominator) 43 | 44 | np.save(hessian_inverse_path, hessian_inverse) 45 | 46 | 47 | def get_filtered_saliency(saliency, mask): 48 | s = list(saliency) 49 | m = list(mask) 50 | 51 | _, filtered_w = zip( 52 | *( 53 | (masked_val, weight_val) 54 | for masked_val, weight_val in zip(m, s) 55 | if masked_val == 1 56 | ) 57 | ) 58 | return filtered_w 59 | 60 | 61 | def get_layer_count(network): 62 | i = 0 63 | for _ in deepstruct.sparse.prunable_layers_with_name(network): 64 | i += 1 65 | return i 66 | 67 | 68 | def get_weight_distribution(network): 69 | all_weights = [] 70 | for layer in deepstruct.sparse.prunable_layers(network): 71 | mask = list(layer.get_mask().numpy().flatten()) 72 | weights = list(layer.get_weight().data.numpy().flatten()) 73 | 74 | masked_val, filtered_weights = zip( 75 | *( 76 | (masked_val, weight_val) 77 | for masked_val, weight_val in zip(mask, weights) 78 | if masked_val == 1 79 | ) 80 | ) 81 | 82 | all_weights += list(filtered_weights) 83 | 84 | # return all the weights, that are not masked as a numpy array 85 | return np.array(all_weights) 86 | 87 | 88 | def entropy(labels, base=None): 89 | value, counts = np.unique(labels, return_counts=True) 90 | norm_counts = counts / counts.sum() 91 | base = math.e if base is None else base 92 | return -(norm_counts * np.log(norm_counts) / np.log(base)).sum() 93 | 94 | 95 | def kullback_leibler(a, b): 96 | a = np.asarray(a, dtype=np.float) 97 | b = np.asarray(b, dtype=np.float) 98 | 99 | return np.sum(np.where(a != 0, a * np.log(a / b), 0)) 100 | -------------------------------------------------------------------------------- /development.md: -------------------------------------------------------------------------------- 1 | # Roadmap 2 | 3 | **New Features** 4 | - [ ] define on graph transformation a level of granularity for structure extraction 5 | - [ ] estimate a graph size during the inference-duckpunching-phase of the transformation 6 | - [ ] mapping between network models and graphs (partially there) 7 | - [ ] scalable graph themes for transformation into function space 8 | 9 | **General** 10 | - [ ] re-work pruning to enable different strategies with an outside-model object-oriented software design 11 | - [ ] document how to extract a single mask 12 | - [ ] document how to initialize a deep directed acyclic network 13 | - [ ] document how to train models with data, e.g. even with pytorch ignite 14 | - [ ] document the graph transformation via duckpunching 15 | - [ ] describe idea of graph themes 16 | - [ ] describe architecture of deepstruct with flowcharts / visualizations 17 | - [ ] describe idea of mapping between network model and graphs (we use networkx) 18 | 19 | - [x] sparse recurrent network models 20 | - [x] organize and explain when to use which sparse model in application 21 | 22 | 23 | 24 | 25 | # Practices & Conventions 26 | 27 | ## Namings 28 | Consider a reversed naming scheme for variables, i.e. ```parameter_lr``` for a learning rate parameter. 29 | The advantage of it is to have a naming scheme which allows for fast auto-complete etc. 30 | 31 | ## Publishing 32 | ```bash 33 | poetry build 34 | twine upload dist/* 35 | ``` 36 | - Create wheel files in *dist/*: ``poetry build`` 37 | - Install wheel in current environment with pip: ``pip install path/to/deepstruct/dist/deepstruct-0.1.0-py3-none-any.whl`` 38 | 39 | ## Running CI image locally 40 | Install latest *gitlab-runner* (version 12.3 or up): 41 | ```bash 42 | # For Debian/Ubuntu/Mint 43 | curl -L https://packages.gitlab.com/install/repositories/runner/gitlab-runner/script.deb.sh | sudo bash 44 | 45 | # For RHEL/CentOS/Fedora 46 | curl -L https://packages.gitlab.com/install/repositories/runner/gitlab-runner/script.rpm.sh | sudo bash 47 | 48 | apt-get update 49 | apt-get install gitlab-runner 50 | 51 | $ gitlab-runner -v 52 | Version: 12.3.0 53 | ``` 54 | Execute job *tests*: ``gitlab-runner exec docker test-python3.9`` 55 | 56 | ## Running github action locally 57 | Install *https://github.com/nektos/act*. 58 | Run ``act`` 59 | 60 | ## Running pre-commit checks locally 61 | - Execute pre-commit manually: ``poetry run pre-commit run --all-files`` 62 | - Update pre-commit: ``poetry run pre-commit autoupdate`` 63 | - Add pre-commit to your local git: ``poetry run pre-commit install`` 64 | -------------------------------------------------------------------------------- /docs/artificial-landscape-approximation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innvariant/deepstruct/0ffbaf73425c063485cb173f5c0fc3677368b522/docs/artificial-landscape-approximation.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # deepstruct 2 | Documentation for [deepstruct on github](https://github.com/innvariant/deepstruct). 3 | 4 | ### Notes 5 | - **2023-01-18** mentioned some additional functionality such as learnable weights for a mask 6 | - **2022-07-05** some pruning functionality was re-activated; good examples can also be found in the testing directory 7 | - **2022-05-05** added some simple examples as currently used in experiments 8 | - **2020-11-22** currently documenting the package 9 | 10 | 11 | ## Introduction 12 | Deepstruct provides tools and models in pytorch to easily work with all kinds of sparsity of neural networks. 13 | The four major approaches in that context are 1) **pruning** neural networks, 2) defining **prior structures** on neural networks, 3) **growing** neural network structures and 4) conducting graph-neural-network **round-trips**. 14 | 15 | ![Visualization of pruning and growing neural nets.](methods-pruning-growing.png) 16 | 17 | ## Sparse Feed-forward Neural Net for MNIST 18 | One-line construction suffices to build a neural network with any number of hidden layers and then prune it in a second step. 19 | The models in *deepstruct* provide an additional property *model.mask* on which a sparsity pattern can be defined per layer. 20 | Pruning boils down to defining zeros in this binary mask. 21 | ```python 22 | import deepstruct.sparse 23 | from deepstruct.pruning import PruningStrategy, prune_network_by_saliency 24 | 25 | input_size = 784 26 | output_size = 10 27 | model = deepstruct.sparse.MaskedDeepFFN(input_size, output_size, [200, 100]) 28 | 29 | # Prune 10% of the model based on its absolute weights 30 | prune_network_by_saliency(model, 10, strategy=PruningStrategy.PERCENTAGE) 31 | ``` 32 | 33 | # Base Layer *MaskedLinearLayer* 34 | The underlying base module for *deepstruct* is a *MaskedLinearLayer* which has a *mask*-property. 35 | Example: 36 | ```python 37 | import torch 38 | from deepstruct.sparse import MaskedLinearLayer 39 | 40 | layer = MaskedLinearLayer(784, 100) 41 | layer.mask = torch.zeros(100, 784) # zeros out all connections 42 | layer.mask = torch.ones(100, 784) # activate all connections between input and output 43 | ``` 44 | 45 | ## Learnable mask weights for differentiable sparsity 46 | Upon initialization this can be set to be learnable: 47 | ```python 48 | from deepstruct.sparse import MaskedLinearLayer 49 | MaskedLinearLayer(784, 100, mask_as_params=True) 50 | ``` 51 | In that case of setting the mask as a learnable parameters, the underlying data structure is of shape $(out_features, in_feature, 2)$ instead of a commonly used binary matrix (out_features, in_feature). 52 | This means that for each connectivity weight in a layer you can learn kind of "log-probabilities" $p_1$ and $p_2$ which are passed through a softmax such that they are interpreted as probabilities for how likely it is that the connection is active or not. 53 | In an inferencing step these probabilities are then hardened by an argmax operation such that the connection will be either on or off (similar to differentiable learning of paths). 54 | This can be interesting in differentiable structure learning settings. 55 | 56 | ## Creating Neural Nets from Graphs 57 | ```python 58 | import deepstruct.sparse 59 | 60 | input_size = 5 61 | output_size = 2 62 | structure = deepstruct.sparse.CachedLayeredGraph() 63 | structure.add_nodes_from(range(20)) 64 | model = deepstruct.sparse.MaskedDeepDAN(input_size, output_size, structure) 65 | ``` 66 | 67 | ## Binary Trees, Grids or Small-World Networks as Prior Structure of Neural Nets 68 | Various graph generators can be easily used to build a sparse structure with various residual / skip connections based on a given networkx graph. 69 | The data structure has to be converted into a layered graph form from which the topological sorting can be better used for the underlying implementation. 70 | The model *MaskedDeepDAN* then provides a simple constructor to obtain a model from the given layered directed acyclic graph structure by specifying the input dimensions and the output dimensions of the underlying problem. 71 | Training proceeds as with any other pytorch model. 72 | ```python 73 | import networkx as nx 74 | import deepstruct.sparse 75 | import deepstruct.graph as dsg 76 | 77 | # Create a graph with networkx 78 | graph_btree = nx.balanced_tree(r=2, h=3) 79 | graph_grid = nx.grid_2d_graph(3, 30, periodic=False) 80 | graph_smallworld = nx.watts_strogatz_graph(100, 3, 0.8) 81 | 82 | ds_graph_btree = dsg.LayeredGraph.load_from(graph_btree) 83 | ds_graph_grid = dsg.LayeredGraph.load_from(graph_grid) 84 | ds_graph_smallworld = dsg.LayeredGraph.load_from(graph_smallworld) 85 | 86 | # Define a model based on the structure 87 | input_shape = (5, 5) 88 | output_size = 2 89 | model = deepstruct.sparse.MaskedDeepDAN(input_shape, output_size, ds_graph_btree) 90 | ``` 91 | 92 | 93 | ## Extract graphs from neural nets 94 | As of *2022-05-05* this is currently only implemented on a zero-th order level of a neural network in which neurons correspond to graph vertices. 95 | This is a very expensive transformation as for common models you will transform a model of several megabytes in efficient data storages from pytorch into a networkx graph of hundred thousands to millions of vertices. 96 | We're working on defining other levels of sparsity and you're welcome to support us in it, e.g. write a mail to julian.stier@uni-passau.de ! 97 | ```python 98 | import torch 99 | import deepstruct.transform as dtr 100 | 101 | # Define a transformation object which takes a random input to pass through the model for duck-punching ("analysis") 102 | input_shape = (5, 5) 103 | model = None # take the model e.g. from above 104 | functor = dtr.GraphTransform(torch.randn((1,)+input_shape)) 105 | 106 | # Obtain the graph structure from the model as based on your transformation routine 107 | graph = functor.transform(model) 108 | print(graph.nodes) 109 | ``` 110 | 111 | 112 | # Training a deepstruct model with random BTrees 113 | ```python 114 | import torch 115 | import numpy as np 116 | import networkx as nx 117 | import deepstruct.sparse 118 | import deepstruct.graph as dsg 119 | 120 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 121 | 122 | # Arrange 123 | batch_size = 10 124 | input_size = 784 125 | output_size = 10 126 | 127 | graph_btree = nx.balanced_tree(r=2, h=3) 128 | ds_graph_btree = dsg.LayeredGraph.load_from(graph_btree) 129 | model = deepstruct.sparse.MaskedDeepDAN(input_size, output_size, ds_graph_btree) 130 | model.to(device) 131 | 132 | # Prepare training 133 | optimizer = torch.optim.Adam(model.parameters()) 134 | criterion = torch.nn.CrossEntropyLoss() 135 | 136 | # Here you could put your for-loop over a dataloader 137 | for epoch_current in range(10): 138 | random_input = torch.tensor( 139 | np.random.random((batch_size, input_size)), device=device, requires_grad=False 140 | ) 141 | random_target = torch.tensor( 142 | np.random.randint(0, 2, batch_size), device=device, requires_grad=False 143 | ) 144 | 145 | optimizer.zero_grad() 146 | prediction = model(random_input) 147 | loss = criterion(prediction, random_target) 148 | loss.backward() 149 | optimizer.step() 150 | ``` 151 | 152 | 153 | ## Available Models 154 | 155 | 156 | ## Artificial Datasets 157 | -------------------------------------------------------------------------------- /docs/logo-wide.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innvariant/deepstruct/0ffbaf73425c063485cb173f5c0fc3677368b522/docs/logo-wide.png -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innvariant/deepstruct/0ffbaf73425c063485cb173f5c0fc3677368b522/docs/logo.png -------------------------------------------------------------------------------- /docs/masked-deep-cell-dan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innvariant/deepstruct/0ffbaf73425c063485cb173f5c0fc3677368b522/docs/masked-deep-cell-dan.png -------------------------------------------------------------------------------- /docs/masked-deep-dan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innvariant/deepstruct/0ffbaf73425c063485cb173f5c0fc3677368b522/docs/masked-deep-dan.png -------------------------------------------------------------------------------- /docs/masked-deep-ffn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innvariant/deepstruct/0ffbaf73425c063485cb173f5c0fc3677368b522/docs/masked-deep-ffn.png -------------------------------------------------------------------------------- /docs/methods-pruning-growing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innvariant/deepstruct/0ffbaf73425c063485cb173f5c0fc3677368b522/docs/methods-pruning-growing.png -------------------------------------------------------------------------------- /docs/sparse-network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innvariant/deepstruct/0ffbaf73425c063485cb173f5c0fc3677368b522/docs/sparse-network.png -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: deepstruct 2 | theme: readthedocs 3 | repo_url: https://github.com/innvariant/deepstruct 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "deepstruct" 3 | version = "0.11.0-dev" 4 | description = "" 5 | authors = [ 6 | "Julian Stier " 7 | ] 8 | license = "MIT" 9 | 10 | include = [ 11 | "pyproject.toml", 12 | "README.md" 13 | ] 14 | 15 | readme = "README.md" 16 | 17 | homepage = "https://github.com/innvariant/deepstruct" 18 | repository = "https://github.com/innvariant/deepstruct" 19 | documentation = "https://deepstruct.readthedocs.io" 20 | 21 | keywords = ["neural network", "sparsity", "machine learning", "structure", "graph", "training"] 22 | 23 | [tool.poetry.dependencies] 24 | python = "^3.9,<4.0" 25 | torch = "^2" 26 | networkx = "^2.0" 27 | importlib-metadata = "^4.4" 28 | importlib-resources = "^5.0" 29 | semantic_version = "^2.10" 30 | deprecated = "^1.2.10" 31 | numpy = "^1.21" 32 | tqdm = "^4.66.1" 33 | 34 | 35 | [tool.poetry.dev-dependencies] 36 | black = "^22.3.0" 37 | matplotlib = { version = "^3.3" } 38 | pre-commit = "^2.3.0" 39 | pytest = "^7.1" 40 | pytest-mock = "^3.0" 41 | pyfakefs = "^4.0.2" 42 | torchvision = "^0.17" 43 | mkdocs = "^1.1.2" 44 | 45 | [tool.poetry.group.dev.dependencies] 46 | scipy = "^1.13.0" 47 | 48 | [tool.isort] 49 | profile = "black" 50 | line_length = 88 51 | force_single_line = true 52 | atomic = true 53 | include_trailing_comma = true 54 | lines_after_imports = 2 55 | lines_between_types = 1 56 | multi_line_output = 3 57 | use_parentheses = true 58 | filter_files = true 59 | src_paths = ["deepstruct", "tests"] 60 | skip_glob = ["*/setup.py", "res/"] 61 | known_first_party = "deepstruct" 62 | known_third_party = ["importlib_metadata", "importlib_resources", "pyfakefs", "pytest", "semantic_version", "torch" ] 63 | 64 | [tool.black] 65 | line-length = 88 66 | include = '\.pyi?$' 67 | exclude = ''' 68 | /( 69 | \.eggs 70 | | \.git 71 | | \.hg 72 | | \.mypy_cache 73 | | \.tox 74 | | \.venv 75 | | _build 76 | | buck-out 77 | | res 78 | | build 79 | | dist 80 | | tests/cache/ 81 | | tests/.*/setup.py 82 | )/ 83 | ''' 84 | 85 | [build-system] 86 | requires = ["poetry>=0.12"] 87 | build-backend = "poetry.masonry.api" 88 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innvariant/deepstruct/0ffbaf73425c063485cb173f5c0fc3677368b522/tests/__init__.py -------------------------------------------------------------------------------- /tests/graph/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innvariant/deepstruct/0ffbaf73425c063485cb173f5c0fc3677368b522/tests/graph/__init__.py -------------------------------------------------------------------------------- /tests/graph/test_CachedLayeredGraph.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import networkx as nx 4 | import numpy as np 5 | 6 | import deepstruct.graph 7 | 8 | 9 | def test_store_load(tmp_path): 10 | path_graph = os.path.join(tmp_path, "tmp.graphml") 11 | 12 | l0 = deepstruct.graph.CachedLayeredGraph() 13 | l0.add_nodes_from([1, 2, 3, 4, 5]) 14 | l0.add_edges_from([(1, 3), (2, 4), (3, 4), (4, 5)]) 15 | l0.save(path_graph) 16 | 17 | l1 = deepstruct.graph.LayeredGraph.load(path_graph) 18 | 19 | assert l1 is not None 20 | assert len(l1.nodes) == len(l0.nodes) 21 | assert len(l1.edges) == len(l0.edges) 22 | assert nx.is_isomorphic(l0, l1) 23 | 24 | 25 | def test_cached_layered_graph_default_success(): 26 | layered_graph = deepstruct.graph.CachedLayeredGraph() 27 | 28 | layered_graph.add_nodes_from(np.arange(1, 7)) 29 | 30 | # First layer 31 | layered_graph.add_edge(1, 3) 32 | layered_graph.add_edge(1, 4) 33 | layered_graph.add_edge(1, 5) 34 | layered_graph.add_edge(1, 6) 35 | layered_graph.add_edge(2, 3) 36 | layered_graph.add_edge(2, 4) 37 | layered_graph.add_edge(2, 5) 38 | layered_graph.add_edge(2, 7) 39 | 40 | # Second layer 41 | layered_graph.add_edge(3, 6) 42 | layered_graph.add_edge(4, 6) 43 | layered_graph.add_edge(4, 7) 44 | layered_graph.add_edge(5, 7) 45 | 46 | first_layer_size_before = layered_graph.get_layer_size(0) 47 | assert not layered_graph._has_changed 48 | 49 | # Add vertex 0 and connect it to vertices from layer 2 50 | layered_graph.add_edge(0, 3) 51 | layered_graph.add_edge(0, 4) 52 | assert layered_graph._has_changed 53 | 54 | first_layer_size_after = layered_graph.get_layer_size(0) 55 | assert first_layer_size_before != first_layer_size_after 56 | -------------------------------------------------------------------------------- /tests/graph/test_LabeledDAG.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | import deepstruct.graph 7 | 8 | 9 | def test_add_single(): 10 | # Arrange 11 | dag = deepstruct.graph.LabeledDAG() 12 | 13 | # Act 14 | dag.add_node(5) 15 | 16 | # Assert 17 | assert len(dag.nodes) == 1 18 | assert 0 in dag.nodes 19 | assert 5 not in dag.nodes 20 | assert len(dag.edges) == 0 21 | assert dag.num_layers == 1 22 | 23 | 24 | def test_add_two_separate_nodes(): 25 | # Arrange 26 | dag = deepstruct.graph.LabeledDAG() 27 | random_label_1_larger_than_one = np.random.randint(2, 5) 28 | random_label_2_larger_than_one = np.random.randint(2, 5) 29 | 30 | # Act 31 | dag.add_node(random_label_1_larger_than_one) 32 | dag.add_node(random_label_2_larger_than_one) 33 | 34 | # Assert 35 | assert len(dag.nodes) == 2 36 | assert 0 in dag.nodes 37 | assert 1 in dag.nodes 38 | assert random_label_1_larger_than_one not in dag.nodes 39 | assert random_label_2_larger_than_one not in dag.nodes 40 | assert len(dag.edges) == 0 41 | assert dag.num_layers == 1 42 | 43 | 44 | def test_add_two_connected_nodes(): 45 | # Arrange 46 | dag = deepstruct.graph.LabeledDAG() 47 | random_label_1_larger_than_one = np.random.randint(2, 5) 48 | random_label_2_larger_than_one = np.random.randint(2, 5) 49 | 50 | # Act 51 | dag.add_edge(random_label_1_larger_than_one, random_label_2_larger_than_one) 52 | 53 | # Assert 54 | assert len(dag.nodes) == 2 55 | assert 0 in dag.nodes 56 | assert 1 in dag.nodes 57 | assert random_label_1_larger_than_one not in dag.nodes 58 | assert random_label_2_larger_than_one not in dag.nodes 59 | assert len(dag.edges) == 1 60 | assert dag.num_layers == 2 61 | 62 | 63 | def test_cycle(): 64 | # Arrange 65 | dag = deepstruct.graph.LabeledDAG() 66 | dag.add_edge(0, 1) # Add one valid edge and two nodes 67 | dag.add_edge(1, 2) # Add one valid edge and one new node 68 | 69 | # Act 70 | with pytest.raises(AssertionError): 71 | dag.add_edge(2, 0) # This would close the cycle 72 | 73 | # Assert 74 | assert len(dag.nodes) == 3 75 | assert len(dag.edges) == 2 76 | assert dag.num_layers == 3 77 | 78 | 79 | def test_add_two_layers(): 80 | # Arrange 81 | dag = deepstruct.graph.LabeledDAG() 82 | 83 | size_layer0 = np.random.randint(2, 10) 84 | size_layer1 = np.random.randint(2, 10) 85 | 86 | layer0 = dag.add_vertices(size_layer0, layer=0) 87 | layer1 = dag.add_vertices(size_layer1, layer=1) 88 | 89 | # Act 90 | for source in layer0: 91 | dag.add_edges_from((source, t) for t in layer1) 92 | 93 | # Assert 94 | assert len(dag.nodes) == size_layer0 + size_layer1 95 | assert len(dag.edges) == size_layer0 * size_layer1 96 | assert dag.num_layers == 2 97 | 98 | 99 | def test_add_two_layers_crossing(): 100 | # Arrange 101 | dag = deepstruct.graph.LabeledDAG() 102 | 103 | size_layer0 = np.random.randint(2, 10) 104 | size_layer1 = np.random.randint(2, 10) 105 | 106 | # Act 107 | dag.add_edges_from( 108 | itertools.product( 109 | np.arange(size_layer0), np.arange(size_layer0 + size_layer1 + 1) 110 | ) 111 | ) 112 | 113 | 114 | def test_multiple_large_layers(): 115 | # Arrange 116 | dag = deepstruct.graph.LabeledDAG() 117 | 118 | num_layers = np.random.randint(15, 21) 119 | size_layer = {} 120 | for layer in range(num_layers): 121 | size_layer[layer] = np.random.randint(50, 101) 122 | dag.add_vertices(size_layer[layer], layer=layer) 123 | 124 | # Act 125 | num_edges = 0 126 | for layer_source in range(num_layers - 1): 127 | for layer_target in np.random.choice( 128 | range(layer_source + 1, num_layers), 129 | np.random.randint(num_layers - layer_source), 130 | replace=False, 131 | ): 132 | v_source = np.random.choice( 133 | dag.get_vertices(layer_source), 134 | dag.get_layer_size(layer_source), 135 | replace=False, 136 | ) 137 | v_target = np.random.choice( 138 | dag.get_vertices(layer_target), 139 | dag.get_layer_size(layer_target), 140 | replace=False, 141 | ) 142 | dag.add_edges_from((s, t) for t in v_target for s in v_source) 143 | num_edges += len(v_source) * len(v_target) 144 | 145 | # Assert 146 | assert len(dag.nodes) == sum(size_layer.values()) 147 | assert len(dag.edges) == num_edges 148 | assert dag.num_layers == num_layers 149 | 150 | 151 | def test_append_simple(): 152 | # Arrange 153 | graph1 = deepstruct.graph.LabeledDAG() 154 | graph2 = deepstruct.graph.LabeledDAG() 155 | 156 | graph1_size_layer0 = 3 # np.random.randint(2, 10) 157 | graph1_size_layer1 = 5 # np.random.randint(2, 10) 158 | graph2_size_layer0 = graph1_size_layer1 159 | graph2_size_layer1 = 4 # np.random.randint(2, 10) 160 | 161 | graph1_layer0 = graph1.add_vertices(graph1_size_layer0, layer=0) 162 | graph1_layer1 = graph1.add_vertices(graph1_size_layer1, layer=1) 163 | graph2_layer0 = graph2.add_vertices(graph2_size_layer0, layer=0) 164 | graph2_layer1 = graph2.add_vertices(graph2_size_layer1, layer=1) 165 | graph1.add_edges_from( 166 | (s, t) for t in graph1_layer1 for s in graph1_layer0 if np.random.randint(2) 167 | ) 168 | graph2.add_edges_from( 169 | (s, t) for t in graph2_layer1 for s in graph2_layer0 if np.random.randint(3) 170 | ) 171 | 172 | graph1_num_edges = len(graph1.edges) 173 | graph2_num_edges = len(graph2.edges) 174 | 175 | graph1.append(graph2) 176 | 177 | assert len(graph1) == graph1_size_layer0 + graph1_size_layer1 + graph2_size_layer1 178 | assert len(graph1.edges) == graph1_num_edges + graph2_num_edges 179 | 180 | 181 | def test_append_multiple_large_layers(): 182 | # Arrange 183 | graph1 = deepstruct.graph.LabeledDAG() 184 | graph2 = deepstruct.graph.LabeledDAG() 185 | 186 | graph1_num_layers = np.random.randint(10, 21) 187 | graph1_size_layer = {} 188 | for layer in range(graph1_num_layers): 189 | graph1_size_layer[layer] = np.random.randint(50, 101) 190 | graph1.add_vertices(graph1_size_layer[layer], layer=layer) 191 | 192 | graph2_num_layers = np.random.randint(10, 21) 193 | graph2_size_layer = {0: graph1_size_layer[graph1_num_layers - 1]} 194 | graph2.add_vertices(graph2_size_layer[0], layer=0) 195 | for layer in range(1, graph2_num_layers): 196 | graph2_size_layer[layer] = np.random.randint(50, 101) 197 | graph2.add_vertices(graph2_size_layer[layer], layer=layer) 198 | 199 | num_edges = {} 200 | for graph, num_layers in zip( 201 | [graph1, graph2], [graph1_num_layers, graph2_num_layers] 202 | ): 203 | num_edges[graph] = 0 204 | for layer_source in range(num_layers - 1): 205 | for layer_target in np.random.choice( 206 | range(layer_source + 1, num_layers), 207 | np.random.randint(num_layers - layer_source), 208 | replace=False, 209 | ): 210 | v_source = np.random.choice( 211 | graph.get_vertices(layer_source), 212 | graph.get_layer_size(layer_source), 213 | replace=False, 214 | ) 215 | v_target = np.random.choice( 216 | graph.get_vertices(layer_target), 217 | graph.get_layer_size(layer_target), 218 | replace=False, 219 | ) 220 | graph.add_edges_from((s, t) for t in v_target for s in v_source) 221 | num_edges[graph] += len(v_source) * len(v_target) 222 | 223 | # Pre-Check 224 | for graph, num_layers, size_layer in zip( 225 | [graph1, graph2], 226 | [graph1_num_layers, graph2_num_layers], 227 | [graph1_size_layer, graph2_size_layer], 228 | ): 229 | assert len(graph.nodes) == sum(size_layer.values()) 230 | assert len(graph.edges) == num_edges[graph] 231 | assert graph.num_layers == num_layers 232 | 233 | # Act 234 | graph1.append(graph2) 235 | 236 | # Assert 237 | assert len(graph1.nodes) == sum(graph1_size_layer.values()) + sum( 238 | graph2_size_layer.values() 239 | ) - graph2.get_layer_size(0) 240 | assert len(graph1.edges) == num_edges[graph1] + num_edges[graph2] 241 | assert graph1.num_layers == graph1_num_layers + graph2_num_layers - 1 242 | -------------------------------------------------------------------------------- /tests/graph/test_LayerIndex.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | import deepstruct.graph 4 | import deepstruct.util 5 | 6 | 7 | def test_dev(): 8 | nodes = [1, 2, 3, 4, 5] 9 | 10 | structure = nx.DiGraph() 11 | structure.add_nodes_from(nodes) 12 | structure.add_edge(1, 3) 13 | structure.add_edge(1, 4) 14 | structure.add_edge(1, 5) 15 | structure.add_edge(2, 3) 16 | structure.add_edge(2, 4) 17 | structure.add_edge(3, 5) 18 | structure.add_edge(4, 5) 19 | 20 | layer_index, vertex_by_layer = deepstruct.graph.build_layer_index(structure) 21 | 22 | for n in nodes: 23 | assert n in layer_index 24 | assert layer_index[n] in vertex_by_layer 25 | assert n in vertex_by_layer[layer_index[n]] 26 | -------------------------------------------------------------------------------- /tests/large/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innvariant/deepstruct/0ffbaf73425c063485cb173f5c0fc3677368b522/tests/large/__init__.py -------------------------------------------------------------------------------- /tests/large/test_GraphTransform.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import networkx as nx 4 | import torch 5 | import torchvision 6 | 7 | from torchvision import datasets 8 | from torchvision import transforms 9 | 10 | import deepstruct.graph 11 | import deepstruct.sparse 12 | 13 | from deepstruct.transform import GraphTransform 14 | 15 | 16 | path_common_mnist = "/media/data/set/mnist/" 17 | 18 | 19 | def test_mnist_large(): 20 | learning_rate = 0.001 21 | batch_size = 100 22 | 23 | random_graph = nx.newman_watts_strogatz_graph(100, 4, 0.5) 24 | structure = deepstruct.graph.CachedLayeredGraph() 25 | structure.add_edges_from(random_graph.edges) 26 | structure.add_nodes_from(random_graph.nodes) 27 | 28 | # Build a neural network classifier with 784 input and 10 output neurons and the given structure 29 | model = deepstruct.sparse.MaskedDeepDAN(784, 10, structure) 30 | model.apply_mask() # Apply the mask on the weights (hard, not undoable) 31 | model.recompute_mask() # Use weight magnitude to recompute the mask from the network 32 | pruned_structure = ( 33 | model.generate_structure() 34 | ) # Get the structure -- a networkx graph -- based on the current mask 35 | 36 | new_model = deepstruct.sparse.MaskedDeepDAN(784, 10, pruned_structure) 37 | 38 | # Define transform to normalize data 39 | transform = transforms.Compose( 40 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 41 | ) 42 | 43 | # Download and load the training data 44 | train_set = datasets.MNIST( 45 | path_common_mnist, download=True, train=True, transform=transform 46 | ) 47 | trainloader = torch.utils.data.DataLoader( 48 | train_set, batch_size=batch_size, shuffle=True 49 | ) 50 | 51 | test_set = datasets.MNIST( 52 | path_common_mnist, download=True, train=False, transform=transform 53 | ) 54 | testloader = torch.utils.data.DataLoader( 55 | test_set, batch_size=batch_size, shuffle=True 56 | ) 57 | 58 | optimizer = torch.optim.Adam(new_model.parameters(), lr=learning_rate) 59 | criterion = torch.nn.CrossEntropyLoss() 60 | 61 | for feat, target in trainloader: 62 | optimizer.zero_grad() 63 | prediction = new_model(feat) 64 | loss = criterion(prediction, target) 65 | loss.backward() 66 | optimizer.step() 67 | 68 | new_model.eval() 69 | for feat, target in testloader: 70 | prediction = new_model(feat) 71 | loss = criterion(prediction, target) 72 | print(loss) 73 | 74 | 75 | def test_torchvision_large(): 76 | shape_input = (3, 224, 224) 77 | print("Loading model") 78 | model = torchvision.models.alexnet(pretrained=True) 79 | print("Model loaded") 80 | 81 | functor = GraphTransform(torch.randn((1,) + shape_input)) 82 | 83 | # Act 84 | print("Start functor transformation") 85 | result = functor.transform(model) 86 | print("Functor transformation done") 87 | 88 | print(len(result.nodes)) 89 | print(len(result.edges)) 90 | 91 | 92 | def test_generate_structure_on_large_maskeddeepffn_success(): 93 | # Arrange 94 | shape_input = (1, 28, 28) 95 | layers = [1000, 500, 500, 200, 100] 96 | model = deepstruct.sparse.MaskedDeepFFN(shape_input, 10, layers) 97 | 98 | functor = GraphTransform(torch.randn((1,) + shape_input)) 99 | 100 | # Act 101 | time_transform_start = time.time() 102 | structure = functor.transform(model) 103 | time_transform_end = time.time() 104 | print( 105 | f"Took {round(time_transform_end-time_transform_start, 4)} to transform large structure." 106 | ) 107 | 108 | # Assert 109 | assert 2 + len(layers) == structure.num_layers 110 | structure_layer_sizes = [ 111 | structure.get_layer_size(lay) for lay in structure.layers[1:-1] 112 | ] 113 | for ix, (l1, l2) in enumerate(zip(structure_layer_sizes, layers)): 114 | assert ( 115 | l1 == l2 116 | ), f"Structure {structure_layer_sizes} did not match definition {layers} as layer {ix}" 117 | -------------------------------------------------------------------------------- /tests/large/test_flexible_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models 3 | 4 | from deepstruct.flexible_transform import GraphTransform 5 | from deepstruct.traverse_strategies import FXTraversal 6 | from tests.utils import plot_graph 7 | 8 | 9 | def test_fx_resnet18(): 10 | print(" ") 11 | resnet = torchvision.models.resnet18(True) 12 | input_tensor = torch.rand(1, 3, 224, 224) 13 | graph_transformer = GraphTransform(input_tensor, traversal_strategy=FXTraversal()) 14 | graph_transformer.transform(resnet) 15 | graph = graph_transformer.get_graph() 16 | plot_graph(graph, "Transformation resnet") 17 | 18 | 19 | def test_fx_alexnet(): 20 | print(" ") 21 | alexnet = torchvision.models.alexnet(True) 22 | input_tensor = torch.rand(1, 3, 224, 224) 23 | graph_transformer = GraphTransform( 24 | input_tensor, 25 | traversal_strategy=FXTraversal( 26 | unfold_modules=[object], 27 | fold_modules=[ 28 | torch.nn.Conv2d, 29 | torch.nn.Linear, 30 | torch.nn.BatchNorm2d, 31 | torch.nn.MaxPool2d, 32 | ], 33 | ), 34 | ) 35 | graph_transformer.transform(alexnet) 36 | graph = graph_transformer.get_graph() 37 | plot_graph(graph, "Transformation alexnet") 38 | 39 | 40 | def test_fx_resnet50(): 41 | print(" ") 42 | res50 = torchvision.models.resnet50(True) 43 | input_tensor = torch.rand(1, 3, 224, 224) 44 | graph_transformer = GraphTransform( 45 | input_tensor, traversal_strategy=FXTraversal(exclude_fn=[torch.add]) 46 | ) 47 | graph_transformer.transform(res50) 48 | plot_graph(graph_transformer.get_graph(), "Transformation resnet 50") 49 | -------------------------------------------------------------------------------- /tests/models.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | 3 | 4 | class ConvNet(nn.Module): 5 | def __init__(self): 6 | super(ConvNet, self).__init__() 7 | self.conv1 = nn.Conv2d( 8 | in_channels=1, out_channels=2, kernel_size=2, stride=1, padding=1 9 | ) 10 | self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0) 11 | self.fc = nn.Linear(1, 2) 12 | 13 | def forward(self, x): 14 | x = self.conv1(x) 15 | x = self.pool(x) 16 | x = self.fc(x) 17 | return x 18 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innvariant/deepstruct/0ffbaf73425c063485cb173f5c0fc3677368b522/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/test_DeepGraphModule.py: -------------------------------------------------------------------------------- 1 | import deepstruct.models 2 | 3 | 4 | def test_constructor(): 5 | # Arrange 6 | deepstruct.models.DeepGraphModule() 7 | 8 | # Act 9 | pass 10 | -------------------------------------------------------------------------------- /tests/pruning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innvariant/deepstruct/0ffbaf73425c063485cb173f5c0fc3677368b522/tests/pruning/__init__.py -------------------------------------------------------------------------------- /tests/pruning/test_engine_basic.py: -------------------------------------------------------------------------------- 1 | import torch.nn 2 | 3 | import deepstruct.pruning.engine as dpeng 4 | import deepstruct.sparse as dsp 5 | 6 | 7 | def test_engine_construct(): 8 | model = dsp.MaskedDeepFFN(784, 10, [100], use_layer_norm=True) 9 | criterion = torch.nn.CrossEntropyLoss() 10 | 11 | def prune_step(engine, batch): 12 | model.eval() 13 | inputs, targets = batch[0].cuda(), batch[1].cuda() 14 | outputs = model(inputs) 15 | loss = criterion(outputs, targets) 16 | model.saliency = loss 17 | 18 | dpeng.Engine(prune_step) 19 | -------------------------------------------------------------------------------- /tests/scalable/test_scalable_dan.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | 6 | import deepstruct.graph 7 | import deepstruct.scalable 8 | import deepstruct.sparse 9 | 10 | 11 | def test_skip_connections(): 12 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 13 | 14 | ct1_nx_flat = deepstruct.graph.CachedLayeredGraph() 15 | ct1_nx_flat.add_nodes_from([1, 2, 3, 4, 5]) 16 | ct1_nx_flat.add_edges_from([(1, 3), (2, 4), (3, 4), (4, 5)]) 17 | 18 | architecture = deepstruct.scalable.ScalableDAN( 19 | ct1_nx_flat, deepstruct.graph.uniform_proportions(ct1_nx_flat) 20 | ) 21 | 22 | size_batch = 50 23 | size_input = 20 24 | features_random = torch.tensor( 25 | np.random.random((size_batch, size_input)), device=device, requires_grad=False 26 | ) 27 | fn_model = architecture.build(size_input, 2, 50) 28 | 29 | prediction = fn_model(features_random) 30 | print(prediction.shape) 31 | 32 | 33 | def test_load_reload_model(tmp_path): 34 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 35 | 36 | ct1_nx_flat = deepstruct.graph.CachedLayeredGraph() 37 | ct1_nx_flat.add_nodes_from([1, 2, 3, 4, 5]) 38 | ct1_nx_flat.add_edges_from([(1, 3), (2, 4), (3, 4), (4, 5)]) 39 | 40 | size_batch = 50 41 | size_input = 20 42 | size_output = 2 43 | 44 | # Init 45 | architecture = deepstruct.scalable.ScalableDAN( 46 | ct1_nx_flat, deepstruct.graph.uniform_proportions(ct1_nx_flat) 47 | ) 48 | fn_model, structure_init = architecture.build( 49 | size_input, size_output, 50, return_graph=True 50 | ) 51 | fn_model.to(device) 52 | 53 | # Store 54 | path_checkpoint = os.path.join(tmp_path, "init.pth") 55 | path_structure = os.path.join(tmp_path, "init.graphml") 56 | torch.save({"model_state": fn_model.state_dict()}, path_checkpoint) 57 | structure_init.save(path_structure) 58 | 59 | # Reload 60 | structure_reloaded = deepstruct.graph.CachedLayeredGraph.load(path_structure) 61 | checkpoint = torch.load(path_checkpoint) 62 | fn_reloaded = deepstruct.scalable.ScalableDAN.model( 63 | size_input, size_output, structure_reloaded, use_layer_norm=True 64 | ) 65 | fn_reloaded.load_state_dict(checkpoint["model_state"]) 66 | fn_reloaded.to(device) 67 | 68 | # Inference 69 | fn_reloaded.train() 70 | optimizer = torch.optim.Adam(fn_reloaded.parameters()) 71 | criterion = torch.nn.CrossEntropyLoss() 72 | features_random = torch.tensor( 73 | np.random.random((size_batch, size_input)), device=device, requires_grad=False 74 | ) 75 | features_random.to(device) 76 | targets_random = torch.tensor( 77 | np.random.randint(0, 2, size_batch), device=device, requires_grad=False 78 | ) 79 | targets_random.to(device) 80 | prediction = fn_reloaded(features_random) 81 | loss = criterion(prediction, targets_random) 82 | loss.backward() 83 | optimizer.step() 84 | -------------------------------------------------------------------------------- /tests/sparse/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innvariant/deepstruct/0ffbaf73425c063485cb173f5c0fc3677368b522/tests/sparse/__init__.py -------------------------------------------------------------------------------- /tests/sparse/test_DeepCellDAN.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import networkx as nx 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | import deepstruct.graph 9 | import deepstruct.sparse 10 | import deepstruct.util 11 | 12 | 13 | class Foo(nn.Module): 14 | def __init__(self, num_classes: int, input_channels: int, input_size: int): 15 | super(Foo, self).__init__() 16 | 17 | self._num_classes = num_classes 18 | 19 | reduced_channels = 100 20 | reduction_steps = max(math.floor(math.log2(input_size)) - 5, 1) 21 | input_convs = [ReductionCell(input_channels, reduced_channels)] 22 | for step in range(1, reduction_steps): 23 | input_convs.append(ReductionCell(reduced_channels, reduced_channels)) 24 | self._input_convs = nn.ModuleList(input_convs) 25 | 26 | self.conv1 = nn.Conv2d(reduced_channels, 100, (1, 1)) 27 | self.bn1 = nn.BatchNorm2d(100) 28 | self.fc1 = nn.Linear(100, self._num_classes) 29 | 30 | def forward(self, input): 31 | y = input 32 | for conv in self._input_convs: 33 | y = conv(y) 34 | 35 | y = self.conv1(y) 36 | y = self.bn1(y) # [B, X, N, M] 37 | y = torch.nn.functional.adaptive_avg_pool2d(y, (1, 1)) # [B, X, 1, 1] 38 | y = y.view(y.size(0), -1) # [B, X] 39 | return self.fc1(y) # [B, _num_classes] 40 | 41 | 42 | class ReductionCell(nn.Module): 43 | def __init__(self, input_channels, output_channels): 44 | super().__init__() 45 | 46 | self.conv_reduce = nn.Conv2d( 47 | input_channels, output_channels, kernel_size=5, padding=2, stride=2 48 | ) 49 | self.act = nn.ReLU() 50 | self.batch_norm = nn.BatchNorm2d(output_channels) 51 | 52 | def forward(self, input): 53 | return self.batch_norm(self.act(self.conv_reduce(input))) 54 | 55 | 56 | def test_development(): 57 | """ 58 | 59 | :return: 60 | """ 61 | 62 | """ 63 | Arrange 64 | """ 65 | 66 | # Define a customized cell constructor 67 | # Each cell has to map [batch_size, in_degree, a, b] -> [batch_size, 1, x, y] 68 | # Except for input cells, they map [batch_size, input_channel_size, a, b] -> [batch_size, 1, x, y] 69 | def my_cell_constructor( 70 | is_input, is_output, in_degree, out_degree, layer, input_channel_size 71 | ): 72 | if is_input: 73 | return ReductionCell(input_channel_size, 1) 74 | else: 75 | return ReductionCell(in_degree, 1) 76 | 77 | # Generate a random directed acyclic network 78 | # random_graph = nx.navigable_small_world_graph(200, 4, 5, 2) 79 | random_graph = nx.watts_strogatz_graph(200, 3, 0.8) 80 | adj_matrix = nx.convert_matrix.to_numpy_array(random_graph) 81 | directed_graph = nx.convert_matrix.from_numpy_array(np.tril(adj_matrix)) 82 | 83 | # Pass the random network to cached layered graph as a structural wrapper 84 | structure = deepstruct.graph.CachedLayeredGraph() 85 | structure.add_nodes_from(directed_graph.nodes) 86 | structure.add_edges_from(directed_graph.edges) 87 | 88 | batch_size = 100 89 | input_channels = 3 90 | output_classes = 10 91 | model = deepstruct.sparse.DeepCellDAN( 92 | output_classes, input_channels, my_cell_constructor, structure 93 | ) 94 | 95 | def count_parameters(model: torch.nn.Module): 96 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 97 | 98 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 99 | model.to(device) 100 | 101 | random_input = torch.tensor( 102 | np.random.randn(batch_size, input_channels, 50, 50), 103 | dtype=torch.float32, 104 | device=device, 105 | ) 106 | 107 | # Act 108 | output = model(random_input) 109 | 110 | # Assert 111 | assert output.shape[0] == batch_size 112 | assert output.shape[1] == output_classes 113 | assert count_parameters(model) > 1 114 | -------------------------------------------------------------------------------- /tests/sparse/test_MaskedDeepDAN.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import torch 4 | 5 | import deepstruct.graph 6 | import deepstruct.sparse 7 | import deepstruct.util 8 | 9 | 10 | def test_feed_success(): 11 | # Arrange 12 | n_samples = 100 13 | size_batch = 7 14 | shape_input = (105, 13) 15 | size_target = 10 16 | random_graph = nx.watts_strogatz_graph(200, 3, 0.8) 17 | structure = deepstruct.graph.CachedLayeredGraph() 18 | structure.add_edges_from(random_graph.edges) 19 | structure.add_nodes_from(random_graph.nodes) 20 | model = deepstruct.sparse.MaskedDeepDAN(shape_input, size_target, structure) 21 | 22 | features = [torch.randn((size_batch,) + shape_input) for _ in range(n_samples)] 23 | targets = [torch.randint(size_target, size=(size_batch,)) for _ in range(n_samples)] 24 | 25 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 26 | criterion = torch.nn.CrossEntropyLoss() 27 | 28 | # Assert 29 | for feat, target in zip(features, targets): 30 | optimizer.zero_grad() 31 | prediction = model(feat) 32 | loss = criterion(prediction, target) 33 | loss.backward() 34 | optimizer.step() 35 | 36 | 37 | def test_random_structures_success(): 38 | # Arrange 39 | random_graph = nx.watts_strogatz_graph(200, 3, 0.8) 40 | structure = deepstruct.graph.CachedLayeredGraph() 41 | structure.add_edges_from(random_graph.edges) 42 | structure.add_nodes_from(random_graph.nodes) 43 | model = deepstruct.sparse.MaskedDeepDAN(784, 10, structure) 44 | 45 | # Act 46 | extracted_structure = model.generate_structure() 47 | new_model = deepstruct.sparse.MaskedDeepDAN(784, 10, extracted_structure) 48 | 49 | # Assert 50 | # self.assertTrue(nx.algorithms.isomorphism.faster_could_be_isomorphic(structure, extracted_structure)) 51 | assert nx.is_isomorphic(structure, extracted_structure) 52 | assert nx.is_isomorphic(structure, new_model.generate_structure()) 53 | 54 | 55 | def test_random_structures_with_input_and_output_success(): 56 | # Arrange 57 | random_graph = nx.watts_strogatz_graph(200, 3, 0.8) 58 | structure = deepstruct.graph.CachedLayeredGraph() 59 | structure.add_edges_from(random_graph.edges) 60 | structure.add_nodes_from(random_graph.nodes) 61 | model = deepstruct.sparse.MaskedDeepDAN(784, 10, structure) 62 | 63 | # Act 64 | extracted_structure = model.generate_structure( 65 | include_input=True, include_output=True 66 | ) 67 | deepstruct.sparse.MaskedDeepDAN(784, 10, extracted_structure) 68 | 69 | 70 | def test_apply_mask_success(): 71 | random_graph = nx.watts_strogatz_graph(200, 3, 0.8) 72 | structure = deepstruct.graph.CachedLayeredGraph() 73 | structure.add_edges_from(random_graph.edges) 74 | structure.add_nodes_from(random_graph.nodes) 75 | model = deepstruct.sparse.MaskedDeepDAN(784, 10, structure) 76 | 77 | previous_weights = [] 78 | for layer in deepstruct.sparse.maskable_layers(model): 79 | previous_weights.append(np.copy(layer.weight.detach().numpy())) 80 | 81 | model.apply_mask() 82 | 83 | different = [] 84 | for layer, previous_weight in zip( 85 | deepstruct.sparse.maskable_layers(model), previous_weights 86 | ): 87 | different.append( 88 | not np.all( 89 | np.equal( 90 | np.array(previous_weight), np.array(layer.weight.detach().numpy()) 91 | ) 92 | ) 93 | ) 94 | assert np.any(different) 95 | 96 | 97 | def test_get_structure(): 98 | structure = deepstruct.graph.CachedLayeredGraph() 99 | 100 | block0_size = 8 101 | block1_size = 8 102 | block2_size = 2 103 | block3_size = 2 104 | block4_size = 2 105 | block5_size = 2 106 | block6_size = 10 107 | block0 = np.arange(1, block0_size + 1) 108 | block1 = np.arange(block0_size + 1, block0_size + block1_size + 1) 109 | block2 = np.arange( 110 | block0_size + block1_size + 1, block0_size + block1_size + block2_size + 1 111 | ) 112 | block3 = np.arange( 113 | block0_size + block1_size + block2_size + 1, 114 | block0_size + block1_size + block2_size + block3_size + 1, 115 | ) 116 | block4 = np.arange( 117 | block0_size + block1_size + block2_size + block3_size + 1, 118 | block0_size + block1_size + block2_size + block3_size + block4_size + 1, 119 | ) 120 | block5 = np.arange( 121 | block0_size + block1_size + block2_size + block3_size + block4_size + 1, 122 | block0_size 123 | + block1_size 124 | + block2_size 125 | + block3_size 126 | + block4_size 127 | + block5_size 128 | + 1, 129 | ) 130 | block6 = np.arange( 131 | block0_size 132 | + block1_size 133 | + block2_size 134 | + block3_size 135 | + block4_size 136 | + block5_size 137 | + 1, 138 | block0_size 139 | + block1_size 140 | + block2_size 141 | + block3_size 142 | + block4_size 143 | + block5_size 144 | + block6_size 145 | + 1, 146 | ) 147 | 148 | # First layer 149 | for v in block0: 150 | for t in block2: 151 | structure.add_edge(v, t) 152 | for v in block0: 153 | for t in block3: 154 | structure.add_edge(v, t) 155 | for v in block0: 156 | for t in block5: 157 | structure.add_edge(v, t) 158 | for v in block1: 159 | for t in block3: 160 | structure.add_edge(v, t) 161 | for v in block1: 162 | for t in block4: 163 | structure.add_edge(v, t) 164 | for v in block1: 165 | for t in block6: 166 | structure.add_edge(v, t) 167 | 168 | # Second layer 169 | for v in block2: 170 | for t in block5: 171 | structure.add_edge(v, t) 172 | for v in block3: 173 | for t in block5: 174 | structure.add_edge(v, t) 175 | for v in block3: 176 | for t in block6: 177 | structure.add_edge(v, t) 178 | for v in block4: 179 | for t in block6: 180 | structure.add_edge(v, t) 181 | 182 | model = deepstruct.sparse.MaskedDeepDAN(784, 10, structure) 183 | print(model) 184 | 185 | new_structure = model.generate_structure(include_input=False, include_output=False) 186 | 187 | model2 = deepstruct.sparse.MaskedDeepDAN(784, 10, new_structure) 188 | print(model2) 189 | 190 | 191 | def test_dev(): 192 | structure = deepstruct.graph.CachedLayeredGraph() 193 | structure.add_nodes_from(np.arange(1, 7)) 194 | 195 | block0_size = 50 196 | block1_size = 50 197 | block2_size = 30 198 | block3_size = 30 199 | block4_size = 30 200 | block5_size = 20 201 | block6_size = 20 202 | block0 = np.arange(1, block0_size + 1) 203 | block1 = np.arange(block0_size + 1, block0_size + block1_size + 1) 204 | block2 = np.arange( 205 | block0_size + block1_size + 1, block0_size + block1_size + block2_size + 1 206 | ) 207 | block3 = np.arange( 208 | block0_size + block1_size + block2_size + 1, 209 | block0_size + block1_size + block2_size + block3_size + 1, 210 | ) 211 | block4 = np.arange( 212 | block0_size + block1_size + block2_size + block3_size + 1, 213 | block0_size + block1_size + block2_size + block3_size + block4_size + 1, 214 | ) 215 | block5 = np.arange( 216 | block0_size + block1_size + block2_size + block3_size + block4_size + 1, 217 | block0_size 218 | + block1_size 219 | + block2_size 220 | + block3_size 221 | + block4_size 222 | + block5_size 223 | + 1, 224 | ) 225 | block6 = np.arange( 226 | block0_size 227 | + block1_size 228 | + block2_size 229 | + block3_size 230 | + block4_size 231 | + block5_size 232 | + 1, 233 | block0_size 234 | + block1_size 235 | + block2_size 236 | + block3_size 237 | + block4_size 238 | + block5_size 239 | + block6_size 240 | + 1, 241 | ) 242 | 243 | # First layer 244 | for v in block0: 245 | for t in block2: 246 | structure.add_edge(v, t) 247 | for v in block0: 248 | for t in block3: 249 | structure.add_edge(v, t) 250 | for v in block1: 251 | for t in block3: 252 | structure.add_edge(v, t) 253 | for v in block1: 254 | for t in block4: 255 | structure.add_edge(v, t) 256 | 257 | # Second layer 258 | for v in block2: 259 | for t in block5: 260 | structure.add_edge(v, t) 261 | for v in block3: 262 | for t in block5: 263 | structure.add_edge(v, t) 264 | for v in block3: 265 | for t in block6: 266 | structure.add_edge(v, t) 267 | for v in block4: 268 | for t in block6: 269 | structure.add_edge(v, t) 270 | 271 | deepstruct.sparse.MaskedDeepDAN(784, 10, structure) 272 | 273 | 274 | def test_rand_structure_with_layer_norm_success(): 275 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 276 | 277 | # Arrange 278 | random_graph = nx.newman_watts_strogatz_graph(100, 4, 0.5) 279 | # print("Random graph has %s edges." % len(random_graph.edges)) 280 | structure = deepstruct.graph.CachedLayeredGraph() 281 | structure.add_edges_from(random_graph.edges) 282 | structure.add_nodes_from(random_graph.nodes) 283 | batch_size = 10 284 | shape_input = (28, 28) 285 | size_output = 10 286 | random_input = torch.tensor( 287 | np.random.random((batch_size,) + shape_input), 288 | device=device, 289 | requires_grad=False, 290 | ) 291 | model = deepstruct.sparse.MaskedDeepDAN( 292 | shape_input, size_output, structure, use_layer_norm=True 293 | ) 294 | model.to(device) 295 | 296 | # Act 297 | model(random_input) 298 | 299 | # Act 300 | extracted_structure = model.generate_structure( 301 | include_input=False, include_output=False 302 | ) 303 | assert len(structure.nodes) == len(extracted_structure.nodes) 304 | assert len(structure.edges) == len(extracted_structure.edges) 305 | assert nx.is_isomorphic(structure, extracted_structure) 306 | -------------------------------------------------------------------------------- /tests/sparse/test_MaskedDeepFFN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils 4 | 5 | import deepstruct.sparse 6 | 7 | 8 | def test_set_random_masks(): 9 | model = deepstruct.sparse.MaskedDeepFFN(784, 10, [20, 15, 12]) 10 | 11 | for layer in deepstruct.sparse.maskable_layers(model): 12 | random_mask = torch.tensor(np.random.binomial(1, 0.5, layer.mask.shape)) 13 | layer.mask = random_mask 14 | 15 | 16 | def test_prune(): 17 | model = deepstruct.sparse.MaskedDeepFFN(784, 10, [1000, 500, 200, 100]) 18 | model.recompute_mask(theta=0.01) 19 | model.apply_mask() 20 | 21 | for layer in deepstruct.sparse.maskable_layers(model): 22 | print(layer.mask.shape) 23 | print(torch.sum(layer.mask) / float(torch.numel(layer.mask))) 24 | 25 | 26 | def test_random_forward_possibly_on_gpu_success(): 27 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 28 | 29 | # Arrange 30 | batch_size = 10 31 | input_size = 784 32 | output_size = 10 33 | model = deepstruct.sparse.MaskedDeepFFN(input_size, output_size, [200, 100, 50]) 34 | model.to(device) 35 | random_input = torch.tensor( 36 | np.random.random((batch_size, input_size)), device=device, requires_grad=False 37 | ) 38 | 39 | # Act 40 | output = model(random_input) 41 | 42 | # Assert 43 | assert output.numel() == batch_size * output_size 44 | 45 | 46 | def test_random_forward_with_multiple_dimensions_success(): 47 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 48 | 49 | # Arrange 50 | batch_size = 10 51 | input_size = (10, 5, 8) 52 | output_size = 10 53 | model = deepstruct.sparse.MaskedDeepFFN(input_size, output_size, [100, 200, 50]) 54 | model.to(device) 55 | random_input = torch.tensor( 56 | np.random.random((batch_size,) + input_size), device=device, requires_grad=False 57 | ) 58 | 59 | # Act 60 | output = model(random_input) 61 | 62 | # Assert 63 | assert output.numel() == batch_size * output_size 64 | -------------------------------------------------------------------------------- /tests/sparse/test_MaskedLinearLayer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import deepstruct.pruning.util as dprutil 5 | import deepstruct.sparse 6 | import deepstruct.util 7 | 8 | 9 | def test_param_determines_mask_type(): 10 | # Arrange 11 | layer1 = deepstruct.sparse.MaskedLinearLayer(5, 3, mask_as_params=False) 12 | layer2 = deepstruct.sparse.MaskedLinearLayer(5, 3, mask_as_params=True) 13 | 14 | # Assert 15 | assert layer1.mask.dtype == torch.bool 16 | assert layer2.mask.dtype == torch.int64 17 | 18 | 19 | def test_set_mask_explicitly_success(): 20 | input_size = 5 21 | output_size = 2 22 | layer = deepstruct.sparse.MaskedLinearLayer(input_size, output_size) 23 | mask = torch.zeros((output_size, input_size), dtype=torch.bool) 24 | mask[0, 0] = 1 25 | mask[0, 1] = 1 26 | mask[1, 2] = 1 27 | 28 | layer.mask = mask 29 | 30 | assert np.all(np.equal(np.array(mask), np.array(layer.mask))) 31 | 32 | 33 | def test_parameter_reset_success(): 34 | # Arrange - initialize a masked layer and randomize its mask 35 | input_size = 5 36 | output_size = 7 37 | layer = deepstruct.sparse.MaskedLinearLayer(input_size, output_size) 38 | layer.apply(dprutil.set_random_masks) 39 | initial_state = np.copy(layer.mask) 40 | 41 | # Act - Now the mask should be reset to only ones 42 | layer.reset_parameters() 43 | 44 | # Assert - The random mask and the resetted mask should not match 45 | assert layer.mask.size() == initial_state.shape 46 | assert (np.array(layer.mask) != initial_state).any() 47 | 48 | 49 | def test_mask_changes_output_success(): 50 | input_size = 5 51 | output_size = 7 52 | layer = deepstruct.sparse.MaskedLinearLayer(input_size, output_size) 53 | input = torch.rand(input_size) 54 | 55 | layer.apply(dprutil.set_random_masks) 56 | first_mask = np.copy(layer.mask) 57 | first_mask_output = layer(input).detach().numpy() 58 | layer.apply(dprutil.set_random_masks) 59 | second_mask = np.copy(layer.mask) 60 | second_mask_output = layer(input).detach().numpy() 61 | 62 | assert ( 63 | first_mask != second_mask 64 | ).any(), "Masks for inference should not equal, but are randomly generated." 65 | assert np.any(np.not_equal(first_mask_output, second_mask_output)) 66 | 67 | 68 | def test_random_input_success(): 69 | input_size = 5 70 | output_size = 2 71 | model = deepstruct.sparse.MaskedLinearLayer(input_size, output_size) 72 | input = torch.tensor(np.random.random(input_size)) 73 | 74 | output = model(input) 75 | 76 | assert output.numel() == output_size 77 | 78 | 79 | def test_initialize_random_parameterizable_mask_success(): 80 | # Arrange - initialize a masked layer and randomize its mask 81 | input_size = 20 82 | output_size = 10 83 | layer = deepstruct.sparse.MaskedLinearLayer( 84 | input_size, output_size, mask_as_params=True 85 | ) 86 | initial_state = np.copy(layer.mask) 87 | 88 | # Act 89 | layer.apply(dprutil.set_random_masks) 90 | 91 | # Assert 92 | assert (np.array(layer.mask) != initial_state).any() 93 | 94 | 95 | def test_paramterized_masks_contained_in_model_params(): 96 | # Arrange - initialize a masked layer and randomize its mask 97 | name_param = "_mask" 98 | input_size = 5 99 | output_size = 7 100 | layer = deepstruct.sparse.MaskedLinearLayer( 101 | input_size, output_size, mask_as_params=True 102 | ) 103 | 104 | params = {name: p for name, p in layer.named_parameters()} 105 | 106 | assert len(list(layer.parameters())) == 3 107 | assert name_param in params 108 | assert params[name_param].numel() == 2 * input_size * output_size 109 | 110 | 111 | def test_nonparamterized_masks_not_contained_in_model_params(): 112 | # Arrange - initialize a masked layer and randomize its mask 113 | name_param = "_mask" 114 | input_size = 5 115 | output_size = 7 116 | layer = deepstruct.sparse.MaskedLinearLayer( 117 | input_size, output_size, mask_as_params=False 118 | ) 119 | 120 | params = {name: p for name, p in layer.named_parameters()} 121 | 122 | assert name_param not in params 123 | assert len(list(layer.parameters())) == 2 124 | 125 | 126 | def test_paramterized_masks_success(): 127 | # Arrange - initialize a masked layer and randomize its mask 128 | input_size = 5 129 | output_size = 7 130 | layer = deepstruct.sparse.MaskedLinearLayer( 131 | input_size, output_size, mask_as_params=True 132 | ) 133 | initial_alpha_mask = layer._mask.clone().detach().cpu().numpy() 134 | optimizer = torch.optim.Adam(layer.parameters(), lr=0.1, weight_decay=0.1) 135 | 136 | # Act 137 | for _ in range(10): 138 | optimizer.zero_grad() 139 | loss = torch.sum(torch.abs(layer._mask[:, :, 1])) 140 | loss.backward() 141 | optimizer.step() 142 | 143 | # Assert 144 | assert layer._mask.size() == initial_alpha_mask.shape 145 | assert (layer._mask.clone().detach().cpu().numpy() != initial_alpha_mask).any() 146 | -------------------------------------------------------------------------------- /tests/test_MaskableModule.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from deepstruct.sparse import MaskableModule 4 | from deepstruct.sparse import MaskedLinearLayer 5 | 6 | 7 | def test_constructor_success(): 8 | MaskableModule() 9 | 10 | 11 | def test_inheritance__success(): 12 | expected_child_unmaskable = nn.Linear(10, 5) 13 | expected_child_maskable = MaskedLinearLayer(10, 5) 14 | 15 | class InheritedMaskableModule(MaskableModule): 16 | def __init__(self): 17 | super().__init__() 18 | self._linear1 = expected_child_unmaskable 19 | self._linear2 = expected_child_maskable 20 | 21 | model = InheritedMaskableModule() 22 | 23 | assert expected_child_maskable in model.maskable_children 24 | assert expected_child_unmaskable not in model.maskable_children 25 | -------------------------------------------------------------------------------- /tests/test_flexible_transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import networkx as nx 4 | import torch 5 | import torchvision.models 6 | 7 | from torch import nn as nn 8 | from torch.nn import functional as F 9 | from torchvision import models as models 10 | 11 | from deepstruct.flexible_transform import GraphTransform 12 | from deepstruct.traverse_strategies import FXTraversal 13 | 14 | from .utils import calculate_network_metrics 15 | from .utils import plot_graph 16 | 17 | 18 | class FuncConstructorCNN(nn.Module): 19 | def __init__(self): 20 | super(FuncConstructorCNN, self).__init__() 21 | self.fc = nn.Linear(10, 1) 22 | 23 | def forward(self, x): 24 | ones_tensor = torch.ones_like(x) 25 | y = torch.ones(10) + ones_tensor 26 | return self.fc(y) 27 | 28 | 29 | class ControlFlowCNN(nn.Module): 30 | def __init__(self): 31 | super(ControlFlowCNN, self).__init__() 32 | self.fc = nn.Linear(10, 1) 33 | self.rand_num = random.randint(0, 10) 34 | 35 | def forward(self, x): 36 | if x.sum() > 0: 37 | x = self.fc(x) 38 | else: 39 | x = F.relu(self.fc(x)) 40 | 41 | if self.rand_num < 5: 42 | x = F.sigmoid(x) 43 | return x 44 | 45 | 46 | def test_fx_transformer_simple_network(): 47 | net = SimpleCNN() 48 | input_tensor = torch.randn(1, 1, 6, 6) 49 | graph_transformer = GraphTransform(input_tensor, traversal_strategy=FXTraversal()) 50 | graph_transformer.transform(net) 51 | graph = graph_transformer.get_graph() 52 | plot_graph(graph, "Transformation") 53 | simplenet_names = [ 54 | "cos", 55 | "fc", 56 | "size", 57 | "view", 58 | "pool", 59 | "relu", 60 | "conv1", 61 | "x", 62 | "output", 63 | ] 64 | assert len(graph.nodes) == len(simplenet_names) 65 | graph_names = nx.get_node_attributes(graph, "name").values() 66 | assert all(name in graph_names for name in simplenet_names) 67 | assert all(name in simplenet_names for name in graph_names) 68 | 69 | 70 | def test_fx_functional_constructor(): 71 | net = FuncConstructorCNN() 72 | input_tensor = torch.randn(1, 10) 73 | graph_transformer = GraphTransform(input_tensor, traversal_strategy=FXTraversal()) 74 | graph_transformer.transform(net) 75 | plot_graph(graph_transformer.get_graph(), "Transformation") 76 | 77 | 78 | def test_fx_control_flow(): 79 | net = ControlFlowCNN() 80 | input_tensor = torch.randn(1, 10) 81 | graph_transformer = GraphTransform(input_tensor, traversal_strategy=FXTraversal()) 82 | graph_transformer.transform(net) 83 | plot_graph(graph_transformer.get_graph(), "Transformation") 84 | 85 | 86 | def test_fx_transformer_simple_network_excludes(): 87 | net = SimpleCNN() 88 | input_tensor = torch.randn(1, 1, 6, 6) 89 | excluded_functions = [torch.nn.functional.relu, torch.cos] 90 | excluded_modules = [torch.nn.modules.Linear, torch.nn.modules.MaxPool2d] 91 | graph_transformer = GraphTransform( 92 | input_tensor, 93 | traversal_strategy=FXTraversal( 94 | exclude_fn=excluded_functions, exclude_modules=excluded_modules 95 | ), 96 | ) 97 | graph_transformer.transform(net) 98 | graph = graph_transformer.get_graph() 99 | plot_graph(graph, "Transformation") 100 | simplenet_names = [ 101 | "cos", 102 | "fc", 103 | "size", 104 | "view", 105 | "pool", 106 | "relu", 107 | "conv1", 108 | "x", 109 | "output", 110 | ] 111 | test_names = ["size", "view", "conv1", "x", "output"] 112 | assert len(graph.nodes) == len(simplenet_names) - len(excluded_modules) - len( 113 | excluded_functions 114 | ) 115 | graph_names = nx.get_node_attributes(graph, "name").values() 116 | assert all(name in graph_names for name in test_names) 117 | assert all(name in test_names for name in graph_names) 118 | 119 | 120 | def test_fx_transformer_simple_network_includes(): 121 | net = SimpleCNN() 122 | input_tensor = torch.randn(1, 1, 6, 6) 123 | included_functions = [torch.cos] 124 | included_modules = [torch.nn.modules.Conv2d, torch.nn.modules.MaxPool2d] 125 | graph_transformer = GraphTransform( 126 | input_tensor, 127 | traversal_strategy=FXTraversal( 128 | include_fn=included_functions, include_modules=included_modules 129 | ), 130 | ) 131 | graph_transformer.transform(net) 132 | graph = graph_transformer.get_graph() 133 | plot_graph(graph, "Transformation") 134 | test_names = ["cos", "pool", "conv1", "x", "output"] 135 | assert len(graph.nodes) == len(included_functions) + len(included_modules) + 2 136 | graph_names = nx.get_node_attributes(graph, "name").values() 137 | assert all(name in graph_names for name in test_names) 138 | assert all(name in test_names for name in graph_names) 139 | 140 | 141 | def test_fx_transformer_simple_network_includes_excludes(): 142 | net = SimpleCNN() 143 | input_tensor = torch.randn(1, 1, 6, 6) 144 | graph_transformer = GraphTransform( 145 | input_tensor, 146 | traversal_strategy=FXTraversal( 147 | exclude_fn=[torch.cos, torch.Tensor.size], 148 | include_modules=[torch.nn.modules.Linear, torch.nn.modules.Conv2d], 149 | ), 150 | ) 151 | graph_transformer.transform(net) 152 | graph = graph_transformer.get_graph() 153 | plot_graph(graph, "Transformation") 154 | test_names = ["fc", "view", "relu", "conv1", "x", "output"] 155 | graph_names = nx.get_node_attributes(graph, "name").values() 156 | assert len(graph.nodes) == len(test_names) 157 | assert all(name in graph_names for name in test_names) 158 | assert all(name in test_names for name in graph_names) 159 | assert graph.has_edge( 160 | graph.get_indices_for_name("x")[0], graph.get_indices_for_name("conv1")[0] 161 | ) 162 | assert graph.has_edge( 163 | graph.get_indices_for_name("conv1")[0], graph.get_indices_for_name("relu")[0] 164 | ) 165 | assert graph.has_edge( 166 | graph.get_indices_for_name("relu")[0], graph.get_indices_for_name("view")[0] 167 | ) 168 | assert graph.has_edge( 169 | graph.get_indices_for_name("view")[0], graph.get_indices_for_name("fc")[0] 170 | ) 171 | assert graph.has_edge( 172 | graph.get_indices_for_name("fc")[0], graph.get_indices_for_name("output")[0] 173 | ) 174 | 175 | 176 | def test_fx_fold_modules(): 177 | resnet = torchvision.models.resnet18(True) 178 | input_tensor = torch.rand(1, 3, 224, 224) 179 | graph_transformer = GraphTransform( 180 | input_tensor, 181 | traversal_strategy=FXTraversal(fold_modules=[torch.nn.modules.Sequential]), 182 | ) 183 | graph_transformer.transform(resnet) 184 | graph = graph_transformer.get_graph() 185 | assert len(graph.nodes) == 13 186 | plot_graph(graph, "Transformation resnet") 187 | 188 | 189 | def test_fx_unfold_modules(): 190 | net = SimpleCNN() 191 | input_tensor = torch.randn(1, 1, 6, 6) 192 | graph_transformer = GraphTransform( 193 | input_tensor, 194 | traversal_strategy=FXTraversal(unfold_modules=[torch.nn.modules.Linear]), 195 | ) 196 | graph_transformer.transform(net) 197 | graph = graph_transformer.get_graph() 198 | plot_graph(graph, "Transformation simplenet") 199 | 200 | 201 | def test_fx_googlenet(): 202 | print(" ") 203 | gnet = torchvision.models.googlenet(True) 204 | input_tensor = torch.rand(1, 3, 224, 224) 205 | graph_transformer = GraphTransform(input_tensor, traversal_strategy=FXTraversal()) 206 | graph_transformer.transform(gnet) 207 | plot_graph(graph_transformer.get_graph(), "Transformation googlenet") 208 | 209 | 210 | def test_fx_hybridmodel(): 211 | print(" ") 212 | hybridmodel = CNNtoRNN(num_classes=10, hidden_size=256, num_layers=2) 213 | input_tensor = torch.rand(4, 10, 3, 224, 224) 214 | graph_transformer = GraphTransform(input_tensor, traversal_strategy=FXTraversal()) 215 | graph_transformer.transform(hybridmodel) 216 | plot_graph(graph_transformer.get_graph(), "Transformation hybridmodel") 217 | 218 | 219 | def test_ged(): 220 | net1 = SimpleCNN() 221 | net2 = SmallCNN() 222 | input_tensor = torch.randn(1, 1, 6, 6) 223 | graph_transformer = GraphTransform(input_tensor, traversal_strategy=FXTraversal()) 224 | graph_transformer2 = GraphTransform(input_tensor, traversal_strategy=FXTraversal()) 225 | graph_transformer.transform(net1) 226 | graph_transformer2.transform(net2) 227 | G1 = graph_transformer.get_graph() 228 | G2 = graph_transformer2.get_graph() 229 | 230 | metrics_G1 = calculate_network_metrics(G1) 231 | metrics_G2 = calculate_network_metrics(G2) 232 | ged = nx.graph_edit_distance(G1, G2) 233 | print("Kennzahlen von G1:", metrics_G1) 234 | print("Kennzahlen von G2:", metrics_G2) 235 | print(ged) 236 | 237 | 238 | class CNNtoRNN(nn.Module): 239 | def __init__(self, hidden_size, num_layers, num_classes): 240 | super(CNNtoRNN, self).__init__() 241 | self.cnn = models.resnet50(pretrained=True) 242 | self.cnn = nn.Sequential(*list(self.cnn.children())[:-2]) 243 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 244 | self.rnn = nn.LSTM( 245 | input_size=2048, 246 | hidden_size=hidden_size, 247 | num_layers=num_layers, 248 | batch_first=True, 249 | ) 250 | self.fc = nn.Linear(hidden_size, num_classes) 251 | 252 | def forward(self, x): 253 | batch_size, timesteps, C, H, W = x.size() 254 | 255 | # CNN-Feature-Extraktion 256 | c_in = x.view(batch_size * timesteps, C, H, W) 257 | c_out = self.cnn(c_in) 258 | c_out = self.avgpool(c_out) 259 | c_out = c_out.view(batch_size, timesteps, -1) 260 | r_out, (hn, cn) = self.rnn(c_out) 261 | r_out2 = self.fc(r_out[:, -1, :]) 262 | return r_out2 263 | 264 | 265 | class SimpleCNN(nn.Module): 266 | def __init__(self): 267 | super(SimpleCNN, self).__init__() 268 | self.conv1 = nn.Conv2d( 269 | in_channels=1, out_channels=6, kernel_size=3, stride=1, padding=1 270 | ) 271 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 272 | self.fc = nn.Linear(54, 2) 273 | 274 | def forward(self, x): 275 | x = self.conv1(x) 276 | x = F.relu(x) 277 | x = self.pool(x) 278 | x = x.view(x.size(0), -1) 279 | x = self.fc(x) 280 | x = torch.cos(x) 281 | return x 282 | 283 | 284 | class SmallCNN(nn.Module): 285 | def __init__(self): 286 | super(SmallCNN, self).__init__() 287 | self.conv1 = nn.Conv2d( 288 | in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1 289 | ) 290 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 291 | self.fc = nn.Linear(9, 6) 292 | self.fc2 = nn.Linear(6, 3) 293 | 294 | def forward(self, x): 295 | x = self.conv1(x) 296 | x = F.relu(x) 297 | x = self.pool(x) 298 | x = x.view(x.size(0), -1) 299 | x = self.fc(x) 300 | x = self.fc2(x) 301 | x = torch.cos(x) 302 | return x 303 | -------------------------------------------------------------------------------- /tests/test_low_level_rep.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from matplotlib import pyplot as plt 8 | 9 | from deepstruct.flexible_transform import GraphTransform 10 | from deepstruct.node_map_strategies import LowLevelNodeMap 11 | from deepstruct.transform import Conv2dLayerFunctor 12 | from deepstruct.traverse_strategies import FXTraversal 13 | 14 | 15 | def plot_graph(graph, title): 16 | labels = nx.get_node_attributes(graph, "name") 17 | fig, ax = plt.subplots(figsize=(10, 10)) 18 | nx.draw( 19 | graph, 20 | labels=labels, 21 | with_labels=True, 22 | node_size=700, 23 | node_color="lightblue", 24 | font_size=8, 25 | ax=ax, 26 | ) 27 | plt.title(title) 28 | plt.show() 29 | 30 | 31 | def calculate_network_metrics(graph): 32 | metrics = { 33 | "nodes": graph.number_of_nodes(), 34 | "edges": graph.number_of_edges(), 35 | "avg degree": sum(dict(graph.degree()).values()) / graph.number_of_nodes(), 36 | "density": nx.density(graph), 37 | "Average cluster coefficient": nx.average_clustering(graph), 38 | } 39 | return metrics 40 | 41 | 42 | class SimpleCNN(nn.Module): 43 | def __init__(self): 44 | super(SimpleCNN, self).__init__() 45 | self.conv1 = nn.Conv2d( 46 | in_channels=1, out_channels=6, kernel_size=3, stride=1, padding=1 47 | ) 48 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 49 | self.fc = nn.Linear(54, 2) 50 | 51 | def forward(self, x): 52 | x = self.conv1(x) 53 | x = F.relu(x) 54 | x = self.pool(x) 55 | x = x.view(x.size(0), -1) 56 | x = self.fc(x) 57 | x = torch.cos(x) 58 | return x 59 | 60 | 61 | class SmallCNN(nn.Module): 62 | def __init__(self): 63 | super(SmallCNN, self).__init__() 64 | self.conv1 = nn.Conv2d( 65 | in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1 66 | ) 67 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 68 | self.fc = nn.Linear(9, 6) 69 | self.fc2 = nn.Linear(6, 3) 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = F.relu(x) 74 | x = self.pool(x) 75 | x = x.view(x.size(0), -1) 76 | x = self.fc(x) 77 | x = self.fc2(x) 78 | x = torch.cos(x) 79 | return x 80 | 81 | 82 | class ConvNet(nn.Module): 83 | def __init__(self): 84 | super(ConvNet, self).__init__() 85 | self.conv1 = nn.Conv2d( 86 | in_channels=1, out_channels=2, kernel_size=2, stride=1, padding=1 87 | ) 88 | self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0) 89 | self.fc = nn.Linear(1, 2) 90 | 91 | def forward(self, x): 92 | x = self.conv1(x) 93 | x = self.pool(x) 94 | x = self.fc(x) 95 | return x 96 | 97 | 98 | def test_fx_low_level_linear_fully_connected(): 99 | print(" ") 100 | net = SmallCNN() 101 | input_tensor = torch.randn(1, 1, 6, 6) 102 | graph_transformer = GraphTransform( 103 | input_tensor, 104 | traversal_strategy=FXTraversal(), 105 | node_map_strategy=LowLevelNodeMap(), 106 | ) 107 | graph_transformer.transform(net) 108 | graph = graph_transformer.get_graph() 109 | fc_nodes = graph.get_indices_for_name("fc") 110 | fc2_nodes = graph.get_indices_for_name("fc2") 111 | cos_nodes = graph.get_indices_for_name("cos") 112 | assert len(fc_nodes) == 9 113 | assert len(fc2_nodes) == 6 114 | assert len(cos_nodes) == 3 115 | for node in fc_nodes: 116 | assert graph.in_degree(node) == 1 117 | assert graph.out_degree(node) == 6 118 | for node in fc2_nodes: 119 | assert graph.in_degree(node) == 9 120 | assert graph.out_degree(node) == 3 121 | for node in cos_nodes: 122 | assert graph.in_degree(node) == 6 123 | assert graph.out_degree(node) == 1 124 | 125 | plot_graph(graph, "Transformation smallnet") 126 | 127 | 128 | def test_fx_low_level_linear(): 129 | print(" ") 130 | net = SmallCNN() 131 | input_tensor = torch.randn(1, 1, 6, 6) 132 | graph_transformer = GraphTransform( 133 | input_tensor, 134 | traversal_strategy=FXTraversal(), 135 | node_map_strategy=LowLevelNodeMap(0.15), 136 | ) 137 | graph_transformer.transform(net) 138 | graph = graph_transformer.get_graph() 139 | plot_graph(graph, "Transformation smallnet") 140 | 141 | 142 | def test_convnet(): 143 | print(" ") 144 | conv_net = ConvNet() 145 | input_tensor = torch.rand(1, 1, 2, 2) 146 | graph_transformer = GraphTransform( 147 | input_tensor, 148 | traversal_strategy=FXTraversal(), 149 | node_map_strategy=LowLevelNodeMap(), 150 | ) 151 | graph_transformer.transform(conv_net) 152 | graph = graph_transformer.get_graph() 153 | plot_graph(graph, "convnet") 154 | conv_nodes = graph.get_indices_for_name("conv1") 155 | pool_nodes = graph.get_indices_for_name("pool") 156 | assert len(conv_nodes) == 4 157 | assert len(pool_nodes) == 18 158 | for node in conv_nodes: 159 | assert graph.in_degree(node) == 1 160 | in_deg_sum = 0 161 | nodes_with_indeg = 0 162 | for node in pool_nodes: 163 | in_deg_sum += graph.in_degree(node) 164 | if graph.in_degree(node) != 0: 165 | nodes_with_indeg += 1 166 | assert graph.out_degree(node) == 1 167 | assert in_deg_sum == len(pool_nodes) 168 | 169 | 170 | def test_conv_simple(): 171 | class ConvModel(nn.Module): 172 | def __init__(self): 173 | super(ConvModel, self).__init__() 174 | channels_in = 3 175 | kernel_size = (3, 3) 176 | self.conv1 = nn.Conv2d( 177 | in_channels=channels_in, 178 | out_channels=2, 179 | kernel_size=kernel_size, 180 | stride=1, 181 | ) 182 | 183 | def forward(self, x): 184 | return self.conv1(x) 185 | 186 | input_width = 5 187 | input_height = 5 188 | channels_in = 3 189 | model = ConvModel() 190 | model.conv1.weight[:, :].data += 10 191 | random_input = torch.rand(size=(1, channels_in, input_height, input_width)) 192 | graph_transformer = GraphTransform( 193 | random_input, node_map_strategy=LowLevelNodeMap(0.01) 194 | ) 195 | 196 | graph_transformer.transform(model) 197 | graph = graph_transformer.get_graph() 198 | output = model.forward(random_input) 199 | number_output_features = np.prod(output.shape) 200 | plot_graph(graph, "conv simple") 201 | assert len(graph.get_indices_for_name("output")) == number_output_features 202 | 203 | 204 | def test_conv_with_add(): 205 | class ConvModel(nn.Module): 206 | def __init__(self): 207 | super(ConvModel, self).__init__() 208 | self.conv1 = nn.Conv2d( 209 | in_channels=1, out_channels=2, kernel_size=2, stride=1, padding=1 210 | ) 211 | self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0) 212 | 213 | def forward(self, x): 214 | x = self.conv1(x) 215 | tmp = x 216 | x = self.pool(x) 217 | x = torch.cos(x) 218 | x = torch.add(x, tmp) 219 | return x 220 | 221 | print(" ") 222 | conv_model = ConvModel() 223 | input_tensor = torch.rand(1, 1, 2, 2) 224 | print(input_tensor) 225 | graph_transformer = GraphTransform( 226 | input_tensor, 227 | traversal_strategy=FXTraversal(), 228 | node_map_strategy=LowLevelNodeMap(), 229 | ) 230 | graph_transformer.transform(conv_model) 231 | graph = graph_transformer.get_graph() 232 | plot_graph(graph, "convnet") 233 | conv_nodes = graph.get_indices_for_name("conv1") 234 | pool_nodes = graph.get_indices_for_name("pool") 235 | cos_nodes = graph.get_indices_for_name("cos") 236 | add_nodes = graph.get_indices_for_name("add") 237 | print(len(conv_nodes), len(pool_nodes), len(cos_nodes), len(add_nodes)) 238 | for node in conv_nodes: 239 | assert graph.in_degree(node) == 1 240 | in_deg_sum = 0 241 | nodes_with_indeg = 0 242 | pools_with_zero_deg = 0 243 | for node in pool_nodes: 244 | in_deg_sum += graph.in_degree(node) 245 | if graph.in_degree(node) == 0: 246 | pools_with_zero_deg += 1 247 | if graph.in_degree(node) != 0: 248 | nodes_with_indeg += 1 249 | assert graph.out_degree(node) == 1 250 | print(pools_with_zero_deg) 251 | 252 | 253 | def test_realistic_convolution(): 254 | class ConvModel(nn.Module): 255 | def __init__(self): 256 | super(ConvModel, self).__init__() 257 | channels_in = 3 258 | kernel_size = (5, 5) 259 | self.conv1 = nn.Conv2d( 260 | in_channels=channels_in, 261 | out_channels=2, 262 | kernel_size=kernel_size, 263 | stride=1, 264 | ) 265 | 266 | def forward(self, x): 267 | return self.conv1(x) 268 | 269 | input_width = 100 270 | input_height = 100 271 | channels_in = 3 272 | model = ConvModel() 273 | model.conv1.weight[:, :].data += 10 274 | random_input = torch.rand(size=(1, channels_in, input_height, input_width)) 275 | graph_transformer = GraphTransform( 276 | random_input, node_map_strategy=LowLevelNodeMap(0.01) 277 | ) 278 | 279 | graph_transformer.transform(model) 280 | graph = graph_transformer.get_graph() 281 | output = model.forward(random_input) 282 | number_output_features = np.prod(output.shape) 283 | # plot_graph(graph, "realistic conv") 284 | assert len(graph.get_indices_for_name("output")) == number_output_features 285 | 286 | 287 | def test_realistic_convolution2(): 288 | # Arrange 289 | input_width = 100 # 100x100 is already quite a huge graph 290 | input_height = 100 291 | channels_in = 3 292 | kernel_size = (5, 5) 293 | model = torch.nn.Conv2d( 294 | in_channels=channels_in, out_channels=2, kernel_size=kernel_size, stride=1 295 | ) 296 | # Make sure each weight is large enough so none is getting "pruned" 297 | model.weight[:, :].data += 10 298 | output = model.forward(torch.rand(size=(1, channels_in, input_height, input_width))) 299 | number_output_features = np.prod(output.shape) 300 | 301 | functor = Conv2dLayerFunctor(input_width, input_height, threshold=0.01) 302 | 303 | # Act 304 | result = functor.transform(model) 305 | 306 | # Assert 307 | assert result.last_layer_size == number_output_features 308 | 309 | 310 | def test_transposed_convolution(): 311 | print(" ") 312 | # Arrange 313 | input_width = 2 314 | input_height = 2 315 | channels_in = 1 316 | model = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=2, stride=1, padding=1) 317 | # Make sure each weight is large enough so none is getting "pruned" 318 | model.weight[:, :].data += 10 319 | output = model.forward(torch.rand(size=(1, channels_in, input_height, input_width))) 320 | number_output_features = np.prod(output.shape) 321 | 322 | functor = Conv2dLayerFunctor(input_width, input_height, threshold=0.01) 323 | 324 | # Act 325 | result = functor.transform(model) 326 | print(number_output_features) 327 | # Assert 328 | assert result.last_layer_size == number_output_features 329 | nx.draw( 330 | result, with_labels=True, node_size=700, node_color="lightblue", font_size=8 331 | ) 332 | plt.show() 333 | -------------------------------------------------------------------------------- /tests/test_recurrent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import deepstruct.recurrent 5 | import deepstruct.sparse 6 | 7 | 8 | def test_recurrent_simple(): 9 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 10 | 11 | # Arrange 12 | batch_size = 15 13 | input_size = 20 14 | hidden_size = 30 15 | model = deepstruct.recurrent.MaskedRecurrentLayer( 16 | input_size, hidden_size=hidden_size 17 | ) 18 | model.to(device) 19 | random_input = torch.tensor( 20 | np.random.random((batch_size, input_size)), 21 | dtype=torch.float32, 22 | device=device, 23 | requires_grad=False, 24 | ) 25 | 26 | # Act 27 | hidden_state = torch.tensor( 28 | np.random.random((batch_size, hidden_size)), 29 | dtype=torch.float32, 30 | device=device, 31 | requires_grad=False, 32 | ) 33 | output = model(random_input, hidden_state) 34 | 35 | # Assert 36 | assert output.numel() == batch_size * hidden_size 37 | 38 | 39 | def test_recurrent_unusual_activation(): 40 | # Arrange 41 | batch_size = 15 42 | input_size = 20 43 | hidden_size = 30 44 | model = deepstruct.recurrent.MaskedRecurrentLayer( 45 | input_size, hidden_size=hidden_size, nonlinearity=torch.nn.LogSigmoid() 46 | ) 47 | random_input = torch.tensor( 48 | np.random.random((batch_size, input_size)), 49 | dtype=torch.float32, 50 | requires_grad=False, 51 | ) 52 | 53 | # Act 54 | hidden_state = torch.tensor( 55 | np.random.random((batch_size, hidden_size)), 56 | dtype=torch.float32, 57 | requires_grad=False, 58 | ) 59 | output = model(random_input, hidden_state) 60 | 61 | # Assert 62 | assert output.numel() == batch_size * hidden_size 63 | 64 | 65 | def test_deep_recurrent_simple(): 66 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 67 | 68 | # Arrange 69 | batch_size = 15 70 | input_size = 20 71 | seq_size = 18 72 | output_size = 12 73 | model = deepstruct.recurrent.MaskedDeepRNN( 74 | input_size, hidden_layers=[100, 100, output_size], batch_first=True 75 | ) 76 | model.to(device) 77 | random_input = torch.tensor( 78 | np.random.random((batch_size, seq_size, input_size)), 79 | # np.random.random((seq_size, batch_size, input_size)), 80 | dtype=torch.float32, 81 | device=device, 82 | requires_grad=False, 83 | ) 84 | print(str(model)) 85 | 86 | # Act 87 | result_shape = model(random_input).shape 88 | 89 | assert result_shape.numel() == batch_size * output_size 90 | 91 | 92 | def test_deep_recurrent_layertypes_simple(): 93 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 94 | 95 | # Arrange 96 | batch_size = 15 97 | input_size = 20 98 | seq_size = 18 99 | output_size = 12 100 | layer_constructors = [ 101 | deepstruct.recurrent.MaskedRecurrentLayer, 102 | deepstruct.recurrent.MaskedGRULayer, 103 | deepstruct.recurrent.MaskedLSTMLayer, 104 | ] 105 | for builder in layer_constructors: 106 | model = deepstruct.recurrent.MaskedDeepRNN( 107 | input_size, 108 | hidden_layers=[100, 100, output_size], 109 | batch_first=True, 110 | build_recurrent_layer=builder, 111 | ) 112 | model.to(device) 113 | random_input = torch.tensor( 114 | np.random.random((batch_size, seq_size, input_size)), 115 | # np.random.random((seq_size, batch_size, input_size)), 116 | dtype=torch.float32, 117 | device=device, 118 | requires_grad=False, 119 | ) 120 | 121 | # Act 122 | result_shape = model(random_input).shape 123 | 124 | assert result_shape.numel() == batch_size * output_size 125 | 126 | 127 | def test_learn_start_symbol(): 128 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 129 | 130 | # Arrange 131 | n_samples = 200 132 | batch_size = 15 133 | input_size = 1 134 | seq_size = 20 135 | output_size = 10 136 | 137 | # Build up samples like 138 | # [1, 2, 3, 4, 5, 0, 0, 0] --> 1 139 | # [3, 4, 5, 6, 0, 0, 0, 0] --> 3 140 | samples_x = [] 141 | samples_y = [] 142 | for sample_idx in range(n_samples): 143 | start = np.random.randint(1, 20) + 1 144 | length = np.random.randint(2, seq_size + 1) 145 | x_seq = np.pad( 146 | np.arange(start, start + length), (0, seq_size - length), "constant" 147 | ).reshape(seq_size, 1) 148 | y_result = start 149 | samples_x.append(x_seq) 150 | samples_y.append(y_result) 151 | 152 | samples_x = np.array(samples_x) 153 | samples_y = np.array(samples_y) 154 | 155 | # Define model, loss and optimizer 156 | model = deepstruct.recurrent.MaskedDeepRNN( 157 | input_size, 158 | hidden_layers=[30, output_size], 159 | batch_first=True, 160 | build_recurrent_layer=deepstruct.recurrent.MaskedLSTMLayer, 161 | ) 162 | model = torch.nn.Sequential(model, torch.nn.Linear(output_size, 1)) 163 | model.to(device) 164 | 165 | loss = torch.nn.MSELoss() 166 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 167 | 168 | # Perform training with backpropagation 169 | for epoch in range(100): 170 | errors = [] 171 | for batch_idx in range(int(n_samples / batch_size)): 172 | optimizer.zero_grad() 173 | 174 | output = model( 175 | torch.tensor( 176 | samples_x[batch_idx * 15 : 15 + batch_idx * 15], 177 | dtype=torch.float32, 178 | device=device, 179 | ) 180 | ) 181 | expected = torch.tensor( 182 | samples_y[batch_idx * 15 : 15 + batch_idx * 15].reshape(batch_size, 1), 183 | dtype=torch.float32, 184 | device=device, 185 | ) 186 | 187 | error = loss(output, expected) 188 | error.backward() 189 | errors.append(error.detach().cpu().numpy()) 190 | 191 | optimizer.step() 192 | 193 | for start, length in [(1, 3), (5, 10), (8, 5)]: 194 | x_seq = np.pad( 195 | np.arange(start, start + length), (0, seq_size - length), "constant" 196 | ).reshape(1, seq_size, 1) 197 | prediction = model(torch.tensor(x_seq, dtype=torch.float32, device=device)) 198 | target = start 199 | print("Test: target=", target, "prediction=", prediction) 200 | 201 | 202 | def test_learn_summation(): 203 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 204 | 205 | # Arrange 206 | n_samples = 200 207 | batch_size = 15 208 | input_size = 1 209 | seq_size = 20 210 | output_size = 10 211 | 212 | # Build up samples like 213 | # [1, 2, 3, 4, 5, 0, 0, 0] --> 15 214 | # [3, 4, 5, 6, 0, 0, 0, 0] --> 18 215 | def gauss(n): 216 | return (n * (n + 1)) / 2 217 | 218 | samples_x = [] 219 | samples_y = [] 220 | for sample_idx in range(n_samples): 221 | start = np.random.randint(1, 20) 222 | length = np.random.randint(2, seq_size + 1) 223 | x_seq = np.pad( 224 | np.arange(start, start + length), (0, seq_size - length), "constant" 225 | ).reshape(seq_size, 1) 226 | y_result = gauss(start + length - 1) - gauss(start - 1) 227 | samples_x.append(x_seq) 228 | samples_y.append(y_result) 229 | 230 | samples_x = np.array(samples_x) 231 | samples_y = np.array(samples_y) 232 | 233 | # Define model, loss and optimizer 234 | model = deepstruct.recurrent.MaskedDeepRNN( 235 | input_size, 236 | hidden_layers=[50, output_size], 237 | batch_first=True, 238 | nonlinearity=torch.nn.ReLU(), 239 | ) 240 | model = torch.nn.Sequential(model, torch.nn.Linear(output_size, 1)) 241 | model.to(device) 242 | loss = torch.nn.MSELoss() 243 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 244 | 245 | # Perform training with backpropagation 246 | for epoch in range(100): 247 | errors = [] 248 | for batch_idx in range(int(n_samples / batch_size)): 249 | optimizer.zero_grad() 250 | 251 | output = model( 252 | torch.tensor( 253 | samples_x[batch_idx * 15 : 15 + batch_idx * 15], 254 | dtype=torch.float32, 255 | device=device, 256 | ) 257 | ) 258 | expected = torch.tensor( 259 | samples_y[batch_idx * 15 : 15 + batch_idx * 15].reshape(batch_size, 1), 260 | dtype=torch.float32, 261 | device=device, 262 | ) 263 | 264 | error = loss(output, expected) 265 | error.backward() 266 | errors.append(error.detach().cpu().numpy()) 267 | 268 | optimizer.step() 269 | 270 | for start, length in [(1, 5), (5, 7), (8, 4)]: 271 | x_seq = np.pad( 272 | np.arange(start, start + length), (0, seq_size - length), "constant" 273 | ).reshape(1, seq_size, 1) 274 | prediction = model(torch.tensor(x_seq, dtype=torch.float32, device=device)) 275 | target = gauss(start + length - 1) - gauss(start - 1) 276 | print("Test: target=", target, "prediction=", prediction) 277 | -------------------------------------------------------------------------------- /tests/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innvariant/deepstruct/0ffbaf73425c063485cb173f5c0fc3677368b522/tests/training/__init__.py -------------------------------------------------------------------------------- /tests/training/test_MaskedLinearLayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import deepstruct.pruning.util as dprutil 5 | import deepstruct.sparse 6 | 7 | 8 | def test_learn(): 9 | # Arrange 10 | input_size = 5 11 | output_size = 2 12 | hidden_size = 10 13 | layer_one = deepstruct.sparse.MaskedLinearLayer( 14 | input_size, hidden_size, mask_as_params=False 15 | ) 16 | layer_h2h = deepstruct.sparse.MaskedLinearLayer( 17 | hidden_size, hidden_size, mask_as_params=False 18 | ) 19 | layer_out = deepstruct.sparse.MaskedLinearLayer( 20 | hidden_size, output_size, mask_as_params=False 21 | ) 22 | layer_one.apply(dprutil.set_random_masks) 23 | # layer_h2h.mask = torch.ones((hidden_size, hidden_size)) 24 | for n_source in range(hidden_size): 25 | for n_target in range(hidden_size): 26 | layer_h2h[n_source, n_target] = ( 27 | 0 if n_source < hidden_size / 2 or n_target < hidden_size / 2 else 1 28 | ) 29 | print(layer_h2h.weight) 30 | print(layer_h2h.mask) 31 | 32 | layer_out.apply(dprutil.set_random_masks) 33 | 34 | samples_per_class = 5000 35 | ys = torch.cat([torch.ones(samples_per_class), torch.zeros(samples_per_class)]) 36 | means = (ys * 2) - 1 + torch.randn_like(ys) 37 | input = torch.stack([torch.normal(means, 1) for _ in range(input_size)], dim=1) 38 | shuffling = torch.randperm(2 * samples_per_class) 39 | prop_train = 0.8 40 | offset_train = int(prop_train * len(shuffling)) 41 | ids_train = shuffling[:offset_train] 42 | ids_test = shuffling[offset_train:] 43 | input_train = input[ids_train, :] 44 | input_test = input[ids_test, :] 45 | target_train = ys[ids_train].long() 46 | target_test = ys[ids_test].long() 47 | 48 | optimizer = torch.optim.Adam( 49 | list(layer_one.parameters()) 50 | + list(layer_h2h.parameters()) 51 | + list(layer_out.parameters()), 52 | lr=0.02, 53 | weight_decay=0.1, 54 | ) 55 | loss = torch.nn.CrossEntropyLoss() 56 | 57 | # Act 58 | print(layer_one.weight) 59 | print(layer_h2h.weight) 60 | h2h_priortraining = torch.clone(layer_h2h.weight) 61 | for _ in range(100): 62 | optimizer.zero_grad() 63 | h = layer_one(input_train) 64 | h = layer_h2h(torch.tanh(h)) 65 | prediction = layer_out(torch.tanh(h)) 66 | error = loss(prediction, target_train) 67 | error.backward() 68 | optimizer.step() 69 | print(layer_one.weight) 70 | print(layer_h2h.weight) 71 | h2h_posttraining = torch.clone(layer_h2h.weight) 72 | 73 | print("Diffs h2h weights") 74 | print(torch.abs(h2h_priortraining - h2h_posttraining)) 75 | 76 | # print(torch.round(torch.where(layer_h2h.mask.bool(), layer_h2h.weight*10**3, torch.zeros_like(layer_h2h.weight))) / (10**3)) 77 | 78 | h = layer_one(input_test) 79 | print("h", h) 80 | h_nomask = F.linear(torch.tanh(h), layer_h2h.weight, layer_h2h.bias) 81 | print("h_nomask", h_nomask) 82 | h = layer_h2h(torch.tanh(h)) 83 | print("h", h) 84 | prediction = layer_out(torch.tanh(h)) 85 | print("h2h_bias", layer_h2h.bias) 86 | 87 | accuracy = float(torch.sum(torch.argmax(prediction, axis=1) == target_test)) / len( 88 | target_test 89 | ) 90 | assert accuracy > 0.5 91 | -------------------------------------------------------------------------------- /tests/transform/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innvariant/deepstruct/0ffbaf73425c063485cb173f5c0fc3677368b522/tests/transform/__init__.py -------------------------------------------------------------------------------- /tests/transform/test_Conv2dLayerFunctor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from deepstruct.transform import Conv2dLayerFunctor 5 | 6 | 7 | def test_conv_simple(): 8 | # Arrange 9 | input_width = 5 10 | input_height = 5 11 | channels_in = 3 12 | kernel_size = (3, 3) 13 | model = torch.nn.Conv2d( 14 | in_channels=channels_in, out_channels=2, kernel_size=kernel_size, stride=1 15 | ) 16 | # Make sure each weight is large enough so none is getting "pruned" 17 | model.weight[:, :].data += 10 18 | 19 | functor = Conv2dLayerFunctor(input_width, input_height, threshold=0.01) 20 | 21 | # Act 22 | result = functor.transform(model) 23 | output = model.forward(torch.rand(size=(1, channels_in, input_height, input_width))) 24 | number_output_features = np.prod(output.shape) 25 | 26 | # Assert 27 | assert result.last_layer_size == number_output_features 28 | # TODO: check connectivity 29 | 30 | """import networkx as nx 31 | from networkx.drawing.nx_agraph import write_dot, graphviz_layout 32 | import matplotlib.pyplot as plt 33 | pos = graphviz_layout(result, prog='dot') 34 | nx.draw(result, pos, with_labels=True, arrows=True) 35 | plt.show()""" 36 | 37 | 38 | def test_conv_nonsquare_kernel(): 39 | # Arrange 40 | input_width = 20 41 | input_height = 10 42 | channels_in = 3 43 | kernel_size = (6, 3) 44 | assert kernel_size[0] != kernel_size[1] 45 | model = torch.nn.Conv2d( 46 | in_channels=channels_in, out_channels=2, kernel_size=kernel_size, stride=1 47 | ) 48 | # Make sure each weight is large enough so none is getting "pruned" 49 | model.weight[:, :].data += 10 50 | 51 | functor = Conv2dLayerFunctor(input_width, input_height, threshold=0.01) 52 | 53 | # Act 54 | result = functor.transform(model) 55 | output = model.forward(torch.rand(size=(1, channels_in, input_height, input_width))) 56 | number_output_features = np.prod(output.shape) 57 | 58 | # Assert 59 | assert result.last_layer_size == number_output_features 60 | 61 | 62 | def test_conv_multiple_configs(): 63 | # Arrange 64 | input_width = 10 65 | input_height = 10 66 | models = [] 67 | channels_in = 3 68 | 69 | for stride in range(1, 3): 70 | for channels_out in range(1, 5): 71 | for kernel_dim_size in range(2, 7): 72 | kernel_size = (kernel_dim_size, kernel_dim_size) 73 | model = torch.nn.Conv2d( 74 | in_channels=channels_in, 75 | out_channels=channels_out, 76 | kernel_size=kernel_size, 77 | stride=stride, 78 | ) 79 | # Make sure each weight is large enough so none is getting "pruned" 80 | model.weight[:, :].data += 10 81 | models.append(model) 82 | 83 | functor = Conv2dLayerFunctor(input_width, input_height, threshold=0.01) 84 | 85 | # Act 86 | for model in models: 87 | result = functor.transform(model) 88 | output = model.forward( 89 | torch.rand(size=(1, channels_in, input_height, input_width)) 90 | ) 91 | number_output_features = np.prod(output.shape) 92 | 93 | # Assert 94 | assert result.last_layer_size == number_output_features 95 | 96 | 97 | def test_realistic_convolution(): 98 | # Arrange 99 | input_width = 100 # 100x100 is already quite a huge graph 100 | input_height = 100 101 | channels_in = 3 102 | kernel_size = (5, 5) 103 | model = torch.nn.Conv2d( 104 | in_channels=channels_in, out_channels=2, kernel_size=kernel_size, stride=1 105 | ) 106 | # Make sure each weight is large enough so none is getting "pruned" 107 | model.weight[:, :].data += 10 108 | output = model.forward(torch.rand(size=(1, channels_in, input_height, input_width))) 109 | number_output_features = np.prod(output.shape) 110 | 111 | functor = Conv2dLayerFunctor(input_width, input_height, threshold=0.01) 112 | 113 | # Act 114 | result = functor.transform(model) 115 | 116 | # Assert 117 | assert result.last_layer_size == number_output_features 118 | -------------------------------------------------------------------------------- /tests/transform/test_GraphTransform.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import torch 3 | 4 | import deepstruct.sparse 5 | 6 | from deepstruct.transform import GraphTransform 7 | 8 | 9 | class SimpleNet(torch.nn.Module): 10 | def __init__(self, size_input: int, size_hidden: int): 11 | super().__init__() 12 | self._size_hidden = size_hidden 13 | self._linear1 = torch.nn.Linear(size_input, size_hidden * size_hidden) 14 | self._linear2 = torch.nn.Linear( 15 | size_hidden * size_hidden, size_hidden * size_hidden 16 | ) 17 | self._conv = torch.nn.Conv2d(1, 1, (3, 3)) 18 | self._linear3 = torch.nn.Linear(1 * 2 * 2, 10) 19 | self._act = torch.nn.ReLU() 20 | 21 | def forward(self, input): 22 | out = self._act(self._linear2(self._act(self._linear1(input)))) 23 | out = out.reshape((1, 1, self._size_hidden, self._size_hidden)) 24 | out = self._conv(out).flatten() 25 | return self._linear3(out) 26 | 27 | 28 | def test_stacked_graph(): 29 | # Arrange 30 | # a linear module with larger input than its output 31 | size_input = 20 32 | size_hidden = 4 33 | model = SimpleNet(size_input, size_hidden) 34 | # Make sure each weight is large enough so none is getting "pruned" 35 | model._linear1.weight[:, :].data += 1 36 | model._linear2.weight[:, :].data += 1 37 | model._conv.weight[:, :].data += 1 38 | 39 | functor = GraphTransform(torch.randn((1, 20))) 40 | 41 | # Act 42 | result = functor.transform(model) 43 | print(result.nodes) 44 | 45 | 46 | def test_deep_ffn(): 47 | # Arrange 48 | # a linear module with larger input than its output 49 | shape_input = (50,) 50 | layers = [100, 50, 100] 51 | output_size = 10 52 | model = deepstruct.sparse.MaskedDeepFFN(shape_input, output_size, layers) 53 | for layer in deepstruct.sparse.maskable_layers(model): 54 | layer.weight[:, :].data += 1 # make sure everything is fully connected 55 | 56 | functor = GraphTransform(torch.randn((1,) + shape_input)) 57 | 58 | # Act 59 | result = functor.transform(model) 60 | 61 | assert len(result.nodes) == shape_input[0] + sum(layers) + output_size 62 | assert ( 63 | len(result.edges) 64 | == shape_input[0] * layers[0] 65 | + sum(l1 * l2 for l1, l2 in zip(layers[0:-1], layers[1:])) 66 | + layers[-1] * output_size 67 | ) 68 | 69 | 70 | def test_deep_ffn2(): 71 | # Arrange 72 | # a linear module with larger input than its output 73 | shape_input = (50,) 74 | layers = [100] * 100 75 | output_size = 10 76 | model = deepstruct.sparse.MaskedDeepFFN(shape_input, output_size, layers) 77 | for layer in deepstruct.sparse.maskable_layers(model): 78 | layer.weight[:, :].data += 1 # make sure everything is fully connected 79 | 80 | functor = GraphTransform(torch.randn((1,) + shape_input)) 81 | 82 | # Act 83 | result = functor.transform(model) 84 | 85 | assert len(result.nodes) == shape_input[0] + sum(layers) + output_size 86 | assert ( 87 | len(result.edges) 88 | == shape_input[0] * layers[0] 89 | + sum(l1 * l2 for l1, l2 in zip(layers[0:-1], layers[1:])) 90 | + layers[-1] * output_size 91 | ) 92 | 93 | 94 | def test_deep_dan(): 95 | # Arrange 96 | shape_input = (50,) 97 | output_size = 10 98 | random_graph = nx.newman_watts_strogatz_graph(100, 4, 0.5) 99 | print(len(random_graph.nodes)) 100 | print(len(random_graph.edges)) 101 | structure = deepstruct.graph.CachedLayeredGraph() 102 | structure.add_edges_from(random_graph.edges) 103 | structure.add_nodes_from(random_graph.nodes) 104 | model = deepstruct.sparse.MaskedDeepDAN(shape_input, output_size, structure) 105 | for layer in deepstruct.sparse.maskable_layers(model): 106 | layer.weight[:, :].data += 1 # make sure everything is fully connected 107 | 108 | functor = GraphTransform(torch.randn((1,) + shape_input)) 109 | 110 | # Act 111 | result = functor.transform(model) 112 | 113 | print(len(result.nodes)) 114 | print(len(result.edges)) 115 | -------------------------------------------------------------------------------- /tests/transform/test_LinearLayerFunctor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from deepstruct.transform import LinearLayerFunctor 5 | 6 | 7 | def test_linear_simple_compression(): 8 | # Arrange 9 | # a linear module with larger input than its output 10 | size_layer_1 = 10 11 | size_layer_2 = 5 12 | assert size_layer_2 < size_layer_1 13 | model = torch.nn.Linear(size_layer_1, size_layer_2) 14 | # Make sure each weight is large enough so none is getting "pruned" 15 | model.weight.data += 1 16 | functor = LinearLayerFunctor(threshold=0.01) 17 | 18 | # Act 19 | result = functor.transform(model) 20 | 21 | # Assert 22 | assert len(result.nodes) == size_layer_1 + size_layer_2 23 | assert len(result.edges) == size_layer_1 * size_layer_2 24 | assert result.num_layers == 2 25 | 26 | 27 | def test_linear_simple_expansion(): 28 | # Arrange 29 | # a linear module with a smaller input than output 30 | size_layer_1 = 4 31 | size_layer_2 = 8 32 | assert size_layer_2 > size_layer_1 33 | model = torch.nn.Linear(size_layer_1, size_layer_2) 34 | model.weight.data = torch.tensor( 35 | np.random.uniform(1, 2, size=(size_layer_2, size_layer_1)), dtype=torch.float32 36 | ) 37 | functor = LinearLayerFunctor(threshold=0.01) 38 | 39 | # Act 40 | result = functor.transform(model) 41 | 42 | # Assert 43 | assert len(result.nodes) == size_layer_1 + size_layer_2 44 | assert len(result.edges) == size_layer_1 * size_layer_2 45 | assert result.num_layers == 2 46 | 47 | 48 | def test_linear_sparse(): 49 | # Arrange 50 | # calculate a sparse binary matrix and create a linear module of it 51 | size_layer_1 = 5 52 | size_layer_2 = 10 53 | model = torch.nn.Linear(size_layer_1, size_layer_2) 54 | weights = np.random.uniform(1, 2, size=(size_layer_2, size_layer_1)) 55 | mask = np.random.binomial(1, 0.3, size=(size_layer_2, size_layer_1)) 56 | num_edges = np.sum(mask) 57 | model.weight.data = torch.tensor(weights * mask, dtype=torch.float32) 58 | functor = LinearLayerFunctor(threshold=0.01) 59 | 60 | # Act 61 | result = functor.transform(model) 62 | 63 | # Assert 64 | assert len(result.nodes) == size_layer_1 + size_layer_2 65 | assert len(result.edges) == num_edges 66 | assert result.num_layers == 2 67 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | from matplotlib import pyplot as plt 4 | 5 | 6 | def plot_graph(graph, title): 7 | labels = nx.get_node_attributes(graph, "name") 8 | fig, ax = plt.subplots(figsize=(10, 10)) 9 | nx.draw( 10 | graph, 11 | labels=labels, 12 | with_labels=True, 13 | node_size=700, 14 | node_color="lightblue", 15 | font_size=8, 16 | ax=ax, 17 | ) 18 | plt.title(title) 19 | plt.show() 20 | 21 | 22 | def calculate_network_metrics(graph): 23 | metrics = { 24 | "nodes": graph.number_of_nodes(), 25 | "edges": graph.number_of_edges(), 26 | "avg degree": sum(dict(graph.degree()).values()) / graph.number_of_nodes(), 27 | "density": nx.density(graph), 28 | "Average cluster coefficient": nx.average_clustering(graph), 29 | } 30 | return metrics 31 | --------------------------------------------------------------------------------