├── README.rst
├── tasks
├── __init__.py
├── adding_task.py
├── argmax_task.py
└── copy_task.py
├── docs
├── dnc.png
├── dnc0.png
├── dnc1.png
├── dnc2.png
├── dnc-mem-debug.png
└── dnc.xml
├── release.sh
├── .gitignore
├── .vscode
└── settings.json
├── dnc
├── __init__.py
├── faiss_index.py
├── sam.py
├── sdnc.py
├── util.py
├── sparse_memory.py
├── dnc.py
└── sparse_temporal_memory.py
├── .travis.yml
├── setup.cfg
├── LICENSE
├── requirements.txt
├── test
├── test_utils.py
├── test_rnn.py
├── test_lstm.py
├── test_gru.py
├── test_sam_gru.py
├── test_sam_rnn.py
├── test_sam_lstm.py
├── test_sdnc_rnn.py
├── test_sdnc_gru.py
└── test_sdnc_lstm.py
├── .pre-commit-config.yaml
├── setup.py
├── CHANGELOG.md
└── scripts
└── build_faiss.sh
/README.rst:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tasks/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/dnc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ixaxaar/pytorch-dnc/HEAD/docs/dnc.png
--------------------------------------------------------------------------------
/docs/dnc0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ixaxaar/pytorch-dnc/HEAD/docs/dnc0.png
--------------------------------------------------------------------------------
/docs/dnc1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ixaxaar/pytorch-dnc/HEAD/docs/dnc1.png
--------------------------------------------------------------------------------
/docs/dnc2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ixaxaar/pytorch-dnc/HEAD/docs/dnc2.png
--------------------------------------------------------------------------------
/docs/dnc-mem-debug.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ixaxaar/pytorch-dnc/HEAD/docs/dnc-mem-debug.png
--------------------------------------------------------------------------------
/release.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | rm -rf dist/
4 | python3 setup.py sdist
5 | python3 setup.py bdist_wheel
6 | twine upload dist/*
7 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pt
2 | __pycache__/
3 | .pypirc
4 | pred.txt
5 | multi-bleu.perl
6 | *.pt
7 | *.pyc
8 | #.*
9 | .idea
10 | *.sublime-*
11 | .DS_Store
12 | data/
13 | build/
14 | venv/
15 | __pycache__/
16 | *.lang
17 | *.log
18 | .cache/
19 | dist/
20 | dnc.egg-info/
21 | tasks/checkpoints/
22 | faiss/
23 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "editor.tokenColorCustomizations": {
3 | "comments": "",
4 | "textMateRules": []
5 | },
6 | "python.formatting.provider": "black",
7 | "python.linting.enabled": true,
8 | "python.linting.flake8Enabled": false,
9 | "python.linting.mypyEnabled": true
10 | }
11 |
--------------------------------------------------------------------------------
/dnc/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | from .dnc import DNC
5 | from .memory import Memory
6 | from .sam import SAM
7 | from .sdnc import SDNC
8 | from .sparse_memory import SparseMemory
9 | from .sparse_temporal_memory import SparseTemporalMemory
10 |
11 | __all__ = [
12 | "DNC",
13 | "Memory",
14 | "SAM",
15 | "SDNC",
16 | "SparseMemory",
17 | "SparseTemporalMemory",
18 | ]
19 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - "3.6"
4 | # command to install dependencies
5 | before_install:
6 | - sudo apt-get -qq update
7 | - sudo apt-get install -yqq software-properties-common git
8 | - sudo apt-get install -yqq libopenblas-dev liblapack3 python3-numpy python3-dev swig
9 | - sudo ln -s /usr/lib/libopenblas.so /usr/lib/libopenblas.so.3
10 | install:
11 | - pip install -r ./requirements.txt
12 | # command to run tests
13 | script:
14 | - pytest ./test
15 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [bdist_wheel]
2 | universal=0
3 |
4 | [flake8]
5 | exclude = */__init__.py,migrations/*
6 | ignore = D401, E741, DAR003, D100, B010, E203, W503, F405, F403, E126, E128, E501, F841, W505, E203, D202, DAR201, DAR401, F401, D107, DAR101, D400, D102, D205, D101
7 |
8 | max-line-length = 99
9 | max-doc-length = 99
10 | show-source = true
11 |
12 | [yapf]
13 | based_on_style = pep8
14 | allow_multiline_lambdas = True
15 | indent_dictionary_value = True
16 | allow_split_before_dict_value = False
17 | blank_lines_around_top_level_definition = 2
18 | column_limit = 99
19 | continuation_indent_width = 4
20 | indent_width = 4
21 | use_tabs = False
22 | split_before_arithmetic_operator = True
23 |
24 | [isort]
25 | line_length = 99
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Russi Chatterjee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | autoflake==2.3.1
2 | certifi==2025.1.31
3 | charset-normalizer==3.4.1
4 | exceptiongroup==1.2.2
5 | filelock==3.17.0
6 | fsspec==2025.2.0
7 | idna==3.10
8 | iniconfig==2.0.0
9 | Jinja2==3.1.6
10 | jsonpatch==1.33
11 | jsonpointer==3.0.0
12 | MarkupSafe==3.0.2
13 | mpmath==1.3.0
14 | networkx==3.4.2
15 | numpy==2.2.3
16 | nvidia-cublas-cu12==12.4.5.8
17 | nvidia-cuda-cupti-cu12==12.4.127
18 | nvidia-cuda-nvrtc-cu12==12.4.127
19 | nvidia-cuda-runtime-cu12==12.4.127
20 | nvidia-cudnn-cu12==9.1.0.70
21 | nvidia-cufft-cu12==11.2.1.3
22 | nvidia-curand-cu12==10.3.5.147
23 | nvidia-cusolver-cu12==11.6.1.9
24 | nvidia-cusparse-cu12==12.3.1.170
25 | nvidia-cusparselt-cu12==0.6.2
26 | nvidia-nccl-cu12==2.21.5
27 | nvidia-nvjitlink-cu12==12.4.127
28 | nvidia-nvtx-cu12==12.4.127
29 | packaging==24.2
30 | pillow==11.1.0
31 | pluggy==1.5.0
32 | pyflakes==3.2.0
33 | pytest==8.3.4
34 | requests==2.32.4
35 | scipy==1.15.2
36 | six==1.17.0
37 | sympy==1.13.1
38 | tomli==2.2.1
39 | torch==2.7.1
40 | torchaudio==2.6.0
41 | torchvision==0.21.0
42 | tornado==6.5.1
43 | triton==3.2.0
44 | typing_extensions==4.12.2
45 | urllib3==2.6.0
46 | visdom==0.2.4
47 | websocket-client==1.8.0
--------------------------------------------------------------------------------
/test/test_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | import numpy as np
7 |
8 |
9 | def generate_data(
10 | batch_size: int, length: int, size: int, device: torch.device = torch.device("cpu")
11 | ) -> tuple[torch.Tensor, torch.Tensor]:
12 | """Generates data for the copy task, directly on the specified device."""
13 | input_data = np.zeros((batch_size, 2 * length + 1, size), dtype=np.float32)
14 | target_output = np.zeros((batch_size, 2 * length + 1, size), dtype=np.float32)
15 | sequence = np.random.binomial(1, 0.5, (batch_size, length, size - 1))
16 |
17 | input_data[:, :length, : size - 1] = sequence
18 | input_data[:, length, -1] = 1 # the end symbol
19 | target_output[:, length + 1 :, : size - 1] = sequence
20 |
21 | return (torch.tensor(input_data, device=device), torch.tensor(target_output, device=device))
22 |
23 |
24 | def criterion(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
25 | """Calculates the binary cross-entropy loss with logits."""
26 | return F.binary_cross_entropy_with_logits(predictions, targets)
27 |
28 |
29 | def get_device(cuda_id: int) -> torch.device:
30 | """Gets the torch device based on CUDA availability and ID."""
31 | if cuda_id >= 0 and torch.cuda.is_available():
32 | return torch.device(f"cuda:{cuda_id}")
33 | return torch.device("cpu")
34 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v5.0.0
4 | hooks:
5 | - id: no-commit-to-branch
6 | args: [--branch, master]
7 | - id: check-merge-conflict
8 | - id: trailing-whitespace
9 | - id: end-of-file-fixer
10 | - id: check-added-large-files
11 | - id: check-yaml
12 | args: [--allow-multiple-documents]
13 | - id: check-json
14 | - id: pretty-format-json
15 | args: [--autofix, --no-sort-keys]
16 | - id: check-xml
17 | - id: debug-statements
18 | - id: check-case-conflict
19 | - id: detect-private-key
20 | - id: requirements-txt-fixer
21 |
22 | - repo: https://github.com/psf/black
23 | rev: 25.1.0
24 | hooks:
25 | - id: black
26 | language: python
27 | args: [--line-length=120]
28 |
29 | - repo: https://github.com/PyCQA/flake8
30 | rev: 7.1.2
31 | hooks:
32 | - id: flake8
33 | additional_dependencies:
34 | ["flake8-bugbear", "flake8-docstrings", "darglint"]
35 | args: [--max-line-length=120]
36 | files: ^dnc/
37 |
38 | - repo: https://github.com/pre-commit/mirrors-mypy
39 | rev: "v1.15.0"
40 | hooks:
41 | - id: mypy
42 | additional_dependencies: ["torch", "numpy", "types-PyYAML"]
43 | args: [--ignore-missing-imports]
44 | files: ^dnc/
45 |
46 | default_language_version:
47 | python: python3
48 |
49 | fail_fast: true
50 | exclude: ^migrations/
51 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """A setuptools based setup module.
4 |
5 | See:
6 | https://packaging.python.org/en/latest/distributing.html
7 | https://github.com/pypa/sampleproject
8 | """
9 |
10 | # Always prefer setuptools over distutils
11 | from setuptools import setup, find_packages
12 |
13 | # To use a consistent encoding
14 | from codecs import open
15 | from os import path
16 |
17 | here = path.abspath(path.dirname(__file__))
18 |
19 | # Get the long description from the README file
20 | with open(path.join(here, "README.rst"), encoding="utf-8") as f:
21 | long_description = f.read()
22 |
23 | setup(
24 | name="dnc",
25 | version="2.0.0b1",
26 | description="Differentiable Neural Computer, for Pytorch",
27 | long_description=long_description,
28 | # The project's main homepage.
29 | url="https://github.com/pypa/dnc",
30 | # Author details
31 | author="Russi Chatterjee",
32 | author_email="root@ixaxaar.in",
33 | # Choose your license
34 | license="MIT",
35 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers
36 | classifiers=[
37 | "Development Status :: 3 - Alpha",
38 | "Intended Audience :: Science/Research",
39 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
40 | "License :: OSI Approved :: MIT License",
41 | "Programming Language :: Python :: 3",
42 | "Programming Language :: Python :: 3.3",
43 | "Programming Language :: Python :: 3.4",
44 | "Programming Language :: Python :: 3.5",
45 | "Programming Language :: Python :: 3.6",
46 | ],
47 | keywords="differentiable neural computer dnc memory network",
48 | packages=find_packages(exclude=["contrib", "docs", "tests", "tasks", "scripts"]),
49 | install_requires=["torch", "numpy", "flann"],
50 | extras_require={
51 | "dev": ["check-manifest"],
52 | "test": ["coverage"],
53 | },
54 | python_requires=">=3",
55 | )
56 |
--------------------------------------------------------------------------------
/dnc/faiss_index.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import faiss
5 | from faiss import cast_integer_to_float_ptr as cast_float
6 | from faiss import cast_integer_to_idx_t_ptr as cast_long
7 |
8 | import torch
9 |
10 | from .util import ptr, ensure_gpu
11 |
12 |
13 | class FAISSIndex(object):
14 | """FAISS Index for approximate nearest neighbor search."""
15 |
16 | def __init__(
17 | self,
18 | cell_size: int = 20,
19 | nr_cells: int = 1024,
20 | K: int = 4,
21 | num_lists: int = 32,
22 | probes: int = 32,
23 | res: faiss.GpuResources | None = None,
24 | train: torch.Tensor | None = None,
25 | device: torch.device | None = None,
26 | ):
27 | """Initialize FAISSIndex.
28 |
29 | Args:
30 | cell_size: Size of each memory cell.
31 | nr_cells: Number of memory cells.
32 | K: Number of nearest neighbors to retrieve.
33 | num_lists: Number of lists for the index.
34 | probes: Number of probes for searching.
35 | res: FAISS GpuResources object.
36 | train: Training data.
37 | device: PyTorch device
38 |
39 | """
40 | super(FAISSIndex, self).__init__()
41 | self.cell_size = cell_size
42 | self.nr_cells = nr_cells
43 | self.probes = probes
44 | self.K = K
45 | self.num_lists = num_lists
46 | self.device = device
47 |
48 | # BEWARE: if this variable gets deallocated, FAISS crashes
49 | self.res = res if res else faiss.StandardGpuResources()
50 | train_tensor = train if train is not None else torch.randn(self.nr_cells * 100, self.cell_size)
51 |
52 | # Configure FAISS resources for GPU if needed
53 | if self.device is not None and self.device.type == "cuda":
54 | self.res.setTempMemoryFraction(0.01)
55 | self.res.initializeForDevice(self.device.index if self.device.index is not None else 0)
56 | # Create GPU index with a quantizer
57 | quantizer = faiss.IndexFlatL2(self.cell_size)
58 | self.index = faiss.GpuIndexIVFFlat(self.res, quantizer, self.cell_size, self.num_lists, faiss.METRIC_L2)
59 | else:
60 | # Create CPU index for both None device and explicit CPU device
61 | # First create a quantizer
62 | quantizer = faiss.IndexFlatL2(self.cell_size)
63 | self.index = faiss.IndexIVFFlat(quantizer, self.cell_size, self.num_lists, faiss.METRIC_L2)
64 |
65 | # set number of probes
66 | self.index.nprobes = self.probes
67 | self.train(train_tensor)
68 |
69 | def train(self, train: torch.Tensor) -> None:
70 | """Trains the index.
71 |
72 | Args:
73 | train: Training data.
74 | """
75 | train = ensure_gpu(train, self.device)
76 |
77 | # Only synchronize if using CUDA
78 | if self.device is not None and self.device.type == "cuda":
79 | torch.cuda.synchronize(self.device)
80 |
81 | self.index.train_c(self.nr_cells, cast_float(ptr(train)))
82 |
83 | # Only synchronize if using CUDA
84 | if self.device is not None and self.device.type == "cuda":
85 | torch.cuda.synchronize(self.device)
86 |
87 | def reset(self) -> None:
88 | """Resets the index."""
89 | if self.device is not None and self.device.type == "cuda":
90 | torch.cuda.synchronize(self.device)
91 | self.index.reset()
92 | if self.device is not None and self.device.type == "cuda":
93 | torch.cuda.synchronize(self.device)
94 |
95 | def add(self, other: torch.Tensor, positions: torch.Tensor | None = None, last: int | None = None) -> None:
96 | """Adds vectors to the index.
97 |
98 | Args:
99 | other: Vectors to add.
100 | positions: Positions of the vectors.
101 | last: Index of the last vector to add.
102 | """
103 | other = ensure_gpu(other, self.device)
104 |
105 | if self.device is not None and self.device.type == "cuda":
106 | torch.cuda.synchronize(self.device)
107 |
108 | if positions is not None:
109 | positions = ensure_gpu(positions, self.device).long()
110 | assert positions.size(0) == other.size(0), "Mismatch in number of positions and vectors"
111 | self.index.add_with_ids_c(other.size(0), cast_float(ptr(other)), cast_long(ptr(positions + 1)))
112 | else:
113 | other = other[:last, :] if last is not None else other
114 | self.index.add_c(other.size(0), cast_float(ptr(other)))
115 |
116 | if self.device is not None and self.device.type == "cuda":
117 | torch.cuda.synchronize(self.device)
118 |
119 | def search(self, query: torch.Tensor, k: int | None = None) -> tuple[torch.Tensor, torch.Tensor]:
120 | """Searches the index for nearest neighbors.
121 |
122 | Args:
123 | query: Query vectors.
124 | k: Number of nearest neighbors to retrieve.
125 |
126 | Returns:
127 | Tuple: Distances and labels of the nearest neighbors.
128 | """
129 |
130 | query = ensure_gpu(query, self.device)
131 |
132 | k = k if k else self.K
133 | (b, _) = query.size()
134 |
135 | distances = torch.empty(b, k, device=self.device, dtype=torch.float32)
136 | labels = torch.empty(b, k, device=self.device, dtype=torch.int64)
137 |
138 | if self.device is not None and self.device.type == "cuda":
139 | torch.cuda.synchronize(self.device)
140 |
141 | self.index.search_c(b, cast_float(ptr(query)), k, cast_float(ptr(distances)), cast_long(ptr(labels)))
142 |
143 | if self.device is not None and self.device.type == "cuda":
144 | torch.cuda.synchronize(self.device)
145 |
146 | return (distances, (labels - 1))
147 |
--------------------------------------------------------------------------------
/test/test_rnn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | import torch as T
6 | import torch.optim as optim
7 |
8 | import sys
9 |
10 | sys.path.insert(0, ".")
11 |
12 | import functools
13 |
14 | from dnc import DNC
15 | from test_utils import generate_data, criterion
16 |
17 |
18 | def test_rnn_1():
19 | T.manual_seed(1111)
20 |
21 | input_size = 100
22 | hidden_size = 100
23 | rnn_type = "rnn"
24 | num_layers = 1
25 | num_hidden_layers = 1
26 | dropout = 0
27 | nr_cells = 1
28 | cell_size = 1
29 | read_heads = 1
30 | device = None
31 | debug = True
32 | lr = 0.001
33 | sequence_max_length = 10
34 | batch_size = 10
35 | cuda = device
36 | clip = 10
37 | length = 10
38 |
39 | rnn = DNC(
40 | input_size=input_size,
41 | hidden_size=hidden_size,
42 | rnn_type=rnn_type,
43 | num_layers=num_layers,
44 | num_hidden_layers=num_hidden_layers,
45 | dropout=dropout,
46 | nr_cells=nr_cells,
47 | cell_size=cell_size,
48 | read_heads=read_heads,
49 | device=device,
50 | debug=debug,
51 | )
52 |
53 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
54 | optimizer.zero_grad()
55 |
56 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
57 |
58 | output, (chx, mhx, rv), v = rnn(input_data, None)
59 |
60 | # Make output and target compatible for loss calculation
61 | # target: [batch, seq, features] -> [seq, batch, features]
62 | target_output = target_output.permute(1, 0, 2).contiguous()
63 |
64 | loss = criterion(output, target_output)
65 | loss.backward()
66 |
67 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
68 | optimizer.step()
69 |
70 | assert target_output.size() == T.Size([21, 10, 100])
71 | assert chx[0][0].size() == T.Size([10, 100])
72 | assert mhx[0]["memory"].size() == T.Size([10, 1, 1])
73 | assert rv.size() == T.Size([10, 1])
74 |
75 |
76 | def test_rnn_n():
77 | T.manual_seed(1111)
78 |
79 | input_size = 100
80 | hidden_size = 100
81 | rnn_type = "rnn"
82 | num_layers = 3
83 | num_hidden_layers = 5
84 | dropout = 0.2
85 | nr_cells = 12
86 | cell_size = 17
87 | read_heads = 3
88 | device = None
89 | debug = True
90 | lr = 0.001
91 | sequence_max_length = 10
92 | batch_size = 10
93 | cuda = device
94 | clip = 20
95 | length = 13
96 |
97 | rnn = DNC(
98 | input_size=input_size,
99 | hidden_size=hidden_size,
100 | rnn_type=rnn_type,
101 | num_layers=num_layers,
102 | num_hidden_layers=num_hidden_layers,
103 | dropout=dropout,
104 | nr_cells=nr_cells,
105 | cell_size=cell_size,
106 | read_heads=read_heads,
107 | device=device,
108 | debug=debug,
109 | )
110 |
111 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
112 | optimizer.zero_grad()
113 |
114 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
115 |
116 | output, (chx, mhx, rv), v = rnn(input_data, None)
117 |
118 | # Make output and target compatible for loss calculation
119 | # target: [batch, seq, features] -> [seq, batch, features]
120 | target_output = target_output.permute(1, 0, 2).contiguous()
121 |
122 | loss = criterion(output, target_output)
123 | loss.backward()
124 |
125 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
126 | optimizer.step()
127 |
128 | assert target_output.size() == T.Size([27, 10, 100])
129 | assert chx[1].size() == T.Size([num_hidden_layers, 10, 100])
130 | assert mhx[0]["memory"].size() == T.Size([10, 12, 17])
131 | assert rv.size() == T.Size([10, 51])
132 |
133 |
134 | def test_rnn_no_memory_pass():
135 | T.manual_seed(1111)
136 |
137 | input_size = 100
138 | hidden_size = 100
139 | rnn_type = "rnn"
140 | num_layers = 3
141 | num_hidden_layers = 5
142 | dropout = 0.2
143 | nr_cells = 12
144 | cell_size = 17
145 | read_heads = 3
146 | device = None
147 | debug = True
148 | lr = 0.001
149 | sequence_max_length = 10
150 | batch_size = 10
151 | cuda = device
152 | clip = 20
153 | length = 13
154 |
155 | rnn = DNC(
156 | input_size=input_size,
157 | hidden_size=hidden_size,
158 | rnn_type=rnn_type,
159 | num_layers=num_layers,
160 | num_hidden_layers=num_hidden_layers,
161 | dropout=dropout,
162 | nr_cells=nr_cells,
163 | cell_size=cell_size,
164 | read_heads=read_heads,
165 | device=device,
166 | debug=debug,
167 | )
168 |
169 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
170 | optimizer.zero_grad()
171 |
172 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
173 |
174 | # Transform target to match expected output shape
175 | target_output = target_output.permute(1, 0, 2).contiguous()
176 |
177 | # Initialize hidden state explicitly
178 | controller_hidden = None
179 | memory_hidden = None
180 | last_read = None
181 | outputs = []
182 |
183 | for x in range(6):
184 | output, (controller_hidden, memory_hidden, last_read), v = rnn(
185 | input_data, (controller_hidden, memory_hidden, last_read), pass_through_memory=False
186 | )
187 | outputs.append(output)
188 |
189 | # Sum outputs for all iterations
190 | output = functools.reduce(lambda x, y: x + y, outputs)
191 | loss = criterion(output, target_output)
192 | loss.backward()
193 |
194 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
195 | optimizer.step()
196 |
197 | assert target_output.size() == T.Size([27, 10, 100])
198 | assert controller_hidden[1].size() == T.Size([num_hidden_layers, 10, 100])
199 | assert memory_hidden[0]["memory"].size() == T.Size([10, 12, 17])
200 | assert last_read is not None
201 |
--------------------------------------------------------------------------------
/test/test_lstm.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | import torch as T
6 | import torch.optim as optim
7 |
8 | import sys
9 | import functools
10 |
11 | sys.path.insert(0, ".")
12 |
13 | from dnc import DNC
14 | from test_utils import generate_data, criterion
15 |
16 |
17 | def test_rnn_1():
18 | T.manual_seed(1111)
19 |
20 | input_size = 100
21 | hidden_size = 100
22 | rnn_type = "lstm"
23 | num_layers = 1
24 | num_hidden_layers = 1
25 | dropout = 0
26 | nr_cells = 1
27 | cell_size = 1
28 | read_heads = 1
29 | device = None
30 | debug = True
31 | lr = 0.001
32 | sequence_max_length = 10
33 | batch_size = 10
34 | cuda = device
35 | clip = 10
36 | length = 10
37 |
38 | rnn = DNC(
39 | input_size=input_size,
40 | hidden_size=hidden_size,
41 | rnn_type=rnn_type,
42 | num_layers=num_layers,
43 | num_hidden_layers=num_hidden_layers,
44 | dropout=dropout,
45 | nr_cells=nr_cells,
46 | cell_size=cell_size,
47 | read_heads=read_heads,
48 | device=device,
49 | debug=debug,
50 | )
51 |
52 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
53 | optimizer.zero_grad()
54 |
55 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
56 |
57 | output, (chx, mhx, rv), v = rnn(input_data, None)
58 |
59 | # Make output and target compatible for loss calculation
60 | # target: [batch, seq, features] -> [seq, batch, features]
61 | target_output = target_output.permute(1, 0, 2).contiguous()
62 |
63 | loss = criterion(output, target_output)
64 | loss.backward()
65 |
66 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
67 | optimizer.step()
68 |
69 | assert target_output.size() == T.Size([21, 10, 100])
70 | assert chx[0][0][0].size() == T.Size([10, 100])
71 | assert mhx[0]["memory"].size() == T.Size([10, 1, 1])
72 | assert rv.size() == T.Size([10, 1])
73 |
74 |
75 | def test_rnn_n():
76 | T.manual_seed(1111)
77 |
78 | input_size = 100
79 | hidden_size = 100
80 | rnn_type = "lstm"
81 | num_layers = 3
82 | num_hidden_layers = 5
83 | dropout = 0.2
84 | nr_cells = 12
85 | cell_size = 17
86 | read_heads = 3
87 | device = None
88 | debug = True
89 | lr = 0.001
90 | sequence_max_length = 10
91 | batch_size = 10
92 | cuda = device
93 | clip = 20
94 | length = 13
95 |
96 | rnn = DNC(
97 | input_size=input_size,
98 | hidden_size=hidden_size,
99 | rnn_type=rnn_type,
100 | num_layers=num_layers,
101 | num_hidden_layers=num_hidden_layers,
102 | dropout=dropout,
103 | nr_cells=nr_cells,
104 | cell_size=cell_size,
105 | read_heads=read_heads,
106 | device=device,
107 | debug=debug,
108 | )
109 |
110 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
111 | optimizer.zero_grad()
112 |
113 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
114 |
115 | output, (chx, mhx, rv), v = rnn(input_data, None)
116 |
117 | # Make output and target compatible for loss calculation
118 | # target: [batch, seq, features] -> [seq, batch, features]
119 | target_output = target_output.permute(1, 0, 2).contiguous()
120 |
121 | loss = criterion(output, target_output)
122 | loss.backward()
123 |
124 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
125 | optimizer.step()
126 |
127 | assert target_output.size() == T.Size([27, 10, 100])
128 | assert chx[0][0].size() == T.Size([num_hidden_layers, 10, 100])
129 | assert mhx[0]["memory"].size() == T.Size([10, 12, 17])
130 | assert rv.size() == T.Size([10, 51])
131 |
132 |
133 | def test_rnn_no_memory_pass():
134 | T.manual_seed(1111)
135 |
136 | input_size = 100
137 | hidden_size = 100
138 | rnn_type = "lstm"
139 | num_layers = 3
140 | num_hidden_layers = 5
141 | dropout = 0.2
142 | nr_cells = 12
143 | cell_size = 17
144 | read_heads = 3
145 | device = None
146 | debug = True
147 | lr = 0.001
148 | sequence_max_length = 10
149 | batch_size = 10
150 | cuda = device
151 | clip = 20
152 | length = 13
153 |
154 | rnn = DNC(
155 | input_size=input_size,
156 | hidden_size=hidden_size,
157 | rnn_type=rnn_type,
158 | num_layers=num_layers,
159 | num_hidden_layers=num_hidden_layers,
160 | dropout=dropout,
161 | nr_cells=nr_cells,
162 | cell_size=cell_size,
163 | read_heads=read_heads,
164 | device=device,
165 | debug=debug,
166 | )
167 |
168 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
169 | optimizer.zero_grad()
170 |
171 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
172 |
173 | # Transform target to match expected output shape
174 | target_output = target_output.permute(1, 0, 2).contiguous()
175 |
176 | # Initialize hidden state explicitly
177 | controller_hidden = None
178 | memory_hidden = None
179 | last_read = None
180 | outputs = []
181 |
182 | for x in range(6):
183 | output, (controller_hidden, memory_hidden, last_read), v = rnn(
184 | input_data, (controller_hidden, memory_hidden, last_read), pass_through_memory=False
185 | )
186 | outputs.append(output)
187 |
188 | # Sum outputs for all iterations
189 | output = functools.reduce(lambda x, y: x + y, outputs)
190 | loss = criterion(output, target_output)
191 | loss.backward()
192 |
193 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
194 | optimizer.step()
195 |
196 | assert target_output.size() == T.Size([27, 10, 100])
197 | assert controller_hidden[0][0].size() == T.Size([num_hidden_layers, 10, 100])
198 | assert memory_hidden[0]["memory"].size() == T.Size([10, 12, 17])
199 | assert last_read is not None
200 |
--------------------------------------------------------------------------------
/test/test_gru.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | import torch as T
6 | import torch.optim as optim
7 |
8 | import sys
9 |
10 | sys.path.insert(0, ".")
11 |
12 | import functools
13 |
14 | from dnc import DNC
15 | from test_utils import generate_data, criterion
16 |
17 |
18 | def test_rnn_1():
19 | T.manual_seed(1111)
20 |
21 | input_size = 100
22 | hidden_size = 100
23 | rnn_type = "gru"
24 | num_layers = 1
25 | num_hidden_layers = 1
26 | dropout = 0
27 | nr_cells = 1
28 | cell_size = 1
29 | read_heads = 1
30 | device = None
31 | debug = True
32 | lr = 0.001
33 | sequence_max_length = 10
34 | batch_size = 10
35 | cuda = device
36 | clip = 10
37 | length = 10
38 |
39 | rnn = DNC(
40 | input_size=input_size,
41 | hidden_size=hidden_size,
42 | rnn_type=rnn_type,
43 | num_layers=num_layers,
44 | num_hidden_layers=num_hidden_layers,
45 | dropout=dropout,
46 | nr_cells=nr_cells,
47 | cell_size=cell_size,
48 | read_heads=read_heads,
49 | device=device,
50 | debug=debug,
51 | )
52 |
53 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
54 | optimizer.zero_grad()
55 |
56 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
57 |
58 | output, (chx, mhx, rv), v = rnn(input_data, None)
59 |
60 | # Make output and target compatible for loss calculation
61 | # target: [batch, seq, features] -> [seq, batch, features]
62 | target_output = target_output.permute(1, 0, 2).contiguous()
63 |
64 | loss = criterion(output, target_output)
65 | loss.backward()
66 |
67 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
68 | optimizer.step()
69 |
70 | assert target_output.size() == T.Size([21, 10, 100])
71 | assert chx[0][0].size() == T.Size([10, 100])
72 | assert mhx[0]["memory"].size() == T.Size([10, 1, 1])
73 | assert rv.size() == T.Size([10, 1])
74 |
75 |
76 | def test_rnn_n():
77 | T.manual_seed(1111)
78 |
79 | input_size = 100
80 | hidden_size = 100
81 | rnn_type = "gru"
82 | num_layers = 3
83 | num_hidden_layers = 5
84 | dropout = 0.2
85 | nr_cells = 12
86 | cell_size = 17
87 | read_heads = 3
88 | device = None
89 | debug = True
90 | lr = 0.001
91 | sequence_max_length = 10
92 | batch_size = 10
93 | cuda = device
94 | clip = 20
95 | length = 13
96 |
97 | rnn = DNC(
98 | input_size=input_size,
99 | hidden_size=hidden_size,
100 | rnn_type=rnn_type,
101 | num_layers=num_layers,
102 | num_hidden_layers=num_hidden_layers,
103 | dropout=dropout,
104 | nr_cells=nr_cells,
105 | cell_size=cell_size,
106 | read_heads=read_heads,
107 | device=device,
108 | debug=debug,
109 | )
110 |
111 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
112 | optimizer.zero_grad()
113 |
114 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
115 |
116 | output, (chx, mhx, rv), v = rnn(input_data, None)
117 |
118 | # Make output and target compatible for loss calculation
119 | # target: [batch, seq, features] -> [seq, batch, features]
120 | target_output = target_output.permute(1, 0, 2).contiguous()
121 |
122 | loss = criterion(output, target_output)
123 | loss.backward()
124 |
125 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
126 | optimizer.step()
127 |
128 | assert target_output.size() == T.Size([27, 10, 100])
129 | assert chx[1].size() == T.Size([num_hidden_layers, 10, 100])
130 | assert mhx[0]["memory"].size() == T.Size([10, 12, 17])
131 | assert rv.size() == T.Size([10, 51])
132 |
133 |
134 | def test_rnn_no_memory_pass():
135 | T.manual_seed(1111)
136 |
137 | input_size = 100
138 | hidden_size = 100
139 | rnn_type = "gru"
140 | num_layers = 3
141 | num_hidden_layers = 5
142 | dropout = 0.2
143 | nr_cells = 12
144 | cell_size = 17
145 | read_heads = 3
146 | device = None
147 | debug = True
148 | lr = 0.001
149 | sequence_max_length = 10
150 | batch_size = 10
151 | cuda = device
152 | clip = 20
153 | length = 13
154 |
155 | rnn = DNC(
156 | input_size=input_size,
157 | hidden_size=hidden_size,
158 | rnn_type=rnn_type,
159 | num_layers=num_layers,
160 | num_hidden_layers=num_hidden_layers,
161 | dropout=dropout,
162 | nr_cells=nr_cells,
163 | cell_size=cell_size,
164 | read_heads=read_heads,
165 | device=device,
166 | debug=debug,
167 | )
168 |
169 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
170 | optimizer.zero_grad()
171 |
172 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
173 |
174 | # Transform target to match expected output shape
175 | target_output = target_output.permute(1, 0, 2).contiguous()
176 |
177 | # Initialize hidden state explicitly
178 | controller_hidden = None
179 | memory_hidden = None
180 | last_read = None
181 | outputs = []
182 |
183 | for x in range(6):
184 | output, (controller_hidden, memory_hidden, last_read), v = rnn(
185 | input_data, (controller_hidden, memory_hidden, last_read), pass_through_memory=False
186 | )
187 | outputs.append(output)
188 |
189 | # Sum outputs for all iterations
190 | output = functools.reduce(lambda x, y: x + y, outputs)
191 | loss = criterion(output, target_output)
192 | loss.backward()
193 |
194 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
195 | optimizer.step()
196 |
197 | assert target_output.size() == T.Size([27, 10, 100])
198 | assert controller_hidden[0].size() == T.Size([num_hidden_layers, 10, 100])
199 | assert memory_hidden[0]["memory"].size() == T.Size([10, 12, 17])
200 | # Last read might not be None due to the memory access with pass_through_memory=False
201 | assert last_read is not None
202 |
--------------------------------------------------------------------------------
/dnc/sam.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | import torch
4 |
5 | from .dnc import DNC
6 | from .sparse_memory import SparseMemory
7 |
8 |
9 | class SAM(DNC):
10 | """Sparse Access Memory (SAM) module."""
11 |
12 | def __init__(
13 | self,
14 | input_size: int,
15 | hidden_size: int,
16 | rnn_type: str = "lstm",
17 | num_layers: int = 1,
18 | num_hidden_layers: int = 2,
19 | bias: bool = True,
20 | batch_first: bool = True,
21 | dropout: float = 0.0,
22 | bidirectional: bool = False,
23 | nr_cells: int = 5000,
24 | sparse_reads: int = 4,
25 | read_heads: int = 4,
26 | cell_size: int = 10,
27 | nonlinearity: str = "tanh",
28 | independent_linears: bool = False,
29 | share_memory: bool = True,
30 | debug: bool = False,
31 | clip: float = 20,
32 | device: torch.device | None = None,
33 | ):
34 | """Initialize SAM.
35 |
36 | Args:
37 | input_size: Input size.
38 | hidden_size: Hidden size.
39 | rnn_type: Type of RNN cell (lstm, gru, rnn).
40 | num_layers: Number of RNN layers.
41 | num_hidden_layers: Number of hidden layers in each RNN.
42 | bias: Whether to use bias in the RNN.
43 | batch_first: Whether the input is batch-first.
44 | dropout: Dropout rate.
45 | bidirectional: Whether the RNN is bidirectional.
46 | nr_cells: Number of memory cells.
47 | sparse_reads: Number of sparse reads.
48 | read_heads: Number of read heads.
49 | cell_size: Size of each memory cell.
50 | nonlinearity: Nonlinearity for RNN ('tanh' or 'relu').
51 | device: GPU ID (deprecated, use device).
52 | independent_linears: Whether to use independent linear layers in memory.
53 | share_memory: Whether to share memory across layers.
54 | debug: Whether to enable debug mode.
55 | clip: Value to clip controller output.
56 | device: PyTorch device.
57 | """
58 | super(SAM, self).__init__(
59 | input_size=input_size,
60 | hidden_size=hidden_size,
61 | rnn_type=rnn_type,
62 | num_layers=num_layers,
63 | num_hidden_layers=num_hidden_layers,
64 | bias=bias,
65 | batch_first=batch_first,
66 | dropout=dropout,
67 | nr_cells=nr_cells,
68 | read_heads=read_heads,
69 | cell_size=cell_size,
70 | nonlinearity=nonlinearity,
71 | independent_linears=independent_linears,
72 | share_memory_between_layers=share_memory,
73 | debug=debug,
74 | clip=clip,
75 | device=device,
76 | )
77 | self.sparse_reads = sparse_reads
78 | self.device = device
79 | # override SDNC memories with SAM
80 | self.memories = []
81 |
82 | for layer in range(self.num_layers):
83 | # memories for each layer
84 | if not self.share_memory_between_layers:
85 | self.memories.append(
86 | SparseMemory(
87 | input_size=self.output_size,
88 | mem_size=self.nr_cells,
89 | cell_size=self.w,
90 | sparse_reads=self.sparse_reads,
91 | read_heads=self.read_heads,
92 | device=self.device,
93 | independent_linears=self.independent_linears,
94 | )
95 | )
96 | setattr(self, "rnn_layer_memory_" + str(layer), self.memories[layer])
97 |
98 | # only one memory shared by all layers
99 | if self.share_memory_between_layers:
100 | self.memories.append(
101 | SparseMemory(
102 | input_size=self.output_size,
103 | mem_size=self.nr_cells,
104 | cell_size=self.w,
105 | sparse_reads=self.sparse_reads,
106 | read_heads=self.read_heads,
107 | device=self.device,
108 | independent_linears=self.independent_linears,
109 | )
110 | )
111 | setattr(self, "rnn_layer_memory_shared", self.memories[0])
112 |
113 | def _debug(self, mhx: dict, debug_obj: dict | None) -> dict | None:
114 | """Debug function to collect memory information.
115 | Args:
116 | mhx: Memory hidden state.
117 | debug_obj: Debug object to store information.
118 |
119 | Returns:
120 | Updated debug object or None.
121 | """
122 | if not debug_obj:
123 | debug_obj = {
124 | "memory": [],
125 | "visible_memory": [],
126 | "read_weights": [],
127 | "write_weights": [],
128 | "read_vectors": [],
129 | "least_used_mem": [],
130 | "usage": [],
131 | "read_positions": [],
132 | }
133 |
134 | debug_obj["memory"].append(mhx["memory"][0].detach().cpu().numpy())
135 | debug_obj["visible_memory"].append(mhx["visible_memory"][0].detach().cpu().numpy())
136 | debug_obj["read_weights"].append(mhx["read_weights"][0].unsqueeze(0).detach().cpu().numpy())
137 | debug_obj["write_weights"].append(mhx["write_weights"][0].unsqueeze(0).detach().cpu().numpy())
138 | debug_obj["read_vectors"].append(mhx["read_vectors"][0].detach().cpu().numpy())
139 | debug_obj["least_used_mem"].append(mhx["least_used_mem"][0].unsqueeze(0).detach().cpu().numpy())
140 | debug_obj["usage"].append(mhx["usage"][0].unsqueeze(0).detach().cpu().numpy())
141 | debug_obj["read_positions"].append(mhx["read_positions"][0].unsqueeze(0).detach().cpu().numpy())
142 |
143 | return debug_obj
144 |
--------------------------------------------------------------------------------
/test/test_sam_gru.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # #!/usr/bin/env python3
3 | # # -*- coding: utf-8 -*-
4 |
5 |
6 | import torch as T
7 | import torch.optim as optim
8 |
9 | import sys
10 | import functools
11 |
12 | sys.path.insert(0, ".")
13 |
14 | from dnc import SAM
15 | from test_utils import generate_data, criterion
16 |
17 |
18 | def test_rnn_1():
19 | T.manual_seed(1111)
20 |
21 | input_size = 100
22 | hidden_size = 100
23 | rnn_type = "gru"
24 | num_layers = 1
25 | num_hidden_layers = 1
26 | dropout = 0
27 | nr_cells = 100
28 | cell_size = 10
29 | read_heads = 1
30 | sparse_reads = 2
31 | device = None
32 | debug = True
33 | lr = 0.001
34 | sequence_max_length = 10
35 | batch_size = 10
36 | cuda = device
37 | clip = 10
38 | length = 10
39 |
40 | rnn = SAM(
41 | input_size=input_size,
42 | hidden_size=hidden_size,
43 | rnn_type=rnn_type,
44 | num_layers=num_layers,
45 | num_hidden_layers=num_hidden_layers,
46 | dropout=dropout,
47 | nr_cells=nr_cells,
48 | cell_size=cell_size,
49 | read_heads=read_heads,
50 | sparse_reads=sparse_reads,
51 | device=device,
52 | debug=debug,
53 | )
54 |
55 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
56 | optimizer.zero_grad()
57 |
58 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
59 |
60 | output, (chx, mhx, rv), v = rnn(input_data, None)
61 |
62 | # Make output and target compatible for loss calculation
63 | # target: [batch, seq, features] -> [seq, batch, features]
64 | target_output = target_output.permute(1, 0, 2).contiguous()
65 |
66 | loss = criterion(output, target_output)
67 | loss.backward()
68 |
69 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
70 | optimizer.step()
71 |
72 | assert target_output.size() == T.Size([21, 10, 100])
73 | assert chx[0][0].size() == T.Size([10, 100])
74 | # assert mhx['memory'].size() == T.Size([10,1,1])
75 | assert rv.size() == T.Size([10, 10])
76 |
77 |
78 | def test_rnn_n():
79 | T.manual_seed(1111)
80 |
81 | input_size = 100
82 | hidden_size = 100
83 | rnn_type = "gru"
84 | num_layers = 3
85 | num_hidden_layers = 5
86 | dropout = 0.2
87 | nr_cells = 200
88 | cell_size = 17
89 | read_heads = 2
90 | sparse_reads = 4
91 | device = None
92 | debug = True
93 | lr = 0.001
94 | sequence_max_length = 10
95 | batch_size = 10
96 | cuda = device
97 | clip = 20
98 | length = 13
99 |
100 | rnn = SAM(
101 | input_size=input_size,
102 | hidden_size=hidden_size,
103 | rnn_type=rnn_type,
104 | num_layers=num_layers,
105 | num_hidden_layers=num_hidden_layers,
106 | dropout=dropout,
107 | nr_cells=nr_cells,
108 | cell_size=cell_size,
109 | read_heads=read_heads,
110 | sparse_reads=sparse_reads,
111 | device=device,
112 | debug=debug,
113 | )
114 |
115 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
116 | optimizer.zero_grad()
117 |
118 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
119 |
120 | output, (chx, mhx, rv), v = rnn(input_data, None)
121 |
122 | # Make output and target compatible for loss calculation
123 | # target: [batch, seq, features] -> [seq, batch, features]
124 | target_output = target_output.permute(1, 0, 2).contiguous()
125 |
126 | loss = criterion(output, target_output)
127 | loss.backward()
128 |
129 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
130 | optimizer.step()
131 |
132 | assert target_output.size() == T.Size([27, 10, 100])
133 | assert chx[0].size() == T.Size([num_hidden_layers, 10, 100])
134 | # assert mhx['memory'].size() == T.Size([10,12,17])
135 | assert rv.size() == T.Size([10, 34])
136 |
137 |
138 | def test_rnn_no_memory_pass():
139 | T.manual_seed(1111)
140 |
141 | input_size = 100
142 | hidden_size = 100
143 | rnn_type = "gru"
144 | num_layers = 3
145 | num_hidden_layers = 5
146 | dropout = 0.2
147 | nr_cells = 5000
148 | cell_size = 17
149 | sparse_reads = 3
150 | device = None
151 | debug = True
152 | lr = 0.001
153 | sequence_max_length = 10
154 | batch_size = 10
155 | cuda = device
156 | clip = 20
157 | length = 13
158 |
159 | rnn = SAM(
160 | input_size=input_size,
161 | hidden_size=hidden_size,
162 | rnn_type=rnn_type,
163 | num_layers=num_layers,
164 | num_hidden_layers=num_hidden_layers,
165 | dropout=dropout,
166 | nr_cells=nr_cells,
167 | cell_size=cell_size,
168 | sparse_reads=sparse_reads,
169 | device=device,
170 | debug=debug,
171 | )
172 |
173 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
174 | optimizer.zero_grad()
175 |
176 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
177 |
178 | # Make output and target compatible for loss calculation
179 | # target: [batch, seq, features] -> [seq, batch, features]
180 | target_output = target_output.permute(1, 0, 2).contiguous()
181 |
182 | # Initialize hidden state explicitly
183 | controller_hidden = None
184 | memory_hidden = None
185 | last_read = None
186 | outputs = []
187 |
188 | for x in range(6):
189 | output, (controller_hidden, memory_hidden, last_read), v = rnn(
190 | input_data, (controller_hidden, memory_hidden, last_read), pass_through_memory=False
191 | )
192 | outputs.append(output)
193 |
194 | # Sum outputs for all iterations
195 | output = functools.reduce(lambda x, y: x + y, outputs)
196 | loss = criterion(output, target_output)
197 | loss.backward()
198 |
199 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
200 | optimizer.step()
201 |
202 | assert target_output.size() == T.Size([27, 10, 100])
203 | assert controller_hidden[0].size() == T.Size([num_hidden_layers, 10, 100])
204 | # assert memory_hidden[0]['memory'].size() == T.Size([10,12,17])
205 | assert last_read is not None
206 |
--------------------------------------------------------------------------------
/test/test_sam_rnn.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # #!/usr/bin/env python3
3 | # # -*- coding: utf-8 -*-
4 |
5 |
6 | import torch as T
7 | import torch.optim as optim
8 |
9 | import sys
10 | import functools
11 |
12 | sys.path.insert(0, ".")
13 |
14 | from dnc import SAM
15 | from test_utils import generate_data, criterion
16 |
17 |
18 | def test_rnn_1():
19 | T.manual_seed(1111)
20 |
21 | input_size = 100
22 | hidden_size = 100
23 | rnn_type = "rnn"
24 | num_layers = 1
25 | num_hidden_layers = 1
26 | dropout = 0
27 | nr_cells = 100
28 | cell_size = 10
29 | read_heads = 1
30 | sparse_reads = 2
31 | device = None
32 | debug = True
33 | lr = 0.001
34 | sequence_max_length = 10
35 | batch_size = 10
36 | cuda = device
37 | clip = 10
38 | length = 10
39 |
40 | rnn = SAM(
41 | input_size=input_size,
42 | hidden_size=hidden_size,
43 | rnn_type=rnn_type,
44 | num_layers=num_layers,
45 | num_hidden_layers=num_hidden_layers,
46 | dropout=dropout,
47 | nr_cells=nr_cells,
48 | cell_size=cell_size,
49 | read_heads=read_heads,
50 | sparse_reads=sparse_reads,
51 | device=device,
52 | debug=debug,
53 | )
54 |
55 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
56 | optimizer.zero_grad()
57 |
58 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
59 |
60 | output, (chx, mhx, rv), v = rnn(input_data, None)
61 |
62 | # Make output and target compatible for loss calculation
63 | # target: [batch, seq, features] -> [seq, batch, features]
64 | target_output = target_output.permute(1, 0, 2).contiguous()
65 |
66 | loss = criterion(output, target_output)
67 | loss.backward()
68 |
69 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
70 | optimizer.step()
71 |
72 | assert target_output.size() == T.Size([21, 10, 100])
73 | assert chx[0][0].size() == T.Size([10, 100])
74 | # assert mhx['memory'].size() == T.Size([10,1,1])
75 | assert rv.size() == T.Size([10, 10])
76 |
77 |
78 | def test_rnn_n():
79 | T.manual_seed(1111)
80 |
81 | input_size = 100
82 | hidden_size = 100
83 | rnn_type = "rnn"
84 | num_layers = 3
85 | num_hidden_layers = 5
86 | dropout = 0.2
87 | nr_cells = 200
88 | cell_size = 17
89 | read_heads = 2
90 | sparse_reads = 4
91 | device = None
92 | debug = True
93 | lr = 0.001
94 | sequence_max_length = 10
95 | batch_size = 10
96 | cuda = device
97 | clip = 20
98 | length = 13
99 |
100 | rnn = SAM(
101 | input_size=input_size,
102 | hidden_size=hidden_size,
103 | rnn_type=rnn_type,
104 | num_layers=num_layers,
105 | num_hidden_layers=num_hidden_layers,
106 | dropout=dropout,
107 | nr_cells=nr_cells,
108 | cell_size=cell_size,
109 | read_heads=read_heads,
110 | sparse_reads=sparse_reads,
111 | device=device,
112 | debug=debug,
113 | )
114 |
115 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
116 | optimizer.zero_grad()
117 |
118 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
119 |
120 | output, (chx, mhx, rv), v = rnn(input_data, None)
121 |
122 | # Make output and target compatible for loss calculation
123 | # target: [batch, seq, features] -> [seq, batch, features]
124 | target_output = target_output.permute(1, 0, 2).contiguous()
125 |
126 | loss = criterion(output, target_output)
127 | loss.backward()
128 |
129 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
130 | optimizer.step()
131 |
132 | assert target_output.size() == T.Size([27, 10, 100])
133 | assert chx[0].size() == T.Size([num_hidden_layers, 10, 100])
134 | # assert mhx['memory'].size() == T.Size([10,12,17])
135 | assert rv.size() == T.Size([10, 34])
136 |
137 |
138 | def test_rnn_no_memory_pass():
139 | T.manual_seed(1111)
140 |
141 | input_size = 100
142 | hidden_size = 100
143 | rnn_type = "rnn"
144 | num_layers = 3
145 | num_hidden_layers = 5
146 | dropout = 0.2
147 | nr_cells = 5000
148 | cell_size = 17
149 | sparse_reads = 3
150 | device = None
151 | debug = True
152 | lr = 0.001
153 | sequence_max_length = 10
154 | batch_size = 10
155 | cuda = device
156 | clip = 20
157 | length = 13
158 |
159 | rnn = SAM(
160 | input_size=input_size,
161 | hidden_size=hidden_size,
162 | rnn_type=rnn_type,
163 | num_layers=num_layers,
164 | num_hidden_layers=num_hidden_layers,
165 | dropout=dropout,
166 | nr_cells=nr_cells,
167 | cell_size=cell_size,
168 | sparse_reads=sparse_reads,
169 | device=device,
170 | debug=debug,
171 | )
172 |
173 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
174 | optimizer.zero_grad()
175 |
176 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
177 |
178 | # Make output and target compatible for loss calculation
179 | # target: [batch, seq, features] -> [seq, batch, features]
180 | target_output = target_output.permute(1, 0, 2).contiguous()
181 |
182 | # Initialize hidden state explicitly
183 | controller_hidden = None
184 | memory_hidden = None
185 | last_read = None
186 | outputs = []
187 |
188 | for x in range(6):
189 | output, (controller_hidden, memory_hidden, last_read), v = rnn(
190 | input_data, (controller_hidden, memory_hidden, last_read), pass_through_memory=False
191 | )
192 | outputs.append(output)
193 |
194 | # Sum outputs for all iterations
195 | output = functools.reduce(lambda x, y: x + y, outputs)
196 | loss = criterion(output, target_output)
197 | loss.backward()
198 |
199 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
200 | optimizer.step()
201 |
202 | assert target_output.size() == T.Size([27, 10, 100])
203 | assert controller_hidden[0].size() == T.Size([num_hidden_layers, 10, 100])
204 | # assert memory_hidden[0]['memory'].size() == T.Size([10,12,17])
205 | assert last_read is not None
206 |
--------------------------------------------------------------------------------
/test/test_sam_lstm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # #!/usr/bin/env python3
3 | # # -*- coding: utf-8 -*-
4 |
5 |
6 | import torch as T
7 | import torch.optim as optim
8 |
9 | import sys
10 | import functools
11 |
12 | sys.path.insert(0, ".")
13 |
14 | from dnc import SAM
15 | from test_utils import generate_data, criterion
16 |
17 |
18 | def test_rnn_1():
19 | T.manual_seed(1111)
20 |
21 | input_size = 100
22 | hidden_size = 100
23 | rnn_type = "lstm"
24 | num_layers = 1
25 | num_hidden_layers = 1
26 | dropout = 0
27 | nr_cells = 100
28 | cell_size = 10
29 | read_heads = 1
30 | sparse_reads = 2
31 | device = None
32 | debug = True
33 | lr = 0.001
34 | sequence_max_length = 10
35 | batch_size = 10
36 | cuda = device
37 | clip = 10
38 | length = 10
39 |
40 | rnn = SAM(
41 | input_size=input_size,
42 | hidden_size=hidden_size,
43 | rnn_type=rnn_type,
44 | num_layers=num_layers,
45 | num_hidden_layers=num_hidden_layers,
46 | dropout=dropout,
47 | nr_cells=nr_cells,
48 | cell_size=cell_size,
49 | read_heads=read_heads,
50 | sparse_reads=sparse_reads,
51 | device=device,
52 | debug=debug,
53 | )
54 |
55 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
56 | optimizer.zero_grad()
57 |
58 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
59 |
60 | output, (chx, mhx, rv), v = rnn(input_data, None)
61 |
62 | # Make output and target compatible for loss calculation
63 | # target: [batch, seq, features] -> [seq, batch, features]
64 | target_output = target_output.permute(1, 0, 2).contiguous()
65 |
66 | loss = criterion(output, target_output)
67 | loss.backward()
68 |
69 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
70 | optimizer.step()
71 |
72 | assert target_output.size() == T.Size([21, 10, 100])
73 | assert chx[0][0][0].size() == T.Size([10, 100])
74 | # assert mhx['memory'].size() == T.Size([10,1,1])
75 | assert rv.size() == T.Size([10, 10])
76 |
77 |
78 | def test_rnn_n():
79 | T.manual_seed(1111)
80 |
81 | input_size = 100
82 | hidden_size = 100
83 | rnn_type = "lstm"
84 | num_layers = 3
85 | num_hidden_layers = 5
86 | dropout = 0.2
87 | nr_cells = 200
88 | cell_size = 17
89 | read_heads = 2
90 | sparse_reads = 4
91 | device = None
92 | debug = True
93 | lr = 0.001
94 | sequence_max_length = 10
95 | batch_size = 10
96 | cuda = device
97 | clip = 20
98 | length = 13
99 |
100 | rnn = SAM(
101 | input_size=input_size,
102 | hidden_size=hidden_size,
103 | rnn_type=rnn_type,
104 | num_layers=num_layers,
105 | num_hidden_layers=num_hidden_layers,
106 | dropout=dropout,
107 | nr_cells=nr_cells,
108 | cell_size=cell_size,
109 | read_heads=read_heads,
110 | sparse_reads=sparse_reads,
111 | device=device,
112 | debug=debug,
113 | )
114 |
115 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
116 | optimizer.zero_grad()
117 |
118 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
119 |
120 | output, (chx, mhx, rv), v = rnn(input_data, None)
121 |
122 | # Make output and target compatible for loss calculation
123 | # target: [batch, seq, features] -> [seq, batch, features]
124 | target_output = target_output.permute(1, 0, 2).contiguous()
125 |
126 | loss = criterion(output, target_output)
127 | loss.backward()
128 |
129 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
130 | optimizer.step()
131 |
132 | assert target_output.size() == T.Size([27, 10, 100])
133 | assert chx[0][0].size() == T.Size([num_hidden_layers, 10, 100])
134 | # assert mhx['memory'].size() == T.Size([10,12,17])
135 | assert rv.size() == T.Size([10, 34])
136 |
137 |
138 | def test_rnn_no_memory_pass():
139 | T.manual_seed(1111)
140 |
141 | input_size = 100
142 | hidden_size = 100
143 | rnn_type = "lstm"
144 | num_layers = 3
145 | num_hidden_layers = 5
146 | dropout = 0.2
147 | nr_cells = 5000
148 | cell_size = 17
149 | sparse_reads = 3
150 | device = None
151 | debug = True
152 | lr = 0.001
153 | sequence_max_length = 10
154 | batch_size = 10
155 | cuda = device
156 | clip = 20
157 | length = 13
158 |
159 | rnn = SAM(
160 | input_size=input_size,
161 | hidden_size=hidden_size,
162 | rnn_type=rnn_type,
163 | num_layers=num_layers,
164 | num_hidden_layers=num_hidden_layers,
165 | dropout=dropout,
166 | nr_cells=nr_cells,
167 | cell_size=cell_size,
168 | sparse_reads=sparse_reads,
169 | device=device,
170 | debug=debug,
171 | )
172 |
173 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
174 | optimizer.zero_grad()
175 |
176 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
177 |
178 | # Make output and target compatible for loss calculation
179 | # target: [batch, seq, features] -> [seq, batch, features]
180 | target_output = target_output.permute(1, 0, 2).contiguous()
181 |
182 | # Initialize hidden state explicitly
183 | controller_hidden = None
184 | memory_hidden = None
185 | last_read = None
186 | outputs = []
187 |
188 | for x in range(6):
189 | output, (controller_hidden, memory_hidden, last_read), v = rnn(
190 | input_data, (controller_hidden, memory_hidden, last_read), pass_through_memory=False
191 | )
192 | outputs.append(output)
193 |
194 | # Sum outputs for all iterations
195 | output = functools.reduce(lambda x, y: x + y, outputs)
196 | loss = criterion(output, target_output)
197 | loss.backward()
198 |
199 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
200 | optimizer.step()
201 |
202 | assert target_output.size() == T.Size([27, 10, 100])
203 | assert controller_hidden[0][0].size() == T.Size([num_hidden_layers, 10, 100])
204 | # assert memory_hidden[0]['memory'].size() == T.Size([10,12,17])
205 | assert last_read is not None
206 |
--------------------------------------------------------------------------------
/test/test_sdnc_rnn.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # #!/usr/bin/env python3
3 | # # -*- coding: utf-8 -*-
4 |
5 |
6 | import torch as T
7 | import torch.optim as optim
8 |
9 | import sys
10 | import functools
11 |
12 | sys.path.insert(0, ".")
13 |
14 | from dnc import SDNC
15 | from test_utils import generate_data, criterion
16 |
17 |
18 | def test_rnn_1():
19 | T.manual_seed(1111)
20 |
21 | input_size = 100
22 | hidden_size = 100
23 | rnn_type = "rnn"
24 | num_layers = 1
25 | num_hidden_layers = 1
26 | dropout = 0
27 | nr_cells = 100
28 | cell_size = 10
29 | read_heads = 1
30 | sparse_reads = 2
31 | temporal_reads = 1
32 | device = None
33 | debug = True
34 | lr = 0.001
35 | sequence_max_length = 10
36 | batch_size = 10
37 | cuda = device
38 | clip = 10
39 | length = 10
40 |
41 | rnn = SDNC(
42 | input_size=input_size,
43 | hidden_size=hidden_size,
44 | rnn_type=rnn_type,
45 | num_layers=num_layers,
46 | num_hidden_layers=num_hidden_layers,
47 | dropout=dropout,
48 | nr_cells=nr_cells,
49 | cell_size=cell_size,
50 | read_heads=read_heads,
51 | sparse_reads=sparse_reads,
52 | temporal_reads=temporal_reads,
53 | device=device,
54 | debug=debug,
55 | )
56 |
57 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
58 | optimizer.zero_grad()
59 |
60 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
61 |
62 | output, (chx, mhx, rv), v = rnn(input_data, None)
63 |
64 | # Make output and target compatible for loss calculation
65 | # target: [batch, seq, features] -> [seq, batch, features]
66 | target_output = target_output.permute(1, 0, 2).contiguous()
67 |
68 | loss = criterion(output, target_output)
69 | loss.backward()
70 |
71 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
72 | optimizer.step()
73 |
74 | assert target_output.size() == T.Size([21, 10, 100])
75 | assert chx[0][0].size() == T.Size([10, 100])
76 | # assert mhx[0]['memory'].size() == T.Size([10,1,1])
77 | assert rv.size() == T.Size([10, 10])
78 |
79 |
80 | def test_rnn_n():
81 | T.manual_seed(1111)
82 |
83 | input_size = 100
84 | hidden_size = 100
85 | rnn_type = "rnn"
86 | num_layers = 3
87 | num_hidden_layers = 5
88 | dropout = 0.2
89 | nr_cells = 200
90 | cell_size = 17
91 | read_heads = 2
92 | sparse_reads = 4
93 | temporal_reads = 3
94 | device = None
95 | debug = True
96 | lr = 0.001
97 | sequence_max_length = 10
98 | batch_size = 10
99 | cuda = device
100 | clip = 20
101 | length = 13
102 |
103 | rnn = SDNC(
104 | input_size=input_size,
105 | hidden_size=hidden_size,
106 | rnn_type=rnn_type,
107 | num_layers=num_layers,
108 | num_hidden_layers=num_hidden_layers,
109 | dropout=dropout,
110 | nr_cells=nr_cells,
111 | cell_size=cell_size,
112 | read_heads=read_heads,
113 | sparse_reads=sparse_reads,
114 | temporal_reads=temporal_reads,
115 | device=device,
116 | debug=debug,
117 | )
118 |
119 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
120 | optimizer.zero_grad()
121 |
122 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
123 |
124 | output, (chx, mhx, rv), v = rnn(input_data, None)
125 |
126 | # Make output and target compatible for loss calculation
127 | # target: [batch, seq, features] -> [seq, batch, features]
128 | target_output = target_output.permute(1, 0, 2).contiguous()
129 |
130 | loss = criterion(output, target_output)
131 | loss.backward()
132 |
133 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
134 | optimizer.step()
135 |
136 | assert target_output.size() == T.Size([27, 10, 100])
137 | assert chx[0].size() == T.Size([num_hidden_layers, 10, 100])
138 | # assert mhx[0]['memory'].size() == T.Size([10,12,17])
139 | assert rv.size() == T.Size([10, 34])
140 |
141 |
142 | def test_rnn_no_memory_pass():
143 | T.manual_seed(1111)
144 |
145 | input_size = 100
146 | hidden_size = 100
147 | rnn_type = "rnn"
148 | num_layers = 3
149 | num_hidden_layers = 5
150 | dropout = 0.2
151 | nr_cells = 5000
152 | cell_size = 17
153 | sparse_reads = 3
154 | temporal_reads = 4
155 | device = None
156 | debug = True
157 | lr = 0.001
158 | sequence_max_length = 10
159 | batch_size = 10
160 | cuda = device
161 | clip = 20
162 | length = 13
163 |
164 | rnn = SDNC(
165 | input_size=input_size,
166 | hidden_size=hidden_size,
167 | rnn_type=rnn_type,
168 | num_layers=num_layers,
169 | num_hidden_layers=num_hidden_layers,
170 | dropout=dropout,
171 | nr_cells=nr_cells,
172 | cell_size=cell_size,
173 | sparse_reads=sparse_reads,
174 | temporal_reads=temporal_reads,
175 | device=device,
176 | debug=debug,
177 | )
178 |
179 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
180 | optimizer.zero_grad()
181 |
182 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
183 |
184 | # Transform target to match expected output shape
185 | target_output = target_output.permute(1, 0, 2).contiguous()
186 |
187 | # Initialize hidden state explicitly
188 | controller_hidden = None
189 | memory_hidden = None
190 | last_read = None
191 | outputs = []
192 |
193 | for x in range(6):
194 | output, (controller_hidden, memory_hidden, last_read), v = rnn(
195 | input_data, (controller_hidden, memory_hidden, last_read), pass_through_memory=False
196 | )
197 | outputs.append(output)
198 |
199 | # Sum outputs for all iterations
200 | output = functools.reduce(lambda x, y: x + y, outputs)
201 | loss = criterion(output, target_output)
202 | loss.backward()
203 |
204 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
205 | optimizer.step()
206 |
207 | assert target_output.size() == T.Size([27, 10, 100])
208 | assert controller_hidden[0].size() == T.Size([num_hidden_layers, 10, 100])
209 | # assert memory_hidden[0]['memory'].size() == T.Size([10,12,17])
210 | assert last_read is not None
211 |
--------------------------------------------------------------------------------
/test/test_sdnc_gru.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # #!/usr/bin/env python3
3 | # # -*- coding: utf-8 -*-
4 |
5 |
6 | import torch as T
7 | import torch.optim as optim
8 |
9 | import sys
10 | import functools
11 |
12 | sys.path.insert(0, ".")
13 |
14 | from dnc import SDNC
15 | from test_utils import generate_data, criterion
16 |
17 |
18 | def test_rnn_1():
19 | T.manual_seed(1111)
20 |
21 | input_size = 100
22 | hidden_size = 100
23 | rnn_type = "gru"
24 | num_layers = 1
25 | num_hidden_layers = 1
26 | dropout = 0
27 | nr_cells = 100
28 | cell_size = 10
29 | read_heads = 1
30 | sparse_reads = 2
31 | temporal_reads = 1
32 | device = None
33 | debug = True
34 | lr = 0.001
35 | sequence_max_length = 10
36 | batch_size = 10
37 | cuda = device
38 | clip = 10
39 | length = 10
40 |
41 | rnn = SDNC(
42 | input_size=input_size,
43 | hidden_size=hidden_size,
44 | rnn_type=rnn_type,
45 | num_layers=num_layers,
46 | num_hidden_layers=num_hidden_layers,
47 | dropout=dropout,
48 | nr_cells=nr_cells,
49 | cell_size=cell_size,
50 | read_heads=read_heads,
51 | sparse_reads=sparse_reads,
52 | temporal_reads=temporal_reads,
53 | device=device,
54 | debug=debug,
55 | )
56 |
57 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
58 | optimizer.zero_grad()
59 |
60 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
61 |
62 | output, (chx, mhx, rv), v = rnn(input_data, None)
63 |
64 | # Make output and target compatible for loss calculation
65 | # target: [batch, seq, features] -> [seq, batch, features]
66 | target_output = target_output.permute(1, 0, 2).contiguous()
67 |
68 | loss = criterion(output, target_output)
69 | loss.backward()
70 |
71 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
72 | optimizer.step()
73 |
74 | assert target_output.size() == T.Size([21, 10, 100])
75 | assert chx[0][0].size() == T.Size([10, 100])
76 | # assert mhx['memory'].size() == T.Size([10,1,1])
77 | assert rv.size() == T.Size([10, 10])
78 |
79 |
80 | def test_rnn_n():
81 | T.manual_seed(1111)
82 |
83 | input_size = 100
84 | hidden_size = 100
85 | rnn_type = "gru"
86 | num_layers = 3
87 | num_hidden_layers = 5
88 | dropout = 0.2
89 | nr_cells = 200
90 | cell_size = 17
91 | read_heads = 2
92 | sparse_reads = 4
93 | temporal_reads = 3
94 | device = None
95 | debug = True
96 | lr = 0.001
97 | sequence_max_length = 10
98 | batch_size = 10
99 | cuda = device
100 | clip = 20
101 | length = 13
102 |
103 | rnn = SDNC(
104 | input_size=input_size,
105 | hidden_size=hidden_size,
106 | rnn_type=rnn_type,
107 | num_layers=num_layers,
108 | num_hidden_layers=num_hidden_layers,
109 | dropout=dropout,
110 | nr_cells=nr_cells,
111 | cell_size=cell_size,
112 | read_heads=read_heads,
113 | sparse_reads=sparse_reads,
114 | temporal_reads=temporal_reads,
115 | device=device,
116 | debug=debug,
117 | )
118 |
119 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
120 | optimizer.zero_grad()
121 |
122 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
123 |
124 | output, (chx, mhx, rv), v = rnn(input_data, None)
125 |
126 | # Make output and target compatible for loss calculation
127 | # target: [batch, seq, features] -> [seq, batch, features]
128 | target_output = target_output.permute(1, 0, 2).contiguous()
129 |
130 | loss = criterion(output, target_output)
131 | loss.backward()
132 |
133 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
134 | optimizer.step()
135 |
136 | assert target_output.size() == T.Size([27, 10, 100])
137 | assert chx[0].size() == T.Size([num_hidden_layers, 10, 100])
138 | # assert mhx['memory'].size() == T.Size([10,12,17])
139 | assert rv.size() == T.Size([10, 34])
140 |
141 |
142 | def test_rnn_no_memory_pass():
143 | T.manual_seed(1111)
144 |
145 | input_size = 100
146 | hidden_size = 100
147 | rnn_type = "gru"
148 | num_layers = 3
149 | num_hidden_layers = 5
150 | dropout = 0.2
151 | nr_cells = 5000
152 | cell_size = 17
153 | sparse_reads = 3
154 | temporal_reads = 4
155 | device = None
156 | debug = True
157 | lr = 0.001
158 | sequence_max_length = 10
159 | batch_size = 10
160 | cuda = device
161 | clip = 20
162 | length = 13
163 |
164 | rnn = SDNC(
165 | input_size=input_size,
166 | hidden_size=hidden_size,
167 | rnn_type=rnn_type,
168 | num_layers=num_layers,
169 | num_hidden_layers=num_hidden_layers,
170 | dropout=dropout,
171 | nr_cells=nr_cells,
172 | cell_size=cell_size,
173 | sparse_reads=sparse_reads,
174 | temporal_reads=temporal_reads,
175 | device=device,
176 | debug=debug,
177 | )
178 |
179 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
180 | optimizer.zero_grad()
181 |
182 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
183 |
184 | # Make output and target compatible for loss calculation
185 | # target: [batch, seq, features] -> [seq, batch, features]
186 | target_output = target_output.permute(1, 0, 2).contiguous()
187 |
188 | # Initialize hidden state explicitly
189 | controller_hidden = None
190 | memory_hidden = None
191 | last_read = None
192 | outputs = []
193 |
194 | for x in range(6):
195 | output, (controller_hidden, memory_hidden, last_read), v = rnn(
196 | input_data, (controller_hidden, memory_hidden, last_read), pass_through_memory=False
197 | )
198 | outputs.append(output)
199 |
200 | # Sum outputs for all iterations
201 | output = functools.reduce(lambda x, y: x + y, outputs)
202 | loss = criterion(output, target_output)
203 | loss.backward()
204 |
205 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
206 | optimizer.step()
207 |
208 | assert target_output.size() == T.Size([27, 10, 100])
209 | assert controller_hidden[0].size() == T.Size([num_hidden_layers, 10, 100])
210 | # assert memory_hidden[0]['memory'].size() == T.Size([10,12,17])
211 | assert last_read is not None
212 |
--------------------------------------------------------------------------------
/test/test_sdnc_lstm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # #!/usr/bin/env python3
3 | # # -*- coding: utf-8 -*-
4 |
5 |
6 | import torch as T
7 | import torch.optim as optim
8 |
9 | import sys
10 | import functools
11 |
12 | sys.path.insert(0, ".")
13 |
14 | from dnc import SDNC
15 | from test_utils import generate_data, criterion
16 |
17 |
18 | def test_rnn_1():
19 | T.manual_seed(1111)
20 |
21 | input_size = 100
22 | hidden_size = 100
23 | rnn_type = "lstm"
24 | num_layers = 1
25 | num_hidden_layers = 1
26 | dropout = 0
27 | nr_cells = 100
28 | cell_size = 10
29 | read_heads = 1
30 | sparse_reads = 2
31 | temporal_reads = 1
32 | device = None
33 | debug = True
34 | lr = 0.001
35 | sequence_max_length = 10
36 | batch_size = 10
37 | cuda = device
38 | clip = 10
39 | length = 10
40 |
41 | rnn = SDNC(
42 | input_size=input_size,
43 | hidden_size=hidden_size,
44 | rnn_type=rnn_type,
45 | num_layers=num_layers,
46 | num_hidden_layers=num_hidden_layers,
47 | dropout=dropout,
48 | nr_cells=nr_cells,
49 | cell_size=cell_size,
50 | read_heads=read_heads,
51 | sparse_reads=sparse_reads,
52 | temporal_reads=temporal_reads,
53 | device=device,
54 | debug=debug,
55 | )
56 |
57 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
58 | optimizer.zero_grad()
59 |
60 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
61 |
62 | output, (chx, mhx, rv), v = rnn(input_data, None)
63 |
64 | # Make output and target compatible for loss calculation
65 | # target: [batch, seq, features] -> [seq, batch, features]
66 | target_output = target_output.permute(1, 0, 2).contiguous()
67 |
68 | loss = criterion(output, target_output)
69 | loss.backward()
70 |
71 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
72 | optimizer.step()
73 |
74 | assert target_output.size() == T.Size([21, 10, 100])
75 | assert chx[0][0][0].size() == T.Size([10, 100])
76 | # assert mhx['memory'].size() == T.Size([10,1,1])
77 | assert rv.size() == T.Size([10, 10])
78 |
79 |
80 | def test_rnn_n():
81 | T.manual_seed(1111)
82 |
83 | input_size = 100
84 | hidden_size = 100
85 | rnn_type = "lstm"
86 | num_layers = 3
87 | num_hidden_layers = 5
88 | dropout = 0.2
89 | nr_cells = 200
90 | cell_size = 17
91 | read_heads = 2
92 | sparse_reads = 4
93 | temporal_reads = 3
94 | device = None
95 | debug = True
96 | lr = 0.001
97 | sequence_max_length = 10
98 | batch_size = 10
99 | cuda = device
100 | clip = 20
101 | length = 13
102 |
103 | rnn = SDNC(
104 | input_size=input_size,
105 | hidden_size=hidden_size,
106 | rnn_type=rnn_type,
107 | num_layers=num_layers,
108 | num_hidden_layers=num_hidden_layers,
109 | dropout=dropout,
110 | nr_cells=nr_cells,
111 | cell_size=cell_size,
112 | read_heads=read_heads,
113 | sparse_reads=sparse_reads,
114 | temporal_reads=temporal_reads,
115 | device=device,
116 | debug=debug,
117 | )
118 |
119 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
120 | optimizer.zero_grad()
121 |
122 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
123 |
124 | output, (chx, mhx, rv), v = rnn(input_data, None)
125 |
126 | # Make output and target compatible for loss calculation
127 | # target: [batch, seq, features] -> [seq, batch, features]
128 | target_output = target_output.permute(1, 0, 2).contiguous()
129 |
130 | loss = criterion(output, target_output)
131 | loss.backward()
132 |
133 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
134 | optimizer.step()
135 |
136 | assert target_output.size() == T.Size([27, 10, 100])
137 | assert chx[0][0].size() == T.Size([num_hidden_layers, 10, 100])
138 | # assert mhx['memory'].size() == T.Size([10,12,17])
139 | assert rv.size() == T.Size([10, 34])
140 |
141 |
142 | def test_rnn_no_memory_pass():
143 | T.manual_seed(1111)
144 |
145 | input_size = 100
146 | hidden_size = 100
147 | rnn_type = "lstm"
148 | num_layers = 3
149 | num_hidden_layers = 5
150 | dropout = 0.2
151 | nr_cells = 5000
152 | cell_size = 17
153 | sparse_reads = 3
154 | temporal_reads = 4
155 | device = None
156 | debug = True
157 | lr = 0.001
158 | sequence_max_length = 10
159 | batch_size = 10
160 | cuda = device
161 | clip = 20
162 | length = 13
163 |
164 | rnn = SDNC(
165 | input_size=input_size,
166 | hidden_size=hidden_size,
167 | rnn_type=rnn_type,
168 | num_layers=num_layers,
169 | num_hidden_layers=num_hidden_layers,
170 | dropout=dropout,
171 | nr_cells=nr_cells,
172 | cell_size=cell_size,
173 | sparse_reads=sparse_reads,
174 | temporal_reads=temporal_reads,
175 | device=device,
176 | debug=debug,
177 | )
178 |
179 | optimizer = optim.Adam(rnn.parameters(), lr=lr)
180 | optimizer.zero_grad()
181 |
182 | input_data, target_output = generate_data(batch_size, length, input_size, cuda)
183 |
184 | # Make output and target compatible for loss calculation
185 | # target: [batch, seq, features] -> [seq, batch, features]
186 | target_output = target_output.permute(1, 0, 2).contiguous()
187 |
188 | # Initialize hidden state explicitly
189 | controller_hidden = None
190 | memory_hidden = None
191 | last_read = None
192 | outputs = []
193 |
194 | for x in range(6):
195 | output, (controller_hidden, memory_hidden, last_read), v = rnn(
196 | input_data, (controller_hidden, memory_hidden, last_read), pass_through_memory=False
197 | )
198 | outputs.append(output)
199 |
200 | # Sum outputs for all iterations
201 | output = functools.reduce(lambda x, y: x + y, outputs)
202 | loss = criterion(output, target_output)
203 | loss.backward()
204 |
205 | T.nn.utils.clip_grad_norm_(rnn.parameters(), clip)
206 | optimizer.step()
207 |
208 | assert target_output.size() == T.Size([27, 10, 100])
209 | assert controller_hidden[0][0].size() == T.Size([num_hidden_layers, 10, 100])
210 | # assert memory_hidden[0]['memory'].size() == T.Size([10,12,17])
211 | assert last_read is not None
212 |
--------------------------------------------------------------------------------
/dnc/sdnc.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import torch
5 |
6 | from .dnc import DNC
7 | from .sparse_temporal_memory import SparseTemporalMemory
8 |
9 |
10 | class SDNC(DNC):
11 | """Sparse Differentiable Neural Computer (SDNC) module."""
12 |
13 | def __init__(
14 | self,
15 | input_size: int,
16 | hidden_size: int,
17 | rnn_type: str = "lstm",
18 | num_layers: int = 1,
19 | num_hidden_layers: int = 2,
20 | bias: bool = True,
21 | batch_first: bool = True,
22 | dropout: float = 0,
23 | bidirectional: bool = False,
24 | nr_cells: int = 5000,
25 | sparse_reads: int = 4,
26 | temporal_reads: int = 4,
27 | read_heads: int = 4,
28 | cell_size: int = 10,
29 | nonlinearity: str = "tanh",
30 | independent_linears: bool = False,
31 | share_memory: bool = True,
32 | debug: bool = False,
33 | clip: float = 20,
34 | device: torch.device | None = None,
35 | ):
36 | """
37 |
38 | Args:
39 | input_size: Input size.
40 | hidden_size: Hidden size.
41 | rnn_type: Type of RNN cell (lstm, gru, rnn).
42 | num_layers: Number of RNN layers.
43 | num_hidden_layers: Number of hidden layers in each RNN.
44 | bias: Whether to use bias in the RNN.
45 | batch_first: Whether the input is batch-first.
46 | dropout: Dropout rate.
47 | bidirectional: Whether the RNN is bidirectional.
48 | nr_cells: Number of memory cells.
49 | sparse_reads: Number of sparse reads.
50 | temporal_reads: Number of temporal reads.
51 | read_heads: Number of read heads.
52 | cell_size: Size of each memory cell.
53 | nonlinearity: Nonlinearity for RNN ('tanh' or 'relu').
54 | independent_linears: Whether to use independent linear layers in memory.
55 | share_memory: Whether to share memory across layers.
56 | debug: Whether to enable debug mode.
57 | clip: Value to clip controller output.
58 | device: the device to use
59 | """
60 | super(SDNC, self).__init__(
61 | input_size=input_size,
62 | hidden_size=hidden_size,
63 | rnn_type=rnn_type,
64 | num_layers=num_layers,
65 | num_hidden_layers=num_hidden_layers,
66 | bias=bias,
67 | batch_first=batch_first,
68 | dropout=dropout,
69 | nr_cells=nr_cells,
70 | read_heads=read_heads,
71 | cell_size=cell_size,
72 | nonlinearity=nonlinearity,
73 | independent_linears=independent_linears,
74 | share_memory_between_layers=share_memory,
75 | debug=debug,
76 | clip=clip,
77 | device=device,
78 | )
79 |
80 | self.sparse_reads = sparse_reads
81 | self.temporal_reads = temporal_reads
82 | self.device = device
83 |
84 | self.memories = []
85 |
86 | for layer in range(self.num_layers):
87 | # memories for each layer
88 | if not self.share_memory_between_layers:
89 | self.memories.append(
90 | SparseTemporalMemory(
91 | input_size=self.output_size,
92 | mem_size=self.nr_cells,
93 | cell_size=self.w,
94 | sparse_reads=self.sparse_reads,
95 | read_heads=self.read_heads,
96 | temporal_reads=self.temporal_reads,
97 | device=self.device,
98 | independent_linears=self.independent_linears,
99 | )
100 | )
101 | setattr(self, "rnn_layer_memory_" + str(layer), self.memories[layer])
102 |
103 | # only one memory shared by all layers
104 | if self.share_memory_between_layers:
105 | self.memories.append(
106 | SparseTemporalMemory(
107 | input_size=self.output_size,
108 | mem_size=self.nr_cells,
109 | cell_size=self.w,
110 | sparse_reads=self.sparse_reads,
111 | read_heads=self.read_heads,
112 | temporal_reads=self.temporal_reads,
113 | device=self.device,
114 | independent_linears=self.independent_linears,
115 | )
116 | )
117 | setattr(self, "rnn_layer_memory_shared", self.memories[0])
118 |
119 | def _debug(self, mhx: dict, debug_obj: dict | None) -> dict | None:
120 | """Debug function to collect memory information.
121 |
122 | Args:
123 | mhx: Memory hidden state.
124 | debug_obj: Debug object to store information.
125 |
126 | Returns:
127 | Updated debug object or None.
128 | """
129 | if not debug_obj:
130 | debug_obj = {
131 | "memory": [],
132 | "visible_memory": [],
133 | "link_matrix": [],
134 | "rev_link_matrix": [],
135 | "precedence": [],
136 | "read_weights": [],
137 | "write_weights": [],
138 | "read_vectors": [],
139 | "least_used_mem": [],
140 | "usage": [],
141 | "read_positions": [],
142 | }
143 |
144 | debug_obj["memory"].append(mhx["memory"][0].detach().cpu().numpy())
145 | debug_obj["visible_memory"].append(mhx["visible_memory"][0].detach().cpu().numpy())
146 | debug_obj["link_matrix"].append(mhx["link_matrix"][0].detach().cpu().numpy())
147 | debug_obj["rev_link_matrix"].append(mhx["rev_link_matrix"][0].detach().cpu().numpy())
148 | debug_obj["precedence"].append(mhx["precedence"][0].unsqueeze(0).detach().cpu().numpy())
149 | debug_obj["read_weights"].append(mhx["read_weights"][0].unsqueeze(0).detach().cpu().numpy())
150 | debug_obj["write_weights"].append(mhx["write_weights"][0].unsqueeze(0).detach().cpu().numpy())
151 | debug_obj["read_vectors"].append(mhx["read_vectors"][0].detach().cpu().numpy())
152 | debug_obj["least_used_mem"].append(mhx["least_used_mem"][0].unsqueeze(0).detach().cpu().numpy())
153 | debug_obj["usage"].append(mhx["usage"][0].unsqueeze(0).detach().cpu().numpy())
154 | debug_obj["read_positions"].append(mhx["read_positions"][0].unsqueeze(0).detach().cpu().numpy())
155 |
156 | return debug_obj
157 |
--------------------------------------------------------------------------------
/dnc/util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from typing import Callable
8 |
9 | δ = 1e-6
10 |
11 |
12 | def recursiveTrace(obj: torch.Tensor | torch.nn.Module | None) -> None:
13 | """Recursively traces the computational graph of a tensor or module.
14 |
15 | Args:
16 | obj: The tensor or module to trace.
17 | """
18 | if obj is None:
19 | return
20 |
21 | print(type(obj))
22 | if hasattr(obj, "grad_fn"):
23 | print(obj.grad_fn)
24 | recursiveTrace(obj.grad_fn) # type: ignore
25 | elif hasattr(obj, "next_functions"):
26 | print(obj.requires_grad, len(obj.next_functions)) # type: ignore
27 | for f, _ in obj.next_functions: # type: ignore
28 | recursiveTrace(f)
29 |
30 |
31 | def cuda(x: torch.Tensor, requires_grad: bool = False, device: torch.device | None = None) -> torch.Tensor:
32 | """Moves a tensor to the specified device (CPU or GPU).
33 |
34 | Args:
35 | x: The tensor to move.
36 | requires_grad: Whether the tensor should require gradients.
37 | device: The device to move the tensor to. Defaults to CPU.
38 |
39 | Returns:
40 | The tensor on the specified device.
41 | """
42 | if device is None:
43 | return x.float().requires_grad_(requires_grad)
44 | else:
45 | return x.float().to(device).requires_grad_(requires_grad)
46 |
47 |
48 | def cudavec(x: np.ndarray, requires_grad: bool = False, device: torch.device | None = None) -> torch.Tensor:
49 | """Creates a tensor from a NumPy array and moves it to the specified device.
50 |
51 | Args:
52 | x: The NumPy array.
53 | requires_grad: Whether the tensor should require gradients.
54 | device: The device. Defaults to cpu.
55 |
56 | Returns:
57 | The tensor on the specified device.
58 | """
59 | return cuda(torch.Tensor(x), requires_grad, device)
60 |
61 |
62 | def cudalong(x: np.ndarray, requires_grad: bool = False, device: torch.device | None = None) -> torch.Tensor:
63 | """Creates a LongTensor from a NumPy array and moves it to the specified device.
64 |
65 | Args:
66 | x: The NumPy array.
67 | requires_grad: Whether the tensor should require gradients.
68 | device: The device. Defaults to CPU
69 |
70 | Returns:
71 | The LongTensor on the specified device.
72 | """
73 | return cuda(torch.LongTensor(x.astype(np.int64)), requires_grad, device)
74 |
75 |
76 | def θ(a: torch.Tensor, b: torch.Tensor, norm_by: int = 2) -> torch.Tensor:
77 | """Calculates the batchwise cosine similarity between two tensors.
78 |
79 | Args:
80 | a: A 3D tensor (b * m * w).
81 | b: A 3D tensor (b * r * w).
82 | norm_by: The norm to use for normalization.
83 |
84 | Returns:
85 | The batchwise cosine similarity (b * r * m).
86 | """
87 | dot = torch.bmm(a, b.transpose(1, 2))
88 | a_norm = torch.norm(a, p=norm_by, dim=2).unsqueeze(2)
89 | b_norm = torch.norm(b, p=norm_by, dim=2).unsqueeze(1)
90 | cos = dot / (a_norm * b_norm + δ)
91 | return cos.transpose(1, 2).contiguous()
92 |
93 |
94 | def σ(input: torch.Tensor, axis: int = 1) -> torch.Tensor: # NOQA
95 | """Applies the softmax function along a specified axis.
96 |
97 | Args:
98 | input: The input tensor.
99 | axis: The axis along which to apply softmax.
100 |
101 | Returns:
102 | The softmax output.
103 | """
104 | return F.softmax(input, dim=axis)
105 |
106 |
107 | def register_nan_checks(model: nn.Module) -> None:
108 | """Registers backward hooks to check for NaN gradients.
109 |
110 | Args:
111 | model: The model to register hooks on.
112 | """
113 |
114 | def check_grad(
115 | module: nn.Module, grad_input: tuple[torch.Tensor | None, ...], grad_output: tuple[torch.Tensor | None, ...]
116 | ) -> None:
117 | if any(torch.isnan(gi).any() for gi in grad_input if gi is not None):
118 | print(f"NaN gradient in grad_input of {type(module).__name__}")
119 |
120 | for module in model.modules():
121 | module.register_full_backward_hook(check_grad) # type: ignore
122 |
123 |
124 | def apply_dict(dic: dict) -> None:
125 | """Applies gradient NaN checks to a dictionary of variables.
126 |
127 | Args:
128 | dic: The dictionary.
129 | """
130 | for k, v in dic.items():
131 | apply_var(v, k)
132 | if isinstance(v, nn.Module):
133 | key_list = [a for a in dir(v) if not a.startswith("__")]
134 | for key in key_list:
135 | apply_var(getattr(v, key), key)
136 | for pk, pv in v.named_parameters():
137 | apply_var(pv, pk)
138 |
139 |
140 | def apply_var(v: torch.Tensor | nn.Module | None, k: str) -> None:
141 | """Applies gradient NaN checks to a variable.
142 |
143 | Args:
144 | v: The variable.
145 | k: The name of the variable.
146 | """
147 | if isinstance(v, torch.Tensor) and v.requires_grad:
148 | v.register_hook(check_nan_gradient(k))
149 |
150 |
151 | def check_nan_gradient(name: str = "") -> Callable[[torch.Tensor], torch.Tensor | None]:
152 | """Creates a hook to check for NaN gradients.
153 |
154 | Args:
155 | name: The name of the variable.
156 |
157 | Returns:
158 | The hook function.
159 | """
160 |
161 | def f(tensor: torch.Tensor) -> torch.Tensor | None:
162 | if torch.isnan(tensor).any():
163 | print(f"\nnan gradient of {name}:")
164 | return tensor
165 | return None
166 |
167 | return f
168 |
169 |
170 | def ptr(tensor: torch.Tensor) -> int:
171 | """Returns the memory address of a tensor.
172 |
173 | Args:
174 | tensor: The tensor.
175 |
176 | Returns:
177 | The memory address.
178 | """
179 | return tensor.data_ptr()
180 |
181 |
182 | def ensure_gpu(tensor: torch.Tensor | np.ndarray, device: torch.device | None) -> torch.Tensor:
183 | """Ensures a tensor is on the specified GPU.
184 |
185 | Args:
186 | tensor: The tensor or NumPy array.
187 | device: The device
188 |
189 | Returns:
190 | The tensor on the specified GPU.
191 | """
192 | if isinstance(tensor, torch.Tensor) and device is not None:
193 | return tensor.to(device)
194 | elif isinstance(tensor, np.ndarray) and device is not None:
195 | return torch.tensor(tensor, device=device)
196 | elif isinstance(tensor, np.ndarray):
197 | return torch.Tensor(tensor)
198 | else:
199 | return tensor
200 |
201 |
202 | def print_gradient(x: torch.Tensor, name: str) -> None:
203 | """Prints the gradient of a tensor.
204 |
205 | Args:
206 | x: The tensor.
207 | name: name of tensor
208 | """
209 | s = "Gradient of " + name + " ----------------------------------"
210 | x.register_hook(lambda y: print(s, y.squeeze()))
211 |
--------------------------------------------------------------------------------
/docs/dnc.xml:
--------------------------------------------------------------------------------
1 | 7T1Ze6LKtr8mj2d/MhY8MghomByiwMv5BAoEFZQZfv2tUtOddLLP6Xt2x6TTsdMKi1UFtea1qoA7Sjp0arE5bo08hPs7chR2d5R8R5IUyVHoB0P6C4QYcdwFEhdJeIGNvgMWyQAfEa/QOglh+QyxyvN9lRyvQOICDPIsg0H1DLYpirx9jhbl+/AZ4LiJ4bPeMWARbPbwBdo6CavtBcqR4Dtcg0m8fTwzwfKXI/4m2MVFXmfX892RVHT+XA4fNo99Xc9bbjdh3j4BUeM7SiryvLpsHToJ7jFxn5NN+Zuj3667gFn1Mw2Y63U0m30NHy/5fGFV/0iMcFNuIcYf3VHitjrs0SaBNtG1HzHKoYuxGPy1aUvqLxiQ/w7yw7GuIPrNqk2SweLfaFxilOz3Ur7Pi3OvlMJwDEUjOGocJuhyH49leYZ6FVEnSfB4pqrId/BJ49H5gztFp7iKD8Fe95/gEQL+h+DXUcKigt3fkor4xgAk2TA/wKroEcq1wb8Ay1zaXKWaBexlv/0uIuToStDtM/G4AjdXsYy/df6dNWjjyp3XOUXzX5z6aU6BHzjF3pRT7Ben/ledokc35RTzxan/Vaco/pac+uYvvzMGhshlX3fzotrmcZ5t9uPvUDGoi+bMOkzFs1P+xsin9H3CVNgllYNR/mKue+6TIzYsEnTlsLh2gkhX9E/Q8a779NiPDb7JEu4xhVXVX7m3qascgb4PQ8/z49+xnxckICrfjjwGKMz/WyAwBf+zOCCC53URXLGuDK82RQwfsfjXpaaA+02VNM+7/ycCQFEvVHVdJBW8qVgQT4Ri9BfJ/Fe5eCYVlwafUAb+xnL8chn48qs/b625H6w1d0trTb1kzJta62dqyTzTwL/AH2GaH7n2VC/ZW5lm8EIvkboVEOe78JBfO3tH1/0f3PQnZPvNzDHxE7WDR7ubHM4VF3FTHi/1mijpMN3F6wE53FSbO0q47JLKMYvvSClZida8Hd2rcS6gj7l42I4fYrQVG+jLGEuCi+HJaeH5eEPUTGmxmk0kIZ5EwnaXnIH70Xy1HT2Q/CHUwm1weBA8kmn89UMVkGbvrVd1QG6b0OLkZqBRgx4sHubiSksCEHpiMDPFsbAUNXedq/dVNEuCmZGXnWQ3iXZ8mOb3VmSu0CVHxTByyHA0GERkCi0avUJre0/yumEoeyYUOeGOFHVbo2aBEGv1VJhJk3FrCGMpUAVJtGRBFDIxFmI57tBldy4aKV3O0EAXnw13f9+6wmTMPSIIljwWBE1BDJMThCB0k50oTOKl0Aqz6RmX/ry4aqed6bP9UDx6D9xyFWOi0IgoAm0glGA9EQRLRwSUx8lsLCSi8a27ZfxIwC/cL9wv3D8G1zxMhJn80WzX7XGniwr7F/t38HFviiu76dl3th+PR1+4X7jvjauLHrar8Hew7V+4N8CVWm2J7er+N7Dtb4tbKjusR8xvo8tvgEvp7R2p2FRDSTM5nmaNpbOirEJRWE6pApdQhPFeWe4W9ewgSb+oQE3xzwvUgH9RnyYo5mV9mhyRv6BiSb+oXM3hJnzHojV4VrUmnpUpnxz8bFVr4h2r1txX1foDsZ0gb1W1Jt6tah1x6EtnLlVrdKHiTBqh719WuRbH9PfKdWr3pEPNgqgDIKOG03K5VMPkgK7yJLvHSinmymK9e3g4oSsQ1bEG2aCuI7pkTYQSU3UNDL8ip8r8Ptlv9mo7F8fL02LWjfq5Oc0OEBLtuPRnZrhlrAyTv4Y+DxPQoP4iOaMyG/IeAmelhr6pqdV48WFJCsmpAQxB8BTNNLpqeyNai83joK0N1oE8F3gxopLKOpgwppDasTyboMGNJOMzQGSD9Q9NI69A2/voICKNE9UDw7an6IGG7ThiLDuyJsPe1holLMBwWGpq4O14oQf3dDR2Ow4OK74PIjy/gPy4jU8hCinmmjCbfEH/IVQc1C3HDDzbJs66LPuw3CDzIWakFy2pTd0ByyLVtWWUiz4Loo0Z95BMDw65FUtARZS8IsOGNJbYzIjAzgrpyDg947BFU9F0pzfYUIgzFdsrdPYwUqmocbCcwNLbnaEiXUVqGDpHfjdU9ZiS9XDXzJnUMQgYOCZu334cif4YkMHKODg98O68KYAw9kneZPaFWnBEz8Nu0HnYrlSi0KITsJ0SiFsegIqj9lGtrpGpURreBiDgUQiopMQ8Ivn5NKVCswubMbKH4sDZDfUAtzhOLpWwQoGp4tv6esIBdHRuCBU54wsx4kHTS0cUX5u99mEk+nNAxQE2hR4NJrKcHsTeltGxdrLDHLQTf03OLJKPpIjqEu3E74qqsfN6Fqa6bjq+aoTuEGZ0YFe86fZpjZobTqYMmPW7DyXJHwUiG5uIWiikh6zZnD0URLxm68ib85uL6aIYkTSYgWC7LLJCI/A6Pj7RxiTsRth3RzoVs4ImLGnkzw04sQVZS+RW6P6I/U2k4fn/aI4J5ZDxPVWBnQwCZ0tpQcXwKEDXVcPDsRfn21NtyJy1gbJw0YHWfeg01Lbhc6/SdUPjJto6WpHw2LFhhzAUFre29Rz7l7qU2gVJwyzmh5hr+JHapsgiKbxgSPbSbsXJ2OeEhYjdjTCO582fC2GaRbcELQM54+BkDcdb7ZQkjpnMp5aReURma3DJAK2sQTa3ms4yvT22PFBcA9UyNqxDLqoa2DMdRbiK2pRDWNKncwiH/0q9lDyU0BwCpC8Mk9rYlQhwGLOYHwSUBo1c1GQJGNqjhyFN6Bk18QfX7Si9wS6dC/NIUPWx1opJJ3KCW87+gP2phZIBatTwBJ9LJoA7rufTFfT8RithD6ytwKJ93tX8QN5C5BhEO2RRXOTXw6hPG8rDmYfbafcMjrckbY38gUJ2GdrxgG376cB3XAjJAivFIlyDVt5KIcJVONunyfsIZjKVdtXQfwwv9zmgyFdnIot99EZdbgbdyhqpO/qeMGxIzJpqqHF4RBRsNVgThnB6KGmgKThS5S7qxHR83voWNp+HEnOrx3qLNmpgMaRnk6YPs/IBWOigyOugJeUy4tHx6ZZhO9SJYoX+VcsmVJOhaMAuyAdgZtQ0dd6fQp8Sqpks5hCbkvOSLBdnTSyWXOjtkQ/cjgmeqTOnq/cj22lUb0/ztlzXAzCdVl15aR12SOkJyG3MzAHMBPdlUoVkNj4lk2bFhM28YJqDRnWMTflZgMUBgCii2tpReMtgoTCP9tAKOLkJan44tBU3EwSbs7kZFFR7rMVi2om04H7ufckdDUFk40h58K2yDUt/OuLdsa2T5opxsgIcHC1rYOMZgzc4R5RqznHWKcMQV79EmyMOGZmifBYRGI5AxvQor7J1astAayjs8lzkUVrZTeZRFGlusGxqZmZFg7uKIhPEJsNFWrk513AQos4Fx0AW5CD5cBL7m0JFlLeIzRFHmFpBbg9EMMzZtozsUnD6WUWwrM14+2xJpBYR6apcsnLVq8xyUGGphZZDEZ2HY/RmPV4esT3e8/ncBJsdvp/mbDlHAfLDc2RQRTJlhxR1sG1VDjhu1697QHSN7X4Qab/h/jRdNl6YLflUhSNsnZxS0NaHAHmjDb/imJQMVY9qUi8GnMjdz3CWKe5xWJ/v0eb6gHOpWZFhP6VZEXDqE9S4ZaFmJ2DFFA7o1RwWOE7EbtC3KZEqEBfDiR/vcc7VNu4YHZgKM+wjReuDZT032Vcz7L6L0OcALCe4YhYhWhwUcm6R4Yo56CieT3O677bx3TlRdcoGx41mn6VEavQXm4T+qp52mpEdS7DTKtg5jTSAaHnUsVfJgED6a5sv8LmUqDi4iHMsmyoUL2hpm9A4BchCmwBihsu4QTOoSpTIWCp2hgTYGm6sDyW7N/M8OFhjg6vvnrYcdMi5TQtaDYwECz6zstN6ymQDn+KaQzPvdHRsYYtpVfSUYy9lxIXeufDoAfPBsgdSXOHUVyMn/LBn2zxaJ4AZR5wV+wM9yOfqHpaKlbfoM3s12taCkxmeh7KKxOf4PzGz0rGFAJJizgVWcaX6RvP41OM91f99Iv9XTOK+vCXwayL/PWZ0yVdmdG91azBNviIEJboMNI7uiAgNs+A97xP+TxLxCe4IfpX3o3/I+3NToSg2/ROEY55kVfmkZxsDnhgjgvxhUdF1bl/52QYEN/pB8i7X8F0Ovw3m5xYbvBTNWy02CM6LDXBR7LzcQNar46W08uuWG8jz78sNEr44+M4siLZVwDNGxwCgpaWgeBNpsktnwhSfWorjWYJC2cWDEnLiKhDGEz9HqY34MFaUmfvwIMCpvBPihdnP55Nd0GmI3+lADMPJjccugWuq90ul0OXeNnWc/4wagGMxCCMH50TpJWwYaXJJd0NKWkKGdmW+5wK9Ytv90vFxnbDKUtnk4Ny11+oJh9q6wp/MPOsDVZjSC262EBZyK07EqS1M0vG4FZLfFcba94yNUwsalsZDivSfFLXThUxMZhXeCNEp1VQe5qtousBxF6k2FEM0jXMqMZ2yZejoTFpwbgOcaTWk8pi0Agb6rSBAoXj3Ed6EipoO7VXYHCzbsgK261eZqlqGazHEwWICsCQ8+2DlwZ5/WB7Tg03MmgNZ53xpKgAssRayhnaSV3ZgF+3gF6ADVu5xTJdqYs2LS5ftt9s0xfaEp8OrhK72FzaxCVtmS9uGCYDFdNQR/DoZZHrhN0JJ1n0ucBpv28I4Hu84ITFmS0HSOzSEyWeC1ceh0diSGmyN3uTdwLvJdCg0By59rPimRrDxfgDOoqE2NDTMvYUV3XVJZsAV5/vRmZbi1KjSRqfrpYZNaN1essFRErloayCctaSjrF3scOEZ55V2E4KZj1NCyU4jn1svibjILJxF5ko03fL4pk04/kO0wNY7XPnrKWJeH6C9POKJK3t7JJjQXlvQpbsOESjJq2aF73Xe4BIuWGunPjR4KGd2wbBtunBCIkCa1Eyb8EGfptVaZlI5YmjkDXaZMF4m0kyaQ2EqTKefESbPGeg5dsGzbTejdII0OdIqVhyyw0iel0wdDqmuQsvQdbZHkrxmRpHD0PcrJPOLY3qsUn42wzI63tiygJQfTtApFFrYCbssFmbJIhDm1meDiSOIV7wpdlQwjH0oAqTlLVZa89TyZM3XFjGzDrVPDDVVSCslakCk4vmjk5LuqKbiSrehCJJ33cJngYb74hVAqTbacIG9tm0zLrpoZNkeA1UGNvsd725OBV7fwwx+4KkpR6ir9kHIYvDu1vBtYXIe1NwSsAAaxnrJkScuimyD3WbcQUu5vqRO0srCixEU5yRFdhmuM7V2CmFuBF1XZNNkwWIz0eDln0VFt92qtLPtyG6ak0cskaMMAtrfQ7SxRRGYVIt4ogH/r/PpA0M21ArR3ivSGEfj4kmlRtgm+0HOVKk85SthISyXrTTpREOYNGN0+TGHYIvPBNN3fJ3yD/6xiHkUetErHLi5rNiBbB1QgBzImrIybH+J05EzS+9k6gGVta1vFeZo6KjaADTYhs36PgnoLs27qjrS4QFFb/UlrBg45DGXdpRNowWYBZA/IPvCUMWpdFLZtgIXe0gWML5QCgcnRsZsAQVRmk5bcUzHSEk/E0ybM3SkQzYmlo4yilDykCF5X2pG4DLUocxKRau1YstBLKFUWmaFZFpRxcarAUTm+pJZKcCP1vfWArs9rrOaIkSy78PG2QIcS5hJtsLT1S6TDpqlMuuC2eYlz4k8OCcxIYodsyJzCMZDeQnEa4cabixuU99eIawBNEJO1kmLTP5i8iGo9sawEVw5GbUvkNFGhKGGQd46KZOWa5mI84qsBxjVEc5YiO39HmpYQ84ZCx9YYXYc+OxIoNQFjCi82KCiUlmh3JDCRej8Eg1mfrkK0O/Z4+Io0QEUMcJ2v9H0S840Ge0BrFejs/Y14zPaasDtI2FZ7MT4PpQ+juV+o9i7Gaqa4d1mrd4XA8EG7uATi/Jg4eS67ZDsijsfMFjIlfW4Xm5YggS1ny1cw9+wZlqHTDR1ZIpnMScd3KrBjWyXNfQynFqmEWw6WIo7h8LznA8o3+FEpQY1RXAKdu/UkXRtu/ow0cgbwqbhxTzzbrYisdTR9RiEHKTXPpFkNTZIfYGfNyTOtQzg3AfZDaY4oghcgshXD9gSReRaJwWZ0piOD87CX8iaHKaOySN6mxXfDA1Tao5D5Ejai5TauRRlr1Svuy5Bl1Q+oSmLLozq40XQbxGRg+LUR0awYk0OLjxi0dBBpJU8yayXTHoC6wjXUOYEH2kStjEOgLm4Y8hC5LoORipBzG2UEQ2Bel3ZD2TaHpaZuqo4FICO4EHkWc5/yM65UCRmfJPtYbTK8zbMUk3FTAKDE8ZZpio1n8VazM7FD5H1vS1MSnrsLnlqHzo+kyZphKhJXWjYYKrUiDw4GnHx4td8FU3mOGRvWBQknm3ysWqa4ojzSgpn/NFJcswqPHahI4NhiwOkElt+xqUAjVxIyVYZ8r5WEPa1hnquw6yJbRwScdBHbW1+QJwdE2SjpZIHbWbCVbv1R7DCbx3zcyfsGfnmEj/y9doLHKdBvjYtclzw6wuFGwiAiyE0b6E4qKPsgqDbYbSxh/k2TS27Q/EQdzAqqmn0exxLNdqFk5v9gNeSpOw3nQjzJjtcq7Ucjvs7IqKsyEPREzpPR1AWyrqcTT+CpSdMaevjSOwb1VZ0ToQJh4U8Meq59HCer108rKz5PSO5kwmeLPhVU7bEs0kPlnj5bFCC5F6ZsiV+xf2X9Ae+sfZTzs1S73eTNfHyqc2T7FhXL0TgO08x8dptUsHFcXMeQltsjs95+4whP/ns3ecP9JUoQZRv/6z70auTjc/07vGFHE/17vHtBf/oAerMe6vdKw/K/t01i3+pWSR9I9Ui+RszlPxTOcrdiKPUy7dRfD3i/oMIwc0eT0HdWq3BP33TxSfk9u3io6+XOv3PrzRgmZu+gOblI6OsuvoAoexYYEWWvTUvfghlefKVFPLNQtm/W/SLBhVUeVG+N0v48+fGLGGY5+rBv/LGjzdjCXXj1zPd/e362h/ikc+64PqxPPMeST358slpEhpcke/3iLxIuJIwhNlZGjavRK83VkYWCCKv3FgZ+fdURvrlezmM159oh0ZYPad3Actk2PhnBEzM65JrhM2IdwyuQGL9KK9kxLv7JM7Q9h5GuCtMtiTY7IUruMK6I5aIv0kWL8+K9C/6xsygWPIZMwjqJTPAK7wgf0HYQL98N0521klk/M4DeKIu+00PX3FefwaTAHgeUTCvMIng34pLLy3aF5dej8Gfcwnw4M24hBePfXsP8uXei+9vm6bG/wc=
2 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Table of Contents
4 |
5 | - [Change Log](#change-log)
6 | - [Unreleased](#unreleased)
7 | - [1.0.1 (2019-04-05)](#101-2019-04-05)
8 | - [1.0.0 (2019-04-05)](#100-2019-04-05)
9 | - [0.0.9 (2018-04-23)](#009-2018-04-23)
10 | - [0.0.7 (2017-12-20)](#007-2017-12-20)
11 | - [0.0.6 (2017-11-12)](#006-2017-11-12)
12 | - [0.5.0 (2017-11-01)](#050-2017-11-01)
13 | - [0.0.3 (2017-10-27)](#003-2017-10-27)
14 | - [0.0.2 (2017-10-26)](#002-2017-10-26)
15 | - [v0.0.1 (2017-10-26)](#v001-2017-10-26)
16 |
17 |
18 |
19 | # Change Log
20 |
21 | ## [Unreleased](https://github.com/ixaxaar/pytorch-dnc/tree/HEAD)
22 |
23 | [Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/1.0.1...HEAD)
24 |
25 | **Merged pull requests:**
26 |
27 | - Fixes for \#43 [\#44](https://github.com/ixaxaar/pytorch-dnc/pull/44) ([ixaxaar](https://github.com/ixaxaar))
28 |
29 | ## [1.0.1](https://github.com/ixaxaar/pytorch-dnc/tree/1.0.1) (2019-04-05)
30 | [Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/1.0.0...1.0.1)
31 |
32 | **Closed issues:**
33 |
34 | - When running adding task -- ModuleNotFoundError: No module named 'index' [\#39](https://github.com/ixaxaar/pytorch-dnc/issues/39)
35 | - SyntaxError [\#36](https://github.com/ixaxaar/pytorch-dnc/issues/36)
36 | - PySide dependency error [\#33](https://github.com/ixaxaar/pytorch-dnc/issues/33)
37 | - Issues when using pytorch 0.4 [\#31](https://github.com/ixaxaar/pytorch-dnc/issues/31)
38 | - TypeError: cat received an invalid combination of arguments - got \(list, int\), but expected one of: [\#29](https://github.com/ixaxaar/pytorch-dnc/issues/29)
39 |
40 | **Merged pull requests:**
41 |
42 | - Fixes \#36 and \#39 [\#42](https://github.com/ixaxaar/pytorch-dnc/pull/42) ([ixaxaar](https://github.com/ixaxaar))
43 |
44 | ## [1.0.0](https://github.com/ixaxaar/pytorch-dnc/tree/1.0.0) (2019-04-05)
45 | [Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/0.0.9...1.0.0)
46 |
47 | **Closed issues:**
48 |
49 | - Question about the running speed of Pyflann and Faiss for the SAM model [\#40](https://github.com/ixaxaar/pytorch-dnc/issues/40)
50 | - SyntaxError [\#37](https://github.com/ixaxaar/pytorch-dnc/issues/37)
51 | - Values in hidden become nan [\#35](https://github.com/ixaxaar/pytorch-dnc/issues/35)
52 | - faiss error [\#32](https://github.com/ixaxaar/pytorch-dnc/issues/32)
53 |
54 | **Merged pull requests:**
55 |
56 | - Port to pytorch 1.x [\#41](https://github.com/ixaxaar/pytorch-dnc/pull/41) ([ixaxaar](https://github.com/ixaxaar))
57 | - fix parens in example usage and gpu usage for SAM [\#30](https://github.com/ixaxaar/pytorch-dnc/pull/30) ([kierkegaard13](https://github.com/kierkegaard13))
58 |
59 | ## [0.0.9](https://github.com/ixaxaar/pytorch-dnc/tree/0.0.9) (2018-04-23)
60 | [Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/0.0.7...0.0.9)
61 |
62 | **Fixed bugs:**
63 |
64 | - Use usage vector to determine least recently used memory [\#26](https://github.com/ixaxaar/pytorch-dnc/issues/26)
65 | - Store entire memory after memory limit is reached [\#24](https://github.com/ixaxaar/pytorch-dnc/issues/24)
66 |
67 | **Merged pull requests:**
68 |
69 | - memory.py: fix indexing for read\_modes transform [\#28](https://github.com/ixaxaar/pytorch-dnc/pull/28) ([jbinas](https://github.com/jbinas))
70 | - Bugfixes [\#27](https://github.com/ixaxaar/pytorch-dnc/pull/27) ([ixaxaar](https://github.com/ixaxaar))
71 |
72 | ## [0.0.7](https://github.com/ixaxaar/pytorch-dnc/tree/0.0.7) (2017-12-20)
73 | [Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/0.0.6...0.0.7)
74 |
75 | **Implemented enhancements:**
76 |
77 | - GPU kNNs [\#21](https://github.com/ixaxaar/pytorch-dnc/issues/21)
78 | - Implement temporal addressing for SDNCs [\#18](https://github.com/ixaxaar/pytorch-dnc/issues/18)
79 | - Feature: Sparse Access Memory [\#4](https://github.com/ixaxaar/pytorch-dnc/issues/4)
80 | - SAMs [\#22](https://github.com/ixaxaar/pytorch-dnc/pull/22) ([ixaxaar](https://github.com/ixaxaar))
81 | - Temporal links for SDNC [\#19](https://github.com/ixaxaar/pytorch-dnc/pull/19) ([ixaxaar](https://github.com/ixaxaar))
82 | - SDNC [\#16](https://github.com/ixaxaar/pytorch-dnc/pull/16) ([ixaxaar](https://github.com/ixaxaar))
83 |
84 | **Merged pull requests:**
85 |
86 | - Add more tasks [\#23](https://github.com/ixaxaar/pytorch-dnc/pull/23) ([ixaxaar](https://github.com/ixaxaar))
87 | - Scale interface vectors, dynamic memory pass [\#17](https://github.com/ixaxaar/pytorch-dnc/pull/17) ([ixaxaar](https://github.com/ixaxaar))
88 | - Update README.md [\#14](https://github.com/ixaxaar/pytorch-dnc/pull/14) ([MaxwellRebo](https://github.com/MaxwellRebo))
89 |
90 | ## [0.0.6](https://github.com/ixaxaar/pytorch-dnc/tree/0.0.6) (2017-11-12)
91 | [Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/0.5.0...0.0.6)
92 |
93 | **Implemented enhancements:**
94 |
95 | - Re-write allocation vector code, use pytorch's cumprod [\#13](https://github.com/ixaxaar/pytorch-dnc/issues/13)
96 |
97 | **Fixed bugs:**
98 |
99 | - Stacked DNCs forward pass wrong [\#12](https://github.com/ixaxaar/pytorch-dnc/issues/12)
100 | - Temporal debugging of memory [\#11](https://github.com/ixaxaar/pytorch-dnc/pull/11) ([ixaxaar](https://github.com/ixaxaar))
101 |
102 | ## [0.5.0](https://github.com/ixaxaar/pytorch-dnc/tree/0.5.0) (2017-11-01)
103 | [Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/0.0.3...0.5.0)
104 |
105 | **Implemented enhancements:**
106 |
107 | - Multiple hidden layers per controller layer [\#7](https://github.com/ixaxaar/pytorch-dnc/issues/7)
108 | - Vizdom integration and fix cumprod bug \#5 [\#6](https://github.com/ixaxaar/pytorch-dnc/pull/6) ([ixaxaar](https://github.com/ixaxaar))
109 |
110 | **Fixed bugs:**
111 |
112 | - Use shifted cumprods, emulate tensorflow's cumprod with exclusive=True [\#5](https://github.com/ixaxaar/pytorch-dnc/issues/5)
113 | - Vizdom integration and fix cumprod bug \\#5 [\#6](https://github.com/ixaxaar/pytorch-dnc/pull/6) ([ixaxaar](https://github.com/ixaxaar))
114 |
115 | **Closed issues:**
116 |
117 | - Write unit tests [\#8](https://github.com/ixaxaar/pytorch-dnc/issues/8)
118 | - broken links [\#3](https://github.com/ixaxaar/pytorch-dnc/issues/3)
119 |
120 | **Merged pull requests:**
121 |
122 | - Test travis build [\#10](https://github.com/ixaxaar/pytorch-dnc/pull/10) ([ixaxaar](https://github.com/ixaxaar))
123 | - Implement Hidden layers, small enhancements, cleanups [\#9](https://github.com/ixaxaar/pytorch-dnc/pull/9) ([ixaxaar](https://github.com/ixaxaar))
124 |
125 | ## [0.0.3](https://github.com/ixaxaar/pytorch-dnc/tree/0.0.3) (2017-10-27)
126 | [Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/0.0.2...0.0.3)
127 |
128 | **Implemented enhancements:**
129 |
130 | - Implementation of Dropout for controller [\#2](https://github.com/ixaxaar/pytorch-dnc/pull/2) ([ixaxaar](https://github.com/ixaxaar))
131 | - Fix size issue for GRU and vanilla RNN [\#1](https://github.com/ixaxaar/pytorch-dnc/pull/1) ([ixaxaar](https://github.com/ixaxaar))
132 |
133 | ## [0.0.2](https://github.com/ixaxaar/pytorch-dnc/tree/0.0.2) (2017-10-26)
134 | [Full Changelog](https://github.com/ixaxaar/pytorch-dnc/compare/v0.0.1...0.0.2)
135 |
136 | ## [v0.0.1](https://github.com/ixaxaar/pytorch-dnc/tree/v0.0.1) (2017-10-26)
137 |
138 |
139 | \* *This Change Log was automatically generated by [github_changelog_generator](https://github.com/skywinder/Github-Changelog-Generator)*
140 |
--------------------------------------------------------------------------------
/tasks/adding_task.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import argparse
5 | import os
6 | import sys
7 |
8 | import numpy as np
9 | import torch
10 | from typing import Any
11 | import torch.optim as optim
12 | from torch.nn.utils import clip_grad_norm_
13 | from visdom import Visdom
14 |
15 | sys.path.insert(0, os.path.join("..", ".."))
16 | from dnc import DNC, SDNC, SAM
17 |
18 |
19 | def get_device(cuda_id: int) -> torch.device:
20 | if cuda_id >= 0 and torch.cuda.is_available():
21 | return torch.device(f"cuda:{cuda_id}")
22 | return torch.device("cpu")
23 |
24 |
25 | def onehot(x: int, n: int) -> np.ndarray:
26 | ret = np.zeros(n, dtype=np.float32)
27 | ret[x] = 1.0
28 | return ret
29 |
30 |
31 | def generate_data(length: int, size: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor, str]:
32 | content = np.random.randint(0, size - 1, length)
33 | seqlen = length + 1
34 | x_seq = [onehot(int(val), size) if i < length else onehot(size - 1, size) for i, val in enumerate(content)]
35 | x_seq = np.array(x_seq, dtype=np.float32).reshape(1, seqlen, size) # type: ignore
36 | sums = np.array(np.sum(content), dtype=np.float32).reshape(1, 1, 1)
37 | sums_text = " + ".join(str(val) for val in content)
38 |
39 | return (torch.tensor(x_seq, device=device), torch.tensor(sums, device=device), sums_text)
40 |
41 |
42 | def cross_entropy(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
43 | return (prediction - target) ** 2
44 |
45 |
46 | def main() -> None:
47 | parser = argparse.ArgumentParser(description="PyTorch Differentiable Neural Computer Adding Task")
48 | parser.add_argument("-input_size", type=int, default=6, help="Dimension of input feature")
49 | parser.add_argument("--rnn_type", type=str, default="lstm", help="Type of recurrent cells (lstm, gru, rnn)")
50 | parser.add_argument("--nhid", type=int, default=64, help="Number of hidden units in the controller")
51 | parser.add_argument("--dropout", type=float, default=0, help="Controller dropout rate")
52 | parser.add_argument("--memory_type", type=str, default="dnc", help="Memory type (dnc, sdnc, sam)")
53 | parser.add_argument("--nlayer", type=int, default=1, help="Number of memory layers")
54 | parser.add_argument("--nhlayer", type=int, default=2, help="Number of hidden layers in each RNN")
55 | parser.add_argument("--lr", type=float, default=1e-4, help="Initial learning rate")
56 | parser.add_argument("--optim", type=str, default="adam", help="Optimizer (adam, rmsprop)")
57 | parser.add_argument("--clip", type=float, default=50, help="Gradient clipping value")
58 | parser.add_argument("--batch_size", type=int, default=100, help="Batch size")
59 | parser.add_argument("--mem_size", type=int, default=20, help="Memory cell size")
60 | parser.add_argument("--mem_slot", type=int, default=16, help="Number of memory slots")
61 | parser.add_argument("--read_heads", type=int, default=4, help="Number of read heads")
62 | parser.add_argument("--sparse_reads", type=int, default=10, help="Number of sparse reads per head (sdnc/sam)")
63 | parser.add_argument("--temporal_reads", type=int, default=2, help="Number of temporal reads (sdnc)")
64 | parser.add_argument("--sequence_max_length", type=int, default=1000, help="Maximum sequence length")
65 | parser.add_argument("--cuda", type=int, default=-1, help="CUDA GPU ID (-1 for CPU)")
66 | parser.add_argument("--iterations", type=int, default=2000, help="Total number of iterations")
67 | parser.add_argument("--summarize_freq", type=int, default=100, help="Summarize frequency")
68 | parser.add_argument("--check_freq", type=int, default=100, help="Checkpoint frequency")
69 | parser.add_argument("--visdom", action="store_true", help="Use Visdom for visualization")
70 | args = parser.parse_args()
71 | print(args)
72 |
73 | device = get_device(args.cuda)
74 |
75 | if args.visdom:
76 | viz = Visdom()
77 | if not viz.check_connection():
78 | print("Visdom server not running. Disabling Visdom.")
79 | args.visdom = False
80 |
81 | if args.memory_type == "dnc":
82 | rnn = DNC(
83 | input_size=args.input_size,
84 | hidden_size=args.nhid,
85 | rnn_type=args.rnn_type,
86 | num_layers=args.nlayer,
87 | num_hidden_layers=args.nhlayer,
88 | dropout=args.dropout,
89 | nr_cells=args.mem_slot,
90 | cell_size=args.mem_size,
91 | read_heads=args.read_heads,
92 | device=device,
93 | debug=args.visdom,
94 | batch_first=True,
95 | independent_linears=True,
96 | )
97 | elif args.memory_type == "sdnc":
98 | rnn = SDNC(
99 | input_size=args.input_size,
100 | hidden_size=args.nhid,
101 | rnn_type=args.rnn_type,
102 | num_layers=args.nlayer,
103 | num_hidden_layers=args.nhlayer,
104 | dropout=args.dropout,
105 | nr_cells=args.mem_slot,
106 | cell_size=args.mem_size,
107 | sparse_reads=args.sparse_reads,
108 | temporal_reads=args.temporal_reads,
109 | read_heads=args.read_heads,
110 | device=device,
111 | debug=args.visdom,
112 | batch_first=True,
113 | independent_linears=False,
114 | )
115 | elif args.memory_type == "sam":
116 | rnn = SAM(
117 | input_size=args.input_size,
118 | hidden_size=args.nhid,
119 | rnn_type=args.rnn_type,
120 | num_layers=args.nlayer,
121 | num_hidden_layers=args.nhlayer,
122 | dropout=args.dropout,
123 | nr_cells=args.mem_slot,
124 | cell_size=args.mem_size,
125 | sparse_reads=args.sparse_reads,
126 | read_heads=args.read_heads,
127 | device=device,
128 | debug=args.visdom,
129 | batch_first=True,
130 | independent_linears=False,
131 | )
132 | else:
133 | raise ValueError('Invalid memory_type. Choose "dnc", "sdnc", or "sam".')
134 |
135 | rnn = rnn.to(device)
136 | print(rnn)
137 | optimizer: Any
138 |
139 | if args.optim == "adam":
140 | optimizer = optim.Adam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=(0.9, 0.98))
141 | elif args.optim == "adamax":
142 | optimizer = optim.Adamax(rnn.parameters(), lr=args.lr, eps=1e-9, betas=(0.9, 0.98))
143 | elif args.optim == "rmsprop":
144 | optimizer = optim.RMSprop(rnn.parameters(), lr=args.lr, momentum=0.9, eps=1e-10)
145 | elif args.optim == "sgd":
146 | optimizer = optim.SGD(rnn.parameters(), lr=args.lr)
147 | elif args.optim == "adagrad":
148 | optimizer = optim.Adagrad(rnn.parameters(), lr=args.lr)
149 | elif args.optim == "adadelta":
150 | optimizer = optim.Adadelta(rnn.parameters(), lr=args.lr)
151 | else:
152 | raise ValueError(f"Unsupported optimizer: {args.optim}")
153 |
154 | last_100_losses = []
155 |
156 | for epoch in range(args.iterations + 1):
157 | print(f"\rIteration {epoch}/{args.iterations}", end="")
158 | optimizer.zero_grad()
159 |
160 | random_length = np.random.randint(2, args.sequence_max_length + 1)
161 | input_data, target_output, sums_text = generate_data(random_length, args.input_size, device)
162 | input_data = input_data.repeat(args.batch_size, 1, 1)
163 | target_output = target_output.repeat(args.batch_size, 1, 1)
164 |
165 | output, (chx, mhx, rv) = rnn(input_data, (None, None, None), reset_experience=True, pass_through_memory=True)
166 |
167 | output = output.sum(dim=2, keepdim=True).sum(dim=1, keepdim=True)
168 | loss = cross_entropy(output, target_output)
169 | loss.backward()
170 |
171 | clip_grad_norm_(rnn.parameters(), args.clip)
172 | optimizer.step()
173 | loss_value = loss.item()
174 |
175 | # Detach memory from graph
176 | if mhx is not None:
177 | mhx = {k: (v.detach() if isinstance(v, torch.Tensor) else v) for k, v in mhx.items()}
178 |
179 | last_100_losses.append(loss_value)
180 |
181 | if epoch % args.summarize_freq == 0:
182 | print(f"\rIteration {epoch}/{args.iterations}")
183 | print(f"Avg. Loss: {np.mean(last_100_losses):.4f}")
184 | output_value = output.detach().cpu().numpy().item()
185 | print(f"Real value: = {int(target_output[0].item())}")
186 | print(f"Predicted: = {int(output_value // 1)} [{output_value}]")
187 | last_100_losses = []
188 |
189 | print("\nTesting generalization...")
190 | rnn.eval() # Switch to evaluation mode
191 |
192 | with torch.no_grad(): # Disable gradient calculations during testing
193 | for i in range(int((args.iterations + 1) / 10)):
194 | print(f"\nIteration {i}/{args.iterations // 10}")
195 | random_length = np.random.randint(2, int(args.sequence_max_length) * 10 + 1)
196 | input_data, target_output, _ = generate_data(random_length, args.input_size, device)
197 | input_data = input_data.repeat(args.batch_size, 1, 1)
198 | target_output = target_output.repeat(args.batch_size, 1, 1)
199 |
200 | output, *_ = rnn(input_data, (None, None, None), reset_experience=True, pass_through_memory=True)
201 |
202 | output_value = output.sum(dim=2, keepdim=True).sum(dim=1, keepdim=True).item()
203 | print(f"Real value: = {int(target_output[0].item())}")
204 | print(f"Predicted: = {int(output_value // 1)} [{output_value}]")
205 |
206 |
207 | if __name__ == "__main__":
208 | main()
209 |
--------------------------------------------------------------------------------
/tasks/argmax_task.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import argparse
5 | import os
6 | import sys
7 |
8 | import numpy as np
9 | import torch
10 | from typing import Any
11 | import torch.optim as optim
12 | from torch.nn.utils import clip_grad_norm_
13 | from visdom import Visdom
14 |
15 | # Add the parent directory to sys.path to allow imports from dnc
16 | sys.path.insert(
17 | 0, os.path.join("..", "..")
18 | ) # os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) # <-- NO! Use relative path.
19 | from dnc import DNC, SDNC, SAM
20 |
21 |
22 | def get_device(cuda_id: int) -> torch.device:
23 | """Gets the torch device based on CUDA availability and ID."""
24 | if cuda_id >= 0 and torch.cuda.is_available():
25 | return torch.device(f"cuda:{cuda_id}")
26 | else:
27 | return torch.device("cpu")
28 |
29 |
30 | def onehot(x: int, n: int) -> np.ndarray:
31 | """Creates a one-hot encoded vector."""
32 | ret = np.zeros(n, dtype=np.float32)
33 | ret[x] = 1.0
34 | return ret
35 |
36 |
37 | def generate_data(length: int, size: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
38 | """Generates data for the argmax task."""
39 | content = np.random.randint(0, size - 1, length)
40 |
41 | seqlen = length + 1
42 | x_seq = [onehot(int(val), size) if i < length else onehot(size - 1, size) for i, val in enumerate(content)]
43 | x_seq.append(onehot(size - 1, size))
44 | x_seq = np.array(x_seq, dtype=np.float32).reshape(1, seqlen, size) # type: ignore
45 |
46 | max_ind = np.argmax(content)
47 | target_output = np.zeros((1, 1, 1), dtype=np.float32)
48 | target_output[:, 0, 0] = max_ind
49 |
50 | weights_vec = np.zeros((1, 1, 1), dtype=np.float32)
51 | weights_vec[:, 0, 0] = 1.0
52 |
53 | return (
54 | torch.tensor(x_seq, device=device),
55 | torch.tensor(target_output, device=device),
56 | torch.tensor(weights_vec, device=device),
57 | )
58 |
59 |
60 | def main() -> None:
61 | """Main function for the argmax task."""
62 | parser = argparse.ArgumentParser(description="PyTorch Differentiable Neural Computer Argmax Task")
63 | parser.add_argument("--input_size", type=int, default=6, help="Dimension of input feature")
64 | parser.add_argument("--rnn_type", type=str, default="lstm", help="Type of recurrent cells (lstm, gru, rnn)")
65 | parser.add_argument("--nhid", type=int, default=100, help="Number of hidden units in the controller")
66 | parser.add_argument("--dropout", type=float, default=0, help="Controller dropout rate")
67 | parser.add_argument("--memory_type", type=str, default="dnc", help="Memory type (dnc, sdnc, sam)")
68 | parser.add_argument("--nlayer", type=int, default=1, help="Number of memory layers")
69 | parser.add_argument("--nhlayer", type=int, default=2, help="Number of hidden layers in each RNN")
70 | parser.add_argument("--lr", type=float, default=1e-4, help="Initial learning rate")
71 | parser.add_argument("--optim", type=str, default="adam", help="Optimizer (adam, rmsprop)")
72 | parser.add_argument("--clip", type=float, default=50, help="Gradient clipping value")
73 | parser.add_argument("--batch_size", type=int, default=100, help="Batch size")
74 | parser.add_argument("--mem_size", type=int, default=20, help="Memory cell size")
75 | parser.add_argument("--mem_slot", type=int, default=16, help="Number of memory slots")
76 | parser.add_argument("--read_heads", type=int, default=4, help="Number of read heads")
77 | parser.add_argument(
78 | "--sparse_reads", type=int, default=10, help="Number of sparse reads per read head (for sdnc and sam)"
79 | )
80 | parser.add_argument("--temporal_reads", type=int, default=2, help="Number of temporal reads (for sdnc)")
81 | parser.add_argument("--sequence_max_length", type=int, default=4, help="Maximum sequence length")
82 | parser.add_argument("--cuda", type=int, default=-1, help="CUDA GPU ID (-1 for CPU)")
83 | parser.add_argument("--iterations", type=int, default=2000, help="Total number of iterations")
84 | parser.add_argument("--summarize_freq", type=int, default=100, help="Summarize frequency")
85 | parser.add_argument("--check_freq", type=int, default=100, help="Checkpoint frequency")
86 | parser.add_argument("--visdom", action="store_true", help="Use Visdom for visualization")
87 |
88 | args = parser.parse_args()
89 | print(args)
90 |
91 | device = get_device(args.cuda)
92 |
93 | if args.visdom:
94 | viz = Visdom()
95 | if not viz.check_connection():
96 | print("Visdom server not running. Disabling Visdom.")
97 | args.visdom = False
98 |
99 | if args.memory_type == "dnc":
100 | rnn = DNC(
101 | input_size=args.input_size,
102 | hidden_size=args.nhid,
103 | rnn_type=args.rnn_type,
104 | num_layers=args.nlayer,
105 | num_hidden_layers=args.nhlayer,
106 | dropout=args.dropout,
107 | nr_cells=args.mem_slot,
108 | cell_size=args.mem_size,
109 | read_heads=args.read_heads,
110 | device=device,
111 | debug=args.visdom,
112 | batch_first=True,
113 | independent_linears=False,
114 | )
115 | elif args.memory_type == "sdnc":
116 | rnn = SDNC(
117 | input_size=args.input_size,
118 | hidden_size=args.nhid,
119 | rnn_type=args.rnn_type,
120 | num_layers=args.nlayer,
121 | num_hidden_layers=args.nhlayer,
122 | dropout=args.dropout,
123 | nr_cells=args.mem_slot,
124 | cell_size=args.mem_size,
125 | sparse_reads=args.sparse_reads,
126 | temporal_reads=args.temporal_reads,
127 | read_heads=args.read_heads,
128 | device=device,
129 | debug=args.visdom,
130 | batch_first=True,
131 | independent_linears=False,
132 | )
133 | elif args.memory_type == "sam":
134 | rnn = SAM(
135 | input_size=args.input_size,
136 | hidden_size=args.nhid,
137 | rnn_type=args.rnn_type,
138 | num_layers=args.nlayer,
139 | num_hidden_layers=args.nhlayer,
140 | dropout=args.dropout,
141 | nr_cells=args.mem_slot,
142 | cell_size=args.mem_size,
143 | sparse_reads=args.sparse_reads,
144 | read_heads=args.read_heads,
145 | device=device,
146 | debug=args.visdom,
147 | batch_first=True,
148 | independent_linears=False,
149 | )
150 | else:
151 | raise ValueError('Invalid memory_type. Choose "dnc", "sdnc", or "sam".')
152 |
153 | rnn = rnn.to(device)
154 | print(rnn)
155 | optimizer: Any
156 |
157 | if args.optim == "adam":
158 | optimizer = optim.Adam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=(0.9, 0.98))
159 | elif args.optim == "adamax":
160 | optimizer = optim.Adamax(rnn.parameters(), lr=args.lr, eps=1e-9, betas=(0.9, 0.98))
161 | elif args.optim == "rmsprop":
162 | optimizer = optim.RMSprop(rnn.parameters(), lr=args.lr, momentum=0.9, eps=1e-10)
163 | elif args.optim == "sgd":
164 | optimizer = optim.SGD(rnn.parameters(), lr=args.lr)
165 | elif args.optim == "adagrad":
166 | optimizer = optim.Adagrad(rnn.parameters(), lr=args.lr)
167 | elif args.optim == "adadelta":
168 | optimizer = optim.Adadelta(rnn.parameters(), lr=args.lr)
169 | else:
170 | raise ValueError(f"Invalid optimizer: {args.optim}")
171 |
172 | last_100_losses = []
173 |
174 | for epoch in range(args.iterations + 1):
175 | print(f"\rIteration {epoch}/{args.iterations}", end="")
176 | optimizer.zero_grad()
177 |
178 | random_length = np.random.randint(2, args.sequence_max_length + 1)
179 | input_data, target_output, loss_weights = generate_data(random_length, args.input_size, device)
180 | input_data = input_data.repeat(args.batch_size, 1, 1)
181 | target_output = target_output.repeat(args.batch_size, 1, 1)
182 | loss_weights = loss_weights.repeat(args.batch_size, 1, 1)
183 |
184 | output, (chx, mhx, rv) = rnn(
185 | input_data, (None, None, None), reset_experience=True, pass_through_memory=True
186 | ) # debug removed
187 |
188 | loss = torch.mean(((loss_weights * output).sum(-1, keepdim=True) - target_output) ** 2)
189 | loss.backward()
190 |
191 | clip_grad_norm_(rnn.parameters(), args.clip)
192 | optimizer.step()
193 | loss_value = loss.item()
194 |
195 | # Detach memory from graph
196 | if mhx is not None:
197 | mhx = {k: (v.detach() if isinstance(v, torch.Tensor) else v) for k, v in mhx.items()}
198 |
199 | last_100_losses.append(loss_value)
200 |
201 | if epoch % args.summarize_freq == 0:
202 | output_value = (loss_weights * output).sum().item()
203 | target_value = target_output.sum().item()
204 |
205 | print(f"\rIteration {epoch}/{args.iterations}")
206 | print(f"Avg. Loss: {np.mean(last_100_losses):.4f}")
207 | print(f"Real value: = {int(target_value)}")
208 | print(f"Predicted: = {int(output_value // 1)} [{output_value}]")
209 | last_100_losses = []
210 |
211 | print("\nTesting generalization...")
212 | rnn.eval() # Switch to evaluation mode
213 |
214 | with torch.no_grad(): # Disable gradient calculations during testing
215 | for i in range(int((args.iterations + 1) / 10)):
216 | print(f"\nIteration {i}/{args.iterations // 10}")
217 | random_length = np.random.randint(2, args.sequence_max_length * 2 + 1)
218 | input_data, target_output, loss_weights = generate_data(random_length, args.input_size, device)
219 | input_data = input_data.repeat(args.batch_size, 1, 1)
220 | target_output = target_output.repeat(args.batch_size, 1, 1)
221 | loss_weights = loss_weights.repeat(args.batch_size, 1, 1)
222 |
223 | output, *_ = rnn(input_data, (None, None, None), reset_experience=True, pass_through_memory=True)
224 |
225 | output_value = output[:, -1, :].sum().item()
226 | target_value = target_output.sum().item()
227 |
228 | print(f"Real value: = {int(target_value)}")
229 | print(f"Predicted: = {int(output_value // 1)} [{output_value}]")
230 |
231 |
232 | if __name__ == "__main__":
233 | main()
234 |
--------------------------------------------------------------------------------
/scripts/build_faiss.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | # Script to compile FAISS with CUDA and cuBLAS support directly into current virtual environment
5 | # This script will use the currently active Python environment
6 |
7 | # Check if a virtual environment is active
8 | if [ -z "$VIRTUAL_ENV" ]; then
9 | echo "Error: No active Python virtual environment detected."
10 | echo "Please activate a virtual environment before running this script."
11 | exit 1
12 | fi
13 |
14 | echo "Using Python virtual environment at: $VIRTUAL_ENV"
15 |
16 | # Configuration
17 | FAISS_VERSION="1.10.0" # Latest stable version
18 | BUILD_DIR="$(pwd)/faiss_build"
19 | INSTALL_DIR="$VIRTUAL_ENV" # Install directly to virtual environment
20 | CUDA_ARCHS="75" # For compute capability 7.5, use comma-separated list for multiple archs
21 | PYTHON_BINDINGS="ON" # Set to OFF to disable Python bindings
22 | USE_MKL="OFF" # Set to ON to use Intel MKL (recommended for performance)
23 | USE_CUVS="OFF" # Set to ON to enable NVIDIA cuVS implementations
24 | OPT_LEVEL="avx2" # Options: generic, avx2, avx512, avx512_spr (x86-64) or generic, sve (aarch64)
25 | ENABLE_GPU="ON" # Set to OFF for CPU-only build
26 | BUILD_TYPE="Release" # Options: Release, Debug
27 | ENABLE_TESTING="OFF" # Set to ON to build tests
28 | BUILD_SHARED="ON" # Set to ON for shared libraries, OFF for static
29 | PARALLEL_JOBS=$(nproc) # Number of parallel build jobs
30 |
31 | # Colors for output
32 | GREEN='\033[0;32m'
33 | RED='\033[0;31m'
34 | YELLOW='\033[0;33m'
35 | NC='\033[0m' # No Color
36 |
37 | # Print configuration
38 | echo -e "${GREEN}FAISS Build Configuration:${NC}"
39 | echo -e " FAISS version: ${YELLOW}${FAISS_VERSION}${NC}"
40 | echo -e " Build directory: ${YELLOW}${BUILD_DIR}${NC}"
41 | echo -e " Install directory: ${YELLOW}${INSTALL_DIR}${NC}"
42 | echo -e " CUDA architectures: ${YELLOW}${CUDA_ARCHS}${NC}"
43 | echo -e " Python bindings: ${YELLOW}${PYTHON_BINDINGS}${NC}"
44 | echo -e " Use Intel MKL: ${YELLOW}${USE_MKL}${NC}"
45 | echo -e " Use NVIDIA cuVS: ${YELLOW}${USE_CUVS}${NC}"
46 | echo -e " Optimization level: ${YELLOW}${OPT_LEVEL}${NC}"
47 | echo -e " Enable GPU: ${YELLOW}${ENABLE_GPU}${NC}"
48 | echo -e " Build type: ${YELLOW}${BUILD_TYPE}${NC}"
49 | echo -e " Enable testing: ${YELLOW}${ENABLE_TESTING}${NC}"
50 | echo -e " Build shared libraries: ${YELLOW}${BUILD_SHARED}${NC}"
51 | echo -e " Parallel jobs: ${YELLOW}${PARALLEL_JOBS}${NC}"
52 |
53 | # Function to check for required tools
54 | check_requirements() {
55 | echo -e "${GREEN}Checking for required tools...${NC}"
56 | local missing_tools=()
57 |
58 | # Check for C++ compiler
59 | if ! command -v g++ &>/dev/null && ! command -v clang++ &>/dev/null; then
60 | missing_tools+=("C++17 compiler (g++ or clang++)")
61 | else
62 | echo -e " C++ compiler: ${YELLOW}$(command -v g++ 2>/dev/null || command -v clang++)${NC}"
63 | fi
64 |
65 | # Check for CMake
66 | if ! command -v cmake &>/dev/null; then
67 | missing_tools+=("CMake")
68 | else
69 | echo -e " CMake: ${YELLOW}$(cmake --version | head -n1)${NC}"
70 | fi
71 |
72 | # Check for Git
73 | if ! command -v git &>/dev/null; then
74 | missing_tools+=("Git")
75 | else
76 | echo -e " Git: ${YELLOW}$(git --version)${NC}"
77 | fi
78 |
79 | # Check for CUDA if GPU is enabled
80 | if [ "$ENABLE_GPU" = "ON" ]; then
81 | if ! command -v nvcc &>/dev/null; then
82 | missing_tools+=("CUDA toolkit (nvcc)")
83 | else
84 | echo -e " CUDA: ${YELLOW}$(nvcc --version | grep release)${NC}"
85 | fi
86 | fi
87 |
88 | # Check for Python and NumPy if Python bindings are enabled
89 | if [ "$PYTHON_BINDINGS" = "ON" ]; then
90 | # Check for NumPy
91 | if ! python -c "import numpy" &>/dev/null; then
92 | missing_tools+=("NumPy (Python package)")
93 | else
94 | echo -e " NumPy: ${YELLOW}$(python -c "import numpy; print(numpy.__version__)")${NC}"
95 | fi
96 |
97 | # Check for SWIG
98 | if ! command -v swig &>/dev/null; then
99 | missing_tools+=("SWIG")
100 | else
101 | echo -e " SWIG: ${YELLOW}$(swig -version | head -n2 | tail -n1)${NC}"
102 | fi
103 | fi
104 |
105 | # Report missing tools
106 | if [ ${#missing_tools[@]} -gt 0 ]; then
107 | echo -e "${RED}Missing required tools:${NC}"
108 | for tool in "${missing_tools[@]}"; do
109 | echo -e " - ${RED}${tool}${NC}"
110 | done
111 | echo -e "${RED}Please install the missing tools and try again.${NC}"
112 | exit 1
113 | fi
114 |
115 | echo -e "${GREEN}All required tools are installed.${NC}"
116 | }
117 |
118 | # Clone or update FAISS repository
119 | clone_or_update_faiss() {
120 | echo -e "${GREEN}Cloning or updating FAISS repository...${NC}"
121 |
122 | mkdir -p "$BUILD_DIR"
123 |
124 | if [ ! -d "$BUILD_DIR/faiss" ]; then
125 | cd "$BUILD_DIR"
126 | echo -e "${GREEN}Cloning FAISS repository...${NC}"
127 | git clone https://github.com/facebookresearch/faiss.git
128 | cd faiss
129 | if [ -n "$FAISS_VERSION" ]; then
130 | echo -e "${GREEN}Checking out version ${FAISS_VERSION}...${NC}"
131 | git checkout "v${FAISS_VERSION}" || git checkout "${FAISS_VERSION}"
132 | fi
133 | else
134 | echo -e "${GREEN}FAISS repository already exists, updating...${NC}"
135 | cd "$BUILD_DIR/faiss"
136 | git fetch
137 | if [ -n "$FAISS_VERSION" ]; then
138 | echo -e "${GREEN}Checking out version ${FAISS_VERSION}...${NC}"
139 | git checkout "v${FAISS_VERSION}" || git checkout "${FAISS_VERSION}"
140 | fi
141 | fi
142 | }
143 |
144 | # Configure build with CMake
145 | configure_build() {
146 | echo -e "${GREEN}Configuring build with CMake...${NC}"
147 |
148 | mkdir -p "$BUILD_DIR/faiss/build"
149 | cd "$BUILD_DIR/faiss/build"
150 |
151 | # Get current Python executable and site-packages directory
152 | PYTHON_EXECUTABLE=$(which python)
153 | PYTHON_SITE_PACKAGES=$(python -c "import site; print(site.getsitepackages()[0])")
154 |
155 | echo -e "${GREEN}Using Python: ${YELLOW}${PYTHON_EXECUTABLE}${NC}"
156 | echo -e "${GREEN}Python site-packages: ${YELLOW}${PYTHON_SITE_PACKAGES}${NC}"
157 |
158 | # Prepare CMake arguments
159 | CMAKE_ARGS=(
160 | "-DCMAKE_BUILD_TYPE=${BUILD_TYPE}"
161 | "-DFAISS_ENABLE_GPU=${ENABLE_GPU}"
162 | "-DFAISS_ENABLE_PYTHON=${PYTHON_BINDINGS}"
163 | "-DBUILD_TESTING=${ENABLE_TESTING}"
164 | "-DBUILD_SHARED_LIBS=${BUILD_SHARED}"
165 | "-DCMAKE_INSTALL_PREFIX=${INSTALL_DIR}"
166 | "-DFAISS_OPT_LEVEL=${OPT_LEVEL}"
167 | "-DPython_EXECUTABLE=${PYTHON_EXECUTABLE}"
168 | )
169 |
170 | # Add CUDA architectures if GPU is enabled
171 | if [ "$ENABLE_GPU" = "ON" ]; then
172 | CMAKE_ARGS+=("-DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCHS}")
173 | fi
174 |
175 | # Add cuVS support if enabled
176 | if [ "$USE_CUVS" = "ON" ]; then
177 | CMAKE_ARGS+=("-DFAISS_ENABLE_CUVS=ON")
178 | fi
179 |
180 | # Add MKL support if enabled
181 | if [ "$USE_MKL" = "ON" ]; then
182 | CMAKE_ARGS+=("-DBLA_VENDOR=Intel10_64_dyn")
183 | fi
184 |
185 | # Run CMake
186 | cmake "${CMAKE_ARGS[@]}" ..
187 | }
188 |
189 | # Build FAISS
190 | build_faiss() {
191 | echo -e "${GREEN}Building FAISS...${NC}"
192 | cd "$BUILD_DIR/faiss/build"
193 |
194 | # Build C++ library
195 | echo -e "${GREEN}Building C++ library...${NC}"
196 | make -j"${PARALLEL_JOBS}" faiss
197 |
198 | # Build optimized versions if needed
199 | if [ "$OPT_LEVEL" = "avx2" ]; then
200 | echo -e "${GREEN}Building AVX2 optimized version...${NC}"
201 | make -j"${PARALLEL_JOBS}" faiss_avx2
202 | elif [ "$OPT_LEVEL" = "avx512" ]; then
203 | echo -e "${GREEN}Building AVX512 optimized version...${NC}"
204 | make -j"${PARALLEL_JOBS}" faiss_avx512
205 | elif [ "$OPT_LEVEL" = "avx512_spr" ]; then
206 | echo -e "${GREEN}Building AVX512 Sapphire Rapids optimized version...${NC}"
207 | make -j"${PARALLEL_JOBS}" faiss_avx512_spr
208 | fi
209 |
210 | # Build Python bindings if enabled
211 | if [ "$PYTHON_BINDINGS" = "ON" ]; then
212 | echo -e "${GREEN}Building Python bindings...${NC}"
213 | make -j"${PARALLEL_JOBS}" swigfaiss
214 | fi
215 | }
216 |
217 | # Install FAISS
218 | install_faiss() {
219 | echo -e "${GREEN}Installing FAISS...${NC}"
220 | cd "$BUILD_DIR/faiss/build"
221 |
222 | # Install C++ library and headers
223 | echo -e "${GREEN}Installing C++ library and headers...${NC}"
224 | if [ "$BUILD_SHARED" = "ON" ]; then
225 | # First copy shared libraries to the virtual environment's lib directory
226 | mkdir -p "$VIRTUAL_ENV/lib"
227 | find . -name "*.so*" -not -path "*python*" -type f -exec cp -v {} "$VIRTUAL_ENV/lib/" \;
228 | fi
229 |
230 | # Install Python bindings if enabled
231 | if [ "$PYTHON_BINDINGS" = "ON" ]; then
232 | echo -e "${GREEN}Installing Python bindings...${NC}"
233 | cd "$BUILD_DIR/faiss/build/faiss/python"
234 | python setup.py install --prefix="$VIRTUAL_ENV"
235 | fi
236 | }
237 |
238 | # Run tests if enabled
239 | run_tests() {
240 | if [ "$ENABLE_TESTING" = "ON" ]; then
241 | echo -e "${GREEN}Running tests...${NC}"
242 | cd "$BUILD_DIR/faiss/build"
243 | make test
244 | fi
245 | }
246 |
247 | # Create a test script
248 | create_test_script() {
249 | echo -e "${GREEN}Creating test script...${NC}"
250 |
251 | mkdir -p "$BUILD_DIR"
252 |
253 | # Create Python test script
254 | cat >"$BUILD_DIR/test_faiss.py" < torch.device:
22 | """Gets the torch device based on CUDA availability and ID."""
23 | if cuda_id >= 0 and torch.cuda.is_available():
24 | return torch.device(f"cuda:{cuda_id}")
25 | else:
26 | return torch.device("cpu")
27 |
28 |
29 | def generate_data(batch_size: int, length: int, size: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
30 | """Generates data for the copy task."""
31 | sequence = np.random.binomial(1, 0.5, (batch_size, length, size - 1)).astype(np.float32)
32 | input_data = np.zeros((batch_size, 2 * length + 1, size), dtype=np.float32)
33 | target_output = np.zeros((batch_size, 2 * length + 1, size), dtype=np.float32)
34 |
35 | input_data[:, :length, : size - 1] = sequence
36 | input_data[:, length, -1] = 1 # Add the end-of-sequence marker
37 | target_output[:, length + 1 :, : size - 1] = sequence
38 |
39 | return (torch.tensor(input_data, device=device), torch.tensor(target_output, device=device))
40 |
41 |
42 | def criterion(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
43 | """Calculates the binary cross-entropy loss."""
44 | # Use F.binary_cross_entropy_with_logits for numerical stability
45 | return F.binary_cross_entropy_with_logits(predictions, targets)
46 |
47 |
48 | def main() -> None:
49 | """Main function for the copy task."""
50 | parser = argparse.ArgumentParser(description="PyTorch Differentiable Neural Computer Copy Task")
51 | parser.add_argument("-input_size", type=int, default=6, help="Dimension of input feature")
52 | parser.add_argument("--rnn_type", type=str, default="lstm", help="Type of recurrent cells (lstm, gru, rnn)")
53 | parser.add_argument("--nhid", type=int, default=64, help="Number of hidden units in the controller")
54 | parser.add_argument("--dropout", type=float, default=0, help="Controller dropout rate")
55 | parser.add_argument("--memory_type", type=str, default="dnc", help="Memory type (dnc, sdnc, sam)")
56 | parser.add_argument("--nlayer", type=int, default=1, help="Number of memory layers")
57 | parser.add_argument("--nhlayer", type=int, default=2, help="Number of hidden layers in each RNN")
58 | parser.add_argument("--lr", type=float, default=1e-4, help="Initial learning rate")
59 | parser.add_argument("--optim", type=str, default="adam", help="Optimizer (adam, rmsprop)")
60 | parser.add_argument("--clip", type=float, default=50, help="Gradient clipping value")
61 | parser.add_argument("--batch_size", type=int, default=100, help="Batch size")
62 | parser.add_argument("--mem_size", type=int, default=20, help="Memory cell size")
63 | parser.add_argument("--mem_slot", type=int, default=16, help="Number of memory slots")
64 | parser.add_argument("--read_heads", type=int, default=4, help="Number of read heads")
65 | parser.add_argument("--sparse_reads", type=int, default=10, help="Number of sparse reads per head (sdnc/sam)")
66 | parser.add_argument("--temporal_reads", type=int, default=2, help="Number of temporal reads (sdnc)")
67 | parser.add_argument("--sequence_max_length", type=int, default=4, help="Maximum sequence length")
68 | parser.add_argument("--curriculum_increment", type=int, default=0, help="Sequence length increment per freq")
69 | parser.add_argument("--curriculum_freq", type=int, default=1000, help="Frequency of curriculum increment")
70 | parser.add_argument("--cuda", type=int, default=-1, help="CUDA GPU ID (-1 for CPU)")
71 | parser.add_argument("--iterations", type=int, default=100000, help="Total number of iterations")
72 | parser.add_argument("--summarize_freq", type=int, default=100, help="Summarize frequency")
73 | parser.add_argument("--check_freq", type=int, default=100, help="Checkpoint frequency")
74 | parser.add_argument("--visdom", action="store_true", help="Use Visdom for visualization")
75 | args = parser.parse_args()
76 | print(args)
77 |
78 | device = get_device(args.cuda)
79 |
80 | if args.visdom:
81 | viz = Visdom()
82 | if not viz.check_connection():
83 | print("Visdom server not running. Disabling Visdom.")
84 | args.visdom = False
85 |
86 | if args.memory_type == "dnc":
87 | rnn = DNC(
88 | input_size=args.input_size,
89 | hidden_size=args.nhid,
90 | rnn_type=args.rnn_type,
91 | num_layers=args.nlayer,
92 | num_hidden_layers=args.nhlayer,
93 | dropout=args.dropout,
94 | nr_cells=args.mem_slot,
95 | cell_size=args.mem_size,
96 | read_heads=args.read_heads,
97 | device=device,
98 | debug=args.visdom,
99 | batch_first=True,
100 | independent_linears=True,
101 | )
102 | elif args.memory_type == "sdnc":
103 | rnn = SDNC(
104 | input_size=args.input_size,
105 | hidden_size=args.nhid,
106 | rnn_type=args.rnn_type,
107 | num_layers=args.nlayer,
108 | num_hidden_layers=args.nhlayer,
109 | dropout=args.dropout,
110 | nr_cells=args.mem_slot,
111 | cell_size=args.mem_size,
112 | sparse_reads=args.sparse_reads,
113 | temporal_reads=args.temporal_reads,
114 | read_heads=args.read_heads,
115 | device=device,
116 | debug=args.visdom,
117 | batch_first=True,
118 | independent_linears=False,
119 | )
120 | elif args.memory_type == "sam":
121 | rnn = SAM(
122 | input_size=args.input_size,
123 | hidden_size=args.nhid,
124 | rnn_type=args.rnn_type,
125 | num_layers=args.nlayer,
126 | num_hidden_layers=args.nhlayer,
127 | dropout=args.dropout,
128 | nr_cells=args.mem_slot,
129 | cell_size=args.mem_size,
130 | sparse_reads=args.sparse_reads,
131 | read_heads=args.read_heads,
132 | device=device,
133 | debug=args.visdom,
134 | batch_first=True,
135 | independent_linears=False,
136 | )
137 | else:
138 | raise ValueError('Invalid memory_type. Choose "dnc", "sdnc", or "sam".')
139 |
140 | rnn = rnn.to(device)
141 | print(rnn)
142 | optimizer: Any
143 |
144 | if args.optim == "adam":
145 | optimizer = optim.Adam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=(0.9, 0.98))
146 | elif args.optim == "adamax":
147 | optimizer = optim.Adamax(rnn.parameters(), lr=args.lr, eps=1e-9, betas=(0.9, 0.98))
148 | elif args.optim == "rmsprop":
149 | optimizer = optim.RMSprop(rnn.parameters(), lr=args.lr, momentum=0.9, eps=1e-10)
150 | elif args.optim == "sgd":
151 | optimizer = optim.SGD(rnn.parameters(), lr=args.lr)
152 | elif args.optim == "adagrad":
153 | optimizer = optim.Adagrad(rnn.parameters(), lr=args.lr)
154 | elif args.optim == "adadelta":
155 | optimizer = optim.Adadelta(rnn.parameters(), lr=args.lr)
156 | else:
157 | raise ValueError(f"Unsupported optimizer: {args.optim}")
158 |
159 | last_losses = []
160 |
161 | for epoch in range(args.iterations + 1):
162 | print(f"\rIteration {epoch}/{args.iterations}", end="")
163 | optimizer.zero_grad()
164 |
165 | random_length = np.random.randint(1, args.sequence_max_length + 1)
166 | input_data, target_output = generate_data(args.batch_size, random_length, args.input_size, device)
167 |
168 | output, (chx, mhx, rv) = rnn(input_data, (None, None, None), reset_experience=True, pass_through_memory=True)
169 |
170 | loss = criterion(output, target_output)
171 | loss.backward()
172 |
173 | clip_grad_norm_(rnn.parameters(), args.clip)
174 | optimizer.step()
175 | loss_value = loss.item()
176 |
177 | # Detach memory from graph
178 | if mhx is not None:
179 | mhx = {k: (v.detach() if isinstance(v, torch.Tensor) else v) for k, v in mhx.items()}
180 |
181 | last_losses.append(loss_value)
182 |
183 | if epoch % args.summarize_freq == 0:
184 | avg_loss = np.mean(last_losses)
185 | print(f"\n\tAvg. Loss: {avg_loss:.4f}")
186 | last_losses = []
187 | if np.isnan(avg_loss):
188 | raise ValueError("NaN loss. Experiment failed.")
189 |
190 | if args.visdom and rnn.debug: # added rnn.debug
191 | avg_loss = np.mean(last_losses)
192 | last_losses = []
193 | if args.memory_type == "dnc":
194 |
195 | memory = rnn._debug(mhx, None)["memory"] # type: ignore
196 | if memory is not None and len(memory) > 0:
197 | viz.heatmap(
198 | np.array(memory[-1]),
199 | opts=dict(
200 | xtickstep=10,
201 | ytickstep=2,
202 | title=f"Memory, t: {epoch}, loss: {avg_loss:.4f}",
203 | ylabel="layer * time",
204 | xlabel="mem_slot * mem_size",
205 | ),
206 | )
207 |
208 | link_matrix = rnn._debug(mhx, None)["link_matrix"] # type: ignore
209 | if link_matrix is not None and len(link_matrix) > 0:
210 | viz.heatmap(
211 | np.array(link_matrix[-1]).reshape(args.mem_slot, args.mem_slot),
212 | opts=dict(
213 | xtickstep=10,
214 | ytickstep=2,
215 | title=f"Link Matrix, t: {epoch}, loss: {avg_loss:.4f}",
216 | ylabel="mem_slot",
217 | xlabel="mem_slot",
218 | ),
219 | )
220 |
221 | precedence = rnn._debug(mhx, None)["precedence"] # type: ignore
222 | if precedence is not None and len(precedence) > 0:
223 | viz.heatmap(
224 | np.array(precedence[-1]),
225 | opts=dict(
226 | xtickstep=10,
227 | ytickstep=2,
228 | title=f"Precedence, t: {epoch}, loss: {avg_loss:.4f}",
229 | ylabel="layer * time",
230 | xlabel="mem_slot",
231 | ),
232 | )
233 |
234 | if args.memory_type == "sdnc":
235 | link_matrix = rnn._debug(mhx, None)["link_matrix"] # type: ignore
236 | if link_matrix is not None and len(link_matrix) > 0:
237 | viz.heatmap(
238 | np.array(link_matrix[-1]).reshape(args.mem_slot, -1),
239 | opts=dict(
240 | xtickstep=10,
241 | ytickstep=2,
242 | title=f"Link Matrix, t: {epoch}, loss: {avg_loss:.4f}",
243 | ylabel="mem_slot",
244 | xlabel="mem_slot",
245 | ),
246 | )
247 |
248 | rev_link_matrix = rnn._debug(mhx, None)["rev_link_matrix"] # type: ignore
249 | if rev_link_matrix is not None and len(rev_link_matrix) > 0:
250 | viz.heatmap(
251 | np.array(rev_link_matrix[-1]).reshape(args.mem_slot, -1),
252 | opts=dict(
253 | xtickstep=10,
254 | ytickstep=2,
255 | title=f"Reverse Link Matrix, t: {epoch}, loss: {avg_loss:.4f}",
256 | ylabel="mem_slot",
257 | xlabel="mem_slot",
258 | ),
259 | )
260 | read_positions = rnn._debug(mhx, None)["read_positions"] # type: ignore
261 | if read_positions is not None and len(read_positions) > 0:
262 | viz.heatmap(
263 | np.array(read_positions[-1]),
264 | opts=dict(
265 | xtickstep=10,
266 | ytickstep=2,
267 | title=f"Read Positions, t: {epoch}, loss: {avg_loss:.4f}",
268 | ylabel="layer * time",
269 | xlabel="mem_slot",
270 | ),
271 | )
272 |
273 | read_weights = rnn._debug(mhx, None)["read_weights"] # type: ignore
274 | if read_weights is not None and len(read_weights) > 0:
275 | viz.heatmap(
276 | np.array(read_weights[-1]),
277 | opts=dict(
278 | xtickstep=10,
279 | ytickstep=2,
280 | title=f"Read Weights, t: {epoch}, loss: {avg_loss:.4f}",
281 | ylabel="layer * time",
282 | xlabel="nr_read_heads * mem_slot",
283 | ),
284 | )
285 |
286 | write_weights = rnn._debug(mhx, None)["write_weights"] # type: ignore
287 | if write_weights is not None and len(write_weights) > 0:
288 | viz.heatmap(
289 | np.array(write_weights[-1]),
290 | opts=dict(
291 | xtickstep=10,
292 | ytickstep=2,
293 | title=f"Write Weights, t: {epoch}, loss: {avg_loss:.4f}",
294 | ylabel="layer * time",
295 | xlabel="mem_slot",
296 | ),
297 | )
298 |
299 | if args.memory_type == "dnc":
300 | usage_vector = rnn._debug(mhx, None)["usage_vector"] # type: ignore
301 | else:
302 | usage_vector = rnn._debug(mhx, None)["usage"] # type: ignore
303 |
304 | if usage_vector is not None and len(usage_vector) > 0:
305 | viz.heatmap(
306 | np.array(usage_vector[-1]),
307 | opts=dict(
308 | xtickstep=10,
309 | ytickstep=2,
310 | title=f"Usage Vector, t: {epoch}, loss: {avg_loss:.4f}",
311 | ylabel="layer * time",
312 | xlabel="mem_slot",
313 | ),
314 | )
315 |
316 | if args.curriculum_increment > 0 and epoch != 0 and epoch % args.curriculum_freq == 0:
317 | args.sequence_max_length += args.curriculum_increment
318 | print(f"Increasing max length to {args.sequence_max_length}")
319 |
320 | if epoch != 0 and epoch % args.check_freq == 0:
321 | print("\nSaving Checkpoint ... ", end="")
322 | check_ptr = os.path.join(args.checkpoint_dir, f"step_{epoch}.pth")
323 | torch.save(rnn.state_dict(), check_ptr)
324 | print("Done!")
325 |
326 | print("\nTesting generalization...")
327 | rnn.eval()
328 |
329 | with torch.no_grad():
330 | for i in range(int((args.iterations + 1) / 10)):
331 | print(f"\nIteration {i}/{args.iterations // 10}")
332 | random_length = np.random.randint(2, args.sequence_max_length * 10 + 1)
333 |
334 | input_data, target_output = generate_data(args.batch_size, random_length, args.input_size, device)
335 | output, _ = rnn(input_data, (None, None, None), reset_experience=True, pass_through_memory=True)
336 | output_value = torch.sigmoid(output).round().detach().cpu().numpy()
337 | target_value = target_output.detach().cpu().numpy()
338 |
339 | num_correct = (output_value == target_value).sum()
340 | total_num = target_output.numel()
341 | accuracy = num_correct / total_num
342 | print(f"Accuracy: {accuracy:.4f}")
343 |
344 |
345 | if __name__ == "__main__":
346 | main()
347 |
--------------------------------------------------------------------------------
/dnc/sparse_memory.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .util import cuda, σ, θ # Import necessary functions from util
9 |
10 |
11 | class SparseMemory(nn.Module):
12 | """Sparse Memory module."""
13 |
14 | def __init__(
15 | self,
16 | input_size: int,
17 | mem_size: int = 512,
18 | cell_size: int = 32,
19 | independent_linears: bool = True,
20 | read_heads: int = 4,
21 | sparse_reads: int = 4,
22 | num_lists: int | None = None,
23 | index_checks: int = 32,
24 | device: torch.device | None = None,
25 | ):
26 | """Initialize SparseMemory.
27 |
28 | Args:
29 | input_size: Input size.
30 | mem_size: Memory size.
31 | cell_size: Size of each memory cell.
32 | independent_linears: Whether to use independent linear layers.
33 | read_heads: Number of read heads.
34 | sparse_reads: Number of sparse reads.
35 | num_lists: Number of lists for indexing.
36 | index_checks: Number of index checks.
37 | device: PyTorch device to use.
38 | """
39 | super(SparseMemory, self).__init__()
40 |
41 | self.mem_size = mem_size
42 | self.cell_size = cell_size
43 | self.device = device
44 | self.input_size = input_size
45 | self.independent_linears = independent_linears
46 | self.K = sparse_reads if self.mem_size > sparse_reads else self.mem_size
47 | self.read_heads = read_heads
48 | self.num_lists = num_lists if num_lists is not None else int(self.mem_size / 100)
49 | self.index_checks = index_checks
50 |
51 | m = self.mem_size
52 | w = self.cell_size
53 | r = self.read_heads
54 | # The visible memory size: (K * R read heads, forward and backward
55 | # temporal reads of size KL and least used memory cell)
56 | self.c = (r * self.K) + 1
57 |
58 | if self.independent_linears:
59 | self.read_query_transform = nn.Linear(self.input_size, w * r)
60 | self.write_vector_transform = nn.Linear(self.input_size, w)
61 | self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
62 | self.write_gate_transform = nn.Linear(self.input_size, 1)
63 | torch.nn.init.kaiming_uniform_(self.read_query_transform.weight)
64 | torch.nn.init.kaiming_uniform_(self.write_vector_transform.weight)
65 | torch.nn.init.kaiming_uniform_(self.interpolation_gate_transform.weight)
66 | torch.nn.init.kaiming_uniform_(self.write_gate_transform.weight)
67 |
68 | else:
69 | self.interface_size = (r * w) + w + self.c + 1
70 | self.interface_weights = nn.Linear(self.input_size, self.interface_size)
71 | torch.nn.init.kaiming_uniform_(self.interface_weights.weight)
72 |
73 | self.I = cuda(1 - torch.eye(self.c).unsqueeze(0), device=self.device) # (1 * n * n)
74 | self.δ = 0.005 # minimum usage
75 | self.timestep = 0
76 | self.mem_limit_reached = False
77 | if self.device is not None and self.device.type == "cuda":
78 | self.to(self.device)
79 |
80 | def rebuild_indexes(self, hidden: dict, erase: bool = False) -> dict:
81 | """Rebuilds the indexes for sparse memory access.
82 |
83 | Args:
84 | hidden: Hidden state dictionary.
85 | erase: Whether to erase the existing memory content.
86 |
87 | Returns:
88 | Updated hidden state dictionary.
89 | """
90 | b = hidden["memory"].size(0)
91 |
92 | # if indexes already exist, we reset them
93 | if "indexes" in hidden:
94 | for x in hidden["indexes"]:
95 | x.reset()
96 | else:
97 | # create new indexes, try to use FAISS
98 | try:
99 | from .faiss_index import FAISSIndex
100 |
101 | hidden["indexes"] = [
102 | FAISSIndex(
103 | cell_size=self.cell_size,
104 | nr_cells=self.mem_size,
105 | K=self.K,
106 | num_lists=self.num_lists,
107 | probes=self.index_checks,
108 | device=self.device,
109 | )
110 | for _ in range(b)
111 | ]
112 | except ImportError:
113 | print(
114 | "FAISS not found, please install FAISS, consult https://github.com/facebookresearch/faiss/blob/main/INSTALL.md"
115 | )
116 | raise
117 |
118 | # add existing memory into indexes
119 | pos = hidden["read_positions"].squeeze().detach().cpu().numpy()
120 | if not erase:
121 | for n, i in enumerate(hidden["indexes"]):
122 | i.reset()
123 | i.add(hidden["memory"][n], last=pos[n][-1])
124 | else:
125 | self.timestep = 0
126 | self.mem_limit_reached = False
127 |
128 | return hidden
129 |
130 | def reset(self, batch_size: int = 1, hidden: dict | None = None, erase: bool = True) -> dict:
131 | """Resets the memory and hidden state.
132 |
133 | Args:
134 | batch_size: Batch size.
135 | hidden: Hidden state dictionary.
136 | erase: Whether to erase the existing memory content.
137 | Returns:
138 | Reset hidden state dictionary.
139 |
140 | """
141 | m = self.mem_size
142 | w = self.cell_size
143 | b = batch_size
144 | r = self.read_heads
145 | c = self.c
146 |
147 | if hidden is None:
148 | hidden = {
149 | # warning can be a huge chunk of contiguous memory
150 | "memory": cuda(torch.zeros(b, m, w).fill_(self.δ), device=self.device),
151 | "visible_memory": cuda(torch.zeros(b, c, w).fill_(self.δ), device=self.device),
152 | "read_weights": cuda(torch.zeros(b, m).fill_(self.δ), device=self.device),
153 | "write_weights": cuda(torch.zeros(b, m).fill_(self.δ), device=self.device),
154 | "read_vectors": cuda(torch.zeros(b, r, w).fill_(self.δ), device=self.device),
155 | "least_used_mem": cuda(torch.zeros(b, 1).fill_(c + 1), device=self.device).long(),
156 | "usage": cuda(torch.zeros(b, m), device=self.device),
157 | "read_positions": cuda(torch.arange(0, c).expand(b, c), device=self.device).long(),
158 | }
159 | hidden = self.rebuild_indexes(hidden, erase=True)
160 | else:
161 | # duplication is faster than moving tensors between devices (or even cloning)
162 | hidden["memory"] = hidden["memory"].clone()
163 | hidden["visible_memory"] = hidden["visible_memory"].clone()
164 | hidden["read_weights"] = hidden["read_weights"].clone()
165 | hidden["write_weights"] = hidden["write_weights"].clone()
166 | hidden["read_vectors"] = hidden["read_vectors"].clone()
167 | hidden["least_used_mem"] = hidden["least_used_mem"].clone()
168 | hidden["usage"] = hidden["usage"].clone()
169 | hidden["read_positions"] = hidden["read_positions"].clone()
170 | hidden = self.rebuild_indexes(hidden, erase)
171 |
172 | if erase:
173 | hidden["memory"].data.fill_(self.δ)
174 | hidden["visible_memory"].data.fill_(self.δ)
175 | hidden["read_weights"].data.fill_(self.δ)
176 | hidden["write_weights"].data.fill_(self.δ)
177 | hidden["read_vectors"].data.fill_(self.δ)
178 | hidden["least_used_mem"].data.fill_(c + 1)
179 | hidden["usage"].data.fill_(0)
180 | hidden["read_positions"] = cuda(torch.arange(0, c).expand(b, c), device=self.device).long()
181 |
182 | return hidden
183 |
184 | def write_into_sparse_memory(self, hidden: dict) -> dict:
185 | """Writes the visible memory into the sparse memory matrix.
186 |
187 | Args:
188 | hidden: Hidden state dictionary
189 |
190 | Returns:
191 | Updated hidden state dictionary.
192 | """
193 | visible_memory = hidden["visible_memory"]
194 | positions = hidden["read_positions"]
195 |
196 | (b, m, w) = hidden["memory"].size()
197 | # Create a new tensor for memory to avoid inplace operations during backprop
198 | new_memory = hidden["memory"].clone()
199 | # update memory (using non-inplace operation)
200 | new_memory.scatter_(1, positions.unsqueeze(2).expand(b, self.c, w), visible_memory)
201 | hidden["memory"] = new_memory
202 |
203 | # non-differentiable operations
204 | pos = positions.detach().cpu().numpy()
205 | for batch in range(b):
206 | # update indexes
207 | hidden["indexes"][batch].reset()
208 | hidden["indexes"][batch].add(
209 | hidden["memory"][batch], last=(pos[batch][-1] if not self.mem_limit_reached else None)
210 | )
211 |
212 | mem_limit_reached = hidden["least_used_mem"][0].detach().cpu().numpy()[0] >= self.mem_size - 1
213 | self.mem_limit_reached = mem_limit_reached or self.mem_limit_reached
214 |
215 | return hidden
216 |
217 | def write(
218 | self, interpolation_gate: torch.Tensor, write_vector: torch.Tensor, write_gate: torch.Tensor, hidden: dict
219 | ) -> dict:
220 | """Performs the memory write operation.
221 |
222 | Args:
223 | interpolation_gate: Interpolation gate.
224 | write_vector: Write vector.
225 | write_gate: Write gate.
226 | hidden: Hidden state dictionary.
227 |
228 | Returns:
229 | Updated hidden state dictionary.
230 |
231 | """
232 |
233 | read_weights = hidden["read_weights"].gather(1, hidden["read_positions"])
234 | # encourage read and write in the first timestep
235 | if self.timestep == 1:
236 | read_weights = read_weights + 1
237 | write_weights = hidden["write_weights"].gather(1, hidden["read_positions"])
238 |
239 | hidden["usage"], I = self.update_usage(hidden["read_positions"], read_weights, write_weights, hidden["usage"])
240 |
241 | # either we write to previous read locations
242 | x = interpolation_gate * read_weights
243 | # or to a new location
244 | y = (1 - interpolation_gate) * I
245 | write_weights = write_gate * (x + y)
246 |
247 | # store the write weights (avoid inplace operation)
248 | new_write_weights = hidden["write_weights"].clone()
249 | new_write_weights.scatter_(1, hidden["read_positions"], write_weights)
250 | hidden["write_weights"] = new_write_weights
251 |
252 | # erase matrix
253 | erase_matrix = I.unsqueeze(2).expand(hidden["visible_memory"].size())
254 |
255 | # write into memory
256 | hidden["visible_memory"] = hidden["visible_memory"] * (1 - erase_matrix) + torch.bmm(
257 | write_weights.unsqueeze(2), write_vector
258 | )
259 |
260 | hidden = self.write_into_sparse_memory(hidden)
261 |
262 | # update least used memory cell
263 | hidden["least_used_mem"] = torch.topk(hidden["usage"], 1, dim=-1, largest=False)[1]
264 |
265 | return hidden
266 |
267 | def update_usage(
268 | self, read_positions: torch.Tensor, read_weights: torch.Tensor, write_weights: torch.Tensor, usage: torch.Tensor
269 | ) -> tuple[torch.Tensor, torch.Tensor]:
270 | """Updates the usage vector.
271 |
272 | Args:
273 | read_positions: Read positions.
274 | read_weights: Read weights.
275 | write_weights: Write weights.
276 | usage: Usage vector.
277 |
278 | Returns:
279 | Tuple: Updated usage vector and indicator matrix.
280 | """
281 | (b, _) = read_positions.size()
282 | # usage is timesteps since a non-negligible memory access
283 | u = (read_weights + write_weights > self.δ).float()
284 |
285 | # usage before write
286 | relevant_usages = usage.gather(1, read_positions)
287 |
288 | # indicator of words with minimal memory usage
289 | minusage = torch.min(relevant_usages, -1, keepdim=True)[0]
290 | minusage = minusage.expand(relevant_usages.size())
291 | I = (relevant_usages == minusage).float()
292 |
293 | # usage after write
294 | relevant_usages = (self.timestep - relevant_usages) * u + relevant_usages * (1 - u)
295 |
296 | # Replace inplace scatter with clone + scatter + assignment
297 | new_usage = usage.clone()
298 | new_usage.scatter_(1, read_positions, relevant_usages)
299 | usage = new_usage
300 |
301 | return usage, I
302 |
303 | def read_from_sparse_memory(
304 | self, memory: torch.Tensor, indexes: list, keys: torch.Tensor, least_used_mem: torch.Tensor, usage: torch.Tensor
305 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
306 | """Reads from sparse memory using indexes.
307 |
308 | Args:
309 | memory: Memory tensor.
310 | indexes: List of indexes.
311 | keys: Read keys.
312 | least_used_mem: Least used memory locations.
313 | usage: Usage vector.
314 |
315 | Returns:
316 | Tuple: Read vectors, read positions, read weights, and visible memory.
317 | """
318 | b = keys.size(0)
319 | rpos = []
320 |
321 | # we search for k cells per read head
322 | for batch in range(b):
323 | distances, positions = indexes[batch].search(keys[batch])
324 | rpos.append(positions)
325 | read_positions = torch.stack(rpos, 0)
326 |
327 | # add least used mem to read positions
328 | (b, r, k) = read_positions.size()
329 | read_positions = read_positions.squeeze(1).view(b, -1)
330 |
331 | # no gradient here
332 | # temporal reads
333 | (b, m, w) = memory.size()
334 | # Use the memory size as the max length rather than relying on least_used_mem value
335 | # If memory limit is reached, use full memory size minus 1
336 | max_length = (m - 1) if self.mem_limit_reached else min(int(least_used_mem[0, 0].detach().cpu().numpy()), m - 1)
337 |
338 | # differentiable ops
339 | # append forward and backward read positions, might lead to duplicates
340 | read_positions = torch.cat([read_positions, least_used_mem], 1)
341 | read_positions = torch.clamp(read_positions, 0, max_length)
342 |
343 | visible_memory = memory.gather(1, read_positions.unsqueeze(2).expand(b, self.c, w))
344 |
345 | read_weights = σ(θ(visible_memory, keys), 2)
346 | read_vectors = torch.bmm(read_weights, visible_memory)
347 | read_weights = torch.prod(read_weights, 1)
348 |
349 | return read_vectors, read_positions, read_weights, visible_memory
350 |
351 | def read(self, read_query: torch.Tensor, hidden: dict) -> tuple[torch.Tensor, dict]:
352 | """Performs the memory read operation.
353 |
354 | Args:
355 | read_query: Read query.
356 | hidden: Hidden state dictionary.
357 |
358 | Returns:
359 | Tuple: Read vectors and updated hidden state.
360 | """
361 | # sparse read
362 | read_vectors, positions, read_weights, visible_memory = self.read_from_sparse_memory(
363 | hidden["memory"], hidden["indexes"], read_query, hidden["least_used_mem"], hidden["usage"]
364 | )
365 |
366 | hidden["read_positions"] = positions
367 | # Avoid inplace operation
368 | new_read_weights = hidden["read_weights"].clone()
369 | new_read_weights.scatter_(1, positions, read_weights)
370 | hidden["read_weights"] = new_read_weights
371 | hidden["read_vectors"] = read_vectors
372 | hidden["visible_memory"] = visible_memory
373 |
374 | return hidden["read_vectors"], hidden
375 |
376 | def forward(self, ξ: torch.Tensor, hidden: dict) -> tuple[torch.Tensor, dict]:
377 | """Forward pass through the memory.
378 |
379 | Args:
380 | ξ: Input tensor.
381 | hidden: Hidden state dictionary.
382 |
383 | Returns:
384 | Tuple: Read vectors and updated hidden state.
385 | """
386 | m = self.mem_size
387 | w = self.cell_size
388 | r = self.read_heads
389 | c = self.c
390 | b = ξ.size(0)
391 |
392 | if self.independent_linears:
393 | # r read keys (b * r * w)
394 | read_query = self.read_query_transform(ξ).view(b, r, w)
395 | # write key (b * 1 * w)
396 | write_vector = self.write_vector_transform(ξ).view(b, 1, w)
397 | # write vector (b * 1 * r)
398 | interpolation_gate = torch.sigmoid(self.interpolation_gate_transform(ξ)).view(b, c)
399 | # write gate (b * 1)
400 | write_gate = torch.sigmoid(self.write_gate_transform(ξ).view(b, 1))
401 | else:
402 | ξ = self.interface_weights(ξ)
403 | # r read keys (b * r * w)
404 | read_query = ξ[:, : r * w].contiguous().view(b, r, w)
405 | # write key (b * 1 * w)
406 | write_vector = ξ[:, r * w : r * w + w].contiguous().view(b, 1, w)
407 | # write vector (b * 1 * r)
408 | interpolation_gate = torch.sigmoid(ξ[:, r * w + w : r * w + w + c]).contiguous().view(b, c)
409 | # write gate (b * 1)
410 | write_gate = torch.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
411 |
412 | self.timestep += 1
413 | hidden = self.write(interpolation_gate, write_vector, write_gate, hidden)
414 | return self.read(read_query, hidden)
415 |
--------------------------------------------------------------------------------
/dnc/dnc.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import Any
3 |
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
8 |
9 | from .memory import Memory, MemoryHiddenState
10 | from .sparse_memory import SparseMemory
11 | from .sparse_temporal_memory import SparseTemporalMemory
12 |
13 | from .util import cuda
14 |
15 | # Define controller hidden state type for clarity
16 | ControllerHiddenState = torch.Tensor | tuple[torch.Tensor, torch.Tensor]
17 | DNCHiddenState = tuple[
18 | list[ControllerHiddenState],
19 | list[MemoryHiddenState],
20 | torch.Tensor,
21 | ]
22 | LayerHiddenState = tuple[ControllerHiddenState, MemoryHiddenState, torch.Tensor | None]
23 |
24 |
25 | class DNC(nn.Module):
26 | """Differentiable neural computer."""
27 |
28 | def __init__(
29 | self,
30 | input_size: int,
31 | hidden_size: int,
32 | rnn_type: str = "lstm",
33 | num_layers: int = 1,
34 | num_hidden_layers: int = 2,
35 | bias: bool = True,
36 | batch_first: bool = True,
37 | dropout: float = 0,
38 | nr_cells: int = 5,
39 | read_heads: int = 2,
40 | cell_size: int = 10,
41 | nonlinearity: str = "tanh",
42 | independent_linears: bool = False,
43 | share_memory_between_layers: bool = True,
44 | debug: bool = False,
45 | clip: float = 20,
46 | device: torch.device | None = None,
47 | ):
48 | """Create a DNC network.
49 |
50 | Args:
51 | input_size: Input size.
52 | hidden_size: Size of hidden layers.
53 | rnn_type: Type of recurrent cell, can be `rnn`, `gru` and `lstm`.
54 | num_layers: Number of layers of DNC.
55 | num_hidden_layers: Number of layers of RNNs in each DNC layer.
56 | bias: Whether to use bias.
57 | batch_first: If True, then the input and output tensors are provided as `(batch, seq, feature)`.
58 | dropout: Dropout fraction to be applied to each RNN layer.
59 | nr_cells: Size of memory: number of memory cells.
60 | read_heads: Number of read heads that read from memory.
61 | cell_size:Size of memory: size of each cell.
62 | nonlinearity: The non-linearity to use for RNNs, applicable when `rnn_type="rnn"`.
63 | independent_linears: Use independent linear modules for meomry transform operators.
64 | share_memory_between_layers: Share one memory module between all layers.
65 | debug: Run in debug mode.
66 | clip: Clip controller outputs.
67 | device: Device (cpu, cuda, cuda:0, ...)
68 | """
69 | super(DNC, self).__init__()
70 |
71 | self.input_size = input_size
72 | self.hidden_size = hidden_size
73 | self.rnn_type = rnn_type
74 | self.num_layers = num_layers
75 | self.num_hidden_layers = num_hidden_layers
76 | self.bias = bias
77 | self.batch_first = batch_first
78 | self.dropout = dropout
79 | self.nr_cells = nr_cells
80 | self.read_heads = read_heads
81 | self.cell_size = cell_size
82 | self.nonlinearity = nonlinearity
83 | self.independent_linears = independent_linears
84 | self.share_memory_between_layers = share_memory_between_layers
85 | self.debug = debug
86 | self.clip = clip
87 | self.device = device
88 |
89 | self.w = self.cell_size
90 | self.r = self.read_heads
91 |
92 | self.read_vectors_size = self.read_heads * self.cell_size
93 | self.output_size = self.hidden_size
94 |
95 | self.nn_input_size = self.input_size + self.read_vectors_size
96 | self.nn_output_size = self.output_size + self.read_vectors_size
97 |
98 | self.rnns: list[nn.RNN | nn.GRU | nn.LSTM] = []
99 | self.memories: list[Memory | SparseMemory | SparseTemporalMemory] = []
100 |
101 | for layer in range(self.num_layers):
102 | if self.rnn_type.lower() == "rnn":
103 | self.rnns.append(
104 | nn.RNN(
105 | (self.nn_input_size if layer == 0 else self.nn_output_size),
106 | self.output_size,
107 | bias=self.bias,
108 | nonlinearity=self.nonlinearity,
109 | batch_first=True,
110 | dropout=self.dropout,
111 | num_layers=self.num_hidden_layers,
112 | )
113 | )
114 | elif self.rnn_type.lower() == "gru":
115 | self.rnns.append(
116 | nn.GRU(
117 | (self.nn_input_size if layer == 0 else self.nn_output_size),
118 | self.output_size,
119 | bias=self.bias,
120 | batch_first=True,
121 | dropout=self.dropout,
122 | num_layers=self.num_hidden_layers,
123 | )
124 | )
125 | elif self.rnn_type.lower() == "lstm":
126 | self.rnns.append(
127 | nn.LSTM(
128 | (self.nn_input_size if layer == 0 else self.nn_output_size),
129 | self.output_size,
130 | bias=self.bias,
131 | batch_first=True,
132 | dropout=self.dropout,
133 | num_layers=self.num_hidden_layers,
134 | )
135 | )
136 | setattr(self, self.rnn_type.lower() + "_layer_" + str(layer), self.rnns[layer])
137 |
138 | # memories for each layer
139 | if not self.share_memory_between_layers:
140 | self.memories.append(
141 | Memory(
142 | input_size=self.output_size,
143 | nr_cells=self.nr_cells,
144 | cell_size=self.w,
145 | read_heads=self.r,
146 | device=self.device,
147 | independent_linears=self.independent_linears,
148 | )
149 | )
150 | setattr(self, "rnn_layer_memory_" + str(layer), self.memories[layer])
151 |
152 | # only one memory shared by all layers
153 | if self.share_memory_between_layers:
154 | self.memories.append(
155 | Memory(
156 | input_size=self.output_size,
157 | nr_cells=self.nr_cells,
158 | cell_size=self.w,
159 | read_heads=self.r,
160 | device=self.device,
161 | independent_linears=self.independent_linears,
162 | )
163 | )
164 | setattr(self, "rnn_layer_memory_shared", self.memories[0])
165 |
166 | # final output layer
167 | self.output = nn.Linear(self.nn_output_size, self.input_size)
168 | torch.nn.init.kaiming_uniform_(self.output.weight)
169 |
170 | if self.device is not None and self.device.type == "cuda":
171 | self.to(self.device)
172 |
173 | def _init_hidden(self, hx: DNCHiddenState | None, batch_size: int, reset_experience: bool) -> DNCHiddenState:
174 | """Initializes the hidden states.
175 |
176 | Args:
177 | hx: Existing hidden state or None.
178 | batch_size: Batch size.
179 | reset_experience: Whether to reset memory experience.
180 |
181 | Returns:
182 | Initialized hidden state.
183 | """
184 | # Parse hidden state components
185 | if hx is not None:
186 | chx, mhx, last_read = hx
187 | else:
188 | chx, mhx, last_read = None, None, None
189 |
190 | # Initialize controller hidden state if needed
191 | if chx is None:
192 | h: torch.Tensor = cuda(
193 | torch.zeros(self.num_hidden_layers, batch_size, self.output_size),
194 | device=self.device,
195 | )
196 | torch.nn.init.xavier_uniform_(h)
197 | chx = [(h, h) if self.rnn_type.lower() == "lstm" else h for _ in range(self.num_layers)]
198 |
199 | # Initialize last read vectors if needed
200 | if last_read is None:
201 | last_read = cuda(torch.zeros(batch_size, self.w * self.r), device=self.device)
202 |
203 | # Initialize memory states if needed
204 | if mhx is None:
205 | if self.share_memory_between_layers:
206 | mhx = [self.memories[0].reset(batch_size, erase=reset_experience)]
207 | else:
208 | mhx = [m.reset(batch_size, erase=reset_experience) for m in self.memories]
209 | else:
210 | if self.share_memory_between_layers:
211 | if len(mhx) == 0 or mhx[0] is None:
212 | mhx = [self.memories[0].reset(batch_size, erase=reset_experience)]
213 | else:
214 | mhx = [self.memories[0].reset(batch_size, mhx[0], erase=reset_experience)]
215 | else:
216 | if len(mhx) == 0:
217 | mhx = [m.reset(batch_size, erase=reset_experience) for m in self.memories]
218 | else:
219 | new_mhx = []
220 | for i, m in enumerate(self.memories):
221 | if i < len(mhx) and mhx[i] is not None:
222 | new_mhx.append(m.reset(batch_size, mhx[i], erase=reset_experience))
223 | else:
224 | new_mhx.append(m.reset(batch_size, erase=reset_experience))
225 | mhx = new_mhx
226 |
227 | return chx, mhx, last_read
228 |
229 | def _debug(
230 | self, mhx: MemoryHiddenState, debug_obj: dict[str, list[np.ndarray]] | None
231 | ) -> dict[str, list[np.ndarray]] | None:
232 | """Collects debug information. Only returns a debug_obj if self.debug is True.
233 |
234 | Args:
235 | mhx: Memory hidden state.
236 | debug_obj: Debug object containing lists of numpy arrays.
237 |
238 | Returns:
239 | Debug object or None.
240 | """
241 | if not self.debug:
242 | return None
243 |
244 | if not debug_obj:
245 | debug_obj = {
246 | "memory": [],
247 | "link_matrix": [],
248 | "precedence": [],
249 | "read_weights": [],
250 | "write_weights": [],
251 | "usage_vector": [],
252 | }
253 |
254 | debug_obj["memory"].append(mhx["memory"][0].detach().cpu().numpy())
255 | debug_obj["link_matrix"].append(mhx["link_matrix"][0][0].detach().cpu().numpy())
256 | debug_obj["precedence"].append(mhx["precedence"][0].detach().cpu().numpy())
257 | debug_obj["read_weights"].append(mhx["read_weights"][0].detach().cpu().numpy())
258 | debug_obj["write_weights"].append(mhx["write_weights"][0].detach().cpu().numpy())
259 | debug_obj["usage_vector"].append(mhx["usage_vector"][0].unsqueeze(0).detach().cpu().numpy())
260 | return debug_obj
261 |
262 | def _layer_forward(
263 | self,
264 | input: torch.Tensor,
265 | layer: int,
266 | hx: LayerHiddenState,
267 | pass_through_memory: bool = True,
268 | ) -> tuple[torch.Tensor, LayerHiddenState]:
269 | """Performs a forward pass through a single layer.
270 |
271 | Args:
272 | input : Input tensor.
273 | layer: Layer index.
274 | hx: Hidden state for the layer.
275 | pass_through_memory: Whether to pass the input through memory.
276 |
277 | Returns:
278 | Tuple: Output, and updated hidden state.
279 | """
280 | (chx, mhx, _) = hx
281 |
282 | # pass through the controller layer
283 | input, chx = self.rnns[layer](input.unsqueeze(1), chx)
284 | input = input.squeeze(1) # Remove the sequence length dimension (always 1)
285 |
286 | # clip the controller output
287 | if self.clip != 0:
288 | output = torch.clamp(input, -self.clip, self.clip)
289 | else:
290 | output = input
291 |
292 | # the interface vector
293 | ξ = output
294 |
295 | # pass through memory
296 | if pass_through_memory:
297 | if self.share_memory_between_layers:
298 | read_vecs, mhx = self.memories[0](ξ, mhx)
299 | else:
300 | read_vecs, mhx = self.memories[layer](ξ, mhx)
301 | # the read vectors
302 | read_vectors = read_vecs.view(-1, self.w * self.r)
303 | else:
304 | # Initialize read vectors with zeros when not passing through memory
305 | read_vectors = cuda(torch.zeros(ξ.size(0), self.w * self.r), device=self.device)
306 |
307 | return output, (chx, mhx, read_vectors)
308 |
309 | def forward(
310 | self,
311 | input_data: torch.Tensor | PackedSequence,
312 | hx: DNCHiddenState | None,
313 | reset_experience: bool = False,
314 | pass_through_memory: bool = True,
315 | ) -> (
316 | tuple[torch.Tensor | PackedSequence, DNCHiddenState]
317 | | tuple[torch.Tensor | PackedSequence, DNCHiddenState, dict[str, Any]]
318 | ):
319 | """Performs a forward pass through the DNC.
320 |
321 | Args:
322 | input_data: Input tensor or PackedSequence.
323 | hx: Hidden state or None.
324 | reset_experience: Whether to reset memory experience.
325 | pass_through_memory: Whether to pass the input through memory.
326 |
327 | Returns:
328 | Tuple: Output (same type as input_data), updated hidden state, and optionally debug information.
329 |
330 | """
331 | max_length: int
332 | # handle packed data
333 | if isinstance(input_data, PackedSequence):
334 | input, lengths = pad_packed_sequence(input_data, batch_first=self.batch_first)
335 | max_length = int(lengths.max().item())
336 | elif isinstance(input_data, torch.Tensor):
337 | input = input_data
338 | batch_size = input.size(0) if self.batch_first else input.size(1)
339 | max_length = input.size(1) if self.batch_first else input.size(0)
340 | lengths = torch.tensor([max_length] * batch_size, device=input.device)
341 |
342 | else:
343 | raise TypeError("input_data must be a PackedSequence or Tensor")
344 |
345 | if not self.batch_first:
346 | input = input.transpose(0, 1)
347 | # make the data time-first
348 |
349 | controller_hidden, mem_hidden, last_read = self._init_hidden(hx, batch_size, reset_experience)
350 |
351 | # last_read is guaranteed to be initialized by _init_hidden, so no need to check for None
352 |
353 | inputs = [torch.cat([input[:, x, :], last_read], 1) for x in range(max_length)]
354 |
355 | # batched forward pass per element / word / etc
356 | if self.debug:
357 | viz: dict[str, Any] | None = None
358 |
359 | outs: list[torch.Tensor | None] = [None] * max_length
360 | read_vectors: torch.Tensor | None = None
361 |
362 | # pass through time
363 | for time in range(max_length):
364 | # pass thorugh layers
365 | for layer in range(self.num_layers):
366 | # this layer's hidden states
367 | chx_layer = controller_hidden[layer]
368 | mem_layer = mem_hidden[0] if self.share_memory_between_layers else mem_hidden[layer]
369 |
370 | # pass through controller
371 | outs[time], (
372 | chx_layer_output,
373 | mem_layer_output,
374 | read_vectors,
375 | ) = self._layer_forward(
376 | inputs[time], layer, (chx_layer, mem_layer, read_vectors), pass_through_memory # type: ignore
377 | )
378 |
379 | # debug memory
380 | if self.debug:
381 | viz = self._debug(mem_layer_output, viz)
382 |
383 | # store the memory back (per layer or shared)
384 | if self.share_memory_between_layers:
385 | mem_hidden[0] = mem_layer_output # type: ignore
386 | else:
387 | mem_hidden[layer] = mem_layer_output # type: ignore
388 | controller_hidden[layer] = chx_layer_output
389 |
390 | if read_vectors is not None:
391 | # the controller output + read vectors go into next layer
392 | outs[time] = torch.cat([outs[time], read_vectors], 1) # type: ignore
393 | else:
394 | outs[time] = torch.cat([outs[time], last_read], 1) # type: ignore
395 | inputs[time] = outs[time] # type: ignore
396 |
397 | if self.debug and viz:
398 | viz = {k: [np.array(v) for v in vs] for k, vs in viz.items()}
399 | viz = {k: [v.reshape(v.shape[0], -1) for v in vs] for k, vs in viz.items()}
400 |
401 | # pass through final output layer
402 | inputs_tensor = torch.stack(inputs)
403 | outputs = self.output(inputs_tensor)
404 |
405 | if not self.batch_first:
406 | outputs = outputs.transpose(0, 1)
407 |
408 | if isinstance(input_data, PackedSequence):
409 | outputs = pack_padded_sequence(outputs, lengths.cpu(), batch_first=self.batch_first, enforce_sorted=False)
410 |
411 | if self.debug:
412 | return outputs, (controller_hidden, mem_hidden, read_vectors), viz # type: ignore
413 | else:
414 | return outputs, (controller_hidden, mem_hidden, read_vectors) # type: ignore
415 |
416 | def __repr__(self) -> str:
417 | """Provides a string representation of the DNC module."""
418 |
419 | s = "\n----------------------------------------\n"
420 | s += "{name}({input_size}, {hidden_size}"
421 | if self.rnn_type != "lstm":
422 | s += ", rnn_type={rnn_type}"
423 | if self.num_layers != 1:
424 | s += ", num_layers={num_layers}"
425 | if self.num_hidden_layers != 2:
426 | s += ", num_hidden_layers={num_hidden_layers}"
427 | if not self.bias:
428 | s += ", bias={bias}"
429 | if not self.batch_first:
430 | s += ", batch_first={batch_first}"
431 | if self.dropout != 0:
432 | s += ", dropout={dropout}"
433 | if self.nr_cells != 5:
434 | s += ", nr_cells={nr_cells}"
435 | if self.read_heads != 2:
436 | s += ", read_heads={read_heads}"
437 | if self.cell_size != 10:
438 | s += ", cell_size={cell_size}"
439 | if self.nonlinearity != "tanh":
440 | s += ", nonlinearity={nonlinearity}"
441 | if self.independent_linears:
442 | s += ", independent_linears={independent_linears}"
443 | if not self.share_memory_between_layers:
444 | s += ", share_memory_between_layers={share_memory_between_layers}"
445 | if self.debug:
446 | s += ", debug={debug}"
447 | if self.clip != 20:
448 | s += ", clip={clip}"
449 | if self.device:
450 | s += f", device='{self.device}'"
451 |
452 | s += ")\n" + super(DNC, self).__repr__() + "\n----------------------------------------\n"
453 | return s.format(name=self.__class__.__name__, **self.__dict__)
454 |
--------------------------------------------------------------------------------
/dnc/sparse_temporal_memory.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .util import cuda, σ, θ
9 |
10 |
11 | class SparseTemporalMemory(nn.Module):
12 | """Sparse Temporal Memory module."""
13 |
14 | def __init__(
15 | self,
16 | input_size: int,
17 | mem_size: int = 512,
18 | cell_size: int = 32,
19 | independent_linears: bool = True,
20 | read_heads: int = 4,
21 | sparse_reads: int = 4,
22 | temporal_reads: int = 4,
23 | num_lists: int | None = None,
24 | index_checks: int = 32,
25 | device: torch.device | None = None,
26 | ):
27 | """Initialize SparseTemporalMemory.
28 |
29 | Args:
30 | input_size: Input size.
31 | mem_size: Memory size.
32 | cell_size: Size of each memory cell.
33 | independent_linears: Whether to use independent linear layers.
34 | read_heads: Number of read heads.
35 | sparse_reads: Number of sparse reads.
36 | temporal_reads: Number of temporal reads.
37 | num_lists: Number of lists for indexing.
38 | index_checks: Number of index checks.
39 | device: PyTorch device
40 |
41 | """
42 | super(SparseTemporalMemory, self).__init__()
43 |
44 | self.mem_size = mem_size
45 | self.cell_size = cell_size
46 | self.device = device
47 | self.input_size = input_size
48 | self.independent_linears = independent_linears
49 | self.K = sparse_reads if self.mem_size > sparse_reads else self.mem_size
50 | self.KL = temporal_reads if self.mem_size > temporal_reads else self.mem_size
51 | self.read_heads = read_heads
52 | self.num_lists = num_lists if num_lists is not None else int(self.mem_size / 100)
53 | self.index_checks = index_checks
54 |
55 | m = self.mem_size
56 | w = self.cell_size
57 | r = self.read_heads
58 | # The visible memory size: (K * R read heads, forward and backward
59 | # temporal reads of size KL and least used memory cell)
60 | self.c = (r * self.K) + (self.KL * 2) + 1
61 |
62 | if self.independent_linears:
63 | self.read_query_transform = nn.Linear(self.input_size, w * r)
64 | self.write_vector_transform = nn.Linear(self.input_size, w)
65 | self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
66 | self.write_gate_transform = nn.Linear(self.input_size, 1)
67 | torch.nn.init.kaiming_uniform_(self.read_query_transform.weight)
68 | torch.nn.init.kaiming_uniform_(self.write_vector_transform.weight)
69 | torch.nn.init.kaiming_uniform_(self.interpolation_gate_transform.weight)
70 | torch.nn.init.kaiming_uniform_(self.write_gate_transform.weight)
71 | else:
72 | self.interface_size = (r * w) + w + self.c + 1
73 | self.interface_weights = nn.Linear(self.input_size, self.interface_size)
74 | torch.nn.init.kaiming_uniform_(self.interface_weights.weight)
75 |
76 | self.I = cuda(1 - torch.eye(self.c).unsqueeze(0), device=self.device) # (1 * n * n)
77 | self.δ = 0.005 # minimum usage
78 | self.timestep = 0
79 | self.mem_limit_reached = False
80 |
81 | if self.device is not None and self.device.type == "cuda":
82 | self.to(self.device)
83 |
84 | def rebuild_indexes(self, hidden: dict, erase: bool = False) -> dict:
85 | """Rebuilds the indexes for sparse memory access.
86 |
87 | Args:
88 | hidden: Hidden state dictionary.
89 | erase: Whether to erase the existing memory content.
90 |
91 | Returns:
92 | Updated hidden state dictionary.
93 | """
94 | b = hidden["memory"].size(0)
95 |
96 | # if indexes already exist, we reset them
97 | if "indexes" in hidden:
98 | for x in hidden["indexes"]:
99 | x.reset()
100 | else:
101 | # create new indexes
102 | try:
103 | from .faiss_index import FAISSIndex
104 |
105 | hidden["indexes"] = [
106 | FAISSIndex(
107 | cell_size=self.cell_size,
108 | nr_cells=self.mem_size,
109 | K=self.K,
110 | num_lists=self.num_lists,
111 | probes=self.index_checks,
112 | device=self.device,
113 | )
114 | for _ in range(b)
115 | ]
116 | except ImportError:
117 | print(
118 | "FAISS not found, please install FAISS, consult https://github.com/facebookresearch/faiss/blob/main/INSTALL.md"
119 | )
120 | raise
121 |
122 | # add existing memory into indexes
123 | pos = hidden["read_positions"].squeeze().detach().cpu().numpy()
124 | if not erase:
125 | for n, i in enumerate(hidden["indexes"]):
126 | i.reset()
127 | i.add(hidden["memory"][n], last=pos[n][-1])
128 | else:
129 | self.timestep = 0
130 | self.mem_limit_reached = False
131 |
132 | return hidden
133 |
134 | def reset(self, batch_size: int = 1, hidden: dict | None = None, erase: bool = True) -> dict:
135 | """Resets the memory and hidden state.
136 |
137 | Args:
138 | batch_size: Batch size.
139 | hidden: Hidden state dictionary.
140 | erase: Whether to erase the existing memory content
141 |
142 | Returns:
143 | Reset hidden state dictionary.
144 | """
145 | m = self.mem_size
146 | w = self.cell_size
147 | b = batch_size
148 | r = self.read_heads
149 | c = self.c
150 |
151 | if hidden is None:
152 | hidden = {
153 | # warning can be a huge chunk of contiguous memory
154 | "memory": cuda(torch.zeros(b, m, w).fill_(self.δ), device=self.device),
155 | "visible_memory": cuda(torch.zeros(b, c, w).fill_(self.δ), device=self.device),
156 | "link_matrix": cuda(torch.zeros(b, m, self.KL * 2), device=self.device),
157 | "rev_link_matrix": cuda(torch.zeros(b, m, self.KL * 2), device=self.device),
158 | "precedence": cuda(torch.zeros(b, self.KL * 2).fill_(self.δ), device=self.device),
159 | "read_weights": cuda(torch.zeros(b, m).fill_(self.δ), device=self.device),
160 | "write_weights": cuda(torch.zeros(b, m).fill_(self.δ), device=self.device),
161 | "read_vectors": cuda(torch.zeros(b, r, w).fill_(self.δ), device=self.device),
162 | "least_used_mem": cuda(torch.zeros(b, 1).fill_(c + 1), device=self.device).long(),
163 | "usage": cuda(torch.zeros(b, m), device=self.device),
164 | "read_positions": cuda(torch.arange(0, c).expand(b, c), device=self.device).long(),
165 | }
166 | hidden = self.rebuild_indexes(hidden, erase=True)
167 | else:
168 | # duplication is faster than moving tensors between devices (or even cloning)
169 | hidden["memory"] = hidden["memory"].clone()
170 | hidden["visible_memory"] = hidden["visible_memory"].clone()
171 | hidden["link_matrix"] = hidden["link_matrix"].clone()
172 | hidden["rev_link_matrix"] = hidden["link_matrix"].clone()
173 | hidden["precedence"] = hidden["precedence"].clone()
174 | hidden["read_weights"] = hidden["read_weights"].clone()
175 | hidden["write_weights"] = hidden["write_weights"].clone()
176 | hidden["read_vectors"] = hidden["read_vectors"].clone()
177 | hidden["least_used_mem"] = hidden["least_used_mem"].clone()
178 | hidden["usage"] = hidden["usage"].clone()
179 | hidden["read_positions"] = hidden["read_positions"].clone()
180 | hidden = self.rebuild_indexes(hidden, erase)
181 |
182 | if erase:
183 | hidden["memory"].data.fill_(self.δ)
184 | hidden["visible_memory"].data.fill_(self.δ)
185 | hidden["link_matrix"].data.zero_()
186 | hidden["rev_link_matrix"].data.zero_()
187 | hidden["precedence"].data.zero_()
188 | hidden["read_weights"].data.fill_(self.δ)
189 | hidden["write_weights"].data.fill_(self.δ)
190 | hidden["read_vectors"].data.fill_(self.δ)
191 | hidden["least_used_mem"].data.fill_(c + 1)
192 | hidden["usage"].data.fill_(0)
193 | hidden["read_positions"] = cuda(torch.arange(0, c).expand(b, c), device=self.device).long()
194 |
195 | return hidden
196 |
197 | def write_into_sparse_memory(self, hidden: dict) -> dict:
198 | """Writes the visible memory into the sparse memory matrix.
199 |
200 | Args:
201 | hidden (dict): Hidden state dictionary
202 |
203 | Returns:
204 | dict: Updated hidden state dictionary.
205 | """
206 | visible_memory = hidden["visible_memory"]
207 | positions = hidden["read_positions"]
208 |
209 | (b, m, w) = hidden["memory"].size()
210 | # Create a new tensor for memory to avoid inplace operations during backprop
211 | new_memory = hidden["memory"].clone()
212 | # update memory (using non-inplace operation)
213 | new_memory.scatter_(1, positions.unsqueeze(2).expand(b, self.c, w), visible_memory)
214 | hidden["memory"] = new_memory
215 |
216 | # non-differentiable operations
217 | pos = positions.detach().cpu().numpy()
218 | for batch in range(b):
219 | # update indexes
220 | hidden["indexes"][batch].reset()
221 | hidden["indexes"][batch].add(
222 | hidden["memory"][batch], last=(pos[batch][-1] if not self.mem_limit_reached else None)
223 | )
224 |
225 | mem_limit_reached = hidden["least_used_mem"][0].detach().cpu().numpy()[0] >= self.mem_size - 1
226 | self.mem_limit_reached = mem_limit_reached or self.mem_limit_reached
227 |
228 | return hidden
229 |
230 | def update_link_matrices(
231 | self,
232 | link_matrix: torch.Tensor,
233 | rev_link_matrix: torch.Tensor,
234 | write_weights: torch.Tensor,
235 | precedence: torch.Tensor,
236 | temporal_read_positions: torch.Tensor,
237 | ) -> tuple[torch.Tensor, torch.Tensor]:
238 | """Updates the forward and backward link matrices.
239 |
240 | Args:
241 | link_matrix: Forward link matrix.
242 | rev_link_matrix: Backward link matrix.
243 | write_weights: Write weights.
244 | precedence: Precedence vector.
245 | temporal_read_positions: Temporal read positions.
246 |
247 | Returns:
248 | Tuple: Updated forward and backward link matrices.
249 | """
250 | write_weights_i = write_weights.unsqueeze(2)
251 | precedence_j = precedence.unsqueeze(1)
252 |
253 | (b, m, k) = link_matrix.size()
254 | I = cuda(torch.eye(m, k).unsqueeze(0).expand((b, m, k)), device=self.device)
255 |
256 | # since only KL*2 entries are kept non-zero sparse, create the dense version from the sparse one
257 | precedence_dense = cuda(torch.zeros(b, m), device=self.device)
258 | # Use non-inplace operation - create new tensor and assign
259 | precedence_tmp = precedence_dense.clone()
260 | precedence_tmp.scatter_(1, temporal_read_positions, precedence)
261 | precedence_dense = precedence_tmp
262 | precedence_dense_i = precedence_dense.unsqueeze(2)
263 |
264 | temporal_write_weights_j = write_weights.gather(1, temporal_read_positions).unsqueeze(1)
265 |
266 | link_matrix = (1 - write_weights_i) * link_matrix + write_weights_i * precedence_j
267 |
268 | rev_link_matrix = (1 - temporal_write_weights_j) * rev_link_matrix + (
269 | temporal_write_weights_j * precedence_dense_i
270 | )
271 |
272 | return link_matrix * I, rev_link_matrix * I
273 |
274 | def update_precedence(self, precedence: torch.Tensor, write_weights: torch.Tensor) -> torch.Tensor:
275 | """Updates the precedence vector.
276 |
277 | Args:
278 | precedence: Precedence vector.
279 | write_weights: Write weights.
280 |
281 | Returns:
282 | Updated precedence vector.
283 |
284 | """
285 | return (1 - torch.sum(write_weights, dim=-1, keepdim=True)) * precedence + write_weights
286 |
287 | def write(
288 | self, interpolation_gate: torch.Tensor, write_vector: torch.Tensor, write_gate: torch.Tensor, hidden: dict
289 | ) -> dict:
290 | """Performs the memory write operation.
291 |
292 | Args:
293 | interpolation_gate : Interpolation gate.
294 | write_vector: Write vector.
295 | write_gate: Write gate.
296 | hidden: Hidden state dictionary
297 |
298 | Returns:
299 | Updated hidden state dictionary.
300 |
301 | """
302 |
303 | read_weights = hidden["read_weights"].gather(1, hidden["read_positions"])
304 | # encourage read and write in the first timestep
305 | if self.timestep == 1:
306 | read_weights = read_weights + 1
307 | write_weights = hidden["write_weights"].gather(1, hidden["read_positions"])
308 |
309 | hidden["usage"], I = self.update_usage(hidden["read_positions"], read_weights, write_weights, hidden["usage"])
310 |
311 | # either we write to previous read locations
312 | x = interpolation_gate * read_weights
313 | # or to a new location
314 | y = (1 - interpolation_gate) * I
315 | write_weights = write_gate * (x + y)
316 |
317 | # store the write weights (avoid inplace operations)
318 | new_write_weights = hidden["write_weights"].clone()
319 | new_write_weights.scatter_(1, hidden["read_positions"], write_weights)
320 | hidden["write_weights"] = new_write_weights
321 |
322 | # erase matrix
323 | erase_matrix = I.unsqueeze(2).expand(hidden["visible_memory"].size())
324 |
325 | # write into memory
326 | hidden["visible_memory"] = hidden["visible_memory"] * (1 - erase_matrix) + torch.bmm(
327 | write_weights.unsqueeze(2), write_vector
328 | )
329 | hidden = self.write_into_sparse_memory(hidden)
330 |
331 | # update link_matrix and precedence
332 | (b, _) = write_weights.size() # c
333 |
334 | # update link matrix
335 | temporal_read_positions = hidden["read_positions"][:, self.read_heads * self.K + 1 :]
336 | hidden["link_matrix"], hidden["rev_link_matrix"] = self.update_link_matrices(
337 | hidden["link_matrix"],
338 | hidden["rev_link_matrix"],
339 | hidden["write_weights"],
340 | hidden["precedence"],
341 | temporal_read_positions,
342 | )
343 |
344 | # update precedence vector
345 | read_weights = hidden["read_weights"].gather(1, temporal_read_positions)
346 | hidden["precedence"] = self.update_precedence(hidden["precedence"], read_weights)
347 |
348 | # update least used memory cell
349 | hidden["least_used_mem"] = torch.topk(hidden["usage"], 1, dim=-1, largest=False)[1]
350 |
351 | return hidden
352 |
353 | def update_usage(
354 | self, read_positions: torch.Tensor, read_weights: torch.Tensor, write_weights: torch.Tensor, usage: torch.Tensor
355 | ) -> tuple[torch.Tensor, torch.Tensor]:
356 | """Updates the usage vector.
357 |
358 | Args:
359 | read_positions: Read positions.
360 | read_weights: Read weights.
361 | write_weights: Write weights.
362 | usage: Usage vector.
363 |
364 | Returns:
365 | Tuple: Updated usage vector and indicator matrix.
366 | """
367 | # usage is timesteps since a non-negligible memory access
368 | u = (read_weights + write_weights > self.δ).float()
369 |
370 | # usage before write
371 | relevant_usages = usage.gather(1, read_positions)
372 |
373 | # indicator of words with minimal memory usage
374 | minusage = torch.min(relevant_usages, -1, keepdim=True)[0]
375 | minusage = minusage.expand(relevant_usages.size())
376 | I = (relevant_usages == minusage).float()
377 |
378 | # usage after write
379 | relevant_usages = (self.timestep - relevant_usages) * u + relevant_usages * (1 - u)
380 |
381 | # Replace inplace scatter with clone + scatter + assignment
382 | new_usage = usage.clone()
383 | new_usage.scatter_(1, read_positions, relevant_usages)
384 | usage = new_usage
385 |
386 | return usage, I
387 |
388 | def directional_weightings(
389 | self, link_matrix: torch.Tensor, rev_link_matrix: torch.Tensor, temporal_read_weights: torch.Tensor
390 | ) -> tuple[torch.Tensor, torch.Tensor]:
391 | """Calculates the forward and backward weightings.
392 |
393 | Args:
394 | link_matrix: Forward link matrix.
395 | rev_link_matrix: Backward link matrix.
396 | temporal_read_weights: Temporal read weights.
397 |
398 | Returns:
399 | Tuple: Forward and backward weightings.
400 | """
401 | f = torch.bmm(link_matrix, temporal_read_weights.unsqueeze(2)).squeeze(2)
402 | b = torch.bmm(rev_link_matrix, temporal_read_weights.unsqueeze(2)).squeeze(2)
403 | return f, b
404 |
405 | def read_from_sparse_memory(
406 | self,
407 | memory: torch.Tensor,
408 | indexes: list,
409 | keys: torch.Tensor,
410 | least_used_mem: torch.Tensor,
411 | usage: torch.Tensor,
412 | forward: torch.Tensor,
413 | backward: torch.Tensor,
414 | prev_read_positions: torch.Tensor,
415 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
416 | """Reads from sparse memory using indexes.
417 |
418 | Args:
419 | memory: Memory tensor.
420 | indexes: List of indexes.
421 | keys: Read keys.
422 | least_used_mem: Least used memory locations.
423 | usage: Usage vector.
424 | forward: Forward weightings.
425 | backward: Backward weightings.
426 | prev_read_positions: Previous read positions.
427 |
428 | Returns:
429 | Tuple: Read vectors, read positions, read weights, and visible memory.
430 | """
431 | b = keys.size(0)
432 | rpos = []
433 |
434 | # we search for k cells per read head
435 | for batch in range(b):
436 | distances, positions = indexes[batch].search(keys[batch])
437 | rpos.append(positions)
438 | read_positions = torch.stack(rpos, 0)
439 |
440 | (b, r, k) = read_positions.size()
441 | read_positions = read_positions.squeeze(1).view(b, -1)
442 |
443 | # no gradient here
444 | # temporal reads
445 | (b, m, w) = memory.size()
446 | # get the top KL entries
447 | max_length = int(least_used_mem[0, 0].detach().cpu().numpy()) if not self.mem_limit_reached else (m - 1)
448 |
449 | _, fp = torch.topk(forward, self.KL, largest=True)
450 | _, bp = torch.topk(backward, self.KL, largest=True)
451 |
452 | # differentiable ops
453 | # append forward and backward read positions, might lead to duplicates
454 | read_positions = torch.cat([read_positions, fp, bp], 1)
455 | read_positions = torch.cat([read_positions, least_used_mem], 1)
456 | read_positions = torch.clamp(read_positions, 0, max_length)
457 |
458 | visible_memory = memory.gather(1, read_positions.unsqueeze(2).expand(b, self.c, w))
459 |
460 | read_weights = σ(θ(visible_memory, keys), 2)
461 | read_vectors = torch.bmm(read_weights, visible_memory)
462 | read_weights = torch.prod(read_weights, 1)
463 |
464 | return read_vectors, read_positions, read_weights, visible_memory
465 |
466 | def read(self, read_query: torch.Tensor, hidden: dict) -> tuple[torch.Tensor, dict]:
467 | """Performs the memory read operation.
468 |
469 | Args:
470 | read_query: Read query.
471 | hidden: Hidden state dictionary.
472 |
473 | Returns:
474 | Tuple: Read vectors and updated hidden state.
475 | """
476 | # get forward and backward weights
477 | temporal_read_positions = hidden["read_positions"][:, self.read_heads * self.K + 1 :]
478 | read_weights = hidden["read_weights"].gather(1, temporal_read_positions)
479 | forward, backward = self.directional_weightings(hidden["link_matrix"], hidden["rev_link_matrix"], read_weights)
480 |
481 | # sparse read
482 | read_vectors, positions, read_weights, visible_memory = self.read_from_sparse_memory(
483 | hidden["memory"],
484 | hidden["indexes"],
485 | read_query,
486 | hidden["least_used_mem"],
487 | hidden["usage"],
488 | forward,
489 | backward,
490 | hidden["read_positions"],
491 | )
492 |
493 | hidden["read_positions"] = positions
494 | # Avoid inplace operation
495 | new_read_weights = hidden["read_weights"].clone()
496 | new_read_weights.scatter_(1, positions, read_weights)
497 | hidden["read_weights"] = new_read_weights
498 | hidden["read_vectors"] = read_vectors
499 | hidden["visible_memory"] = visible_memory
500 |
501 | return hidden["read_vectors"], hidden
502 |
503 | def forward(self, ξ: torch.Tensor, hidden: dict) -> tuple[torch.Tensor, dict]:
504 | """Forward pass through the memory.
505 |
506 | Args:
507 | ξ: Input tensor.
508 | hidden: Hidden state dictionary
509 |
510 | Returns:
511 | Tuple: Read vectors and updated hidden state.
512 | """
513 | m = self.mem_size
514 | w = self.cell_size
515 | r = self.read_heads
516 | c = self.c
517 | b = ξ.size(0)
518 |
519 | if self.independent_linears:
520 | # r read keys (b * r * w)
521 | read_query = self.read_query_transform(ξ).view(b, r, w)
522 | # write key (b * 1 * w)
523 | write_vector = self.write_vector_transform(ξ).view(b, 1, w)
524 | # write vector (b * 1 * r)
525 | interpolation_gate = torch.sigmoid(self.interpolation_gate_transform(ξ)).view(b, c)
526 | # write gate (b * 1)
527 | write_gate = torch.sigmoid(self.write_gate_transform(ξ).view(b, 1))
528 | else:
529 | ξ = self.interface_weights(ξ)
530 | # r read keys (b * r * w)
531 | read_query = ξ[:, : r * w].contiguous().view(b, r, w)
532 | # write key (b * 1 * w)
533 | write_vector = ξ[:, r * w : r * w + w].contiguous().view(b, 1, w)
534 | # write vector (b * 1 * r)
535 | interpolation_gate = torch.sigmoid(ξ[:, r * w + w : r * w + w + c]).contiguous().view(b, c)
536 | # write gate (b * 1)
537 | write_gate = torch.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
538 |
539 | self.timestep += 1
540 | hidden = self.write(interpolation_gate, write_vector, write_gate, hidden)
541 | return self.read(read_query, hidden)
542 |
--------------------------------------------------------------------------------