├── .gitignore ├── LICENSE ├── README.md ├── evaluate.py ├── images ├── associative.png ├── associative_1.png ├── associative_2.png ├── associative_loss.png ├── copy.png ├── copy_1.png ├── copy_loss.png ├── priority_sort_1.png ├── prioritysort.png ├── prioritysort_loss.png ├── repeat_copy.png ├── repeat_copy_1.png ├── repeat_copy_loss.png ├── repeat_copy_rep.png ├── repeat_copy_rep_1.png ├── repeat_copy_rep_2.png ├── repeat_copy_seq_len.png └── repeat_copy_seq_len_1.png ├── ntm ├── __init__.py ├── args.py ├── datasets │ ├── __init__.py │ ├── associative.py │ ├── copy.py │ ├── ngram.py │ ├── prioritysort.py │ └── repeatcopy.py ├── modules │ ├── __init__.py │ ├── controller.py │ ├── head.py │ └── memory.py ├── ntm.py └── tasks │ ├── __init__.py │ ├── associative.json │ ├── copy.json │ ├── ngram.json │ ├── prioritysort.json │ └── repeatcopy.json ├── saved_models ├── saved_model_associative_100000.pt ├── saved_model_copy_500000.pt ├── saved_model_prioritysort_100000.pt └── saved_model_repeatcopy_100000.pt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | .vector_cache/ 12 | env/ 13 | build/ 14 | data/ 15 | logs/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Translations 41 | *.mo 42 | *.pot 43 | 44 | # PyBuilder 45 | target/ 46 | 47 | # IPython Notebook 48 | .ipynb_checkpoints 49 | 50 | # pyenv 51 | .python-version 52 | 53 | # celery beat schedule file 54 | celerybeat-schedule 55 | 56 | # dotenv 57 | .env 58 | 59 | # virtualenv 60 | venv/ 61 | ENV/ 62 | 63 | # Spyder project settings 64 | .spyderproject 65 | 66 | # Rope project settings 67 | .ropeproject 68 | 69 | # OS X files 70 | .DS_Store 71 | 72 | .idea 73 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | Copyright © 2018 Karan Desai 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the “Software”), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in 12 | all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 | THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Neural Turing Machines (Pytorch) 2 | ================================= 3 | Code for the paper 4 | 5 | **[Neural Turing Machines][1]** 6 | 7 | Alex Graves, Greg Wayne, Ivo Danihelka 8 | 9 | [1]: https://arxiv.org/abs/1410.5401 10 | Neural Turing Machines (NTMs) contain a recurrent network coupled with an external memory resource, which it can interact with by attentional processes. Therefore NTMs can be called Memory Augmented Neural Networks. They are end-to-end differentiable and thus are hypothesised at being able to learn simple algorithms. They outperform LSTMs in learning several algorithmic tasks due to the presence of external memory without an increase in parameters and computation. 11 | 12 | This repository is a stable Pytorch implementation of a Neural Turing Machine and contains the code for training, evaluating and visualizing results for the Copy, Repeat Copy, Associative Recall and Priority Sort tasks. The code has been tested for all 4 tasks and the results obtained are in accordance with the results mentioned in the paper. The training and evaluation code for N-Gram task has been provided however the results would be uploaded after testing. 13 | 14 |

15 | 16 |

17 |

18 | Neural Turing Machine architecture 19 |

