├── README.md ├── agent.py ├── config.yaml ├── data.py ├── main.py ├── module.py ├── requirements.txt └── trainer.py /README.md: -------------------------------------------------------------------------------- 1 |

encode-attend-navigate-pytorch

2 |

3 | Pytorch implementation of encode-attend-navigate, a Deep Reinforcement Learning based TSP solver. 4 |

5 | 6 | ## Get started 7 | 8 | ### Run on Colab 9 | 10 | You can leverage the free GPU on Colab to train this model. Just run this notebook : 11 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Tggr-QIQSyt7jnjZRuBp5wBt6eDoC1-c?usp=sharing) 12 | 13 | ### Run locally 14 | 15 | Clone the repository : 16 | 17 | ```console 18 | git clone https://github.com/astariul/encode-attend-navigate-pytorch.git 19 | cd encode-attend-navigate-pytorch 20 | ``` 21 | 22 | --- 23 | 24 | Install dependencies : 25 | 26 | ```console 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | --- 31 | 32 | Run the code : 33 | 34 | ```console 35 | python main.py 36 | ``` 37 | 38 | --- 39 | 40 | You can specify your own configuration file : 41 | 42 | ```console 43 | python main.py config=my_conf.yaml 44 | ``` 45 | 46 | --- 47 | 48 | Or directly modify parameters from the command line : 49 | 50 | ```console 51 | python main.py lr=0.002 max_len=100 batch_size=64 52 | ``` 53 | 54 | ### Expected results 55 | 56 | I ran the code with the following command line : 57 | 58 | ```console 59 | python main.py enc_stacks=1 lr=0.0002 p_dropout=0.1 60 | ``` 61 | 62 | On Colab, with a `Tesla T4` GPU, it tooks 1h 46m for the training to complete. 63 | 64 | Here is the training curves : 65 | 66 | 67 | 68 | 69 | 70 | --- 71 | 72 | After training, here is a few example of path generated : 73 | 74 | 75 | 76 | 77 | 78 | ## Implementation 79 | 80 | This code is a direct translation of the [official TF 1.x implementation](https://framagit.org/MichelDeudon/encode-attend-navigate), by @MichelDeudon. 81 | 82 | Please refer to their README for additional details. 83 | 84 | --- 85 | 86 | To ensure the Pytorch implementation produces the same results as the original implementation, I compared the outputs of each layer given the same inputs and check if they are the same. 87 | 88 | You can find (and run) these tests on this Colab notebook : [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1HChapUUC_3cZoZsG1A3WJLwclQRsyuR2?usp=sharing) 89 | -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from module import Embedding, Encoder, Decoder, Critic 5 | 6 | 7 | class Agent(nn.Module): 8 | def __init__(self, space_dim=2, embed_hidden=128, enc_stacks=3, ff_hidden=512, enc_heads=16, query_hidden=360, att_hidden=256, crit_hidden=256, n_history=3, p_dropout=0.1): 9 | """Agent, made of an encoder + decoder for the actor part, and a critic 10 | part. 11 | 12 | Args: 13 | space_dim (int, optional): Dimension for the cities coordinates. 14 | Defaults to 2. 15 | embed_hidden (int, optional): Embeddings hidden size. Defaults to 128. 16 | enc_stacks (int, optional): Number of encoder layers. Defaults to 3. 17 | ff_hidden (int, optional): Hidden size for the FF part of the encoder. 18 | Defaults to 512. 19 | enc_heads (int, optional): Number of attention heads for the encoder. 20 | Defaults to 16. 21 | query_hidden (int, optional): Query hidden size. Defaults to 360. 22 | att_hidden (int, optional): Attention hidden size. Defaults to 256. 23 | crit_hidden (int, optional): Critic hidden size. Defaults to 256. 24 | n_history (int, optional): Size of history (memory size of the 25 | decoder). Defaults to 3. 26 | p_dropout (float, optional): Dropout rate. Defaults to 0.1. 27 | """ 28 | super().__init__() 29 | 30 | # Actor 31 | self.embedding = Embedding(in_dim=space_dim, out_dim=embed_hidden) 32 | self.encoder = Encoder(num_layers=enc_stacks, n_hidden=embed_hidden, ff_hidden=ff_hidden, num_heads=enc_heads, p_dropout=p_dropout) 33 | self.decoder = Decoder(n_hidden=embed_hidden, att_dim=att_hidden, query_dim=query_hidden, n_history=n_history) 34 | 35 | # Critic 36 | self.critic = Critic(n_hidden=embed_hidden, att_hidden=att_hidden, crit_hidden=crit_hidden) 37 | 38 | def forward(self, inputs, c=10, temp=1): 39 | embed_inp = self.embedding(inputs) 40 | encoder_hidden = self.encoder(embed_inp) 41 | 42 | tour, log_probs, entropies = self.decoder(encoder_hidden) 43 | 44 | critique = self.critic(encoder_hidden) 45 | 46 | return tour, critique, log_probs, entropies -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | # Network architecture 2 | embed_hidden: 128 3 | enc_stacks: 3 4 | ff_hidden: 512 5 | enc_heads: 16 6 | query_hidden: 360 7 | att_hidden: 256 8 | crit_hidden: 256 9 | n_history: 3 10 | p_dropout: 0.1 11 | 12 | # Logging 13 | proj_name: "encode_attend_navigate" 14 | log_interval: 200 15 | 16 | # Training 17 | steps: 20000 18 | batch_size: 256 19 | max_len: 50 20 | dimension: 2 21 | model_path: "model.pt" 22 | test: True 23 | test_steps: 10 24 | device: "cuda:0" 25 | 26 | # Optimizer 27 | lr: 0.001 28 | lr_decay_rate: 0.96 29 | lr_decay_steps: 5000 30 | grad_clip: 1 -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | ### Code entirely copied from the original repo 2 | # https://github.com/MichelDeudon/encode-attend-navigate/blob/master/code/data_generator.py 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import math 7 | from sklearn.decomposition import PCA 8 | 9 | 10 | # Compute a sequence's reward 11 | def reward(tsp_sequence): 12 | tour = np.concatenate( 13 | (tsp_sequence, np.expand_dims(tsp_sequence[0], 0)) 14 | ) # sequence to tour (end=start) 15 | inter_city_distances = np.sqrt( 16 | np.sum(np.square(tour[:-1, :2] - tour[1:, :2]), axis=1) 17 | ) # tour length 18 | return np.sum(inter_city_distances) # reward 19 | 20 | 21 | # Swap city[i] with city[j] in sequence 22 | def swap2opt(tsp_sequence, i, j): 23 | new_tsp_sequence = np.copy(tsp_sequence) 24 | new_tsp_sequence[i : j + 1] = np.flip( 25 | tsp_sequence[i : j + 1], axis=0 26 | ) # flip or swap ? 27 | return new_tsp_sequence 28 | 29 | 30 | # One step of 2opt = one double loop and return first improved sequence 31 | def step2opt(tsp_sequence): 32 | seq_length = tsp_sequence.shape[0] 33 | distance = reward(tsp_sequence) 34 | for i in range(1, seq_length - 1): 35 | for j in range(i + 1, seq_length): 36 | new_tsp_sequence = swap2opt(tsp_sequence, i, j) 37 | new_distance = reward(new_tsp_sequence) 38 | if new_distance < distance: 39 | return new_tsp_sequence, new_distance 40 | return tsp_sequence, distance 41 | 42 | 43 | class DataGenerator(object): 44 | def __init__(self): 45 | pass 46 | 47 | def gen_instance( 48 | self, max_length, dimension, seed=0 49 | ): # Generate random TSP instance 50 | if seed != 0: 51 | np.random.seed(seed) 52 | sequence = np.random.rand( 53 | max_length, dimension 54 | ) # (max_length) cities with (dimension) coordinates in [0,1] 55 | pca = PCA(n_components=dimension) # center & rotate coordinates 56 | sequence = pca.fit_transform(sequence) 57 | return sequence 58 | 59 | def train_batch( 60 | self, batch_size, max_length, dimension 61 | ): # Generate random batch for training procedure 62 | input_batch = [] 63 | for _ in range(batch_size): 64 | input_ = self.gen_instance( 65 | max_length, dimension 66 | ) # Generate random TSP instance 67 | input_batch.append(input_) # Store batch 68 | return input_batch 69 | 70 | def test_batch( 71 | self, batch_size, max_length, dimension, seed=0, shuffle=False 72 | ): # Generate random batch for testing procedure 73 | input_batch = [] 74 | input_ = self.gen_instance( 75 | max_length, dimension, seed=seed 76 | ) # Generate random TSP instance 77 | for _ in range(batch_size): 78 | sequence = np.copy(input_) 79 | if shuffle == True: 80 | np.random.shuffle(sequence) # Shuffle sequence 81 | input_batch.append(sequence) # Store batch 82 | return input_batch 83 | 84 | def loop2opt( 85 | self, tsp_sequence, max_iter=2000 86 | ): # Iterate step2opt max_iter times (2-opt local search) 87 | best_reward = reward(tsp_sequence) 88 | new_tsp_sequence = np.copy(tsp_sequence) 89 | for _ in range(max_iter): 90 | new_tsp_sequence, new_reward = step2opt(new_tsp_sequence) 91 | if new_reward < best_reward: 92 | best_reward = new_reward 93 | else: 94 | break 95 | return new_tsp_sequence, best_reward 96 | 97 | def visualize_2D_trip(self, trip): # Plot tour 98 | plt.figure(1) 99 | colors = ["red"] # First city red 100 | for i in range(len(trip) - 1): 101 | colors.append("blue") 102 | 103 | plt.scatter(trip[:, 0], trip[:, 1], color=colors) # Plot cities 104 | tour = np.array(list(range(len(trip))) + [0]) # Plot tour 105 | X = trip[tour, 0] 106 | Y = trip[tour, 1] 107 | plt.plot(X, Y, "--") 108 | 109 | plt.xlim(-0.75, 0.75) 110 | plt.ylim(-0.75, 0.75) 111 | plt.xlabel("X") 112 | plt.ylabel("Y") 113 | plt.show() 114 | 115 | def visualize_sampling( 116 | self, permutations 117 | ): # Heatmap of permutations (x=cities; y=steps) 118 | max_length = len(permutations[0]) 119 | grid = np.zeros([max_length, max_length]) # initialize heatmap grid to 0 120 | 121 | transposed_permutations = np.transpose(permutations) 122 | for t, cities_t in enumerate( 123 | transposed_permutations 124 | ): # step t, cities chosen at step t 125 | city_indices, counts = np.unique(cities_t, return_counts=True, axis=0) 126 | for u, v in zip(city_indices, counts): 127 | grid[t][ 128 | u 129 | ] += v # update grid with counts from the batch of permutations 130 | 131 | fig = plt.figure(1) # plot heatmap 132 | ax = fig.add_subplot(1, 1, 1) 133 | ax.set_aspect("equal") 134 | plt.imshow(grid, interpolation="nearest", cmap="gray") 135 | plt.colorbar() 136 | plt.title("Sampled permutations") 137 | plt.ylabel("Time t") 138 | plt.xlabel("City i") 139 | plt.show() 140 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import wandb 4 | from omegaconf import OmegaConf as omg 5 | import torch 6 | 7 | from agent import Agent 8 | from trainer import Trainer, reward_fn 9 | from data import DataGenerator 10 | 11 | 12 | def load_conf(): 13 | """Quick method to load configuration (using OmegaConf). By default, 14 | configuration is loaded from the default config file (config.yaml). 15 | Another config file can be specific through command line. 16 | Also, configuration can be over-written by command line. 17 | 18 | Returns: 19 | OmegaConf.DictConfig: OmegaConf object representing the configuration. 20 | """ 21 | default_conf = omg.create({"config" : "config.yaml"}) 22 | 23 | sys.argv = [a.strip("-") for a in sys.argv] 24 | cli_conf = omg.from_cli() 25 | 26 | yaml_file = omg.merge(default_conf, cli_conf).config 27 | 28 | yaml_conf = omg.load(yaml_file) 29 | 30 | return omg.merge(default_conf, yaml_conf, cli_conf) 31 | 32 | 33 | def main(): 34 | conf = load_conf() 35 | wandb.init(project=conf.proj_name, config=dict(conf)) 36 | 37 | agent = Agent(embed_hidden=conf.embed_hidden, enc_stacks=conf.enc_stacks, ff_hidden=conf.ff_hidden, enc_heads=conf.enc_heads, query_hidden=conf.query_hidden, att_hidden=conf.att_hidden, crit_hidden=conf.crit_hidden, n_history=conf.n_history, p_dropout=conf.p_dropout) 38 | wandb.watch(agent) 39 | 40 | dataset = DataGenerator() 41 | 42 | trainer = Trainer(conf, agent, dataset) 43 | trainer.run() 44 | 45 | # Save trained agent 46 | torch.save(agent.state_dict(), conf.model_path) 47 | 48 | if conf.test: 49 | device = torch.device(conf.device) 50 | # Load trained agent 51 | agent.load_state_dict(torch.load(conf.model_path)) 52 | agent.eval() 53 | agent = agent.to(device) 54 | 55 | running_reward = 0 56 | for _ in range(conf.test_steps): 57 | input_batch = dataset.test_batch(conf.batch_size, conf.max_len, conf.dimension, shuffle=False) 58 | input_batch = torch.Tensor(input_batch).to(device) 59 | 60 | tour, *_ = agent(input_batch) 61 | 62 | reward = reward_fn(input_batch, tour) 63 | 64 | # Find best solution 65 | j = reward.argmin() 66 | best_tour = tour[j][:-1].tolist() 67 | 68 | # Log 69 | running_reward += reward[j] 70 | 71 | # Display 72 | print('Reward (before 2 opt)', reward[j]) 73 | opt_tour, opt_length = dataset.loop2opt(input_batch.cpu()[0][best_tour]) 74 | print('Reward (with 2 opt)', opt_length) 75 | dataset.visualize_2D_trip(opt_tour) 76 | 77 | wandb.run.summary["test_reward"] = running_reward / conf.test_steps 78 | 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.distributions as distrib 7 | 8 | 9 | LARGE_NUMBER = 100_000_000 10 | 11 | 12 | class Embedding(nn.Module): 13 | def __init__(self, in_dim, out_dim, batch_norm=True): 14 | """ Module for embedding the input features. 15 | 16 | Args: 17 | in_dim (int): Input dimension. 18 | out_dim (int): Output dimension. 19 | batch_norm (bool, optional): Wether to apply batch normalization or 20 | not. Defaults to True. 21 | """ 22 | super().__init__() 23 | self.dense = nn.Linear(in_dim, out_dim, bias=False) 24 | self.batch_norm = ( 25 | nn.BatchNorm1d(out_dim, eps=0.001, momentum=0.01) 26 | if batch_norm 27 | else nn.Identity() 28 | ) # (Use TF default parameter, for consistency with original code) 29 | 30 | def forward(self, x): 31 | # x : [bs, seq, in_dim] 32 | emb_x = self.dense(x) # [bs, out_dim, seq] 33 | return self.batch_norm(emb_x.transpose(1, 2)).transpose(1, 2) 34 | 35 | 36 | class MultiHeadAttention(nn.Module): 37 | def __init__(self, n_hidden=512, num_heads=16, p_dropout=0.1): 38 | """ Module applying a single block of multi-head attention. 39 | 40 | Args: 41 | n_hidden (int, optional): Hidden size. Should be a multiple of 42 | `num_heads`. Defaults to 126. 43 | num_heads (int, optional): Number of heads. Defaults to 16. 44 | p_dropout (float, optional): Dropout rate. Defaults to 0.1. 45 | 46 | Raises: 47 | ValueError: Error raised if `n_hidden` is not a multiple of 48 | `num_heads`. 49 | """ 50 | super().__init__() 51 | if n_hidden % num_heads != 0: 52 | raise ValueError( 53 | "`n_hidden` ({}) should be a multiple of `num_heads` ({})".format( 54 | n_hidden, num_heads 55 | ) 56 | ) 57 | 58 | self.q = nn.Linear(n_hidden, n_hidden) 59 | self.k = nn.Linear(n_hidden, n_hidden) 60 | self.v = nn.Linear(n_hidden, n_hidden) 61 | 62 | self.dropout = nn.Dropout(p_dropout) 63 | self.batch_norm = nn.BatchNorm1d(n_hidden, eps=0.001, momentum=0.01) 64 | # (Use TF default parameter, for consistency with original code) 65 | 66 | self.num_heads = num_heads 67 | 68 | def forward(self, inputs): 69 | # inputs : [bs, seq, n_hidden] 70 | bs, _, n_hidden = inputs.size() 71 | 72 | # Linear projections 73 | q = F.relu(self.q(inputs)) 74 | k = F.relu(self.k(inputs)) 75 | v = F.relu(self.v(inputs)) 76 | 77 | # Split and concat 78 | q_ = torch.cat(torch.split(q, n_hidden // self.num_heads, dim=2)) 79 | k_ = torch.cat(torch.split(k, n_hidden // self.num_heads, dim=2)) 80 | v_ = torch.cat(torch.split(v, n_hidden // self.num_heads, dim=2)) 81 | 82 | # Multiplication 83 | outputs = torch.matmul(q_, torch.transpose(k_, 2, 1)) 84 | 85 | # Scale 86 | outputs = outputs / (k_.size(-1) ** 0.5) 87 | 88 | # Activation 89 | outputs = F.softmax(outputs, dim=-1) 90 | 91 | # Dropout 92 | outputs = self.dropout(outputs) 93 | 94 | # Weighted sum 95 | outputs = torch.matmul(outputs, v_) 96 | 97 | # Restore shape 98 | outputs = torch.cat(torch.split(outputs, bs), dim=2) 99 | 100 | # Residual connection 101 | outputs += inputs 102 | 103 | # Normalize 104 | outputs = self.batch_norm(outputs.transpose(1, 2)).transpose(1, 2) 105 | 106 | return outputs 107 | 108 | 109 | class FeedForward(nn.Module): 110 | def __init__(self, layers_size=[512, 2048, 512]): 111 | """ Feed Forward network. 112 | 113 | Args: 114 | layers_size (list, optional): List describing the internal sizes of 115 | the FF network. Defaults to [512, 2048, 512]. 116 | """ 117 | super().__init__() 118 | 119 | self.layers = nn.ModuleList([nn.Linear(in_size, out_size) for in_size, out_size in zip(layers_size[:-1], layers_size[1:])]) 120 | 121 | self.batch_norm = nn.BatchNorm1d(layers_size[-1], eps=0.001, momentum=0.01) 122 | # (Use TF default parameter, for consistency with original code) 123 | 124 | def forward(self, inputs): 125 | # inputs : [bs, seq, n_hidden] 126 | outputs = inputs 127 | for i, layer in enumerate(self.layers): 128 | outputs = layer(outputs) 129 | 130 | if i < len(self.layers) - 1: 131 | outputs = F.relu(outputs) 132 | 133 | # Residual connection 134 | outputs += inputs 135 | 136 | # Normalize 137 | outputs = self.batch_norm(outputs.transpose(1, 2)).transpose(1, 2) 138 | 139 | return outputs 140 | 141 | 142 | class Encoder(nn.Module): 143 | def __init__( 144 | self, num_layers=3, n_hidden=512, ff_hidden=2048, num_heads=16, p_dropout=0.1 145 | ): 146 | """ Encoder layer 147 | 148 | Args: 149 | num_layers (int, optional): Number of layer in the Encoder. Defaults to 3. 150 | ff_hidden (int, optional): Size for the hidden layer of FF. Defaults to 2048. 151 | n_hidden (int, optional): Hidden size. Defaults to 512. 152 | num_heads (int, optional): Number of Attention heads. Defaults to 16. 153 | p_dropout (float, optional): Dropout rate. Defaults to 0.1. 154 | """ 155 | super().__init__() 156 | self.multihead_attention = nn.ModuleList([MultiHeadAttention(n_hidden, num_heads, p_dropout) for _ in range(num_layers)]) 157 | self.ff = nn.ModuleList([FeedForward(layers_size=[n_hidden, ff_hidden, n_hidden]) for _ in range(num_layers)]) 158 | 159 | def forward(self, input_seq): 160 | for att, ff in zip(self.multihead_attention, self.ff): 161 | input_seq = ff(att(input_seq)) 162 | return input_seq 163 | 164 | 165 | class Pointer(nn.Module): 166 | def __init__(self, query_dim=256, n_hidden=512): 167 | """ Pointer network. 168 | 169 | Args: 170 | query_dim (int, optional): Dimension of the query. Defaults to 256. 171 | n_hidden (int, optional): Hidden size. Defaults to 512. 172 | """ 173 | super().__init__() 174 | self.w_q = nn.Linear(query_dim, n_hidden, bias=False) 175 | self.v = nn.Parameter(torch.Tensor(n_hidden)) 176 | self.reset_parameters() 177 | 178 | def reset_parameters(self): 179 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.w_q.weight) 180 | bound = 1 / math.sqrt(fan_in) 181 | nn.init.uniform_(self.v, -bound, bound) # Similar to a bias of Linear layer 182 | 183 | def forward(self, encoded_ref, query, mask, c=10, temp=1): 184 | encoded_query = self.w_q(query).unsqueeze(1) 185 | scores = torch.sum(self.v * torch.tanh(encoded_ref + encoded_query), dim=-1) 186 | scores = c * torch.tanh(scores / temp) 187 | masked_scores = torch.clip( 188 | scores - LARGE_NUMBER * mask, -LARGE_NUMBER, LARGE_NUMBER 189 | ) 190 | return masked_scores 191 | 192 | 193 | class FullGlimpse(nn.Module): 194 | def __init__(self, in_dim=128, out_dim=256): 195 | """ Full Glimpse for the Critic. 196 | 197 | Args: 198 | in_dim (int): Input dimension. 199 | out_dim (int): Output dimension. 200 | """ 201 | super().__init__() 202 | self.dense = nn.Linear(in_dim, out_dim, bias=False) 203 | self.v = nn.Parameter(torch.Tensor(out_dim)) 204 | self.reset_parameters() 205 | 206 | def reset_parameters(self): 207 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.dense.weight) 208 | bound = 1 / math.sqrt(fan_in) 209 | nn.init.uniform_(self.v, -bound, bound) # Similar to a bias of Linear layer 210 | 211 | def forward(self, ref): 212 | # Attention 213 | encoded_ref = self.dense(ref) 214 | scores = torch.sum(self.v * torch.tanh(encoded_ref), dim=-1) 215 | attention = F.softmax(scores, dim=-1) 216 | 217 | # Glimpse : Linear combination of reference vectors (define new query vector) 218 | glimpse = ref * attention.unsqueeze(-1) 219 | glimpse = torch.sum(glimpse, dim=1) 220 | return glimpse 221 | 222 | 223 | class Decoder(nn.Module): 224 | def __init__(self, n_hidden=512, att_dim=256, query_dim=360, n_history=3): 225 | """ Decoder with a Pointer network and a memory of size `n_history`. 226 | 227 | Args: 228 | n_hidden (int, optional): Encoder hidden size. Defaults to 512. 229 | att_dim (int, optional): Attention dimension size. Defaults to 256. 230 | query_dim (int, optional): Dimension of the query. Defaults to 360. 231 | n_history (int, optional): Size of history. Defaults to 3. 232 | """ 233 | super().__init__() 234 | self.dense = nn.Linear(n_hidden, att_dim, bias=False) 235 | self.n_history = n_history 236 | self.queriers = nn.ModuleList([ 237 | nn.Linear(n_hidden, query_dim, bias=False) for _ in range(n_history) 238 | ]) 239 | self.pointer = Pointer(query_dim, att_dim) 240 | 241 | def forward(self, inputs, c=10, temp=1): 242 | batch_size, seq_len, hidden = inputs.size() 243 | 244 | idx_list, log_probs, entropies = [], [], [] # Tours index, log_probs, entropies 245 | mask = torch.zeros([batch_size, seq_len], device=inputs.device) # Mask for actions 246 | 247 | encoded_input = self.dense(inputs) 248 | 249 | prev_actions = [torch.zeros([batch_size, hidden], device=inputs.device) for _ in range(self.n_history)] 250 | 251 | for _ in range(seq_len): 252 | query = F.relu( 253 | torch.stack( 254 | [ 255 | querier(prev_action) 256 | for prev_action, querier in zip(prev_actions, self.queriers) 257 | ] 258 | ).sum(dim=0) 259 | ) 260 | logits = self.pointer(encoded_input, query, mask, c=c, temp=temp) 261 | 262 | probs = distrib.Categorical(logits=logits) 263 | idx = probs.sample() 264 | 265 | idx_list.append(idx) # Tour index 266 | log_probs.append(probs.log_prob(idx)) 267 | entropies.append(probs.entropy()) 268 | mask = mask + torch.zeros([batch_size, seq_len], device=inputs.device).scatter_(1, idx.unsqueeze(1), 1) 269 | 270 | action_rep = inputs[torch.arange(batch_size), idx] 271 | prev_actions.pop(0) 272 | prev_actions.append(action_rep) 273 | 274 | idx_list.append(idx_list[0]) # Return to start 275 | tour = torch.stack(idx_list, dim=1) # Permutations 276 | log_probs = sum(log_probs) # log-probs for backprop (Reinforce) 277 | entropies = sum(entropies) 278 | 279 | return tour, log_probs, entropies 280 | 281 | 282 | class Critic(nn.Module): 283 | def __init__(self, n_hidden=128, att_hidden=256, crit_hidden=256): 284 | """ Critic module, estimating the minimum length of the tour from the 285 | encoded inputs. 286 | 287 | Args: 288 | n_hidden (int, optional): Size of the encoded input. Defaults to 128. 289 | att_hidden (int, optional): Attention hidden size. Defaults to 256. 290 | crit_hidden (int, optional): Critic hidden size. Defaults to 256. 291 | """ 292 | super().__init__() 293 | self.glimpse = FullGlimpse(n_hidden, att_hidden) 294 | self.hidden = nn.Linear(n_hidden, crit_hidden) 295 | self.output = nn.Linear(crit_hidden, 1) 296 | 297 | def forward(self, inputs): 298 | frame = self.glimpse(inputs) 299 | hidden_out = F.relu(self.hidden(frame)) 300 | preds = self.output(hidden_out).squeeze() 301 | return preds 302 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wandb==0.10.20 2 | torch==1.8 3 | omegaconf==2.0.6 -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def reward_fn(coords, tour): 8 | """Reward function. Compute the total distance for a tour, given the 9 | coordinates of each city and the tour indexes. 10 | 11 | Args: 12 | coords (torch.Tensor): Tensor of size [batch_size, seq_len, dim], 13 | representing each city's coordinates. 14 | tour (torch.Tensor): Tensor of size [batch_size, seq_len + 1], 15 | representing the tour's indexes (comes back to the first city). 16 | 17 | Returns: 18 | float: Reward for this tour. 19 | """ 20 | dim = coords.size(-1) 21 | 22 | ordered_coords = torch.gather(coords, 1, tour.long().unsqueeze(-1).repeat(1, 1, dim)) 23 | ordered_coords = ordered_coords.transpose(0, 2) # [dim, seq_len, batch_size] 24 | 25 | # For each dimension (x, y), compute the squared difference between each city 26 | delta2 = [torch.square(d[1:] - d[:-1]).transpose(0, 1) for d in ordered_coords] 27 | 28 | # Euclidian distance between each city 29 | inter_city_distances = torch.sqrt(sum(delta2)) 30 | distance = inter_city_distances.sum(dim=-1) 31 | 32 | return distance.float() 33 | 34 | 35 | class Trainer(): 36 | def __init__(self, conf, agent, dataset): 37 | """Trainer class, taking care of training the agent. 38 | 39 | Args: 40 | conf (OmegaConf.DictConf): Configuration. 41 | agent (torch.nn.Module): Agent network to train. 42 | dataset (data.DataGenerator): Data generator. 43 | """ 44 | super().__init__() 45 | 46 | self.conf = conf 47 | self.agent = agent 48 | self.dataset = dataset 49 | 50 | self.device = torch.device(self.conf.device) 51 | self.agent = self.agent.to(self.device) 52 | 53 | self.optim = torch.optim.Adam(params=self.agent.parameters(), lr=self.conf.lr) 54 | gamma = 1 - self.conf.lr_decay_rate / self.conf.lr_decay_steps # To have same behavior as Tensorflow implementation 55 | self.scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.optim, gamma=gamma) 56 | 57 | 58 | def train_step(self, data): 59 | self.optim.zero_grad() 60 | 61 | # Forward pass 62 | tour, critique, log_probs, _ = self.agent(data) 63 | 64 | # Compute reward 65 | reward = reward_fn(data, tour) 66 | 67 | # Compute losses for both actor (reinforce) and critic 68 | loss1 = ((reward - critique).detach() * log_probs).mean() 69 | loss2 = F.mse_loss(reward, critique) 70 | 71 | # Backward pass 72 | (loss1 + loss2).backward() 73 | 74 | # Clip gradients 75 | nn.utils.clip_grad_norm_(self.agent.parameters(), self.conf.grad_clip) 76 | 77 | # Optimize 78 | self.optim.step() 79 | 80 | # Update LR 81 | self.scheduler.step() 82 | 83 | return reward.mean(), [loss1, loss2] 84 | 85 | def run(self): 86 | self.agent.train() 87 | running_reward, running_losses = 0, [0, 0] 88 | for step in range(self.conf.steps): 89 | input_batch = self.dataset.train_batch(self.conf.batch_size, self.conf.max_len, self.conf.dimension) 90 | input_batch = torch.Tensor(input_batch).to(self.device) 91 | 92 | reward, losses = self.train_step(input_batch) 93 | 94 | running_reward += reward 95 | running_losses[0] += losses[0] 96 | running_losses[1] += losses[1] 97 | 98 | if step % self.conf.log_interval == 0 and step != 0: 99 | # Log stuff 100 | wandb.log({ 101 | 'reward': running_reward / self.conf.log_interval, 102 | 'actor_loss': running_losses[0] / self.conf.log_interval, 103 | 'critic_loss': running_losses[1] / self.conf.log_interval, 104 | 'learning_rate': self.scheduler.get_last_lr()[0], 105 | 'step': step 106 | }) 107 | 108 | # Reset running reward/loss 109 | running_reward, running_losses = 0, [0, 0] 110 | 111 | 112 | 113 | --------------------------------------------------------------------------------