├── README.md
├── agent.py
├── config.yaml
├── data.py
├── main.py
├── module.py
├── requirements.txt
└── trainer.py
/README.md:
--------------------------------------------------------------------------------
1 |
encode-attend-navigate-pytorch
2 |
3 | Pytorch implementation of encode-attend-navigate, a Deep Reinforcement Learning based TSP solver.
4 |
5 |
6 | ## Get started
7 |
8 | ### Run on Colab
9 |
10 | You can leverage the free GPU on Colab to train this model. Just run this notebook :
11 | [](https://colab.research.google.com/drive/1Tggr-QIQSyt7jnjZRuBp5wBt6eDoC1-c?usp=sharing)
12 |
13 | ### Run locally
14 |
15 | Clone the repository :
16 |
17 | ```console
18 | git clone https://github.com/astariul/encode-attend-navigate-pytorch.git
19 | cd encode-attend-navigate-pytorch
20 | ```
21 |
22 | ---
23 |
24 | Install dependencies :
25 |
26 | ```console
27 | pip install -r requirements.txt
28 | ```
29 |
30 | ---
31 |
32 | Run the code :
33 |
34 | ```console
35 | python main.py
36 | ```
37 |
38 | ---
39 |
40 | You can specify your own configuration file :
41 |
42 | ```console
43 | python main.py config=my_conf.yaml
44 | ```
45 |
46 | ---
47 |
48 | Or directly modify parameters from the command line :
49 |
50 | ```console
51 | python main.py lr=0.002 max_len=100 batch_size=64
52 | ```
53 |
54 | ### Expected results
55 |
56 | I ran the code with the following command line :
57 |
58 | ```console
59 | python main.py enc_stacks=1 lr=0.0002 p_dropout=0.1
60 | ```
61 |
62 | On Colab, with a `Tesla T4` GPU, it tooks 1h 46m for the training to complete.
63 |
64 | Here is the training curves :
65 |
66 |
67 |
68 |
69 |
70 | ---
71 |
72 | After training, here is a few example of path generated :
73 |
74 |
75 |
76 |
77 |
78 | ## Implementation
79 |
80 | This code is a direct translation of the [official TF 1.x implementation](https://framagit.org/MichelDeudon/encode-attend-navigate), by @MichelDeudon.
81 |
82 | Please refer to their README for additional details.
83 |
84 | ---
85 |
86 | To ensure the Pytorch implementation produces the same results as the original implementation, I compared the outputs of each layer given the same inputs and check if they are the same.
87 |
88 | You can find (and run) these tests on this Colab notebook : [](https://colab.research.google.com/drive/1HChapUUC_3cZoZsG1A3WJLwclQRsyuR2?usp=sharing)
89 |
--------------------------------------------------------------------------------
/agent.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from module import Embedding, Encoder, Decoder, Critic
5 |
6 |
7 | class Agent(nn.Module):
8 | def __init__(self, space_dim=2, embed_hidden=128, enc_stacks=3, ff_hidden=512, enc_heads=16, query_hidden=360, att_hidden=256, crit_hidden=256, n_history=3, p_dropout=0.1):
9 | """Agent, made of an encoder + decoder for the actor part, and a critic
10 | part.
11 |
12 | Args:
13 | space_dim (int, optional): Dimension for the cities coordinates.
14 | Defaults to 2.
15 | embed_hidden (int, optional): Embeddings hidden size. Defaults to 128.
16 | enc_stacks (int, optional): Number of encoder layers. Defaults to 3.
17 | ff_hidden (int, optional): Hidden size for the FF part of the encoder.
18 | Defaults to 512.
19 | enc_heads (int, optional): Number of attention heads for the encoder.
20 | Defaults to 16.
21 | query_hidden (int, optional): Query hidden size. Defaults to 360.
22 | att_hidden (int, optional): Attention hidden size. Defaults to 256.
23 | crit_hidden (int, optional): Critic hidden size. Defaults to 256.
24 | n_history (int, optional): Size of history (memory size of the
25 | decoder). Defaults to 3.
26 | p_dropout (float, optional): Dropout rate. Defaults to 0.1.
27 | """
28 | super().__init__()
29 |
30 | # Actor
31 | self.embedding = Embedding(in_dim=space_dim, out_dim=embed_hidden)
32 | self.encoder = Encoder(num_layers=enc_stacks, n_hidden=embed_hidden, ff_hidden=ff_hidden, num_heads=enc_heads, p_dropout=p_dropout)
33 | self.decoder = Decoder(n_hidden=embed_hidden, att_dim=att_hidden, query_dim=query_hidden, n_history=n_history)
34 |
35 | # Critic
36 | self.critic = Critic(n_hidden=embed_hidden, att_hidden=att_hidden, crit_hidden=crit_hidden)
37 |
38 | def forward(self, inputs, c=10, temp=1):
39 | embed_inp = self.embedding(inputs)
40 | encoder_hidden = self.encoder(embed_inp)
41 |
42 | tour, log_probs, entropies = self.decoder(encoder_hidden)
43 |
44 | critique = self.critic(encoder_hidden)
45 |
46 | return tour, critique, log_probs, entropies
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | # Network architecture
2 | embed_hidden: 128
3 | enc_stacks: 3
4 | ff_hidden: 512
5 | enc_heads: 16
6 | query_hidden: 360
7 | att_hidden: 256
8 | crit_hidden: 256
9 | n_history: 3
10 | p_dropout: 0.1
11 |
12 | # Logging
13 | proj_name: "encode_attend_navigate"
14 | log_interval: 200
15 |
16 | # Training
17 | steps: 20000
18 | batch_size: 256
19 | max_len: 50
20 | dimension: 2
21 | model_path: "model.pt"
22 | test: True
23 | test_steps: 10
24 | device: "cuda:0"
25 |
26 | # Optimizer
27 | lr: 0.001
28 | lr_decay_rate: 0.96
29 | lr_decay_steps: 5000
30 | grad_clip: 1
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | ### Code entirely copied from the original repo
2 | # https://github.com/MichelDeudon/encode-attend-navigate/blob/master/code/data_generator.py
3 |
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | import math
7 | from sklearn.decomposition import PCA
8 |
9 |
10 | # Compute a sequence's reward
11 | def reward(tsp_sequence):
12 | tour = np.concatenate(
13 | (tsp_sequence, np.expand_dims(tsp_sequence[0], 0))
14 | ) # sequence to tour (end=start)
15 | inter_city_distances = np.sqrt(
16 | np.sum(np.square(tour[:-1, :2] - tour[1:, :2]), axis=1)
17 | ) # tour length
18 | return np.sum(inter_city_distances) # reward
19 |
20 |
21 | # Swap city[i] with city[j] in sequence
22 | def swap2opt(tsp_sequence, i, j):
23 | new_tsp_sequence = np.copy(tsp_sequence)
24 | new_tsp_sequence[i : j + 1] = np.flip(
25 | tsp_sequence[i : j + 1], axis=0
26 | ) # flip or swap ?
27 | return new_tsp_sequence
28 |
29 |
30 | # One step of 2opt = one double loop and return first improved sequence
31 | def step2opt(tsp_sequence):
32 | seq_length = tsp_sequence.shape[0]
33 | distance = reward(tsp_sequence)
34 | for i in range(1, seq_length - 1):
35 | for j in range(i + 1, seq_length):
36 | new_tsp_sequence = swap2opt(tsp_sequence, i, j)
37 | new_distance = reward(new_tsp_sequence)
38 | if new_distance < distance:
39 | return new_tsp_sequence, new_distance
40 | return tsp_sequence, distance
41 |
42 |
43 | class DataGenerator(object):
44 | def __init__(self):
45 | pass
46 |
47 | def gen_instance(
48 | self, max_length, dimension, seed=0
49 | ): # Generate random TSP instance
50 | if seed != 0:
51 | np.random.seed(seed)
52 | sequence = np.random.rand(
53 | max_length, dimension
54 | ) # (max_length) cities with (dimension) coordinates in [0,1]
55 | pca = PCA(n_components=dimension) # center & rotate coordinates
56 | sequence = pca.fit_transform(sequence)
57 | return sequence
58 |
59 | def train_batch(
60 | self, batch_size, max_length, dimension
61 | ): # Generate random batch for training procedure
62 | input_batch = []
63 | for _ in range(batch_size):
64 | input_ = self.gen_instance(
65 | max_length, dimension
66 | ) # Generate random TSP instance
67 | input_batch.append(input_) # Store batch
68 | return input_batch
69 |
70 | def test_batch(
71 | self, batch_size, max_length, dimension, seed=0, shuffle=False
72 | ): # Generate random batch for testing procedure
73 | input_batch = []
74 | input_ = self.gen_instance(
75 | max_length, dimension, seed=seed
76 | ) # Generate random TSP instance
77 | for _ in range(batch_size):
78 | sequence = np.copy(input_)
79 | if shuffle == True:
80 | np.random.shuffle(sequence) # Shuffle sequence
81 | input_batch.append(sequence) # Store batch
82 | return input_batch
83 |
84 | def loop2opt(
85 | self, tsp_sequence, max_iter=2000
86 | ): # Iterate step2opt max_iter times (2-opt local search)
87 | best_reward = reward(tsp_sequence)
88 | new_tsp_sequence = np.copy(tsp_sequence)
89 | for _ in range(max_iter):
90 | new_tsp_sequence, new_reward = step2opt(new_tsp_sequence)
91 | if new_reward < best_reward:
92 | best_reward = new_reward
93 | else:
94 | break
95 | return new_tsp_sequence, best_reward
96 |
97 | def visualize_2D_trip(self, trip): # Plot tour
98 | plt.figure(1)
99 | colors = ["red"] # First city red
100 | for i in range(len(trip) - 1):
101 | colors.append("blue")
102 |
103 | plt.scatter(trip[:, 0], trip[:, 1], color=colors) # Plot cities
104 | tour = np.array(list(range(len(trip))) + [0]) # Plot tour
105 | X = trip[tour, 0]
106 | Y = trip[tour, 1]
107 | plt.plot(X, Y, "--")
108 |
109 | plt.xlim(-0.75, 0.75)
110 | plt.ylim(-0.75, 0.75)
111 | plt.xlabel("X")
112 | plt.ylabel("Y")
113 | plt.show()
114 |
115 | def visualize_sampling(
116 | self, permutations
117 | ): # Heatmap of permutations (x=cities; y=steps)
118 | max_length = len(permutations[0])
119 | grid = np.zeros([max_length, max_length]) # initialize heatmap grid to 0
120 |
121 | transposed_permutations = np.transpose(permutations)
122 | for t, cities_t in enumerate(
123 | transposed_permutations
124 | ): # step t, cities chosen at step t
125 | city_indices, counts = np.unique(cities_t, return_counts=True, axis=0)
126 | for u, v in zip(city_indices, counts):
127 | grid[t][
128 | u
129 | ] += v # update grid with counts from the batch of permutations
130 |
131 | fig = plt.figure(1) # plot heatmap
132 | ax = fig.add_subplot(1, 1, 1)
133 | ax.set_aspect("equal")
134 | plt.imshow(grid, interpolation="nearest", cmap="gray")
135 | plt.colorbar()
136 | plt.title("Sampled permutations")
137 | plt.ylabel("Time t")
138 | plt.xlabel("City i")
139 | plt.show()
140 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import wandb
4 | from omegaconf import OmegaConf as omg
5 | import torch
6 |
7 | from agent import Agent
8 | from trainer import Trainer, reward_fn
9 | from data import DataGenerator
10 |
11 |
12 | def load_conf():
13 | """Quick method to load configuration (using OmegaConf). By default,
14 | configuration is loaded from the default config file (config.yaml).
15 | Another config file can be specific through command line.
16 | Also, configuration can be over-written by command line.
17 |
18 | Returns:
19 | OmegaConf.DictConfig: OmegaConf object representing the configuration.
20 | """
21 | default_conf = omg.create({"config" : "config.yaml"})
22 |
23 | sys.argv = [a.strip("-") for a in sys.argv]
24 | cli_conf = omg.from_cli()
25 |
26 | yaml_file = omg.merge(default_conf, cli_conf).config
27 |
28 | yaml_conf = omg.load(yaml_file)
29 |
30 | return omg.merge(default_conf, yaml_conf, cli_conf)
31 |
32 |
33 | def main():
34 | conf = load_conf()
35 | wandb.init(project=conf.proj_name, config=dict(conf))
36 |
37 | agent = Agent(embed_hidden=conf.embed_hidden, enc_stacks=conf.enc_stacks, ff_hidden=conf.ff_hidden, enc_heads=conf.enc_heads, query_hidden=conf.query_hidden, att_hidden=conf.att_hidden, crit_hidden=conf.crit_hidden, n_history=conf.n_history, p_dropout=conf.p_dropout)
38 | wandb.watch(agent)
39 |
40 | dataset = DataGenerator()
41 |
42 | trainer = Trainer(conf, agent, dataset)
43 | trainer.run()
44 |
45 | # Save trained agent
46 | torch.save(agent.state_dict(), conf.model_path)
47 |
48 | if conf.test:
49 | device = torch.device(conf.device)
50 | # Load trained agent
51 | agent.load_state_dict(torch.load(conf.model_path))
52 | agent.eval()
53 | agent = agent.to(device)
54 |
55 | running_reward = 0
56 | for _ in range(conf.test_steps):
57 | input_batch = dataset.test_batch(conf.batch_size, conf.max_len, conf.dimension, shuffle=False)
58 | input_batch = torch.Tensor(input_batch).to(device)
59 |
60 | tour, *_ = agent(input_batch)
61 |
62 | reward = reward_fn(input_batch, tour)
63 |
64 | # Find best solution
65 | j = reward.argmin()
66 | best_tour = tour[j][:-1].tolist()
67 |
68 | # Log
69 | running_reward += reward[j]
70 |
71 | # Display
72 | print('Reward (before 2 opt)', reward[j])
73 | opt_tour, opt_length = dataset.loop2opt(input_batch.cpu()[0][best_tour])
74 | print('Reward (with 2 opt)', opt_length)
75 | dataset.visualize_2D_trip(opt_tour)
76 |
77 | wandb.run.summary["test_reward"] = running_reward / conf.test_steps
78 |
79 |
80 | if __name__ == "__main__":
81 | main()
82 |
--------------------------------------------------------------------------------
/module.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.distributions as distrib
7 |
8 |
9 | LARGE_NUMBER = 100_000_000
10 |
11 |
12 | class Embedding(nn.Module):
13 | def __init__(self, in_dim, out_dim, batch_norm=True):
14 | """ Module for embedding the input features.
15 |
16 | Args:
17 | in_dim (int): Input dimension.
18 | out_dim (int): Output dimension.
19 | batch_norm (bool, optional): Wether to apply batch normalization or
20 | not. Defaults to True.
21 | """
22 | super().__init__()
23 | self.dense = nn.Linear(in_dim, out_dim, bias=False)
24 | self.batch_norm = (
25 | nn.BatchNorm1d(out_dim, eps=0.001, momentum=0.01)
26 | if batch_norm
27 | else nn.Identity()
28 | ) # (Use TF default parameter, for consistency with original code)
29 |
30 | def forward(self, x):
31 | # x : [bs, seq, in_dim]
32 | emb_x = self.dense(x) # [bs, out_dim, seq]
33 | return self.batch_norm(emb_x.transpose(1, 2)).transpose(1, 2)
34 |
35 |
36 | class MultiHeadAttention(nn.Module):
37 | def __init__(self, n_hidden=512, num_heads=16, p_dropout=0.1):
38 | """ Module applying a single block of multi-head attention.
39 |
40 | Args:
41 | n_hidden (int, optional): Hidden size. Should be a multiple of
42 | `num_heads`. Defaults to 126.
43 | num_heads (int, optional): Number of heads. Defaults to 16.
44 | p_dropout (float, optional): Dropout rate. Defaults to 0.1.
45 |
46 | Raises:
47 | ValueError: Error raised if `n_hidden` is not a multiple of
48 | `num_heads`.
49 | """
50 | super().__init__()
51 | if n_hidden % num_heads != 0:
52 | raise ValueError(
53 | "`n_hidden` ({}) should be a multiple of `num_heads` ({})".format(
54 | n_hidden, num_heads
55 | )
56 | )
57 |
58 | self.q = nn.Linear(n_hidden, n_hidden)
59 | self.k = nn.Linear(n_hidden, n_hidden)
60 | self.v = nn.Linear(n_hidden, n_hidden)
61 |
62 | self.dropout = nn.Dropout(p_dropout)
63 | self.batch_norm = nn.BatchNorm1d(n_hidden, eps=0.001, momentum=0.01)
64 | # (Use TF default parameter, for consistency with original code)
65 |
66 | self.num_heads = num_heads
67 |
68 | def forward(self, inputs):
69 | # inputs : [bs, seq, n_hidden]
70 | bs, _, n_hidden = inputs.size()
71 |
72 | # Linear projections
73 | q = F.relu(self.q(inputs))
74 | k = F.relu(self.k(inputs))
75 | v = F.relu(self.v(inputs))
76 |
77 | # Split and concat
78 | q_ = torch.cat(torch.split(q, n_hidden // self.num_heads, dim=2))
79 | k_ = torch.cat(torch.split(k, n_hidden // self.num_heads, dim=2))
80 | v_ = torch.cat(torch.split(v, n_hidden // self.num_heads, dim=2))
81 |
82 | # Multiplication
83 | outputs = torch.matmul(q_, torch.transpose(k_, 2, 1))
84 |
85 | # Scale
86 | outputs = outputs / (k_.size(-1) ** 0.5)
87 |
88 | # Activation
89 | outputs = F.softmax(outputs, dim=-1)
90 |
91 | # Dropout
92 | outputs = self.dropout(outputs)
93 |
94 | # Weighted sum
95 | outputs = torch.matmul(outputs, v_)
96 |
97 | # Restore shape
98 | outputs = torch.cat(torch.split(outputs, bs), dim=2)
99 |
100 | # Residual connection
101 | outputs += inputs
102 |
103 | # Normalize
104 | outputs = self.batch_norm(outputs.transpose(1, 2)).transpose(1, 2)
105 |
106 | return outputs
107 |
108 |
109 | class FeedForward(nn.Module):
110 | def __init__(self, layers_size=[512, 2048, 512]):
111 | """ Feed Forward network.
112 |
113 | Args:
114 | layers_size (list, optional): List describing the internal sizes of
115 | the FF network. Defaults to [512, 2048, 512].
116 | """
117 | super().__init__()
118 |
119 | self.layers = nn.ModuleList([nn.Linear(in_size, out_size) for in_size, out_size in zip(layers_size[:-1], layers_size[1:])])
120 |
121 | self.batch_norm = nn.BatchNorm1d(layers_size[-1], eps=0.001, momentum=0.01)
122 | # (Use TF default parameter, for consistency with original code)
123 |
124 | def forward(self, inputs):
125 | # inputs : [bs, seq, n_hidden]
126 | outputs = inputs
127 | for i, layer in enumerate(self.layers):
128 | outputs = layer(outputs)
129 |
130 | if i < len(self.layers) - 1:
131 | outputs = F.relu(outputs)
132 |
133 | # Residual connection
134 | outputs += inputs
135 |
136 | # Normalize
137 | outputs = self.batch_norm(outputs.transpose(1, 2)).transpose(1, 2)
138 |
139 | return outputs
140 |
141 |
142 | class Encoder(nn.Module):
143 | def __init__(
144 | self, num_layers=3, n_hidden=512, ff_hidden=2048, num_heads=16, p_dropout=0.1
145 | ):
146 | """ Encoder layer
147 |
148 | Args:
149 | num_layers (int, optional): Number of layer in the Encoder. Defaults to 3.
150 | ff_hidden (int, optional): Size for the hidden layer of FF. Defaults to 2048.
151 | n_hidden (int, optional): Hidden size. Defaults to 512.
152 | num_heads (int, optional): Number of Attention heads. Defaults to 16.
153 | p_dropout (float, optional): Dropout rate. Defaults to 0.1.
154 | """
155 | super().__init__()
156 | self.multihead_attention = nn.ModuleList([MultiHeadAttention(n_hidden, num_heads, p_dropout) for _ in range(num_layers)])
157 | self.ff = nn.ModuleList([FeedForward(layers_size=[n_hidden, ff_hidden, n_hidden]) for _ in range(num_layers)])
158 |
159 | def forward(self, input_seq):
160 | for att, ff in zip(self.multihead_attention, self.ff):
161 | input_seq = ff(att(input_seq))
162 | return input_seq
163 |
164 |
165 | class Pointer(nn.Module):
166 | def __init__(self, query_dim=256, n_hidden=512):
167 | """ Pointer network.
168 |
169 | Args:
170 | query_dim (int, optional): Dimension of the query. Defaults to 256.
171 | n_hidden (int, optional): Hidden size. Defaults to 512.
172 | """
173 | super().__init__()
174 | self.w_q = nn.Linear(query_dim, n_hidden, bias=False)
175 | self.v = nn.Parameter(torch.Tensor(n_hidden))
176 | self.reset_parameters()
177 |
178 | def reset_parameters(self):
179 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.w_q.weight)
180 | bound = 1 / math.sqrt(fan_in)
181 | nn.init.uniform_(self.v, -bound, bound) # Similar to a bias of Linear layer
182 |
183 | def forward(self, encoded_ref, query, mask, c=10, temp=1):
184 | encoded_query = self.w_q(query).unsqueeze(1)
185 | scores = torch.sum(self.v * torch.tanh(encoded_ref + encoded_query), dim=-1)
186 | scores = c * torch.tanh(scores / temp)
187 | masked_scores = torch.clip(
188 | scores - LARGE_NUMBER * mask, -LARGE_NUMBER, LARGE_NUMBER
189 | )
190 | return masked_scores
191 |
192 |
193 | class FullGlimpse(nn.Module):
194 | def __init__(self, in_dim=128, out_dim=256):
195 | """ Full Glimpse for the Critic.
196 |
197 | Args:
198 | in_dim (int): Input dimension.
199 | out_dim (int): Output dimension.
200 | """
201 | super().__init__()
202 | self.dense = nn.Linear(in_dim, out_dim, bias=False)
203 | self.v = nn.Parameter(torch.Tensor(out_dim))
204 | self.reset_parameters()
205 |
206 | def reset_parameters(self):
207 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.dense.weight)
208 | bound = 1 / math.sqrt(fan_in)
209 | nn.init.uniform_(self.v, -bound, bound) # Similar to a bias of Linear layer
210 |
211 | def forward(self, ref):
212 | # Attention
213 | encoded_ref = self.dense(ref)
214 | scores = torch.sum(self.v * torch.tanh(encoded_ref), dim=-1)
215 | attention = F.softmax(scores, dim=-1)
216 |
217 | # Glimpse : Linear combination of reference vectors (define new query vector)
218 | glimpse = ref * attention.unsqueeze(-1)
219 | glimpse = torch.sum(glimpse, dim=1)
220 | return glimpse
221 |
222 |
223 | class Decoder(nn.Module):
224 | def __init__(self, n_hidden=512, att_dim=256, query_dim=360, n_history=3):
225 | """ Decoder with a Pointer network and a memory of size `n_history`.
226 |
227 | Args:
228 | n_hidden (int, optional): Encoder hidden size. Defaults to 512.
229 | att_dim (int, optional): Attention dimension size. Defaults to 256.
230 | query_dim (int, optional): Dimension of the query. Defaults to 360.
231 | n_history (int, optional): Size of history. Defaults to 3.
232 | """
233 | super().__init__()
234 | self.dense = nn.Linear(n_hidden, att_dim, bias=False)
235 | self.n_history = n_history
236 | self.queriers = nn.ModuleList([
237 | nn.Linear(n_hidden, query_dim, bias=False) for _ in range(n_history)
238 | ])
239 | self.pointer = Pointer(query_dim, att_dim)
240 |
241 | def forward(self, inputs, c=10, temp=1):
242 | batch_size, seq_len, hidden = inputs.size()
243 |
244 | idx_list, log_probs, entropies = [], [], [] # Tours index, log_probs, entropies
245 | mask = torch.zeros([batch_size, seq_len], device=inputs.device) # Mask for actions
246 |
247 | encoded_input = self.dense(inputs)
248 |
249 | prev_actions = [torch.zeros([batch_size, hidden], device=inputs.device) for _ in range(self.n_history)]
250 |
251 | for _ in range(seq_len):
252 | query = F.relu(
253 | torch.stack(
254 | [
255 | querier(prev_action)
256 | for prev_action, querier in zip(prev_actions, self.queriers)
257 | ]
258 | ).sum(dim=0)
259 | )
260 | logits = self.pointer(encoded_input, query, mask, c=c, temp=temp)
261 |
262 | probs = distrib.Categorical(logits=logits)
263 | idx = probs.sample()
264 |
265 | idx_list.append(idx) # Tour index
266 | log_probs.append(probs.log_prob(idx))
267 | entropies.append(probs.entropy())
268 | mask = mask + torch.zeros([batch_size, seq_len], device=inputs.device).scatter_(1, idx.unsqueeze(1), 1)
269 |
270 | action_rep = inputs[torch.arange(batch_size), idx]
271 | prev_actions.pop(0)
272 | prev_actions.append(action_rep)
273 |
274 | idx_list.append(idx_list[0]) # Return to start
275 | tour = torch.stack(idx_list, dim=1) # Permutations
276 | log_probs = sum(log_probs) # log-probs for backprop (Reinforce)
277 | entropies = sum(entropies)
278 |
279 | return tour, log_probs, entropies
280 |
281 |
282 | class Critic(nn.Module):
283 | def __init__(self, n_hidden=128, att_hidden=256, crit_hidden=256):
284 | """ Critic module, estimating the minimum length of the tour from the
285 | encoded inputs.
286 |
287 | Args:
288 | n_hidden (int, optional): Size of the encoded input. Defaults to 128.
289 | att_hidden (int, optional): Attention hidden size. Defaults to 256.
290 | crit_hidden (int, optional): Critic hidden size. Defaults to 256.
291 | """
292 | super().__init__()
293 | self.glimpse = FullGlimpse(n_hidden, att_hidden)
294 | self.hidden = nn.Linear(n_hidden, crit_hidden)
295 | self.output = nn.Linear(crit_hidden, 1)
296 |
297 | def forward(self, inputs):
298 | frame = self.glimpse(inputs)
299 | hidden_out = F.relu(self.hidden(frame))
300 | preds = self.output(hidden_out).squeeze()
301 | return preds
302 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | wandb==0.10.20
2 | torch==1.8
3 | omegaconf==2.0.6
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import wandb
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def reward_fn(coords, tour):
8 | """Reward function. Compute the total distance for a tour, given the
9 | coordinates of each city and the tour indexes.
10 |
11 | Args:
12 | coords (torch.Tensor): Tensor of size [batch_size, seq_len, dim],
13 | representing each city's coordinates.
14 | tour (torch.Tensor): Tensor of size [batch_size, seq_len + 1],
15 | representing the tour's indexes (comes back to the first city).
16 |
17 | Returns:
18 | float: Reward for this tour.
19 | """
20 | dim = coords.size(-1)
21 |
22 | ordered_coords = torch.gather(coords, 1, tour.long().unsqueeze(-1).repeat(1, 1, dim))
23 | ordered_coords = ordered_coords.transpose(0, 2) # [dim, seq_len, batch_size]
24 |
25 | # For each dimension (x, y), compute the squared difference between each city
26 | delta2 = [torch.square(d[1:] - d[:-1]).transpose(0, 1) for d in ordered_coords]
27 |
28 | # Euclidian distance between each city
29 | inter_city_distances = torch.sqrt(sum(delta2))
30 | distance = inter_city_distances.sum(dim=-1)
31 |
32 | return distance.float()
33 |
34 |
35 | class Trainer():
36 | def __init__(self, conf, agent, dataset):
37 | """Trainer class, taking care of training the agent.
38 |
39 | Args:
40 | conf (OmegaConf.DictConf): Configuration.
41 | agent (torch.nn.Module): Agent network to train.
42 | dataset (data.DataGenerator): Data generator.
43 | """
44 | super().__init__()
45 |
46 | self.conf = conf
47 | self.agent = agent
48 | self.dataset = dataset
49 |
50 | self.device = torch.device(self.conf.device)
51 | self.agent = self.agent.to(self.device)
52 |
53 | self.optim = torch.optim.Adam(params=self.agent.parameters(), lr=self.conf.lr)
54 | gamma = 1 - self.conf.lr_decay_rate / self.conf.lr_decay_steps # To have same behavior as Tensorflow implementation
55 | self.scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.optim, gamma=gamma)
56 |
57 |
58 | def train_step(self, data):
59 | self.optim.zero_grad()
60 |
61 | # Forward pass
62 | tour, critique, log_probs, _ = self.agent(data)
63 |
64 | # Compute reward
65 | reward = reward_fn(data, tour)
66 |
67 | # Compute losses for both actor (reinforce) and critic
68 | loss1 = ((reward - critique).detach() * log_probs).mean()
69 | loss2 = F.mse_loss(reward, critique)
70 |
71 | # Backward pass
72 | (loss1 + loss2).backward()
73 |
74 | # Clip gradients
75 | nn.utils.clip_grad_norm_(self.agent.parameters(), self.conf.grad_clip)
76 |
77 | # Optimize
78 | self.optim.step()
79 |
80 | # Update LR
81 | self.scheduler.step()
82 |
83 | return reward.mean(), [loss1, loss2]
84 |
85 | def run(self):
86 | self.agent.train()
87 | running_reward, running_losses = 0, [0, 0]
88 | for step in range(self.conf.steps):
89 | input_batch = self.dataset.train_batch(self.conf.batch_size, self.conf.max_len, self.conf.dimension)
90 | input_batch = torch.Tensor(input_batch).to(self.device)
91 |
92 | reward, losses = self.train_step(input_batch)
93 |
94 | running_reward += reward
95 | running_losses[0] += losses[0]
96 | running_losses[1] += losses[1]
97 |
98 | if step % self.conf.log_interval == 0 and step != 0:
99 | # Log stuff
100 | wandb.log({
101 | 'reward': running_reward / self.conf.log_interval,
102 | 'actor_loss': running_losses[0] / self.conf.log_interval,
103 | 'critic_loss': running_losses[1] / self.conf.log_interval,
104 | 'learning_rate': self.scheduler.get_last_lr()[0],
105 | 'step': step
106 | })
107 |
108 | # Reset running reward/loss
109 | running_reward, running_losses = 0, [0, 0]
110 |
111 |
112 |
113 |
--------------------------------------------------------------------------------