├── teaser.mp4
├── architecture.png
├── requirements.txt
├── LICENSE.md
├── README.md
├── graph_models.py
├── graph_dataset.py
├── models.py
├── graph_train.py
├── dataset.py
└── train.py
/teaser.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yilundu/irem_code_release/HEAD/teaser.mp4
--------------------------------------------------------------------------------
/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yilundu/irem_code_release/HEAD/architecture.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | easydict==1.9
2 | imageio==2.8.0
3 | ipython==8.2.0
4 | matplotlib==3.2.2
5 | numpy==1.21.2
6 | opencv_python==4.5.3.56
7 | Pillow==9.1.0
8 | scikit_image==0.16.2
9 | scipy==1.7.3
10 | seaborn==0.11.0
11 | torch==1.2.0
12 | torch_geometric==2.0.4
13 | torchvision==0.2.2
14 | tqdm==4.60.0
15 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Iterative Reasoning through Energy Minimization Authors
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.""""))))
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning Iterative Reasoning through Energy Minimization
2 |
3 | Yilun Du, Shuang Li, Joshua B. Tenenbaum, Igor Mordatch
4 |
5 | A link to our paper can be found on [arXiv](https://arxiv.org/abs/2206.15448).
6 |
7 | ## Overview
8 |
9 | Official codebase for [Learning Iterative Reasoning through Energy Minimization]().
10 | Contains scripts to reproduce experiments.
11 |
12 |
13 |
14 |
15 |
16 | ## Instructions
17 |
18 | Please install the listed requirements.txt file.
19 |
20 | ```
21 | pip install -r requirements.txt
22 | ```
23 |
24 | We provide code for running continuous graph reasoning experiments in `graph_train.py` and code for running continuous matrix experiments in `train.py`.
25 |
26 |
27 | To run continuous matrix addition experiments you utilize the following command:
28 |
29 | ```
30 | python train.py --exp=addition_experiment --train --num_steps=10 --dataset=addition --train --cuda --infinite
31 | ```
32 |
33 | To evaluate the final performance of the model after training, you may use the command:
34 |
35 | ```
36 | python train.py --exp=addition_experiment --num_steps=10 --dataset=addition --cuda --infinite --resume_iter=10000
37 | ```
38 |
39 | and the following command for the OOD test set:
40 |
41 | ```
42 | python train.py --exp=addition_experiment --num_steps=10 --dataset=addition --cuda --infinite --resume_iter=10000 --ood
43 | ```
44 |
45 | We may substitute the flag --dataset with other keywords such as inverse or lowrank (as well as additional ones defined in dataset.py).
46 |
47 |
48 |
49 | To run discrete graph reasoning experiments you may utilize the following command:
50 |
51 | ```
52 | python graph_train.py --exp=identity_experiment --train --num_steps=10 --dataset=identity --train --cuda --infinite
53 | ```
54 |
55 | We may substitute the flag --dataset with other datasets such as shortestpath or connected (as well as additional ones defined in graph\_dataset.py).
56 |
57 |
58 | To evaluate the model, you may then utilize the following command:
59 |
60 | ```
61 | python graph_train.py --exp=identity_experiment --num_steps=10 --dataset=identity --cuda --infinite --resume_iter=10000
62 | ```
63 |
64 | ## Citation
65 |
66 | Please cite our paper as:
67 |
68 | ```
69 | @article{du2022irem,
70 | title={Learning Iterative Reasoning through Energy Minimization},
71 | author={Yilun Du and Shuang Li and Joshua B. Tenenbaum and Igor Mordatch},
72 | booktitle={Proceedings of the 39th International Conference on Machine
73 | Learning (ICML-22)},
74 | year={2022}
75 | }
76 | ```
77 |
78 | Note: this is not an official Google product.
79 |
80 | ## License
81 |
82 | MIT
83 |
--------------------------------------------------------------------------------
/graph_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_geometric.nn import GINEConv, global_max_pool
3 | from torch import nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class GraphEBM(nn.Module):
8 | def __init__(self, inp_dim, out_dim, mem):
9 | super(GraphEBM, self).__init__()
10 | h = 128
11 |
12 | self.edge_map = nn.Linear(1, h // 2)
13 | self.edge_map_opt = nn.Linear(1, h // 2)
14 |
15 | self.conv1 = GINEConv(nn.Sequential(nn.Linear(1, 128)), edge_dim=h)
16 | self.conv2 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h)
17 | self.conv3 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h)
18 |
19 | self.fc1 = nn.Linear(128, 128)
20 | self.fc2 = nn.Linear(128, 1)
21 |
22 | def forward(self, inp, opt_edge):
23 |
24 | edge_embed = self.edge_map(inp.edge_attr)
25 | opt_edge_embed = self.edge_map_opt(opt_edge)
26 |
27 | edge_embed = torch.cat([edge_embed, opt_edge_embed], dim=-1)
28 |
29 | h = self.conv1(inp.x, inp.edge_index, edge_attr=edge_embed)
30 | h = self.conv2(h, inp.edge_index, edge_attr=edge_embed)
31 | h = self.conv3(h, inp.edge_index, edge_attr=edge_embed)
32 |
33 | mean_feat = global_max_pool(h, inp.batch)
34 | energy = self.fc2(F.relu(self.fc1(mean_feat)))
35 |
36 | return energy
37 |
38 |
39 | class GraphIterative(nn.Module):
40 | def __init__(self, inp_dim, out_dim, mem):
41 | super(GraphIterative, self).__init__()
42 | h = 128
43 |
44 | self.edge_map = nn.Linear(1, h // 2)
45 | self.edge_map_opt = nn.Linear(1, h // 2)
46 |
47 | self.conv1 = GINEConv(nn.Sequential(nn.Linear(1, 128)), edge_dim=h)
48 | self.conv2 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h)
49 | self.conv3 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h)
50 |
51 | self.decode = nn.Linear(256, 1)
52 |
53 | def forward(self, inp, opt_edge):
54 |
55 | edge_embed = self.edge_map(inp.edge_attr)
56 | opt_edge_embed = self.edge_map_opt(opt_edge)
57 |
58 | edge_embed = torch.cat([edge_embed, opt_edge_embed], dim=-1)
59 |
60 | h = self.conv1(inp.x, inp.edge_index, edge_attr=edge_embed)
61 | h = self.conv2(h, inp.edge_index, edge_attr=edge_embed)
62 | h = self.conv3(h, inp.edge_index, edge_attr=edge_embed)
63 |
64 | edge_index = inp.edge_index
65 | hidden = h
66 |
67 | h1 = hidden[edge_index[0]]
68 | h2 = hidden[edge_index[1]]
69 |
70 | h = torch.cat([h1, h2], dim=-1)
71 | output = self.decode(h)
72 |
73 | return output
74 |
75 |
76 | class GraphRecurrent(nn.Module):
77 | def __init__(self, inp_dim, out_dim):
78 | super(GraphRecurrent, self).__init__()
79 | h = 128
80 |
81 | self.edge_map = nn.Linear(1, h)
82 |
83 | self.conv1 = GINEConv(nn.Sequential(nn.Linear(1, 128)), edge_dim=h)
84 |
85 | self.lstm = nn.LSTM(h, h)
86 | # self.conv2 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h)
87 | self.conv3 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h)
88 |
89 | self.decode = nn.Linear(256, 1)
90 |
91 | def forward(self, inp, state=None):
92 |
93 | edge_embed = self.edge_map(inp.edge_attr)
94 |
95 | h = self.conv1(inp.x, inp.edge_index, edge_attr=edge_embed)
96 |
97 | h, state = self.lstm(h[None], state)
98 | h = self.conv3(h[0], inp.edge_index, edge_attr=edge_embed)
99 |
100 | edge_index = inp.edge_index
101 | hidden = h
102 |
103 | h1 = hidden[edge_index[0]]
104 | h2 = hidden[edge_index[1]]
105 |
106 | h = torch.cat([h1, h2], dim=-1)
107 | output = self.decode(h)
108 |
109 | return output, state
110 |
111 |
112 | class GraphPonder(nn.Module):
113 | def __init__(self, inp_dim, out_dim):
114 | super(GraphPonder, self).__init__()
115 | h = 128
116 |
117 | self.edge_map = nn.Linear(1, h)
118 |
119 | self.conv1 = GINEConv(nn.Sequential(nn.Linear(1, 128)), edge_dim=h)
120 | self.conv2 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h)
121 | self.conv3 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h)
122 |
123 | self.decode = nn.Linear(256, 1)
124 |
125 | def forward(self, inp, iters=1):
126 |
127 | edge_embed = self.edge_map(inp.edge_attr)
128 |
129 | h = self.conv1(inp.x, inp.edge_index, edge_attr=edge_embed)
130 | outputs = []
131 |
132 | for i in range(iters):
133 | h = F.relu(self.conv2(h, inp.edge_index, edge_attr=edge_embed))
134 |
135 | output = self.conv3(h, inp.edge_index, edge_attr=edge_embed)
136 |
137 | edge_index = inp.edge_index
138 | hidden = output
139 |
140 | h1 = hidden[edge_index[0]]
141 | h2 = hidden[edge_index[1]]
142 |
143 | output = torch.cat([h1, h2], dim=-1)
144 | output = self.decode(output)
145 |
146 | outputs.append(output)
147 |
148 | return outputs
149 |
150 |
151 |
152 | class GraphFC(nn.Module):
153 | def __init__(self, inp_dim, out_dim):
154 | super(GraphFC, self).__init__()
155 | h = 128
156 |
157 | self.edge_map = nn.Linear(1, h)
158 | self.conv1 = GINEConv(nn.Sequential(nn.Linear(1, 128)), edge_dim=h)
159 | self.conv2 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h)
160 | self.conv3 = GINEConv(nn.Sequential(nn.Linear(128, 128)), edge_dim=h)
161 |
162 | self.decode = nn.Linear(256, 1)
163 |
164 |
165 | def forward(self, inp):
166 |
167 | edge_embed = self.edge_map(inp.edge_attr)
168 |
169 | h = self.conv1(inp.x, inp.edge_index, edge_attr=edge_embed)
170 | h = self.conv2(h, inp.edge_index, edge_attr=edge_embed)
171 | h = self.conv3(h, inp.edge_index, edge_attr=edge_embed)
172 |
173 | edge_index = inp.edge_index
174 | hidden = h
175 |
176 | h1 = hidden[edge_index[0]]
177 | h2 = hidden[edge_index[1]]
178 |
179 | h = torch.cat([h1, h2], dim=-1)
180 | output = self.decode(h)
181 |
182 | return output
183 |
--------------------------------------------------------------------------------
/graph_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from torch_geometric.data import Data
3 | import torch
4 | import numpy as np
5 | from scipy.sparse.csgraph import connected_components
6 | from scipy.sparse import csr_matrix
7 | from scipy.sparse.csgraph import shortest_path
8 | import random
9 |
10 |
11 | def extract(a, t, x_shape):
12 | b, *_ = t.shape
13 | out = a.gather(-1, t)
14 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
15 |
16 |
17 | def cosine_beta_schedule(timesteps, s = 0.008):
18 | """
19 | cosine schedule
20 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
21 | """
22 | steps = timesteps + 1
23 | x = np.linspace(0, steps, steps)
24 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
25 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
26 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
27 | return np.clip(betas, a_min = 0, a_max = 0.999)
28 |
29 |
30 | class NoisyWrapper:
31 |
32 | def __init__(self, dataset, timesteps):
33 |
34 | self.dataset = dataset
35 | self.timesteps = timesteps
36 | betas = cosine_beta_schedule(timesteps)
37 | self.inp_dim = dataset.inp_dim
38 | self.out_dim = dataset.out_dim
39 |
40 | alphas = 1. - betas
41 | alphas_cumprod = np.cumprod(alphas, axis=0)
42 |
43 | # self.sqrt_alphas_cumprod = torch.tensor(np.sqrt(alphas_cumprod))
44 | # self.sqrt_one_minus_alphas_cumprod = torch.tensor(np.sqrt(1. - alphas_cumprod))
45 |
46 | alphas_cumprod = np.linspace(1, 0, timesteps)
47 | self.sqrt_alphas_cumprod = torch.tensor(np.sqrt(alphas_cumprod))
48 | self.sqrt_one_minus_alphas_cumprod = torch.tensor(np.sqrt(1. - alphas_cumprod))
49 | self.extract = extract
50 |
51 | def __len__(self):
52 | return len(self.dataset)
53 |
54 | def __getitem__(self, *args, **kwargs):
55 | data = self.dataset.__getitem__(*args, **kwargs)
56 | y = data['y']
57 |
58 | t = torch.randint(1, self.timesteps, (1,)).long()
59 | t_next = t - 1
60 | noise = torch.randn_like(y)
61 |
62 | sample = (
63 | self.extract(self.sqrt_alphas_cumprod, t, y.shape) * y +
64 | self.extract(self.sqrt_one_minus_alphas_cumprod, t, y.shape) * noise
65 | )
66 |
67 | sample_next = (
68 | self.extract(self.sqrt_alphas_cumprod, t_next, y.shape) * y +
69 | self.extract(self.sqrt_one_minus_alphas_cumprod, t_next, y.shape) * noise
70 | )
71 |
72 | data['y_prev'] = sample.float()
73 | data['y'] = sample_next.float()
74 |
75 | return data
76 |
77 |
78 | class Identity(data.Dataset):
79 | """Constructs a dataset with N circles, N is the number of components set by flags"""
80 |
81 | def __init__(self, split, rank, vary=False):
82 | """Initialize this dataset class.
83 |
84 | Parameters:
85 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
86 | """
87 | self.h = rank
88 | self.w = rank
89 | self.vary = vary
90 | self.rank = rank
91 |
92 | self.split = split
93 | self.inp_dim = self.h * self.w
94 | self.out_dim = self.h * self.w
95 |
96 | def __getitem__(self, index):
97 | """Return a data point and its metadata information.
98 |
99 | Parameters:
100 | index - - a random integer for data indexing
101 |
102 | Returns a dictionary that contains A and A_paths
103 | A(tensor) - - an image in one domain
104 | A_paths(str) - - the path of the image
105 | """
106 |
107 | if self.vary:
108 | rank = random.randint(2, self.rank)
109 | else:
110 | rank = self.rank
111 |
112 | R = np.random.uniform(-1, 1, (rank, rank))
113 | R_corrupt = R
114 |
115 | repeat = 128 // self.w + 1
116 |
117 | R_tile = np.tile(R, (1, repeat))
118 |
119 | node_features = torch.Tensor(np.zeros((rank, 1)))
120 | edge_index = torch.LongTensor(np.array(np.mgrid[:rank, :rank]).reshape((2, -1)))
121 | edge_features_context = torch.Tensor(np.tile(R_corrupt.reshape((-1, 1)), (1, 1)))
122 | edge_features_label = torch.Tensor(R.reshape((-1, 1)))
123 | noise = torch.Tensor(edge_features_label.reshape((-1, 1)))
124 |
125 | data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features_context, y=edge_features_label, noise=noise)
126 |
127 | return data
128 |
129 | def __len__(self):
130 | """Return the total number of images in the dataset."""
131 | # Dataset is always randomly generated
132 | return int(1e6)
133 |
134 |
135 |
136 | class ConnectedComponents(data.Dataset):
137 | """Constructs a dataset with N circles, N is the number of components set by flags"""
138 |
139 | def __init__(self, split, rank, vary=False):
140 | """Initialize this dataset class.
141 |
142 | Parameters:
143 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
144 | """
145 | self.h = rank
146 | self.w = rank
147 | self.rank = rank
148 | self.vary = vary
149 |
150 | self.split = split
151 | self.inp_dim = self.h * self.w
152 | self.out_dim = self.h * self.w
153 |
154 | def __getitem__(self, index):
155 | """Return a data point and its metadata information.
156 |
157 | Parameters:
158 | index - - a random integer for data indexing
159 |
160 | Returns a dictionary that contains A and A_paths
161 | A(tensor) - - an image in one domain
162 | A_paths(str) - - the path of the image
163 | """
164 |
165 | if self.vary:
166 | rank = random.randint(2, self.rank)
167 | else:
168 | rank = self.rank
169 |
170 | R = np.random.uniform(0, 1, (rank, rank))
171 | connections = R > 0.95
172 | components, component_arr = connected_components(connections, directed=False)
173 |
174 | label = np.zeros((rank, rank))
175 |
176 | for i in range(components):
177 | component_i = np.arange(rank)[component_arr == i]
178 |
179 | for j in range(component_i.shape[0]):
180 | for k in range(component_i.shape[0]):
181 | label[component_i[j], component_i[k]] = 1
182 |
183 | node_features = torch.Tensor(np.zeros((rank, 1)))
184 | edge_index = torch.LongTensor(np.array(np.mgrid[:rank, :rank]).reshape((2, -1)))
185 |
186 | edge_features_context = torch.Tensor(np.tile(connections.reshape((-1, 1)), (1, 1)))
187 | edge_features_label = torch.Tensor(label.reshape((-1, 1)))
188 | noise = torch.Tensor(label.reshape((-1, 1)))
189 |
190 | data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features_context, y=edge_features_label, noise=noise)
191 |
192 | return data
193 |
194 | def __len__(self):
195 | """Return the total number of images in the dataset."""
196 | # Dataset is always randomly generated
197 | return int(1e6)
198 |
199 |
200 | class ShortestPath(data.Dataset):
201 | """Constructs a dataset with N circles, N is the number of components set by flags"""
202 |
203 | def __init__(self, split, rank, vary=False):
204 | """Initialize this dataset class.
205 |
206 | Parameters:
207 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
208 | """
209 | self.h = rank
210 | self.w = rank
211 | self.vary = vary
212 | self.rank = rank
213 |
214 | self.split = split
215 | self.inp_dim = self.h * self.w
216 | self.out_dim = self.h * self.w
217 |
218 | def __getitem__(self, index):
219 | """Return a data point and its metadata information.
220 |
221 | Parameters:
222 | index - - a random integer for data indexing
223 |
224 | Returns a dictionary that contains A and A_paths
225 | A(tensor) - - an image in one domain
226 | A_paths(str) - - the path of the image
227 | """
228 | if self.vary:
229 | rank = random.randint(2, self.rank)
230 | else:
231 | rank = self.rank
232 |
233 | graph = np.random.uniform(0, 1, size=[rank, rank])
234 | graph = graph + graph.transpose()
235 | np.fill_diagonal(graph, 0) # can move to self
236 |
237 | graph_dist, graph_predecessors = shortest_path(csgraph=csr_matrix(graph), unweighted=False, directed=False, return_predecessors=True)
238 |
239 | node_features = torch.Tensor(np.zeros((rank, 1)))
240 | edge_index = torch.LongTensor(np.array(np.mgrid[:rank, :rank]).reshape((2, -1)))
241 |
242 | edge_features_context = torch.Tensor(np.tile(graph.reshape((-1, 1)), (1, 1)))
243 | edge_features_label = torch.Tensor(graph_dist.reshape((-1, 1)))
244 | noise = torch.Tensor(edge_features_label.reshape((-1, 1)))
245 |
246 | data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features_context, y=edge_features_label, noise=noise)
247 |
248 | return data
249 |
250 | def __len__(self):
251 | """Return the total number of images in the dataset."""
252 | # Dataset is always randomly generated
253 | return int(1e6)
254 |
255 |
256 | if __name__ == "__main__":
257 | dataset = ConnectedComponents('train', 10)
258 | data = dataset[0]
259 | import pdb
260 | pdb.set_trace()
261 | print(data)
262 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | from torch.nn import ModuleList
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | import torch
5 | from easydict import EasyDict
6 | import numpy as np
7 | from torch.nn import TransformerEncoderLayer
8 |
9 |
10 | def swish(x):
11 | return x * torch.sigmoid(x)
12 |
13 |
14 | class View(nn.Module):
15 | def __init__(self, size):
16 | super(View, self).__init__()
17 | self.size = size
18 |
19 | def forward(self, tensor):
20 | return tensor.view(self.size)
21 |
22 |
23 | def kaiming_init(m):
24 | if isinstance(m, (nn.Linear, nn.Conv2d)):
25 | torch.nn.init.kaiming_normal(m.weight)
26 | if m.bias is not None:
27 | m.bias.data.fill_(0)
28 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
29 | m.weight.data.fill_(1)
30 | if m.bias is not None:
31 | m.bias.data.fill_(0)
32 |
33 |
34 | def reparametrize(mu, logvar):
35 | std = logvar.div(2).exp()
36 | eps = torch.autograd.Variable(std.data.new(std.size()).normal_())
37 | return mu + std * eps
38 |
39 |
40 | class EBM(nn.Module):
41 | def __init__(self, inp_dim, out_dim, mem):
42 | super(EBM, self).__init__()
43 | h = 512
44 |
45 | if mem:
46 | self.fc1 = nn.Linear(inp_dim + out_dim + h, h)
47 | else:
48 | self.fc1 = nn.Linear(inp_dim + out_dim, h)
49 |
50 | self.mem = mem
51 | self.fc2 = nn.Linear(h, h)
52 | self.fc3 = nn.Linear(h, h)
53 | self.fc4 = nn.Linear(h, 1)
54 |
55 | self.inp_dim = inp_dim
56 | self.out_dim = out_dim
57 |
58 | def forward(self, x):
59 | h = swish(self.fc1(x))
60 | h = swish(self.fc2(h))
61 | h = swish(self.fc3(h))
62 |
63 | output = self.fc4(h)
64 |
65 | return output
66 |
67 |
68 | class EBMTwin(nn.Module):
69 | def __init__(self, inp_dim, out_dim, mem):
70 | super(EBMTwin, self).__init__()
71 | h = 512
72 | self.inp_dim = inp_dim
73 | self.out_dim = out_dim
74 |
75 | self.first_fc1 = nn.Linear(inp_dim, h)
76 | self.second_fc1 = nn.Linear(out_dim, h)
77 |
78 | self.first_fc2 = nn.Linear(h, h)
79 | self.second_fc2 = nn.Linear(h, h)
80 |
81 | self.first_fc3 = nn.Linear(h, h)
82 | self.second_fc3 = nn.Linear(h, h)
83 |
84 | self.prod_h1 = nn.Linear(2 * h, h)
85 | self.prod_h2 = nn.Linear(2 * h, h)
86 |
87 | self.scale_parameter = torch.nn.Parameter(torch.ones(1))
88 |
89 | def forward(self, x):
90 | x_out = x[..., :self.out_dim]
91 | x = x[..., self.out_dim:]
92 |
93 | first_h1 = self.first_fc1(x)
94 | second_h1 = self.second_fc1(x_out)
95 |
96 | first_h2 = self.first_fc2(swish(first_h1))
97 | second_h2 = self.second_fc2(swish(second_h1))
98 |
99 | first_h3 = self.first_fc3(swish(first_h2))
100 | second_h3 = self.second_fc3(swish(second_h2))
101 |
102 | h = torch.cat([first_h2, second_h2], dim=-1)
103 |
104 | energy_3 = torch.abs(first_h3 - second_h3).mean(dim=-1)
105 |
106 | energy = energy_3
107 |
108 | return energy
109 |
110 |
111 | class FC(nn.Module):
112 | def __init__(self, inp_dim, out_dim):
113 | super(FC, self).__init__()
114 | h = 512
115 |
116 | self.fc1 = nn.Linear(inp_dim, h)
117 |
118 | self.fc2 = nn.Linear(h, h)
119 | self.fc3 = nn.Linear(h, h)
120 | self.fc4 = nn.Linear(h, out_dim)
121 |
122 | self.inp_dim = inp_dim
123 | self.out_dim = out_dim
124 |
125 | def forward(self, x):
126 | h = F.relu(self.fc1(x))
127 | h = F.relu(self.fc2(h))
128 | h = F.relu(self.fc3(h))
129 | output = self.fc4(h)
130 |
131 | return output
132 |
133 |
134 | class RecurrentFC(nn.Module):
135 | def __init__(self, inp_dim, out_dim):
136 | super(RecurrentFC, self).__init__()
137 | h = 196
138 |
139 | self.inp_map = nn.Linear(inp_dim, h)
140 | self.inp_map2 = nn.Linear(inp_dim, h)
141 | self.fc1 = nn.Linear(inp_dim, h)
142 | self.lstm = nn.LSTM(h, h)
143 | self.fc3 = nn.Linear(h, h)
144 | self.fc4 = nn.Linear(h, out_dim)
145 | self.scale_parameter = torch.nn.Parameter(torch.ones(1))
146 |
147 | self.inp_dim = inp_dim
148 | self.out_dim = out_dim
149 |
150 | def forward(self, x, state=None):
151 |
152 | h = F.relu(self.fc1(x))
153 | h, state = self.lstm(h[None], state)
154 | output = self.fc4(h[0])
155 |
156 | return output, state
157 |
158 |
159 | class Embedder:
160 | def __init__(self, **kwargs):
161 | self.kwargs = kwargs
162 | self.create_embedding_fn()
163 |
164 | def create_embedding_fn(self):
165 | embed_fns = []
166 | d = self.kwargs['input_dims']
167 | out_dim = 0
168 | if self.kwargs['include_input']:
169 | embed_fns.append(lambda x: x)
170 | out_dim += d
171 |
172 | max_freq = self.kwargs['max_freq_log2']
173 | N_freqs = self.kwargs['num_freqs']
174 |
175 | if self.kwargs['random_feature']:
176 | freq_bands = 2.0**torch.linspace(0., max_freq, steps=N_freqs)
177 | elif self.kwargs['log_sampling']:
178 | freq_bands = 2.0**torch.linspace(0., max_freq, steps=N_freqs)
179 | else:
180 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
181 |
182 | for freq in freq_bands:
183 | for p_fn in self.kwargs['periodic_fns']:
184 | embed_fns.append(
185 | lambda x,
186 | p_fn=p_fn,
187 | freq=freq: p_fn(
188 | x * freq))
189 | out_dim += d
190 |
191 | self.embed_fns = embed_fns
192 | self.out_dim = out_dim
193 |
194 | def embed(self, inputs):
195 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
196 |
197 |
198 | class IterativeFC(nn.Module):
199 | def __init__(self, inp_dim, out_dim, *args):
200 | super(IterativeFC, self).__init__()
201 | h = 512
202 |
203 | input_dims = 1
204 | multires = 10
205 |
206 | self.fc1 = nn.Linear(inp_dim + out_dim, h)
207 | self.fc2 = nn.Linear(h, h)
208 | self.fc3 = nn.Linear(h, h)
209 | self.fc4 = nn.Linear(h, out_dim)
210 | self.scale_parameter = torch.nn.Parameter(torch.ones(1))
211 | self.fc_map = nn.Linear(out_dim, inp_dim)
212 |
213 | self.out_dim = out_dim
214 | self.inp_dim = inp_dim
215 |
216 | def forward(self, x):
217 | x_out = x[..., :self.out_dim]
218 | x = x[..., self.out_dim:]
219 |
220 | h = self.fc_map(x_out)
221 |
222 | prod = x * h
223 | x = torch.cat([x, x_out], dim=-1)
224 |
225 | h = h_first = swish(self.fc1(x))
226 | h = swish(self.fc2(h))
227 | h = swish(self.fc3(h))
228 | output = self.fc4(h)
229 |
230 | return output
231 |
232 |
233 | class PonderFC(nn.Module):
234 | def __init__(self, inp_dim, out_dim, num_step):
235 | super(PonderFC, self).__init__()
236 | h = 512
237 |
238 | input_dims = 1
239 | multires = 10
240 |
241 | self.fc1 = nn.Linear(inp_dim + out_dim, h)
242 | self.fc2 = nn.Linear(h, h)
243 | self.fc3 = nn.Linear(h, h)
244 | self.fc4 = nn.Linear(h, h)
245 | self.decode_logit = nn.Linear(h, 1)
246 |
247 | self.fc4_decode = nn.Linear(h, out_dim)
248 | self.scale_parameter = torch.nn.Parameter(torch.ones(1))
249 | self.fc_map = nn.Linear(out_dim, inp_dim)
250 |
251 | self.out_dim = out_dim
252 | self.inp_dim = inp_dim
253 | self.ponder_prob = 1 - 1 / num_step
254 |
255 | def forward(self, x, eval=False, iters=1):
256 | logits = []
257 | outputs = []
258 |
259 | x_out = x[..., :self.out_dim]
260 | x = x[..., self.out_dim:]
261 |
262 | joint = torch.cat([x_out, x], dim=-1)
263 |
264 | for i in range(iters):
265 | h = F.relu(self.fc1(joint))
266 | h = F.relu(self.fc2(h))
267 | logit = self.decode_logit(h)
268 | output = self.fc4_decode(h)
269 |
270 | joint = torch.cat([output, x], dim=-1)
271 |
272 | logits.append(logit)
273 | outputs.append(output)
274 |
275 | logits = torch.stack(logits, dim=1)
276 |
277 | return outputs, logits
278 |
279 |
280 | class IterativeFCAttention(nn.Module):
281 | def __init__(self, inp_dim, out_dim, rank):
282 | super(IterativeFCAttention, self).__init__()
283 | h = 512
284 |
285 | input_dims = 1
286 | multires = 10
287 | self.rank = rank
288 |
289 | self.inp_fc1 = nn.Linear(inp_dim, h)
290 | self.inp_fc2 = nn.Linear(h, h)
291 |
292 | self.out_fc1 = nn.Linear(out_dim, h)
293 | self.out_fc2 = nn.Linear(h, h)
294 |
295 | self.fc1 = nn.Linear(3 * h, h)
296 | self.fc2 = nn.Linear(h, h)
297 | self.fc3 = nn.Linear(h, out_dim)
298 |
299 | self.out_dim = out_dim
300 | self.inp_dim = inp_dim
301 | self.scale_parameter = torch.nn.Parameter(torch.ones(1))
302 |
303 | def forward(self, x):
304 | # Expects a input x which consists of input computation as well as
305 | # partially optimized result
306 | x_out = x[..., :self.out_dim]
307 | x = x[..., self.out_dim:]
308 |
309 | inp_h = F.relu(self.inp_fc2(F.relu(self.inp_fc1(x))))
310 | out_h = F.relu(self.out_fc2(F.relu(self.out_fc1(x_out))))
311 |
312 | prod_h = out_h * inp_h
313 | h = torch.cat([inp_h, out_h, prod_h], dim=-1)
314 |
315 | out = self.fc3(F.relu(self.fc2(F.relu(self.fc1(h)))))
316 |
317 | return out
318 |
319 |
320 | class IterativeTransformer(nn.Module):
321 | def __init__(self, inp_dim, out_dim, rank):
322 | super(IterativeTransformer, self).__init__()
323 | h = 512
324 | input_dims = 1
325 | multires = 10
326 | self.rank = rank
327 |
328 | embed_kwargs = {
329 | 'include_input': True,
330 | 'random_feature': True,
331 | 'input_dims': input_dims,
332 | 'max_freq_log2': multires - 1,
333 | 'num_freqs': multires,
334 | 'log_sampling': True,
335 | 'periodic_fns': [torch.sin, torch.cos],
336 | }
337 |
338 | self.inp_embed = Embedder(**embed_kwargs)
339 | self.out_embed = Embedder(**embed_kwargs)
340 | self.inp_linear = nn.Linear(1, 63)
341 | self.out_linear = nn.Linear(1, 63)
342 |
343 | coord = np.mgrid[:inp_dim] / inp_dim
344 | coord = torch.Tensor(coord)[None, :, None].cuda()
345 |
346 | self.coord = coord
347 |
348 | out_coord = np.mgrid[:out_dim] / out_dim
349 | out_coord = torch.Tensor(out_coord)[None, :, None].cuda()
350 |
351 | self.out_coord = out_coord
352 |
353 | self.encode1 = TransformerEncoderLayer(84, 4, dim_feedforward=256)
354 | self.encode2 = TransformerEncoderLayer(84, 4, dim_feedforward=256)
355 |
356 | self.fc1 = nn.Linear(84, 128)
357 | self.fc2 = nn.Linear(128, 1)
358 |
359 | self.out_dim = out_dim
360 | self.inp_dim = inp_dim
361 |
362 | def forward(self, x):
363 | x_out = x[..., :self.out_dim].view(-1, self.out_dim, 1)
364 | x = x[..., self.out_dim:].view(-1, self.inp_dim, 1)
365 |
366 | x_linear = self.inp_linear(x)
367 | x_out_linear = self.out_linear(x_out)
368 | x_pos_embed = self.inp_embed.embed(self.coord)
369 | x_out_pos_embed = self.out_embed.embed(self.out_coord)
370 |
371 | x_pos_embed = x_pos_embed.expand(x_linear.size(0), -1, -1)
372 | x_out_pos_embed = x_out_pos_embed.expand(x_linear.size(0), -1, -1)
373 | pos_embed = torch.cat([x_linear, x_pos_embed], dim=-1)
374 | out_embed = torch.cat([x_out_linear, x_out_pos_embed], dim=-1)
375 |
376 | s = pos_embed.size()
377 | pos_embed = pos_embed.view(s[0], -1, s[-1])
378 | out_embed = out_embed.view(s[0], -1, s[-1])
379 |
380 | embed = torch.cat([pos_embed, out_embed], dim=1)
381 | embed = embed.permute(1, 0, 2).contiguous()
382 |
383 | embed = self.encode1(embed)
384 | embed = self.encode2(embed)
385 |
386 | embed = self.fc2(F.relu(self.fc1(embed.permute(1, 0, 2))))[:, :, 0]
387 | output = embed[:, :self.out_dim]
388 |
389 | return output
390 |
391 |
392 | class IterativeAttention(nn.Module):
393 | def __init__(self, inp_dim, out_dim, rank):
394 | super(IterativeAttention, self).__init__()
395 | h = 64
396 |
397 | input_dims = 1
398 | multires = 10
399 | self.rank = rank
400 | self.out_dim = out_dim
401 |
402 | embed_kwargs = {
403 | 'include_input': True,
404 | 'random_feature': True,
405 | 'input_dims': input_dims,
406 | 'max_freq_log2': multires - 1,
407 | 'num_freqs': multires,
408 | 'log_sampling': True,
409 | 'periodic_fns': [torch.sin, torch.cos],
410 | }
411 |
412 | self.inp_embed = Embedder(**embed_kwargs)
413 | self.out_embed = Embedder(**embed_kwargs)
414 | self.inp_linear = nn.Linear(1, 42)
415 | self.out_linear = nn.Linear(1, 42)
416 |
417 | coord = np.mgrid[:rank, :rank] / rank
418 | coord = torch.Tensor(coord).permute(1, 2, 0)[None, :, :, :].cuda()
419 |
420 | self.coord = coord
421 |
422 | out_coord = np.mgrid[:out_dim] / rank
423 | out_coord = torch.Tensor(out_coord)[None, :, None].cuda()
424 |
425 | self.out_coord = out_coord
426 |
427 | self.inp_fc1 = nn.Linear(84, h)
428 | self.inp_fc2 = nn.Linear(h, h)
429 | self.inp_fc3 = nn.Linear(h, h)
430 |
431 | self.out_fc1 = nn.Linear(63, h)
432 | self.out_fc2 = nn.Linear(h, h)
433 | self.out_fc3 = nn.Linear(h, h)
434 |
435 | self.at_fc1 = nn.Linear(63, h)
436 | self.at_fc2 = nn.Linear(h, h)
437 |
438 | self.fc1 = nn.Linear(3 * h, h)
439 | self.fc2 = nn.Linear(h, h)
440 | self.fc3 = nn.Linear(h, 1)
441 | self.scale_parameter = torch.nn.Parameter(torch.ones(1))
442 |
443 | self.out_dim = out_dim
444 | self.inp_dim = inp_dim
445 |
446 | def forward(self, x, idx):
447 | # Expects a input x which consists of input computation as well as
448 | # partially optimized result
449 | output = x[..., :self.out_dim]
450 | x_out = x[..., :self.out_dim].view(-1, self.out_dim, 1)
451 | x = x[..., self.out_dim:].view(-1, self.rank, self.rank, 1)
452 |
453 | x_linear = self.inp_linear(x)
454 | x_out_linear = self.out_linear(x_out)
455 | x_pos_embed = self.inp_embed.embed(self.coord)
456 | x_out_pos_embed = self.out_embed.embed(self.out_coord)
457 |
458 | x_pos_embed = x_pos_embed.expand(x_linear.size(0), -1, -1, -1)
459 | x_out_pos_embed = x_out_pos_embed.expand(x_linear.size(0), -1, -1)
460 |
461 | pos_embed = torch.cat([x_linear, x_pos_embed], dim=-1)
462 | out_embed = torch.cat([x_out_linear, x_out_pos_embed], dim=-1)
463 |
464 | s = pos_embed.size()
465 | pos_embed = pos_embed.view(s[0], -1, s[-1])
466 | s = out_embed.size()
467 | out_embed = out_embed.view(s[0], -1, s[-1])
468 |
469 | idx = idx[:, :, None].expand(-1, -1, out_embed.size(-1))
470 | at_embed = torch.gather(pos_embed, 1, idx)
471 |
472 | at_embed = self.at_fc2(F.relu(self.at_fc1(at_embed)))
473 | out_embed = self.out_fc2(F.relu(self.out_fc1(out_embed)))
474 | pos_embed = self.inp_fc2(F.relu(self.inp_fc1(pos_embed)))
475 |
476 | out_val = self.out_fc3(F.relu(out_embed))
477 | pos_val = self.inp_fc3(F.relu(pos_embed))
478 |
479 | at_out_wt = torch.einsum('bij,bkj->bik', at_embed, out_embed)
480 | at_out_wt = torch.softmax(at_out_wt, dim=-1)
481 | out_embed = (at_out_wt[:, :, :, None] *
482 | out_val[:, None, :, :]).sum(dim=2)
483 |
484 | at_pos_wt = torch.einsum('bij,bkj->bik', at_embed, pos_embed)
485 | at_pos_wt = torch.softmax(at_pos_wt, dim=-1)
486 | pos_embed = (at_pos_wt[:, :, :, None] *
487 | pos_val[:, None, :, :]).sum(dim=2)
488 |
489 | embed = torch.cat([at_embed, pos_embed, out_embed], dim=-1)
490 | out = self.fc3(F.relu(self.fc2(F.relu(self.fc1(embed)))))[:, :, 0]
491 |
492 | output.requires_grad_()
493 | idx = idx[:, :, 0]
494 | output = torch.scatter(output, 1, idx, out)
495 |
496 | return output
497 |
498 |
499 | class AttentionEBM(nn.Module):
500 | def __init__(self, inp_dim, out_dim, rank):
501 | super(AttentionEBM, self).__init__()
502 | h = 64
503 |
504 | input_dims = 1
505 | multires = 10
506 | self.rank = rank
507 | self.out_dim = out_dim
508 |
509 | embed_kwargs = {
510 | 'include_input': True,
511 | 'random_feature': True,
512 | 'input_dims': input_dims,
513 | 'max_freq_log2': multires - 1,
514 | 'num_freqs': multires,
515 | 'log_sampling': True,
516 | 'periodic_fns': [torch.sin, torch.cos],
517 | }
518 |
519 | self.inp_embed = Embedder(**embed_kwargs)
520 | self.out_embed = Embedder(**embed_kwargs)
521 | self.inp_linear = nn.Linear(1, 42)
522 | self.out_linear = nn.Linear(1, 42)
523 |
524 | coord = np.mgrid[:rank, :rank] / rank
525 | coord = torch.Tensor(coord).permute(1, 2, 0)[None, :, :, :].cuda()
526 |
527 | self.coord = coord
528 |
529 | out_coord = np.mgrid[:out_dim] / rank
530 | out_coord = torch.Tensor(out_coord)[None, :, None].cuda()
531 |
532 | self.out_coord = out_coord
533 |
534 | self.inp_fc1 = nn.Linear(84, h)
535 | self.inp_fc2 = nn.Linear(h, h)
536 | self.inp_fc3 = nn.Linear(h, h)
537 |
538 | self.out_fc1 = nn.Linear(63, h)
539 | self.out_fc2 = nn.Linear(h, h)
540 | self.out_fc3 = nn.Linear(h, h)
541 |
542 | self.at_fc1 = nn.Linear(63, h)
543 | self.at_fc2 = nn.Linear(h, h)
544 |
545 | self.fc1 = nn.Linear(3 * h, h)
546 | self.fc2 = nn.Linear(h, h)
547 | self.fc3 = nn.Linear(h, 1)
548 | self.scale_parameter = torch.nn.Parameter(torch.ones(1))
549 |
550 | self.out_dim = out_dim
551 | self.inp_dim = inp_dim
552 |
553 | def forward(self, x, idx):
554 | # Expects a input x which consists of input computation as well as
555 | # partially optimized result
556 |
557 | act_fn = swish
558 |
559 | output = x[..., :self.out_dim]
560 | x_out = x[..., :self.out_dim].view(-1, self.out_dim, 1)
561 | x = x[..., self.out_dim:].view(-1, self.rank, self.rank, 1)
562 |
563 | x_linear = self.inp_linear(x)
564 | x_out_linear = self.out_linear(x_out)
565 | x_pos_embed = self.inp_embed.embed(self.coord)
566 | x_out_pos_embed = self.out_embed.embed(self.out_coord)
567 |
568 | x_pos_embed = x_pos_embed.expand(x_linear.size(0), -1, -1, -1)
569 | x_out_pos_embed = x_out_pos_embed.expand(x_linear.size(0), -1, -1)
570 |
571 | pos_embed = torch.cat([x_linear, x_pos_embed], dim=-1)
572 | out_embed = torch.cat([x_out_linear, x_out_pos_embed], dim=-1)
573 |
574 | s = pos_embed.size()
575 | pos_embed = pos_embed.view(s[0], -1, s[-1])
576 | s = out_embed.size()
577 | out_embed = out_embed.view(s[0], -1, s[-1])
578 |
579 | idx = idx[:, :, None].expand(-1, -1, out_embed.size(-1))
580 | at_embed = torch.gather(pos_embed, 1, idx)
581 |
582 | at_embed = self.at_fc2(act_fn(self.at_fc1(at_embed)))
583 | out_embed = self.out_fc2(act_fn(self.out_fc1(out_embed)))
584 | pos_embed = self.inp_fc2(act_fn(self.inp_fc1(pos_embed)))
585 |
586 | out_val = self.out_fc3(act_fn(out_embed))
587 | pos_val = self.inp_fc3(act_fn(pos_embed))
588 |
589 | at_out_wt = torch.einsum('bij,bkj->bik', at_embed, out_embed)
590 | at_out_wt = torch.softmax(at_out_wt, dim=-1)
591 | out_embed = (at_out_wt[:, :, :, None] *
592 | out_val[:, None, :, :]).sum(dim=2)
593 |
594 | at_pos_wt = torch.einsum('bij,bkj->bik', at_embed, pos_embed)
595 | at_pos_wt = torch.softmax(at_pos_wt, dim=-1)
596 | pos_embed = (at_pos_wt[:, :, :, None] *
597 | pos_val[:, None, :, :]).sum(dim=2)
598 |
599 | embed = torch.cat([at_embed, pos_embed, out_embed], dim=-1)
600 | out = self.fc3(act_fn(self.fc2(act_fn(self.fc1(embed)))))[:, :, 0]
601 |
602 | return out
603 |
604 |
605 | if __name__ == "__main__":
606 | # Unit test the model
607 | args = EasyDict()
608 | args.filter_dim = 64
609 | args.latent_dim = 64
610 | args.im_size = 256
611 |
612 | model = LatentEBM(args).cuda()
613 | x = torch.zeros(1, 3, 256, 256).cuda()
614 | latent = torch.zeros(1, 64).cuda()
615 | model(x, latent)
616 |
--------------------------------------------------------------------------------
/graph_train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from graph_models import GraphEBM, GraphFC, GraphPonder, GraphRecurrent
3 | import torch.nn.functional as F
4 | import os
5 | import pdb
6 | from graph_dataset import Identity, ConnectedComponents, ShortestPath
7 | import matplotlib.pyplot as plt
8 | import torch.multiprocessing as mp
9 | import torch.distributed as dist
10 | from torch.optim import Adam, SparseAdam
11 | from torch.utils.tensorboard import SummaryWriter
12 | from torch_geometric.loader import DataLoader
13 | from torch_geometric.data.batch import Batch
14 | from easydict import EasyDict
15 | import os.path as osp
16 | from torch.nn.utils import clip_grad_norm
17 | import numpy as np
18 | from imageio import imwrite
19 | import argparse
20 | from torchvision.datasets import ImageFolder
21 | import torchvision.transforms as transforms
22 | import torch.nn as nn
23 | import torch.optim as optim
24 | import torch.backends.cudnn as cudnn
25 | import random
26 | from torchvision.utils import make_grid
27 | import seaborn as sns
28 | from torch_geometric.nn import global_mean_pool
29 |
30 |
31 | def worker_init_fn(worker_id):
32 | np.random.seed(int(torch.utils.data.get_worker_info().seed)%(2**32-1))
33 |
34 | class ReplayBuffer(object):
35 | def __init__(self, size):
36 | """Create Replay buffer.
37 | Parameters
38 | ----------
39 | size: int
40 | Max number of transitions to store in the buffer. When the buffer
41 | overflows the old memories are dropped.
42 | """
43 | self._storage = []
44 | self._maxsize = size
45 | self._next_idx = 0
46 |
47 | def __len__(self):
48 | return len(self._storage)
49 |
50 | def add(self, inputs):
51 | batch_size = len(inputs)
52 | if self._next_idx >= len(self._storage):
53 | self._storage.extend(inputs)
54 | else:
55 | if batch_size + self._next_idx < self._maxsize:
56 | self._storage[self._next_idx:self._next_idx +
57 | batch_size] = inputs
58 | else:
59 | split_idx = self._maxsize - self._next_idx
60 | self._storage[self._next_idx:] = inputs[:split_idx]
61 | self._storage[:batch_size - split_idx] = inputs[split_idx:]
62 | self._next_idx = (self._next_idx + batch_size) % self._maxsize
63 |
64 | def _encode_sample(self, idxes):
65 | datas = []
66 | for i in idxes:
67 | data = self._storage[i]
68 | datas.append(data)
69 |
70 | return datas
71 |
72 | def sample(self, batch_size, no_transform=False, downsample=False):
73 | """Sample a batch of experiences.
74 | Parameters
75 | ----------
76 | batch_size: int
77 | How many transitions to sample.
78 | Returns
79 | -------
80 | obs_batch: np.array
81 | batch of observations
82 | act_batch: np.array
83 | batch of actions executed given obs_batch
84 | rew_batch: np.array
85 | rewards received as results of executing act_batch
86 | next_obs_batch: np.array
87 | next set of observations seen after executing act_batch
88 | done_mask: np.array
89 | done_mask[i] = 1 if executing act_batch[i] resulted in
90 | the end of an episode and 0 otherwise.
91 | """
92 | idxes = [random.randint(0, len(self._storage) - 1)
93 | for _ in range(batch_size)]
94 | return self._encode_sample(idxes), torch.Tensor(idxes)
95 |
96 | def set_elms(self, data, idxes):
97 | if len(self._storage) < self._maxsize:
98 | self.add(data)
99 | else:
100 | for i, ix in enumerate(idxes):
101 | self._storage[ix] = data[i]
102 |
103 |
104 | """Parse input arguments"""
105 | parser = argparse.ArgumentParser(description='Train EBM model')
106 |
107 | parser.add_argument('--train', action='store_true', help='whether or not to train')
108 | parser.add_argument('--cuda', action='store_true', help='whether to use cuda or not')
109 | parser.add_argument('--vary', action='store_true', help='vary size of graph')
110 | parser.add_argument('--no_replay_buffer', action='store_true', help='utilize a replay buffer')
111 |
112 | parser.add_argument('--dataset', default='identity', type=str, help='dataset to evaluate')
113 | parser.add_argument('--logdir', default='cachedir', type=str, help='location where log of experiments will be stored')
114 | parser.add_argument('--exp', default='default', type=str, help='name of experiments')
115 |
116 | # training
117 | parser.add_argument('--resume_iter', default=0, type=int, help='iteration to resume training')
118 | parser.add_argument('--batch_size', default=512, type=int, help='size of batch of input to use')
119 | parser.add_argument('--num_epoch', default=10000, type=int, help='number of epochs of training to run')
120 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate for training')
121 | parser.add_argument('--log_interval', default=10, type=int, help='log outputs every so many batches')
122 | parser.add_argument('--save_interval', default=1000, type=int, help='save outputs every so many batches')
123 |
124 | # data
125 | parser.add_argument('--data_workers', default=4, type=int, help='Number of different data workers to load data in parallel')
126 |
127 | # EBM specific settings
128 | parser.add_argument('--filter_dim', default=64, type=int, help='number of filters to use')
129 | parser.add_argument('--rank', default=10, type=int, help='rank of matrix to use')
130 | parser.add_argument('--num_steps', default=5, type=int, help='Steps of gradient descent for training')
131 | parser.add_argument('--step_lr', default=100.0, type=float, help='step size of latents')
132 | parser.add_argument('--latent_dim', default=64, type=int, help='dimension of the latent')
133 | parser.add_argument('--decoder', action='store_true', help='utilize a decoder to output prediction')
134 | parser.add_argument('--gen', action='store_true', help='evaluate generalization')
135 | parser.add_argument('--gen_rank', default=5, type=int, help='Add additional rank for generalization')
136 | parser.add_argument('--recurrent', action='store_true', help='utilize a decoder to output prediction')
137 | parser.add_argument('--ponder', action='store_true', help='utilize a decoder to output prediction')
138 | parser.add_argument('--no_truncate_grad', action='store_true', help='not truncate gradient')
139 | parser.add_argument('--iterative_decoder', action='store_true', help='utilize a decoder to output prediction')
140 | parser.add_argument('--mem', action='store_true', help='add external memory to compute answers')
141 |
142 | # Distributed training hyperparameters
143 | parser.add_argument('--nodes', default=1, type=int, help='number of nodes for training')
144 | parser.add_argument('--gpus', default=1, type=int, help='number of gpus per nodes')
145 | parser.add_argument('--node_rank', default=0, type=int, help='rank of node')
146 | parser.add_argument('--capacity', default=50000, type=int, help='number of elements to generate')
147 | parser.add_argument('--infinite', action='store_true', help='makes the dataset have an infinite number of elements')
148 |
149 | best_test_error = 10.0
150 | best_gen_test_error = 10.0
151 |
152 |
153 | def average_gradients(model):
154 | size = float(dist.get_world_size())
155 |
156 | for name, param in model.named_parameters():
157 | if param.grad is None:
158 | continue
159 |
160 | dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
161 | param.grad.data /= size
162 |
163 |
164 | def gen_answer(inp, FLAGS, model, pred, scratch, num_steps, create_graph=True):
165 | preds = []
166 | im_grads = []
167 | energies = []
168 | im_sups = []
169 |
170 | if FLAGS.decoder:
171 | pred = model.forward(inp)
172 | preds = [pred]
173 | im_grad = torch.zeros(1)
174 | im_grads = [im_grad]
175 | energies = [torch.zeros(1)]
176 | elif FLAGS.recurrent:
177 | preds = []
178 | im_grad = torch.zeros(1)
179 | im_grads = [im_grad]
180 | energies = [torch.zeros(1)]
181 | state = None
182 | for i in range(num_steps):
183 | pred, state = model.forward(inp, state)
184 | preds.append(pred)
185 | elif FLAGS.ponder:
186 | preds = model.forward(inp, iters=num_steps)
187 | pred = preds[-1]
188 | im_grad = torch.zeros(1)
189 | im_grads = [im_grad]
190 | energies = [torch.zeros(1)]
191 | state = None
192 | elif FLAGS.iterative_decoder:
193 | for i in range(num_steps):
194 | energy = torch.zeros(1)
195 |
196 | out_dim = model.out_dim
197 |
198 | im_merge = torch.cat([pred, inp], dim=-1)
199 | pred = model.forward(im_merge) + pred
200 |
201 | energies.append(torch.zeros(1))
202 | im_grads.append(torch.zeros(1))
203 | else:
204 | with torch.enable_grad():
205 | pred.requires_grad_(requires_grad=True)
206 | s = inp.size()
207 | scratch.requires_grad_(requires_grad=True)
208 |
209 | for i in range(num_steps):
210 |
211 | energy = model.forward(inp, pred)
212 |
213 | if FLAGS.no_truncate_grad:
214 | im_grad, = torch.autograd.grad([energy.sum()], [pred], create_graph=create_graph)
215 | else:
216 | if i == (num_steps - 1):
217 | im_grad, = torch.autograd.grad([energy.sum()], [pred], create_graph=True)
218 | else:
219 | im_grad, = torch.autograd.grad([energy.sum()], [pred], create_graph=False)
220 | pred = pred - FLAGS.step_lr * im_grad
221 |
222 | preds.append(pred)
223 | energies.append(energy)
224 | im_grads.append(im_grad)
225 |
226 | return pred, preds, im_grads, energies, scratch
227 |
228 |
229 | def ema_model(model, model_ema, mu=0.999):
230 | for (model, model_ema) in zip(model, model_ema):
231 | for param, param_ema in zip(model.parameters(), model_ema.parameters()):
232 | param_ema.data[:] = mu * param_ema.data + (1 - mu) * param.data
233 |
234 |
235 | def sync_model(model):
236 | size = float(dist.get_world_size())
237 |
238 | for model in model:
239 | for param in model.parameters():
240 | dist.broadcast(param.data, 0)
241 |
242 |
243 | def init_model(FLAGS, device, dataset):
244 | if FLAGS.decoder:
245 | model = GraphFC(dataset.inp_dim, dataset.out_dim)
246 | elif FLAGS.ponder:
247 | model = GraphPonder(dataset.inp_dim, dataset.out_dim)
248 | elif FLAGS.recurrent:
249 | model = GraphRecurrent(dataset.inp_dim, dataset.out_dim)
250 | elif FLAGS.iterative_decoder:
251 | model = IterativeFC(dataset.inp_dim, dataset.out_dim, FLAGS.mem)
252 | else:
253 | model = GraphEBM(dataset.inp_dim, dataset.out_dim, FLAGS.mem)
254 |
255 | model.to(device)
256 | optimizer = Adam(model.parameters(), lr=1e-4)
257 |
258 | return model, optimizer
259 |
260 |
261 | def test(train_dataloader, model, FLAGS, step=0, gen=False):
262 | global best_test_error, best_gen_test_error
263 | if FLAGS.cuda:
264 | dev = torch.device("cuda")
265 | else:
266 | dev = torch.device("cpu")
267 |
268 | replay_buffer = None
269 | dist_list = []
270 | energy_list = []
271 | min_dist_energy_list = []
272 |
273 | model.eval()
274 | counter = 0
275 | with torch.no_grad():
276 | for data in train_dataloader:
277 | data = data.to(dev)
278 | im = data['y']
279 |
280 | pred = (torch.rand_like(data.edge_attr) - 0.5) * 2.
281 |
282 | scratch = torch.zeros_like(pred)
283 |
284 | pred_init = pred
285 | pred, preds, im_grad, energies, scratch = gen_answer(data, FLAGS, model, pred, scratch, 40, create_graph=False)
286 | preds = torch.stack(preds, dim=0)
287 | energies = torch.stack(energies, dim=0).mean(dim=-1).mean(dim=-1)
288 |
289 | dist = (preds - im[None, :])
290 | dist = torch.pow(dist, 2).mean(dim=-1)
291 | n = dist.size(1)
292 |
293 | dist = dist.mean(dim=-1)
294 | min_idx = energies.argmin(dim=0)
295 | dist_min = dist[min_idx]
296 | min_dist_energy_list.append(dist_min)
297 |
298 | dist_list.append(dist.detach())
299 | energy_list.append(energies.detach())
300 |
301 | counter = counter + 1
302 |
303 | if counter > 10:
304 | dist = torch.stack(dist_list, dim=0).mean(dim=0)
305 | energies = torch.stack(energy_list, dim=0).mean(dim=0)
306 | min_dist_energy = torch.stack(min_dist_energy_list, dim=0).mean()
307 | print("Testing..................")
308 | print("last step error: ", dist)
309 | print("energy values: ", energies)
310 | print('test at step %d done!' % step)
311 | break
312 |
313 | if gen:
314 | best_gen_test_error = min(best_gen_test_error, min_dist_energy.item())
315 | print("best gen test error: ", best_gen_test_error)
316 | else:
317 | best_test_error = min(best_test_error, min_dist_energy.item())
318 | print("best test error: ", best_test_error)
319 | model.train()
320 |
321 |
322 | def train(train_dataloader, test_dataloader, gen_dataloader, logger, model, optimizer, FLAGS, logdir, rank_idx):
323 | it = FLAGS.resume_iter
324 | optimizer.zero_grad()
325 |
326 | dev = torch.device("cuda")
327 | replay_buffer = ReplayBuffer(10000)
328 |
329 | for epoch in range(FLAGS.num_epoch):
330 | for data in train_dataloader:
331 | pred = (torch.rand_like(data.edge_attr) - 0.5) * 2
332 | data['noise'] = pred
333 | scratch = torch.zeros_like(pred)
334 | nreplay = FLAGS.batch_size
335 |
336 | if FLAGS.replay_buffer and len(replay_buffer) >= FLAGS.batch_size:
337 | data_list, levels = replay_buffer.sample(32)
338 | new_data_list = data.to_data_list()
339 |
340 | ix = int(FLAGS.batch_size * 0.5)
341 |
342 | nreplay = len(data_list) + 40
343 | data_list = data_list + new_data_list
344 | data = Batch.from_data_list(data_list)
345 | pred = data['noise']
346 | else:
347 | levels = np.zeros(FLAGS.batch_size)
348 | replay_mask = (
349 | np.random.uniform(
350 | 0,
351 | 1,
352 | FLAGS.batch_size) > 1.0)
353 |
354 | data = data.to(dev)
355 | pred = data['noise']
356 | num_steps = FLAGS.num_steps
357 | pred, preds, im_grads, energies, scratch = gen_answer(data, FLAGS, model, pred, scratch, num_steps)
358 | energies = torch.stack(energies, dim=0)
359 |
360 | # Choose energies to be consistent at the last step
361 | preds = torch.stack(preds, dim=1)
362 |
363 | im_grads = torch.stack(im_grads, dim=1)
364 | im = data['y']
365 |
366 | if FLAGS.decoder:
367 | im_loss = torch.pow(preds[:, -1:] - im[:, None, :], 2).mean(dim=-1).mean(dim=-1)
368 | elif FLAGS.ponder:
369 | im_loss = torch.pow(preds[:, :] - im[:, None, :], 2).mean(dim=-1).mean(dim=-1)
370 | else:
371 | im_loss = torch.pow(preds[:, -1:] - im[:, None, :], 2).mean(dim=-1).mean(dim=-1)
372 |
373 | loss = im_loss.mean()
374 | loss.backward()
375 |
376 | if FLAGS.replay_buffer:
377 | data['noise'] = preds[:, -1].detach()
378 | data_list = data.cpu().to_data_list()
379 |
380 | replay_buffer.add(data_list[:nreplay])
381 |
382 | if FLAGS.gpus > 1:
383 | average_gradients(model)
384 |
385 | optimizer.step()
386 | optimizer.zero_grad()
387 |
388 | if it > 10000:
389 | assert False
390 |
391 | if it % FLAGS.log_interval == 0 and rank_idx == 0:
392 | loss = loss.item()
393 | kvs = {}
394 |
395 | kvs['im_loss'] = im_loss.mean().item()
396 |
397 | string = "Iteration {} ".format(it)
398 |
399 | for k, v in kvs.items():
400 | string += "%s: %.6f " % (k,v)
401 | logger.add_scalar(k, v, it)
402 |
403 | print(string)
404 |
405 | if it % FLAGS.save_interval == 0 and rank_idx == 0:
406 | model_path = osp.join(logdir, "model_latest.pth".format(it))
407 | ckpt = {'FLAGS': FLAGS}
408 |
409 | ckpt['model_state_dict'] = model.state_dict()
410 | ckpt['optimizer_state_dict'] = optimizer.state_dict()
411 |
412 | torch.save(ckpt, model_path)
413 |
414 | print("Testing performance .......................")
415 | test(test_dataloader, model, FLAGS, step=it, gen=False)
416 |
417 | print("Generalization performance .......................")
418 | test(gen_dataloader, model, FLAGS, step=it, gen=True)
419 |
420 | it += 1
421 |
422 |
423 |
424 | def main_single(rank, FLAGS):
425 | rank_idx = FLAGS.node_rank * FLAGS.gpus + rank
426 | world_size = FLAGS.nodes * FLAGS.gpus
427 |
428 |
429 | if not os.path.exists('result/%s' % FLAGS.exp):
430 | try:
431 | os.makedirs('result/%s' % FLAGS.exp)
432 | except:
433 | pass
434 |
435 | if FLAGS.dataset == 'identity':
436 | dataset = Identity('train', FLAGS.rank, vary=FLAGS.vary)
437 | test_dataset = Identity('test', FLAGS.rank)
438 | gen_dataset = Identity('test', FLAGS.rank+FLAGS.gen_rank)
439 | elif FLAGS.dataset == 'connected':
440 | dataset = ConnectedComponents('train', FLAGS.rank, vary=FLAGS.vary)
441 | test_dataset = ConnectedComponents('test', FLAGS.rank)
442 | gen_dataset = ConnectedComponents('test', FLAGS.rank+FLAGS.gen_rank)
443 | elif FLAGS.dataset == 'shortestpath':
444 | dataset = ShortestPath('train', FLAGS.rank, vary=FLAGS.vary)
445 | test_dataset = ShortestPath('test', FLAGS.rank)
446 | gen_dataset = ShortestPath('test', FLAGS.rank+FLAGS.gen_rank)
447 |
448 | if FLAGS.gen:
449 | test_dataset = gen_dataset
450 |
451 | shuffle=True
452 | sampler = None
453 |
454 | if world_size > 1:
455 | group = dist.init_process_group(backend='nccl', init_method='tcp://localhost:8113', world_size=world_size, rank=rank_idx, group_name="default")
456 |
457 | torch.cuda.set_device(rank)
458 | device = torch.device('cuda')
459 |
460 | logdir = osp.join(FLAGS.logdir, FLAGS.exp)
461 | FLAGS_OLD = FLAGS
462 |
463 | if FLAGS.resume_iter != 0:
464 | model_path = osp.join(logdir, "model_latest.pth".format(FLAGS.resume_iter))
465 |
466 | checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
467 | FLAGS = checkpoint['FLAGS']
468 |
469 | FLAGS.resume_iter = FLAGS_OLD.resume_iter
470 | FLAGS.save_interval = FLAGS_OLD.save_interval
471 | FLAGS.nodes = FLAGS_OLD.nodes
472 | FLAGS.gpus = FLAGS_OLD.gpus
473 | FLAGS.node_rank = FLAGS_OLD.node_rank
474 | FLAGS.train = FLAGS_OLD.train
475 | FLAGS.batch_size = FLAGS_OLD.batch_size
476 | FLAGS.num_steps = FLAGS_OLD.num_steps
477 | FLAGS.exp = FLAGS_OLD.exp
478 | FLAGS.vis = FLAGS_OLD.vis
479 | FLAGS.ponder = FLAGS_OLD.ponder
480 | FLAGS.recurrent = FLAGS_OLD.recurrent
481 | FLAGS.no_truncate_grad = FLAGS_OLD.no_truncate_grad
482 | FLAGS.gen_rank = FLAGS_OLD.gen_rank
483 | FLAGS.step_lr = FLAGS_OLD.step_lr
484 |
485 | model, optimizer = init_model(FLAGS, device, dataset)
486 | state_dict = model.state_dict()
487 |
488 | model.load_state_dict(checkpoint['model_state_dict'], strict=False)
489 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
490 | else:
491 | model, optimizer = init_model(FLAGS, device, dataset)
492 |
493 | if FLAGS.gpus > 1:
494 | sync_model(model)
495 |
496 | train_dataloader = DataLoader(dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size, shuffle=shuffle, worker_init_fn=worker_init_fn)
497 | test_dataloader = DataLoader(test_dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size, shuffle=True, pin_memory=False, drop_last=True, worker_init_fn=worker_init_fn)
498 | gen_dataloader = DataLoader(gen_dataset, num_workers=FLAGS.data_workers, batch_size=FLAGS.batch_size // 2, shuffle=True, pin_memory=False, drop_last=True, worker_init_fn=worker_init_fn)
499 |
500 | logger = SummaryWriter(logdir)
501 | it = FLAGS.resume_iter
502 |
503 | if FLAGS.train:
504 | model.train()
505 | else:
506 | model.eval()
507 |
508 | if FLAGS.train:
509 | train(train_dataloader, test_dataloader, gen_dataloader, logger, model, optimizer, FLAGS, logdir, rank_idx)
510 | else:
511 | test(test_dataloader, model, FLAGS, step=FLAGS.resume_iter)
512 |
513 |
514 | def main():
515 | FLAGS = parser.parse_args()
516 | FLAGS.replay_buffer = not FLAGS.no_replay_buffer
517 | FLAGS.vary = True
518 | logdir = osp.join(FLAGS.logdir, FLAGS.exp)
519 |
520 | if not osp.exists(logdir):
521 | os.makedirs(logdir)
522 |
523 | if FLAGS.gpus > 1:
524 | mp.spawn(main_single, nprocs=FLAGS.gpus, args=(FLAGS,))
525 | else:
526 | main_single(0, FLAGS)
527 |
528 |
529 | if __name__ == "__main__":
530 | try:
531 | torch.multiprocessing.set_start_method('spawn')
532 | except:
533 | pass
534 | main()
535 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import os
3 | from tqdm import tqdm
4 | import os.path as osp
5 | import numpy as np
6 | import json
7 | import torchvision.transforms.functional as TF
8 | import random
9 | from torchvision.datasets import ImageFolder
10 |
11 | from PIL import Image
12 | import torch.utils.data as data
13 | import torch
14 | import cv2
15 | from torchvision import transforms
16 | import glob
17 |
18 | from glob import glob
19 | from imageio import imread
20 | from skimage.transform import resize as imresize
21 |
22 | from torchvision.datasets import CIFAR100
23 | import numpy as np
24 |
25 | from scipy.sparse import csr_matrix
26 | from scipy.sparse.csgraph import shortest_path
27 | from copy import deepcopy
28 | import time
29 | from scipy.linalg import lu
30 |
31 |
32 | def extract(a, t, x_shape):
33 | b, *_ = t.shape
34 | out = a.gather(-1, t)
35 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
36 |
37 |
38 | def cosine_beta_schedule(timesteps, s=0.008):
39 | """
40 | cosine schedule
41 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
42 | """
43 | steps = timesteps + 1
44 | x = np.linspace(0, steps, steps)
45 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
46 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
47 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
48 | return np.clip(betas, a_min=0, a_max=0.999)
49 |
50 |
51 | class NoisyWrapper:
52 |
53 | def __init__(self, dataset, timesteps):
54 |
55 | self.dataset = dataset
56 | self.timesteps = timesteps
57 | betas = cosine_beta_schedule(timesteps)
58 | self.inp_dim = dataset.inp_dim
59 | self.out_dim = dataset.out_dim
60 |
61 | alphas = 1. - betas
62 | alphas_cumprod = np.cumprod(alphas, axis=0)
63 |
64 | alphas_cumprod = np.linspace(1, 0, timesteps)
65 | self.sqrt_alphas_cumprod = torch.tensor(np.sqrt(alphas_cumprod))
66 | self.sqrt_one_minus_alphas_cumprod = torch.tensor(
67 | np.sqrt(1. - alphas_cumprod))
68 | self.extract = extract
69 |
70 | def __len__(self):
71 | return len(self.dataset)
72 |
73 | def __getitem__(self, *args, **kwargs):
74 | x, y = self.dataset.__getitem__(*args, **kwargs)
75 | x = torch.tensor(x)
76 | y = torch.tensor(y)
77 |
78 | t = torch.randint(1, self.timesteps, (1,)).long()
79 | t_next = t - 1
80 | noise = torch.randn_like(y)
81 |
82 | sample = (
83 | self.extract(
84 | self.sqrt_alphas_cumprod,
85 | t,
86 | y.shape) *
87 | y +
88 | self.extract(
89 | self.sqrt_one_minus_alphas_cumprod,
90 | t,
91 | y.shape) *
92 | noise)
93 |
94 | sample_next = (
95 | self.extract(
96 | self.sqrt_alphas_cumprod,
97 | t_next,
98 | y.shape) *
99 | y +
100 | self.extract(
101 | self.sqrt_one_minus_alphas_cumprod,
102 | t_next,
103 | y.shape) *
104 | noise)
105 | return x, sample, sample_next
106 |
107 |
108 | def conjgrad(A, b, x, num_steps=20):
109 | """
110 | A function to solve [A]{x} = {b} linear equation system with the
111 | conjugate gradient method.
112 | More at: http://en.wikipedia.org/wiki/Conjugate_gradient_method
113 | ========== Parameters ==========
114 | A : matrix
115 | A real symmetric positive definite matrix.
116 | b : vector
117 | The right hand side (RHS) vector of the system.
118 | x : vector
119 | The starting guess for the solution.
120 | """
121 | r = b - np.dot(A, x)
122 | p = r
123 | rsold = np.dot(np.transpose(r), r)
124 |
125 | sols = [x.flatten()]
126 |
127 | for i in range(num_steps):
128 | Ap = np.dot(A, p)
129 | alpha = rsold / np.dot(np.transpose(p), Ap)
130 | x = x + alpha[0, 0] * p
131 | r = r - alpha[0, 0] * Ap
132 | rsnew = np.dot(np.transpose(r), r)
133 | # if np.sqrt(rsnew) < 1e-8:
134 | # break
135 | p = r + (rsnew / rsold) * p
136 | rsold = rsnew
137 | sols.append(x.flatten())
138 |
139 | sols = np.stack(sols, axis=0)
140 | return sols
141 |
142 |
143 | class FiniteWrapper(data.Dataset):
144 | """Constructs a dataset with N circles, N is the number of components set by flags"""
145 |
146 | def __init__(self, dataset, string, capacity, rank, num_steps):
147 | """Initialize this dataset class.
148 |
149 | Parameters:
150 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
151 | """
152 |
153 | cache_file = "cache/{}_{}_{}_{}.npz".format(
154 | capacity, string, rank, num_steps)
155 | self.inp_dim = dataset.inp_dim
156 | self.out_dim = dataset.out_dim
157 |
158 | if osp.exists(cache_file):
159 | data = np.load(cache_file)
160 | inps = data['inp']
161 | outs = data['out']
162 | else:
163 | "Generating static dataset for training....."
164 | inps = []
165 | outs = []
166 |
167 | for i in tqdm(range(capacity)):
168 | inp, out, trace = dataset[i]
169 | inps.append(inp)
170 | outs.append(out)
171 |
172 | inps = np.array(inps)
173 | outs = np.array(outs)
174 |
175 | np.savez(cache_file, inp=inps, out=outs)
176 |
177 | self.inp = inps
178 | self.out = outs
179 |
180 | def __getitem__(self, index):
181 | """Return a data point and its metadata information.
182 |
183 | Parameters:
184 | index - - a random integer for data indexing
185 |
186 | Returns a dictionary that contains A and A_paths
187 | A(tensor) - - an image in one domain
188 | A_paths(str) - - the path of the image
189 | """
190 | inp = self.inp[index]
191 | out = self.out[index]
192 |
193 | return inp, out
194 |
195 | def __len__(self):
196 | """Return the total number of images in the dataset."""
197 | # Dataset is always randomly generated
198 | return self.inp.shape[0]
199 |
200 |
201 | class Identity(data.Dataset):
202 | """Constructs a dataset with N circles, N is the number of components set by flags"""
203 |
204 | def __init__(self, split, rank, num_steps=5):
205 | """Initialize this dataset class.
206 |
207 | Parameters:
208 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
209 | """
210 | self.h = rank
211 | self.w = rank
212 | self.num_steps = num_steps
213 |
214 | self.split = split
215 | self.inp_dim = self.h * self.w
216 | self.out_dim = self.h * self.w
217 |
218 | def __getitem__(self, index):
219 | """Return a data point and its metadata information.
220 |
221 | Parameters:
222 | index - - a random integer for data indexing
223 |
224 | Returns a dictionary that contains A and A_paths
225 | A(tensor) - - an image in one domain
226 | A_paths(str) - - the path of the image
227 | """
228 |
229 | R = np.random.uniform(0, 1, (self.h, self.w))
230 | R_corrupt = R
231 |
232 | R_list = np.tile(R.flatten()[None, :], (self.num_steps, 1))
233 | R_list = np.concatenate([R_corrupt.flatten()[None, :], R_list], axis=0)
234 |
235 | return R_corrupt.flatten(), R.flatten()
236 |
237 | def __len__(self):
238 | """Return the total number of images in the dataset."""
239 | # Dataset is always randomly generated
240 | return int(1e6)
241 |
242 |
243 | class LowRankDataset(data.Dataset):
244 | """Constructs a dataset with N circles, N is the number of components set by flags"""
245 |
246 | def __init__(self, split, rank, ood):
247 | """Initialize this dataset class.
248 |
249 | Parameters:
250 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
251 | """
252 | self.h = rank
253 | self.w = rank
254 |
255 | self.split = split
256 | self.inp_dim = self.h * self.w
257 | self.out_dim = self.h * self.w
258 | self.ood = ood
259 |
260 | def __getitem__(self, index):
261 | """Return a data point and its metadata information.
262 |
263 | Parameters:
264 | index - - a random integer for data indexing
265 |
266 | Returns a dictionary that contains A and A_paths
267 | A(tensor) - - an image in one domain
268 | A_paths(str) - - the path of the image
269 | """
270 |
271 | U = np.random.randn(self.h, 10)
272 | V = np.random.randn(self.w, 10)
273 |
274 | if self.ood:
275 | R = 0.1 * np.random.randn(self.h, self.w) + np.dot(U, V.T) / 5
276 | else:
277 | R = 0.1 * np.random.randn(self.h, self.w) + np.dot(U, V.T) / 20
278 |
279 | mask = np.round(np.random.rand(self.h, self.w))
280 |
281 | R_corrupt = R * mask
282 | return R_corrupt.flatten(), R.flatten()
283 |
284 | def __len__(self):
285 | """Return the total number of images in the dataset."""
286 | return int(1e6)
287 |
288 |
289 | class Negate(data.Dataset):
290 | """Constructs a dataset with N circles, N is the number of components set by flags"""
291 |
292 | def __init__(self, split, rank):
293 | """Initialize this dataset class.
294 |
295 | Parameters:
296 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
297 | """
298 | self.h = rank
299 | self.w = rank
300 |
301 | self.split = split
302 | self.inp_dim = self.h * self.w
303 | self.out_dim = self.h * self.w
304 |
305 | def __getitem__(self, index):
306 | """Return a data point and its metadata information.
307 |
308 | Parameters:
309 | index - - a random integer for data indexing
310 |
311 | Returns a dictionary that contains A and A_paths
312 | A(tensor) - - an image in one domain
313 | A_paths(str) - - the path of the image
314 | """
315 |
316 | R = np.random.uniform(0, 1, (self.h, self.w))
317 | R_corrupt = -1 * R
318 |
319 | return R_corrupt.flatten(), R.flatten()
320 |
321 | def __len__(self):
322 | """Return the total number of images in the dataset."""
323 | # Dataset is always randomly generated
324 | return int(1e6)
325 |
326 |
327 | class Addition(data.Dataset):
328 | """Constructs a dataset with N circles, N is the number of components set by flags"""
329 |
330 | def __init__(self, split, rank, ood):
331 | """Initialize this dataset class.
332 |
333 | Parameters:
334 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
335 | """
336 | self.h = rank
337 |
338 | if self.h == 2:
339 | self.w = 1
340 | else:
341 | self.w = rank
342 | self.ood = ood
343 |
344 | self.split = split
345 | self.inp_dim = 2 * self.h * self.w
346 | self.out_dim = self.h * self.w
347 |
348 | def __getitem__(self, index):
349 | """Return a data point and its metadata information.
350 |
351 | Parameters:
352 | index - - a random integer for data indexing
353 |
354 | Returns a dictionary that contains A and A_paths
355 | A(tensor) - - an image in one domain
356 | A_paths(str) - - the path of the image
357 | """
358 |
359 | if self.ood:
360 | scale = 2.5
361 | else:
362 | scale = 1.0
363 |
364 | R_one = np.random.uniform(-scale, scale, (self.h, self.w)).flatten()
365 | R_two = np.random.uniform(-scale, scale, (self.h, self.w)).flatten()
366 | R_corrupt = np.concatenate([R_one, R_two], axis=0)
367 | R = R_one + R_two
368 |
369 | return R_corrupt.flatten(), R.flatten()
370 |
371 | def __len__(self):
372 | """Return the total number of images in the dataset."""
373 | # Dataset is always randomly generated
374 | return int(1e6)
375 |
376 |
377 | class Square(data.Dataset):
378 | """Constructs a dataset with N circles, N is the number of components set by flags"""
379 |
380 | def __init__(self, split, rank, num_steps=10):
381 | """Initialize this dataset class.
382 |
383 | Parameters:
384 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
385 | """
386 | self.h = rank
387 | self.w = rank
388 | self.num_steps = num_steps
389 |
390 | self.split = split
391 | self.inp_dim = self.h * self.w
392 | self.out_dim = self.h * self.w
393 |
394 | def __getitem__(self, index):
395 | """Return a data point and its metadata information.
396 |
397 | Parameters:
398 | index - - a random integer for data indexing
399 |
400 | Returns a dictionary that contains A and A_paths
401 | A(tensor) - - an image in one domain
402 | A_paths(str) - - the path of the image
403 | """
404 |
405 | R_corrupt = np.random.uniform(-1, 1, (self.h, self.w))
406 | R = np.matmul(R_corrupt, R_corrupt).flatten() / 5.
407 | # R = R_corrupt * R_corrupt
408 |
409 | R_list = np.tile(R.flatten()[None, :], (self.num_steps, 1))
410 | R_list = np.concatenate([R_corrupt.flatten()[None, :], R_list], axis=0)
411 |
412 | return R_corrupt.flatten(), R.flatten()
413 |
414 | def __len__(self):
415 | """Return the total number of images in the dataset."""
416 | # Dataset is always randomly generated
417 | return int(1e6)
418 |
419 |
420 | class Inverse(data.Dataset):
421 | """Constructs a dataset with N circles, N is the number of components set by flags"""
422 |
423 | def __init__(self, split, rank, ood):
424 | """Initialize this dataset class.
425 |
426 | Parameters:
427 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
428 | """
429 | self.h = rank
430 | self.w = rank
431 | self.ood = ood
432 |
433 | self.split = split
434 | self.inp_dim = self.h * self.w
435 | self.out_dim = self.h * self.w
436 |
437 | def __getitem__(self, index):
438 | """Return a data point and its metadata information.
439 |
440 | Parameters:
441 | index - - a random integer for data indexing
442 |
443 | Returns a dictionary that contains A and A_paths
444 | A(tensor) - - an image in one domain
445 | A_paths(str) - - the path of the image
446 | """
447 |
448 | R_corrupt = np.random.uniform(-1, 1, (self.h, self.w))
449 | R_corrupt = R_corrupt.dot(R_corrupt.transpose())
450 | # R_corrupt = R_corrupt + 0.5 * np.eye(self.h)
451 |
452 | if self.ood:
453 | R_corrupt = R_corrupt + R_corrupt.transpose() + 0.1 * np.eye(self.h)
454 | else:
455 | R_corrupt = R_corrupt + R_corrupt.transpose() + 0.5 * np.eye(self.h)
456 |
457 | R = np.linalg.inv(R_corrupt)
458 |
459 | return R_corrupt.flatten(), R.flatten()
460 |
461 | def __len__(self):
462 | """Return the total number of images in the dataset."""
463 | # Dataset is always randomly generated
464 | return int(1e6)
465 |
466 |
467 | class Equation(data.Dataset):
468 | """Constructs a dataset with N circles, N is the number of components set by flags"""
469 |
470 | def __init__(self, split, rank, num_steps=10):
471 | """Initialize this dataset class.
472 |
473 | Parameters:
474 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
475 | """
476 | self.h = rank
477 | self.w = rank
478 |
479 | self.split = split
480 | self.inp_dim = rank * rank + rank
481 | self.out_dim = rank
482 | self.num_steps = num_steps
483 |
484 | def __getitem__(self, index):
485 | """Return a data point and its metadata information.
486 |
487 | Parameters:
488 | index - - a random integer for data indexing
489 |
490 | Returns a dictionary that contains A and A_paths
491 | A(tensor) - - an image in one domain
492 | A_paths(str) - - the path of the image
493 | """
494 |
495 | A = np.random.uniform(-1, 1, (self.h, self.w))
496 | A = A + A.transpose()
497 | A = np.matmul(A, A.transpose())
498 |
499 | x = np.random.uniform(-1, 1, (self.h, 1))
500 | b = np.matmul(A, x)
501 |
502 | inp = np.concatenate([A.flatten(), b.flatten()], axis=0)
503 | sol = x.flatten()
504 |
505 | x_guess = np.random.uniform(-1, 1, (self.h, 1))
506 |
507 | return inp, sol
508 |
509 | def __len__(self):
510 | """Return the total number of images in the dataset."""
511 | # Dataset is always randomly generated
512 | return int(1e6)
513 |
514 |
515 | class LU(data.Dataset):
516 | """Constructs a dataset with N circles, N is the number of components set by flags"""
517 |
518 | def __init__(self, split, rank):
519 | """Initialize this dataset class.
520 |
521 | Parameters:
522 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
523 | """
524 | self.h = rank
525 | self.w = rank
526 |
527 | self.split = split
528 | self.inp_dim = self.h * self.w
529 | self.out_dim = 2 * self.h * self.w
530 |
531 | def __getitem__(self, index):
532 | """Return a data point and its metadata information.
533 |
534 | Parameters:
535 | index - - a random integer for data indexing
536 |
537 | Returns a dictionary that contains A and A_paths
538 | A(tensor) - - an image in one domain
539 | A_paths(str) - - the path of the image
540 | """
541 |
542 | R_corrupt = np.random.uniform(-1, 1, (self.h, self.w))
543 | p, l, u = lu(R_corrupt)
544 | R = np.concatenate([l.flatten(), u.flatten()], axis=0)
545 |
546 | return R_corrupt.flatten(), R
547 |
548 | def __len__(self):
549 | """Return the total number of images in the dataset."""
550 | # Dataset is always randomly generated
551 | return int(1e6)
552 |
553 |
554 | class Det(data.Dataset):
555 | """Constructs a dataset with N circles, N is the number of components set by flags"""
556 |
557 | def __init__(self, split, rank):
558 | """Initialize this dataset class.
559 |
560 | Parameters:
561 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
562 | """
563 | self.h = rank
564 | self.w = rank
565 |
566 | self.split = split
567 | self.inp_dim = self.h * self.w
568 | self.out_dim = 1
569 |
570 | def __getitem__(self, index):
571 | """Return a data point and its metadata information.
572 |
573 | Parameters:
574 | index - - a random integer for data indexing
575 |
576 | Returns a dictionary that contains A and A_paths
577 | A(tensor) - - an image in one domain
578 | A_paths(str) - - the path of the image
579 | """
580 |
581 | R_corrupt = np.random.uniform(0, 1, (self.h, self.w))
582 | R = np.linalg.det(R_corrupt) * 10
583 |
584 | return R_corrupt.flatten(), np.array([R])
585 |
586 | def __len__(self):
587 | """Return the total number of images in the dataset."""
588 | # Dataset is always randomly generated
589 | return int(1e6)
590 |
591 |
592 | class ShortestPath(data.Dataset):
593 | """Constructs a dataset with N circles, N is the number of components set by flags"""
594 |
595 | def __init__(self, split, rank=20, num_steps=10, vary=False):
596 | """Initialize this dataset class.
597 |
598 | Parameters:
599 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
600 | """
601 | self.h = rank
602 | self.w = rank
603 | self.rank = rank
604 | self.edge_prob = 0.3
605 | self.vary = vary
606 |
607 | self.inp_dim = self.h * self.w
608 | self.out_dim = self.h * self.w
609 | self.split = split
610 | self.num_steps = num_steps
611 |
612 | def __getitem__(self, index):
613 | """Return a data point and its metadata information.
614 |
615 | Parameters:
616 | index - - a random integer for data indexing
617 |
618 | Returns a dictionary that contains A and A_paths
619 | A(tensor) - - an image in one domain
620 | A_paths(str) - - the path of the image
621 | """
622 |
623 | # if self.split == "train":
624 | # np.random.seed(index % 10000)
625 |
626 | if self.vary:
627 | rank = random.randint(2, self.rank)
628 | else:
629 | rank = self.rank
630 |
631 | graph = np.random.uniform(0, 1, size=[rank, rank])
632 | graph = graph + graph.transpose()
633 | np.fill_diagonal(graph, 0) # can move to self
634 |
635 | graph_dist, graph_predecessors = shortest_path(csgraph=csr_matrix(
636 | graph), unweighted=False, directed=False, return_predecessors=True)
637 |
638 | return graph.flatten(), graph_dist.flatten()
639 |
640 | def __len__(self):
641 | """Return the total number of images in the dataset."""
642 | # Dataset is always randomly generated
643 | return int(1e6)
644 |
645 |
646 | class Sort(data.Dataset):
647 | """Constructs a dataset with N circles, N is the number of components set by flags"""
648 |
649 | def __init__(self, split, rank=20, num_steps=10):
650 | """Initialize this dataset class.
651 |
652 | Parameters:
653 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
654 | """
655 | self.h = rank
656 | self.w = rank
657 | self.edge_prob = 0.3
658 | self.num_steps = num_steps
659 |
660 | self.inp_dim = self.h * self.w
661 | self.out_dim = self.h * self.w
662 | self.split = split
663 |
664 | def __getitem__(self, index):
665 | """Return a data point and its metadata information.
666 |
667 | Parameters:
668 | index - - a random integer for data indexing
669 |
670 | Returns a dictionary that contains A and A_paths
671 | A(tensor) - - an image in one domain
672 | A_paths(str) - - the path of the image
673 | """
674 |
675 | x = np.random.uniform(-1, 1, self.inp_dim)
676 | rix = np.random.permutation(x.shape[0])
677 | x = x[rix]
678 | x_sort = np.sort(x)
679 |
680 | return x, x_sort
681 |
682 | def __len__(self):
683 | """Return the total number of images in the dataset."""
684 | # Dataset is always randomly generated
685 | return int(1e6)
686 |
687 |
688 | class Eigen(data.Dataset):
689 | """Constructs a dataset with N circles, N is the number of components set by flags"""
690 |
691 | def __init__(self, split, rank=20, num_steps=10):
692 | """Initialize this dataset class.
693 |
694 | Parameters:
695 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
696 | """
697 | self.h = rank
698 | self.w = rank
699 | self.edge_prob = 0.3
700 | self.num_steps = num_steps
701 |
702 | self.inp_dim = self.h * self.w
703 | self.out_dim = self.h
704 | self.split = split
705 |
706 | def __getitem__(self, index):
707 | """Return a data point and its metadata information.
708 |
709 | Parameters:
710 | index - - a random integer for data indexing
711 |
712 | Returns a dictionary that contains A and A_paths
713 | A(tensor) - - an image in one domain
714 | A_paths(str) - - the path of the image
715 | """
716 |
717 | x = np.random.uniform(-1, 1, (self.h, self.w))
718 | x = (x + x.transpose()) / 2
719 | x = np.matmul(x, x.transpose())
720 | val, eig = np.linalg.eig(x)
721 | rix = np.argsort(np.abs(val))[::-1]
722 |
723 | val = val[rix]
724 | eig = eig[:, rix]
725 |
726 | vecs = []
727 | vec = np.random.uniform(-1, 1, (self.h, 1))
728 | vecs.append(vec)
729 | eig = eig[:, 0]
730 | sign = np.sign(eig[0])
731 | eig = eig * sign
732 |
733 | return x.flatten(), eig
734 |
735 | def __len__(self):
736 | """Return the total number of images in the dataset."""
737 | # Dataset is always randomly generated
738 | return int(1e6)
739 |
740 |
741 | class QR(data.Dataset):
742 | """Constructs a dataset with N circles, N is the number of components set by flags"""
743 |
744 | def __init__(self, split, rank=20, num_steps=10):
745 | """Initialize this dataset class.
746 |
747 | Parameters:
748 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
749 | """
750 | self.h = rank
751 | self.w = rank
752 | self.edge_prob = 0.3
753 |
754 | self.inp_dim = self.h * self.w
755 | self.out_dim = self.h * self.w * 2
756 | self.split = split
757 | self.num_steps = num_steps
758 |
759 | def __getitem__(self, index):
760 | """Return a data point and its metadata information.
761 |
762 | Parameters:
763 | index - - a random integer for data indexing
764 |
765 | Returns a dictionary that contains A and A_paths
766 | A(tensor) - - an image in one domain
767 | A_paths(str) - - the path of the image
768 | """
769 |
770 | x = np.random.uniform(-1, 1, (self.h, self.w))
771 | q, r = np.linalg.qr(x)
772 | flat = np.concatenate([q.flatten(), r.flatten()])
773 |
774 | flat_list = np.tile(flat[None, :], (self.num_steps, 1))
775 | flat_start = np.random.uniform(-1, 1, (*flat.shape))
776 | flat_list = np.concatenate([flat_start[None, :], flat_list], axis=0)
777 |
778 | return x.flatten(), flat # , flat_list
779 |
780 | def __len__(self):
781 | """Return the total number of images in the dataset."""
782 | # Dataset is always randomly generated
783 | return int(1e6)
784 |
785 |
786 | class Parity(data.Dataset):
787 | """Constructs a dataset with N circles, N is the number of components set by flags"""
788 |
789 | def __init__(self, split, rank=20, num_steps=10):
790 | """Initialize this dataset class.
791 |
792 | Parameters:
793 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
794 | """
795 | self.h = rank
796 | self.w = rank
797 | self.edge_prob = 0.3
798 |
799 | self.inp_dim = rank
800 | self.out_dim = 1
801 | self.split = split
802 | self.num_steps = num_steps
803 |
804 | def __getitem__(self, index):
805 | """Return a data point and its metadata information.
806 |
807 | Parameters:
808 | index - - a random integer for data indexing
809 |
810 | Returns a dictionary that contains A and A_paths
811 | A(tensor) - - an image in one domain
812 | A_paths(str) - - the path of the image
813 | """
814 |
815 | val = np.random.uniform(0, 1, size=[self.inp_dim])
816 | mask = (val > 0.5).astype(np.float32).flatten()
817 | parity = mask.sum() % 2
818 | parity = np.array([parity])
819 |
820 | rand_val = np.random.uniform(0, 1) > 0.5
821 | parity_noise = np.array([rand_val])
822 |
823 | return mask, parity
824 |
825 | def __len__(self):
826 | """Return the total number of images in the dataset."""
827 | # Dataset is always randomly generated
828 | return int(1e6)
829 |
830 |
831 | if __name__ == "__main__":
832 | A = np.random.uniform(-1, 1, (5, 5))
833 | A = 0.5 * (A + A.transpose())
834 | A = np.matmul(A, A.transpose())
835 | x = np.random.uniform(-1, 1, (5, 1))
836 | b = np.random.uniform(-1, 1, (5, 1))
837 | sol = conjgrad(A, b, x)
838 | import pdb
839 | pdb.set_trace()
840 | print(A)
841 | print(sol)
842 | print(b)
843 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import EBM, FC, IterativeFC, IterativeAttention, IterativeFCAttention, \
3 | IterativeTransformer, EBMTwin, RecurrentFC, PonderFC
4 | import torch.nn.functional as F
5 | import os
6 | import pdb
7 | from dataset import LowRankDataset, ShortestPath, Negate, Inverse, Square, Identity, \
8 | Det, LU, Sort, Eigen, QR, Equation, FiniteWrapper, Parity, Addition
9 | import matplotlib.pyplot as plt
10 | import torch.multiprocessing as mp
11 | import torch.distributed as dist
12 | from torch.optim import Adam
13 | from torch.utils.tensorboard import SummaryWriter
14 | from torch.utils.data import DataLoader
15 | import os.path as osp
16 | import numpy as np
17 | from imageio import imwrite
18 | import argparse
19 | import torchvision.transforms as transforms
20 | import torch.nn as nn
21 | import torch.optim as optim
22 | import torch.backends.cudnn as cudnn
23 | import random
24 | from torchvision.utils import make_grid
25 | import seaborn as sns
26 |
27 |
28 | def worker_init_fn(worker_id):
29 | np.random.seed(int(torch.utils.data.get_worker_info().seed) % (2**32 - 1))
30 |
31 |
32 | class ReplayBuffer(object):
33 | def __init__(self, size):
34 | """Create Replay buffer.
35 | Parameters
36 | ----------
37 | size: int
38 | Max number of transitions to store in the buffer. When the buffer
39 | overflows the old memories are dropped.
40 | """
41 | self._storage = []
42 | self._maxsize = size
43 | self._next_idx = 0
44 |
45 | def __len__(self):
46 | return len(self._storage)
47 |
48 | def add(self, inputs):
49 | batch_size = len(inputs)
50 | if self._next_idx >= len(self._storage):
51 | self._storage.extend(inputs)
52 | else:
53 | if batch_size + self._next_idx < self._maxsize:
54 | self._storage[self._next_idx:self._next_idx +
55 | batch_size] = inputs
56 | else:
57 | split_idx = self._maxsize - self._next_idx
58 | self._storage[self._next_idx:] = inputs[:split_idx]
59 | self._storage[:batch_size - split_idx] = inputs[split_idx:]
60 | self._next_idx = (self._next_idx + batch_size) % self._maxsize
61 |
62 | def _encode_sample(self, idxes):
63 | inps = []
64 | opts = []
65 | targets = []
66 | scratchs = []
67 |
68 | # Store in the intermediate state of optimization problem
69 | for i in idxes:
70 | inp, opt, target, scratch = self._storage[i]
71 | opt = opt
72 | inps.append(inp)
73 | opts.append(opt)
74 | targets.append(target)
75 | scratchs.append(scratch)
76 |
77 | inps = np.array(inps)
78 | opts = np.array(opts)
79 | targets = np.array(targets)
80 | scratchs = np.array(scratchs)
81 |
82 | return inps, opts, targets, scratchs
83 |
84 | def sample(self, batch_size):
85 | """Sample a batch of experiences.
86 | Parameters
87 | ----------
88 | batch_size: int
89 | How many transitions to sample.
90 | Returns
91 | -------
92 | obs_batch: np.array
93 | batch of observations
94 | act_batch: np.array
95 | batch of actions executed given obs_batch
96 | rew_batch: np.array
97 | rewards received as results of executing act_batch
98 | next_obs_batch: np.array
99 | next set of observations seen after executing act_batch
100 | done_mask: np.array
101 | done_mask[i] = 1 if executing act_batch[i] resulted in
102 | the end of an episode and 0 otherwise.
103 | """
104 | idxes = [random.randint(0, len(self._storage) - 1)
105 | for _ in range(batch_size)]
106 | return self._encode_sample(idxes), torch.Tensor(idxes)
107 |
108 | def set_elms(self, data, idxes):
109 | if len(self._storage) < self._maxsize:
110 | self.add(data)
111 | else:
112 | for i, ix in enumerate(idxes):
113 | self._storage[ix] = data[i]
114 |
115 |
116 | """Parse input arguments"""
117 | parser = argparse.ArgumentParser(description='Train EBM model')
118 |
119 | parser.add_argument('--train', action='store_true',
120 | help='whether or not to train')
121 | parser.add_argument('--cuda', action='store_true',
122 | help='whether to use cuda or not')
123 | parser.add_argument('--no_replay_buffer', action='store_true',
124 | help='do not use a replay buffer to train models')
125 | parser.add_argument('--dataset', default='negate', type=str,
126 | help='dataset to evaluate')
127 | parser.add_argument('--logdir', default='cachedir', type=str,
128 | help='location where log of experiments will be stored')
129 | parser.add_argument('--exp', default='default', type=str,
130 | help='name of experiments')
131 |
132 | # training
133 | parser.add_argument('--resume_iter', default=0, type=int,
134 | help='iteration to resume training')
135 | parser.add_argument('--batch_size', default=512, type=int,
136 | help='size of batch of input to use')
137 | parser.add_argument('--num_epoch', default=10000, type=int,
138 | help='number of epochs of training to run')
139 | parser.add_argument('--lr', default=1e-4, type=float,
140 | help='learning rate for training')
141 | parser.add_argument('--log_interval', default=10, type=int,
142 | help='log outputs every so many batches')
143 | parser.add_argument('--save_interval', default=1000, type=int,
144 | help='save outputs every so many batches')
145 |
146 | # data
147 | parser.add_argument('--data_workers', default=4, type=int,
148 | help='Number of different data workers to load data in parallel')
149 |
150 | # Model specific settings
151 | parser.add_argument('--rank', default=20, type=int,
152 | help='rank of matrix to use')
153 | parser.add_argument('--num_steps', default=10, type=int,
154 | help='Steps of gradient descent for training')
155 | parser.add_argument('--step_lr', default=100.0, type=float,
156 | help='step size of latents')
157 | parser.add_argument('--ood', action='store_true',
158 | help='test on the harder ood dataset')
159 | parser.add_argument('--recurrent', action='store_true',
160 | help='utilize a recurrent model to output prediction')
161 | parser.add_argument('--ponder', action='store_true',
162 | help='utilize a ponder network model to output prediction')
163 | parser.add_argument('--decoder', action='store_true',
164 | help='utilize a decoder network to output prediction')
165 | parser.add_argument('--iterative_decoder', action='store_true',
166 | help='utilize a decoder to output prediction')
167 | parser.add_argument('--mem', action='store_true',
168 | help='add external memory to compute answers')
169 | parser.add_argument('--no_truncate', action='store_true',
170 | help='don"t truncate gradient backprop')
171 |
172 | # Distributed training hyperparameters
173 | parser.add_argument('--gpus', default=1, type=int,
174 | help='number of gpus to train with')
175 | parser.add_argument('--node_rank', default=0, type=int, help='rank of node')
176 | parser.add_argument('--capacity', default=50000, type=int,
177 | help='number of elements to generate')
178 | parser.add_argument('--infinite', action='store_true',
179 | help='makes the dataset have an infinite number of elements')
180 |
181 |
182 | best_test_error_10 = 10.0
183 | best_test_error_20 = 10.0
184 | best_test_error_40 = 10.0
185 | best_test_error_80 = 10.0
186 | best_test_error = 10.0
187 |
188 |
189 | def average_gradients(model):
190 | size = float(dist.get_world_size())
191 |
192 | for name, param in model.named_parameters():
193 | if param.grad is None:
194 | continue
195 |
196 | dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
197 | param.grad.data /= size
198 |
199 |
200 | def gen_answer(inp, FLAGS, model, pred, scratchpad, num_steps, create_graph=True):
201 | """
202 | Implement iterative reasoning to obtain the answer to a problem
203 | """
204 |
205 | # List of intermediate predictions
206 | preds = []
207 | im_grads = []
208 | energies = []
209 | logits = []
210 |
211 | if FLAGS.decoder:
212 | pred = model.forward(inp)
213 | preds = [pred]
214 | im_grad = torch.zeros(1)
215 | im_grads = [im_grad]
216 | energies = [torch.zeros(1)]
217 | elif FLAGS.recurrent:
218 | preds = []
219 | im_grad = torch.zeros(1)
220 | im_grads = [im_grad]
221 | energies = [torch.zeros(1)]
222 | state = None
223 | for i in range(num_steps):
224 | pred, state = model.forward(inp, state)
225 | preds.append(pred)
226 | elif FLAGS.ponder:
227 | im_merge = torch.cat([pred, inp], dim=-1)
228 |
229 | preds, logits = model.forward(im_merge, iters=num_steps)
230 | pred = preds[-1]
231 | im_grad = torch.zeros(1)
232 | im_grads = [im_grad]
233 | energies = [torch.zeros(1)]
234 | state = None
235 |
236 | elif FLAGS.iterative_decoder:
237 | for i in range(num_steps):
238 | energy = torch.zeros(1)
239 |
240 | noise_add = (torch.rand_like(pred) - 0.5)
241 | out_dim = model.out_dim
242 |
243 | im_merge = torch.cat([pred, inp], dim=-1)
244 | pred = model.forward(im_merge) + pred
245 |
246 | preds.append(pred)
247 | energies.append(torch.zeros(1))
248 | im_grads.append(torch.zeros(1))
249 | else:
250 | with torch.enable_grad():
251 | pred.requires_grad_(requires_grad=True)
252 | s = inp.size()
253 | scratchpad.requires_grad_(requires_grad=True)
254 | preds.append(pred)
255 |
256 | for i in range(num_steps):
257 | noise = torch.rand_like(pred) - 0.5
258 |
259 | if FLAGS.mem:
260 | im_merge = torch.cat([pred, inp, scratchpad], dim=-1)
261 | else:
262 | im_merge = torch.cat([pred, inp], dim=-1)
263 |
264 | energy = model.forward(im_merge)
265 |
266 | if FLAGS.mem:
267 | im_grad, scratchpad_grad = torch.autograd.grad(
268 | [energy.sum()], [pred, scratchpad], create_graph=create_graph)
269 | else:
270 |
271 | if FLAGS.no_truncate:
272 | im_grad, = torch.autograd.grad(
273 | [energy.sum()], [pred], create_graph=create_graph)
274 | else:
275 | if i != (num_steps - 1):
276 | im_grad, = torch.autograd.grad(
277 | [energy.sum()], [pred], create_graph=False)
278 | else:
279 | im_grad, = torch.autograd.grad(
280 | [energy.sum()], [pred], create_graph=create_graph)
281 |
282 | pred = pred - FLAGS.step_lr * im_grad
283 |
284 | if FLAGS.mem:
285 | scratchpad = scratchpad - FLAGS.step_lr * scratchpad
286 | scratchpad = torch.clamp(scratchpad, -1, 1)
287 |
288 | preds.append(pred)
289 | energies.append(energy)
290 | im_grads.append(im_grad)
291 |
292 | return pred, preds, im_grads, energies, scratchpad, logits
293 |
294 |
295 | def ema_model(model, model_ema, mu=0.999):
296 | for (model, model_ema) in zip(model, model_ema):
297 | for param, param_ema in zip(
298 | model.parameters(), model_ema.parameters()):
299 | param_ema.data[:] = mu * param_ema.data + (1 - mu) * param.data
300 |
301 |
302 | def sync_model(model):
303 | size = float(dist.get_world_size())
304 |
305 | for param in model.parameters():
306 | dist.broadcast(param.data, 0)
307 |
308 |
309 | def init_model(FLAGS, device, dataset):
310 | if FLAGS.decoder:
311 | model = FC(dataset.inp_dim, dataset.out_dim)
312 | elif FLAGS.recurrent:
313 | model = RecurrentFC(dataset.inp_dim, dataset.out_dim)
314 | elif FLAGS.ponder:
315 | model = PonderFC(dataset.inp_dim, dataset.out_dim, FLAGS.num_steps)
316 | elif FLAGS.iterative_decoder:
317 | model = IterativeFC(dataset.inp_dim, dataset.out_dim, FLAGS.mem)
318 | else:
319 | model = EBM(dataset.inp_dim, dataset.out_dim, FLAGS.mem)
320 |
321 | model.to(device)
322 | optimizer = Adam(model.parameters(), lr=1e-4)
323 |
324 | return model, optimizer
325 |
326 |
327 | def safe_cumprod(t, eps=1e-10, dim=-1):
328 | t = torch.clip(t, min=eps, max=1.)
329 | return torch.exp(torch.cumsum(torch.log(t), dim=dim))
330 |
331 |
332 | def exclusive_cumprod(t, dim=-1):
333 | cum_prod = safe_cumprod(t, dim=dim)
334 | return pad_to(cum_prod, (1, -1), value=1., dim=dim)
335 |
336 |
337 | def calc_geometric(l, dim=-1):
338 | return exclusive_cumprod(1 - l, dim=dim) * l
339 |
340 |
341 | def test(train_dataloader, model, FLAGS, step=0):
342 | global best_test_error_10, best_test_error_20, best_test_error_40, best_test_error_80, best_test_error
343 | if FLAGS.cuda:
344 | dev = torch.device("cuda")
345 | else:
346 | dev = torch.device("cpu")
347 |
348 | replay_buffer = None
349 | dist_list = []
350 | energy_list = []
351 | min_dist_list = []
352 |
353 | model.eval()
354 | counter = 0
355 |
356 | with torch.no_grad():
357 | for inp, im in train_dataloader:
358 | im = im.float().to(dev)
359 | inp = inp.float().to(dev)
360 |
361 | # Initialize prediction from random guess
362 | pred = (torch.rand_like(im) - 0.5) * 2
363 | scratch = torch.zeros_like(inp)
364 |
365 | pred_init = pred
366 | pred, preds, im_grad, energies, scratch, logits = gen_answer(
367 | inp, FLAGS, model, pred, scratch, 80)
368 | preds = torch.stack(preds, dim=0)
369 |
370 | if FLAGS.ponder:
371 | halting_probs = calc_geometric(logits.sigmoid(), dim=1)[..., 0]
372 | cum_halting_probs = torch.cumsum(halting_probs, dim=-1)
373 | rand_val = torch.rand(
374 | halting_probs.size(0)).to(
375 | cum_halting_probs.device)
376 | sort_id = torch.searchsorted(
377 | cum_halting_probs, rand_val[:, None])
378 | sort_id = torch.clamp(sort_id, 0, sort_id.size(1) - 1)
379 | sort_id = sort_id[:, :, None].expand(-1, -1, pred.size(-1))
380 |
381 | energies = torch.stack(energies, dim=0)
382 |
383 | dist = (preds - im[None, :])
384 | dist = torch.pow(dist, 2)
385 |
386 | dist = dist.mean(dim=-1)
387 | n = dist.size(1)
388 |
389 | dist_energies = dist[1:, :]
390 | min_idx = energies[:, :, 0].argmin(dim=0)[None, :]
391 | dist_min_energy = torch.gather(dist_energies, 0, min_idx)
392 | min_dist_list.append(dist_min_energy.detach())
393 |
394 | dist = dist.mean(dim=-1)
395 |
396 | energies = energies.mean(dim=-1).mean(dim=-1)
397 | dist_list.append(dist.detach())
398 | energy_list.append(energies.detach())
399 |
400 | counter = counter + 1
401 |
402 | if counter > 10:
403 | dist_list = torch.stack(dist_list, dim=0)
404 | dist = dist_list.mean(dim=0)
405 | energy_list = torch.stack(energy_list, dim=0)
406 | energies = energy_list.mean(dim=0)
407 | min_dist = torch.stack(min_dist_list, dim=0).mean()
408 |
409 | print("Testing..................")
410 | print("step errors: ", dist[:20])
411 | print("energy values: ", energies)
412 | print('test at step %d done!' % step)
413 | break
414 |
415 | if FLAGS.decoder or FLAGS.ponder:
416 | best_test_error_10 = min(best_test_error_10, dist[0].item())
417 | best_test_error_20 = min(best_test_error_20, dist[0].item())
418 | best_test_error_40 = min(best_test_error_40, dist[0].item())
419 | best_test_error_80 = min(best_test_error_80, dist[0].item())
420 | best_test_error = min(best_test_error, dist[0].item())
421 | else:
422 | best_test_error_10 = min(best_test_error_10, dist[9].item())
423 | best_test_error_20 = min(best_test_error_20, dist[19].item())
424 | best_test_error_40 = min(best_test_error_40, dist[39].item())
425 | best_test_error_80 = min(best_test_error_80, dist[79].item())
426 | best_test_error = min(best_test_error, min_dist.item())
427 |
428 | print("best test error (10, 20, 40, 80, min_energy): {} {} {} {} {}".format(
429 | best_test_error_10, best_test_error_20, best_test_error_40,
430 | best_test_error_80, best_test_error))
431 |
432 | model.train()
433 |
434 |
435 | def train(train_dataloader, test_dataloader, logger, model,
436 | optimizer, FLAGS, logdir, rank_idx):
437 |
438 | it = FLAGS.resume_iter
439 | optimizer.zero_grad()
440 | dev = torch.device("cuda")
441 |
442 | # initalize a replay buffer of solutions
443 | replay_buffer = ReplayBuffer(10000)
444 |
445 | for epoch in range(FLAGS.num_epoch):
446 | for inp, im in train_dataloader:
447 | im = im.float().to(dev)
448 | inp = inp.float().to(dev)
449 |
450 | # Initalize a solution from random
451 | pred = (torch.rand_like(im) - 0.5) * 2
452 |
453 | # Sample a proportion of samples from past optimization results
454 | if FLAGS.replay_buffer and len(replay_buffer) >= FLAGS.batch_size:
455 | replay_batch, _ = replay_buffer.sample(im.size(0))
456 | inp_replay, opt_replay, gt_replay, scratch_replay = replay_batch
457 |
458 | replay_mask = np.concatenate( [np.ones(im.size(0)), np.zeros(im.size(0))]).astype(np.bool)
459 | inp = torch.cat([torch.Tensor(inp_replay).cuda(), inp], dim=0)
460 | pred = torch.cat([torch.Tensor(opt_replay).cuda(), pred], dim=0)
461 | im = torch.cat([torch.Tensor(gt_replay).cuda(), im], dim=0)
462 | else:
463 | replay_mask = (
464 | np.random.uniform(
465 | 0,
466 | 1,
467 | im.size(0)) > 1.0)
468 |
469 | scratch = torch.zeros_like(inp)
470 |
471 | num_steps = FLAGS.num_steps
472 | pred, preds, im_grads, energies, scratch, logits = gen_answer(
473 | inp, FLAGS, model, pred, scratch, num_steps)
474 | energies = torch.stack(energies, dim=0)
475 | preds = torch.stack(preds, dim=1)
476 |
477 | im_grads = torch.stack(im_grads, dim=1)
478 |
479 | if FLAGS.ponder:
480 | geometric_dist = calc_geometric(torch.full(
481 | (FLAGS.num_steps,), 1 / FLAGS.num_steps, device=dev))
482 | halting_probs = calc_geometric(logits.sigmoid(), dim=1)[..., 0]
483 |
484 | if FLAGS.decoder:
485 | im_loss = torch.pow(
486 | preds[:, -1:] - im[:, None, :], 2).mean(dim=-1).mean(dim=-1)
487 | elif FLAGS.ponder:
488 | halting_probs = halting_probs / \
489 | halting_probs.sum(dim=1)[:, None]
490 | im_loss = (torch.pow(
491 | preds[:, :] - im[:, None, :], 2)).mean(dim=-1).mean(dim=-1)
492 | else:
493 | im_loss = torch.pow(
494 | preds[:, -1:] - im[:, None, :], 2).mean(dim=-1).mean(dim=-1)
495 |
496 | loss = im_loss.mean()
497 |
498 | if FLAGS.ponder:
499 | ponder_loss = 0.01 * \
500 | F.kl_div(torch.log(geometric_dist[None, :] + 1e-10), halting_probs, None, None, 'batchmean')
501 | loss = loss + ponder_loss
502 |
503 | loss.backward()
504 |
505 | if FLAGS.replay_buffer:
506 | inp_replay = inp.cpu().detach().numpy()
507 | pred_replay = pred.cpu().detach().numpy()
508 | im_replay = im.cpu().detach().numpy()
509 | scratch = scratch.cpu().detach().numpy()
510 | encode_tuple = list(zip(list(inp_replay), list(
511 | pred_replay), list(im_replay), list(scratch)))
512 |
513 | replay_buffer.add(encode_tuple)
514 |
515 | if FLAGS.gpus > 1:
516 | average_gradients(model)
517 |
518 | optimizer.step()
519 | optimizer.zero_grad()
520 |
521 | if it > 10000:
522 | assert False
523 |
524 | if it % FLAGS.log_interval == 0 and rank_idx == 0:
525 | loss = loss.item()
526 | kvs = {}
527 | kvs['im_loss'] = im_loss.mean().item()
528 |
529 | if it > 10:
530 | replay_mask = replay_mask
531 | no_replay_mask = ~replay_mask
532 | kvs['no_replay_loss'] = im_loss[no_replay_mask].mean().item()
533 | kvs['replay_loss'] = im_loss[replay_mask].mean().item()
534 |
535 | if FLAGS.ponder:
536 | kvs['ponder_loss'] = ponder_loss
537 |
538 | if (not FLAGS.iterative_decoder) and (not FLAGS.decoder) and (
539 | not FLAGS.recurrent) and (not FLAGS.ponder):
540 | kvs['energy_no_replay'] = energies[-1,
541 | no_replay_mask].mean().item()
542 | kvs['energy_replay'] = energies[-1,
543 | replay_mask].mean().item()
544 |
545 | kvs['energy_start_no_replay'] = energies[0,
546 | no_replay_mask].mean().item()
547 | kvs['energy_start_replay'] = energies[0,
548 | replay_mask].mean().item()
549 |
550 | mean_last_dist = torch.abs(pred - im).mean()
551 | kvs['mean_last_dist'] = mean_last_dist.item()
552 |
553 | string = "Iteration {} ".format(it)
554 |
555 | for k, v in kvs.items():
556 | string += "%s: %.6f " % (k, v)
557 | logger.add_scalar(k, v, it)
558 |
559 | print(string)
560 |
561 | if it % FLAGS.save_interval == 0 and rank_idx == 0:
562 | model_path = osp.join(logdir, "model_latest.pth".format(it))
563 | ckpt = {'FLAGS': FLAGS}
564 |
565 | ckpt['model_state_dict'] = model.state_dict()
566 | ckpt['optimizer_state_dict'] = optimizer.state_dict()
567 |
568 | torch.save(ckpt, model_path)
569 |
570 | test(test_dataloader, model, FLAGS, step=it)
571 |
572 | it += 1
573 |
574 |
575 | def main_single(rank, FLAGS):
576 | rank_idx = rank
577 | world_size = FLAGS.gpus
578 | logdir = osp.join(FLAGS.logdir, FLAGS.exp)
579 |
580 | if not os.path.exists('result/%s' % FLAGS.exp):
581 | try:
582 | os.makedirs('result/%s' % FLAGS.exp)
583 | except BaseException:
584 | pass
585 |
586 | if not os.path.exists(logdir):
587 | try:
588 | os.makedirs('logdir')
589 | except BaseException:
590 | pass
591 |
592 | # Load Dataset
593 | if FLAGS.dataset == 'lowrank':
594 | dataset = LowRankDataset('train', FLAGS.rank, FLAGS.ood)
595 | test_dataset = LowRankDataset('test', FLAGS.rank, FLAGS.ood)
596 | elif FLAGS.dataset == 'shortestpath':
597 | dataset = ShortestPath('train', FLAGS.rank, FLAGS.num_steps)
598 | test_dataset = ShortestPath('test', FLAGS.rank, FLAGS.num_steps)
599 | elif FLAGS.dataset == 'negate':
600 | dataset = Negate('train', FLAGS.rank)
601 | test_dataset = Negate('test', FLAGS.rank)
602 | elif FLAGS.dataset == 'addition':
603 | dataset = Addition('train', FLAGS.rank, FLAGS.ood)
604 | test_dataset = Addition('test', FLAGS.rank, FLAGS.ood)
605 | elif FLAGS.dataset == 'inverse':
606 | dataset = Inverse('train', FLAGS.rank, FLAGS.ood)
607 | test_dataset = Inverse('test', FLAGS.rank, FLAGS.ood)
608 | elif FLAGS.dataset == 'square':
609 | dataset = Square('train', FLAGS.rank, FLAGS.num_steps)
610 | test_dataset = Square('test', FLAGS.rank, FLAGS.num_steps)
611 | elif FLAGS.dataset == 'identity':
612 | dataset = Identity('train', FLAGS.rank, FLAGS.num_steps)
613 | test_dataset = Identity('test', FLAGS.rank, FLAGS.num_steps)
614 | elif FLAGS.dataset == 'det':
615 | dataset = Det('train', FLAGS.rank)
616 | test_dataset = Det('test', FLAGS.rank)
617 | elif FLAGS.dataset == 'lu':
618 | dataset = LU('train', FLAGS.rank)
619 | test_dataset = LU('test', FLAGS.rank)
620 | elif FLAGS.dataset == 'sort':
621 | dataset = Sort('train', FLAGS.rank, FLAGS.num_steps)
622 | test_dataset = Sort('test', FLAGS.rank, FLAGS.num_steps)
623 | elif FLAGS.dataset == 'eigen':
624 | dataset = Eigen('train', FLAGS.rank, FLAGS.num_steps)
625 | test_dataset = Eigen('test', FLAGS.rank, FLAGS.num_steps)
626 | elif FLAGS.dataset == 'equation':
627 | dataset = Equation('train', FLAGS.rank, FLAGS.num_steps)
628 | test_dataset = Equation('test', FLAGS.rank, FLAGS.num_steps)
629 | elif FLAGS.dataset == 'qr':
630 | dataset = QR('train', FLAGS.rank, FLAGS.num_steps)
631 | test_dataset = QR('test', FLAGS.rank, FLAGS.num_steps)
632 | elif FLAGS.dataset == 'parity':
633 | dataset = Parity('train', FLAGS.rank, FLAGS.num_steps)
634 | test_dataset = Parity('test', FLAGS.rank, FLAGS.num_steps)
635 |
636 | if not FLAGS.infinite:
637 | dataset = FiniteWrapper(
638 | dataset,
639 | FLAGS.dataset,
640 | FLAGS.capacity,
641 | FLAGS.rank,
642 | FLAGS.num_steps)
643 |
644 | shuffle = True
645 | sampler = None
646 |
647 | if world_size > 1:
648 | group = dist.init_process_group(
649 | backend='nccl',
650 | init_method='tcp://localhost:8113',
651 | world_size=world_size,
652 | rank=rank_idx,
653 | group_name="default")
654 |
655 | torch.cuda.set_device(rank)
656 | device = torch.device('cuda')
657 |
658 | FLAGS_OLD = FLAGS
659 |
660 | # Load model and key arguments
661 | if FLAGS.resume_iter != 0:
662 | model_path = osp.join(
663 | logdir, "model_latest.pth".format(
664 | FLAGS.resume_iter))
665 |
666 | checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
667 | FLAGS = checkpoint['FLAGS']
668 |
669 | FLAGS.resume_iter = FLAGS_OLD.resume_iter
670 | FLAGS.save_interval = FLAGS_OLD.save_interval
671 | FLAGS.gpus = FLAGS_OLD.gpus
672 | FLAGS.train = FLAGS_OLD.train
673 | FLAGS.batch_size = FLAGS_OLD.batch_size
674 | FLAGS.step_lr = FLAGS_OLD.step_lr
675 | FLAGS.num_steps = FLAGS_OLD.num_steps
676 | FLAGS.exp = FLAGS_OLD.exp
677 | FLAGS.ponder = FLAGS_OLD.ponder
678 | FLAGS.heatmap = FLAGS_OLD.heatmap
679 |
680 | model, optimizer = init_model(FLAGS, device, dataset)
681 | state_dict = model.state_dict()
682 |
683 | model.load_state_dict(checkpoint['model_state_dict'], strict=False)
684 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
685 | else:
686 | model, optimizer = init_model(FLAGS, device, dataset)
687 |
688 | if FLAGS.gpus > 1:
689 | sync_model(model)
690 |
691 | print("num_parameters: ", sum([p.numel() for p in model.parameters()]))
692 |
693 | train_dataloader = DataLoader(
694 | dataset,
695 | num_workers=FLAGS.data_workers,
696 | batch_size=FLAGS.batch_size,
697 | shuffle=shuffle,
698 | pin_memory=False,
699 | worker_init_fn=worker_init_fn)
700 | test_dataloader = DataLoader(
701 | test_dataset,
702 | num_workers=FLAGS.data_workers,
703 | batch_size=FLAGS.batch_size,
704 | shuffle=True,
705 | pin_memory=False,
706 | drop_last=True,
707 | worker_init_fn=worker_init_fn)
708 |
709 | logger = SummaryWriter(logdir)
710 | it = FLAGS.resume_iter
711 |
712 | if FLAGS.train:
713 | model.train()
714 | else:
715 | model.eval()
716 |
717 | if FLAGS.train:
718 | train(
719 | train_dataloader,
720 | test_dataloader,
721 | logger,
722 | model,
723 | optimizer,
724 | FLAGS,
725 | logdir,
726 | rank_idx)
727 | else:
728 | test(test_dataloader, model, FLAGS, step=FLAGS.resume_iter)
729 |
730 |
731 | def main():
732 | FLAGS = parser.parse_args()
733 | FLAGS.replay_buffer = not FLAGS.no_replay_buffer
734 | logdir = osp.join(FLAGS.logdir, FLAGS.exp)
735 |
736 | if FLAGS.recurrent:
737 | FLAGS.no_replay_buffer = True
738 |
739 | if FLAGS.decoder:
740 | FLAGS.no_replay_buffer = True
741 |
742 | if not osp.exists(logdir):
743 | os.makedirs(logdir)
744 |
745 | if FLAGS.gpus > 1:
746 | mp.spawn(main_single, nprocs=FLAGS.gpus, args=(FLAGS,))
747 | else:
748 | main_single(0, FLAGS)
749 |
750 |
751 | if __name__ == "__main__":
752 | try:
753 | torch.multiprocessing.set_start_method('spawn')
754 | except BaseException:
755 | pass
756 |
757 | main()
758 |
--------------------------------------------------------------------------------