20 | 21 | Setup 22 | ================================= 23 | Our code is implemented in Pytorch 0.4.0 and Python >=3.5. To setup, proceed as follows : 24 | 25 | To install Pytorch head over to ```https://pytorch.org/``` or install using miniconda or anaconda package by running 26 | ```conda install -c soumith pytorch ```. 27 | 28 | Clone this repository : 29 | 30 | ``` 31 | git clone https://www.github.com/kdexd/ntm-pytorch 32 | ``` 33 | 34 | The other python libraries that you'll need to run the code : 35 | ``` 36 | pip install numpy 37 | pip install tensorboard_logger 38 | pip install matplotlib 39 | pip install tqdm 40 | pip install Pillow 41 | ``` 42 | 43 | Training 44 | ================================ 45 | Training works with default arguments by : 46 | ``` 47 | python train.py 48 | ``` 49 | The script runs with all arguments set to default value. If you wish to changes any of these, run the script with ```-h``` to see available arguments and change them as per need be. 50 | ``` 51 | usage : train.py [-h] [-task_json TASK_JSON] [-batch_size BATCH_SIZE] 52 | [-num_iters NUM_ITERS] [-lr LR] [-momentum MOMENTUM] 53 | [-alpha ALPHA] [-saved_model SAVED_MODEL] [-beta1 BETA1] [-beta2 BETA2] 54 | ``` 55 | Both RMSprop and Adam optimizers have been provided. ```-momentum``` and ```-alpha``` are parameters for RMSprop and ```-beta1``` and ```-beta2``` are parameters for Adam. All these arguments are initialized to their default values. 56 | 57 | The smoothing factor for all curves is ```0.6``` 58 | - Training for copy task is carried out with sequence length ranging from 1-20. The curve for bits per sequence error vs iterations for this task is shown below : 59 | ![Alt text](https://github.com/vlgiitr/ntm-pytorch/blob/master/images/copy_loss.png) 60 | 61 | - Training for repeat copy task is carried out with sequence length ranging from 1-10 and repeat number in the range 1-10. The curve for bits per sequence error vs iterations for this task is shown below : 62 | ![Alt text](https://github.com/vlgiitr/ntm-pytorch/blob/master/images/repeat_copy_loss.png) 63 | 64 | - Training for associative recall task is carried out the number of items ranging from 2-6.The curve for bits per sequence error vs iterations for this task is shown below : 65 | ![Alt text](https://github.com/vlgiitr/ntm-pytorch/blob/master/images/associative_loss.png) 66 | 67 | - Training for priority sort task is carried outwith an input sequence length of 20 and target sequence length of 16. The curve for bits per sequence error vs iterations for this task is shown below : 68 | ![Alt text](https://github.com/vlgiitr/ntm-pytorch/blob/master/images/prioritysort_loss.png) 69 | 70 | 71 | Evaluation 72 | =============================== 73 | The model was trained and was evaluated as mentioned in the paper. The results were in accordance with the paper. Saved models for all the tasks are available in the ```saved_models``` folder. The model for copy task has been trained upto 500k iterations and those for repeat copy, associative recall and priority sort have been trained upto 100k iterations. The code for saving and loading the model has been incorporated in ```train.py``` and ```evaluate.py``` respectively. 74 | 75 | The evaluation parameters for all tasks have been included in ```evaluate.py```. 76 | 77 | Evaluation can be done as follows : 78 | ``` 79 | python evaluate.py 80 | ``` 81 | - Results for copy task shows that the NTM generalizes well for sequence length upto 120. The target and output for copy task is shown below : 82 | 83 | ![Alt text](https://github.com/vlgiitr/ntm-pytorch/blob/master/images/copy_1.png) 84 | 85 | - Results for the repeat copy task shows that the NTM generalizes well for maximum sequence length of 20 and repeat number upto 20. The target and output for repeat copy task is shown below : 86 | 87 | ![Alt text](https://github.com/vlgiitr/ntm-pytorch/blob/master/images/repeat_copy_seq_len_1.png) 88 | 89 | - Results for associative recall task shows that the NTM generalizes well for number of items upto 20. The target and output for associative recall task is shown below : 90 | 91 | ![Alt text](https://github.com/vlgiitr/ntm-pytorch/blob/master/images/associative_2.png) 92 | 93 | - Results for the priority sort task also show the better generalization capability of the NTM. The target and output for priority sort task is shown below : 94 | 95 | ![Alt text](https://github.com/vlgiitr/ntm-pytorch/blob/master/images/priority_sort_1.png) 96 | 97 | 98 | Visualization 99 | =============================== 100 | We have integrated Tensorboard_logger to visualize training and evaluation loss and bits per sequence error. To install tensorboard logger use : 101 | ``` 102 | pip install tensorboard_logger 103 | ``` 104 | 105 | Sample outputs and bits per sequence error curves have been provided in the ```images``` folder. 106 | 107 | Acknowledgements 108 | =============================== 109 | - We have used the following codebase as a reference for our implementation : **[loudinthecloud/pytorch-ntm][2]** 110 | 111 | [2]:https://github.com/loudinthecloud/pytorch-ntm 112 | 113 | LICENSE 114 | =============================== 115 | MIT 116 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import matplotlib.pyplot as plt 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from ntm import NTM 9 | from ntm.datasets import CopyDataset, RepeatCopyDataset, AssociativeDataset, NGram, PrioritySort 10 | from ntm.args import get_parser 11 | 12 | args = get_parser().parse_args() 13 | 14 | args.task_json = 'ntm/tasks/copy.json' 15 | ''' 16 | args.task_json = 'ntm/tasks/repeatcopy.json' 17 | args.task_json = 'ntm/tasks/associative.json' 18 | args.task_json = 'ntm/tasks/ngram.json' 19 | args.task_json = 'ntm/tasks/prioritysort.json' 20 | ''' 21 | 22 | task_params = json.load(open(args.task_json)) 23 | criterion = nn.BCELoss() 24 | 25 | # ---Evaluation parameters for Copy task--- 26 | task_params['min_seq_len'] = 20 27 | task_params['max_seq_len'] = 120 28 | 29 | ''' 30 | # ---Evaluation parameters for RepeatCopy task--- 31 | # (Sequence length generalisation) 32 | task_params['min_seq_len'] = 10 33 | task_params['max_seq_len'] = 20 34 | # (Number of repetition generalisation) 35 | task_params['min_repeat'] = 10 36 | task_params['max_repeat'] = 20 37 | 38 | # ---Evaluation parameters for AssociativeRecall task--- 39 | task_params['min_item'] = 6 40 | task_params['max_item'] = 20 41 | 42 | # For NGram and Priority sort task parameters need not be changed. 43 | ''' 44 | 45 | dataset = CopyDataset(task_params) 46 | ''' 47 | dataset = RepeatCopyDataset(task_params) 48 | dataset = AssociativeDataset(task_params) 49 | dataset = NGram(task_params) 50 | dataset = PrioritySort(task_params) 51 | ''' 52 | 53 | args.saved_model = 'saved_model_copy.pt' 54 | ''' 55 | args.saved_model = 'saved_model_repeatcopy.pt' 56 | args.saved_model = 'saved_model_associative.pt' 57 | args.saved_model = 'saved_model_ngram.pt' 58 | args.saved_model = 'saved_model_prioritysort.pt' 59 | ''' 60 | 61 | cur_dir = os.getcwd() 62 | PATH = os.path.join(cur_dir, args.saved_model) 63 | # PATH = os.path.join(cur_dir, 'saved_models/saved_model_copy_500000.pt') 64 | # ntm = torch.load(PATH) 65 | 66 | """ 67 | For the Copy task, input_size: seq_width + 2, output_size: seq_width 68 | For the RepeatCopy task, input_size: seq_width + 2, output_size: seq_width + 1 69 | For the Associative task, input_size: seq_width + 2, output_size: seq_width 70 | For the NGram task, input_size: 1, output_size: 1 71 | For the Priority Sort task, input_size: seq_width + 1, output_size: seq_width 72 | """ 73 | 74 | ntm = NTM(input_size=task_params['seq_width'] + 2, 75 | output_size=task_params['seq_width'], 76 | controller_size=task_params['controller_size'], 77 | memory_units=task_params['memory_units'], 78 | memory_unit_size=task_params['memory_unit_size'], 79 | num_heads=task_params['num_heads']) 80 | 81 | ntm.load_state_dict(torch.load(PATH)) 82 | 83 | # ----------------------------------------------------------------------------- 84 | # --- evaluation 85 | # ----------------------------------------------------------------------------- 86 | ntm.reset() 87 | data = dataset[0] # 0 is a dummy index 88 | input, target = data['input'], data['target'] 89 | out = torch.zeros(target.size()) 90 | 91 | # ----------------------------------------------------------------------------- 92 | # loop for other tasks 93 | # ----------------------------------------------------------------------------- 94 | for i in range(input.size()[0]): 95 | # to maintain consistency in dimensions as torch.cat was throwing error 96 | in_data = torch.unsqueeze(input[i], 0) 97 | ntm(in_data) 98 | 99 | # passing zero vector as the input while generating target sequence 100 | in_data = torch.unsqueeze(torch.zeros(input.size()[1]), 0) 101 | for i in range(target.size()[0]): 102 | out[i] = ntm(in_data) 103 | # ----------------------------------------------------------------------------- 104 | # ----------------------------------------------------------------------------- 105 | # loop for NGram task 106 | # ----------------------------------------------------------------------------- 107 | ''' 108 | for i in range(task_params['seq_len'] - 1): 109 | in_data = input[i].view(1, -1) 110 | ntm(in_data) 111 | target_data = torch.zeros([1]).view(1, -1) 112 | out[i] = ntm(target_data) 113 | ''' 114 | # ----------------------------------------------------------------------------- 115 | 116 | loss = criterion(out, target) 117 | 118 | binary_output = out.clone() 119 | binary_output = binary_output.detach().apply_(lambda x: 0 if x < 0.5 else 1) 120 | 121 | # sequence prediction error is calculted in bits per sequence 122 | error = torch.sum(torch.abs(binary_output - target)) 123 | 124 | # ---logging--- 125 | print('Loss: %.2f\tError in bits per sequence: %.2f' % (loss, error)) 126 | 127 | # ---saving results--- 128 | result = {'output': binary_output, 'target': target} 129 | -------------------------------------------------------------------------------- /images/associative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/images/associative.png -------------------------------------------------------------------------------- /images/associative_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/images/associative_1.png -------------------------------------------------------------------------------- /images/associative_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/images/associative_2.png -------------------------------------------------------------------------------- /images/associative_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/images/associative_loss.png -------------------------------------------------------------------------------- /images/copy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/images/copy.png -------------------------------------------------------------------------------- /images/copy_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/images/copy_1.png -------------------------------------------------------------------------------- /images/copy_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/images/copy_loss.png -------------------------------------------------------------------------------- /images/priority_sort_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/images/priority_sort_1.png -------------------------------------------------------------------------------- /images/prioritysort.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/images/prioritysort.png -------------------------------------------------------------------------------- /images/prioritysort_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/images/prioritysort_loss.png -------------------------------------------------------------------------------- /images/repeat_copy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/images/repeat_copy.png -------------------------------------------------------------------------------- /images/repeat_copy_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/images/repeat_copy_1.png -------------------------------------------------------------------------------- /images/repeat_copy_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/images/repeat_copy_loss.png -------------------------------------------------------------------------------- /images/repeat_copy_rep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/images/repeat_copy_rep.png -------------------------------------------------------------------------------- /images/repeat_copy_rep_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/images/repeat_copy_rep_1.png -------------------------------------------------------------------------------- /images/repeat_copy_rep_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/images/repeat_copy_rep_2.png -------------------------------------------------------------------------------- /images/repeat_copy_seq_len.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/images/repeat_copy_seq_len.png -------------------------------------------------------------------------------- /images/repeat_copy_seq_len_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/images/repeat_copy_seq_len_1.png -------------------------------------------------------------------------------- /ntm/__init__.py: -------------------------------------------------------------------------------- 1 | from .ntm import NTM 2 | -------------------------------------------------------------------------------- /ntm/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_parser(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('-task_json', type=str, default='ntm/tasks/copy.json', 7 | help='path to json file with task specific parameters') 8 | parser.add_argument('-saved_model', default='saved_model_copy.pt', 9 | help='path to file with final model parameters') 10 | parser.add_argument('-batch_size', type=int, default=1, 11 | help='batch size of input sequence during training') 12 | parser.add_argument('-num_iters', type=int, default=100000, 13 | help='number of iterations for training') 14 | 15 | # todo: only rmsprop optimizer supported yet, support adam too 16 | parser.add_argument('-lr', type=float, default=1e-4, 17 | help='learning rate for rmsprop optimizer') 18 | parser.add_argument('-momentum', type=float, default=0.9, 19 | help='momentum for rmsprop optimizer') 20 | parser.add_argument('-alpha', type=float, default=0.95, 21 | help='alpha for rmsprop optimizer') 22 | parser.add_argument('-beta1', type=float, default=0.9, 23 | help='beta1 constant for adam optimizer') 24 | parser.add_argument('-beta2', type=float, default=0.999, 25 | help='beta2 constant for adam optimizer') 26 | return parser 27 | -------------------------------------------------------------------------------- /ntm/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .copy import CopyDataset 2 | from .repeatcopy import RepeatCopyDataset 3 | from .associative import AssociativeDataset 4 | from .ngram import NGram 5 | from .prioritysort import PrioritySort -------------------------------------------------------------------------------- /ntm/datasets/associative.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch.distributions.binomial import Binomial 4 | 5 | 6 | class AssociativeDataset(Dataset): 7 | """A Dataset class to generate random examples for associative recall task. 8 | 9 | Each input consists of a list of items with the number of itmes between 10 | `min_item` and `max_item`. An item is a sequence of binary vectors bounded 11 | on left and right by delimiter symbols. The list is followed by query item 12 | selected randomly from input items. It too is bounded by query delimiters. 13 | 14 | Target returns the item next to the query item. 15 | """ 16 | 17 | def __init__(self, task_params): 18 | """Initialize a dataset instance for Associative Recall task. 19 | 20 | Arguments 21 | --------- 22 | task_params : dict 23 | A dict containing parameters relevant to associative recall task. 24 | """ 25 | self.seq_width = task_params["seq_width"] 26 | self.seq_len = task_params["seq_len"] 27 | self.min_item = task_params["min_item"] 28 | self.max_item = task_params["max_item"] 29 | 30 | def __len__(self): 31 | # sequences are generated randomly so this does not matter 32 | # set a sufficiently large size for data loader to sample mini-batches 33 | return 65536 34 | 35 | def __getitem__(self, idx): 36 | # idx only acts as a counter while generating batches. 37 | num_item = torch.randint( 38 | self.min_item, self.max_item, (1,), dtype=torch.long).item() 39 | prob = 0.5 * \ 40 | torch.ones([self.seq_len, self.seq_width], dtype=torch.float64) 41 | seq = Binomial(1, prob) 42 | 43 | # fill in input two bit wider than target to account for delimiter 44 | # flags. 45 | input_items = torch.zeros( 46 | [(self.seq_len + 1) * (num_item + 1) + 1, self.seq_width + 2]) 47 | for i in range(num_item): 48 | input_items[(self.seq_len + 1) * i, self.seq_width] = 1.0 49 | input_items[(self.seq_len + 1) * i + 1:(self.seq_len + 1) 50 | * (i + 1), :self.seq_width] = seq.sample() 51 | 52 | # generate query item randomly 53 | # in case of only one item, torch.randint throws error as num_item-1=0 54 | query_item = 0 55 | if num_item != 1: 56 | query_item = torch.randint( 57 | 0, num_item - 1, (1,), dtype=torch.long).item() 58 | query_seq = input_items[(self.seq_len + 1) * query_item + 59 | 1:(self.seq_len + 1) * (query_item + 1), :self.seq_width] 60 | input_items[(self.seq_len + 1) * num_item, 61 | self.seq_width + 1] = 1.0 # query delimiter 62 | input_items[(self.seq_len + 1) * num_item + 1:(self.seq_len + 1) 63 | * (num_item + 1), :self.seq_width] = query_seq 64 | input_items[(self.seq_len + 1) * (num_item + 1), 65 | self.seq_width + 1] = 1.0 # query delimiter 66 | 67 | # generate target sequences(item next to query in the input list) 68 | target_item = torch.zeros([self.seq_len, self.seq_width]) 69 | # in case of last item, target sequence is zero 70 | if query_item != num_item - 1: 71 | target_item[:self.seq_len, :self.seq_width] = input_items[ 72 | (self.seq_len + 1) * (query_item + 1) + 1:(self.seq_len + 1) * (query_item + 2), :self.seq_width] 73 | 74 | return {'input': input_items, 'target': target_item} 75 | -------------------------------------------------------------------------------- /ntm/datasets/copy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch.distributions.binomial import Binomial 4 | 5 | 6 | class CopyDataset(Dataset): 7 | """A Dataset class to generate random examples for the copy task. Each 8 | sequence has a random length between `min_seq_len` and `max_seq_len`. 9 | Each vector in the sequence has a fixed length of `seq_width`. The vectors 10 | are bounded by start and end delimiter flags. 11 | 12 | To account for the delimiter flags, the input sequence length as well 13 | width is two more than the target sequence. 14 | """ 15 | 16 | def __init__(self, task_params): 17 | """Initialize a dataset instance for copy task. 18 | 19 | Arguments 20 | --------- 21 | task_params : dict 22 | A dict containing parameters relevant to copy task. 23 | """ 24 | self.seq_width = task_params['seq_width'] 25 | self.min_seq_len = task_params['min_seq_len'] 26 | self.max_seq_len = task_params['max_seq_len'] 27 | 28 | def __len__(self): 29 | # sequences are generated randomly so this does not matter 30 | # set a sufficiently large size for data loader to sample mini-batches 31 | return 65536 32 | 33 | def __getitem__(self, idx): 34 | # idx only acts as a counter while generating batches. 35 | seq_len = torch.randint( 36 | self.min_seq_len, self.max_seq_len, (1,), dtype=torch.long).item() 37 | prob = 0.5 * torch.ones([seq_len, self.seq_width], dtype=torch.float64) 38 | seq = Binomial(1, prob).sample() 39 | 40 | # fill in input sequence, two bit longer and wider than target 41 | input_seq = torch.zeros([seq_len + 2, self.seq_width + 2]) 42 | input_seq[0, self.seq_width] = 1.0 # start delimiter 43 | input_seq[1:seq_len + 1, :self.seq_width] = seq 44 | input_seq[seq_len + 1, self.seq_width + 1] = 1.0 # end delimiter 45 | 46 | target_seq = torch.zeros([seq_len, self.seq_width]) 47 | target_seq[:seq_len, :self.seq_width] = seq 48 | return {'input': input_seq, 'target': target_seq} 49 | -------------------------------------------------------------------------------- /ntm/datasets/ngram.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch.distributions.beta import Beta 4 | from torch.distributions.bernoulli import Bernoulli 5 | 6 | 7 | class NGram(Dataset): 8 | """A Dataset class to generate random examples for the N-gram task. 9 | 10 | Each sequence is generated using a lookup table for n-Gram distribution 11 | probabilities. The lookup table contains 2**(n-1) numbers specifying the 12 | probability that the next bit will be one. The numbers represent all 13 | possible (n-1) length binary histories. The probabilities are independently 14 | drawn from Beta(0.5,0.5) distribution. 15 | 16 | The first 5 bits, for which insuffient context exists to sample from the 17 | table, are drawn i.i.d. from a Bernoulli distribution with p=0.5. The 18 | subsequent bits are drawn using probabilities from the table. 19 | """ 20 | 21 | def __init__(self, task_params): 22 | """ Initialize a dataset instance for N-gram task. 23 | 24 | Arguments 25 | --------- 26 | task_params : dict 27 | A dict containing parameters relevant to N-gram task. 28 | """ 29 | self.seq_len = task_params["seq_len"] 30 | self.n = task_params["N"] 31 | 32 | def __len__(self): 33 | # sequences are generated randomly so this does not matter 34 | # set a sufficiently large size for data loader to sample mini-batches 35 | return 65536 36 | 37 | def __getitem__(self, idx): 38 | # idx only acts as a counter while generating batches. 39 | beta_prob = Beta(torch.tensor([0.5]), torch.tensor([0.5])) 40 | lookup_table = {} 41 | 42 | # generate probabilities for the lookup table. The key represents the 43 | # possible binary sequences. 44 | for i in range(2**(self.n - 1)): 45 | lookup_table[bin(i)[2:].rjust(self.n - 1, '0')] = beta_prob.sample() 46 | 47 | # generate input sequence 48 | input_seq = torch.zeros([self.seq_len]) 49 | prob = Bernoulli(torch.tensor([0.5])) 50 | for i in range(self.n): 51 | input_seq[i] = prob.sample() 52 | for i in range(self.n - 1, self.seq_len): 53 | prev = input_seq[i - self.n + 1:i] 54 | prev = ''.join(map(str, map(int, prev))) 55 | prob = lookup_table[prev] 56 | input_seq[i] = Bernoulli(prob).sample() 57 | 58 | # As the task is to predict the next bit, the target sequence is a bit 59 | # shorter than the input. 60 | target_seq = input_seq[1:self.seq_len] 61 | 62 | return {'input': input_seq, 'target': target_seq} 63 | -------------------------------------------------------------------------------- /ntm/datasets/prioritysort.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch.distributions.uniform import Uniform 4 | from torch.distributions.binomial import Binomial 5 | 6 | 7 | class PrioritySort(Dataset): 8 | """A Dataset class to generate random examples for priority sort task. 9 | 10 | In the input sequence, each vector is generated randomly along with a 11 | scalar priority rating. The priority is drawn uniformly from the range 12 | [-1,1) and is provided on a separate input channel. 13 | 14 | The target contains the binary vectors sorted according to their priorities 15 | """ 16 | 17 | def __init__(self, task_params): 18 | """ Initialize a dataset instance for the priority sort task. 19 | 20 | Arguments 21 | --------- 22 | task_params : dict 23 | A dict containing parameters relevant to priority sort task. 24 | """ 25 | self.seq_width = task_params["seq_width"] 26 | self.input_seq_len = task_params["input_seq_len"] 27 | self.target_seq_len = task_params["target_seq_len"] 28 | 29 | def __len__(self): 30 | # sequences are generated randomly so this does not matter 31 | # set a sufficiently large size for data loader to sample mini-batches 32 | return 65536 33 | 34 | def __getitem__(self, idx): 35 | # idx only acts as a counter while generating batches. 36 | prob = 0.5 * torch.ones([self.input_seq_len, 37 | self.seq_width], dtype=torch.float64) 38 | seq = Binomial(1, prob).sample() 39 | # Extra input channel for providing priority value 40 | input_seq = torch.zeros([self.input_seq_len, self.seq_width + 1]) 41 | input_seq[:self.input_seq_len, :self.seq_width] = seq 42 | 43 | # torch's Uniform function draws samples from the half-open interval 44 | # [low, high) but in the paper the priorities are drawn from [-1,1]. 45 | # This minor difference is being ignored here as supposedly it doesn't 46 | # affects the task. 47 | priority = Uniform(torch.tensor([-1.0]), torch.tensor([1.0])) 48 | for i in range(self.input_seq_len): 49 | input_seq[i, self.seq_width] = priority.sample() 50 | 51 | sorted, _ = torch.sort(input_seq, 0, descending=True) 52 | target_seq = sorted[:self.target_seq_len, :self.seq_width] 53 | 54 | return {'input': input_seq, 'target': target_seq} 55 | -------------------------------------------------------------------------------- /ntm/datasets/repeatcopy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from torch.distributions.binomial import Binomial 5 | 6 | 7 | class RepeatCopyDataset(Dataset): 8 | """A Dataset class to generate random examples for the repeat copy task. 9 | Each sequence has a random length between `min_seq_len` and `max_seq_len`. 10 | Each vector in the sequence has a fixed length of `seq_width`. The input 11 | sequence is prefixed by a start delimiter. 12 | 13 | Along with the delimiter flag, the input sequence also contains a channel 14 | for number of repetitions. The input representing the repeat number is 15 | normalised to have mean zero and variance one. 16 | 17 | For the target sequence, each sequence is repeated given number of times 18 | followed by a delimiter flag marking end of the target sequence. 19 | """ 20 | 21 | def __init__(self, task_params): 22 | """Initialize a dataset instance for repeat copy task. 23 | 24 | Arguments 25 | --------- 26 | task_params : dict 27 | A dict containing parameters relevant to repeat copy task. 28 | """ 29 | self.seq_width = task_params["seq_width"] 30 | self.min_seq_len = task_params["min_seq_len"] 31 | self.max_seq_len = task_params["max_seq_len"] 32 | self.min_repeat = task_params["min_repeat"] 33 | self.max_repeat = task_params["max_repeat"] 34 | 35 | def normalise(self, rep): 36 | rep_mean = (self.max_repeat - self.min_repeat) / 2 37 | rep_var = (((self.max_repeat - self.min_repeat + 1) ** 2) - 1) / 12 38 | rep_std = np.sqrt(rep_var) 39 | return (rep - rep_mean) / rep_std 40 | 41 | def __len__(self): 42 | # sequences are generated randomly so this does not matter 43 | # set a sufficiently large size for data loader to sample mini-batches 44 | return 65536 45 | 46 | def __getitem__(self, idx): 47 | # idx only acts as a counter while generating batches. 48 | seq_len = torch.randint( 49 | self.min_seq_len, self.max_seq_len, (1,), dtype=torch.long).item() 50 | rep = torch.randint( 51 | self.min_repeat, self.max_repeat, (1,), dtype=torch.long).item() 52 | prob = 0.5 * torch.ones([seq_len, self.seq_width], dtype=torch.float64) 53 | seq = Binomial(1, prob).sample() 54 | 55 | # fill in input sequence, two bit longer and wider than target 56 | input_seq = torch.zeros([seq_len + 2, self.seq_width + 2]) 57 | input_seq[0, self.seq_width] = 1.0 # delimiter 58 | input_seq[1:seq_len + 1, :self.seq_width] = seq 59 | input_seq[seq_len + 1, self.seq_width + 1] = self.normalise(rep) 60 | 61 | target_seq = torch.zeros( 62 | [seq_len * rep + 1, self.seq_width + 1]) 63 | target_seq[:seq_len * rep, :self.seq_width] = seq.repeat(rep, 1) 64 | target_seq[seq_len * rep, self.seq_width] = 1.0 # delimiter 65 | 66 | return {'input': input_seq, 'target': target_seq} 67 | -------------------------------------------------------------------------------- /ntm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .controller import NTMController 2 | from .head import NTMHead 3 | from .memory import NTMMemory 4 | -------------------------------------------------------------------------------- /ntm/modules/controller.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class NTMController(nn.Module): 7 | 8 | def __init__(self, input_size, controller_size, output_size, read_data_size): 9 | super().__init__() 10 | self.input_size = input_size 11 | self.controller_size = controller_size 12 | self.output_size = output_size 13 | self.read_data_size = read_data_size 14 | 15 | self.controller_net = nn.LSTMCell(input_size, controller_size) 16 | self.out_net = nn.Linear(read_data_size, output_size) 17 | # nn.init.xavier_uniform_(self.out_net.weight) 18 | nn.init.kaiming_uniform_(self.out_net.weight) 19 | self.h_state = torch.zeros([1, controller_size]) 20 | self.c_state = torch.zeros([1, controller_size]) 21 | # layers to learn bias values for controller state reset 22 | self.h_bias_fc = nn.Linear(1, controller_size) 23 | # nn.init.kaiming_uniform_(self.h_bias_fc.weight) 24 | self.c_bias_fc = nn.Linear(1, controller_size) 25 | # nn.init.kaiming_uniform_(self.c_bias_fc.weight) 26 | self.reset() 27 | 28 | def forward(self, in_data, prev_reads): 29 | x = torch.cat([in_data] + prev_reads, dim=-1) 30 | self.h_state, self.c_state = self.controller_net( 31 | x, (self.h_state, self.c_state)) 32 | return self.h_state, self.c_state 33 | 34 | def output(self, read_data): 35 | complete_state = torch.cat([self.h_state] + read_data, dim=-1) 36 | output = F.sigmoid(self.out_net(complete_state)) 37 | return output 38 | 39 | def reset(self, batch_size=1): 40 | in_data = torch.tensor([[0.]]) # dummy input 41 | h_bias = self.h_bias_fc(in_data) 42 | self.h_state = h_bias.repeat(batch_size, 1) 43 | c_bias = self.c_bias_fc(in_data) 44 | self.c_state = c_bias.repeat(batch_size, 1) 45 | -------------------------------------------------------------------------------- /ntm/modules/head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class NTMHead(nn.Module): 7 | 8 | def __init__(self, mode, controller_size, key_size): 9 | super().__init__() 10 | self.mode = mode 11 | self.key_size = key_size 12 | 13 | # all the fc layers to produce scalars for memory addressing 14 | self.key_fc = nn.Linear(controller_size, key_size) 15 | self.key_strength_fc = nn.Linear(controller_size, 1) 16 | 17 | # these five fc layers cannot be in controller class 18 | # since each head has its own parameters and scalars 19 | self.interpolation_gate_fc = nn.Linear(controller_size, 1) 20 | self.shift_weighting_fc = nn.Linear(controller_size, 3) 21 | self.sharpen_factor_fc = nn.Linear(controller_size, 1) 22 | # --(optional : for separation of add and erase mechanism) 23 | # self.erase_weight_fc = nn.Linear(controller_size, key_size) 24 | 25 | # fc layer to produce write data. data vector length=key_size 26 | self.write_data_fc = nn.Linear(controller_size, key_size) 27 | self.reset() 28 | 29 | def forward(self, controller_state, prev_weights, memory, data=None): 30 | """Accept previous state (weights and memory) and controller state, 31 | produce attention weights for current read or write operation. 32 | Weights are produced by content-based and location-based addressing. 33 | 34 | Refer *Figure 2* in the paper to see how weights are produced. 35 | 36 | The head returns current weights useful for next time step, while 37 | it reads from or writes to ``memory`` based on its mode, using the 38 | ``data`` vector. ``data`` is filled and returned for read mode, 39 | returned as is for write mode. 40 | 41 | Refer *Section 3.1* for read mode and *Section 3.2* for write mode. 42 | 43 | Parameters 44 | ---------- 45 | controller_state : torch.Tensor 46 | Long-term state of the controller. 47 | ``(batch_size, controller_size)`` 48 | 49 | prev_weights : torch.Tensor 50 | Attention weights from previous time step. 51 | ``(batch_size, memory_units)`` 52 | 53 | memory : ntm_modules.NTMMemory 54 | Memory Instance. Read write operations will be performed in place. 55 | 56 | data : torch.Tensor 57 | Depending upon the mode, this data vector will be used by memory. 58 | ``(batch_size, memory_unit_size)`` 59 | 60 | Returns 61 | ------- 62 | current_weights, data : torch.Tensor, torch.Tensor 63 | Current weights and data (filled in read operation else as it is). 64 | ``(batch_size, memory_units), (batch_size, memory_unit_size)`` 65 | """ 66 | 67 | # all these are marked as "controller outputs" in Figure 2 68 | key = self.key_fc(controller_state) 69 | b = F.softplus(self.key_strength_fc(controller_state)) 70 | g = F.sigmoid(self.interpolation_gate_fc(controller_state)) 71 | s = F.softmax(self.shift_weighting_fc(controller_state)) 72 | # here the sharpening factor is less than 1 whereas as required in the 73 | # paper it should be greater than 1. hence adding 1. 74 | y = 1 + F.softplus(self.sharpen_factor_fc(controller_state)) 75 | # e = F.sigmoid(self.erase_weight_fc(controller_state)) # erase vector 76 | a = self.write_data_fc(controller_state) # add vector 77 | 78 | content_weights = memory.content_addressing(key, b) 79 | # location-based addressing - interpolate, shift, sharpen 80 | interpolated_weights = g * content_weights + (1 - g) * prev_weights 81 | shifted_weights = self._circular_conv1d(interpolated_weights, s) 82 | # the softmax introduces the exp of the argument which isn't there in 83 | # the paper. there it's just a simple normalization of the arguments. 84 | current_weights = shifted_weights ** y 85 | # current_weights = F.softmax(shifted_weights ** y) 86 | current_weights = torch.div(current_weights, torch.sum( 87 | current_weights, dim=1).view(-1, 1) + 1e-16) 88 | 89 | if self.mode == 'r': 90 | data = memory.read(current_weights) 91 | elif self.mode == 'w': 92 | # memory.write(current_weights, a, e) 93 | memory.write(current_weights, a) 94 | else: 95 | raise ValueError("mode must be read ('r') or write('w')") 96 | return current_weights, data 97 | 98 | @staticmethod 99 | def _circular_conv1d(in_tensor, weights): 100 | # pad left with elements from right, and vice-versa 101 | batch_size = weights.size(0) 102 | pad = int((weights.size(1) - 1) / 2) 103 | in_tensor = torch.cat( 104 | [in_tensor[:, -pad:], in_tensor, in_tensor[:, :pad]], dim=1) 105 | out_tensor = F.conv1d(in_tensor.view(batch_size, 1, -1), 106 | weights.view(batch_size, 1, -1)) 107 | out_tensor = out_tensor.view(batch_size, -1) 108 | return out_tensor 109 | 110 | def reset(self): 111 | nn.init.xavier_uniform_(self.key_strength_fc.weight, gain=1.4) 112 | nn.init.xavier_uniform_(self.interpolation_gate_fc.weight, gain=1.4) 113 | nn.init.xavier_uniform_(self.shift_weighting_fc.weight, gain=1.4) 114 | nn.init.xavier_uniform_(self.sharpen_factor_fc.weight, gain=1.4) 115 | nn.init.xavier_uniform_(self.write_data_fc.weight, gain=1.4) 116 | # nn.init.xavier_uniform_(self.erase_weight_fc.weight, gain=1.4) 117 | 118 | # nn.init.kaiming_uniform_(self.key_strength_fc.weight) 119 | # nn.init.kaiming_uniform_(self.interpolation_gate_fc.weight) 120 | # nn.init.kaiming_uniform_(self.shift_weighting_fc.weight) 121 | # nn.init.kaiming_uniform_(self.sharpen_factor_fc.weight) 122 | # nn.init.kaiming_uniform_(self.write_data_fc.weight) 123 | # nn.init.kaiming_uniform_(self.erase_weight_fc.weight) 124 | 125 | nn.init.normal_(self.key_fc.bias, std=0.01) 126 | nn.init.normal_(self.key_strength_fc.bias, std=0.01) 127 | nn.init.normal_(self.interpolation_gate_fc.bias, std=0.01) 128 | nn.init.normal_(self.shift_weighting_fc.bias, std=0.01) 129 | nn.init.normal_(self.sharpen_factor_fc.bias, std=0.01) 130 | nn.init.normal_(self.write_data_fc.bias, std=0.01) 131 | # nn.init.normal_(self.erase_weight_fc.bias, std=0.01) 132 | -------------------------------------------------------------------------------- /ntm/modules/memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class NTMMemory(nn.Module): 7 | 8 | def __init__(self, memory_units, memory_unit_size): 9 | super().__init__() 10 | self.n = memory_units 11 | self.m = memory_unit_size 12 | self.memory = torch.zeros([1, self.n, self.m]) 13 | nn.init.kaiming_uniform_(self.memory) 14 | # layer to learn bias values for memory reset 15 | self.memory_bias_fc = nn.Linear(1, self.n * self.m) 16 | self.reset() 17 | 18 | def forward(self, *inputs): 19 | pass 20 | 21 | def content_addressing(self, key, key_strength): 22 | """Perform content based addressing of memory based on key vector. 23 | Calculate cosine similarity between key vector and each unit of memory, 24 | finally obtain a softmax distribution out of it. These are normalized 25 | content weights according to content based addressing. 26 | 27 | Refer *Section 3.3.1* and *Figure 2* in the paper. 28 | 29 | Parameters 30 | ---------- 31 | key : torch.Tensor 32 | The key vector (a.k.a. query vector) emitted by a read/write head. 33 | ``(batch_size, memory_unit_size)`` 34 | 35 | key_strength : torch.Tensor 36 | A scalar weight (a.k.a. beta) multiplied for shaping the softmax 37 | distribution. 38 | ``(zero-dimensional)`` 39 | 40 | Returns 41 | ------- 42 | content_weights : torch.Tensor 43 | Normalized content-addressed weights vector. k-th element of this 44 | vector represents the amount of k-th memory unit to be read from or 45 | written to. 46 | ``(batch_size, memory_units)`` 47 | """ 48 | 49 | # view key with three dimensions as memory, add dummy dimension 50 | key = key.view(-1, 1, self.m) 51 | 52 | # calculate similarity along last dimension (self.m) 53 | similarity = F.cosine_similarity(key, self.memory, dim=-1) 54 | content_weights = F.softmax(key_strength * similarity, dim=1) 55 | return content_weights 56 | 57 | def read(self, weights): 58 | """Read from memory through soft attention over all locations. 59 | 60 | Refer *Section 3.1* in the paper for read mechanism. 61 | 62 | Parameters 63 | ---------- 64 | weights : torch.Tensor 65 | Attention weights emitted by a read head. 66 | ``(batch_size, memory_units)`` 67 | 68 | Returns 69 | ------- 70 | data : torch.Tensor 71 | Data read from memory weighted by attention. 72 | ``(batch_size, memory_unit_size)`` 73 | """ 74 | # expand and perform batch matrix mutliplication 75 | weights = weights.view(-1, 1, self.n) 76 | # (b, 1, self.n) x (b, self.n, self.m) -> (b, 1, self.m) 77 | data = torch.bmm(weights, self.memory).view(-1, self.m) 78 | return data 79 | 80 | def write(self, weights, data, erase=None): 81 | """Write to memory through soft attention over all locations. 82 | 83 | Refer *Section 3.2* in the paper for write mechanism. 84 | 85 | .. note:: 86 | Erase and add mechanisms have been merged here. 87 | ``weights(erase) = (1 - weights(add))`` 88 | 89 | Parameters 90 | ---------- 91 | weights : torch.Tensor 92 | Attention weights emitted by a write head. 93 | ``(batch_size, memory_units)`` 94 | 95 | data : torch.Tensor 96 | Data to be written to memory. 97 | ``(batch_size, memory_unit_size)`` 98 | 99 | erase(optional) : torch.Tensor 100 | Extent of erasure to be performed on the memory unit. 101 | ``(batch_size, memory_unit_size)`` 102 | """ 103 | 104 | # make weights and write_data sizes same as memory 105 | weights = weights.view(-1, self.n, 1).expand(self.memory.size()) 106 | data = data.view(-1, 1, self.m).expand(self.memory.size()) 107 | self.memory = weights * data + (1 - weights) * self.memory 108 | # --(separate erase and add mechanism) 109 | # erase = erase.view(-1, 1, self.m).expand(self.memory.size()) 110 | # self.memory = (1 - weights * erase) * self.memory 111 | # self.memory = weights * data + self.memory 112 | 113 | def reset(self, batch_size=1): 114 | # self.memory = torch.zeros([batch_size, self.n, self.m]) 115 | # nn.init.kaiming_uniform_(self.memory) 116 | in_data = torch.tensor([[0.]]) # dummy input 117 | memory_bias = F.sigmoid(self.memory_bias_fc(in_data)) 118 | self.memory = memory_bias.view(self.n, self.m).repeat(batch_size, 1, 1) 119 | -------------------------------------------------------------------------------- /ntm/ntm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .modules.controller import NTMController 5 | from .modules.head import NTMHead 6 | from .modules.memory import NTMMemory 7 | 8 | 9 | class NTM(nn.Module): 10 | def __init__(self, 11 | input_size, 12 | output_size, 13 | controller_size, 14 | memory_units, 15 | memory_unit_size, 16 | num_heads): 17 | super().__init__() 18 | self.controller_size = controller_size 19 | self.controller = NTMController( 20 | input_size + num_heads * memory_unit_size, controller_size, output_size, 21 | read_data_size=controller_size + num_heads * memory_unit_size) 22 | 23 | self.memory = NTMMemory(memory_units, memory_unit_size) 24 | self.memory_unit_size = memory_unit_size 25 | self.memory_units = memory_units 26 | self.num_heads = num_heads 27 | self.heads = nn.ModuleList([]) 28 | for head in range(num_heads): 29 | self.heads += [ 30 | NTMHead('r', controller_size, key_size=memory_unit_size), 31 | NTMHead('w', controller_size, key_size=memory_unit_size) 32 | ] 33 | 34 | self.prev_head_weights = [] 35 | self.prev_reads = [] 36 | self.reset() 37 | 38 | def reset(self, batch_size=1): 39 | self.memory.reset(batch_size) 40 | self.controller.reset(batch_size) 41 | self.prev_head_weights = [] 42 | for i in range(len(self.heads)): 43 | prev_weight = torch.zeros([batch_size, self.memory.n]) 44 | self.prev_head_weights.append(prev_weight) 45 | self.prev_reads = [] 46 | for i in range(self.num_heads): 47 | prev_read = torch.zeros([batch_size, self.memory.m]) 48 | # using random initialization for previous reads 49 | nn.init.kaiming_uniform_(prev_read) 50 | self.prev_reads.append(prev_read) 51 | 52 | def forward(self, in_data): 53 | controller_h_state, controller_c_state = self.controller( 54 | in_data, self.prev_reads) 55 | read_data = [] 56 | head_weights = [] 57 | for head, prev_head_weight in zip(self.heads, self.prev_head_weights): 58 | if head.mode == 'r': 59 | head_weight, r = head( 60 | controller_c_state, prev_head_weight, self.memory) 61 | read_data.append(r) 62 | else: 63 | head_weight, _ = head( 64 | controller_c_state, prev_head_weight, self.memory) 65 | head_weights.append(head_weight) 66 | 67 | output = self.controller.output(read_data) 68 | 69 | self.prev_head_weights = head_weights 70 | self.prev_reads = read_data 71 | 72 | return output 73 | -------------------------------------------------------------------------------- /ntm/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .ntm import NTM 2 | -------------------------------------------------------------------------------- /ntm/tasks/associative.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "associativerecall", 3 | "controller_size": 100, 4 | "memory_units": 128, 5 | "memory_unit_size": 20, 6 | "num_heads": 1, 7 | "seq_width": 6, 8 | "seq_len": 3, 9 | "min_item": 2, 10 | "max_item": 6 11 | } -------------------------------------------------------------------------------- /ntm/tasks/copy.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "copy", 3 | "controller_size": 100, 4 | "memory_units": 128, 5 | "memory_unit_size": 20, 6 | "num_heads": 1, 7 | "seq_width": 8, 8 | "min_seq_len": 1, 9 | "max_seq_len": 20 10 | } 11 | -------------------------------------------------------------------------------- /ntm/tasks/ngram.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "ngram", 3 | "controller_size": 100, 4 | "memory_units": 128, 5 | "memory_unit_size": 20, 6 | "num_heads": 1, 7 | "seq_len": 200, 8 | "N": 6 9 | } -------------------------------------------------------------------------------- /ntm/tasks/prioritysort.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "prioritysort", 3 | "controller_size": 200, 4 | "memory_units": 128, 5 | "memory_unit_size": 20, 6 | "num_heads": 5, 7 | "seq_width": 8, 8 | "input_seq_len": 20, 9 | "target_seq_len": 16 10 | } -------------------------------------------------------------------------------- /ntm/tasks/repeatcopy.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "repeatcopy", 3 | "controller_size": 100, 4 | "memory_units": 128, 5 | "memory_unit_size": 20, 6 | "num_heads": 1, 7 | "seq_width": 8, 8 | "min_seq_len": 1, 9 | "max_seq_len": 10, 10 | "min_repeat": 1, 11 | "max_repeat": 10 12 | } -------------------------------------------------------------------------------- /saved_models/saved_model_associative_100000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/saved_models/saved_model_associative_100000.pt -------------------------------------------------------------------------------- /saved_models/saved_model_copy_500000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/saved_models/saved_model_copy_500000.pt -------------------------------------------------------------------------------- /saved_models/saved_model_prioritysort_100000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/saved_models/saved_model_prioritysort_100000.pt -------------------------------------------------------------------------------- /saved_models/saved_model_repeatcopy_100000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlgiitr/ntm-pytorch/9bea8b1fa262c2879959857cdfe6412904b06ff4/saved_models/saved_model_repeatcopy_100000.pt -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | import numpy as np 4 | import os 5 | 6 | import torch 7 | from torch import nn, optim 8 | from tensorboard_logger import configure, log_value 9 | 10 | from ntm import NTM 11 | from ntm.datasets import CopyDataset, RepeatCopyDataset, AssociativeDataset, NGram, PrioritySort 12 | from ntm.args import get_parser 13 | 14 | 15 | args = get_parser().parse_args() 16 | 17 | configure("runs/") 18 | 19 | # ---------------------------------------------------------------------------- 20 | # -- initialize datasets, model, criterion and optimizer 21 | # ---------------------------------------------------------------------------- 22 | 23 | args.task_json = 'ntm/tasks/copy.json' 24 | ''' 25 | args.task_json = 'ntm/tasks/repeatcopy.json' 26 | args.task_json = 'ntm/tasks/associative.json' 27 | args.task_json = 'ntm/tasks/ngram.json' 28 | args.task_json = 'ntm/tasks/prioritysort.json' 29 | ''' 30 | 31 | task_params = json.load(open(args.task_json)) 32 | 33 | dataset = CopyDataset(task_params) 34 | ''' 35 | dataset = RepeatCopyDataset(task_params) 36 | dataset = AssociativeDataset(task_params) 37 | dataset = NGram(task_params) 38 | dataset = PrioritySort(task_params) 39 | ''' 40 | 41 | """ 42 | For the Copy task, input_size: seq_width + 2, output_size: seq_width 43 | For the RepeatCopy task, input_size: seq_width + 2, output_size: seq_width + 1 44 | For the Associative task, input_size: seq_width + 2, output_size: seq_width 45 | For the NGram task, input_size: 1, output_size: 1 46 | For the Priority Sort task, input_size: seq_width + 1, output_size: seq_width 47 | """ 48 | ntm = NTM(input_size=task_params['seq_width'] + 2, 49 | output_size=task_params['seq_width'], 50 | controller_size=task_params['controller_size'], 51 | memory_units=task_params['memory_units'], 52 | memory_unit_size=task_params['memory_unit_size'], 53 | num_heads=task_params['num_heads']) 54 | 55 | criterion = nn.BCELoss() 56 | # As the learning rate is task specific, the argument can be moved to json file 57 | optimizer = optim.RMSprop(ntm.parameters(), 58 | lr=args.lr, 59 | alpha=args.alpha, 60 | momentum=args.momentum) 61 | ''' 62 | optimizer = optim.Adam(ntm.parameters(), lr=args.lr, 63 | betas=(args.beta1, args.beta2)) 64 | ''' 65 | 66 | args.saved_model = 'saved_model_copy.pt' 67 | ''' 68 | args.saved_model = 'saved_model_repeatcopy.pt' 69 | args.saved_model = 'saved_model_associative.pt' 70 | args.saved_model = 'saved_model_ngram.pt' 71 | args.saved_model = 'saved_model_prioritysort.pt' 72 | ''' 73 | 74 | cur_dir = os.getcwd() 75 | PATH = os.path.join(cur_dir, args.saved_model) 76 | 77 | # ---------------------------------------------------------------------------- 78 | # -- basic training loop 79 | # ---------------------------------------------------------------------------- 80 | losses = [] 81 | errors = [] 82 | for iter in tqdm(range(args.num_iters)): 83 | optimizer.zero_grad() 84 | ntm.reset() 85 | 86 | data = dataset[iter] 87 | input, target = data['input'], data['target'] 88 | out = torch.zeros(target.size()) 89 | 90 | # ------------------------------------------------------------------------- 91 | # loop for other tasks 92 | # ------------------------------------------------------------------------- 93 | for i in range(input.size()[0]): 94 | # to maintain consistency in dimensions as torch.cat was throwing error 95 | in_data = torch.unsqueeze(input[i], 0) 96 | ntm(in_data) 97 | 98 | # passing zero vector as input while generating target sequence 99 | in_data = torch.unsqueeze(torch.zeros(input.size()[1]), 0) 100 | for i in range(target.size()[0]): 101 | out[i] = ntm(in_data) 102 | # ------------------------------------------------------------------------- 103 | # ------------------------------------------------------------------------- 104 | # loop for NGram task 105 | # ------------------------------------------------------------------------- 106 | ''' 107 | for i in range(task_params['seq_len'] - 1): 108 | in_data = input[i].view(1, -1) 109 | ntm(in_data) 110 | target_data = torch.zeros([1]).view(1, -1) 111 | out[i] = ntm(target_data) 112 | ''' 113 | # ------------------------------------------------------------------------- 114 | 115 | loss = criterion(out, target) 116 | losses.append(loss.item()) 117 | loss.backward() 118 | # clips gradient in the range [-10,10]. Again there is a slight but 119 | # insignificant deviation from the paper where they are clipped to (-10,10) 120 | nn.utils.clip_grad_value_(ntm.parameters(), 10) 121 | optimizer.step() 122 | 123 | binary_output = out.clone() 124 | binary_output = binary_output.detach().apply_(lambda x: 0 if x < 0.5 else 1) 125 | 126 | # sequence prediction error is calculted in bits per sequence 127 | error = torch.sum(torch.abs(binary_output - target)) 128 | errors.append(error.item()) 129 | 130 | # ---logging--- 131 | if iter % 200 == 0: 132 | print('Iteration: %d\tLoss: %.2f\tError in bits per sequence: %.2f' % 133 | (iter, np.mean(losses), np.mean(errors))) 134 | log_value('train_loss', np.mean(losses), iter) 135 | log_value('bit_error_per_sequence', np.mean(errors), iter) 136 | losses = [] 137 | errors = [] 138 | 139 | # ---saving the model--- 140 | torch.save(ntm.state_dict(), PATH) 141 | # torch.save(ntm, PATH) 142 | --------------------------------------------------------------------------------