├── .gitignore ├── README.md ├── config.yaml ├── env.py ├── eval.py ├── groubi_tsp.py ├── models.py ├── requirements.txt ├── revtorch ├── __init__.py └── revtorch.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | tmp/ 2 | data/ 3 | outputs/ 4 | logs/ 5 | results/ 6 | result_ckpt/ 7 | submit.yaml 8 | .vscode/ 9 | local_attn_ablations.yaml 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pointerformer 2 | 3 | We provide code, data and trained models. 4 | 5 | ## Requirements 6 | 7 | ``` 8 | conda create -n python=3.7 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | ## Trained Model and Dataset 13 | 14 | You can download our trained models and datasets from [here](https://drive.google.com/file/d/1Sx5hkXTzYSZ98Iqf1UWJdQfVhCbJGgpz/view?usp=sharing "trained_model_and_data.zip"), then unzip the files to the root directory of this project. 15 | 16 | ## Train 17 | 18 | train a model on TSP100. 19 | 20 | `python train.py graph_size=100 group_size=100` 21 | 22 | ## Evaluate 23 | 24 | eval on tsp_random 25 | 26 | `python eval.py val_data_path=./data/tsp100_test_concorde.txt load_checkpoint_path=./result_ckpt/model100.ckpt` 27 | 28 | eval on tsp_partner 29 | 30 | `python eval.py val_data_path=./data/partner_100.txt load_checkpoint_path=./result_ckpt/model100.ckpt real_data=true` -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | graph_size: 100 2 | node_dim: 2 3 | encoder_type: revmha #mha 4 | decoder_type: DecoderForLarge #Decoder 5 | decoder_dim: 128 6 | embedding_dim: 128 7 | n_layers: 6 8 | group_size: 100 9 | add_init_projection: true # only for mha_encoder 10 | n_heads: 8 # only for mha_encoder 11 | tanh_clipping: 50 12 | total_epoch: 500 13 | epoch_size: 100000 14 | val_size: 10000 15 | train_batch_size: 64 16 | val_batch_size: 128 17 | learning_rate: 1e-4 18 | weight_decay: 1e-6 19 | data_distribution: uniform # or normals 20 | seed: 1234 21 | gpus: [0] 22 | precision: 16 23 | val_type: x8Aug_nTraj 24 | default_run_name: TSP-N${graph_size}G${group_size}-${encoder_type}${n_layers}E${embedding_dim}KD-s${seed}-${now:%m%dT%H%M} 25 | 26 | wandb: false 27 | wandb_project: tsp_eAM 28 | tensorboard: true 29 | 30 | run_name: null 31 | save_dir: ./outputs/ 32 | val_data_path: ./data/tsp${graph_size}_test_concorde.txt 33 | 34 | data_augment: true 35 | multi_pointer: 8 36 | divide_std: true 37 | add_more_query: true 38 | multi_pointer_level: 1 39 | 40 | load_path: null 41 | load_checkpoint_path: ./result_ckpt/model100.ckpt 42 | 43 | real_data: false 44 | integer: false 45 | sample: false 46 | -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import DataLoader, Dataset 4 | 5 | import os 6 | 7 | 8 | class GroupState: 9 | def __init__(self, group_size, x): 10 | # x.shape = [B, N, 2] 11 | self.batch_size = x.size(0) 12 | self.group_size = group_size 13 | self.device = x.device 14 | 15 | self.selected_count = 0 16 | # current_node.shape = [B, G] 17 | self.current_node = None 18 | # selected_node_list.shape = [B, G, selected_count] 19 | self.selected_node_list = torch.zeros( 20 | x.size(0), group_size, 0, device=x.device 21 | ).long() 22 | # ninf_mask.shape = [B, G, N] 23 | self.ninf_mask = torch.zeros(x.size(0), group_size, x.size(1), device=x.device) 24 | 25 | def move_to(self, selected_idx_mat): 26 | # selected_idx_mat.shape = [B, G] 27 | self.selected_count += 1 28 | self.current_node = selected_idx_mat 29 | self.selected_node_list = torch.cat( 30 | (self.selected_node_list, selected_idx_mat[:, :, None]), dim=2 31 | ) 32 | self.ninf_mask.scatter_( 33 | dim=-1, index=selected_idx_mat[:, :, None], value=-torch.inf 34 | ) 35 | 36 | 37 | class MultiTrajectoryTSP: 38 | def __init__(self, x, x_raw=None, integer=False): 39 | self.integer = integer 40 | if x_raw is None: 41 | self.x_raw = x.clone() 42 | else: 43 | self.x_raw = x_raw 44 | self.x = x 45 | self.batch_size = self.B = x.size(0) 46 | self.graph_size = self.N = x.size(1) 47 | self.node_dim = self.C = x.size(2) 48 | self.group_size = self.G = None 49 | self.group_state = None 50 | 51 | def reset(self, group_size): 52 | self.group_size = group_size 53 | self.group_state = GroupState(group_size=group_size, x=self.x) 54 | reward = None 55 | done = False 56 | return self.group_state, reward, done 57 | 58 | def step(self, selected_idx_mat): 59 | # move state 60 | self.group_state.move_to(selected_idx_mat) 61 | 62 | # returning values 63 | done = self.group_state.selected_count == self.graph_size 64 | if done: 65 | reward = -self._get_group_travel_distance() # note the minus sign! 66 | else: 67 | reward = None 68 | return self.group_state, reward, done 69 | 70 | def _get_group_travel_distance(self): 71 | # ordered_seq.shape = [B, G, N, C] 72 | shp = (self.B, self.group_size, self.N, self.C) 73 | gathering_index = self.group_state.selected_node_list.unsqueeze(3).expand(*shp) 74 | seq_expanded = self.x_raw[:, None, :, :].expand(*shp) 75 | ordered_seq = seq_expanded.gather(dim=2, index=gathering_index) 76 | rolled_seq = ordered_seq.roll(dims=2, shifts=-1) 77 | # segment_lengths.size = [B, G, N] 78 | segment_lengths = ((ordered_seq - rolled_seq) ** 2).sum(3).sqrt() 79 | if self.integer: 80 | group_travel_distances = segment_lengths.round().sum(2) 81 | else: 82 | group_travel_distances = segment_lengths.sum(2) 83 | return group_travel_distances 84 | 85 | 86 | def readDataFile(filePath): 87 | """ 88 | read validation dataset from "https://github.com/Spider-scnu/TSP" 89 | """ 90 | res = [] 91 | with open(filePath, "r") as fp: 92 | datas = fp.readlines() 93 | for data in datas: 94 | data = [float(i) for i in data.split("o")[0].split()] 95 | loc_x = torch.FloatTensor(data[::2]) 96 | loc_y = torch.FloatTensor(data[1::2]) 97 | data = torch.stack([loc_x, loc_y], dim=1) 98 | res.append(data) 99 | res = torch.stack(res, dim=0) 100 | return res 101 | 102 | 103 | def readTSPLib(filePath): 104 | """ 105 | read TSPLib 106 | """ 107 | data_trans, data_raw = [], [] 108 | with open(filePath, "r") as fp: 109 | loc_x = [] 110 | loc_y = [] 111 | datas = fp.readlines() 112 | for data in datas: 113 | if ":" in data or "EOF" in data or "NODE_COORD_SECTION" in data: 114 | continue 115 | data = [float(i) for i in data.split()] 116 | if len(data) == 3: 117 | loc_x.append(data[1]) 118 | loc_y.append(data[2]) 119 | loc_x = torch.FloatTensor(loc_x) 120 | loc_y = torch.FloatTensor(loc_y) 121 | 122 | data = torch.stack([loc_x, loc_y], dim=1) 123 | data_raw.append(data) 124 | 125 | mx = loc_x.max() - loc_x.min() 126 | my = loc_y.max() - loc_y.min() 127 | data = torch.stack([loc_x - loc_x.min(), loc_y - loc_y.min()], dim=1) 128 | data = data / max(mx, my) 129 | data_trans.append(data) 130 | 131 | data_trans = torch.stack(data_trans, dim=0) 132 | data_raw = torch.stack(data_raw, dim=0) 133 | return data_trans, data_raw 134 | 135 | 136 | def readTSPLibOpt(opt_path): 137 | with open(opt_path, "r") as fp: 138 | datas = fp.readlines() 139 | tours = [] 140 | for data in datas: 141 | if ":" in data or "-1" in data or "TOUR_SECTION" in data or "EOF" in data: 142 | continue 143 | tours.extend([int(i) - 1 for i in data.split()]) 144 | tours = np.array(tours, dtype=np.int) 145 | return tours 146 | 147 | 148 | class TSPDataset(Dataset): 149 | def __init__( 150 | self, 151 | size=50, 152 | node_dim=2, 153 | num_samples=100000, 154 | data_distribution="uniform", 155 | data_path=None, 156 | ): 157 | super(TSPDataset, self).__init__() 158 | if data_distribution == "uniform": 159 | self.data = torch.rand(num_samples, size, node_dim) 160 | elif data_distribution == "normal": 161 | self.data = torch.randn(num_samples, size, node_dim) 162 | self.size = num_samples 163 | if not data_path is None: 164 | if data_path.split(".")[-1] == "tsp": 165 | self.data, data_raw = readTSPLib(data_path) 166 | opt_path = data_path.replace(".tsp", ".opt.tour") 167 | print(opt_path) 168 | if os.path.exists(opt_path): 169 | self.opt_route = readTSPLibOpt(opt_path) 170 | tmp = np.roll(self.opt_route, -1) 171 | d = data_raw[0, self.opt_route] - data_raw[0, tmp] 172 | self.opt = np.linalg.norm(d, axis=-1).sum() 173 | else: 174 | self.opt = -1 175 | self.data = data_raw 176 | 177 | else: 178 | self.data = readDataFile(data_path) 179 | self.size = self.data.shape[0] 180 | 181 | def __len__(self): 182 | return self.size 183 | 184 | def __getitem__(self, idx): 185 | return self.data[idx] 186 | 187 | 188 | if __name__ == "__main__": 189 | TSPDataset(data_path="./data/ALL_tsp/att48.tsp") 190 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import time 4 | 5 | import hydra 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from omegaconf import DictConfig, open_dict 13 | from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger 14 | from pytorch_lightning.callbacks import ModelCheckpoint 15 | from torch.utils.data import DataLoader, Dataset 16 | import models 17 | import utils 18 | from torch.distributions.categorical import Categorical 19 | 20 | from env import MultiTrajectoryTSP, TSPDataset 21 | 22 | 23 | class TSPModel(pl.LightningModule): 24 | def __init__(self, cfg: DictConfig): 25 | super().__init__() 26 | if cfg.node_dim > 2: 27 | assert ( 28 | "noAug" in cfg.val_type 29 | ), "High-dimension TSP doesn't support augmentation" 30 | 31 | ## Encoder model 32 | if cfg.encoder_type == "mha": 33 | self.encoder = models.MHAEncoder( 34 | n_layers=cfg.n_layers, 35 | n_heads=cfg.n_heads, 36 | embedding_dim=cfg.embedding_dim, 37 | input_dim=24 if cfg.data_augment else 2, 38 | add_init_projection=cfg.add_init_projection, 39 | ) 40 | elif cfg.encoder_type == "revmha": 41 | self.encoder = models.RevMHAEncoder( 42 | n_layers=cfg.n_layers, 43 | n_heads=cfg.n_heads, 44 | embedding_dim=cfg.embedding_dim, 45 | input_dim=24 if cfg.data_augment else 2, 46 | intermediate_dim=cfg.embedding_dim * 4, 47 | add_init_projection=cfg.add_init_projection, 48 | ) 49 | 50 | ## Decoder model 51 | if cfg.decoder_type == "DecoderForLarge": 52 | self.decoder = models.DecoderForLarge( 53 | embedding_dim=cfg.embedding_dim, 54 | n_heads=cfg.n_heads, 55 | tanh_clipping=cfg.tanh_clipping, 56 | multi_pointer=8, 57 | multi_pointer_level=1, 58 | add_more_query=True, 59 | ) 60 | else: 61 | self.decoder = models.Decoder( 62 | embedding_dim=cfg.embedding_dim, 63 | n_heads=cfg.n_heads, 64 | tanh_clipping=cfg.tanh_clipping, 65 | ) 66 | 67 | self.cfg = cfg 68 | self.save_hyperparameters(cfg) 69 | 70 | def configure_optimizers(self): 71 | optimizer = torch.optim.Adam( 72 | self.parameters(), 73 | lr=self.cfg.learning_rate * len(self.cfg.gpus), 74 | weight_decay=self.cfg.weight_decay, 75 | ) 76 | lr_scheduler = torch.optim.lr_scheduler.StepLR( 77 | optimizer, step_size=1, gamma=1.0 78 | ) 79 | return [optimizer], [{"scheduler": lr_scheduler, "interval": "epoch"}] 80 | 81 | def train_dataloader(self): 82 | self.group_size = self.cfg.group_size 83 | dataset = TSPDataset( 84 | size=self.cfg.graph_size, 85 | node_dim=self.cfg.node_dim, 86 | num_samples=self.cfg.epoch_size, 87 | data_distribution=self.cfg.data_distribution, 88 | data_path=self.cfg.val_data_path, 89 | ) 90 | return DataLoader( 91 | dataset, 92 | num_workers=os.cpu_count(), 93 | batch_size=self.cfg.train_batch_size, 94 | pin_memory=True, 95 | ) 96 | 97 | def val_dataloader(self): 98 | dataset = TSPDataset( 99 | size=self.cfg.graph_size, 100 | node_dim=self.cfg.node_dim, 101 | num_samples=self.cfg.val_size, 102 | data_distribution=self.cfg.data_distribution, 103 | data_path=self.cfg.val_data_path, 104 | ) 105 | return DataLoader( 106 | dataset, 107 | batch_size=self.cfg.val_batch_size, 108 | num_workers=os.cpu_count(), 109 | pin_memory=True, 110 | ) 111 | 112 | def training_step(self, batch, _, TSPLIB=True): 113 | if TSPLIB: 114 | scale = ( 115 | (batch.max(1, keepdim=True).values - batch.min(1, keepdim=True).values) 116 | .max(2, keepdim=True) 117 | .values 118 | ) 119 | batch_trans = ( 120 | (batch - batch.min(1, keepdim=True).values) / scale 121 | ) * 0.8 + 0.1 122 | 123 | B, N, _ = batch.shape 124 | G = self.group_size 125 | 126 | batch_idx_range = torch.arange(B)[:, None].expand(B, G) 127 | group_idx_range = torch.arange(G)[None, :].expand(B, G) 128 | 129 | env = MultiTrajectoryTSP(batch_trans, x_raw=batch, TSPLIB=TSPLIB) 130 | s, r, d = env.reset(group_size=G) # reset env 131 | 132 | # Data argument 133 | if self.cfg.data_augment: 134 | batch = utils.data_augment(batch_trans) 135 | 136 | # Encode 137 | embeddings = self.encoder(batch) 138 | self.decoder.reset(batch, embeddings, G) # decoder reset 139 | 140 | entropy_list = [] 141 | log_prob = torch.zeros(B, G, device=self.device) 142 | while not d: 143 | if s.current_node is None: 144 | first_action = torch.randperm(N, device=self.device)[None, :G].expand( 145 | B, G 146 | ) 147 | s, r, d = env.step(first_action) 148 | continue 149 | else: 150 | last_node = s.current_node 151 | 152 | action_probs = self.decoder(last_node, s.ninf_mask, s.selected_count) 153 | m = Categorical(action_probs.reshape(B * G, -1)) 154 | entropy_list.append(m.entropy().mean().item()) 155 | action = m.sample().view(B, G) 156 | chosen_action_prob = ( 157 | action_probs[batch_idx_range, group_idx_range, action].reshape(B, G) 158 | + 1e-8 159 | ) 160 | log_prob += chosen_action_prob.log() 161 | s, r, d = env.step(action) 162 | 163 | # Note that when G == 1, we can only use the PG without baseline so far 164 | r_trans = r.to(self.device) # -1/torch.exp(r) 165 | if self.cfg.divide_std: 166 | advantage = ( 167 | (r_trans - r_trans.mean(dim=1, keepdim=True)) 168 | / (r_trans.std(dim=1, unbiased=False, keepdim=True) + 1e-8) 169 | if G != 1 170 | else r_trans 171 | ) 172 | else: 173 | advantage = ( 174 | (r_trans - r_trans.mean(dim=1, keepdim=True)) if G != 1 else r_trans 175 | ) 176 | loss = (-advantage * log_prob).mean() 177 | 178 | length_max = -r.max(dim=1)[0].mean().clone().detach().item() 179 | lenght_mean = -r.mean(1).mean().clone().detach().item() 180 | entropy_mean = sum(entropy_list) / len(entropy_list) 181 | 182 | self.log( 183 | name="length_max", value=length_max, prog_bar=True, 184 | ) # sync_dist=True 185 | self.log( 186 | name="length_mean", value=lenght_mean, prog_bar=True, 187 | ) 188 | self.log( 189 | name="entropy", value=entropy_mean, prog_bar=True, 190 | ) 191 | 192 | embed_max = embeddings.abs().max() 193 | self.log( 194 | name="embed_max", value=embed_max, prog_bar=True, 195 | ) 196 | 197 | assert torch.isnan(loss).sum() == 0, print("loss is nan!") 198 | 199 | return {"loss": loss, "length": length_max, "entropy": entropy_mean} 200 | 201 | def training_epoch_end(self, outputs): 202 | outputs = torch.as_tensor([item["length"] for item in outputs]) 203 | self.train_length_mean = outputs.mean().item() 204 | self.train_length_std = outputs.std().item() 205 | 206 | def validate_all( 207 | self, batch, val_type="None", return_pi=False, real_data=False, sample=False 208 | ): 209 | batch_raw = batch 210 | if real_data: 211 | scale = ( 212 | (batch.max(1, keepdim=True).values - batch.min(1, keepdim=True).values) 213 | .max(2, keepdim=True) 214 | .values 215 | ) 216 | batch_trans = ( 217 | (batch - batch.min(1, keepdim=True).values) / scale 218 | ) * 0.9 + 0.05 219 | 220 | if self.cfg.val_type == "x8Aug_nTraj": 221 | if real_data: 222 | batch_raw = batch_raw.repeat(8, 1, 1) 223 | batch = utils.augment_xy_data_by_8_fold(batch_trans) 224 | else: 225 | batch_raw = batch_raw.repeat(8, 1, 1) 226 | batch = utils.augment_xy_data_by_8_fold(batch) 227 | B, N, _ = batch.shape 228 | G = N 229 | elif self.cfg.val_type == "nTraj": 230 | B, N, _ = batch.shape 231 | G = N 232 | else: 233 | B, N, _ = batch.shape 234 | G = 1 235 | 236 | env = MultiTrajectoryTSP(batch, x_raw=batch_raw, integer=self.cfg.integer) 237 | 238 | s, r, d = env.reset(group_size=G) 239 | 240 | if self.cfg.data_augment: 241 | batch = utils.data_augment(batch) 242 | 243 | embeddings = self.encoder(batch) 244 | self.decoder.reset(batch, embeddings, G, trainging=False) 245 | 246 | first_action = torch.randperm(N)[None, :G].expand(B, G).to(self.device) 247 | pi = first_action[..., None] 248 | s, r, d = env.step(first_action) 249 | 250 | while not d: 251 | action_probs = self.decoder( 252 | s.current_node.to(self.device), 253 | s.ninf_mask.to(self.device), 254 | s.selected_count, 255 | ) 256 | 257 | if sample: 258 | m = Categorical(action_probs.reshape(B * G, -1)) 259 | action = m.sample().view(B, G) 260 | else: 261 | action = action_probs.argmax(dim=2) 262 | 263 | # pi = torch.cat([pi, action[..., None]], dim=-1) 264 | s, r, d = env.step(action) 265 | 266 | if self.cfg.val_type == "x8Aug_nTraj": 267 | B = round(B / 8) 268 | reward = r.reshape(8, B, G) 269 | # pi = pi.reshape(8, B, G, N) 270 | 271 | reward_greedy = reward[0, :, 0] 272 | reward_ntraj, idx_dim_ntraj = reward[0, :, :].max(dim=-1) 273 | max_reward_aug_ntraj, idx_dim_2 = reward.max(dim=2) 274 | max_reward_aug_ntraj, idx_dim_0 = max_reward_aug_ntraj.max(dim=0) 275 | 276 | # best_pi_greedy = pi[0,:,0] 277 | # idx_dim_ntraj = idx_dim_ntraj.reshape(B,1,1) 278 | # best_pi_n_traj = pi[0,:,:].gather(1,idx_dim_ntraj.repeat(1,1,N)) #(B,G,N) 279 | # idx_dim_0 = idx_dim_0.reshape(1, B, 1, 1) 280 | # idx_dim_2 = idx_dim_2.reshape(8, B, 1, 1).gather(0, idx_dim_0) 281 | # best_pi_aug_ntraj = pi.gather(0, idx_dim_0.repeat(1, 1, G, N)) 282 | # best_pi_aug_ntraj = best_pi_aug_ntraj.gather(2, idx_dim_2.repeat(1, 1, 1, N)) 283 | # print(torch.sort(best_pi_aug_ntraj[0][0][0])) 284 | 285 | return { 286 | "max_reward_aug_ntraj": -max_reward_aug_ntraj, 287 | "reward_greedy": -reward_greedy, 288 | "reward_ntraj": -reward_ntraj, 289 | } 290 | elif self.cfg.val_type == "nTraj": 291 | reward = r 292 | reward_greedy = reward[:, 0] 293 | reward_ntraj = reward.max(dim=-1).values 294 | 295 | return { 296 | "max_reward_aug_ntraj": torch.zeros(B), 297 | "reward_greedy": -reward_greedy, 298 | "reward_ntraj": -reward_ntraj, 299 | } 300 | elif self.cfg.val_type == "1Traj": 301 | reward = r 302 | reward_greedy = reward[:, 0] 303 | return { 304 | "max_reward_aug_ntraj": torch.zeros(B), 305 | "reward_greedy": -reward_greedy, 306 | "reward_ntraj": torch.zeros(B), 307 | } 308 | else: 309 | print("no {} eval type".format(self.cfg.val_type)) 310 | 311 | def validation_step(self, batch, batch_idx): 312 | self.validation_start_time = time.time() 313 | with torch.no_grad(): 314 | outputs = self.validate_all( 315 | batch, real_data=self.cfg.real_data, sample=self.cfg.sample 316 | ) 317 | return outputs 318 | 319 | def validation_step_end(self, batch_parts): 320 | return batch_parts 321 | 322 | def on_validation_epoch_start(self) -> None: 323 | self.validation_start_time = time.time() 324 | 325 | def validation_epoch_end(self, outputs): 326 | self.validation_time = time.time() - self.validation_start_time 327 | max_reward_aug_ntraj = [item["max_reward_aug_ntraj"] for item in outputs] 328 | reward_greedy = [item["reward_greedy"] for item in outputs] 329 | reward_ntraj = [item["reward_ntraj"] for item in outputs] 330 | 331 | self.max_reward_aug_ntraj = torch.cat(max_reward_aug_ntraj).mean().item() 332 | self.max_reward_aug_ntraj_std = torch.cat(max_reward_aug_ntraj).std().item() 333 | self.reward_greedy = torch.cat(reward_greedy).mean().item() 334 | self.reward_greedy_std = torch.cat(reward_greedy).std().item() 335 | self.reward_ntraj = torch.cat(reward_ntraj).mean().item() 336 | self.reward_ntraj_std = torch.cat(reward_ntraj).std().item() 337 | 338 | def on_validation_epoch_end(self): 339 | 340 | self.log_dict( 341 | { 342 | # "train_graph_size": self.train_graph_size, 343 | # "train_length": self.train_length_mean, 344 | "max_reward_aug_ntraj": self.max_reward_aug_ntraj, 345 | "reward_greedy": self.reward_greedy, 346 | "reward_ntraj": self.reward_ntraj, 347 | "validation_time": self.validation_time, 348 | }, 349 | on_epoch=True, 350 | on_step=False, 351 | ) 352 | 353 | # self. 354 | print( 355 | f"\nEpoch {self.current_epoch}: ", 356 | # f'train_graph_size={self.train_graph_size}, ', 357 | # 'train_performance={:.03f}±{:.03f}, '.format(self.train_length_mean, self.train_length_std), 358 | "max_reward_aug_ntraj={:.03f}±{:.03f}, ".format( 359 | self.max_reward_aug_ntraj, self.max_reward_aug_ntraj_std 360 | ), 361 | "reward_greedy={:.03f}±{:.3f}, ".format( 362 | self.reward_greedy, self.reward_greedy_std 363 | ), 364 | "reward_ntraj={:.03f}±{:.03f}, ".format( 365 | self.reward_ntraj, self.reward_ntraj_std 366 | ), 367 | "validation time={:.03f}".format(self.validation_time), 368 | ) 369 | 370 | 371 | @hydra.main(config_name="config") 372 | def run(cfg: DictConfig) -> None: 373 | pl.seed_everything(cfg.seed) 374 | cfg.run_name = cfg.run_name or cfg.default_run_name 375 | if cfg.save_dir is None: 376 | root_dir = (os.getcwd(),) 377 | elif os.path.isabs(cfg.save_dir): 378 | root_dir = cfg.save_dir 379 | else: 380 | root_dir = os.path.join(hydra.utils.get_original_cwd(), cfg.save_dir) 381 | root_dir = os.path.join(root_dir, f"{cfg.run_name}") 382 | with open_dict(cfg): 383 | cfg.root_dir = root_dir 384 | 385 | cfg.load_checkpoint_path = os.path.join( 386 | hydra.utils.get_original_cwd(), cfg.load_checkpoint_path 387 | ) 388 | cfg.val_data_path = os.path.join(hydra.utils.get_original_cwd(), cfg.val_data_path) 389 | 390 | # build TSPModel 391 | tsp_model = TSPModel(cfg) 392 | tsp_model = tsp_model.load_from_checkpoint(cfg.load_checkpoint_path) 393 | tsp_model.cfg = cfg 394 | 395 | # build trainer 396 | trainer = pl.Trainer( 397 | default_root_dir=root_dir, 398 | gpus=cfg.gpus, 399 | accelerator="dp", 400 | precision=cfg.precision, 401 | max_epochs=cfg.total_epoch, 402 | reload_dataloaders_every_epoch=True, 403 | num_sanity_val_steps=0, 404 | callbacks=[], 405 | ) 406 | 407 | dataset = TSPDataset( 408 | size=cfg.graph_size, 409 | node_dim=cfg.node_dim, 410 | num_samples=cfg.val_size, 411 | data_distribution=cfg.data_distribution, 412 | data_path=cfg.val_data_path, 413 | ) 414 | 415 | val_dataloader = DataLoader( 416 | dataset, 417 | batch_size=cfg.val_batch_size, 418 | num_workers=os.cpu_count(), 419 | pin_memory=True, 420 | ) 421 | 422 | trainer.validate(tsp_model, val_dataloaders=val_dataloader) 423 | 424 | 425 | if __name__ == "__main__": 426 | run() 427 | -------------------------------------------------------------------------------- /groubi_tsp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import random 4 | from itertools import combinations 5 | import gurobipy as gp 6 | from gurobipy import GRB 7 | from env import readDataFile 8 | import time 9 | 10 | n = None 11 | 12 | 13 | # Parse argument 14 | 15 | # if len(sys.argv) < 2: 16 | # print('Usage: tsp.py npoints') 17 | # sys.exit(1) 18 | # n = int(sys.argv[1]) 19 | 20 | # Create n random points 21 | 22 | # random.seed(1) 23 | # points = [(random.randint(0, 100), random.randint(0, 100)) for i in range(n)] 24 | 25 | # Dictionary of Euclidean distance between each pair of points 26 | 27 | 28 | def solver(data_path): 29 | # Callback - use lazy constraints to eliminate sub-tours 30 | def subtourelim(model, where): 31 | if where == GRB.Callback.MIPSOL: 32 | vals = model.cbGetSolution(model._vars) 33 | # find the shortest cycle in the selected edge list 34 | tour = subtour(vals) 35 | if len(tour) < n: 36 | # add subtour elimination constr. for every pair of cities in tour 37 | model.cbLazy( 38 | gp.quicksum(model._vars[i, j] for i, j in combinations(tour, 2)) 39 | <= len(tour) - 1 40 | ) 41 | 42 | # Given a tuplelist of edges, find the shortest subtour 43 | 44 | def subtour(vals): 45 | # make a list of edges selected in the solution 46 | edges = gp.tuplelist((i, j) for i, j in vals.keys() if vals[i, j] > 0.5) 47 | unvisited = list(range(n)) 48 | cycle = range(n + 1) # initial length has 1 more city 49 | while unvisited: # true if list is non-empty 50 | thiscycle = [] 51 | neighbors = unvisited 52 | while neighbors: 53 | current = neighbors[0] 54 | thiscycle.append(current) 55 | unvisited.remove(current) 56 | neighbors = [j for i, j in edges.select(current, "*") if j in unvisited] 57 | if len(cycle) > len(thiscycle): 58 | cycle = thiscycle 59 | return cycle 60 | 61 | data = readDataFile(data_path) 62 | n = data.shape[1] 63 | all_res = [] 64 | start = time.time() 65 | for points in data: 66 | 67 | dist = { 68 | (i, j): math.sqrt(sum((points[i][k] - points[j][k]) ** 2 for k in range(2))) 69 | for i in range(n) 70 | for j in range(i) 71 | } 72 | 73 | m = gp.Model() 74 | 75 | # Create variables 76 | 77 | vars = m.addVars(dist.keys(), obj=dist, vtype=GRB.BINARY, name="e") 78 | for i, j in vars.keys(): 79 | vars[j, i] = vars[i, j] # edge in opposite direction 80 | 81 | # You could use Python looping constructs and m.addVar() to create 82 | # these decision variables instead. The following would be equivalent 83 | # to the preceding m.addVars() call... 84 | # 85 | # vars = tupledict() 86 | # for i,j in dist.keys(): 87 | # vars[i,j] = m.addVar(obj=dist[i,j], vtype=GRB.BINARY, 88 | # name='e[%d,%d]'%(i,j)) 89 | 90 | # Add degree-2 constraint 91 | 92 | m.addConstrs(vars.sum(i, "*") == 2 for i in range(n)) 93 | 94 | # Using Python looping constructs, the preceding would be... 95 | # 96 | # for i in range(n): 97 | # m.addConstr(sum(vars[i,j] for j in range(n)) == 2) 98 | 99 | # Optimize model 100 | 101 | m._vars = vars 102 | m.Params.LazyConstraints = 1 103 | m.optimize(subtourelim) 104 | 105 | vals = m.getAttr("X", vars) 106 | tour = subtour(vals) 107 | 108 | assert len(tour) == n 109 | 110 | # print('') 111 | # print('Optimal tour: %s' % str(tour)) 112 | # print('Optimal cost: %g' % m.ObjVal) 113 | # print('') 114 | 115 | all_res.append(m.ObjVal) 116 | end = time.time() 117 | mean_len = sum(all_res) / len(all_res) 118 | total_time = end - start 119 | with open("./data/res_partner_{}".format(n), "w") as fp: 120 | fp.write(str(mean_len) + "\n" + str(total_time)) 121 | 122 | print(mean_len, total_time) 123 | 124 | 125 | if __name__ == "__main__": 126 | ## 127 | for n in [500, 1000]: 128 | solver("./data/partner_{}.txt".format(n)) 129 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import Tensor, nn 6 | import torch.utils.checkpoint 7 | import utils 8 | import revtorch as rv 9 | 10 | 11 | class MHAEncoderLayer(torch.nn.Module): 12 | def __init__(self, embedding_dim, n_heads=8): 13 | super().__init__() 14 | self.n_heads = n_heads 15 | self.Wq = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) 16 | self.Wk = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) 17 | self.Wv = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) 18 | self.multi_head_combine = torch.nn.Linear(embedding_dim, embedding_dim) 19 | self.feed_forward = torch.nn.Sequential( 20 | torch.nn.Linear(embedding_dim, embedding_dim * 4), 21 | torch.nn.ReLU(), 22 | torch.nn.Linear(embedding_dim * 4, embedding_dim), 23 | ) 24 | self.norm1 = torch.nn.BatchNorm1d(embedding_dim) 25 | self.norm2 = torch.nn.BatchNorm1d(embedding_dim) 26 | 27 | def forward(self, x, mask=None): 28 | q = utils.make_heads(self.Wq(x), self.n_heads) 29 | k = utils.make_heads(self.Wk(x), self.n_heads) 30 | v = utils.make_heads(self.Wv(x), self.n_heads) 31 | x = x + self.multi_head_combine(utils.multi_head_attention(q, k, v, mask)) 32 | x = self.norm1(x.view(-1, x.size(-1))).view(*x.size()) 33 | x = x + self.feed_forward(x) 34 | x = self.norm2(x.view(-1, x.size(-1))).view(*x.size()) 35 | return x 36 | 37 | 38 | class MHAEncoder(torch.nn.Module): 39 | def __init__( 40 | self, n_layers, n_heads, embedding_dim, input_dim, add_init_projection=True 41 | ): 42 | super().__init__() 43 | if add_init_projection or input_dim != embedding_dim: 44 | self.init_projection_layer = torch.nn.Linear(input_dim, embedding_dim) 45 | self.attn_layers = torch.nn.ModuleList( 46 | [ 47 | MHAEncoderLayer(embedding_dim=embedding_dim, n_heads=n_heads) 48 | for _ in range(n_layers) 49 | ] 50 | ) 51 | 52 | def forward(self, x, mask=None): 53 | if hasattr(self, "init_projection_layer"): 54 | x = self.init_projection_layer(x) 55 | for idx, layer in enumerate(self.attn_layers): 56 | x = layer(x, mask) 57 | return x 58 | 59 | 60 | class Decoder(torch.nn.Module): 61 | def __init__( 62 | self, embedding_dim, n_heads=8, tanh_clipping=10.0, 63 | ): 64 | super().__init__() 65 | self.embedding_dim = embedding_dim 66 | self.n_heads = n_heads 67 | self.tanh_clipping = tanh_clipping 68 | 69 | self.Wq_graph = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) 70 | self.Wq_first = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) 71 | self.Wq_last = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) 72 | 73 | self.Wk = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) 74 | self.Wv = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) 75 | 76 | self.multi_head_combine = torch.nn.Linear(embedding_dim, embedding_dim) 77 | # self.query_fc = torch.nn.Linear(embedding_dim*2, embedding_dim) 78 | 79 | self.q_graph = None # saved q1, for multi-head attention 80 | self.q_first = None # saved q2, for multi-head attention 81 | self.glimpse_k = None # saved key, for multi-head attention 82 | self.glimpse_v = None # saved value, for multi-head_attention 83 | self.logit_k = None # saved, for single-head attention 84 | self.group_ninf_mask = None # reference to ninf_mask owned by state 85 | 86 | def reset(self, coordinates, embeddings, G, trainging=False): 87 | # embeddings.shape = [B, N, H] 88 | # graph_embedding.shape = [B, 1, H] 89 | # q_graph.hape = [B, n_heads, 1, key_dim] 90 | # glimpse_k.shape = glimpse_v.shape =[B, n_heads, N, key_dim] 91 | # logit_k.shape = [B, H, N] 92 | # group_ninf_mask.shape = [B, G, N] 93 | self.coordinates = coordinates 94 | self.embeddings = embeddings 95 | graph_embedding = self.embeddings.mean(dim=1, keepdim=True) 96 | 97 | self.q_graph = self.Wq_graph(graph_embedding) 98 | self.q_first = None 99 | self.glimpse_k = utils.make_heads(self.Wk(embeddings), self.n_heads) 100 | self.glimpse_v = utils.make_heads(self.Wv(embeddings), self.n_heads) 101 | self.logit_k = embeddings.transpose(1, 2) 102 | # self.group_ninf_mask = group_ninf_mask 103 | 104 | def forward(self, last_node, group_ninf_mask, step): 105 | self.group_ninf_mask = group_ninf_mask 106 | B, N, H = self.embeddings.shape 107 | G = self.group_ninf_mask.size(1) 108 | 109 | last_node_index = last_node.view(B, G, 1).expand(-1, -1, H) 110 | last_node_embedding = self.embeddings.gather(1, last_node_index) 111 | 112 | # q_graph.shape = [B, n_heads, 1, key_dim] 113 | # q_first.shape = q_last.shape = [B, n_heads, G, key_dim] 114 | if self.q_first is None: 115 | # self.q_first = utils.make_heads(self.Wq_first(last_node_embedding), 116 | # self.n_heads) 117 | self.q_first = self.Wq_first(last_node_embedding) 118 | 119 | q_last = self.Wq_last(last_node_embedding) 120 | # glimpse_q.shape = [B, n_heads, G, key_dim] 121 | glimpse_q = self.q_first + q_last + self.q_graph 122 | 123 | glimpse_q = utils.make_heads(glimpse_q, self.n_heads) 124 | if self.n_decoding_neighbors is not None: 125 | D = self.coordinates.size(-1) 126 | K = torch.count_nonzero(self.group_ninf_mask[0, 0] == 0.0).item() 127 | K = min(self.n_decoding_neighbors, K) 128 | last_node_coordinate = self.coordinates.gather( 129 | dim=1, index=last_node.unsqueeze(-1).expand(B, G, D) 130 | ) 131 | distances = torch.cdist(last_node_coordinate, self.coordinates) 132 | distances[self.group_ninf_mask == -np.inf] = np.inf 133 | indices = distances.topk(k=K, dim=-1, largest=False).indices 134 | glimpse_mask = torch.ones_like(self.group_ninf_mask) * (-np.inf) 135 | glimpse_mask.scatter_( 136 | dim=-1, index=indices, src=torch.zeros_like(glimpse_mask) 137 | ) 138 | else: 139 | glimpse_mask = self.group_ninf_mask 140 | 141 | # q_last_nhead=utils.make_heads(q_last,self.n_heads) 142 | attn_out = utils.multi_head_attention( 143 | q=glimpse_q, k=self.glimpse_k, v=self.glimpse_v, mask=glimpse_mask, 144 | ) 145 | 146 | # mha_out.shape = [B, G, H] 147 | # score.shape = [B, G, N] 148 | final_q = self.multi_head_combine(attn_out) 149 | score = torch.matmul(final_q, self.logit_k) 150 | 151 | score_clipped = torch.tanh(score) * self.tanh_clipping 152 | score_masked = score_clipped + self.group_ninf_mask 153 | 154 | probs = F.softmax(score_masked, dim=2) 155 | 156 | assert (probs == probs).all(), "Probs should not contain any nans!" 157 | return probs 158 | 159 | 160 | """ 161 | RevMHAEncoder 162 | """ 163 | 164 | 165 | class MHABlock(nn.Module): 166 | def __init__(self, hidden_size: int, num_heads: int): 167 | super().__init__() 168 | self.mixing_layer_norm = nn.BatchNorm1d(hidden_size) 169 | self.mha = nn.MultiheadAttention(hidden_size, num_heads, bias=False) 170 | 171 | def forward(self, hidden_states: Tensor): 172 | 173 | assert hidden_states.dim() == 3 174 | hidden_states = self.mixing_layer_norm(hidden_states.transpose(1, 2)).transpose( 175 | 1, 2 176 | ) 177 | hidden_states_t = hidden_states.transpose(0, 1) 178 | mha_output = self.mha(hidden_states_t, hidden_states_t, hidden_states_t)[ 179 | 0 180 | ].transpose(0, 1) 181 | 182 | return mha_output 183 | 184 | 185 | class FFBlock(nn.Module): 186 | def __init__(self, hidden_size: int, intermediate_size: int): 187 | super().__init__() 188 | self.feed_forward = nn.Linear(hidden_size, intermediate_size) 189 | self.output_dense = nn.Linear(intermediate_size, hidden_size) 190 | self.output_layer_norm = nn.BatchNorm1d(hidden_size) 191 | self.activation = nn.GELU() 192 | 193 | def forward(self, hidden_states: Tensor): 194 | hidden_states = ( 195 | self.output_layer_norm(hidden_states.transpose(1, 2)) 196 | .transpose(1, 2) 197 | .contiguous() 198 | ) 199 | intermediate_output = self.feed_forward(hidden_states) 200 | intermediate_output = self.activation(intermediate_output) 201 | output = self.output_dense(intermediate_output) 202 | 203 | return output 204 | 205 | 206 | class RevMHAEncoder(nn.Module): 207 | def __init__( 208 | self, 209 | n_layers: int, 210 | n_heads: int, 211 | embedding_dim: int, 212 | input_dim: int, 213 | intermediate_dim: int, 214 | add_init_projection=True, 215 | ): 216 | super().__init__() 217 | if add_init_projection or input_dim != embedding_dim: 218 | self.init_projection_layer = torch.nn.Linear(input_dim, embedding_dim) 219 | self.num_hidden_layers = n_layers 220 | blocks = [] 221 | for _ in range(n_layers): 222 | f_func = MHABlock(embedding_dim, n_heads) 223 | g_func = FFBlock(embedding_dim, intermediate_dim) 224 | # we construct a reversible block with our F and G functions 225 | blocks.append(rv.ReversibleBlock(f_func, g_func, split_along_dim=-1)) 226 | 227 | self.sequence = rv.ReversibleSequence(nn.ModuleList(blocks)) 228 | 229 | def forward(self, x: Tensor, mask=None): 230 | if hasattr(self, "init_projection_layer"): 231 | x = self.init_projection_layer(x) 232 | x = torch.cat([x, x], dim=-1) 233 | out = self.sequence(x) 234 | return torch.stack(out.chunk(2, dim=-1))[-1] 235 | 236 | 237 | class DecoderForLarge(torch.nn.Module): 238 | def __init__( 239 | self, 240 | embedding_dim, 241 | n_heads=8, 242 | tanh_clipping=10.0, 243 | multi_pointer=1, 244 | multi_pointer_level=1, 245 | add_more_query=True, 246 | ): 247 | super().__init__() 248 | self.embedding_dim = embedding_dim 249 | self.n_heads = n_heads 250 | self.tanh_clipping = tanh_clipping 251 | 252 | self.Wq_graph = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) 253 | self.Wq_first = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) 254 | 255 | self.Wq_last = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) 256 | self.wq = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) 257 | 258 | self.W_visited = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) 259 | 260 | self.Wk = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) 261 | # self.Wv = torch.nn.Linear(embedding_dim, embedding_dim, bias=False) 262 | # self.multi_head_combine = torch.nn.Linear(embedding_dim, embedding_dim) 263 | 264 | self.q_graph = None # saved q1, for multi-head attention 265 | self.q_first = None # saved q2, for multi-head attention 266 | self.glimpse_k = None # saved key, for multi-head attention 267 | self.glimpse_v = None # saved value, for multi-head_attention 268 | self.logit_k = None # saved, for single-head attention 269 | self.group_ninf_mask = None # reference to ninf_mask owned by state 270 | self.multi_pointer = multi_pointer # 271 | self.multi_pointer_level = multi_pointer_level 272 | self.add_more_query = add_more_query 273 | 274 | def reset(self, coordinates, embeddings, G, trainging=True): 275 | # embeddings.shape = [B, N, H] 276 | # graph_embedding.shape = [B, 1, H] 277 | # q_graph.hape = [B, n_heads, 1, key_dim] 278 | # glimpse_k.shape = glimpse_v.shape =[B, n_heads, N, key_dim] 279 | # logit_k.shape = [B, H, N] 280 | # group_ninf_mask.shape = [B, G, N] 281 | 282 | B, N, H = embeddings.shape 283 | # G = group_ninf_mask.size(1) 284 | 285 | self.coordinates = coordinates # [:,:2] 286 | self.embeddings = embeddings 287 | self.embeddings_group = self.embeddings.unsqueeze(1).expand(B, G, N, H) 288 | graph_embedding = self.embeddings.mean(dim=1, keepdim=True) 289 | 290 | self.q_graph = self.Wq_graph(graph_embedding) 291 | self.q_first = None 292 | self.logit_k = embeddings.transpose(1, 2) 293 | if self.multi_pointer > 1: 294 | self.logit_k = utils.make_heads( 295 | self.Wk(embeddings), self.multi_pointer 296 | ).transpose( 297 | 2, 3 298 | ) # [B, n_heads, key_dim, N] 299 | 300 | def forward(self, last_node, group_ninf_mask, S): 301 | B, N, H = self.embeddings.shape 302 | G = group_ninf_mask.size(1) 303 | 304 | # Get last node embedding 305 | last_node_index = last_node.view(B, G, 1).expand(-1, -1, H) 306 | last_node_embedding = self.embeddings.gather(1, last_node_index) 307 | q_last = self.Wq_last(last_node_embedding) 308 | 309 | # Get frist node embedding 310 | if self.q_first is None: 311 | self.q_first = self.Wq_first(last_node_embedding) 312 | group_ninf_mask = group_ninf_mask.detach() 313 | 314 | mask_visited = group_ninf_mask.clone() 315 | mask_visited[mask_visited == -np.inf] = 1.0 316 | q_visited = self.W_visited(torch.bmm(mask_visited, self.embeddings) / N) 317 | D = self.coordinates.size(-1) 318 | last_node_coordinate = self.coordinates.gather( 319 | dim=1, index=last_node.unsqueeze(-1).expand(B, G, D) 320 | ) 321 | distances = torch.cdist(last_node_coordinate, self.coordinates) 322 | 323 | if self.add_more_query: 324 | final_q = q_last + self.q_first + self.q_graph + q_visited 325 | else: 326 | final_q = q_last + self.q_first + self.q_graph 327 | 328 | if self.multi_pointer > 1: 329 | final_q = utils.make_heads( 330 | self.wq(final_q), self.n_heads 331 | ) # (B,n_head,G,H) (B,n_head,H,N) 332 | score = (torch.matmul(final_q, self.logit_k) / math.sqrt(H)) - ( 333 | distances / math.sqrt(2) 334 | ).unsqueeze( 335 | 1 336 | ) # (B,n_head,G,N) 337 | if self.multi_pointer_level == 1: 338 | score_clipped = self.tanh_clipping * torch.tanh(score.mean(1)) 339 | elif self.multi_pointer_level == 2: 340 | score_clipped = (self.tanh_clipping * torch.tanh(score)).mean(1) 341 | else: 342 | # add mask 343 | score_clipped = self.tanh_clipping * torch.tanh(score) 344 | mask_prob = group_ninf_mask.detach().clone() 345 | mask_prob[mask_prob == -np.inf] = -1e8 346 | 347 | score_masked = score_clipped + mask_prob.unsqueeze(1) 348 | probs = F.softmax(score_masked, dim=-1).mean(1) 349 | return probs 350 | else: 351 | score = torch.matmul(final_q, self.logit_k) / math.sqrt( 352 | H 353 | ) - distances / math.sqrt(2) 354 | score_clipped = self.tanh_clipping * torch.tanh(score) 355 | 356 | # add mask 357 | mask_prob = group_ninf_mask.detach().clone() 358 | mask_prob[mask_prob == -np.inf] = -1e8 359 | score_masked = score_clipped + mask_prob 360 | probs = F.softmax(score_masked, dim=2) 361 | 362 | return probs 363 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning==1.3.0 2 | revtorch==0.2.4 3 | wandb==0.12.10 4 | hydra-core==1.1.1 5 | tensorboardX==2.4.1 6 | scikit-learn==1.0.2 7 | torch==1.10.0 -------------------------------------------------------------------------------- /revtorch/__init__.py: -------------------------------------------------------------------------------- 1 | from revtorch.revtorch import ReversibleBlock, ReversibleSequence -------------------------------------------------------------------------------- /revtorch/revtorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # import torch.autograd.function as func 5 | import sys 6 | import random 7 | 8 | 9 | class ReversibleBlock(nn.Module): 10 | """ 11 | Elementary building block for building (partially) reversible architectures 12 | 13 | Implementation of the Reversible block described in the RevNet paper 14 | (https://arxiv.org/abs/1707.04585). Must be used inside a :class:`revtorch.ReversibleSequence` 15 | for autograd support. 16 | 17 | Arguments: 18 | f_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape 19 | g_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape 20 | split_along_dim (integer): dimension along which the tensor is split into the two parts requried for the reversible block 21 | fix_random_seed (boolean): Use the same random seed for the forward and backward pass if set to true 22 | """ 23 | 24 | def __init__(self, f_block, g_block, split_along_dim=1, fix_random_seed=False): 25 | super(ReversibleBlock, self).__init__() 26 | self.f_block = f_block 27 | self.g_block = g_block 28 | self.split_along_dim = split_along_dim 29 | self.fix_random_seed = fix_random_seed 30 | self.random_seeds = {} 31 | 32 | def _init_seed(self, namespace): 33 | if self.fix_random_seed: 34 | self.random_seeds[namespace] = random.randint(0, sys.maxsize) 35 | self._set_seed(namespace) 36 | 37 | def _set_seed(self, namespace): 38 | if self.fix_random_seed: 39 | torch.manual_seed(self.random_seeds[namespace]) 40 | 41 | def forward(self, x): 42 | """ 43 | Performs the forward pass of the reversible block. Does not record any gradients. 44 | :param x: Input tensor. Must be splittable along dimension 1. 45 | :return: Output tensor of the same shape as the input tensor 46 | """ 47 | x1, x2 = torch.chunk(x, 2, dim=self.split_along_dim) 48 | y1, y2 = None, None 49 | with torch.no_grad(): 50 | self._init_seed("f") 51 | y1 = x1 + self.f_block(x2) 52 | self._init_seed("g") 53 | y2 = x2 + self.g_block(y1) 54 | 55 | return torch.cat([y1, y2], dim=self.split_along_dim) 56 | 57 | def backward_pass(self, y, dy, retain_graph): 58 | """ 59 | Performs the backward pass of the reversible block. 60 | 61 | Calculates the derivatives of the block's parameters in f_block and g_block, as well as the inputs of the 62 | forward pass and its gradients. 63 | 64 | :param y: Outputs of the reversible block 65 | :param dy: Derivatives of the outputs 66 | :param retain_graph: Whether to retain the graph on intercepted backwards 67 | :return: A tuple of (block input, block input derivatives). The block inputs are the same shape as the block outptus. 68 | """ 69 | 70 | # Split the arguments channel-wise 71 | if y.dtype == torch.half: 72 | y = y.float() 73 | y1, y2 = torch.chunk(y, 2, dim=self.split_along_dim) 74 | del y 75 | assert not y1.requires_grad, "y1 must already be detached" 76 | assert not y2.requires_grad, "y2 must already be detached" 77 | dy1, dy2 = torch.chunk(dy, 2, dim=self.split_along_dim) 78 | del dy 79 | assert not dy1.requires_grad, "dy1 must not require grad" 80 | assert not dy2.requires_grad, "dy2 must not require grad" 81 | 82 | # Enable autograd for y1 and y2. This ensures that PyTorch 83 | # keeps track of ops. that use y1 and y2 as inputs in a DAG 84 | y1.requires_grad = True 85 | y2.requires_grad = True 86 | 87 | # Ensures that PyTorch tracks the operations in a DAG 88 | with torch.enable_grad(): 89 | self._set_seed("g") 90 | gy1 = self.g_block(y1) 91 | 92 | # Use autograd framework to differentiate the calculation. The 93 | # derivatives of the parameters of G are set as a side effect 94 | gy1.backward(dy2, retain_graph=retain_graph) 95 | 96 | with torch.no_grad(): 97 | x2 = y2 - gy1 # Restore first input of forward() 98 | del y2, gy1 99 | 100 | # The gradient of x1 is the sum of the gradient of the output 101 | # y1 as well as the gradient that flows back through G 102 | # (The gradient that flows back through G is stored in y1.grad) 103 | dx1 = dy1 + y1.grad 104 | del dy1 105 | y1.grad = None 106 | 107 | with torch.enable_grad(): 108 | x2.requires_grad = True 109 | self._set_seed("f") 110 | fx2 = self.f_block(x2) 111 | 112 | # Use autograd framework to differentiate the calculation. The 113 | # derivatives of the parameters of F are set as a side effec 114 | fx2.backward(dx1, retain_graph=retain_graph) 115 | 116 | with torch.no_grad(): 117 | x1 = y1 - fx2 # Restore second input of forward() 118 | del y1, fx2 119 | 120 | # The gradient of x2 is the sum of the gradient of the output 121 | # y2 as well as the gradient that flows back through F 122 | # (The gradient that flows back through F is stored in x2.grad) 123 | dx2 = dy2 + x2.grad 124 | del dy2 125 | x2.grad = None 126 | 127 | # Undo the channelwise split 128 | x = torch.cat([x1, x2.detach()], dim=self.split_along_dim) 129 | dx = torch.cat([dx1, dx2], dim=self.split_along_dim) 130 | 131 | return x, dx 132 | 133 | 134 | class _ReversibleModuleFunction(torch.autograd.function.Function): 135 | """ 136 | Integrates the reversible sequence into the autograd framework 137 | """ 138 | 139 | @staticmethod 140 | def forward(ctx, x, reversible_blocks, eagerly_discard_variables): 141 | """ 142 | Performs the forward pass of a reversible sequence within the autograd framework 143 | :param ctx: autograd context 144 | :param x: input tensor 145 | :param reversible_blocks: nn.Modulelist of reversible blocks 146 | :return: output tensor 147 | """ 148 | assert isinstance(reversible_blocks, nn.ModuleList) 149 | for block in reversible_blocks: 150 | assert isinstance(block, ReversibleBlock) 151 | x = block(x) 152 | ctx.y = ( 153 | x.detach() 154 | ) # not using ctx.save_for_backward(x) saves us memory by beeing able to free ctx.y earlier in the backward pass 155 | ctx.reversible_blocks = reversible_blocks 156 | ctx.eagerly_discard_variables = eagerly_discard_variables 157 | return x 158 | 159 | @staticmethod 160 | def backward(ctx, dy): 161 | """ 162 | Performs the backward pass of a reversible sequence within the autograd framework 163 | :param ctx: autograd context 164 | :param dy: derivatives of the outputs 165 | :return: derivatives of the inputs 166 | """ 167 | y = ctx.y 168 | if ctx.eagerly_discard_variables: 169 | del ctx.y 170 | for i in range(len(ctx.reversible_blocks) - 1, -1, -1): 171 | y, dy = ctx.reversible_blocks[i].backward_pass( 172 | y, dy, not ctx.eagerly_discard_variables 173 | ) 174 | if ctx.eagerly_discard_variables: 175 | del ctx.reversible_blocks 176 | return dy, None, None 177 | 178 | 179 | class ReversibleSequence(nn.Module): 180 | """ 181 | Basic building element for (partially) reversible networks 182 | 183 | A reversible sequence is a sequence of arbitrarly many reversible blocks. The entire sequence is reversible. 184 | The activations are only saved at the end of the sequence. Backpropagation leverages the reversible nature of 185 | the reversible sequece to save memory. 186 | 187 | Arguments: 188 | reversible_blocks (nn.ModuleList): A ModuleList that exclusivly contains instances of ReversibleBlock 189 | which are to be used in the reversible sequence. 190 | eagerly_discard_variables (bool): If set to true backward() discards the variables requried for 191 | calculating the gradient and therefore saves memory. Disable if you call backward() multiple times. 192 | """ 193 | 194 | def __init__(self, reversible_blocks, eagerly_discard_variables=True): 195 | super(ReversibleSequence, self).__init__() 196 | assert isinstance(reversible_blocks, nn.ModuleList) 197 | for block in reversible_blocks: 198 | assert isinstance(block, ReversibleBlock) 199 | 200 | self.reversible_blocks = reversible_blocks 201 | self.eagerly_discard_variables = eagerly_discard_variables 202 | 203 | def forward(self, x): 204 | """ 205 | Forward pass of a reversible sequence 206 | :param x: Input tensor 207 | :return: Output tensor 208 | """ 209 | x = _ReversibleModuleFunction.apply( 210 | x, self.reversible_blocks, self.eagerly_discard_variables 211 | ) 212 | return x 213 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import time 4 | 5 | import hydra 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from omegaconf import DictConfig, open_dict 13 | from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger 14 | from pytorch_lightning.callbacks import ModelCheckpoint 15 | from torch.utils.data import DataLoader, Dataset 16 | import models 17 | import utils 18 | from torch.distributions.categorical import Categorical 19 | 20 | from env import MultiTrajectoryTSP, TSPDataset 21 | 22 | 23 | class TSPModel(pl.LightningModule): 24 | def __init__(self, cfg: DictConfig): 25 | super().__init__() 26 | if cfg.node_dim > 2: 27 | assert ( 28 | "noAug" in cfg.val_type 29 | ), "High-dimension TSP doesn't support augmentation" 30 | 31 | ## Encoder model 32 | if cfg.encoder_type == "mha": 33 | self.encoder = models.MHAEncoder( 34 | n_layers=cfg.n_layers, 35 | n_heads=cfg.n_heads, 36 | embedding_dim=cfg.embedding_dim, 37 | input_dim=24 if cfg.data_augment else 2, 38 | add_init_projection=cfg.add_init_projection, 39 | ) 40 | elif cfg.encoder_type == "revmha": 41 | self.encoder = models.RevMHAEncoder( 42 | n_layers=cfg.n_layers, 43 | n_heads=cfg.n_heads, 44 | embedding_dim=cfg.embedding_dim, 45 | input_dim=24 if cfg.data_augment else 2, 46 | intermediate_dim=cfg.embedding_dim * 4, 47 | add_init_projection=cfg.add_init_projection, 48 | ) 49 | 50 | ## Decoder model 51 | if cfg.decoder_type == "DecoderForLarge": 52 | self.decoder = models.DecoderForLarge( 53 | embedding_dim=cfg.embedding_dim, 54 | n_heads=cfg.n_heads, 55 | tanh_clipping=cfg.tanh_clipping, 56 | multi_pointer=cfg.multi_pointer, 57 | multi_pointer_level=cfg.multi_pointer_level, 58 | add_more_query=cfg.add_more_query, 59 | ) 60 | else: 61 | self.decoder = models.Decoder( 62 | embedding_dim=cfg.embedding_dim, 63 | n_heads=cfg.n_heads, 64 | tanh_clipping=cfg.tanh_clipping, 65 | ) 66 | 67 | self.cfg = cfg 68 | self.save_hyperparameters(cfg) 69 | 70 | def configure_optimizers(self): 71 | optimizer = torch.optim.Adam( 72 | self.parameters(), 73 | lr=self.cfg.learning_rate * len(self.cfg.gpus), 74 | weight_decay=self.cfg.weight_decay, 75 | ) 76 | lr_scheduler = torch.optim.lr_scheduler.StepLR( 77 | optimizer, step_size=1, gamma=1.0 78 | ) 79 | return [optimizer], [{"scheduler": lr_scheduler, "interval": "epoch"}] 80 | 81 | def train_dataloader(self): 82 | self.group_size = self.cfg.group_size 83 | dataset = TSPDataset( 84 | size=self.cfg.graph_size, 85 | node_dim=self.cfg.node_dim, 86 | num_samples=self.cfg.epoch_size, 87 | data_distribution=self.cfg.data_distribution, 88 | ) 89 | return DataLoader( 90 | dataset, 91 | num_workers=os.cpu_count(), 92 | batch_size=self.cfg.train_batch_size, 93 | pin_memory=True, 94 | ) 95 | 96 | def val_dataloader(self): 97 | dataset = TSPDataset( 98 | size=self.cfg.graph_size, 99 | node_dim=self.cfg.node_dim, 100 | num_samples=self.cfg.val_size, 101 | data_distribution=self.cfg.data_distribution, 102 | data_path=self.cfg.val_data_path, 103 | ) 104 | return DataLoader( 105 | dataset, 106 | batch_size=self.cfg.val_batch_size, 107 | num_workers=os.cpu_count(), 108 | pin_memory=True, 109 | ) 110 | 111 | def training_step(self, batch, _): 112 | B, N, _ = batch.shape 113 | G = self.group_size 114 | 115 | batch_idx_range = torch.arange(B)[:, None].expand(B, G) 116 | group_idx_range = torch.arange(G)[None, :].expand(B, G) 117 | 118 | env = MultiTrajectoryTSP(batch.cpu()) # to cpu for env 119 | s, r, d = env.reset(group_size=G) # reset env 120 | 121 | # Data argument 122 | if self.cfg.data_augment: 123 | batch = utils.data_augment(batch) 124 | 125 | # Encode 126 | embeddings = self.encoder(batch) 127 | self.decoder.reset(batch, embeddings, G) # decoder reset 128 | 129 | entropy_list = [] 130 | log_prob = torch.zeros(B, G, device=self.device) 131 | while not d: 132 | if s.current_node is None: 133 | first_action = torch.randperm(N)[None, :G].expand(B, G) 134 | s, r, d = env.step(first_action) 135 | continue 136 | else: 137 | last_node = s.current_node 138 | 139 | action_probs = self.decoder( 140 | last_node.to(self.device), s.ninf_mask.to(self.device), s.selected_count 141 | ) 142 | m = Categorical(action_probs.reshape(B * G, -1)) 143 | entropy_list.append(m.entropy().mean().item()) 144 | action = m.sample().view(B, G) 145 | chosen_action_prob = ( 146 | action_probs[batch_idx_range, group_idx_range, action].reshape(B, G) 147 | + 1e-8 148 | ) 149 | log_prob += chosen_action_prob.log() 150 | s, r, d = env.step(action.cpu()) 151 | 152 | # Note that when G == 1, we can only use the PG without baseline so far 153 | r_trans = r.to(self.device) # -1/torch.exp(r) 154 | if self.cfg.divide_std: 155 | advantage = ( 156 | (r_trans - r_trans.mean(dim=1, keepdim=True)) 157 | / (r_trans.std(dim=1, unbiased=False, keepdim=True) + 1e-8) 158 | if G != 1 159 | else r_trans 160 | ) 161 | else: 162 | advantage = ( 163 | (r_trans - r_trans.mean(dim=1, keepdim=True)) if G != 1 else r_trans 164 | ) 165 | loss = (-advantage * log_prob).mean() 166 | 167 | length_max = -r.max(dim=1)[0].mean().clone().detach().item() 168 | lenght_mean = -r.mean(1).mean().clone().detach().item() 169 | entropy_mean = sum(entropy_list) / len(entropy_list) 170 | adv = advantage.abs().mean().clone().detach().item() 171 | 172 | self.log( 173 | name="length_max", value=length_max, prog_bar=True, 174 | ) # sync_dist=True 175 | self.log( 176 | name="length_mean", value=lenght_mean, prog_bar=True, 177 | ) 178 | self.log( 179 | name="entropy", value=entropy_mean, prog_bar=True, 180 | ) 181 | self.log(name="adv", value=adv, prog_bar=True) 182 | 183 | embed_max = embeddings.abs().max() 184 | self.log( 185 | name="embed_max", value=embed_max, prog_bar=True, 186 | ) 187 | 188 | assert torch.isnan(loss).sum() == 0, print("loss is nan!") 189 | 190 | return {"loss": loss, "length": length_max, "entropy": entropy_mean} 191 | 192 | def training_epoch_end(self, outputs): 193 | outputs = torch.as_tensor([item["length"] for item in outputs]) 194 | self.train_length_mean = outputs.mean().item() 195 | self.train_length_std = outputs.std().item() 196 | 197 | def validate_all(self, batch, val_type=None, return_pi=False): 198 | batch = utils.augment_xy_data_by_8_fold(batch) 199 | B, N, _ = batch.shape 200 | G = N 201 | 202 | env = MultiTrajectoryTSP(batch) 203 | s, r, d = env.reset(group_size=G) 204 | 205 | if self.cfg.data_augment: 206 | batch = utils.data_augment(batch) 207 | 208 | embeddings = self.encoder(batch) 209 | self.decoder.reset(batch, embeddings, G, trainging=False) 210 | 211 | first_action = torch.randperm(N)[None, :G].expand(B, G).to(self.device) 212 | pi = first_action[..., None] 213 | s, r, d = env.step(first_action) 214 | 215 | while not d: 216 | action_probs = self.decoder( 217 | s.current_node.to(self.device), 218 | s.ninf_mask.to(self.device), 219 | s.selected_count, 220 | ) 221 | action = action_probs.argmax(dim=2) 222 | pi = torch.cat([pi, action[..., None]], dim=-1) 223 | s, r, d = env.step(action) 224 | 225 | B = round(B / 8) 226 | reward = r.reshape(8, B, G) 227 | pi = pi.reshape(8, B, G, N) 228 | 229 | reward_greedy = reward[0, :, 0] 230 | reward_ntraj, idx_dim_ntraj = reward[0, :, :].max(dim=-1) 231 | max_reward_aug_ntraj, idx_dim_2 = reward.max(dim=2) 232 | max_reward_aug_ntraj, idx_dim_0 = max_reward_aug_ntraj.max(dim=0) 233 | 234 | if return_pi: 235 | best_pi_greedy = pi[0, :, 0] 236 | idx_dim_ntraj = idx_dim_ntraj.reshape(B, 1, 1) 237 | best_pi_n_traj = pi[0, :, :].gather( 238 | 1, idx_dim_ntraj.repeat(1, 1, N) 239 | ) # (B,G,N) 240 | idx_dim_0 = idx_dim_0.reshape(1, B, 1, 1) 241 | idx_dim_2 = idx_dim_2.reshape(8, B, 1, 1).gather(0, idx_dim_0) 242 | best_pi_aug_ntraj = pi.gather(0, idx_dim_0.repeat(1, 1, G, N)) 243 | best_pi_aug_ntraj = best_pi_aug_ntraj.gather( 244 | 2, idx_dim_2.repeat(1, 1, 1, N) 245 | ) 246 | return -max_reward_aug_ntraj, best_pi_aug_ntraj.squeeze() 247 | return { 248 | "max_reward_aug_ntraj": -max_reward_aug_ntraj, 249 | "reward_greedy": -reward_greedy, 250 | "reward_ntraj": -reward_ntraj, 251 | } 252 | 253 | def validation_step(self, batch, batch_idx): 254 | with torch.no_grad(): 255 | outputs = self.validate_all(batch) 256 | return outputs 257 | 258 | def validation_step_end(self, batch_parts): 259 | return batch_parts 260 | 261 | def on_validation_epoch_start(self) -> None: 262 | self.validation_start_time = time.time() 263 | 264 | def validation_epoch_end(self, outputs): 265 | max_reward_aug_ntraj = [item["max_reward_aug_ntraj"] for item in outputs] 266 | reward_greedy = [item["reward_greedy"] for item in outputs] 267 | reward_ntraj = [item["reward_ntraj"] for item in outputs] 268 | 269 | self.max_reward_aug_ntraj = torch.cat(max_reward_aug_ntraj).mean().item() 270 | self.max_reward_aug_ntraj_std = torch.cat(max_reward_aug_ntraj).std().item() 271 | self.reward_greedy = torch.cat(reward_greedy).mean().item() 272 | self.reward_greedy_std = torch.cat(reward_greedy).std().item() 273 | self.reward_ntraj = torch.cat(reward_ntraj).mean().item() 274 | self.reward_ntraj_std = torch.cat(reward_ntraj).std().item() 275 | 276 | def on_validation_epoch_end(self): 277 | validation_time = time.time() - self.validation_start_time 278 | 279 | self.log_dict( 280 | { 281 | # "train_graph_size": self.train_graph_size, 282 | # "train_length": self.train_length_mean, 283 | "max_reward_aug_ntraj": self.max_reward_aug_ntraj, 284 | "reward_greedy": self.reward_greedy, 285 | "reward_ntraj": self.reward_ntraj, 286 | }, 287 | on_epoch=True, 288 | on_step=False, 289 | ) 290 | 291 | self.print( 292 | f"\nEpoch {self.current_epoch}: ", 293 | # f'train_graph_size={self.train_graph_size}, ', 294 | # 'train_performance={:.03f}±{:.03f}, '.format(self.train_length_mean, self.train_length_std), 295 | "max_reward_aug_ntraj={:.03f}±{:.03f}, ".format( 296 | self.max_reward_aug_ntraj, self.max_reward_aug_ntraj_std 297 | ), 298 | "reward_greedy={:.03f}±{:.3f}, ".format( 299 | self.reward_greedy, self.reward_greedy_std 300 | ), 301 | "reward_ntraj={:.03f}±{:.03f}, ".format( 302 | self.reward_ntraj, self.reward_ntraj_std 303 | ), 304 | "validation time={:.03f}".format(validation_time), 305 | ) 306 | 307 | 308 | @hydra.main(config_name="config") 309 | def run(cfg: DictConfig) -> None: 310 | pl.seed_everything(cfg.seed) 311 | cfg.run_name = cfg.run_name or cfg.default_run_name 312 | if cfg.save_dir is None: 313 | root_dir = (os.getcwd(),) 314 | elif os.path.isabs(cfg.save_dir): 315 | root_dir = cfg.save_dir 316 | else: 317 | root_dir = os.path.join(hydra.utils.get_original_cwd(), cfg.save_dir) 318 | root_dir = os.path.join(root_dir, f"{cfg.run_name}") 319 | with open_dict(cfg): 320 | cfg.root_dir = root_dir 321 | 322 | cfg.val_data_path = os.path.join(hydra.utils.get_original_cwd(), cfg.val_data_path) 323 | 324 | # build TSPModel 325 | tsp_model = TSPModel(cfg) 326 | checkpoint_callback = ModelCheckpoint( 327 | monitor="reward_ntraj", 328 | dirpath=os.path.join(root_dir, "checkpoints"), 329 | filename=cfg.encoder_type + "{epoch:02d}-{val_length:.2f}", 330 | save_top_k=3, 331 | save_last=True, 332 | mode="min", 333 | every_n_val_epochs=2, 334 | ) 335 | 336 | # set up loggers 337 | loggers = [] 338 | if cfg.tensorboard: 339 | tb_logger = TensorBoardLogger("logs") 340 | loggers.append(tb_logger) 341 | # wandb logger 342 | if cfg.wandb: 343 | os.makedirs(os.path.join(os.path.abspath(root_dir), "wandb")) 344 | wandb_logger = WandbLogger( 345 | name=cfg.run_name, 346 | save_dir=root_dir, 347 | project=cfg.wandb_project, 348 | log_model=True, 349 | save_code=True, 350 | group=time.strftime("%Y%m%d", time.localtime()), 351 | tags=cfg.default_run_name.split("-")[:-1], 352 | ) 353 | wandb_logger.log_hyperparams(cfg) 354 | wandb_logger.watch(tsp_model) 355 | loggers.append(wandb_logger) 356 | 357 | # auto resumed training 358 | last_ckpt_path = os.path.join(checkpoint_callback.dirpath, "last.ckpt") 359 | if os.path.exists(last_ckpt_path): 360 | resume = last_ckpt_path 361 | elif cfg.load_path: 362 | resume = cfg.load_path 363 | else: 364 | resume = None 365 | 366 | # build trainer 367 | trainer = pl.Trainer( 368 | default_root_dir=root_dir, 369 | gpus=cfg.gpus, 370 | accelerator="dp", 371 | precision=cfg.precision, 372 | max_epochs=cfg.total_epoch, 373 | reload_dataloaders_every_epoch=True, 374 | num_sanity_val_steps=0, 375 | resume_from_checkpoint=resume, 376 | logger=loggers, 377 | callbacks=[checkpoint_callback], 378 | ) 379 | 380 | # training and save ckpt 381 | trainer.fit(tsp_model) 382 | trainer.save_checkpoint(os.path.join(root_dir, "checkpoint.ckpt")) 383 | 384 | 385 | if __name__ == "__main__": 386 | run() 387 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | utils.py 3 | """ 4 | 5 | import math 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | 10 | def multi_head_attention(q, k, v, mask=None): 11 | # q shape = (B, n_heads, n, key_dim) : n can be either 1 or N 12 | # k,v shape = (B, n_heads, N, key_dim) 13 | # mask.shape = (B, group, N) 14 | 15 | B, n_heads, n, key_dim = q.shape 16 | 17 | # score.shape = (B, n_heads, n, N) 18 | score = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(q.size(-1)) 19 | 20 | if mask is not None: 21 | score += mask[:, None, :, :].expand_as(score) 22 | 23 | shp = [q.size(0), q.size(-2), q.size(1) * q.size(-1)] 24 | attn = torch.matmul(F.softmax(score, dim=3), v).transpose(1, 2) 25 | return attn.reshape(*shp) 26 | 27 | 28 | def make_heads(qkv, n_heads): 29 | shp = (qkv.size(0), qkv.size(1), n_heads, -1) 30 | return qkv.reshape(*shp).transpose(1, 2) 31 | 32 | 33 | def augment_xy_data_by_8_fold(xy_data, training=False): 34 | # xy_data.shape = [B, N, 2] 35 | # x,y shape = [B, N, 1] 36 | 37 | x = xy_data[:, :, [0]] 38 | y = xy_data[:, :, [1]] 39 | 40 | dat1 = torch.cat((x, y), dim=2) 41 | dat2 = torch.cat((1 - x, y), dim=2) 42 | dat3 = torch.cat((x, 1 - y), dim=2) 43 | dat4 = torch.cat((1 - x, 1 - y), dim=2) 44 | 45 | dat5 = torch.cat((y, x), dim=2) 46 | dat6 = torch.cat((1 - y, x), dim=2) 47 | dat7 = torch.cat((y, 1 - x), dim=2) 48 | dat8 = torch.cat((1 - y, 1 - x), dim=2) 49 | 50 | # data_augmented.shape = [B, N, 16] 51 | if training: 52 | data_augmented = torch.cat( 53 | (dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=2 54 | ) 55 | return data_augmented 56 | 57 | # data_augmented.shape = [8*B, N, 2] 58 | data_augmented = torch.cat((dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=0) 59 | return data_augmented 60 | 61 | 62 | def data_augment(batch): 63 | batch = augment_xy_data_by_8_fold(batch, training=True) 64 | theta = [] 65 | for i in range(8): 66 | theta.append( 67 | torch.atan(batch[:, :, i * 2 + 1] / batch[:, :, i * 2]).unsqueeze(-1) 68 | ) 69 | theta.append(batch) 70 | batch = torch.cat(theta, dim=2) 71 | return batch 72 | --------------------------------------------------------------------------------