├── teaser.mp4 ├── architecture.png ├── requirements.txt ├── LICENSE.md ├── README.md ├── graph_models.py ├── graph_dataset.py ├── models.py ├── graph_train.py ├── dataset.py └── train.py /teaser.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yilundu/irem_code_release/HEAD/teaser.mp4 -------------------------------------------------------------------------------- /architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yilundu/irem_code_release/HEAD/architecture.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict==1.9 2 | imageio==2.8.0 3 | ipython==8.2.0 4 | matplotlib==3.2.2 5 | numpy==1.21.2 6 | opencv_python==4.5.3.56 7 | Pillow==9.1.0 8 | scikit_image==0.16.2 9 | scipy==1.7.3 10 | seaborn==0.11.0 11 | torch==1.2.0 12 | torch_geometric==2.0.4 13 | torchvision==0.2.2 14 | tqdm==4.60.0 15 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Iterative Reasoning through Energy Minimization Authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE."""")))) 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Iterative Reasoning through Energy Minimization 2 | 3 | Yilun Du, Shuang Li, Joshua B. Tenenbaum, Igor Mordatch 4 | 5 | A link to our paper can be found on [arXiv](https://arxiv.org/abs/2206.15448). 6 | 7 | ## Overview 8 | 9 | Official codebase for [Learning Iterative Reasoning through Energy Minimization](). 10 | Contains scripts to reproduce experiments. 11 | 12 |

13 | 14 |

15 | 16 | ## Instructions 17 | 18 | Please install the listed requirements.txt file. 19 | 20 | ``` 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | We provide code for running continuous graph reasoning experiments in `graph_train.py` and code for running continuous matrix experiments in `train.py`. 25 | 26 | 27 | To run continuous matrix addition experiments you utilize the following command: 28 | 29 | ``` 30 | python train.py --exp=addition_experiment --train --num_steps=10 --dataset=addition --train --cuda --infinite 31 | ``` 32 | 33 | To evaluate the final performance of the model after training, you may use the command: 34 | 35 | ``` 36 | python train.py --exp=addition_experiment --num_steps=10 --dataset=addition --cuda --infinite --resume_iter=10000 37 | ``` 38 | 39 | and the following command for the OOD test set: 40 | 41 | ``` 42 | python train.py --exp=addition_experiment --num_steps=10 --dataset=addition --cuda --infinite --resume_iter=10000 --ood 43 | ``` 44 | 45 | We may substitute the flag --dataset with other keywords such as inverse or lowrank (as well as additional ones defined in dataset.py). 46 | 47 | 48 | 49 | To run discrete graph reasoning experiments you may utilize the following command: 50 | 51 | ``` 52 | python graph_train.py --exp=identity_experiment --train --num_steps=10 --dataset=identity --train --cuda --infinite 53 | ``` 54 | 55 | We may substitute the flag --dataset with other datasets such as shortestpath or connected (as well as additional ones defined in graph\_dataset.py). 56 | 57 | 58 | To evaluate the model, you may then utilize the following command: 59 | 60 | ``` 61 | python graph_train.py --exp=identity_experiment --num_steps=10 --dataset=identity --cuda --infinite --resume_iter=10000 62 | ``` 63 | 64 | ## Citation 65 | 66 | Please cite our paper as: 67 | 68 | ``` 69 | @article{du2022irem, 70 | title={Learning Iterative Reasoning through Energy Minimization}, 71 | author={Yilun Du and Shuang Li and Joshua B. Tenenbaum and Igor Mordatch}, 72 | booktitle={Proceedings of the 39th International Conference on Machine 73 | Learning (ICML-22)}, 74 | year={2022} 75 | } 76 | ``` 77 | 78 | Note: this is not an official Google product. 79 | 80 | ## License 81 | 82 | MIT 83 | -------------------------------------------------------------------------------- /graph_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import GINEConv, global_max_pool 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class GraphEBM(nn.Module): 8 | def __init__(self, inp_dim, out_dim, mem): 9 | super(GraphEBM, self).__init__() 10 | h = 128 11 | 12 | self.edge_map = nn.Linear(1, h // 2) 13 | self.edge_map_opt = nn.Linear(1, h // 2) 14 | 15 | self.conv1 = GINEConv(nn.Sequential(nn.Linear(1, 128)), edge_dim=h) 16 | self.conv2 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h) 17 | self.conv3 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h) 18 | 19 | self.fc1 = nn.Linear(128, 128) 20 | self.fc2 = nn.Linear(128, 1) 21 | 22 | def forward(self, inp, opt_edge): 23 | 24 | edge_embed = self.edge_map(inp.edge_attr) 25 | opt_edge_embed = self.edge_map_opt(opt_edge) 26 | 27 | edge_embed = torch.cat([edge_embed, opt_edge_embed], dim=-1) 28 | 29 | h = self.conv1(inp.x, inp.edge_index, edge_attr=edge_embed) 30 | h = self.conv2(h, inp.edge_index, edge_attr=edge_embed) 31 | h = self.conv3(h, inp.edge_index, edge_attr=edge_embed) 32 | 33 | mean_feat = global_max_pool(h, inp.batch) 34 | energy = self.fc2(F.relu(self.fc1(mean_feat))) 35 | 36 | return energy 37 | 38 | 39 | class GraphIterative(nn.Module): 40 | def __init__(self, inp_dim, out_dim, mem): 41 | super(GraphIterative, self).__init__() 42 | h = 128 43 | 44 | self.edge_map = nn.Linear(1, h // 2) 45 | self.edge_map_opt = nn.Linear(1, h // 2) 46 | 47 | self.conv1 = GINEConv(nn.Sequential(nn.Linear(1, 128)), edge_dim=h) 48 | self.conv2 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h) 49 | self.conv3 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h) 50 | 51 | self.decode = nn.Linear(256, 1) 52 | 53 | def forward(self, inp, opt_edge): 54 | 55 | edge_embed = self.edge_map(inp.edge_attr) 56 | opt_edge_embed = self.edge_map_opt(opt_edge) 57 | 58 | edge_embed = torch.cat([edge_embed, opt_edge_embed], dim=-1) 59 | 60 | h = self.conv1(inp.x, inp.edge_index, edge_attr=edge_embed) 61 | h = self.conv2(h, inp.edge_index, edge_attr=edge_embed) 62 | h = self.conv3(h, inp.edge_index, edge_attr=edge_embed) 63 | 64 | edge_index = inp.edge_index 65 | hidden = h 66 | 67 | h1 = hidden[edge_index[0]] 68 | h2 = hidden[edge_index[1]] 69 | 70 | h = torch.cat([h1, h2], dim=-1) 71 | output = self.decode(h) 72 | 73 | return output 74 | 75 | 76 | class GraphRecurrent(nn.Module): 77 | def __init__(self, inp_dim, out_dim): 78 | super(GraphRecurrent, self).__init__() 79 | h = 128 80 | 81 | self.edge_map = nn.Linear(1, h) 82 | 83 | self.conv1 = GINEConv(nn.Sequential(nn.Linear(1, 128)), edge_dim=h) 84 | 85 | self.lstm = nn.LSTM(h, h) 86 | # self.conv2 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h) 87 | self.conv3 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h) 88 | 89 | self.decode = nn.Linear(256, 1) 90 | 91 | def forward(self, inp, state=None): 92 | 93 | edge_embed = self.edge_map(inp.edge_attr) 94 | 95 | h = self.conv1(inp.x, inp.edge_index, edge_attr=edge_embed) 96 | 97 | h, state = self.lstm(h[None], state) 98 | h = self.conv3(h[0], inp.edge_index, edge_attr=edge_embed) 99 | 100 | edge_index = inp.edge_index 101 | hidden = h 102 | 103 | h1 = hidden[edge_index[0]] 104 | h2 = hidden[edge_index[1]] 105 | 106 | h = torch.cat([h1, h2], dim=-1) 107 | output = self.decode(h) 108 | 109 | return output, state 110 | 111 | 112 | class GraphPonder(nn.Module): 113 | def __init__(self, inp_dim, out_dim): 114 | super(GraphPonder, self).__init__() 115 | h = 128 116 | 117 | self.edge_map = nn.Linear(1, h) 118 | 119 | self.conv1 = GINEConv(nn.Sequential(nn.Linear(1, 128)), edge_dim=h) 120 | self.conv2 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h) 121 | self.conv3 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h) 122 | 123 | self.decode = nn.Linear(256, 1) 124 | 125 | def forward(self, inp, iters=1): 126 | 127 | edge_embed = self.edge_map(inp.edge_attr) 128 | 129 | h = self.conv1(inp.x, inp.edge_index, edge_attr=edge_embed) 130 | outputs = [] 131 | 132 | for i in range(iters): 133 | h = F.relu(self.conv2(h, inp.edge_index, edge_attr=edge_embed)) 134 | 135 | output = self.conv3(h, inp.edge_index, edge_attr=edge_embed) 136 | 137 | edge_index = inp.edge_index 138 | hidden = output 139 | 140 | h1 = hidden[edge_index[0]] 141 | h2 = hidden[edge_index[1]] 142 | 143 | output = torch.cat([h1, h2], dim=-1) 144 | output = self.decode(output) 145 | 146 | outputs.append(output) 147 | 148 | return outputs 149 | 150 | 151 | 152 | class GraphFC(nn.Module): 153 | def __init__(self, inp_dim, out_dim): 154 | super(GraphFC, self).__init__() 155 | h = 128 156 | 157 | self.edge_map = nn.Linear(1, h) 158 | self.conv1 = GINEConv(nn.Sequential(nn.Linear(1, 128)), edge_dim=h) 159 | self.conv2 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h) 160 | self.conv3 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h) 161 | 162 | self.decode = nn.Linear(256, 1) 163 | 164 | 165 | def forward(self, inp): 166 | 167 | edge_embed = self.edge_map(inp.edge_attr) 168 | 169 | h = self.conv1(inp.x, inp.edge_index, edge_attr=edge_embed) 170 | h = self.conv2(h, inp.edge_index, edge_attr=edge_embed) 171 | h = self.conv3(h, inp.edge_index, edge_attr=edge_embed) 172 | 173 | edge_index = inp.edge_index 174 | hidden = h 175 | 176 | h1 = hidden[edge_index[0]] 177 | h2 = hidden[edge_index[1]] 178 | 179 | h = torch.cat([h1, h2], dim=-1) 180 | output = self.decode(h) 181 | 182 | return output 183 | -------------------------------------------------------------------------------- /graph_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from torch_geometric.data import Data 3 | import torch 4 | import numpy as np 5 | from scipy.sparse.csgraph import connected_components 6 | from scipy.sparse import csr_matrix 7 | from scipy.sparse.csgraph import shortest_path 8 | import random 9 | 10 | 11 | def extract(a, t, x_shape): 12 | b, *_ = t.shape 13 | out = a.gather(-1, t) 14 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 15 | 16 | 17 | def cosine_beta_schedule(timesteps, s = 0.008): 18 | """ 19 | cosine schedule 20 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 21 | """ 22 | steps = timesteps + 1 23 | x = np.linspace(0, steps, steps) 24 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 25 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 26 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 27 | return np.clip(betas, a_min = 0, a_max = 0.999) 28 | 29 | 30 | class NoisyWrapper: 31 | 32 | def __init__(self, dataset, timesteps): 33 | 34 | self.dataset = dataset 35 | self.timesteps = timesteps 36 | betas = cosine_beta_schedule(timesteps) 37 | self.inp_dim = dataset.inp_dim 38 | self.out_dim = dataset.out_dim 39 | 40 | alphas = 1. - betas 41 | alphas_cumprod = np.cumprod(alphas, axis=0) 42 | 43 | # self.sqrt_alphas_cumprod = torch.tensor(np.sqrt(alphas_cumprod)) 44 | # self.sqrt_one_minus_alphas_cumprod = torch.tensor(np.sqrt(1. - alphas_cumprod)) 45 | 46 | alphas_cumprod = np.linspace(1, 0, timesteps) 47 | self.sqrt_alphas_cumprod = torch.tensor(np.sqrt(alphas_cumprod)) 48 | self.sqrt_one_minus_alphas_cumprod = torch.tensor(np.sqrt(1. - alphas_cumprod)) 49 | self.extract = extract 50 | 51 | def __len__(self): 52 | return len(self.dataset) 53 | 54 | def __getitem__(self, *args, **kwargs): 55 | data = self.dataset.__getitem__(*args, **kwargs) 56 | y = data['y'] 57 | 58 | t = torch.randint(1, self.timesteps, (1,)).long() 59 | t_next = t - 1 60 | noise = torch.randn_like(y) 61 | 62 | sample = ( 63 | self.extract(self.sqrt_alphas_cumprod, t, y.shape) * y + 64 | self.extract(self.sqrt_one_minus_alphas_cumprod, t, y.shape) * noise 65 | ) 66 | 67 | sample_next = ( 68 | self.extract(self.sqrt_alphas_cumprod, t_next, y.shape) * y + 69 | self.extract(self.sqrt_one_minus_alphas_cumprod, t_next, y.shape) * noise 70 | ) 71 | 72 | data['y_prev'] = sample.float() 73 | data['y'] = sample_next.float() 74 | 75 | return data 76 | 77 | 78 | class Identity(data.Dataset): 79 | """Constructs a dataset with N circles, N is the number of components set by flags""" 80 | 81 | def __init__(self, split, rank, vary=False): 82 | """Initialize this dataset class. 83 | 84 | Parameters: 85 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 86 | """ 87 | self.h = rank 88 | self.w = rank 89 | self.vary = vary 90 | self.rank = rank 91 | 92 | self.split = split 93 | self.inp_dim = self.h * self.w 94 | self.out_dim = self.h * self.w 95 | 96 | def __getitem__(self, index): 97 | """Return a data point and its metadata information. 98 | 99 | Parameters: 100 | index - - a random integer for data indexing 101 | 102 | Returns a dictionary that contains A and A_paths 103 | A(tensor) - - an image in one domain 104 | A_paths(str) - - the path of the image 105 | """ 106 | 107 | if self.vary: 108 | rank = random.randint(2, self.rank) 109 | else: 110 | rank = self.rank 111 | 112 | R = np.random.uniform(-1, 1, (rank, rank)) 113 | R_corrupt = R 114 | 115 | repeat = 128 // self.w + 1 116 | 117 | R_tile = np.tile(R, (1, repeat)) 118 | 119 | node_features = torch.Tensor(np.zeros((rank, 1))) 120 | edge_index = torch.LongTensor(np.array(np.mgrid[:rank, :rank]).reshape((2, -1))) 121 | edge_features_context = torch.Tensor(np.tile(R_corrupt.reshape((-1, 1)), (1, 1))) 122 | edge_features_label = torch.Tensor(R.reshape((-1, 1))) 123 | noise = torch.Tensor(edge_features_label.reshape((-1, 1))) 124 | 125 | data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features_context, y=edge_features_label, noise=noise) 126 | 127 | return data 128 | 129 | def __len__(self): 130 | """Return the total number of images in the dataset.""" 131 | # Dataset is always randomly generated 132 | return int(1e6) 133 | 134 | 135 | 136 | class ConnectedComponents(data.Dataset): 137 | """Constructs a dataset with N circles, N is the number of components set by flags""" 138 | 139 | def __init__(self, split, rank, vary=False): 140 | """Initialize this dataset class. 141 | 142 | Parameters: 143 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 144 | """ 145 | self.h = rank 146 | self.w = rank 147 | self.rank = rank 148 | self.vary = vary 149 | 150 | self.split = split 151 | self.inp_dim = self.h * self.w 152 | self.out_dim = self.h * self.w 153 | 154 | def __getitem__(self, index): 155 | """Return a data point and its metadata information. 156 | 157 | Parameters: 158 | index - - a random integer for data indexing 159 | 160 | Returns a dictionary that contains A and A_paths 161 | A(tensor) - - an image in one domain 162 | A_paths(str) - - the path of the image 163 | """ 164 | 165 | if self.vary: 166 | rank = random.randint(2, self.rank) 167 | else: 168 | rank = self.rank 169 | 170 | R = np.random.uniform(0, 1, (rank, rank)) 171 | connections = R > 0.95 172 | components, component_arr = connected_components(connections, directed=False) 173 | 174 | label = np.zeros((rank, rank)) 175 | 176 | for i in range(components): 177 | component_i = np.arange(rank)[component_arr == i] 178 | 179 | for j in range(component_i.shape[0]): 180 | for k in range(component_i.shape[0]): 181 | label[component_i[j], component_i[k]] = 1 182 | 183 | node_features = torch.Tensor(np.zeros((rank, 1))) 184 | edge_index = torch.LongTensor(np.array(np.mgrid[:rank, :rank]).reshape((2, -1))) 185 | 186 | edge_features_context = torch.Tensor(np.tile(connections.reshape((-1, 1)), (1, 1))) 187 | edge_features_label = torch.Tensor(label.reshape((-1, 1))) 188 | noise = torch.Tensor(label.reshape((-1, 1))) 189 | 190 | data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features_context, y=edge_features_label, noise=noise) 191 | 192 | return data 193 | 194 | def __len__(self): 195 | """Return the total number of images in the dataset.""" 196 | # Dataset is always randomly generated 197 | return int(1e6) 198 | 199 | 200 | class ShortestPath(data.Dataset): 201 | """Constructs a dataset with N circles, N is the number of components set by flags""" 202 | 203 | def __init__(self, split, rank, vary=False): 204 | """Initialize this dataset class. 205 | 206 | Parameters: 207 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 208 | """ 209 | self.h = rank 210 | self.w = rank 211 | self.vary = vary 212 | self.rank = rank 213 | 214 | self.split = split 215 | self.inp_dim = self.h * self.w 216 | self.out_dim = self.h * self.w 217 | 218 | def __getitem__(self, index): 219 | """Return a data point and its metadata information. 220 | 221 | Parameters: 222 | index - - a random integer for data indexing 223 | 224 | Returns a dictionary that contains A and A_paths 225 | A(tensor) - - an image in one domain 226 | A_paths(str) - - the path of the image 227 | """ 228 | if self.vary: 229 | rank = random.randint(2, self.rank) 230 | else: 231 | rank = self.rank 232 | 233 | graph = np.random.uniform(0, 1, size=[rank, rank]) 234 | graph = graph + graph.transpose() 235 | np.fill_diagonal(graph, 0) # can move to self 236 | 237 | graph_dist, graph_predecessors = shortest_path(csgraph=csr_matrix(graph), unweighted=False, directed=False, return_predecessors=True) 238 | 239 | node_features = torch.Tensor(np.zeros((rank, 1))) 240 | edge_index = torch.LongTensor(np.array(np.mgrid[:rank, :rank]).reshape((2, -1))) 241 | 242 | edge_features_context = torch.Tensor(np.tile(graph.reshape((-1, 1)), (1, 1))) 243 | edge_features_label = torch.Tensor(graph_dist.reshape((-1, 1))) 244 | noise = torch.Tensor(edge_features_label.reshape((-1, 1))) 245 | 246 | data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features_context, y=edge_features_label, noise=noise) 247 | 248 | return data 249 | 250 | def __len__(self): 251 | """Return the total number of images in the dataset.""" 252 | # Dataset is always randomly generated 253 | return int(1e6) 254 | 255 | 256 | if __name__ == "__main__": 257 | dataset = ConnectedComponents('train', 10) 258 | data = dataset[0] 259 | import pdb 260 | pdb.set_trace() 261 | print(data) 262 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from torch.nn import ModuleList 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import torch 5 | from easydict import EasyDict 6 | import numpy as np 7 | from torch.nn import TransformerEncoderLayer 8 | 9 | 10 | def swish(x): 11 | return x * torch.sigmoid(x) 12 | 13 | 14 | class View(nn.Module): 15 | def __init__(self, size): 16 | super(View, self).__init__() 17 | self.size = size 18 | 19 | def forward(self, tensor): 20 | return tensor.view(self.size) 21 | 22 | 23 | def kaiming_init(m): 24 | if isinstance(m, (nn.Linear, nn.Conv2d)): 25 | torch.nn.init.kaiming_normal(m.weight) 26 | if m.bias is not None: 27 | m.bias.data.fill_(0) 28 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): 29 | m.weight.data.fill_(1) 30 | if m.bias is not None: 31 | m.bias.data.fill_(0) 32 | 33 | 34 | def reparametrize(mu, logvar): 35 | std = logvar.div(2).exp() 36 | eps = torch.autograd.Variable(std.data.new(std.size()).normal_()) 37 | return mu + std * eps 38 | 39 | 40 | class EBM(nn.Module): 41 | def __init__(self, inp_dim, out_dim, mem): 42 | super(EBM, self).__init__() 43 | h = 512 44 | 45 | if mem: 46 | self.fc1 = nn.Linear(inp_dim + out_dim + h, h) 47 | else: 48 | self.fc1 = nn.Linear(inp_dim + out_dim, h) 49 | 50 | self.mem = mem 51 | self.fc2 = nn.Linear(h, h) 52 | self.fc3 = nn.Linear(h, h) 53 | self.fc4 = nn.Linear(h, 1) 54 | 55 | self.inp_dim = inp_dim 56 | self.out_dim = out_dim 57 | 58 | def forward(self, x): 59 | h = swish(self.fc1(x)) 60 | h = swish(self.fc2(h)) 61 | h = swish(self.fc3(h)) 62 | 63 | output = self.fc4(h) 64 | 65 | return output 66 | 67 | 68 | class EBMTwin(nn.Module): 69 | def __init__(self, inp_dim, out_dim, mem): 70 | super(EBMTwin, self).__init__() 71 | h = 512 72 | self.inp_dim = inp_dim 73 | self.out_dim = out_dim 74 | 75 | self.first_fc1 = nn.Linear(inp_dim, h) 76 | self.second_fc1 = nn.Linear(out_dim, h) 77 | 78 | self.first_fc2 = nn.Linear(h, h) 79 | self.second_fc2 = nn.Linear(h, h) 80 | 81 | self.first_fc3 = nn.Linear(h, h) 82 | self.second_fc3 = nn.Linear(h, h) 83 | 84 | self.prod_h1 = nn.Linear(2 * h, h) 85 | self.prod_h2 = nn.Linear(2 * h, h) 86 | 87 | self.scale_parameter = torch.nn.Parameter(torch.ones(1)) 88 | 89 | def forward(self, x): 90 | x_out = x[..., :self.out_dim] 91 | x = x[..., self.out_dim:] 92 | 93 | first_h1 = self.first_fc1(x) 94 | second_h1 = self.second_fc1(x_out) 95 | 96 | first_h2 = self.first_fc2(swish(first_h1)) 97 | second_h2 = self.second_fc2(swish(second_h1)) 98 | 99 | first_h3 = self.first_fc3(swish(first_h2)) 100 | second_h3 = self.second_fc3(swish(second_h2)) 101 | 102 | h = torch.cat([first_h2, second_h2], dim=-1) 103 | 104 | energy_3 = torch.abs(first_h3 - second_h3).mean(dim=-1) 105 | 106 | energy = energy_3 107 | 108 | return energy 109 | 110 | 111 | class FC(nn.Module): 112 | def __init__(self, inp_dim, out_dim): 113 | super(FC, self).__init__() 114 | h = 512 115 | 116 | self.fc1 = nn.Linear(inp_dim, h) 117 | 118 | self.fc2 = nn.Linear(h, h) 119 | self.fc3 = nn.Linear(h, h) 120 | self.fc4 = nn.Linear(h, out_dim) 121 | 122 | self.inp_dim = inp_dim 123 | self.out_dim = out_dim 124 | 125 | def forward(self, x): 126 | h = F.relu(self.fc1(x)) 127 | h = F.relu(self.fc2(h)) 128 | h = F.relu(self.fc3(h)) 129 | output = self.fc4(h) 130 | 131 | return output 132 | 133 | 134 | class RecurrentFC(nn.Module): 135 | def __init__(self, inp_dim, out_dim): 136 | super(RecurrentFC, self).__init__() 137 | h = 196 138 | 139 | self.inp_map = nn.Linear(inp_dim, h) 140 | self.inp_map2 = nn.Linear(inp_dim, h) 141 | self.fc1 = nn.Linear(inp_dim, h) 142 | self.lstm = nn.LSTM(h, h) 143 | self.fc3 = nn.Linear(h, h) 144 | self.fc4 = nn.Linear(h, out_dim) 145 | self.scale_parameter = torch.nn.Parameter(torch.ones(1)) 146 | 147 | self.inp_dim = inp_dim 148 | self.out_dim = out_dim 149 | 150 | def forward(self, x, state=None): 151 | 152 | h = F.relu(self.fc1(x)) 153 | h, state = self.lstm(h[None], state) 154 | output = self.fc4(h[0]) 155 | 156 | return output, state 157 | 158 | 159 | class Embedder: 160 | def __init__(self, **kwargs): 161 | self.kwargs = kwargs 162 | self.create_embedding_fn() 163 | 164 | def create_embedding_fn(self): 165 | embed_fns = [] 166 | d = self.kwargs['input_dims'] 167 | out_dim = 0 168 | if self.kwargs['include_input']: 169 | embed_fns.append(lambda x: x) 170 | out_dim += d 171 | 172 | max_freq = self.kwargs['max_freq_log2'] 173 | N_freqs = self.kwargs['num_freqs'] 174 | 175 | if self.kwargs['random_feature']: 176 | freq_bands = 2.0**torch.linspace(0., max_freq, steps=N_freqs) 177 | elif self.kwargs['log_sampling']: 178 | freq_bands = 2.0**torch.linspace(0., max_freq, steps=N_freqs) 179 | else: 180 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) 181 | 182 | for freq in freq_bands: 183 | for p_fn in self.kwargs['periodic_fns']: 184 | embed_fns.append( 185 | lambda x, 186 | p_fn=p_fn, 187 | freq=freq: p_fn( 188 | x * freq)) 189 | out_dim += d 190 | 191 | self.embed_fns = embed_fns 192 | self.out_dim = out_dim 193 | 194 | def embed(self, inputs): 195 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 196 | 197 | 198 | class IterativeFC(nn.Module): 199 | def __init__(self, inp_dim, out_dim, *args): 200 | super(IterativeFC, self).__init__() 201 | h = 512 202 | 203 | input_dims = 1 204 | multires = 10 205 | 206 | self.fc1 = nn.Linear(inp_dim + out_dim, h) 207 | self.fc2 = nn.Linear(h, h) 208 | self.fc3 = nn.Linear(h, h) 209 | self.fc4 = nn.Linear(h, out_dim) 210 | self.scale_parameter = torch.nn.Parameter(torch.ones(1)) 211 | self.fc_map = nn.Linear(out_dim, inp_dim) 212 | 213 | self.out_dim = out_dim 214 | self.inp_dim = inp_dim 215 | 216 | def forward(self, x): 217 | x_out = x[..., :self.out_dim] 218 | x = x[..., self.out_dim:] 219 | 220 | h = self.fc_map(x_out) 221 | 222 | prod = x * h 223 | x = torch.cat([x, x_out], dim=-1) 224 | 225 | h = h_first = swish(self.fc1(x)) 226 | h = swish(self.fc2(h)) 227 | h = swish(self.fc3(h)) 228 | output = self.fc4(h) 229 | 230 | return output 231 | 232 | 233 | class PonderFC(nn.Module): 234 | def __init__(self, inp_dim, out_dim, num_step): 235 | super(PonderFC, self).__init__() 236 | h = 512 237 | 238 | input_dims = 1 239 | multires = 10 240 | 241 | self.fc1 = nn.Linear(inp_dim + out_dim, h) 242 | self.fc2 = nn.Linear(h, h) 243 | self.fc3 = nn.Linear(h, h) 244 | self.fc4 = nn.Linear(h, h) 245 | self.decode_logit = nn.Linear(h, 1) 246 | 247 | self.fc4_decode = nn.Linear(h, out_dim) 248 | self.scale_parameter = torch.nn.Parameter(torch.ones(1)) 249 | self.fc_map = nn.Linear(out_dim, inp_dim) 250 | 251 | self.out_dim = out_dim 252 | self.inp_dim = inp_dim 253 | self.ponder_prob = 1 - 1 / num_step 254 | 255 | def forward(self, x, eval=False, iters=1): 256 | logits = [] 257 | outputs = [] 258 | 259 | x_out = x[..., :self.out_dim] 260 | x = x[..., self.out_dim:] 261 | 262 | joint = torch.cat([x_out, x], dim=-1) 263 | 264 | for i in range(iters): 265 | h = F.relu(self.fc1(joint)) 266 | h = F.relu(self.fc2(h)) 267 | logit = self.decode_logit(h) 268 | output = self.fc4_decode(h) 269 | 270 | joint = torch.cat([output, x], dim=-1) 271 | 272 | logits.append(logit) 273 | outputs.append(output) 274 | 275 | logits = torch.stack(logits, dim=1) 276 | 277 | return outputs, logits 278 | 279 | 280 | class IterativeFCAttention(nn.Module): 281 | def __init__(self, inp_dim, out_dim, rank): 282 | super(IterativeFCAttention, self).__init__() 283 | h = 512 284 | 285 | input_dims = 1 286 | multires = 10 287 | self.rank = rank 288 | 289 | self.inp_fc1 = nn.Linear(inp_dim, h) 290 | self.inp_fc2 = nn.Linear(h, h) 291 | 292 | self.out_fc1 = nn.Linear(out_dim, h) 293 | self.out_fc2 = nn.Linear(h, h) 294 | 295 | self.fc1 = nn.Linear(3 * h, h) 296 | self.fc2 = nn.Linear(h, h) 297 | self.fc3 = nn.Linear(h, out_dim) 298 | 299 | self.out_dim = out_dim 300 | self.inp_dim = inp_dim 301 | self.scale_parameter = torch.nn.Parameter(torch.ones(1)) 302 | 303 | def forward(self, x): 304 | # Expects a input x which consists of input computation as well as 305 | # partially optimized result 306 | x_out = x[..., :self.out_dim] 307 | x = x[..., self.out_dim:] 308 | 309 | inp_h = F.relu(self.inp_fc2(F.relu(self.inp_fc1(x)))) 310 | out_h = F.relu(self.out_fc2(F.relu(self.out_fc1(x_out)))) 311 | 312 | prod_h = out_h * inp_h 313 | h = torch.cat([inp_h, out_h, prod_h], dim=-1) 314 | 315 | out = self.fc3(F.relu(self.fc2(F.relu(self.fc1(h))))) 316 | 317 | return out 318 | 319 | 320 | class IterativeTransformer(nn.Module): 321 | def __init__(self, inp_dim, out_dim, rank): 322 | super(IterativeTransformer, self).__init__() 323 | h = 512 324 | input_dims = 1 325 | multires = 10 326 | self.rank = rank 327 | 328 | embed_kwargs = { 329 | 'include_input': True, 330 | 'random_feature': True, 331 | 'input_dims': input_dims, 332 | 'max_freq_log2': multires - 1, 333 | 'num_freqs': multires, 334 | 'log_sampling': True, 335 | 'periodic_fns': [torch.sin, torch.cos], 336 | } 337 | 338 | self.inp_embed = Embedder(**embed_kwargs) 339 | self.out_embed = Embedder(**embed_kwargs) 340 | self.inp_linear = nn.Linear(1, 63) 341 | self.out_linear = nn.Linear(1, 63) 342 | 343 | coord = np.mgrid[:inp_dim] / inp_dim 344 | coord = torch.Tensor(coord)[None, :, None].cuda() 345 | 346 | self.coord = coord 347 | 348 | out_coord = np.mgrid[:out_dim] / out_dim 349 | out_coord = torch.Tensor(out_coord)[None, :, None].cuda() 350 | 351 | self.out_coord = out_coord 352 | 353 | self.encode1 = TransformerEncoderLayer(84, 4, dim_feedforward=256) 354 | self.encode2 = TransformerEncoderLayer(84, 4, dim_feedforward=256) 355 | 356 | self.fc1 = nn.Linear(84, 128) 357 | self.fc2 = nn.Linear(128, 1) 358 | 359 | self.out_dim = out_dim 360 | self.inp_dim = inp_dim 361 | 362 | def forward(self, x): 363 | x_out = x[..., :self.out_dim].view(-1, self.out_dim, 1) 364 | x = x[..., self.out_dim:].view(-1, self.inp_dim, 1) 365 | 366 | x_linear = self.inp_linear(x) 367 | x_out_linear = self.out_linear(x_out) 368 | x_pos_embed = self.inp_embed.embed(self.coord) 369 | x_out_pos_embed = self.out_embed.embed(self.out_coord) 370 | 371 | x_pos_embed = x_pos_embed.expand(x_linear.size(0), -1, -1) 372 | x_out_pos_embed = x_out_pos_embed.expand(x_linear.size(0), -1, -1) 373 | pos_embed = torch.cat([x_linear, x_pos_embed], dim=-1) 374 | out_embed = torch.cat([x_out_linear, x_out_pos_embed], dim=-1) 375 | 376 | s = pos_embed.size() 377 | pos_embed = pos_embed.view(s[0], -1, s[-1]) 378 | out_embed = out_embed.view(s[0], -1, s[-1]) 379 | 380 | embed = torch.cat([pos_embed, out_embed], dim=1) 381 | embed = embed.permute(1, 0, 2).contiguous() 382 | 383 | embed = self.encode1(embed) 384 | embed = self.encode2(embed) 385 | 386 | embed = self.fc2(F.relu(self.fc1(embed.permute(1, 0, 2))))[:, :, 0] 387 | output = embed[:, :self.out_dim] 388 | 389 | return output 390 | 391 | 392 | class IterativeAttention(nn.Module): 393 | def __init__(self, inp_dim, out_dim, rank): 394 | super(IterativeAttention, self).__init__() 395 | h = 64 396 | 397 | input_dims = 1 398 | multires = 10 399 | self.rank = rank 400 | self.out_dim = out_dim 401 | 402 | embed_kwargs = { 403 | 'include_input': True, 404 | 'random_feature': True, 405 | 'input_dims': input_dims, 406 | 'max_freq_log2': multires - 1, 407 | 'num_freqs': multires, 408 | 'log_sampling': True, 409 | 'periodic_fns': [torch.sin, torch.cos], 410 | } 411 | 412 | self.inp_embed = Embedder(**embed_kwargs) 413 | self.out_embed = Embedder(**embed_kwargs) 414 | self.inp_linear = nn.Linear(1, 42) 415 | self.out_linear = nn.Linear(1, 42) 416 | 417 | coord = np.mgrid[:rank, :rank] / rank 418 | coord = torch.Tensor(coord).permute(1, 2, 0)[None, :, :, :].cuda() 419 | 420 | self.coord = coord 421 | 422 | out_coord = np.mgrid[:out_dim] / rank 423 | out_coord = torch.Tensor(out_coord)[None, :, None].cuda() 424 | 425 | self.out_coord = out_coord 426 | 427 | self.inp_fc1 = nn.Linear(84, h) 428 | self.inp_fc2 = nn.Linear(h, h) 429 | self.inp_fc3 = nn.Linear(h, h) 430 | 431 | self.out_fc1 = nn.Linear(63, h) 432 | self.out_fc2 = nn.Linear(h, h) 433 | self.out_fc3 = nn.Linear(h, h) 434 | 435 | self.at_fc1 = nn.Linear(63, h) 436 | self.at_fc2 = nn.Linear(h, h) 437 | 438 | self.fc1 = nn.Linear(3 * h, h) 439 | self.fc2 = nn.Linear(h, h) 440 | self.fc3 = nn.Linear(h, 1) 441 | self.scale_parameter = torch.nn.Parameter(torch.ones(1)) 442 | 443 | self.out_dim = out_dim 444 | self.inp_dim = inp_dim 445 | 446 | def forward(self, x, idx): 447 | # Expects a input x which consists of input computation as well as 448 | # partially optimized result 449 | output = x[..., :self.out_dim] 450 | x_out = x[..., :self.out_dim].view(-1, self.out_dim, 1) 451 | x = x[..., self.out_dim:].view(-1, self.rank, self.rank, 1) 452 | 453 | x_linear = self.inp_linear(x) 454 | x_out_linear = self.out_linear(x_out) 455 | x_pos_embed = self.inp_embed.embed(self.coord) 456 | x_out_pos_embed = self.out_embed.embed(self.out_coord) 457 | 458 | x_pos_embed = x_pos_embed.expand(x_linear.size(0), -1, -1, -1) 459 | x_out_pos_embed = x_out_pos_embed.expand(x_linear.size(0), -1, -1) 460 | 461 | pos_embed = torch.cat([x_linear, x_pos_embed], dim=-1) 462 | out_embed = torch.cat([x_out_linear, x_out_pos_embed], dim=-1) 463 | 464 | s = pos_embed.size() 465 | pos_embed = pos_embed.view(s[0], -1, s[-1]) 466 | s = out_embed.size() 467 | out_embed = out_embed.view(s[0], -1, s[-1]) 468 | 469 | idx = idx[:, :, None].expand(-1, -1, out_embed.size(-1)) 470 | at_embed = torch.gather(pos_embed, 1, idx) 471 | 472 | at_embed = self.at_fc2(F.relu(self.at_fc1(at_embed))) 473 | out_embed = self.out_fc2(F.relu(self.out_fc1(out_embed))) 474 | pos_embed = self.inp_fc2(F.relu(self.inp_fc1(pos_embed))) 475 | 476 | out_val = self.out_fc3(F.relu(out_embed)) 477 | pos_val = self.inp_fc3(F.relu(pos_embed)) 478 | 479 | at_out_wt = torch.einsum('bij,bkj->bik', at_embed, out_embed) 480 | at_out_wt = torch.softmax(at_out_wt, dim=-1) 481 | out_embed = (at_out_wt[:, :, :, None] * 482 | out_val[:, None, :, :]).sum(dim=2) 483 | 484 | at_pos_wt = torch.einsum('bij,bkj->bik', at_embed, pos_embed) 485 | at_pos_wt = torch.softmax(at_pos_wt, dim=-1) 486 | pos_embed = (at_pos_wt[:, :, :, None] * 487 | pos_val[:, None, :, :]).sum(dim=2) 488 | 489 | embed = torch.cat([at_embed, pos_embed, out_embed], dim=-1) 490 | out = self.fc3(F.relu(self.fc2(F.relu(self.fc1(embed)))))[:, :, 0] 491 | 492 | output.requires_grad_() 493 | idx = idx[:, :, 0] 494 | output = torch.scatter(output, 1, idx, out) 495 | 496 | return output 497 | 498 | 499 | class AttentionEBM(nn.Module): 500 | def __init__(self, inp_dim, out_dim, rank): 501 | super(AttentionEBM, self).__init__() 502 | h = 64 503 | 504 | input_dims = 1 505 | multires = 10 506 | self.rank = rank 507 | self.out_dim = out_dim 508 | 509 | embed_kwargs = { 510 | 'include_input': True, 511 | 'random_feature': True, 512 | 'input_dims': input_dims, 513 | 'max_freq_log2': multires - 1, 514 | 'num_freqs': multires, 515 | 'log_sampling': True, 516 | 'periodic_fns': [torch.sin, torch.cos], 517 | } 518 | 519 | self.inp_embed = Embedder(**embed_kwargs) 520 | self.out_embed = Embedder(**embed_kwargs) 521 | self.inp_linear = nn.Linear(1, 42) 522 | self.out_linear = nn.Linear(1, 42) 523 | 524 | coord = np.mgrid[:rank, :rank] / rank 525 | coord = torch.Tensor(coord).permute(1, 2, 0)[None, :, :, :].cuda() 526 | 527 | self.coord = coord 528 | 529 | out_coord = np.mgrid[:out_dim] / rank 530 | out_coord = torch.Tensor(out_coord)[None, :, None].cuda() 531 | 532 | self.out_coord = out_coord 533 | 534 | self.inp_fc1 = nn.Linear(84, h) 535 | self.inp_fc2 = nn.Linear(h, h) 536 | self.inp_fc3 = nn.Linear(h, h) 537 | 538 | self.out_fc1 = nn.Linear(63, h) 539 | self.out_fc2 = nn.Linear(h, h) 540 | self.out_fc3 = nn.Linear(h, h) 541 | 542 | self.at_fc1 = nn.Linear(63, h) 543 | self.at_fc2 = nn.Linear(h, h) 544 | 545 | self.fc1 = nn.Linear(3 * h, h) 546 | self.fc2 = nn.Linear(h, h) 547 | self.fc3 = nn.Linear(h, 1) 548 | self.scale_parameter = torch.nn.Parameter(torch.ones(1)) 549 | 550 | self.out_dim = out_dim 551 | self.inp_dim = inp_dim 552 | 553 | def forward(self, x, idx): 554 | # Expects a input x which consists of input computation as well as 555 | # partially optimized result 556 | 557 | act_fn = swish 558 | 559 | output = x[..., :self.out_dim] 560 | x_out = x[..., :self.out_dim].view(-1, self.out_dim, 1) 561 | x = x[..., self.out_dim:].view(-1, self.rank, self.rank, 1) 562 | 563 | x_linear = self.inp_linear(x) 564 | x_out_linear = self.out_linear(x_out) 565 | x_pos_embed = self.inp_embed.embed(self.coord) 566 | x_out_pos_embed = self.out_embed.embed(self.out_coord) 567 | 568 | x_pos_embed = x_pos_embed.expand(x_linear.size(0), -1, -1, -1) 569 | x_out_pos_embed = x_out_pos_embed.expand(x_linear.size(0), -1, -1) 570 | 571 | pos_embed = torch.cat([x_linear, x_pos_embed], dim=-1) 572 | out_embed = torch.cat([x_out_linear, x_out_pos_embed], dim=-1) 573 | 574 | s = pos_embed.size() 575 | pos_embed = pos_embed.view(s[0], -1, s[-1]) 576 | s = out_embed.size() 577 | out_embed = out_embed.view(s[0], -1, s[-1]) 578 | 579 | idx = idx[:, :, None].expand(-1, -1, out_embed.size(-1)) 580 | at_embed = torch.gather(pos_embed, 1, idx) 581 | 582 | at_embed = self.at_fc2(act_fn(self.at_fc1(at_embed))) 583 | out_embed = self.out_fc2(act_fn(self.out_fc1(out_embed))) 584 | pos_embed = self.inp_fc2(act_fn(self.inp_fc1(pos_embed))) 585 | 586 | out_val = self.out_fc3(act_fn(out_embed)) 587 | pos_val = self.inp_fc3(act_fn(pos_embed)) 588 | 589 | at_out_wt = torch.einsum('bij,bkj->bik', at_embed, out_embed) 590 | at_out_wt = torch.softmax(at_out_wt, dim=-1) 591 | out_embed = (at_out_wt[:, :, :, None] * 592 | out_val[:, None, :, :]).sum(dim=2) 593 | 594 | at_pos_wt = torch.einsum('bij,bkj->bik', at_embed, pos_embed) 595 | at_pos_wt = torch.softmax(at_pos_wt, dim=-1) 596 | pos_embed = (at_pos_wt[:, :, :, None] * 597 | pos_val[:, None, :, :]).sum(dim=2) 598 | 599 | embed = torch.cat([at_embed, pos_embed, out_embed], dim=-1) 600 | out = self.fc3(act_fn(self.fc2(act_fn(self.fc1(embed)))))[:, :, 0] 601 | 602 | return out 603 | 604 | 605 | if __name__ == "__main__": 606 | # Unit test the model 607 | args = EasyDict() 608 | args.filter_dim = 64 609 | args.latent_dim = 64 610 | args.im_size = 256 611 | 612 | model = LatentEBM(args).cuda() 613 | x = torch.zeros(1, 3, 256, 256).cuda() 614 | latent = torch.zeros(1, 64).cuda() 615 | model(x, latent) 616 | -------------------------------------------------------------------------------- /graph_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from graph_models import GraphEBM, GraphFC, GraphPonder, GraphRecurrent 3 | import torch.nn.functional as F 4 | import os 5 | import pdb 6 | from graph_dataset import Identity, ConnectedComponents, ShortestPath 7 | import matplotlib.pyplot as plt 8 | import torch.multiprocessing as mp 9 | import torch.distributed as dist 10 | from torch.optim import Adam, SparseAdam 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch_geometric.loader import DataLoader 13 | from torch_geometric.data.batch import Batch 14 | from easydict import EasyDict 15 | import os.path as osp 16 | from torch.nn.utils import clip_grad_norm 17 | import numpy as np 18 | from imageio import imwrite 19 | import argparse 20 | from torchvision.datasets import ImageFolder 21 | import torchvision.transforms as transforms 22 | import torch.nn as nn 23 | import torch.optim as optim 24 | import torch.backends.cudnn as cudnn 25 | import random 26 | from torchvision.utils import make_grid 27 | import seaborn as sns 28 | from torch_geometric.nn import global_mean_pool 29 | 30 | 31 | def worker_init_fn(worker_id): 32 | np.random.seed(int(torch.utils.data.get_worker_info().seed)%(2**32-1)) 33 | 34 | class ReplayBuffer(object): 35 | def __init__(self, size): 36 | """Create Replay buffer. 37 | Parameters 38 | ---------- 39 | size: int 40 | Max number of transitions to store in the buffer. When the buffer 41 | overflows the old memories are dropped. 42 | """ 43 | self._storage = [] 44 | self._maxsize = size 45 | self._next_idx = 0 46 | 47 | def __len__(self): 48 | return len(self._storage) 49 | 50 | def add(self, inputs): 51 | batch_size = len(inputs) 52 | if self._next_idx >= len(self._storage): 53 | self._storage.extend(inputs) 54 | else: 55 | if batch_size + self._next_idx < self._maxsize: 56 | self._storage[self._next_idx:self._next_idx + 57 | batch_size] = inputs 58 | else: 59 | split_idx = self._maxsize - self._next_idx 60 | self._storage[self._next_idx:] = inputs[:split_idx] 61 | self._storage[:batch_size - split_idx] = inputs[split_idx:] 62 | self._next_idx = (self._next_idx + batch_size) % self._maxsize 63 | 64 | def _encode_sample(self, idxes): 65 | datas = [] 66 | for i in idxes: 67 | data = self._storage[i] 68 | datas.append(data) 69 | 70 | return datas 71 | 72 | def sample(self, batch_size, no_transform=False, downsample=False): 73 | """Sample a batch of experiences. 74 | Parameters 75 | ---------- 76 | batch_size: int 77 | How many transitions to sample. 78 | Returns 79 | ------- 80 | obs_batch: np.array 81 | batch of observations 82 | act_batch: np.array 83 | batch of actions executed given obs_batch 84 | rew_batch: np.array 85 | rewards received as results of executing act_batch 86 | next_obs_batch: np.array 87 | next set of observations seen after executing act_batch 88 | done_mask: np.array 89 | done_mask[i] = 1 if executing act_batch[i] resulted in 90 | the end of an episode and 0 otherwise. 91 | """ 92 | idxes = [random.randint(0, len(self._storage) - 1) 93 | for _ in range(batch_size)] 94 | return self._encode_sample(idxes), torch.Tensor(idxes) 95 | 96 | def set_elms(self, data, idxes): 97 | if len(self._storage) < self._maxsize: 98 | self.add(data) 99 | else: 100 | for i, ix in enumerate(idxes): 101 | self._storage[ix] = data[i] 102 | 103 | 104 | """Parse input arguments""" 105 | parser = argparse.ArgumentParser(description='Train EBM model') 106 | 107 | parser.add_argument('--train', action='store_true', help='whether or not to train') 108 | parser.add_argument('--cuda', action='store_true', help='whether to use cuda or not') 109 | parser.add_argument('--vary', action='store_true', help='vary size of graph') 110 | parser.add_argument('--no_replay_buffer', action='store_true', help='utilize a replay buffer') 111 | 112 | parser.add_argument('--dataset', default='identity', type=str, help='dataset to evaluate') 113 | parser.add_argument('--logdir', default='cachedir', type=str, help='location where log of experiments will be stored') 114 | parser.add_argument('--exp', default='default', type=str, help='name of experiments') 115 | 116 | # training 117 | parser.add_argument('--resume_iter', default=0, type=int, help='iteration to resume training') 118 | parser.add_argument('--batch_size', default=512, type=int, help='size of batch of input to use') 119 | parser.add_argument('--num_epoch', default=10000, type=int, help='number of epochs of training to run') 120 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate for training') 121 | parser.add_argument('--log_interval', default=10, type=int, help='log outputs every so many batches') 122 | parser.add_argument('--save_interval', default=1000, type=int, help='save outputs every so many batches') 123 | 124 | # data 125 | parser.add_argument('--data_workers', default=4, type=int, help='Number of different data workers to load data in parallel') 126 | 127 | # EBM specific settings 128 | parser.add_argument('--filter_dim', default=64, type=int, help='number of filters to use') 129 | parser.add_argument('--rank', default=10, type=int, help='rank of matrix to use') 130 | parser.add_argument('--num_steps', default=5, type=int, help='Steps of gradient descent for training') 131 | parser.add_argument('--step_lr', default=100.0, type=float, help='step size of latents') 132 | parser.add_argument('--latent_dim', default=64, type=int, help='dimension of the latent') 133 | parser.add_argument('--decoder', action='store_true', help='utilize a decoder to output prediction') 134 | parser.add_argument('--gen', action='store_true', help='evaluate generalization') 135 | parser.add_argument('--gen_rank', default=5, type=int, help='Add additional rank for generalization') 136 | parser.add_argument('--recurrent', action='store_true', help='utilize a decoder to output prediction') 137 | parser.add_argument('--ponder', action='store_true', help='utilize a decoder to output prediction') 138 | parser.add_argument('--no_truncate_grad', action='store_true', help='not truncate gradient') 139 | parser.add_argument('--iterative_decoder', action='store_true', help='utilize a decoder to output prediction') 140 | parser.add_argument('--mem', action='store_true', help='add external memory to compute answers') 141 | 142 | # Distributed training hyperparameters 143 | parser.add_argument('--nodes', default=1, type=int, help='number of nodes for training') 144 | parser.add_argument('--gpus', default=1, type=int, help='number of gpus per nodes') 145 | parser.add_argument('--node_rank', default=0, type=int, help='rank of node') 146 | parser.add_argument('--capacity', default=50000, type=int, help='number of elements to generate') 147 | parser.add_argument('--infinite', action='store_true', help='makes the dataset have an infinite number of elements') 148 | 149 | best_test_error = 10.0 150 | best_gen_test_error = 10.0 151 | 152 | 153 | def average_gradients(model): 154 | size = float(dist.get_world_size()) 155 | 156 | for name, param in model.named_parameters(): 157 | if param.grad is None: 158 | continue 159 | 160 | dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM) 161 | param.grad.data /= size 162 | 163 | 164 | def gen_answer(inp, FLAGS, model, pred, scratch, num_steps, create_graph=True): 165 | preds = [] 166 | im_grads = [] 167 | energies = [] 168 | im_sups = [] 169 | 170 | if FLAGS.decoder: 171 | pred = model.forward(inp) 172 | preds = [pred] 173 | im_grad = torch.zeros(1) 174 | im_grads = [im_grad] 175 | energies = [torch.zeros(1)] 176 | elif FLAGS.recurrent: 177 | preds = [] 178 | im_grad = torch.zeros(1) 179 | im_grads = [im_grad] 180 | energies = [torch.zeros(1)] 181 | state = None 182 | for i in range(num_steps): 183 | pred, state = model.forward(inp, state) 184 | preds.append(pred) 185 | elif FLAGS.ponder: 186 | preds = model.forward(inp, iters=num_steps) 187 | pred = preds[-1] 188 | im_grad = torch.zeros(1) 189 | im_grads = [im_grad] 190 | energies = [torch.zeros(1)] 191 | state = None 192 | elif FLAGS.iterative_decoder: 193 | for i in range(num_steps): 194 | energy = torch.zeros(1) 195 | 196 | out_dim = model.out_dim 197 | 198 | im_merge = torch.cat([pred, inp], dim=-1) 199 | pred = model.forward(im_merge) + pred 200 | 201 | energies.append(torch.zeros(1)) 202 | im_grads.append(torch.zeros(1)) 203 | else: 204 | with torch.enable_grad(): 205 | pred.requires_grad_(requires_grad=True) 206 | s = inp.size() 207 | scratch.requires_grad_(requires_grad=True) 208 | 209 | for i in range(num_steps): 210 | 211 | energy = model.forward(inp, pred) 212 | 213 | if FLAGS.no_truncate_grad: 214 | im_grad, = torch.autograd.grad([energy.sum()], [pred], create_graph=create_graph) 215 | else: 216 | if i == (num_steps - 1): 217 | im_grad, = torch.autograd.grad([energy.sum()], [pred], create_graph=True) 218 | else: 219 | im_grad, = torch.autograd.grad([energy.sum()], [pred], create_graph=False) 220 | pred = pred - FLAGS.step_lr * im_grad 221 | 222 | preds.append(pred) 223 | energies.append(energy) 224 | im_grads.append(im_grad) 225 | 226 | return pred, preds, im_grads, energies, scratch 227 | 228 | 229 | def ema_model(model, model_ema, mu=0.999): 230 | for (model, model_ema) in zip(model, model_ema): 231 | for param, param_ema in zip(model.parameters(), model_ema.parameters()): 232 | param_ema.data[:] = mu * param_ema.data + (1 - mu) * param.data 233 | 234 | 235 | def sync_model(model): 236 | size = float(dist.get_world_size()) 237 | 238 | for model in model: 239 | for param in model.parameters(): 240 | dist.broadcast(param.data, 0) 241 | 242 | 243 | def init_model(FLAGS, device, dataset): 244 | if FLAGS.decoder: 245 | model = GraphFC(dataset.inp_dim, dataset.out_dim) 246 | elif FLAGS.ponder: 247 | model = GraphPonder(dataset.inp_dim, dataset.out_dim) 248 | elif FLAGS.recurrent: 249 | model = GraphRecurrent(dataset.inp_dim, dataset.out_dim) 250 | elif FLAGS.iterative_decoder: 251 | model = IterativeFC(dataset.inp_dim, dataset.out_dim, FLAGS.mem) 252 | else: 253 | model = GraphEBM(dataset.inp_dim, dataset.out_dim, FLAGS.mem) 254 | 255 | model.to(device) 256 | optimizer = Adam(model.parameters(), lr=1e-4) 257 | 258 | return model, optimizer 259 | 260 | 261 | def test(train_dataloader, model, FLAGS, step=0, gen=False): 262 | global best_test_error, best_gen_test_error 263 | if FLAGS.cuda: 264 | dev = torch.device("cuda") 265 | else: 266 | dev = torch.device("cpu") 267 | 268 | replay_buffer = None 269 | dist_list = [] 270 | energy_list = [] 271 | min_dist_energy_list = [] 272 | 273 | model.eval() 274 | counter = 0 275 | with torch.no_grad(): 276 | for data in train_dataloader: 277 | data = data.to(dev) 278 | im = data['y'] 279 | 280 | pred = (torch.rand_like(data.edge_attr) - 0.5) * 2. 281 | 282 | scratch = torch.zeros_like(pred) 283 | 284 | pred_init = pred 285 | pred, preds, im_grad, energies, scratch = gen_answer(data, FLAGS, model, pred, scratch, 40, create_graph=False) 286 | preds = torch.stack(preds, dim=0) 287 | energies = torch.stack(energies, dim=0).mean(dim=-1).mean(dim=-1) 288 | 289 | dist = (preds - im[None, :]) 290 | dist = torch.pow(dist, 2).mean(dim=-1) 291 | n = dist.size(1) 292 | 293 | dist = dist.mean(dim=-1) 294 | min_idx = energies.argmin(dim=0) 295 | dist_min = dist[min_idx] 296 | min_dist_energy_list.append(dist_min) 297 | 298 | dist_list.append(dist.detach()) 299 | energy_list.append(energies.detach()) 300 | 301 | counter = counter + 1 302 | 303 | if counter > 10: 304 | dist = torch.stack(dist_list, dim=0).mean(dim=0) 305 | energies = torch.stack(energy_list, dim=0).mean(dim=0) 306 | min_dist_energy = torch.stack(min_dist_energy_list, dim=0).mean() 307 | print("Testing..................") 308 | print("last step error: ", dist) 309 | print("energy values: ", energies) 310 | print('test at step %d done!' % step) 311 | break 312 | 313 | if gen: 314 | best_gen_test_error = min(best_gen_test_error, min_dist_energy.item()) 315 | print("best gen test error: ", best_gen_test_error) 316 | else: 317 | best_test_error = min(best_test_error, min_dist_energy.item()) 318 | print("best test error: ", best_test_error) 319 | model.train() 320 | 321 | 322 | def train(train_dataloader, test_dataloader, gen_dataloader, logger, model, optimizer, FLAGS, logdir, rank_idx): 323 | it = FLAGS.resume_iter 324 | optimizer.zero_grad() 325 | 326 | dev = torch.device("cuda") 327 | replay_buffer = ReplayBuffer(10000) 328 | 329 | for epoch in range(FLAGS.num_epoch): 330 | for data in train_dataloader: 331 | pred = (torch.rand_like(data.edge_attr) - 0.5) * 2 332 | data['noise'] = pred 333 | scratch = torch.zeros_like(pred) 334 | nreplay = FLAGS.batch_size 335 | 336 | if FLAGS.replay_buffer and len(replay_buffer) >= FLAGS.batch_size: 337 | data_list, levels = replay_buffer.sample(32) 338 | new_data_list = data.to_data_list() 339 | 340 | ix = int(FLAGS.batch_size * 0.5) 341 | 342 | nreplay = len(data_list) + 40 343 | data_list = data_list + new_data_list 344 | data = Batch.from_data_list(data_list) 345 | pred = data['noise'] 346 | else: 347 | levels = np.zeros(FLAGS.batch_size) 348 | replay_mask = ( 349 | np.random.uniform( 350 | 0, 351 | 1, 352 | FLAGS.batch_size) > 1.0) 353 | 354 | data = data.to(dev) 355 | pred = data['noise'] 356 | num_steps = FLAGS.num_steps 357 | pred, preds, im_grads, energies, scratch = gen_answer(data, FLAGS, model, pred, scratch, num_steps) 358 | energies = torch.stack(energies, dim=0) 359 | 360 | # Choose energies to be consistent at the last step 361 | preds = torch.stack(preds, dim=1) 362 | 363 | im_grads = torch.stack(im_grads, dim=1) 364 | im = data['y'] 365 | 366 | if FLAGS.decoder: 367 | im_loss = torch.pow(preds[:, -1:] - im[:, None, :], 2).mean(dim=-1).mean(dim=-1) 368 | elif FLAGS.ponder: 369 | im_loss = torch.pow(preds[:, :] - im[:, None, :], 2).mean(dim=-1).mean(dim=-1) 370 | else: 371 | im_loss = torch.pow(preds[:, -1:] - im[:, None, :], 2).mean(dim=-1).mean(dim=-1) 372 | 373 | loss = im_loss.mean() 374 | loss.backward() 375 | 376 | if FLAGS.replay_buffer: 377 | data['noise'] = preds[:, -1].detach() 378 | data_list = data.cpu().to_data_list() 379 | 380 | replay_buffer.add(data_list[:nreplay]) 381 | 382 | if FLAGS.gpus > 1: 383 | average_gradients(model) 384 | 385 | optimizer.step() 386 | optimizer.zero_grad() 387 | 388 | if it > 10000: 389 | assert False 390 | 391 | if it % FLAGS.log_interval == 0 and rank_idx == 0: 392 | loss = loss.item() 393 | kvs = {} 394 | 395 | kvs['im_loss'] = im_loss.mean().item() 396 | 397 | string = "Iteration {} ".format(it) 398 | 399 | for k, v in kvs.items(): 400 | string += "%s: %.6f " % (k,v) 401 | logger.add_scalar(k, v, it) 402 | 403 | print(string) 404 | 405 | if it % FLAGS.save_interval == 0 and rank_idx == 0: 406 | model_path = osp.join(logdir, "model_latest.pth".format(it)) 407 | ckpt = {'FLAGS': FLAGS} 408 | 409 | ckpt['model_state_dict'] = model.state_dict() 410 | ckpt['optimizer_state_dict'] = optimizer.state_dict() 411 | 412 | torch.save(ckpt, model_path) 413 | 414 | print("Testing performance .......................") 415 | test(test_dataloader, model, FLAGS, step=it, gen=False) 416 | 417 | print("Generalization performance .......................") 418 | test(gen_dataloader, model, FLAGS, step=it, gen=True) 419 | 420 | it += 1 421 | 422 | 423 | 424 | def main_single(rank, FLAGS): 425 | rank_idx = FLAGS.node_rank * FLAGS.gpus + rank 426 | world_size = FLAGS.nodes * FLAGS.gpus 427 | 428 | 429 | if not os.path.exists('result/%s' % FLAGS.exp): 430 | try: 431 | os.makedirs('result/%s' % FLAGS.exp) 432 | except: 433 | pass 434 | 435 | if FLAGS.dataset == 'identity': 436 | dataset = Identity('train', FLAGS.rank, vary=FLAGS.vary) 437 | test_dataset = Identity('test', FLAGS.rank) 438 | gen_dataset = Identity('test', FLAGS.rank+FLAGS.gen_rank) 439 | elif FLAGS.dataset == 'connected': 440 | dataset = ConnectedComponents('train', FLAGS.rank, vary=FLAGS.vary) 441 | test_dataset = ConnectedComponents('test', FLAGS.rank) 442 | gen_dataset = ConnectedComponents('test', FLAGS.rank+FLAGS.gen_rank) 443 | elif FLAGS.dataset == 'shortestpath': 444 | dataset = ShortestPath('train', FLAGS.rank, vary=FLAGS.vary) 445 | test_dataset = ShortestPath('test', FLAGS.rank) 446 | gen_dataset = ShortestPath('test', FLAGS.rank+FLAGS.gen_rank) 447 | 448 | if FLAGS.gen: 449 | test_dataset = gen_dataset 450 | 451 | shuffle=True 452 | sampler = None 453 | 454 | if world_size > 1: 455 | group = dist.init_process_group(backend='nccl', init_method='tcp://localhost:8113', world_size=world_size, rank=rank_idx, group_name="default") 456 | 457 | torch.cuda.set_device(rank) 458 | device = torch.device('cuda') 459 | 460 | logdir = osp.join(FLAGS.logdir, FLAGS.exp) 461 | FLAGS_OLD = FLAGS 462 | 463 | if FLAGS.resume_iter != 0: 464 | model_path = osp.join(logdir, "model_latest.pth".format(FLAGS.resume_iter)) 465 | 466 | checkpoint = torch.load(model_path, map_location=torch.device('cpu')) 467 | FLAGS = checkpoint['FLAGS'] 468 | 469 | FLAGS.resume_iter = FLAGS_OLD.resume_iter 470 | FLAGS.save_interval = FLAGS_OLD.save_interval 471 | FLAGS.nodes = FLAGS_OLD.nodes 472 | FLAGS.gpus = FLAGS_OLD.gpus 473 | FLAGS.node_rank = FLAGS_OLD.node_rank 474 | FLAGS.train = FLAGS_OLD.train 475 | FLAGS.batch_size = FLAGS_OLD.batch_size 476 | FLAGS.num_steps = FLAGS_OLD.num_steps 477 | FLAGS.exp = FLAGS_OLD.exp 478 | FLAGS.vis = FLAGS_OLD.vis 479 | FLAGS.ponder = FLAGS_OLD.ponder 480 | FLAGS.recurrent = FLAGS_OLD.recurrent 481 | FLAGS.no_truncate_grad = FLAGS_OLD.no_truncate_grad 482 | FLAGS.gen_rank = FLAGS_OLD.gen_rank 483 | FLAGS.step_lr = FLAGS_OLD.step_lr 484 | 485 | model, optimizer = init_model(FLAGS, device, dataset) 486 | state_dict = model.state_dict() 487 | 488 | model.load_state_dict(checkpoint['model_state_dict'], strict=False) 489 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 490 | else: 491 | model, optimizer = init_model(FLAGS, device, dataset) 492 | 493 | if FLAGS.gpus > 1: 494 | sync_model(model) 495 | 496 | train_dataloader = DataLoader(dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size, shuffle=shuffle, worker_init_fn=worker_init_fn) 497 | test_dataloader = DataLoader(test_dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size, shuffle=True, pin_memory=False, drop_last=True, worker_init_fn=worker_init_fn) 498 | gen_dataloader = DataLoader(gen_dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size // 2, shuffle=True, pin_memory=False, drop_last=True, worker_init_fn=worker_init_fn) 499 | 500 | logger = SummaryWriter(logdir) 501 | it = FLAGS.resume_iter 502 | 503 | if FLAGS.train: 504 | model.train() 505 | else: 506 | model.eval() 507 | 508 | if FLAGS.train: 509 | train(train_dataloader, test_dataloader, gen_dataloader, logger, model, optimizer, FLAGS, logdir, rank_idx) 510 | else: 511 | test(test_dataloader, model, FLAGS, step=FLAGS.resume_iter) 512 | 513 | 514 | def main(): 515 | FLAGS = parser.parse_args() 516 | FLAGS.replay_buffer = not FLAGS.no_replay_buffer 517 | FLAGS.vary = True 518 | logdir = osp.join(FLAGS.logdir, FLAGS.exp) 519 | 520 | if not osp.exists(logdir): 521 | os.makedirs(logdir) 522 | 523 | if FLAGS.gpus > 1: 524 | mp.spawn(main_single, nprocs=FLAGS.gpus, args=(FLAGS,)) 525 | else: 526 | main_single(0, FLAGS) 527 | 528 | 529 | if __name__ == "__main__": 530 | try: 531 | torch.multiprocessing.set_start_method('spawn') 532 | except: 533 | pass 534 | main() 535 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import os 3 | from tqdm import tqdm 4 | import os.path as osp 5 | import numpy as np 6 | import json 7 | import torchvision.transforms.functional as TF 8 | import random 9 | from torchvision.datasets import ImageFolder 10 | 11 | from PIL import Image 12 | import torch.utils.data as data 13 | import torch 14 | import cv2 15 | from torchvision import transforms 16 | import glob 17 | 18 | from glob import glob 19 | from imageio import imread 20 | from skimage.transform import resize as imresize 21 | 22 | from torchvision.datasets import CIFAR100 23 | import numpy as np 24 | 25 | from scipy.sparse import csr_matrix 26 | from scipy.sparse.csgraph import shortest_path 27 | from copy import deepcopy 28 | import time 29 | from scipy.linalg import lu 30 | 31 | 32 | def extract(a, t, x_shape): 33 | b, *_ = t.shape 34 | out = a.gather(-1, t) 35 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 36 | 37 | 38 | def cosine_beta_schedule(timesteps, s=0.008): 39 | """ 40 | cosine schedule 41 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 42 | """ 43 | steps = timesteps + 1 44 | x = np.linspace(0, steps, steps) 45 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 46 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 47 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 48 | return np.clip(betas, a_min=0, a_max=0.999) 49 | 50 | 51 | class NoisyWrapper: 52 | 53 | def __init__(self, dataset, timesteps): 54 | 55 | self.dataset = dataset 56 | self.timesteps = timesteps 57 | betas = cosine_beta_schedule(timesteps) 58 | self.inp_dim = dataset.inp_dim 59 | self.out_dim = dataset.out_dim 60 | 61 | alphas = 1. - betas 62 | alphas_cumprod = np.cumprod(alphas, axis=0) 63 | 64 | alphas_cumprod = np.linspace(1, 0, timesteps) 65 | self.sqrt_alphas_cumprod = torch.tensor(np.sqrt(alphas_cumprod)) 66 | self.sqrt_one_minus_alphas_cumprod = torch.tensor( 67 | np.sqrt(1. - alphas_cumprod)) 68 | self.extract = extract 69 | 70 | def __len__(self): 71 | return len(self.dataset) 72 | 73 | def __getitem__(self, *args, **kwargs): 74 | x, y = self.dataset.__getitem__(*args, **kwargs) 75 | x = torch.tensor(x) 76 | y = torch.tensor(y) 77 | 78 | t = torch.randint(1, self.timesteps, (1,)).long() 79 | t_next = t - 1 80 | noise = torch.randn_like(y) 81 | 82 | sample = ( 83 | self.extract( 84 | self.sqrt_alphas_cumprod, 85 | t, 86 | y.shape) * 87 | y + 88 | self.extract( 89 | self.sqrt_one_minus_alphas_cumprod, 90 | t, 91 | y.shape) * 92 | noise) 93 | 94 | sample_next = ( 95 | self.extract( 96 | self.sqrt_alphas_cumprod, 97 | t_next, 98 | y.shape) * 99 | y + 100 | self.extract( 101 | self.sqrt_one_minus_alphas_cumprod, 102 | t_next, 103 | y.shape) * 104 | noise) 105 | return x, sample, sample_next 106 | 107 | 108 | def conjgrad(A, b, x, num_steps=20): 109 | """ 110 | A function to solve [A]{x} = {b} linear equation system with the 111 | conjugate gradient method. 112 | More at: http://en.wikipedia.org/wiki/Conjugate_gradient_method 113 | ========== Parameters ========== 114 | A : matrix 115 | A real symmetric positive definite matrix. 116 | b : vector 117 | The right hand side (RHS) vector of the system. 118 | x : vector 119 | The starting guess for the solution. 120 | """ 121 | r = b - np.dot(A, x) 122 | p = r 123 | rsold = np.dot(np.transpose(r), r) 124 | 125 | sols = [x.flatten()] 126 | 127 | for i in range(num_steps): 128 | Ap = np.dot(A, p) 129 | alpha = rsold / np.dot(np.transpose(p), Ap) 130 | x = x + alpha[0, 0] * p 131 | r = r - alpha[0, 0] * Ap 132 | rsnew = np.dot(np.transpose(r), r) 133 | # if np.sqrt(rsnew) < 1e-8: 134 | # break 135 | p = r + (rsnew / rsold) * p 136 | rsold = rsnew 137 | sols.append(x.flatten()) 138 | 139 | sols = np.stack(sols, axis=0) 140 | return sols 141 | 142 | 143 | class FiniteWrapper(data.Dataset): 144 | """Constructs a dataset with N circles, N is the number of components set by flags""" 145 | 146 | def __init__(self, dataset, string, capacity, rank, num_steps): 147 | """Initialize this dataset class. 148 | 149 | Parameters: 150 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 151 | """ 152 | 153 | cache_file = "cache/{}_{}_{}_{}.npz".format( 154 | capacity, string, rank, num_steps) 155 | self.inp_dim = dataset.inp_dim 156 | self.out_dim = dataset.out_dim 157 | 158 | if osp.exists(cache_file): 159 | data = np.load(cache_file) 160 | inps = data['inp'] 161 | outs = data['out'] 162 | else: 163 | "Generating static dataset for training....." 164 | inps = [] 165 | outs = [] 166 | 167 | for i in tqdm(range(capacity)): 168 | inp, out, trace = dataset[i] 169 | inps.append(inp) 170 | outs.append(out) 171 | 172 | inps = np.array(inps) 173 | outs = np.array(outs) 174 | 175 | np.savez(cache_file, inp=inps, out=outs) 176 | 177 | self.inp = inps 178 | self.out = outs 179 | 180 | def __getitem__(self, index): 181 | """Return a data point and its metadata information. 182 | 183 | Parameters: 184 | index - - a random integer for data indexing 185 | 186 | Returns a dictionary that contains A and A_paths 187 | A(tensor) - - an image in one domain 188 | A_paths(str) - - the path of the image 189 | """ 190 | inp = self.inp[index] 191 | out = self.out[index] 192 | 193 | return inp, out 194 | 195 | def __len__(self): 196 | """Return the total number of images in the dataset.""" 197 | # Dataset is always randomly generated 198 | return self.inp.shape[0] 199 | 200 | 201 | class Identity(data.Dataset): 202 | """Constructs a dataset with N circles, N is the number of components set by flags""" 203 | 204 | def __init__(self, split, rank, num_steps=5): 205 | """Initialize this dataset class. 206 | 207 | Parameters: 208 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 209 | """ 210 | self.h = rank 211 | self.w = rank 212 | self.num_steps = num_steps 213 | 214 | self.split = split 215 | self.inp_dim = self.h * self.w 216 | self.out_dim = self.h * self.w 217 | 218 | def __getitem__(self, index): 219 | """Return a data point and its metadata information. 220 | 221 | Parameters: 222 | index - - a random integer for data indexing 223 | 224 | Returns a dictionary that contains A and A_paths 225 | A(tensor) - - an image in one domain 226 | A_paths(str) - - the path of the image 227 | """ 228 | 229 | R = np.random.uniform(0, 1, (self.h, self.w)) 230 | R_corrupt = R 231 | 232 | R_list = np.tile(R.flatten()[None, :], (self.num_steps, 1)) 233 | R_list = np.concatenate([R_corrupt.flatten()[None, :], R_list], axis=0) 234 | 235 | return R_corrupt.flatten(), R.flatten() 236 | 237 | def __len__(self): 238 | """Return the total number of images in the dataset.""" 239 | # Dataset is always randomly generated 240 | return int(1e6) 241 | 242 | 243 | class LowRankDataset(data.Dataset): 244 | """Constructs a dataset with N circles, N is the number of components set by flags""" 245 | 246 | def __init__(self, split, rank, ood): 247 | """Initialize this dataset class. 248 | 249 | Parameters: 250 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 251 | """ 252 | self.h = rank 253 | self.w = rank 254 | 255 | self.split = split 256 | self.inp_dim = self.h * self.w 257 | self.out_dim = self.h * self.w 258 | self.ood = ood 259 | 260 | def __getitem__(self, index): 261 | """Return a data point and its metadata information. 262 | 263 | Parameters: 264 | index - - a random integer for data indexing 265 | 266 | Returns a dictionary that contains A and A_paths 267 | A(tensor) - - an image in one domain 268 | A_paths(str) - - the path of the image 269 | """ 270 | 271 | U = np.random.randn(self.h, 10) 272 | V = np.random.randn(self.w, 10) 273 | 274 | if self.ood: 275 | R = 0.1 * np.random.randn(self.h, self.w) + np.dot(U, V.T) / 5 276 | else: 277 | R = 0.1 * np.random.randn(self.h, self.w) + np.dot(U, V.T) / 20 278 | 279 | mask = np.round(np.random.rand(self.h, self.w)) 280 | 281 | R_corrupt = R * mask 282 | return R_corrupt.flatten(), R.flatten() 283 | 284 | def __len__(self): 285 | """Return the total number of images in the dataset.""" 286 | return int(1e6) 287 | 288 | 289 | class Negate(data.Dataset): 290 | """Constructs a dataset with N circles, N is the number of components set by flags""" 291 | 292 | def __init__(self, split, rank): 293 | """Initialize this dataset class. 294 | 295 | Parameters: 296 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 297 | """ 298 | self.h = rank 299 | self.w = rank 300 | 301 | self.split = split 302 | self.inp_dim = self.h * self.w 303 | self.out_dim = self.h * self.w 304 | 305 | def __getitem__(self, index): 306 | """Return a data point and its metadata information. 307 | 308 | Parameters: 309 | index - - a random integer for data indexing 310 | 311 | Returns a dictionary that contains A and A_paths 312 | A(tensor) - - an image in one domain 313 | A_paths(str) - - the path of the image 314 | """ 315 | 316 | R = np.random.uniform(0, 1, (self.h, self.w)) 317 | R_corrupt = -1 * R 318 | 319 | return R_corrupt.flatten(), R.flatten() 320 | 321 | def __len__(self): 322 | """Return the total number of images in the dataset.""" 323 | # Dataset is always randomly generated 324 | return int(1e6) 325 | 326 | 327 | class Addition(data.Dataset): 328 | """Constructs a dataset with N circles, N is the number of components set by flags""" 329 | 330 | def __init__(self, split, rank, ood): 331 | """Initialize this dataset class. 332 | 333 | Parameters: 334 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 335 | """ 336 | self.h = rank 337 | 338 | if self.h == 2: 339 | self.w = 1 340 | else: 341 | self.w = rank 342 | self.ood = ood 343 | 344 | self.split = split 345 | self.inp_dim = 2 * self.h * self.w 346 | self.out_dim = self.h * self.w 347 | 348 | def __getitem__(self, index): 349 | """Return a data point and its metadata information. 350 | 351 | Parameters: 352 | index - - a random integer for data indexing 353 | 354 | Returns a dictionary that contains A and A_paths 355 | A(tensor) - - an image in one domain 356 | A_paths(str) - - the path of the image 357 | """ 358 | 359 | if self.ood: 360 | scale = 2.5 361 | else: 362 | scale = 1.0 363 | 364 | R_one = np.random.uniform(-scale, scale, (self.h, self.w)).flatten() 365 | R_two = np.random.uniform(-scale, scale, (self.h, self.w)).flatten() 366 | R_corrupt = np.concatenate([R_one, R_two], axis=0) 367 | R = R_one + R_two 368 | 369 | return R_corrupt.flatten(), R.flatten() 370 | 371 | def __len__(self): 372 | """Return the total number of images in the dataset.""" 373 | # Dataset is always randomly generated 374 | return int(1e6) 375 | 376 | 377 | class Square(data.Dataset): 378 | """Constructs a dataset with N circles, N is the number of components set by flags""" 379 | 380 | def __init__(self, split, rank, num_steps=10): 381 | """Initialize this dataset class. 382 | 383 | Parameters: 384 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 385 | """ 386 | self.h = rank 387 | self.w = rank 388 | self.num_steps = num_steps 389 | 390 | self.split = split 391 | self.inp_dim = self.h * self.w 392 | self.out_dim = self.h * self.w 393 | 394 | def __getitem__(self, index): 395 | """Return a data point and its metadata information. 396 | 397 | Parameters: 398 | index - - a random integer for data indexing 399 | 400 | Returns a dictionary that contains A and A_paths 401 | A(tensor) - - an image in one domain 402 | A_paths(str) - - the path of the image 403 | """ 404 | 405 | R_corrupt = np.random.uniform(-1, 1, (self.h, self.w)) 406 | R = np.matmul(R_corrupt, R_corrupt).flatten() / 5. 407 | # R = R_corrupt * R_corrupt 408 | 409 | R_list = np.tile(R.flatten()[None, :], (self.num_steps, 1)) 410 | R_list = np.concatenate([R_corrupt.flatten()[None, :], R_list], axis=0) 411 | 412 | return R_corrupt.flatten(), R.flatten() 413 | 414 | def __len__(self): 415 | """Return the total number of images in the dataset.""" 416 | # Dataset is always randomly generated 417 | return int(1e6) 418 | 419 | 420 | class Inverse(data.Dataset): 421 | """Constructs a dataset with N circles, N is the number of components set by flags""" 422 | 423 | def __init__(self, split, rank, ood): 424 | """Initialize this dataset class. 425 | 426 | Parameters: 427 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 428 | """ 429 | self.h = rank 430 | self.w = rank 431 | self.ood = ood 432 | 433 | self.split = split 434 | self.inp_dim = self.h * self.w 435 | self.out_dim = self.h * self.w 436 | 437 | def __getitem__(self, index): 438 | """Return a data point and its metadata information. 439 | 440 | Parameters: 441 | index - - a random integer for data indexing 442 | 443 | Returns a dictionary that contains A and A_paths 444 | A(tensor) - - an image in one domain 445 | A_paths(str) - - the path of the image 446 | """ 447 | 448 | R_corrupt = np.random.uniform(-1, 1, (self.h, self.w)) 449 | R_corrupt = R_corrupt.dot(R_corrupt.transpose()) 450 | # R_corrupt = R_corrupt + 0.5 * np.eye(self.h) 451 | 452 | if self.ood: 453 | R_corrupt = R_corrupt + R_corrupt.transpose() + 0.1 * np.eye(self.h) 454 | else: 455 | R_corrupt = R_corrupt + R_corrupt.transpose() + 0.5 * np.eye(self.h) 456 | 457 | R = np.linalg.inv(R_corrupt) 458 | 459 | return R_corrupt.flatten(), R.flatten() 460 | 461 | def __len__(self): 462 | """Return the total number of images in the dataset.""" 463 | # Dataset is always randomly generated 464 | return int(1e6) 465 | 466 | 467 | class Equation(data.Dataset): 468 | """Constructs a dataset with N circles, N is the number of components set by flags""" 469 | 470 | def __init__(self, split, rank, num_steps=10): 471 | """Initialize this dataset class. 472 | 473 | Parameters: 474 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 475 | """ 476 | self.h = rank 477 | self.w = rank 478 | 479 | self.split = split 480 | self.inp_dim = rank * rank + rank 481 | self.out_dim = rank 482 | self.num_steps = num_steps 483 | 484 | def __getitem__(self, index): 485 | """Return a data point and its metadata information. 486 | 487 | Parameters: 488 | index - - a random integer for data indexing 489 | 490 | Returns a dictionary that contains A and A_paths 491 | A(tensor) - - an image in one domain 492 | A_paths(str) - - the path of the image 493 | """ 494 | 495 | A = np.random.uniform(-1, 1, (self.h, self.w)) 496 | A = A + A.transpose() 497 | A = np.matmul(A, A.transpose()) 498 | 499 | x = np.random.uniform(-1, 1, (self.h, 1)) 500 | b = np.matmul(A, x) 501 | 502 | inp = np.concatenate([A.flatten(), b.flatten()], axis=0) 503 | sol = x.flatten() 504 | 505 | x_guess = np.random.uniform(-1, 1, (self.h, 1)) 506 | 507 | return inp, sol 508 | 509 | def __len__(self): 510 | """Return the total number of images in the dataset.""" 511 | # Dataset is always randomly generated 512 | return int(1e6) 513 | 514 | 515 | class LU(data.Dataset): 516 | """Constructs a dataset with N circles, N is the number of components set by flags""" 517 | 518 | def __init__(self, split, rank): 519 | """Initialize this dataset class. 520 | 521 | Parameters: 522 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 523 | """ 524 | self.h = rank 525 | self.w = rank 526 | 527 | self.split = split 528 | self.inp_dim = self.h * self.w 529 | self.out_dim = 2 * self.h * self.w 530 | 531 | def __getitem__(self, index): 532 | """Return a data point and its metadata information. 533 | 534 | Parameters: 535 | index - - a random integer for data indexing 536 | 537 | Returns a dictionary that contains A and A_paths 538 | A(tensor) - - an image in one domain 539 | A_paths(str) - - the path of the image 540 | """ 541 | 542 | R_corrupt = np.random.uniform(-1, 1, (self.h, self.w)) 543 | p, l, u = lu(R_corrupt) 544 | R = np.concatenate([l.flatten(), u.flatten()], axis=0) 545 | 546 | return R_corrupt.flatten(), R 547 | 548 | def __len__(self): 549 | """Return the total number of images in the dataset.""" 550 | # Dataset is always randomly generated 551 | return int(1e6) 552 | 553 | 554 | class Det(data.Dataset): 555 | """Constructs a dataset with N circles, N is the number of components set by flags""" 556 | 557 | def __init__(self, split, rank): 558 | """Initialize this dataset class. 559 | 560 | Parameters: 561 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 562 | """ 563 | self.h = rank 564 | self.w = rank 565 | 566 | self.split = split 567 | self.inp_dim = self.h * self.w 568 | self.out_dim = 1 569 | 570 | def __getitem__(self, index): 571 | """Return a data point and its metadata information. 572 | 573 | Parameters: 574 | index - - a random integer for data indexing 575 | 576 | Returns a dictionary that contains A and A_paths 577 | A(tensor) - - an image in one domain 578 | A_paths(str) - - the path of the image 579 | """ 580 | 581 | R_corrupt = np.random.uniform(0, 1, (self.h, self.w)) 582 | R = np.linalg.det(R_corrupt) * 10 583 | 584 | return R_corrupt.flatten(), np.array([R]) 585 | 586 | def __len__(self): 587 | """Return the total number of images in the dataset.""" 588 | # Dataset is always randomly generated 589 | return int(1e6) 590 | 591 | 592 | class ShortestPath(data.Dataset): 593 | """Constructs a dataset with N circles, N is the number of components set by flags""" 594 | 595 | def __init__(self, split, rank=20, num_steps=10, vary=False): 596 | """Initialize this dataset class. 597 | 598 | Parameters: 599 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 600 | """ 601 | self.h = rank 602 | self.w = rank 603 | self.rank = rank 604 | self.edge_prob = 0.3 605 | self.vary = vary 606 | 607 | self.inp_dim = self.h * self.w 608 | self.out_dim = self.h * self.w 609 | self.split = split 610 | self.num_steps = num_steps 611 | 612 | def __getitem__(self, index): 613 | """Return a data point and its metadata information. 614 | 615 | Parameters: 616 | index - - a random integer for data indexing 617 | 618 | Returns a dictionary that contains A and A_paths 619 | A(tensor) - - an image in one domain 620 | A_paths(str) - - the path of the image 621 | """ 622 | 623 | # if self.split == "train": 624 | # np.random.seed(index % 10000) 625 | 626 | if self.vary: 627 | rank = random.randint(2, self.rank) 628 | else: 629 | rank = self.rank 630 | 631 | graph = np.random.uniform(0, 1, size=[rank, rank]) 632 | graph = graph + graph.transpose() 633 | np.fill_diagonal(graph, 0) # can move to self 634 | 635 | graph_dist, graph_predecessors = shortest_path(csgraph=csr_matrix( 636 | graph), unweighted=False, directed=False, return_predecessors=True) 637 | 638 | return graph.flatten(), graph_dist.flatten() 639 | 640 | def __len__(self): 641 | """Return the total number of images in the dataset.""" 642 | # Dataset is always randomly generated 643 | return int(1e6) 644 | 645 | 646 | class Sort(data.Dataset): 647 | """Constructs a dataset with N circles, N is the number of components set by flags""" 648 | 649 | def __init__(self, split, rank=20, num_steps=10): 650 | """Initialize this dataset class. 651 | 652 | Parameters: 653 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 654 | """ 655 | self.h = rank 656 | self.w = rank 657 | self.edge_prob = 0.3 658 | self.num_steps = num_steps 659 | 660 | self.inp_dim = self.h * self.w 661 | self.out_dim = self.h * self.w 662 | self.split = split 663 | 664 | def __getitem__(self, index): 665 | """Return a data point and its metadata information. 666 | 667 | Parameters: 668 | index - - a random integer for data indexing 669 | 670 | Returns a dictionary that contains A and A_paths 671 | A(tensor) - - an image in one domain 672 | A_paths(str) - - the path of the image 673 | """ 674 | 675 | x = np.random.uniform(-1, 1, self.inp_dim) 676 | rix = np.random.permutation(x.shape[0]) 677 | x = x[rix] 678 | x_sort = np.sort(x) 679 | 680 | return x, x_sort 681 | 682 | def __len__(self): 683 | """Return the total number of images in the dataset.""" 684 | # Dataset is always randomly generated 685 | return int(1e6) 686 | 687 | 688 | class Eigen(data.Dataset): 689 | """Constructs a dataset with N circles, N is the number of components set by flags""" 690 | 691 | def __init__(self, split, rank=20, num_steps=10): 692 | """Initialize this dataset class. 693 | 694 | Parameters: 695 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 696 | """ 697 | self.h = rank 698 | self.w = rank 699 | self.edge_prob = 0.3 700 | self.num_steps = num_steps 701 | 702 | self.inp_dim = self.h * self.w 703 | self.out_dim = self.h 704 | self.split = split 705 | 706 | def __getitem__(self, index): 707 | """Return a data point and its metadata information. 708 | 709 | Parameters: 710 | index - - a random integer for data indexing 711 | 712 | Returns a dictionary that contains A and A_paths 713 | A(tensor) - - an image in one domain 714 | A_paths(str) - - the path of the image 715 | """ 716 | 717 | x = np.random.uniform(-1, 1, (self.h, self.w)) 718 | x = (x + x.transpose()) / 2 719 | x = np.matmul(x, x.transpose()) 720 | val, eig = np.linalg.eig(x) 721 | rix = np.argsort(np.abs(val))[::-1] 722 | 723 | val = val[rix] 724 | eig = eig[:, rix] 725 | 726 | vecs = [] 727 | vec = np.random.uniform(-1, 1, (self.h, 1)) 728 | vecs.append(vec) 729 | eig = eig[:, 0] 730 | sign = np.sign(eig[0]) 731 | eig = eig * sign 732 | 733 | return x.flatten(), eig 734 | 735 | def __len__(self): 736 | """Return the total number of images in the dataset.""" 737 | # Dataset is always randomly generated 738 | return int(1e6) 739 | 740 | 741 | class QR(data.Dataset): 742 | """Constructs a dataset with N circles, N is the number of components set by flags""" 743 | 744 | def __init__(self, split, rank=20, num_steps=10): 745 | """Initialize this dataset class. 746 | 747 | Parameters: 748 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 749 | """ 750 | self.h = rank 751 | self.w = rank 752 | self.edge_prob = 0.3 753 | 754 | self.inp_dim = self.h * self.w 755 | self.out_dim = self.h * self.w * 2 756 | self.split = split 757 | self.num_steps = num_steps 758 | 759 | def __getitem__(self, index): 760 | """Return a data point and its metadata information. 761 | 762 | Parameters: 763 | index - - a random integer for data indexing 764 | 765 | Returns a dictionary that contains A and A_paths 766 | A(tensor) - - an image in one domain 767 | A_paths(str) - - the path of the image 768 | """ 769 | 770 | x = np.random.uniform(-1, 1, (self.h, self.w)) 771 | q, r = np.linalg.qr(x) 772 | flat = np.concatenate([q.flatten(), r.flatten()]) 773 | 774 | flat_list = np.tile(flat[None, :], (self.num_steps, 1)) 775 | flat_start = np.random.uniform(-1, 1, (*flat.shape)) 776 | flat_list = np.concatenate([flat_start[None, :], flat_list], axis=0) 777 | 778 | return x.flatten(), flat # , flat_list 779 | 780 | def __len__(self): 781 | """Return the total number of images in the dataset.""" 782 | # Dataset is always randomly generated 783 | return int(1e6) 784 | 785 | 786 | class Parity(data.Dataset): 787 | """Constructs a dataset with N circles, N is the number of components set by flags""" 788 | 789 | def __init__(self, split, rank=20, num_steps=10): 790 | """Initialize this dataset class. 791 | 792 | Parameters: 793 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 794 | """ 795 | self.h = rank 796 | self.w = rank 797 | self.edge_prob = 0.3 798 | 799 | self.inp_dim = rank 800 | self.out_dim = 1 801 | self.split = split 802 | self.num_steps = num_steps 803 | 804 | def __getitem__(self, index): 805 | """Return a data point and its metadata information. 806 | 807 | Parameters: 808 | index - - a random integer for data indexing 809 | 810 | Returns a dictionary that contains A and A_paths 811 | A(tensor) - - an image in one domain 812 | A_paths(str) - - the path of the image 813 | """ 814 | 815 | val = np.random.uniform(0, 1, size=[self.inp_dim]) 816 | mask = (val > 0.5).astype(np.float32).flatten() 817 | parity = mask.sum() % 2 818 | parity = np.array([parity]) 819 | 820 | rand_val = np.random.uniform(0, 1) > 0.5 821 | parity_noise = np.array([rand_val]) 822 | 823 | return mask, parity 824 | 825 | def __len__(self): 826 | """Return the total number of images in the dataset.""" 827 | # Dataset is always randomly generated 828 | return int(1e6) 829 | 830 | 831 | if __name__ == "__main__": 832 | A = np.random.uniform(-1, 1, (5, 5)) 833 | A = 0.5 * (A + A.transpose()) 834 | A = np.matmul(A, A.transpose()) 835 | x = np.random.uniform(-1, 1, (5, 1)) 836 | b = np.random.uniform(-1, 1, (5, 1)) 837 | sol = conjgrad(A, b, x) 838 | import pdb 839 | pdb.set_trace() 840 | print(A) 841 | print(sol) 842 | print(b) 843 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import EBM, FC, IterativeFC, IterativeAttention, IterativeFCAttention, \ 3 | IterativeTransformer, EBMTwin, RecurrentFC, PonderFC 4 | import torch.nn.functional as F 5 | import os 6 | import pdb 7 | from dataset import LowRankDataset, ShortestPath, Negate, Inverse, Square, Identity, \ 8 | Det, LU, Sort, Eigen, QR, Equation, FiniteWrapper, Parity, Addition 9 | import matplotlib.pyplot as plt 10 | import torch.multiprocessing as mp 11 | import torch.distributed as dist 12 | from torch.optim import Adam 13 | from torch.utils.tensorboard import SummaryWriter 14 | from torch.utils.data import DataLoader 15 | import os.path as osp 16 | import numpy as np 17 | from imageio import imwrite 18 | import argparse 19 | import torchvision.transforms as transforms 20 | import torch.nn as nn 21 | import torch.optim as optim 22 | import torch.backends.cudnn as cudnn 23 | import random 24 | from torchvision.utils import make_grid 25 | import seaborn as sns 26 | 27 | 28 | def worker_init_fn(worker_id): 29 | np.random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1)) 30 | 31 | 32 | class ReplayBuffer(object): 33 | def __init__(self, size): 34 | """Create Replay buffer. 35 | Parameters 36 | ---------- 37 | size: int 38 | Max number of transitions to store in the buffer. When the buffer 39 | overflows the old memories are dropped. 40 | """ 41 | self._storage = [] 42 | self._maxsize = size 43 | self._next_idx = 0 44 | 45 | def __len__(self): 46 | return len(self._storage) 47 | 48 | def add(self, inputs): 49 | batch_size = len(inputs) 50 | if self._next_idx >= len(self._storage): 51 | self._storage.extend(inputs) 52 | else: 53 | if batch_size + self._next_idx < self._maxsize: 54 | self._storage[self._next_idx:self._next_idx + 55 | batch_size] = inputs 56 | else: 57 | split_idx = self._maxsize - self._next_idx 58 | self._storage[self._next_idx:] = inputs[:split_idx] 59 | self._storage[:batch_size - split_idx] = inputs[split_idx:] 60 | self._next_idx = (self._next_idx + batch_size) % self._maxsize 61 | 62 | def _encode_sample(self, idxes): 63 | inps = [] 64 | opts = [] 65 | targets = [] 66 | scratchs = [] 67 | 68 | # Store in the intermediate state of optimization problem 69 | for i in idxes: 70 | inp, opt, target, scratch = self._storage[i] 71 | opt = opt 72 | inps.append(inp) 73 | opts.append(opt) 74 | targets.append(target) 75 | scratchs.append(scratch) 76 | 77 | inps = np.array(inps) 78 | opts = np.array(opts) 79 | targets = np.array(targets) 80 | scratchs = np.array(scratchs) 81 | 82 | return inps, opts, targets, scratchs 83 | 84 | def sample(self, batch_size): 85 | """Sample a batch of experiences. 86 | Parameters 87 | ---------- 88 | batch_size: int 89 | How many transitions to sample. 90 | Returns 91 | ------- 92 | obs_batch: np.array 93 | batch of observations 94 | act_batch: np.array 95 | batch of actions executed given obs_batch 96 | rew_batch: np.array 97 | rewards received as results of executing act_batch 98 | next_obs_batch: np.array 99 | next set of observations seen after executing act_batch 100 | done_mask: np.array 101 | done_mask[i] = 1 if executing act_batch[i] resulted in 102 | the end of an episode and 0 otherwise. 103 | """ 104 | idxes = [random.randint(0, len(self._storage) - 1) 105 | for _ in range(batch_size)] 106 | return self._encode_sample(idxes), torch.Tensor(idxes) 107 | 108 | def set_elms(self, data, idxes): 109 | if len(self._storage) < self._maxsize: 110 | self.add(data) 111 | else: 112 | for i, ix in enumerate(idxes): 113 | self._storage[ix] = data[i] 114 | 115 | 116 | """Parse input arguments""" 117 | parser = argparse.ArgumentParser(description='Train EBM model') 118 | 119 | parser.add_argument('--train', action='store_true', 120 | help='whether or not to train') 121 | parser.add_argument('--cuda', action='store_true', 122 | help='whether to use cuda or not') 123 | parser.add_argument('--no_replay_buffer', action='store_true', 124 | help='do not use a replay buffer to train models') 125 | parser.add_argument('--dataset', default='negate', type=str, 126 | help='dataset to evaluate') 127 | parser.add_argument('--logdir', default='cachedir', type=str, 128 | help='location where log of experiments will be stored') 129 | parser.add_argument('--exp', default='default', type=str, 130 | help='name of experiments') 131 | 132 | # training 133 | parser.add_argument('--resume_iter', default=0, type=int, 134 | help='iteration to resume training') 135 | parser.add_argument('--batch_size', default=512, type=int, 136 | help='size of batch of input to use') 137 | parser.add_argument('--num_epoch', default=10000, type=int, 138 | help='number of epochs of training to run') 139 | parser.add_argument('--lr', default=1e-4, type=float, 140 | help='learning rate for training') 141 | parser.add_argument('--log_interval', default=10, type=int, 142 | help='log outputs every so many batches') 143 | parser.add_argument('--save_interval', default=1000, type=int, 144 | help='save outputs every so many batches') 145 | 146 | # data 147 | parser.add_argument('--data_workers', default=4, type=int, 148 | help='Number of different data workers to load data in parallel') 149 | 150 | # Model specific settings 151 | parser.add_argument('--rank', default=20, type=int, 152 | help='rank of matrix to use') 153 | parser.add_argument('--num_steps', default=10, type=int, 154 | help='Steps of gradient descent for training') 155 | parser.add_argument('--step_lr', default=100.0, type=float, 156 | help='step size of latents') 157 | parser.add_argument('--ood', action='store_true', 158 | help='test on the harder ood dataset') 159 | parser.add_argument('--recurrent', action='store_true', 160 | help='utilize a recurrent model to output prediction') 161 | parser.add_argument('--ponder', action='store_true', 162 | help='utilize a ponder network model to output prediction') 163 | parser.add_argument('--decoder', action='store_true', 164 | help='utilize a decoder network to output prediction') 165 | parser.add_argument('--iterative_decoder', action='store_true', 166 | help='utilize a decoder to output prediction') 167 | parser.add_argument('--mem', action='store_true', 168 | help='add external memory to compute answers') 169 | parser.add_argument('--no_truncate', action='store_true', 170 | help='don"t truncate gradient backprop') 171 | 172 | # Distributed training hyperparameters 173 | parser.add_argument('--gpus', default=1, type=int, 174 | help='number of gpus to train with') 175 | parser.add_argument('--node_rank', default=0, type=int, help='rank of node') 176 | parser.add_argument('--capacity', default=50000, type=int, 177 | help='number of elements to generate') 178 | parser.add_argument('--infinite', action='store_true', 179 | help='makes the dataset have an infinite number of elements') 180 | 181 | 182 | best_test_error_10 = 10.0 183 | best_test_error_20 = 10.0 184 | best_test_error_40 = 10.0 185 | best_test_error_80 = 10.0 186 | best_test_error = 10.0 187 | 188 | 189 | def average_gradients(model): 190 | size = float(dist.get_world_size()) 191 | 192 | for name, param in model.named_parameters(): 193 | if param.grad is None: 194 | continue 195 | 196 | dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM) 197 | param.grad.data /= size 198 | 199 | 200 | def gen_answer(inp, FLAGS, model, pred, scratchpad, num_steps, create_graph=True): 201 | """ 202 | Implement iterative reasoning to obtain the answer to a problem 203 | """ 204 | 205 | # List of intermediate predictions 206 | preds = [] 207 | im_grads = [] 208 | energies = [] 209 | logits = [] 210 | 211 | if FLAGS.decoder: 212 | pred = model.forward(inp) 213 | preds = [pred] 214 | im_grad = torch.zeros(1) 215 | im_grads = [im_grad] 216 | energies = [torch.zeros(1)] 217 | elif FLAGS.recurrent: 218 | preds = [] 219 | im_grad = torch.zeros(1) 220 | im_grads = [im_grad] 221 | energies = [torch.zeros(1)] 222 | state = None 223 | for i in range(num_steps): 224 | pred, state = model.forward(inp, state) 225 | preds.append(pred) 226 | elif FLAGS.ponder: 227 | im_merge = torch.cat([pred, inp], dim=-1) 228 | 229 | preds, logits = model.forward(im_merge, iters=num_steps) 230 | pred = preds[-1] 231 | im_grad = torch.zeros(1) 232 | im_grads = [im_grad] 233 | energies = [torch.zeros(1)] 234 | state = None 235 | 236 | elif FLAGS.iterative_decoder: 237 | for i in range(num_steps): 238 | energy = torch.zeros(1) 239 | 240 | noise_add = (torch.rand_like(pred) - 0.5) 241 | out_dim = model.out_dim 242 | 243 | im_merge = torch.cat([pred, inp], dim=-1) 244 | pred = model.forward(im_merge) + pred 245 | 246 | preds.append(pred) 247 | energies.append(torch.zeros(1)) 248 | im_grads.append(torch.zeros(1)) 249 | else: 250 | with torch.enable_grad(): 251 | pred.requires_grad_(requires_grad=True) 252 | s = inp.size() 253 | scratchpad.requires_grad_(requires_grad=True) 254 | preds.append(pred) 255 | 256 | for i in range(num_steps): 257 | noise = torch.rand_like(pred) - 0.5 258 | 259 | if FLAGS.mem: 260 | im_merge = torch.cat([pred, inp, scratchpad], dim=-1) 261 | else: 262 | im_merge = torch.cat([pred, inp], dim=-1) 263 | 264 | energy = model.forward(im_merge) 265 | 266 | if FLAGS.mem: 267 | im_grad, scratchpad_grad = torch.autograd.grad( 268 | [energy.sum()], [pred, scratchpad], create_graph=create_graph) 269 | else: 270 | 271 | if FLAGS.no_truncate: 272 | im_grad, = torch.autograd.grad( 273 | [energy.sum()], [pred], create_graph=create_graph) 274 | else: 275 | if i != (num_steps - 1): 276 | im_grad, = torch.autograd.grad( 277 | [energy.sum()], [pred], create_graph=False) 278 | else: 279 | im_grad, = torch.autograd.grad( 280 | [energy.sum()], [pred], create_graph=create_graph) 281 | 282 | pred = pred - FLAGS.step_lr * im_grad 283 | 284 | if FLAGS.mem: 285 | scratchpad = scratchpad - FLAGS.step_lr * scratchpad 286 | scratchpad = torch.clamp(scratchpad, -1, 1) 287 | 288 | preds.append(pred) 289 | energies.append(energy) 290 | im_grads.append(im_grad) 291 | 292 | return pred, preds, im_grads, energies, scratchpad, logits 293 | 294 | 295 | def ema_model(model, model_ema, mu=0.999): 296 | for (model, model_ema) in zip(model, model_ema): 297 | for param, param_ema in zip( 298 | model.parameters(), model_ema.parameters()): 299 | param_ema.data[:] = mu * param_ema.data + (1 - mu) * param.data 300 | 301 | 302 | def sync_model(model): 303 | size = float(dist.get_world_size()) 304 | 305 | for param in model.parameters(): 306 | dist.broadcast(param.data, 0) 307 | 308 | 309 | def init_model(FLAGS, device, dataset): 310 | if FLAGS.decoder: 311 | model = FC(dataset.inp_dim, dataset.out_dim) 312 | elif FLAGS.recurrent: 313 | model = RecurrentFC(dataset.inp_dim, dataset.out_dim) 314 | elif FLAGS.ponder: 315 | model = PonderFC(dataset.inp_dim, dataset.out_dim, FLAGS.num_steps) 316 | elif FLAGS.iterative_decoder: 317 | model = IterativeFC(dataset.inp_dim, dataset.out_dim, FLAGS.mem) 318 | else: 319 | model = EBM(dataset.inp_dim, dataset.out_dim, FLAGS.mem) 320 | 321 | model.to(device) 322 | optimizer = Adam(model.parameters(), lr=1e-4) 323 | 324 | return model, optimizer 325 | 326 | 327 | def safe_cumprod(t, eps=1e-10, dim=-1): 328 | t = torch.clip(t, min=eps, max=1.) 329 | return torch.exp(torch.cumsum(torch.log(t), dim=dim)) 330 | 331 | 332 | def exclusive_cumprod(t, dim=-1): 333 | cum_prod = safe_cumprod(t, dim=dim) 334 | return pad_to(cum_prod, (1, -1), value=1., dim=dim) 335 | 336 | 337 | def calc_geometric(l, dim=-1): 338 | return exclusive_cumprod(1 - l, dim=dim) * l 339 | 340 | 341 | def test(train_dataloader, model, FLAGS, step=0): 342 | global best_test_error_10, best_test_error_20, best_test_error_40, best_test_error_80, best_test_error 343 | if FLAGS.cuda: 344 | dev = torch.device("cuda") 345 | else: 346 | dev = torch.device("cpu") 347 | 348 | replay_buffer = None 349 | dist_list = [] 350 | energy_list = [] 351 | min_dist_list = [] 352 | 353 | model.eval() 354 | counter = 0 355 | 356 | with torch.no_grad(): 357 | for inp, im in train_dataloader: 358 | im = im.float().to(dev) 359 | inp = inp.float().to(dev) 360 | 361 | # Initialize prediction from random guess 362 | pred = (torch.rand_like(im) - 0.5) * 2 363 | scratch = torch.zeros_like(inp) 364 | 365 | pred_init = pred 366 | pred, preds, im_grad, energies, scratch, logits = gen_answer( 367 | inp, FLAGS, model, pred, scratch, 80) 368 | preds = torch.stack(preds, dim=0) 369 | 370 | if FLAGS.ponder: 371 | halting_probs = calc_geometric(logits.sigmoid(), dim=1)[..., 0] 372 | cum_halting_probs = torch.cumsum(halting_probs, dim=-1) 373 | rand_val = torch.rand( 374 | halting_probs.size(0)).to( 375 | cum_halting_probs.device) 376 | sort_id = torch.searchsorted( 377 | cum_halting_probs, rand_val[:, None]) 378 | sort_id = torch.clamp(sort_id, 0, sort_id.size(1) - 1) 379 | sort_id = sort_id[:, :, None].expand(-1, -1, pred.size(-1)) 380 | 381 | energies = torch.stack(energies, dim=0) 382 | 383 | dist = (preds - im[None, :]) 384 | dist = torch.pow(dist, 2) 385 | 386 | dist = dist.mean(dim=-1) 387 | n = dist.size(1) 388 | 389 | dist_energies = dist[1:, :] 390 | min_idx = energies[:, :, 0].argmin(dim=0)[None, :] 391 | dist_min_energy = torch.gather(dist_energies, 0, min_idx) 392 | min_dist_list.append(dist_min_energy.detach()) 393 | 394 | dist = dist.mean(dim=-1) 395 | 396 | energies = energies.mean(dim=-1).mean(dim=-1) 397 | dist_list.append(dist.detach()) 398 | energy_list.append(energies.detach()) 399 | 400 | counter = counter + 1 401 | 402 | if counter > 10: 403 | dist_list = torch.stack(dist_list, dim=0) 404 | dist = dist_list.mean(dim=0) 405 | energy_list = torch.stack(energy_list, dim=0) 406 | energies = energy_list.mean(dim=0) 407 | min_dist = torch.stack(min_dist_list, dim=0).mean() 408 | 409 | print("Testing..................") 410 | print("step errors: ", dist[:20]) 411 | print("energy values: ", energies) 412 | print('test at step %d done!' % step) 413 | break 414 | 415 | if FLAGS.decoder or FLAGS.ponder: 416 | best_test_error_10 = min(best_test_error_10, dist[0].item()) 417 | best_test_error_20 = min(best_test_error_20, dist[0].item()) 418 | best_test_error_40 = min(best_test_error_40, dist[0].item()) 419 | best_test_error_80 = min(best_test_error_80, dist[0].item()) 420 | best_test_error = min(best_test_error, dist[0].item()) 421 | else: 422 | best_test_error_10 = min(best_test_error_10, dist[9].item()) 423 | best_test_error_20 = min(best_test_error_20, dist[19].item()) 424 | best_test_error_40 = min(best_test_error_40, dist[39].item()) 425 | best_test_error_80 = min(best_test_error_80, dist[79].item()) 426 | best_test_error = min(best_test_error, min_dist.item()) 427 | 428 | print("best test error (10, 20, 40, 80, min_energy): {} {} {} {} {}".format( 429 | best_test_error_10, best_test_error_20, best_test_error_40, 430 | best_test_error_80, best_test_error)) 431 | 432 | model.train() 433 | 434 | 435 | def train(train_dataloader, test_dataloader, logger, model, 436 | optimizer, FLAGS, logdir, rank_idx): 437 | 438 | it = FLAGS.resume_iter 439 | optimizer.zero_grad() 440 | dev = torch.device("cuda") 441 | 442 | # initalize a replay buffer of solutions 443 | replay_buffer = ReplayBuffer(10000) 444 | 445 | for epoch in range(FLAGS.num_epoch): 446 | for inp, im in train_dataloader: 447 | im = im.float().to(dev) 448 | inp = inp.float().to(dev) 449 | 450 | # Initalize a solution from random 451 | pred = (torch.rand_like(im) - 0.5) * 2 452 | 453 | # Sample a proportion of samples from past optimization results 454 | if FLAGS.replay_buffer and len(replay_buffer) >= FLAGS.batch_size: 455 | replay_batch, _ = replay_buffer.sample(im.size(0)) 456 | inp_replay, opt_replay, gt_replay, scratch_replay = replay_batch 457 | 458 | replay_mask = np.concatenate( [np.ones(im.size(0)), np.zeros(im.size(0))]).astype(np.bool) 459 | inp = torch.cat([torch.Tensor(inp_replay).cuda(), inp], dim=0) 460 | pred = torch.cat([torch.Tensor(opt_replay).cuda(), pred], dim=0) 461 | im = torch.cat([torch.Tensor(gt_replay).cuda(), im], dim=0) 462 | else: 463 | replay_mask = ( 464 | np.random.uniform( 465 | 0, 466 | 1, 467 | im.size(0)) > 1.0) 468 | 469 | scratch = torch.zeros_like(inp) 470 | 471 | num_steps = FLAGS.num_steps 472 | pred, preds, im_grads, energies, scratch, logits = gen_answer( 473 | inp, FLAGS, model, pred, scratch, num_steps) 474 | energies = torch.stack(energies, dim=0) 475 | preds = torch.stack(preds, dim=1) 476 | 477 | im_grads = torch.stack(im_grads, dim=1) 478 | 479 | if FLAGS.ponder: 480 | geometric_dist = calc_geometric(torch.full( 481 | (FLAGS.num_steps,), 1 / FLAGS.num_steps, device=dev)) 482 | halting_probs = calc_geometric(logits.sigmoid(), dim=1)[..., 0] 483 | 484 | if FLAGS.decoder: 485 | im_loss = torch.pow( 486 | preds[:, -1:] - im[:, None, :], 2).mean(dim=-1).mean(dim=-1) 487 | elif FLAGS.ponder: 488 | halting_probs = halting_probs / \ 489 | halting_probs.sum(dim=1)[:, None] 490 | im_loss = (torch.pow( 491 | preds[:, :] - im[:, None, :], 2)).mean(dim=-1).mean(dim=-1) 492 | else: 493 | im_loss = torch.pow( 494 | preds[:, -1:] - im[:, None, :], 2).mean(dim=-1).mean(dim=-1) 495 | 496 | loss = im_loss.mean() 497 | 498 | if FLAGS.ponder: 499 | ponder_loss = 0.01 * \ 500 | F.kl_div(torch.log(geometric_dist[None, :] + 1e-10), halting_probs, None, None, 'batchmean') 501 | loss = loss + ponder_loss 502 | 503 | loss.backward() 504 | 505 | if FLAGS.replay_buffer: 506 | inp_replay = inp.cpu().detach().numpy() 507 | pred_replay = pred.cpu().detach().numpy() 508 | im_replay = im.cpu().detach().numpy() 509 | scratch = scratch.cpu().detach().numpy() 510 | encode_tuple = list(zip(list(inp_replay), list( 511 | pred_replay), list(im_replay), list(scratch))) 512 | 513 | replay_buffer.add(encode_tuple) 514 | 515 | if FLAGS.gpus > 1: 516 | average_gradients(model) 517 | 518 | optimizer.step() 519 | optimizer.zero_grad() 520 | 521 | if it > 10000: 522 | assert False 523 | 524 | if it % FLAGS.log_interval == 0 and rank_idx == 0: 525 | loss = loss.item() 526 | kvs = {} 527 | kvs['im_loss'] = im_loss.mean().item() 528 | 529 | if it > 10: 530 | replay_mask = replay_mask 531 | no_replay_mask = ~replay_mask 532 | kvs['no_replay_loss'] = im_loss[no_replay_mask].mean().item() 533 | kvs['replay_loss'] = im_loss[replay_mask].mean().item() 534 | 535 | if FLAGS.ponder: 536 | kvs['ponder_loss'] = ponder_loss 537 | 538 | if (not FLAGS.iterative_decoder) and (not FLAGS.decoder) and ( 539 | not FLAGS.recurrent) and (not FLAGS.ponder): 540 | kvs['energy_no_replay'] = energies[-1, 541 | no_replay_mask].mean().item() 542 | kvs['energy_replay'] = energies[-1, 543 | replay_mask].mean().item() 544 | 545 | kvs['energy_start_no_replay'] = energies[0, 546 | no_replay_mask].mean().item() 547 | kvs['energy_start_replay'] = energies[0, 548 | replay_mask].mean().item() 549 | 550 | mean_last_dist = torch.abs(pred - im).mean() 551 | kvs['mean_last_dist'] = mean_last_dist.item() 552 | 553 | string = "Iteration {} ".format(it) 554 | 555 | for k, v in kvs.items(): 556 | string += "%s: %.6f " % (k, v) 557 | logger.add_scalar(k, v, it) 558 | 559 | print(string) 560 | 561 | if it % FLAGS.save_interval == 0 and rank_idx == 0: 562 | model_path = osp.join(logdir, "model_latest.pth".format(it)) 563 | ckpt = {'FLAGS': FLAGS} 564 | 565 | ckpt['model_state_dict'] = model.state_dict() 566 | ckpt['optimizer_state_dict'] = optimizer.state_dict() 567 | 568 | torch.save(ckpt, model_path) 569 | 570 | test(test_dataloader, model, FLAGS, step=it) 571 | 572 | it += 1 573 | 574 | 575 | def main_single(rank, FLAGS): 576 | rank_idx = rank 577 | world_size = FLAGS.gpus 578 | logdir = osp.join(FLAGS.logdir, FLAGS.exp) 579 | 580 | if not os.path.exists('result/%s' % FLAGS.exp): 581 | try: 582 | os.makedirs('result/%s' % FLAGS.exp) 583 | except BaseException: 584 | pass 585 | 586 | if not os.path.exists(logdir): 587 | try: 588 | os.makedirs('logdir') 589 | except BaseException: 590 | pass 591 | 592 | # Load Dataset 593 | if FLAGS.dataset == 'lowrank': 594 | dataset = LowRankDataset('train', FLAGS.rank, FLAGS.ood) 595 | test_dataset = LowRankDataset('test', FLAGS.rank, FLAGS.ood) 596 | elif FLAGS.dataset == 'shortestpath': 597 | dataset = ShortestPath('train', FLAGS.rank, FLAGS.num_steps) 598 | test_dataset = ShortestPath('test', FLAGS.rank, FLAGS.num_steps) 599 | elif FLAGS.dataset == 'negate': 600 | dataset = Negate('train', FLAGS.rank) 601 | test_dataset = Negate('test', FLAGS.rank) 602 | elif FLAGS.dataset == 'addition': 603 | dataset = Addition('train', FLAGS.rank, FLAGS.ood) 604 | test_dataset = Addition('test', FLAGS.rank, FLAGS.ood) 605 | elif FLAGS.dataset == 'inverse': 606 | dataset = Inverse('train', FLAGS.rank, FLAGS.ood) 607 | test_dataset = Inverse('test', FLAGS.rank, FLAGS.ood) 608 | elif FLAGS.dataset == 'square': 609 | dataset = Square('train', FLAGS.rank, FLAGS.num_steps) 610 | test_dataset = Square('test', FLAGS.rank, FLAGS.num_steps) 611 | elif FLAGS.dataset == 'identity': 612 | dataset = Identity('train', FLAGS.rank, FLAGS.num_steps) 613 | test_dataset = Identity('test', FLAGS.rank, FLAGS.num_steps) 614 | elif FLAGS.dataset == 'det': 615 | dataset = Det('train', FLAGS.rank) 616 | test_dataset = Det('test', FLAGS.rank) 617 | elif FLAGS.dataset == 'lu': 618 | dataset = LU('train', FLAGS.rank) 619 | test_dataset = LU('test', FLAGS.rank) 620 | elif FLAGS.dataset == 'sort': 621 | dataset = Sort('train', FLAGS.rank, FLAGS.num_steps) 622 | test_dataset = Sort('test', FLAGS.rank, FLAGS.num_steps) 623 | elif FLAGS.dataset == 'eigen': 624 | dataset = Eigen('train', FLAGS.rank, FLAGS.num_steps) 625 | test_dataset = Eigen('test', FLAGS.rank, FLAGS.num_steps) 626 | elif FLAGS.dataset == 'equation': 627 | dataset = Equation('train', FLAGS.rank, FLAGS.num_steps) 628 | test_dataset = Equation('test', FLAGS.rank, FLAGS.num_steps) 629 | elif FLAGS.dataset == 'qr': 630 | dataset = QR('train', FLAGS.rank, FLAGS.num_steps) 631 | test_dataset = QR('test', FLAGS.rank, FLAGS.num_steps) 632 | elif FLAGS.dataset == 'parity': 633 | dataset = Parity('train', FLAGS.rank, FLAGS.num_steps) 634 | test_dataset = Parity('test', FLAGS.rank, FLAGS.num_steps) 635 | 636 | if not FLAGS.infinite: 637 | dataset = FiniteWrapper( 638 | dataset, 639 | FLAGS.dataset, 640 | FLAGS.capacity, 641 | FLAGS.rank, 642 | FLAGS.num_steps) 643 | 644 | shuffle = True 645 | sampler = None 646 | 647 | if world_size > 1: 648 | group = dist.init_process_group( 649 | backend='nccl', 650 | init_method='tcp://localhost:8113', 651 | world_size=world_size, 652 | rank=rank_idx, 653 | group_name="default") 654 | 655 | torch.cuda.set_device(rank) 656 | device = torch.device('cuda') 657 | 658 | FLAGS_OLD = FLAGS 659 | 660 | # Load model and key arguments 661 | if FLAGS.resume_iter != 0: 662 | model_path = osp.join( 663 | logdir, "model_latest.pth".format( 664 | FLAGS.resume_iter)) 665 | 666 | checkpoint = torch.load(model_path, map_location=torch.device('cpu')) 667 | FLAGS = checkpoint['FLAGS'] 668 | 669 | FLAGS.resume_iter = FLAGS_OLD.resume_iter 670 | FLAGS.save_interval = FLAGS_OLD.save_interval 671 | FLAGS.gpus = FLAGS_OLD.gpus 672 | FLAGS.train = FLAGS_OLD.train 673 | FLAGS.batch_size = FLAGS_OLD.batch_size 674 | FLAGS.step_lr = FLAGS_OLD.step_lr 675 | FLAGS.num_steps = FLAGS_OLD.num_steps 676 | FLAGS.exp = FLAGS_OLD.exp 677 | FLAGS.ponder = FLAGS_OLD.ponder 678 | FLAGS.heatmap = FLAGS_OLD.heatmap 679 | 680 | model, optimizer = init_model(FLAGS, device, dataset) 681 | state_dict = model.state_dict() 682 | 683 | model.load_state_dict(checkpoint['model_state_dict'], strict=False) 684 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 685 | else: 686 | model, optimizer = init_model(FLAGS, device, dataset) 687 | 688 | if FLAGS.gpus > 1: 689 | sync_model(model) 690 | 691 | print("num_parameters: ", sum([p.numel() for p in model.parameters()])) 692 | 693 | train_dataloader = DataLoader( 694 | dataset, 695 | num_workers=FLAGS.data_workers, 696 | batch_size=FLAGS.batch_size, 697 | shuffle=shuffle, 698 | pin_memory=False, 699 | worker_init_fn=worker_init_fn) 700 | test_dataloader = DataLoader( 701 | test_dataset, 702 | num_workers=FLAGS.data_workers, 703 | batch_size=FLAGS.batch_size, 704 | shuffle=True, 705 | pin_memory=False, 706 | drop_last=True, 707 | worker_init_fn=worker_init_fn) 708 | 709 | logger = SummaryWriter(logdir) 710 | it = FLAGS.resume_iter 711 | 712 | if FLAGS.train: 713 | model.train() 714 | else: 715 | model.eval() 716 | 717 | if FLAGS.train: 718 | train( 719 | train_dataloader, 720 | test_dataloader, 721 | logger, 722 | model, 723 | optimizer, 724 | FLAGS, 725 | logdir, 726 | rank_idx) 727 | else: 728 | test(test_dataloader, model, FLAGS, step=FLAGS.resume_iter) 729 | 730 | 731 | def main(): 732 | FLAGS = parser.parse_args() 733 | FLAGS.replay_buffer = not FLAGS.no_replay_buffer 734 | logdir = osp.join(FLAGS.logdir, FLAGS.exp) 735 | 736 | if FLAGS.recurrent: 737 | FLAGS.no_replay_buffer = True 738 | 739 | if FLAGS.decoder: 740 | FLAGS.no_replay_buffer = True 741 | 742 | if not osp.exists(logdir): 743 | os.makedirs(logdir) 744 | 745 | if FLAGS.gpus > 1: 746 | mp.spawn(main_single, nprocs=FLAGS.gpus, args=(FLAGS,)) 747 | else: 748 | main_single(0, FLAGS) 749 | 750 | 751 | if __name__ == "__main__": 752 | try: 753 | torch.multiprocessing.set_start_method('spawn') 754 | except BaseException: 755 | pass 756 | 757 | main() 758 | --------------------------------------------------------------------------------