├── README.rst ├── tasks ├── __init__.py ├── adding_task.py ├── argmax_task.py └── copy_task.py ├── docs ├── dnc.png ├── dnc0.png ├── dnc1.png ├── dnc2.png ├── dnc-mem-debug.png └── dnc.xml ├── release.sh ├── .gitignore ├── .vscode └── settings.json ├── dnc ├── __init__.py ├── faiss_index.py ├── sam.py ├── sdnc.py ├── util.py ├── sparse_memory.py ├── dnc.py └── sparse_temporal_memory.py ├── .travis.yml ├── setup.cfg ├── LICENSE ├── requirements.txt ├── test ├── test_utils.py ├── test_rnn.py ├── test_lstm.py ├── test_gru.py ├── test_sam_gru.py ├── test_sam_rnn.py ├── test_sam_lstm.py ├── test_sdnc_rnn.py ├── test_sdnc_gru.py └── test_sdnc_lstm.py ├── .pre-commit-config.yaml ├── setup.py ├── CHANGELOG.md └── scripts └── build_faiss.sh /README.rst: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/dnc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ixaxaar/pytorch-dnc/HEAD/docs/dnc.png -------------------------------------------------------------------------------- /docs/dnc0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ixaxaar/pytorch-dnc/HEAD/docs/dnc0.png -------------------------------------------------------------------------------- /docs/dnc1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ixaxaar/pytorch-dnc/HEAD/docs/dnc1.png -------------------------------------------------------------------------------- /docs/dnc2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ixaxaar/pytorch-dnc/HEAD/docs/dnc2.png -------------------------------------------------------------------------------- /docs/dnc-mem-debug.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ixaxaar/pytorch-dnc/HEAD/docs/dnc-mem-debug.png -------------------------------------------------------------------------------- /release.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | rm -rf dist/ 4 | python3 setup.py sdist 5 | python3 setup.py bdist_wheel 6 | twine upload dist/* 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | __pycache__/ 3 | .pypirc 4 | pred.txt 5 | multi-bleu.perl 6 | *.pt 7 | *.pyc 8 | #.* 9 | .idea 10 | *.sublime-* 11 | .DS_Store 12 | data/ 13 | build/ 14 | venv/ 15 | __pycache__/ 16 | *.lang 17 | *.log 18 | .cache/ 19 | dist/ 20 | dnc.egg-info/ 21 | tasks/checkpoints/ 22 | faiss/ 23 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.tokenColorCustomizations": { 3 | "comments": "", 4 | "textMateRules": [] 5 | }, 6 | "python.formatting.provider": "black", 7 | "python.linting.enabled": true, 8 | "python.linting.flake8Enabled": false, 9 | "python.linting.mypyEnabled": true 10 | } 11 | -------------------------------------------------------------------------------- /dnc/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | from .dnc import DNC 5 | from .memory import Memory 6 | from .sam import SAM 7 | from .sdnc import SDNC 8 | from .sparse_memory import SparseMemory 9 | from .sparse_temporal_memory import SparseTemporalMemory 10 | 11 | __all__ = [ 12 | "DNC", 13 | "Memory", 14 | "SAM", 15 | "SDNC", 16 | "SparseMemory", 17 | "SparseTemporalMemory", 18 | ] 19 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.6" 4 | # command to install dependencies 5 | before_install: 6 | - sudo apt-get -qq update 7 | - sudo apt-get install -yqq software-properties-common git 8 | - sudo apt-get install -yqq libopenblas-dev liblapack3 python3-numpy python3-dev swig 9 | - sudo ln -s /usr/lib/libopenblas.so /usr/lib/libopenblas.so.3 10 | install: 11 | - pip install -r ./requirements.txt 12 | # command to run tests 13 | script: 14 | - pytest ./test 15 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=0 3 | 4 | [flake8] 5 | exclude = */__init__.py,migrations/* 6 | ignore = D401, E741, DAR003, D100, B010, E203, W503, F405, F403, E126, E128, E501, F841, W505, E203, D202, DAR201, DAR401, F401, D107, DAR101, D400, D102, D205, D101 7 | 8 | max-line-length = 99 9 | max-doc-length = 99 10 | show-source = true 11 | 12 | [yapf] 13 | based_on_style = pep8 14 | allow_multiline_lambdas = True 15 | indent_dictionary_value = True 16 | allow_split_before_dict_value = False 17 | blank_lines_around_top_level_definition = 2 18 | column_limit = 99 19 | continuation_indent_width = 4 20 | indent_width = 4 21 | use_tabs = False 22 | split_before_arithmetic_operator = True 23 | 24 | [isort] 25 | line_length = 99 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Russi Chatterjee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | autoflake==2.3.1 2 | certifi==2025.1.31 3 | charset-normalizer==3.4.1 4 | exceptiongroup==1.2.2 5 | filelock==3.17.0 6 | fsspec==2025.2.0 7 | idna==3.10 8 | iniconfig==2.0.0 9 | Jinja2==3.1.6 10 | jsonpatch==1.33 11 | jsonpointer==3.0.0 12 | MarkupSafe==3.0.2 13 | mpmath==1.3.0 14 | networkx==3.4.2 15 | numpy==2.2.3 16 | nvidia-cublas-cu12==12.4.5.8 17 | nvidia-cuda-cupti-cu12==12.4.127 18 | nvidia-cuda-nvrtc-cu12==12.4.127 19 | nvidia-cuda-runtime-cu12==12.4.127 20 | nvidia-cudnn-cu12==9.1.0.70 21 | nvidia-cufft-cu12==11.2.1.3 22 | nvidia-curand-cu12==10.3.5.147 23 | nvidia-cusolver-cu12==11.6.1.9 24 | nvidia-cusparse-cu12==12.3.1.170 25 | nvidia-cusparselt-cu12==0.6.2 26 | nvidia-nccl-cu12==2.21.5 27 | nvidia-nvjitlink-cu12==12.4.127 28 | nvidia-nvtx-cu12==12.4.127 29 | packaging==24.2 30 | pillow==11.1.0 31 | pluggy==1.5.0 32 | pyflakes==3.2.0 33 | pytest==8.3.4 34 | requests==2.32.4 35 | scipy==1.15.2 36 | six==1.17.0 37 | sympy==1.13.1 38 | tomli==2.2.1 39 | torch==2.7.1 40 | torchaudio==2.6.0 41 | torchvision==0.21.0 42 | tornado==6.5.1 43 | triton==3.2.0 44 | typing_extensions==4.12.2 45 | urllib3==2.6.0 46 | visdom==0.2.4 47 | websocket-client==1.8.0 -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | 9 | def generate_data( 10 | batch_size: int, length: int, size: int, device: torch.device = torch.device("cpu") 11 | ) -> tuple[torch.Tensor, torch.Tensor]: 12 | """Generates data for the copy task, directly on the specified device.""" 13 | input_data = np.zeros((batch_size, 2 * length + 1, size), dtype=np.float32) 14 | target_output = np.zeros((batch_size, 2 * length + 1, size), dtype=np.float32) 15 | sequence = np.random.binomial(1, 0.5, (batch_size, length, size - 1)) 16 | 17 | input_data[:, :length, : size - 1] = sequence 18 | input_data[:, length, -1] = 1 # the end symbol 19 | target_output[:, length + 1 :, : size - 1] = sequence 20 | 21 | return (torch.tensor(input_data, device=device), torch.tensor(target_output, device=device)) 22 | 23 | 24 | def criterion(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 25 | """Calculates the binary cross-entropy loss with logits.""" 26 | return F.binary_cross_entropy_with_logits(predictions, targets) 27 | 28 | 29 | def get_device(cuda_id: int) -> torch.device: 30 | """Gets the torch device based on CUDA availability and ID.""" 31 | if cuda_id >= 0 and torch.cuda.is_available(): 32 | return torch.device(f"cuda:{cuda_id}") 33 | return torch.device("cpu") 34 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: no-commit-to-branch 6 | args: [--branch, master] 7 | - id: check-merge-conflict 8 | - id: trailing-whitespace 9 | - id: end-of-file-fixer 10 | - id: check-added-large-files 11 | - id: check-yaml 12 | args: [--allow-multiple-documents] 13 | - id: check-json 14 | - id: pretty-format-json 15 | args: [--autofix, --no-sort-keys] 16 | - id: check-xml 17 | - id: debug-statements 18 | - id: check-case-conflict 19 | - id: detect-private-key 20 | - id: requirements-txt-fixer 21 | 22 | - repo: https://github.com/psf/black 23 | rev: 25.1.0 24 | hooks: 25 | - id: black 26 | language: python 27 | args: [--line-length=120] 28 | 29 | - repo: https://github.com/PyCQA/flake8 30 | rev: 7.1.2 31 | hooks: 32 | - id: flake8 33 | additional_dependencies: 34 | ["flake8-bugbear", "flake8-docstrings", "darglint"] 35 | args: [--max-line-length=120] 36 | files: ^dnc/ 37 | 38 | - repo: https://github.com/pre-commit/mirrors-mypy 39 | rev: "v1.15.0" 40 | hooks: 41 | - id: mypy 42 | additional_dependencies: ["torch", "numpy", "types-PyYAML"] 43 | args: [--ignore-missing-imports] 44 | files: ^dnc/ 45 | 46 | default_language_version: 47 | python: python3 48 | 49 | fail_fast: true 50 | exclude: ^migrations/ 51 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """A setuptools based setup module. 4 | 5 | See: 6 | https://packaging.python.org/en/latest/distributing.html 7 | https://github.com/pypa/sampleproject 8 | """ 9 | 10 | # Always prefer setuptools over distutils 11 | from setuptools import setup, find_packages 12 | 13 | # To use a consistent encoding 14 | from codecs import open 15 | from os import path 16 | 17 | here = path.abspath(path.dirname(__file__)) 18 | 19 | # Get the long description from the README file 20 | with open(path.join(here, "README.rst"), encoding="utf-8") as f: 21 | long_description = f.read() 22 | 23 | setup( 24 | name="dnc", 25 | version="2.0.0b1", 26 | description="Differentiable Neural Computer, for Pytorch", 27 | long_description=long_description, 28 | # The project's main homepage. 29 | url="https://github.com/pypa/dnc", 30 | # Author details 31 | author="Russi Chatterjee", 32 | author_email="root@ixaxaar.in", 33 | # Choose your license 34 | license="MIT", 35 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers 36 | classifiers=[ 37 | "Development Status :: 3 - Alpha", 38 | "Intended Audience :: Science/Research", 39 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 40 | "License :: OSI Approved :: MIT License", 41 | "Programming Language :: Python :: 3", 42 | "Programming Language :: Python :: 3.3", 43 | "Programming Language :: Python :: 3.4", 44 | "Programming Language :: Python :: 3.5", 45 | "Programming Language :: Python :: 3.6", 46 | ], 47 | keywords="differentiable neural computer dnc memory network", 48 | packages=find_packages(exclude=["contrib", "docs", "tests", "tasks", "scripts"]), 49 | install_requires=["torch", "numpy", "flann"], 50 | extras_require={ 51 | "dev": ["check-manifest"], 52 | "test": ["coverage"], 53 | }, 54 | python_requires=">=3", 55 | ) 56 | -------------------------------------------------------------------------------- /dnc/faiss_index.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import faiss 5 | from faiss import cast_integer_to_float_ptr as cast_float 6 | from faiss import cast_integer_to_idx_t_ptr as cast_long 7 | 8 | import torch 9 | 10 | from .util import ptr, ensure_gpu 11 | 12 | 13 | class FAISSIndex(object): 14 | """FAISS Index for approximate nearest neighbor search.""" 15 | 16 | def __init__( 17 | self, 18 | cell_size: int = 20, 19 | nr_cells: int = 1024, 20 | K: int = 4, 21 | num_lists: int = 32, 22 | probes: int = 32, 23 | res: faiss.GpuResources | None = None, 24 | train: torch.Tensor | None = None, 25 | device: torch.device | None = None, 26 | ): 27 | """Initialize FAISSIndex. 28 | 29 | Args: 30 | cell_size: Size of each memory cell. 31 | nr_cells: Number of memory cells. 32 | K: Number of nearest neighbors to retrieve. 33 | num_lists: Number of lists for the index. 34 | probes: Number of probes for searching. 35 | res: FAISS GpuResources object. 36 | train: Training data. 37 | device: PyTorch device 38 | 39 | """ 40 | super(FAISSIndex, self).__init__() 41 | self.cell_size = cell_size 42 | self.nr_cells = nr_cells 43 | self.probes = probes 44 | self.K = K 45 | self.num_lists = num_lists 46 | self.device = device 47 | 48 | # BEWARE: if this variable gets deallocated, FAISS crashes 49 | self.res = res if res else faiss.StandardGpuResources() 50 | train_tensor = train if train is not None else torch.randn(self.nr_cells * 100, self.cell_size) 51 | 52 | # Configure FAISS resources for GPU if needed 53 | if self.device is not None and self.device.type == "cuda": 54 | self.res.setTempMemoryFraction(0.01) 55 | self.res.initializeForDevice(self.device.index if self.device.index is not None else 0) 56 | # Create GPU index with a quantizer 57 | quantizer = faiss.IndexFlatL2(self.cell_size) 58 | self.index = faiss.GpuIndexIVFFlat(self.res, quantizer, self.cell_size, self.num_lists, faiss.METRIC_L2) 59 | else: 60 | # Create CPU index for both None device and explicit CPU device 61 | # First create a quantizer 62 | quantizer = faiss.IndexFlatL2(self.cell_size) 63 | self.index = faiss.IndexIVFFlat(quantizer, self.cell_size, self.num_lists, faiss.METRIC_L2) 64 | 65 | # set number of probes 66 | self.index.nprobes = self.probes 67 | self.train(train_tensor) 68 | 69 | def train(self, train: torch.Tensor) -> None: 70 | """Trains the index. 71 | 72 | Args: 73 | train: Training data. 74 | """ 75 | train = ensure_gpu(train, self.device) 76 | 77 | # Only synchronize if using CUDA 78 | if self.device is not None and self.device.type == "cuda": 79 | torch.cuda.synchronize(self.device) 80 | 81 | self.index.train_c(self.nr_cells, cast_float(ptr(train))) 82 | 83 | # Only synchronize if using CUDA 84 | if self.device is not None and self.device.type == "cuda": 85 | torch.cuda.synchronize(self.device) 86 | 87 | def reset(self) -> None: 88 | """Resets the index.""" 89 | if self.device is not None and self.device.type == "cuda": 90 | torch.cuda.synchronize(self.device) 91 | self.index.reset() 92 | if self.device is not None and self.device.type == "cuda": 93 | torch.cuda.synchronize(self.device) 94 | 95 | def add(self, other: torch.Tensor, positions: torch.Tensor | None = None, last: int | None = None) -> None: 96 | """Adds vectors to the index. 97 | 98 | Args: 99 | other: Vectors to add. 100 | positions: Positions of the vectors. 101 | last: Index of the last vector to add. 102 | """ 103 | other = ensure_gpu(other, self.device) 104 | 105 | if self.device is not None and self.device.type == "cuda": 106 | torch.cuda.synchronize(self.device) 107 | 108 | if positions is not None: 109 | positions = ensure_gpu(positions, self.device).long() 110 | assert positions.size(0) == other.size(0), "Mismatch in number of positions and vectors" 111 | self.index.add_with_ids_c(other.size(0), cast_float(ptr(other)), cast_long(ptr(positions + 1))) 112 | else: 113 | other = other[:last, :] if last is not None else other 114 | self.index.add_c(other.size(0), cast_float(ptr(other))) 115 | 116 | if self.device is not None and self.device.type == "cuda": 117 | torch.cuda.synchronize(self.device) 118 | 119 | def search(self, query: torch.Tensor, k: int | None = None) -> tuple[torch.Tensor, torch.Tensor]: 120 | """Searches the index for nearest neighbors. 121 | 122 | Args: 123 | query: Query vectors. 124 | k: Number of nearest neighbors to retrieve. 125 | 126 | Returns: 127 | Tuple: Distances and labels of the nearest neighbors. 128 | """ 129 | 130 | query = ensure_gpu(query, self.device) 131 | 132 | k = k if k else self.K 133 | (b, _) = query.size() 134 | 135 | distances = torch.empty(b, k, device=self.device, dtype=torch.float32) 136 | labels = torch.empty(b, k, device=self.device, dtype=torch.int64) 137 | 138 | if self.device is not None and self.device.type == "cuda": 139 | torch.cuda.synchronize(self.device) 140 | 141 | self.index.search_c(b, cast_float(ptr(query)), k, cast_float(ptr(distances)), cast_long(ptr(labels))) 142 | 143 | if self.device is not None and self.device.type == "cuda": 144 | torch.cuda.synchronize(self.device) 145 | 146 | return (distances, (labels - 1)) 147 | -------------------------------------------------------------------------------- /test/test_rnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import torch as T 6 | import torch.optim as optim 7 | 8 | import sys 9 | 10 | sys.path.insert(0, ".") 11 | 12 | import functools 13 | 14 | from dnc import DNC 15 | from test_utils import generate_data, criterion 16 | 17 | 18 | def test_rnn_1(): 19 | T.manual_seed(1111) 20 | 21 | input_size = 100 22 | hidden_size = 100 23 | rnn_type = "rnn" 24 | num_layers = 1 25 | num_hidden_layers = 1 26 | dropout = 0 27 | nr_cells = 1 28 | cell_size = 1 29 | read_heads = 1 30 | device = None 31 | debug = True 32 | lr = 0.001 33 | sequence_max_length = 10 34 | batch_size = 10 35 | cuda = device 36 | clip = 10 37 | length = 10 38 | 39 | rnn = DNC( 40 | input_size=input_size, 41 | hidden_size=hidden_size, 42 | rnn_type=rnn_type, 43 | num_layers=num_layers, 44 | num_hidden_layers=num_hidden_layers, 45 | dropout=dropout, 46 | nr_cells=nr_cells, 47 | cell_size=cell_size, 48 | read_heads=read_heads, 49 | device=device, 50 | debug=debug, 51 | ) 52 | 53 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 54 | optimizer.zero_grad() 55 | 56 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 57 | 58 | output, (chx, mhx, rv), v = rnn(input_data, None) 59 | 60 | # Make output and target compatible for loss calculation 61 | # target: [batch, seq, features] -> [seq, batch, features] 62 | target_output = target_output.permute(1, 0, 2).contiguous() 63 | 64 | loss = criterion(output, target_output) 65 | loss.backward() 66 | 67 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 68 | optimizer.step() 69 | 70 | assert target_output.size() == T.Size([21, 10, 100]) 71 | assert chx[0][0].size() == T.Size([10, 100]) 72 | assert mhx[0]["memory"].size() == T.Size([10, 1, 1]) 73 | assert rv.size() == T.Size([10, 1]) 74 | 75 | 76 | def test_rnn_n(): 77 | T.manual_seed(1111) 78 | 79 | input_size = 100 80 | hidden_size = 100 81 | rnn_type = "rnn" 82 | num_layers = 3 83 | num_hidden_layers = 5 84 | dropout = 0.2 85 | nr_cells = 12 86 | cell_size = 17 87 | read_heads = 3 88 | device = None 89 | debug = True 90 | lr = 0.001 91 | sequence_max_length = 10 92 | batch_size = 10 93 | cuda = device 94 | clip = 20 95 | length = 13 96 | 97 | rnn = DNC( 98 | input_size=input_size, 99 | hidden_size=hidden_size, 100 | rnn_type=rnn_type, 101 | num_layers=num_layers, 102 | num_hidden_layers=num_hidden_layers, 103 | dropout=dropout, 104 | nr_cells=nr_cells, 105 | cell_size=cell_size, 106 | read_heads=read_heads, 107 | device=device, 108 | debug=debug, 109 | ) 110 | 111 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 112 | optimizer.zero_grad() 113 | 114 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 115 | 116 | output, (chx, mhx, rv), v = rnn(input_data, None) 117 | 118 | # Make output and target compatible for loss calculation 119 | # target: [batch, seq, features] -> [seq, batch, features] 120 | target_output = target_output.permute(1, 0, 2).contiguous() 121 | 122 | loss = criterion(output, target_output) 123 | loss.backward() 124 | 125 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 126 | optimizer.step() 127 | 128 | assert target_output.size() == T.Size([27, 10, 100]) 129 | assert chx[1].size() == T.Size([num_hidden_layers, 10, 100]) 130 | assert mhx[0]["memory"].size() == T.Size([10, 12, 17]) 131 | assert rv.size() == T.Size([10, 51]) 132 | 133 | 134 | def test_rnn_no_memory_pass(): 135 | T.manual_seed(1111) 136 | 137 | input_size = 100 138 | hidden_size = 100 139 | rnn_type = "rnn" 140 | num_layers = 3 141 | num_hidden_layers = 5 142 | dropout = 0.2 143 | nr_cells = 12 144 | cell_size = 17 145 | read_heads = 3 146 | device = None 147 | debug = True 148 | lr = 0.001 149 | sequence_max_length = 10 150 | batch_size = 10 151 | cuda = device 152 | clip = 20 153 | length = 13 154 | 155 | rnn = DNC( 156 | input_size=input_size, 157 | hidden_size=hidden_size, 158 | rnn_type=rnn_type, 159 | num_layers=num_layers, 160 | num_hidden_layers=num_hidden_layers, 161 | dropout=dropout, 162 | nr_cells=nr_cells, 163 | cell_size=cell_size, 164 | read_heads=read_heads, 165 | device=device, 166 | debug=debug, 167 | ) 168 | 169 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 170 | optimizer.zero_grad() 171 | 172 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 173 | 174 | # Transform target to match expected output shape 175 | target_output = target_output.permute(1, 0, 2).contiguous() 176 | 177 | # Initialize hidden state explicitly 178 | controller_hidden = None 179 | memory_hidden = None 180 | last_read = None 181 | outputs = [] 182 | 183 | for x in range(6): 184 | output, (controller_hidden, memory_hidden, last_read), v = rnn( 185 | input_data, (controller_hidden, memory_hidden, last_read), pass_through_memory=False 186 | ) 187 | outputs.append(output) 188 | 189 | # Sum outputs for all iterations 190 | output = functools.reduce(lambda x, y: x + y, outputs) 191 | loss = criterion(output, target_output) 192 | loss.backward() 193 | 194 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 195 | optimizer.step() 196 | 197 | assert target_output.size() == T.Size([27, 10, 100]) 198 | assert controller_hidden[1].size() == T.Size([num_hidden_layers, 10, 100]) 199 | assert memory_hidden[0]["memory"].size() == T.Size([10, 12, 17]) 200 | assert last_read is not None 201 | -------------------------------------------------------------------------------- /test/test_lstm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import torch as T 6 | import torch.optim as optim 7 | 8 | import sys 9 | import functools 10 | 11 | sys.path.insert(0, ".") 12 | 13 | from dnc import DNC 14 | from test_utils import generate_data, criterion 15 | 16 | 17 | def test_rnn_1(): 18 | T.manual_seed(1111) 19 | 20 | input_size = 100 21 | hidden_size = 100 22 | rnn_type = "lstm" 23 | num_layers = 1 24 | num_hidden_layers = 1 25 | dropout = 0 26 | nr_cells = 1 27 | cell_size = 1 28 | read_heads = 1 29 | device = None 30 | debug = True 31 | lr = 0.001 32 | sequence_max_length = 10 33 | batch_size = 10 34 | cuda = device 35 | clip = 10 36 | length = 10 37 | 38 | rnn = DNC( 39 | input_size=input_size, 40 | hidden_size=hidden_size, 41 | rnn_type=rnn_type, 42 | num_layers=num_layers, 43 | num_hidden_layers=num_hidden_layers, 44 | dropout=dropout, 45 | nr_cells=nr_cells, 46 | cell_size=cell_size, 47 | read_heads=read_heads, 48 | device=device, 49 | debug=debug, 50 | ) 51 | 52 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 53 | optimizer.zero_grad() 54 | 55 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 56 | 57 | output, (chx, mhx, rv), v = rnn(input_data, None) 58 | 59 | # Make output and target compatible for loss calculation 60 | # target: [batch, seq, features] -> [seq, batch, features] 61 | target_output = target_output.permute(1, 0, 2).contiguous() 62 | 63 | loss = criterion(output, target_output) 64 | loss.backward() 65 | 66 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 67 | optimizer.step() 68 | 69 | assert target_output.size() == T.Size([21, 10, 100]) 70 | assert chx[0][0][0].size() == T.Size([10, 100]) 71 | assert mhx[0]["memory"].size() == T.Size([10, 1, 1]) 72 | assert rv.size() == T.Size([10, 1]) 73 | 74 | 75 | def test_rnn_n(): 76 | T.manual_seed(1111) 77 | 78 | input_size = 100 79 | hidden_size = 100 80 | rnn_type = "lstm" 81 | num_layers = 3 82 | num_hidden_layers = 5 83 | dropout = 0.2 84 | nr_cells = 12 85 | cell_size = 17 86 | read_heads = 3 87 | device = None 88 | debug = True 89 | lr = 0.001 90 | sequence_max_length = 10 91 | batch_size = 10 92 | cuda = device 93 | clip = 20 94 | length = 13 95 | 96 | rnn = DNC( 97 | input_size=input_size, 98 | hidden_size=hidden_size, 99 | rnn_type=rnn_type, 100 | num_layers=num_layers, 101 | num_hidden_layers=num_hidden_layers, 102 | dropout=dropout, 103 | nr_cells=nr_cells, 104 | cell_size=cell_size, 105 | read_heads=read_heads, 106 | device=device, 107 | debug=debug, 108 | ) 109 | 110 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 111 | optimizer.zero_grad() 112 | 113 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 114 | 115 | output, (chx, mhx, rv), v = rnn(input_data, None) 116 | 117 | # Make output and target compatible for loss calculation 118 | # target: [batch, seq, features] -> [seq, batch, features] 119 | target_output = target_output.permute(1, 0, 2).contiguous() 120 | 121 | loss = criterion(output, target_output) 122 | loss.backward() 123 | 124 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 125 | optimizer.step() 126 | 127 | assert target_output.size() == T.Size([27, 10, 100]) 128 | assert chx[0][0].size() == T.Size([num_hidden_layers, 10, 100]) 129 | assert mhx[0]["memory"].size() == T.Size([10, 12, 17]) 130 | assert rv.size() == T.Size([10, 51]) 131 | 132 | 133 | def test_rnn_no_memory_pass(): 134 | T.manual_seed(1111) 135 | 136 | input_size = 100 137 | hidden_size = 100 138 | rnn_type = "lstm" 139 | num_layers = 3 140 | num_hidden_layers = 5 141 | dropout = 0.2 142 | nr_cells = 12 143 | cell_size = 17 144 | read_heads = 3 145 | device = None 146 | debug = True 147 | lr = 0.001 148 | sequence_max_length = 10 149 | batch_size = 10 150 | cuda = device 151 | clip = 20 152 | length = 13 153 | 154 | rnn = DNC( 155 | input_size=input_size, 156 | hidden_size=hidden_size, 157 | rnn_type=rnn_type, 158 | num_layers=num_layers, 159 | num_hidden_layers=num_hidden_layers, 160 | dropout=dropout, 161 | nr_cells=nr_cells, 162 | cell_size=cell_size, 163 | read_heads=read_heads, 164 | device=device, 165 | debug=debug, 166 | ) 167 | 168 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 169 | optimizer.zero_grad() 170 | 171 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 172 | 173 | # Transform target to match expected output shape 174 | target_output = target_output.permute(1, 0, 2).contiguous() 175 | 176 | # Initialize hidden state explicitly 177 | controller_hidden = None 178 | memory_hidden = None 179 | last_read = None 180 | outputs = [] 181 | 182 | for x in range(6): 183 | output, (controller_hidden, memory_hidden, last_read), v = rnn( 184 | input_data, (controller_hidden, memory_hidden, last_read), pass_through_memory=False 185 | ) 186 | outputs.append(output) 187 | 188 | # Sum outputs for all iterations 189 | output = functools.reduce(lambda x, y: x + y, outputs) 190 | loss = criterion(output, target_output) 191 | loss.backward() 192 | 193 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 194 | optimizer.step() 195 | 196 | assert target_output.size() == T.Size([27, 10, 100]) 197 | assert controller_hidden[0][0].size() == T.Size([num_hidden_layers, 10, 100]) 198 | assert memory_hidden[0]["memory"].size() == T.Size([10, 12, 17]) 199 | assert last_read is not None 200 | -------------------------------------------------------------------------------- /test/test_gru.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import torch as T 6 | import torch.optim as optim 7 | 8 | import sys 9 | 10 | sys.path.insert(0, ".") 11 | 12 | import functools 13 | 14 | from dnc import DNC 15 | from test_utils import generate_data, criterion 16 | 17 | 18 | def test_rnn_1(): 19 | T.manual_seed(1111) 20 | 21 | input_size = 100 22 | hidden_size = 100 23 | rnn_type = "gru" 24 | num_layers = 1 25 | num_hidden_layers = 1 26 | dropout = 0 27 | nr_cells = 1 28 | cell_size = 1 29 | read_heads = 1 30 | device = None 31 | debug = True 32 | lr = 0.001 33 | sequence_max_length = 10 34 | batch_size = 10 35 | cuda = device 36 | clip = 10 37 | length = 10 38 | 39 | rnn = DNC( 40 | input_size=input_size, 41 | hidden_size=hidden_size, 42 | rnn_type=rnn_type, 43 | num_layers=num_layers, 44 | num_hidden_layers=num_hidden_layers, 45 | dropout=dropout, 46 | nr_cells=nr_cells, 47 | cell_size=cell_size, 48 | read_heads=read_heads, 49 | device=device, 50 | debug=debug, 51 | ) 52 | 53 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 54 | optimizer.zero_grad() 55 | 56 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 57 | 58 | output, (chx, mhx, rv), v = rnn(input_data, None) 59 | 60 | # Make output and target compatible for loss calculation 61 | # target: [batch, seq, features] -> [seq, batch, features] 62 | target_output = target_output.permute(1, 0, 2).contiguous() 63 | 64 | loss = criterion(output, target_output) 65 | loss.backward() 66 | 67 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 68 | optimizer.step() 69 | 70 | assert target_output.size() == T.Size([21, 10, 100]) 71 | assert chx[0][0].size() == T.Size([10, 100]) 72 | assert mhx[0]["memory"].size() == T.Size([10, 1, 1]) 73 | assert rv.size() == T.Size([10, 1]) 74 | 75 | 76 | def test_rnn_n(): 77 | T.manual_seed(1111) 78 | 79 | input_size = 100 80 | hidden_size = 100 81 | rnn_type = "gru" 82 | num_layers = 3 83 | num_hidden_layers = 5 84 | dropout = 0.2 85 | nr_cells = 12 86 | cell_size = 17 87 | read_heads = 3 88 | device = None 89 | debug = True 90 | lr = 0.001 91 | sequence_max_length = 10 92 | batch_size = 10 93 | cuda = device 94 | clip = 20 95 | length = 13 96 | 97 | rnn = DNC( 98 | input_size=input_size, 99 | hidden_size=hidden_size, 100 | rnn_type=rnn_type, 101 | num_layers=num_layers, 102 | num_hidden_layers=num_hidden_layers, 103 | dropout=dropout, 104 | nr_cells=nr_cells, 105 | cell_size=cell_size, 106 | read_heads=read_heads, 107 | device=device, 108 | debug=debug, 109 | ) 110 | 111 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 112 | optimizer.zero_grad() 113 | 114 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 115 | 116 | output, (chx, mhx, rv), v = rnn(input_data, None) 117 | 118 | # Make output and target compatible for loss calculation 119 | # target: [batch, seq, features] -> [seq, batch, features] 120 | target_output = target_output.permute(1, 0, 2).contiguous() 121 | 122 | loss = criterion(output, target_output) 123 | loss.backward() 124 | 125 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 126 | optimizer.step() 127 | 128 | assert target_output.size() == T.Size([27, 10, 100]) 129 | assert chx[1].size() == T.Size([num_hidden_layers, 10, 100]) 130 | assert mhx[0]["memory"].size() == T.Size([10, 12, 17]) 131 | assert rv.size() == T.Size([10, 51]) 132 | 133 | 134 | def test_rnn_no_memory_pass(): 135 | T.manual_seed(1111) 136 | 137 | input_size = 100 138 | hidden_size = 100 139 | rnn_type = "gru" 140 | num_layers = 3 141 | num_hidden_layers = 5 142 | dropout = 0.2 143 | nr_cells = 12 144 | cell_size = 17 145 | read_heads = 3 146 | device = None 147 | debug = True 148 | lr = 0.001 149 | sequence_max_length = 10 150 | batch_size = 10 151 | cuda = device 152 | clip = 20 153 | length = 13 154 | 155 | rnn = DNC( 156 | input_size=input_size, 157 | hidden_size=hidden_size, 158 | rnn_type=rnn_type, 159 | num_layers=num_layers, 160 | num_hidden_layers=num_hidden_layers, 161 | dropout=dropout, 162 | nr_cells=nr_cells, 163 | cell_size=cell_size, 164 | read_heads=read_heads, 165 | device=device, 166 | debug=debug, 167 | ) 168 | 169 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 170 | optimizer.zero_grad() 171 | 172 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 173 | 174 | # Transform target to match expected output shape 175 | target_output = target_output.permute(1, 0, 2).contiguous() 176 | 177 | # Initialize hidden state explicitly 178 | controller_hidden = None 179 | memory_hidden = None 180 | last_read = None 181 | outputs = [] 182 | 183 | for x in range(6): 184 | output, (controller_hidden, memory_hidden, last_read), v = rnn( 185 | input_data, (controller_hidden, memory_hidden, last_read), pass_through_memory=False 186 | ) 187 | outputs.append(output) 188 | 189 | # Sum outputs for all iterations 190 | output = functools.reduce(lambda x, y: x + y, outputs) 191 | loss = criterion(output, target_output) 192 | loss.backward() 193 | 194 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 195 | optimizer.step() 196 | 197 | assert target_output.size() == T.Size([27, 10, 100]) 198 | assert controller_hidden[0].size() == T.Size([num_hidden_layers, 10, 100]) 199 | assert memory_hidden[0]["memory"].size() == T.Size([10, 12, 17]) 200 | # Last read might not be None due to the memory access with pass_through_memory=False 201 | assert last_read is not None 202 | -------------------------------------------------------------------------------- /dnc/sam.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | 5 | from .dnc import DNC 6 | from .sparse_memory import SparseMemory 7 | 8 | 9 | class SAM(DNC): 10 | """Sparse Access Memory (SAM) module.""" 11 | 12 | def __init__( 13 | self, 14 | input_size: int, 15 | hidden_size: int, 16 | rnn_type: str = "lstm", 17 | num_layers: int = 1, 18 | num_hidden_layers: int = 2, 19 | bias: bool = True, 20 | batch_first: bool = True, 21 | dropout: float = 0.0, 22 | bidirectional: bool = False, 23 | nr_cells: int = 5000, 24 | sparse_reads: int = 4, 25 | read_heads: int = 4, 26 | cell_size: int = 10, 27 | nonlinearity: str = "tanh", 28 | independent_linears: bool = False, 29 | share_memory: bool = True, 30 | debug: bool = False, 31 | clip: float = 20, 32 | device: torch.device | None = None, 33 | ): 34 | """Initialize SAM. 35 | 36 | Args: 37 | input_size: Input size. 38 | hidden_size: Hidden size. 39 | rnn_type: Type of RNN cell (lstm, gru, rnn). 40 | num_layers: Number of RNN layers. 41 | num_hidden_layers: Number of hidden layers in each RNN. 42 | bias: Whether to use bias in the RNN. 43 | batch_first: Whether the input is batch-first. 44 | dropout: Dropout rate. 45 | bidirectional: Whether the RNN is bidirectional. 46 | nr_cells: Number of memory cells. 47 | sparse_reads: Number of sparse reads. 48 | read_heads: Number of read heads. 49 | cell_size: Size of each memory cell. 50 | nonlinearity: Nonlinearity for RNN ('tanh' or 'relu'). 51 | device: GPU ID (deprecated, use device). 52 | independent_linears: Whether to use independent linear layers in memory. 53 | share_memory: Whether to share memory across layers. 54 | debug: Whether to enable debug mode. 55 | clip: Value to clip controller output. 56 | device: PyTorch device. 57 | """ 58 | super(SAM, self).__init__( 59 | input_size=input_size, 60 | hidden_size=hidden_size, 61 | rnn_type=rnn_type, 62 | num_layers=num_layers, 63 | num_hidden_layers=num_hidden_layers, 64 | bias=bias, 65 | batch_first=batch_first, 66 | dropout=dropout, 67 | nr_cells=nr_cells, 68 | read_heads=read_heads, 69 | cell_size=cell_size, 70 | nonlinearity=nonlinearity, 71 | independent_linears=independent_linears, 72 | share_memory_between_layers=share_memory, 73 | debug=debug, 74 | clip=clip, 75 | device=device, 76 | ) 77 | self.sparse_reads = sparse_reads 78 | self.device = device 79 | # override SDNC memories with SAM 80 | self.memories = [] 81 | 82 | for layer in range(self.num_layers): 83 | # memories for each layer 84 | if not self.share_memory_between_layers: 85 | self.memories.append( 86 | SparseMemory( 87 | input_size=self.output_size, 88 | mem_size=self.nr_cells, 89 | cell_size=self.w, 90 | sparse_reads=self.sparse_reads, 91 | read_heads=self.read_heads, 92 | device=self.device, 93 | independent_linears=self.independent_linears, 94 | ) 95 | ) 96 | setattr(self, "rnn_layer_memory_" + str(layer), self.memories[layer]) 97 | 98 | # only one memory shared by all layers 99 | if self.share_memory_between_layers: 100 | self.memories.append( 101 | SparseMemory( 102 | input_size=self.output_size, 103 | mem_size=self.nr_cells, 104 | cell_size=self.w, 105 | sparse_reads=self.sparse_reads, 106 | read_heads=self.read_heads, 107 | device=self.device, 108 | independent_linears=self.independent_linears, 109 | ) 110 | ) 111 | setattr(self, "rnn_layer_memory_shared", self.memories[0]) 112 | 113 | def _debug(self, mhx: dict, debug_obj: dict | None) -> dict | None: 114 | """Debug function to collect memory information. 115 | Args: 116 | mhx: Memory hidden state. 117 | debug_obj: Debug object to store information. 118 | 119 | Returns: 120 | Updated debug object or None. 121 | """ 122 | if not debug_obj: 123 | debug_obj = { 124 | "memory": [], 125 | "visible_memory": [], 126 | "read_weights": [], 127 | "write_weights": [], 128 | "read_vectors": [], 129 | "least_used_mem": [], 130 | "usage": [], 131 | "read_positions": [], 132 | } 133 | 134 | debug_obj["memory"].append(mhx["memory"][0].detach().cpu().numpy()) 135 | debug_obj["visible_memory"].append(mhx["visible_memory"][0].detach().cpu().numpy()) 136 | debug_obj["read_weights"].append(mhx["read_weights"][0].unsqueeze(0).detach().cpu().numpy()) 137 | debug_obj["write_weights"].append(mhx["write_weights"][0].unsqueeze(0).detach().cpu().numpy()) 138 | debug_obj["read_vectors"].append(mhx["read_vectors"][0].detach().cpu().numpy()) 139 | debug_obj["least_used_mem"].append(mhx["least_used_mem"][0].unsqueeze(0).detach().cpu().numpy()) 140 | debug_obj["usage"].append(mhx["usage"][0].unsqueeze(0).detach().cpu().numpy()) 141 | debug_obj["read_positions"].append(mhx["read_positions"][0].unsqueeze(0).detach().cpu().numpy()) 142 | 143 | return debug_obj 144 | -------------------------------------------------------------------------------- /test/test_sam_gru.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # #!/usr/bin/env python3 3 | # # -*- coding: utf-8 -*- 4 | 5 | 6 | import torch as T 7 | import torch.optim as optim 8 | 9 | import sys 10 | import functools 11 | 12 | sys.path.insert(0, ".") 13 | 14 | from dnc import SAM 15 | from test_utils import generate_data, criterion 16 | 17 | 18 | def test_rnn_1(): 19 | T.manual_seed(1111) 20 | 21 | input_size = 100 22 | hidden_size = 100 23 | rnn_type = "gru" 24 | num_layers = 1 25 | num_hidden_layers = 1 26 | dropout = 0 27 | nr_cells = 100 28 | cell_size = 10 29 | read_heads = 1 30 | sparse_reads = 2 31 | device = None 32 | debug = True 33 | lr = 0.001 34 | sequence_max_length = 10 35 | batch_size = 10 36 | cuda = device 37 | clip = 10 38 | length = 10 39 | 40 | rnn = SAM( 41 | input_size=input_size, 42 | hidden_size=hidden_size, 43 | rnn_type=rnn_type, 44 | num_layers=num_layers, 45 | num_hidden_layers=num_hidden_layers, 46 | dropout=dropout, 47 | nr_cells=nr_cells, 48 | cell_size=cell_size, 49 | read_heads=read_heads, 50 | sparse_reads=sparse_reads, 51 | device=device, 52 | debug=debug, 53 | ) 54 | 55 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 56 | optimizer.zero_grad() 57 | 58 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 59 | 60 | output, (chx, mhx, rv), v = rnn(input_data, None) 61 | 62 | # Make output and target compatible for loss calculation 63 | # target: [batch, seq, features] -> [seq, batch, features] 64 | target_output = target_output.permute(1, 0, 2).contiguous() 65 | 66 | loss = criterion(output, target_output) 67 | loss.backward() 68 | 69 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 70 | optimizer.step() 71 | 72 | assert target_output.size() == T.Size([21, 10, 100]) 73 | assert chx[0][0].size() == T.Size([10, 100]) 74 | # assert mhx['memory'].size() == T.Size([10,1,1]) 75 | assert rv.size() == T.Size([10, 10]) 76 | 77 | 78 | def test_rnn_n(): 79 | T.manual_seed(1111) 80 | 81 | input_size = 100 82 | hidden_size = 100 83 | rnn_type = "gru" 84 | num_layers = 3 85 | num_hidden_layers = 5 86 | dropout = 0.2 87 | nr_cells = 200 88 | cell_size = 17 89 | read_heads = 2 90 | sparse_reads = 4 91 | device = None 92 | debug = True 93 | lr = 0.001 94 | sequence_max_length = 10 95 | batch_size = 10 96 | cuda = device 97 | clip = 20 98 | length = 13 99 | 100 | rnn = SAM( 101 | input_size=input_size, 102 | hidden_size=hidden_size, 103 | rnn_type=rnn_type, 104 | num_layers=num_layers, 105 | num_hidden_layers=num_hidden_layers, 106 | dropout=dropout, 107 | nr_cells=nr_cells, 108 | cell_size=cell_size, 109 | read_heads=read_heads, 110 | sparse_reads=sparse_reads, 111 | device=device, 112 | debug=debug, 113 | ) 114 | 115 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 116 | optimizer.zero_grad() 117 | 118 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 119 | 120 | output, (chx, mhx, rv), v = rnn(input_data, None) 121 | 122 | # Make output and target compatible for loss calculation 123 | # target: [batch, seq, features] -> [seq, batch, features] 124 | target_output = target_output.permute(1, 0, 2).contiguous() 125 | 126 | loss = criterion(output, target_output) 127 | loss.backward() 128 | 129 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 130 | optimizer.step() 131 | 132 | assert target_output.size() == T.Size([27, 10, 100]) 133 | assert chx[0].size() == T.Size([num_hidden_layers, 10, 100]) 134 | # assert mhx['memory'].size() == T.Size([10,12,17]) 135 | assert rv.size() == T.Size([10, 34]) 136 | 137 | 138 | def test_rnn_no_memory_pass(): 139 | T.manual_seed(1111) 140 | 141 | input_size = 100 142 | hidden_size = 100 143 | rnn_type = "gru" 144 | num_layers = 3 145 | num_hidden_layers = 5 146 | dropout = 0.2 147 | nr_cells = 5000 148 | cell_size = 17 149 | sparse_reads = 3 150 | device = None 151 | debug = True 152 | lr = 0.001 153 | sequence_max_length = 10 154 | batch_size = 10 155 | cuda = device 156 | clip = 20 157 | length = 13 158 | 159 | rnn = SAM( 160 | input_size=input_size, 161 | hidden_size=hidden_size, 162 | rnn_type=rnn_type, 163 | num_layers=num_layers, 164 | num_hidden_layers=num_hidden_layers, 165 | dropout=dropout, 166 | nr_cells=nr_cells, 167 | cell_size=cell_size, 168 | sparse_reads=sparse_reads, 169 | device=device, 170 | debug=debug, 171 | ) 172 | 173 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 174 | optimizer.zero_grad() 175 | 176 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 177 | 178 | # Make output and target compatible for loss calculation 179 | # target: [batch, seq, features] -> [seq, batch, features] 180 | target_output = target_output.permute(1, 0, 2).contiguous() 181 | 182 | # Initialize hidden state explicitly 183 | controller_hidden = None 184 | memory_hidden = None 185 | last_read = None 186 | outputs = [] 187 | 188 | for x in range(6): 189 | output, (controller_hidden, memory_hidden, last_read), v = rnn( 190 | input_data, (controller_hidden, memory_hidden, last_read), pass_through_memory=False 191 | ) 192 | outputs.append(output) 193 | 194 | # Sum outputs for all iterations 195 | output = functools.reduce(lambda x, y: x + y, outputs) 196 | loss = criterion(output, target_output) 197 | loss.backward() 198 | 199 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 200 | optimizer.step() 201 | 202 | assert target_output.size() == T.Size([27, 10, 100]) 203 | assert controller_hidden[0].size() == T.Size([num_hidden_layers, 10, 100]) 204 | # assert memory_hidden[0]['memory'].size() == T.Size([10,12,17]) 205 | assert last_read is not None 206 | -------------------------------------------------------------------------------- /test/test_sam_rnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # #!/usr/bin/env python3 3 | # # -*- coding: utf-8 -*- 4 | 5 | 6 | import torch as T 7 | import torch.optim as optim 8 | 9 | import sys 10 | import functools 11 | 12 | sys.path.insert(0, ".") 13 | 14 | from dnc import SAM 15 | from test_utils import generate_data, criterion 16 | 17 | 18 | def test_rnn_1(): 19 | T.manual_seed(1111) 20 | 21 | input_size = 100 22 | hidden_size = 100 23 | rnn_type = "rnn" 24 | num_layers = 1 25 | num_hidden_layers = 1 26 | dropout = 0 27 | nr_cells = 100 28 | cell_size = 10 29 | read_heads = 1 30 | sparse_reads = 2 31 | device = None 32 | debug = True 33 | lr = 0.001 34 | sequence_max_length = 10 35 | batch_size = 10 36 | cuda = device 37 | clip = 10 38 | length = 10 39 | 40 | rnn = SAM( 41 | input_size=input_size, 42 | hidden_size=hidden_size, 43 | rnn_type=rnn_type, 44 | num_layers=num_layers, 45 | num_hidden_layers=num_hidden_layers, 46 | dropout=dropout, 47 | nr_cells=nr_cells, 48 | cell_size=cell_size, 49 | read_heads=read_heads, 50 | sparse_reads=sparse_reads, 51 | device=device, 52 | debug=debug, 53 | ) 54 | 55 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 56 | optimizer.zero_grad() 57 | 58 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 59 | 60 | output, (chx, mhx, rv), v = rnn(input_data, None) 61 | 62 | # Make output and target compatible for loss calculation 63 | # target: [batch, seq, features] -> [seq, batch, features] 64 | target_output = target_output.permute(1, 0, 2).contiguous() 65 | 66 | loss = criterion(output, target_output) 67 | loss.backward() 68 | 69 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 70 | optimizer.step() 71 | 72 | assert target_output.size() == T.Size([21, 10, 100]) 73 | assert chx[0][0].size() == T.Size([10, 100]) 74 | # assert mhx['memory'].size() == T.Size([10,1,1]) 75 | assert rv.size() == T.Size([10, 10]) 76 | 77 | 78 | def test_rnn_n(): 79 | T.manual_seed(1111) 80 | 81 | input_size = 100 82 | hidden_size = 100 83 | rnn_type = "rnn" 84 | num_layers = 3 85 | num_hidden_layers = 5 86 | dropout = 0.2 87 | nr_cells = 200 88 | cell_size = 17 89 | read_heads = 2 90 | sparse_reads = 4 91 | device = None 92 | debug = True 93 | lr = 0.001 94 | sequence_max_length = 10 95 | batch_size = 10 96 | cuda = device 97 | clip = 20 98 | length = 13 99 | 100 | rnn = SAM( 101 | input_size=input_size, 102 | hidden_size=hidden_size, 103 | rnn_type=rnn_type, 104 | num_layers=num_layers, 105 | num_hidden_layers=num_hidden_layers, 106 | dropout=dropout, 107 | nr_cells=nr_cells, 108 | cell_size=cell_size, 109 | read_heads=read_heads, 110 | sparse_reads=sparse_reads, 111 | device=device, 112 | debug=debug, 113 | ) 114 | 115 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 116 | optimizer.zero_grad() 117 | 118 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 119 | 120 | output, (chx, mhx, rv), v = rnn(input_data, None) 121 | 122 | # Make output and target compatible for loss calculation 123 | # target: [batch, seq, features] -> [seq, batch, features] 124 | target_output = target_output.permute(1, 0, 2).contiguous() 125 | 126 | loss = criterion(output, target_output) 127 | loss.backward() 128 | 129 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 130 | optimizer.step() 131 | 132 | assert target_output.size() == T.Size([27, 10, 100]) 133 | assert chx[0].size() == T.Size([num_hidden_layers, 10, 100]) 134 | # assert mhx['memory'].size() == T.Size([10,12,17]) 135 | assert rv.size() == T.Size([10, 34]) 136 | 137 | 138 | def test_rnn_no_memory_pass(): 139 | T.manual_seed(1111) 140 | 141 | input_size = 100 142 | hidden_size = 100 143 | rnn_type = "rnn" 144 | num_layers = 3 145 | num_hidden_layers = 5 146 | dropout = 0.2 147 | nr_cells = 5000 148 | cell_size = 17 149 | sparse_reads = 3 150 | device = None 151 | debug = True 152 | lr = 0.001 153 | sequence_max_length = 10 154 | batch_size = 10 155 | cuda = device 156 | clip = 20 157 | length = 13 158 | 159 | rnn = SAM( 160 | input_size=input_size, 161 | hidden_size=hidden_size, 162 | rnn_type=rnn_type, 163 | num_layers=num_layers, 164 | num_hidden_layers=num_hidden_layers, 165 | dropout=dropout, 166 | nr_cells=nr_cells, 167 | cell_size=cell_size, 168 | sparse_reads=sparse_reads, 169 | device=device, 170 | debug=debug, 171 | ) 172 | 173 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 174 | optimizer.zero_grad() 175 | 176 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 177 | 178 | # Make output and target compatible for loss calculation 179 | # target: [batch, seq, features] -> [seq, batch, features] 180 | target_output = target_output.permute(1, 0, 2).contiguous() 181 | 182 | # Initialize hidden state explicitly 183 | controller_hidden = None 184 | memory_hidden = None 185 | last_read = None 186 | outputs = [] 187 | 188 | for x in range(6): 189 | output, (controller_hidden, memory_hidden, last_read), v = rnn( 190 | input_data, (controller_hidden, memory_hidden, last_read), pass_through_memory=False 191 | ) 192 | outputs.append(output) 193 | 194 | # Sum outputs for all iterations 195 | output = functools.reduce(lambda x, y: x + y, outputs) 196 | loss = criterion(output, target_output) 197 | loss.backward() 198 | 199 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 200 | optimizer.step() 201 | 202 | assert target_output.size() == T.Size([27, 10, 100]) 203 | assert controller_hidden[0].size() == T.Size([num_hidden_layers, 10, 100]) 204 | # assert memory_hidden[0]['memory'].size() == T.Size([10,12,17]) 205 | assert last_read is not None 206 | -------------------------------------------------------------------------------- /test/test_sam_lstm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # #!/usr/bin/env python3 3 | # # -*- coding: utf-8 -*- 4 | 5 | 6 | import torch as T 7 | import torch.optim as optim 8 | 9 | import sys 10 | import functools 11 | 12 | sys.path.insert(0, ".") 13 | 14 | from dnc import SAM 15 | from test_utils import generate_data, criterion 16 | 17 | 18 | def test_rnn_1(): 19 | T.manual_seed(1111) 20 | 21 | input_size = 100 22 | hidden_size = 100 23 | rnn_type = "lstm" 24 | num_layers = 1 25 | num_hidden_layers = 1 26 | dropout = 0 27 | nr_cells = 100 28 | cell_size = 10 29 | read_heads = 1 30 | sparse_reads = 2 31 | device = None 32 | debug = True 33 | lr = 0.001 34 | sequence_max_length = 10 35 | batch_size = 10 36 | cuda = device 37 | clip = 10 38 | length = 10 39 | 40 | rnn = SAM( 41 | input_size=input_size, 42 | hidden_size=hidden_size, 43 | rnn_type=rnn_type, 44 | num_layers=num_layers, 45 | num_hidden_layers=num_hidden_layers, 46 | dropout=dropout, 47 | nr_cells=nr_cells, 48 | cell_size=cell_size, 49 | read_heads=read_heads, 50 | sparse_reads=sparse_reads, 51 | device=device, 52 | debug=debug, 53 | ) 54 | 55 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 56 | optimizer.zero_grad() 57 | 58 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 59 | 60 | output, (chx, mhx, rv), v = rnn(input_data, None) 61 | 62 | # Make output and target compatible for loss calculation 63 | # target: [batch, seq, features] -> [seq, batch, features] 64 | target_output = target_output.permute(1, 0, 2).contiguous() 65 | 66 | loss = criterion(output, target_output) 67 | loss.backward() 68 | 69 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 70 | optimizer.step() 71 | 72 | assert target_output.size() == T.Size([21, 10, 100]) 73 | assert chx[0][0][0].size() == T.Size([10, 100]) 74 | # assert mhx['memory'].size() == T.Size([10,1,1]) 75 | assert rv.size() == T.Size([10, 10]) 76 | 77 | 78 | def test_rnn_n(): 79 | T.manual_seed(1111) 80 | 81 | input_size = 100 82 | hidden_size = 100 83 | rnn_type = "lstm" 84 | num_layers = 3 85 | num_hidden_layers = 5 86 | dropout = 0.2 87 | nr_cells = 200 88 | cell_size = 17 89 | read_heads = 2 90 | sparse_reads = 4 91 | device = None 92 | debug = True 93 | lr = 0.001 94 | sequence_max_length = 10 95 | batch_size = 10 96 | cuda = device 97 | clip = 20 98 | length = 13 99 | 100 | rnn = SAM( 101 | input_size=input_size, 102 | hidden_size=hidden_size, 103 | rnn_type=rnn_type, 104 | num_layers=num_layers, 105 | num_hidden_layers=num_hidden_layers, 106 | dropout=dropout, 107 | nr_cells=nr_cells, 108 | cell_size=cell_size, 109 | read_heads=read_heads, 110 | sparse_reads=sparse_reads, 111 | device=device, 112 | debug=debug, 113 | ) 114 | 115 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 116 | optimizer.zero_grad() 117 | 118 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 119 | 120 | output, (chx, mhx, rv), v = rnn(input_data, None) 121 | 122 | # Make output and target compatible for loss calculation 123 | # target: [batch, seq, features] -> [seq, batch, features] 124 | target_output = target_output.permute(1, 0, 2).contiguous() 125 | 126 | loss = criterion(output, target_output) 127 | loss.backward() 128 | 129 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 130 | optimizer.step() 131 | 132 | assert target_output.size() == T.Size([27, 10, 100]) 133 | assert chx[0][0].size() == T.Size([num_hidden_layers, 10, 100]) 134 | # assert mhx['memory'].size() == T.Size([10,12,17]) 135 | assert rv.size() == T.Size([10, 34]) 136 | 137 | 138 | def test_rnn_no_memory_pass(): 139 | T.manual_seed(1111) 140 | 141 | input_size = 100 142 | hidden_size = 100 143 | rnn_type = "lstm" 144 | num_layers = 3 145 | num_hidden_layers = 5 146 | dropout = 0.2 147 | nr_cells = 5000 148 | cell_size = 17 149 | sparse_reads = 3 150 | device = None 151 | debug = True 152 | lr = 0.001 153 | sequence_max_length = 10 154 | batch_size = 10 155 | cuda = device 156 | clip = 20 157 | length = 13 158 | 159 | rnn = SAM( 160 | input_size=input_size, 161 | hidden_size=hidden_size, 162 | rnn_type=rnn_type, 163 | num_layers=num_layers, 164 | num_hidden_layers=num_hidden_layers, 165 | dropout=dropout, 166 | nr_cells=nr_cells, 167 | cell_size=cell_size, 168 | sparse_reads=sparse_reads, 169 | device=device, 170 | debug=debug, 171 | ) 172 | 173 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 174 | optimizer.zero_grad() 175 | 176 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 177 | 178 | # Make output and target compatible for loss calculation 179 | # target: [batch, seq, features] -> [seq, batch, features] 180 | target_output = target_output.permute(1, 0, 2).contiguous() 181 | 182 | # Initialize hidden state explicitly 183 | controller_hidden = None 184 | memory_hidden = None 185 | last_read = None 186 | outputs = [] 187 | 188 | for x in range(6): 189 | output, (controller_hidden, memory_hidden, last_read), v = rnn( 190 | input_data, (controller_hidden, memory_hidden, last_read), pass_through_memory=False 191 | ) 192 | outputs.append(output) 193 | 194 | # Sum outputs for all iterations 195 | output = functools.reduce(lambda x, y: x + y, outputs) 196 | loss = criterion(output, target_output) 197 | loss.backward() 198 | 199 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 200 | optimizer.step() 201 | 202 | assert target_output.size() == T.Size([27, 10, 100]) 203 | assert controller_hidden[0][0].size() == T.Size([num_hidden_layers, 10, 100]) 204 | # assert memory_hidden[0]['memory'].size() == T.Size([10,12,17]) 205 | assert last_read is not None 206 | -------------------------------------------------------------------------------- /test/test_sdnc_rnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # #!/usr/bin/env python3 3 | # # -*- coding: utf-8 -*- 4 | 5 | 6 | import torch as T 7 | import torch.optim as optim 8 | 9 | import sys 10 | import functools 11 | 12 | sys.path.insert(0, ".") 13 | 14 | from dnc import SDNC 15 | from test_utils import generate_data, criterion 16 | 17 | 18 | def test_rnn_1(): 19 | T.manual_seed(1111) 20 | 21 | input_size = 100 22 | hidden_size = 100 23 | rnn_type = "rnn" 24 | num_layers = 1 25 | num_hidden_layers = 1 26 | dropout = 0 27 | nr_cells = 100 28 | cell_size = 10 29 | read_heads = 1 30 | sparse_reads = 2 31 | temporal_reads = 1 32 | device = None 33 | debug = True 34 | lr = 0.001 35 | sequence_max_length = 10 36 | batch_size = 10 37 | cuda = device 38 | clip = 10 39 | length = 10 40 | 41 | rnn = SDNC( 42 | input_size=input_size, 43 | hidden_size=hidden_size, 44 | rnn_type=rnn_type, 45 | num_layers=num_layers, 46 | num_hidden_layers=num_hidden_layers, 47 | dropout=dropout, 48 | nr_cells=nr_cells, 49 | cell_size=cell_size, 50 | read_heads=read_heads, 51 | sparse_reads=sparse_reads, 52 | temporal_reads=temporal_reads, 53 | device=device, 54 | debug=debug, 55 | ) 56 | 57 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 58 | optimizer.zero_grad() 59 | 60 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 61 | 62 | output, (chx, mhx, rv), v = rnn(input_data, None) 63 | 64 | # Make output and target compatible for loss calculation 65 | # target: [batch, seq, features] -> [seq, batch, features] 66 | target_output = target_output.permute(1, 0, 2).contiguous() 67 | 68 | loss = criterion(output, target_output) 69 | loss.backward() 70 | 71 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 72 | optimizer.step() 73 | 74 | assert target_output.size() == T.Size([21, 10, 100]) 75 | assert chx[0][0].size() == T.Size([10, 100]) 76 | # assert mhx[0]['memory'].size() == T.Size([10,1,1]) 77 | assert rv.size() == T.Size([10, 10]) 78 | 79 | 80 | def test_rnn_n(): 81 | T.manual_seed(1111) 82 | 83 | input_size = 100 84 | hidden_size = 100 85 | rnn_type = "rnn" 86 | num_layers = 3 87 | num_hidden_layers = 5 88 | dropout = 0.2 89 | nr_cells = 200 90 | cell_size = 17 91 | read_heads = 2 92 | sparse_reads = 4 93 | temporal_reads = 3 94 | device = None 95 | debug = True 96 | lr = 0.001 97 | sequence_max_length = 10 98 | batch_size = 10 99 | cuda = device 100 | clip = 20 101 | length = 13 102 | 103 | rnn = SDNC( 104 | input_size=input_size, 105 | hidden_size=hidden_size, 106 | rnn_type=rnn_type, 107 | num_layers=num_layers, 108 | num_hidden_layers=num_hidden_layers, 109 | dropout=dropout, 110 | nr_cells=nr_cells, 111 | cell_size=cell_size, 112 | read_heads=read_heads, 113 | sparse_reads=sparse_reads, 114 | temporal_reads=temporal_reads, 115 | device=device, 116 | debug=debug, 117 | ) 118 | 119 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 120 | optimizer.zero_grad() 121 | 122 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 123 | 124 | output, (chx, mhx, rv), v = rnn(input_data, None) 125 | 126 | # Make output and target compatible for loss calculation 127 | # target: [batch, seq, features] -> [seq, batch, features] 128 | target_output = target_output.permute(1, 0, 2).contiguous() 129 | 130 | loss = criterion(output, target_output) 131 | loss.backward() 132 | 133 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 134 | optimizer.step() 135 | 136 | assert target_output.size() == T.Size([27, 10, 100]) 137 | assert chx[0].size() == T.Size([num_hidden_layers, 10, 100]) 138 | # assert mhx[0]['memory'].size() == T.Size([10,12,17]) 139 | assert rv.size() == T.Size([10, 34]) 140 | 141 | 142 | def test_rnn_no_memory_pass(): 143 | T.manual_seed(1111) 144 | 145 | input_size = 100 146 | hidden_size = 100 147 | rnn_type = "rnn" 148 | num_layers = 3 149 | num_hidden_layers = 5 150 | dropout = 0.2 151 | nr_cells = 5000 152 | cell_size = 17 153 | sparse_reads = 3 154 | temporal_reads = 4 155 | device = None 156 | debug = True 157 | lr = 0.001 158 | sequence_max_length = 10 159 | batch_size = 10 160 | cuda = device 161 | clip = 20 162 | length = 13 163 | 164 | rnn = SDNC( 165 | input_size=input_size, 166 | hidden_size=hidden_size, 167 | rnn_type=rnn_type, 168 | num_layers=num_layers, 169 | num_hidden_layers=num_hidden_layers, 170 | dropout=dropout, 171 | nr_cells=nr_cells, 172 | cell_size=cell_size, 173 | sparse_reads=sparse_reads, 174 | temporal_reads=temporal_reads, 175 | device=device, 176 | debug=debug, 177 | ) 178 | 179 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 180 | optimizer.zero_grad() 181 | 182 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 183 | 184 | # Transform target to match expected output shape 185 | target_output = target_output.permute(1, 0, 2).contiguous() 186 | 187 | # Initialize hidden state explicitly 188 | controller_hidden = None 189 | memory_hidden = None 190 | last_read = None 191 | outputs = [] 192 | 193 | for x in range(6): 194 | output, (controller_hidden, memory_hidden, last_read), v = rnn( 195 | input_data, (controller_hidden, memory_hidden, last_read), pass_through_memory=False 196 | ) 197 | outputs.append(output) 198 | 199 | # Sum outputs for all iterations 200 | output = functools.reduce(lambda x, y: x + y, outputs) 201 | loss = criterion(output, target_output) 202 | loss.backward() 203 | 204 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 205 | optimizer.step() 206 | 207 | assert target_output.size() == T.Size([27, 10, 100]) 208 | assert controller_hidden[0].size() == T.Size([num_hidden_layers, 10, 100]) 209 | # assert memory_hidden[0]['memory'].size() == T.Size([10,12,17]) 210 | assert last_read is not None 211 | -------------------------------------------------------------------------------- /test/test_sdnc_gru.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # #!/usr/bin/env python3 3 | # # -*- coding: utf-8 -*- 4 | 5 | 6 | import torch as T 7 | import torch.optim as optim 8 | 9 | import sys 10 | import functools 11 | 12 | sys.path.insert(0, ".") 13 | 14 | from dnc import SDNC 15 | from test_utils import generate_data, criterion 16 | 17 | 18 | def test_rnn_1(): 19 | T.manual_seed(1111) 20 | 21 | input_size = 100 22 | hidden_size = 100 23 | rnn_type = "gru" 24 | num_layers = 1 25 | num_hidden_layers = 1 26 | dropout = 0 27 | nr_cells = 100 28 | cell_size = 10 29 | read_heads = 1 30 | sparse_reads = 2 31 | temporal_reads = 1 32 | device = None 33 | debug = True 34 | lr = 0.001 35 | sequence_max_length = 10 36 | batch_size = 10 37 | cuda = device 38 | clip = 10 39 | length = 10 40 | 41 | rnn = SDNC( 42 | input_size=input_size, 43 | hidden_size=hidden_size, 44 | rnn_type=rnn_type, 45 | num_layers=num_layers, 46 | num_hidden_layers=num_hidden_layers, 47 | dropout=dropout, 48 | nr_cells=nr_cells, 49 | cell_size=cell_size, 50 | read_heads=read_heads, 51 | sparse_reads=sparse_reads, 52 | temporal_reads=temporal_reads, 53 | device=device, 54 | debug=debug, 55 | ) 56 | 57 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 58 | optimizer.zero_grad() 59 | 60 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 61 | 62 | output, (chx, mhx, rv), v = rnn(input_data, None) 63 | 64 | # Make output and target compatible for loss calculation 65 | # target: [batch, seq, features] -> [seq, batch, features] 66 | target_output = target_output.permute(1, 0, 2).contiguous() 67 | 68 | loss = criterion(output, target_output) 69 | loss.backward() 70 | 71 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 72 | optimizer.step() 73 | 74 | assert target_output.size() == T.Size([21, 10, 100]) 75 | assert chx[0][0].size() == T.Size([10, 100]) 76 | # assert mhx['memory'].size() == T.Size([10,1,1]) 77 | assert rv.size() == T.Size([10, 10]) 78 | 79 | 80 | def test_rnn_n(): 81 | T.manual_seed(1111) 82 | 83 | input_size = 100 84 | hidden_size = 100 85 | rnn_type = "gru" 86 | num_layers = 3 87 | num_hidden_layers = 5 88 | dropout = 0.2 89 | nr_cells = 200 90 | cell_size = 17 91 | read_heads = 2 92 | sparse_reads = 4 93 | temporal_reads = 3 94 | device = None 95 | debug = True 96 | lr = 0.001 97 | sequence_max_length = 10 98 | batch_size = 10 99 | cuda = device 100 | clip = 20 101 | length = 13 102 | 103 | rnn = SDNC( 104 | input_size=input_size, 105 | hidden_size=hidden_size, 106 | rnn_type=rnn_type, 107 | num_layers=num_layers, 108 | num_hidden_layers=num_hidden_layers, 109 | dropout=dropout, 110 | nr_cells=nr_cells, 111 | cell_size=cell_size, 112 | read_heads=read_heads, 113 | sparse_reads=sparse_reads, 114 | temporal_reads=temporal_reads, 115 | device=device, 116 | debug=debug, 117 | ) 118 | 119 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 120 | optimizer.zero_grad() 121 | 122 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 123 | 124 | output, (chx, mhx, rv), v = rnn(input_data, None) 125 | 126 | # Make output and target compatible for loss calculation 127 | # target: [batch, seq, features] -> [seq, batch, features] 128 | target_output = target_output.permute(1, 0, 2).contiguous() 129 | 130 | loss = criterion(output, target_output) 131 | loss.backward() 132 | 133 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 134 | optimizer.step() 135 | 136 | assert target_output.size() == T.Size([27, 10, 100]) 137 | assert chx[0].size() == T.Size([num_hidden_layers, 10, 100]) 138 | # assert mhx['memory'].size() == T.Size([10,12,17]) 139 | assert rv.size() == T.Size([10, 34]) 140 | 141 | 142 | def test_rnn_no_memory_pass(): 143 | T.manual_seed(1111) 144 | 145 | input_size = 100 146 | hidden_size = 100 147 | rnn_type = "gru" 148 | num_layers = 3 149 | num_hidden_layers = 5 150 | dropout = 0.2 151 | nr_cells = 5000 152 | cell_size = 17 153 | sparse_reads = 3 154 | temporal_reads = 4 155 | device = None 156 | debug = True 157 | lr = 0.001 158 | sequence_max_length = 10 159 | batch_size = 10 160 | cuda = device 161 | clip = 20 162 | length = 13 163 | 164 | rnn = SDNC( 165 | input_size=input_size, 166 | hidden_size=hidden_size, 167 | rnn_type=rnn_type, 168 | num_layers=num_layers, 169 | num_hidden_layers=num_hidden_layers, 170 | dropout=dropout, 171 | nr_cells=nr_cells, 172 | cell_size=cell_size, 173 | sparse_reads=sparse_reads, 174 | temporal_reads=temporal_reads, 175 | device=device, 176 | debug=debug, 177 | ) 178 | 179 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 180 | optimizer.zero_grad() 181 | 182 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 183 | 184 | # Make output and target compatible for loss calculation 185 | # target: [batch, seq, features] -> [seq, batch, features] 186 | target_output = target_output.permute(1, 0, 2).contiguous() 187 | 188 | # Initialize hidden state explicitly 189 | controller_hidden = None 190 | memory_hidden = None 191 | last_read = None 192 | outputs = [] 193 | 194 | for x in range(6): 195 | output, (controller_hidden, memory_hidden, last_read), v = rnn( 196 | input_data, (controller_hidden, memory_hidden, last_read), pass_through_memory=False 197 | ) 198 | outputs.append(output) 199 | 200 | # Sum outputs for all iterations 201 | output = functools.reduce(lambda x, y: x + y, outputs) 202 | loss = criterion(output, target_output) 203 | loss.backward() 204 | 205 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 206 | optimizer.step() 207 | 208 | assert target_output.size() == T.Size([27, 10, 100]) 209 | assert controller_hidden[0].size() == T.Size([num_hidden_layers, 10, 100]) 210 | # assert memory_hidden[0]['memory'].size() == T.Size([10,12,17]) 211 | assert last_read is not None 212 | -------------------------------------------------------------------------------- /test/test_sdnc_lstm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # #!/usr/bin/env python3 3 | # # -*- coding: utf-8 -*- 4 | 5 | 6 | import torch as T 7 | import torch.optim as optim 8 | 9 | import sys 10 | import functools 11 | 12 | sys.path.insert(0, ".") 13 | 14 | from dnc import SDNC 15 | from test_utils import generate_data, criterion 16 | 17 | 18 | def test_rnn_1(): 19 | T.manual_seed(1111) 20 | 21 | input_size = 100 22 | hidden_size = 100 23 | rnn_type = "lstm" 24 | num_layers = 1 25 | num_hidden_layers = 1 26 | dropout = 0 27 | nr_cells = 100 28 | cell_size = 10 29 | read_heads = 1 30 | sparse_reads = 2 31 | temporal_reads = 1 32 | device = None 33 | debug = True 34 | lr = 0.001 35 | sequence_max_length = 10 36 | batch_size = 10 37 | cuda = device 38 | clip = 10 39 | length = 10 40 | 41 | rnn = SDNC( 42 | input_size=input_size, 43 | hidden_size=hidden_size, 44 | rnn_type=rnn_type, 45 | num_layers=num_layers, 46 | num_hidden_layers=num_hidden_layers, 47 | dropout=dropout, 48 | nr_cells=nr_cells, 49 | cell_size=cell_size, 50 | read_heads=read_heads, 51 | sparse_reads=sparse_reads, 52 | temporal_reads=temporal_reads, 53 | device=device, 54 | debug=debug, 55 | ) 56 | 57 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 58 | optimizer.zero_grad() 59 | 60 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 61 | 62 | output, (chx, mhx, rv), v = rnn(input_data, None) 63 | 64 | # Make output and target compatible for loss calculation 65 | # target: [batch, seq, features] -> [seq, batch, features] 66 | target_output = target_output.permute(1, 0, 2).contiguous() 67 | 68 | loss = criterion(output, target_output) 69 | loss.backward() 70 | 71 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 72 | optimizer.step() 73 | 74 | assert target_output.size() == T.Size([21, 10, 100]) 75 | assert chx[0][0][0].size() == T.Size([10, 100]) 76 | # assert mhx['memory'].size() == T.Size([10,1,1]) 77 | assert rv.size() == T.Size([10, 10]) 78 | 79 | 80 | def test_rnn_n(): 81 | T.manual_seed(1111) 82 | 83 | input_size = 100 84 | hidden_size = 100 85 | rnn_type = "lstm" 86 | num_layers = 3 87 | num_hidden_layers = 5 88 | dropout = 0.2 89 | nr_cells = 200 90 | cell_size = 17 91 | read_heads = 2 92 | sparse_reads = 4 93 | temporal_reads = 3 94 | device = None 95 | debug = True 96 | lr = 0.001 97 | sequence_max_length = 10 98 | batch_size = 10 99 | cuda = device 100 | clip = 20 101 | length = 13 102 | 103 | rnn = SDNC( 104 | input_size=input_size, 105 | hidden_size=hidden_size, 106 | rnn_type=rnn_type, 107 | num_layers=num_layers, 108 | num_hidden_layers=num_hidden_layers, 109 | dropout=dropout, 110 | nr_cells=nr_cells, 111 | cell_size=cell_size, 112 | read_heads=read_heads, 113 | sparse_reads=sparse_reads, 114 | temporal_reads=temporal_reads, 115 | device=device, 116 | debug=debug, 117 | ) 118 | 119 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 120 | optimizer.zero_grad() 121 | 122 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 123 | 124 | output, (chx, mhx, rv), v = rnn(input_data, None) 125 | 126 | # Make output and target compatible for loss calculation 127 | # target: [batch, seq, features] -> [seq, batch, features] 128 | target_output = target_output.permute(1, 0, 2).contiguous() 129 | 130 | loss = criterion(output, target_output) 131 | loss.backward() 132 | 133 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 134 | optimizer.step() 135 | 136 | assert target_output.size() == T.Size([27, 10, 100]) 137 | assert chx[0][0].size() == T.Size([num_hidden_layers, 10, 100]) 138 | # assert mhx['memory'].size() == T.Size([10,12,17]) 139 | assert rv.size() == T.Size([10, 34]) 140 | 141 | 142 | def test_rnn_no_memory_pass(): 143 | T.manual_seed(1111) 144 | 145 | input_size = 100 146 | hidden_size = 100 147 | rnn_type = "lstm" 148 | num_layers = 3 149 | num_hidden_layers = 5 150 | dropout = 0.2 151 | nr_cells = 5000 152 | cell_size = 17 153 | sparse_reads = 3 154 | temporal_reads = 4 155 | device = None 156 | debug = True 157 | lr = 0.001 158 | sequence_max_length = 10 159 | batch_size = 10 160 | cuda = device 161 | clip = 20 162 | length = 13 163 | 164 | rnn = SDNC( 165 | input_size=input_size, 166 | hidden_size=hidden_size, 167 | rnn_type=rnn_type, 168 | num_layers=num_layers, 169 | num_hidden_layers=num_hidden_layers, 170 | dropout=dropout, 171 | nr_cells=nr_cells, 172 | cell_size=cell_size, 173 | sparse_reads=sparse_reads, 174 | temporal_reads=temporal_reads, 175 | device=device, 176 | debug=debug, 177 | ) 178 | 179 | optimizer = optim.Adam(rnn.parameters(), lr=lr) 180 | optimizer.zero_grad() 181 | 182 | input_data, target_output = generate_data(batch_size, length, input_size, cuda) 183 | 184 | # Make output and target compatible for loss calculation 185 | # target: [batch, seq, features] -> [seq, batch, features] 186 | target_output = target_output.permute(1, 0, 2).contiguous() 187 | 188 | # Initialize hidden state explicitly 189 | controller_hidden = None 190 | memory_hidden = None 191 | last_read = None 192 | outputs = [] 193 | 194 | for x in range(6): 195 | output, (controller_hidden, memory_hidden, last_read), v = rnn( 196 | input_data, (controller_hidden, memory_hidden, last_read), pass_through_memory=False 197 | ) 198 | outputs.append(output) 199 | 200 | # Sum outputs for all iterations 201 | output = functools.reduce(lambda x, y: x + y, outputs) 202 | loss = criterion(output, target_output) 203 | loss.backward() 204 | 205 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip) 206 | optimizer.step() 207 | 208 | assert target_output.size() == T.Size([27, 10, 100]) 209 | assert controller_hidden[0][0].size() == T.Size([num_hidden_layers, 10, 100]) 210 | # assert memory_hidden[0]['memory'].size() == T.Size([10,12,17]) 211 | assert last_read is not None 212 | -------------------------------------------------------------------------------- /dnc/sdnc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | 6 | from .dnc import DNC 7 | from .sparse_temporal_memory import SparseTemporalMemory 8 | 9 | 10 | class SDNC(DNC): 11 | """Sparse Differentiable Neural Computer (SDNC) module.""" 12 | 13 | def __init__( 14 | self, 15 | input_size: int, 16 | hidden_size: int, 17 | rnn_type: str = "lstm", 18 | num_layers: int = 1, 19 | num_hidden_layers: int = 2, 20 | bias: bool = True, 21 | batch_first: bool = True, 22 | dropout: float = 0, 23 | bidirectional: bool = False, 24 | nr_cells: int = 5000, 25 | sparse_reads: int = 4, 26 | temporal_reads: int = 4, 27 | read_heads: int = 4, 28 | cell_size: int = 10, 29 | nonlinearity: str = "tanh", 30 | independent_linears: bool = False, 31 | share_memory: bool = True, 32 | debug: bool = False, 33 | clip: float = 20, 34 | device: torch.device | None = None, 35 | ): 36 | """ 37 | 38 | Args: 39 | input_size: Input size. 40 | hidden_size: Hidden size. 41 | rnn_type: Type of RNN cell (lstm, gru, rnn). 42 | num_layers: Number of RNN layers. 43 | num_hidden_layers: Number of hidden layers in each RNN. 44 | bias: Whether to use bias in the RNN. 45 | batch_first: Whether the input is batch-first. 46 | dropout: Dropout rate. 47 | bidirectional: Whether the RNN is bidirectional. 48 | nr_cells: Number of memory cells. 49 | sparse_reads: Number of sparse reads. 50 | temporal_reads: Number of temporal reads. 51 | read_heads: Number of read heads. 52 | cell_size: Size of each memory cell. 53 | nonlinearity: Nonlinearity for RNN ('tanh' or 'relu'). 54 | independent_linears: Whether to use independent linear layers in memory. 55 | share_memory: Whether to share memory across layers. 56 | debug: Whether to enable debug mode. 57 | clip: Value to clip controller output. 58 | device: the device to use 59 | """ 60 | super(SDNC, self).__init__( 61 | input_size=input_size, 62 | hidden_size=hidden_size, 63 | rnn_type=rnn_type, 64 | num_layers=num_layers, 65 | num_hidden_layers=num_hidden_layers, 66 | bias=bias, 67 | batch_first=batch_first, 68 | dropout=dropout, 69 | nr_cells=nr_cells, 70 | read_heads=read_heads, 71 | cell_size=cell_size, 72 | nonlinearity=nonlinearity, 73 | independent_linears=independent_linears, 74 | share_memory_between_layers=share_memory, 75 | debug=debug, 76 | clip=clip, 77 | device=device, 78 | ) 79 | 80 | self.sparse_reads = sparse_reads 81 | self.temporal_reads = temporal_reads 82 | self.device = device 83 | 84 | self.memories = [] 85 | 86 | for layer in range(self.num_layers): 87 | # memories for each layer 88 | if not self.share_memory_between_layers: 89 | self.memories.append( 90 | SparseTemporalMemory( 91 | input_size=self.output_size, 92 | mem_size=self.nr_cells, 93 | cell_size=self.w, 94 | sparse_reads=self.sparse_reads, 95 | read_heads=self.read_heads, 96 | temporal_reads=self.temporal_reads, 97 | device=self.device, 98 | independent_linears=self.independent_linears, 99 | ) 100 | ) 101 | setattr(self, "rnn_layer_memory_" + str(layer), self.memories[layer]) 102 | 103 | # only one memory shared by all layers 104 | if self.share_memory_between_layers: 105 | self.memories.append( 106 | SparseTemporalMemory( 107 | input_size=self.output_size, 108 | mem_size=self.nr_cells, 109 | cell_size=self.w, 110 | sparse_reads=self.sparse_reads, 111 | read_heads=self.read_heads, 112 | temporal_reads=self.temporal_reads, 113 | device=self.device, 114 | independent_linears=self.independent_linears, 115 | ) 116 | ) 117 | setattr(self, "rnn_layer_memory_shared", self.memories[0]) 118 | 119 | def _debug(self, mhx: dict, debug_obj: dict | None) -> dict | None: 120 | """Debug function to collect memory information. 121 | 122 | Args: 123 | mhx: Memory hidden state. 124 | debug_obj: Debug object to store information. 125 | 126 | Returns: 127 | Updated debug object or None. 128 | """ 129 | if not debug_obj: 130 | debug_obj = { 131 | "memory": [], 132 | "visible_memory": [], 133 | "link_matrix": [], 134 | "rev_link_matrix": [], 135 | "precedence": [], 136 | "read_weights": [], 137 | "write_weights": [], 138 | "read_vectors": [], 139 | "least_used_mem": [], 140 | "usage": [], 141 | "read_positions": [], 142 | } 143 | 144 | debug_obj["memory"].append(mhx["memory"][0].detach().cpu().numpy()) 145 | debug_obj["visible_memory"].append(mhx["visible_memory"][0].detach().cpu().numpy()) 146 | debug_obj["link_matrix"].append(mhx["link_matrix"][0].detach().cpu().numpy()) 147 | debug_obj["rev_link_matrix"].append(mhx["rev_link_matrix"][0].detach().cpu().numpy()) 148 | debug_obj["precedence"].append(mhx["precedence"][0].unsqueeze(0).detach().cpu().numpy()) 149 | debug_obj["read_weights"].append(mhx["read_weights"][0].unsqueeze(0).detach().cpu().numpy()) 150 | debug_obj["write_weights"].append(mhx["write_weights"][0].unsqueeze(0).detach().cpu().numpy()) 151 | debug_obj["read_vectors"].append(mhx["read_vectors"][0].detach().cpu().numpy()) 152 | debug_obj["least_used_mem"].append(mhx["least_used_mem"][0].unsqueeze(0).detach().cpu().numpy()) 153 | debug_obj["usage"].append(mhx["usage"][0].unsqueeze(0).detach().cpu().numpy()) 154 | debug_obj["read_positions"].append(mhx["read_positions"][0].unsqueeze(0).detach().cpu().numpy()) 155 | 156 | return debug_obj 157 | -------------------------------------------------------------------------------- /dnc/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from typing import Callable 8 | 9 | δ = 1e-6 10 | 11 | 12 | def recursiveTrace(obj: torch.Tensor | torch.nn.Module | None) -> None: 13 | """Recursively traces the computational graph of a tensor or module. 14 | 15 | Args: 16 | obj: The tensor or module to trace. 17 | """ 18 | if obj is None: 19 | return 20 | 21 | print(type(obj)) 22 | if hasattr(obj, "grad_fn"): 23 | print(obj.grad_fn) 24 | recursiveTrace(obj.grad_fn) # type: ignore 25 | elif hasattr(obj, "next_functions"): 26 | print(obj.requires_grad, len(obj.next_functions)) # type: ignore 27 | for f, _ in obj.next_functions: # type: ignore 28 | recursiveTrace(f) 29 | 30 | 31 | def cuda(x: torch.Tensor, requires_grad: bool = False, device: torch.device | None = None) -> torch.Tensor: 32 | """Moves a tensor to the specified device (CPU or GPU). 33 | 34 | Args: 35 | x: The tensor to move. 36 | requires_grad: Whether the tensor should require gradients. 37 | device: The device to move the tensor to. Defaults to CPU. 38 | 39 | Returns: 40 | The tensor on the specified device. 41 | """ 42 | if device is None: 43 | return x.float().requires_grad_(requires_grad) 44 | else: 45 | return x.float().to(device).requires_grad_(requires_grad) 46 | 47 | 48 | def cudavec(x: np.ndarray, requires_grad: bool = False, device: torch.device | None = None) -> torch.Tensor: 49 | """Creates a tensor from a NumPy array and moves it to the specified device. 50 | 51 | Args: 52 | x: The NumPy array. 53 | requires_grad: Whether the tensor should require gradients. 54 | device: The device. Defaults to cpu. 55 | 56 | Returns: 57 | The tensor on the specified device. 58 | """ 59 | return cuda(torch.Tensor(x), requires_grad, device) 60 | 61 | 62 | def cudalong(x: np.ndarray, requires_grad: bool = False, device: torch.device | None = None) -> torch.Tensor: 63 | """Creates a LongTensor from a NumPy array and moves it to the specified device. 64 | 65 | Args: 66 | x: The NumPy array. 67 | requires_grad: Whether the tensor should require gradients. 68 | device: The device. Defaults to CPU 69 | 70 | Returns: 71 | The LongTensor on the specified device. 72 | """ 73 | return cuda(torch.LongTensor(x.astype(np.int64)), requires_grad, device) 74 | 75 | 76 | def θ(a: torch.Tensor, b: torch.Tensor, norm_by: int = 2) -> torch.Tensor: 77 | """Calculates the batchwise cosine similarity between two tensors. 78 | 79 | Args: 80 | a: A 3D tensor (b * m * w). 81 | b: A 3D tensor (b * r * w). 82 | norm_by: The norm to use for normalization. 83 | 84 | Returns: 85 | The batchwise cosine similarity (b * r * m). 86 | """ 87 | dot = torch.bmm(a, b.transpose(1, 2)) 88 | a_norm = torch.norm(a, p=norm_by, dim=2).unsqueeze(2) 89 | b_norm = torch.norm(b, p=norm_by, dim=2).unsqueeze(1) 90 | cos = dot / (a_norm * b_norm + δ) 91 | return cos.transpose(1, 2).contiguous() 92 | 93 | 94 | def σ(input: torch.Tensor, axis: int = 1) -> torch.Tensor: # NOQA 95 | """Applies the softmax function along a specified axis. 96 | 97 | Args: 98 | input: The input tensor. 99 | axis: The axis along which to apply softmax. 100 | 101 | Returns: 102 | The softmax output. 103 | """ 104 | return F.softmax(input, dim=axis) 105 | 106 | 107 | def register_nan_checks(model: nn.Module) -> None: 108 | """Registers backward hooks to check for NaN gradients. 109 | 110 | Args: 111 | model: The model to register hooks on. 112 | """ 113 | 114 | def check_grad( 115 | module: nn.Module, grad_input: tuple[torch.Tensor | None, ...], grad_output: tuple[torch.Tensor | None, ...] 116 | ) -> None: 117 | if any(torch.isnan(gi).any() for gi in grad_input if gi is not None): 118 | print(f"NaN gradient in grad_input of {type(module).__name__}") 119 | 120 | for module in model.modules(): 121 | module.register_full_backward_hook(check_grad) # type: ignore 122 | 123 | 124 | def apply_dict(dic: dict) -> None: 125 | """Applies gradient NaN checks to a dictionary of variables. 126 | 127 | Args: 128 | dic: The dictionary. 129 | """ 130 | for k, v in dic.items(): 131 | apply_var(v, k) 132 | if isinstance(v, nn.Module): 133 | key_list = [a for a in dir(v) if not a.startswith("__")] 134 | for key in key_list: 135 | apply_var(getattr(v, key), key) 136 | for pk, pv in v.named_parameters(): 137 | apply_var(pv, pk) 138 | 139 | 140 | def apply_var(v: torch.Tensor | nn.Module | None, k: str) -> None: 141 | """Applies gradient NaN checks to a variable. 142 | 143 | Args: 144 | v: The variable. 145 | k: The name of the variable. 146 | """ 147 | if isinstance(v, torch.Tensor) and v.requires_grad: 148 | v.register_hook(check_nan_gradient(k)) 149 | 150 | 151 | def check_nan_gradient(name: str = "") -> Callable[[torch.Tensor], torch.Tensor | None]: 152 | """Creates a hook to check for NaN gradients. 153 | 154 | Args: 155 | name: The name of the variable. 156 | 157 | Returns: 158 | The hook function. 159 | """ 160 | 161 | def f(tensor: torch.Tensor) -> torch.Tensor | None: 162 | if torch.isnan(tensor).any(): 163 | print(f"\nnan gradient of {name}:") 164 | return tensor 165 | return None 166 | 167 | return f 168 | 169 | 170 | def ptr(tensor: torch.Tensor) -> int: 171 | """Returns the memory address of a tensor. 172 | 173 | Args: 174 | tensor: The tensor. 175 | 176 | Returns: 177 | The memory address. 178 | """ 179 | return tensor.data_ptr() 180 | 181 | 182 | def ensure_gpu(tensor: torch.Tensor | np.ndarray, device: torch.device | None) -> torch.Tensor: 183 | """Ensures a tensor is on the specified GPU. 184 | 185 | Args: 186 | tensor: The tensor or NumPy array. 187 | device: The device 188 | 189 | Returns: 190 | The tensor on the specified GPU. 191 | """ 192 | if isinstance(tensor, torch.Tensor) and device is not None: 193 | return tensor.to(device) 194 | elif isinstance(tensor, np.ndarray) and device is not None: 195 | return torch.tensor(tensor, device=device) 196 | elif isinstance(tensor, np.ndarray): 197 | return torch.Tensor(tensor) 198 | else: 199 | return tensor 200 | 201 | 202 | def print_gradient(x: torch.Tensor, name: str) -> None: 203 | """Prints the gradient of a tensor. 204 | 205 | Args: 206 | x: The tensor. 207 | name: name of tensor 208 | """ 209 | s = "Gradient of " + name + " ----------------------------------" 210 | x.register_hook(lambda y: print(s, y.squeeze())) 211 | -------------------------------------------------------------------------------- /docs/dnc.xml: -------------------------------------------------------------------------------- 1 | 7T1Ze6LKtr8mj2d/MhY8MghomByiwMv5BAoEFZQZfv2tUtOddLLP6Xt2x6TTsdMKi1UFtea1qoA7Sjp0arE5bo08hPs7chR2d5R8R5IUyVHoB0P6C4QYcdwFEhdJeIGNvgMWyQAfEa/QOglh+QyxyvN9lRyvQOICDPIsg0H1DLYpirx9jhbl+/AZ4LiJ4bPeMWARbPbwBdo6CavtBcqR4Dtcg0m8fTwzwfKXI/4m2MVFXmfX892RVHT+XA4fNo99Xc9bbjdh3j4BUeM7SiryvLpsHToJ7jFxn5NN+Zuj3667gFn1Mw2Y63U0m30NHy/5fGFV/0iMcFNuIcYf3VHitjrs0SaBNtG1HzHKoYuxGPy1aUvqLxiQ/w7yw7GuIPrNqk2SweLfaFxilOz3Ur7Pi3OvlMJwDEUjOGocJuhyH49leYZ6FVEnSfB4pqrId/BJ49H5gztFp7iKD8Fe95/gEQL+h+DXUcKigt3fkor4xgAk2TA/wKroEcq1wb8Ay1zaXKWaBexlv/0uIuToStDtM/G4AjdXsYy/df6dNWjjyp3XOUXzX5z6aU6BHzjF3pRT7Ben/ledokc35RTzxan/Vaco/pac+uYvvzMGhshlX3fzotrmcZ5t9uPvUDGoi+bMOkzFs1P+xsin9H3CVNgllYNR/mKue+6TIzYsEnTlsLh2gkhX9E/Q8a779NiPDb7JEu4xhVXVX7m3qascgb4PQ8/z49+xnxckICrfjjwGKMz/WyAwBf+zOCCC53URXLGuDK82RQwfsfjXpaaA+02VNM+7/ycCQFEvVHVdJBW8qVgQT4Ri9BfJ/Fe5eCYVlwafUAb+xnL8chn48qs/b625H6w1d0trTb1kzJta62dqyTzTwL/AH2GaH7n2VC/ZW5lm8EIvkboVEOe78JBfO3tH1/0f3PQnZPvNzDHxE7WDR7ubHM4VF3FTHi/1mijpMN3F6wE53FSbO0q47JLKMYvvSClZida8Hd2rcS6gj7l42I4fYrQVG+jLGEuCi+HJaeH5eEPUTGmxmk0kIZ5EwnaXnIH70Xy1HT2Q/CHUwm1weBA8kmn89UMVkGbvrVd1QG6b0OLkZqBRgx4sHubiSksCEHpiMDPFsbAUNXedq/dVNEuCmZGXnWQ3iXZ8mOb3VmSu0CVHxTByyHA0GERkCi0avUJre0/yumEoeyYUOeGOFHVbo2aBEGv1VJhJk3FrCGMpUAVJtGRBFDIxFmI57tBldy4aKV3O0EAXnw13f9+6wmTMPSIIljwWBE1BDJMThCB0k50oTOKl0Aqz6RmX/ry4aqed6bP9UDx6D9xyFWOi0IgoAm0glGA9EQRLRwSUx8lsLCSi8a27ZfxIwC/cL9wv3D8G1zxMhJn80WzX7XGniwr7F/t38HFviiu76dl3th+PR1+4X7jvjauLHrar8Hew7V+4N8CVWm2J7er+N7Dtb4tbKjusR8xvo8tvgEvp7R2p2FRDSTM5nmaNpbOirEJRWE6pApdQhPFeWe4W9ewgSb+oQE3xzwvUgH9RnyYo5mV9mhyRv6BiSb+oXM3hJnzHojV4VrUmnpUpnxz8bFVr4h2r1txX1foDsZ0gb1W1Jt6tah1x6EtnLlVrdKHiTBqh719WuRbH9PfKdWr3pEPNgqgDIKOG03K5VMPkgK7yJLvHSinmymK9e3g4oSsQ1bEG2aCuI7pkTYQSU3UNDL8ip8r8Ptlv9mo7F8fL02LWjfq5Oc0OEBLtuPRnZrhlrAyTv4Y+DxPQoP4iOaMyG/IeAmelhr6pqdV48WFJCsmpAQxB8BTNNLpqeyNai83joK0N1oE8F3gxopLKOpgwppDasTyboMGNJOMzQGSD9Q9NI69A2/voICKNE9UDw7an6IGG7ThiLDuyJsPe1holLMBwWGpq4O14oQf3dDR2Ow4OK74PIjy/gPy4jU8hCinmmjCbfEH/IVQc1C3HDDzbJs66LPuw3CDzIWakFy2pTd0ByyLVtWWUiz4Loo0Z95BMDw65FUtARZS8IsOGNJbYzIjAzgrpyDg947BFU9F0pzfYUIgzFdsrdPYwUqmocbCcwNLbnaEiXUVqGDpHfjdU9ZiS9XDXzJnUMQgYOCZu334cif4YkMHKODg98O68KYAw9kneZPaFWnBEz8Nu0HnYrlSi0KITsJ0SiFsegIqj9lGtrpGpURreBiDgUQiopMQ8Ivn5NKVCswubMbKH4sDZDfUAtzhOLpWwQoGp4tv6esIBdHRuCBU54wsx4kHTS0cUX5u99mEk+nNAxQE2hR4NJrKcHsTeltGxdrLDHLQTf03OLJKPpIjqEu3E74qqsfN6Fqa6bjq+aoTuEGZ0YFe86fZpjZobTqYMmPW7DyXJHwUiG5uIWiikh6zZnD0URLxm68ib85uL6aIYkTSYgWC7LLJCI/A6Pj7RxiTsRth3RzoVs4ImLGnkzw04sQVZS+RW6P6I/U2k4fn/aI4J5ZDxPVWBnQwCZ0tpQcXwKEDXVcPDsRfn21NtyJy1gbJw0YHWfeg01Lbhc6/SdUPjJto6WpHw2LFhhzAUFre29Rz7l7qU2gVJwyzmh5hr+JHapsgiKbxgSPbSbsXJ2OeEhYjdjTCO582fC2GaRbcELQM54+BkDcdb7ZQkjpnMp5aReURma3DJAK2sQTa3ms4yvT22PFBcA9UyNqxDLqoa2DMdRbiK2pRDWNKncwiH/0q9lDyU0BwCpC8Mk9rYlQhwGLOYHwSUBo1c1GQJGNqjhyFN6Bk18QfX7Si9wS6dC/NIUPWx1opJJ3KCW87+gP2phZIBatTwBJ9LJoA7rufTFfT8RithD6ytwKJ93tX8QN5C5BhEO2RRXOTXw6hPG8rDmYfbafcMjrckbY38gUJ2GdrxgG376cB3XAjJAivFIlyDVt5KIcJVONunyfsIZjKVdtXQfwwv9zmgyFdnIot99EZdbgbdyhqpO/qeMGxIzJpqqHF4RBRsNVgThnB6KGmgKThS5S7qxHR83voWNp+HEnOrx3qLNmpgMaRnk6YPs/IBWOigyOugJeUy4tHx6ZZhO9SJYoX+VcsmVJOhaMAuyAdgZtQ0dd6fQp8Sqpks5hCbkvOSLBdnTSyWXOjtkQ/cjgmeqTOnq/cj22lUb0/ztlzXAzCdVl15aR12SOkJyG3MzAHMBPdlUoVkNj4lk2bFhM28YJqDRnWMTflZgMUBgCii2tpReMtgoTCP9tAKOLkJan44tBU3EwSbs7kZFFR7rMVi2om04H7ufckdDUFk40h58K2yDUt/OuLdsa2T5opxsgIcHC1rYOMZgzc4R5RqznHWKcMQV79EmyMOGZmifBYRGI5AxvQor7J1astAayjs8lzkUVrZTeZRFGlusGxqZmZFg7uKIhPEJsNFWrk513AQos4Fx0AW5CD5cBL7m0JFlLeIzRFHmFpBbg9EMMzZtozsUnD6WUWwrM14+2xJpBYR6apcsnLVq8xyUGGphZZDEZ2HY/RmPV4esT3e8/ncBJsdvp/mbDlHAfLDc2RQRTJlhxR1sG1VDjhu1697QHSN7X4Qab/h/jRdNl6YLflUhSNsnZxS0NaHAHmjDb/imJQMVY9qUi8GnMjdz3CWKe5xWJ/v0eb6gHOpWZFhP6VZEXDqE9S4ZaFmJ2DFFA7o1RwWOE7EbtC3KZEqEBfDiR/vcc7VNu4YHZgKM+wjReuDZT032Vcz7L6L0OcALCe4YhYhWhwUcm6R4Yo56CieT3O677bx3TlRdcoGx41mn6VEavQXm4T+qp52mpEdS7DTKtg5jTSAaHnUsVfJgED6a5sv8LmUqDi4iHMsmyoUL2hpm9A4BchCmwBihsu4QTOoSpTIWCp2hgTYGm6sDyW7N/M8OFhjg6vvnrYcdMi5TQtaDYwECz6zstN6ymQDn+KaQzPvdHRsYYtpVfSUYy9lxIXeufDoAfPBsgdSXOHUVyMn/LBn2zxaJ4AZR5wV+wM9yOfqHpaKlbfoM3s12taCkxmeh7KKxOf4PzGz0rGFAJJizgVWcaX6RvP41OM91f99Iv9XTOK+vCXwayL/PWZ0yVdmdG91azBNviIEJboMNI7uiAgNs+A97xP+TxLxCe4IfpX3o3/I+3NToSg2/ROEY55kVfmkZxsDnhgjgvxhUdF1bl/52QYEN/pB8i7X8F0Ovw3m5xYbvBTNWy02CM6LDXBR7LzcQNar46W08uuWG8jz78sNEr44+M4siLZVwDNGxwCgpaWgeBNpsktnwhSfWorjWYJC2cWDEnLiKhDGEz9HqY34MFaUmfvwIMCpvBPihdnP55Nd0GmI3+lADMPJjccugWuq90ul0OXeNnWc/4wagGMxCCMH50TpJWwYaXJJd0NKWkKGdmW+5wK9Ytv90vFxnbDKUtnk4Ny11+oJh9q6wp/MPOsDVZjSC262EBZyK07EqS1M0vG4FZLfFcba94yNUwsalsZDivSfFLXThUxMZhXeCNEp1VQe5qtousBxF6k2FEM0jXMqMZ2yZejoTFpwbgOcaTWk8pi0Agb6rSBAoXj3Ed6EipoO7VXYHCzbsgK261eZqlqGazHEwWICsCQ8+2DlwZ5/WB7Tg03MmgNZ53xpKgAssRayhnaSV3ZgF+3gF6ADVu5xTJdqYs2LS5ftt9s0xfaEp8OrhK72FzaxCVtmS9uGCYDFdNQR/DoZZHrhN0JJ1n0ucBpv28I4Hu84ITFmS0HSOzSEyWeC1ceh0diSGmyN3uTdwLvJdCg0By59rPimRrDxfgDOoqE2NDTMvYUV3XVJZsAV5/vRmZbi1KjSRqfrpYZNaN1essFRErloayCctaSjrF3scOEZ55V2E4KZj1NCyU4jn1svibjILJxF5ko03fL4pk04/kO0wNY7XPnrKWJeH6C9POKJK3t7JJjQXlvQpbsOESjJq2aF73Xe4BIuWGunPjR4KGd2wbBtunBCIkCa1Eyb8EGfptVaZlI5YmjkDXaZMF4m0kyaQ2EqTKefESbPGeg5dsGzbTejdII0OdIqVhyyw0iel0wdDqmuQsvQdbZHkrxmRpHD0PcrJPOLY3qsUn42wzI63tiygJQfTtApFFrYCbssFmbJIhDm1meDiSOIV7wpdlQwjH0oAqTlLVZa89TyZM3XFjGzDrVPDDVVSCslakCk4vmjk5LuqKbiSrehCJJ33cJngYb74hVAqTbacIG9tm0zLrpoZNkeA1UGNvsd725OBV7fwwx+4KkpR6ir9kHIYvDu1vBtYXIe1NwSsAAaxnrJkScuimyD3WbcQUu5vqRO0srCixEU5yRFdhmuM7V2CmFuBF1XZNNkwWIz0eDln0VFt92qtLPtyG6ak0cskaMMAtrfQ7SxRRGYVIt4ogH/r/PpA0M21ArR3ivSGEfj4kmlRtgm+0HOVKk85SthISyXrTTpREOYNGN0+TGHYIvPBNN3fJ3yD/6xiHkUetErHLi5rNiBbB1QgBzImrIybH+J05EzS+9k6gGVta1vFeZo6KjaADTYhs36PgnoLs27qjrS4QFFb/UlrBg45DGXdpRNowWYBZA/IPvCUMWpdFLZtgIXe0gWML5QCgcnRsZsAQVRmk5bcUzHSEk/E0ybM3SkQzYmlo4yilDykCF5X2pG4DLUocxKRau1YstBLKFUWmaFZFpRxcarAUTm+pJZKcCP1vfWArs9rrOaIkSy78PG2QIcS5hJtsLT1S6TDpqlMuuC2eYlz4k8OCcxIYodsyJzCMZDeQnEa4cabixuU99eIawBNEJO1kmLTP5i8iGo9sawEVw5GbUvkNFGhKGGQd46KZOWa5mI84qsBxjVEc5YiO39HmpYQ84ZCx9YYXYc+OxIoNQFjCi82KCiUlmh3JDCRej8Eg1mfrkK0O/Z4+Io0QEUMcJ2v9H0S840Ge0BrFejs/Y14zPaasDtI2FZ7MT4PpQ+juV+o9i7Gaqa4d1mrd4XA8EG7uATi/Jg4eS67ZDsijsfMFjIlfW4Xm5YggS1ny1cw9+wZlqHTDR1ZIpnMScd3KrBjWyXNfQynFqmEWw6WIo7h8LznA8o3+FEpQY1RXAKdu/UkXRtu/ow0cgbwqbhxTzzbrYisdTR9RiEHKTXPpFkNTZIfYGfNyTOtQzg3AfZDaY4oghcgshXD9gSReRaJwWZ0piOD87CX8iaHKaOySN6mxXfDA1Tao5D5Ejai5TauRRlr1Svuy5Bl1Q+oSmLLozq40XQbxGRg+LUR0awYk0OLjxi0dBBpJU8yayXTHoC6wjXUOYEH2kStjEOgLm4Y8hC5LoORipBzG2UEQ2Bel3ZD2TaHpaZuqo4FICO4EHkWc5/yM65UCRmfJPtYbTK8zbMUk3FTAKDE8ZZpio1n8VazM7FD5H1vS1MSnrsLnlqHzo+kyZphKhJXWjYYKrUiDw4GnHx4td8FU3mOGRvWBQknm3ysWqa4ojzSgpn/NFJcswqPHahI4NhiwOkElt+xqUAjVxIyVYZ8r5WEPa1hnquw6yJbRwScdBHbW1+QJwdE2SjpZIHbWbCVbv1R7DCbx3zcyfsGfnmEj/y9doLHKdBvjYtclzw6wuFGwiAiyE0b6E4qKPsgqDbYbSxh/k2TS27Q/EQdzAqqmn0exxLNdqFk5v9gNeSpOw3nQjzJjtcq7Ucjvs7IqKsyEPREzpPR1AWyrqcTT+CpSdMaevjSOwb1VZ0ToQJh4U8Meq59HCer108rKz5PSO5kwmeLPhVU7bEs0kPlnj5bFCC5F6ZsiV+xf2X9Ae+sfZTzs1S73eTNfHyqc2T7FhXL0TgO08x8dptUsHFcXMeQltsjs95+4whP/ns3ecP9JUoQZRv/6z70auTjc/07vGFHE/17vHtBf/oAerMe6vdKw/K/t01i3+pWSR9I9Ui+RszlPxTOcrdiKPUy7dRfD3i/oMIwc0eT0HdWq3BP33TxSfk9u3io6+XOv3PrzRgmZu+gOblI6OsuvoAoexYYEWWvTUvfghlefKVFPLNQtm/W/SLBhVUeVG+N0v48+fGLGGY5+rBv/LGjzdjCXXj1zPd/e362h/ikc+64PqxPPMeST358slpEhpcke/3iLxIuJIwhNlZGjavRK83VkYWCCKv3FgZ+fdURvrlezmM159oh0ZYPad3Actk2PhnBEzM65JrhM2IdwyuQGL9KK9kxLv7JM7Q9h5GuCtMtiTY7IUruMK6I5aIv0kWL8+K9C/6xsygWPIZMwjqJTPAK7wgf0HYQL98N0521klk/M4DeKIu+00PX3FefwaTAHgeUTCvMIng34pLLy3aF5dej8Gfcwnw4M24hBePfXsP8uXei+9vm6bG/wc= 2 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Table of Contents 4 | 5 | - [Change Log](#change-log) 6 | - [Unreleased](#unreleased) 7 | - [1.0.1 (2019-04-05)](#101-2019-04-05) 8 | - [1.0.0 (2019-04-05)](#100-2019-04-05) 9 | - [0.0.9 (2018-04-23)](#009-2018-04-23) 10 | - [0.0.7 (2017-12-20)](#007-2017-12-20) 11 | - [0.0.6 (2017-11-12)](#006-2017-11-12) 12 | - [0.5.0 (2017-11-01)](#050-2017-11-01) 13 | - [0.0.3 (2017-10-27)](#003-2017-10-27) 14 | - [0.0.2 (2017-10-26)](#002-2017-10-26) 15 | - [v0.0.1 (2017-10-26)](#v001-2017-10-26) 16 | 17 | 18 | 19 | # Change Log 20 | 21 | ## [Unreleased](https://github.com/ixaxaar/pytorch-dnc/tree/HEAD) 22 | 23 | [Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/1.0.1...HEAD) 24 | 25 | **Merged pull requests:** 26 | 27 | - Fixes for \#43 [\#44](https://github.com/ixaxaar/pytorch-dnc/pull/44) ([ixaxaar](https://github.com/ixaxaar)) 28 | 29 | ## [1.0.1](https://github.com/ixaxaar/pytorch-dnc/tree/1.0.1) (2019-04-05) 30 | [Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/1.0.0...1.0.1) 31 | 32 | **Closed issues:** 33 | 34 | - When running adding task -- ModuleNotFoundError: No module named 'index' [\#39](https://github.com/ixaxaar/pytorch-dnc/issues/39) 35 | - SyntaxError [\#36](https://github.com/ixaxaar/pytorch-dnc/issues/36) 36 | - PySide dependency error [\#33](https://github.com/ixaxaar/pytorch-dnc/issues/33) 37 | - Issues when using pytorch 0.4 [\#31](https://github.com/ixaxaar/pytorch-dnc/issues/31) 38 | - TypeError: cat received an invalid combination of arguments - got \(list, int\), but expected one of: [\#29](https://github.com/ixaxaar/pytorch-dnc/issues/29) 39 | 40 | **Merged pull requests:** 41 | 42 | - Fixes \#36 and \#39 [\#42](https://github.com/ixaxaar/pytorch-dnc/pull/42) ([ixaxaar](https://github.com/ixaxaar)) 43 | 44 | ## [1.0.0](https://github.com/ixaxaar/pytorch-dnc/tree/1.0.0) (2019-04-05) 45 | [Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/0.0.9...1.0.0) 46 | 47 | **Closed issues:** 48 | 49 | - Question about the running speed of Pyflann and Faiss for the SAM model [\#40](https://github.com/ixaxaar/pytorch-dnc/issues/40) 50 | - SyntaxError [\#37](https://github.com/ixaxaar/pytorch-dnc/issues/37) 51 | - Values in hidden become nan [\#35](https://github.com/ixaxaar/pytorch-dnc/issues/35) 52 | - faiss error [\#32](https://github.com/ixaxaar/pytorch-dnc/issues/32) 53 | 54 | **Merged pull requests:** 55 | 56 | - Port to pytorch 1.x [\#41](https://github.com/ixaxaar/pytorch-dnc/pull/41) ([ixaxaar](https://github.com/ixaxaar)) 57 | - fix parens in example usage and gpu usage for SAM [\#30](https://github.com/ixaxaar/pytorch-dnc/pull/30) ([kierkegaard13](https://github.com/kierkegaard13)) 58 | 59 | ## [0.0.9](https://github.com/ixaxaar/pytorch-dnc/tree/0.0.9) (2018-04-23) 60 | [Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/0.0.7...0.0.9) 61 | 62 | **Fixed bugs:** 63 | 64 | - Use usage vector to determine least recently used memory [\#26](https://github.com/ixaxaar/pytorch-dnc/issues/26) 65 | - Store entire memory after memory limit is reached [\#24](https://github.com/ixaxaar/pytorch-dnc/issues/24) 66 | 67 | **Merged pull requests:** 68 | 69 | - memory.py: fix indexing for read\_modes transform [\#28](https://github.com/ixaxaar/pytorch-dnc/pull/28) ([jbinas](https://github.com/jbinas)) 70 | - Bugfixes [\#27](https://github.com/ixaxaar/pytorch-dnc/pull/27) ([ixaxaar](https://github.com/ixaxaar)) 71 | 72 | ## [0.0.7](https://github.com/ixaxaar/pytorch-dnc/tree/0.0.7) (2017-12-20) 73 | [Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/0.0.6...0.0.7) 74 | 75 | **Implemented enhancements:** 76 | 77 | - GPU kNNs [\#21](https://github.com/ixaxaar/pytorch-dnc/issues/21) 78 | - Implement temporal addressing for SDNCs [\#18](https://github.com/ixaxaar/pytorch-dnc/issues/18) 79 | - Feature: Sparse Access Memory [\#4](https://github.com/ixaxaar/pytorch-dnc/issues/4) 80 | - SAMs [\#22](https://github.com/ixaxaar/pytorch-dnc/pull/22) ([ixaxaar](https://github.com/ixaxaar)) 81 | - Temporal links for SDNC [\#19](https://github.com/ixaxaar/pytorch-dnc/pull/19) ([ixaxaar](https://github.com/ixaxaar)) 82 | - SDNC [\#16](https://github.com/ixaxaar/pytorch-dnc/pull/16) ([ixaxaar](https://github.com/ixaxaar)) 83 | 84 | **Merged pull requests:** 85 | 86 | - Add more tasks [\#23](https://github.com/ixaxaar/pytorch-dnc/pull/23) ([ixaxaar](https://github.com/ixaxaar)) 87 | - Scale interface vectors, dynamic memory pass [\#17](https://github.com/ixaxaar/pytorch-dnc/pull/17) ([ixaxaar](https://github.com/ixaxaar)) 88 | - Update README.md [\#14](https://github.com/ixaxaar/pytorch-dnc/pull/14) ([MaxwellRebo](https://github.com/MaxwellRebo)) 89 | 90 | ## [0.0.6](https://github.com/ixaxaar/pytorch-dnc/tree/0.0.6) (2017-11-12) 91 | [Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/0.5.0...0.0.6) 92 | 93 | **Implemented enhancements:** 94 | 95 | - Re-write allocation vector code, use pytorch's cumprod [\#13](https://github.com/ixaxaar/pytorch-dnc/issues/13) 96 | 97 | **Fixed bugs:** 98 | 99 | - Stacked DNCs forward pass wrong [\#12](https://github.com/ixaxaar/pytorch-dnc/issues/12) 100 | - Temporal debugging of memory [\#11](https://github.com/ixaxaar/pytorch-dnc/pull/11) ([ixaxaar](https://github.com/ixaxaar)) 101 | 102 | ## [0.5.0](https://github.com/ixaxaar/pytorch-dnc/tree/0.5.0) (2017-11-01) 103 | [Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/0.0.3...0.5.0) 104 | 105 | **Implemented enhancements:** 106 | 107 | - Multiple hidden layers per controller layer [\#7](https://github.com/ixaxaar/pytorch-dnc/issues/7) 108 | - Vizdom integration and fix cumprod bug \#5 [\#6](https://github.com/ixaxaar/pytorch-dnc/pull/6) ([ixaxaar](https://github.com/ixaxaar)) 109 | 110 | **Fixed bugs:** 111 | 112 | - Use shifted cumprods, emulate tensorflow's cumprod with exclusive=True [\#5](https://github.com/ixaxaar/pytorch-dnc/issues/5) 113 | - Vizdom integration and fix cumprod bug \\#5 [\#6](https://github.com/ixaxaar/pytorch-dnc/pull/6) ([ixaxaar](https://github.com/ixaxaar)) 114 | 115 | **Closed issues:** 116 | 117 | - Write unit tests [\#8](https://github.com/ixaxaar/pytorch-dnc/issues/8) 118 | - broken links [\#3](https://github.com/ixaxaar/pytorch-dnc/issues/3) 119 | 120 | **Merged pull requests:** 121 | 122 | - Test travis build [\#10](https://github.com/ixaxaar/pytorch-dnc/pull/10) ([ixaxaar](https://github.com/ixaxaar)) 123 | - Implement Hidden layers, small enhancements, cleanups [\#9](https://github.com/ixaxaar/pytorch-dnc/pull/9) ([ixaxaar](https://github.com/ixaxaar)) 124 | 125 | ## [0.0.3](https://github.com/ixaxaar/pytorch-dnc/tree/0.0.3) (2017-10-27) 126 | [Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/0.0.2...0.0.3) 127 | 128 | **Implemented enhancements:** 129 | 130 | - Implementation of Dropout for controller [\#2](https://github.com/ixaxaar/pytorch-dnc/pull/2) ([ixaxaar](https://github.com/ixaxaar)) 131 | - Fix size issue for GRU and vanilla RNN [\#1](https://github.com/ixaxaar/pytorch-dnc/pull/1) ([ixaxaar](https://github.com/ixaxaar)) 132 | 133 | ## [0.0.2](https://github.com/ixaxaar/pytorch-dnc/tree/0.0.2) (2017-10-26) 134 | [Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/v0.0.1...0.0.2) 135 | 136 | ## [v0.0.1](https://github.com/ixaxaar/pytorch-dnc/tree/v0.0.1) (2017-10-26) 137 | 138 | 139 | \* *This Change Log was automatically generated by [github_changelog_generator](https://github.com/skywinder/Github-Changelog-Generator)* 140 | -------------------------------------------------------------------------------- /tasks/adding_task.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import os 6 | import sys 7 | 8 | import numpy as np 9 | import torch 10 | from typing import Any 11 | import torch.optim as optim 12 | from torch.nn.utils import clip_grad_norm_ 13 | from visdom import Visdom 14 | 15 | sys.path.insert(0, os.path.join("..", "..")) 16 | from dnc import DNC, SDNC, SAM 17 | 18 | 19 | def get_device(cuda_id: int) -> torch.device: 20 | if cuda_id >= 0 and torch.cuda.is_available(): 21 | return torch.device(f"cuda:{cuda_id}") 22 | return torch.device("cpu") 23 | 24 | 25 | def onehot(x: int, n: int) -> np.ndarray: 26 | ret = np.zeros(n, dtype=np.float32) 27 | ret[x] = 1.0 28 | return ret 29 | 30 | 31 | def generate_data(length: int, size: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor, str]: 32 | content = np.random.randint(0, size - 1, length) 33 | seqlen = length + 1 34 | x_seq = [onehot(int(val), size) if i < length else onehot(size - 1, size) for i, val in enumerate(content)] 35 | x_seq = np.array(x_seq, dtype=np.float32).reshape(1, seqlen, size) # type: ignore 36 | sums = np.array(np.sum(content), dtype=np.float32).reshape(1, 1, 1) 37 | sums_text = " + ".join(str(val) for val in content) 38 | 39 | return (torch.tensor(x_seq, device=device), torch.tensor(sums, device=device), sums_text) 40 | 41 | 42 | def cross_entropy(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 43 | return (prediction - target) ** 2 44 | 45 | 46 | def main() -> None: 47 | parser = argparse.ArgumentParser(description="PyTorch Differentiable Neural Computer Adding Task") 48 | parser.add_argument("-input_size", type=int, default=6, help="Dimension of input feature") 49 | parser.add_argument("--rnn_type", type=str, default="lstm", help="Type of recurrent cells (lstm, gru, rnn)") 50 | parser.add_argument("--nhid", type=int, default=64, help="Number of hidden units in the controller") 51 | parser.add_argument("--dropout", type=float, default=0, help="Controller dropout rate") 52 | parser.add_argument("--memory_type", type=str, default="dnc", help="Memory type (dnc, sdnc, sam)") 53 | parser.add_argument("--nlayer", type=int, default=1, help="Number of memory layers") 54 | parser.add_argument("--nhlayer", type=int, default=2, help="Number of hidden layers in each RNN") 55 | parser.add_argument("--lr", type=float, default=1e-4, help="Initial learning rate") 56 | parser.add_argument("--optim", type=str, default="adam", help="Optimizer (adam, rmsprop)") 57 | parser.add_argument("--clip", type=float, default=50, help="Gradient clipping value") 58 | parser.add_argument("--batch_size", type=int, default=100, help="Batch size") 59 | parser.add_argument("--mem_size", type=int, default=20, help="Memory cell size") 60 | parser.add_argument("--mem_slot", type=int, default=16, help="Number of memory slots") 61 | parser.add_argument("--read_heads", type=int, default=4, help="Number of read heads") 62 | parser.add_argument("--sparse_reads", type=int, default=10, help="Number of sparse reads per head (sdnc/sam)") 63 | parser.add_argument("--temporal_reads", type=int, default=2, help="Number of temporal reads (sdnc)") 64 | parser.add_argument("--sequence_max_length", type=int, default=1000, help="Maximum sequence length") 65 | parser.add_argument("--cuda", type=int, default=-1, help="CUDA GPU ID (-1 for CPU)") 66 | parser.add_argument("--iterations", type=int, default=2000, help="Total number of iterations") 67 | parser.add_argument("--summarize_freq", type=int, default=100, help="Summarize frequency") 68 | parser.add_argument("--check_freq", type=int, default=100, help="Checkpoint frequency") 69 | parser.add_argument("--visdom", action="store_true", help="Use Visdom for visualization") 70 | args = parser.parse_args() 71 | print(args) 72 | 73 | device = get_device(args.cuda) 74 | 75 | if args.visdom: 76 | viz = Visdom() 77 | if not viz.check_connection(): 78 | print("Visdom server not running. Disabling Visdom.") 79 | args.visdom = False 80 | 81 | if args.memory_type == "dnc": 82 | rnn = DNC( 83 | input_size=args.input_size, 84 | hidden_size=args.nhid, 85 | rnn_type=args.rnn_type, 86 | num_layers=args.nlayer, 87 | num_hidden_layers=args.nhlayer, 88 | dropout=args.dropout, 89 | nr_cells=args.mem_slot, 90 | cell_size=args.mem_size, 91 | read_heads=args.read_heads, 92 | device=device, 93 | debug=args.visdom, 94 | batch_first=True, 95 | independent_linears=True, 96 | ) 97 | elif args.memory_type == "sdnc": 98 | rnn = SDNC( 99 | input_size=args.input_size, 100 | hidden_size=args.nhid, 101 | rnn_type=args.rnn_type, 102 | num_layers=args.nlayer, 103 | num_hidden_layers=args.nhlayer, 104 | dropout=args.dropout, 105 | nr_cells=args.mem_slot, 106 | cell_size=args.mem_size, 107 | sparse_reads=args.sparse_reads, 108 | temporal_reads=args.temporal_reads, 109 | read_heads=args.read_heads, 110 | device=device, 111 | debug=args.visdom, 112 | batch_first=True, 113 | independent_linears=False, 114 | ) 115 | elif args.memory_type == "sam": 116 | rnn = SAM( 117 | input_size=args.input_size, 118 | hidden_size=args.nhid, 119 | rnn_type=args.rnn_type, 120 | num_layers=args.nlayer, 121 | num_hidden_layers=args.nhlayer, 122 | dropout=args.dropout, 123 | nr_cells=args.mem_slot, 124 | cell_size=args.mem_size, 125 | sparse_reads=args.sparse_reads, 126 | read_heads=args.read_heads, 127 | device=device, 128 | debug=args.visdom, 129 | batch_first=True, 130 | independent_linears=False, 131 | ) 132 | else: 133 | raise ValueError('Invalid memory_type. Choose "dnc", "sdnc", or "sam".') 134 | 135 | rnn = rnn.to(device) 136 | print(rnn) 137 | optimizer: Any 138 | 139 | if args.optim == "adam": 140 | optimizer = optim.Adam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=(0.9, 0.98)) 141 | elif args.optim == "adamax": 142 | optimizer = optim.Adamax(rnn.parameters(), lr=args.lr, eps=1e-9, betas=(0.9, 0.98)) 143 | elif args.optim == "rmsprop": 144 | optimizer = optim.RMSprop(rnn.parameters(), lr=args.lr, momentum=0.9, eps=1e-10) 145 | elif args.optim == "sgd": 146 | optimizer = optim.SGD(rnn.parameters(), lr=args.lr) 147 | elif args.optim == "adagrad": 148 | optimizer = optim.Adagrad(rnn.parameters(), lr=args.lr) 149 | elif args.optim == "adadelta": 150 | optimizer = optim.Adadelta(rnn.parameters(), lr=args.lr) 151 | else: 152 | raise ValueError(f"Unsupported optimizer: {args.optim}") 153 | 154 | last_100_losses = [] 155 | 156 | for epoch in range(args.iterations + 1): 157 | print(f"\rIteration {epoch}/{args.iterations}", end="") 158 | optimizer.zero_grad() 159 | 160 | random_length = np.random.randint(2, args.sequence_max_length + 1) 161 | input_data, target_output, sums_text = generate_data(random_length, args.input_size, device) 162 | input_data = input_data.repeat(args.batch_size, 1, 1) 163 | target_output = target_output.repeat(args.batch_size, 1, 1) 164 | 165 | output, (chx, mhx, rv) = rnn(input_data, (None, None, None), reset_experience=True, pass_through_memory=True) 166 | 167 | output = output.sum(dim=2, keepdim=True).sum(dim=1, keepdim=True) 168 | loss = cross_entropy(output, target_output) 169 | loss.backward() 170 | 171 | clip_grad_norm_(rnn.parameters(), args.clip) 172 | optimizer.step() 173 | loss_value = loss.item() 174 | 175 | # Detach memory from graph 176 | if mhx is not None: 177 | mhx = {k: (v.detach() if isinstance(v, torch.Tensor) else v) for k, v in mhx.items()} 178 | 179 | last_100_losses.append(loss_value) 180 | 181 | if epoch % args.summarize_freq == 0: 182 | print(f"\rIteration {epoch}/{args.iterations}") 183 | print(f"Avg. Loss: {np.mean(last_100_losses):.4f}") 184 | output_value = output.detach().cpu().numpy().item() 185 | print(f"Real value: = {int(target_output[0].item())}") 186 | print(f"Predicted: = {int(output_value // 1)} [{output_value}]") 187 | last_100_losses = [] 188 | 189 | print("\nTesting generalization...") 190 | rnn.eval() # Switch to evaluation mode 191 | 192 | with torch.no_grad(): # Disable gradient calculations during testing 193 | for i in range(int((args.iterations + 1) / 10)): 194 | print(f"\nIteration {i}/{args.iterations // 10}") 195 | random_length = np.random.randint(2, int(args.sequence_max_length) * 10 + 1) 196 | input_data, target_output, _ = generate_data(random_length, args.input_size, device) 197 | input_data = input_data.repeat(args.batch_size, 1, 1) 198 | target_output = target_output.repeat(args.batch_size, 1, 1) 199 | 200 | output, *_ = rnn(input_data, (None, None, None), reset_experience=True, pass_through_memory=True) 201 | 202 | output_value = output.sum(dim=2, keepdim=True).sum(dim=1, keepdim=True).item() 203 | print(f"Real value: = {int(target_output[0].item())}") 204 | print(f"Predicted: = {int(output_value // 1)} [{output_value}]") 205 | 206 | 207 | if __name__ == "__main__": 208 | main() 209 | -------------------------------------------------------------------------------- /tasks/argmax_task.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import os 6 | import sys 7 | 8 | import numpy as np 9 | import torch 10 | from typing import Any 11 | import torch.optim as optim 12 | from torch.nn.utils import clip_grad_norm_ 13 | from visdom import Visdom 14 | 15 | # Add the parent directory to sys.path to allow imports from dnc 16 | sys.path.insert( 17 | 0, os.path.join("..", "..") 18 | ) # os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) # <-- NO! Use relative path. 19 | from dnc import DNC, SDNC, SAM 20 | 21 | 22 | def get_device(cuda_id: int) -> torch.device: 23 | """Gets the torch device based on CUDA availability and ID.""" 24 | if cuda_id >= 0 and torch.cuda.is_available(): 25 | return torch.device(f"cuda:{cuda_id}") 26 | else: 27 | return torch.device("cpu") 28 | 29 | 30 | def onehot(x: int, n: int) -> np.ndarray: 31 | """Creates a one-hot encoded vector.""" 32 | ret = np.zeros(n, dtype=np.float32) 33 | ret[x] = 1.0 34 | return ret 35 | 36 | 37 | def generate_data(length: int, size: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 38 | """Generates data for the argmax task.""" 39 | content = np.random.randint(0, size - 1, length) 40 | 41 | seqlen = length + 1 42 | x_seq = [onehot(int(val), size) if i < length else onehot(size - 1, size) for i, val in enumerate(content)] 43 | x_seq.append(onehot(size - 1, size)) 44 | x_seq = np.array(x_seq, dtype=np.float32).reshape(1, seqlen, size) # type: ignore 45 | 46 | max_ind = np.argmax(content) 47 | target_output = np.zeros((1, 1, 1), dtype=np.float32) 48 | target_output[:, 0, 0] = max_ind 49 | 50 | weights_vec = np.zeros((1, 1, 1), dtype=np.float32) 51 | weights_vec[:, 0, 0] = 1.0 52 | 53 | return ( 54 | torch.tensor(x_seq, device=device), 55 | torch.tensor(target_output, device=device), 56 | torch.tensor(weights_vec, device=device), 57 | ) 58 | 59 | 60 | def main() -> None: 61 | """Main function for the argmax task.""" 62 | parser = argparse.ArgumentParser(description="PyTorch Differentiable Neural Computer Argmax Task") 63 | parser.add_argument("--input_size", type=int, default=6, help="Dimension of input feature") 64 | parser.add_argument("--rnn_type", type=str, default="lstm", help="Type of recurrent cells (lstm, gru, rnn)") 65 | parser.add_argument("--nhid", type=int, default=100, help="Number of hidden units in the controller") 66 | parser.add_argument("--dropout", type=float, default=0, help="Controller dropout rate") 67 | parser.add_argument("--memory_type", type=str, default="dnc", help="Memory type (dnc, sdnc, sam)") 68 | parser.add_argument("--nlayer", type=int, default=1, help="Number of memory layers") 69 | parser.add_argument("--nhlayer", type=int, default=2, help="Number of hidden layers in each RNN") 70 | parser.add_argument("--lr", type=float, default=1e-4, help="Initial learning rate") 71 | parser.add_argument("--optim", type=str, default="adam", help="Optimizer (adam, rmsprop)") 72 | parser.add_argument("--clip", type=float, default=50, help="Gradient clipping value") 73 | parser.add_argument("--batch_size", type=int, default=100, help="Batch size") 74 | parser.add_argument("--mem_size", type=int, default=20, help="Memory cell size") 75 | parser.add_argument("--mem_slot", type=int, default=16, help="Number of memory slots") 76 | parser.add_argument("--read_heads", type=int, default=4, help="Number of read heads") 77 | parser.add_argument( 78 | "--sparse_reads", type=int, default=10, help="Number of sparse reads per read head (for sdnc and sam)" 79 | ) 80 | parser.add_argument("--temporal_reads", type=int, default=2, help="Number of temporal reads (for sdnc)") 81 | parser.add_argument("--sequence_max_length", type=int, default=4, help="Maximum sequence length") 82 | parser.add_argument("--cuda", type=int, default=-1, help="CUDA GPU ID (-1 for CPU)") 83 | parser.add_argument("--iterations", type=int, default=2000, help="Total number of iterations") 84 | parser.add_argument("--summarize_freq", type=int, default=100, help="Summarize frequency") 85 | parser.add_argument("--check_freq", type=int, default=100, help="Checkpoint frequency") 86 | parser.add_argument("--visdom", action="store_true", help="Use Visdom for visualization") 87 | 88 | args = parser.parse_args() 89 | print(args) 90 | 91 | device = get_device(args.cuda) 92 | 93 | if args.visdom: 94 | viz = Visdom() 95 | if not viz.check_connection(): 96 | print("Visdom server not running. Disabling Visdom.") 97 | args.visdom = False 98 | 99 | if args.memory_type == "dnc": 100 | rnn = DNC( 101 | input_size=args.input_size, 102 | hidden_size=args.nhid, 103 | rnn_type=args.rnn_type, 104 | num_layers=args.nlayer, 105 | num_hidden_layers=args.nhlayer, 106 | dropout=args.dropout, 107 | nr_cells=args.mem_slot, 108 | cell_size=args.mem_size, 109 | read_heads=args.read_heads, 110 | device=device, 111 | debug=args.visdom, 112 | batch_first=True, 113 | independent_linears=False, 114 | ) 115 | elif args.memory_type == "sdnc": 116 | rnn = SDNC( 117 | input_size=args.input_size, 118 | hidden_size=args.nhid, 119 | rnn_type=args.rnn_type, 120 | num_layers=args.nlayer, 121 | num_hidden_layers=args.nhlayer, 122 | dropout=args.dropout, 123 | nr_cells=args.mem_slot, 124 | cell_size=args.mem_size, 125 | sparse_reads=args.sparse_reads, 126 | temporal_reads=args.temporal_reads, 127 | read_heads=args.read_heads, 128 | device=device, 129 | debug=args.visdom, 130 | batch_first=True, 131 | independent_linears=False, 132 | ) 133 | elif args.memory_type == "sam": 134 | rnn = SAM( 135 | input_size=args.input_size, 136 | hidden_size=args.nhid, 137 | rnn_type=args.rnn_type, 138 | num_layers=args.nlayer, 139 | num_hidden_layers=args.nhlayer, 140 | dropout=args.dropout, 141 | nr_cells=args.mem_slot, 142 | cell_size=args.mem_size, 143 | sparse_reads=args.sparse_reads, 144 | read_heads=args.read_heads, 145 | device=device, 146 | debug=args.visdom, 147 | batch_first=True, 148 | independent_linears=False, 149 | ) 150 | else: 151 | raise ValueError('Invalid memory_type. Choose "dnc", "sdnc", or "sam".') 152 | 153 | rnn = rnn.to(device) 154 | print(rnn) 155 | optimizer: Any 156 | 157 | if args.optim == "adam": 158 | optimizer = optim.Adam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=(0.9, 0.98)) 159 | elif args.optim == "adamax": 160 | optimizer = optim.Adamax(rnn.parameters(), lr=args.lr, eps=1e-9, betas=(0.9, 0.98)) 161 | elif args.optim == "rmsprop": 162 | optimizer = optim.RMSprop(rnn.parameters(), lr=args.lr, momentum=0.9, eps=1e-10) 163 | elif args.optim == "sgd": 164 | optimizer = optim.SGD(rnn.parameters(), lr=args.lr) 165 | elif args.optim == "adagrad": 166 | optimizer = optim.Adagrad(rnn.parameters(), lr=args.lr) 167 | elif args.optim == "adadelta": 168 | optimizer = optim.Adadelta(rnn.parameters(), lr=args.lr) 169 | else: 170 | raise ValueError(f"Invalid optimizer: {args.optim}") 171 | 172 | last_100_losses = [] 173 | 174 | for epoch in range(args.iterations + 1): 175 | print(f"\rIteration {epoch}/{args.iterations}", end="") 176 | optimizer.zero_grad() 177 | 178 | random_length = np.random.randint(2, args.sequence_max_length + 1) 179 | input_data, target_output, loss_weights = generate_data(random_length, args.input_size, device) 180 | input_data = input_data.repeat(args.batch_size, 1, 1) 181 | target_output = target_output.repeat(args.batch_size, 1, 1) 182 | loss_weights = loss_weights.repeat(args.batch_size, 1, 1) 183 | 184 | output, (chx, mhx, rv) = rnn( 185 | input_data, (None, None, None), reset_experience=True, pass_through_memory=True 186 | ) # debug removed 187 | 188 | loss = torch.mean(((loss_weights * output).sum(-1, keepdim=True) - target_output) ** 2) 189 | loss.backward() 190 | 191 | clip_grad_norm_(rnn.parameters(), args.clip) 192 | optimizer.step() 193 | loss_value = loss.item() 194 | 195 | # Detach memory from graph 196 | if mhx is not None: 197 | mhx = {k: (v.detach() if isinstance(v, torch.Tensor) else v) for k, v in mhx.items()} 198 | 199 | last_100_losses.append(loss_value) 200 | 201 | if epoch % args.summarize_freq == 0: 202 | output_value = (loss_weights * output).sum().item() 203 | target_value = target_output.sum().item() 204 | 205 | print(f"\rIteration {epoch}/{args.iterations}") 206 | print(f"Avg. Loss: {np.mean(last_100_losses):.4f}") 207 | print(f"Real value: = {int(target_value)}") 208 | print(f"Predicted: = {int(output_value // 1)} [{output_value}]") 209 | last_100_losses = [] 210 | 211 | print("\nTesting generalization...") 212 | rnn.eval() # Switch to evaluation mode 213 | 214 | with torch.no_grad(): # Disable gradient calculations during testing 215 | for i in range(int((args.iterations + 1) / 10)): 216 | print(f"\nIteration {i}/{args.iterations // 10}") 217 | random_length = np.random.randint(2, args.sequence_max_length * 2 + 1) 218 | input_data, target_output, loss_weights = generate_data(random_length, args.input_size, device) 219 | input_data = input_data.repeat(args.batch_size, 1, 1) 220 | target_output = target_output.repeat(args.batch_size, 1, 1) 221 | loss_weights = loss_weights.repeat(args.batch_size, 1, 1) 222 | 223 | output, *_ = rnn(input_data, (None, None, None), reset_experience=True, pass_through_memory=True) 224 | 225 | output_value = output[:, -1, :].sum().item() 226 | target_value = target_output.sum().item() 227 | 228 | print(f"Real value: = {int(target_value)}") 229 | print(f"Predicted: = {int(output_value // 1)} [{output_value}]") 230 | 231 | 232 | if __name__ == "__main__": 233 | main() 234 | -------------------------------------------------------------------------------- /scripts/build_faiss.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # Script to compile FAISS with CUDA and cuBLAS support directly into current virtual environment 5 | # This script will use the currently active Python environment 6 | 7 | # Check if a virtual environment is active 8 | if [ -z "$VIRTUAL_ENV" ]; then 9 | echo "Error: No active Python virtual environment detected." 10 | echo "Please activate a virtual environment before running this script." 11 | exit 1 12 | fi 13 | 14 | echo "Using Python virtual environment at: $VIRTUAL_ENV" 15 | 16 | # Configuration 17 | FAISS_VERSION="1.10.0" # Latest stable version 18 | BUILD_DIR="$(pwd)/faiss_build" 19 | INSTALL_DIR="$VIRTUAL_ENV" # Install directly to virtual environment 20 | CUDA_ARCHS="75" # For compute capability 7.5, use comma-separated list for multiple archs 21 | PYTHON_BINDINGS="ON" # Set to OFF to disable Python bindings 22 | USE_MKL="OFF" # Set to ON to use Intel MKL (recommended for performance) 23 | USE_CUVS="OFF" # Set to ON to enable NVIDIA cuVS implementations 24 | OPT_LEVEL="avx2" # Options: generic, avx2, avx512, avx512_spr (x86-64) or generic, sve (aarch64) 25 | ENABLE_GPU="ON" # Set to OFF for CPU-only build 26 | BUILD_TYPE="Release" # Options: Release, Debug 27 | ENABLE_TESTING="OFF" # Set to ON to build tests 28 | BUILD_SHARED="ON" # Set to ON for shared libraries, OFF for static 29 | PARALLEL_JOBS=$(nproc) # Number of parallel build jobs 30 | 31 | # Colors for output 32 | GREEN='\033[0;32m' 33 | RED='\033[0;31m' 34 | YELLOW='\033[0;33m' 35 | NC='\033[0m' # No Color 36 | 37 | # Print configuration 38 | echo -e "${GREEN}FAISS Build Configuration:${NC}" 39 | echo -e " FAISS version: ${YELLOW}${FAISS_VERSION}${NC}" 40 | echo -e " Build directory: ${YELLOW}${BUILD_DIR}${NC}" 41 | echo -e " Install directory: ${YELLOW}${INSTALL_DIR}${NC}" 42 | echo -e " CUDA architectures: ${YELLOW}${CUDA_ARCHS}${NC}" 43 | echo -e " Python bindings: ${YELLOW}${PYTHON_BINDINGS}${NC}" 44 | echo -e " Use Intel MKL: ${YELLOW}${USE_MKL}${NC}" 45 | echo -e " Use NVIDIA cuVS: ${YELLOW}${USE_CUVS}${NC}" 46 | echo -e " Optimization level: ${YELLOW}${OPT_LEVEL}${NC}" 47 | echo -e " Enable GPU: ${YELLOW}${ENABLE_GPU}${NC}" 48 | echo -e " Build type: ${YELLOW}${BUILD_TYPE}${NC}" 49 | echo -e " Enable testing: ${YELLOW}${ENABLE_TESTING}${NC}" 50 | echo -e " Build shared libraries: ${YELLOW}${BUILD_SHARED}${NC}" 51 | echo -e " Parallel jobs: ${YELLOW}${PARALLEL_JOBS}${NC}" 52 | 53 | # Function to check for required tools 54 | check_requirements() { 55 | echo -e "${GREEN}Checking for required tools...${NC}" 56 | local missing_tools=() 57 | 58 | # Check for C++ compiler 59 | if ! command -v g++ &>/dev/null && ! command -v clang++ &>/dev/null; then 60 | missing_tools+=("C++17 compiler (g++ or clang++)") 61 | else 62 | echo -e " C++ compiler: ${YELLOW}$(command -v g++ 2>/dev/null || command -v clang++)${NC}" 63 | fi 64 | 65 | # Check for CMake 66 | if ! command -v cmake &>/dev/null; then 67 | missing_tools+=("CMake") 68 | else 69 | echo -e " CMake: ${YELLOW}$(cmake --version | head -n1)${NC}" 70 | fi 71 | 72 | # Check for Git 73 | if ! command -v git &>/dev/null; then 74 | missing_tools+=("Git") 75 | else 76 | echo -e " Git: ${YELLOW}$(git --version)${NC}" 77 | fi 78 | 79 | # Check for CUDA if GPU is enabled 80 | if [ "$ENABLE_GPU" = "ON" ]; then 81 | if ! command -v nvcc &>/dev/null; then 82 | missing_tools+=("CUDA toolkit (nvcc)") 83 | else 84 | echo -e " CUDA: ${YELLOW}$(nvcc --version | grep release)${NC}" 85 | fi 86 | fi 87 | 88 | # Check for Python and NumPy if Python bindings are enabled 89 | if [ "$PYTHON_BINDINGS" = "ON" ]; then 90 | # Check for NumPy 91 | if ! python -c "import numpy" &>/dev/null; then 92 | missing_tools+=("NumPy (Python package)") 93 | else 94 | echo -e " NumPy: ${YELLOW}$(python -c "import numpy; print(numpy.__version__)")${NC}" 95 | fi 96 | 97 | # Check for SWIG 98 | if ! command -v swig &>/dev/null; then 99 | missing_tools+=("SWIG") 100 | else 101 | echo -e " SWIG: ${YELLOW}$(swig -version | head -n2 | tail -n1)${NC}" 102 | fi 103 | fi 104 | 105 | # Report missing tools 106 | if [ ${#missing_tools[@]} -gt 0 ]; then 107 | echo -e "${RED}Missing required tools:${NC}" 108 | for tool in "${missing_tools[@]}"; do 109 | echo -e " - ${RED}${tool}${NC}" 110 | done 111 | echo -e "${RED}Please install the missing tools and try again.${NC}" 112 | exit 1 113 | fi 114 | 115 | echo -e "${GREEN}All required tools are installed.${NC}" 116 | } 117 | 118 | # Clone or update FAISS repository 119 | clone_or_update_faiss() { 120 | echo -e "${GREEN}Cloning or updating FAISS repository...${NC}" 121 | 122 | mkdir -p "$BUILD_DIR" 123 | 124 | if [ ! -d "$BUILD_DIR/faiss" ]; then 125 | cd "$BUILD_DIR" 126 | echo -e "${GREEN}Cloning FAISS repository...${NC}" 127 | git clone https://github.com/facebookresearch/faiss.git 128 | cd faiss 129 | if [ -n "$FAISS_VERSION" ]; then 130 | echo -e "${GREEN}Checking out version ${FAISS_VERSION}...${NC}" 131 | git checkout "v${FAISS_VERSION}" || git checkout "${FAISS_VERSION}" 132 | fi 133 | else 134 | echo -e "${GREEN}FAISS repository already exists, updating...${NC}" 135 | cd "$BUILD_DIR/faiss" 136 | git fetch 137 | if [ -n "$FAISS_VERSION" ]; then 138 | echo -e "${GREEN}Checking out version ${FAISS_VERSION}...${NC}" 139 | git checkout "v${FAISS_VERSION}" || git checkout "${FAISS_VERSION}" 140 | fi 141 | fi 142 | } 143 | 144 | # Configure build with CMake 145 | configure_build() { 146 | echo -e "${GREEN}Configuring build with CMake...${NC}" 147 | 148 | mkdir -p "$BUILD_DIR/faiss/build" 149 | cd "$BUILD_DIR/faiss/build" 150 | 151 | # Get current Python executable and site-packages directory 152 | PYTHON_EXECUTABLE=$(which python) 153 | PYTHON_SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])") 154 | 155 | echo -e "${GREEN}Using Python: ${YELLOW}${PYTHON_EXECUTABLE}${NC}" 156 | echo -e "${GREEN}Python site-packages: ${YELLOW}${PYTHON_SITE_PACKAGES}${NC}" 157 | 158 | # Prepare CMake arguments 159 | CMAKE_ARGS=( 160 | "-DCMAKE_BUILD_TYPE=${BUILD_TYPE}" 161 | "-DFAISS_ENABLE_GPU=${ENABLE_GPU}" 162 | "-DFAISS_ENABLE_PYTHON=${PYTHON_BINDINGS}" 163 | "-DBUILD_TESTING=${ENABLE_TESTING}" 164 | "-DBUILD_SHARED_LIBS=${BUILD_SHARED}" 165 | "-DCMAKE_INSTALL_PREFIX=${INSTALL_DIR}" 166 | "-DFAISS_OPT_LEVEL=${OPT_LEVEL}" 167 | "-DPython_EXECUTABLE=${PYTHON_EXECUTABLE}" 168 | ) 169 | 170 | # Add CUDA architectures if GPU is enabled 171 | if [ "$ENABLE_GPU" = "ON" ]; then 172 | CMAKE_ARGS+=("-DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCHS}") 173 | fi 174 | 175 | # Add cuVS support if enabled 176 | if [ "$USE_CUVS" = "ON" ]; then 177 | CMAKE_ARGS+=("-DFAISS_ENABLE_CUVS=ON") 178 | fi 179 | 180 | # Add MKL support if enabled 181 | if [ "$USE_MKL" = "ON" ]; then 182 | CMAKE_ARGS+=("-DBLA_VENDOR=Intel10_64_dyn") 183 | fi 184 | 185 | # Run CMake 186 | cmake "${CMAKE_ARGS[@]}" .. 187 | } 188 | 189 | # Build FAISS 190 | build_faiss() { 191 | echo -e "${GREEN}Building FAISS...${NC}" 192 | cd "$BUILD_DIR/faiss/build" 193 | 194 | # Build C++ library 195 | echo -e "${GREEN}Building C++ library...${NC}" 196 | make -j"${PARALLEL_JOBS}" faiss 197 | 198 | # Build optimized versions if needed 199 | if [ "$OPT_LEVEL" = "avx2" ]; then 200 | echo -e "${GREEN}Building AVX2 optimized version...${NC}" 201 | make -j"${PARALLEL_JOBS}" faiss_avx2 202 | elif [ "$OPT_LEVEL" = "avx512" ]; then 203 | echo -e "${GREEN}Building AVX512 optimized version...${NC}" 204 | make -j"${PARALLEL_JOBS}" faiss_avx512 205 | elif [ "$OPT_LEVEL" = "avx512_spr" ]; then 206 | echo -e "${GREEN}Building AVX512 Sapphire Rapids optimized version...${NC}" 207 | make -j"${PARALLEL_JOBS}" faiss_avx512_spr 208 | fi 209 | 210 | # Build Python bindings if enabled 211 | if [ "$PYTHON_BINDINGS" = "ON" ]; then 212 | echo -e "${GREEN}Building Python bindings...${NC}" 213 | make -j"${PARALLEL_JOBS}" swigfaiss 214 | fi 215 | } 216 | 217 | # Install FAISS 218 | install_faiss() { 219 | echo -e "${GREEN}Installing FAISS...${NC}" 220 | cd "$BUILD_DIR/faiss/build" 221 | 222 | # Install C++ library and headers 223 | echo -e "${GREEN}Installing C++ library and headers...${NC}" 224 | if [ "$BUILD_SHARED" = "ON" ]; then 225 | # First copy shared libraries to the virtual environment's lib directory 226 | mkdir -p "$VIRTUAL_ENV/lib" 227 | find . -name "*.so*" -not -path "*python*" -type f -exec cp -v {} "$VIRTUAL_ENV/lib/" \; 228 | fi 229 | 230 | # Install Python bindings if enabled 231 | if [ "$PYTHON_BINDINGS" = "ON" ]; then 232 | echo -e "${GREEN}Installing Python bindings...${NC}" 233 | cd "$BUILD_DIR/faiss/build/faiss/python" 234 | python setup.py install --prefix="$VIRTUAL_ENV" 235 | fi 236 | } 237 | 238 | # Run tests if enabled 239 | run_tests() { 240 | if [ "$ENABLE_TESTING" = "ON" ]; then 241 | echo -e "${GREEN}Running tests...${NC}" 242 | cd "$BUILD_DIR/faiss/build" 243 | make test 244 | fi 245 | } 246 | 247 | # Create a test script 248 | create_test_script() { 249 | echo -e "${GREEN}Creating test script...${NC}" 250 | 251 | mkdir -p "$BUILD_DIR" 252 | 253 | # Create Python test script 254 | cat >"$BUILD_DIR/test_faiss.py" < torch.device: 22 | """Gets the torch device based on CUDA availability and ID.""" 23 | if cuda_id >= 0 and torch.cuda.is_available(): 24 | return torch.device(f"cuda:{cuda_id}") 25 | else: 26 | return torch.device("cpu") 27 | 28 | 29 | def generate_data(batch_size: int, length: int, size: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: 30 | """Generates data for the copy task.""" 31 | sequence = np.random.binomial(1, 0.5, (batch_size, length, size - 1)).astype(np.float32) 32 | input_data = np.zeros((batch_size, 2 * length + 1, size), dtype=np.float32) 33 | target_output = np.zeros((batch_size, 2 * length + 1, size), dtype=np.float32) 34 | 35 | input_data[:, :length, : size - 1] = sequence 36 | input_data[:, length, -1] = 1 # Add the end-of-sequence marker 37 | target_output[:, length + 1 :, : size - 1] = sequence 38 | 39 | return (torch.tensor(input_data, device=device), torch.tensor(target_output, device=device)) 40 | 41 | 42 | def criterion(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 43 | """Calculates the binary cross-entropy loss.""" 44 | # Use F.binary_cross_entropy_with_logits for numerical stability 45 | return F.binary_cross_entropy_with_logits(predictions, targets) 46 | 47 | 48 | def main() -> None: 49 | """Main function for the copy task.""" 50 | parser = argparse.ArgumentParser(description="PyTorch Differentiable Neural Computer Copy Task") 51 | parser.add_argument("-input_size", type=int, default=6, help="Dimension of input feature") 52 | parser.add_argument("--rnn_type", type=str, default="lstm", help="Type of recurrent cells (lstm, gru, rnn)") 53 | parser.add_argument("--nhid", type=int, default=64, help="Number of hidden units in the controller") 54 | parser.add_argument("--dropout", type=float, default=0, help="Controller dropout rate") 55 | parser.add_argument("--memory_type", type=str, default="dnc", help="Memory type (dnc, sdnc, sam)") 56 | parser.add_argument("--nlayer", type=int, default=1, help="Number of memory layers") 57 | parser.add_argument("--nhlayer", type=int, default=2, help="Number of hidden layers in each RNN") 58 | parser.add_argument("--lr", type=float, default=1e-4, help="Initial learning rate") 59 | parser.add_argument("--optim", type=str, default="adam", help="Optimizer (adam, rmsprop)") 60 | parser.add_argument("--clip", type=float, default=50, help="Gradient clipping value") 61 | parser.add_argument("--batch_size", type=int, default=100, help="Batch size") 62 | parser.add_argument("--mem_size", type=int, default=20, help="Memory cell size") 63 | parser.add_argument("--mem_slot", type=int, default=16, help="Number of memory slots") 64 | parser.add_argument("--read_heads", type=int, default=4, help="Number of read heads") 65 | parser.add_argument("--sparse_reads", type=int, default=10, help="Number of sparse reads per head (sdnc/sam)") 66 | parser.add_argument("--temporal_reads", type=int, default=2, help="Number of temporal reads (sdnc)") 67 | parser.add_argument("--sequence_max_length", type=int, default=4, help="Maximum sequence length") 68 | parser.add_argument("--curriculum_increment", type=int, default=0, help="Sequence length increment per freq") 69 | parser.add_argument("--curriculum_freq", type=int, default=1000, help="Frequency of curriculum increment") 70 | parser.add_argument("--cuda", type=int, default=-1, help="CUDA GPU ID (-1 for CPU)") 71 | parser.add_argument("--iterations", type=int, default=100000, help="Total number of iterations") 72 | parser.add_argument("--summarize_freq", type=int, default=100, help="Summarize frequency") 73 | parser.add_argument("--check_freq", type=int, default=100, help="Checkpoint frequency") 74 | parser.add_argument("--visdom", action="store_true", help="Use Visdom for visualization") 75 | args = parser.parse_args() 76 | print(args) 77 | 78 | device = get_device(args.cuda) 79 | 80 | if args.visdom: 81 | viz = Visdom() 82 | if not viz.check_connection(): 83 | print("Visdom server not running. Disabling Visdom.") 84 | args.visdom = False 85 | 86 | if args.memory_type == "dnc": 87 | rnn = DNC( 88 | input_size=args.input_size, 89 | hidden_size=args.nhid, 90 | rnn_type=args.rnn_type, 91 | num_layers=args.nlayer, 92 | num_hidden_layers=args.nhlayer, 93 | dropout=args.dropout, 94 | nr_cells=args.mem_slot, 95 | cell_size=args.mem_size, 96 | read_heads=args.read_heads, 97 | device=device, 98 | debug=args.visdom, 99 | batch_first=True, 100 | independent_linears=True, 101 | ) 102 | elif args.memory_type == "sdnc": 103 | rnn = SDNC( 104 | input_size=args.input_size, 105 | hidden_size=args.nhid, 106 | rnn_type=args.rnn_type, 107 | num_layers=args.nlayer, 108 | num_hidden_layers=args.nhlayer, 109 | dropout=args.dropout, 110 | nr_cells=args.mem_slot, 111 | cell_size=args.mem_size, 112 | sparse_reads=args.sparse_reads, 113 | temporal_reads=args.temporal_reads, 114 | read_heads=args.read_heads, 115 | device=device, 116 | debug=args.visdom, 117 | batch_first=True, 118 | independent_linears=False, 119 | ) 120 | elif args.memory_type == "sam": 121 | rnn = SAM( 122 | input_size=args.input_size, 123 | hidden_size=args.nhid, 124 | rnn_type=args.rnn_type, 125 | num_layers=args.nlayer, 126 | num_hidden_layers=args.nhlayer, 127 | dropout=args.dropout, 128 | nr_cells=args.mem_slot, 129 | cell_size=args.mem_size, 130 | sparse_reads=args.sparse_reads, 131 | read_heads=args.read_heads, 132 | device=device, 133 | debug=args.visdom, 134 | batch_first=True, 135 | independent_linears=False, 136 | ) 137 | else: 138 | raise ValueError('Invalid memory_type. Choose "dnc", "sdnc", or "sam".') 139 | 140 | rnn = rnn.to(device) 141 | print(rnn) 142 | optimizer: Any 143 | 144 | if args.optim == "adam": 145 | optimizer = optim.Adam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=(0.9, 0.98)) 146 | elif args.optim == "adamax": 147 | optimizer = optim.Adamax(rnn.parameters(), lr=args.lr, eps=1e-9, betas=(0.9, 0.98)) 148 | elif args.optim == "rmsprop": 149 | optimizer = optim.RMSprop(rnn.parameters(), lr=args.lr, momentum=0.9, eps=1e-10) 150 | elif args.optim == "sgd": 151 | optimizer = optim.SGD(rnn.parameters(), lr=args.lr) 152 | elif args.optim == "adagrad": 153 | optimizer = optim.Adagrad(rnn.parameters(), lr=args.lr) 154 | elif args.optim == "adadelta": 155 | optimizer = optim.Adadelta(rnn.parameters(), lr=args.lr) 156 | else: 157 | raise ValueError(f"Unsupported optimizer: {args.optim}") 158 | 159 | last_losses = [] 160 | 161 | for epoch in range(args.iterations + 1): 162 | print(f"\rIteration {epoch}/{args.iterations}", end="") 163 | optimizer.zero_grad() 164 | 165 | random_length = np.random.randint(1, args.sequence_max_length + 1) 166 | input_data, target_output = generate_data(args.batch_size, random_length, args.input_size, device) 167 | 168 | output, (chx, mhx, rv) = rnn(input_data, (None, None, None), reset_experience=True, pass_through_memory=True) 169 | 170 | loss = criterion(output, target_output) 171 | loss.backward() 172 | 173 | clip_grad_norm_(rnn.parameters(), args.clip) 174 | optimizer.step() 175 | loss_value = loss.item() 176 | 177 | # Detach memory from graph 178 | if mhx is not None: 179 | mhx = {k: (v.detach() if isinstance(v, torch.Tensor) else v) for k, v in mhx.items()} 180 | 181 | last_losses.append(loss_value) 182 | 183 | if epoch % args.summarize_freq == 0: 184 | avg_loss = np.mean(last_losses) 185 | print(f"\n\tAvg. Loss: {avg_loss:.4f}") 186 | last_losses = [] 187 | if np.isnan(avg_loss): 188 | raise ValueError("NaN loss. Experiment failed.") 189 | 190 | if args.visdom and rnn.debug: # added rnn.debug 191 | avg_loss = np.mean(last_losses) 192 | last_losses = [] 193 | if args.memory_type == "dnc": 194 | 195 | memory = rnn._debug(mhx, None)["memory"] # type: ignore 196 | if memory is not None and len(memory) > 0: 197 | viz.heatmap( 198 | np.array(memory[-1]), 199 | opts=dict( 200 | xtickstep=10, 201 | ytickstep=2, 202 | title=f"Memory, t: {epoch}, loss: {avg_loss:.4f}", 203 | ylabel="layer * time", 204 | xlabel="mem_slot * mem_size", 205 | ), 206 | ) 207 | 208 | link_matrix = rnn._debug(mhx, None)["link_matrix"] # type: ignore 209 | if link_matrix is not None and len(link_matrix) > 0: 210 | viz.heatmap( 211 | np.array(link_matrix[-1]).reshape(args.mem_slot, args.mem_slot), 212 | opts=dict( 213 | xtickstep=10, 214 | ytickstep=2, 215 | title=f"Link Matrix, t: {epoch}, loss: {avg_loss:.4f}", 216 | ylabel="mem_slot", 217 | xlabel="mem_slot", 218 | ), 219 | ) 220 | 221 | precedence = rnn._debug(mhx, None)["precedence"] # type: ignore 222 | if precedence is not None and len(precedence) > 0: 223 | viz.heatmap( 224 | np.array(precedence[-1]), 225 | opts=dict( 226 | xtickstep=10, 227 | ytickstep=2, 228 | title=f"Precedence, t: {epoch}, loss: {avg_loss:.4f}", 229 | ylabel="layer * time", 230 | xlabel="mem_slot", 231 | ), 232 | ) 233 | 234 | if args.memory_type == "sdnc": 235 | link_matrix = rnn._debug(mhx, None)["link_matrix"] # type: ignore 236 | if link_matrix is not None and len(link_matrix) > 0: 237 | viz.heatmap( 238 | np.array(link_matrix[-1]).reshape(args.mem_slot, -1), 239 | opts=dict( 240 | xtickstep=10, 241 | ytickstep=2, 242 | title=f"Link Matrix, t: {epoch}, loss: {avg_loss:.4f}", 243 | ylabel="mem_slot", 244 | xlabel="mem_slot", 245 | ), 246 | ) 247 | 248 | rev_link_matrix = rnn._debug(mhx, None)["rev_link_matrix"] # type: ignore 249 | if rev_link_matrix is not None and len(rev_link_matrix) > 0: 250 | viz.heatmap( 251 | np.array(rev_link_matrix[-1]).reshape(args.mem_slot, -1), 252 | opts=dict( 253 | xtickstep=10, 254 | ytickstep=2, 255 | title=f"Reverse Link Matrix, t: {epoch}, loss: {avg_loss:.4f}", 256 | ylabel="mem_slot", 257 | xlabel="mem_slot", 258 | ), 259 | ) 260 | read_positions = rnn._debug(mhx, None)["read_positions"] # type: ignore 261 | if read_positions is not None and len(read_positions) > 0: 262 | viz.heatmap( 263 | np.array(read_positions[-1]), 264 | opts=dict( 265 | xtickstep=10, 266 | ytickstep=2, 267 | title=f"Read Positions, t: {epoch}, loss: {avg_loss:.4f}", 268 | ylabel="layer * time", 269 | xlabel="mem_slot", 270 | ), 271 | ) 272 | 273 | read_weights = rnn._debug(mhx, None)["read_weights"] # type: ignore 274 | if read_weights is not None and len(read_weights) > 0: 275 | viz.heatmap( 276 | np.array(read_weights[-1]), 277 | opts=dict( 278 | xtickstep=10, 279 | ytickstep=2, 280 | title=f"Read Weights, t: {epoch}, loss: {avg_loss:.4f}", 281 | ylabel="layer * time", 282 | xlabel="nr_read_heads * mem_slot", 283 | ), 284 | ) 285 | 286 | write_weights = rnn._debug(mhx, None)["write_weights"] # type: ignore 287 | if write_weights is not None and len(write_weights) > 0: 288 | viz.heatmap( 289 | np.array(write_weights[-1]), 290 | opts=dict( 291 | xtickstep=10, 292 | ytickstep=2, 293 | title=f"Write Weights, t: {epoch}, loss: {avg_loss:.4f}", 294 | ylabel="layer * time", 295 | xlabel="mem_slot", 296 | ), 297 | ) 298 | 299 | if args.memory_type == "dnc": 300 | usage_vector = rnn._debug(mhx, None)["usage_vector"] # type: ignore 301 | else: 302 | usage_vector = rnn._debug(mhx, None)["usage"] # type: ignore 303 | 304 | if usage_vector is not None and len(usage_vector) > 0: 305 | viz.heatmap( 306 | np.array(usage_vector[-1]), 307 | opts=dict( 308 | xtickstep=10, 309 | ytickstep=2, 310 | title=f"Usage Vector, t: {epoch}, loss: {avg_loss:.4f}", 311 | ylabel="layer * time", 312 | xlabel="mem_slot", 313 | ), 314 | ) 315 | 316 | if args.curriculum_increment > 0 and epoch != 0 and epoch % args.curriculum_freq == 0: 317 | args.sequence_max_length += args.curriculum_increment 318 | print(f"Increasing max length to {args.sequence_max_length}") 319 | 320 | if epoch != 0 and epoch % args.check_freq == 0: 321 | print("\nSaving Checkpoint ... ", end="") 322 | check_ptr = os.path.join(args.checkpoint_dir, f"step_{epoch}.pth") 323 | torch.save(rnn.state_dict(), check_ptr) 324 | print("Done!") 325 | 326 | print("\nTesting generalization...") 327 | rnn.eval() 328 | 329 | with torch.no_grad(): 330 | for i in range(int((args.iterations + 1) / 10)): 331 | print(f"\nIteration {i}/{args.iterations // 10}") 332 | random_length = np.random.randint(2, args.sequence_max_length * 10 + 1) 333 | 334 | input_data, target_output = generate_data(args.batch_size, random_length, args.input_size, device) 335 | output, _ = rnn(input_data, (None, None, None), reset_experience=True, pass_through_memory=True) 336 | output_value = torch.sigmoid(output).round().detach().cpu().numpy() 337 | target_value = target_output.detach().cpu().numpy() 338 | 339 | num_correct = (output_value == target_value).sum() 340 | total_num = target_output.numel() 341 | accuracy = num_correct / total_num 342 | print(f"Accuracy: {accuracy:.4f}") 343 | 344 | 345 | if __name__ == "__main__": 346 | main() 347 | -------------------------------------------------------------------------------- /dnc/sparse_memory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .util import cuda, σ, θ # Import necessary functions from util 9 | 10 | 11 | class SparseMemory(nn.Module): 12 | """Sparse Memory module.""" 13 | 14 | def __init__( 15 | self, 16 | input_size: int, 17 | mem_size: int = 512, 18 | cell_size: int = 32, 19 | independent_linears: bool = True, 20 | read_heads: int = 4, 21 | sparse_reads: int = 4, 22 | num_lists: int | None = None, 23 | index_checks: int = 32, 24 | device: torch.device | None = None, 25 | ): 26 | """Initialize SparseMemory. 27 | 28 | Args: 29 | input_size: Input size. 30 | mem_size: Memory size. 31 | cell_size: Size of each memory cell. 32 | independent_linears: Whether to use independent linear layers. 33 | read_heads: Number of read heads. 34 | sparse_reads: Number of sparse reads. 35 | num_lists: Number of lists for indexing. 36 | index_checks: Number of index checks. 37 | device: PyTorch device to use. 38 | """ 39 | super(SparseMemory, self).__init__() 40 | 41 | self.mem_size = mem_size 42 | self.cell_size = cell_size 43 | self.device = device 44 | self.input_size = input_size 45 | self.independent_linears = independent_linears 46 | self.K = sparse_reads if self.mem_size > sparse_reads else self.mem_size 47 | self.read_heads = read_heads 48 | self.num_lists = num_lists if num_lists is not None else int(self.mem_size / 100) 49 | self.index_checks = index_checks 50 | 51 | m = self.mem_size 52 | w = self.cell_size 53 | r = self.read_heads 54 | # The visible memory size: (K * R read heads, forward and backward 55 | # temporal reads of size KL and least used memory cell) 56 | self.c = (r * self.K) + 1 57 | 58 | if self.independent_linears: 59 | self.read_query_transform = nn.Linear(self.input_size, w * r) 60 | self.write_vector_transform = nn.Linear(self.input_size, w) 61 | self.interpolation_gate_transform = nn.Linear(self.input_size, self.c) 62 | self.write_gate_transform = nn.Linear(self.input_size, 1) 63 | torch.nn.init.kaiming_uniform_(self.read_query_transform.weight) 64 | torch.nn.init.kaiming_uniform_(self.write_vector_transform.weight) 65 | torch.nn.init.kaiming_uniform_(self.interpolation_gate_transform.weight) 66 | torch.nn.init.kaiming_uniform_(self.write_gate_transform.weight) 67 | 68 | else: 69 | self.interface_size = (r * w) + w + self.c + 1 70 | self.interface_weights = nn.Linear(self.input_size, self.interface_size) 71 | torch.nn.init.kaiming_uniform_(self.interface_weights.weight) 72 | 73 | self.I = cuda(1 - torch.eye(self.c).unsqueeze(0), device=self.device) # (1 * n * n) 74 | self.δ = 0.005 # minimum usage 75 | self.timestep = 0 76 | self.mem_limit_reached = False 77 | if self.device is not None and self.device.type == "cuda": 78 | self.to(self.device) 79 | 80 | def rebuild_indexes(self, hidden: dict, erase: bool = False) -> dict: 81 | """Rebuilds the indexes for sparse memory access. 82 | 83 | Args: 84 | hidden: Hidden state dictionary. 85 | erase: Whether to erase the existing memory content. 86 | 87 | Returns: 88 | Updated hidden state dictionary. 89 | """ 90 | b = hidden["memory"].size(0) 91 | 92 | # if indexes already exist, we reset them 93 | if "indexes" in hidden: 94 | for x in hidden["indexes"]: 95 | x.reset() 96 | else: 97 | # create new indexes, try to use FAISS 98 | try: 99 | from .faiss_index import FAISSIndex 100 | 101 | hidden["indexes"] = [ 102 | FAISSIndex( 103 | cell_size=self.cell_size, 104 | nr_cells=self.mem_size, 105 | K=self.K, 106 | num_lists=self.num_lists, 107 | probes=self.index_checks, 108 | device=self.device, 109 | ) 110 | for _ in range(b) 111 | ] 112 | except ImportError: 113 | print( 114 | "FAISS not found, please install FAISS, consult https://github.com/facebookresearch/faiss/blob/main/INSTALL.md" 115 | ) 116 | raise 117 | 118 | # add existing memory into indexes 119 | pos = hidden["read_positions"].squeeze().detach().cpu().numpy() 120 | if not erase: 121 | for n, i in enumerate(hidden["indexes"]): 122 | i.reset() 123 | i.add(hidden["memory"][n], last=pos[n][-1]) 124 | else: 125 | self.timestep = 0 126 | self.mem_limit_reached = False 127 | 128 | return hidden 129 | 130 | def reset(self, batch_size: int = 1, hidden: dict | None = None, erase: bool = True) -> dict: 131 | """Resets the memory and hidden state. 132 | 133 | Args: 134 | batch_size: Batch size. 135 | hidden: Hidden state dictionary. 136 | erase: Whether to erase the existing memory content. 137 | Returns: 138 | Reset hidden state dictionary. 139 | 140 | """ 141 | m = self.mem_size 142 | w = self.cell_size 143 | b = batch_size 144 | r = self.read_heads 145 | c = self.c 146 | 147 | if hidden is None: 148 | hidden = { 149 | # warning can be a huge chunk of contiguous memory 150 | "memory": cuda(torch.zeros(b, m, w).fill_(self.δ), device=self.device), 151 | "visible_memory": cuda(torch.zeros(b, c, w).fill_(self.δ), device=self.device), 152 | "read_weights": cuda(torch.zeros(b, m).fill_(self.δ), device=self.device), 153 | "write_weights": cuda(torch.zeros(b, m).fill_(self.δ), device=self.device), 154 | "read_vectors": cuda(torch.zeros(b, r, w).fill_(self.δ), device=self.device), 155 | "least_used_mem": cuda(torch.zeros(b, 1).fill_(c + 1), device=self.device).long(), 156 | "usage": cuda(torch.zeros(b, m), device=self.device), 157 | "read_positions": cuda(torch.arange(0, c).expand(b, c), device=self.device).long(), 158 | } 159 | hidden = self.rebuild_indexes(hidden, erase=True) 160 | else: 161 | # duplication is faster than moving tensors between devices (or even cloning) 162 | hidden["memory"] = hidden["memory"].clone() 163 | hidden["visible_memory"] = hidden["visible_memory"].clone() 164 | hidden["read_weights"] = hidden["read_weights"].clone() 165 | hidden["write_weights"] = hidden["write_weights"].clone() 166 | hidden["read_vectors"] = hidden["read_vectors"].clone() 167 | hidden["least_used_mem"] = hidden["least_used_mem"].clone() 168 | hidden["usage"] = hidden["usage"].clone() 169 | hidden["read_positions"] = hidden["read_positions"].clone() 170 | hidden = self.rebuild_indexes(hidden, erase) 171 | 172 | if erase: 173 | hidden["memory"].data.fill_(self.δ) 174 | hidden["visible_memory"].data.fill_(self.δ) 175 | hidden["read_weights"].data.fill_(self.δ) 176 | hidden["write_weights"].data.fill_(self.δ) 177 | hidden["read_vectors"].data.fill_(self.δ) 178 | hidden["least_used_mem"].data.fill_(c + 1) 179 | hidden["usage"].data.fill_(0) 180 | hidden["read_positions"] = cuda(torch.arange(0, c).expand(b, c), device=self.device).long() 181 | 182 | return hidden 183 | 184 | def write_into_sparse_memory(self, hidden: dict) -> dict: 185 | """Writes the visible memory into the sparse memory matrix. 186 | 187 | Args: 188 | hidden: Hidden state dictionary 189 | 190 | Returns: 191 | Updated hidden state dictionary. 192 | """ 193 | visible_memory = hidden["visible_memory"] 194 | positions = hidden["read_positions"] 195 | 196 | (b, m, w) = hidden["memory"].size() 197 | # Create a new tensor for memory to avoid inplace operations during backprop 198 | new_memory = hidden["memory"].clone() 199 | # update memory (using non-inplace operation) 200 | new_memory.scatter_(1, positions.unsqueeze(2).expand(b, self.c, w), visible_memory) 201 | hidden["memory"] = new_memory 202 | 203 | # non-differentiable operations 204 | pos = positions.detach().cpu().numpy() 205 | for batch in range(b): 206 | # update indexes 207 | hidden["indexes"][batch].reset() 208 | hidden["indexes"][batch].add( 209 | hidden["memory"][batch], last=(pos[batch][-1] if not self.mem_limit_reached else None) 210 | ) 211 | 212 | mem_limit_reached = hidden["least_used_mem"][0].detach().cpu().numpy()[0] >= self.mem_size - 1 213 | self.mem_limit_reached = mem_limit_reached or self.mem_limit_reached 214 | 215 | return hidden 216 | 217 | def write( 218 | self, interpolation_gate: torch.Tensor, write_vector: torch.Tensor, write_gate: torch.Tensor, hidden: dict 219 | ) -> dict: 220 | """Performs the memory write operation. 221 | 222 | Args: 223 | interpolation_gate: Interpolation gate. 224 | write_vector: Write vector. 225 | write_gate: Write gate. 226 | hidden: Hidden state dictionary. 227 | 228 | Returns: 229 | Updated hidden state dictionary. 230 | 231 | """ 232 | 233 | read_weights = hidden["read_weights"].gather(1, hidden["read_positions"]) 234 | # encourage read and write in the first timestep 235 | if self.timestep == 1: 236 | read_weights = read_weights + 1 237 | write_weights = hidden["write_weights"].gather(1, hidden["read_positions"]) 238 | 239 | hidden["usage"], I = self.update_usage(hidden["read_positions"], read_weights, write_weights, hidden["usage"]) 240 | 241 | # either we write to previous read locations 242 | x = interpolation_gate * read_weights 243 | # or to a new location 244 | y = (1 - interpolation_gate) * I 245 | write_weights = write_gate * (x + y) 246 | 247 | # store the write weights (avoid inplace operation) 248 | new_write_weights = hidden["write_weights"].clone() 249 | new_write_weights.scatter_(1, hidden["read_positions"], write_weights) 250 | hidden["write_weights"] = new_write_weights 251 | 252 | # erase matrix 253 | erase_matrix = I.unsqueeze(2).expand(hidden["visible_memory"].size()) 254 | 255 | # write into memory 256 | hidden["visible_memory"] = hidden["visible_memory"] * (1 - erase_matrix) + torch.bmm( 257 | write_weights.unsqueeze(2), write_vector 258 | ) 259 | 260 | hidden = self.write_into_sparse_memory(hidden) 261 | 262 | # update least used memory cell 263 | hidden["least_used_mem"] = torch.topk(hidden["usage"], 1, dim=-1, largest=False)[1] 264 | 265 | return hidden 266 | 267 | def update_usage( 268 | self, read_positions: torch.Tensor, read_weights: torch.Tensor, write_weights: torch.Tensor, usage: torch.Tensor 269 | ) -> tuple[torch.Tensor, torch.Tensor]: 270 | """Updates the usage vector. 271 | 272 | Args: 273 | read_positions: Read positions. 274 | read_weights: Read weights. 275 | write_weights: Write weights. 276 | usage: Usage vector. 277 | 278 | Returns: 279 | Tuple: Updated usage vector and indicator matrix. 280 | """ 281 | (b, _) = read_positions.size() 282 | # usage is timesteps since a non-negligible memory access 283 | u = (read_weights + write_weights > self.δ).float() 284 | 285 | # usage before write 286 | relevant_usages = usage.gather(1, read_positions) 287 | 288 | # indicator of words with minimal memory usage 289 | minusage = torch.min(relevant_usages, -1, keepdim=True)[0] 290 | minusage = minusage.expand(relevant_usages.size()) 291 | I = (relevant_usages == minusage).float() 292 | 293 | # usage after write 294 | relevant_usages = (self.timestep - relevant_usages) * u + relevant_usages * (1 - u) 295 | 296 | # Replace inplace scatter with clone + scatter + assignment 297 | new_usage = usage.clone() 298 | new_usage.scatter_(1, read_positions, relevant_usages) 299 | usage = new_usage 300 | 301 | return usage, I 302 | 303 | def read_from_sparse_memory( 304 | self, memory: torch.Tensor, indexes: list, keys: torch.Tensor, least_used_mem: torch.Tensor, usage: torch.Tensor 305 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 306 | """Reads from sparse memory using indexes. 307 | 308 | Args: 309 | memory: Memory tensor. 310 | indexes: List of indexes. 311 | keys: Read keys. 312 | least_used_mem: Least used memory locations. 313 | usage: Usage vector. 314 | 315 | Returns: 316 | Tuple: Read vectors, read positions, read weights, and visible memory. 317 | """ 318 | b = keys.size(0) 319 | rpos = [] 320 | 321 | # we search for k cells per read head 322 | for batch in range(b): 323 | distances, positions = indexes[batch].search(keys[batch]) 324 | rpos.append(positions) 325 | read_positions = torch.stack(rpos, 0) 326 | 327 | # add least used mem to read positions 328 | (b, r, k) = read_positions.size() 329 | read_positions = read_positions.squeeze(1).view(b, -1) 330 | 331 | # no gradient here 332 | # temporal reads 333 | (b, m, w) = memory.size() 334 | # Use the memory size as the max length rather than relying on least_used_mem value 335 | # If memory limit is reached, use full memory size minus 1 336 | max_length = (m - 1) if self.mem_limit_reached else min(int(least_used_mem[0, 0].detach().cpu().numpy()), m - 1) 337 | 338 | # differentiable ops 339 | # append forward and backward read positions, might lead to duplicates 340 | read_positions = torch.cat([read_positions, least_used_mem], 1) 341 | read_positions = torch.clamp(read_positions, 0, max_length) 342 | 343 | visible_memory = memory.gather(1, read_positions.unsqueeze(2).expand(b, self.c, w)) 344 | 345 | read_weights = σ(θ(visible_memory, keys), 2) 346 | read_vectors = torch.bmm(read_weights, visible_memory) 347 | read_weights = torch.prod(read_weights, 1) 348 | 349 | return read_vectors, read_positions, read_weights, visible_memory 350 | 351 | def read(self, read_query: torch.Tensor, hidden: dict) -> tuple[torch.Tensor, dict]: 352 | """Performs the memory read operation. 353 | 354 | Args: 355 | read_query: Read query. 356 | hidden: Hidden state dictionary. 357 | 358 | Returns: 359 | Tuple: Read vectors and updated hidden state. 360 | """ 361 | # sparse read 362 | read_vectors, positions, read_weights, visible_memory = self.read_from_sparse_memory( 363 | hidden["memory"], hidden["indexes"], read_query, hidden["least_used_mem"], hidden["usage"] 364 | ) 365 | 366 | hidden["read_positions"] = positions 367 | # Avoid inplace operation 368 | new_read_weights = hidden["read_weights"].clone() 369 | new_read_weights.scatter_(1, positions, read_weights) 370 | hidden["read_weights"] = new_read_weights 371 | hidden["read_vectors"] = read_vectors 372 | hidden["visible_memory"] = visible_memory 373 | 374 | return hidden["read_vectors"], hidden 375 | 376 | def forward(self, ξ: torch.Tensor, hidden: dict) -> tuple[torch.Tensor, dict]: 377 | """Forward pass through the memory. 378 | 379 | Args: 380 | ξ: Input tensor. 381 | hidden: Hidden state dictionary. 382 | 383 | Returns: 384 | Tuple: Read vectors and updated hidden state. 385 | """ 386 | m = self.mem_size 387 | w = self.cell_size 388 | r = self.read_heads 389 | c = self.c 390 | b = ξ.size(0) 391 | 392 | if self.independent_linears: 393 | # r read keys (b * r * w) 394 | read_query = self.read_query_transform(ξ).view(b, r, w) 395 | # write key (b * 1 * w) 396 | write_vector = self.write_vector_transform(ξ).view(b, 1, w) 397 | # write vector (b * 1 * r) 398 | interpolation_gate = torch.sigmoid(self.interpolation_gate_transform(ξ)).view(b, c) 399 | # write gate (b * 1) 400 | write_gate = torch.sigmoid(self.write_gate_transform(ξ).view(b, 1)) 401 | else: 402 | ξ = self.interface_weights(ξ) 403 | # r read keys (b * r * w) 404 | read_query = ξ[:, : r * w].contiguous().view(b, r, w) 405 | # write key (b * 1 * w) 406 | write_vector = ξ[:, r * w : r * w + w].contiguous().view(b, 1, w) 407 | # write vector (b * 1 * r) 408 | interpolation_gate = torch.sigmoid(ξ[:, r * w + w : r * w + w + c]).contiguous().view(b, c) 409 | # write gate (b * 1) 410 | write_gate = torch.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1) 411 | 412 | self.timestep += 1 413 | hidden = self.write(interpolation_gate, write_vector, write_gate, hidden) 414 | return self.read(read_query, hidden) 415 | -------------------------------------------------------------------------------- /dnc/dnc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Any 3 | 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence 8 | 9 | from .memory import Memory, MemoryHiddenState 10 | from .sparse_memory import SparseMemory 11 | from .sparse_temporal_memory import SparseTemporalMemory 12 | 13 | from .util import cuda 14 | 15 | # Define controller hidden state type for clarity 16 | ControllerHiddenState = torch.Tensor | tuple[torch.Tensor, torch.Tensor] 17 | DNCHiddenState = tuple[ 18 | list[ControllerHiddenState], 19 | list[MemoryHiddenState], 20 | torch.Tensor, 21 | ] 22 | LayerHiddenState = tuple[ControllerHiddenState, MemoryHiddenState, torch.Tensor | None] 23 | 24 | 25 | class DNC(nn.Module): 26 | """Differentiable neural computer.""" 27 | 28 | def __init__( 29 | self, 30 | input_size: int, 31 | hidden_size: int, 32 | rnn_type: str = "lstm", 33 | num_layers: int = 1, 34 | num_hidden_layers: int = 2, 35 | bias: bool = True, 36 | batch_first: bool = True, 37 | dropout: float = 0, 38 | nr_cells: int = 5, 39 | read_heads: int = 2, 40 | cell_size: int = 10, 41 | nonlinearity: str = "tanh", 42 | independent_linears: bool = False, 43 | share_memory_between_layers: bool = True, 44 | debug: bool = False, 45 | clip: float = 20, 46 | device: torch.device | None = None, 47 | ): 48 | """Create a DNC network. 49 | 50 | Args: 51 | input_size: Input size. 52 | hidden_size: Size of hidden layers. 53 | rnn_type: Type of recurrent cell, can be `rnn`, `gru` and `lstm`. 54 | num_layers: Number of layers of DNC. 55 | num_hidden_layers: Number of layers of RNNs in each DNC layer. 56 | bias: Whether to use bias. 57 | batch_first: If True, then the input and output tensors are provided as `(batch, seq, feature)`. 58 | dropout: Dropout fraction to be applied to each RNN layer. 59 | nr_cells: Size of memory: number of memory cells. 60 | read_heads: Number of read heads that read from memory. 61 | cell_size:Size of memory: size of each cell. 62 | nonlinearity: The non-linearity to use for RNNs, applicable when `rnn_type="rnn"`. 63 | independent_linears: Use independent linear modules for meomry transform operators. 64 | share_memory_between_layers: Share one memory module between all layers. 65 | debug: Run in debug mode. 66 | clip: Clip controller outputs. 67 | device: Device (cpu, cuda, cuda:0, ...) 68 | """ 69 | super(DNC, self).__init__() 70 | 71 | self.input_size = input_size 72 | self.hidden_size = hidden_size 73 | self.rnn_type = rnn_type 74 | self.num_layers = num_layers 75 | self.num_hidden_layers = num_hidden_layers 76 | self.bias = bias 77 | self.batch_first = batch_first 78 | self.dropout = dropout 79 | self.nr_cells = nr_cells 80 | self.read_heads = read_heads 81 | self.cell_size = cell_size 82 | self.nonlinearity = nonlinearity 83 | self.independent_linears = independent_linears 84 | self.share_memory_between_layers = share_memory_between_layers 85 | self.debug = debug 86 | self.clip = clip 87 | self.device = device 88 | 89 | self.w = self.cell_size 90 | self.r = self.read_heads 91 | 92 | self.read_vectors_size = self.read_heads * self.cell_size 93 | self.output_size = self.hidden_size 94 | 95 | self.nn_input_size = self.input_size + self.read_vectors_size 96 | self.nn_output_size = self.output_size + self.read_vectors_size 97 | 98 | self.rnns: list[nn.RNN | nn.GRU | nn.LSTM] = [] 99 | self.memories: list[Memory | SparseMemory | SparseTemporalMemory] = [] 100 | 101 | for layer in range(self.num_layers): 102 | if self.rnn_type.lower() == "rnn": 103 | self.rnns.append( 104 | nn.RNN( 105 | (self.nn_input_size if layer == 0 else self.nn_output_size), 106 | self.output_size, 107 | bias=self.bias, 108 | nonlinearity=self.nonlinearity, 109 | batch_first=True, 110 | dropout=self.dropout, 111 | num_layers=self.num_hidden_layers, 112 | ) 113 | ) 114 | elif self.rnn_type.lower() == "gru": 115 | self.rnns.append( 116 | nn.GRU( 117 | (self.nn_input_size if layer == 0 else self.nn_output_size), 118 | self.output_size, 119 | bias=self.bias, 120 | batch_first=True, 121 | dropout=self.dropout, 122 | num_layers=self.num_hidden_layers, 123 | ) 124 | ) 125 | elif self.rnn_type.lower() == "lstm": 126 | self.rnns.append( 127 | nn.LSTM( 128 | (self.nn_input_size if layer == 0 else self.nn_output_size), 129 | self.output_size, 130 | bias=self.bias, 131 | batch_first=True, 132 | dropout=self.dropout, 133 | num_layers=self.num_hidden_layers, 134 | ) 135 | ) 136 | setattr(self, self.rnn_type.lower() + "_layer_" + str(layer), self.rnns[layer]) 137 | 138 | # memories for each layer 139 | if not self.share_memory_between_layers: 140 | self.memories.append( 141 | Memory( 142 | input_size=self.output_size, 143 | nr_cells=self.nr_cells, 144 | cell_size=self.w, 145 | read_heads=self.r, 146 | device=self.device, 147 | independent_linears=self.independent_linears, 148 | ) 149 | ) 150 | setattr(self, "rnn_layer_memory_" + str(layer), self.memories[layer]) 151 | 152 | # only one memory shared by all layers 153 | if self.share_memory_between_layers: 154 | self.memories.append( 155 | Memory( 156 | input_size=self.output_size, 157 | nr_cells=self.nr_cells, 158 | cell_size=self.w, 159 | read_heads=self.r, 160 | device=self.device, 161 | independent_linears=self.independent_linears, 162 | ) 163 | ) 164 | setattr(self, "rnn_layer_memory_shared", self.memories[0]) 165 | 166 | # final output layer 167 | self.output = nn.Linear(self.nn_output_size, self.input_size) 168 | torch.nn.init.kaiming_uniform_(self.output.weight) 169 | 170 | if self.device is not None and self.device.type == "cuda": 171 | self.to(self.device) 172 | 173 | def _init_hidden(self, hx: DNCHiddenState | None, batch_size: int, reset_experience: bool) -> DNCHiddenState: 174 | """Initializes the hidden states. 175 | 176 | Args: 177 | hx: Existing hidden state or None. 178 | batch_size: Batch size. 179 | reset_experience: Whether to reset memory experience. 180 | 181 | Returns: 182 | Initialized hidden state. 183 | """ 184 | # Parse hidden state components 185 | if hx is not None: 186 | chx, mhx, last_read = hx 187 | else: 188 | chx, mhx, last_read = None, None, None 189 | 190 | # Initialize controller hidden state if needed 191 | if chx is None: 192 | h: torch.Tensor = cuda( 193 | torch.zeros(self.num_hidden_layers, batch_size, self.output_size), 194 | device=self.device, 195 | ) 196 | torch.nn.init.xavier_uniform_(h) 197 | chx = [(h, h) if self.rnn_type.lower() == "lstm" else h for _ in range(self.num_layers)] 198 | 199 | # Initialize last read vectors if needed 200 | if last_read is None: 201 | last_read = cuda(torch.zeros(batch_size, self.w * self.r), device=self.device) 202 | 203 | # Initialize memory states if needed 204 | if mhx is None: 205 | if self.share_memory_between_layers: 206 | mhx = [self.memories[0].reset(batch_size, erase=reset_experience)] 207 | else: 208 | mhx = [m.reset(batch_size, erase=reset_experience) for m in self.memories] 209 | else: 210 | if self.share_memory_between_layers: 211 | if len(mhx) == 0 or mhx[0] is None: 212 | mhx = [self.memories[0].reset(batch_size, erase=reset_experience)] 213 | else: 214 | mhx = [self.memories[0].reset(batch_size, mhx[0], erase=reset_experience)] 215 | else: 216 | if len(mhx) == 0: 217 | mhx = [m.reset(batch_size, erase=reset_experience) for m in self.memories] 218 | else: 219 | new_mhx = [] 220 | for i, m in enumerate(self.memories): 221 | if i < len(mhx) and mhx[i] is not None: 222 | new_mhx.append(m.reset(batch_size, mhx[i], erase=reset_experience)) 223 | else: 224 | new_mhx.append(m.reset(batch_size, erase=reset_experience)) 225 | mhx = new_mhx 226 | 227 | return chx, mhx, last_read 228 | 229 | def _debug( 230 | self, mhx: MemoryHiddenState, debug_obj: dict[str, list[np.ndarray]] | None 231 | ) -> dict[str, list[np.ndarray]] | None: 232 | """Collects debug information. Only returns a debug_obj if self.debug is True. 233 | 234 | Args: 235 | mhx: Memory hidden state. 236 | debug_obj: Debug object containing lists of numpy arrays. 237 | 238 | Returns: 239 | Debug object or None. 240 | """ 241 | if not self.debug: 242 | return None 243 | 244 | if not debug_obj: 245 | debug_obj = { 246 | "memory": [], 247 | "link_matrix": [], 248 | "precedence": [], 249 | "read_weights": [], 250 | "write_weights": [], 251 | "usage_vector": [], 252 | } 253 | 254 | debug_obj["memory"].append(mhx["memory"][0].detach().cpu().numpy()) 255 | debug_obj["link_matrix"].append(mhx["link_matrix"][0][0].detach().cpu().numpy()) 256 | debug_obj["precedence"].append(mhx["precedence"][0].detach().cpu().numpy()) 257 | debug_obj["read_weights"].append(mhx["read_weights"][0].detach().cpu().numpy()) 258 | debug_obj["write_weights"].append(mhx["write_weights"][0].detach().cpu().numpy()) 259 | debug_obj["usage_vector"].append(mhx["usage_vector"][0].unsqueeze(0).detach().cpu().numpy()) 260 | return debug_obj 261 | 262 | def _layer_forward( 263 | self, 264 | input: torch.Tensor, 265 | layer: int, 266 | hx: LayerHiddenState, 267 | pass_through_memory: bool = True, 268 | ) -> tuple[torch.Tensor, LayerHiddenState]: 269 | """Performs a forward pass through a single layer. 270 | 271 | Args: 272 | input : Input tensor. 273 | layer: Layer index. 274 | hx: Hidden state for the layer. 275 | pass_through_memory: Whether to pass the input through memory. 276 | 277 | Returns: 278 | Tuple: Output, and updated hidden state. 279 | """ 280 | (chx, mhx, _) = hx 281 | 282 | # pass through the controller layer 283 | input, chx = self.rnns[layer](input.unsqueeze(1), chx) 284 | input = input.squeeze(1) # Remove the sequence length dimension (always 1) 285 | 286 | # clip the controller output 287 | if self.clip != 0: 288 | output = torch.clamp(input, -self.clip, self.clip) 289 | else: 290 | output = input 291 | 292 | # the interface vector 293 | ξ = output 294 | 295 | # pass through memory 296 | if pass_through_memory: 297 | if self.share_memory_between_layers: 298 | read_vecs, mhx = self.memories[0](ξ, mhx) 299 | else: 300 | read_vecs, mhx = self.memories[layer](ξ, mhx) 301 | # the read vectors 302 | read_vectors = read_vecs.view(-1, self.w * self.r) 303 | else: 304 | # Initialize read vectors with zeros when not passing through memory 305 | read_vectors = cuda(torch.zeros(ξ.size(0), self.w * self.r), device=self.device) 306 | 307 | return output, (chx, mhx, read_vectors) 308 | 309 | def forward( 310 | self, 311 | input_data: torch.Tensor | PackedSequence, 312 | hx: DNCHiddenState | None, 313 | reset_experience: bool = False, 314 | pass_through_memory: bool = True, 315 | ) -> ( 316 | tuple[torch.Tensor | PackedSequence, DNCHiddenState] 317 | | tuple[torch.Tensor | PackedSequence, DNCHiddenState, dict[str, Any]] 318 | ): 319 | """Performs a forward pass through the DNC. 320 | 321 | Args: 322 | input_data: Input tensor or PackedSequence. 323 | hx: Hidden state or None. 324 | reset_experience: Whether to reset memory experience. 325 | pass_through_memory: Whether to pass the input through memory. 326 | 327 | Returns: 328 | Tuple: Output (same type as input_data), updated hidden state, and optionally debug information. 329 | 330 | """ 331 | max_length: int 332 | # handle packed data 333 | if isinstance(input_data, PackedSequence): 334 | input, lengths = pad_packed_sequence(input_data, batch_first=self.batch_first) 335 | max_length = int(lengths.max().item()) 336 | elif isinstance(input_data, torch.Tensor): 337 | input = input_data 338 | batch_size = input.size(0) if self.batch_first else input.size(1) 339 | max_length = input.size(1) if self.batch_first else input.size(0) 340 | lengths = torch.tensor([max_length] * batch_size, device=input.device) 341 | 342 | else: 343 | raise TypeError("input_data must be a PackedSequence or Tensor") 344 | 345 | if not self.batch_first: 346 | input = input.transpose(0, 1) 347 | # make the data time-first 348 | 349 | controller_hidden, mem_hidden, last_read = self._init_hidden(hx, batch_size, reset_experience) 350 | 351 | # last_read is guaranteed to be initialized by _init_hidden, so no need to check for None 352 | 353 | inputs = [torch.cat([input[:, x, :], last_read], 1) for x in range(max_length)] 354 | 355 | # batched forward pass per element / word / etc 356 | if self.debug: 357 | viz: dict[str, Any] | None = None 358 | 359 | outs: list[torch.Tensor | None] = [None] * max_length 360 | read_vectors: torch.Tensor | None = None 361 | 362 | # pass through time 363 | for time in range(max_length): 364 | # pass thorugh layers 365 | for layer in range(self.num_layers): 366 | # this layer's hidden states 367 | chx_layer = controller_hidden[layer] 368 | mem_layer = mem_hidden[0] if self.share_memory_between_layers else mem_hidden[layer] 369 | 370 | # pass through controller 371 | outs[time], ( 372 | chx_layer_output, 373 | mem_layer_output, 374 | read_vectors, 375 | ) = self._layer_forward( 376 | inputs[time], layer, (chx_layer, mem_layer, read_vectors), pass_through_memory # type: ignore 377 | ) 378 | 379 | # debug memory 380 | if self.debug: 381 | viz = self._debug(mem_layer_output, viz) 382 | 383 | # store the memory back (per layer or shared) 384 | if self.share_memory_between_layers: 385 | mem_hidden[0] = mem_layer_output # type: ignore 386 | else: 387 | mem_hidden[layer] = mem_layer_output # type: ignore 388 | controller_hidden[layer] = chx_layer_output 389 | 390 | if read_vectors is not None: 391 | # the controller output + read vectors go into next layer 392 | outs[time] = torch.cat([outs[time], read_vectors], 1) # type: ignore 393 | else: 394 | outs[time] = torch.cat([outs[time], last_read], 1) # type: ignore 395 | inputs[time] = outs[time] # type: ignore 396 | 397 | if self.debug and viz: 398 | viz = {k: [np.array(v) for v in vs] for k, vs in viz.items()} 399 | viz = {k: [v.reshape(v.shape[0], -1) for v in vs] for k, vs in viz.items()} 400 | 401 | # pass through final output layer 402 | inputs_tensor = torch.stack(inputs) 403 | outputs = self.output(inputs_tensor) 404 | 405 | if not self.batch_first: 406 | outputs = outputs.transpose(0, 1) 407 | 408 | if isinstance(input_data, PackedSequence): 409 | outputs = pack_padded_sequence(outputs, lengths.cpu(), batch_first=self.batch_first, enforce_sorted=False) 410 | 411 | if self.debug: 412 | return outputs, (controller_hidden, mem_hidden, read_vectors), viz # type: ignore 413 | else: 414 | return outputs, (controller_hidden, mem_hidden, read_vectors) # type: ignore 415 | 416 | def __repr__(self) -> str: 417 | """Provides a string representation of the DNC module.""" 418 | 419 | s = "\n----------------------------------------\n" 420 | s += "{name}({input_size}, {hidden_size}" 421 | if self.rnn_type != "lstm": 422 | s += ", rnn_type={rnn_type}" 423 | if self.num_layers != 1: 424 | s += ", num_layers={num_layers}" 425 | if self.num_hidden_layers != 2: 426 | s += ", num_hidden_layers={num_hidden_layers}" 427 | if not self.bias: 428 | s += ", bias={bias}" 429 | if not self.batch_first: 430 | s += ", batch_first={batch_first}" 431 | if self.dropout != 0: 432 | s += ", dropout={dropout}" 433 | if self.nr_cells != 5: 434 | s += ", nr_cells={nr_cells}" 435 | if self.read_heads != 2: 436 | s += ", read_heads={read_heads}" 437 | if self.cell_size != 10: 438 | s += ", cell_size={cell_size}" 439 | if self.nonlinearity != "tanh": 440 | s += ", nonlinearity={nonlinearity}" 441 | if self.independent_linears: 442 | s += ", independent_linears={independent_linears}" 443 | if not self.share_memory_between_layers: 444 | s += ", share_memory_between_layers={share_memory_between_layers}" 445 | if self.debug: 446 | s += ", debug={debug}" 447 | if self.clip != 20: 448 | s += ", clip={clip}" 449 | if self.device: 450 | s += f", device='{self.device}'" 451 | 452 | s += ")\n" + super(DNC, self).__repr__() + "\n----------------------------------------\n" 453 | return s.format(name=self.__class__.__name__, **self.__dict__) 454 | -------------------------------------------------------------------------------- /dnc/sparse_temporal_memory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .util import cuda, σ, θ 9 | 10 | 11 | class SparseTemporalMemory(nn.Module): 12 | """Sparse Temporal Memory module.""" 13 | 14 | def __init__( 15 | self, 16 | input_size: int, 17 | mem_size: int = 512, 18 | cell_size: int = 32, 19 | independent_linears: bool = True, 20 | read_heads: int = 4, 21 | sparse_reads: int = 4, 22 | temporal_reads: int = 4, 23 | num_lists: int | None = None, 24 | index_checks: int = 32, 25 | device: torch.device | None = None, 26 | ): 27 | """Initialize SparseTemporalMemory. 28 | 29 | Args: 30 | input_size: Input size. 31 | mem_size: Memory size. 32 | cell_size: Size of each memory cell. 33 | independent_linears: Whether to use independent linear layers. 34 | read_heads: Number of read heads. 35 | sparse_reads: Number of sparse reads. 36 | temporal_reads: Number of temporal reads. 37 | num_lists: Number of lists for indexing. 38 | index_checks: Number of index checks. 39 | device: PyTorch device 40 | 41 | """ 42 | super(SparseTemporalMemory, self).__init__() 43 | 44 | self.mem_size = mem_size 45 | self.cell_size = cell_size 46 | self.device = device 47 | self.input_size = input_size 48 | self.independent_linears = independent_linears 49 | self.K = sparse_reads if self.mem_size > sparse_reads else self.mem_size 50 | self.KL = temporal_reads if self.mem_size > temporal_reads else self.mem_size 51 | self.read_heads = read_heads 52 | self.num_lists = num_lists if num_lists is not None else int(self.mem_size / 100) 53 | self.index_checks = index_checks 54 | 55 | m = self.mem_size 56 | w = self.cell_size 57 | r = self.read_heads 58 | # The visible memory size: (K * R read heads, forward and backward 59 | # temporal reads of size KL and least used memory cell) 60 | self.c = (r * self.K) + (self.KL * 2) + 1 61 | 62 | if self.independent_linears: 63 | self.read_query_transform = nn.Linear(self.input_size, w * r) 64 | self.write_vector_transform = nn.Linear(self.input_size, w) 65 | self.interpolation_gate_transform = nn.Linear(self.input_size, self.c) 66 | self.write_gate_transform = nn.Linear(self.input_size, 1) 67 | torch.nn.init.kaiming_uniform_(self.read_query_transform.weight) 68 | torch.nn.init.kaiming_uniform_(self.write_vector_transform.weight) 69 | torch.nn.init.kaiming_uniform_(self.interpolation_gate_transform.weight) 70 | torch.nn.init.kaiming_uniform_(self.write_gate_transform.weight) 71 | else: 72 | self.interface_size = (r * w) + w + self.c + 1 73 | self.interface_weights = nn.Linear(self.input_size, self.interface_size) 74 | torch.nn.init.kaiming_uniform_(self.interface_weights.weight) 75 | 76 | self.I = cuda(1 - torch.eye(self.c).unsqueeze(0), device=self.device) # (1 * n * n) 77 | self.δ = 0.005 # minimum usage 78 | self.timestep = 0 79 | self.mem_limit_reached = False 80 | 81 | if self.device is not None and self.device.type == "cuda": 82 | self.to(self.device) 83 | 84 | def rebuild_indexes(self, hidden: dict, erase: bool = False) -> dict: 85 | """Rebuilds the indexes for sparse memory access. 86 | 87 | Args: 88 | hidden: Hidden state dictionary. 89 | erase: Whether to erase the existing memory content. 90 | 91 | Returns: 92 | Updated hidden state dictionary. 93 | """ 94 | b = hidden["memory"].size(0) 95 | 96 | # if indexes already exist, we reset them 97 | if "indexes" in hidden: 98 | for x in hidden["indexes"]: 99 | x.reset() 100 | else: 101 | # create new indexes 102 | try: 103 | from .faiss_index import FAISSIndex 104 | 105 | hidden["indexes"] = [ 106 | FAISSIndex( 107 | cell_size=self.cell_size, 108 | nr_cells=self.mem_size, 109 | K=self.K, 110 | num_lists=self.num_lists, 111 | probes=self.index_checks, 112 | device=self.device, 113 | ) 114 | for _ in range(b) 115 | ] 116 | except ImportError: 117 | print( 118 | "FAISS not found, please install FAISS, consult https://github.com/facebookresearch/faiss/blob/main/INSTALL.md" 119 | ) 120 | raise 121 | 122 | # add existing memory into indexes 123 | pos = hidden["read_positions"].squeeze().detach().cpu().numpy() 124 | if not erase: 125 | for n, i in enumerate(hidden["indexes"]): 126 | i.reset() 127 | i.add(hidden["memory"][n], last=pos[n][-1]) 128 | else: 129 | self.timestep = 0 130 | self.mem_limit_reached = False 131 | 132 | return hidden 133 | 134 | def reset(self, batch_size: int = 1, hidden: dict | None = None, erase: bool = True) -> dict: 135 | """Resets the memory and hidden state. 136 | 137 | Args: 138 | batch_size: Batch size. 139 | hidden: Hidden state dictionary. 140 | erase: Whether to erase the existing memory content 141 | 142 | Returns: 143 | Reset hidden state dictionary. 144 | """ 145 | m = self.mem_size 146 | w = self.cell_size 147 | b = batch_size 148 | r = self.read_heads 149 | c = self.c 150 | 151 | if hidden is None: 152 | hidden = { 153 | # warning can be a huge chunk of contiguous memory 154 | "memory": cuda(torch.zeros(b, m, w).fill_(self.δ), device=self.device), 155 | "visible_memory": cuda(torch.zeros(b, c, w).fill_(self.δ), device=self.device), 156 | "link_matrix": cuda(torch.zeros(b, m, self.KL * 2), device=self.device), 157 | "rev_link_matrix": cuda(torch.zeros(b, m, self.KL * 2), device=self.device), 158 | "precedence": cuda(torch.zeros(b, self.KL * 2).fill_(self.δ), device=self.device), 159 | "read_weights": cuda(torch.zeros(b, m).fill_(self.δ), device=self.device), 160 | "write_weights": cuda(torch.zeros(b, m).fill_(self.δ), device=self.device), 161 | "read_vectors": cuda(torch.zeros(b, r, w).fill_(self.δ), device=self.device), 162 | "least_used_mem": cuda(torch.zeros(b, 1).fill_(c + 1), device=self.device).long(), 163 | "usage": cuda(torch.zeros(b, m), device=self.device), 164 | "read_positions": cuda(torch.arange(0, c).expand(b, c), device=self.device).long(), 165 | } 166 | hidden = self.rebuild_indexes(hidden, erase=True) 167 | else: 168 | # duplication is faster than moving tensors between devices (or even cloning) 169 | hidden["memory"] = hidden["memory"].clone() 170 | hidden["visible_memory"] = hidden["visible_memory"].clone() 171 | hidden["link_matrix"] = hidden["link_matrix"].clone() 172 | hidden["rev_link_matrix"] = hidden["link_matrix"].clone() 173 | hidden["precedence"] = hidden["precedence"].clone() 174 | hidden["read_weights"] = hidden["read_weights"].clone() 175 | hidden["write_weights"] = hidden["write_weights"].clone() 176 | hidden["read_vectors"] = hidden["read_vectors"].clone() 177 | hidden["least_used_mem"] = hidden["least_used_mem"].clone() 178 | hidden["usage"] = hidden["usage"].clone() 179 | hidden["read_positions"] = hidden["read_positions"].clone() 180 | hidden = self.rebuild_indexes(hidden, erase) 181 | 182 | if erase: 183 | hidden["memory"].data.fill_(self.δ) 184 | hidden["visible_memory"].data.fill_(self.δ) 185 | hidden["link_matrix"].data.zero_() 186 | hidden["rev_link_matrix"].data.zero_() 187 | hidden["precedence"].data.zero_() 188 | hidden["read_weights"].data.fill_(self.δ) 189 | hidden["write_weights"].data.fill_(self.δ) 190 | hidden["read_vectors"].data.fill_(self.δ) 191 | hidden["least_used_mem"].data.fill_(c + 1) 192 | hidden["usage"].data.fill_(0) 193 | hidden["read_positions"] = cuda(torch.arange(0, c).expand(b, c), device=self.device).long() 194 | 195 | return hidden 196 | 197 | def write_into_sparse_memory(self, hidden: dict) -> dict: 198 | """Writes the visible memory into the sparse memory matrix. 199 | 200 | Args: 201 | hidden (dict): Hidden state dictionary 202 | 203 | Returns: 204 | dict: Updated hidden state dictionary. 205 | """ 206 | visible_memory = hidden["visible_memory"] 207 | positions = hidden["read_positions"] 208 | 209 | (b, m, w) = hidden["memory"].size() 210 | # Create a new tensor for memory to avoid inplace operations during backprop 211 | new_memory = hidden["memory"].clone() 212 | # update memory (using non-inplace operation) 213 | new_memory.scatter_(1, positions.unsqueeze(2).expand(b, self.c, w), visible_memory) 214 | hidden["memory"] = new_memory 215 | 216 | # non-differentiable operations 217 | pos = positions.detach().cpu().numpy() 218 | for batch in range(b): 219 | # update indexes 220 | hidden["indexes"][batch].reset() 221 | hidden["indexes"][batch].add( 222 | hidden["memory"][batch], last=(pos[batch][-1] if not self.mem_limit_reached else None) 223 | ) 224 | 225 | mem_limit_reached = hidden["least_used_mem"][0].detach().cpu().numpy()[0] >= self.mem_size - 1 226 | self.mem_limit_reached = mem_limit_reached or self.mem_limit_reached 227 | 228 | return hidden 229 | 230 | def update_link_matrices( 231 | self, 232 | link_matrix: torch.Tensor, 233 | rev_link_matrix: torch.Tensor, 234 | write_weights: torch.Tensor, 235 | precedence: torch.Tensor, 236 | temporal_read_positions: torch.Tensor, 237 | ) -> tuple[torch.Tensor, torch.Tensor]: 238 | """Updates the forward and backward link matrices. 239 | 240 | Args: 241 | link_matrix: Forward link matrix. 242 | rev_link_matrix: Backward link matrix. 243 | write_weights: Write weights. 244 | precedence: Precedence vector. 245 | temporal_read_positions: Temporal read positions. 246 | 247 | Returns: 248 | Tuple: Updated forward and backward link matrices. 249 | """ 250 | write_weights_i = write_weights.unsqueeze(2) 251 | precedence_j = precedence.unsqueeze(1) 252 | 253 | (b, m, k) = link_matrix.size() 254 | I = cuda(torch.eye(m, k).unsqueeze(0).expand((b, m, k)), device=self.device) 255 | 256 | # since only KL*2 entries are kept non-zero sparse, create the dense version from the sparse one 257 | precedence_dense = cuda(torch.zeros(b, m), device=self.device) 258 | # Use non-inplace operation - create new tensor and assign 259 | precedence_tmp = precedence_dense.clone() 260 | precedence_tmp.scatter_(1, temporal_read_positions, precedence) 261 | precedence_dense = precedence_tmp 262 | precedence_dense_i = precedence_dense.unsqueeze(2) 263 | 264 | temporal_write_weights_j = write_weights.gather(1, temporal_read_positions).unsqueeze(1) 265 | 266 | link_matrix = (1 - write_weights_i) * link_matrix + write_weights_i * precedence_j 267 | 268 | rev_link_matrix = (1 - temporal_write_weights_j) * rev_link_matrix + ( 269 | temporal_write_weights_j * precedence_dense_i 270 | ) 271 | 272 | return link_matrix * I, rev_link_matrix * I 273 | 274 | def update_precedence(self, precedence: torch.Tensor, write_weights: torch.Tensor) -> torch.Tensor: 275 | """Updates the precedence vector. 276 | 277 | Args: 278 | precedence: Precedence vector. 279 | write_weights: Write weights. 280 | 281 | Returns: 282 | Updated precedence vector. 283 | 284 | """ 285 | return (1 - torch.sum(write_weights, dim=-1, keepdim=True)) * precedence + write_weights 286 | 287 | def write( 288 | self, interpolation_gate: torch.Tensor, write_vector: torch.Tensor, write_gate: torch.Tensor, hidden: dict 289 | ) -> dict: 290 | """Performs the memory write operation. 291 | 292 | Args: 293 | interpolation_gate : Interpolation gate. 294 | write_vector: Write vector. 295 | write_gate: Write gate. 296 | hidden: Hidden state dictionary 297 | 298 | Returns: 299 | Updated hidden state dictionary. 300 | 301 | """ 302 | 303 | read_weights = hidden["read_weights"].gather(1, hidden["read_positions"]) 304 | # encourage read and write in the first timestep 305 | if self.timestep == 1: 306 | read_weights = read_weights + 1 307 | write_weights = hidden["write_weights"].gather(1, hidden["read_positions"]) 308 | 309 | hidden["usage"], I = self.update_usage(hidden["read_positions"], read_weights, write_weights, hidden["usage"]) 310 | 311 | # either we write to previous read locations 312 | x = interpolation_gate * read_weights 313 | # or to a new location 314 | y = (1 - interpolation_gate) * I 315 | write_weights = write_gate * (x + y) 316 | 317 | # store the write weights (avoid inplace operations) 318 | new_write_weights = hidden["write_weights"].clone() 319 | new_write_weights.scatter_(1, hidden["read_positions"], write_weights) 320 | hidden["write_weights"] = new_write_weights 321 | 322 | # erase matrix 323 | erase_matrix = I.unsqueeze(2).expand(hidden["visible_memory"].size()) 324 | 325 | # write into memory 326 | hidden["visible_memory"] = hidden["visible_memory"] * (1 - erase_matrix) + torch.bmm( 327 | write_weights.unsqueeze(2), write_vector 328 | ) 329 | hidden = self.write_into_sparse_memory(hidden) 330 | 331 | # update link_matrix and precedence 332 | (b, _) = write_weights.size() # c 333 | 334 | # update link matrix 335 | temporal_read_positions = hidden["read_positions"][:, self.read_heads * self.K + 1 :] 336 | hidden["link_matrix"], hidden["rev_link_matrix"] = self.update_link_matrices( 337 | hidden["link_matrix"], 338 | hidden["rev_link_matrix"], 339 | hidden["write_weights"], 340 | hidden["precedence"], 341 | temporal_read_positions, 342 | ) 343 | 344 | # update precedence vector 345 | read_weights = hidden["read_weights"].gather(1, temporal_read_positions) 346 | hidden["precedence"] = self.update_precedence(hidden["precedence"], read_weights) 347 | 348 | # update least used memory cell 349 | hidden["least_used_mem"] = torch.topk(hidden["usage"], 1, dim=-1, largest=False)[1] 350 | 351 | return hidden 352 | 353 | def update_usage( 354 | self, read_positions: torch.Tensor, read_weights: torch.Tensor, write_weights: torch.Tensor, usage: torch.Tensor 355 | ) -> tuple[torch.Tensor, torch.Tensor]: 356 | """Updates the usage vector. 357 | 358 | Args: 359 | read_positions: Read positions. 360 | read_weights: Read weights. 361 | write_weights: Write weights. 362 | usage: Usage vector. 363 | 364 | Returns: 365 | Tuple: Updated usage vector and indicator matrix. 366 | """ 367 | # usage is timesteps since a non-negligible memory access 368 | u = (read_weights + write_weights > self.δ).float() 369 | 370 | # usage before write 371 | relevant_usages = usage.gather(1, read_positions) 372 | 373 | # indicator of words with minimal memory usage 374 | minusage = torch.min(relevant_usages, -1, keepdim=True)[0] 375 | minusage = minusage.expand(relevant_usages.size()) 376 | I = (relevant_usages == minusage).float() 377 | 378 | # usage after write 379 | relevant_usages = (self.timestep - relevant_usages) * u + relevant_usages * (1 - u) 380 | 381 | # Replace inplace scatter with clone + scatter + assignment 382 | new_usage = usage.clone() 383 | new_usage.scatter_(1, read_positions, relevant_usages) 384 | usage = new_usage 385 | 386 | return usage, I 387 | 388 | def directional_weightings( 389 | self, link_matrix: torch.Tensor, rev_link_matrix: torch.Tensor, temporal_read_weights: torch.Tensor 390 | ) -> tuple[torch.Tensor, torch.Tensor]: 391 | """Calculates the forward and backward weightings. 392 | 393 | Args: 394 | link_matrix: Forward link matrix. 395 | rev_link_matrix: Backward link matrix. 396 | temporal_read_weights: Temporal read weights. 397 | 398 | Returns: 399 | Tuple: Forward and backward weightings. 400 | """ 401 | f = torch.bmm(link_matrix, temporal_read_weights.unsqueeze(2)).squeeze(2) 402 | b = torch.bmm(rev_link_matrix, temporal_read_weights.unsqueeze(2)).squeeze(2) 403 | return f, b 404 | 405 | def read_from_sparse_memory( 406 | self, 407 | memory: torch.Tensor, 408 | indexes: list, 409 | keys: torch.Tensor, 410 | least_used_mem: torch.Tensor, 411 | usage: torch.Tensor, 412 | forward: torch.Tensor, 413 | backward: torch.Tensor, 414 | prev_read_positions: torch.Tensor, 415 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 416 | """Reads from sparse memory using indexes. 417 | 418 | Args: 419 | memory: Memory tensor. 420 | indexes: List of indexes. 421 | keys: Read keys. 422 | least_used_mem: Least used memory locations. 423 | usage: Usage vector. 424 | forward: Forward weightings. 425 | backward: Backward weightings. 426 | prev_read_positions: Previous read positions. 427 | 428 | Returns: 429 | Tuple: Read vectors, read positions, read weights, and visible memory. 430 | """ 431 | b = keys.size(0) 432 | rpos = [] 433 | 434 | # we search for k cells per read head 435 | for batch in range(b): 436 | distances, positions = indexes[batch].search(keys[batch]) 437 | rpos.append(positions) 438 | read_positions = torch.stack(rpos, 0) 439 | 440 | (b, r, k) = read_positions.size() 441 | read_positions = read_positions.squeeze(1).view(b, -1) 442 | 443 | # no gradient here 444 | # temporal reads 445 | (b, m, w) = memory.size() 446 | # get the top KL entries 447 | max_length = int(least_used_mem[0, 0].detach().cpu().numpy()) if not self.mem_limit_reached else (m - 1) 448 | 449 | _, fp = torch.topk(forward, self.KL, largest=True) 450 | _, bp = torch.topk(backward, self.KL, largest=True) 451 | 452 | # differentiable ops 453 | # append forward and backward read positions, might lead to duplicates 454 | read_positions = torch.cat([read_positions, fp, bp], 1) 455 | read_positions = torch.cat([read_positions, least_used_mem], 1) 456 | read_positions = torch.clamp(read_positions, 0, max_length) 457 | 458 | visible_memory = memory.gather(1, read_positions.unsqueeze(2).expand(b, self.c, w)) 459 | 460 | read_weights = σ(θ(visible_memory, keys), 2) 461 | read_vectors = torch.bmm(read_weights, visible_memory) 462 | read_weights = torch.prod(read_weights, 1) 463 | 464 | return read_vectors, read_positions, read_weights, visible_memory 465 | 466 | def read(self, read_query: torch.Tensor, hidden: dict) -> tuple[torch.Tensor, dict]: 467 | """Performs the memory read operation. 468 | 469 | Args: 470 | read_query: Read query. 471 | hidden: Hidden state dictionary. 472 | 473 | Returns: 474 | Tuple: Read vectors and updated hidden state. 475 | """ 476 | # get forward and backward weights 477 | temporal_read_positions = hidden["read_positions"][:, self.read_heads * self.K + 1 :] 478 | read_weights = hidden["read_weights"].gather(1, temporal_read_positions) 479 | forward, backward = self.directional_weightings(hidden["link_matrix"], hidden["rev_link_matrix"], read_weights) 480 | 481 | # sparse read 482 | read_vectors, positions, read_weights, visible_memory = self.read_from_sparse_memory( 483 | hidden["memory"], 484 | hidden["indexes"], 485 | read_query, 486 | hidden["least_used_mem"], 487 | hidden["usage"], 488 | forward, 489 | backward, 490 | hidden["read_positions"], 491 | ) 492 | 493 | hidden["read_positions"] = positions 494 | # Avoid inplace operation 495 | new_read_weights = hidden["read_weights"].clone() 496 | new_read_weights.scatter_(1, positions, read_weights) 497 | hidden["read_weights"] = new_read_weights 498 | hidden["read_vectors"] = read_vectors 499 | hidden["visible_memory"] = visible_memory 500 | 501 | return hidden["read_vectors"], hidden 502 | 503 | def forward(self, ξ: torch.Tensor, hidden: dict) -> tuple[torch.Tensor, dict]: 504 | """Forward pass through the memory. 505 | 506 | Args: 507 | ξ: Input tensor. 508 | hidden: Hidden state dictionary 509 | 510 | Returns: 511 | Tuple: Read vectors and updated hidden state. 512 | """ 513 | m = self.mem_size 514 | w = self.cell_size 515 | r = self.read_heads 516 | c = self.c 517 | b = ξ.size(0) 518 | 519 | if self.independent_linears: 520 | # r read keys (b * r * w) 521 | read_query = self.read_query_transform(ξ).view(b, r, w) 522 | # write key (b * 1 * w) 523 | write_vector = self.write_vector_transform(ξ).view(b, 1, w) 524 | # write vector (b * 1 * r) 525 | interpolation_gate = torch.sigmoid(self.interpolation_gate_transform(ξ)).view(b, c) 526 | # write gate (b * 1) 527 | write_gate = torch.sigmoid(self.write_gate_transform(ξ).view(b, 1)) 528 | else: 529 | ξ = self.interface_weights(ξ) 530 | # r read keys (b * r * w) 531 | read_query = ξ[:, : r * w].contiguous().view(b, r, w) 532 | # write key (b * 1 * w) 533 | write_vector = ξ[:, r * w : r * w + w].contiguous().view(b, 1, w) 534 | # write vector (b * 1 * r) 535 | interpolation_gate = torch.sigmoid(ξ[:, r * w + w : r * w + w + c]).contiguous().view(b, c) 536 | # write gate (b * 1) 537 | write_gate = torch.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1) 538 | 539 | self.timestep += 1 540 | hidden = self.write(interpolation_gate, write_vector, write_gate, hidden) 541 | return self.read(read_query, hidden) 542 | --------------------------------------------------------------------------------