├── datasets ├── __init__.py ├── reverse.py └── copy.py ├── requirements.txt ├── tasks ├── copy.json └── reverse.json ├── README.md ├── LICENSE ├── baselines ├── nvm │ ├── controller.py │ ├── lstm_baseline.py │ ├── ntm.py │ ├── ntm_mem.py │ ├── util.py │ ├── con_att.py │ ├── ntm_warper.py │ └── head.py ├── transformer.py ├── dnc │ ├── util.py │ ├── memory.py │ ├── sp_memory.py │ ├── dnc.py │ └── sdnc.py └── panm.py ├── args.py ├── .gitignore └── run_algo_task.py /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .copy import CopyDataset 2 | from .reverse import ReverseDataset 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.21.2 2 | tensorboard_logger==0.1.0 3 | torch==1.12.1 4 | tqdm==4.62.3 5 | -------------------------------------------------------------------------------- /tasks/copy.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "copy", 3 | "controller_size": 256, 4 | "memory_units": 128, 5 | "memory_unit_size": 32, 6 | "num_heads": 1, 7 | "num_slot": 8, 8 | "slot_size": 32, 9 | "rel_size": 32, 10 | "seq_width": 10, 11 | "min_seq_len": 1, 12 | "max_seq_len": 10, 13 | "iter": 50000 14 | } 15 | -------------------------------------------------------------------------------- /tasks/reverse.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "reverse", 3 | "controller_size": 256, 4 | "memory_units": 128, 5 | "memory_unit_size": 32, 6 | "num_heads": 1, 7 | "num_slot": 8, 8 | "slot_size": 96, 9 | "rel_size": 96, 10 | "seq_width": 10, 11 | "min_seq_len": 1, 12 | "max_seq_len": 10, 13 | "iter": 50000 14 | } 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PANM 2 | Source code for paper "Plug, Play, and Generalize: Length Extrapolation with Pointer-Augmented Neural Memory" 3 | - Paper: https://openreview.net/forum?id=dyQ9vFbF6D 4 | - Blog: https://hungleai.substack.com/p/extending-neural-networks-to-new 5 | - Code reference: https://github.com/thaihungle/SAM 6 | 7 | # Setup 8 | python 3.8 9 | ``` 10 | pip install -r requirements.txt 11 | mkdir logs 12 | mkdir saved_models 13 | ``` 14 | 15 | # Alogirthm tasks 16 | run training command examples for Copy 17 | ``` 18 | LSTM baseline: python run_algo_task.py -task_json=./tasks/copy.json -model_name=lstm -mode=train 19 | PANM: python run_toys.py -task_json=./tasks/copy.json -model_name=panm -mode=train 20 | ``` 21 | 22 | run testing command examples for Copy (x2 test length) 23 | ``` 24 | LSTM baseline: python run_algo_task.py -task_json=./tasks/copy.json -model_name=lstm -mode=test -genlen=2 25 | PANM: python run_toys.py -task_json=./tasks/copy.json -model_name=panm -mode=test -genlen=2 26 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 HLe 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /baselines/nvm/controller.py: -------------------------------------------------------------------------------- 1 | """LSTM Controller.""" 2 | import torch 3 | from torch import nn 4 | from torch.nn import Parameter 5 | import numpy as np 6 | 7 | 8 | class LSTMController(nn.Module): 9 | """An NTM controller based on LSTM.""" 10 | def __init__(self, num_inputs, num_outputs, num_layers): 11 | super(LSTMController, self).__init__() 12 | 13 | self.num_inputs = num_inputs 14 | self.num_outputs = num_outputs 15 | self.num_layers = num_layers 16 | 17 | self.lstm = nn.LSTM(input_size=num_inputs, 18 | hidden_size=num_outputs, 19 | num_layers=num_layers) 20 | 21 | # The hidden state is a learned parameter 22 | if torch.cuda.is_available(): 23 | self.lstm_h_bias = Parameter(torch.randn(self.num_layers, 1, self.num_outputs).cuda() * 0.05) 24 | self.lstm_c_bias = Parameter(torch.randn(self.num_layers, 1, self.num_outputs).cuda() * 0.05) 25 | else: 26 | self.lstm_h_bias = Parameter(torch.randn(self.num_layers, 1, self.num_outputs) * 0.05) 27 | self.lstm_c_bias = Parameter(torch.randn(self.num_layers, 1, self.num_outputs) * 0.05) 28 | 29 | self.reset_parameters() 30 | 31 | def create_new_state(self, batch_size): 32 | # Dimension: (num_layers * num_directions, batch, hidden_size) 33 | lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1) 34 | lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1) 35 | return lstm_h, lstm_c 36 | 37 | def reset_parameters(self): 38 | for p in self.lstm.parameters(): 39 | if p.dim() == 1: 40 | nn.init.constant_(p, 0) 41 | else: 42 | stdev = 5 / (np.sqrt(self.num_inputs + self.num_outputs)) 43 | nn.init.uniform_(p, -stdev, stdev) 44 | 45 | def size(self): 46 | return self.num_inputs, self.num_outputs 47 | 48 | def forward(self, x, prev_state): 49 | x = x.unsqueeze(0) 50 | outp, state = self.lstm(x, prev_state) 51 | return outp.squeeze(0), state -------------------------------------------------------------------------------- /datasets/reverse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch.distributions.binomial import Binomial 4 | 5 | 6 | class ReverseDataset(Dataset): 7 | """A Dataset class to generate random examples for the copy task. Each 8 | sequence has a random length between `min_seq_len` and `max_seq_len`. 9 | Each vector in the sequence has a fixed length of `seq_width`. The vectors 10 | are bounded by start and end delimiter flags. 11 | 12 | To account for the delimiter flags, the input sequence length as well 13 | width is two more than the target sequence. 14 | """ 15 | 16 | def __init__(self, task_params): 17 | """Initialize a dataset instance for copy task. 18 | 19 | Arguments 20 | --------- 21 | task_params : dict 22 | A dict containing parameters relevant to copy task. 23 | """ 24 | self.seq_width = task_params['seq_width'] 25 | self.min_seq_len = task_params['min_seq_len'] 26 | self.max_seq_len = task_params['max_seq_len'] 27 | self.in_dim = task_params['seq_width'] + 2 28 | self.out_dim = task_params['seq_width'] 29 | self.prob = task_params['prob'] 30 | self.mode = task_params['mode'] 31 | 32 | 33 | 34 | def __len__(self): 35 | # sequences are generated randomly so this does not matter 36 | # set a sufficiently large size for data loader to sample mini-batches 37 | return 65536 38 | 39 | def get_sample_wlen(self, seq_len, bs=1): 40 | # idx only acts as a counter while generating batches. 41 | if self.mode == "onehot": 42 | seq = torch.nn.functional.one_hot(torch.randint(self.seq_width, (seq_len, bs)),num_classes=self.seq_width) 43 | else: 44 | prob = self.prob * torch.ones([seq_len, bs, self.seq_width], dtype=torch.float64) 45 | seq = Binomial(1, prob).sample() 46 | 47 | # fill in input sequence, two bit longer and wider than target 48 | input_seq = torch.zeros([seq_len + 2, bs,self.seq_width + 2]) 49 | input_seq[0, :,self.seq_width] = 1.0 # start delimiter 50 | input_seq[1:seq_len + 1,:, :self.seq_width] = seq 51 | input_seq[seq_len + 1, :, self.seq_width + 1] = 1.0 # end delimiter 52 | 53 | target_seq = torch.zeros([seq_len, bs, self.seq_width]) 54 | target_seq[:seq_len,:, :self.seq_width] = torch.flip(seq, [0]) 55 | return {'input': input_seq, 'target': target_seq} 56 | -------------------------------------------------------------------------------- /datasets/copy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch.distributions.binomial import Binomial 4 | import torch.nn.functional as F 5 | 6 | class CopyDataset(Dataset): 7 | """A Dataset class to generate random examples for the copy task. Each 8 | sequence has a random length between `min_seq_len` and `max_seq_len`. 9 | Each vector in the sequence has a fixed length of `seq_width`. The vectors 10 | are bounded by start and end delimiter flags. 11 | 12 | To account for the delimiter flags, the input sequence length as well 13 | width is two more than the target sequence. 14 | """ 15 | 16 | def __init__(self, task_params): 17 | """Initialize a dataset instance for copy task. 18 | 19 | Arguments 20 | --------- 21 | task_params : dict 22 | A dict containing parameters relevant to copy task. 23 | """ 24 | self.seq_width = task_params['seq_width'] 25 | self.min_seq_len = task_params['min_seq_len'] 26 | self.max_seq_len = task_params['max_seq_len'] 27 | self.in_dim = task_params['seq_width'] + 2 28 | self.out_dim = task_params['seq_width'] 29 | self.prob = task_params['prob'] 30 | self.mode = task_params['mode'] 31 | 32 | def __len__(self): 33 | # sequences are generated randomly so this does not matter 34 | # set a sufficiently large size for data loader to sample mini-batches 35 | return 65536 36 | 37 | 38 | def get_sample_wlen(self, seq_len, bs=1): 39 | # idx only acts as a counter while generating batches. 40 | 41 | if self.mode == "onehot": 42 | seq = F.one_hot(torch.randint(self.seq_width, (seq_len, bs)),num_classes=self.seq_width) 43 | else: 44 | prob = self.prob * torch.ones([seq_len, bs, self.seq_width], dtype=torch.float64) 45 | seq = Binomial(1, prob).sample() 46 | 47 | # fill in input sequence, two bit longer and wider than target 48 | input_seq = torch.zeros([seq_len + 2, bs,self.seq_width + 2]) 49 | input_seq[0, :,self.seq_width] = 1.0 # start delimiter 50 | input_seq[1:seq_len + 1,:, :self.seq_width] = seq 51 | input_seq[seq_len + 1, :, self.seq_width + 1] = 1.0 # end delimiter 52 | 53 | target_seq = torch.zeros([seq_len, bs, self.seq_width]) 54 | target_seq[:seq_len,:, :self.seq_width] = seq 55 | return {'input': input_seq, 'target': target_seq} 56 | 57 | -------------------------------------------------------------------------------- /baselines/nvm/lstm_baseline.py: -------------------------------------------------------------------------------- 1 | """LSTM Controller.""" 2 | import torch 3 | from torch import nn 4 | from torch.nn import Parameter 5 | import numpy as np 6 | 7 | 8 | class LSTMBaseline(nn.Module): 9 | """An NTM controller based on LSTM.""" 10 | def __init__(self, num_inputs, num_hidden, num_outputs, num_layers): 11 | super(LSTMBaseline, self).__init__() 12 | 13 | self.num_inputs = num_inputs 14 | self.num_hidden = num_hidden 15 | self.num_layers = num_layers 16 | 17 | self.lstm = nn.LSTM(input_size=num_inputs, 18 | hidden_size=num_hidden, 19 | num_layers=num_layers) 20 | 21 | self.out = nn.Linear(num_hidden, num_outputs) 22 | 23 | # The hidden state is a learned parameter 24 | if torch.cuda.is_available(): 25 | self.lstm_h_bias = Parameter(torch.randn(self.num_layers, 1, self.num_hidden).cuda() * 0.05) 26 | self.lstm_c_bias = Parameter(torch.randn(self.num_layers, 1, self.num_hidden).cuda() * 0.05) 27 | else: 28 | self.lstm_h_bias = Parameter(torch.randn(self.num_layers, 1, self.num_hidden) * 0.05) 29 | self.lstm_c_bias = Parameter(torch.randn(self.num_layers, 1, self.num_hidden) * 0.05) 30 | 31 | self.reset_parameters() 32 | 33 | def create_new_state(self, batch_size): 34 | # Dimension: (num_layers * num_directions, batch, hidden_size) 35 | lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1) 36 | lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1) 37 | return lstm_h, lstm_c 38 | 39 | def reset_parameters(self): 40 | for p in self.lstm.parameters(): 41 | if p.dim() == 1: 42 | nn.init.constant_(p, 0) 43 | else: 44 | stdev = 5 / (np.sqrt(self.num_inputs + self.num_hidden)) 45 | nn.init.uniform_(p, -stdev, stdev) 46 | 47 | def init_sequence(self, batch_size): 48 | """Initializing the state.""" 49 | self.previous_state = self.create_new_state(batch_size) 50 | 51 | def size(self): 52 | return self.num_inputs, self.num_hidden 53 | 54 | def forward(self, x): 55 | x = x.unsqueeze(0) 56 | outp, self.previous_state = self.lstm(x, self.previous_state) 57 | outp = self.out(outp) 58 | return outp.squeeze(0), self.previous_state 59 | 60 | def calculate_num_params(self): 61 | """Returns the total number of parameters.""" 62 | num_params = 0 63 | for p in self.parameters(): 64 | num_params += p.data.view(-1).size(0) 65 | return num_params 66 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_parser(): 5 | parser = argparse.ArgumentParser() 6 | 7 | parser.add_argument('-model_name', default='lstm', 8 | help='the name of the model') 9 | parser.add_argument('-task_json', type=str, default='tasks/copy.json', 10 | help='path to json file with task specific parameters') 11 | parser.add_argument('-log_dir', default='logs/', 12 | help='path to log metrics') 13 | parser.add_argument('-save_dir', default='saved_models/', 14 | help='path to file with final model parameters') 15 | parser.add_argument('-batch_size', type=int, default=128, 16 | help='batch size of input sequence during training') 17 | parser.add_argument('-genlen', type=float, default=2, 18 | help='scale of generalisation') 19 | parser.add_argument('-task_prob', type=float, default=0.5, 20 | help='scale of generalisation') 21 | parser.add_argument('-dropout', type=float, default=0, 22 | help='dropout') 23 | parser.add_argument('-layers', type=int, default=1, 24 | help='layers of computation') 25 | parser.add_argument('-nheads', type=int, default=1, 26 | help='number of heads') 27 | parser.add_argument('-clip_grad', type=int, default=10, 28 | help='clip gradient') 29 | parser.add_argument('-num_iters', type=int, default=50000, 30 | help='number of iterations for training') 31 | parser.add_argument('-max_len', type=int, default=300, 32 | help='max_len') 33 | parser.add_argument('-mode_toy', type=str, default="onehot", 34 | help='logit or onehot') 35 | parser.add_argument('-freq_val', type=int, default=200, 36 | help='validation frequence') 37 | parser.add_argument('-num_eval', type=int, default=10, 38 | help='number of evaluation') 39 | parser.add_argument('-mode', type=str, default="train", 40 | help='train or test') 41 | parser.add_argument('-seed', default=42,type=int, 42 | help='random seed') 43 | # todo: only rmsprop optimizer supported yet, support adam too 44 | parser.add_argument('-lr', type=float, default=1e-4, 45 | help='learning rate for rmsprop optimizer') 46 | parser.add_argument('-momentum', type=float, default=0.9, 47 | help='momentum for rmsprop optimizer') 48 | parser.add_argument('-alpha', type=float, default=0.95, 49 | help='alpha for rmsprop optimizer') 50 | parser.add_argument('-beta1', type=float, default=0.9, 51 | help='beta1 constant for adam optimizer') 52 | parser.add_argument('-beta2', type=float, default=0.999, 53 | help='beta2 constant for adam optimizer') 54 | return parser 55 | -------------------------------------------------------------------------------- /baselines/nvm/ntm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class NTM(nn.Module): 8 | """A Neural Turing Machine.""" 9 | def __init__(self, num_inputs, num_outputs, controller, memory, heads): 10 | """Initialize the NTM. 11 | :param num_inputs: External input size. 12 | :param num_outputs: External output size. 13 | :param controller: :class:`LSTMController` 14 | :param memory: :class:`NTMMemory` 15 | :param heads: list of :class:`NTMReadHead` or :class:`NTMWriteHead` 16 | Note: This design allows the flexibility of using any number of read and 17 | write heads independently, also, the order by which the heads are 18 | called in controlled by the user (order in list) 19 | """ 20 | super(NTM, self).__init__() 21 | 22 | # Save arguments 23 | self.num_inputs = num_inputs 24 | self.num_outputs = num_outputs 25 | self.controller = controller 26 | self.memory = memory 27 | self.heads = heads 28 | 29 | self.N, self.M = memory.size() 30 | _, self.controller_size = controller.size() 31 | 32 | # Initialize the initial previous read values to random biases 33 | self.num_read_heads = 0 34 | self.init_r = [] 35 | for head in heads: 36 | if head.is_read_head(): 37 | if torch.cuda.is_available(): 38 | init_r_bias = torch.randn(1, self.M).cuda() * 0.01 39 | else: 40 | init_r_bias = torch.randn(1, self.M) * 0.01 41 | self.register_buffer("read{}_bias".format(self.num_read_heads), init_r_bias.data) 42 | self.init_r += [init_r_bias] 43 | self.num_read_heads += 1 44 | 45 | assert self.num_read_heads > 0, "heads list must contain at least a single read head" 46 | 47 | # Initialize a fully connected layer to produce the actual output: 48 | # [controller_output; previous_reads ] -> output 49 | self.fc = nn.Linear(self.controller_size + self.num_read_heads * self.M, num_outputs) 50 | self.reset_parameters() 51 | 52 | def create_new_state(self, batch_size): 53 | init_r = [r.clone().repeat(batch_size, 1) for r in self.init_r] 54 | controller_state = self.controller.create_new_state(batch_size) 55 | heads_state = [head.create_new_state(batch_size) for head in self.heads] 56 | 57 | return init_r, controller_state, heads_state 58 | 59 | def reset_parameters(self): 60 | # Initialize the linear layer 61 | nn.init.xavier_uniform_(self.fc.weight, gain=1) 62 | nn.init.normal_(self.fc.bias, std=0.01) 63 | 64 | def forward(self, x, prev_state): 65 | """NTM forward function. 66 | :param x: input vector (batch_size x num_inputs) 67 | :param prev_state: The previous state of the NTM 68 | """ 69 | # Unpack the previous state 70 | prev_reads, prev_controller_state, prev_heads_states = prev_state 71 | 72 | # Use the controller to get an embeddings 73 | inp = torch.cat([x] + prev_reads, dim=1) 74 | controller_outp, controller_state = self.controller(inp, prev_controller_state) 75 | 76 | # Read/Write from the list of heads 77 | reads = [] 78 | heads_states = [] 79 | for head, prev_head_state in zip(self.heads, prev_heads_states): 80 | if head.is_read_head(): 81 | r, head_state = head(controller_outp, prev_head_state) 82 | reads += [r] 83 | else: 84 | head_state = head(controller_outp, prev_head_state) 85 | heads_states += [head_state] 86 | 87 | # Generate Output 88 | inp2 = torch.cat([controller_outp] + reads, dim=1) 89 | o = self.fc(inp2) 90 | 91 | # Pack the current state 92 | state = (reads, controller_state, heads_states) 93 | 94 | return o, state -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | logs/ 165 | saved_models/ -------------------------------------------------------------------------------- /baselines/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | import torch.nn.functional as F 4 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 5 | from torch.utils.data import dataset 6 | import math 7 | 8 | class TransformerModel(nn.Module): 9 | 10 | def __init__(self, output_dim: int, d_model: int, nhead: int, d_hid: int, 11 | nlayers: int, dropout: float = 0.5, encoder=None): 12 | super().__init__() 13 | self.model_type = 'Transformer' 14 | self.pos_encoder = PositionalEncoding(d_model, dropout) 15 | encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout) 16 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) 17 | if encoder is None: 18 | self.emb = nn.Embedding(ntoken, d_model) 19 | else: 20 | self.emb = encoder 21 | self.d_model = d_model 22 | self.out = nn.Linear(d_model, output_dim) 23 | 24 | self.init_weights() 25 | 26 | def init_weights(self) -> None: 27 | initrange = 0.1 28 | self.emb.weight.data.uniform_(-initrange, initrange) 29 | self.out.bias.data.zero_() 30 | self.out.weight.data.uniform_(-initrange, initrange) 31 | 32 | def forward(self, src: Tensor, target_length: int) -> Tensor: 33 | """ 34 | Args: 35 | src: Tensor, shape [seq_len, batch_size] 36 | src_mask: Tensor, shape [seq_len, seq_len] 37 | 38 | Returns: 39 | output Tensor of shape [seq_len, batch_size, ntoken] 40 | """ 41 | 42 | decoder_input = torch.zeros(target_length, src.shape[1], src.shape[2], device=src.device) 43 | inputsall = torch.cat([src, decoder_input], dim=0) 44 | inputsall = self.emb(inputsall) 45 | inputsall = self.pos_encoder(inputsall) 46 | # src_mask = generate_square_subsequent_mask(target_length) 47 | # decoder_mask = generate_square_decoder_mask(len(inputsall), target_length).to(inputsall.device) 48 | # print(inputsall.shape) 49 | output = self.transformer_encoder(inputsall) 50 | # print(output.shape) 51 | 52 | output = self.out(output) 53 | return output[len(src):], None 54 | 55 | def init_sequence(self, batch_size): 56 | pass 57 | 58 | def calculate_num_params(self): 59 | """Returns the total number of parameters.""" 60 | num_params = 0 61 | for p in self.parameters(): 62 | num_params += p.data.view(-1).size(0) 63 | return num_params 64 | 65 | 66 | def generate_square_subsequent_mask(sz: int) -> Tensor: 67 | """Generates an upper-triangular matrix of -inf, with zeros on diag.""" 68 | return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1) 69 | 70 | def generate_square_decoder_mask(seq_len: int, decoder_len: int) -> Tensor: 71 | """Generates an upper-triangular matrix of -inf, with zeros on diag.""" 72 | mask1 = torch.ones(seq_len, decoder_len) * float('-inf') 73 | mask2 = torch.zeros(seq_len, seq_len-decoder_len) * float('-inf') 74 | return torch.cat([mask1,mask2],dim=-1) 75 | 76 | class PositionalEncoding(nn.Module): 77 | 78 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): 79 | super().__init__() 80 | self.dropout = nn.Dropout(p=dropout) 81 | 82 | position = torch.arange(max_len).unsqueeze(1) 83 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 84 | pe = torch.zeros(max_len, 1, d_model) 85 | pe[:, 0, 0::2] = torch.sin(position * div_term) 86 | pe[:, 0, 1::2] = torch.cos(position * div_term) 87 | self.register_buffer('pe', pe) 88 | 89 | def forward(self, x: Tensor) -> Tensor: 90 | """ 91 | Args: 92 | x: Tensor, shape [seq_len, batch_size, embedding_dim] 93 | """ 94 | x = x + self.pe[:x.size(0)] 95 | return self.dropout(x) 96 | -------------------------------------------------------------------------------- /baselines/nvm/ntm_mem.py: -------------------------------------------------------------------------------- 1 | """An NTM's memory implementation.""" 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | import numpy as np 6 | 7 | 8 | def _convolve(w, s, b): 9 | """Circular convolution implementation.""" 10 | t = torch.cat([w[:,-1:], w, w[:,:1]], 1) 11 | c = F.conv1d(t.view(b, 1, -1), s.view(b, 1, -1))[range(b),range(b), :] 12 | return c 13 | 14 | def _convolve_slow(w, s): 15 | """Circular convolution implementation.""" 16 | assert s.size(0) == 3 17 | t = torch.cat([w[-1:], w, w[:1]]) 18 | c = F.conv1d(t.view(1, 1, -1), s.view(1, 1, -1)).view(-1) 19 | return c 20 | 21 | class NTMMemory(nn.Module): 22 | """Memory bank for NTM.""" 23 | def __init__(self, N, M): 24 | """Initialize the NTM Memory matrix. 25 | The memory's dimensions are (batch_size x N x M). 26 | Each batch has it's own memory matrix. 27 | :param N: Number of rows in the memory. 28 | :param M: Number of columns/features in the memory. 29 | """ 30 | super(NTMMemory, self).__init__() 31 | 32 | self.N = N 33 | self.M = M 34 | 35 | # The memory bias allows the heads to learn how to initially address 36 | # memory locations by content 37 | if torch.cuda.is_available(): 38 | self.register_buffer('mem_bias', torch.Tensor(N, M).cuda()) 39 | else: 40 | self.register_buffer('mem_bias', torch.Tensor(N, M)) 41 | 42 | # Initialize memory bias 43 | stdev = 1 / (np.sqrt(N + M)) 44 | nn.init.uniform_(self.mem_bias, -stdev, stdev) 45 | 46 | def reset(self, batch_size): 47 | """Initialize memory from bias, for start-of-sequence.""" 48 | self.batch_size = batch_size 49 | self.memory = self.mem_bias.clone().repeat(batch_size, 1, 1) 50 | 51 | def size(self): 52 | return self.N, self.M 53 | 54 | def read(self, w): 55 | """Read from memory (according to section 3.1).""" 56 | return torch.matmul(w.unsqueeze(1), self.memory).squeeze(1) 57 | 58 | def write(self, w, e, a): 59 | """write to memory (according to section 3.2).""" 60 | self.prev_mem = self.memory 61 | if torch.cuda.is_available(): 62 | self.memory = torch.Tensor(self.batch_size, self.N, self.M).cuda() 63 | else: 64 | self.memory = torch.Tensor(self.batch_size, self.N, self.M) 65 | erase = torch.matmul(w.unsqueeze(-1), e.unsqueeze(1)) 66 | add = torch.matmul(w.unsqueeze(-1), a.unsqueeze(1)) 67 | self.memory = self.prev_mem * (1 - erase) + add 68 | 69 | def address(self, k, β, g, s, γ, w_prev): 70 | """NTM Addressing (according to section 3.3). 71 | Returns a softmax weighting over the rows of the memory matrix. 72 | :param k: The key vector. 73 | :param β: The key strength (focus). 74 | :param g: Scalar interpolation gate (with previous weighting). 75 | :param s: Shift weighting. 76 | :param γ: Sharpen weighting scalar. 77 | :param w_prev: The weighting produced in the previous time step. 78 | """ 79 | # Content focus 80 | wc = self._similarity(k, β) 81 | 82 | # Location focus 83 | wg = self._interpolate(w_prev, wc, g) 84 | ŵ = self._shift(wg, s) 85 | w = self._sharpen(ŵ, γ) 86 | 87 | return w 88 | 89 | def _similarity(self, k, β): 90 | k = k.view(self.batch_size, 1, -1) 91 | w = F.softmax(β * F.cosine_similarity(self.memory + 1e-16, k + 1e-16, dim=-1), dim=1) 92 | return w 93 | 94 | def _interpolate(self, w_prev, wc, g): 95 | return g * wc + (1 - g) * w_prev 96 | 97 | def _shift(self, wg, s): 98 | result = _convolve(wg, s, wg.size(0)) 99 | return result 100 | 101 | def _shift_slow(self, wg, s): 102 | if torch.cuda.is_available(): 103 | result = torch.zeros(wg.size()).cuda() 104 | else: 105 | result = torch.zeros(wg.size()) 106 | for b in range(wg.size(0)): 107 | result[b] = _convolve_slow(wg[b], s[b]) 108 | return result 109 | 110 | def _sharpen(self, ŵ, γ): 111 | w = ŵ ** γ 112 | w = torch.div(w, torch.sum(w, dim=1).view(-1, 1) + 1e-16) 113 | return w 114 | -------------------------------------------------------------------------------- /baselines/nvm/util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch.nn as nn 5 | import torch as T 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable as var 8 | import numpy as np 9 | import torch 10 | from torch.autograd import Variable 11 | import re 12 | import string 13 | 14 | 15 | def recursiveTrace(obj): 16 | print(type(obj)) 17 | if hasattr(obj, 'grad_fn'): 18 | print(obj.grad_fn) 19 | recursiveTrace(obj.grad_fn) 20 | elif hasattr(obj, 'saved_variables'): 21 | print(obj.requires_grad, len(obj.saved_tensors), len(obj.saved_variables)) 22 | [print(v) for v in obj.saved_variables] 23 | [recursiveTrace(v.grad_fn) for v in obj.saved_variables] 24 | 25 | 26 | def cuda(x, grad=False, gpu_id=-1): 27 | if gpu_id == -1: 28 | return var(x, requires_grad=grad) 29 | else: 30 | return var(x.pin_memory(), requires_grad=grad).cuda(gpu_id) 31 | 32 | 33 | def cudavec(x, grad=False, gpu_id=-1): 34 | if gpu_id == -1: 35 | return var(T.from_numpy(x), requires_grad=grad) 36 | else: 37 | return var(T.from_numpy(x).pin_memory(), requires_grad=grad).cuda(gpu_id) 38 | 39 | 40 | def cudalong(x, grad=False, gpu_id=-1): 41 | if gpu_id == -1: 42 | return var(T.from_numpy(x.astype(np.long)), requires_grad=grad) 43 | else: 44 | return var(T.from_numpy(x.astype(np.long)).pin_memory(), requires_grad=grad).cuda(gpu_id) 45 | 46 | 47 | def θ(a, b, dimA=2, dimB=2, normBy=2): 48 | """Batchwise Cosine distance 49 | 50 | Cosine distance 51 | 52 | Arguments: 53 | a {Tensor} -- A 3D Tensor (b * m * w) 54 | b {Tensor} -- A 3D Tensor (b * r * w) 55 | 56 | Keyword Arguments: 57 | dimA {number} -- exponent value of the norm for `a` (default: {2}) 58 | dimB {number} -- exponent value of the norm for `b` (default: {1}) 59 | 60 | Returns: 61 | Tensor -- Batchwise cosine distance (b * r * m) 62 | """ 63 | a_norm = T.norm(a, normBy, dimA, keepdim=True).expand_as(a) + δ 64 | b_norm = T.norm(b, normBy, dimB, keepdim=True).expand_as(b) + δ 65 | 66 | x = T.bmm(a, b.transpose(1, 2)).transpose(1, 2) / ( 67 | T.bmm(a_norm, b_norm.transpose(1, 2)).transpose(1, 2) + δ) 68 | # apply_dict(locals()) 69 | return x 70 | 71 | 72 | def σ(input, axis=1): 73 | """Softmax on an axis 74 | 75 | Softmax on an axis 76 | 77 | Arguments: 78 | input {Tensor} -- input Tensor 79 | 80 | Keyword Arguments: 81 | axis {number} -- axis on which to take softmax on (default: {1}) 82 | 83 | Returns: 84 | Tensor -- Softmax output Tensor 85 | """ 86 | input_size = input.size() 87 | 88 | trans_input = input.transpose(axis, len(input_size) - 1) 89 | trans_size = trans_input.size() 90 | 91 | input_2d = trans_input.contiguous().view(-1, trans_size[-1]) 92 | if '0.3' in T.__version__: 93 | soft_max_2d = F.softmax(input_2d, -1) 94 | else: 95 | soft_max_2d = F.softmax(input_2d) 96 | soft_max_nd = soft_max_2d.view(*trans_size) 97 | return soft_max_nd.transpose(axis, len(input_size) - 1) 98 | 99 | δ = 1e-6 100 | 101 | 102 | def register_nan_checks(model): 103 | def check_grad(module, grad_input, grad_output): 104 | # print(module) you can add this to see that the hook is called 105 | # print('hook called for ' + str(type(module))) 106 | if any(np.all(np.isnan(gi.data.cpu().numpy())) for gi in grad_input if gi is not None): 107 | print('NaN gradient in grad_input ' + type(module).__name__) 108 | 109 | model.apply(lambda module: module.register_backward_hook(check_grad)) 110 | 111 | 112 | def apply_dict(dic): 113 | for k, v in dic.items(): 114 | apply_var(v, k) 115 | if isinstance(v, nn.Module): 116 | key_list = [a for a in dir(v) if not a.startswith('__')] 117 | for key in key_list: 118 | apply_var(getattr(v, key), key) 119 | for pk, pv in v._parameters.items(): 120 | apply_var(pv, pk) 121 | 122 | 123 | def apply_var(v, k): 124 | if isinstance(v, Variable) and v.requires_grad: 125 | v.register_hook(check_nan_gradient(k)) 126 | 127 | 128 | def check_nan_gradient(name=''): 129 | def f(tensor): 130 | if np.isnan(T.mean(tensor).data.cpu().numpy()): 131 | print('\nnan gradient of {} :'.format(name)) 132 | # print(tensor) 133 | # assert 0, 'nan gradient' 134 | return tensor 135 | return f 136 | 137 | def ptr(tensor): 138 | if T.is_tensor(tensor): 139 | return tensor.storage().data_ptr() 140 | elif hasattr(tensor, 'data'): 141 | return tensor.clone().data.storage().data_ptr() 142 | else: 143 | return tensor 144 | 145 | # TODO: EWW change this shit 146 | def ensure_gpu(tensor, gpu_id): 147 | if "cuda" in str(type(tensor)) and gpu_id != -1: 148 | return tensor.cuda(gpu_id) 149 | elif "cuda" in str(type(tensor)): 150 | return tensor.cpu() 151 | elif "Tensor" in str(type(tensor)) and gpu_id != -1: 152 | return tensor.cuda(gpu_id) 153 | elif "Tensor" in str(type(tensor)): 154 | return tensor 155 | elif type(tensor) is np.ndarray: 156 | return cudavec(tensor, gpu_id=gpu_id).data 157 | else: 158 | return tensor 159 | 160 | 161 | def print_gradient(x, name): 162 | s = "Gradient of " + name + " ----------------------------------" 163 | x.register_hook(lambda y: print(s, y.squeeze())) 164 | -------------------------------------------------------------------------------- /baselines/dnc/util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch.nn as nn 5 | import torch as T 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable as var 8 | import numpy as np 9 | import torch 10 | from torch.autograd import Variable 11 | import re 12 | import string 13 | 14 | 15 | def recursiveTrace(obj): 16 | print(type(obj)) 17 | if hasattr(obj, 'grad_fn'): 18 | print(obj.grad_fn) 19 | recursiveTrace(obj.grad_fn) 20 | elif hasattr(obj, 'saved_variables'): 21 | print(obj.requires_grad, len(obj.saved_tensors), len(obj.saved_variables)) 22 | [print(v) for v in obj.saved_variables] 23 | [recursiveTrace(v.grad_fn) for v in obj.saved_variables] 24 | 25 | 26 | def cuda(x, grad=False, gpu_id=-1): 27 | if gpu_id == -1: 28 | return var(x, requires_grad=grad) 29 | else: 30 | return var(x.pin_memory(), requires_grad=grad).cuda(gpu_id) 31 | 32 | 33 | def cudavec(x, grad=False, gpu_id=-1): 34 | if gpu_id == -1: 35 | return var(T.from_numpy(x), requires_grad=grad) 36 | else: 37 | return var(T.from_numpy(x).pin_memory(), requires_grad=grad).cuda(gpu_id) 38 | 39 | 40 | def cudalong(x, grad=False, gpu_id=-1): 41 | if gpu_id == -1: 42 | return var(T.from_numpy(x.astype(np.long)), requires_grad=grad) 43 | else: 44 | return var(T.from_numpy(x.astype(np.long)).pin_memory(), requires_grad=grad).cuda(gpu_id) 45 | 46 | 47 | def θ(a, b, dimA=2, dimB=2, normBy=2): 48 | """Batchwise Cosine distance 49 | 50 | Cosine distance 51 | 52 | Arguments: 53 | a {Tensor} -- A 3D Tensor (b * m * w) 54 | b {Tensor} -- A 3D Tensor (b * r * w) 55 | 56 | Keyword Arguments: 57 | dimA {number} -- exponent value of the norm for `a` (default: {2}) 58 | dimB {number} -- exponent value of the norm for `b` (default: {1}) 59 | 60 | Returns: 61 | Tensor -- Batchwise cosine distance (b * r * m) 62 | """ 63 | a_norm = T.norm(a, normBy, dimA, keepdim=True).expand_as(a) + δ 64 | b_norm = T.norm(b, normBy, dimB, keepdim=True).expand_as(b) + δ 65 | 66 | x = T.bmm(a, b.transpose(1, 2)).transpose(1, 2) / ( 67 | T.bmm(a_norm, b_norm.transpose(1, 2)).transpose(1, 2) + δ) 68 | # apply_dict(locals()) 69 | return x 70 | 71 | 72 | def σ(input, axis=1): 73 | """Softmax on an axis 74 | 75 | Softmax on an axis 76 | 77 | Arguments: 78 | input {Tensor} -- input Tensor 79 | 80 | Keyword Arguments: 81 | axis {number} -- axis on which to take softmax on (default: {1}) 82 | 83 | Returns: 84 | Tensor -- Softmax output Tensor 85 | """ 86 | input_size = input.size() 87 | 88 | trans_input = input.transpose(axis, len(input_size) - 1) 89 | trans_size = trans_input.size() 90 | 91 | input_2d = trans_input.contiguous().view(-1, trans_size[-1]) 92 | if '0.3' in T.__version__: 93 | soft_max_2d = F.softmax(input_2d, -1) 94 | else: 95 | soft_max_2d = F.softmax(input_2d) 96 | soft_max_nd = soft_max_2d.view(*trans_size) 97 | return soft_max_nd.transpose(axis, len(input_size) - 1) 98 | 99 | δ = 1e-6 100 | 101 | 102 | def register_nan_checks(model): 103 | def check_grad(module, grad_input, grad_output): 104 | # print(module) you can add this to see that the hook is called 105 | # print('hook called for ' + str(type(module))) 106 | if any(np.all(np.isnan(gi.data.cpu().numpy())) for gi in grad_input if gi is not None): 107 | print('NaN gradient in grad_input ' + type(module).__name__) 108 | 109 | model.apply(lambda module: module.register_backward_hook(check_grad)) 110 | 111 | 112 | def apply_dict(dic): 113 | for k, v in dic.items(): 114 | apply_var(v, k) 115 | if isinstance(v, nn.Module): 116 | key_list = [a for a in dir(v) if not a.startswith('__')] 117 | for key in key_list: 118 | apply_var(getattr(v, key), key) 119 | for pk, pv in v._parameters.items(): 120 | apply_var(pv, pk) 121 | 122 | 123 | def apply_var(v, k): 124 | if isinstance(v, Variable) and v.requires_grad: 125 | v.register_hook(check_nan_gradient(k)) 126 | 127 | 128 | def check_nan_gradient(name=''): 129 | def f(tensor): 130 | if np.isnan(T.mean(tensor).data.cpu().numpy()): 131 | print('\nnan gradient of {} :'.format(name)) 132 | # print(tensor) 133 | # assert 0, 'nan gradient' 134 | return tensor 135 | return f 136 | 137 | def ptr(tensor): 138 | if T.is_tensor(tensor): 139 | return tensor.storage().data_ptr() 140 | elif hasattr(tensor, 'data'): 141 | return tensor.clone().data.storage().data_ptr() 142 | else: 143 | return tensor 144 | 145 | # TODO: EWW change this shit 146 | def ensure_gpu(tensor, gpu_id): 147 | if "cuda" in str(type(tensor)) and gpu_id != -1: 148 | return tensor.cuda(gpu_id) 149 | elif "cuda" in str(type(tensor)): 150 | return tensor.cpu() 151 | elif "Tensor" in str(type(tensor)) and gpu_id != -1: 152 | return tensor.cuda(gpu_id) 153 | elif "Tensor" in str(type(tensor)): 154 | return tensor 155 | elif type(tensor) is np.ndarray: 156 | return cudavec(tensor, gpu_id=gpu_id).data 157 | else: 158 | return tensor 159 | 160 | 161 | def print_gradient(x, name): 162 | s = "Gradient of " + name + " ----------------------------------" 163 | x.register_hook(lambda y: print(s, y.squeeze())) 164 | -------------------------------------------------------------------------------- /baselines/nvm/con_att.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import random 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | 8 | 9 | class Encoder(nn.Module): 10 | def __init__(self, input_size, embed_size, hidden_size, 11 | n_layers=1, dropout=0.5, embedded=False): 12 | super(Encoder, self).__init__() 13 | self.input_size = input_size 14 | self.hidden_size = hidden_size 15 | self.embed_size = embed_size 16 | self.embedded = embedded 17 | self.embed = nn.Embedding(input_size, embed_size) 18 | self.gru = nn.GRU(embed_size, hidden_size, n_layers, 19 | dropout=dropout, bidirectional=True) 20 | 21 | def forward(self, src, hidden=None): 22 | if not self.embedded: 23 | embedded = self.embed(src) 24 | else: 25 | embedded = src 26 | outputs, hidden = self.gru(embedded, hidden) 27 | # sum bidirectional outputs 28 | outputs = (outputs[:, :, :self.hidden_size] + 29 | outputs[:, :, self.hidden_size:]) 30 | return outputs, hidden 31 | 32 | 33 | class Attention(nn.Module): 34 | def __init__(self, hidden_size): 35 | super(Attention, self).__init__() 36 | self.hidden_size = hidden_size 37 | self.attn = nn.Linear(self.hidden_size * 2, hidden_size) 38 | self.v = nn.Parameter(torch.rand(hidden_size)) 39 | stdv = 1. / math.sqrt(self.v.size(0)) 40 | self.v.data.uniform_(-stdv, stdv) 41 | 42 | def forward(self, hidden, encoder_outputs): 43 | timestep = encoder_outputs.size(0) 44 | h = hidden.repeat(timestep, 1, 1).transpose(0, 1) 45 | encoder_outputs = encoder_outputs.transpose(0, 1) # [B*T*H] 46 | attn_energies = self.score(h, encoder_outputs) 47 | return F.softmax(attn_energies, dim=1).unsqueeze(1) 48 | 49 | def score(self, hidden, encoder_outputs): 50 | # [B*T*2H]->[B*T*H] 51 | energy = F.relu(self.attn(torch.cat([hidden, encoder_outputs], 2))) 52 | energy = energy.transpose(1, 2) # [B*H*T] 53 | v = self.v.repeat(encoder_outputs.size(0), 1).unsqueeze(1) # [B*1*H] 54 | energy = torch.bmm(v, energy) # [B*1*T] 55 | return energy.squeeze(1) # [B*T] 56 | 57 | 58 | class Decoder(nn.Module): 59 | def __init__(self, embed_size, hidden_size, output_size, 60 | n_layers=1, dropout=0.2, embedded=False): 61 | super(Decoder, self).__init__() 62 | self.embed_size = embed_size 63 | self.hidden_size = hidden_size 64 | self.output_size = output_size 65 | self.n_layers = n_layers 66 | self.embedded = embedded 67 | self.embed = nn.Embedding(output_size, embed_size) 68 | self.dropout = nn.Dropout(dropout, inplace=True) 69 | self.attention = Attention(hidden_size) 70 | self.gru = nn.GRU(hidden_size + embed_size, hidden_size, 71 | n_layers, dropout=dropout) 72 | self.out = nn.Linear(hidden_size * 2, output_size) 73 | 74 | def forward(self, input, last_hidden, encoder_outputs): 75 | # Get the embedding of the current input word (last output word) 76 | if not self.embedded: 77 | embedded = self.embed(input).unsqueeze(0) # (1,B,N) 78 | embedded = self.dropout(embedded) 79 | else: 80 | embedded = input 81 | # Calculate attention weights and apply to encoder outputs 82 | attn_weights = self.attention(last_hidden[-1], encoder_outputs) 83 | context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) # (B,1,N) 84 | context = context.transpose(0, 1) # (1,B,N) 85 | # Combine embedded input word and attended context, run through RNN 86 | rnn_input = torch.cat([embedded, context], 2) 87 | output, hidden = self.gru(rnn_input, last_hidden) 88 | output = output.squeeze(0) # (1,B,N) -> (B,N) 89 | context = context.squeeze(0) 90 | output = self.out(torch.cat([output, context], 1)) 91 | return output, hidden, attn_weights 92 | 93 | 94 | class Seq2Seq(nn.Module): 95 | def __init__(self, in_size, out_size, embed_size, hidden_size, embedded=False, emb_layer=None): 96 | super(Seq2Seq, self).__init__() 97 | encoder = Encoder(in_size, embed_size, hidden_size, 98 | n_layers=1, dropout=0.5, embedded=embedded) 99 | decoder = Decoder(embed_size, hidden_size, out_size, 100 | n_layers=1, dropout=0.5, embedded=embedded) 101 | self.embedded = embedded 102 | self.embed_size = embed_size 103 | self.encoder = encoder 104 | self.decoder = decoder 105 | self.emb_layer = emb_layer 106 | 107 | def forward(self, src, target_length, trg=None, teacher_forcing_ratio=-1): 108 | if self.emb_layer: 109 | src = self.emb_layer(src) 110 | max_len = target_length 111 | batch_size = src.size(1) 112 | if trg is not None: 113 | max_len = trg.size(0) 114 | else: 115 | if self.embedded: 116 | trg = Variable(torch.zeros(max_len, batch_size, self.embed_size)).to(src.device) 117 | else: 118 | trg = Variable(torch.zeros(max_len, batch_size).long()).to(src.device) 119 | outputs = [] 120 | encoder_output, hidden = self.encoder(src) 121 | hidden = hidden[:self.decoder.n_layers] 122 | output = Variable(trg.data[0, :]) # sos 123 | for t in range(0, max_len): 124 | if self.embedded: 125 | output = output.unsqueeze(0) 126 | output, hidden, attn_weights = self.decoder( 127 | output, hidden, encoder_output) 128 | outputs.append(output) 129 | if teacher_forcing_ratio!=-1: 130 | is_teacher = random.random() < teacher_forcing_ratio 131 | else: 132 | is_teacher = True 133 | top1 = output.data.max(1)[1] 134 | if not is_teacher and self.embedded: 135 | top1 = self.tgt_tok_emb(top1) 136 | 137 | output = Variable(trg.data[t] if is_teacher else top1).to(src.device) 138 | return torch.stack(outputs), None 139 | 140 | def init_sequence(self, batch_size): 141 | pass 142 | 143 | def calculate_num_params(self): 144 | """Returns the total number of parameters.""" 145 | num_params = 0 146 | for p in self.parameters(): 147 | num_params += p.data.view(-1).size(0) 148 | return num_params -------------------------------------------------------------------------------- /baselines/nvm/ntm_warper.py: -------------------------------------------------------------------------------- 1 | """All in one NTM. Encapsulation of all components.""" 2 | import torch 3 | from torch import nn 4 | from .ntm import NTM 5 | from .controller import LSTMController 6 | from .head import NTMReadHead, NTMWriteHead 7 | from .ntm_mem import NTMMemory 8 | import numpy as np 9 | import torch.nn.functional as F 10 | 11 | 12 | class EncapsulatedNTM(nn.Module): 13 | 14 | def __init__(self, num_inputs, num_outputs, 15 | controller_size, controller_layers, num_heads, N, M, 16 | program_size=0, pkey_dim=0): 17 | """Initialize an EncapsulatedNTM. 18 | :param num_inputs: External number of inputs. 19 | :param num_outputs: External number of outputs. 20 | :param controller_size: The size of the internal representation. 21 | :param controller_layers: Controller number of layers. 22 | :param num_heads: Number of heads. 23 | :param N: Number of rows in the memory bank. 24 | :param M: Number of cols/features in the memory bank. 25 | """ 26 | super(EncapsulatedNTM, self).__init__() 27 | 28 | # Save args 29 | self.num_inputs = num_inputs 30 | self.num_outputs = num_outputs 31 | self.controller_size = controller_size 32 | self.controller_layers = controller_layers 33 | self.num_heads = num_heads 34 | self.N = N 35 | self.M = M 36 | self.program_size = program_size 37 | self.pkey_dim = pkey_dim 38 | self.emb = None 39 | 40 | # Create the NTM components 41 | memory = NTMMemory(N, M) 42 | controller = LSTMController(num_inputs + M*num_heads, controller_size, controller_layers) 43 | self.heads = nn.ModuleList([]) 44 | for i in range(num_heads): 45 | self.heads += [ 46 | NTMReadHead(memory, controller_size, self.program_size, self.pkey_dim), 47 | NTMWriteHead(memory, controller_size, self.program_size, self.pkey_dim) 48 | ] 49 | 50 | self.ntm = NTM(num_inputs, num_outputs, controller, memory, self.heads) 51 | self.memory = memory 52 | 53 | def init_sequence(self, batch_size): 54 | """Initializing the state.""" 55 | self.batch_size = batch_size 56 | self.memory.reset(batch_size) 57 | self.previous_state = self.ntm.create_new_state(batch_size) 58 | 59 | def forward(self, x=None): 60 | if self.emb is not None: 61 | x = self.emb(x.long()).squeeze(1) 62 | 63 | if x is None: 64 | if torch.cuda.is_available(): 65 | x = torch.zeros(self.batch_size, self.num_inputs).cuda() 66 | else: 67 | x = torch.zeros(self.batch_size, self.num_inputs) 68 | o, self.previous_state = self.ntm(x, self.previous_state) 69 | return o, self.previous_state 70 | 71 | def program_loss_pl1(self): 72 | ploss = 0 73 | count = 0 74 | for head in self.heads: 75 | for i in range(self.program_size): 76 | for j in range(i + 1, self.program_size): 77 | ploss += F.cosine_similarity \ 78 | (head.instruction_weight[i, :self.pkey_dim], 79 | head.instruction_weight[j, :self.pkey_dim], 80 | dim=0) 81 | count += 1 82 | return ploss / count 83 | 84 | def set_program_mask(self, pm): 85 | for head in self.heads: 86 | head.program_mask=pm 87 | 88 | def set_att_mode(self, mode="kv"): 89 | for head in self.heads: 90 | print("set att mode to: {}".format(mode)) 91 | head.att_mode=mode 92 | 93 | def program_loss_pl2(self): 94 | ploss = 0 95 | count = 0 96 | if torch.cuda.is_available(): 97 | I = torch.eye(self.program_size).cuda() 98 | else: 99 | I = torch.eye(self.program_size) 100 | 101 | for head in self.heads: 102 | W = head.instruction_weight[:, :self.pkey_dim] 103 | ploss += torch.norm(torch.matmul(W, torch.t(W))-I) 104 | count+=1 105 | return ploss / count 106 | 107 | def get_read_meta_info(self): 108 | meta={"read_program_weights":[], 109 | "read_query_keys":[], 110 | "read_program_keys":[], 111 | "write_program_weights": [], 112 | "write_query_keys": [], 113 | "write_program_keys": [], 114 | "read_data_weights":[], 115 | "write_data_weights":[], 116 | "css":[] 117 | } 118 | for head in self.heads: 119 | meta["css"].append(head.cs) 120 | if head.is_read_head(): 121 | if self.program_size>0: 122 | meta["read_program_weights"].append(head.program_weights) 123 | meta["read_program_keys"].append(head.instruction_weight[:, :self.pkey_dim]) 124 | meta["read_query_keys"].append(head.query_keys) 125 | 126 | meta["read_data_weights"].append(head.data_weights) 127 | else: 128 | if self.program_size>0: 129 | meta["write_program_weights"].append(head.program_weights) 130 | meta["write_program_keys"].append(head.instruction_weight[:, :self.pkey_dim]) 131 | meta["write_query_keys"].append(head.query_keys) 132 | 133 | meta["write_data_weights"].append(head.data_weights) 134 | 135 | for k,vv in meta.items(): 136 | for i1, v in enumerate(vv): 137 | if isinstance(v,list): 138 | for i2, v2 in enumerate(v): 139 | meta[k][i1][i2] = np.asarray(v2.detach().cpu()) 140 | else: 141 | meta[k][i1] = np.asarray(vv[i1].detach().cpu()) 142 | 143 | return meta 144 | 145 | def calculate_num_params(self): 146 | """Returns the total number of parameters.""" 147 | num_params = 0 148 | for p in self.parameters(): 149 | num_params += p.data.view(-1).size(0) 150 | return num_params -------------------------------------------------------------------------------- /run_algo_task.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | import numpy as np 4 | import random 5 | import os 6 | import torch 7 | from torch import nn, optim 8 | from tensorboard_logger import configure, log_value 9 | import torch.nn.functional as F 10 | 11 | from datasets import CopyDataset, ReverseDataset 12 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | 15 | from args import get_parser 16 | args = get_parser().parse_args() 17 | 18 | 19 | def get_data(mode): 20 | if "copy" == args.task_name or "reverse" == args.task_name: 21 | if mode=="train": 22 | random_length = np.random.randint(task_params['min_seq_len'], 23 | task_params['max_seq_len'] + 1) 24 | else: 25 | random_length = int(task_params['max_seq_len']*args.genlen) + 1 26 | 27 | data = dataset.get_sample_wlen(random_length, bs=args.batch_size) 28 | return data 29 | 30 | def model_compute(data, is_train=True): 31 | if torch.cuda.is_available(): 32 | input, target = data['input'].cuda(), data['target'].cuda() 33 | out = torch.zeros(target.size()).cuda() 34 | else: 35 | input, target = data['input'], data['target'] 36 | out = torch.zeros(target.size()) 37 | 38 | 39 | # ------------------------------------------------------------------------- 40 | # loop for other tasks 41 | # ------------------------------------------------------------------------- 42 | if "lstm" in args.model_name or "ntm" in args.model_name: 43 | for i in range(input.size()[0]): 44 | in_data = input[i] 45 | sout, _ = model(in_data) 46 | if torch.cuda.is_available(): 47 | in_data = torch.zeros(input.size()).cuda() 48 | else: 49 | in_data = torch.zeros(input.size()) 50 | for i in range(target.size()[0]): 51 | sout, _ = model(in_data[-1]) 52 | out[i] = sout 53 | elif "dnc" in args.model_name or "att" in args.model_name or "transformer" in args.model_name \ 54 | or "panm" in args.model_name: 55 | model.is_train = is_train 56 | out, _, = model(input, target_length=target.shape[0]) 57 | 58 | return out, target, input 59 | 60 | 61 | def get_err(out, target): 62 | binary_output = out.clone() 63 | if torch.cuda.is_available(): 64 | binary_output = binary_output.detach().cpu().apply_(lambda x: 0 if x < 0.5 else 1).cuda() 65 | else: 66 | binary_output = binary_output.detach().apply_(lambda x: 0 if x < 0.5 else 1) 67 | if "nfar" in args.task_name or "onehot" in args.mode_toy: 68 | binary_output = torch.nn.functional.one_hot(torch.argmax(out, dim=-1)) 69 | error = torch.sum(torch.argmax(out, dim=-1) != torch.argmax(target, dim=-1)).float()/(target.shape[1]) 70 | else: 71 | # sequence prediction error is calculted in bits per sequence 72 | error = torch.sum(torch.abs(binary_output - target))/args.batch_size 73 | return error, binary_output 74 | 75 | # ---------------------------------------------------------------------------- 76 | # -- initialize datasets, model, criterion and optimizer 77 | # ---------------------------------------------------------------------------- 78 | 79 | 80 | task_params = json.load(open(args.task_json)) 81 | args.task_name = task_params['task'] 82 | if 'iter' in task_params: 83 | args.num_iters = task_params['iter'] 84 | log_dir = os.path.join(args.log_dir,args.task_name+"-"+args.mode_toy) 85 | if not os.path.isdir(log_dir): 86 | os.mkdir(log_dir) 87 | log_dir = os.path.join(log_dir, args.model_name+f"-s{args.seed}") 88 | if not os.path.isdir(log_dir): 89 | os.mkdir(log_dir) 90 | 91 | 92 | 93 | save_dir = os.path.join(args.save_dir,args.task_name+args.model_name+"-"+args.mode_toy) 94 | if not os.path.isdir(save_dir): 95 | os.mkdir(save_dir) 96 | 97 | save_dirbest = os.path.join(save_dir,"{}-{}-best.pt".format(args.model_name, args.seed)) 98 | save_dirlast = os.path.join(save_dir,"{}-{}-last.pt".format(args.model_name, args.seed)) 99 | if args.mode == "test": 100 | save_dirlast = os.path.join(save_dir,"{}-{}-best.pt".format(args.model_name, args.seed)) 101 | 102 | 103 | task_params["prob"]=args.task_prob 104 | task_params["mode"]=args.mode_toy 105 | 106 | if "copy" == args.task_name: 107 | dataset = CopyDataset(task_params) 108 | if "reverse" == args.task_name: 109 | dataset = ReverseDataset(task_params) 110 | 111 | in_dim = dataset.in_dim 112 | out_dim = dataset.out_dim 113 | 114 | 115 | 116 | 117 | 118 | 119 | if 'lstm' in args.model_name: 120 | from baselines.nvm.lstm_baseline import LSTMBaseline 121 | hidden_dim = task_params['controller_size']*2 122 | model = LSTMBaseline(in_dim, hidden_dim, out_dim, 1) 123 | elif 'att' in args.model_name: 124 | from baselines.nvm.con_att import Seq2Seq 125 | hidden_dim = task_params['controller_size']*2 126 | emb_layer = torch.nn.Linear(in_dim, hidden_dim) 127 | model = Seq2Seq(hidden_dim, out_dim, hidden_dim, hidden_dim, embedded=True, emb_layer=emb_layer) 128 | elif 'dnc' in args.model_name: 129 | from baselines.dnc.dnc import DNC 130 | gpu_id=-1 131 | if torch.cuda.is_available(): 132 | gpu_id=0 133 | model = DNC( 134 | input_size=in_dim, 135 | final_output_size=out_dim, 136 | hidden_size=task_params['controller_size']*2, 137 | read_heads = task_params['num_heads'], 138 | nr_cells=task_params['memory_units'], 139 | cell_size=task_params['memory_unit_size'], 140 | gpu_id=gpu_id) 141 | elif 'ntm' in args.model_name: 142 | from baselines.nvm.ntm_warper import EncapsulatedNTM 143 | model = EncapsulatedNTM( 144 | num_inputs=in_dim, 145 | num_outputs=out_dim, 146 | controller_size=task_params['controller_size']*2, 147 | controller_layers =1, 148 | num_heads = task_params['num_heads'], 149 | N=task_params['memory_units'], 150 | M=task_params['memory_unit_size']) 151 | elif 'transformer' in args.model_name: 152 | from baselines.transformer import TransformerModel 153 | emsize = task_params['controller_size'] # embedding dimension 154 | d_hid = 512 # dimension of the feedforward network model in nn.TransformerEncoder 155 | nlayers = 2 # number of nn.TransformerEncoderLayer in nn.TransformerEncoder 156 | nhead = 8 # number of heads in nn.MultiheadAttention 157 | dropout = 0.2 # dropout probability 158 | mlp_encoder = nn.Linear(in_dim, emsize) 159 | model = TransformerModel(out_dim, emsize, nhead, d_hid, nlayers, dropout, encoder=mlp_encoder) 160 | elif 'panm' in args.model_name: 161 | from baselines.panm import PANM 162 | model = PANM( 163 | in_size = in_dim, 164 | out_size = out_dim, 165 | embed_size = in_dim, 166 | controller_size = task_params['controller_size']//2, 167 | hidden_dim = task_params['controller_size']//2, 168 | ) 169 | 170 | 171 | 172 | print(model) 173 | if torch.cuda.is_available(): 174 | model.cuda() 175 | 176 | print("====num params=====") 177 | 178 | print(model.calculate_num_params()) 179 | 180 | print("========") 181 | 182 | criterion = nn.CrossEntropyLoss() 183 | 184 | 185 | # As the learning rate is task specific, the argument can be moved to json file 186 | optimizer = optim.RMSprop(model.parameters(), 187 | lr=args.lr, 188 | alpha=args.alpha, 189 | momentum=args.momentum) 190 | 191 | 192 | 193 | 194 | 195 | cur_dir = os.getcwd() 196 | 197 | 198 | # ---------------------------------------------------------------------------- 199 | # -- basic training loop 200 | # ---------------------------------------------------------------------------- 201 | losses = [] 202 | errors = [] 203 | rel_errors = [] 204 | loss_pls = [] 205 | 206 | best_loss = 10000 207 | 208 | print(args) 209 | 210 | torch.manual_seed(args.seed) 211 | torch.cuda.manual_seed(args.seed) 212 | np.random.seed(args.seed) 213 | random.seed(args.seed) 214 | if args.mode=="train": 215 | model.train() 216 | num_iter = args.num_iters 217 | print("===training===") 218 | configure(log_dir) 219 | elif args.mode=="test": 220 | num_iter = args.num_eval 221 | print("===testing===") 222 | print(DEVICE) 223 | model.load_state_dict(torch.load(save_dirlast,map_location=DEVICE)) 224 | model.eval() 225 | print(f"load weight {save_dirlast}") 226 | 227 | 228 | for iter in tqdm(range(num_iter)): 229 | annelr = iter*1.0/num_iter 230 | model.annelr = annelr 231 | optimizer.zero_grad() 232 | model.init_sequence(batch_size=args.batch_size) 233 | 234 | data = get_data(args.mode) 235 | out, target, input = model_compute(data) 236 | 237 | # ------------------------------------------------------------------------- 238 | loss = criterion(torch.reshape(out, [-1, dataset.out_dim]), 239 | torch.argmax(torch.reshape(target, [-1, dataset.out_dim]), -1)) 240 | loss = torch.mean(loss) 241 | 242 | 243 | losses.append(loss.item()) 244 | 245 | if args.mode=="train": 246 | 247 | 248 | loss.backward() 249 | if args.clip_grad > 0: 250 | nn.utils.clip_grad_value_(model.parameters(), args.clip_grad) 251 | optimizer.step() 252 | 253 | error, binary_output = get_err(out, target) 254 | errors.append(error.item()) 255 | rel_errors.append((error/target.shape[0]).item()) 256 | 257 | 258 | # ---logging--- 259 | if args.mode=="train" and iter % args.freq_val == 0: 260 | print('Iteration: %d\tLoss: %.2f\tError in bits per sequence: %.2f' % 261 | (iter, np.mean(losses), np.mean(errors))) 262 | mloss = np.mean(losses) 263 | 264 | 265 | log_value('Train/loss', mloss, iter) 266 | log_value('Train/bit_error_per_sequence', np.mean(errors), iter) 267 | log_value('Train/percentage_error', np.mean(rel_errors), iter) 268 | 269 | losses = [] 270 | rel_errors = [] 271 | errors = [] 272 | loss_pls = [] 273 | 274 | print("EVAL ...") 275 | model.eval() 276 | eval_err = [] 277 | eval_err2 = [] 278 | for i in range(args.num_eval): 279 | data = get_data("eval") 280 | out, target, input = model_compute(data, is_train=False) 281 | error, binary_output = get_err(out, target) 282 | eval_err.append((error/target.shape[0]).item()) 283 | eval_err2.append(error.item()) 284 | 285 | merror = np.mean(eval_err) 286 | if merror <=best_loss: 287 | # ---saving the model--- 288 | print("SAVE MODEL BEST TO:\n", save_dirbest) 289 | torch.save(model.state_dict(), save_dirbest) 290 | best_loss = merror 291 | log_value('Test/bit_error_per_sequence', np.mean(eval_err2), iter) 292 | log_value('Test/percentage_error', merror, iter) 293 | model.train() 294 | 295 | if args.mode=="train": 296 | # ---saving the model--- 297 | print("SAVE LAST MODEL TO:\n", save_dirlast) 298 | torch.save(model.state_dict(), save_dirlast) 299 | 300 | if args.mode=="test": 301 | print('test_loss', np.mean(losses)) 302 | print(f'bit_error_per_sequence {np.mean(errors)} -->{np.mean(rel_errors)} over {len(errors)} samples') 303 | 304 | -------------------------------------------------------------------------------- /baselines/panm.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import random 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | from torch.nn import MultiheadAttention 8 | 9 | 10 | class PoiterUnit(nn.Module): 11 | def __init__(self, in_dim, hid_dim, emb_size, nheads=1, dropout=0, add_space=10): 12 | super(PoiterUnit, self).__init__() 13 | self.hid_dim = hid_dim 14 | self.emb_size = emb_size 15 | self.controller = nn.GRU(in_dim, hid_dim, batch_first=False) 16 | self.add_att = MultiheadAttention(embed_dim=emb_size, num_heads=nheads, dropout=dropout, kdim=add_space, vdim=emb_size) 17 | self.q = nn.Linear(hid_dim, emb_size) 18 | self.is_train = True 19 | self.A_map = nn.Linear(add_space, add_space) 20 | 21 | def forward(self, inputs, cur_dp, hidden, A): 22 | output, hidden = self.controller(cur_dp, hidden) 23 | cur_val, aweights = self.add_att(self.q(hidden),self.A_map(A), inputs) 24 | cur_add = torch.matmul(aweights, A.view(A.shape[1],A.shape[0],-1)).squeeze(1) 25 | return cur_val, cur_add, hidden, aweights 26 | 27 | 28 | 29 | class Encoder(nn.Module): 30 | def __init__(self, input_size, embed_size, hidden_size, 31 | n_layers=1, dropout=0.5, embedded=False): 32 | super(Encoder, self).__init__() 33 | self.input_size = input_size 34 | self.hidden_size = hidden_size 35 | self.embed_size = embed_size 36 | self.embedded = embedded 37 | self.embed = nn.Embedding(input_size, embed_size) 38 | self.gru = nn.GRU(embed_size, hidden_size, n_layers, 39 | dropout=dropout, bidirectional=True) 40 | 41 | def forward(self, src, hidden=None): 42 | if not self.embedded: 43 | embedded = self.embed(src) 44 | else: 45 | embedded = src 46 | outputs, hidden = self.gru(embedded, hidden) 47 | outputs = (outputs[:, :, :self.hidden_size] + 48 | outputs[:, :, self.hidden_size:]) 49 | return outputs, hidden 50 | 51 | 52 | class Attention(nn.Module): 53 | def __init__(self, hidden_size, memory_size): 54 | super(Attention, self).__init__() 55 | self.hidden_size = hidden_size 56 | self.attn = nn.Linear(self.hidden_size + memory_size, memory_size) 57 | self.v = nn.Parameter(torch.rand(memory_size)) 58 | stdv = 1. / math.sqrt(self.v.size(0)) 59 | self.v.data.uniform_(-stdv, stdv) 60 | 61 | def forward(self, hidden, encoder_outputs): 62 | timestep = encoder_outputs.size(0) 63 | h = hidden.repeat(timestep, 1, 1).transpose(0, 1) 64 | encoder_outputs = encoder_outputs.transpose(0, 1) # [B*T*H] 65 | attn_energies = self.score(h, encoder_outputs) 66 | return F.softmax(attn_energies, dim=1).unsqueeze(1) 67 | 68 | def score(self, hidden, encoder_outputs): 69 | # [B*T*2H]->[B*T*H] 70 | energy = F.relu(self.attn(torch.cat([hidden, encoder_outputs], 2))) 71 | energy = energy.transpose(1, 2) # [B*H*T] 72 | v = self.v.repeat(encoder_outputs.size(0), 1).unsqueeze(1) # [B*1*H] 73 | energy = torch.bmm(v, energy) # [B*1*T] 74 | return energy.squeeze(1) # [B*T] 75 | 76 | 77 | class Decoder(nn.Module): 78 | def __init__(self, embed_size, hidden_size, memory_size, output_size, 79 | n_layers=1, dropout=0.2, embedded=False, dis_att=False): 80 | super(Decoder, self).__init__() 81 | self.embed_size = embed_size 82 | self.hidden_size = hidden_size 83 | self.output_size = output_size 84 | self.n_layers = n_layers 85 | self.embedded = embedded 86 | self.embed = nn.Embedding(output_size, embed_size) 87 | self.dropout = nn.Dropout(dropout, inplace=True) 88 | self.dis_att = dis_att 89 | if not self.dis_att: 90 | self.attention = Attention(hidden_size, memory_size) 91 | else: 92 | memory_size = 0 93 | self.out = nn.Linear(hidden_size + memory_size , output_size) 94 | 95 | self.gru = nn.GRU(memory_size + embed_size , hidden_size, 96 | n_layers, dropout=dropout) 97 | 98 | def forward(self, input, last_hidden, encoder_outputs): 99 | # Get the embedding of the current input word (last output word) 100 | if not self.embedded: 101 | embedded = self.embed(input).unsqueeze(0) # (1,B,N) 102 | embedded = self.dropout(embedded) 103 | else: 104 | embedded = input 105 | # Calculate attention weights and apply to encoder outputs 106 | if not self.dis_att: 107 | attn_weights = self.attention(last_hidden[-1], encoder_outputs) 108 | context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) # (B,1,N) 109 | context = context.transpose(0, 1) # (1,B,N) 110 | # Combine embedded input word and attended context, run through RNN 111 | rnn_input = torch.cat([embedded, context], 2) 112 | output, hidden = self.gru(rnn_input, last_hidden) 113 | output = output.squeeze(0) # (1,B,N) -> (B,N) 114 | context = context.squeeze(0) 115 | output = self.out(torch.cat([output, context], 1)) 116 | else: 117 | output, hidden = self.gru(embedded, last_hidden) 118 | output = self.out(output).squeeze(0) 119 | attn_weights = None 120 | 121 | return output, hidden, attn_weights 122 | 123 | 124 | class PANM(nn.Module): 125 | def __init__(self, in_size, out_size, embed_size, controller_size = 512, 126 | layers = 1, nheads = 1,hidden_dim=512,dropout=0, 127 | embedded=True, add_space=8, Hc_num_pointers=1, Ha_num_pointers=2): 128 | super(PANM, self).__init__() 129 | hidden_size = hidden_dim 130 | self.Ha_num_pointers=Ha_num_pointers 131 | self.Hc_num_pointers=Hc_num_pointers 132 | self.hidden_size = hidden_dim 133 | self.embedded = embedded 134 | self.embed_size = embed_size 135 | self.encoder = Encoder(in_size, embed_size, hidden_size, 136 | n_layers=1, dropout=0.5, embedded=embedded) 137 | decoder = Decoder(embed_size+(hidden_size)*(self.Ha_num_pointers), controller_size, hidden_size, hidden_size, 138 | n_layers=1, dropout=0, embedded=embedded) 139 | self.decoder = decoder 140 | self.address_space = add_space 141 | self.num_address = 2**self.address_space 142 | self.A, _ = self.gen_address(self.num_address, 1,is_random=False) 143 | self.A=self.A.squeeze(1) 144 | self.dropout = nn.Dropout(dropout) 145 | self.out = nn.Linear(hidden_size*(self.Ha_num_pointers+1+self.Hc_num_pointers), out_size) 146 | punits = [] 147 | for _ in range(self.Ha_num_pointers): 148 | addgen = PoiterUnit(add_space, controller_size, hidden_size, add_space=add_space) 149 | punits.append(addgen) 150 | self.punits = nn.ModuleList(punits) 151 | self.add_att = MultiheadAttention(embed_dim=hidden_size, num_heads=1,dropout=0.5, kdim=hidden_size, vdim=hidden_size) 152 | tohs = [] 153 | for _ in range(self.Hc_num_pointers): 154 | toh = nn.Linear(hidden_size*self.Ha_num_pointers, hidden_size) 155 | tohs.append(toh) 156 | self.tohs = nn.ModuleList(tohs) 157 | 158 | 159 | def binary(self, x): 160 | mask = 2**torch.arange(self.address_space) 161 | return torch.tensor(x).unsqueeze(-1).bitwise_and(mask).ne(0).byte() 162 | 163 | def gen_address(self, seq_len, bs, is_random=True): 164 | A = torch.zeros(seq_len, bs, self.address_space) 165 | start_a=torch.zeros(bs, self.address_space) 166 | end_a=torch.zeros(bs, self.address_space) 167 | content_a=torch.zeros(bs, self.address_space) 168 | for b in range(bs): 169 | start_p=0 170 | if is_random: 171 | start_p=random.randint(0, self.num_address) 172 | for i in range(seq_len): 173 | j=(i+start_p)%self.num_address 174 | A[i,b,:]=self.binary(j) 175 | if i==0: 176 | start_a[b] = A[i,b,:] 177 | if i==seq_len-1: 178 | end_a[b] = A[i,b,:] 179 | return A, torch.cat([start_a, end_a, content_a],dim=-1) 180 | 181 | def fast_gen_address(self, seq_len, bs, is_fix=False): 182 | A = torch.zeros(seq_len, bs, self.address_space) 183 | start_a=torch.zeros(bs, self.address_space) 184 | end_a=torch.zeros(bs, self.address_space) 185 | content_a=torch.zeros(bs, self.address_space) 186 | for b in range(bs): 187 | if not is_fix: 188 | start_p = random.randint(0, self.num_address) 189 | else: 190 | start_p = 0 191 | A[:,b,:] = torch.roll(self.A, start_p, dims=0)[:seq_len] 192 | start_a[b] = A[0,b,:] 193 | end_a[b] = A[seq_len-1,b,:] 194 | return A, start_a, end_a, content_a 195 | 196 | 197 | def forward(self, src, target_length, tgt=None): 198 | trg=tgt 199 | encoder_output, hidden = self.encoder(src) 200 | 201 | max_len = target_length 202 | batch_size = src.size(1) 203 | if trg is not None: 204 | max_len = trg.size(0) 205 | else: 206 | trg = torch.zeros(max_len, batch_size, self.decoder.output_size).to(src.device) 207 | if self.embedded: 208 | trg = Variable(torch.zeros(max_len, batch_size, self.embed_size)).to(src.device) 209 | else: 210 | trg = Variable(torch.zeros(max_len, batch_size).long()).to(src.device) 211 | outputs = [] 212 | 213 | 214 | hidden = hidden[:self.decoder.n_layers] 215 | output = Variable(trg.data[0, :]) # sos 216 | A, cur_ptr1, cur_ptr2, cur_ptr3 = self.fast_gen_address(encoder_output.shape[0], batch_size) 217 | A = A.to(encoder_output.device) 218 | cur_ptr1 = cur_ptr1.to(encoder_output.device) 219 | cur_ptr2 = cur_ptr2.to(encoder_output.device) 220 | cur_ptr3 = cur_ptr3.to(encoder_output.device) 221 | hiddena = [] 222 | 223 | 224 | 225 | cur_ptrs_mode1 = [] 226 | for i in range(self.Ha_num_pointers): 227 | if i%2==0: 228 | cur_ptrs_mode1.append(cur_ptr1) 229 | else: 230 | cur_ptrs_mode1.append(cur_ptr2) 231 | hiddena.append(hidden.clone().detach().zero_()) 232 | 233 | cur_vals_mode2 = [] 234 | cur_ptrs_mode2 = [] 235 | 236 | for _ in range(self.Hc_num_pointers): 237 | cur_vals_mode2.append(torch.zeros(1, batch_size, self.hidden_size).to(encoder_output.device)) 238 | cur_ptrs_mode2.append(cur_ptr3.clone()) 239 | cur_val_mode2 = torch.cat(cur_vals_mode2, dim=-1) 240 | cur_ptr_mode2 = torch.cat(cur_ptrs_mode2, dim=-1) 241 | 242 | for t in range(0, max_len): 243 | cur_vals_mode1 = [] 244 | 245 | for i in range(self.Ha_num_pointers): 246 | cur_val_mode1, cur_add, hiddena[i], aweights1 = self.punits[i](encoder_output, cur_ptrs_mode1[i].unsqueeze(0), hiddena[i], A) 247 | cur_vals_mode1.append(cur_val_mode1) 248 | cur_ptrs_mode1[i] = cur_add 249 | 250 | cur_val_mode1 = torch.cat(cur_vals_mode1,dim=-1) 251 | 252 | 253 | output, hidden, attn_weights = self.decoder( 254 | torch.cat([output.unsqueeze(0), cur_val_mode1], dim=-1), 255 | hidden, encoder_output) 256 | 257 | 258 | cur_vals_mode2 = [] 259 | cur_ptrs_mode2 = [] 260 | for i in range(self.Hc_num_pointers): 261 | 262 | cur_val_mode2, aweights2 = self.add_att(self.tohs[i](cur_val_mode1.squeeze(0)).unsqueeze(0), encoder_output, encoder_output) 263 | cur_ptr_mode2 = torch.matmul(aweights2, A.view(A.shape[1],A.shape[0],-1)).squeeze(1) 264 | cur_vals_mode2.append(cur_val_mode2) 265 | cur_ptrs_mode2.append(cur_ptr_mode2) 266 | 267 | 268 | cur_val_mode2 = torch.cat(cur_vals_mode2, dim=-1) 269 | cur_ptr_mode2 = torch.cat(cur_ptrs_mode2, dim=-1) 270 | 271 | fout = torch.cat([output, cur_val_mode1.squeeze(0), cur_val_mode2.squeeze(0)],dim=-1) 272 | output = self.out(self.dropout(fout)) 273 | 274 | 275 | outputs.append(output) 276 | output = trg[t] 277 | 278 | return torch.stack(outputs), None 279 | 280 | 281 | def calculate_num_params(self): 282 | """Returns the total number of parameters.""" 283 | num_params = 0 284 | for p in self.parameters(): 285 | num_params += p.data.view(-1).size(0) 286 | return num_params 287 | 288 | def init_sequence(self, batch_size): 289 | pass -------------------------------------------------------------------------------- /baselines/nvm/head.py: -------------------------------------------------------------------------------- 1 | """NTM Read and Write Heads.""" 2 | from baselines.nvm.util import * 3 | 4 | def sample_gumbel(shape, eps=1e-20): 5 | U = torch.rand(shape).cuda() 6 | return -Variable(torch.log(-torch.log(U + eps) + eps)) 7 | 8 | def gumbel_softmax_sample(logits, temperature): 9 | y = logits + sample_gumbel(logits.size()) 10 | return F.softmax(y / temperature, dim=-1) 11 | 12 | def gumbel_softmax(logits, temperature): 13 | """ 14 | ST-gumple-softmax 15 | input: [*, n_class] 16 | return: flatten --> [*, n_class] an one-hot vector 17 | """ 18 | y = gumbel_softmax_sample(logits, temperature) 19 | shape = y.size() 20 | _, ind = y.max(dim=-1) 21 | y_hard = torch.zeros_like(y).view(-1, shape[-1]) 22 | y_hard.scatter_(1, ind.view(-1, 1), 1) 23 | y_hard = y_hard.view(*shape) 24 | y_hard = (y_hard - y).detach() + y 25 | return y_hard 26 | 27 | def _split_cols(mat, lengths): 28 | """Split a 2D matrix to variable length columns.""" 29 | assert mat.size()[1] == sum(lengths), "Lengths must be summed to num columns" 30 | l = np.cumsum([0] + lengths) 31 | results = [] 32 | for s, e in zip(l[:-1], l[1:]): 33 | results += [mat[:, s:e]] 34 | return results 35 | 36 | 37 | class NTMHeadBase(nn.Module): 38 | """An NTM Read/Write Head.""" 39 | 40 | def __init__(self, memory, controller_size): 41 | """Initilize the read/write head. 42 | :param memory: The :class:`NTMMemory` to be addressed by the head. 43 | :param controller_size: The size of the internal representation. 44 | """ 45 | super(NTMHeadBase, self).__init__() 46 | 47 | self.memory = memory 48 | self.N, self.M = memory.size() 49 | self.controller_size = controller_size 50 | self.cs = [] 51 | self.program_weights=[] 52 | self.query_keys=[] 53 | self.query_strengths=[] 54 | self.data_weights=[] 55 | self.att_mode = "kv" 56 | 57 | 58 | def create_new_state(self, batch_size): 59 | raise NotImplementedError 60 | 61 | def register_parameters(self): 62 | raise NotImplementedError 63 | 64 | def is_read_head(self): 65 | return NotImplementedError 66 | 67 | def _address_memory(self, k, β, g, s, γ, w_prev): 68 | # Handle Activations 69 | k = k.clone() 70 | β = F.softplus(β) 71 | g = F.sigmoid(g) 72 | s = F.softmax(s, dim=1) 73 | γ = 1 + F.softplus(γ) 74 | 75 | w = self.memory.address(k, β, g, s, γ, w_prev) 76 | 77 | return w 78 | 79 | def read_mem(self, memory, read_weights, key_size): 80 | return torch.bmm(read_weights, memory[:,:,key_size:]) 81 | 82 | def content_weightings(self, memory, keys, strengths, key_size, program_mask=None): 83 | if key_size>0: 84 | if self.att_mode=="kv": 85 | d = θ(F.tanh(memory[:,:,:key_size]), F.tanh(keys[:,:,:key_size])) 86 | else: 87 | d = keys 88 | else: 89 | d = θ(memory, keys) 90 | # print(memory[:,:,:key_size]) 91 | d = σ(d * strengths.unsqueeze(2), 2) 92 | if program_mask is not None: 93 | # # d = torch.abs(d*program_mask) 94 | # # print(d) 95 | # # d2 = torch.zeros(d.shape).cuda() 96 | # # _, di = d.max(-1) 97 | # # d2[:,:,di]=1 98 | # # d=d2 99 | d = gumbel_softmax(d, 10) 100 | # # print(d) 101 | return d 102 | 103 | class NTMReadHead(NTMHeadBase): 104 | def __init__(self, memory, controller_size, program_size=2, pkey_dim=2): 105 | super(NTMReadHead, self).__init__(memory, controller_size) 106 | self.program_size = program_size 107 | self.program_mask = None 108 | self.pkey_dim = pkey_dim 109 | # Corresponding to k, β, g, s, γ sizes from the paper 110 | self.read_lengths = [self.M, 1, 1, 3, 1] 111 | self.layernorm = nn.GroupNorm(1, sum(self.read_lengths)) 112 | 113 | if self.program_size>0: 114 | self.program_key = nn.Linear(controller_size, self.pkey_dim) 115 | self.program_strength = nn.Linear(controller_size, 1) 116 | 117 | self.instruction_weight = nn.Parameter(torch.zeros(self.program_size, 118 | self.pkey_dim + 119 | (self.controller_size+1)*sum(self.read_lengths), 120 | requires_grad=True)) 121 | else: 122 | self.fc_read = nn.Linear(controller_size, sum(self.read_lengths)) 123 | 124 | self.reset_parameters() 125 | 126 | def create_new_state(self, batch_size): 127 | self.cs = [] 128 | self.program_weights = [] 129 | self.query_keys = [] 130 | self.query_strengths=[] 131 | self.data_weights = [] 132 | # The state holds the previous time step address weightings 133 | if torch.cuda.is_available(): 134 | return torch.zeros(batch_size, self.N).cuda() 135 | else: 136 | return torch.zeros(batch_size, self.N) 137 | 138 | 139 | def reset_parameters(self): 140 | # Initialize the linear layers 141 | if self.program_size>0: 142 | nn.init.xavier_uniform_(self.instruction_weight, gain=1.4) 143 | nn.init.xavier_uniform_(self.program_key.weight, gain=1.4) 144 | nn.init.normal_(self.program_key.bias, std=0.01) 145 | nn.init.xavier_uniform_(self.program_strength.weight, gain=1.4) 146 | nn.init.normal_(self.program_strength.bias, std=0.01) 147 | else: 148 | nn.init.xavier_uniform_(self.fc_read.weight, gain=1.4) 149 | nn.init.normal_(self.fc_read.bias, std=0.01) 150 | 151 | 152 | def is_read_head(self): 153 | return True 154 | 155 | def forward(self, embeddings, w_prev): 156 | """NTMReadHead forward function. 157 | :param embeddings: input representation of the controller. 158 | :param w_prev: previous step state 159 | """ 160 | self.cs.append(embeddings[0]) 161 | if self.program_size>0: 162 | # if len(self.query_keys)>0: 163 | # read_keys = self.query_keys[0] 164 | # read_strengths = self.query_strengths[0] 165 | # else: 166 | read_keys = self.program_key(embeddings) 167 | read_strengths = F.softplus(self.program_strength(embeddings)) 168 | content_weights = self.content_weightings(self.instruction_weight.unsqueeze(0).repeat(read_keys.shape[0],1, 1), 169 | read_keys.unsqueeze(1), 170 | read_strengths, self.pkey_dim, self.program_mask) 171 | instruction = self.read_mem(self.instruction_weight.unsqueeze(0).repeat(read_keys.shape[0],1, 1), 172 | content_weights, self.pkey_dim) 173 | i_w = instruction[:,:,:self.controller_size*sum(self.read_lengths)].view(-1, self.controller_size, sum(self.read_lengths)) 174 | i_b = instruction[:,:,self.controller_size*sum(self.read_lengths):].view(-1, 1, sum(self.read_lengths)) 175 | 176 | o = (torch.matmul(embeddings.unsqueeze(1), i_w)+i_b).squeeze(1) 177 | self.program_weights.append(content_weights) 178 | # print(content_weights) 179 | self.query_keys.append(read_keys) 180 | self.query_strengths.append(read_strengths) 181 | else: 182 | o = self.fc_read(embeddings) 183 | # o = self.layernorm(o) 184 | k, β, g, s, γ = _split_cols(o, self.read_lengths) 185 | 186 | # Read from memory 187 | w = self._address_memory(k, β, g, s, γ, w_prev) 188 | r = self.memory.read(w) 189 | self.data_weights.append(w) 190 | return r, w 191 | 192 | 193 | class NTMWriteHead(NTMHeadBase): 194 | def __init__(self, memory, controller_size, program_size=2, pkey_dim=2): 195 | super(NTMWriteHead, self).__init__(memory, controller_size) 196 | 197 | # Corresponding to k, β, g, s, γ, e, a sizes from the paper 198 | self.write_lengths = [self.M, 1, 1, 3, 1, self.M, self.M] 199 | self.program_size = program_size 200 | self.program_mask = None 201 | self.pkey_dim = pkey_dim 202 | self.layernorm = nn.GroupNorm(1, sum(self.write_lengths)) 203 | 204 | if self.program_size>0: 205 | self.program_key = nn.Linear(controller_size, self.pkey_dim) 206 | self.program_strength = nn.Linear(controller_size, 1) 207 | 208 | self.instruction_weight = nn.Parameter(torch.zeros(self.program_size, 209 | self.pkey_dim + 210 | (self.controller_size+1)*sum(self.write_lengths), 211 | requires_grad=True)) 212 | else: 213 | self.fc_write = nn.Linear(controller_size, sum(self.write_lengths)) 214 | self.reset_parameters() 215 | 216 | def create_new_state(self, batch_size): 217 | self.cs = [] 218 | self.program_weights = [] 219 | self.query_keys = [] 220 | self.query_strengths = [] 221 | self.data_weights = [] 222 | if torch.cuda.is_available(): 223 | return torch.zeros(batch_size, self.N).cuda() 224 | else: 225 | return torch.zeros(batch_size, self.N) 226 | 227 | def reset_parameters(self): 228 | # Initialize the linear layers 229 | if self.program_size>0: 230 | nn.init.xavier_uniform_(self.instruction_weight, gain=1.4) 231 | nn.init.xavier_uniform_(self.program_key.weight, gain=1.4) 232 | nn.init.normal_(self.program_key.bias, std=0.01) 233 | nn.init.xavier_uniform_(self.program_strength.weight, gain=1.4) 234 | nn.init.normal_(self.program_strength.bias, std=0.01) 235 | else: 236 | nn.init.xavier_uniform_(self.fc_write.weight, gain=1.4) 237 | nn.init.normal_(self.fc_write.bias, std=0.01) 238 | 239 | 240 | def is_read_head(self): 241 | return False 242 | 243 | def forward(self, embeddings, w_prev): 244 | """NTMWriteHead forward function. 245 | :param embeddings: input representation of the controller. 246 | :param w_prev: previous step state 247 | """ 248 | self.cs.append(embeddings[0]) 249 | if self.program_size>0: 250 | # if len(self.query_keys)>0: 251 | # read_keys = self.query_keys[0] 252 | # read_strengths = self.query_strengths[0] 253 | # else: 254 | read_keys = self.program_key(embeddings) 255 | read_strengths = F.softplus(self.program_strength(embeddings)) 256 | content_weights = self.content_weightings(self.instruction_weight.unsqueeze(0).repeat(read_keys.shape[0],1, 1), 257 | read_keys.unsqueeze(1), 258 | read_strengths, self.pkey_dim, self.program_mask) 259 | instruction = self.read_mem(self.instruction_weight.unsqueeze(0).repeat(read_keys.shape[0], 1, 1), 260 | content_weights, self.pkey_dim) 261 | 262 | i_w = instruction[:, :, :self.controller_size * sum(self.write_lengths)].view(-1, self.controller_size, 263 | sum(self.write_lengths)) 264 | i_b = instruction[:, :, self.controller_size * sum(self.write_lengths):].view(-1, 1, sum(self.write_lengths)) 265 | 266 | o = (torch.matmul(embeddings.unsqueeze(1), i_w) + i_b).squeeze(1) 267 | self.program_weights.append(content_weights) 268 | # print(content_weights) 269 | self.query_keys.append(read_keys) 270 | self.query_strengths.append(read_strengths) 271 | # u, s, d = torch.svd(instruction[0]) 272 | # print(s[0]) 273 | else: 274 | o = self.fc_write(embeddings) 275 | # u, s, d = torch.svd(self.fc_write.weight) 276 | # print(s[0]) 277 | 278 | # o = self.layernorm(o) 279 | k, β, g, s, γ, e, a = _split_cols(o, self.write_lengths) 280 | 281 | # e should be in [0, 1] 282 | e = F.sigmoid(e) 283 | 284 | # Write to memory 285 | w = self._address_memory(k, β, g, s, γ, w_prev) 286 | self.memory.write(w, e, a) 287 | self.data_weights.append(w) 288 | return w -------------------------------------------------------------------------------- /baselines/dnc/memory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch.nn as nn 5 | import torch as T 6 | from torch.autograd import Variable as var 7 | import torch.nn.functional as F 8 | import numpy as np 9 | 10 | from .util import * 11 | 12 | 13 | class Memory(nn.Module): 14 | 15 | def __init__(self, input_size, mem_size=512, cell_size=32, read_heads=4, gpu_id=-1, independent_linears=True): 16 | super(Memory, self).__init__() 17 | 18 | self.mem_size = mem_size 19 | self.cell_size = cell_size 20 | self.read_heads = read_heads 21 | self.gpu_id = gpu_id 22 | self.input_size = input_size 23 | self.independent_linears = independent_linears 24 | 25 | m = self.mem_size 26 | w = self.cell_size 27 | r = self.read_heads 28 | 29 | if self.independent_linears: 30 | self.read_keys_transform = nn.Linear(self.input_size, w * r) 31 | self.read_strengths_transform = nn.Linear(self.input_size, r) 32 | self.write_key_transform = nn.Linear(self.input_size, w) 33 | self.write_strength_transform = nn.Linear(self.input_size, 1) 34 | self.erase_vector_transform = nn.Linear(self.input_size, w) 35 | self.write_vector_transform = nn.Linear(self.input_size, w) 36 | self.free_gates_transform = nn.Linear(self.input_size, r) 37 | self.allocation_gate_transform = nn.Linear(self.input_size, 1) 38 | self.write_gate_transform = nn.Linear(self.input_size, 1) 39 | self.read_modes_transform = nn.Linear(self.input_size, 3 * r) 40 | else: 41 | self.interface_size = (w * r) + (3 * w) + (5 * r) + 3 42 | self.interface_weights = nn.Linear(self.input_size, self.interface_size) 43 | 44 | self.I = cuda(1 - T.eye(m).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n) 45 | 46 | def reset(self, batch_size=1, hidden=None, erase=True): 47 | m = self.mem_size 48 | w = self.cell_size 49 | r = self.read_heads 50 | b = batch_size 51 | 52 | if hidden is None: 53 | return { 54 | 'memory': cuda(T.zeros(b, m, w).fill_(0), gpu_id=self.gpu_id), 55 | 'link_matrix': cuda(T.zeros(b, 1, m, m), gpu_id=self.gpu_id), 56 | 'precedence': cuda(T.zeros(b, 1, m), gpu_id=self.gpu_id), 57 | 'read_weights': cuda(T.zeros(b, r, m).fill_(0), gpu_id=self.gpu_id), 58 | 'write_weights': cuda(T.zeros(b, 1, m).fill_(0), gpu_id=self.gpu_id), 59 | 'usage_vector': cuda(T.zeros(b, m), gpu_id=self.gpu_id) 60 | } 61 | else: 62 | hidden['memory'] = hidden['memory'].clone() 63 | hidden['link_matrix'] = hidden['link_matrix'].clone() 64 | hidden['precedence'] = hidden['precedence'].clone() 65 | hidden['read_weights'] = hidden['read_weights'].clone() 66 | hidden['write_weights'] = hidden['write_weights'].clone() 67 | hidden['usage_vector'] = hidden['usage_vector'].clone() 68 | 69 | if erase: 70 | hidden['memory'].data.fill_(0) 71 | hidden['link_matrix'].data.zero_() 72 | hidden['precedence'].data.zero_() 73 | hidden['read_weights'].data.fill_(0) 74 | hidden['write_weights'].data.fill_(0) 75 | hidden['usage_vector'].data.zero_() 76 | return hidden 77 | 78 | def get_usage_vector(self, usage, free_gates, read_weights, write_weights): 79 | # write_weights = write_weights.detach() # detach from the computation graph 80 | usage = usage + (1 - usage) * (1 - T.prod(1 - write_weights, 1)) 81 | ψ = T.prod(1 - free_gates.unsqueeze(2) * read_weights, 1) 82 | return usage * ψ 83 | 84 | def allocate(self, usage, write_gate): 85 | # ensure values are not too small prior to cumprod. 86 | usage = δ + (1 - δ) * usage 87 | batch_size = usage.size(0) 88 | # free list 89 | sorted_usage, φ = T.topk(usage, self.mem_size, dim=1, largest=False) 90 | 91 | # cumprod with exclusive=True 92 | # https://discuss.pytorch.org/t/cumprod-exclusive-true-equivalences/2614/8 93 | v = var(sorted_usage.data.new(batch_size, 1).fill_(1)) 94 | cat_sorted_usage = T.cat((v, sorted_usage), 1) 95 | prod_sorted_usage = T.cumprod(cat_sorted_usage, 1)[:, :-1] 96 | 97 | sorted_allocation_weights = (1 - sorted_usage) * prod_sorted_usage.squeeze() 98 | 99 | # construct the reverse sorting index https://stackoverflow.com/questions/2483696/undo-or-reverse-argsort-python 100 | _, φ_rev = T.topk(φ, k=self.mem_size, dim=1, largest=False) 101 | allocation_weights = sorted_allocation_weights.gather(1, φ_rev.long()) 102 | 103 | return allocation_weights.unsqueeze(1), usage 104 | 105 | def write_weighting(self, memory, write_content_weights, allocation_weights, write_gate, allocation_gate): 106 | ag = allocation_gate.unsqueeze(-1) 107 | wg = write_gate.unsqueeze(-1) 108 | 109 | return wg * (ag * allocation_weights + (1 - ag) * write_content_weights) 110 | 111 | def get_link_matrix(self, link_matrix, write_weights, precedence): 112 | precedence = precedence.unsqueeze(2) 113 | write_weights_i = write_weights.unsqueeze(3) 114 | write_weights_j = write_weights.unsqueeze(2) 115 | 116 | prev_scale = 1 - write_weights_i - write_weights_j 117 | new_link_matrix = write_weights_i * precedence 118 | 119 | link_matrix = prev_scale * link_matrix + new_link_matrix 120 | # trick to delete diag elems 121 | return self.I.expand_as(link_matrix) * link_matrix 122 | 123 | def update_precedence(self, precedence, write_weights): 124 | return (1 - T.sum(write_weights, 2, keepdim=True)) * precedence + write_weights 125 | 126 | def write(self, write_key, write_vector, erase_vector, free_gates, read_strengths, write_strength, write_gate, 127 | allocation_gate, hidden): 128 | # get current usage 129 | hidden['usage_vector'] = self.get_usage_vector( 130 | hidden['usage_vector'], 131 | free_gates, 132 | hidden['read_weights'], 133 | hidden['write_weights'] 134 | ) 135 | 136 | # lookup memory with write_key and write_strength 137 | write_content_weights = self.content_weightings(hidden['memory'], write_key, write_strength) 138 | 139 | # get memory allocation 140 | alloc, _ = self.allocate( 141 | hidden['usage_vector'], 142 | allocation_gate * write_gate 143 | ) 144 | 145 | # get write weightings 146 | hidden['write_weights'] = self.write_weighting( 147 | hidden['memory'], 148 | write_content_weights, 149 | alloc, 150 | write_gate, 151 | allocation_gate 152 | ) 153 | 154 | weighted_resets = hidden['write_weights'].unsqueeze(3) * erase_vector.unsqueeze(2) 155 | reset_gate = T.prod(1 - weighted_resets, 1) 156 | # Update memory 157 | hidden['memory'] = hidden['memory'] * reset_gate 158 | 159 | hidden['memory'] = hidden['memory'] + \ 160 | T.bmm(hidden['write_weights'].transpose(1, 2), write_vector) 161 | 162 | # update link_matrix 163 | hidden['link_matrix'] = self.get_link_matrix( 164 | hidden['link_matrix'], 165 | hidden['write_weights'], 166 | hidden['precedence'] 167 | ) 168 | hidden['precedence'] = self.update_precedence(hidden['precedence'], hidden['write_weights']) 169 | 170 | return hidden 171 | 172 | def content_weightings(self, memory, keys, strengths): 173 | d = θ(memory, keys) 174 | return σ(d * strengths.unsqueeze(2), 2) 175 | 176 | def directional_weightings(self, link_matrix, read_weights): 177 | rw = read_weights.unsqueeze(1) 178 | 179 | f = T.matmul(link_matrix, rw.transpose(2, 3)).transpose(2, 3) 180 | b = T.matmul(rw, link_matrix) 181 | return f.transpose(1, 2), b.transpose(1, 2) 182 | 183 | def read_weightings(self, memory, content_weights, link_matrix, read_modes, read_weights): 184 | forward_weight, backward_weight = self.directional_weightings(link_matrix, read_weights) 185 | 186 | content_mode = read_modes[:, :, 2].contiguous().unsqueeze(2) * content_weights 187 | backward_mode = T.sum(read_modes[:, :, 0:1].contiguous().unsqueeze(3) * backward_weight, 2) 188 | forward_mode = T.sum(read_modes[:, :, 1:2].contiguous().unsqueeze(3) * forward_weight, 2) 189 | 190 | return backward_mode + content_mode + forward_mode 191 | 192 | def read_vectors(self, memory, read_weights): 193 | return T.bmm(read_weights, memory) 194 | 195 | def read(self, read_keys, read_strengths, read_modes, hidden): 196 | content_weights = self.content_weightings(hidden['memory'], read_keys, read_strengths) 197 | 198 | hidden['read_weights'] = self.read_weightings( 199 | hidden['memory'], 200 | content_weights, 201 | hidden['link_matrix'], 202 | read_modes, 203 | hidden['read_weights'] 204 | ) 205 | read_vectors = self.read_vectors(hidden['memory'], hidden['read_weights']) 206 | return read_vectors, hidden 207 | 208 | def forward(self, ξ, hidden): 209 | 210 | # ξ = ξ.detach() 211 | m = self.mem_size 212 | w = self.cell_size 213 | r = self.read_heads 214 | b = ξ.size()[0] 215 | 216 | if self.independent_linears: 217 | # r read keys (b * r * w) 218 | read_keys = F.tanh(self.read_keys_transform(ξ).view(b, r, w)) 219 | # r read strengths (b * r) 220 | read_strengths = F.softplus(self.read_strengths_transform(ξ).view(b, r)) 221 | # write key (b * 1 * w) 222 | write_key = F.tanh(self.write_key_transform(ξ).view(b, 1, w)) 223 | # write strength (b * 1) 224 | write_strength = F.softplus(self.write_strength_transform(ξ).view(b, 1)) 225 | # erase vector (b * 1 * w) 226 | erase_vector = F.sigmoid(self.erase_vector_transform(ξ).view(b, 1, w)) 227 | # write vector (b * 1 * w) 228 | write_vector = F.tanh(self.write_vector_transform(ξ).view(b, 1, w)) 229 | # r free gates (b * r) 230 | free_gates = F.sigmoid(self.free_gates_transform(ξ).view(b, r)) 231 | # allocation gate (b * 1) 232 | allocation_gate = F.sigmoid(self.allocation_gate_transform(ξ).view(b, 1)) 233 | # write gate (b * 1) 234 | write_gate = F.sigmoid(self.write_gate_transform(ξ).view(b, 1)) 235 | # read modes (b * r * 3) 236 | read_modes = σ(self.read_modes_transform(ξ).view(b, r, 3), 1) 237 | else: 238 | ξ = self.interface_weights(ξ) 239 | # r read keys (b * w * r) 240 | read_keys = F.tanh(ξ[:, :r * w].contiguous().view(b, r, w)) 241 | # r read strengths (b * r) 242 | read_strengths = F.softplus(ξ[:, r * w:r * w + r].contiguous().view(b, r)) 243 | # write key (b * w * 1) 244 | write_key = F.tanh(ξ[:, r * w + r:r * w + r + w].contiguous().view(b, 1, w)) 245 | # write strength (b * 1) 246 | write_strength = F.softplus(ξ[:, r * w + r + w].contiguous().view(b, 1)) 247 | # erase vector (b * w) 248 | erase_vector = F.sigmoid(ξ[:, r * w + r + w + 1: r * w + r + 2 * w + 1].contiguous().view(b, 1, w)) 249 | # write vector (b * w) 250 | write_vector = F.tanh(ξ[:, r * w + r + 2 * w + 1: r * w + r + 3 * w + 1].contiguous().view(b, 1, w)) 251 | # r free gates (b * r) 252 | free_gates = F.sigmoid(ξ[:, r * w + r + 3 * w + 1: r * w + 2 * r + 3 * w + 1].contiguous().view(b, r)) 253 | # allocation gate (b * 1) 254 | allocation_gate = F.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 1].contiguous().unsqueeze(1).view(b, 1)) 255 | # write gate (b * 1) 256 | write_gate = F.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 2].contiguous()).unsqueeze(1).view(b, 1) 257 | # read modes (b * 3*r) 258 | read_modes = σ(ξ[:, r * w + 2 * r + 3 * w + 3: r * w + 5 * r + 3 * w + 3].contiguous().view(b, r, 3), 1) 259 | 260 | hidden = self.write(write_key, write_vector, erase_vector, free_gates, 261 | read_strengths, write_strength, write_gate, allocation_gate, hidden) 262 | return self.read(read_keys, read_strengths, read_modes, hidden) 263 | -------------------------------------------------------------------------------- /baselines/dnc/sp_memory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch.nn as nn 5 | import torch as T 6 | from torch.autograd import Variable as var 7 | import torch.nn.functional as F 8 | import numpy as np 9 | import math 10 | from .util import * 11 | 12 | 13 | class Memory(nn.Module): 14 | 15 | def __init__(self, input_size, mem_size=512, cell_size=32, 16 | key_size=0, read_heads=4, gpu_id=-1, key_program_size=0, 17 | program_size=0, 18 | deallocate=False): 19 | super(Memory, self).__init__() 20 | 21 | self.mem_size = mem_size 22 | self.cell_size = cell_size 23 | self.read_heads = read_heads 24 | self.gpu_id = gpu_id 25 | self.input_size = input_size 26 | self.deallocate=deallocate 27 | self.key_size = key_size 28 | self.key_program_size = key_program_size 29 | self.program_size = program_size 30 | 31 | 32 | 33 | if self.key_size==0: 34 | key_size=self.cell_size 35 | 36 | m = self.mem_size 37 | w = self.cell_size 38 | r = self.read_heads 39 | 40 | self.interface_size = (key_size * r) + key_size + 2 * (key_size + w) + (2 * r) + 5 41 | if self.program_size==0: 42 | self.interface_weights = nn.Linear(self.input_size, self.interface_size, bias=False) 43 | else: 44 | self.interface_size += key_program_size + 1 45 | self.instruction_weight = nn.Parameter(cuda(T.zeros(self.program_size, 46 | self.key_program_size+self.input_size*self.interface_size) 47 | .fill_(0), gpu_id=self.gpu_id),requires_grad = True) 48 | stdv = 1. / math.sqrt(self.input_size) 49 | nn.init.uniform_(self.instruction_weight,-stdv, stdv) 50 | # self.I = cuda(1 - T.eye(m).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n) 51 | self.layernorm = nn.GroupNorm(1, self.interface_size) 52 | 53 | def reset(self, batch_size=1, hidden=None, erase=True): 54 | m = self.mem_size 55 | w = self.cell_size 56 | r = self.read_heads 57 | b = batch_size 58 | 59 | if hidden is None: 60 | hidden = { 61 | 'memory': cuda(T.zeros(b, m, self.key_size + w).fill_(0), gpu_id=self.gpu_id), 62 | 'read_weights': cuda(T.zeros(b, r, m).fill_(0), gpu_id=self.gpu_id), 63 | 'write_weights': cuda(T.zeros(b, 1, m).fill_(0), gpu_id=self.gpu_id), 64 | 'usage_vector': cuda(T.zeros(b, m), gpu_id=self.gpu_id), 65 | 'read_keys': [], 66 | 'write_keys': [], 67 | # 'write_key': cuda(T.zeros(b, 1, self.key_size).fill_(0), gpu_id=self.gpu_id), 68 | } 69 | if self.program_size>0: 70 | hidden['read_weights_program']=cuda(T.zeros(b, 1, self.program_size).fill_(0), gpu_id=self.gpu_id) 71 | hidden['pmemory']=self.instruction_weight[0,self.key_program_size:].repeat(b,1,1).\ 72 | view(b, self.input_size, self.interface_size) 73 | 74 | else: 75 | for k,v in hidden.items(): 76 | if isinstance(v, list): 77 | for i2,v2 in enumerate(v): 78 | hidden[k][i2]=v2.clone() 79 | else: 80 | hidden[k] = v.clone() 81 | # hidden['memory'] = hidden['memory'].clone() 82 | # hidden['read_weights'] = hidden['read_weights'].clone() 83 | # hidden['write_weights'] = hidden['write_weights'].clone() 84 | # hidden['usage_vector'] = hidden['usage_vector'].clone() 85 | 86 | if erase: 87 | for k, v in hidden.items(): 88 | if isinstance(v, list): 89 | hidden[k] = [] 90 | else: 91 | v.data.fill_(0) 92 | # hidden['memory'].data.fill_(0) 93 | # hidden['read_weights'].data.fill_(0) 94 | # hidden['write_weights'].data.fill_(0) 95 | # hidden['usage_vector'].data.zero_() 96 | return hidden 97 | 98 | def get_usage_vector(self, usage, free_gates, read_weights, write_weights): 99 | # write_weights = write_weights.detach() # detach from the computation graph 100 | usage = usage + (1 - usage) * (1 - T.prod(1 - write_weights, 1)) 101 | ψ = T.prod(1 - free_gates.unsqueeze(2) * read_weights, 1) 102 | return usage * ψ, ψ 103 | 104 | def allocate(self, usage): 105 | # ensure values are not too small prior to cumprod. 106 | usage = δ + (1 - δ) * usage 107 | batch_size = usage.size(0) 108 | # free list 109 | sorted_usage, φ = T.topk(usage, self.mem_size, dim=1, largest=False) 110 | 111 | # cumprod with exclusive=True 112 | # https://discuss.pytorch.org/t/cumprod-exclusive-true-equivalences/2614/8 113 | v = var(sorted_usage.data.new(batch_size, 1).fill_(1)) 114 | cat_sorted_usage = T.cat((v, sorted_usage), 1) 115 | prod_sorted_usage = T.cumprod(cat_sorted_usage, 1)[:, :-1] 116 | 117 | sorted_allocation_weights = (1 - sorted_usage) * prod_sorted_usage.squeeze() 118 | 119 | # construct the reverse sorting index https://stackoverflow.com/questions/2483696/undo-or-reverse-argsort-python 120 | _, φ_rev = T.topk(φ, k=self.mem_size, dim=1, largest=False) 121 | allocation_weights = sorted_allocation_weights.gather(1, φ_rev.long()) 122 | 123 | return allocation_weights.unsqueeze(1), usage 124 | 125 | def write_weighting(self, memory, write_content_weights, allocation_weights, write_gate, allocation_gate, last_read_weights): 126 | lastrw = allocation_gate[:, 0].unsqueeze(-1) *T.mean(last_read_weights, dim=1) 127 | nallow = allocation_gate[:, 1].unsqueeze(-1) * allocation_weights.squeeze(1) 128 | conw = allocation_gate[:, 2].unsqueeze(-1) * write_content_weights.squeeze(1) 129 | 130 | fw = lastrw + nallow + conw 131 | 132 | return (write_gate * fw).unsqueeze(1) 133 | 134 | def write(self, write_key, write_vector, erase_vector, free_gates, read_strengths, write_strength, write_gate, 135 | allocation_gate, hidden): 136 | # get current usage 137 | last_read_ws = hidden['read_weights'] 138 | hidden['usage_vector'], ψ = self.get_usage_vector( 139 | hidden['usage_vector'], 140 | free_gates, 141 | last_read_ws, 142 | hidden['write_weights'] 143 | ) 144 | 145 | # lookup memory with write_key and write_strength 146 | write_content_weights = self.content_weightings(hidden['memory'], write_key, 147 | write_strength, self.key_size) 148 | 149 | # get memory allocation 150 | alloc, _ = self.allocate( 151 | hidden['usage_vector'], 152 | ) 153 | 154 | # get write weightings 155 | 156 | hidden['write_weights'] = self.write_weighting( 157 | hidden['memory'], 158 | write_content_weights, 159 | alloc, 160 | write_gate, 161 | allocation_gate, 162 | last_read_ws 163 | ) 164 | 165 | weighted_resets = hidden['write_weights'].unsqueeze(3) * erase_vector.unsqueeze(2) 166 | reset_gate = T.prod(1 - weighted_resets, 1) 167 | 168 | # Update memory 169 | if self.deallocate: 170 | hidden['memory'][:,:,self.key_size:] = hidden['memory'][:,:,self.key_size:] * ψ 171 | hidden['memory'] = hidden['memory'] * reset_gate 172 | 173 | hidden['memory'] = self.write_mem(hidden['memory'], hidden['write_weights'], write_vector) 174 | 175 | return hidden 176 | 177 | def content_weightings(self, memory, keys, strengths, key_size): 178 | if key_size>0: 179 | d = θ(memory[:,:,:key_size], keys[:,:,:key_size]) 180 | else: 181 | d = θ(memory[:,:,:], keys) 182 | return σ(d * strengths.unsqueeze(2), 2) 183 | 184 | def read_mem(self, memory, read_weights, key_size): 185 | return T.bmm(read_weights, memory[:,:,key_size:]) 186 | 187 | def write_mem(self, memory, write_weights, write_vector): 188 | return memory + \ 189 | T.bmm(write_weights.transpose(1, 2), write_vector) 190 | 191 | def read(self, read_keys, read_strengths, hidden): 192 | content_weights = self.content_weightings(hidden['memory'], read_keys, read_strengths, self.key_size) 193 | hidden['read_weights'] = content_weights 194 | read_vectors = self.read_mem(hidden['memory'], hidden['read_weights'], self.key_size) 195 | return read_vectors, hidden 196 | 197 | def read_program(self, read_keys, read_strengths, hidden): 198 | content_weights = self.content_weightings(self.instruction_weight.unsqueeze(0).repeat(read_keys.shape[0],1,1), 199 | read_keys, read_strengths, self.key_program_size) 200 | 201 | hidden['read_weights_program'] = content_weights 202 | read_vectors = self.read_mem(self.instruction_weight.unsqueeze(0).repeat(read_keys.shape[0],1,1), 203 | hidden['read_weights_program'], self.key_program_size) 204 | return read_vectors, hidden 205 | 206 | def forward(self, ξ, hidden): 207 | 208 | # ξ = ξ.detach() 209 | k = self.cell_size 210 | if self.key_size>0: 211 | k = self.key_size 212 | m = self.mem_size 213 | w = self.cell_size 214 | r = self.read_heads 215 | b = ξ.size()[0] 216 | 217 | if self.program_size>0: 218 | ξ = T.matmul(ξ.unsqueeze(1), hidden['pmemory']).squeeze(1) 219 | # print(hidden['pmemory']) 220 | else: 221 | ξ = self.interface_weights(ξ) 222 | # print(self.interface_weights.weight) 223 | ξ = self.layernorm(ξ) 224 | counter = 0 225 | # r read keys (b * w * r) 226 | read_keys = F.tanh(ξ[:, counter:counter + r * k].contiguous().view(b, r, k)) 227 | counter += r * k 228 | # r read strengths (b * r) 229 | read_strengths = F.softplus(ξ[:, counter:counter + r].contiguous().view(b, r)) 230 | counter += r 231 | # write key (b * k * 1) 232 | write_key = F.tanh(ξ[:, counter:counter + k].contiguous().view(b, 1, k)) 233 | counter += k 234 | # write strength (b * 1) 235 | write_strength = F.softplus(ξ[:, counter].contiguous().view(b, 1)) 236 | counter += 1 237 | # erase vector (b * w) 238 | erase_vector = F.sigmoid(ξ[:, counter: counter + self.key_size + w].contiguous().view(b, 1, self.key_size + w)) 239 | counter += self.key_size + w 240 | # write vector (b * w) 241 | write_vector = F.tanh(ξ[:, counter: counter + self.key_size + w].contiguous().view(b, 1, self.key_size + w)) 242 | counter += self.key_size + w 243 | # r free gates (b * r) 244 | free_gates = F.sigmoid(ξ[:, counter: counter + r].contiguous().view(b, r)) 245 | counter += r 246 | # allocation gate (b * 3) 247 | allocation_gate = σ(ξ[:, counter: counter + 3].contiguous().view(b, 3), 1) 248 | counter += 3 249 | # write gate (b * 1) 250 | write_gate = F.sigmoid(ξ[:, counter].contiguous()).unsqueeze(1).view(b, 1) 251 | if self.program_size>0: 252 | counter+=1 253 | # r read keys program (b * w * 1) 254 | read_keys_program = F.tanh(ξ[:, counter:counter+1*self.key_program_size].contiguous().view(b, 1, self.key_program_size)) 255 | counter+=1*self.key_program_size 256 | # r read strengths program (b * 1) 257 | read_strengths_program = F.softplus(ξ[:, counter:counter+1].contiguous().view(b, 1)) 258 | 259 | 260 | hidden = self.write(write_key, write_vector, erase_vector, free_gates, 261 | read_strengths, write_strength, write_gate, allocation_gate, hidden) 262 | if self.program_size>0: 263 | read_vectors_program, hidden = self.read_program(read_keys_program, read_strengths_program, hidden) 264 | hidden['pmemory'] = read_vectors_program[:,0,:].view(read_vectors_program.shape[0],self.input_size, self.interface_size) 265 | hidden['pkey'] = self.instruction_weight[:,:self.key_program_size] 266 | hidden['ploss']=0 267 | count=0 268 | for i in range(self.program_size): 269 | for j in range(i+1, self.program_size): 270 | hidden['ploss']+= F.cosine_similarity\ 271 | (self.instruction_weight[i,:self.key_program_size], 272 | self.instruction_weight[j, :self.key_program_size], 273 | dim = 0 ) 274 | count+=1 275 | hidden['ploss']/=count 276 | read_vectors, hidden = self.read(read_keys, read_strengths, hidden) 277 | #read_vectors = T.cat((read_vectors, read_keys), 2) 278 | hidden['read_keys'].append(read_keys) 279 | hidden['write_keys'].append(write_key) 280 | #hidden['write_key']=write_key 281 | return read_vectors, hidden 282 | -------------------------------------------------------------------------------- /baselines/dnc/dnc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch.nn as nn 5 | import torch as T 6 | from torch.autograd import Variable as var 7 | import numpy as np 8 | 9 | from torch.nn.utils.rnn import pad_packed_sequence as pad 10 | from torch.nn.utils.rnn import pack_padded_sequence as pack 11 | from torch.nn.utils.rnn import PackedSequence 12 | 13 | from .util import * 14 | # from .memory import * 15 | from .memory import * 16 | 17 | from torch.nn.init import orthogonal, xavier_uniform 18 | 19 | class DNC(nn.Module): 20 | 21 | def __init__( 22 | self, 23 | input_size, 24 | final_output_size, 25 | hidden_size, 26 | rnn_type='lstm', 27 | num_layers=1, 28 | num_hidden_layers=1, 29 | bias=True, 30 | batch_first=False, 31 | dropout=0, 32 | bidirectional=False, 33 | nr_cells=5, 34 | read_heads=2, 35 | cell_size=10, 36 | nonlinearity='tanh', 37 | gpu_id=-1, 38 | independent_linears=False, 39 | share_memory=True, 40 | debug=False, 41 | clip=0, 42 | pass_through_memory=True 43 | ): 44 | super(DNC, self).__init__() 45 | # todo: separate weights and RNNs for the interface and output vectors 46 | 47 | self.input_size = input_size 48 | self.hidden_size = hidden_size 49 | self.rnn_type = rnn_type 50 | self.num_layers = num_layers 51 | self.num_hidden_layers = num_hidden_layers 52 | self.bias = bias 53 | self.batch_first = batch_first 54 | self.dropout = dropout 55 | self.bidirectional = bidirectional 56 | self.nr_cells = nr_cells 57 | self.read_heads = read_heads 58 | self.cell_size = cell_size 59 | self.nonlinearity = nonlinearity 60 | self.gpu_id = gpu_id 61 | self.independent_linears = independent_linears 62 | self.share_memory = share_memory 63 | self.debug = debug 64 | self.clip = clip 65 | self.pass_through_memory = pass_through_memory 66 | 67 | self.w = self.cell_size 68 | self.r = self.read_heads 69 | 70 | self.read_vectors_size = self.r * self.w 71 | 72 | self.nn_input_size = self.input_size 73 | if self.pass_through_memory: 74 | self.nn_input_size += self.read_vectors_size 75 | self.nn_output_size = self.hidden_size 76 | 77 | if self.pass_through_memory: 78 | self.nn_output_size += self.read_vectors_size 79 | 80 | self.emb_size = self.hidden_size 81 | 82 | 83 | if self.bidirectional: 84 | self.nn_output_size+=self.hidden_size 85 | 86 | 87 | 88 | self.final_output_size = final_output_size 89 | 90 | self.rnns = [] 91 | self.memories = [] 92 | 93 | mem_input_size = self.hidden_size 94 | if self.bidirectional: 95 | mem_input_size += self.hidden_size 96 | self.brnns = [] 97 | 98 | self.embs = [] 99 | for layer in range(self.num_layers): 100 | if layer