├── MNIST ├── crl │ ├── __init__.py │ ├── dataloader │ │ ├── __init__.py │ │ ├── generative_recognition_mapping.py │ │ ├── base_dataloader.py │ │ ├── datautils.py │ │ ├── mnist_dataset.py │ │ ├── arithmetic.py │ │ ├── modulo_datagen.py │ │ ├── languages.py │ │ ├── pretrain.py │ │ ├── utils.py │ │ └── datagen.py │ ├── requirements.txt │ ├── dataset.py │ └── utils.py ├── logbook │ ├── __init__.py │ ├── util.py │ └── logbook.py ├── utils │ ├── __init__.py │ ├── test_sample.py │ ├── misc.py │ ├── util.py │ ├── rms.py │ ├── device.py │ ├── sample.py │ ├── cli.py │ └── log.py ├── prelude.py ├── utilities │ ├── GroupLinearLayer.py │ ├── rule_stats.py │ ├── sparse_grad_attn.py │ ├── set_transformer.py │ ├── sparse_attn.py │ ├── invariant_modules.py │ ├── slot_attention_old.py │ ├── SharedGroupLinearLayer.py │ ├── layer_conn_attention.py │ ├── slot_attention_custom.py │ ├── slot_attention_custom_2.py │ ├── BlockGRU.py │ └── BlockLSTM.py ├── run.sh ├── init.py ├── bootstrap.py ├── block_wrapper.py ├── dataset.py └── argument_parser.py ├── synthetic ├── README.md ├── utilities │ ├── GroupLinearLayer.py │ ├── rule_stats.py │ ├── sparse_grad_attn.py │ ├── set_transformer.py │ ├── sparse_attn.py │ ├── invariant_modules.py │ ├── slot_attention_old.py │ ├── SharedGroupLinearLayer.py │ ├── layer_conn_attention.py │ ├── BlockLSTM.py │ ├── slot_attention_custom_2.py │ ├── slot_attention_custom.py │ └── BlockGRU.py ├── eval_runner.sh ├── runner.sh ├── model.py └── data.py ├── README.md └── requirements.txt /MNIST/crl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MNIST/logbook/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MNIST/crl/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /synthetic/README.md: -------------------------------------------------------------------------------- 1 | # synthetic 2 | -------------------------------------------------------------------------------- /MNIST/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .device import Device 2 | from .log import Logger 3 | from .sample import sample_indices 4 | -------------------------------------------------------------------------------- /MNIST/crl/requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision 2 | tqdm 3 | matplotlib 4 | numpy 5 | git+https://github.com//ncullen93/torchsample#0.1.3 6 | pandas 7 | nibabel 8 | -------------------------------------------------------------------------------- /MNIST/utils/test_sample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .sample import sample_indices 3 | 4 | 5 | def test_sample_large(): 6 | arr = sample_indices(100, 90) 7 | assert np.unique(arr).__len__() == 90 8 | 9 | 10 | def test_sample_small(): 11 | arr = sample_indices(100, 30) 12 | assert np.unique(arr).__len__() == 30 13 | -------------------------------------------------------------------------------- /MNIST/utils/misc.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import operator 3 | from torch import Tensor 4 | from typing import Any, Iterable 5 | 6 | 7 | def iter_prod(it: Iterable[Any]) -> Any: 8 | return reduce(operator.mul, it) 9 | 10 | 11 | def normalize_(t: Tensor, eps: float) -> None: 12 | mean = t.mean() 13 | std = t.std() 14 | t.sub_(mean).div_(std + eps) 15 | -------------------------------------------------------------------------------- /MNIST/prelude.py: -------------------------------------------------------------------------------- 1 | from torch import nn, Tensor 2 | from typing import Any, Callable, Iterable, Sequence, Tuple, TypeVar, Union 3 | from utils.device import Device 4 | 5 | 6 | try: 7 | from typing import GenericMeta, NamedTupleMeta # type: ignore 8 | 9 | class GenericNamedMeta(NamedTupleMeta, GenericMeta): 10 | pass 11 | except ImportError: 12 | from typing import NamedTupleMeta # type: ignore 13 | GenericNamedMeta = NamedTupleMeta # type: ignore 14 | 15 | T = TypeVar('T') 16 | Self = Any 17 | 18 | 19 | class Array(Sequence[T]): 20 | @property 21 | def shape(self) -> tuple: 22 | ... 23 | 24 | def squeeze(self) -> Self: 25 | ... 26 | 27 | def transpose(self, *args) -> Self: 28 | ... 29 | 30 | def __rsub__(self, value: Any) -> Self: 31 | ... 32 | 33 | def __truediv__(self, rvalue: Any) -> Self: 34 | ... 35 | 36 | 37 | Action = TypeVar('Action', int, Array) 38 | State = TypeVar('State') 39 | NetFn = Callable[[Tuple[int, ...], int, Device], nn.Module] 40 | Params = Iterable[Union[Tensor, dict]] 41 | -------------------------------------------------------------------------------- /MNIST/utils/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions""" 2 | import collections 3 | import os 4 | import random 5 | import pathlib 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def set_seed(seed): 12 | """Set seed""" 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | if torch.cuda.is_available(): 17 | torch.cuda.manual_seed_all(seed) 18 | os.environ['PYTHONHASHSEED'] = str(seed) 19 | 20 | 21 | def flatten_dict(d, parent_key='', sep='#'): 22 | """Method to flatten a given dict using the given seperator. 23 | Taken from https://stackoverflow.com/a/6027615/1353861 24 | """ 25 | items = [] 26 | for k, v in d.items(): 27 | new_key = parent_key + sep + k if parent_key else k 28 | if isinstance(v, collections.MutableMapping): 29 | items.extend(flatten_dict(v, new_key, sep=sep).items()) 30 | else: 31 | items.append((new_key, v)) 32 | return dict(items) 33 | 34 | 35 | def make_dir(path): 36 | """Make dir""" 37 | pathlib.Path(path).mkdir(parents=True, exist_ok=True) 38 | -------------------------------------------------------------------------------- /MNIST/utilities/GroupLinearLayer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | class GroupLinearLayer(nn.Module): 6 | def __init__(self, din, dout, num_blocks, bias=True, a = None): 7 | super(GroupLinearLayer, self).__init__() 8 | self.nb = num_blocks 9 | #din = din // num_blocks 10 | #dout = dout // num_blocks 11 | self.dout = dout 12 | if a is None: 13 | a = 1. / math.sqrt(dout) 14 | self.weight = nn.Parameter(torch.FloatTensor(num_blocks,din,dout).uniform_(-a,a)) 15 | self.bias = bias 16 | if bias is True: 17 | self.bias = nn.Parameter(torch.FloatTensor(num_blocks,dout).uniform_(-a,a)) 18 | #self.bias = nn.Parameter(torch.zeros(dout*num_blocks)) 19 | else: 20 | self.bias = None 21 | def forward(self,x): 22 | ts,bs,m = x.shape 23 | #x = x.reshape((ts*bs, self.nb, m//self.nb)) 24 | x = x.permute(1,0,2) 25 | x = torch.bmm(x,self.weight) 26 | x = x.permute(1,0,2) 27 | if not self.bias is None: 28 | x = x + self.bias 29 | #x = x.reshape((ts, bs, self.dout*self.nb)) 30 | return x 31 | 32 | -------------------------------------------------------------------------------- /synthetic/utilities/GroupLinearLayer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | class GroupLinearLayer(nn.Module): 6 | def __init__(self, din, dout, num_blocks, bias=True, a = None): 7 | super(GroupLinearLayer, self).__init__() 8 | self.nb = num_blocks 9 | #din = din // num_blocks 10 | #dout = dout // num_blocks 11 | self.dout = dout 12 | if a is None: 13 | a = 1. / math.sqrt(dout) 14 | self.weight = nn.Parameter(torch.FloatTensor(num_blocks,din,dout).uniform_(-a,a)) 15 | self.bias = bias 16 | if bias is True: 17 | self.bias = nn.Parameter(torch.FloatTensor(num_blocks,dout).uniform_(-a,a)) 18 | #self.bias = nn.Parameter(torch.zeros(dout*num_blocks)) 19 | else: 20 | self.bias = None 21 | def forward(self,x): 22 | ts,bs,m = x.shape 23 | #x = x.reshape((ts*bs, self.nb, m//self.nb)) 24 | x = x.permute(1,0,2) 25 | x = torch.bmm(x,self.weight) 26 | x = x.permute(1,0,2) 27 | if not self.bias is None: 28 | x = x + self.bias 29 | #x = x.reshape((ts, bs, self.dout*self.nb)) 30 | return x 31 | 32 | -------------------------------------------------------------------------------- /MNIST/utilities/rule_stats.py: -------------------------------------------------------------------------------- 1 | def get_stats(rule_selections, variable_selections, variable_rules, application_option, num_rules = 5, num_blocks = 1): 2 | if isinstance(application_option, str): 3 | application_option = int(application_option.split('.')[0]) 4 | for b in range(rule_selections[0].shape[0]): 5 | for w in range(len(rule_selections)): 6 | if application_option == 0 or application_option == 3: 7 | try: 8 | tup = (rule_selections[w][b][0], variable_selections[w][b][0]) 9 | except: 10 | tup = (rule_selections[w][b], variable_selections[w][b]) 11 | elif application_option == 1: 12 | y = rule_selections[w][b] 13 | 14 | r1 = y[0] % num_rules 15 | v1 = y[0] % num_blocks 16 | r2 = y[1] % num_rules 17 | v2 = y[1] % num_blocks 18 | tup = (r1, v1, r2, v2) 19 | if tup not in variable_rules: 20 | variable_rules[tup] = 1 21 | else: 22 | variable_rules[tup] += 1 23 | return variable_rules -------------------------------------------------------------------------------- /synthetic/utilities/rule_stats.py: -------------------------------------------------------------------------------- 1 | def get_stats(rule_selections, variable_selections, variable_rules, application_option, num_rules = 5, num_blocks = 1): 2 | if isinstance(application_option, str): 3 | application_option = int(application_option.split('.')[0]) 4 | for b in range(rule_selections[0].shape[0]): 5 | for w in range(len(rule_selections)): 6 | if application_option == 0 or application_option == 3: 7 | try: 8 | tup = (rule_selections[w][b][0], variable_selections[w][b][0]) 9 | except: 10 | tup = (rule_selections[w][b], variable_selections[w][b]) 11 | elif application_option == 1: 12 | y = rule_selections[w][b] 13 | 14 | r1 = y[0] % num_rules 15 | v1 = y[0] % num_blocks 16 | r2 = y[1] % num_rules 17 | v2 = y[1] % num_blocks 18 | tup = (r1, v1, r2, v2) 19 | if tup not in variable_rules: 20 | variable_rules[tup] = 1 21 | else: 22 | variable_rules[tup] += 1 23 | return variable_rules 24 | -------------------------------------------------------------------------------- /synthetic/eval_runner.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | rule_time_steps=1 4 | num_rules=$1 5 | rule_emb_dim=$2 6 | batch_size=50 7 | split=mcd1 8 | epochs=100 9 | algo=lstm 10 | lr=0.0001 11 | perm_inv=False 12 | num_blocks=4 13 | hidden_dim=$3 14 | n_layers=1 15 | n_templates=2 16 | application_option=3.0.-1 17 | seed=$4 18 | comm=False 19 | anneal_rate=0.01 20 | use_entropy=False 21 | use_biases=False 22 | dir_name=$num_rules-$rule_emb_dim-$hidden_dim-$seed-False 23 | 24 | 25 | python eval.py --use_rules --comm $comm --grad no --transformer yes --application_option $application_option --seed $seed \ 26 | --use_attention --alternate_training no --split $split --n_templates $n_templates\ 27 | --algo $algo --use_entropy $use_entropy \ 28 | --save_dir $dir_name \ 29 | --lr $lr --drop 0.5 --nhid $hidden_dim --num_blocks $num_blocks --topk $num_blocks \ 30 | --nlayers $n_layers --cuda --cudnn --emsize 300 --log-interval 50 --perm_inv $perm_inv \ 31 | --epochs $epochs --train_len 50 --test_len 200 --gumble_anneal_rate $anneal_rate \ 32 | --rule_time_steps $rule_time_steps --num_rules $num_rules --rule_emb_dim $rule_emb_dim --batch_size $batch_size --use_biases $use_biases --timesteps 10 | tee -a "$dir_name/eval.log" 33 | 34 | 35 | -------------------------------------------------------------------------------- /MNIST/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | training_dataset='crl' 5 | dim1=100 6 | block1=1 7 | topk1=1 8 | lr=0.0001 9 | encoder=1 10 | version=2 11 | att_out=64 12 | application_option=3.0.-1 13 | num_rules=4 14 | rule_time_steps=1 15 | num_transforms=4 16 | transform_length=4 17 | algo=RIM 18 | rule_dim=6 19 | seed=${1} 20 | share_key_value=False 21 | color=False 22 | inp_heads=1 23 | templates=0 24 | drop=0 25 | comm=False 26 | name="CRL_${algo}-"$dim1"_"$block1"_num_rules_"$num_rules"_rule_time_steps_"$rule_time_steps-$seed-${share_key_value}-${color} 27 | OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 python run_crl.py --algo $algo --train_dataset $training_dataset --hidden_size $dim1 --color $color --should_save_csv False --lr $lr --id $name --num_blocks $block1 --topk $topk1 --batch_frequency_to_log_heatmaps -1 --num_modules_read_input 2 --inp_heads $inp_heads --do_rel False --share_inp True --share_comm True --share_key_value $share_key_value --n_templates $templates --num_encoders $encoder --version $version --do_comm $comm --num_rules $num_rules --rule_time_steps $rule_time_steps --version $version --attention_out $att_out --dropout $drop --application_option=$application_option --num_transforms $num_transforms --transform_length $transform_length --rule_dim $rule_dim --seed $seed 28 | 29 | -------------------------------------------------------------------------------- /synthetic/runner.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | rule_time_steps=1 4 | num_rules=$1 5 | rule_emb_dim=$2 6 | hidden_dim=$3 7 | application_option=3.0.-1 8 | seed=$4 9 | generalize=False 10 | dir_name=$num_rules-$rule_emb_dim-$hidden_dim-$seed-$generalize 11 | 12 | 13 | batch_size=50 14 | epochs=100 15 | lr=0.0001 16 | perm_inv=False 17 | num_blocks=4 18 | anneal_rate=0.01 19 | use_entropy=False 20 | use_biases=False 21 | comm=False 22 | algo=lstm 23 | n_layers=1 24 | n_templates=2 25 | mkdir $dir_name 26 | 27 | 28 | python main.py --use_rules --comm $comm --grad no --transformer yes --application_option $application_option --seed $seed \ 29 | --use_attention --alternate_training no --n_templates $n_templates\ 30 | --algo $algo --use_entropy $use_entropy \ 31 | --save_dir $dir_name \ 32 | --lr $lr --drop 0.5 --nhid $hidden_dim --num_blocks $num_blocks --topk $num_blocks \ 33 | --nlayers $n_layers --cuda --cudnn --emsize 300 --log-interval 50 --perm_inv $perm_inv \ 34 | --epochs $epochs --train_len 50 --test_len 200 --gumble_anneal_rate $anneal_rate --generalize $generalize \ 35 | --rule_time_steps $rule_time_steps --num_rules $num_rules --rule_emb_dim $rule_emb_dim --batch_size $batch_size --use_biases $use_biases | tee -a "$dir_name/train.log" 36 | 37 | #./runnner.sh 3 32 64 0 False^C 38 | -------------------------------------------------------------------------------- /MNIST/utils/rms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Tuple 3 | from ..prelude import Array 4 | 5 | 6 | class RunningMeanStd(object): 7 | """From https://github.com/openai/baselines/blob/master/baselines/common/running_mean_std.py 8 | """ 9 | def __init__(self, epsilon: float = 1.0e-4, shape: Tuple[int, ...] = ()) -> None: 10 | self.mean = np.zeros(shape, 'float64') 11 | self.var = np.ones(shape, 'float64') 12 | self.count = epsilon 13 | 14 | def update(self, x: Array[float]) -> None: 15 | self.mean, self.var, self.count = _update_mean_var_count_from_moments( 16 | self.mean, 17 | self.var, 18 | self.count, 19 | np.mean(x, axis=0), 20 | np.var(x, axis=0), 21 | x.shape[0] 22 | ) 23 | 24 | def std(self, eps: float = 1.0e-8) -> Array[float]: 25 | return np.sqrt(self.var + eps) 26 | 27 | 28 | def _update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count): 29 | delta = batch_mean - mean 30 | tot_count = count + batch_count 31 | new_mean = mean + delta * batch_count / tot_count 32 | m_a = var * count 33 | m_b = batch_var * batch_count 34 | M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count 35 | new_var = M2 / tot_count 36 | new_count = tot_count 37 | return new_mean, new_var, new_count 38 | -------------------------------------------------------------------------------- /MNIST/crl/dataloader/generative_recognition_mapping.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | class GR_Map(object): 4 | def __init__(self, tc): 5 | self.tc = tc 6 | self.generative_recognition_map = {} 7 | 8 | def get_gr_map(self): 9 | return copy.deepcopy(self.generative_recognition_map) 10 | 11 | class GR_Map_full(GR_Map): 12 | def __init__(self, tc): 13 | super(GR_Map_full, self).__init__(tc) 14 | stn_ids = { 15 | 'rotate': 0, 16 | 'scale': 1, 17 | 18 | 'translate_up_small': 2, 19 | 'translate_down_small': 3, 20 | 'translate_left_small': 4, 21 | 'translate_right_small': 5, 22 | 23 | 'translate_up_normal': 6, 24 | 'translate_down_normal': 7, 25 | 'translate_left_normal': 8, 26 | 'translate_right_normal': 9, 27 | 28 | 'translate_up_big': 10, 29 | 'translate_down_big': 11, 30 | 'translate_left_big': 12, 31 | 'translate_right_big': 13, 32 | 33 | 'identity': 14} 34 | 35 | # generative ids 36 | rotate_ids = [x.id for x in self.tc.rotate] 37 | scale_ids = [x.id for x in self.tc.scale] 38 | identity_ids = [self.tc.identity.id] 39 | 40 | for i in rotate_ids: 41 | self.generative_recognition_map[i] = stn_ids['rotate'] 42 | for i in scale_ids: 43 | self.generative_recognition_map[i] = stn_ids['scale'] 44 | 45 | for k, v in filter(lambda k_v: 'translate' in k_v[0], stn_ids.items()): 46 | self.generative_recognition_map[v+2] = stn_ids[k] 47 | 48 | for i in identity_ids: 49 | self.generative_recognition_map[i] = stn_ids['identity'] -------------------------------------------------------------------------------- /MNIST/utilities/sparse_grad_attn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Giving an N x M attention matrix, returns the same matrix, 3 | but performs masking to determine where to block gradients. 4 | ''' 5 | 6 | import numpy 7 | import torch 8 | from torch.autograd import Variable 9 | 10 | from .sparse_attn import Sparse_attention 11 | 12 | 13 | class blocked_grad(torch.autograd.Function): 14 | 15 | @staticmethod 16 | def forward(ctx, x, mask): 17 | ctx.save_for_backward(x, mask) 18 | return x 19 | 20 | @staticmethod 21 | def backward(ctx, grad_output): 22 | x, mask = ctx.saved_tensors 23 | return grad_output * mask, mask * 0.0 24 | 25 | 26 | class Sparse_grad_attention(torch.autograd.Function): 27 | # def __init__(self, top_k): 28 | # super(Sparse_grad_attention,self).__init__() 29 | # 30 | # self.sa = Sparse_attention(top_k=top_k) 31 | 32 | @staticmethod 33 | def forward(ctx, inp, sa): 34 | sparsified = sa(inp) 35 | ctx.save_for_backward(inp, sparsified) 36 | 37 | return inp 38 | 39 | @staticmethod 40 | def backward(ctx, grad_output): 41 | inp, sparsified = ctx.saved_tensors 42 | # print('sparsified', sparsified) 43 | return (grad_output) * (sparsified > 0.0).float() 44 | 45 | 46 | if __name__ == "__main__": 47 | k = 2 48 | sga = Sparse_grad_attention(k) 49 | sa = Sparse_attention(k) 50 | 51 | x = torch.from_numpy(numpy.array([[[0.1, 0.0, 0.3, 0.2, 0.4], 52 | [0.5, 0.4, 0.1, 0.0, 0.0]]])) 53 | x = x.reshape((2, 5)) 54 | 55 | x = Variable(x, requires_grad=True) 56 | 57 | print(x) 58 | print('output', sga(x)) 59 | 60 | (sga(x).sum()).backward() 61 | 62 | print('sparse grad', x.grad) 63 | 64 | x = Variable(x.data, requires_grad=True) 65 | 66 | (sa(x).sum()).backward() 67 | 68 | print('normal grad', x.grad) 69 | -------------------------------------------------------------------------------- /synthetic/utilities/sparse_grad_attn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Giving an N x M attention matrix, returns the same matrix, 3 | but performs masking to determine where to block gradients. 4 | ''' 5 | 6 | import numpy 7 | import torch 8 | from torch.autograd import Variable 9 | 10 | from .sparse_attn import Sparse_attention 11 | 12 | 13 | class blocked_grad(torch.autograd.Function): 14 | 15 | @staticmethod 16 | def forward(ctx, x, mask): 17 | ctx.save_for_backward(x, mask) 18 | return x 19 | 20 | @staticmethod 21 | def backward(ctx, grad_output): 22 | x, mask = ctx.saved_tensors 23 | return grad_output * mask, mask * 0.0 24 | 25 | 26 | class Sparse_grad_attention(torch.autograd.Function): 27 | # def __init__(self, top_k): 28 | # super(Sparse_grad_attention,self).__init__() 29 | # 30 | # self.sa = Sparse_attention(top_k=top_k) 31 | 32 | @staticmethod 33 | def forward(ctx, inp, sa): 34 | sparsified = sa(inp) 35 | ctx.save_for_backward(inp, sparsified) 36 | 37 | return inp 38 | 39 | @staticmethod 40 | def backward(ctx, grad_output): 41 | inp, sparsified = ctx.saved_tensors 42 | # print('sparsified', sparsified) 43 | return (grad_output) * (sparsified > 0.0).float() 44 | 45 | 46 | if __name__ == "__main__": 47 | k = 2 48 | sga = Sparse_grad_attention(k) 49 | sa = Sparse_attention(k) 50 | 51 | x = torch.from_numpy(numpy.array([[[0.1, 0.0, 0.3, 0.2, 0.4], 52 | [0.5, 0.4, 0.1, 0.0, 0.0]]])) 53 | x = x.reshape((2, 5)) 54 | 55 | x = Variable(x, requires_grad=True) 56 | 57 | print(x) 58 | print('output', sga(x)) 59 | 60 | (sga(x).sum()).backward() 61 | 62 | print('sparse grad', x.grad) 63 | 64 | x = Variable(x.data, requires_grad=True) 65 | 66 | (sa(x).sum()).backward() 67 | 68 | print('normal grad', x.grad) 69 | -------------------------------------------------------------------------------- /MNIST/utilities/set_transformer.py: -------------------------------------------------------------------------------- 1 | from .invariant_modules import * 2 | 3 | class DeepSet(nn.Module): 4 | def __init__(self, dim_input, num_outputs, dim_output, dim_hidden=128): 5 | super(DeepSet, self).__init__() 6 | self.num_outputs = num_outputs 7 | self.dim_output = dim_output 8 | self.enc = nn.Sequential( 9 | nn.Linear(dim_input, dim_hidden), 10 | nn.ReLU(), 11 | nn.Linear(dim_hidden, dim_hidden), 12 | nn.ReLU(), 13 | nn.Linear(dim_hidden, dim_hidden), 14 | nn.ReLU(), 15 | nn.Linear(dim_hidden, dim_hidden)) 16 | self.dec = nn.Sequential( 17 | nn.Linear(dim_hidden, dim_hidden), 18 | nn.ReLU(), 19 | nn.Linear(dim_hidden, dim_hidden), 20 | nn.ReLU(), 21 | nn.Linear(dim_hidden, dim_hidden), 22 | nn.ReLU(), 23 | nn.Linear(dim_hidden, num_outputs*dim_output)) 24 | 25 | def forward(self, X): 26 | X = self.enc(X).mean(-2) 27 | X = self.dec(X).reshape(-1, self.num_outputs, self.dim_output) 28 | return X 29 | 30 | class SetTransformer(nn.Module): 31 | def __init__(self, dim_input, num_outputs, dim_output, 32 | num_inds=32, dim_hidden=128, num_heads=4, ln=False): 33 | super(SetTransformer, self).__init__() 34 | self.enc = nn.Sequential( 35 | ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln), 36 | ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln)) 37 | self.dec = nn.Sequential( 38 | PMA(dim_hidden, num_heads, num_outputs, ln=ln), 39 | SAB(dim_hidden, dim_hidden, num_heads, ln=ln), 40 | SAB(dim_hidden, dim_hidden, num_heads, ln=ln), 41 | nn.Linear(dim_hidden, dim_output)) 42 | 43 | def forward(self, X): 44 | return self.dec(self.enc(X)) -------------------------------------------------------------------------------- /synthetic/utilities/set_transformer.py: -------------------------------------------------------------------------------- 1 | from .invariant_modules import * 2 | 3 | class DeepSet(nn.Module): 4 | def __init__(self, dim_input, num_outputs, dim_output, dim_hidden=128): 5 | super(DeepSet, self).__init__() 6 | self.num_outputs = num_outputs 7 | self.dim_output = dim_output 8 | self.enc = nn.Sequential( 9 | nn.Linear(dim_input, dim_hidden), 10 | nn.ReLU(), 11 | nn.Linear(dim_hidden, dim_hidden), 12 | nn.ReLU(), 13 | nn.Linear(dim_hidden, dim_hidden), 14 | nn.ReLU(), 15 | nn.Linear(dim_hidden, dim_hidden)) 16 | self.dec = nn.Sequential( 17 | nn.Linear(dim_hidden, dim_hidden), 18 | nn.ReLU(), 19 | nn.Linear(dim_hidden, dim_hidden), 20 | nn.ReLU(), 21 | nn.Linear(dim_hidden, dim_hidden), 22 | nn.ReLU(), 23 | nn.Linear(dim_hidden, num_outputs*dim_output)) 24 | 25 | def forward(self, X): 26 | X = self.enc(X).mean(-2) 27 | X = self.dec(X).reshape(-1, self.num_outputs, self.dim_output) 28 | return X 29 | 30 | class SetTransformer(nn.Module): 31 | def __init__(self, dim_input, num_outputs, dim_output, 32 | num_inds=32, dim_hidden=128, num_heads=4, ln=False): 33 | super(SetTransformer, self).__init__() 34 | self.enc = nn.Sequential( 35 | ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln), 36 | ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln)) 37 | self.dec = nn.Sequential( 38 | PMA(dim_hidden, num_heads, num_outputs, ln=ln), 39 | SAB(dim_hidden, dim_hidden, num_heads, ln=ln), 40 | SAB(dim_hidden, dim_hidden, num_heads, ln=ln), 41 | nn.Linear(dim_hidden, dim_output)) 42 | 43 | def forward(self, X): 44 | return self.dec(self.enc(X)) -------------------------------------------------------------------------------- /MNIST/crl/dataloader/base_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def mkdirp(logdir): 4 | if not os.path.exists(logdir): 5 | os.mkdir(logdir) 6 | 7 | class DataLoader(object): 8 | ############################################################################ 9 | # Initialization 10 | def __init__(self): 11 | self.printer = self.initialize_printer(None, None) 12 | 13 | def initialize_printer(self, logger, args): 14 | if logger is None: 15 | assert args is None 16 | def printer(x): 17 | print(x) 18 | else: 19 | printer = lambda x: printf(logger, args, x) 20 | self.printer = printer 21 | 22 | def initialize_data(self): 23 | raise NotImplementedError 24 | 25 | def create_dataset_name(self): 26 | raise NotImplementedError 27 | 28 | ############################################################################ 29 | # Data Generation 30 | def generate_unique_dataset(self): 31 | pass 32 | 33 | def insufficient_coverage(self): 34 | pass 35 | 36 | def save_dataset(self): 37 | pass 38 | 39 | def save_datasets(self): 40 | pass 41 | 42 | ############################################################################ 43 | # Loading data 44 | def load_dataset(self): 45 | pass 46 | 47 | def preprocess(self, problem): 48 | raise NotImplementedError 49 | 50 | def record_state(self, state): 51 | raise NotImplementedError 52 | 53 | def load_problem(self, mode): 54 | raise NotImplementedError 55 | 56 | def reset(self, mode): 57 | raise NotImplementedError 58 | 59 | ############################################################################ 60 | # Curriculum 61 | def add_dataset(self, mode): 62 | raise NotImplementedError 63 | 64 | def update_curriculum(self): 65 | raise NotImplementedError 66 | 67 | ############################################################################ 68 | # Visualization 69 | def get_trace(self): 70 | raise NotImplementedError 71 | -------------------------------------------------------------------------------- /MNIST/utils/device.py: -------------------------------------------------------------------------------- 1 | from torch import nn, Tensor, torch 2 | from typing import List, Union 3 | from numpy import ndarray 4 | 5 | 6 | class Device: 7 | """Utilities for handling devices 8 | """ 9 | def __init__(self, use_cpu: bool = False, gpu_indices: List[int] = []) -> None: 10 | """ 11 | :param gpu_limits: list of gpus you allow PyTorch to use 12 | """ 13 | super().__init__() 14 | if use_cpu or not torch.cuda.is_available(): 15 | self.__use_cpu() 16 | else: 17 | self.gpu_indices = gpu_indices if gpu_indices else self.__all_gpu() 18 | self.device = torch.device('cuda:{}'.format(self.gpu_indices[0])) 19 | 20 | @property 21 | def unwrapped(self) -> torch.device: 22 | return self.device 23 | 24 | def tensor(self, arr: Union[ndarray, List[ndarray], Tensor], dtype=torch.float32) -> Tensor: 25 | """Convert numpy array or Tensor into Tensor on main_device 26 | :param x: ndarray or Tensor you want to convert 27 | :return: Tensor 28 | """ 29 | t = type(arr) 30 | if t is Tensor: 31 | return arr.to(device=self.device) # type: ignore 32 | elif t is ndarray or t is list: 33 | return torch.tensor(arr, device=self.device, dtype=dtype) 34 | else: 35 | raise ValueError('arr must be ndarray or list or tensor') 36 | 37 | def zeros(self, size: Union[int, tuple]) -> Tensor: 38 | return torch.zeros(size, device=self.device) 39 | 40 | def ones(self, size: Union[int, tuple]) -> Tensor: 41 | return torch.ones(size, device=self.device) 42 | 43 | def data_parallel(self, module: nn.Module) -> nn.DataParallel: 44 | return nn.DataParallel(module, device_ids=self.gpu_indices) 45 | 46 | def is_multi_gpu(self) -> bool: 47 | return len(self.gpu_indices) > 1 48 | 49 | def __all_gpu(self) -> List[int]: 50 | return list(range(torch.cuda.device_count())) 51 | 52 | def __use_cpu(self) -> None: 53 | self.gpu_indices = [] 54 | self.device = torch.device('cpu') 55 | 56 | def __repr__(self) -> str: 57 | return str(self.device) 58 | -------------------------------------------------------------------------------- /MNIST/utilities/sparse_attn.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import numpy 5 | 6 | class Sparse_attention(nn.Module): 7 | def __init__(self, top_k = 5): 8 | super(Sparse_attention,self).__init__() 9 | top_k += 1 10 | self.top_k = top_k 11 | 12 | def forward(self, attn_s): 13 | 14 | # normalize the attention weights using piece-wise Linear function 15 | # only top k should 16 | attn_plot = [] 17 | # torch.max() returns both value and location 18 | #attn_s_max = torch.max(attn_s, dim = 1)[0] 19 | #attn_w = torch.clamp(attn_s_max, min = 0, max = attn_s_max) 20 | eps = 10e-8 21 | time_step = attn_s.size()[1] 22 | if time_step <= self.top_k: 23 | # just make everything greater than 0, and return it 24 | #delta = torch.min(attn_s, dim = 1)[0] 25 | return attn_s 26 | else: 27 | # get top k and return it 28 | # bottom_k = attn_s.size()[1] - self.top_k 29 | # value of the top k elements 30 | #delta = torch.kthvalue(attn_s, bottm_k, dim= 1 )[0] 31 | delta = torch.topk(attn_s, self.top_k, dim= 1)[0][:,-1] + eps 32 | #delta = attn_s_max - torch.topk(attn_s, self.top_k, dim= 1)[0][:,-1] + eps 33 | # normalize 34 | delta = delta.reshape((delta.shape[0],1)) 35 | 36 | 37 | attn_w = attn_s - delta.repeat(1, time_step) 38 | attn_w = torch.clamp(attn_w, min = 0) 39 | attn_w_sum = torch.sum(attn_w, dim = 1, keepdim=True) 40 | attn_w_sum = attn_w_sum + eps 41 | attn_w_normalize = attn_w / attn_w_sum.repeat(1, time_step) 42 | 43 | #print('attn', attn_w_normalize) 44 | 45 | return attn_w_normalize 46 | 47 | 48 | if __name__ == "__main__": 49 | k = 1 50 | print('take top k', k) 51 | sa = Sparse_attention(top_k=k) 52 | 53 | #batch x time 54 | 55 | x = torch.from_numpy(numpy.array([[[0.1, 0.0, 0.3, 0.2, 0.4],[0.5,0.4,0.1,0.0,0.0]]])) 56 | 57 | x = x.reshape((2,5)) 58 | 59 | print('x shape', x.shape) 60 | print('x', x) 61 | 62 | o = sa(x) 63 | 64 | 65 | print('o', o) 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /synthetic/utilities/sparse_attn.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import numpy 5 | 6 | class Sparse_attention(nn.Module): 7 | def __init__(self, top_k = 5): 8 | super(Sparse_attention,self).__init__() 9 | top_k += 1 10 | self.top_k = top_k 11 | 12 | def forward(self, attn_s): 13 | 14 | # normalize the attention weights using piece-wise Linear function 15 | # only top k should 16 | attn_plot = [] 17 | # torch.max() returns both value and location 18 | #attn_s_max = torch.max(attn_s, dim = 1)[0] 19 | #attn_w = torch.clamp(attn_s_max, min = 0, max = attn_s_max) 20 | eps = 10e-8 21 | time_step = attn_s.size()[1] 22 | if time_step <= self.top_k: 23 | # just make everything greater than 0, and return it 24 | #delta = torch.min(attn_s, dim = 1)[0] 25 | return attn_s 26 | else: 27 | # get top k and return it 28 | # bottom_k = attn_s.size()[1] - self.top_k 29 | # value of the top k elements 30 | #delta = torch.kthvalue(attn_s, bottm_k, dim= 1 )[0] 31 | delta = torch.topk(attn_s, self.top_k, dim= 1)[0][:,-1] + eps 32 | #delta = attn_s_max - torch.topk(attn_s, self.top_k, dim= 1)[0][:,-1] + eps 33 | # normalize 34 | delta = delta.reshape((delta.shape[0],1)) 35 | 36 | 37 | attn_w = attn_s - delta.repeat(1, time_step) 38 | attn_w = torch.clamp(attn_w, min = 0) 39 | attn_w_sum = torch.sum(attn_w, dim = 1, keepdim=True) 40 | attn_w_sum = attn_w_sum + eps 41 | attn_w_normalize = attn_w / attn_w_sum.repeat(1, time_step) 42 | 43 | #print('attn', attn_w_normalize) 44 | 45 | return attn_w_normalize 46 | 47 | 48 | if __name__ == "__main__": 49 | k = 1 50 | print('take top k', k) 51 | sa = Sparse_attention(top_k=k) 52 | 53 | #batch x time 54 | 55 | x = torch.from_numpy(numpy.array([[[0.1, 0.0, 0.3, 0.2, 0.4],[0.5,0.4,0.1,0.0,0.0]]])) 56 | 57 | x = x.reshape((2,5)) 58 | 59 | print('x shape', x.shape) 60 | print('x', x) 61 | 62 | o = sa(x) 63 | 64 | 65 | print('o', o) 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /MNIST/crl/dataloader/datautils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import chain, combinations 3 | 4 | def sample_term_in_range(interval): 5 | return np.random.randint(*interval) 6 | 7 | def compute_prime_factors(n): 8 | i = 2 9 | factors = [] 10 | while i * i <= n: 11 | if n % i: 12 | i += 1 13 | else: 14 | n //= i 15 | factors.append(i) 16 | if n > 1: 17 | factors.append(n) 18 | return factors 19 | 20 | def get_prime_factors(n): 21 | prime_factors = compute_prime_factors(np.abs(n)) 22 | combos = chain.from_iterable(combinations(prime_factors, r) for r in range(len(prime_factors)+1)) 23 | factors = list(set([int(np.prod(c)) for c in combos])) 24 | return factors 25 | 26 | def extract_ops(new_exp_str, operator_matcher, operators): 27 | # find locations of operators and terms 28 | new_ops = [] 29 | term_locations = [] 30 | lo = 0 31 | for m in operator_matcher.finditer(new_exp_str): 32 | # op loc 33 | op_loc = m.span() 34 | # op 35 | new_op_str = new_exp_str[op_loc[0]:op_loc[1]] 36 | new_op = operators[new_op_str]().operator 37 | new_ops.append(new_op) 38 | # termloc 39 | hi = op_loc[0] 40 | term_loc = (lo, hi) 41 | term_locations.append(term_loc) 42 | lo = op_loc[1] 43 | term_locations.append((lo, len(new_exp_str))) 44 | return new_ops, term_locations 45 | 46 | def extract_terms(new_exp_str, term_locations): 47 | # get terms 48 | new_terms = [] 49 | for termloc in term_locations: 50 | new_term_str = new_exp_str[termloc[0]:termloc[1]] 51 | new_term = int(new_term_str) 52 | new_terms.append(new_term) 53 | return new_terms 54 | 55 | def extract_terms_ops(new_exp_str, operator_matcher, operators): 56 | new_ops, term_locations = extract_ops(new_exp_str, operator_matcher, operators) 57 | new_terms = extract_terms(new_exp_str, term_locations) 58 | return new_terms, new_ops 59 | 60 | def build_expression_string(terms, ops, operator_dict): 61 | exp_str = str(terms[0]) 62 | for i in range(len(ops)): 63 | exp_str += operator_dict[ops[i]] + str(terms[i+1]) 64 | return exp_str 65 | 66 | def num2onehot(num, size): 67 | v = np.zeros(size) 68 | v[num] = 1 69 | return v -------------------------------------------------------------------------------- /MNIST/utilities/invariant_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class MAB(nn.Module): 7 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 8 | super(MAB, self).__init__() 9 | self.dim_V = dim_V 10 | self.num_heads = num_heads 11 | self.fc_q = nn.Linear(dim_Q, dim_V) 12 | self.fc_k = nn.Linear(dim_K, dim_V) 13 | self.fc_v = nn.Linear(dim_K, dim_V) 14 | if ln: 15 | self.ln0 = nn.LayerNorm(dim_V) 16 | self.ln1 = nn.LayerNorm(dim_V) 17 | self.fc_o = nn.Linear(dim_V, dim_V) 18 | 19 | def forward(self, Q, K): 20 | Q = self.fc_q(Q) 21 | K, V = self.fc_k(K), self.fc_v(K) 22 | 23 | dim_split = self.dim_V // self.num_heads 24 | Q_ = torch.cat(Q.split(dim_split, 2), 0) 25 | K_ = torch.cat(K.split(dim_split, 2), 0) 26 | V_ = torch.cat(V.split(dim_split, 2), 0) 27 | 28 | A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2) 29 | O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) 30 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O) 31 | O = O + F.relu(self.fc_o(O)) 32 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O) 33 | return O 34 | 35 | class SAB(nn.Module): 36 | def __init__(self, dim_in, dim_out, num_heads, ln=False): 37 | super(SAB, self).__init__() 38 | self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln) 39 | 40 | def forward(self, X): 41 | return self.mab(X, X) 42 | 43 | class ISAB(nn.Module): 44 | def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False): 45 | super(ISAB, self).__init__() 46 | self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) 47 | nn.init.xavier_uniform_(self.I) 48 | self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln) 49 | self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln) 50 | 51 | def forward(self, X): 52 | H = self.mab0(self.I.repeat(X.size(0), 1, 1), X) 53 | return self.mab1(X, H) 54 | 55 | class PMA(nn.Module): 56 | def __init__(self, dim, num_heads, num_seeds, ln=False): 57 | super(PMA, self).__init__() 58 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) 59 | nn.init.xavier_uniform_(self.S) 60 | self.mab = MAB(dim, dim, dim, num_heads, ln=ln) 61 | 62 | def forward(self, X): 63 | return self.mab(self.S.repeat(X.size(0), 1, 1), X) -------------------------------------------------------------------------------- /synthetic/utilities/invariant_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class MAB(nn.Module): 7 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 8 | super(MAB, self).__init__() 9 | self.dim_V = dim_V 10 | self.num_heads = num_heads 11 | self.fc_q = nn.Linear(dim_Q, dim_V) 12 | self.fc_k = nn.Linear(dim_K, dim_V) 13 | self.fc_v = nn.Linear(dim_K, dim_V) 14 | if ln: 15 | self.ln0 = nn.LayerNorm(dim_V) 16 | self.ln1 = nn.LayerNorm(dim_V) 17 | self.fc_o = nn.Linear(dim_V, dim_V) 18 | 19 | def forward(self, Q, K): 20 | Q = self.fc_q(Q) 21 | K, V = self.fc_k(K), self.fc_v(K) 22 | 23 | dim_split = self.dim_V // self.num_heads 24 | Q_ = torch.cat(Q.split(dim_split, 2), 0) 25 | K_ = torch.cat(K.split(dim_split, 2), 0) 26 | V_ = torch.cat(V.split(dim_split, 2), 0) 27 | 28 | A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2) 29 | O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) 30 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O) 31 | O = O + F.relu(self.fc_o(O)) 32 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O) 33 | return O 34 | 35 | class SAB(nn.Module): 36 | def __init__(self, dim_in, dim_out, num_heads, ln=False): 37 | super(SAB, self).__init__() 38 | self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln) 39 | 40 | def forward(self, X): 41 | return self.mab(X, X) 42 | 43 | class ISAB(nn.Module): 44 | def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False): 45 | super(ISAB, self).__init__() 46 | self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) 47 | nn.init.xavier_uniform_(self.I) 48 | self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln) 49 | self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln) 50 | 51 | def forward(self, X): 52 | H = self.mab0(self.I.repeat(X.size(0), 1, 1), X) 53 | return self.mab1(X, H) 54 | 55 | class PMA(nn.Module): 56 | def __init__(self, dim, num_heads, num_seeds, ln=False): 57 | super(PMA, self).__init__() 58 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) 59 | nn.init.xavier_uniform_(self.S) 60 | self.mab = MAB(dim, dim, dim, num_heads, ln=ln) 61 | 62 | def forward(self, X): 63 | return self.mab(self.S.repeat(X.size(0), 1, 1), X) -------------------------------------------------------------------------------- /MNIST/utilities/slot_attention_old.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | class MLP(nn.Module): 5 | def __init__(self, in_dim, out_dim): 6 | super().__init__() 7 | self.mlp = nn.Sequential(nn.Linear(in_dim, 64), 8 | nn.ReLU(), 9 | nn.Linear(64, out_dim)) 10 | def forward(self, x): 11 | return self.mlp(x) 12 | 13 | class SlotAttention(nn.Module): 14 | def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128, out_dim = 64, in_dim = 64): 15 | super().__init__() 16 | self.num_slots = num_slots 17 | self.iters = iters 18 | self.eps = eps 19 | self.scale = dim ** -0.5 20 | 21 | self.slots_mu = nn.Parameter(torch.randn(1, 1, dim)) 22 | self.slots_sigma = nn.Parameter(torch.randn(1, 1, dim)) 23 | 24 | self.to_q = nn.Linear(dim, dim) 25 | self.to_k = nn.Linear(in_dim, dim) 26 | self.to_v = nn.Linear(in_dim, dim) 27 | 28 | self.gru = nn.GRUCell(dim, dim) 29 | 30 | hidden_dim = max(dim, hidden_dim) 31 | 32 | self.mlp = nn.Sequential( 33 | nn.Linear(dim, hidden_dim), 34 | nn.ReLU(inplace = True), 35 | nn.Linear(hidden_dim, dim) 36 | ) 37 | 38 | 39 | self.mlp_cast = MLP(dim, out_dim) 40 | self.norm_input = nn.LayerNorm(in_dim) 41 | self.norm_slots = nn.LayerNorm(dim) 42 | self.norm_pre_ff = nn.LayerNorm(dim) 43 | 44 | def forward(self, inputs, slots, num_slots = None): 45 | 46 | b, n, d_ = inputs.shape 47 | n_s = num_slots if num_slots is not None else self.num_slots 48 | 49 | #mu = self.slots_mu.expand(b, n_s, -1) 50 | #sigma = self.slots_sigma.expand(b, n_s, -1) 51 | d = slots.size(-1) 52 | 53 | inputs = self.norm_input(inputs) 54 | k, v = self.to_k(inputs), self.to_v(inputs) 55 | 56 | for _ in range(self.iters): 57 | slots_prev = slots 58 | 59 | slots = self.norm_slots(slots) 60 | q = self.to_q(slots) 61 | 62 | dots = torch.einsum('bid,bjd->bij', q, k) * self.scale 63 | attn = dots.softmax(dim=1) + self.eps 64 | attn = attn / attn.sum(dim=-1, keepdim=True) 65 | 66 | updates = torch.einsum('bjd,bij->bid', v, attn) 67 | 68 | slots = self.gru( 69 | updates.reshape(-1, d), 70 | slots_prev.reshape(-1, d) 71 | ) 72 | 73 | slots = slots.reshape(b, -1, d) 74 | slots = slots + self.mlp(self.norm_pre_ff(slots)) 75 | 76 | slots = self.mlp_cast(slots) 77 | 78 | return slots 79 | -------------------------------------------------------------------------------- /synthetic/utilities/slot_attention_old.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | class MLP(nn.Module): 5 | def __init__(self, in_dim, out_dim): 6 | super().__init__() 7 | self.mlp = nn.Sequential(nn.Linear(in_dim, 64), 8 | nn.ReLU(), 9 | nn.Linear(64, out_dim)) 10 | def forward(self, x): 11 | return self.mlp(x) 12 | 13 | class SlotAttention(nn.Module): 14 | def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128, out_dim = 64, in_dim = 64): 15 | super().__init__() 16 | self.num_slots = num_slots 17 | self.iters = iters 18 | self.eps = eps 19 | self.scale = dim ** -0.5 20 | 21 | self.slots_mu = nn.Parameter(torch.randn(1, 1, dim)) 22 | self.slots_sigma = nn.Parameter(torch.randn(1, 1, dim)) 23 | 24 | self.to_q = nn.Linear(dim, dim) 25 | self.to_k = nn.Linear(in_dim, dim) 26 | self.to_v = nn.Linear(in_dim, dim) 27 | 28 | self.gru = nn.GRUCell(dim, dim) 29 | 30 | hidden_dim = max(dim, hidden_dim) 31 | 32 | self.mlp = nn.Sequential( 33 | nn.Linear(dim, hidden_dim), 34 | nn.ReLU(inplace = True), 35 | nn.Linear(hidden_dim, dim) 36 | ) 37 | 38 | 39 | self.mlp_cast = MLP(dim, out_dim) 40 | self.norm_input = nn.LayerNorm(in_dim) 41 | self.norm_slots = nn.LayerNorm(dim) 42 | self.norm_pre_ff = nn.LayerNorm(dim) 43 | 44 | def forward(self, inputs, slots, num_slots = None): 45 | 46 | b, n, d_ = inputs.shape 47 | n_s = num_slots if num_slots is not None else self.num_slots 48 | 49 | #mu = self.slots_mu.expand(b, n_s, -1) 50 | #sigma = self.slots_sigma.expand(b, n_s, -1) 51 | d = slots.size(-1) 52 | 53 | inputs = self.norm_input(inputs) 54 | k, v = self.to_k(inputs), self.to_v(inputs) 55 | 56 | for _ in range(self.iters): 57 | slots_prev = slots 58 | 59 | slots = self.norm_slots(slots) 60 | q = self.to_q(slots) 61 | 62 | dots = torch.einsum('bid,bjd->bij', q, k) * self.scale 63 | attn = dots.softmax(dim=1) + self.eps 64 | attn = attn / attn.sum(dim=-1, keepdim=True) 65 | 66 | updates = torch.einsum('bjd,bij->bid', v, attn) 67 | 68 | slots = self.gru( 69 | updates.reshape(-1, d), 70 | slots_prev.reshape(-1, d) 71 | ) 72 | 73 | slots = slots.reshape(b, -1, d) 74 | slots = slots + self.mlp(self.norm_pre_ff(slots)) 75 | 76 | slots = self.mlp_cast(slots) 77 | 78 | return slots 79 | -------------------------------------------------------------------------------- /synthetic/utilities/SharedGroupLinearLayer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from .GroupLinearLayer import GroupLinearLayer 5 | 6 | class SharedGroupLinearLayer(nn.Module): 7 | """All the parameters are shared using soft attention this layer is used for sharing Q,K,V parameters of MHA""" 8 | 9 | def __init__(self, din, dout, n_templates): 10 | super(SharedGroupLinearLayer, self).__init__() 11 | 12 | self.w = nn.ModuleList([nn.Linear(din, dout, bias = False) for _ in range(0,n_templates)]) 13 | self.gll_write = GroupLinearLayer(dout,16, n_templates) 14 | self.gll_read = GroupLinearLayer(din,16,1) 15 | #self.register_buffer(self.w) 16 | 17 | def forward(self,x): 18 | #input size (bs,num_blocks,din), required matching num_blocks vs n_templates 19 | bs_size = x.shape[0] 20 | k = x.shape[1] 21 | x= x.reshape(k*bs_size,-1) 22 | x_read = self.gll_read((x*1.0).reshape((x.shape[0], 1, x.shape[1]))) 23 | x_next = [] 24 | for mod in self.w: 25 | x_next_l = mod(x) 26 | x_next.append(x_next_l) 27 | x_next = torch.stack(x_next,1) #(k*bs,n_templates,dout) 28 | 29 | x_write = self.gll_write(x_next) 30 | sm = nn.Softmax(2) 31 | att = sm(torch.bmm(x_read, x_write.permute(0, 2, 1))) 32 | 33 | x_next = torch.bmm(att, x_next) 34 | 35 | x_next = x_next.mean(dim=1).reshape(bs_size,k,-1) 36 | 37 | return x_next 38 | 39 | class NonSharedGroupLinearLayer(nn.Module): 40 | """All the parameters are shared using soft attention this layer is used for sharing Q,K,V parameters of MHA""" 41 | 42 | def __init__(self, din, dout, n_templates): 43 | super(NonSharedGroupLinearLayer, self).__init__() 44 | 45 | self.w = nn.ModuleList([nn.Linear(din, dout) for _ in range(0,n_templates)]) 46 | #self.gll_write = GroupLinearLayer(dout, 16, n_templates) 47 | #self.gll_read = GroupLinearLayer(din, 16, num_blocks) 48 | #self.register_buffer(self.w) 49 | 50 | def forward(self, x, att): 51 | #input size (bs,num_blocks,din), required matching num_blocks vs n_templates 52 | b, k, d = x.size() 53 | 54 | x = x.reshape(b * k, -1) 55 | x_next = [] 56 | for mod in self.w: 57 | x_next_l = mod(x) 58 | x_next.append(x_next_l) 59 | 60 | 61 | x_next = torch.stack(x_next,1) #(k*bs,n_templates,dout) 62 | x_next = x_next.reshape(b, k, len(self.w), -1) 63 | x_next = x_next * att.unsqueeze(-1) 64 | x_next = torch.sum(x_next, dim = 2) 65 | 66 | return x_next 67 | 68 | 69 | 70 | 71 | if __name__ == "__main__": 72 | 73 | GLN = SharedGroupLinearLayer(25,22,6) 74 | 75 | x = torch.randn(64,12,25) 76 | 77 | print(GLN(x).shape) 78 | 79 | for p in GLN.parameters(): 80 | print(p.shape) 81 | 82 | -------------------------------------------------------------------------------- /MNIST/utilities/SharedGroupLinearLayer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from .GroupLinearLayer import GroupLinearLayer 5 | 6 | class SharedGroupLinearLayer(nn.Module): 7 | """All the parameters are shared using soft attention this layer is used for sharing Q,K,V parameters of MHA""" 8 | 9 | def __init__(self, din, dout, n_templates): 10 | super(SharedGroupLinearLayer, self).__init__() 11 | 12 | self.w = nn.ModuleList([nn.Linear(din, dout, bias = False) for _ in range(0,n_templates)]) 13 | self.gll_write = GroupLinearLayer(dout,16, n_templates) 14 | self.gll_read = GroupLinearLayer(din,16,1) 15 | #self.register_buffer(self.w) 16 | 17 | def forward(self,x): 18 | #input size (bs,num_blocks,din), required matching num_blocks vs n_templates 19 | bs_size = x.shape[0] 20 | k = x.shape[1] 21 | x= x.reshape(k*bs_size,-1) 22 | x_read = self.gll_read((x*1.0).reshape((x.shape[0], 1, x.shape[1]))) 23 | x_next = [] 24 | for mod in self.w: 25 | x_next_l = mod(x) 26 | x_next.append(x_next_l) 27 | x_next = torch.stack(x_next,1) #(k*bs,n_templates,dout) 28 | 29 | x_write = self.gll_write(x_next) 30 | sm = nn.Softmax(2) 31 | att = sm(torch.bmm(x_read, x_write.permute(0, 2, 1))) 32 | 33 | x_next = torch.bmm(att, x_next) 34 | 35 | x_next = x_next.mean(dim=1).reshape(bs_size,k,-1) 36 | 37 | return x_next 38 | 39 | class NonSharedGroupLinearLayer(nn.Module): 40 | """All the parameters are shared using soft attention this layer is used for sharing Q,K,V parameters of MHA""" 41 | 42 | def __init__(self, din, dout, n_templates): 43 | super(NonSharedGroupLinearLayer, self).__init__() 44 | 45 | self.w = nn.ModuleList([nn.Linear(din, dout) for _ in range(0,n_templates)]) 46 | #self.gll_write = GroupLinearLayer(dout, 16, n_templates) 47 | #self.gll_read = GroupLinearLayer(din, 16, num_blocks) 48 | #self.register_buffer(self.w) 49 | 50 | def forward(self, x, att): 51 | #input size (bs,num_blocks,din), required matching num_blocks vs n_templates 52 | b, k, d = x.size() 53 | 54 | x = x.reshape(b * k, -1) 55 | x_next = [] 56 | for mod in self.w: 57 | x_next_l = mod(x) 58 | x_next.append(x_next_l) 59 | 60 | 61 | x_next = torch.stack(x_next,1) #(k*bs,n_templates,dout) 62 | x_next = x_next.reshape(b, k, len(self.w), -1) 63 | x_next = x_next * att.unsqueeze(-1) 64 | x_next = torch.sum(x_next, dim = 2) 65 | 66 | return x_next 67 | 68 | 69 | 70 | if __name__ == "__main__": 71 | 72 | GLN = NonSharedGroupLinearLayer(25,22,6, 4) 73 | 74 | x = torch.randn(64,4,25) 75 | 76 | print(GLN(x).shape) 77 | 78 | for p in GLN.parameters(): 79 | print(p.shape) 80 | 81 | -------------------------------------------------------------------------------- /MNIST/crl/dataloader/mnist_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | 4 | from .base_dataloader import DataLoader 5 | 6 | class MNIST(DataLoader): 7 | def __init__(self, args): 8 | super(MNIST, self).__init__() 9 | self.train_loader = torch.utils.data.DataLoader( 10 | datasets.MNIST('../data/mnist', train=True, download=True, 11 | transform=transforms.Compose([ 12 | transforms.ToTensor(), 13 | transforms.Normalize((0.1307,), (0.3081,)) 14 | ])), 15 | batch_size=args.bsize, shuffle=True) 16 | self.curr = False 17 | 18 | def initialize_data(self, splits): 19 | pass 20 | 21 | def reset(self, mode='train', z=None): 22 | for i, (data, target) in enumerate(self.train_loader): 23 | return data.view(-1), target 24 | 25 | def get_trace(self): 26 | return '' 27 | 28 | def change_mt(self): 29 | pass 30 | 31 | def load_mnist_datasets(root, normalize=True, extrap=False): 32 | 33 | if normalize: 34 | train_dataset = datasets.MNIST(root, train=True, download=True, 35 | transform=transforms.Compose([ 36 | transforms.ToTensor(), 37 | transforms.Normalize((0.1307,), (0.3081,)) 38 | ])) 39 | valtest_dataset = datasets.MNIST(root, train=False, transform=transforms.Compose([ 40 | transforms.ToTensor(), 41 | transforms.Normalize((0.1307,), (0.3081,)) 42 | ])) 43 | else: 44 | train_dataset = datasets.MNIST(root, train=True, download=True, 45 | transform=transforms.Compose([ 46 | transforms.ToTensor(), 47 | ])) 48 | valtest_dataset = datasets.MNIST(root, train=False, transform=transforms.Compose([ 49 | transforms.ToTensor(), 50 | ])) 51 | 52 | train_data, train_labels = zip(*train_dataset) 53 | train_data = torch.stack(train_data) 54 | train_labels = torch.LongTensor(train_labels).unsqueeze(1) 55 | 56 | # now you should divide into groups 57 | 58 | numtest = int(len(valtest_dataset) / 2) 59 | valtest_data, valtest_labels = zip(*valtest_dataset) 60 | valtest_data = torch.stack(valtest_data) 61 | valtest_labels = torch.LongTensor(valtest_labels).unsqueeze(1) 62 | 63 | val_data = valtest_data[:numtest] 64 | val_labels = valtest_labels[:numtest] 65 | 66 | test_data = valtest_data[numtest:] 67 | test_labels = valtest_labels[numtest:] 68 | 69 | 70 | mnist_datasets = { 71 | 'train': (train_data, train_labels), 72 | 'val': (val_data, val_labels), 73 | 'test': (test_data, test_labels), 74 | } 75 | if extrap: 76 | mnist_datasets['extrapval'] = (test_data, test_labels) 77 | return mnist_datasets -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NPS_ICML 2 | 3 | ## Installation 4 | Our code was tested with python 3.6 5 | Use `pip install -r requirements.txt` to install all dependencies. 6 | 7 | 8 | ## Arithmetic Task 9 | Run the following command from the `synthetic` folder. 10 | ``` 11 | sh runner.sh num_rules rule_emb_dim embed_dim seed 12 | 13 | num_rules: Number of rules to use. Should always be 3 as there are 3 operations: {addition, subtraction, multiplication}. 14 | rule_emb_dim: Rule embedding dimension. 15 | embed_dim: Dimension to which the numbers are encoded to. 16 | 17 | ``` 18 | 19 | To reproduce the experiments in the paper: 20 | ``` 21 | sh runner.sh 3 32 64 3 22 | ``` 23 | 24 | Expected output: 25 | 26 | ``` 27 | 0 : {'addition': 16640, 'subtraction': 0, 'multiplication': 0} 28 | 2 : {'addition': 0, 'subtraction': 0, 'multiplication': 16501} 29 | 1 : {'addition': 0, 'subtraction': 16856, 'multiplication': 0} 30 | ``` 31 | Here we can see that the rules have been completely segregated. rule number 0 is solely used for the addition operation, rule number 1 is solely used for the multiplication operation, and rule number 1 is solely used for the subtraction operation. You should also see a best eval mse of about: 32 | ``` 33 | best_eval_mse:tensor(0.0005, device='cuda:0') 34 | ``` 35 | Depending on the seed as well as the environment in which the code is run, the best eval mse may vary but should be somewhere around the above number. Below we show best eval mse for 3 different seeds: 36 | ``` 37 | (1) best_eval_mse:tensor(0.0005, device='cuda:0') 38 | (2) best_eval_mse:tensor(0.0004, device='cuda:0') 39 | (3) best_eval_mse:tensor(0.0006, device='cuda:0') 40 | ``` 41 | 42 | To evalaute the model on multiple sequence lengths: 43 | ``` 44 | sh eval_runner.sh num_rules rule_emb_dim embed_dim seed 45 | ``` 46 | This will run the saved model corresponding to the args provided. 47 | 48 | To train and evaluate the model presented in the paper run: 49 | ``` 50 | sh runner.sh 3 32 64 3 51 | sh eval_runner.sh 3 32 64 3 52 | ``` 53 | 54 | The output of eval_runner.sh should be as follows: 55 | ``` 56 | FINAL RESULTS ACROSS VARIOUS SEQUENCE LENGTHS 57 | SEQUENCE LENGTH | MSE 58 | 10 | 0.0003 59 | 20 | 0.0005 60 | 30 | 0.0008 61 | 40 | 0.0014 62 | 50 | 0.0018 63 | ``` 64 | 65 | ## MNIST Transformation Task 66 | 67 | Run the following command from the `MNIST` folder. 68 | 69 | ``` 70 | sh run.sh seed 71 | ``` 72 | Expected output: 73 | After a few epochs a complete segregation of rules should be observed in the following manner: 74 | ``` 75 | rotate_left : {0: 0, 1: 0, 2: 0, 3: 4981} 76 | translate_up : {0: 0, 1: 4950, 2: 0, 3: 0} 77 | rotate_right : {0: 0, 1: 0, 2: 5030, 3: 0} 78 | translate_down : {0: 5039, 1: 0, 2: 0, 3: 0} 79 | 80 | ``` 81 | 82 | The above snippet indicates that rule number 3 is solely being used for the rotate left transformation, rule number 1 is being used for translate up, rule number 2 is being used for rotate right, and rule number 0 is being used for translate down. Hence a complete segregation is observed. 83 | 84 | -------------------------------------------------------------------------------- /MNIST/utils/sample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler 3 | from typing import Iterator 4 | #from ..prelude import Array 5 | 6 | from typing import Any, Callable, Iterable, Sequence, Tuple, TypeVar, Union 7 | from utils.device import Device 8 | 9 | 10 | try: 11 | from typing import GenericMeta, NamedTupleMeta # type: ignore 12 | 13 | class GenericNamedMeta(NamedTupleMeta, GenericMeta): 14 | pass 15 | except ImportError: 16 | from typing import NamedTupleMeta # type: ignore 17 | GenericNamedMeta = NamedTupleMeta # type: ignore 18 | 19 | T = TypeVar('T') 20 | Self = Any 21 | 22 | 23 | class Array(Sequence[T]): 24 | @property 25 | def shape(self) -> tuple: 26 | ... 27 | 28 | def squeeze(self) -> Self: 29 | ... 30 | 31 | def transpose(self, *args) -> Self: 32 | ... 33 | 34 | def __rsub__(self, value: Any) -> Self: 35 | ... 36 | 37 | def __truediv__(self, rvalue: Any) -> Self: 38 | ... 39 | 40 | 41 | 42 | def sample_indices(n: int, k: int) -> np.ndarray: 43 | """Sample k numbers from [0, n) 44 | Based on https://github.com/chainer/chainerrl/blob/master/chainerrl/misc/random.py 45 | """ 46 | if 3 * k >= n: 47 | return np.random.choice(n, k, replace=False) 48 | else: 49 | selected = np.repeat(False, n) 50 | rands = np.random.randint(0, n, size=k * 2) 51 | j = k 52 | for i in range(k): 53 | x = rands[i] 54 | while selected[x]: 55 | if j == 2 * k: 56 | rands[k:] = np.random.randint(0, n, size=k) 57 | j = k 58 | x = rands[i] = rands[j] 59 | j += 1 60 | selected[x] = True 61 | return rands[:k] 62 | 63 | 64 | class FeedForwardBatchSampler(BatchSampler): 65 | def __init__(self, nsteps: int, nworkers: int, batch_size: int) -> None: 66 | super().__init__( 67 | SubsetRandomSampler(range(nsteps * nworkers)), 68 | batch_size=batch_size, 69 | drop_last=True 70 | ) 71 | 72 | 73 | class RecurrentBatchSampler(Sampler): 74 | def __init__(self, nsteps: int, nworkers: int, batch_size: int) -> None: 75 | if batch_size % nsteps > 0: 76 | raise ValueError('batch_size must be a multiple of nsteps') 77 | self.nsteps = nsteps 78 | self.nworkers = nworkers 79 | self.batch_size = batch_size 80 | 81 | def __iter__(self) -> Iterator[Array[int]]: 82 | env_num = self.batch_size // self.nsteps 83 | total, step = self.nsteps * self.nworkers, self.nworkers 84 | perm = np.random.permutation(self.nworkers) 85 | for end in np.arange(env_num, self.nworkers + 1, env_num): 86 | workers = perm[end - env_num: end] 87 | batches = np.stack([np.arange(w, total, step) for w in workers], axis=1) 88 | yield batches.flatten() 89 | 90 | def __len__(self) -> int: 91 | return (self.nsteps * self.nworkers) // self.batch_size 92 | 93 | -------------------------------------------------------------------------------- /MNIST/utilities/layer_conn_attention.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .attention import ScaledDotProductAttention 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | class LayerConnAttention(nn.Module): 10 | ''' Multi-Head Attention module ''' 11 | 12 | def __init__(self, n_head, d_model, d_k, d_v, d_out, dropout=0.1): 13 | super().__init__() 14 | 15 | self.n_head = n_head 16 | self.d_k = d_k 17 | self.d_v = d_v 18 | 19 | self.w_qs = nn.Linear(d_model, n_head * d_k) 20 | self.w_ks = nn.Linear(d_model, n_head * d_k) 21 | self.w_vs = nn.Linear(d_model, n_head * d_v) 22 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 23 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 24 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) 25 | 26 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 27 | self.layer_norm = nn.LayerNorm(d_model) 28 | 29 | self.gate_fc = nn.Linear(n_head * d_v, d_out) 30 | self.fc = nn.Linear(n_head * d_v, d_out) 31 | nn.init.xavier_normal_(self.fc.weight) 32 | 33 | self.dropout = nn.Dropout(dropout) 34 | 35 | 36 | def forward(self, q, k, v, mask=None): 37 | 38 | #print('attn input shape', q.shape) 39 | 40 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 41 | 42 | sz_b, len_q, _ = q.size() 43 | sz_b, len_k, _ = k.size() 44 | sz_b, len_v, _ = v.size() 45 | 46 | residual = q 47 | 48 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 49 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 50 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 51 | #v = v.view(sz_b, len_v, n_head, d_v) 52 | 53 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 54 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 55 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 56 | 57 | #mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 58 | output, attn, extra_loss = self.attention(q, k, v, mask=None) 59 | 60 | output = output.view(n_head, sz_b, len_q, d_v) 61 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 62 | 63 | #print('output shape before fc', output.shape) 64 | 65 | #TODO: probably shouldn't just apply residual layer in the forward pass. 66 | 67 | output_init = output*1.0 68 | 69 | #output = self.dropout(self.fc(output_init)) 70 | output = self.dropout(output_init) 71 | 72 | #gate = F.sigmoid(self.gate_fc(output_init)) 73 | 74 | #output = self.layer_norm(gate * output + (1 - gate) * residual) 75 | #output = gate * output + (1 - gate) * residual 76 | 77 | #output = residual + gate * F.tanh(output) 78 | 79 | #output 80 | 81 | #print('attn', attn[0]) 82 | #print('output input diff', output - residual) 83 | 84 | return output, attn, extra_loss 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /synthetic/utilities/layer_conn_attention.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .attention import ScaledDotProductAttention 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | class LayerConnAttention(nn.Module): 10 | ''' Multi-Head Attention module ''' 11 | 12 | def __init__(self, n_head, d_model, d_k, d_v, d_out, dropout=0.1): 13 | super().__init__() 14 | 15 | self.n_head = n_head 16 | self.d_k = d_k 17 | self.d_v = d_v 18 | 19 | self.w_qs = nn.Linear(d_model, n_head * d_k) 20 | self.w_ks = nn.Linear(d_model, n_head * d_k) 21 | self.w_vs = nn.Linear(d_model, n_head * d_v) 22 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 23 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 24 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) 25 | 26 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 27 | self.layer_norm = nn.LayerNorm(d_model) 28 | 29 | self.gate_fc = nn.Linear(n_head * d_v, d_out) 30 | self.fc = nn.Linear(n_head * d_v, d_out) 31 | nn.init.xavier_normal_(self.fc.weight) 32 | 33 | self.dropout = nn.Dropout(dropout) 34 | 35 | 36 | def forward(self, q, k, v, mask=None): 37 | 38 | #print('attn input shape', q.shape) 39 | 40 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 41 | 42 | sz_b, len_q, _ = q.size() 43 | sz_b, len_k, _ = k.size() 44 | sz_b, len_v, _ = v.size() 45 | 46 | residual = q 47 | 48 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 49 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 50 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 51 | #v = v.view(sz_b, len_v, n_head, d_v) 52 | 53 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 54 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 55 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 56 | 57 | #mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 58 | output, attn, extra_loss = self.attention(q, k, v, mask=None) 59 | 60 | output = output.view(n_head, sz_b, len_q, d_v) 61 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 62 | 63 | #print('output shape before fc', output.shape) 64 | 65 | #TODO: probably shouldn't just apply residual layer in the forward pass. 66 | 67 | output_init = output*1.0 68 | 69 | #output = self.dropout(self.fc(output_init)) 70 | output = self.dropout(output_init) 71 | 72 | #gate = F.sigmoid(self.gate_fc(output_init)) 73 | 74 | #output = self.layer_norm(gate * output + (1 - gate) * residual) 75 | #output = gate * output + (1 - gate) * residual 76 | 77 | #output = residual + gate * F.tanh(output) 78 | 79 | #output 80 | 81 | #print('attn', attn[0]) 82 | #print('output input diff', output - residual) 83 | 84 | return output, attn, extra_loss 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /MNIST/crl/dataloader/arithmetic.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import regex as re 3 | from collections import OrderedDict 4 | import datautils as du 5 | import numpy as np 6 | 7 | class Operator(object): 8 | def __init__(self): 9 | super(Operator, self).__init__() 10 | self.op_string = None 11 | self.matcher = None 12 | self.operator = None 13 | 14 | def transform(self, x, idx): 15 | raise NotImplementedError 16 | 17 | class AttentionOperator(Operator): 18 | def __init__(self): 19 | super(AttentionOperator, self).__init__() 20 | # inherited: self.op_string, self.matcher, self.operator 21 | self.all_op_string = '[\+\-\*\/]' 22 | self.all_operator_matcher = re.compile(self.all_op_string) 23 | self.all_operators = OrderedDict([ 24 | ('+', Plus), 25 | ('*', Multiply), 26 | ('-', Minus), 27 | ('/', Divide) 28 | ]) 29 | self.static_operator_dict = OrderedDict([ 30 | (operator.add, '+'), 31 | (operator.mul, '*'), 32 | (operator.sub, '-'), 33 | (operator.div, '/') 34 | ]) 35 | 36 | def evaluate_subexp(self, x, idx=-1): 37 | terms, ops = du.extract_terms_ops(x, self.all_operator_matcher, self.all_operators) # can be affected by np.inf 38 | if idx==-1 and self.operator in ops: 39 | idx = ops.index(self.operator) # naive default 40 | if ops and idx < len(ops) and ops[idx] == self.operator: 41 | evaluated_subexp = self.operator(terms[idx], terms[idx+1]) 42 | new_terms = terms[:idx] + [evaluated_subexp] + terms[idx+2:] 43 | new_ops = ops[:idx] + ops[idx+1:] 44 | return evaluated_subexp, new_terms, new_ops 45 | else: 46 | return None, None, None 47 | 48 | def get_new_exp_str(self, x, evaluated_subexp, new_terms, new_ops): 49 | if evaluated_subexp is not None: # if evaluated_subexp actually does not work here 50 | new_exp_str = du.build_expression_string(new_terms, new_ops, self.static_operator_dict) 51 | return new_exp_str, evaluated_subexp 52 | else: 53 | return x, None 54 | 55 | def transform(self, x, idx=-1): 56 | evaluated_subexp, new_terms, new_ops = self.evaluate_subexp(x, idx) 57 | new_exp_str, evaluated_subexp = self.get_new_exp_str(x, evaluated_subexp, new_terms, new_ops) 58 | # evaluated_subexp is not actually modified 59 | return new_exp_str, evaluated_subexp 60 | 61 | class Plus(AttentionOperator): 62 | def __init__(self): 63 | super(Plus, self).__init__() 64 | self.op_string = '+' 65 | self.operator = operator.add 66 | 67 | class Minus(AttentionOperator): 68 | def __init__(self): 69 | super(Minus, self).__init__() 70 | self.op_string = '-' 71 | self.operator = operator.sub 72 | 73 | class Multiply(AttentionOperator): 74 | def __init__(self): 75 | super(Multiply, self).__init__() 76 | self.op_string = '*' 77 | self.operator = operator.mul 78 | 79 | class Divide(AttentionOperator): 80 | def __init__(self): 81 | super(Divide, self).__init__() 82 | self.op_string = '/' 83 | self.operator = operator.div -------------------------------------------------------------------------------- /MNIST/crl/dataloader/modulo_datagen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import operator 4 | import matplotlib 5 | matplotlib.use('agg') 6 | import matplotlib.pyplot as plt 7 | import pprint 8 | import copy 9 | 10 | from collections import OrderedDict 11 | 12 | from arithmetic import Plus, Minus, Multiply 13 | import utils 14 | import datautils as du 15 | 16 | from datagen import ArithmeticDataGenerator 17 | 18 | np.random.seed(0) 19 | 20 | class ModuloDataGenerator(ArithmeticDataGenerator): 21 | def __init__(self, ops, numrange): 22 | super(ModuloDataGenerator, self).__init__(ops, numrange) 23 | self.range_length = max(self.range) 24 | self.modulus = self.range_length 25 | self.encoding_length = self.range_length + self.op_length 26 | 27 | def sample_next_term(self, op, first): 28 | if self.operator_dict[op] in '+-*': 29 | next_term = du.sample_term_in_range(self.range) 30 | else: 31 | assert False 32 | return next_term 33 | 34 | def create_problem(self, max_terms): 35 | num_ops = max_terms - 1 36 | ops = [self._sample_operator() for i in xrange(num_ops)] 37 | terms = [du.sample_term_in_range(self.range)] 38 | for i in xrange(num_ops): 39 | terms.append(self.sample_next_term(ops[i], terms[i-1])) 40 | exp_val = self.evaluate_expression(terms, ops) 41 | exp_str = du.build_expression_string(terms, ops, self.operator_dict) 42 | if self.verbose: print 'Final Expression: {} = {}'.format(exp_str, exp_val) 43 | return exp_str, exp_val, terms, ops 44 | 45 | def evaluate_expression(self, terms, ops): 46 | terms = copy.deepcopy(terms) 47 | ops = copy.deepcopy(ops) 48 | # first find all the ops that are * and / 49 | multiplicative_ops_indices = filter(lambda x: self.operator_dict[ops[x]] in '*/', range(len(ops))) 50 | 51 | while len(multiplicative_ops_indices) > 0: 52 | # evaluate multiplication 53 | multiplicative_op_index = multiplicative_ops_indices[0] # get first multiplicative index 54 | surrounding_terms = terms[multiplicative_op_index], terms[multiplicative_op_index+1] 55 | multiplicative_val = ops[multiplicative_op_index](surrounding_terms[0], surrounding_terms[1]) % self.modulus # modulo arithmetic 56 | # update ops and terms 57 | ops = ops[:multiplicative_op_index] + ops[multiplicative_op_index+1:] 58 | terms = terms[:multiplicative_op_index] + [multiplicative_val] + terms[multiplicative_op_index+2:] 59 | # check if there are more multiplicative terms 60 | multiplicative_ops_indices = filter(lambda x: self.operator_dict[ops[x]] in '*/', range(len(ops))) 61 | 62 | # at this point there are no multiplicative ops 63 | result = self._fold_left_ops_terms_eval(ops, terms)[-1] % self.modulus 64 | return result 65 | 66 | def test_generate_data(): 67 | dg = ModuloDataGenerator('+-*', [0,100]) 68 | for i in range(1000): 69 | print '\n*******' 70 | exp_str, exp_val, all_terms, all_ops = dg.create_problem(3) 71 | assert exp_val == dg.evaluate_expression(all_terms, all_ops) 72 | print '{}={}'.format(exp_str, exp_val) 73 | 74 | if __name__ == '__main__': 75 | test_generate_data() 76 | 77 | -------------------------------------------------------------------------------- /MNIST/utils/cli.py: -------------------------------------------------------------------------------- 1 | import click 2 | from typing import Callable, Optional, Tuple 3 | 4 | from .log import ExperimentLog 5 | from ..agents import Agent 6 | from ..config import Config 7 | from ..run import eval_agent, train_agent, random_agent, SAVE_FILE_DEFAULT 8 | 9 | 10 | @click.group() 11 | @click.option('--gpu', required=False, type=int) 12 | @click.option('--seed', type=int, default=None) 13 | @click.pass_context 14 | def rainy_cli(ctx: dict, gpu: Tuple[int], seed: Optional[int]) -> None: 15 | ctx.obj['gpu'] = gpu 16 | ctx.obj['config'].seed = seed 17 | 18 | 19 | @rainy_cli.command() 20 | @click.pass_context 21 | @click.option('--comment', type=str, default=None) 22 | @click.option('--prefix', type=str, default='') 23 | def train(ctx: dict, comment: Optional[str], prefix: str) -> None: 24 | c = ctx.obj['config'] 25 | scr = ctx.obj['script_path'] 26 | if scr: 27 | c.logger.set_dir_from_script_path(scr, comment=comment, prefix=prefix) 28 | c.logger.set_stderr() 29 | ag = ctx.obj['make_agent'](c) 30 | train_agent(ag) 31 | print("random play: {}, trained: {}".format(ag.random_episode(), ag.eval_episode())) 32 | 33 | 34 | @rainy_cli.command() 35 | @click.option('--save', is_flag=True) 36 | @click.option('--render', is_flag=True) 37 | @click.option('--replay', is_flag=True) 38 | @click.option('--action-file', type=str, default='random-actions.json') 39 | @click.pass_context 40 | def random(ctx: dict, save: bool, render: bool, replay: bool, action_file: str) -> None: 41 | c = ctx.obj['config'] 42 | ag = ctx.obj['make_agent'](c) 43 | action_file = fname if save else None 44 | random_agent(ag, render=render, replay=replay, action_file=action_file) 45 | 46 | 47 | @rainy_cli.command() 48 | @click.argument('logdir', required=True, type=str) 49 | @click.option('--model', type=str, default=SAVE_FILE_DEFAULT) 50 | @click.option('--render', is_flag=True) 51 | @click.option('--replay', is_flag=True) 52 | @click.option('--action-file', type=str, default='best-actions.json') 53 | @click.pass_context 54 | def eval(ctx: dict, logdir: str, model: str, render: bool, replay: bool, action_file: str) -> None: 55 | c = ctx.obj['config'] 56 | ag = ctx.obj['make_agent'](c) 57 | eval_agent( 58 | ag, 59 | logdir, 60 | load_file_name=model, 61 | render=render, 62 | replay=replay, 63 | action_file=action_file 64 | ) 65 | 66 | 67 | @rainy_cli.command() 68 | @click.option('--log-dir', type=str) 69 | @click.option('--vi-mode', is_flag=True) 70 | @click.pass_context 71 | def ipython(ctx: dict, log_dir: Optional[str], vi_mode: bool) -> None: 72 | config, make_agent = ctx.obj['config'], ctx.obj['make_agent'] # noqa 73 | if log_dir is not None: 74 | log = ExperimentLog(log_dir) # noqa 75 | else: 76 | open_log = ExperimentLog # noqa 77 | try: 78 | from ptpython.ipython import embed 79 | del ctx, log_dir 80 | import rainy # noqa 81 | embed(vi_mode=vi_mode) 82 | except ImportError: 83 | print("To use ipython mode, install ipython and ptpython first.") 84 | 85 | 86 | def run_cli( 87 | config: Config, 88 | agent_gen: Callable[[Config], Agent], 89 | script_path: Optional[str] = None 90 | ) -> rainy_cli: 91 | obj = { 92 | 'config': config, 93 | 'make_agent': agent_gen, 94 | 'script_path': script_path 95 | } 96 | rainy_cli(obj=obj) 97 | 98 | -------------------------------------------------------------------------------- /synthetic/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import random 4 | import time 5 | #from modularity import RIM, SCOFF, RIMv2, SCOFFv2 6 | from RuleNetwork import RuleNetwork 7 | 8 | class MLP(nn.Module): 9 | def __init__(self, in_dim, out_dim, intermediate_dim = 32): 10 | super().__init__() 11 | self.mlp = nn.Sequential( 12 | nn.Linear(in_dim, intermediate_dim), 13 | nn.ReLU(), 14 | #nn.Linear(intermediate_dim, intermediate_dim), 15 | #nn.ReLU(), 16 | #nn.Linear(intermediate_dim, intermediate_dim), 17 | #nn.ReLU(), 18 | nn.Linear(intermediate_dim, out_dim) 19 | ) 20 | 21 | def forward(self, x): 22 | return self.mlp(x) 23 | 24 | 25 | 26 | class ArithmeticModel(nn.Module): 27 | def __init__(self, args, n_tokens = 10): 28 | super().__init__() 29 | self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 30 | self.algo = args.algo 31 | self.hidden_dim = args.nhid 32 | self.num_blocks = args.num_blocks 33 | self.num_layers = args.nlayers 34 | 35 | 36 | self.encoder = MLP(2, self.hidden_dim, intermediate_dim = 64) 37 | self.encoder_operation = MLP(3, self.hidden_dim, intermediate_dim = 64) 38 | 39 | 40 | 41 | self.application_option = args.application_option 42 | self.num_rules = args.num_rules 43 | self.design_config = {'comm': True, 'grad': False, 44 | 'transformer': True, 'application_option': '3.0.1.0', 'selection': 'gumble'} 45 | 46 | if self.num_rules > 0: 47 | self.rule_network = RuleNetwork(self.hidden_dim, 3, num_transforms = 3, num_rules = args.num_rules, rule_dim = args.rule_emb_dim, query_dim = 32, value_dim = 64, key_dim = 32, num_heads = 4, dropout = 0.1, design_config = self.design_config) 48 | self.decoder = MLP(3 * self.hidden_dim, 1, intermediate_dim = 64) 49 | else: 50 | self.operation_encoder = nn.Sequential(nn.Linear(3, self.hidden_dim), 51 | nn.ReLU(), 52 | nn.Linear(self.hidden_dim, self.hidden_dim)) 53 | self.decoder = MLP(2 * self.hidden_dim + 3, 1, intermediate_dim = 64) 54 | 55 | 56 | def forward(self, x_prev, x_cur, operation): 57 | x_prev = self.encoder(torch.cat([x_prev, torch.zeros([x_prev.shape[0], 1]).cuda()], dim=1)) 58 | x_cur = self.encoder(torch.cat([x_cur, torch.ones([x_cur.shape[0], 1]).cuda()], dim=1)) 59 | operation_rep = self.encoder_operation(operation) 60 | #import ipdb 61 | #ipdb.set_trace() 62 | 63 | 64 | 65 | if self.num_rules > 0: 66 | rule_out, _ = self.rule_network(torch.cat([x_prev.unsqueeze(1), x_cur.unsqueeze(1), operation_rep.unsqueeze(1)],dim=1), message_to_rule_network = None) #operation_rep) 67 | rule_out = rule_out.squeeze(1) 68 | out = rule_out 69 | #intermediate_rep = torch.cat((x_prev, x_cur, rule_out), dim = 1) 70 | else: 71 | #operation = self.operation_encoder(operation) 72 | intermediate_rep = torch.cat((x_prev, x_cur, operation), dim = 1) 73 | out = self.decoder(intermediate_rep) 74 | 75 | 76 | 77 | return out 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /MNIST/init.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Iterable, Optional, Union 3 | import torch 4 | from torch import nn, Tensor 5 | 6 | InitFn = partial 7 | 8 | 9 | def uniform(mean: float = 0., var: float = 1.) -> InitFn: 10 | return partial(nn.init.uniform_, a=mean, b=var) 11 | 12 | 13 | def orthogonal(gain: float = 1.) -> InitFn: 14 | return partial(nn.init.orthogonal_, gain=gain) 15 | 16 | 17 | def kaiming_normal(nonlinearity: str = 'relu') -> InitFn: 18 | return partial(nn.init.kaiming_normal_, nonlinearity=nonlinearity) 19 | 20 | 21 | def kaiming_uniform(nonlinearity: str = 'relu') -> InitFn: 22 | return partial(nn.init.kaiming_uniform_, nonlinearity=nonlinearity) 23 | 24 | 25 | def constant(val: float) -> InitFn: 26 | return partial(nn.init.constant_, val=val) 27 | 28 | 29 | def zero() -> InitFn: 30 | return partial(nn.init.constant_, val=0) 31 | 32 | 33 | def lstm_bias(forget: float = 1.0, other: float = 0.0) -> InitFn: 34 | """Set forget bias and others separately. 35 | """ 36 | def __set_bias(t: Tensor) -> None: 37 | with torch.no_grad(): 38 | i = len(t) // 4 39 | t.fill_(other) 40 | t[i:2 * i].fill_(forget) 41 | return partial(__set_bias) 42 | 43 | 44 | class Initializer: 45 | """Utility Class to initialize weight parameters of NN 46 | """ 47 | def __init__( 48 | self, 49 | nonlinearity: Optional[str] = None, 50 | weight_init: InitFn = orthogonal(), 51 | bias_init: InitFn = zero(), 52 | scale: float = 1., 53 | ) -> None: 54 | """If nonlinearity is specified, use orthogonal 55 | with calucurated gain by torch.init.calculate_gain. 56 | """ 57 | self.weight_init = weight_init 58 | if nonlinearity is not None: 59 | if 'gain' in self.weight_init.keywords: 60 | self.weight_init.keywords['gain'] = nn.init.calculate_gain(nonlinearity) 61 | elif 'nonlinearity' in self.weight_init.keywords: 62 | self.weight_init.keywords['nonlinearity'] = nonlinearity 63 | else: 64 | raise ValueError('{} doesn\'t have gain', self.weight_init) 65 | self.bias_init = bias_init 66 | self.scale = scale 67 | 68 | def __call__(self, mod: Union[nn.Module, nn.Sequential, Iterable[nn.Module]]) -> nn.Module: 69 | return self.__init_dispatch(mod) 70 | 71 | def make_list(self, *args) -> nn.ModuleList: 72 | return nn.ModuleList([self.__init_dispatch(mod) for mod in args]) 73 | 74 | def make_seq(self, *args) -> nn.Sequential: 75 | return nn.Sequential(*map(lambda mod: self.__init_dispatch(mod), args)) 76 | 77 | def __init_dispatch(self, mod: nn.Module) -> nn.Module: 78 | if isinstance(mod, nn.Sequential) or isinstance(mod, nn.ModuleList): 79 | for child in mod.children(): 80 | self.__init_dispatch(child) 81 | else: 82 | self.__init_mod(mod) 83 | return mod 84 | 85 | def __init_mod(self, mod: nn.Module) -> nn.Module: 86 | if isinstance(mod, nn.BatchNorm2d): 87 | return self.__init_batch_norm(mod) 88 | for name, param in mod.named_parameters(): 89 | if 'weight' in name: 90 | self.weight_init(param) 91 | elif 'bias' in name: 92 | self.bias_init(param) 93 | return mod 94 | 95 | def __init_batch_norm(self, mod: nn.BatchNorm2d) -> nn.Module: 96 | mod.weight.data.fill_(1) 97 | mod.bias.data.zero_() 98 | return mod 99 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | ap==0.1.4 3 | astor==0.8.1 4 | astunparse==1.6.3 5 | atari-py==0.2.6 6 | attrs==20.2.0 7 | pytube3==9.6.4 8 | av==8.0.2 9 | backcall==0.2.0 10 | botocore==1.19.13 11 | cachetools==4.1.1 12 | certifi==2020.6.20 13 | cffi==1.14.4 14 | chardet==3.0.4 15 | click==7.1.2 16 | cloudpickle==1.3.0 17 | cmake==3.18.2.post1 18 | configparser==5.0.1 19 | contextlib2==0.6.0.post1 20 | cycler==0.10.0 21 | Cython==0.29.21 22 | dataclasses==0.7 23 | decorator==4.4.2 24 | dill==0.3.3 25 | docker-pycreds==0.4.0 26 | dominate==2.6.0 27 | easydict==1.9 28 | einops==0.3.0 29 | enum34==1.1.10 30 | flake8==3.8.4 31 | flake8-import-order==0.18.1 32 | flatbuffers==1.12 33 | flax==0.3.0 34 | future==0.18.2 35 | gast==0.3.3 36 | gitdb==4.0.5 37 | GitPython==3.1.9 38 | google-auth==1.22.1 39 | google-auth-oauthlib==0.4.1 40 | google-pasta==0.2.0 41 | googleapis-common-protos==1.52.0 42 | grpcio==1.32.0 43 | gym==0.15.7 44 | h5py==2.10.0 45 | idna==2.10 46 | imageio==2.9.0 47 | imageio-ffmpeg==0.4.2 48 | importlib-metadata==2.0.0 49 | importlib-resources==4.1.1 50 | iniconfig==1.1.1 51 | ipdb==0.13.4 52 | ipython==7.16.1 53 | ipython-genutils==0.2.0 54 | jax==0.2.7 55 | jaxlib==0.1.57 56 | jedi==0.17.2 57 | jmespath==0.10.0 58 | joblib==0.17.0 59 | jsonpatch==1.26 60 | jsonpointer==2.0 61 | Keras-Applications==1.0.8 62 | Keras-Preprocessing==1.1.2 63 | kiwisolver==1.2.0 64 | kornia==0.4.1 65 | lmdb==1.0.0 66 | Markdown==3.3.3 67 | matplotlib==3.3.2 68 | mccabe==0.6.1 69 | ml-collections==0.1.0 70 | motmetrics==1.2.0 71 | moviepy==1.0.3 72 | msgpack==1.0.2 73 | networkx==2.5 74 | nibabel==3.1.1 75 | numpy==1.19.2 76 | oauthlib==3.1.0 77 | onnx==1.8.0 78 | opencv-python==4.4.0.44 79 | opt-einsum==3.3.0 80 | packaging==20.4 81 | pandas==1.1.2 82 | parso==0.7.1 83 | pathtools==0.1.2 84 | pexpect==4.8.0 85 | pickleshare==0.7.5 86 | Pillow==7.2.0 87 | pluggy==0.13.1 88 | portalocker==2.0.0 89 | proglog==0.1.9 90 | promise==2.3 91 | prompt-toolkit==3.0.8 92 | protobuf==3.13.0 93 | psutil==5.7.2 94 | ptyprocess==0.6.0 95 | py==1.9.0 96 | py-cpuinfo==7.0.0 97 | pyasn1==0.4.8 98 | pyasn1-modules==0.2.8 99 | pycodestyle==2.6.0 100 | pycparser==2.20 101 | pyflakes==2.2.0 102 | pyglet==1.5.0 103 | Pygments==2.7.1 104 | pyparsing==2.4.7 105 | PyPrind==2.11.2 106 | pytest==6.1.2 107 | pytest-benchmark==3.2.3 108 | python-box==5.1.1 109 | python-dateutil==2.8.1 110 | pytz==2020.1 111 | PyWavelets==1.1.1 112 | PyYAML==5.3.1 113 | pyzmq==19.0.2 114 | recordclass==0.14.3 115 | regex==2020.11.13 116 | requests==2.24.0 117 | requests-oauthlib==1.3.0 118 | rsa==4.6 119 | sacrebleu==1.4.14 120 | sacremoses==0.0.43 121 | scikit-image==0.15.0 122 | scikit-learn==0.23.2 123 | scipy==1.1.0 124 | seaborn==0.11.0 125 | sentry-sdk==0.19.1 126 | shortuuid==1.0.1 127 | six==1.15.0 128 | slot-attention==1.0.1 129 | smmap==3.0.4 130 | subprocess32==3.5.4 131 | tensorboard==2.4.0 132 | tensorboard-plugin-wit==1.7.0 133 | tensorboardX==2.1 134 | tensorflow-datasets==4.1.0 135 | tensorflow-estimator==2.4.0 136 | tensorflow-metadata==0.26.0 137 | termcolor==1.1.0 138 | threadpoolctl==2.1.0 139 | tifffile==2020.9.3 140 | toml==0.10.1 141 | torch==1.6.0 142 | torchfile==0.1.0 143 | torchnet==0.0.4 144 | -e git+https://github.com/ncullen93/torchsample.git@ea4d1b3975f68be0521941e733887ed667a1b46e#egg=torchsample 145 | torchvision==0.7.0 146 | tornado==6.0.4 147 | tqdm==4.50.0 148 | traitlets==4.3.3 149 | typing-extensions==3.7.4.3 150 | urllib3==1.25.11 151 | visdom==0.1.8.9 152 | wandb==0.10.7 153 | watchdog==0.10.3 154 | wcwidth==0.2.5 155 | websocket-client==0.57.0 156 | Werkzeug==1.0.1 157 | wrapt==1.12.1 158 | xmltodict==0.12.0 159 | yacs==0.1.8 160 | zipp==3.4.0 161 | -------------------------------------------------------------------------------- /MNIST/logbook/util.py: -------------------------------------------------------------------------------- 1 | """Logging functions to write to disk""" 2 | import json 3 | import logging 4 | import time 5 | 6 | from utils.util import flatten_dict 7 | 8 | 9 | def _format_log(log): 10 | """format logs""" 11 | log = _add_time_to_log(log) 12 | return json.dumps(log) 13 | 14 | 15 | def write_log(log): 16 | """This is the default method to write a log. 17 | It is assumed that the log has already been processed 18 | before feeding to this method""" 19 | get_logger().info(log) 20 | 21 | 22 | def _add_time_to_log(log): 23 | log["timestamp"] = time.strftime('%I:%M%p %Z %b %d, %Y') 24 | return log 25 | 26 | 27 | def read_log(log): 28 | """This is the single point to read any log message from the file. 29 | All the log messages are persisted as jsons""" 30 | try: 31 | data = json.loads(log) 32 | except json.JSONDecodeError as _: 33 | data = { 34 | } 35 | if data["type"] == "print": 36 | data = { 37 | 38 | } 39 | return data 40 | 41 | 42 | def _format_custom_logs(keys, raw_log, _type="metric"): 43 | """Method to format the custom logs""" 44 | log = {} 45 | if keys: 46 | for key in keys: 47 | if key in raw_log: 48 | log[key] = raw_log[key] 49 | else: 50 | log = raw_log 51 | log["type"] = _type 52 | return _format_log(log), log 53 | 54 | 55 | def write_message_logs(message, experiment_id=0): 56 | """"Write message logs""" 57 | kwargs = {"messgae": message, "experiment_id": experiment_id} 58 | log, _ = _format_custom_logs(keys=[], raw_log=kwargs, _type="print") 59 | write_log(log) 60 | 61 | 62 | def write_trajectory_logs(trajectory, experiment_id=0): 63 | """"Write message logs""" 64 | kwargs = {"message": trajectory, "experiment_id": experiment_id} 65 | log, _ = _format_custom_logs(keys=[], raw_log=kwargs, _type="trajectory") 66 | write_log(log) 67 | 68 | 69 | def write_config_log(config): 70 | """Write config logs""" 71 | config_to_write = json.loads(config.to_json()) 72 | log, _ = _format_custom_logs(keys=[], raw_log=config_to_write, _type="config") 73 | write_log(log) 74 | 75 | 76 | def write_metric_logs(metric): 77 | """Write metric logs""" 78 | keys = [] 79 | log, _ = _format_custom_logs(keys=keys, raw_log=flatten_dict(metric), _type="metric") 80 | write_log(log) 81 | 82 | 83 | def write_metadata_logs(**kwargs): 84 | """Write metadata logs""" 85 | log, _ = _format_custom_logs(keys=["best_epoch_index"], raw_log=kwargs, _type="metadata") 86 | write_log(log) 87 | 88 | 89 | def pprint(config): 90 | """pretty print""" 91 | print(json.dumps(config, indent=4)) 92 | 93 | 94 | def set_logger(config): 95 | """Modified from 96 | https://docs.python.org/3/howto/logging-cookbook.html""" 97 | logger = logging.getLogger("default_logger") 98 | logger.setLevel(logging.INFO) 99 | # create file handler which logs all the messages 100 | logger_file_path = "{}/{}".format(config.folder_log, "log.txt") 101 | file_handler = logging.FileHandler(logger_file_path) 102 | file_handler.setLevel(logging.INFO) 103 | # create console handler with a higher log level 104 | stream_handler = logging.StreamHandler() 105 | stream_handler.setLevel(logging.INFO) 106 | # create formatter and add it to the handlers 107 | formatter = logging.Formatter('%(message)s') 108 | file_handler.setFormatter(formatter) 109 | stream_handler.setFormatter(formatter) 110 | # add the handlers to the logger 111 | logger.addHandler(file_handler) 112 | logger.addHandler(stream_handler) 113 | return logger 114 | 115 | 116 | def get_logger(): 117 | """get logger""" 118 | return logging.getLogger("default_logger") -------------------------------------------------------------------------------- /MNIST/crl/dataloader/languages.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | class Math_to_Language(object): 4 | def __init__(self): 5 | self.vocabulary = None 6 | self.reverse_vocabulary = None 7 | 8 | def translate(self, math_expression): 9 | translation = [] 10 | for e in math_expression: 11 | translation.append(self.vocabulary[e]) 12 | return translation 13 | 14 | class Math_to_English(Math_to_Language): 15 | def __init__(self): 16 | self.vocabulary = OrderedDict([ 17 | ("0", "zero"), 18 | ("1", "one"), 19 | ("2", "two"), 20 | ("3", "three"), 21 | ("4", "four"), 22 | ("5", "five"), 23 | ("6", "six"), 24 | ("7", "seven"), 25 | ("8", "eight"), 26 | ("9", "nine"), 27 | ("+", "plus"), 28 | ("*", "times"), 29 | ("-", "minus"), 30 | ]) 31 | self.reverse_vocabulary = {v: k for k,v in self.vocabulary.iteritems()} 32 | 33 | class Math_to_Spanish(Math_to_Language): 34 | def __init__(self): 35 | self.vocabulary = OrderedDict([ 36 | ("0", "cero"), 37 | ("1", "uno"), 38 | ("2", "dos"), 39 | ("3", "trs"), 40 | ("4", "cuatro"), 41 | ("5", "cinco"), 42 | ("6", "seis"), 43 | ("7", "siete"), 44 | ("8", "ocho"), 45 | ("9", "nueve"), 46 | ("+", "mas"), 47 | ("*", "por"), 48 | ("-", "menos"), 49 | ]) 50 | self.reverse_vocabulary = {v: k for k,v in self.vocabulary.iteritems()} 51 | 52 | class Language_to_PigLatin(object): 53 | def __init__(self): 54 | self.alphabet = [chr(letter) for letter in range(97,123)] 55 | self.vowels = ['a', 'e', 'i', 'o', 'u'] 56 | self.consonants = [letter for letter in self.alphabet if letter not in self.vowels] 57 | 58 | self.vocabulary = None 59 | 60 | def translate(self, language_expression): 61 | translation = [] 62 | for e in language_expression: 63 | if e[0] in self.vowels: 64 | translated = e + 'yay' 65 | else: 66 | # find the first vowel 67 | index_of_first_vowel = -1 68 | for i in range(len(e)): 69 | if e[i] in self.vowels: 70 | index_of_first_vowel = i 71 | break 72 | translated = e[index_of_first_vowel:] + e[:index_of_first_vowel] + 'ay' 73 | translation.append(translated) 74 | return translation 75 | 76 | class English_to_PigLatin(Language_to_PigLatin): 77 | def __init__(self): 78 | super(English_to_PigLatin, self).__init__() 79 | self.vocabulary = OrderedDict() 80 | for word in Math_to_English().vocabulary.values(): 81 | self.vocabulary[word] = self.translate([word])[0] 82 | 83 | class Spanish_to_PigSpanish(Language_to_PigLatin): 84 | def __init__(self): 85 | super(Spanish_to_PigSpanish, self).__init__() 86 | self.vocabulary = OrderedDict() 87 | for word in Math_to_Spanish().vocabulary.values(): 88 | self.vocabulary[word] = self.translate([word])[0] 89 | 90 | class Language_to_ReverseLanguage(object): 91 | def __init__(self): 92 | self.vocabulary = None 93 | 94 | def translate(self, language_expression): 95 | translation = [] 96 | for e in language_expression: 97 | translation.append(e[::-1]) 98 | return translation 99 | 100 | class English_to_ReverseEnglish(Language_to_ReverseLanguage): 101 | def __init__(self): 102 | super(English_to_ReverseEnglish, self).__init__() 103 | self.vocabulary = OrderedDict() 104 | for word in Math_to_English().vocabulary.values(): 105 | self.vocabulary[word] = self.translate([word])[0] 106 | 107 | class Spanish_to_ReverseSpanish(Language_to_ReverseLanguage): 108 | def __init__(self): 109 | super(Spanish_to_ReverseSpanish, self).__init__() 110 | self.vocabulary = OrderedDict() 111 | for word in Math_to_Spanish().vocabulary.values(): 112 | self.vocabulary[word] = self.translate([word])[0] -------------------------------------------------------------------------------- /synthetic/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from torch.utils.data import Dataset, DataLoader 4 | import numpy as np 5 | 6 | class ArithmeticData(Dataset): 7 | def __init__(self, num_examples, example_length): 8 | self.num_examples = num_examples 9 | self.example_length = example_length 10 | 11 | self.digits = [] 12 | self.targets = [] 13 | self.operations = [] 14 | 15 | for i in range(self.num_examples): 16 | target = 0 17 | digs = [] 18 | ops = [] 19 | tar = [] 20 | for i in range(self.example_length): 21 | if random.uniform(0, 1) > 0: 22 | digs.append(np.array([round(random.uniform(0, 1), 2)])) 23 | operation_index = random.randint(0, 2) 24 | if operation_index == 0: 25 | target += digs[-1][0] 26 | elif operation_index == 1: 27 | target -= digs[-1][0] 28 | else: 29 | target *= digs[-1][0] 30 | op = np.zeros((3)) 31 | op[operation_index] = 1 32 | ops.append(op) 33 | tar.append(np.array([target])) 34 | else: 35 | digs.append(np.array([0])) 36 | ops.append(np.array([-1,-1])) 37 | tar.append(np.array([target])) 38 | 39 | self.targets.append(tar) 40 | self.digits.append(digs) 41 | self.operations.append(ops) 42 | 43 | def __len__(self): 44 | return self.num_examples 45 | 46 | def __getitem__(self, i): 47 | dig = self.digits[i] 48 | op = self.operations[i] 49 | dig = np.stack(dig, axis = 1) 50 | op = np.stack(op, axis = 1) 51 | tar = np.stack(self.targets[i], axis = 1) 52 | inp = np.concatenate((dig, op, tar), axis = 0).astype(np.float) 53 | return inp 54 | 55 | class ArithmeticDataSpec(Dataset): 56 | def __init__(self, num_examples, example_length): 57 | self.num_examples = num_examples 58 | self.example_length = example_length 59 | 60 | self.digit_1 = [] 61 | self.digit_2 = [] 62 | self.operation = [] 63 | self.target = [] 64 | 65 | for i in range(self.num_examples): 66 | range_index = random.randint(0, 2) 67 | if range_index == 0: 68 | self.digit_1.append(np.array([random.uniform(0, 0.33)])) 69 | self.digit_2.append(np.array([random.uniform(0, 0.33)])) 70 | self.operation.append(np.array([1.,0.,0.])) 71 | self.target.append(self.digit_1[-1] + self.digit_2[-1]) 72 | elif range_index == 1: 73 | self.digit_1.append(np.array([random.uniform(0.33, 0.66)])) 74 | self.digit_2.append(np.array([random.uniform(0.33, 0.66)])) 75 | self.operation.append(np.array([0.,1.,0.])) 76 | self.target.append(self.digit_1[-1] - self.digit_2[-1]) 77 | else: 78 | self.digit_1.append(np.array([random.uniform(0.66, 1.0)])) 79 | self.digit_2.append(np.array([random.uniform(0.66, 1.0)])) 80 | self.operation.append(np.array([0.,0.,1.])) 81 | self.target.append(self.digit_1[-1] * self.digit_2[-1]) 82 | 83 | """for i in range(self.num_examples): 84 | target = 0 85 | digs = [] 86 | ops = [] 87 | tar = [] 88 | for i in range(self.example_length): 89 | if random.uniform(0, 1) > 0: 90 | range_index = random.randint(0, 2) 91 | if range_index == 0: 92 | digs.append(np.array([round(random.uniform(0, 0.33), 2)])) 93 | op = np.zeros((3)) 94 | op[0] = 1 95 | ops.append(op) 96 | target += digs[-1][0] 97 | tar.append(np.array([target])) 98 | elif range_index == 1: 99 | digs.append(np.array([round(random.uniform(0.33, 0.66), 2)])) 100 | op = np.zeros((3)) 101 | op[1] = 1 102 | ops.append(op) 103 | target -= digs[-1][0] 104 | tar.append(np.array([target])) 105 | else: 106 | digs.append(np.array([round(random.uniform(0.66, 1), 2)])) 107 | op = np.zeros((3)) 108 | op[2] = 1 109 | ops.append(op) 110 | target *= digs[-1][0] 111 | tar.append(np.array([target])) 112 | else: 113 | digs.append(np.array([0])) 114 | ops.append(np.array([-1,-1])) 115 | tar.append(np.array([target])) 116 | self.targets.append(tar) 117 | self.digits.append(digs) 118 | self.operations.append(ops)""" 119 | 120 | def __len__(self): 121 | return self.num_examples 122 | 123 | def __getitem__(self, i): 124 | 125 | return self.digit_1[i], self.digit_2[i], self.operation[i], self.target[i] 126 | 127 | 128 | 129 | if __name__ == '__main__': 130 | data = ArithmeticData(500, 10) 131 | for i in range(len(data)): 132 | print(data[i]) 133 | print('----------------') 134 | 135 | -------------------------------------------------------------------------------- /MNIST/utilities/slot_attention_custom.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.distributions.multivariate_normal import MultivariateNormal 4 | 5 | class MLP(nn.Module): 6 | def __init__(self, in_dim, out_dim): 7 | super().__init__() 8 | self.mlp = nn.Sequential(nn.Linear(in_dim, 64), 9 | nn.ReLU(), 10 | nn.Linear(64, out_dim)) 11 | def forward(self, x): 12 | return self.mlp(x) 13 | 14 | class SlotAttention(nn.Module): 15 | def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128, out_dim = 64, in_dim = 64): 16 | super().__init__() 17 | self.num_slots = num_slots 18 | self.iters = iters 19 | self.eps = eps 20 | self.scale = dim ** -0.5 21 | self.dim = dim 22 | #self.slots_mu = nn.Parameter(torch.randn(1, 1, dim)) 23 | #self.slots_sigma = nn.Parameter(torch.randn(1, 1, dim)) 24 | self.slots_rep = nn.Linear(dim, 2*dim) 25 | self.slots_rep_pr = nn.Linear(dim, 2*dim) 26 | self.to_q = nn.Linear(dim, dim) 27 | self.to_k = nn.Linear(in_dim, dim) 28 | self.to_v = nn.Linear(in_dim, dim) 29 | self.gru = nn.GRUCell(dim, dim) 30 | hidden_dim = max(dim, hidden_dim) 31 | self.mlp = nn.Sequential( 32 | nn.Linear(dim, hidden_dim), 33 | nn.ReLU(inplace = True), 34 | nn.Linear(hidden_dim, dim) 35 | ) 36 | self.mlp_cast = MLP(dim, out_dim) 37 | self.norm_input = nn.LayerNorm(in_dim) 38 | self.norm_slots = nn.LayerNorm(dim) 39 | self.norm_pre_ff = nn.LayerNorm(dim) 40 | def reparameterize(self, mu, logvar, eps = None, deterministic=False): 41 | if not deterministic: 42 | std = logvar.mul(0.5).exp_() 43 | if eps is None: 44 | eps = std.data.new(std.size()).normal_() 45 | return eps.mul(std).add_(mu), eps 46 | else: 47 | return mu, eps 48 | 49 | def get_kl_loss(self, prior, post): 50 | prior = prior.reshape(-1, 2 * self.dim) 51 | post = post.reshape(-1, 2 * self.dim) 52 | 53 | mean_pr, std_pr = prior[:, : self.dim], prior[:, self.dim : 2 * self.dim] 54 | mean_po, std_po = post[:, : self.dim], post[:, self.dim : 2 * self.dim] 55 | 56 | std_pr = std_pr.mul(0.5).exp_() 57 | std_po = std_po.mul(0.5).exp_() 58 | 59 | q1 = MultivariateNormal(loc=mean_pr, scale_tril=torch.diag_embed(std_pr)) 60 | 61 | 62 | q2 = MultivariateNormal(loc=mean_po, scale_tril=torch.diag_embed(std_po)) 63 | 64 | kl = torch.distributions.kl.kl_divergence(q2, q1) 65 | 66 | return kl 67 | 68 | def forward(self, inputs, slots, num_slots = None): 69 | b, n, d_ = inputs.shape 70 | n_s = num_slots if num_slots is not None else self.num_slots 71 | 72 | #slots = torch.normal(mu, sigma) 73 | # hx.reshape((hx.shape[0], self.num_blocks_out, self.block_size_out))) 74 | #mu = self.slots_mu.expand(b, n_s, -1) 75 | #sigma = self.slots_sigma.expand(b, n_s, -1) 76 | d = slots.size(-1) 77 | inputs = self.norm_input(inputs) 78 | k, v = self.to_k(inputs), self.to_v(inputs) 79 | 80 | prior_mu_sigma = self.slots_rep_pr(slots.reshape((slots.shape[0] * n_s, self.dim))) 81 | prior = prior_mu_sigma.reshape(-1, 2 * self.dim) 82 | 83 | slots_mu_sigma = self.slots_rep(slots.reshape((slots.shape[0] * n_s, self.dim))) 84 | 85 | kl_loss = 0 86 | for t in range(self.iters): 87 | slots_mu_sigma = slots_mu_sigma.reshape((b, n_s, 2*self.dim)) 88 | 89 | mu = slots_mu_sigma[:, :, :self.dim] 90 | sigma = slots_mu_sigma[:, :, self.dim: 2*self.dim] 91 | slots, eps = self.reparameterize(mu, sigma) 92 | 93 | slots_prev = slots 94 | slots = self.norm_slots(slots) 95 | q = self.to_q(slots) 96 | dots = torch.einsum('bid,bjd->bij', q, k) * self.scale 97 | attn = dots.softmax(dim=1) + self.eps 98 | attn = attn / attn.sum(dim=-1, keepdim=True) 99 | updates = torch.einsum('bjd,bij->bid', v, attn) 100 | slots = self.gru( 101 | updates.reshape(-1, d), 102 | slots_prev.reshape(-1, d) 103 | ) 104 | slots = slots.reshape(b, -1, d) 105 | slots = slots + self.mlp(self.norm_pre_ff(slots)) 106 | slots_mu_sigma = self.slots_rep(slots.reshape((slots.shape[0] * n_s, self.dim))) 107 | slots_mu_sigma = slots_mu_sigma.reshape((b, n_s, 2*self.dim)) 108 | kl_loss += self.get_kl_loss(prior, slots_mu_sigma) 109 | 110 | #slots = self.mlp_cast(slots) 111 | return slots, kl_loss -------------------------------------------------------------------------------- /MNIST/logbook/logbook.py: -------------------------------------------------------------------------------- 1 | """Wrapper over wandb api""" 2 | 3 | #import wandb 4 | from box import Box 5 | 6 | from logbook import util as log_func 7 | from utils.util import flatten_dict, make_dir 8 | 9 | 10 | class LogBook(): 11 | """Wrapper over comet_ml api""" 12 | 13 | def __init__(self, config): 14 | self.metrics_to_record = [ 15 | "mode", 16 | "batch_idx", 17 | "epoch", 18 | "time", 19 | "correct", 20 | "loss", 21 | "kl_loss", 22 | "time_taken", 23 | "stove_pixel_loss_10" 24 | ] #+ [f"activation_frequency_{idx}" for idx in range(config.num_blocks)] 25 | self._experiment_id = 1 26 | config_dict = vars(config) 27 | 28 | flattened_config = flatten_dict(config_dict, sep="_") 29 | 30 | self.config = Box(config_dict) 31 | # logger_file_dir = "{}/{}".format(self.config.log_folder, self.config.id) 32 | if not config.should_resume: 33 | make_dir(self.config.folder_log) 34 | #make_dir(self.config.folder_log) 35 | log_func.set_logger(self.config) 36 | 37 | self.should_use_remote_logger = False 38 | 39 | if self.should_use_remote_logger: 40 | wandb.init(config=flattened_config, 41 | project="ool", 42 | name=self.config.id, 43 | dir=self.config.folder_log, 44 | entity="kappa") 45 | 46 | self.tensorboard_writer = None 47 | self.should_use_tb = False 48 | 49 | log_func.write_config_log(self.config) 50 | 51 | def log_metrics(self, dic, prefix, step): 52 | """Method to log metric""" 53 | formatted_dict = {} 54 | for key, val in dic.items(): 55 | formatted_dict[prefix + "_" + key] = val 56 | if self.should_use_remote_logger: 57 | wandb.log(formatted_dict, step=step) 58 | 59 | def write_config_log(self, config): 60 | """Write config""" 61 | log_func.write_config_log(config) 62 | flatten_config = flatten_dict(config, sep="_") 63 | flatten_config['experiment_id'] = self._experiment_id 64 | 65 | def write_metric_logs(self, metrics): 66 | """Write Metric""" 67 | metrics['experiment_id'] = self._experiment_id 68 | log_func.write_metric_logs(metrics) 69 | flattened_metrics = flatten_dict(metrics, sep="_") 70 | #key = "activation_frequency" 71 | #activation_frequency = flattened_metrics.pop(key) 72 | #for idx, freq in enumerate(activation_frequency): 73 | # flattened_metrics[f"{key}_{idx}"] = freq 74 | metric_dict = { 75 | key: flattened_metrics[key] 76 | for key in self.metrics_to_record if key in flattened_metrics 77 | } 78 | prefix = metric_dict.pop("mode") 79 | batch_idx = metric_dict["batch_idx"] 80 | self.log_metrics(dic=metric_dict, 81 | prefix=prefix, 82 | step=batch_idx) 83 | 84 | # if self.should_use_tb: 85 | # 86 | # timestep_key = "num_timesteps" 87 | # for key in set(list(metrics.keys())) - set([timestep_key]): 88 | # self.tensorboard_writer.add_scalar(tag=key, 89 | # scalar_value=metrics[key], 90 | # global_step=metrics[timestep_key]) 91 | 92 | def write_image(self, img, mode, step, caption): 93 | if self.should_use_remote_logger: 94 | return wandb.log({f"{mode}_{caption}": 95 | [wandb.Image(img, caption = str(step))]}, step=step) 96 | 97 | def write_compute_logs(self, **kwargs): 98 | """Write Compute Logs""" 99 | kwargs['experiment_id'] = self._experiment_id 100 | log_func.write_metric_logs(**kwargs) 101 | metric_dict = flatten_dict(kwargs, sep="_") 102 | 103 | num_timesteps = metric_dict.pop("num_timesteps") 104 | self.log_metrics(dic=metric_dict, 105 | step=num_timesteps, 106 | prefix="compute") 107 | 108 | def write_message_logs(self, message): 109 | """Write message""" 110 | log_func.write_message_logs(message, experiment_id=self._experiment_id) 111 | 112 | def write_metadata_logs(self, **kwargs): 113 | """Write metadata""" 114 | log_func.write_metadata_logs(**kwargs) 115 | # self.log_other(key="best_epoch_index", value=kwargs["best_epoch_index"]) 116 | 117 | def watch_model(self, model): 118 | """Method to track the gradients of the model""" 119 | if self.should_use_remote_logger: 120 | wandb.watch(models=model, log="all") 121 | -------------------------------------------------------------------------------- /MNIST/utilities/slot_attention_custom_2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import math 4 | from torch.distributions.multivariate_normal import MultivariateNormal 5 | 6 | class MLP(nn.Module): 7 | def __init__(self, in_dim, out_dim): 8 | super().__init__() 9 | self.mlp = nn.Sequential(nn.Linear(in_dim, 64), 10 | nn.ReLU(), 11 | nn.Linear(64, out_dim)) 12 | def forward(self, x): 13 | return self.mlp(x) 14 | 15 | class SlotAttentionV2(nn.Module): 16 | def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128, out_dim = 64, in_dim = 64): 17 | super().__init__() 18 | self.num_slots = num_slots 19 | self.iters = iters 20 | self.eps = eps 21 | self.scale = dim ** -0.5 22 | self.dim = dim 23 | #self.slots_mu = nn.Parameter(torch.randn(1, 1, dim)) 24 | #self.slots_sigma = nn.Parameter(torch.randn(1, 1, dim)) 25 | self.attn_images = [] 26 | self.cur_image = None 27 | self.colors = [0,0,255] 28 | self.slots_rep = nn.Linear(dim, 2*dim) 29 | self.slots_rep_pr = nn.Linear(dim, 2*dim) 30 | self.to_q = nn.Linear(dim, dim) 31 | self.to_k = nn.Linear(in_dim, dim) 32 | self.to_v = nn.Linear(in_dim, dim) 33 | self.gru = nn.GRUCell(dim, dim) 34 | hidden_dim = max(dim, hidden_dim) 35 | self.mlp = nn.Sequential( 36 | nn.Linear(dim, hidden_dim), 37 | nn.ReLU(inplace = True), 38 | nn.Linear(hidden_dim, dim) 39 | ) 40 | self.mlp_cast = MLP(dim, out_dim) 41 | self.norm_input = nn.LayerNorm(in_dim) 42 | self.norm_slots = nn.LayerNorm(dim) 43 | self.norm_pre_ff = nn.LayerNorm(dim) 44 | def reparameterize(self, mu, logvar, eps = None, deterministic=False): 45 | if not deterministic: 46 | std = logvar.mul(0.5).exp_() 47 | if eps is None: 48 | eps = std.data.new(std.size()).normal_() 49 | return eps.mul(std).add_(mu), eps 50 | else: 51 | return mu, eps 52 | 53 | def get_kl_loss(self, prior, post): 54 | prior = prior.reshape(-1, 2 * self.dim) 55 | post = post.reshape(-1, 2 * self.dim) 56 | 57 | mean_pr, std_pr = prior[:, : self.dim], prior[:, self.dim : 2 * self.dim] 58 | mean_po, std_po = post[:, : self.dim], post[:, self.dim : 2 * self.dim] 59 | 60 | std_pr = std_pr.mul(0.5).exp_() 61 | std_po = std_po.mul(0.5).exp_() 62 | 63 | q1 = MultivariateNormal(loc=mean_pr, scale_tril=torch.diag_embed(std_pr)) 64 | 65 | 66 | q2 = MultivariateNormal(loc=mean_po, scale_tril=torch.diag_embed(std_po)) 67 | 68 | kl = torch.distributions.kl.kl_divergence(q2, q1) 69 | 70 | return kl 71 | 72 | def forward(self, inputs, slots, prior_slots, num_slots = None): 73 | b, n, d_ = inputs.shape 74 | n_s = num_slots if num_slots is not None else self.num_slots 75 | d = slots.size(-1) 76 | inputs = self.norm_input(inputs) 77 | k, v = self.to_k(inputs), self.to_v(inputs) 78 | slots_mu_sigma = self.slots_rep(slots.reshape((slots.shape[0] * n_s, self.dim))) 79 | prior_mu_sigma = self.slots_rep_pr(prior_slots.reshape((slots.shape[0] * n_s, self.dim)).detach()) 80 | prior = prior_mu_sigma.reshape(-1, 2 * self.dim) 81 | slots_mu_sigma = slots_mu_sigma.reshape((b, n_s, 2*self.dim)) 82 | mu = slots_mu_sigma[:, :, :self.dim] 83 | sigma = slots_mu_sigma[:, :, self.dim: 2*self.dim] 84 | slots, eps = self.reparameterize(mu, sigma) 85 | kl_loss = 0 86 | for t in range(1): 87 | slots_prev = slots 88 | slots = self.norm_slots(slots) 89 | q = self.to_q(slots) 90 | dots = torch.einsum('bid,bjd->bij', q, k) * self.scale 91 | attn = dots.softmax(dim=1) + self.eps 92 | 93 | attn_image_ = torch.zeros(self.num_slots, 3, int(math.sqrt(attn.size(2))), int(math.sqrt(attn.size(2)))).to(attn.device) 94 | attn_image_ = attn_image_.view((self.num_slots, 3, -1)) 95 | attn_image_[:,2,:] = 255 96 | for k in range(self.num_slots): 97 | attn_image_[k, 2] = attn_image_[k, 2] * attn.detach()[0, k] 98 | attn_image_ = attn_image_.view(self.num_slots, 3, int(math.sqrt(attn.size(2))), int(math.sqrt(attn.size(2)))) 99 | self.cur_image = attn_image_ 100 | 101 | attn = attn / attn.sum(dim=-1, keepdim=True) 102 | 103 | 104 | 105 | 106 | updates = torch.einsum('bjd,bij->bid', v, attn) 107 | slots = self.gru( 108 | updates.reshape(-1, d), 109 | slots_prev.reshape(-1, d) 110 | ) 111 | slots = slots.reshape(b, -1, d) 112 | slots = slots + self.mlp(self.norm_pre_ff(slots)) 113 | slots_mu_sigma = self.slots_rep(slots.reshape((slots.shape[0] * n_s, self.dim))) 114 | slots_mu_sigma = slots_mu_sigma.reshape((b, n_s, 2*self.dim)) 115 | kl_loss += self.get_kl_loss(prior, slots_mu_sigma) 116 | return slots, kl_loss 117 | 118 | def add_image(self): 119 | self.attn_images.append(self.cur_image) 120 | 121 | def reset_images(self): 122 | self.attn_images = [] -------------------------------------------------------------------------------- /MNIST/utilities/BlockGRU.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | ''' 6 | Goal1: a GRU where the weight matrices have a block structure so that information flow is constrained 7 | 8 | Data is assumed to come in [block1, block2, ..., block_n]. 9 | 10 | Goal2: Dynamic parameter sharing between blocks (RIMs) 11 | 12 | ''' 13 | 14 | import torch 15 | import torch.nn as nn 16 | from .GroupLinearLayer import GroupLinearLayer 17 | from .sparse_attn import Sparse_attention 18 | 19 | 20 | ''' 21 | Given an N x N matrix, and a grouping of size, set all elements off the block diagonal to 0.0 22 | ''' 23 | def zero_matrix_elements(matrix, k): 24 | assert matrix.shape[0] % k == 0 25 | assert matrix.shape[1] % k == 0 26 | g1 = matrix.shape[0] // k 27 | g2 = matrix.shape[1] // k 28 | new_mat = torch.zeros_like(matrix) 29 | for b in range(0,k): 30 | new_mat[b*g1 : (b+1)*g1, b*g2 : (b+1)*g2] += matrix[b*g1 : (b+1)*g1, b*g2 : (b+1)*g2] 31 | 32 | matrix *= 0.0 33 | matrix += new_mat 34 | 35 | 36 | class BlockGRU(nn.Module): 37 | """Container module with an encoder, a recurrent module, and a decoder.""" 38 | 39 | def __init__(self, ninp, nhid, k): 40 | super(BlockGRU, self).__init__() 41 | 42 | assert ninp % k == 0 43 | assert nhid % k == 0 44 | 45 | self.k = k 46 | self.gru = nn.GRUCell(ninp, nhid) 47 | self.nhid = nhid 48 | self.ninp = ninp 49 | 50 | def blockify_params(self): 51 | pl = self.gru.parameters() 52 | 53 | for p in pl: 54 | p = p.data 55 | if p.shape == torch.Size([self.nhid*3]): 56 | pass 57 | '''biases, don't need to change anything here''' 58 | if p.shape == torch.Size([self.nhid*3, self.nhid]) or p.shape == torch.Size([self.nhid*3, self.ninp]): 59 | for e in range(0,4): 60 | zero_matrix_elements(p[self.nhid*e : self.nhid*(e+1)], k=self.k) 61 | 62 | def forward(self, input, h): 63 | 64 | #self.blockify_params() 65 | 66 | hnext = self.gru(input, h) 67 | 68 | return hnext, None 69 | 70 | class Identity(torch.autograd.Function): 71 | @staticmethod 72 | def forward(ctx, input): 73 | return input * 1.0 74 | def backward(ctx, grad_output): 75 | print(torch.sqrt(torch.sum(torch.pow(grad_output,2)))) 76 | print('-----------') 77 | return grad_output * 1.0 78 | 79 | class SharedBlockGRU(nn.Module): 80 | """Dynamic sharing of parameters between blocks(RIM's)""" 81 | def __init__(self, ninp, nhid, k, n_templates): 82 | super(SharedBlockGRU, self).__init__() 83 | assert ninp % k == 0 84 | assert nhid % k == 0 85 | self.k = k 86 | self.m = nhid // self.k 87 | self.n_templates = n_templates 88 | print("input to template is ", ninp%k) 89 | self.templates = nn.ModuleList([nn.GRUCell(ninp//k,self.m) for _ in range(0,self.n_templates)]) 90 | self.nhid = nhid 91 | self.ninp = ninp 92 | self.gll_write = GroupLinearLayer(self.m,16, self.n_templates) 93 | self.gll_read = GroupLinearLayer(self.m,16,1) 94 | self.sa = Sparse_attention(1) 95 | print("Using Gumble sparsity") 96 | def blockify_params(self): 97 | return 98 | def forward(self, input, h): 99 | #self.blockify_params() 100 | bs = h.shape[0] 101 | h = h.reshape((h.shape[0], self.k, self.m)).reshape((h.shape[0]*self.k, self.m)) 102 | input = input.reshape(input.shape[0], 1, input.shape[1]) 103 | #input = input.repeat(1,self.k,1) 104 | #input = input.reshape(input.shape[0]*self.k, input.shape[2]) 105 | input = input.reshape(input.shape[0]*self.k, -1) 106 | #print("input shape is", input.shape) 107 | h_read = self.gll_read((h*1.0).reshape((h.shape[0], 1, h.shape[1]))) 108 | hnext_stack = [] 109 | for template in self.templates: 110 | hnext_l = template(input, h) 111 | hnext_l = hnext_l.reshape((hnext_l.shape[0], 1, hnext_l.shape[1])) 112 | hnext_stack.append(hnext_l) 113 | hnext = torch.cat(hnext_stack, 1) 114 | write_key = self.gll_write(hnext) 115 | ''' 116 | sm = nn.Softmax(2) 117 | att = sm(torch.bmm(h_read, write_key.permute(0, 2, 1))).squeeze(1) 118 | att = self.sa(att).unsqueeze(1) 119 | ''' 120 | att = torch.nn.functional.gumbel_softmax(torch.bmm(h_read, write_key.permute(0, 2, 1)), tau=1, hard=True) 121 | #att = att*0.0 + 0.25 122 | hnext = torch.bmm(att, hnext) 123 | hnext = hnext.mean(dim=1) 124 | hnext = hnext.reshape((bs, self.k, self.m)).reshape((bs, self.k*self.m)) 125 | #print('shapes', hnext.shape, cnext.shape) 126 | return hnext, att.data.reshape(bs,self.k,self.n_templates) 127 | 128 | 129 | 130 | if __name__ == "__main__": 131 | 132 | Blocks = BlockGRU(2, 6, k=2) 133 | opt = torch.optim.Adam(Blocks.parameters()) 134 | 135 | pl = Blocks.gru.parameters() 136 | 137 | inp = torch.randn(100,2) 138 | h = torch.randn(100,6) 139 | 140 | h2 = Blocks(inp,h) 141 | 142 | L = h2.sum()**2 143 | 144 | #L.backward() 145 | #opt.step() 146 | #opt.zero_grad() 147 | 148 | 149 | pl = Blocks.gru.parameters() 150 | for p in pl: 151 | print(p.shape) 152 | #print(torch.Size([Blocks.nhid*4])) 153 | if p.shape == torch.Size([Blocks.nhid*3]): 154 | print(p.shape, 'a') 155 | #print(p) 156 | '''biases, don't need to change anything here''' 157 | if p.shape == torch.Size([Blocks.nhid*3, Blocks.nhid]) or p.shape == torch.Size([Blocks.nhid*3, Blocks.ninp]): 158 | print(p.shape, 'b') 159 | for e in range(0,4): 160 | print(p[Blocks.nhid*e : Blocks.nhid*(e+1)]) 161 | -------------------------------------------------------------------------------- /MNIST/bootstrap.py: -------------------------------------------------------------------------------- 1 | """Main entry point of the code""" 2 | from __future__ import print_function 3 | 4 | from os import listdir 5 | 6 | import torch 7 | import random 8 | import numpy as np 9 | import os 10 | from EncoderDecoder import MNISTModel 11 | from argument_parser import argument_parser 12 | from crl.dataset import get_dataloaders as crl_dataloaders 13 | from logbook.logbook import LogBook 14 | from utils.util import make_dir 15 | 16 | 17 | def repackage_hidden(ten_): 18 | """Wraps hidden states in new Tensors, to detach them from their history.""" 19 | if isinstance(ten_, torch.Tensor): 20 | return ten_.detach() 21 | else: 22 | return tuple(repackage_hidden(v) for v in ten_) 23 | 24 | 25 | def main(train_loop, test_loop): 26 | """Function to run the experiment""" 27 | args = argument_parser() 28 | logbook = LogBook(config=args) 29 | 30 | 31 | 32 | if not args.should_resume: 33 | # New Experiment 34 | make_dir(f"{args.folder_log}/model") 35 | logbook.write_message_logs(message=f"Saving args to {args.folder_log}/model/args") 36 | torch.save({"args": vars(args)}, f"{args.folder_log}/model/args") 37 | 38 | use_cuda = torch.cuda.is_available() 39 | args.device = torch.device("cuda" if use_cuda else "cpu") 40 | model = setup_model(args=args, logbook=logbook) 41 | if args.num_rules > 0: 42 | model.rule_network.share_key_value = args.share_key_value 43 | 44 | # logbook.watch_model(model) 45 | 46 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 47 | 48 | args.directory = '.' 49 | 50 | train_loader, test_loader, transfer_loader = crl_dataloaders(num_transforms = args.num_transforms, transform_length = args.transform_length, batch_size = args.batch_size, color = args.color) 51 | 52 | train_batch_idx = 0 53 | 54 | start_epoch = 1 55 | if args.should_resume: 56 | start_epoch = args.checkpoint["epoch"] + 1 57 | logbook.write_message_logs(message=f"Resuming experiment id: {args.id}, from epoch: {start_epoch}") 58 | 59 | for epoch in range(start_epoch, args.epochs + 1): 60 | test_loop(model=model, 61 | test_loader=test_loader, 62 | epoch=epoch, 63 | transfer_loader=transfer_loader, 64 | logbook=logbook, 65 | train_batch_idx=train_batch_idx, 66 | args=args) 67 | train_batch_idx = train_loop(model=model, 68 | train_loader=train_loader, 69 | optimizer=optimizer, 70 | epoch=epoch, 71 | logbook=logbook, 72 | train_batch_idx=train_batch_idx, 73 | args=args) 74 | print("Epoch number", epoch) 75 | 76 | if args.model_persist_frequency > 0 and epoch % args.model_persist_frequency == 0: 77 | logbook.write_message_logs(message=f"Saving model to {args.folder_log}/model/{epoch}") 78 | torch.save(model.state_dict(), f"{args.folder_log}/model/{epoch}") 79 | 80 | 81 | def setup_model(args, logbook): 82 | """Method to setup the model""" 83 | print('Setting seed to ' + str(args.seed)) 84 | 85 | #seed = args.seed 86 | #random.seed(seed) 87 | #np.random.seed(seed) 88 | #torch.manual_seed(seed) 89 | #if torch.cuda.is_available(): 90 | # torch.cuda.manual_seed_all(seed) 91 | #os.environ['PYTHONHASHSEED'] = str(seed) 92 | model = MNISTModel(args) 93 | 94 | 95 | if args.should_resume: 96 | # Find the last checkpointed model and resume from that 97 | model_dir = f"{args.folder_log}/model" 98 | latest_model_idx = max( 99 | [int(model_idx) for model_idx in listdir(model_dir) 100 | if model_idx != "args"] 101 | ) 102 | args.path_to_load_model = f"{model_dir}/{latest_model_idx}" 103 | args.checkpoint = {"epoch": latest_model_idx} 104 | 105 | if args.path_to_load_model != "": 106 | 107 | shape_offset = {} 108 | for path_to_load_model in args.path_to_load_model.split(","): 109 | logbook.write_message_logs(message=f"Loading model from {path_to_load_model}") 110 | _, shape_offset = model.load_state_dict(torch.load(path_to_load_model.strip()), 111 | shape_offset) 112 | 113 | if not args.should_resume: 114 | components_to_load = set(args.components_to_load.split("_")) 115 | total_components = set(["encoders", "decoders", "rules", "blocks"]) 116 | components_to_reinit = [component for component in total_components 117 | if component not in components_to_load] 118 | for component in components_to_reinit: 119 | if component == "blocks": 120 | logbook.write_message_logs(message="Reinit Blocks") 121 | model.rnn_.gru.myrnn.init_blocks() 122 | elif component == "rules": 123 | logbook.write_message_logs(message="Reinit Rules") 124 | model.rnn_.gru.myrnn.init_rules() 125 | elif component == "encoders": 126 | logbook.write_message_logs(message="Reinit Encoders") 127 | model.init_encoders() 128 | elif component == "decoders": 129 | logbook.write_message_logs(message="Reinit Decoders") 130 | model.init_decoders() 131 | 132 | 133 | else: 134 | model = MNISTModel(args) 135 | 136 | model = model.to(args.device) 137 | 138 | return model 139 | 140 | -------------------------------------------------------------------------------- /MNIST/utilities/BlockLSTM.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | Goal1: an LSTM where the weight matrices have a block structure so that information flow is constrained 4 | 5 | Data is assumed to come in [block1, block2, ..., block_n]. 6 | 7 | Goal2: Dynamic parameter sharing between blocks (RIMs) 8 | 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | from .GroupLinearLayer import GroupLinearLayer 14 | ''' 15 | Given an N x N matrix, and a grouping of size, set all elements off the block diagonal to 0.0 16 | ''' 17 | def zero_matrix_elements(matrix, k): 18 | assert matrix.shape[0] % k == 0 19 | assert matrix.shape[1] % k == 0 20 | g1 = matrix.shape[0] // k 21 | g2 = matrix.shape[1] // k 22 | new_mat = torch.zeros_like(matrix) 23 | for b in range(0,k): 24 | new_mat[b*g1 : (b+1)*g1, b*g2 : (b+1)*g2] += matrix[b*g1 : (b+1)*g1, b*g2 : (b+1)*g2] 25 | 26 | matrix *= 0.0 27 | matrix += new_mat 28 | 29 | 30 | class BlockLSTM(nn.Module): 31 | """Container module with an encoder, a recurrent module, and a decoder.""" 32 | 33 | def __init__(self, ninp, nhid, k): 34 | super(BlockLSTM, self).__init__() 35 | 36 | assert ninp % k == 0 37 | assert nhid % k == 0 38 | 39 | self.k = k 40 | self.lstm = nn.LSTMCell(ninp, nhid) 41 | self.nhid = nhid 42 | self.ninp = ninp 43 | 44 | def blockify_params(self): 45 | pl = self.lstm.parameters() 46 | 47 | for p in pl: 48 | p = p.data 49 | if p.shape == torch.Size([self.nhid*4]): 50 | pass 51 | '''biases, don't need to change anything here''' 52 | if p.shape == torch.Size([self.nhid*4, self.nhid]) or p.shape == torch.Size([self.nhid*4, self.ninp]): 53 | for e in range(0,4): 54 | zero_matrix_elements(p[self.nhid*e : self.nhid*(e+1)], k=self.k) 55 | 56 | def forward(self, input, h, c): 57 | 58 | #self.blockify_params() 59 | 60 | hnext, cnext = self.lstm(input, (h, c)) 61 | 62 | return hnext, cnext, None 63 | 64 | class SharedBlockLSTM(nn.Module): 65 | """Dynamic sharing of parameters between blocks(RIM's)""" 66 | 67 | def __init__(self, ninp, nhid, k , n_templates): 68 | super(SharedBlockLSTM, self).__init__() 69 | 70 | assert ninp % k == 0 71 | assert nhid % k == 0 72 | 73 | self.k = k 74 | self.m = nhid // self.k 75 | self.n_templates = n_templates 76 | self.templates = nn.ModuleList([nn.LSTMCell(ninp,self.m) for _ in range(0,self.n_templates)]) 77 | self.nhid = nhid 78 | 79 | self.ninp = ninp 80 | 81 | self.gll_write = GroupLinearLayer(self.m,16, self.n_templates) 82 | self.gll_read = GroupLinearLayer(self.m,16,1) 83 | 84 | def blockify_params(self): 85 | 86 | return 87 | 88 | def forward(self, input, h, c): 89 | 90 | #self.blockify_params() 91 | bs = h.shape[0] 92 | h = h.reshape((h.shape[0], self.k, self.m)).reshape((h.shape[0]*self.k, self.m)) 93 | c = c.reshape((c.shape[0], self.k, self.m)).reshape((c.shape[0]*self.k, self.m)) 94 | 95 | 96 | input = input.reshape(input.shape[0], 1, input.shape[1]) 97 | input = input.repeat(1,self.k,1) 98 | input = input.reshape(input.shape[0]*self.k, input.shape[2]) 99 | 100 | h_read = self.gll_read((h*1.0).reshape((h.shape[0], 1, h.shape[1]))) 101 | 102 | 103 | hnext_stack = [] 104 | cnext_stack = [] 105 | 106 | 107 | for template in self.templates:#[self.lstm1, self.lstm2, self.lstm3, self.lstm4]: 108 | hnext_l, cnext_l = template(input, (h, c)) 109 | 110 | hnext_l = hnext_l.reshape((hnext_l.shape[0], 1, hnext_l.shape[1])) 111 | cnext_l = cnext_l.reshape((cnext_l.shape[0], 1, cnext_l.shape[1])) 112 | 113 | hnext_stack.append(hnext_l) 114 | cnext_stack.append(cnext_l) 115 | 116 | hnext = torch.cat(hnext_stack, 1) 117 | cnext = torch.cat(cnext_stack, 1) 118 | 119 | 120 | write_key = self.gll_write(hnext) 121 | 122 | sm = nn.Softmax(2) 123 | att = sm(torch.bmm(h_read, write_key.permute(0, 2, 1))) 124 | 125 | #att = att*0.0 + 0.25 126 | 127 | #print('hnext shape before att', hnext.shape) 128 | hnext = torch.bmm(att, hnext) 129 | cnext = torch.bmm(att, cnext) 130 | 131 | hnext = hnext.mean(dim=1) 132 | cnext = cnext.mean(dim=1) 133 | 134 | hnext = hnext.reshape((bs, self.k, self.m)).reshape((bs, self.k*self.m)) 135 | cnext = cnext.reshape((bs, self.k, self.m)).reshape((bs, self.k*self.m)) 136 | 137 | #print('shapes', hnext.shape, cnext.shape) 138 | 139 | return hnext, cnext, att.data.reshape(bs,self.k,self.n_templates) 140 | 141 | 142 | 143 | if __name__ == "__main__": 144 | 145 | Blocks = BlockLSTM(2, 6, k=2) 146 | opt = torch.optim.Adam(Blocks.parameters()) 147 | 148 | pl = Blocks.lstm.parameters() 149 | 150 | inp = torch.randn(10,100,2) 151 | h = torch.randn(1,100,3*2) 152 | c = torch.randn(1,100,3*2) 153 | 154 | h2, c2 = Blocks(inp,h,c) 155 | 156 | L = h2.sum()**2 157 | 158 | L.backward() 159 | opt.step() 160 | opt.zero_grad() 161 | 162 | 163 | pl = Blocks.lstm.parameters() 164 | for p in pl: 165 | #print(p.shape) 166 | #print(torch.Size([Blocks.nhid*4])) 167 | if p.shape == torch.Size([Blocks.nhid*4]): 168 | print(p.shape, 'a') 169 | #print(p) 170 | '''biases, don't need to change anything here''' 171 | if p.shape == torch.Size([Blocks.nhid*4, Blocks.nhid]) or p.shape == torch.Size([Blocks.nhid*4, Blocks.ninp]): 172 | print(p.shape, 'b') 173 | for e in range(0,4): 174 | print(p[Blocks.nhid*e : Blocks.nhid*(e+1)]) 175 | -------------------------------------------------------------------------------- /synthetic/utilities/BlockLSTM.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | Goal1: an LSTM where the weight matrices have a block structure so that information flow is constrained 4 | 5 | Data is assumed to come in [block1, block2, ..., block_n]. 6 | 7 | Goal2: Dynamic parameter sharing between blocks (RIMs) 8 | 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | from .GroupLinearLayer import GroupLinearLayer 14 | ''' 15 | Given an N x N matrix, and a grouping of size, set all elements off the block diagonal to 0.0 16 | ''' 17 | def zero_matrix_elements(matrix, k): 18 | assert matrix.shape[0] % k == 0 19 | assert matrix.shape[1] % k == 0 20 | g1 = matrix.shape[0] // k 21 | g2 = matrix.shape[1] // k 22 | new_mat = torch.zeros_like(matrix) 23 | for b in range(0,k): 24 | new_mat[b*g1 : (b+1)*g1, b*g2 : (b+1)*g2] += matrix[b*g1 : (b+1)*g1, b*g2 : (b+1)*g2] 25 | 26 | matrix *= 0.0 27 | matrix += new_mat 28 | 29 | 30 | class BlockLSTM(nn.Module): 31 | """Container module with an encoder, a recurrent module, and a decoder.""" 32 | 33 | def __init__(self, ninp, nhid, k): 34 | super(BlockLSTM, self).__init__() 35 | 36 | assert ninp % k == 0 37 | assert nhid % k == 0 38 | 39 | self.k = k 40 | self.lstm = nn.LSTMCell(ninp, nhid) 41 | self.nhid = nhid 42 | self.ninp = ninp 43 | 44 | def blockify_params(self): 45 | pl = self.lstm.parameters() 46 | 47 | for p in pl: 48 | p = p.data 49 | if p.shape == torch.Size([self.nhid*4]): 50 | pass 51 | '''biases, don't need to change anything here''' 52 | if p.shape == torch.Size([self.nhid*4, self.nhid]) or p.shape == torch.Size([self.nhid*4, self.ninp]): 53 | for e in range(0,4): 54 | zero_matrix_elements(p[self.nhid*e : self.nhid*(e+1)], k=self.k) 55 | 56 | def forward(self, input, h, c): 57 | 58 | #self.blockify_params() 59 | 60 | hnext, cnext = self.lstm(input, (h, c)) 61 | 62 | return hnext, cnext, None 63 | 64 | class SharedBlockLSTM(nn.Module): 65 | """Dynamic sharing of parameters between blocks(RIM's)""" 66 | 67 | def __init__(self, ninp, nhid, k , n_templates): 68 | super(SharedBlockLSTM, self).__init__() 69 | 70 | assert ninp % k == 0 71 | assert nhid % k == 0 72 | 73 | self.k = k 74 | self.m = nhid // self.k 75 | self.n_templates = n_templates 76 | self.templates = nn.ModuleList([nn.LSTMCell(ninp,self.m) for _ in range(0,self.n_templates)]) 77 | self.nhid = nhid 78 | 79 | self.ninp = ninp 80 | 81 | self.gll_write = GroupLinearLayer(self.m,16, self.n_templates) 82 | self.gll_read = GroupLinearLayer(self.m,16,1) 83 | 84 | def blockify_params(self): 85 | 86 | return 87 | 88 | def forward(self, input, h, c): 89 | 90 | #self.blockify_params() 91 | bs = h.shape[0] 92 | h = h.reshape((h.shape[0], self.k, self.m)).reshape((h.shape[0]*self.k, self.m)) 93 | c = c.reshape((c.shape[0], self.k, self.m)).reshape((c.shape[0]*self.k, self.m)) 94 | 95 | 96 | input = input.reshape(input.shape[0], 1, input.shape[1]) 97 | input = input.repeat(1,self.k,1) 98 | input = input.reshape(input.shape[0]*self.k, input.shape[2]) 99 | 100 | h_read = self.gll_read((h*1.0).reshape((h.shape[0], 1, h.shape[1]))) 101 | 102 | 103 | hnext_stack = [] 104 | cnext_stack = [] 105 | 106 | 107 | for template in self.templates:#[self.lstm1, self.lstm2, self.lstm3, self.lstm4]: 108 | hnext_l, cnext_l = template(input, (h, c)) 109 | 110 | hnext_l = hnext_l.reshape((hnext_l.shape[0], 1, hnext_l.shape[1])) 111 | cnext_l = cnext_l.reshape((cnext_l.shape[0], 1, cnext_l.shape[1])) 112 | 113 | hnext_stack.append(hnext_l) 114 | cnext_stack.append(cnext_l) 115 | 116 | hnext = torch.cat(hnext_stack, 1) 117 | cnext = torch.cat(cnext_stack, 1) 118 | 119 | 120 | write_key = self.gll_write(hnext) 121 | 122 | sm = nn.Softmax(2) 123 | att = sm(torch.bmm(h_read, write_key.permute(0, 2, 1))) 124 | 125 | #att = att*0.0 + 0.25 126 | 127 | #print('hnext shape before att', hnext.shape) 128 | hnext = torch.bmm(att, hnext) 129 | cnext = torch.bmm(att, cnext) 130 | 131 | hnext = hnext.mean(dim=1) 132 | cnext = cnext.mean(dim=1) 133 | 134 | hnext = hnext.reshape((bs, self.k, self.m)).reshape((bs, self.k*self.m)) 135 | cnext = cnext.reshape((bs, self.k, self.m)).reshape((bs, self.k*self.m)) 136 | 137 | #print('shapes', hnext.shape, cnext.shape) 138 | 139 | return hnext, cnext, att.data.reshape(bs,self.k,self.n_templates) 140 | 141 | 142 | 143 | if __name__ == "__main__": 144 | 145 | Blocks = BlockLSTM(2, 6, k=2) 146 | opt = torch.optim.Adam(Blocks.parameters()) 147 | 148 | pl = Blocks.lstm.parameters() 149 | 150 | inp = torch.randn(10,100,2) 151 | h = torch.randn(1,100,3*2) 152 | c = torch.randn(1,100,3*2) 153 | 154 | h2, c2 = Blocks(inp,h,c) 155 | 156 | L = h2.sum()**2 157 | 158 | L.backward() 159 | opt.step() 160 | opt.zero_grad() 161 | 162 | 163 | pl = Blocks.lstm.parameters() 164 | for p in pl: 165 | #print(p.shape) 166 | #print(torch.Size([Blocks.nhid*4])) 167 | if p.shape == torch.Size([Blocks.nhid*4]): 168 | print(p.shape, 'a') 169 | #print(p) 170 | '''biases, don't need to change anything here''' 171 | if p.shape == torch.Size([Blocks.nhid*4, Blocks.nhid]) or p.shape == torch.Size([Blocks.nhid*4, Blocks.ninp]): 172 | print(p.shape, 'b') 173 | for e in range(0,4): 174 | print(p[Blocks.nhid*e : Blocks.nhid*(e+1)]) 175 | -------------------------------------------------------------------------------- /synthetic/utilities/slot_attention_custom_2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import math 4 | 5 | from torch.distributions.multivariate_normal import MultivariateNormal 6 | 7 | class MLP(nn.Module): 8 | def __init__(self, in_dim, out_dim): 9 | super().__init__() 10 | self.mlp = nn.Sequential(nn.Linear(in_dim, 64), 11 | nn.ReLU(), 12 | nn.Linear(64, out_dim)) 13 | def forward(self, x): 14 | return self.mlp(x) 15 | 16 | class SlotAttentionV2(nn.Module): 17 | def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128, out_dim = 64, in_dim = 64): 18 | super().__init__() 19 | self.num_slots = num_slots 20 | self.iters = iters 21 | self.eps = eps 22 | self.scale = dim ** -0.5 23 | self.dim = dim 24 | 25 | #self.slots_mu = nn.Parameter(torch.randn(1, 1, dim)) 26 | #self.slots_sigma = nn.Parameter(torch.randn(1, 1, dim)) 27 | #self.slots_rep = nn.Linear(dim, 2*dim) 28 | #self.slots_rep_pr = nn.Linear(dim, 2*dim) 29 | print("using iterations", self.iters) 30 | 31 | 32 | #self.slots_rep = nn.Linear(dim, 2*dim) 33 | 34 | self.slots_rep = nn.Sequential( 35 | nn.Linear(dim, hidden_dim), 36 | nn.ReLU(inplace = True), 37 | nn.Linear(hidden_dim, 2*dim) 38 | ) 39 | self.attn_images = [] 40 | self.cur_image = None 41 | self.colors = [0,0,255] 42 | #self.slots_rep_pr = nn.Linear(dim, 2*dim) 43 | self.slots_rep_pr = nn.Sequential( 44 | nn.Linear(dim, hidden_dim), 45 | nn.ReLU(inplace = True), 46 | nn.Linear(hidden_dim, 2*dim) 47 | ) 48 | 49 | self.to_q = nn.Linear(dim, dim) 50 | self.to_k = nn.Linear(in_dim, dim) 51 | self.to_v = nn.Linear(in_dim, dim) 52 | self.gru = nn.GRUCell(dim, dim) 53 | hidden_dim = max(dim, hidden_dim) 54 | self.mlp = nn.Sequential( 55 | nn.Linear(dim, hidden_dim), 56 | nn.ReLU(inplace = True), 57 | nn.Linear(hidden_dim, dim) 58 | ) 59 | self.mlp_cast = MLP(dim, out_dim) 60 | self.norm_input = nn.LayerNorm(in_dim) 61 | self.norm_slots = nn.LayerNorm(dim) 62 | self.norm_pre_ff = nn.LayerNorm(dim) 63 | 64 | 65 | def reparameterize(self, mu, logvar, eps = None, deterministic=False): 66 | if not deterministic: 67 | std = logvar.mul(0.5).exp_() 68 | if eps is None: 69 | eps = std.data.new(std.size()).normal_() 70 | return eps.mul(std).add_(mu), eps 71 | else: 72 | return mu, eps 73 | 74 | def get_kl_loss(self, prior, post): 75 | prior = prior.reshape(-1, 2 * self.dim) 76 | post = post.reshape(-1, 2 * self.dim) 77 | mean_pr, std_pr = prior[:, : self.dim], prior[:, self.dim : 2 * self.dim] 78 | mean_po, std_po = post[:, : self.dim], post[:, self.dim : 2 * self.dim] 79 | log_var_po = std_po 80 | 81 | 82 | std_pr = std_pr.mul(0.5).exp_() 83 | std_po = std_po.mul(0.5).exp_() 84 | q1 = MultivariateNormal(loc=mean_pr, scale_tril=torch.diag_embed(std_pr)) 85 | q2 = MultivariateNormal(loc=mean_po, scale_tril=torch.diag_embed(std_po)) 86 | kl = torch.distributions.kl.kl_divergence(q2, q1) 87 | 88 | 89 | #mu, log_var = mean_po, log_var_po 90 | #kl = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 91 | 92 | return kl 93 | 94 | def add_image(self): 95 | self.attn_images.append(self.cur_image) 96 | 97 | def reset_images(self): 98 | self.attn_images = [] 99 | 100 | 101 | def forward(self, inputs, slots, prior_slots, num_slots = None): 102 | b, n, d_ = inputs.shape 103 | n_s = num_slots if num_slots is not None else self.num_slots 104 | d = slots.size(-1) 105 | inputs = self.norm_input(inputs) 106 | k, v = self.to_k(inputs), self.to_v(inputs) 107 | slots_mu_sigma = self.slots_rep(slots.reshape((slots.shape[0] * n_s, self.dim))) 108 | prior_mu_sigma = self.slots_rep_pr(prior_slots.reshape((slots.shape[0] * n_s, self.dim)).detach()) 109 | prior = prior_mu_sigma.reshape(-1, 2 * self.dim) 110 | slots_mu_sigma = slots_mu_sigma.reshape((b, n_s, 2*self.dim)) 111 | mu = slots_mu_sigma[:, :, :self.dim] 112 | sigma = slots_mu_sigma[:, :, self.dim: 2*self.dim] 113 | slots, eps = self.reparameterize(mu, sigma) 114 | kl_loss = 0 115 | for t in range(self.iters): 116 | slots_prev = slots 117 | #slots = self.norm_slots(slots) 118 | q = self.to_q(slots) 119 | dots = torch.einsum('bid,bjd->bij', q, k) * self.scale 120 | attn = dots.softmax(dim=1) + self.eps 121 | self.attn = attn 122 | attn_value = attn.clone().detach() 123 | attn_image_ = torch.zeros(self.num_slots, 3, int(math.sqrt(attn_value.size(2))), int(math.sqrt(attn_value.size(2)))).to(attn_value.device) 124 | attn_image_ = attn_image_.view((self.num_slots, 3, -1)) 125 | attn_image_[:,2,:] = 255 126 | for num_ in range(self.num_slots): 127 | attn_image_[num_, 2] = attn_image_[num_, 2] * attn_value[0, num_] 128 | attn_image_ = attn_image_.view(self.num_slots, 3, int(math.sqrt(attn_value.size(2))), int(math.sqrt(attn_value.size(2)))) 129 | self.cur_image = attn_image_ 130 | 131 | attn = attn / attn.sum(dim=-1, keepdim=True) 132 | updates = torch.einsum('bjd,bij->bid', v, attn) 133 | slots = self.gru( 134 | updates.reshape(-1, d), 135 | slots_prev.reshape(-1, d) 136 | ) 137 | slots = slots.reshape(b, -1, d) 138 | slots = slots + self.mlp(self.norm_pre_ff(slots)) 139 | slots_mu_sigma = self.slots_rep(slots.reshape((slots.shape[0] * n_s, self.dim))) 140 | slots_mu_sigma = slots_mu_sigma.reshape((b, n_s, 2*self.dim)) 141 | kl_loss += self.get_kl_loss(prior, slots_mu_sigma) 142 | return slots, kl_loss 143 | -------------------------------------------------------------------------------- /MNIST/crl/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, TensorDataset 2 | 3 | from .dataloader.mnist_dataset import load_mnist_datasets 4 | from .dataloader.mnist_transforms import shrink_mnist_dataset, cuda_mnist_dataset, \ 5 | TransformationCombinationDataLoader, BasicTransformDataset, TrainTransformDataset, ColorTransformDataset 6 | from .dataloader.transformation_combination import ConcatTransformationCombiner, TransformationCombiner, \ 7 | SpatialImageTransformations 8 | from .dataloader.image_transforms import * 9 | import cv2 10 | 11 | 12 | def show_trajectory(traj, transforms): 13 | images = [] 14 | transforms = transforms + ['nothing'] 15 | traj = torch.split(traj, 1, dim = 0) 16 | for i, t in enumerate(traj): 17 | t_ = convert_image_np(t.squeeze(0)) 18 | print(t_.shape) 19 | write = np.zeros((30, t_.shape[1],t_.shape[2])) 20 | 21 | write = cv2.putText(write, transforms[i], (10, 15), cv2.FONT_HERSHEY_SIMPLEX , 22 | 0.3, (255,255,255)) 23 | t_ = np.concatenate((t_, write), axis = 0) 24 | zeros = np.zeros((t_.shape[0], 50, t_.shape[2])) 25 | images.append(t_) 26 | images.append(zeros) 27 | 28 | images = np.concatenate(images[:-1], axis = 1) 29 | cv2.imshow('img', images) 30 | cv2.waitKey(0) 31 | cv2.destroyAllWindows() 32 | 33 | def batchify(batch): 34 | images = [] 35 | transforms = [] 36 | transform_vector = [] 37 | for b in batch: 38 | images.append(b[0]) 39 | transforms.append(b[1]) 40 | transform_vector.append(b[2]) 41 | return torch.stack(images, dim = 0), transforms, torch.stack(transform_vector, dim = 0) 42 | 43 | def get_dataloaders(num_transforms = 4, transform_length = 3, batch_size = 50, color = True, shuffle = True): 44 | mnist_orig = load_mnist_datasets('../data', normalize=False) 45 | train_data = mnist_orig['train'] 46 | val_data = mnist_orig['val'] 47 | test_data = mnist_orig['test'] 48 | if not color: 49 | train_dataset = BasicTransformDataset(train_data[0], train_data[1], num_transforms = num_transforms, transform_length = transform_length, gen = False, state = "Train") 50 | 51 | val_dataset = BasicTransformDataset(val_data[0], val_data[1], num_transforms = num_transforms, transform_length = transform_length, gen = False, state = "Val") 52 | 53 | test_dataset = BasicTransformDataset(test_data[0], test_data[1], num_transforms = num_transforms, transform_length = transform_length, gen = False, state = "Val") 54 | else: 55 | 56 | train_dataset = BasicTransformDataset(train_data[0], train_data[1], num_transforms = num_transforms, transform_length = transform_length, gen = True, state = "Train") 57 | 58 | val_dataset = BasicTransformDataset(val_data[0], val_data[1], num_transforms = num_transforms, transform_length = transform_length, gen = True, state = "Val") 59 | 60 | test_dataset = BasicTransformDataset(test_data[0], test_data[1], num_transforms = num_transforms, transform_length = transform_length, gen = True, state = "Val") 61 | 62 | 63 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = shuffle, num_workers = 4, collate_fn = batchify) 64 | val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size = batch_size, shuffle = shuffle, num_workers = 4, collate_fn = batchify) 65 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size = batch_size, shuffle = shuffle, num_workers = 4, collate_fn = batchify) 66 | 67 | 68 | 69 | """mnist_shrunk = shrink_mnist_dataset(mnist_orig, (64, 64)) 70 | if use_cuda: mnist_shrunk = cuda_mnist_dataset(mnist_shrunk) 71 | 72 | kwargs = {} 73 | 74 | transform_config = lambda: SpatialImageTransformations(cuda=use_cuda, **kwargs) 75 | 76 | if mix_in_normal: 77 | train_combiner = ConcatTransformationCombiner(transformation_combiners=[ 78 | TransformationCombiner(transform_config(), name='3c2_RT', mode='train', cuda=use_cuda), 79 | TransformationCombiner(transform_config(), name='identity', mode='train', cuda=use_cuda)]) 80 | else: 81 | train_combiner = TransformationCombiner(transform_config(), name='3c2_RT', mode='train', cuda=use_cuda) 82 | 83 | transformation_combinations = { 84 | 'train': train_combiner, 85 | 'val': TransformationCombiner(transform_config(), name='3c2_RT', mode='val', cuda=use_cuda), 86 | 'test': TransformationCombiner(transform_config(), name='3c3_SRT', mode='test', cuda=use_cuda), 87 | } 88 | 89 | dataloader = TransformationCombinationDataLoader( 90 | dataset=mnist_shrunk, 91 | transformation_combinations=transformation_combinations, 92 | transform_config=transform_config(), 93 | cuda=use_cuda) # although we can imagine not doing this""" 94 | return train_dataloader, val_dataloader, test_dataloader 95 | 96 | 97 | 98 | 99 | """def get_dataloaders(args, use_cuda, should_shuffle=True, mix_in_normal=False): 100 | Method to return the dataloaders 101 | data = load_image_xforms_env(use_cuda=use_cuda, 102 | mix_in_normal=mix_in_normal) 103 | 104 | modes = ["train", "val", "test"] 105 | shuffle_list = [should_shuffle, False, False] 106 | 107 | def _collate_fn(batch): 108 | return batch 109 | 110 | def _get_dataloader(dataset, shuffle): 111 | print('dataset.py:line 51' + str(dataset[0].size())) 112 | return DataLoader(dataset=TensorDataset(dataset[0], 113 | dataset[1]), 114 | batch_size=args.batch_size, 115 | shuffle=shuffle, 116 | num_workers=0,) 117 | # collate_fn=_collate_fn) 118 | print('dataset.py:line 58:'+ str(data.dataloaders['test'])) 119 | import torch 120 | return [torch.utils.data.DataLoader(data.dataloaders[mode], batch_size = 4) for 121 | mode, shuffle in zip(modes, shuffle_list)]""" 122 | 123 | # get_dataloaders(use_cuda=False, mix_in_normal=False) 124 | 125 | if __name__ == '__main__': 126 | get_dataloaders() -------------------------------------------------------------------------------- /MNIST/block_wrapper.py: -------------------------------------------------------------------------------- 1 | #from rnn_models_bb import RNNModel 2 | #import baseline_lstm_model 3 | import torch 4 | import torch.nn as nn 5 | from modularity import RIMv2, SCOFFv2 6 | #from rnn_models_wiki import RNNModel as RNNModelRules 7 | 8 | class Identity(torch.autograd.Function): 9 | @staticmethod 10 | def forward(ctx, input): 11 | return input * 1.0 12 | def backward(ctx, grad_output): 13 | print(grad_output) 14 | return grad_output * 1.0 15 | 16 | class BlockWrapper(nn.Module): 17 | 18 | def __init__(self, ntokens, nhid, n_out, device = None, **kwargs): 19 | super(BlockWrapper, self).__init__() 20 | num_blocks = kwargs['num_blocks'] 21 | update_topk = kwargs['topk'] 22 | memory_topk = kwargs['memorytopk'] 23 | num_modules_read_input = kwargs['num_modules_read_input'] 24 | inp_heads=kwargs['inp_heads'] 25 | n_templates = kwargs['n_templates'] 26 | algo = kwargs['algo'] 27 | dropout = kwargs['dropout'] 28 | do_rel= kwargs['do_rel'] 29 | memory_slots= kwargs['memory_slots'] 30 | num_memory_heads= kwargs['num_memory_heads'] 31 | memory_head_size= kwargs['memory_head_size'] 32 | share_inp= kwargs["share_inp"] 33 | share_comm= kwargs["share_comm"] 34 | memory_mlp= kwargs["memory_mlp"] 35 | attention_out=kwargs["attention_out"] 36 | version=kwargs["version"] 37 | step_att=kwargs["step_att"] 38 | num_rules = kwargs['num_rules'] 39 | rule_time_steps = kwargs['rule_time_steps'] 40 | rule_selection = kwargs['rule_selection'] 41 | 42 | application_option = kwargs['application_option'] 43 | rule_dim = kwargs['rule_dim'] 44 | print("Number of blocks %s, Updating top K %s, number of input modules %s, number of input heads", num_blocks, update_topk, num_modules_read_input, inp_heads) 45 | rule_config = {'rule_time_steps': rule_time_steps, 'num_rules': num_rules, 46 | 'rule_emb_dim': 64, 'rule_query_dim':32, 'rule_value_dim':64, 'rule_key_dim': 32, 47 | 'rule_heads':4,'rule_dropout': 0.5} 48 | design_config = {'comm': True , 'grad': False, 'transformer': True, 'application_option': 3} 49 | if algo == 'SCOFF': 50 | print('ntoken:' + str(ntokens)) 51 | print('n_hid:' + str(nhid)) 52 | print('n_templates:' + str(n_templates)) 53 | print('dropout:' + str(dropout)) 54 | print('num_blocks:' + str(num_blocks)) 55 | print('update_topk:' + str(update_topk)) 56 | print('num_modules_read_input:'+str(num_modules_read_input)) 57 | print('inp_heads:' + str(inp_heads)) 58 | print('device:' + str(device)) 59 | print('share_comm:' + str(share_comm)) 60 | print('share_inp:' + str(share_inp)) 61 | print('attention_out:' + str(attention_out)) 62 | print('version:' + str(version)) 63 | print('step_att:' + str(step_att)) 64 | 65 | self.myrnn = SCOFFv2('cuda', ntokens, nhid, num_blocks, update_topk, num_templates = n_templates, rnn_cell = 'GRU', 66 | n_layers = 1, bidirectional = False, num_rules = num_rules, rule_time_steps = rule_time_steps, perm_inv = True, application_option = application_option, 67 | version=version, attention_out=attention_out, rule_dim = rule_dim, step_att=step_att, dropout=dropout, rule_selection = rule_selection) 68 | 69 | elif algo in ['GRU','LSTM']: 70 | print("Using Baseline RNN") 71 | self.myrnn = nn.GRU(ntokens, nhid) 72 | 73 | elif algo == 'RIM': 74 | self.myrnn = RIMv2('cuda', ntokens, nhid, num_blocks, update_topk, rnn_cell = 'GRU', n_layers = 1, 75 | bidirectional = False, num_rules = num_rules, rule_time_steps = rule_time_steps, application_option = application_option, 76 | version=version, attention_out=attention_out, step_att=step_att, rule_dim = rule_dim, dropout=dropout, rule_selection = rule_selection) 77 | 78 | #elif algo == 'GWT': 79 | # #self.lstm = GWT('cuda', num_inputs, 500, num_units, num_units, memorytopk=memorytopk, memory_slots=memory_slots, 80 | # # num_memory_heads=num_memory_heads, memory_head_size=memory_head_size, rnn_cell = 'GRU', 81 | # # version=2, attention_out=att_out, step_att=False) 82 | 83 | # self.myrnn = GWT('cuda', ntokens, nhid, num_blocks, update_topk, rnn_cell = 'GRU', n_layers = 1, 84 | # bidirectional = False, version=version, attention_out=attention_out, step_att=step_att, dropout=dropout, 85 | # memorytopk=memory_topk, memory_slots=memory_slots, num_memory_heads=num_memory_heads, memory_head_size=memory_head_size) 86 | # #num_rules = num_rules, rule_time_steps = rule_time_steps, application_option = application_option, 87 | # #version=version, attention_out=attention_out, step_att=step_att, dropout=dropout, rule_selection = rule_selection) 88 | 89 | 90 | else: 91 | raise ValueError('Algo format {} not recognized.'.format(algo)) 92 | 93 | self.nhid = nhid 94 | self.algo = algo 95 | 96 | def forward(self, inp, h): 97 | assert len(h.shape) == 3 98 | assert len(inp.shape) == 3 99 | hidden = (h, h) 100 | entropy = 0 101 | if self.algo in ['GRU', 'LSTM']: 102 | ob, hb = self.myrnn(inp, hidden[0]) 103 | else: 104 | ob, hidden, entropy = self.myrnn(inp, hidden) 105 | hb = hidden[0] 106 | return ob,hb, None, None, None, entropy 107 | 108 | 109 | if __name__ == "__main__": 110 | nhid = 600 111 | ntokens = 10 112 | 113 | blocks = BlockWrapper(ntokens, nhid, n_out=nhid) 114 | gru = torch.nn.GRU(ntokens, nhid).cuda() 115 | 116 | x = torch.randn(50, 64, 10).cuda() 117 | 118 | h0 = torch.randn(1, 64, nhid).cuda() 119 | h0_blocks = torch.randn(1, 64, nhid*2).cuda() 120 | 121 | og, hg = gru(x, h0) 122 | print('gru of x: o,h', og.shape, hg.shape) 123 | 124 | ob, hb = blocks(x, h0_blocks) 125 | print('block res: o,h', ob.shape, hb.shape) 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /synthetic/utilities/slot_attention_custom.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.distributions.multivariate_normal import MultivariateNormal 4 | 5 | 6 | 7 | class MLP(nn.Module): 8 | def __init__(self, in_dim, out_dim): 9 | super().__init__() 10 | self.mlp = nn.Sequential(nn.Linear(in_dim, 64), 11 | nn.ReLU(), 12 | nn.Linear(64, out_dim)) 13 | def forward(self, x): 14 | return self.mlp(x) 15 | 16 | class SlotAttention(nn.Module): 17 | def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128, out_dim = 64, in_dim = 64): 18 | super().__init__() 19 | self.num_slots = num_slots 20 | self.iters = iters 21 | self.eps = eps 22 | self.scale = dim ** -0.5 23 | self.dim = dim 24 | 25 | #self.slots_mu = nn.Parameter(torch.randn(1, 1, dim)) 26 | #self.slots_sigma = nn.Parameter(torch.randn(1, 1, dim)) 27 | 28 | #self.slots_rep = nn.Linear(dim, 2*dim) 29 | #self.slots_rep_pr = nn.Linear(dim, 2*dim) 30 | 31 | self.slots_rep = nn.Sequential( 32 | nn.Linear(dim, hidden_dim), 33 | nn.ReLU(inplace = True), 34 | nn.Linear(hidden_dim, 2*dim) 35 | ) 36 | self.slots_rep_pr = nn.Sequential( 37 | nn.Linear(dim, hidden_dim), 38 | nn.ReLU(inplace = True), 39 | nn.Linear(hidden_dim, 2*dim) 40 | ) 41 | 42 | self.to_q = nn.Linear(dim, dim) 43 | self.to_k = nn.Linear(in_dim, dim) 44 | self.to_v = nn.Linear(in_dim, dim) 45 | 46 | self.gru = nn.GRUCell(dim, dim) 47 | 48 | hidden_dim = max(dim, hidden_dim) 49 | 50 | self.mlp = nn.Sequential( 51 | nn.Linear(dim, hidden_dim), 52 | nn.ReLU(inplace = True), 53 | nn.Linear(hidden_dim, dim) 54 | ) 55 | 56 | #self.mlp_cast = MLP(dim, out_dim) 57 | 58 | self.norm_input = nn.LayerNorm(in_dim) 59 | self.norm_slots = nn.LayerNorm(dim) 60 | self.norm_pre_ff = nn.LayerNorm(dim) 61 | 62 | def reparameterize(self, mu, logvar, eps = None, deterministic=False): 63 | if not deterministic: 64 | std = logvar.mul(0.5).exp_() 65 | if eps is None: 66 | eps = std.data.new(std.size()).normal_() 67 | return eps.mul(std).add_(mu), eps 68 | else: 69 | return mu, eps 70 | 71 | def get_kl_loss(self, prior, post): 72 | prior = prior.reshape(-1, 2 * self.dim) 73 | post = post.reshape(-1, 2 * self.dim) 74 | 75 | mean_pr, std_pr = prior[:, : self.dim], prior[:, self.dim : 2 * self.dim] 76 | mean_po, std_po = post[:, : self.dim], post[:, self.dim : 2 * self.dim] 77 | 78 | std_pr = std_pr.mul(0.5).exp_() 79 | std_po = std_po.mul(0.5).exp_() 80 | 81 | q1 = MultivariateNormal(loc=mean_pr, scale_tril=torch.diag_embed(std_pr)) 82 | q2 = MultivariateNormal(loc=mean_po, scale_tril=torch.diag_embed(std_po)) 83 | kl = torch.distributions.kl.kl_divergence(q2, q1) 84 | return kl 85 | 86 | def forward(self, inputs, slots, num_slots = None): 87 | b, n, d_ = inputs.shape 88 | n_s = num_slots if num_slots is not None else self.num_slots 89 | 90 | d = slots.size(-1) 91 | inputs = self.norm_input(inputs) 92 | k, v = self.to_k(inputs), self.to_v(inputs) 93 | slots_mu_sigma = self.slots_rep(slots.reshape((slots.shape[0] * n_s, self.dim))) 94 | prior_mu_sigma = self.slots_rep_pr(slots.reshape((slots.shape[0] * n_s, self.dim)).detach()) 95 | prior = prior_mu_sigma.reshape(-1, 2 * self.dim) 96 | 97 | slots_mu_sigma = slots_mu_sigma.reshape((b, n_s, 2*self.dim)) 98 | 99 | mu = slots_mu_sigma[:, :, :self.dim] 100 | sigma = slots_mu_sigma[:, :, self.dim: 2*self.dim] 101 | slots, eps = self.reparameterize(mu, sigma) 102 | 103 | kl_loss = 0 104 | for t in range(self.iters): 105 | slots_prev = slots 106 | slots = self.norm_slots(slots) 107 | q = self.to_q(slots) 108 | dots = torch.einsum('bid,bjd->bij', q, k) * self.scale 109 | attn = dots.softmax(dim=1) + self.eps 110 | attn = attn / attn.sum(dim=-1, keepdim=True) 111 | updates = torch.einsum('bjd,bij->bid', v, attn) 112 | slots = self.gru( 113 | updates.reshape(-1, d), 114 | slots_prev.reshape(-1, d) 115 | ) 116 | slots = slots.reshape(b, -1, d) 117 | slots = slots + self.mlp(self.norm_pre_ff(slots)) 118 | slots_mu_sigma = self.slots_rep(slots.reshape((slots.shape[0] * n_s, self.dim))) 119 | slots_mu_sigma = slots_mu_sigma.reshape((b, n_s, 2*self.dim)) 120 | kl_loss += ((t+1)/self.iters) * self.get_kl_loss(prior, slots_mu_sigma) 121 | return slots, kl_loss 122 | 123 | def old_forward(self, inputs, slots, num_slots = None): 124 | 125 | b, n, d_ = inputs.shape 126 | n_s = num_slots if num_slots is not None else self.num_slots 127 | 128 | 129 | self.slots_mu_sigma = self.slots_rep(slots.reshape((slots.shape[0] * n_s, self.dim))) 130 | self.slots_mu_sigma = self.slots_mu_sigma.reshape((b, n_s, 2*self.dim)) 131 | mu = self.slots_mu_sigma[:, :, :self.dim] 132 | sigma = self.slots_mu_sigma[:, :, self.dim: 2*self.dim] 133 | 134 | #slots = torch.normal(mu, sigma) 135 | slots, eps = self.reparameterize(mu, sigma) 136 | 137 | d = slots.size(-1) 138 | 139 | inputs = self.norm_input(inputs) 140 | k, v = self.to_k(inputs), self.to_v(inputs) 141 | 142 | slots_mu_sigma = self.slots_rep(slots.reshape((slots.shape[0] * n_s, self.dim))) 143 | prior = slots_mu_sigma.reshape(-1, 2 * self.dim) 144 | 145 | kl_loss = 0 146 | 147 | for _ in range(self.iters): 148 | slots_prev = slots 149 | 150 | slots = self.norm_slots(slots) 151 | q = self.to_q(slots) 152 | 153 | dots = torch.einsum('bid,bjd->bij', q, k) * self.scale 154 | attn = dots.softmax(dim=1) + self.eps 155 | attn = attn / attn.sum(dim=-1, keepdim=True) 156 | 157 | updates = torch.einsum('bjd,bij->bid', v, attn) 158 | 159 | slots = self.gru( 160 | updates.reshape(-1, d), 161 | slots_prev.reshape(-1, d) 162 | ) 163 | 164 | slots = slots.reshape(b, -1, d) 165 | slots = slots + self.mlp(self.norm_pre_ff(slots)) 166 | 167 | #slots = self.mlp_cast(slots) 168 | 169 | return slots 170 | -------------------------------------------------------------------------------- /MNIST/crl/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import groupby 2 | from operator import itemgetter 3 | import torch 4 | from torch.autograd import Variable 5 | from torch.distributions import Categorical 6 | 7 | def printf(logger, args, string): 8 | if args.printf: 9 | f = open(logger.logdir+'.txt', 'a') 10 | print >>f, string 11 | else: 12 | print(string) 13 | 14 | def create_exp_string(args, relevant_arg_names, prefix, suffix): 15 | string = prefix + '_' 16 | d = vars(args) 17 | for key in sorted(set(relevant_arg_names)): 18 | val = d[key] 19 | to_append = key if isinstance(val, bool) else key + '_' + str(val) 20 | string += to_append + '_' 21 | string += suffix 22 | return string 23 | 24 | def inrange(value, interval): 25 | """ 26 | Outputs whether value > interval[0] 27 | and < interval[1], inclusive 28 | """ 29 | # return value >= interval[0] and value <= interval[1] # NOTE I will change this to exclusive!! 30 | return value >= interval[0] and value < interval[1] # NOTE I will change this to exclusive!! 31 | 32 | def group_consecutive(list_of_numbers): 33 | groups = [] 34 | for k, g in groupby(enumerate(list_of_numbers), lambda i_x: i_x[0]-i_x[1]): 35 | mg = map(itemgetter(1), g) 36 | groups.append(tuple(mg)) 37 | return groups 38 | 39 | def cuda_if_needed(x, args): 40 | if args.cuda: 41 | return x.cuda() 42 | else: 43 | return x 44 | 45 | def group_by_element(list_of_numbers): 46 | """ 47 | m = [3,3,3,1,2,2,4,4,4,4,5,5,5,5,5] 48 | idx, vals = group(m) 49 | --> 50 | idx = [[0, 1, 2], [3], [4, 5], [6, 7, 8, 9], [10, 11, 12, 13, 14]] 51 | vals = [[3, 3, 3], [1], [2, 2], [4, 4, 4, 4], [5, 5, 5, 5, 5]] 52 | """ 53 | vals = [list(v) for k,v in groupby(list_of_numbers)] 54 | idx = [] 55 | a = range(len(list_of_numbers)) 56 | i = 0 57 | for sublist in vals: 58 | j = i + len(sublist) 59 | idx.append(a[i:j]) 60 | i = j 61 | return idx, vals 62 | 63 | 64 | def permute(list_of_numbers, indices): 65 | return [list_of_numbers[i] for i in indices] 66 | 67 | def group_by_indices(list_of_numbers, idx_groupings): 68 | return [[list_of_numbers[i] for i in g] for g in idx_groupings] 69 | 70 | def invert_permutation(indices): 71 | return [i for i, j in sorted(enumerate(indices), key=lambda k_j: k_j[0])] 72 | 73 | def sort_group_perm(lengths): 74 | perm_idx, sorted_lengths = sort_decr(lengths) 75 | group_idx, group_lengths = group_by_element(sorted_lengths) 76 | inverse_perm_idx = invert_permutation(perm_idx) 77 | return perm_idx, group_idx, inverse_perm_idx 78 | 79 | def sort_decr(lengths): 80 | perm_idx, sorted_lengths = zip(*[(c, d) for c, d in sorted(enumerate(lengths), key=lambda x: x[1], reverse=True)]) 81 | return perm_idx, sorted_lengths 82 | 83 | def var_length_in_batch_wrapper(fn, inputs, inputs_xform, input_to_group_by, args): 84 | """ 85 | 1. permutes by length 86 | 2. groups by length 87 | 3. applies neural network 88 | 4. unpermutes output of neural network 89 | 90 | args: 91 | fn: neural network function 92 | """ 93 | # sort by length 94 | lengths = [len(e) for e in input_to_group_by] 95 | perm_idx, sorted_lengths = sort_decr(lengths) 96 | inputs_p = map(lambda x: permute(x, perm_idx), inputs) 97 | # group by sorted length 98 | group_idx, group_lengths = group_by_element(sorted_lengths) 99 | inputs_grp = map(lambda x: group_by_indices(x, group_idx), inputs_p) 100 | # convert every group in inputs_grp to torch tensor 101 | inputs_grp_th = map(lambda f_y: map(f_y[0], f_y[1]), zip(inputs_xform, inputs_grp)) 102 | def execute_fn_on_grouped_inputs(fn, grouped_inputs): 103 | outputs = [] 104 | for inp in zip(*grouped_inputs): 105 | out = fn(*inp) 106 | outputs.append(out) # does not mess up the Variable. 107 | outputs = torch.cat(outputs) 108 | return outputs 109 | # run network 110 | outputs_p = execute_fn_on_grouped_inputs(fn, inputs_grp_th) # Variable 111 | # unpermute 112 | inverse_perm_idx = invert_permutation(perm_idx) 113 | inverse_perm_idx_th = cuda_if_needed(torch.LongTensor(inverse_perm_idx), args) 114 | outputs = outputs_p[inverse_perm_idx_th] # hopefully this doesn't mess up the gradient computation... 115 | return outputs 116 | 117 | def var_length_var_dim_in_batch_wrapper(fn, inputs, inputs_xform, input_to_group_by, args): 118 | """ 119 | 1. permutes by length 120 | 2. groups by length 121 | 3. applies neural network 122 | 4. unpermutes output of neural network 123 | 124 | args: 125 | fn: neural network function 126 | """ 127 | # sort by length 128 | lengths = [len(e) for e in input_to_group_by] 129 | perm_idx, sorted_lengths = sort_decr(lengths) 130 | inputs_p = map(lambda x: permute(x, perm_idx), inputs) 131 | # group by sorted length 132 | group_idx, group_lengths = group_by_element(sorted_lengths) 133 | inputs_grp = map(lambda x: group_by_indices(x, group_idx), inputs_p) 134 | # convert every group in inputs_grp to torch tensor 135 | inputs_grp_th = map(lambda f, y: map(f, y), zip(inputs_xform, inputs_grp)) 136 | def execute_fn_on_grouped_inputs(fn, grouped_inputs): 137 | outputs = [] 138 | for inp in zip(*grouped_inputs): 139 | out = fn(*inp) 140 | outputs.append(out) # does not mess up the Variable. 141 | outputs = torch.cat(outputs) 142 | return outputs 143 | # run network 144 | outputs_p = execute_fn_on_grouped_inputs(fn, inputs_grp_th) # Variable 145 | # unpermute 146 | inverse_perm_idx = invert_permutation(perm_idx) 147 | inverse_perm_idx_th = cuda_if_needed(torch.LongTensor(inverse_perm_idx), args) 148 | outputs = outputs_p[inverse_perm_idx_th] # hopefully this doesn't mess up the gradient computation... 149 | return outputs 150 | 151 | def reverse(x, dim): 152 | idx = torch.LongTensor([i for i in range(x.size(dim)-1, -1, -1)]) 153 | if isinstance(x, torch.autograd.variable.Variable): 154 | idx = Variable(idx) 155 | if 'cuda' in x.data.type(): 156 | idx = idx.cuda() 157 | else: 158 | assert 'Tensor' in x.type() 159 | if 'cuda' in x.type(): 160 | idx = idx.cuda() 161 | return x.index_select(dim, idx) 162 | 163 | 164 | def entropy(dist, eps=1e-20): 165 | dist_eps = dist + eps 166 | log_action_dist = torch.log(dist_eps) 167 | h = -torch.sum(log_action_dist * dist_eps) 168 | return h 169 | 170 | def sample_from_categorical_dist(dist): 171 | m = Categorical(dist) 172 | s = m.sample() 173 | return s 174 | 175 | def logprob_categorical_dist(dist, s): 176 | m = Categorical(dist) 177 | lp = m.log_prob(s) 178 | return lp 179 | 180 | -------------------------------------------------------------------------------- /MNIST/crl/dataloader/pretrain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from multilingual_dataset import ArithmeticLanguageWordEncoding, ArithmeticLanguageTranslation 5 | 6 | class PreTrainDataset(object): 7 | def __init__(self, max_terms, num_range, ops, samplefrom, episodecap, root, nlang): 8 | super(PreTrainDataset, self).__init__() 9 | 10 | envbuilder = lambda m, c, p: ArithmeticLanguageTranslation( 11 | max_terms=m, 12 | num_range=num_range, 13 | ops=ops, 14 | samplefrom=samplefrom, 15 | episodecap=episodecap, 16 | cheat=c, 17 | root=root, 18 | curr=True, 19 | pair=p, 20 | nlang=nlang) 21 | 22 | self.pairs = self.create_pairs(nlang, 'ed') 23 | self.datasets = self.create_datasets(envbuilder, self.pairs, nlang) 24 | 25 | assert all(self.datasets[0].vocabulary == d.vocabulary for d in self.datasets) 26 | assert all(self.datasets[0].langsize == d.langsize for d in self.datasets) 27 | assert all(self.datasets[0].zsize == d.zsize for d in self.datasets) 28 | 29 | self.langsize = self.datasets[0].langsize 30 | self.zsize = self.datasets[0].zsize 31 | self.vocabulary = self.datasets[0].vocabulary 32 | self.vocabsize = self.datasets[0].vocabsize 33 | self.current_exp_strs = [] 34 | 35 | def initialize_data(self, splits): 36 | map(lambda x: x.initialize_data(splits), self.datasets) 37 | 38 | def initialize_printer(self, logger, args): 39 | map(lambda x: x.initialize_printer(logger, args), self.datasets) 40 | 41 | def create_datasets(self, envbuilder, pairs, nlang): 42 | datasets = [] 43 | # add translator dataset encoder-decoder 44 | print('ADDING TRANSLATOR DATASETS') 45 | for pair in pairs: 46 | print('PAIR', pair) 47 | datasets.append(envbuilder(m=[2,2,2], c=False, p=pair)) 48 | # add reducer dataset 49 | print('ADDING REDUCER DATASETS') 50 | datasets.extend([ 51 | envbuilder(m=[3,3,3], c=True, p='mm')]) 52 | return datasets 53 | 54 | def create_pairs(self, nlang, mode): 55 | if mode == 'ed': 56 | pairs = ['em', 'me', 'pm', 'mp', 'rm', 'mr'] 57 | if nlang >= 5: 58 | pairs += ['sm', 'ms'] 59 | if nlang >= 6: 60 | pairs += ['gm', 'mg'] 61 | if nlang >= 7: 62 | pairs += ['vm', 'mv'] 63 | if nlang > 7: assert False 64 | elif mode == 'ring': 65 | # random permutation for ring: 4,1,3,2,5 66 | # r,m,p,e,s 67 | # m,p,e,s,r 68 | if nlang == 5: 69 | pairs = ['rm', 'mp', 'pe', 'es', 'sr'] 70 | else: 71 | assert False 72 | pass 73 | else: 74 | assert False 75 | return pairs 76 | 77 | def change_dataset(self): 78 | self.d_index = np.random.randint(len(self.datasets)) 79 | 80 | def reset(self, mode, z): 81 | return self.datasets[self.d_index].reset(mode, z, whole_expr=1) 82 | 83 | def encode_tokens(self, tokens): 84 | return self.datasets[self.d_index].encode_tokens(tokens) 85 | 86 | def decode_tokens(self, tokens): 87 | return self.datasets[self.d_index].decode_tokens(tokens) 88 | 89 | def get_exp_str(self, encoded_expression): 90 | return self.datasets[self.d_index].get_exp_str(encoded_expression) 91 | 92 | def add_exp_str(self, exp_str): 93 | self.datasets[self.d_index].add_exp_str(exp_str) 94 | 95 | def get_problem_trace(self): 96 | return self.datasets[self.d_index].get_problem_trace() 97 | 98 | def get_trace(self): 99 | return self.datasets[self.d_index].get_trace() 100 | 101 | def change_mt(self): 102 | self.datasets[self.d_index].change_mt() 103 | 104 | class Pretrain_Multilingual_Dataset(object): 105 | def __init__(self, max_terms, num_range, ops, samplefrom, episodecap, root, nlang): 106 | super(Pretrain_Multilingual_Dataset, self).__init__() 107 | assert nlang == 5 108 | # just need to make sure these specs are the same as what was pretrained 109 | self.pretrain_env = PreTrainDataset( 110 | max_terms=max_terms, 111 | num_range=num_range, 112 | ops=ops, 113 | samplefrom=samplefrom, 114 | episodecap=episodecap, 115 | root=root, 116 | curr=True, 117 | nlang=nlang) 118 | 119 | self.multilingual_env = ArithmeticLanguageWordEncoding( 120 | max_terms=max_terms, 121 | num_range=num_range, 122 | ops=ops, 123 | samplefrom=samplefrom, 124 | episodecap=episodecap, 125 | root=root, 126 | curr=True, 127 | nlang=nlang 128 | ) 129 | 130 | self.datasets = [self.pretrain_env, self.multilingual_env] 131 | 132 | assert all(self.datasets[0].vocabulary == d.vocabulary for d in self.datasets) 133 | assert all(self.datasets[0].langsize == d.langsize for d in self.datasets) 134 | assert all(self.datasets[0].zsize == d.zsize for d in self.datasets) 135 | 136 | self.langsize = self.datasets[0].langsize 137 | self.zsize = self.datasets[0].zsize 138 | self.vocabulary = self.datasets[0].vocabulary 139 | self.vocabsize = self.datasets[0].vocabsize 140 | self.current_exp_strs = [] 141 | 142 | def initialize_data(self, splits): 143 | map(lambda x: x.initialize_data(splits), self.datasets) 144 | 145 | def initialize_printer(self, logger, args): 146 | map(lambda x: x.initialize_printer(logger, args), self.datasets) 147 | 148 | def change_dataset(self): 149 | self.d_index = np.random.randint(len(self.datasets)) 150 | if self.d_index == 0: 151 | self.pretrain_env.change_dataset() 152 | 153 | def reset(self, mode, z): 154 | return self.datasets[self.d_index].reset(mode, z) 155 | 156 | def encode_tokens(self, tokens): 157 | return self.datasets[self.d_index].encode_tokens(tokens) 158 | 159 | def decode_tokens(self, tokens): 160 | return self.datasets[self.d_index].decode_tokens(tokens) 161 | 162 | def get_exp_str(self, encoded_expression): 163 | return self.datasets[self.d_index].get_exp_str(encoded_expression) 164 | 165 | def add_exp_str(self, exp_str): 166 | self.datasets[self.d_index].add_exp_str(exp_str) 167 | 168 | def get_problem_trace(self): 169 | return self.datasets[self.d_index].get_problem_trace() 170 | 171 | def get_trace(self): 172 | return self.datasets[self.d_index].get_trace() 173 | 174 | def change_mt(self): 175 | self.datasets[self.d_index].change_mt() 176 | 177 | def update_curriculum(self): 178 | self.datasets[1].update_curriculum() 179 | 180 | -------------------------------------------------------------------------------- /MNIST/crl/dataloader/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import groupby 2 | from operator import itemgetter 3 | import torch 4 | from torch.autograd import Variable 5 | from torch.distributions import Categorical 6 | 7 | 8 | 9 | def printf(logger, args, string): 10 | if args.printf: 11 | f = open(logger.logdir+'.txt', 'a') 12 | print >>f, string 13 | else: 14 | print string 15 | 16 | 17 | def create_exp_string(args, relevant_arg_names, prefix, suffix): 18 | string = prefix + '_' 19 | d = vars(args) 20 | for key in sorted(set(relevant_arg_names)): 21 | val = d[key] 22 | to_append = key if isinstance(val, bool) else key + '_' + str(val) 23 | string += to_append + '_' 24 | string += suffix 25 | return string 26 | 27 | def inrange(value, interval): 28 | """ 29 | Outputs whether value > interval[0] 30 | and < interval[1], inclusive 31 | """ 32 | # return value >= interval[0] and value <= interval[1] # NOTE I will change this to exclusive!! 33 | return value >= interval[0] and value < interval[1] # NOTE I will change this to exclusive!! 34 | 35 | 36 | def group_consecutive(list_of_numbers): 37 | groups = [] 38 | for k, g in groupby(enumerate(list_of_numbers), lambda (i, x): i-x): 39 | mg = map(itemgetter(1), g) 40 | groups.append(tuple(mg)) 41 | return groups 42 | 43 | def cuda_if_needed(x, args): 44 | if args.cuda: 45 | return x.cuda() 46 | else: 47 | return x 48 | 49 | def group_by_element(list_of_numbers): 50 | """ 51 | m = [3,3,3,1,2,2,4,4,4,4,5,5,5,5,5] 52 | idx, vals = group(m) 53 | --> 54 | idx = [[0, 1, 2], [3], [4, 5], [6, 7, 8, 9], [10, 11, 12, 13, 14]] 55 | vals = [[3, 3, 3], [1], [2, 2], [4, 4, 4, 4], [5, 5, 5, 5, 5]] 56 | """ 57 | vals = [list(v) for k,v in groupby(list_of_numbers)] 58 | idx = [] 59 | a = range(len(list_of_numbers)) 60 | i = 0 61 | for sublist in vals: 62 | j = i + len(sublist) 63 | idx.append(a[i:j]) 64 | i = j 65 | return idx, vals 66 | 67 | 68 | def permute(list_of_numbers, indices): 69 | return [list_of_numbers[i] for i in indices] 70 | 71 | def group_by_indices(list_of_numbers, idx_groupings): 72 | return [[list_of_numbers[i] for i in g] for g in idx_groupings] 73 | 74 | def invert_permutation(indices): 75 | return [i for i, j in sorted(enumerate(indices), key=lambda (_, j): j)] 76 | 77 | def sort_group_perm(lengths): 78 | perm_idx, sorted_lengths = sort_decr(lengths) 79 | group_idx, group_lengths = group_by_element(sorted_lengths) 80 | inverse_perm_idx = invert_permutation(perm_idx) 81 | return perm_idx, group_idx, inverse_perm_idx 82 | 83 | def sort_decr(lengths): 84 | perm_idx, sorted_lengths = zip(*[(c, d) for c, d in sorted(enumerate(lengths), key=lambda x: x[1], reverse=True)]) 85 | return perm_idx, sorted_lengths 86 | 87 | def var_length_in_batch_wrapper(fn, inputs, inputs_xform, input_to_group_by, args): 88 | """ 89 | 1. permutes by length 90 | 2. groups by length 91 | 3. applies neural network 92 | 4. unpermutes output of neural network 93 | 94 | args: 95 | fn: neural network function 96 | 97 | TODO: 98 | make sure that outputs.extend does not mess up the Variable 99 | """ 100 | # sort by length 101 | lengths = [len(e) for e in input_to_group_by] 102 | perm_idx, sorted_lengths = sort_decr(lengths) 103 | inputs_p = map(lambda x: permute(x, perm_idx), inputs) 104 | # group by sorted length 105 | group_idx, group_lengths = group_by_element(sorted_lengths) 106 | inputs_grp = map(lambda x: group_by_indices(x, group_idx), inputs_p) 107 | # convert every group in inputs_grp to torch tensor 108 | inputs_grp_th = map(lambda (f, y): map(f, y), zip(inputs_xform, inputs_grp)) 109 | def execute_fn_on_grouped_inputs(fn, grouped_inputs): 110 | outputs = [] 111 | for inp in zip(*grouped_inputs): 112 | out = fn(*inp) 113 | outputs.append(out) # does not mess up the Variable. 114 | outputs = torch.cat(outputs) 115 | return outputs 116 | # run network 117 | outputs_p = execute_fn_on_grouped_inputs(fn, inputs_grp_th) # Variable 118 | # unpermute 119 | inverse_perm_idx = invert_permutation(perm_idx) 120 | inverse_perm_idx_th = cuda_if_needed(torch.LongTensor(inverse_perm_idx), args) 121 | outputs = outputs_p[inverse_perm_idx_th] # hopefully this doesn't mess up the gradient computation... 122 | return outputs 123 | 124 | # TODO: does this retain previous activations? 125 | # because you are running the network multiple times before the backward pass 126 | # yes, it still works because look at the "backward both losses together" 127 | # https://discuss.pytorch.org/t/how-to-use-the-backward-functions-for-multiple-losses/1826/7?u=simonw 128 | 129 | def var_length_var_dim_in_batch_wrapper(fn, inputs, inputs_xform, input_to_group_by, args): 130 | """ 131 | 1. permutes by length 132 | 2. groups by length 133 | 3. applies neural network 134 | 4. unpermutes output of neural network 135 | 136 | args: 137 | fn: neural network function 138 | 139 | TODO: 140 | make sure that outputs.extend does not mess up the Variable 141 | """ 142 | # sort by length 143 | lengths = [len(e) for e in input_to_group_by] 144 | perm_idx, sorted_lengths = sort_decr(lengths) 145 | inputs_p = map(lambda x: permute(x, perm_idx), inputs) 146 | # group by sorted length 147 | group_idx, group_lengths = group_by_element(sorted_lengths) 148 | inputs_grp = map(lambda x: group_by_indices(x, group_idx), inputs_p) 149 | # convert every group in inputs_grp to torch tensor 150 | inputs_grp_th = map(lambda (f, y): map(f, y), zip(inputs_xform, inputs_grp)) 151 | def execute_fn_on_grouped_inputs(fn, grouped_inputs): 152 | outputs = [] 153 | for inp in zip(*grouped_inputs): 154 | out = fn(*inp) 155 | outputs.append(out) # does not mess up the Variable. 156 | outputs = torch.cat(outputs) 157 | return outputs 158 | # run network 159 | outputs_p = execute_fn_on_grouped_inputs(fn, inputs_grp_th) # Variable 160 | # unpermute 161 | inverse_perm_idx = invert_permutation(perm_idx) 162 | inverse_perm_idx_th = cuda_if_needed(torch.LongTensor(inverse_perm_idx), args) 163 | outputs = outputs_p[inverse_perm_idx_th] # hopefully this doesn't mess up the gradient computation... 164 | return outputs 165 | 166 | 167 | def reverse(x, dim): 168 | idx = torch.LongTensor([i for i in range(x.size(dim)-1, -1, -1)]) 169 | if isinstance(x, torch.autograd.variable.Variable): 170 | idx = Variable(idx) 171 | if 'cuda' in x.data.type(): 172 | idx = idx.cuda() 173 | else: 174 | assert 'Tensor' in x.type() 175 | if 'cuda' in x.type(): 176 | idx = idx.cuda() 177 | return x.index_select(dim, idx) 178 | 179 | 180 | def entropy(dist, eps=1e-20): 181 | dist_eps = dist + eps 182 | log_action_dist = torch.log(dist_eps) 183 | h = -torch.sum(log_action_dist * dist_eps) 184 | return h 185 | 186 | def sample_from_categorical_dist(dist): 187 | m = Categorical(dist) 188 | s = m.sample() 189 | return s 190 | 191 | def logprob_categorical_dist(dist, s): 192 | m = Categorical(dist) 193 | lp = m.log_prob(s) 194 | return lp 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | -------------------------------------------------------------------------------- /MNIST/utils/log.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from datetime import datetime 3 | import json 4 | import logging 5 | import numpy as np 6 | #import git 7 | from pathlib import Path 8 | import sys 9 | from typing import Any, Dict, List, Optional, Set 10 | 11 | NORMAL_FORMATTER = logging.Formatter('%(levelname)s %(asctime)s: %(name)s: %(message)s') 12 | JSON_FORMATTER = logging.Formatter('%(levelname)s::%(message)s') 13 | FINGERPRINT = 'fingerprint.txt' 14 | LOGFILE = 'log.txt' 15 | EXP = 5 16 | logging.addLevelName(EXP, 'EXP') 17 | 18 | 19 | class Logger(logging.Logger): 20 | def __init__(self) -> None: 21 | # set log level to debug 22 | super().__init__('rainy', EXP) 23 | self._log_dir: Optional[Path] = None 24 | self.exp_start = datetime.now() 25 | 26 | def set_dir_from_script_path( 27 | self, 28 | script_path_: str, 29 | comment: Optional[str] = None, 30 | prefix: str = '', 31 | ) -> None: 32 | script_path = Path(script_path_) 33 | log_dir = script_path.stem + '-' + self.exp_start.strftime("%y%m%d-%H%M%S") 34 | if prefix: 35 | log_dir = prefix + '/' + log_dir 36 | try: 37 | repo = git.Repo(script_path, search_parent_directories=True) 38 | head = repo.head.commit 39 | log_dir += '-' + head.hexsha[:8] 40 | finally: 41 | pass 42 | log_dir_path = Path(log_dir) 43 | if not log_dir_path.exists(): 44 | log_dir_path.mkdir() 45 | self.set_dir(log_dir_path, comment=comment) 46 | 47 | def set_dir(self, log_dir: Path, comment: Optional[str] = None) -> None: 48 | self._log_dir = log_dir 49 | 50 | def make_handler(log_path: Path, level: int) -> logging.Handler: 51 | if not log_path.exists(): 52 | log_path.touch() 53 | handler = logging.FileHandler(log_path.as_posix()) 54 | handler.setFormatter(JSON_FORMATTER) 55 | handler.setLevel(level) 56 | return handler 57 | finger = log_dir.joinpath(FINGERPRINT) 58 | with open(finger.as_posix(), 'w') as f: 59 | f.write('{}\n'.format(self.exp_start)) 60 | if comment: 61 | f.write(comment) 62 | handler = make_handler(Path(log_dir).joinpath(LOGFILE), EXP) 63 | self.addHandler(handler) 64 | 65 | def set_stderr(self, level: int = EXP) -> None: 66 | handler = logging.StreamHandler(stream=sys.stderr) 67 | handler.setFormatter(NORMAL_FORMATTER) 68 | handler.setLevel(level) 69 | self.addHandler(handler) 70 | 71 | @property 72 | def log_dir(self) -> Optional[Path]: 73 | return self._log_dir 74 | 75 | def exp(self, name: str, msg: dict, *args, **kwargs) -> None: 76 | """ 77 | For experiment logging. Only dict is enabled as argument 78 | """ 79 | if not self.isEnabledFor(EXP): 80 | return 81 | delta = datetime.now() - self.exp_start 82 | msg['elapsed-time'] = delta.total_seconds() 83 | msg['name'] = name 84 | self._log(EXP, json.dumps(msg, sort_keys=True), args, **kwargs) # type: ignore 85 | 86 | 87 | def _load_log_file(file_path: Path) -> List[Dict[str, Any]]: 88 | with open(file_path.as_posix()) as f: 89 | lines = f.readlines() 90 | log = [] 91 | for line in filter(lambda s: s.startswith('EXP::'), lines): 92 | log.append(json.loads(line[5:])) 93 | return log 94 | 95 | 96 | class LogWrapper: 97 | """Wrapper of filterd log. 98 | """ 99 | def __init__( 100 | self, 101 | name: str, 102 | inner: List[Dict[str, Any]], 103 | path: Optional[Path] = None, 104 | ) -> None: 105 | self.name = name 106 | self.inner = inner 107 | self._available_keys: Set[str] = set() 108 | self._path = path 109 | 110 | @property 111 | def unwrapped(self) -> List[Dict[str, Any]]: 112 | return self.inner 113 | 114 | def keys(self) -> Set[str]: 115 | if not self._available_keys: 116 | for log in self.inner: 117 | for key in log: 118 | self._available_keys.add(key) 119 | return self._available_keys 120 | 121 | def get(self, key: str) -> List[Any]: 122 | if key not in self.inner[0]: 123 | raise KeyError( 124 | 'LogWrapper({}) doesn\'t have the logging key {}. Available keys: {}' 125 | .format(self.name, key, self.keys()) 126 | ) 127 | return list(map(lambda d: d[key], self.inner)) 128 | 129 | def is_empty(self) -> bool: 130 | return len(self.inner) == 0 131 | 132 | def __repr__(self) -> str: 133 | return 'LogWrapper({}, {})'.format(self._path, self.name) 134 | 135 | def __getitem__(self, key: str) -> List[Any]: 136 | return self.get(key) 137 | 138 | 139 | class ExperimentLog: 140 | """Structured log file. 141 | Used to get graphs or else from rainy log files. 142 | """ 143 | def __init__(self, file_or_dir_name: str) -> None: 144 | path = Path(file_or_dir_name) 145 | if path.is_dir(): 146 | log_path = path.joinpath(LOGFILE) 147 | self.fingerprint = path.joinpath(FINGERPRINT).read_text() 148 | else: 149 | log_path = path 150 | self.fingerprint = '' 151 | self.log = _load_log_file(log_path) 152 | self._available_keys: Set[str] = set() 153 | self.log_path = log_path 154 | 155 | def keys(self) -> Set[str]: 156 | if not self._available_keys: 157 | for log in self.log: 158 | self._available_keys.add(log['name']) 159 | return self._available_keys 160 | 161 | def get(self, key: str) -> LogWrapper: 162 | log = LogWrapper( 163 | key, 164 | list(filter(lambda log: log['name'] == key, self.log)), 165 | self.log_path 166 | ) 167 | if log.is_empty(): 168 | raise KeyError( 169 | '{} doesn\'t have the key {}. Available keys: {}' 170 | .format(self, key, self.keys()) 171 | ) 172 | return log 173 | 174 | def plot_reward(self, batch_size: int, max_steps: int = int(2e7), title: str = '') -> None: 175 | try: 176 | import matplotlib.pyplot as plt 177 | except ModuleNotFoundError as e: 178 | print('plot_reward need matplotlib installed') 179 | raise e 180 | tlog = self.get('train') 181 | x, y = 'update-steps', 'reward-mean' 182 | plt.plot(np.array(tlog[x]) * batch_size, tlog[y]) 183 | tick_fractions = np.array([0.1, 0.2, 0.5, 1.0]) 184 | ticks = tick_fractions * max_steps 185 | MILLION = int(1e6) 186 | if max_steps >= MILLION: 187 | tick_names = ["{}M".format(int(tick / 1e6)) for tick in ticks] 188 | else: 189 | tick_names = ["{}".format(int(tick)) for tick in ticks] 190 | plt.xticks(ticks, tick_names) 191 | plt.title(title) 192 | plt.xlabel('Frames used for training') 193 | plt.ylabel(y) 194 | plt.show() 195 | 196 | def __getitem__(self, key: str) -> LogWrapper: 197 | return self.get(key) 198 | 199 | def __repr__(self) -> str: 200 | return 'ExperimentLog({})'.format(self.log_path.as_posix()) 201 | 202 | 203 | class ExpStats: 204 | """Statictics of loss 205 | """ 206 | def __init__(self) -> None: 207 | self.inner: Dict[str, List[float]] = defaultdict(list) 208 | 209 | def update(self, d: Dict[str, float]) -> None: 210 | for key in d.keys(): 211 | self.inner[key].append(d[key]) 212 | 213 | def report_and_reset(self) -> Dict[str, float]: 214 | res = {} 215 | for k, v in self.inner.items(): 216 | res[k] = np.array(v).mean() 217 | v.clear() 218 | return res 219 | -------------------------------------------------------------------------------- /MNIST/dataset.py: -------------------------------------------------------------------------------- 1 | """Logic to interface with the dataset""" 2 | from __future__ import print_function 3 | 4 | import h5py 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data.dataset import Dataset 9 | from crl.dataset import get_dataloaders as get_crl_dataloaders 10 | 11 | 12 | class LoadDataset(Dataset): 13 | """Dataset class""" 14 | 15 | def __init__(self, mode, length=51, directory='/Volumes/Fred/Data', 16 | dataset="balls4mass64.h5"): 17 | self.length = length 18 | self.mode = mode 19 | self.directory = directory 20 | # datasets = ['/atari.h5', '/balls3curtain64.h5', '/balls4mass64.h5', 21 | # '/balls678mass64.h5'] 22 | self.dataset = dataset 23 | print(dataset) 24 | if dataset != '2Balls': 25 | hdf5_file = h5py.File(self.directory + "/" + dataset, 'r') 26 | if mode == "transfer": 27 | self.input_data = hdf5_file['test'] 28 | else: 29 | self.input_data = hdf5_file[self.mode] 30 | self.input_data = np.array(self.input_data['features']) 31 | else: 32 | if mode == 'train': 33 | npz = np.load('data_generation/train/gravity_x_0.00_y_0.00.npz') 34 | else: 35 | npz = np.load('data_generation/val/gravity_x_0.00_y_0.00.npz') 36 | 37 | self.images = npz['images'].astype(np.float32) 38 | self.images = np.split(self.images, self.images.shape[0]) 39 | 40 | 41 | 42 | 43 | def __getitem__(self, index, out_list=('features', 'groups')): 44 | # ['collisions', 'events', 'features', 'groups', 'positions', 'velocities'] 45 | # Currently (51 ,64, 64, 1) 46 | if self.dataset != '2Balls': 47 | features = 1.0 * self.input_data[:self.length, index, :, :, :] 48 | # True, False label, conert to int 49 | # Convert to tensors 50 | L = self.input_data.shape[0] 51 | features = torch.tensor(features.reshape(L, 1, 64, 64)) 52 | else: 53 | video = self.images[index] 54 | features = torch.tensor(video).view(50, 1, 64, 64) 55 | return features.float() 56 | 57 | def __len__(self): 58 | if self.dataset != '2Balls': 59 | return int(np.shape(self.input_data)[1]) 60 | else: 61 | return len(self.images) 62 | 63 | 64 | 65 | 66 | 67 | class LoadRGBDataset(Dataset): 68 | def __init__(self, mode, length=51, directory='/Volumes/Fred/Data', dataset="balls4mass64.h5"): 69 | self.length = length 70 | self.mode = mode 71 | self.directory = directory 72 | hdf5_file = h5py.File(self.directory + "/" + dataset, 'r') 73 | 74 | #if mode == "transfer": 75 | # self.input_data = hdf5_file['test'] 76 | #else: 77 | # self.input_data = hdf5_file[self.mode] 78 | #self.input_data = np.array(self.input_data['features']) 79 | 80 | #datasets = ['/atari.h5','/balls3curtain64.h5','/balls4mass64.h5','/balls678mass64.h5'] 81 | hdf5_file = h5py.File(self.directory + "/" + dataset, 'r') 82 | #if mode != 'transfer': 83 | # print('READING IN 4 DATASET') 84 | # hdf5_file = h5py.File(self.directory+'/balls4mass64.h5', 'r') 85 | #else: 86 | # print('READING IN 6-7-8 DATASET') 87 | # hdf5_file = h5py.File(self.directory+'/balls678mass64.h5', 'r') 88 | 89 | if mode != 'transfer': 90 | self.input_data = hdf5_file[self.mode] 91 | else: 92 | self.input_data = hdf5_file['test'] 93 | print(self.input_data) 94 | self.data_to_use = np.array(self.input_data['groups']) 95 | print("Done with RGB Convert") 96 | 97 | def __getitem__(self, index, out_list=('features', 'groups')): 98 | # ['collisions', 'events', 'features', 'groups', 'positions', 'velocities'] 99 | # Currently (51 ,64, 64, 1) 100 | # print("In get item") 101 | # print('data_to_use shape: ',self.data_to_use.shape) 102 | # print('index is ',index) 103 | features = 1.0*self.data_to_use[:,index,:,:,:] # True, False label, conert to int 104 | #print(features.shape) 105 | 106 | colors = np.array([[228,26,28],[55,126,184],[77,175,74],[152,78,163],[255,127,0],[255,255,51]])/255. 107 | (Time, X_dim, Y_dim, Channels) = features.shape 108 | 109 | #print(self.data_to_use.shape) 110 | uniques = np.unique(features)[1:] 111 | uniques = uniques.astype(int) 112 | rc = [np.random.choice([0,1,2,3]) for _ in range(len(uniques))] 113 | 114 | self.data_to_use2 = np.zeros((Time, 3, X_dim, Y_dim)) 115 | for t in range(Time): 116 | 117 | r_channel = np.zeros((64,64)) 118 | g_channel = np.zeros((64,64)) 119 | b_channel = np.zeros((64,64)) 120 | # use four colours 121 | for ball in uniques: 122 | self.data_to_use2[t,0,:,:] += ((features[t,:,:,0]==ball)*1.0)*colors[rc[ball-1]][0] 123 | self.data_to_use2[t,1,:,:] += ((features[t,:,:,0]==ball)*1.0)*colors[rc[ball-1]][1] 124 | self.data_to_use2[t,2,:,:] += ((features[t,:,:,0]==ball)*1.0)*colors[rc[ball-1]][2] 125 | features = self.data_to_use2 126 | features_no_noise = np.copy(features) 127 | features = torch.tensor(features) 128 | features_no_noise = torch.tensor(features_no_noise) 129 | #exit() 130 | #print(features.float().shape) 131 | return (features.float(),features_no_noise.float()) 132 | 133 | def __len__(self): 134 | return int(np.shape(self.data_to_use)[1]) 135 | 136 | def get_dataloaders(args): 137 | """Method to return the train, test and transfer dataloaders""" 138 | 139 | modes = ["training", "test", "transfer"] 140 | dataset_names = [args.train_dataset, args.train_dataset, args.train_dataset] 141 | shuffle_list = [True, False, False] 142 | 143 | def _get_dataloader(mode, dataset_name, shuffle): 144 | dataset = LoadDataset(mode=mode, 145 | length=args.sequence_length, 146 | directory=args.directory, 147 | dataset=dataset_name) 148 | return DataLoader(dataset, batch_size=args.batch_size, 149 | shuffle=shuffle, num_workers=0) 150 | 151 | return [_get_dataloader(mode, dataset_name, shuffle) for 152 | mode, dataset_name, shuffle in zip(modes, dataset_names, shuffle_list)] 153 | 154 | def get_rgb_dataloaders(args): 155 | """Method to return the train, test and transfer dataloaders""" 156 | 157 | modes = ["training", "test", "transfer"] 158 | dataset_names = [args.train_dataset, args.train_dataset, args.transfer_dataset] 159 | shuffle_list = [True, False, False] 160 | 161 | def _get_dataloader(mode, dataset_name, shuffle): 162 | dataset = LoadRGBDataset(mode=mode, 163 | length=args.sequence_length, 164 | directory=args.directory, 165 | dataset=dataset_name) 166 | return DataLoader(dataset, batch_size=args.batch_size, 167 | shuffle=shuffle, num_workers=0) 168 | 169 | return [_get_dataloader(mode, dataset_name, shuffle) for 170 | mode, dataset_name, shuffle in zip(modes, dataset_names, shuffle_list)] 171 | 172 | def crl_dataloaders(args, should_shuffle=True): 173 | """Method to return the train, test and transfer dataloaders""" 174 | 175 | if args.train_dataset == "crl": 176 | use_cuda = True 177 | if str(args.device) == "cpu": 178 | use_cuda = False 179 | return get_crl_dataloaders(args = args, 180 | use_cuda=use_cuda, should_shuffle=should_shuffle) 181 | else: 182 | return get_rnem_dataloaders() 183 | 184 | -------------------------------------------------------------------------------- /synthetic/utilities/BlockGRU.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | ''' 6 | Goal1: a GRU where the weight matrices have a block structure so that information flow is constrained 7 | 8 | Data is assumed to come in [block1, block2, ..., block_n]. 9 | 10 | Goal2: Dynamic parameter sharing between blocks (RIMs) 11 | 12 | ''' 13 | 14 | import torch 15 | import torch.nn as nn 16 | from .GroupLinearLayer import GroupLinearLayer 17 | from .sparse_attn import Sparse_attention 18 | 19 | 20 | ''' 21 | Given an N x N matrix, and a grouping of size, set all elements off the block diagonal to 0.0 22 | ''' 23 | def zero_matrix_elements(matrix, k): 24 | assert matrix.shape[0] % k == 0 25 | assert matrix.shape[1] % k == 0 26 | g1 = matrix.shape[0] // k 27 | g2 = matrix.shape[1] // k 28 | new_mat = torch.zeros_like(matrix) 29 | for b in range(0,k): 30 | new_mat[b*g1 : (b+1)*g1, b*g2 : (b+1)*g2] += matrix[b*g1 : (b+1)*g1, b*g2 : (b+1)*g2] 31 | 32 | matrix *= 0.0 33 | matrix += new_mat 34 | 35 | 36 | class BlockGRU(nn.Module): 37 | """Container module with an encoder, a recurrent module, and a decoder.""" 38 | 39 | def __init__(self, ninp, nhid, k): 40 | super(BlockGRU, self).__init__() 41 | 42 | assert ninp % k == 0 43 | assert nhid % k == 0 44 | 45 | self.k = k 46 | self.gru = nn.GRUCell(ninp, nhid) 47 | self.nhid = nhid 48 | self.ninp = ninp 49 | 50 | def blockify_params(self): 51 | pl = self.gru.parameters() 52 | 53 | for p in pl: 54 | p = p.data 55 | if p.shape == torch.Size([self.nhid*3]): 56 | pass 57 | '''biases, don't need to change anything here''' 58 | if p.shape == torch.Size([self.nhid*3, self.nhid]) or p.shape == torch.Size([self.nhid*3, self.ninp]): 59 | for e in range(0,4): 60 | zero_matrix_elements(p[self.nhid*e : self.nhid*(e+1)], k=self.k) 61 | 62 | def forward(self, input, h): 63 | 64 | #self.blockify_params() 65 | 66 | hnext = self.gru(input, h) 67 | 68 | return hnext, None 69 | 70 | class Identity(torch.autograd.Function): 71 | @staticmethod 72 | def forward(ctx, input): 73 | return input * 1.0 74 | def backward(ctx, grad_output): 75 | print(torch.sqrt(torch.sum(torch.pow(grad_output,2)))) 76 | print('-----------') 77 | return grad_output * 1.0 78 | 79 | class OldSharedBlockGRU(nn.Module): 80 | """Dynamic sharing of parameters between blocks(RIM's)""" 81 | 82 | def __init__(self, ninp, nhid, k, n_templates): 83 | super(SharedBlockGRU, self).__init__() 84 | 85 | assert ninp % k == 0 86 | assert nhid % k == 0 87 | 88 | self.k = k 89 | self.m = nhid // self.k 90 | 91 | self.n_templates = n_templates 92 | self.templates = nn.ModuleList([nn.GRUCell(ninp,self.m) for _ in range(0,self.n_templates)]) 93 | self.nhid = nhid 94 | 95 | self.ninp = ninp 96 | 97 | self.gll_write = GroupLinearLayer(self.m,16, self.n_templates) 98 | self.gll_read = GroupLinearLayer(self.m,16,1) 99 | self.sa = Sparse_attention(1) 100 | print("Using Gumble sparsity") 101 | 102 | def blockify_params(self): 103 | 104 | return 105 | 106 | def forward(self, input, h): 107 | 108 | #self.blockify_params() 109 | bs = h.shape[0] 110 | h = h.reshape((h.shape[0], self.k, self.m)).reshape((h.shape[0]*self.k, self.m)) 111 | 112 | input = input.reshape(input.shape[0], 1, input.shape[1]) 113 | input = input.repeat(1,self.k,1) 114 | input = input.reshape(input.shape[0]*self.k, input.shape[2]) 115 | 116 | h_read = self.gll_read((h*1.0).reshape((h.shape[0], 1, h.shape[1]))) 117 | 118 | 119 | hnext_stack = [] 120 | 121 | 122 | for template in self.templates: 123 | hnext_l = template(input, h) 124 | hnext_l = hnext_l.reshape((hnext_l.shape[0], 1, hnext_l.shape[1])) 125 | hnext_stack.append(hnext_l) 126 | 127 | hnext = torch.cat(hnext_stack, 1) 128 | 129 | write_key = self.gll_write(hnext) 130 | 131 | ''' 132 | sm = nn.Softmax(2) 133 | att = sm(torch.bmm(h_read, write_key.permute(0, 2, 1))).squeeze(1) 134 | att = self.sa(att).unsqueeze(1) 135 | ''' 136 | 137 | att = torch.nn.functional.gumbel_softmax(torch.bmm(h_read, write_key.permute(0, 2, 1)), tau=1, hard=True) 138 | #att = att*0.0 + 0.25 139 | 140 | hnext = torch.bmm(att, hnext) 141 | 142 | hnext = hnext.mean(dim=1) 143 | hnext = hnext.reshape((bs, self.k, self.m)).reshape((bs, self.k*self.m)) 144 | #print('shapes', hnext.shape, cnext.shape) 145 | 146 | return hnext, att.data.reshape(bs,self.k,self.n_templates) 147 | 148 | 149 | class SharedBlockGRU(nn.Module): 150 | """Dynamic sharing of parameters between blocks(RIM's)""" 151 | 152 | def __init__(self, ninp, nhid, k, n_templates): 153 | super(SharedBlockGRU, self).__init__() 154 | 155 | assert ninp % k == 0 156 | assert nhid % k == 0 157 | 158 | self.k = k 159 | self.m = nhid // self.k 160 | 161 | self.n_templates = n_templates 162 | print("input to template is ", ninp//k) 163 | self.templates = nn.ModuleList([nn.GRUCell(ninp//k,self.m) for _ in range(0,self.n_templates)]) 164 | self.nhid = nhid 165 | 166 | self.ninp = ninp 167 | 168 | #self.gll_write = GroupLinearLayer(self.m,16, self.n_templates) 169 | self.gll_write = nn.Linear(self.m, 16) 170 | self.gll_read = GroupLinearLayer(self.m,16,1) 171 | 172 | 173 | self.sa = Sparse_attention(1) 174 | print("Using Gumble sparsity") 175 | 176 | def blockify_params(self): 177 | 178 | return 179 | 180 | def forward(self, input, h): 181 | 182 | #self.blockify_params() 183 | bs = h.shape[0] 184 | h = h.reshape((h.shape[0], self.k, self.m)).reshape((h.shape[0]*self.k, self.m)) 185 | 186 | input = input.reshape(input.shape[0], 1, input.shape[1]) 187 | #input = input.repeat(1,self.k,1) 188 | #input = input.reshape(input.shape[0]*self.k, input.shape[2]) 189 | input = input.reshape(input.shape[0]*self.k, -1) 190 | 191 | #print("input shape is", input.shape) 192 | 193 | h_read = self.gll_read((h*1.0).reshape((h.shape[0], 1, h.shape[1]))) 194 | 195 | 196 | hnext_stack = [] 197 | 198 | 199 | for template in self.templates: 200 | hnext_l = template(input, h) 201 | hnext_l = hnext_l.reshape((hnext_l.shape[0], 1, hnext_l.shape[1])) 202 | hnext_stack.append(hnext_l) 203 | 204 | hnext = torch.cat(hnext_stack, 1) 205 | 206 | write_key = self.gll_write(hnext) 207 | 208 | ''' 209 | sm = nn.Softmax(2) 210 | att = sm(torch.bmm(h_read, write_key.permute(0, 2, 1))).squeeze(1) 211 | att = self.sa(att).unsqueeze(1) 212 | ''' 213 | 214 | att = torch.nn.functional.gumbel_softmax(torch.bmm(h_read, write_key.permute(0, 2, 1)), tau=1, hard=True) 215 | #att = att*0.0 + 0.25 216 | 217 | hnext = torch.bmm(att, hnext) 218 | 219 | hnext = hnext.mean(dim=1) 220 | hnext = hnext.reshape((bs, self.k, self.m)).reshape((bs, self.k*self.m)) 221 | #print('shapes', hnext.shape, cnext.shape) 222 | 223 | return hnext, att.data.reshape(bs,self.k,self.n_templates) 224 | 225 | 226 | 227 | if __name__ == "__main__": 228 | 229 | Blocks = BlockGRU(2, 6, k=2) 230 | opt = torch.optim.Adam(Blocks.parameters()) 231 | 232 | pl = Blocks.gru.parameters() 233 | 234 | inp = torch.randn(100,2) 235 | h = torch.randn(100,6) 236 | 237 | h2 = Blocks(inp,h) 238 | 239 | L = h2.sum()**2 240 | 241 | #L.backward() 242 | #opt.step() 243 | #opt.zero_grad() 244 | 245 | 246 | pl = Blocks.gru.parameters() 247 | for p in pl: 248 | print(p.shape) 249 | #print(torch.Size([Blocks.nhid*4])) 250 | if p.shape == torch.Size([Blocks.nhid*3]): 251 | print(p.shape, 'a') 252 | #print(p) 253 | '''biases, don't need to change anything here''' 254 | if p.shape == torch.Size([Blocks.nhid*3, Blocks.nhid]) or p.shape == torch.Size([Blocks.nhid*3, Blocks.ninp]): 255 | print(p.shape, 'b') 256 | for e in range(0,4): 257 | print(p[Blocks.nhid*e : Blocks.nhid*(e+1)]) 258 | -------------------------------------------------------------------------------- /MNIST/crl/dataloader/datagen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import operator 4 | import matplotlib 5 | matplotlib.use('agg') 6 | import matplotlib.pyplot as plt 7 | import pprint 8 | 9 | from collections import OrderedDict 10 | 11 | from arithmetic import Plus, Minus, Multiply, Divide 12 | import utils 13 | import datautils as du 14 | 15 | np.random.seed(0) 16 | 17 | class ArithmeticDataGenerator(object): 18 | def __init__(self, ops, numrange): 19 | super(ArithmeticDataGenerator, self).__init__() 20 | """ 21 | should set self.encoding_length 22 | """ 23 | self.verbose = False 24 | 25 | # Mappings 26 | self.static_operator_dict = OrderedDict([ 27 | (operator.add, '+'), 28 | (operator.mul, '*'), 29 | (operator.sub, '-'), 30 | (operator.div, '/') 31 | ]) 32 | self.reverse_static_operator_dict = {v:k for (k, v) in self.static_operator_dict.items()} 33 | self.operator_dict = OrderedDict() # this needs to be orderedict because of onehot 34 | for op in ops: 35 | if op in self.reverse_static_operator_dict: 36 | self.operator_dict[self.reverse_static_operator_dict[op]] = op 37 | self.operator_names = self.operator_dict.keys() 38 | self.operators = self.operator_dict.values() 39 | 40 | # Parameters 41 | self.range = numrange 42 | assert self.range[0] == min(self.range) == 0 43 | 44 | # Stats 45 | self.op_length = len(self.operator_dict) 46 | self.range_length = None 47 | self.encoding_length = None 48 | 49 | ######################### Internal Methods ################################# 50 | 51 | def _sample_operator(self): 52 | return np.random.choice(self.operator_names) 53 | 54 | def _fold_left_ops_terms_sample(self, ops, terms): 55 | result_so_far = terms[0] 56 | for o in ops: 57 | next_term = self._sample_second_term(o, result_so_far) 58 | result_so_far = o(result_so_far, next_term) 59 | terms.append(next_term) 60 | return ops, terms, result_so_far 61 | 62 | def _fold_left_ops_terms_eval(self, ops, terms): 63 | results_so_far = [terms[0]] 64 | for i in range(len(ops)): 65 | o = ops[i] 66 | next_term = terms[i+1] 67 | result_so_far = o(results_so_far[-1], next_term) 68 | results_so_far.append(result_so_far) 69 | return results_so_far 70 | 71 | def _extract_multiplicative_groups(self, ops): 72 | # find the indices of the ops 73 | multiplicative_ops_indices = filter(lambda x: self.operator_dict[ops[x]] in '*/', range(len(ops))) 74 | multiplicative_groups = utils.group_consecutive(multiplicative_ops_indices) 75 | return multiplicative_groups 76 | 77 | def _combine_ops_terms(self, ops, terms): 78 | expression = [terms[0]] 79 | for i in range(len(ops)): 80 | expression.append(ops[i]) 81 | expression.append(terms[i+1]) 82 | return expression 83 | 84 | def _split_expression(self, expression): 85 | ops = expression[1::2] 86 | terms = expression[::2] 87 | return ops, terms 88 | 89 | ###################### Helper External Methods ############################# 90 | 91 | def get_additive_expression(self, ops, additive_ops_indices): 92 | additive_ops = [ops[i] for i in additive_ops_indices] 93 | additive_terms = [du.sample_term_in_range(self.range)] 94 | additive_ops, additive_terms, exp_val = self._fold_left_ops_terms_sample(additive_ops, additive_terms) 95 | if self.verbose: 96 | print 'Additive expression {} = {}'.format( 97 | du.build_expression_string(additive_terms, additive_ops, self.operator_dict), 98 | exp_val) 99 | return additive_ops, additive_terms, exp_val 100 | 101 | def create_multiplicative_expression(self): 102 | raise NotImplementedError 103 | 104 | def match_multiplicative_groups_to_additive_term(self, multiplicative_groups): 105 | multiplicative_groups_to_additive_term = {} 106 | for i, mg in enumerate(multiplicative_groups): 107 | lo_index = mg[0] 108 | if lo_index > 0: 109 | # we are matching terms that are not the first term 110 | # in this case, the term we match corresponds the 111 | # index of the additive term that corresponds to this 112 | # multiplicative group 113 | if multiplicative_groups[0][0] == 0: 114 | additive_term_to_match_index = i 115 | else: 116 | additive_term_to_match_index = i+1 117 | else: 118 | # we are matching the first term 119 | additive_term_to_match_index = 0 120 | multiplicative_groups_to_additive_term[mg] = additive_term_to_match_index 121 | return multiplicative_groups_to_additive_term 122 | 123 | def interleave_additive_multiplicative(self, additive_terms, additive_ops, 124 | additive_term_to_multiplicative_group): 125 | all_terms = [] 126 | all_ops = [] 127 | for j in xrange(len(additive_terms)): 128 | if j in additive_term_to_multiplicative_group: 129 | all_terms.extend(additive_term_to_multiplicative_group[j][1]) 130 | all_ops.extend(additive_term_to_multiplicative_group[j][0]) 131 | else: 132 | all_terms.append(additive_terms[j]) 133 | if len(additive_ops) > 0 and j < len(additive_terms)-1: 134 | all_ops.append(additive_ops[j]) 135 | return all_ops, all_terms 136 | 137 | ######################### External Methods ################################# 138 | 139 | def evaluate_expression(self, terms, ops): 140 | assert len(ops) == len(terms)-1 141 | expression = self._combine_ops_terms(ops, terms) 142 | while self._extract_multiplicative_groups(ops) != []: 143 | mg = self._extract_multiplicative_groups(ops)[0] 144 | mg = tuple([1 + 2*i for i in mg]) 145 | lo, hi = mg[0], mg[-1] 146 | multiplicative_ops = [expression[i] for i in mg] 147 | multiplicative_terms = [expression[i] for i in range(lo-1, hi+2) if i not in mg] # check this 148 | multiplicative_val = self._fold_left_ops_terms_eval(multiplicative_ops, multiplicative_terms)[-1] 149 | # replace everything from [lo-1, hi+2) with multiplicative_val 150 | expression = expression[:lo-1] + [multiplicative_val] + expression[hi+2:] 151 | ops, terms = self._split_expression(expression) 152 | # at this point, it should just be an expression of additive terms 153 | additive_ops, additive_terms = self._split_expression(expression) 154 | assert all(self.operator_dict[ao] in '+-' for ao in additive_ops) 155 | result = self._fold_left_ops_terms_eval(additive_ops, additive_terms)[-1] 156 | return result 157 | 158 | def create_problem(self, max_terms): 159 | """ 160 | all the input, intermediate, and output terms should be 161 | in self.range, inclusive 162 | 163 | expressions contain up to self.max_terms-1 operators 164 | 165 | divisions are always divisible, never divisible by 0 166 | 167 | only positive integers 168 | """ 169 | num_ops = max_terms - 1 170 | ops = [self._sample_operator() for i in xrange(num_ops)] 171 | 172 | # focus on the additive_ops_first 173 | additive_ops_indices = filter(lambda x: self.operator_dict[ops[x]] in '+-', range(len(ops))) 174 | additive_ops, additive_terms, exp_val = self.get_additive_expression(ops, additive_ops_indices) 175 | 176 | # match a multiplicative expression for each additive term 177 | multiplicative_groups = self._extract_multiplicative_groups(ops) 178 | additive_term_to_multiplicative_group = self.create_multiplicative_expression( 179 | ops, multiplicative_groups, additive_terms) 180 | 181 | # put everything together 182 | all_ops, all_terms = self.interleave_additive_multiplicative(additive_terms, additive_ops, additive_term_to_multiplicative_group) 183 | exp_str = du.build_expression_string(all_terms, all_ops, self.operator_dict) 184 | if self.verbose: print 'Final Expression: {} = {}'.format(exp_str, exp_val) 185 | return exp_str, exp_val, all_terms, all_ops 186 | 187 | def encode_problem(self, exp_str, exp_val, terms, ops): 188 | raise NotImplementedError -------------------------------------------------------------------------------- /MNIST/argument_parser.py: -------------------------------------------------------------------------------- 1 | """Script to parse all the command-line arguments""" 2 | import argparse 3 | 4 | 5 | def str2bool(v): 6 | """Method to map string to bool for argument parser""" 7 | if isinstance(v, bool): 8 | return v 9 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 10 | return True 11 | if v.lower() in ('no', 'false', 'f', 'n', '0'): 12 | return False 13 | raise argparse.ArgumentTypeError('Boolean value expected.') 14 | 15 | 16 | def argument_parser(): 17 | """Function to parse all the arguments""" 18 | 19 | parser = argparse.ArgumentParser(description='Block Model') 20 | parser.add_argument('--batch_size', type=int, default=50, metavar='N', 21 | help='ADD') 22 | parser.add_argument('--epochs', type=int, default=100, metavar='E', 23 | help='ADD') 24 | parser.add_argument('--inp_heads', type=int, default=1, metavar='E', help='num of heads in input attention') 25 | parser.add_argument('--sequence_length', type=int, default=51, metavar='S', 26 | help='ADD') 27 | parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', help='ADD') 28 | parser.add_argument('--dropout', type=float, default=0.5, metavar='dropout', help='dropout') 29 | parser.add_argument('--kl_coeff', type=float, default=0.0, metavar='KL_coeff', 30 | help='KL_coeff') 31 | parser.add_argument('--num_blocks', type=int, default=6, metavar='num_blocks', 32 | help='Number_of_blocks') 33 | parser.add_argument('--num_encoders', type=int, default=1, metavar='num_encoders', 34 | help='Number of encoders ') 35 | parser.add_argument('--topk', type=int, default=4, metavar='topk', 36 | help='Number_of_topk_blocks') 37 | parser.add_argument('--memorytopk', type=int, default=4, metavar='memtopk', 38 | help='Number_of_topk_blocks') 39 | 40 | parser.add_argument('--hidden_size', type=int, default=600, metavar='hsize', 41 | help='hidden_size') 42 | parser.add_argument('--n_templates', type=int, default=0, metavar='shared_blocks', 43 | help='num_templates') 44 | parser.add_argument('--num_modules_read_input', type=int, default=4, metavar='sort of proxy to inp heads') 45 | parser.add_argument('--share_inp', type=str2bool, default=False, metavar='share inp rims parameters') 46 | parser.add_argument('--share_comm', type=str2bool, default=False, metavar='share comm rims parameters') 47 | 48 | 49 | parser.add_argument('--do_rel', type=str2bool, default=False, metavar='use relational memory or not?') 50 | parser.add_argument('--memory_slots', type=int, default=4, metavar='memory slots for rel memory') 51 | parser.add_argument('--memory_mlp', type=int, default=4, metavar='no of memory mlp for rel memory') 52 | parser.add_argument('--num_memory_heads', type=int, default=4, metavar='memory heads for rel memory') 53 | parser.add_argument('--memory_head_size', type=int, default=16, metavar='memory head size for rel memory') 54 | 55 | parser.add_argument('--attention_out', type=int, default=340, help='ADD') 56 | 57 | parser.add_argument('--id', type=str, default='default', 58 | metavar='id of the experiment', help='id of the experiment') 59 | parser.add_argument('--algo', type=str, default='Rules', 60 | metavar='algorithm of the experiment', help='one of LSTM,GRU or RIM, SCOFF') 61 | parser.add_argument('--num_rules', type = int, default = 0) 62 | parser.add_argument('--rule_time_steps', type = int, default = 0) 63 | parser.add_argument('--model_persist_frequency', type=int, default=10, 64 | metavar='Frequency at which the model is persisted', 65 | help='Number of training epochs after which model is to ' 66 | 'be persisted. -1 means that the model is not' 67 | 'persisted') 68 | parser.add_argument('--batch_frequency_to_log_heatmaps', type=int, default=-1, 69 | metavar='Frequency at which the heatmaps are persisted', 70 | help='Number of training batches after which we will persit the ' 71 | 'heatmaps. -1 means that the heatmap will not be' 72 | 'persisted') 73 | parser.add_argument('--path_to_load_model', type=str, default="", 74 | metavar='Relative Path to load the model', 75 | help='Relative Path to load the model. If this is empty, no model' 76 | 'is loaded.') 77 | parser.add_argument('--components_to_load', type=str, default="", 78 | metavar='_ seperated list of model components that are ' 79 | 'to be loaded.', 80 | help='_ (underscore) seperated list of model components ' 81 | 'that are to be loaded. Possible components ' 82 | 'are blocks, encoders, decoders and rules - ' 83 | 'eg blocks, blocks_rules, rules_blocks, rules,' 84 | 'rules_encoders or encoders etc', ) 85 | 86 | parser.add_argument('--train_dataset', type=str, default="balls4mass64.h5", 87 | metavar='path to dataset on which the model should be ' 88 | 'trained', 89 | help='path to dataset on which the model should be ' 90 | 'trained') 91 | parser.add_argument('--test_dataset', type=str, 92 | metavar='path to dataset on which the model should be ' 93 | 'tested for stove', 94 | help='path to dataset on which the model should be ' 95 | 'tested for stove') 96 | 97 | parser.add_argument('--transfer_dataset', type=str, default="balls678mass64.h5", 98 | metavar='path to dataset on which the model should be ' 99 | 'transfered', 100 | help='path to dataset on which the model should be ' 101 | 'transfered') 102 | 103 | parser.add_argument('--should_save_csv', type=str2bool, nargs='?', 104 | const=True, default=True, 105 | metavar='Flag to decide if the csv logs should be created. ' 106 | 'It is useful as creating csv logs makes a lot of' 107 | 'files.', 108 | help='Flag to decide if the csv logs should be created. ' 109 | 'It is useful as creating csv logs makes a lot of' 110 | 'files.') 111 | 112 | parser.add_argument('--should_resume', type=str2bool, nargs='?', 113 | const=True, default=False, 114 | metavar='Flag to decide if the previous experiment should be ' 115 | 'resumd. If this flag is set, the last saved model ' 116 | '(corresponding to the given id is fetched)', 117 | help='Flag to decide if the previous experiment should be ' 118 | 'resumd. If this flag is set, the last saved model ' 119 | '(corresponding to the given id is fetched)',) 120 | 121 | parser.add_argument('--something', type=str, default='4Balls') 122 | parser.add_argument('--version', type=int, default=1) 123 | parser.add_argument('--do_comm', type=str2bool, default=True) 124 | parser.add_argument('--perm_inv', type=str2bool, default=True) 125 | parser.add_argument('--rule_selection', type=str, default = 'gumble') 126 | parser.add_argument('--application_option', type = str, default = '3.0.1.1') 127 | parser.add_argument('--use_entropy', type = str2bool, default = True) 128 | parser.add_argument('--num_transforms', type = int, default = 8) 129 | parser.add_argument('--transform_length', type = int, default = 3) 130 | parser.add_argument('--rule_dim', type = int , default = 32) 131 | parser.add_argument('--seed', type = int, default = 0) 132 | parser.add_argument('--share_key_value', type = str2bool, default = False) 133 | parser.add_argument('--color', type = str2bool, default = False) 134 | 135 | args = parser.parse_args() 136 | 137 | args.frame_frequency_to_log_heatmaps = 5 138 | 139 | args.folder_log = f"logs/{args.id}" 140 | 141 | #if args.num_encoders != 1: 142 | # args.num_encoders = args.num_blocks 143 | 144 | return args 145 | --------------------------------------------------------------------------------