├── agent ├── __init__.py ├── agent.py ├── utils.py └── ppo.py ├── nets ├── __init__.py ├── critic_network.py ├── actor_network.py └── graph_layers.py ├── problems ├── __init__.py ├── problem_pdtsp.py ├── problem_pdtspl.py └── problem_pdp.py ├── utils ├── __init__.py ├── utils.py └── logger.py ├── .gitignore ├── run.py ├── README.md └── options.py /agent/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /problems/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * -------------------------------------------------------------------------------- /agent/agent.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Tuple 2 | from abc import ABC, abstractmethod 3 | import torch 4 | from tensorboard_logger import Logger as TbLogger 5 | 6 | from nets.actor_network import Actor_N2S, Actor_Construct 7 | from nets.critic_network import Critic_N2S, Critic_Construct 8 | from options import Option 9 | from problems.problem_pdp import PDP 10 | 11 | 12 | class Agent(ABC): 13 | opts: Option 14 | actor: Actor_N2S 15 | critic: Critic_N2S 16 | actor_construct: Actor_Construct 17 | critic_construct: Critic_Construct 18 | optimizer: torch.optim.Optimizer 19 | optimizer_sc: torch.optim.Optimizer 20 | lr_scheduler: torch.optim.lr_scheduler.ExponentialLR 21 | 22 | @abstractmethod 23 | def __init__(self, problem_name: str, size: int, opts: Option) -> None: 24 | pass 25 | 26 | @abstractmethod 27 | def load(self, load_path: str) -> None: 28 | pass 29 | 30 | @abstractmethod 31 | def save(self, epoch: int) -> None: 32 | pass 33 | 34 | @abstractmethod 35 | def eval(self) -> None: 36 | pass 37 | 38 | @abstractmethod 39 | def train(self) -> None: 40 | pass 41 | 42 | @abstractmethod 43 | def rollout( 44 | self, 45 | problem: PDP, 46 | val_m: int, 47 | batch: Dict[str, torch.Tensor], 48 | show_bar: bool, 49 | zoom: bool, 50 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 51 | pass 52 | 53 | @abstractmethod 54 | def start_inference( 55 | self, 56 | problem: PDP, 57 | val_dataset: Optional[str], 58 | tb_logger: Optional[TbLogger], 59 | load_path: Optional[str], 60 | zoom: bool, 61 | ) -> None: 62 | pass 63 | 64 | @abstractmethod 65 | def start_training( 66 | self, 67 | problem: PDP, 68 | val_dataset: Optional[str], 69 | tb_logger: Optional[TbLogger], 70 | load_path: Optional[str], 71 | ) -> None: 72 | pass 73 | -------------------------------------------------------------------------------- /nets/critic_network.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Tuple 2 | from torch import nn 3 | import torch 4 | 5 | from .graph_layers import CriticEncoder, CriticDecoder 6 | 7 | 8 | class Critic_N2S(nn.Module): 9 | def __init__( 10 | self, 11 | embedding_dim: int, 12 | ff_hidden_dim: int, 13 | n_heads: int, 14 | n_layers: int, 15 | normalization: str, 16 | ) -> None: 17 | 18 | super().__init__() 19 | self.embedding_dim = embedding_dim 20 | self.ff_hidden_dim = ff_hidden_dim 21 | self.n_heads = n_heads 22 | self.n_layers = n_layers 23 | self.normalization = normalization 24 | self.encoder = nn.Sequential( 25 | *( 26 | CriticEncoder( 27 | self.n_heads, 28 | self.embedding_dim, 29 | self.ff_hidden_dim, 30 | self.normalization, 31 | ) 32 | for _ in range(1) 33 | ) 34 | ) 35 | 36 | self.decoder = CriticDecoder(self.embedding_dim) 37 | 38 | __call__: Callable[..., Tuple[torch.Tensor, torch.Tensor]] 39 | 40 | def forward( 41 | self, h_wave: torch.Tensor, best_cost: torch.Tensor 42 | ) -> Tuple[torch.Tensor, torch.Tensor]: 43 | 44 | y = self.encoder(h_wave.detach()) 45 | baseline_value = self.decoder(y, best_cost) 46 | 47 | return baseline_value.detach().squeeze(), baseline_value.squeeze() 48 | 49 | 50 | class Critic_Construct(nn.Module): 51 | def __init__(self) -> None: 52 | super().__init__() 53 | 54 | self.trust_degree = nn.Parameter(torch.tensor(0.0)) 55 | 56 | __call__: Callable[..., Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] 57 | 58 | def forward( 59 | self, obj_of_n2s: List[torch.Tensor], bl_val_detached_list: List[torch.Tensor] 60 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 61 | bl_construct = torch.stack(obj_of_n2s) - self.trust_degree * torch.stack( 62 | bl_val_detached_list 63 | ) 64 | return ( 65 | bl_construct.mean(0).detach(), 66 | bl_construct.mean(0), 67 | torch.tensor(self.trust_degree), 68 | ) 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | logs/ 3 | outputs/ 4 | datasets/ 5 | local_only/ 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Iterator, List, Tuple, Union 2 | import torch 3 | from torch import nn 4 | import math 5 | import numpy as np 6 | from torch.nn.parallel import DistributedDataParallel as DDP 7 | 8 | 9 | def get_rotate_mat(theta_f: float) -> torch.Tensor: 10 | theta = torch.tensor(theta_f) 11 | return torch.tensor( 12 | [[torch.cos(theta), -torch.sin(theta)], [torch.sin(theta), torch.cos(theta)]] 13 | ) 14 | 15 | 16 | def rotate_tensor(x: torch.Tensor, d: float) -> torch.Tensor: 17 | rot_mat = get_rotate_mat(d / 360 * 2 * np.pi).to(x.device) 18 | return torch.matmul(x - 0.5, rot_mat) + 0.5 19 | 20 | 21 | def torch_load_cpu(load_path: str) -> Dict[str, Any]: 22 | return torch.load( 23 | load_path, map_location=lambda storage, loc: storage 24 | ) # Load on CPU 25 | 26 | 27 | def get_inner_model(model: Union[nn.Module, DDP]) -> nn.Module: 28 | return model.module if isinstance(model, DDP) else model 29 | 30 | 31 | def move_to(var: Any, device: Union[int, torch.device]) -> Any: 32 | if isinstance(var, dict): 33 | return {k: move_to(v, device) for k, v in var.items()} 34 | return var.to(device) 35 | 36 | 37 | def clip_grad_norms( 38 | param_groups: List[dict], max_norm: float = math.inf 39 | ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: 40 | """ 41 | Clips the norms for all param groups to max_norm and returns gradient norms before clipping 42 | :param optimizer: 43 | :param max_norm: 44 | :param gradient_norms_log: 45 | :return: grad_norms, clipped_grad_norms: list with (clipped) gradient norms per group 46 | """ 47 | grad_norms = [ 48 | torch.nn.utils.clip_grad_norm_( 49 | group['params'], 50 | max_norm 51 | if max_norm > 0 52 | else math.inf, # Inf so no clipping but still call to calc 53 | norm_type=2, 54 | ) 55 | for group in param_groups 56 | ] 57 | grad_norms_clipped = ( 58 | [ 59 | min(g_norm, torch.tensor(max_norm), key=lambda x: x.item()) 60 | for g_norm in grad_norms 61 | ] 62 | if max_norm > 0 63 | else grad_norms 64 | ) 65 | return grad_norms, grad_norms_clipped 66 | 67 | 68 | def batch_picker(total: int, batch: int) -> Iterator[int]: 69 | assert total >= 0 and batch >= 1 70 | 71 | remain = total 72 | while (remain := remain - batch) > (-batch): 73 | if remain >= 0: 74 | pick_count = batch 75 | else: 76 | pick_count = remain + batch 77 | yield pick_count 78 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | import os 3 | import json 4 | import torch 5 | import pprint 6 | import random 7 | from tensorboard_logger import Logger as TbLogger 8 | import warnings 9 | 10 | from options import get_options, Option 11 | from problems.problem_pdp import PDP 12 | from problems.problem_pdtsp import PDTSP 13 | from problems.problem_pdtspl import PDTSPL 14 | from agent.agent import Agent 15 | from agent.ppo import PPO 16 | 17 | 18 | def load_agent(name: str) -> Type[Agent]: 19 | agent = { 20 | 'ppo': PPO, 21 | }.get(name, None) 22 | assert agent is not None, "Currently unsupported agent: {}!".format(name) 23 | return agent 24 | 25 | 26 | def load_problem(name: str) -> Type[PDP]: 27 | d = { 28 | 'pdtsp': PDTSP, 29 | 'pdtspl': PDTSPL, 30 | } 31 | problem = d.get(name, None) 32 | assert problem is not None, "Currently unsupported problem: {}!".format(name) 33 | return problem 34 | 35 | 36 | def run(opts: Option) -> None: 37 | # Pretty print the run args 38 | pprint.pprint(vars(opts)) 39 | 40 | # Set the random seed to initialize neural networks 41 | torch.manual_seed(opts.seed) 42 | random.seed(opts.seed) 43 | 44 | # Optionally configure tensorboard 45 | tb_logger = None 46 | if not opts.no_tb and not opts.distributed: 47 | tb_logger = TbLogger( 48 | os.path.join( 49 | opts.log_dir, 50 | "{}_{}".format(opts.problem, opts.graph_size), 51 | opts.run_name, 52 | ) 53 | ) 54 | if not opts.no_saving and not os.path.exists(opts.save_dir): 55 | os.makedirs(opts.save_dir) 56 | 57 | # Save arguments so exact configuration can always be found 58 | if not opts.no_saving: 59 | with open(os.path.join(opts.save_dir, "args.json"), 'w') as f: 60 | json.dump(vars(opts), f, indent=True) 61 | 62 | # Set the device 63 | opts.device = torch.device("cuda" if opts.use_cuda else "cpu") 64 | 65 | # Figure out what's the problem 66 | problem = load_problem(opts.problem)( 67 | size=opts.graph_size, 68 | init_val_method=opts.init_val_method, 69 | check_feasible=opts.use_assert, 70 | ) 71 | 72 | # Figure out the RL algorithm 73 | agent = load_agent(opts.RL_agent)(problem.name, problem.size, opts) 74 | 75 | # Load data from load_path 76 | assert ( 77 | opts.load_path is None or opts.resume is None 78 | ), "Only one of load path and resume can be given" 79 | load_path = opts.load_path if opts.load_path is not None else opts.resume 80 | 81 | # Do validation only 82 | if opts.eval_only: 83 | # Load the validation datasets 84 | agent.start_inference( 85 | problem, opts.val_dataset, tb_logger, load_path, zoom=opts.zoom 86 | ) 87 | 88 | else: 89 | if opts.resume: 90 | epoch_resume = int( 91 | os.path.splitext(os.path.split(opts.resume)[-1])[0].split("-")[1] 92 | ) 93 | print("Resuming after {}".format(epoch_resume)) 94 | agent.opts.epoch_start = epoch_resume + 1 95 | 96 | # Start the actual training loop 97 | agent.start_training(problem, opts.val_dataset, tb_logger, load_path) 98 | 99 | 100 | if __name__ == "__main__": 101 | warnings.filterwarnings("ignore") 102 | 103 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 104 | torch.backends.cudnn.deterministic = True 105 | torch.backends.cudnn.benchmark = False 106 | 107 | run(get_options()) 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PDP-NCS 2 | 3 | This repo implements our paper: 4 | 5 | Detian Kong, Yining Ma, Zhiguang Cao, Tianshu Yu and Jianhua Xiao, "[Efficient Neural Collaborative Search for Pickup and Delivery Problems](https://www.researchgate.net/publication/383328030_Efficient_Neural_Collaborative_Search_for_Pickup_and_Delivery_Problems)" in the IEEE Transactions on Pattern Analysis and Machine Intelligence. Please cite our paper if the work is useful to you. 6 | 7 | ``` 8 | @article{kong2024efficient, 9 | title={Efficient Neural Collaborative Search for Pickup and Delivery Problems}, 10 | author={Kong, Detian and Ma, Yining and Cao, Zhiguang and Yu, Tianshu and Xiao, jianhua}, 11 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 12 | year={2024}, 13 | publisher={IEEE}, 14 | volume={46}, 15 | number={12}, 16 | pages={11019-11034}, 17 | keywords={Computational modeling;Collaboration;Modeling;Biological system modeling;Search problems;Decoding;Training;Attention mechanism;deep reinforcement learning;learning to optimize;neighborhood search;pickup and delivery}, 18 | doi={10.1109/TPAMI.2024.3450850} 19 | } 20 | ``` 21 | 22 | ## Dependencies 23 | * Python>=3.8 24 | * PyTorch>=1.7 25 | * tensorboard_logger 26 | * tqdm 27 | 28 | ## Usage 29 | 30 | ### Training 31 | 32 | #### PDTSP examples 33 | 34 | 21 nodes: 35 | ```bash 36 | python run.py --problem pdtsp --graph_size 20 --shared_critic 37 | ``` 38 | 39 | 51 nodes: 40 | ```bash 41 | python run.py --problem pdtsp --graph_size 50 --shared_critic 42 | ``` 43 | 44 | 101 nodes: 45 | ```bash 46 | python run.py --problem pdtsp --graph_size 100 --shared_critic 47 | ``` 48 | 49 | #### PDTSP-LIFO examples 50 | 51 | 21 nodes: 52 | ```bash 53 | python run.py --problem pdtspl --graph_size 20 --shared_critic 54 | ``` 55 | 56 | 51 nodes: 57 | ```bash 58 | python run.py --problem pdtspl --graph_size 50 --shared_critic 59 | ``` 60 | 61 | 101 nodes: 62 | ```bash 63 | python run.py --problem pdtspl --graph_size 100 --shared_critic 64 | ``` 65 | 66 | If encountered "RuntimeError: CUDA out of memory", please try smaller batch size by adding option ```--batch_size xxx``` (default is 600). 67 | 68 | ### Inference 69 | 70 | Load the model and specify the iteration T for inference (using --val_m for data augments): 71 | 72 | ```bash 73 | --eval_only 74 | --load_path '{add model to load here}' 75 | --T_max 3000 76 | --val_size 2000 77 | --val_batch_size 200 78 | --val_dataset '{add dataset here}' 79 | --val_m 50 80 | ``` 81 | 82 | #### Examples 83 | 84 | For inference 2,000 PDTSPL instances with 100 nodes and no data augment (NCS): 85 | 86 | ```bash 87 | python run.py --eval_only --no_saving --no_tb --problem pdtspl --graph_size 100 --val_m 1 --val_dataset './datasets/pdp_100.pkl' --load_path './pre-trained/ncs/pdtspl_100/epoch-198.pt' --val_size 2000 --val_batch_size 2000 --T_max 3000 --shared_critic 88 | ``` 89 | 90 | For inference 2,000 PDTSPL instances with 100 nodes using the augments (NCS-A): 91 | 92 | ```bash 93 | python run.py --eval_only --no_saving --no_tb --problem pdtspl --graph_size 100 --val_m 50 --val_dataset './datasets/pdp_100.pkl' --load_path './pre-trained/ncs/pdtspl_100/epoch-198.pt' --val_size 2000 --val_batch_size 200 --T_max 3000 --shared_critic 94 | ``` 95 | 96 | Run ```python run.py -h``` for detailed help on the meaning of each argument. 97 | 98 | Datasets for validation and pre-trained model can be found in [release](https://github.com/dtkon/PDP-NCS/releases) of this repos. 99 | 100 | ## Acknowledgements 101 | The code and the framework are derived from the repos [yining043/PDP-N2S](https://github.com/yining043/PDP-N2S). 102 | -------------------------------------------------------------------------------- /problems/problem_pdtsp.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | 4 | from .problem_pdp import PDP 5 | 6 | 7 | class PDTSP(PDP): 8 | def __init__( 9 | self, size: int, init_val_method: str, check_feasible: bool = False 10 | ) -> None: 11 | super().__init__(size, init_val_method, check_feasible) 12 | 13 | self.name = 'pdtsp' # Pickup and Delivery TSP 14 | 15 | print( 16 | f'PDTSP with {self.size} nodes.', 17 | ' Do assert:', 18 | check_feasible, 19 | ) 20 | 21 | @staticmethod 22 | def get_swap_mask( 23 | selected_node: torch.Tensor, 24 | visit_index: torch.Tensor, 25 | top2: torch.Tensor, 26 | ) -> torch.Tensor: 27 | visited_order_map = PDP._get_visit_order_map(visit_index) 28 | batch_size, graph_size_plus1, _ = visited_order_map.size() 29 | 30 | mask = visited_order_map.clone() # true means unavailable 31 | arange = torch.arange(batch_size) 32 | mask[arange, selected_node.view(-1)] = True 33 | mask[arange, selected_node.view(-1) + graph_size_plus1 // 2] = True 34 | mask[arange, :, selected_node.view(-1)] = True 35 | mask[arange, :, selected_node.view(-1) + graph_size_plus1 // 2] = True 36 | 37 | return mask 38 | 39 | def get_initial_solutions(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: 40 | 41 | batch_size = batch['coordinates'].size(0) 42 | 43 | def get_solution(methods: str) -> torch.Tensor: 44 | 45 | half_size = self.size // 2 46 | 47 | if methods == 'random': 48 | candidates = torch.ones(batch_size, self.size + 1).bool() # all Ture 49 | candidates[:, half_size + 1 :] = 0 # set to False 50 | solution = torch.zeros(batch_size, self.size + 1).long() 51 | selected_node = torch.zeros(batch_size, 1).long() 52 | candidates.scatter_(1, selected_node, 0) # set to False 53 | 54 | for _ in range(self.size): 55 | dists: torch.Tensor = torch.ones(batch_size, self.size + 1) 56 | dists[~candidates] = -1e20 57 | dists = torch.softmax(dists, -1) 58 | next_selected_node = dists.multinomial(1).view(-1, 1) 59 | 60 | add_index = (next_selected_node <= half_size).view(-1) 61 | pairing = ( 62 | next_selected_node[next_selected_node <= half_size].view(-1, 1) 63 | + half_size 64 | ) 65 | candidates[add_index] = candidates[add_index].scatter_( 66 | 1, pairing, 1 67 | ) 68 | 69 | solution.scatter_(1, selected_node, next_selected_node) 70 | candidates.scatter_(1, next_selected_node, 0) 71 | selected_node = next_selected_node 72 | 73 | return solution 74 | 75 | elif methods == 'greedy': 76 | candidates = torch.ones(batch_size, self.size + 1).bool() 77 | candidates[:, half_size + 1 :] = 0 78 | solution = torch.zeros(batch_size, self.size + 1).long() 79 | selected_node = torch.zeros(batch_size, 1).long() 80 | candidates.scatter_(1, selected_node, 0) 81 | 82 | for _ in range(self.size): 83 | d1 = ( 84 | batch['coordinates'] 85 | .cpu() 86 | .gather( 87 | 1, 88 | selected_node.unsqueeze(-1).expand( 89 | batch_size, self.size + 1, 2 90 | ), 91 | ) 92 | ) 93 | d2 = batch['coordinates'].cpu() # (batch_size, graph_size+1, 2) 94 | 95 | dists = (d1 - d2).norm(p=2, dim=2) # (batch_size, graph_size+1) 96 | dists[~candidates] = 1e6 97 | next_selected_node = dists.min(-1)[1].view(-1, 1) 98 | 99 | add_index = (next_selected_node <= half_size).view(-1) 100 | pairing = ( 101 | next_selected_node[next_selected_node <= half_size].view(-1, 1) 102 | + half_size 103 | ) 104 | candidates[add_index] = candidates[add_index].scatter_( 105 | 1, pairing, 1 106 | ) 107 | 108 | solution.scatter_(1, selected_node, next_selected_node) 109 | candidates.scatter_(1, next_selected_node, 0) 110 | selected_node = next_selected_node 111 | 112 | return solution 113 | 114 | else: 115 | raise NotImplementedError() 116 | 117 | return ( 118 | get_solution(self.init_val_method).expand(batch_size, self.size + 1).clone() 119 | ) 120 | 121 | def _check_feasibility(self, solution: torch.Tensor) -> None: 122 | assert ( 123 | (torch.arange(self.size + 1, out=solution.new())) 124 | .view(1, -1) 125 | .expand_as(solution) 126 | == solution.sort(1)[0] 127 | ).all(), ( 128 | ( 129 | (torch.arange(self.size + 1, out=solution.new())) 130 | .view(1, -1) 131 | .expand_as(solution) 132 | == solution.sort(1)[0] 133 | ), 134 | "not visiting all nodes", 135 | solution, 136 | ) 137 | 138 | # calculate visited time 139 | batch_size = solution.size(0) 140 | visited_time = torch.zeros((batch_size, self.size), device=solution.device) 141 | pre = torch.zeros(batch_size, device=solution.device).long() 142 | for i in range(self.size): 143 | visited_time[ 144 | torch.arange(batch_size), solution[torch.arange(batch_size), pre] - 1 145 | ] = (i + 1) 146 | pre = solution[torch.arange(batch_size), pre] 147 | 148 | assert ( 149 | visited_time[:, 0 : self.size // 2] < visited_time[:, self.size // 2 :] 150 | ).all(), ( 151 | visited_time[:, 0 : self.size // 2] < visited_time[:, self.size // 2 :], 152 | "deliverying without pick-up", 153 | ) 154 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | import torch 3 | import math 4 | from tensorboard_logger import Logger as TbLogger 5 | from agent.agent import Agent 6 | 7 | 8 | def log_to_screen( 9 | time_used: torch.Tensor, 10 | init_value: torch.Tensor, 11 | best_value: torch.Tensor, 12 | reward: torch.Tensor, 13 | costs_history: torch.Tensor, 14 | search_history: torch.Tensor, 15 | batch_size: int, 16 | dataset_size: int, 17 | T: int, 18 | ) -> None: 19 | # reward 20 | print('\n', '-' * 60) 21 | print( 22 | 'Avg total reward:'.center(35), 23 | '{:<10f} +- {:<10f}'.format( 24 | reward.sum(1).mean(), torch.std(reward.sum(1)) / math.sqrt(batch_size) 25 | ), 26 | ) 27 | print( 28 | 'Avg step reward:'.center(35), 29 | '{:<10f} +- {:<10f}'.format( 30 | reward.mean(), torch.std(reward) / math.sqrt(batch_size) 31 | ), 32 | ) 33 | 34 | # cost 35 | print('-' * 60) 36 | print( 37 | 'Avg init cost:'.center(35), 38 | '{:<10f} +- {:<10f}'.format( 39 | init_value.mean(), torch.std(init_value) / math.sqrt(batch_size) 40 | ), 41 | ) 42 | for per in range(500, T, 500): 43 | cost_ = costs_history[:, per] 44 | print( 45 | f'Avg cost after T={per} steps:'.center(35), 46 | '{:<10f} +- {:<10f}'.format( 47 | cost_.mean(), torch.std(cost_) / math.sqrt(batch_size) 48 | ), 49 | ) 50 | # best cost 51 | print('-' * 60) 52 | 53 | for per in range(500, T, 500): 54 | cost_ = search_history[:, per] 55 | print( 56 | f'Avg best cost after T={per} steps:'.center(35), 57 | '{:<10f} +- {:<10f}'.format( 58 | cost_.mean(), torch.std(cost_) / math.sqrt(batch_size) 59 | ), 60 | ) 61 | print( 62 | f'Avg final best cost:'.center(35), 63 | '{:<10f} +- {:<10f}'.format( 64 | best_value.mean(), torch.std(best_value) / math.sqrt(batch_size) 65 | ), 66 | ) 67 | print('Final best cost:'.center(35), '{:f}'.format(best_value.min())) 68 | 69 | # time 70 | print('-' * 60) 71 | print('Avg used time:'.center(35), '{:f}s'.format(time_used.mean() / dataset_size)) 72 | print('-' * 60, '\n') 73 | 74 | 75 | def log_to_tb_val( 76 | tb_logger: TbLogger, 77 | time_used: torch.Tensor, 78 | init_value: torch.Tensor, 79 | best_value: torch.Tensor, 80 | reward: torch.Tensor, 81 | costs_history: torch.Tensor, 82 | search_history: torch.Tensor, 83 | batch_size: int, 84 | val_size: int, 85 | dataset_size: int, 86 | T: int, 87 | epoch: Optional[int], 88 | ) -> None: 89 | 90 | tb_logger.log_value('validation/avg_time', time_used.mean() / dataset_size, epoch) 91 | tb_logger.log_value('validation/avg_total_reward', reward.sum(1).mean(), epoch) 92 | tb_logger.log_value('validation/avg_step_reward', reward.mean(), epoch) 93 | 94 | tb_logger.log_value(f'validation/avg_init_cost', init_value.mean(), epoch) 95 | tb_logger.log_value(f'validation/avg_best_cost', best_value.mean(), epoch) 96 | 97 | for per in range(20, 100, 20): 98 | cost_ = costs_history[:, round(T * per / 100)] 99 | tb_logger.log_value(f'validation/avg_.{per}_cost', cost_.mean(), epoch) 100 | 101 | 102 | def log_to_tb_train( 103 | tb_logger: TbLogger, 104 | batch0_depot: torch.Tensor, 105 | agent: Agent, 106 | Reward: torch.Tensor, 107 | ratios: torch.Tensor, 108 | bl_val_detached: torch.Tensor, 109 | total_cost: torch.Tensor, 110 | grad_norms_tuple: Tuple[List[torch.Tensor], List[torch.Tensor]], 111 | reward: List[torch.Tensor], 112 | entropy: torch.Tensor, 113 | approx_kl_divergence: torch.Tensor, 114 | reinforce_loss: torch.Tensor, 115 | baseline_loss: torch.Tensor, 116 | log_likelihood: torch.Tensor, 117 | initial_cost: torch.Tensor, 118 | mini_step: int, 119 | construct_obj: Optional[torch.Tensor], 120 | reinforce_loss_construct: Optional[torch.Tensor], 121 | bl_construct: Optional[torch.Tensor], 122 | baseline_loss_construct: Optional[torch.Tensor], 123 | trust_degree: Optional[torch.Tensor], 124 | ratios_construct: Optional[torch.Tensor], 125 | imitation_loss: Optional[torch.Tensor], 126 | grad_norms_imi_tuple: Optional[Tuple[List[torch.Tensor], List[torch.Tensor]]], 127 | ) -> None: 128 | 129 | tb_logger.log_value( 130 | 'learnrate_pg', agent.optimizer.param_groups[0]['lr'], mini_step 131 | ) 132 | avg_cost = (total_cost).mean().item() 133 | tb_logger.log_value('train/avg_cost', avg_cost, mini_step) 134 | 135 | tb_logger.log_value('train/batch0_depot_x', batch0_depot[0], mini_step) 136 | 137 | if construct_obj is not None: 138 | tb_logger.log_value( 139 | 'train/avg_construct_cost', construct_obj.mean().item(), mini_step 140 | ) 141 | 142 | tb_logger.log_value('train/Target_Return', Reward.mean().item(), mini_step) 143 | tb_logger.log_value('train/ratios', ratios.mean().item(), mini_step) 144 | 145 | if ratios_construct is not None: 146 | tb_logger.log_value( 147 | 'train/ratios_construct', ratios_construct.mean().item(), mini_step 148 | ) 149 | 150 | if trust_degree is not None: 151 | tb_logger.log_value('train/trust_degree', trust_degree.item(), mini_step) 152 | 153 | avg_reward = torch.stack(reward, 0).sum(0).mean().item() 154 | max_reward = torch.stack(reward, 0).max(0)[0].mean().item() 155 | tb_logger.log_value('train/avg_reward', avg_reward, mini_step) 156 | tb_logger.log_value('train/init_cost', initial_cost.mean(), mini_step) 157 | tb_logger.log_value('train/max_reward', max_reward, mini_step) 158 | grad_norms, grad_norms_clipped = grad_norms_tuple 159 | tb_logger.log_value('loss/actor_loss', reinforce_loss.item(), mini_step) 160 | 161 | if reinforce_loss_construct is not None: 162 | tb_logger.log_value( 163 | 'loss/actor_construct_loss', reinforce_loss_construct.item(), mini_step 164 | ) 165 | 166 | if imitation_loss is not None: 167 | tb_logger.log_value('loss/imitation_loss', imitation_loss.item(), mini_step) 168 | 169 | tb_logger.log_value('loss/nll', -log_likelihood.mean().item(), mini_step) 170 | tb_logger.log_value('train/entropy', entropy.mean().item(), mini_step) 171 | tb_logger.log_value( 172 | 'train/approx_kl_divergence', approx_kl_divergence.item(), mini_step 173 | ) 174 | tb_logger.log_value('train/bl_val', bl_val_detached.mean().item(), mini_step) 175 | 176 | if bl_construct is not None: 177 | tb_logger.log_value('train/bl_construct', bl_construct.mean().item(), mini_step) 178 | 179 | tb_logger.log_value('grad/actor', grad_norms[0], mini_step) 180 | tb_logger.log_value('grad_clipped/actor', grad_norms_clipped[0], mini_step) 181 | tb_logger.log_value('loss/critic_loss', baseline_loss.item(), mini_step) 182 | 183 | if baseline_loss_construct is not None: 184 | tb_logger.log_value( 185 | 'loss/critic_construct_loss', baseline_loss_construct.item(), mini_step 186 | ) 187 | 188 | tb_logger.log_value( 189 | 'loss/total_loss', (reinforce_loss + baseline_loss).item(), mini_step 190 | ) 191 | 192 | tb_logger.log_value('grad/critic', grad_norms[1], mini_step) 193 | tb_logger.log_value('grad_clipped/critic', grad_norms_clipped[1], mini_step) 194 | 195 | if agent.opts.shared_critic: 196 | tb_logger.log_value('grad/actor_construct', grad_norms[2], mini_step) 197 | tb_logger.log_value( 198 | 'grad_clipped/actor_construct', grad_norms_clipped[2], mini_step 199 | ) 200 | tb_logger.log_value('grad/critic_construct', grad_norms[3], mini_step) 201 | tb_logger.log_value( 202 | 'grad_clipped/critic_construct', grad_norms_clipped[3], mini_step 203 | ) 204 | if grad_norms_imi_tuple is not None: 205 | grad_norms_imi, grad_norms_imi_clipped = grad_norms_imi_tuple 206 | tb_logger.log_value( 207 | 'grad/actor_construct_imitation', grad_norms_imi[0], mini_step 208 | ) 209 | tb_logger.log_value( 210 | 'grad_clipped/actor_construct_imitation', 211 | grad_norms_imi_clipped[0], 212 | mini_step, 213 | ) 214 | -------------------------------------------------------------------------------- /problems/problem_pdtspl.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | 4 | from .problem_pdp import PDP 5 | 6 | 7 | class PDTSPL(PDP): 8 | def __init__( 9 | self, size: int, init_val_method: str, check_feasible: bool = False 10 | ) -> None: 11 | super().__init__(size, init_val_method, check_feasible) 12 | 13 | self.name = 'pdtspl' # Pickup and Delivery TSP with LIFO constriant 14 | 15 | print( 16 | f'PDTSP-LIFO with {self.size} nodes.', 17 | ' Do assert:', 18 | check_feasible, 19 | ) 20 | 21 | @staticmethod 22 | def get_swap_mask( 23 | selected_node: torch.Tensor, 24 | visit_index: torch.Tensor, 25 | top2: torch.Tensor, 26 | ) -> torch.Tensor: 27 | visited_order_map = PDP._get_visit_order_map(visit_index) 28 | batch_size, graph_size_plus1, _ = visited_order_map.size() 29 | arange = torch.arange(batch_size) 30 | top = torch.where(top2[:, :, 0] == selected_node, top2[:, :, 1], top2[:, :, 0]) 31 | mask_pd = top.view(-1, graph_size_plus1, 1) != top.view(-1, 1, graph_size_plus1) 32 | 33 | mask = visited_order_map.clone() # true means unavailable 34 | mask[arange, selected_node.view(-1)] = True 35 | mask[arange, selected_node.view(-1) + graph_size_plus1 // 2] = True 36 | mask[arange, :, selected_node.view(-1)] = True 37 | mask[arange, :, selected_node.view(-1) + graph_size_plus1 // 2] = True 38 | 39 | return mask | mask_pd 40 | 41 | def get_initial_solutions(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: 42 | 43 | batch_size = batch['coordinates'].size(0) 44 | 45 | def get_solution(methods: str) -> torch.Tensor: 46 | 47 | half_size = self.size // 2 48 | 49 | if methods == 'random': 50 | candidates = torch.ones(batch_size, self.size + 1).bool() 51 | candidates[:, half_size + 1 :] = 0 52 | solution = torch.zeros(batch_size, self.size + 1).long() 53 | selected_node = torch.zeros(batch_size, 1).long() 54 | candidates.scatter_(1, selected_node, 0) 55 | stacks = ( 56 | torch.zeros(batch_size, half_size + 1) - 0.01 57 | ) # fix bug: max is not stable sorting 58 | stacks[:, 0] = 0 # fix bug: max is not stable sorting 59 | for i in range(self.size): 60 | index1 = (selected_node <= half_size) & (selected_node > 0) 61 | if index1.any(): 62 | stacks[index1.view(-1), selected_node[index1]] = i + 1 63 | top = stacks.max(-1)[1] 64 | 65 | dists = torch.ones(batch_size, self.size + 1) 66 | dists[~candidates] = -1e20 67 | dists[top > 0, top[top > 0] + half_size] = 1 68 | dists = torch.softmax(dists, -1) 69 | next_selected_node = dists.multinomial(1).view(-1, 1) 70 | index2 = (next_selected_node > half_size) & (next_selected_node > 0) 71 | if index2.any(): 72 | stacks[ 73 | index2.view(-1), next_selected_node[index2] - half_size 74 | ] = -0.01 75 | 76 | solution.scatter_(1, selected_node, next_selected_node) 77 | candidates.scatter_(1, next_selected_node, 0) 78 | selected_node = next_selected_node 79 | 80 | return solution 81 | 82 | elif methods == 'greedy': 83 | 84 | candidates = torch.ones(batch_size, self.size + 1).bool() 85 | candidates[:, half_size + 1 :] = 0 86 | solution = torch.zeros(batch_size, self.size + 1).long() 87 | selected_node = torch.zeros(batch_size, 1).long() 88 | candidates.scatter_(1, selected_node, 0) 89 | stacks = torch.zeros(batch_size, half_size + 1) - 0.01 90 | stacks[:, 0] = 0 # fix bug: max is not stable sorting 91 | for i in range(self.size): 92 | 93 | index1 = (selected_node <= half_size) & (selected_node > 0) 94 | if index1.any(): 95 | stacks[index1.view(-1), selected_node[index1]] = i + 1 96 | top = stacks.max(-1)[1] 97 | 98 | d1 = ( 99 | batch['coordinates'] 100 | .cpu() 101 | .gather( 102 | 1, 103 | selected_node.unsqueeze(-1).expand( 104 | batch_size, self.size + 1, 2 105 | ), 106 | ) 107 | ) 108 | d2 = batch['coordinates'].cpu() 109 | 110 | dists = (d1 - d2).norm(p=2, dim=2) 111 | dists[~candidates] = 1e6 112 | 113 | dists[:, self.size // 2 + 1 :] += 1e3 # mask all delivery 114 | dists[top > 0, top[top > 0] + half_size] -= 1e3 115 | 116 | next_selected_node = dists.min(-1)[1].view(-1, 1) 117 | 118 | index2 = (next_selected_node > half_size) & (next_selected_node > 0) 119 | if index2.any(): 120 | stacks[ 121 | index2.view(-1), next_selected_node[index2] - half_size 122 | ] = -0.01 123 | 124 | add_index = (next_selected_node <= half_size).view(-1) 125 | pairing = ( 126 | next_selected_node[next_selected_node <= half_size].view(-1, 1) 127 | + half_size 128 | ) 129 | candidates[add_index] = candidates[add_index].scatter_( 130 | 1, pairing, 1 131 | ) 132 | 133 | solution.scatter_(1, selected_node, next_selected_node) 134 | candidates.scatter_(1, next_selected_node, 0) 135 | selected_node = next_selected_node 136 | 137 | return solution 138 | 139 | else: 140 | raise NotImplementedError() 141 | 142 | return ( 143 | get_solution(self.init_val_method).expand(batch_size, self.size + 1).clone() 144 | ) 145 | 146 | def _check_feasibility(self, solution: torch.Tensor) -> None: 147 | 148 | size_p1 = self.size + 1 149 | 150 | assert ( 151 | (torch.arange(size_p1, out=solution.new())).view(1, -1).expand_as(solution) 152 | == solution.sort(1)[0] 153 | ).all(), ( 154 | ( 155 | (torch.arange(size_p1, out=solution.new())) 156 | .view(1, -1) 157 | .expand_as(solution) 158 | == solution.sort(1)[0] 159 | ), 160 | "not visiting all nodes", 161 | solution, 162 | ) 163 | 164 | # calculate visited time 165 | batch_size = solution.size(0) 166 | visited_time = torch.zeros((batch_size, size_p1), device=solution.device) 167 | stacks = ( 168 | torch.zeros((batch_size, size_p1 // 2), device=solution.device).long() 169 | - 0.01 170 | ) 171 | pre = torch.zeros(batch_size, device=solution.device).long() 172 | arange = torch.arange(batch_size) 173 | for i in range(size_p1): 174 | cur = solution[arange, pre] 175 | visited_time[arange, cur] = i + 1 176 | pre = cur 177 | index1 = (cur <= size_p1 // 2) & (cur > 0) 178 | index2 = (cur > size_p1 // 2) & (cur > 0) 179 | if index1.any(): 180 | stacks[index1, cur[index1] - 1] = i + 1 181 | assert ( 182 | stacks.max(-1)[1][index2] == (cur[index2] - 1 - size_p1 // 2) 183 | ).all(), 'pdtsp error' 184 | if (index2).any(): 185 | stacks[index2, cur[index2] - 1 - size_p1 // 2] = -0.01 186 | 187 | assert ((stacks == -0.01).all()) 188 | assert ( 189 | visited_time[:, 1 : size_p1 // 2 + 1] < visited_time[:, size_p1 // 2 + 1 :] 190 | ).all(), ( 191 | visited_time[:, 1 : size_p1 // 2 + 1] < visited_time[:, size_p1 // 2 + 1 :], 192 | "deliverying without pick-up", 193 | ) 194 | -------------------------------------------------------------------------------- /nets/actor_network.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING 2 | import math 3 | from torch import nn 4 | import torch 5 | 6 | from problems.problem_pdp import PDP 7 | 8 | from .graph_layers import ( 9 | N2SEncoder, 10 | N2SDecoder, 11 | EmbeddingNet, 12 | MHA_Self_Score_WithoutNorm, 13 | ConstructEncoder, 14 | ConstructDecoder, 15 | HeterEmbedding, 16 | ) 17 | 18 | if TYPE_CHECKING: 19 | from agent.agent import Agent 20 | 21 | 22 | class mySequential(nn.Sequential): 23 | 24 | __call__: Callable[..., Union[Tuple[torch.Tensor], torch.Tensor]] 25 | 26 | def forward( 27 | self, *inputs: Union[Tuple[torch.Tensor], torch.Tensor] 28 | ) -> Union[Tuple[torch.Tensor], torch.Tensor]: 29 | for module in self._modules.values(): 30 | if type(inputs) == tuple: 31 | inputs = module(*inputs) 32 | else: 33 | inputs = module(inputs) 34 | return inputs # type: ignore 35 | 36 | 37 | class Actor_N2S(nn.Module): 38 | def __init__( 39 | self, 40 | problem_name: str, 41 | embedding_dim: int, 42 | ff_hidden_dim: int, 43 | n_heads_actor: int, 44 | n_layers: int, 45 | normalization: str, 46 | v_range: float, 47 | seq_length: int, 48 | embedding_type: str, 49 | removal_type: str, 50 | ) -> None: 51 | super().__init__() 52 | 53 | self.embedding_dim = embedding_dim 54 | self.ff_hidden_dim = ff_hidden_dim 55 | self.n_heads_actor = n_heads_actor 56 | self.n_layers = n_layers 57 | self.normalization = normalization 58 | self.v_range = v_range 59 | self.seq_length = seq_length 60 | self.calc_stacks = bool(problem_name == 'pdtspl') 61 | self.node_dim = 2 62 | 63 | # networks 64 | self.embedder = EmbeddingNet( 65 | self.node_dim, self.embedding_dim, self.seq_length, embedding_type 66 | ) 67 | 68 | self.pos_emb_encoder = MHA_Self_Score_WithoutNorm( 69 | self.n_heads_actor, self.embedding_dim 70 | ) # for PFEs 71 | 72 | self.encoder = mySequential( 73 | *( 74 | N2SEncoder( 75 | self.n_heads_actor, 76 | self.embedding_dim, 77 | self.ff_hidden_dim, 78 | self.normalization, 79 | ) 80 | for _ in range(self.n_layers) 81 | ) 82 | ) # for NFEs 83 | 84 | self.decoder = N2SDecoder( 85 | self.n_heads_actor, self.embedding_dim, self.v_range, removal_type 86 | ) # the two propsoed decoders 87 | 88 | print('Actor_N2S:', self.get_parameter_number()) 89 | 90 | def get_parameter_number(self) -> Dict[str, int]: 91 | total_num = sum(p.numel() for p in self.parameters()) 92 | trainable_num = sum(p.numel() for p in self.parameters() if p.requires_grad) 93 | return {'Total': total_num, 'Trainable': trainable_num} 94 | 95 | @staticmethod 96 | def _get_action_removal_recent( 97 | action_removal_record: List[torch.Tensor], 98 | ) -> torch.Tensor: 99 | action_removal_record_tensor = torch.stack( 100 | action_removal_record 101 | ) # (len_action_record, batch_size, graph_size/2) 102 | return torch.cat( 103 | ( 104 | action_removal_record_tensor[-3:].transpose(0, 1), 105 | action_removal_record_tensor.mean(0).unsqueeze(1), 106 | ), 107 | 1, 108 | ) # (batch_size, 4, graph_size/2) 109 | 110 | __call__: Callable[ 111 | ..., Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] 112 | ] 113 | 114 | def forward( 115 | self, 116 | problem: PDP, 117 | x_in: torch.Tensor, 118 | solution: torch.Tensor, 119 | pre_action: Optional[torch.Tensor], 120 | action_removal_record: List[torch.Tensor], 121 | fixed_action: Optional[torch.Tensor] = None, 122 | require_entropy: bool = False, 123 | to_critic: bool = False, 124 | only_critic: bool = False, 125 | only_fea: bool = False, 126 | ): 127 | # the embedded input x 128 | # batch_size, graph_size+1, node_dim = x_in.size() 129 | 130 | if only_fea: 131 | h_fea = self.embedder(x_in, None, False)[0] 132 | return h_fea.detach(), None, None, None 133 | 134 | h_fea, g_pos, visit_index, top2 = self.embedder( 135 | x_in, solution, self.calc_stacks 136 | ) 137 | 138 | if h_fea is None: # share or together 139 | h_fea = self.agent.actor_construct(x_in, only_fea=True)[0] 140 | 141 | # pass through encoder 142 | aux_att = self.pos_emb_encoder(g_pos) 143 | h_wave = self.encoder(h_fea, aux_att)[0] 144 | 145 | if only_critic: 146 | return h_wave, None, None, None 147 | 148 | # pass through decoder 149 | action, log_ll, entropy = self.decoder( 150 | problem=problem, 151 | h_wave=h_wave, 152 | solution=solution, 153 | x_in=x_in, 154 | top2=top2, 155 | visit_index=visit_index, 156 | pre_action=pre_action, 157 | selection_recent=Actor_N2S._get_action_removal_recent( 158 | action_removal_record 159 | ).to(x_in.device), 160 | fixed_action=fixed_action, 161 | require_entropy=require_entropy, 162 | ) 163 | 164 | return ( 165 | action, 166 | log_ll.squeeze(), 167 | h_wave if to_critic else None, 168 | entropy if require_entropy else None, 169 | ) 170 | 171 | def hook_agent(self, agent: 'Agent'): 172 | self.agent = agent 173 | 174 | 175 | class Actor_Construct(nn.Module): # kool suggest 8 heads and 128 dim 176 | def __init__( 177 | self, 178 | problem_name: str, 179 | embedding_dim: int, 180 | n_heads: int, 181 | n_layers: int, 182 | normalization: str, 183 | type_select: str, 184 | embedding_type: str, 185 | attn_type: str, 186 | ) -> None: 187 | super().__init__() 188 | 189 | self.stack_is_lifo = bool(problem_name == 'pdtspl') 190 | 191 | self.together = embedding_type == 'together' 192 | 193 | # self.embedder = nn.Linear(2, embedding_dim, bias=False) 194 | if embedding_type == 'pair' or embedding_type == 'together': 195 | self.embedder = HeterEmbedding(2, embedding_dim) 196 | elif embedding_type == 'share': 197 | self.embedder = None # type: ignore 198 | else: 199 | raise NotImplementedError 200 | 201 | self.encoder = nn.Sequential( 202 | *( 203 | ConstructEncoder(n_heads, embedding_dim, normalization, attn_type) 204 | for _ in range(n_layers) 205 | ) 206 | ) 207 | self.decoder = ConstructDecoder( 208 | n_heads, embedding_dim, self.stack_is_lifo, type_select 209 | ) 210 | 211 | self.init_parameters() 212 | 213 | print('Actor_SC:', self.get_parameter_number()) 214 | 215 | def init_parameters(self) -> None: 216 | for param in self.parameters(): 217 | stdv = 1.0 / math.sqrt(param.size(-1)) 218 | param.data.uniform_(-stdv, stdv) 219 | 220 | __call__: Callable[..., Tuple[torch.Tensor, torch.Tensor]] 221 | 222 | def forward( 223 | self, 224 | x_in: torch.Tensor, 225 | fixed_sol: Optional[torch.Tensor] = None, 226 | temperature: float = 1, 227 | only_fea: bool = False, 228 | ): 229 | 230 | if self.embedder is None: # share 231 | h_fea = self.agent.actor(None, x_in, None, None, None, only_fea=True)[0] 232 | else: 233 | h_fea = self.embedder(x_in) # (batch_size, graph_size+1, embedding_dim) 234 | 235 | if only_fea: 236 | return h_fea if self.together else h_fea.detach(), None 237 | 238 | hN: torch.Tensor = self.encoder(h_fea) 239 | hN_mean = hN.mean(1) 240 | 241 | batch_size, graph_size_plus1, _ = h_fea.size() 242 | 243 | init_sol = ( 244 | torch.arange(graph_size_plus1).repeat((batch_size, 1)).to(h_fea.device) 245 | ) 246 | cur_sol = init_sol.clone() 247 | 248 | stack = ( 249 | torch.zeros((batch_size, graph_size_plus1 // 2 + 1)).to(h_fea.device) - 1 250 | ) 251 | stack[:, 0] = 0 252 | 253 | direct_fixed_sol = ( 254 | PDP.direct_solution(fixed_sol) if fixed_sol is not None else None 255 | ) 256 | 257 | log_ll_list = [] 258 | for step in range(graph_size_plus1 - 1): 259 | cur_sol, log_p = self.decoder( 260 | hN, 261 | hN_mean, 262 | cur_sol, 263 | init_sol, 264 | step, 265 | stack, 266 | direct_fixed_sol, 267 | temperature, 268 | ) 269 | log_ll_list.append(log_p.view(-1)) 270 | 271 | log_ll = torch.stack(log_ll_list, 1).sum(1) # (batch_size,) 272 | 273 | return cur_sol, log_ll 274 | 275 | def get_parameter_number(self) -> Dict[str, int]: 276 | total_num = sum(p.numel() for p in self.parameters()) 277 | trainable_num = sum(p.numel() for p in self.parameters() if p.requires_grad) 278 | return {'Total': total_num, 'Trainable': trainable_num} 279 | 280 | def hook_agent(self, agent: 'Agent'): 281 | self.agent = agent 282 | -------------------------------------------------------------------------------- /problems/problem_pdp.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Iterable, List, Optional, Tuple 2 | from abc import ABC, abstractmethod 3 | import os 4 | import pickle 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class PDP(ABC): 10 | name: str 11 | 12 | @abstractmethod 13 | def __init__( 14 | self, size: int, init_val_method: str, check_feasible: bool = False 15 | ) -> None: 16 | self.size = size # the number of nodes in PDTSP 17 | self.check_feasible = check_feasible 18 | self.init_val_method = init_val_method 19 | 20 | @staticmethod 21 | @abstractmethod 22 | def get_swap_mask( 23 | selected_node: torch.Tensor, 24 | visit_index: torch.Tensor, 25 | top2: torch.Tensor, 26 | ) -> torch.Tensor: 27 | pass 28 | 29 | @abstractmethod 30 | def get_initial_solutions(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: 31 | pass 32 | 33 | @abstractmethod 34 | def _check_feasibility(self, solution: torch.Tensor) -> None: 35 | pass 36 | 37 | @staticmethod 38 | def input_coordinates(batch: Dict[str, torch.Tensor]) -> torch.Tensor: 39 | return batch['coordinates'] 40 | 41 | @staticmethod 42 | def _get_visit_order_map(visit_index: torch.Tensor) -> torch.Tensor: 43 | batch_size, graph_size_plus1 = visit_index.size() 44 | return visit_index.view(batch_size, graph_size_plus1, 1) > visit_index.view( 45 | batch_size, 1, graph_size_plus1 46 | ) # row late than column: true 47 | 48 | @staticmethod 49 | def _insert_star( 50 | solution: torch.Tensor, # (batch_size, graph_size+1) 51 | pair_first: torch.Tensor, # (batch_size, 1) 52 | first: torch.Tensor, # (batch_size, 1) 53 | second: torch.Tensor, # (batch_size, 1) 54 | ) -> torch.Tensor: 55 | solution = solution.clone() # if solution=[2,0,1], means 0->2->1->0. 56 | graph_size_plus1 = solution.size(1) 57 | 58 | assert ( 59 | (pair_first != first).all() 60 | and (pair_first != second).all() 61 | and ((pair_first + graph_size_plus1 // 2) != first).all() 62 | and ((pair_first + graph_size_plus1 // 2) != second).all() 63 | ) 64 | 65 | # remove pair node 66 | pre = solution.argsort() # pre=[1,2,0] 67 | pre_pair_first = pre.gather(1, pair_first) # (batch_size, 1) 68 | post_pair_first = solution.gather(1, pair_first) # (batch_size, 1) 69 | 70 | solution.scatter_(1, pre_pair_first, post_pair_first) # remove pair first 71 | solution.scatter_( 72 | 1, pair_first, pair_first 73 | ) # let: pair first -> pair first, for next line's pre correct 74 | 75 | pre = solution.argsort() 76 | 77 | pre_pair_second = pre.gather(1, pair_first + graph_size_plus1 // 2) 78 | post_pair_second = solution.gather(1, pair_first + graph_size_plus1 // 2) 79 | 80 | solution.scatter_(1, pre_pair_second, post_pair_second) # remove pair second 81 | 82 | # insert pair node 83 | post_second = solution.gather(1, second) 84 | solution.scatter_( 85 | 1, second, pair_first + graph_size_plus1 // 2 86 | ) # second -> pair_second 87 | solution.scatter_(1, pair_first + graph_size_plus1 // 2, post_second) 88 | 89 | post_first = solution.gather(1, first) 90 | solution.scatter_(1, first, pair_first) # first -> pair_first 91 | solution.scatter_(1, pair_first, post_first) 92 | 93 | return solution 94 | 95 | @staticmethod 96 | def make_dataset( 97 | filename: Optional[str] = None, 98 | size: int = 20, 99 | num_samples: int = 10000, 100 | offset: int = 0, 101 | silence: bool = False, 102 | ) -> 'PDPDataset': 103 | return PDPDataset(filename, size, num_samples, offset, silence) 104 | 105 | def get_costs( 106 | self, 107 | batch_feature: torch.Tensor, 108 | solution: torch.Tensor, 109 | zoom: bool = False, 110 | ) -> torch.Tensor: 111 | batch_size, graph_size_plus1 = solution.size() 112 | 113 | # check feasibility 114 | if self.check_feasible: 115 | self._check_feasibility(solution) 116 | 117 | # calculate obj value 118 | d1 = batch_feature.gather( 119 | 1, solution.long().unsqueeze(-1).expand(batch_size, graph_size_plus1, 2) 120 | ) 121 | d2 = batch_feature 122 | if not zoom: 123 | length = (d1 - d2).norm(p=2, dim=2).sum(1) # (batch_size,) 124 | else: 125 | length = torch.round((d1 - d2).norm(p=2, dim=2)).sum(1) 126 | 127 | return length 128 | 129 | def step( 130 | self, 131 | batch: Dict[ 132 | str, torch.Tensor 133 | ], # ['coordinates']: (batch_size, graph_size+1, 2) 134 | solution: torch.Tensor, # (batch_size, graph_size+1) 135 | action: torch.Tensor, # (batch_size, 3) 136 | pre_best_obj: torch.Tensor, # (batch_size, 2) or (batch_size,) 137 | action_removal_record: List[torch.Tensor], # len * (batch_size, graph_size/2) 138 | best_sol: Optional[torch.Tensor] = None, # (batch_size, graph_size+1) 139 | zoom: bool = False, 140 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[torch.Tensor]]: 141 | batch_size = solution.size(0) 142 | pre_best_obj = pre_best_obj.view(batch_size, -1) 143 | 144 | action_removal_record = action_removal_record.copy() 145 | cur_vec = torch.zeros_like(action_removal_record.pop(0)) # new tensor 146 | cur_vec[torch.arange(batch_size), action[:, 0]] = 1 147 | action_removal_record.append(cur_vec) 148 | 149 | selected_minus1 = action[:, 0].view(batch_size, 1) 150 | first = action[:, 1].view(batch_size, 1) 151 | second = action[:, 2].view(batch_size, 1) 152 | 153 | next_state = PDP._insert_star(solution, selected_minus1 + 1, first, second) 154 | 155 | new_obj = self.get_costs(batch['coordinates'], next_state, zoom) 156 | 157 | now_best_obj, from_which = torch.min( 158 | torch.cat((new_obj[:, None], pre_best_obj[:, -1, None]), -1), -1 159 | ) 160 | 161 | if best_sol is not None: 162 | choose_from = torch.cat((next_state.unsqueeze(1), best_sol.unsqueeze(1)), 1) 163 | best_sol[torch.arange(batch_size), :] = choose_from[ 164 | torch.arange(batch_size), from_which, : 165 | ] 166 | 167 | reward = pre_best_obj[:, -1] - now_best_obj # (batch_size,) 168 | 169 | return ( 170 | next_state, 171 | reward, 172 | torch.cat((new_obj[:, None], now_best_obj[:, None]), -1), 173 | action_removal_record, 174 | ) 175 | 176 | @staticmethod 177 | def direct_solution(solution: torch.Tensor) -> torch.Tensor: 178 | batch_size, seq_length = solution.size() 179 | arange = torch.arange(batch_size) 180 | visit_index = torch.zeros((batch_size, seq_length), device=solution.device) 181 | pre = torch.zeros((batch_size), device=solution.device).long() 182 | 183 | for i in range(seq_length): 184 | current_nodes = solution[arange, pre] # (batch_size,) 185 | visit_index[arange, current_nodes] = i + 1 186 | pre = current_nodes 187 | 188 | visit_index = (visit_index % seq_length).long() 189 | return visit_index.argsort() 190 | 191 | 192 | class PDPDataset(Dataset): 193 | def __init__( 194 | self, 195 | filename: Optional[str], 196 | size: int, 197 | num_samples: int, 198 | offset: int, 199 | silence: bool, 200 | ): 201 | super().__init__() 202 | 203 | self.data: List[Dict[str, torch.Tensor]] = [] 204 | self.size = size 205 | 206 | if filename is not None: 207 | assert os.path.splitext(filename)[1] == '.pkl', 'file name error' 208 | 209 | with open(filename, 'rb') as f: 210 | data = pickle.load(f) 211 | self.data = [ 212 | PDPDataset._make_instance(args) 213 | for args in data[offset : offset + num_samples] 214 | ] 215 | 216 | else: 217 | self.data = [ 218 | { 219 | 'loc': torch.FloatTensor(self.size, 2).uniform_(0, 1), 220 | 'depot': torch.FloatTensor(2).uniform_(0, 1), 221 | } 222 | for _ in range(num_samples) 223 | ] 224 | 225 | self.N = len(self.data) 226 | 227 | for i, instance in enumerate(self.data): 228 | self.data[i]['coordinates'] = torch.cat( 229 | (instance['depot'].reshape(1, 2), instance['loc']), dim=0 230 | ) 231 | del self.data[i]['depot'] 232 | del self.data[i]['loc'] 233 | if not silence: 234 | print(f'{self.N} instances initialized.') 235 | 236 | def __len__(self) -> int: 237 | return self.N 238 | 239 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 240 | return self.data[idx] 241 | 242 | @staticmethod 243 | def _make_instance(args: Iterable) -> Dict[str, torch.Tensor]: 244 | depot, loc, *args = args 245 | grid_size = 1 246 | if len(args) > 0: 247 | depot_types, customer_types, grid_size = args 248 | return { 249 | 'loc': torch.tensor(loc, dtype=torch.float) / grid_size, 250 | 'depot': torch.tensor(depot, dtype=torch.float) / grid_size, 251 | } 252 | 253 | @staticmethod 254 | def calculate_distance(data: torch.Tensor) -> torch.Tensor: 255 | N_data = data.shape[0] 256 | dists = torch.zeros((N_data, N_data), dtype=torch.float) 257 | d1 = -2 * torch.mm(data, data.T) 258 | d2 = torch.sum(torch.pow(data, 2), dim=1) 259 | d3 = torch.sum(torch.pow(data, 2), dim=1).reshape(1, -1).T 260 | dists = d1 + d2 + d3 261 | dists[dists < 0] = 0 262 | return torch.sqrt(dists) 263 | -------------------------------------------------------------------------------- /agent/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Tuple 2 | import time 3 | import torch 4 | import os 5 | import random 6 | from tqdm import tqdm 7 | import torch.distributed as dist 8 | from torch.utils.data import DataLoader 9 | from tensorboard_logger import Logger as TbLogger 10 | 11 | from problems.problem_pdp import PDP 12 | from utils.logger import log_to_screen, log_to_tb_val 13 | from utils import rotate_tensor, move_to 14 | 15 | from .agent import Agent 16 | 17 | 18 | def gather_tensor_and_concat(tensor: torch.Tensor) -> torch.Tensor: 19 | gather_t = [torch.ones_like(tensor) for _ in range(dist.get_world_size())] 20 | dist.all_gather(gather_t, tensor) 21 | return torch.cat(gather_t) 22 | 23 | 24 | def validate( 25 | rank: int, 26 | problem: PDP, 27 | agent: Agent, 28 | val_dataset_str: Optional[str] = None, 29 | tb_logger: Optional[TbLogger] = None, 30 | distributed: bool = False, 31 | id_: Optional[int] = None, 32 | mem_test: bool = False, 33 | zoom: bool = False, 34 | ) -> None: 35 | # Validate mode 36 | if rank == 0 and not mem_test: 37 | print('\nValidating...', flush=True) 38 | torch.backends.cudnn.deterministic = True 39 | torch.backends.cudnn.benchmark = False 40 | opts = agent.opts 41 | agent.eval() 42 | 43 | random_state_backup = ( 44 | torch.get_rng_state(), 45 | torch.cuda.get_rng_state(), 46 | random.getstate(), 47 | ) 48 | 49 | torch.manual_seed(opts.seed) 50 | random.seed(opts.seed) 51 | 52 | val_dataset = PDP.make_dataset( 53 | size=opts.graph_size, 54 | num_samples=opts.val_size, 55 | filename=val_dataset_str, 56 | silence=mem_test, 57 | ) 58 | 59 | if distributed: 60 | device = torch.device("cuda", rank) 61 | torch.distributed.init_process_group( 62 | backend='nccl', world_size=opts.world_size, rank=rank 63 | ) 64 | torch.cuda.set_device(rank) 65 | agent.actor.to(device) 66 | if opts.normalization == 'batch': 67 | agent.actor = torch.nn.SyncBatchNorm.convert_sync_batchnorm(agent.actor).to( 68 | device 69 | ) # type: ignore 70 | agent.actor = torch.nn.parallel.DistributedDataParallel( 71 | agent.actor, device_ids=[rank] 72 | ) # type: ignore 73 | if opts.shared_critic: 74 | agent.actor_construct.to(device) 75 | if opts.sc_normalization == 'batch': 76 | agent.actor_construct = torch.nn.SyncBatchNorm.convert_sync_batchnorm( 77 | agent.actor_construct 78 | ).to( 79 | device 80 | ) # type: ignore 81 | agent.actor_construct = torch.nn.parallel.DistributedDataParallel( 82 | agent.actor_construct, device_ids=[rank] 83 | ) # type: ignore 84 | if not opts.no_tb and rank == 0 and not mem_test: 85 | tb_logger = TbLogger( 86 | os.path.join( 87 | opts.log_dir, 88 | "{}_{}".format(opts.problem, opts.graph_size), 89 | opts.run_name, 90 | ) 91 | ) 92 | 93 | assert opts.val_batch_size % opts.world_size == 0 94 | train_sampler = torch.utils.data.distributed.DistributedSampler( 95 | val_dataset, shuffle=False 96 | ) # type: ignore 97 | val_dataloader = DataLoader( 98 | val_dataset, 99 | batch_size=opts.val_batch_size // opts.world_size, 100 | shuffle=False, 101 | num_workers=0, 102 | pin_memory=True, 103 | sampler=train_sampler, 104 | ) 105 | else: 106 | val_dataloader = DataLoader( 107 | val_dataset, 108 | batch_size=opts.val_batch_size, 109 | shuffle=False, 110 | num_workers=0, 111 | pin_memory=True, 112 | ) 113 | 114 | s_time = time.time() 115 | bv_list = [] 116 | cost_hist_list = [] 117 | best_hist_list = [] 118 | r_list = [] 119 | for batch in tqdm( 120 | val_dataloader, 121 | desc='inference', 122 | bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}', 123 | disable=mem_test or (rank != 0), 124 | ): 125 | bv_, cost_hist_, best_hist_, r_ = agent.rollout( 126 | problem, opts.val_m, batch, show_bar=(rank == 0 and not mem_test), zoom=zoom 127 | ) 128 | bv_list.append(bv_) 129 | cost_hist_list.append(cost_hist_) 130 | best_hist_list.append(best_hist_) 131 | r_list.append(r_) 132 | 133 | if mem_test: 134 | break 135 | bv = torch.cat(bv_list, 0) 136 | cost_hist = torch.cat(cost_hist_list, 0) 137 | best_hist = torch.cat(best_hist_list, 0) 138 | r = torch.cat(r_list, 0) 139 | 140 | if distributed: 141 | dist.barrier() 142 | 143 | initial_cost = gather_tensor_and_concat(cost_hist[:, 0].contiguous()) 144 | time_used = gather_tensor_and_concat( 145 | torch.tensor([time.time() - s_time]).cuda() 146 | ) 147 | bv = gather_tensor_and_concat(bv.contiguous()) 148 | costs_history = gather_tensor_and_concat(cost_hist.contiguous()) 149 | search_history = gather_tensor_and_concat(best_hist.contiguous()) 150 | reward = gather_tensor_and_concat(r.contiguous()) 151 | 152 | dist.barrier() 153 | else: 154 | initial_cost = cost_hist[:, 0] # bs 155 | time_used = torch.tensor([time.time() - s_time]) # bs 156 | bv = bv 157 | costs_history = cost_hist 158 | search_history = best_hist 159 | reward = r 160 | 161 | # save costs_history and search_history 162 | if opts.save_infer_dir: 163 | torch.save( 164 | {'current': costs_history, 'best': search_history, 'option': vars(opts)}, 165 | opts.save_infer_dir + '/infer_' + opts.run_name + '.pt', 166 | ) 167 | 168 | # log to screen 169 | if rank == 0 and not mem_test: 170 | log_to_screen( 171 | time_used, 172 | initial_cost, 173 | bv, 174 | reward, 175 | costs_history, 176 | search_history, 177 | batch_size=opts.val_size, 178 | dataset_size=len(val_dataset), 179 | T=opts.T_max, 180 | ) 181 | 182 | # log to tb 183 | if (not opts.no_tb) and rank == 0 and not mem_test: 184 | log_to_tb_val( 185 | tb_logger, 186 | time_used, 187 | initial_cost, 188 | bv, 189 | reward, 190 | costs_history, 191 | search_history, 192 | batch_size=opts.val_size, 193 | val_size=opts.val_size, 194 | dataset_size=len(val_dataset), 195 | T=opts.T_max, 196 | epoch=id_, 197 | ) 198 | 199 | torch.set_rng_state(random_state_backup[0]) 200 | torch.cuda.set_rng_state(random_state_backup[1]) 201 | random.setstate(random_state_backup[2]) 202 | 203 | if distributed: 204 | dist.barrier() 205 | 206 | 207 | def batch_augments( 208 | val_m: int, 209 | batch: Dict[str, torch.Tensor], 210 | graph_size_plus1: Optional[int] = None, 211 | node_dim: Optional[int] = None, 212 | one_is_keep: bool = True, 213 | ) -> None: 214 | batch['coordinates'] = batch['coordinates'].unsqueeze(1).repeat(1, val_m, 1, 1) 215 | augments = ['Rotate', 'Flip_x-y', 'Flip_x_cor', 'Flip_y_cor'] 216 | 217 | if val_m > 1 or (not one_is_keep and val_m == 1): 218 | for i in range(val_m): 219 | random.shuffle(augments) 220 | id_ = torch.rand(4) 221 | for aug in augments: 222 | if aug == 'Rotate': 223 | batch['coordinates'][:, i] = rotate_tensor( 224 | batch['coordinates'][:, i], int(id_[0] * 4 + 1) * 90 225 | ) 226 | elif aug == 'Flip_x-y': 227 | if int(id_[1] * 2 + 1) == 1: 228 | data = batch['coordinates'][:, i].clone() 229 | batch['coordinates'][:, i, :, 0] = data[:, :, 1] 230 | batch['coordinates'][:, i, :, 1] = data[:, :, 0] 231 | elif aug == 'Flip_x_cor': 232 | if int(id_[2] * 2 + 1) == 1: 233 | batch['coordinates'][:, i, :, 0] = ( 234 | 1 - batch['coordinates'][:, i, :, 0] 235 | ) 236 | elif aug == 'Flip_y_cor': 237 | if int(id_[3] * 2 + 1) == 1: 238 | batch['coordinates'][:, i, :, 1] = ( 239 | 1 - batch['coordinates'][:, i, :, 1] 240 | ) 241 | 242 | if graph_size_plus1 is not None and node_dim is not None: 243 | batch['coordinates'] = batch['coordinates'].view(-1, graph_size_plus1, node_dim) 244 | 245 | 246 | def mem_test(agent: Agent, problem: PDP, batch: Dict[str, torch.Tensor]) -> None: 247 | random_state_backup = ( 248 | torch.get_rng_state(), 249 | torch.cuda.get_rng_state(), 250 | random.getstate(), 251 | ) 252 | 253 | opts = agent.opts 254 | 255 | batch = ( 256 | move_to(batch, 0) if opts.distributed else move_to(batch, opts.device) 257 | ) # batch_size, graph_size+1, 2 258 | batch_feature: torch.Tensor = ( 259 | move_to(PDP.input_coordinates(batch), 0) 260 | if opts.distributed 261 | else move_to(PDP.input_coordinates(batch), opts.device) 262 | ) 263 | _, graph_size_plus1, node_dim = batch_feature.size() 264 | 265 | agent.eval() 266 | 267 | if opts.shared_critic and not opts.no_sample_init: 268 | print('testing memory restriction for init construct sample...', end=' ') 269 | train_sample_size = min(opts.max_init_sample_size, opts.max_init_sample_batch) 270 | 271 | ms_batch_feature = batch_feature.unsqueeze(1).repeat(1, train_sample_size, 1, 1) 272 | ms_batch_feature = ms_batch_feature.view(-1, graph_size_plus1, node_dim) 273 | 274 | agent.actor_construct(ms_batch_feature) 275 | 276 | print('pass') 277 | 278 | print('testing memory restriction for validate...', end=' ') 279 | 280 | opts_backup = opts.T_max, opts.inference_sample_size 281 | opts.T_max = 1 282 | opts.inference_sample_size = min( 283 | opts.inference_sample_size, opts.inference_sample_batch 284 | ) 285 | validate(0, problem, agent, mem_test=True) 286 | opts.T_max, opts.inference_sample_size = opts_backup 287 | 288 | print('pass') 289 | 290 | torch.set_rng_state(random_state_backup[0]) 291 | torch.cuda.set_rng_state(random_state_backup[1]) 292 | random.setstate(random_state_backup[2]) 293 | 294 | 295 | def zoom_feature(feature: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 296 | max_c = feature.max(1)[0] 297 | min_c = feature.min(1)[0] 298 | x_gap = max_c[:, 0] - min_c[:, 0] # (10,) 299 | y_gap = max_c[:, 1] - min_c[:, 1] # (10,) 300 | xy_gap = torch.cat([x_gap[None, :], y_gap[None, :]]) # (2,10) 301 | gap = xy_gap.max(0)[0] # (10,) 302 | new_fea = (feature - min_c[:, None, :]) / gap[:, None, None] 303 | return new_fea, gap 304 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | import os 3 | import time 4 | import argparse 5 | import math 6 | import torch 7 | 8 | 9 | class Option(argparse.Namespace): 10 | # shared critic 11 | shared_critic: bool 12 | 13 | sc_decoder_select_type: str 14 | sc_rl_train_type: str 15 | sc_map_sample_type: str 16 | sc_normalization: str 17 | no_sample_init: bool 18 | 19 | lr_trust_degree: float 20 | lr_construct: float 21 | max_grad_norm_construct: float 22 | sc_start_train_epoch: int 23 | 24 | inference_sample_size: int 25 | inference_sample_batch: int 26 | inference_temperature: List[float] 27 | 28 | temperature_init: float 29 | temperature_decay: float 30 | max_init_sample_size: int 31 | max_init_sample_batch: int 32 | init_sample_increase: float 33 | max_warm_up: int 34 | 35 | imitation_rate: float 36 | imitation_max_grad_norm: float 37 | imitation_step: int 38 | imitation_max_augment: int 39 | imitation_increase_w: float 40 | imitation_increase_b: float 41 | 42 | sc_attn_type: str 43 | embed_type_n2s: str 44 | embed_type_sc: str 45 | removal_type: str 46 | warm_up_type: str 47 | 48 | load_original_n2s: Optional[str] 49 | save_infer_dir: str 50 | zoom: bool 51 | # dynamic 52 | start_init_sample_epoch: int 53 | sc_map_sample_times: int 54 | cur_temperature: float 55 | cur_init_sample_size: int 56 | cur_imitation_augment: int 57 | start_warm_up_epoch: int 58 | 59 | # overall settings 60 | problem: str 61 | graph_size: int 62 | init_val_method: str 63 | no_cuda: bool 64 | no_tb: bool 65 | no_saving: bool 66 | use_assert: bool 67 | no_DDP: bool 68 | seed: int 69 | DDP_port_offset: int 70 | 71 | # N2S parameters 72 | v_range: float 73 | actor_head_num: int 74 | critic_head_num: int 75 | embedding_dim: int 76 | ff_hidden_dim: int 77 | n_encode_layers: int 78 | normalization: str 79 | 80 | # Training parameters 81 | RL_agent: str 82 | gamma: float 83 | K_epochs: int 84 | eps_clip: float 85 | T_train: int 86 | n_step: int 87 | warm_up: float 88 | batch_size: int 89 | epoch_end: int 90 | epoch_size: int 91 | lr_model: float 92 | lr_critic: float 93 | lr_decay: float 94 | max_grad_norm: float 95 | 96 | # Inference and validation parameters 97 | T_max: int 98 | eval_only: bool 99 | val_size: int 100 | val_batch_size: int 101 | val_dataset: Optional[str] 102 | val_m: int 103 | 104 | # resume and load models 105 | load_path: Optional[str] 106 | resume: Optional[str] 107 | epoch_start: int 108 | 109 | # logs/output settings 110 | no_progress_bar: bool 111 | log_dir: str 112 | log_step: int 113 | output_dir: str 114 | run_name: str 115 | checkpoint_epochs: int 116 | 117 | # add later 118 | world_size: int 119 | distributed: bool 120 | use_cuda: bool 121 | save_dir: str 122 | device: torch.device 123 | 124 | 125 | def get_options(args: Optional[List[str]] = None) -> Option: 126 | parser = argparse.ArgumentParser(description="Neural Neighborhood Search") 127 | 128 | # shared critic 129 | parser.add_argument( 130 | '--shared_critic', action='store_true', help='enable shared critic mechanism' 131 | ) 132 | parser.add_argument( 133 | '--sc_decoder_select_type', 134 | default='sample', 135 | choices=('sample', 'greedy'), 136 | help='next node select type for sc actor', 137 | ) 138 | parser.add_argument( 139 | '--sc_rl_train_type', 140 | default='ppo', 141 | choices=('ppo', 'pg'), 142 | help='RL Training algorithm for sc actor', 143 | ) 144 | parser.add_argument( 145 | '--sc_map_sample_type', 146 | default='augment', 147 | choices=('origin', 'augment'), 148 | help='RL sample type for sc actor', 149 | ) 150 | parser.add_argument( 151 | '--sc_normalization', 152 | default='layer', 153 | help="normalization type for sc actor, 'layer' (default) or 'batch'", 154 | ) 155 | parser.add_argument( 156 | '--no_sample_init', action='store_true', help='sc use random init' 157 | ) 158 | parser.add_argument( 159 | '--lr_trust_degree', 160 | type=float, 161 | default=0.01, 162 | help='learning rate for trust degree', 163 | ) 164 | parser.add_argument( 165 | '--lr_construct', 166 | type=float, 167 | default=1e-4, 168 | help='learning rate for trust degree', 169 | ) 170 | parser.add_argument( 171 | '--max_grad_norm_construct', 172 | type=float, 173 | default=1, 174 | help='maximum L2 norm for gradient clipping of actor-construct', 175 | ) 176 | parser.add_argument( 177 | '--sc_start_train_epoch', 178 | type=int, 179 | default=0, 180 | help='sc actor start train epoch', 181 | ) 182 | parser.add_argument( 183 | '--inference_sample_size', 184 | type=int, 185 | default=128, 186 | help='actor-construct total sample size when inference', 187 | ) 188 | parser.add_argument( 189 | '--inference_sample_batch', 190 | type=int, 191 | default=1, 192 | help='actor-construct sample size per batch when inference', 193 | ) 194 | parser.add_argument( 195 | '--inference_temperature', 196 | type=float, 197 | nargs='*', 198 | default=[1], 199 | help='control the diversity of the sampled tours when inference', 200 | ) 201 | parser.add_argument( 202 | '--temperature_init', 203 | type=float, 204 | default=10, 205 | help='part of adapative init solution generation when training', 206 | ) 207 | parser.add_argument( 208 | '--temperature_decay', 209 | type=float, 210 | default=-1, # variable default 211 | help='part of adapative init solution generation when training', 212 | ) 213 | parser.add_argument( 214 | '--max_init_sample_size', 215 | type=int, 216 | default=-1, # variable default 217 | help='part of adapative init solution generation when training', 218 | ) 219 | parser.add_argument( 220 | '--max_init_sample_batch', 221 | type=int, 222 | default=128, 223 | help='part of adapative init solution generation when training', 224 | ) 225 | parser.add_argument( 226 | '--init_sample_increase', 227 | type=float, 228 | default=-1, # variable default 229 | help='part of adapative init solution generation when training', 230 | ) 231 | parser.add_argument( 232 | '--max_warm_up', 233 | type=int, 234 | default=-1, # variable default 235 | help='max warm up time', 236 | ) 237 | parser.add_argument( 238 | '--imitation_rate', 239 | type=float, 240 | default=1, 241 | help='tunable parameter for imitation learning', 242 | ) 243 | parser.add_argument( 244 | '--imitation_max_grad_norm', 245 | type=float, 246 | default=-1, # variable default 247 | help='maximum L2 norm for gradient clipping of imitation', 248 | ) 249 | parser.add_argument( 250 | '--imitation_step', 251 | type=int, 252 | default=250, 253 | help='if t %% imitation_step == 0, imitation loss run, set <=0 to off', 254 | ) 255 | parser.add_argument( 256 | '--imitation_max_augment', 257 | type=int, 258 | default=-1, # variable default 259 | help='max time of instance augment when performing imitaion learning', 260 | ) 261 | parser.add_argument( 262 | '--imitation_increase_w', 263 | type=float, 264 | default=1, 265 | help='increase speed of instance augment when performing imitaion learning', 266 | ) 267 | parser.add_argument( 268 | '--imitation_increase_b', 269 | type=float, 270 | default=-0.01, # variable default 271 | help='increase speed of instance augment when performing imitaion learning', 272 | ) 273 | parser.add_argument( 274 | '--sc_attn_type', 275 | default='typical', 276 | choices=('typical', 'heter'), 277 | help='MHA type for ConstructEncoder', 278 | ) 279 | parser.add_argument( 280 | '--embed_type_n2s', 281 | default='origin', 282 | choices=('origin', 'pair', 'share', 'sep'), 283 | help='n2s actor graph embedding type', 284 | ) 285 | parser.add_argument( 286 | '--embed_type_sc', 287 | default='pair', 288 | choices=('pair', 'share'), 289 | help='construction actor graph embedding type', 290 | ) 291 | parser.add_argument( 292 | '--removal_type', 293 | default='glitch', 294 | choices=( 295 | 'origin', 296 | 'glitch', 297 | 'update1', 298 | 'update2', 299 | ), 300 | help='N2S NodePairRemovalDecoder type', 301 | ) 302 | parser.add_argument( 303 | '--warm_up_type', 304 | default='origin', # variable default 305 | choices=('origin', 'update'), 306 | help='N2S curriculum learning type', 307 | ) 308 | parser.add_argument('--load_original_n2s', help='path to load original N2S actor') 309 | parser.add_argument( 310 | '--save_infer_dir', help='save costs_history and search_history' 311 | ) 312 | parser.add_argument('--zoom', action='store_true', help='zoom') 313 | 314 | # overall settings 315 | parser.add_argument( 316 | '--problem', 317 | default='pdtsp', 318 | choices=['pdtsp', 'pdtspl'], 319 | help="The targeted problem to solve, default 'pdp'", 320 | ) 321 | parser.add_argument( 322 | '--graph_size', 323 | type=int, 324 | default=20, 325 | help="T number of customers in the targeted problem (graph size)", 326 | ) 327 | parser.add_argument( 328 | '--init_val_method', 329 | choices=['greedy', 'random'], 330 | default='random', 331 | help='method to generate initial solutions for inference', 332 | ) 333 | parser.add_argument('--no_cuda', action='store_true', help='disable GPUs') 334 | parser.add_argument( 335 | '--no_tb', action='store_true', help='disable Tensorboard logging' 336 | ) 337 | parser.add_argument( 338 | '--no_saving', action='store_true', help='disable saving checkpoints' 339 | ) 340 | parser.add_argument('--use_assert', action='store_true', help='enable assertion') 341 | parser.add_argument( 342 | '--no_DDP', action='store_true', help='disable distributed parallel' 343 | ) 344 | parser.add_argument('--seed', type=int, default=1234, help='random seed to use') 345 | parser.add_argument( 346 | '--DDP_port_offset', 347 | type=int, 348 | default=0, 349 | help="os.environ['MASTER_PORT'] = 4869 + this_arg", 350 | ) 351 | 352 | # N2S parameters 353 | parser.add_argument( 354 | '--v_range', type=float, default=6.0, help='to control the entropy' 355 | ) 356 | parser.add_argument( 357 | '--actor_head_num', type=int, default=4, help='head number of N2S actor' 358 | ) 359 | parser.add_argument( 360 | '--critic_head_num', type=int, default=4, help='head number of N2S critic' 361 | ) 362 | parser.add_argument( 363 | '--embedding_dim', 364 | type=int, 365 | default=128, 366 | help='dimension of input embeddings (NEF & PFE)', 367 | ) 368 | parser.add_argument( 369 | '--ff_hidden_dim', 370 | type=int, 371 | default=128, 372 | help='dimension of hidden layers in Enc/Dec', 373 | ) 374 | parser.add_argument( 375 | '--n_encode_layers', 376 | type=int, 377 | default=3, 378 | help='number of stacked layers in the encoder', 379 | ) 380 | parser.add_argument( 381 | '--normalization', 382 | default='layer', 383 | help="normalization type, 'layer' (default) or 'batch'", 384 | ) 385 | 386 | # Training parameters 387 | parser.add_argument( 388 | '--RL_agent', default='ppo', choices=['ppo'], help='RL Training algorithm' 389 | ) 390 | parser.add_argument( 391 | '--gamma', 392 | type=float, 393 | default=0.999, 394 | help='reward discount factor for future rewards', 395 | ) 396 | parser.add_argument('--K_epochs', type=int, default=3, help='mini PPO epoch') 397 | parser.add_argument('--eps_clip', type=float, default=0.1, help='PPO clip ratio') 398 | parser.add_argument( 399 | '--T_train', type=int, default=250, help='number of itrations for training' 400 | ) 401 | parser.add_argument( 402 | '--n_step', type=int, default=5, help='n_step for return estimation' 403 | ) 404 | parser.add_argument( 405 | '--warm_up', 406 | type=float, 407 | default=-1, # variable default 408 | help='hyperparameter of CL scalar $\rho^{CL}$', 409 | ) 410 | parser.add_argument( 411 | '--batch_size', 412 | type=int, 413 | default=600, 414 | help='number of instances per batch during traingammaing', 415 | ) 416 | parser.add_argument( 417 | '--epoch_end', type=int, default=200, help='maximum training epoch' 418 | ) 419 | parser.add_argument( 420 | '--epoch_size', 421 | type=int, 422 | default=12000, 423 | help='number of instances per epoch during training', 424 | ) 425 | parser.add_argument( 426 | '--lr_model', 427 | type=float, 428 | default=8e-5, 429 | help="learning rate for the actor network", 430 | ) 431 | parser.add_argument( 432 | '--lr_critic', 433 | type=float, 434 | default=2e-5, 435 | help="learning rate for the critic network", 436 | ) 437 | parser.add_argument( 438 | '--lr_decay', type=float, default=-1, help='learning rate decay per epoch' 439 | ) # variable default 440 | parser.add_argument( 441 | '--max_grad_norm', 442 | type=float, 443 | default=-1, # variable default 444 | help='maximum L2 norm for gradient clipping', 445 | ) 446 | 447 | # Inference and validation parameters 448 | parser.add_argument( 449 | '--T_max', type=int, default=1500, help='number of steps for inference' 450 | ) 451 | parser.add_argument( 452 | '--eval_only', action='store_true', help='switch to inference mode' 453 | ) 454 | parser.add_argument( 455 | '--val_size', 456 | type=int, 457 | default=1000, 458 | help='number of instances for validation/inference', 459 | ) 460 | parser.add_argument( 461 | '--val_batch_size', 462 | type=int, 463 | default=1000, 464 | help='Number of instances per batch for validation/inference', 465 | ) 466 | parser.add_argument( 467 | '--val_dataset', 468 | type=str, 469 | default=None, # variable default 470 | help='dataset file path', 471 | ) 472 | parser.add_argument( 473 | '--val_m', type=int, default=1, help='number of data augments in Algorithm 2' 474 | ) 475 | 476 | # resume and load models 477 | parser.add_argument( 478 | '--load_path', 479 | default=None, 480 | help='path to load model parameters and optimizer state from', 481 | ) 482 | parser.add_argument( 483 | '--resume', default=None, help='resume from previous checkpoint file' 484 | ) 485 | parser.add_argument( 486 | '--epoch_start', 487 | type=int, 488 | default=0, 489 | help='start at epoch # (relevant for learning rate decay)', 490 | ) 491 | 492 | # logs/output settings 493 | parser.add_argument( 494 | '--no_progress_bar', action='store_true', help='disable progress bar' 495 | ) 496 | parser.add_argument( 497 | '--log_dir', 498 | default='logs', 499 | help='directory to write TensorBoard information to', 500 | ) 501 | parser.add_argument( 502 | '--log_step', 503 | type=int, 504 | default=50, 505 | help='log info every log_step gradient steps', 506 | ) 507 | parser.add_argument( 508 | '--output_dir', default='outputs', help='directory to write output models to' 509 | ) 510 | parser.add_argument( 511 | '--run_name', default='run_name', help='name to identify the run' 512 | ) 513 | parser.add_argument( 514 | '--checkpoint_epochs', 515 | type=int, 516 | default=1, 517 | help='save checkpoint every n epochs (default 1), 0 to save no checkpoints', 518 | ) 519 | 520 | opts = Option() 521 | parser.parse_args(args, namespace=opts) 522 | 523 | # variable default 524 | if opts.temperature_decay == -1: 525 | if opts.graph_size == 50: 526 | opts.temperature_decay = 0.94 527 | elif opts.graph_size == 100: 528 | opts.temperature_decay = 0.93 529 | else: 530 | opts.temperature_decay = 0.95 531 | if opts.max_init_sample_size == -1: 532 | if opts.graph_size == 50: 533 | opts.max_init_sample_size = 256 534 | elif opts.graph_size == 100: 535 | opts.max_init_sample_size = 512 536 | else: 537 | opts.max_init_sample_size = 128 538 | if opts.init_sample_increase == -1: 539 | if opts.graph_size == 50: 540 | opts.init_sample_increase = 1.2 541 | elif opts.graph_size == 100: 542 | opts.init_sample_increase = 1.3 543 | else: 544 | opts.init_sample_increase = 1.1 545 | if opts.max_warm_up == -1: 546 | if opts.shared_critic and not opts.no_sample_init: 547 | opts.max_warm_up = 25 548 | else: 549 | opts.max_warm_up = 250 550 | if opts.imitation_max_grad_norm == -1: 551 | if opts.graph_size <= 50: 552 | opts.imitation_max_grad_norm = 0.1 553 | else: 554 | opts.imitation_max_grad_norm = 0.01 555 | if opts.imitation_max_augment == -1: 556 | if opts.graph_size >= 50: 557 | opts.imitation_max_augment = 25 558 | else: 559 | opts.imitation_max_augment = 10 560 | # if opts.imitation_increase_w == -1: 561 | # opts.imitation_increase_w = 1 562 | if opts.imitation_increase_b == -0.01: 563 | if opts.graph_size == 100: 564 | opts.imitation_increase_b = -2 565 | else: 566 | opts.imitation_increase_b = 0 567 | if opts.warm_up_type == 'origin': 568 | if opts.shared_critic and not opts.no_sample_init: 569 | opts.warm_up_type = 'update' 570 | if opts.warm_up == -1: 571 | if opts.shared_critic and not opts.no_sample_init: 572 | opts.warm_up = 0 573 | # if opts.graph_size == 100: 574 | # opts.warm_up = 2 575 | elif opts.graph_size == 50: 576 | opts.warm_up = 1.5 577 | elif opts.graph_size == 100: 578 | opts.warm_up = 1 579 | else: 580 | opts.warm_up = 2 581 | if opts.lr_decay == -1: 582 | if opts.shared_critic and not opts.no_sample_init: 583 | opts.lr_decay = 0.99 584 | else: 585 | opts.lr_decay = 0.985 586 | if opts.max_grad_norm == -1: 587 | if opts.graph_size == 50: 588 | opts.max_grad_norm = 0.15 589 | elif opts.graph_size == 100: 590 | opts.max_grad_norm = 0.3 591 | else: 592 | opts.max_grad_norm = 0.05 593 | if opts.val_dataset is None: 594 | if opts.graph_size == 20: 595 | opts.val_dataset = './datasets/pdp_20.pkl' 596 | elif opts.graph_size == 50: 597 | opts.val_dataset = './datasets/pdp_50.pkl' 598 | elif opts.graph_size == 100: 599 | opts.val_dataset = './datasets/pdp_100.pkl' 600 | 601 | ### figure out whether to use distributed training 602 | opts.world_size = torch.cuda.device_count() 603 | opts.use_cuda = torch.cuda.is_available() and not opts.no_cuda 604 | opts.distributed = ( 605 | opts.use_cuda and (torch.cuda.device_count() > 1) and (not opts.no_DDP) 606 | ) 607 | os.environ['MASTER_ADDR'] = '127.0.0.1' 608 | os.environ['MASTER_PORT'] = str(4869 + opts.DDP_port_offset) 609 | 610 | # assert opts.val_m <= opts.graph_size // 2 611 | assert opts.epoch_size % opts.batch_size == 0 612 | if opts.distributed: 613 | assert opts.batch_size % opts.world_size == 0 614 | 615 | opts.run_name = ( 616 | "{}_{}".format(opts.run_name, time.strftime("%Y%m%dT%H%M%S")) 617 | if not opts.resume 618 | else opts.resume.split('/')[-2] 619 | ) 620 | opts.save_dir = ( 621 | os.path.join( 622 | opts.output_dir, 623 | "{}_{}".format(opts.problem, opts.graph_size), 624 | opts.run_name, 625 | ) 626 | if not opts.no_saving 627 | else 'no_saving' 628 | ) 629 | 630 | assert opts.temperature_init >= 1 and 0 < opts.temperature_decay < 1 631 | opts.start_init_sample_epoch = math.ceil( 632 | math.log(1 / opts.temperature_init, opts.temperature_decay) 633 | ) 634 | if opts.shared_critic and not opts.no_sample_init: 635 | opts.start_warm_up_epoch = opts.start_init_sample_epoch + math.ceil( 636 | math.log(opts.max_init_sample_size, opts.init_sample_increase) 637 | ) 638 | else: 639 | opts.start_warm_up_epoch = 0 640 | 641 | opts.sc_map_sample_times = math.ceil(opts.T_train / opts.n_step) 642 | 643 | if opts.embed_type_n2s == 'share' and opts.embed_type_sc == 'share': 644 | opts.embed_type_n2s = 'together' 645 | opts.embed_type_sc = 'together' 646 | 647 | return opts 648 | -------------------------------------------------------------------------------- /agent/ppo.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | import os 3 | from tqdm import tqdm 4 | import warnings 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import torch.multiprocessing as mp 8 | import torch.distributed as dist 9 | from tensorboard_logger import Logger as TbLogger 10 | import math 11 | import random 12 | 13 | from utils import clip_grad_norms 14 | from nets.actor_network import Actor_N2S, Actor_Construct 15 | from nets.critic_network import Critic_N2S, Critic_Construct 16 | from utils import torch_load_cpu, get_inner_model, move_to, batch_picker 17 | from utils.logger import log_to_tb_train 18 | from problems.problem_pdp import PDP 19 | from options import Option 20 | 21 | from .agent import Agent 22 | from .utils import validate, batch_augments, mem_test, zoom_feature 23 | 24 | 25 | class Memory: 26 | def __init__(self) -> None: 27 | self.actions: List[torch.Tensor] = [] 28 | self.states: List[torch.Tensor] = [] 29 | self.logprobs: List[torch.Tensor] = [] 30 | self.rewards: List[torch.Tensor] = [] 31 | self.best_obj: List[torch.Tensor] = [] 32 | self.action_removal_record: List[List[torch.Tensor]] = [] 33 | 34 | def clear_memory(self) -> None: 35 | del self.actions[:] 36 | del self.states[:] 37 | del self.logprobs[:] 38 | del self.rewards[:] 39 | del self.best_obj[:] 40 | del self.action_removal_record[:] 41 | 42 | 43 | class PPO(Agent): 44 | def __init__(self, problem_name: str, size: int, opts: Option) -> None: 45 | # figure out the options 46 | self.opts = opts 47 | 48 | # figure out the actor 49 | self.actor = Actor_N2S( 50 | problem_name=problem_name, 51 | embedding_dim=opts.embedding_dim, 52 | ff_hidden_dim=opts.ff_hidden_dim, 53 | n_heads_actor=opts.actor_head_num, 54 | n_layers=opts.n_encode_layers, 55 | normalization=opts.normalization, 56 | v_range=opts.v_range, 57 | seq_length=size + 1, 58 | embedding_type=opts.embed_type_n2s, 59 | removal_type=opts.removal_type, 60 | ) 61 | 62 | if opts.shared_critic: 63 | self.actor_construct = Actor_Construct( 64 | problem_name, 65 | opts.embedding_dim, 66 | 8, 67 | 3, 68 | opts.sc_normalization, 69 | opts.sc_decoder_select_type, 70 | opts.embed_type_sc, 71 | opts.sc_attn_type, 72 | ) 73 | self.critic_construct = Critic_Construct() 74 | 75 | self.actor.hook_agent(self) 76 | self.actor_construct.hook_agent(self) 77 | 78 | if not opts.eval_only: 79 | # figure out the critic 80 | self.critic = Critic_N2S( 81 | embedding_dim=opts.embedding_dim, 82 | ff_hidden_dim=opts.ff_hidden_dim, 83 | n_heads=opts.critic_head_num, 84 | n_layers=opts.n_encode_layers, 85 | normalization=opts.normalization, 86 | ) 87 | 88 | # figure out the optimizer 89 | self.optimizer = torch.optim.Adam( 90 | [{'params': self.actor.parameters(), 'lr': opts.lr_model}] 91 | + [{'params': self.critic.parameters(), 'lr': opts.lr_critic}] 92 | ) 93 | if opts.shared_critic: 94 | self.optimizer_sc = torch.optim.Adam( 95 | [ 96 | { 97 | 'params': self.actor_construct.parameters(), 98 | 'lr': opts.lr_construct, 99 | } 100 | ] 101 | + [ 102 | { 103 | 'params': self.critic_construct.parameters(), 104 | 'lr': opts.lr_trust_degree, 105 | } 106 | ] 107 | ) 108 | 109 | self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( 110 | self.optimizer, 111 | opts.lr_decay, 112 | last_epoch=-1, 113 | ) 114 | 115 | print(f'Distributed: {opts.distributed}') 116 | if opts.use_cuda and not opts.distributed: 117 | self.actor.to(opts.device) 118 | if opts.shared_critic: 119 | self.actor_construct.to(opts.device) 120 | if not opts.eval_only: 121 | self.critic.to(opts.device) 122 | if opts.shared_critic: 123 | self.critic_construct.to(opts.device) 124 | 125 | def load(self, load_path: str) -> None: 126 | assert load_path is not None 127 | load_data = torch_load_cpu(load_path) 128 | # load data for actor 129 | model_actor = get_inner_model(self.actor) 130 | if self.opts.load_original_n2s: 131 | print( 132 | ' [*] Loading original N2S data from {}'.format( 133 | self.opts.load_original_n2s 134 | ) 135 | ) 136 | n2s_load_data = torch_load_cpu(self.opts.load_original_n2s) 137 | model_actor.load_state_dict(n2s_load_data['actor']) 138 | else: 139 | model_actor.load_state_dict(load_data['actor']) 140 | if self.opts.shared_critic: 141 | model_actor_cons = get_inner_model(self.actor_construct) 142 | model_actor_cons.load_state_dict(load_data['actor_construct']) 143 | if not self.opts.eval_only: 144 | # load data for critic 145 | model_critic = get_inner_model(self.critic) 146 | model_critic.load_state_dict(load_data['critic']) 147 | if self.opts.shared_critic: 148 | model_critic_cons = get_inner_model(self.critic_construct) 149 | model_critic_cons.load_state_dict(load_data['critic_construct']) 150 | # load data for optimizer 151 | self.optimizer.load_state_dict(load_data['optimizer']) 152 | if self.opts.shared_critic: 153 | self.optimizer_sc.load_state_dict(load_data['optimizer_sc']) 154 | # load data for torch and cuda 155 | torch.set_rng_state(load_data['rng_state']) 156 | if self.opts.use_cuda: 157 | if isinstance(load_data['cuda_rng_state'], torch.Tensor): 158 | torch.cuda.set_rng_state(load_data['cuda_rng_state']) 159 | else: 160 | if len(load_data['cuda_rng_state']) > 1: 161 | torch.cuda.set_rng_state(load_data['cuda_rng_state'][1]) 162 | else: 163 | torch.cuda.set_rng_state(load_data['cuda_rng_state'][0]) 164 | try: 165 | random.setstate(load_data['random_state']) 166 | except KeyError: 167 | print('Error type: random_state') 168 | # done 169 | print(' [*] Loading data from {}'.format(load_path)) 170 | 171 | def save(self, epoch: int) -> None: 172 | print('Saving model and state...') 173 | torch.save( 174 | { 175 | 'actor': get_inner_model(self.actor).state_dict(), 176 | 'critic': get_inner_model(self.critic).state_dict(), 177 | 'optimizer': self.optimizer.state_dict(), 178 | 'optimizer_sc': self.optimizer_sc.state_dict() 179 | if self.opts.shared_critic 180 | else {}, 181 | 'rng_state': torch.get_rng_state(), 182 | 'cuda_rng_state': torch.cuda.get_rng_state(), 183 | 'actor_construct': get_inner_model(self.actor_construct).state_dict() 184 | if self.opts.shared_critic 185 | else {}, 186 | 'critic_construct': get_inner_model(self.critic_construct).state_dict() 187 | if self.opts.shared_critic 188 | else {}, 189 | 'random_state': random.getstate(), 190 | }, 191 | os.path.join(self.opts.save_dir, 'epoch-{}.pt'.format(epoch)), 192 | ) 193 | 194 | def eval(self) -> None: 195 | torch.set_grad_enabled(False) 196 | self.actor.eval() 197 | if self.opts.shared_critic: 198 | self.actor_construct.eval() 199 | if not self.opts.eval_only: 200 | self.critic.eval() 201 | if self.opts.shared_critic: 202 | self.critic_construct.eval() 203 | 204 | def train(self) -> None: 205 | torch.set_grad_enabled(True) 206 | self.actor.train() 207 | if self.opts.shared_critic: 208 | self.actor_construct.train() 209 | if not self.opts.eval_only: 210 | self.critic.train() 211 | if self.opts.shared_critic: 212 | self.critic_construct.train() 213 | 214 | def rollout( 215 | self, 216 | problem: PDP, 217 | val_m: int, 218 | batch: Dict[str, torch.Tensor], 219 | show_bar: bool, 220 | zoom: bool = False, 221 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 222 | batch = move_to(batch, self.opts.device) 223 | batch_size, graph_size_plus1, node_dim = batch['coordinates'].size() 224 | 225 | batch_augments(val_m, batch, graph_size_plus1, node_dim) 226 | 227 | batch_feature = PDP.input_coordinates( 228 | batch 229 | ) # (new_batch_size, graph_size+1, node_dim) 230 | new_batch_size = batch_feature.size(0) 231 | 232 | if not self.opts.shared_critic: 233 | solution = move_to( 234 | problem.get_initial_solutions(batch), self.opts.device 235 | ).long() 236 | 237 | obj = problem.get_costs(batch_feature, solution, zoom) # (new_batch_size,) 238 | else: 239 | solution_list = [] 240 | obj_list = [] 241 | 242 | if self.opts.inference_sample_size >= self.opts.inference_sample_batch: 243 | ms_batch_feature = batch_feature.unsqueeze(1).repeat( 244 | 1, self.opts.inference_sample_batch, 1, 1 245 | ) 246 | ms_batch_feature = ms_batch_feature.view(-1, graph_size_plus1, node_dim) 247 | 248 | pbar = tqdm( 249 | total=math.ceil( 250 | self.opts.inference_sample_size / self.opts.inference_sample_batch 251 | ) 252 | * len(self.opts.inference_temperature), 253 | disable=self.opts.no_progress_bar or not show_bar, 254 | desc='constructing', 255 | bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}', 256 | ) 257 | 258 | for sample_batch in batch_picker( 259 | self.opts.inference_sample_size, self.opts.inference_sample_batch 260 | ): 261 | if sample_batch < self.opts.inference_sample_batch: 262 | ms_batch_feature = batch_feature.unsqueeze(1).repeat( 263 | 1, sample_batch, 1, 1 264 | ) 265 | ms_batch_feature = ms_batch_feature.view( 266 | -1, graph_size_plus1, node_dim 267 | ) 268 | 269 | for temperature in self.opts.inference_temperature: 270 | if zoom: 271 | ms_batch_feature_4actor, _ = zoom_feature(ms_batch_feature) 272 | else: 273 | ms_batch_feature_4actor = ms_batch_feature 274 | solution, _ = self.actor_construct( 275 | ms_batch_feature_4actor, temperature=temperature 276 | ) 277 | obj = problem.get_costs(ms_batch_feature, solution, zoom) 278 | 279 | solution = solution.view(new_batch_size, sample_batch, -1) 280 | obj = obj.view(new_batch_size, sample_batch) 281 | 282 | solution_list.append(solution) 283 | obj_list.append(obj) 284 | 285 | pbar.update(1) 286 | # pbar.close() 287 | 288 | solution = torch.cat(solution_list, 1) 289 | obj = torch.cat(obj_list, 1) 290 | 291 | min_sol_index = obj.argmin(dim=1) 292 | obj = obj[torch.arange(new_batch_size), min_sol_index] 293 | solution = solution[torch.arange(new_batch_size), min_sol_index] 294 | 295 | if val_m > 1 and False: # shut down 296 | obj_aug = obj.reshape(batch_size, val_m) 297 | solution_aug = solution.reshape(batch_size, val_m, -1) 298 | 299 | min_sol_index_among_val_m = obj_aug.argmin(dim=1) 300 | obj_val_m = obj_aug[torch.arange(batch_size), min_sol_index_among_val_m] 301 | solution_val_m = solution_aug[ 302 | torch.arange(batch_size), min_sol_index_among_val_m 303 | ] 304 | 305 | obj = obj_val_m.unsqueeze(1).repeat(1, val_m).reshape(-1) 306 | solution = ( 307 | solution_val_m.unsqueeze(1) 308 | .repeat(1, val_m, 1) 309 | .reshape(new_batch_size, -1) 310 | ) 311 | 312 | obj_history = [ 313 | torch.cat((obj[:, None], obj[:, None]), -1) 314 | ] # [(new_batch_size, 2)] 315 | 316 | rewards: List[torch.Tensor] = [] 317 | 318 | action = None 319 | action_removal_record = [ 320 | torch.zeros((batch_feature.size(0), problem.size // 2)) 321 | for _ in range(problem.size // 2) # N2S paper section 4.4 last sentence 322 | ] 323 | 324 | for _ in tqdm( 325 | range(self.opts.T_max), 326 | disable=self.opts.no_progress_bar or not show_bar, 327 | desc='rollout', 328 | bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}', 329 | ): 330 | # pass through model 331 | if zoom: 332 | batch_feature_4actor, _ = zoom_feature(batch_feature) 333 | else: 334 | batch_feature_4actor = batch_feature 335 | action = self.actor( 336 | problem, batch_feature_4actor, solution, action, action_removal_record 337 | )[0] 338 | 339 | # new solution 340 | solution, reward, obj, action_removal_record = problem.step( 341 | batch, solution, action, obj, action_removal_record, zoom=zoom 342 | ) 343 | 344 | # record informations 345 | rewards.append(reward) # [(new_batch_size,), ...] 346 | obj_history.append(obj) # [(new_batch_size, 2), ...] 347 | 348 | if self.opts.shared_critic: 349 | pbar.close() 350 | 351 | out = ( 352 | obj[:, -1].reshape(batch_size, val_m).min(1)[0], # (batch_size, 1) 353 | torch.stack(obj_history, 1)[:, :, 0] # current obj history 354 | .view(batch_size, val_m, -1) 355 | .min(1)[0], # (batch_size, T_max) 356 | torch.stack(obj_history, 1)[:, :, -1] # current best obj history 357 | .view(batch_size, val_m, -1) 358 | .min(1)[0], # (batch_size, T_max) 359 | torch.stack(rewards, 1) 360 | .view(batch_size, val_m, -1) 361 | .max(1)[0], # (batch_size, T_max) 362 | ) 363 | 364 | return out 365 | 366 | def start_inference( 367 | self, 368 | problem: PDP, 369 | val_dataset: Optional[str], 370 | tb_logger: Optional[TbLogger], 371 | load_path: Optional[str], 372 | zoom: bool = False, 373 | ) -> None: 374 | if load_path is not None: 375 | self.load(load_path) 376 | if self.opts.distributed: 377 | mp.spawn( 378 | validate, 379 | nprocs=self.opts.world_size, 380 | args=( 381 | problem, 382 | self, 383 | val_dataset, 384 | tb_logger, 385 | True, 386 | None, 387 | False, 388 | zoom, 389 | ), 390 | ) 391 | else: 392 | validate( 393 | 0, problem, self, val_dataset, tb_logger, distributed=False, zoom=zoom 394 | ) 395 | 396 | def start_training( 397 | self, 398 | problem: PDP, 399 | val_dataset: Optional[str], 400 | tb_logger: Optional[TbLogger], 401 | load_path: Optional[str], 402 | ) -> None: 403 | if self.opts.distributed: 404 | mp.spawn( 405 | train, 406 | nprocs=self.opts.world_size, 407 | args=(problem, self, val_dataset, tb_logger, load_path), 408 | ) 409 | else: 410 | train(0, problem, self, val_dataset, tb_logger, load_path) 411 | 412 | 413 | def train( 414 | rank: int, 415 | problem: PDP, 416 | agent: Agent, 417 | val_dataset: Optional[str], 418 | tb_logger: Optional[TbLogger], 419 | load_path: Optional[str], 420 | ) -> None: 421 | opts = agent.opts 422 | 423 | warnings.filterwarnings("ignore") 424 | torch.backends.cudnn.deterministic = True 425 | torch.backends.cudnn.benchmark = False 426 | 427 | if opts.distributed: 428 | device = torch.device("cuda", rank) 429 | torch.distributed.init_process_group( 430 | backend='nccl', world_size=opts.world_size, rank=rank 431 | ) 432 | torch.cuda.set_device(rank) 433 | agent.actor.to(device) 434 | agent.critic.to(device) 435 | 436 | if opts.normalization == 'batch': 437 | agent.actor = torch.nn.SyncBatchNorm.convert_sync_batchnorm(agent.actor).to( 438 | device 439 | ) # type: ignore 440 | agent.critic = torch.nn.SyncBatchNorm.convert_sync_batchnorm( 441 | agent.critic 442 | ).to( 443 | device 444 | ) # type: ignore 445 | 446 | if opts.shared_critic: 447 | agent.actor_construct.to(device) 448 | agent.critic_construct.to(device) 449 | 450 | if opts.sc_normalization == 'batch': 451 | agent.actor_construct = torch.nn.SyncBatchNorm.convert_sync_batchnorm( 452 | agent.actor_construct 453 | ).to( 454 | device 455 | ) # type: ignore 456 | 457 | for state in agent.optimizer_sc.state.values(): 458 | for k, v in state.items(): 459 | if torch.is_tensor(v): 460 | state[k] = v.to(device) 461 | 462 | for state in agent.optimizer.state.values(): 463 | for k, v in state.items(): 464 | if torch.is_tensor(v): 465 | state[k] = v.to(device) 466 | 467 | agent.actor = torch.nn.parallel.DistributedDataParallel( 468 | agent.actor, device_ids=[rank] 469 | ) # type: ignore 470 | if opts.shared_critic: 471 | agent.actor_construct = torch.nn.parallel.DistributedDataParallel( 472 | agent.actor_construct, device_ids=[rank] 473 | ) # type: ignore 474 | if not opts.eval_only: 475 | agent.critic = torch.nn.parallel.DistributedDataParallel( 476 | agent.critic, device_ids=[rank] 477 | ) # type: ignore 478 | if opts.shared_critic: 479 | agent.critic_construct = torch.nn.parallel.DistributedDataParallel( 480 | agent.critic_construct, device_ids=[rank] 481 | ) # type: ignore 482 | 483 | if not opts.no_tb and rank == 0: 484 | tb_logger = TbLogger( 485 | os.path.join( 486 | opts.log_dir, 487 | "{}_{}".format(opts.problem, opts.graph_size), 488 | opts.run_name, 489 | ) 490 | ) 491 | else: 492 | for state in agent.optimizer.state.values(): 493 | for k, v in state.items(): 494 | if torch.is_tensor(v): 495 | state[k] = v.to(opts.device) 496 | if opts.shared_critic: 497 | for state in agent.optimizer_sc.state.values(): 498 | for k, v in state.items(): 499 | if torch.is_tensor(v): 500 | state[k] = v.to(opts.device) 501 | 502 | # check cuda memory 503 | if opts.use_cuda: 504 | if rank == 0: 505 | training_dataset_test = PDP.make_dataset( 506 | size=opts.graph_size, 507 | num_samples=opts.batch_size // opts.world_size, 508 | silence=True, 509 | ) 510 | training_dataloader_test = DataLoader( 511 | training_dataset_test, 512 | batch_size=opts.batch_size // opts.world_size, 513 | shuffle=False, 514 | num_workers=0, 515 | pin_memory=True, 516 | ) 517 | mem_test(agent, problem, next(iter(training_dataloader_test))) 518 | if opts.distributed: 519 | dist.barrier() 520 | 521 | # set or restore seed 522 | if load_path is None: 523 | torch.manual_seed(opts.seed) 524 | random.seed(opts.seed) 525 | else: 526 | agent.load(load_path) 527 | 528 | if opts.distributed: 529 | dist.barrier() 530 | 531 | # Start the actual training loop 532 | for epoch in range(opts.epoch_start, opts.epoch_end): 533 | agent.lr_scheduler.step(epoch) 534 | 535 | # Training mode 536 | if rank == 0: 537 | print('\n\n') 538 | print("|", format(f" Training epoch {epoch} ", "*^60"), "|") 539 | print( 540 | "Training with actor lr={:.3e} critic lr={:.3e} for run {}".format( 541 | agent.optimizer.param_groups[0]['lr'], 542 | agent.optimizer.param_groups[1]['lr'], 543 | opts.run_name, 544 | ), 545 | flush=True, 546 | ) 547 | # prepare training data 548 | training_dataset = PDP.make_dataset( 549 | size=opts.graph_size, num_samples=opts.epoch_size 550 | ) 551 | if opts.distributed: 552 | train_sampler: Any = torch.utils.data.distributed.DistributedSampler( 553 | training_dataset, shuffle=False 554 | ) 555 | training_dataloader = DataLoader( 556 | training_dataset, 557 | batch_size=opts.batch_size // opts.world_size, 558 | shuffle=False, 559 | num_workers=0, 560 | pin_memory=True, 561 | sampler=train_sampler, 562 | ) 563 | else: 564 | training_dataloader = DataLoader( 565 | training_dataset, 566 | batch_size=opts.batch_size, 567 | shuffle=False, 568 | num_workers=0, 569 | pin_memory=True, 570 | ) 571 | 572 | if opts.distributed: 573 | dist.barrier() 574 | 575 | # start training 576 | step = epoch * (opts.epoch_size // opts.batch_size) 577 | pbar = tqdm( 578 | total=(opts.K_epochs) 579 | * (opts.epoch_size // opts.batch_size) 580 | * (opts.T_train // opts.n_step), 581 | disable=opts.no_progress_bar or rank != 0, 582 | desc='training', 583 | bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}', 584 | ) 585 | if opts.shared_critic: 586 | opts.cur_temperature = max( 587 | 1, opts.temperature_init * (opts.temperature_decay**epoch) 588 | ) 589 | if opts.cur_temperature == 1: 590 | opts.cur_init_sample_size = min( 591 | opts.max_init_sample_size, 592 | int( 593 | opts.init_sample_increase 594 | ** (epoch - opts.start_init_sample_epoch) 595 | ), 596 | ) 597 | opts.cur_imitation_augment = max( 598 | 0, 599 | min( 600 | int(epoch * opts.imitation_increase_w + opts.imitation_increase_b), 601 | opts.imitation_max_augment, 602 | ), 603 | ) 604 | for batch in training_dataloader: 605 | train_batch( 606 | rank, 607 | problem, 608 | agent, 609 | epoch, 610 | step, 611 | batch, 612 | tb_logger, 613 | opts, 614 | pbar, 615 | epoch < opts.sc_start_train_epoch, 616 | ) 617 | step += 1 618 | pbar.close() 619 | 620 | # save new model after one epoch 621 | if rank == 0: 622 | if not opts.no_saving and ( 623 | (opts.checkpoint_epochs != 0 and epoch % opts.checkpoint_epochs == 0) 624 | or epoch == opts.epoch_end - 1 625 | ): 626 | agent.save(epoch) 627 | 628 | # validate the new model 629 | validate(rank, problem, agent, val_dataset, tb_logger, id_=epoch) 630 | 631 | # syn 632 | if opts.distributed: 633 | dist.barrier() 634 | 635 | 636 | def train_batch( 637 | rank: int, 638 | problem: PDP, 639 | agent: Agent, 640 | epoch: int, 641 | step: int, 642 | batch: Dict[str, torch.Tensor], 643 | tb_logger: TbLogger, 644 | opts: Option, 645 | pbar: tqdm, 646 | no_sc_train: bool, 647 | ) -> None: 648 | # setup 649 | agent.train() 650 | memory = Memory() 651 | 652 | # prepare the input 653 | batch = ( 654 | move_to(batch, rank) if opts.distributed else move_to(batch, opts.device) 655 | ) # batch_size, graph_size+1, 2 656 | batch_feature: torch.Tensor = ( 657 | move_to(PDP.input_coordinates(batch), rank) 658 | if opts.distributed 659 | else move_to(PDP.input_coordinates(batch), opts.device) 660 | ) 661 | batch_size, graph_size_plus1, node_dim = batch_feature.size() 662 | action = ( 663 | move_to(torch.tensor([-1, -1, -1]).repeat(batch_size, 1), rank) 664 | if opts.distributed 665 | else move_to(torch.tensor([-1, -1, -1]).repeat(batch_size, 1), opts.device) 666 | ) 667 | 668 | action_removal_record = [ 669 | torch.zeros((batch_feature.size(0), problem.size // 2)) 670 | for _ in range(problem.size) # N2S paper section 4.4 last sentence 671 | ] 672 | 673 | # initial solution 674 | if not opts.shared_critic or opts.no_sample_init: 675 | solution: torch.Tensor = ( 676 | move_to(problem.get_initial_solutions(batch), rank) 677 | if opts.distributed 678 | else move_to(problem.get_initial_solutions(batch), opts.device) 679 | ) 680 | obj = problem.get_costs(batch_feature, solution) 681 | best_sol = solution.clone() 682 | else: 683 | agent.eval() 684 | 685 | if opts.cur_temperature > 1: 686 | solution, _ = agent.actor_construct( 687 | batch_feature, temperature=opts.cur_temperature 688 | ) 689 | obj = problem.get_costs(batch_feature, solution) 690 | else: 691 | solution_list = [] 692 | obj_list = [] 693 | 694 | if opts.cur_init_sample_size >= opts.max_init_sample_batch: 695 | ms_batch_feature = batch_feature.unsqueeze(1).repeat( 696 | 1, opts.max_init_sample_batch, 1, 1 697 | ) 698 | ms_batch_feature = ms_batch_feature.view(-1, graph_size_plus1, node_dim) 699 | 700 | for sample_batch in batch_picker( 701 | opts.cur_init_sample_size, opts.max_init_sample_batch 702 | ): 703 | if sample_batch < opts.max_init_sample_batch: 704 | ms_batch_feature = batch_feature.unsqueeze(1).repeat( 705 | 1, sample_batch, 1, 1 706 | ) 707 | ms_batch_feature = ms_batch_feature.view( 708 | -1, graph_size_plus1, node_dim 709 | ) 710 | 711 | solution, _ = agent.actor_construct(ms_batch_feature) 712 | obj = problem.get_costs(ms_batch_feature, solution) 713 | 714 | solution = solution.view(batch_size, sample_batch, -1) 715 | obj = obj.view(batch_size, sample_batch) 716 | 717 | solution_list.append(solution) 718 | obj_list.append(obj) 719 | 720 | solution = torch.cat(solution_list, 1) 721 | obj = torch.cat(obj_list, 1) 722 | 723 | min_sol_index = obj.argmin(dim=1) 724 | obj = obj[torch.arange(batch_size), min_sol_index] 725 | solution = solution[torch.arange(batch_size), min_sol_index] 726 | 727 | best_sol = solution.clone() 728 | 729 | agent.train() 730 | 731 | # warm_up 732 | if opts.warm_up > 0: 733 | agent.eval() 734 | 735 | for _ in range( 736 | min( 737 | opts.max_warm_up, 738 | int(max(0, (epoch - opts.start_warm_up_epoch) // opts.warm_up)), 739 | ) 740 | ): 741 | # get model output 742 | action = agent.actor( 743 | problem, batch_feature, solution, action, action_removal_record 744 | )[0] 745 | 746 | # state transient 747 | solution, rewards, obj, action_removal_record = problem.step( 748 | batch, solution, action, obj, action_removal_record, best_sol 749 | ) 750 | 751 | if opts.warm_up_type == 'update': 752 | obj = obj.view(batch_size, -1)[:, -1] 753 | solution = best_sol 754 | else: 755 | obj = problem.get_costs(batch_feature, solution) 756 | 757 | agent.train() 758 | 759 | # params for training 760 | gamma = opts.gamma 761 | n_step = opts.n_step 762 | T = opts.T_train 763 | K_epochs = opts.K_epochs 764 | eps_clip = opts.eps_clip 765 | t = 0 766 | initial_cost = obj 767 | best_sol = solution.clone() 768 | imitation_loss = None 769 | grad_norms_imi = None 770 | 771 | if opts.sc_map_sample_type == 'augment': 772 | batch_for_sample = {'coordinates': batch_feature.clone()} 773 | batch_augments(opts.sc_map_sample_times, batch_for_sample) 774 | batch_feature_for_sample = batch_for_sample[ 775 | 'coordinates' 776 | ] # (batch_size, augment, graph_size_plus1, node_dim) 777 | sample_index = 0 778 | 779 | # sample trajectory 780 | while t < T: # t will add n_step next time 781 | t_s = t 782 | memory.actions.append(action) 783 | 784 | # data array 785 | total_cost = torch.tensor(0) 786 | 787 | # for first step 788 | entropy_list = [] 789 | bl_val_detached_list = [] 790 | bl_val_list = [] 791 | 792 | if opts.shared_critic: 793 | if opts.sc_map_sample_type == 'augment': 794 | construct_solution, construct_logprobs = agent.actor_construct( 795 | batch_feature_for_sample[:, sample_index, :, :] 796 | ) 797 | sample_index += 1 798 | else: 799 | construct_solution, construct_logprobs = agent.actor_construct( 800 | batch_feature 801 | ) 802 | construct_obj = problem.get_costs(batch_feature, construct_solution) 803 | old_construct_logprobs = construct_logprobs 804 | 805 | obj_of_n2s = [] 806 | 807 | while t - t_s < n_step and not (t == T): 808 | memory.states.append(solution) 809 | memory.action_removal_record.append(action_removal_record) 810 | 811 | # get model output 812 | 813 | action, log_lh, to_critic_, entro_p = agent.actor( 814 | problem, 815 | batch_feature, 816 | solution, 817 | action, 818 | action_removal_record, 819 | require_entropy=True, 820 | to_critic=True, 821 | ) 822 | 823 | memory.actions.append(action) 824 | memory.logprobs.append(log_lh) 825 | memory.best_obj.append(obj.view(obj.size(0), -1)[:, -1].unsqueeze(-1)) 826 | 827 | if opts.shared_critic: 828 | obj_of_n2s.append(obj.view(obj.size(0), -1)[:, 0]) 829 | 830 | entropy_list.append(entro_p.detach().cpu()) 831 | 832 | baseline_val_detached, baseline_val = agent.critic( 833 | to_critic_, obj.view(obj.size(0), -1)[:, -1].unsqueeze(-1) 834 | ) 835 | 836 | bl_val_detached_list.append(baseline_val_detached) 837 | bl_val_list.append(baseline_val) 838 | 839 | # state transient 840 | solution, rewards, obj, action_removal_record = problem.step( 841 | batch, solution, action, obj, action_removal_record, best_sol 842 | ) 843 | memory.rewards.append(rewards) 844 | # memory.mask_true = memory.mask_true + info['swaped'] 845 | 846 | # store info 847 | total_cost = total_cost + obj[:, -1] 848 | 849 | # next 850 | t = t + 1 851 | 852 | # store info 853 | t_time = t - t_s 854 | total_cost = total_cost / t_time 855 | 856 | # begin update ======================= 857 | 858 | # convert list to tensor 859 | all_actions = torch.stack(memory.actions) 860 | old_states = torch.stack(memory.states).detach().view(t_time, batch_size, -1) 861 | old_actions = all_actions[1:].view(t_time, -1, 3) 862 | old_logprobs = torch.stack(memory.logprobs).detach().view(-1) 863 | old_pre_actions = all_actions[:-1].view(t_time, -1, 3) 864 | old_action_removal_record = memory.action_removal_record 865 | 866 | old_best_obj = torch.stack(memory.best_obj) 867 | 868 | # Optimize ppo policy for K mini-epochs: 869 | old_value = None 870 | if opts.shared_critic: 871 | old_value_construct = None 872 | 873 | for k_ in range(K_epochs): 874 | if k_ == 0: 875 | logprobs_list = memory.logprobs 876 | 877 | else: 878 | # Evaluating old actions and values : 879 | logprobs_list = [] 880 | entropy_list = [] 881 | bl_val_detached_list = [] 882 | bl_val_list = [] 883 | 884 | if opts.shared_critic and opts.sc_rl_train_type == 'ppo': 885 | if opts.sc_map_sample_type == 'augment': 886 | _, construct_logprobs = agent.actor_construct( 887 | batch_feature_for_sample[:, sample_index - 1, :, :], 888 | fixed_sol=construct_solution, 889 | ) 890 | else: 891 | _, construct_logprobs = agent.actor_construct( 892 | batch_feature, fixed_sol=construct_solution 893 | ) 894 | 895 | for tt in range(t_time): 896 | # get new action_prob 897 | _, log_p, to_critic_, entro_p = agent.actor( 898 | problem, 899 | batch_feature, 900 | old_states[tt], 901 | old_pre_actions[tt], 902 | old_action_removal_record[tt], 903 | fixed_action=old_actions[tt], # take same action 904 | require_entropy=True, 905 | to_critic=True, 906 | ) 907 | 908 | logprobs_list.append(log_p) 909 | entropy_list.append(entro_p.detach().cpu()) 910 | 911 | baseline_val_detached, baseline_val = agent.critic( 912 | to_critic_, old_best_obj[tt] 913 | ) 914 | 915 | bl_val_detached_list.append(baseline_val_detached) 916 | bl_val_list.append(baseline_val) 917 | 918 | logprobs = torch.stack(logprobs_list).view(-1) 919 | entropy = torch.stack(entropy_list).view(-1) 920 | bl_val_detached = torch.stack(bl_val_detached_list).view(-1) 921 | bl_val = torch.stack(bl_val_list).view(-1) 922 | 923 | if opts.shared_critic: 924 | # bl_construct = torch.stack( 925 | # obj_of_n2s 926 | # ) - _baseline_trust_degree * torch.stack(bl_val_detached_list) 927 | # bl_construct = bl_construct.mean(0) 928 | ( 929 | bl_construct_detach, 930 | bl_construct, 931 | trust_degree, 932 | ) = agent.critic_construct(obj_of_n2s, bl_val_detached_list) 933 | 934 | # obj_of_n2s = [] 935 | 936 | # get traget value for critic 937 | Reward_list = [] 938 | reward_reversed = memory.rewards[::-1] 939 | 940 | # estimate return 941 | R = agent.critic( 942 | agent.actor( 943 | problem, 944 | batch_feature, 945 | solution, 946 | action, 947 | action_removal_record, 948 | only_critic=True, 949 | )[0], 950 | obj.view(obj.size(0), -1)[:, -1].unsqueeze(-1), 951 | )[0] 952 | for r in range(len(reward_reversed)): 953 | R = R * gamma + reward_reversed[r] 954 | Reward_list.append(R) 955 | 956 | # clip the target: 957 | Reward = torch.stack(Reward_list[::-1], 0) # (n_step, batch_size) 958 | Reward = Reward.view(-1) 959 | 960 | # Finding the ratio (pi_theta / pi_theta__old): 961 | ratios = torch.exp(logprobs - old_logprobs.detach()) 962 | 963 | # Finding Surrogate Loss: 964 | advantages = Reward - bl_val_detached 965 | 966 | surr1 = ratios * advantages 967 | surr2 = torch.clamp(ratios, 1 - eps_clip, 1 + eps_clip) * advantages 968 | reinforce_loss = -torch.min(surr1, surr2).mean() 969 | 970 | # define baseline loss 971 | if old_value is None: 972 | baseline_loss = ((bl_val - Reward) ** 2).mean() 973 | old_value = bl_val.detach() 974 | else: 975 | vpredclipped = old_value + torch.clamp( 976 | bl_val - old_value, -eps_clip, eps_clip 977 | ) 978 | v_max = torch.max( 979 | ((bl_val - Reward) ** 2), ((vpredclipped - Reward) ** 2) 980 | ) 981 | baseline_loss = v_max.mean() 982 | 983 | if opts.shared_critic: 984 | if opts.sc_rl_train_type == 'ppo': 985 | ratios_construct = torch.exp( 986 | construct_logprobs - old_construct_logprobs.detach() 987 | ).view(-1) 988 | advantages_construct = ( 989 | (construct_obj - bl_construct_detach).view(-1).detach() 990 | ) 991 | surr1_construct = ratios_construct * advantages_construct 992 | surr2_construct = ( 993 | torch.clamp(ratios_construct, 1 - eps_clip, 1 + eps_clip) 994 | * advantages_construct 995 | ) 996 | reinforce_loss_construct = torch.max( 997 | surr1_construct, surr2_construct 998 | ).mean() 999 | 1000 | if old_value_construct is None: 1001 | baseline_loss_construct = ( 1002 | (bl_construct - construct_obj.detach()) ** 2 1003 | ).mean() 1004 | old_value_construct = bl_construct.detach() 1005 | else: 1006 | vpredclipped_construct = old_value_construct + torch.clamp( 1007 | bl_construct - old_value_construct, -eps_clip, eps_clip 1008 | ) 1009 | v_max_construct = torch.max( 1010 | ((bl_construct - construct_obj.detach()) ** 2), 1011 | ((vpredclipped_construct - construct_obj.detach()) ** 2), 1012 | ) 1013 | baseline_loss_construct = v_max_construct.mean() 1014 | 1015 | elif opts.sc_rl_train_type == 'pg': 1016 | if k_ == 0: 1017 | reinforce_loss_construct = ( 1018 | construct_logprobs.view(-1) 1019 | * (construct_obj - bl_construct_detach).view(-1).detach() 1020 | ).mean() 1021 | baseline_loss_construct = ( 1022 | (bl_construct - construct_obj.detach()) ** 2 1023 | ).mean() 1024 | else: 1025 | reinforce_loss_construct = reinforce_loss_construct.detach() 1026 | baseline_loss_construct = baseline_loss_construct.detach() 1027 | 1028 | # check K-L divergence 1029 | approx_kl_divergence = ( 1030 | (0.5 * (old_logprobs.detach() - logprobs) ** 2).mean().detach() 1031 | ) 1032 | approx_kl_divergence[torch.isinf(approx_kl_divergence)] = 0 1033 | 1034 | # calculate loss 1035 | if opts.shared_critic: 1036 | loss = ( 1037 | baseline_loss 1038 | + reinforce_loss 1039 | + reinforce_loss_construct 1040 | + baseline_loss_construct 1041 | ) 1042 | else: 1043 | loss = baseline_loss + reinforce_loss # - 1e-5 * entropy.mean() 1044 | 1045 | # update gradient step 1046 | agent.optimizer.zero_grad() 1047 | if opts.shared_critic: 1048 | agent.optimizer_sc.zero_grad() 1049 | 1050 | loss.backward() 1051 | 1052 | # Clip gradient norm and get (clipped) gradient norms for logging 1053 | current_step = int( 1054 | step * T / n_step * K_epochs + (t - 1) // n_step * K_epochs + k_ 1055 | ) 1056 | 1057 | grad_norms = clip_grad_norms( 1058 | agent.optimizer.param_groups, opts.max_grad_norm 1059 | ) 1060 | 1061 | # perform gradient descent 1062 | agent.optimizer.step() 1063 | 1064 | if opts.shared_critic: 1065 | grad_norms_sc_actor = clip_grad_norms( 1066 | agent.optimizer_sc.param_groups[:1], opts.max_grad_norm_construct 1067 | ) 1068 | grad_norms_sc_critic = clip_grad_norms( 1069 | agent.optimizer_sc.param_groups[1:], opts.max_grad_norm 1070 | ) 1071 | grad_norms[0].extend(grad_norms_sc_actor[0] + grad_norms_sc_critic[0]) 1072 | grad_norms[1].extend(grad_norms_sc_actor[1] + grad_norms_sc_critic[1]) 1073 | 1074 | if not no_sc_train: 1075 | agent.optimizer_sc.step() 1076 | 1077 | # imitation learning 1078 | if ( 1079 | opts.shared_critic 1080 | and opts.imitation_step > 0 1081 | and t % opts.imitation_step == 0 1082 | and k_ == K_epochs - 1 1083 | and not no_sc_train 1084 | ): 1085 | is_good = (memory.best_obj[-1].view(-1) < construct_obj).float() 1086 | batch_for_imi = {'coordinates': batch_feature.clone()} 1087 | batch_augments( 1088 | opts.cur_imitation_augment, batch_for_imi, one_is_keep=False 1089 | ) 1090 | batch_feature_for_imi = batch_for_imi[ 1091 | 'coordinates' 1092 | ] # (batch_size, imitation_augment, graph_size_plus1, node_dim) 1093 | 1094 | # imi_adv = (memory.best_obj[-1].view(-1) - construct_obj).detach() 1095 | 1096 | for i in range(opts.cur_imitation_augment): 1097 | _, teaching_logprobs = agent.actor_construct( 1098 | batch_feature_for_imi[:, i, :, :], fixed_sol=best_sol 1099 | ) 1100 | imitation_loss = -( 1101 | (is_good * teaching_logprobs).mean() * opts.imitation_rate 1102 | ) 1103 | 1104 | # update gradient step 1105 | agent.optimizer_sc.zero_grad() 1106 | imitation_loss.backward() 1107 | 1108 | grad_norms_imi = clip_grad_norms( 1109 | agent.optimizer_sc.param_groups[:1], 1110 | opts.imitation_max_grad_norm, 1111 | ) 1112 | 1113 | # perform gradient descent 1114 | agent.optimizer_sc.step() 1115 | 1116 | # Logging to tensorboard 1117 | if (not opts.no_tb) and rank == 0: 1118 | if (current_step + 1) % int(opts.log_step) == 0: 1119 | log_to_tb_train( 1120 | tb_logger, 1121 | batch_feature[0, 0], 1122 | agent, 1123 | Reward, 1124 | ratios, 1125 | bl_val_detached, 1126 | total_cost, 1127 | grad_norms, 1128 | memory.rewards, 1129 | entropy, 1130 | approx_kl_divergence, 1131 | reinforce_loss, 1132 | baseline_loss, 1133 | logprobs, 1134 | initial_cost, 1135 | current_step + 1, 1136 | construct_obj if opts.shared_critic else None, 1137 | reinforce_loss_construct if opts.shared_critic else None, 1138 | bl_construct_detach if opts.shared_critic else None, 1139 | baseline_loss_construct if opts.shared_critic else None, 1140 | trust_degree if opts.shared_critic else None, 1141 | ratios_construct 1142 | if opts.shared_critic and opts.sc_rl_train_type == 'ppo' 1143 | else None, 1144 | imitation_loss if opts.shared_critic else None, 1145 | grad_norms_imi if opts.shared_critic else None, 1146 | ) 1147 | 1148 | if rank == 0: 1149 | pbar.update(1) 1150 | 1151 | # end update 1152 | memory.clear_memory() 1153 | -------------------------------------------------------------------------------- /nets/graph_layers.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Tuple 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.distributions import Categorical 5 | import numpy as np 6 | from torch import nn 7 | import math 8 | 9 | from problems.problem_pdp import PDP 10 | 11 | TYPE_REMOVAL = 'N2S' 12 | # TYPE_REMOVAL = 'random' 13 | # TYPE_REMOVAL = 'greedy' 14 | 15 | TYPE_REINSERTION = 'N2S' 16 | # TYPE_REINSERTION = 'random' 17 | # TYPE_REINSERTION = 'greedy' 18 | 19 | 20 | class SkipConnection(nn.Module): 21 | def __init__(self, module: nn.Module) -> None: 22 | super().__init__() 23 | self.module = module 24 | 25 | __call__: Callable[['SkipConnection', torch.Tensor], torch.Tensor] 26 | 27 | def forward(self, input: torch.Tensor) -> torch.Tensor: 28 | return input + self.module(input) 29 | 30 | 31 | class MultiHeadAttention(nn.Module): 32 | def __init__( 33 | self, 34 | n_heads: int, 35 | in_query_dim: int, 36 | in_key_dim: int, 37 | in_val_dim: Optional[int], 38 | out_dim: int, 39 | ) -> None: 40 | super().__init__() 41 | 42 | hidden_dim = out_dim // n_heads 43 | 44 | self.n_heads = n_heads 45 | self.out_dim = out_dim 46 | self.hidden_dim = hidden_dim 47 | self.in_query_dim = in_query_dim 48 | self.in_key_dim = in_key_dim 49 | self.in_val_dim = in_val_dim 50 | 51 | self.norm_factor = 1 / math.sqrt(hidden_dim) # See Attention is all you need 52 | 53 | self.W_query = nn.Parameter(torch.Tensor(n_heads, in_query_dim, hidden_dim)) 54 | self.W_key = nn.Parameter(torch.Tensor(n_heads, in_key_dim, hidden_dim)) 55 | if in_val_dim is not None: # else calculate attention score 56 | self.W_val = nn.Parameter(torch.Tensor(n_heads, in_val_dim, hidden_dim)) 57 | self.W_out = nn.Parameter(torch.Tensor(n_heads, hidden_dim, out_dim)) 58 | 59 | self.init_parameters() 60 | 61 | def init_parameters(self) -> None: 62 | 63 | for param in self.parameters(): 64 | stdv = 1.0 / math.sqrt(param.size(-1)) 65 | param.data.uniform_(-stdv, stdv) 66 | 67 | __call__: Callable[..., torch.Tensor] 68 | 69 | def forward( 70 | self, 71 | q: torch.Tensor, 72 | k: torch.Tensor, 73 | v: Optional[torch.Tensor] = None, 74 | with_norm: bool = False, 75 | ) -> torch.Tensor: 76 | 77 | if self.in_val_dim is None: # calculate attention score 78 | assert v is None 79 | 80 | batch_size, n_query, in_que_dim = q.size() 81 | _, n_key, in_key_dim = k.size() 82 | 83 | if v is not None: 84 | in_val_dim = v.size(2) 85 | 86 | qflat = q.contiguous().view( 87 | -1, in_que_dim 88 | ) # (batch_size * n_query, in_que_dim) 89 | kflat = k.contiguous().view(-1, in_key_dim) # (batch_size * n_key, in_key_dim) 90 | if v is not None: 91 | vflat = v.contiguous().view(-1, in_val_dim) 92 | 93 | shp_q = (self.n_heads, batch_size, n_query, self.hidden_dim) 94 | shp_kv = (self.n_heads, batch_size, n_key, self.hidden_dim) 95 | 96 | # Calculate queries, (n_heads, batch_size, n_query, hidden_dim) 97 | Q = torch.matmul(qflat, self.W_query).view(shp_q) 98 | # self.W_que: (n_heads, in_que_dim, hidden_dim) 99 | # Q_before_view: (n_heads, batch_size * n_query, hidden_dim) 100 | 101 | # Calculate keys and values (n_heads, batch_size, n_key, hidden_dim) 102 | K = torch.matmul(kflat, self.W_key).view(shp_kv) 103 | if v is not None: 104 | V = torch.matmul(vflat, self.W_val).view(shp_kv) 105 | 106 | # Calculate compatibility (n_heads, batch_size, n_query, n_key) 107 | compatibility = torch.matmul(Q, K.transpose(2, 3)) 108 | 109 | if v is None and not with_norm: 110 | return compatibility 111 | 112 | compatibility = self.norm_factor * compatibility 113 | 114 | if v is None and with_norm: 115 | return compatibility 116 | 117 | attn = F.softmax(compatibility, dim=-1) 118 | 119 | heads = torch.matmul(attn, V) # (n_heads, batch_size, n_query, hidden_dim) 120 | 121 | out = torch.mm( 122 | heads.permute(1, 2, 0, 3) # (batch_size, n_query, n_heads, hidden_dim) 123 | .contiguous() 124 | .view( 125 | -1, self.n_heads * self.hidden_dim 126 | ), # (batch_size * n_query, n_heads * hidden_dim) 127 | self.W_out.view(-1, self.out_dim), # (n_heads * hidden_dim, out_dim) 128 | ).view(batch_size, n_query, self.out_dim) 129 | 130 | return out 131 | 132 | 133 | class MultiHeadSelfAttention(nn.Module): 134 | def __init__(self, n_heads: int, input_dim: int) -> None: 135 | super().__init__() 136 | self.MHA = MultiHeadAttention( 137 | n_heads, input_dim, input_dim, input_dim, input_dim 138 | ) 139 | 140 | __call__: Callable[..., torch.Tensor] 141 | 142 | def forward(self, q: torch.Tensor) -> torch.Tensor: 143 | return self.MHA(q, q, q) 144 | 145 | 146 | class MHA_Self_Score_WithoutNorm(nn.Module): 147 | def __init__(self, n_heads: int, input_dim: int) -> None: 148 | super().__init__() 149 | self.MHA = MultiHeadAttention(n_heads, input_dim, input_dim, None, input_dim) 150 | 151 | __call__: Callable[..., torch.Tensor] 152 | 153 | def forward(self, q: torch.Tensor) -> torch.Tensor: 154 | return self.MHA(q, q, with_norm=False) 155 | 156 | 157 | class MLP(nn.Module): 158 | def __init__( 159 | self, 160 | input_dim: int = 128, 161 | feed_forward_dim: int = 64, 162 | embedding_dim: int = 64, 163 | output_dim: int = 1, 164 | p_dropout: float = 0.01, 165 | ) -> None: 166 | super().__init__() 167 | self.fc1 = nn.Linear(input_dim, feed_forward_dim) 168 | self.fc2 = nn.Linear(feed_forward_dim, embedding_dim) 169 | self.fc3 = nn.Linear(embedding_dim, output_dim) 170 | self.dropout = nn.Dropout(p=p_dropout) 171 | self.ReLU = nn.ReLU(inplace=True) 172 | 173 | self.init_parameters() 174 | 175 | def init_parameters(self) -> None: 176 | 177 | for param in self.parameters(): 178 | stdv = 1.0 / math.sqrt(param.size(-1)) 179 | param.data.uniform_(-stdv, stdv) 180 | 181 | __call__: Callable[..., torch.Tensor] 182 | 183 | def forward(self, input: torch.Tensor) -> torch.Tensor: 184 | result = self.ReLU(self.fc1(input)) 185 | result = self.dropout(result) 186 | result = self.ReLU(self.fc2(result)) 187 | result = self.fc3(result).squeeze(-1) 188 | return result 189 | 190 | 191 | class CriticDecoder(nn.Module): 192 | def __init__(self, input_dim: int) -> None: 193 | super().__init__() 194 | self.input_dim = input_dim 195 | 196 | self.project_graph = nn.Linear(self.input_dim, self.input_dim // 2) 197 | 198 | self.project_node = nn.Linear(self.input_dim, self.input_dim // 2) 199 | 200 | self.MLP = MLP(input_dim + 1, input_dim) 201 | 202 | __call__: Callable[..., torch.Tensor] 203 | 204 | def forward(self, y: torch.Tensor, best_cost: torch.Tensor) -> torch.Tensor: 205 | 206 | # h_wave: (batch_size, graph_size+1, input_size) 207 | mean_pooling = y.mean(1) # mean Pooling (batch_size, input_size) 208 | graph_feature: torch.Tensor = self.project_graph(mean_pooling)[ 209 | :, None, : 210 | ] # (batch_size, 1, input_dim/2) 211 | node_feature: torch.Tensor = self.project_node( 212 | y 213 | ) # (batch_size, graph_size+1, input_dim/2) 214 | 215 | # pass through value_head, get estimated value 216 | fusion = node_feature + graph_feature.expand_as( 217 | node_feature 218 | ) # (batch_size, graph_size+1, input_dim/2) 219 | 220 | fusion_feature = torch.cat( 221 | ( 222 | fusion.mean(1), 223 | fusion.max(1)[0], # max_pooling 224 | best_cost.to(y.device), 225 | ), 226 | -1, 227 | ) # (batch_size, input_dim + 1) 228 | 229 | value = self.MLP(fusion_feature) 230 | 231 | return value 232 | 233 | 234 | class NodePairRemovalDecoder(nn.Module): # (12) (13) 235 | def __init__(self, n_heads: int, input_dim: int, type_: str) -> None: 236 | super().__init__() 237 | 238 | # hidden_dim = input_dim // n_heads 239 | self.n_heads = n_heads 240 | self.input_dim = input_dim 241 | self.type_ = type_ 242 | 243 | if self.type_ == 'update2': 244 | hidden_dim = input_dim // n_heads 245 | self.hidden_dim = hidden_dim 246 | 247 | self.W_Q = nn.Parameter(torch.Tensor(n_heads, input_dim, hidden_dim)) 248 | self.W_K = nn.Parameter(torch.Tensor(n_heads, input_dim, hidden_dim)) 249 | self.W_Q_2 = nn.Parameter(torch.Tensor(n_heads, input_dim, hidden_dim)) 250 | self.W_K_2 = nn.Parameter(torch.Tensor(n_heads, input_dim, hidden_dim)) 251 | self.W_Q_3 = nn.Parameter(torch.Tensor(n_heads, input_dim, hidden_dim)) 252 | self.W_K_3 = nn.Parameter(torch.Tensor(n_heads, input_dim, hidden_dim)) 253 | self.agg = MLP(6 * n_heads + 4, 64, 32, 1, 0) 254 | elif self.type_ in ('origin', 'glitch', 'update1'): 255 | hidden_dim = input_dim 256 | self.hidden_dim = hidden_dim 257 | 258 | self.W_Q = nn.Parameter(torch.Tensor(n_heads, input_dim, hidden_dim)) 259 | self.W_K = nn.Parameter(torch.Tensor(n_heads, input_dim, hidden_dim)) 260 | self.agg = MLP(2 * n_heads + 4, 32, 32, 1, 0) 261 | else: 262 | raise NotImplementedError 263 | 264 | self.pair_with: Optional[torch.Tensor] = None 265 | 266 | self.init_parameters() 267 | 268 | def init_parameters(self) -> None: 269 | 270 | for param in self.parameters(): 271 | stdv = 1.0 / math.sqrt(param.size(-1)) 272 | param.data.uniform_(-stdv, stdv) 273 | 274 | __call__: Callable[..., torch.Tensor] 275 | 276 | def forward( 277 | self, 278 | h_hat: torch.Tensor, # hidden state from encoder 279 | solution: torch.Tensor, # if solution=[2,0,1], means 0->2->1->0. 280 | selection_recent: torch.Tensor, # (batch_size, 4, graph_size/2) 281 | ) -> torch.Tensor: 282 | 283 | pre = solution.argsort() # pre=[1,2,0] 284 | post = solution # post=[2,0,1] 285 | 286 | if self.type_ == 'glitch': 287 | post = solution.gather(1, solution) # use post-post 288 | 289 | batch_size, graph_size_plus1, input_dim = h_hat.size() 290 | 291 | hflat = h_hat.contiguous().view( 292 | -1, input_dim 293 | ) # (batch_size * graph_size+1, input_dim) 294 | 295 | shp = (self.n_heads, batch_size, graph_size_plus1, self.hidden_dim) 296 | 297 | # Calculate queries, (n_heads, batch_size, graph_size+1, key_size) 298 | hidden_Q = torch.matmul(hflat, self.W_Q).view(shp) 299 | hidden_K = torch.matmul(hflat, self.W_K).view(shp) 300 | 301 | Q_pre = hidden_Q.gather( 302 | 2, pre.view(1, batch_size, graph_size_plus1, 1).expand_as(hidden_Q) 303 | ) 304 | K_post = hidden_K.gather( 305 | 2, post.view(1, batch_size, graph_size_plus1, 1).expand_as(hidden_Q) 306 | ) 307 | 308 | if self.type_ == 'update1': 309 | half_size = graph_size_plus1 // 2 310 | 311 | pre_pre = pre.gather(1, pre) 312 | post_post = solution.gather(1, solution) 313 | 314 | if self.pair_with is None: 315 | self.pair_with = torch.arange(graph_size_plus1, device=solution.device) 316 | self.pair_with[1 : half_size + 1] += half_size 317 | self.pair_with[half_size + 1 :] -= half_size 318 | 319 | pair_with = self.pair_with.expand( 320 | self.n_heads, batch_size, graph_size_plus1 321 | ) 322 | 323 | Q_pre_pre = hidden_Q.gather( 324 | 2, pre_pre.view(1, batch_size, graph_size_plus1, 1).expand_as(hidden_Q) 325 | ) 326 | K_post_post = hidden_K.gather( 327 | 2, 328 | post_post.view(1, batch_size, graph_size_plus1, 1).expand_as(hidden_Q), 329 | ) 330 | 331 | need_post_post = post == pair_with 332 | need_pre_pre = pre == pair_with 333 | 334 | Q_pre[need_pre_pre] = Q_pre_pre[need_pre_pre] 335 | K_post[need_post_post] = K_post_post[need_post_post] 336 | 337 | compatibility = ( 338 | (Q_pre * hidden_K).sum(-1) 339 | + (hidden_Q * K_post).sum(-1) 340 | - (Q_pre * K_post).sum(-1) 341 | )[ 342 | :, :, 1: 343 | ] # (n_heads, batch_size, graph_size) (12) 344 | 345 | if self.type_ == 'update2': 346 | post_post = solution.gather(1, solution) 347 | 348 | hidden_Q_2 = torch.matmul(hflat, self.W_Q_2).view(shp) 349 | hidden_K_2 = torch.matmul(hflat, self.W_K_2).view(shp) 350 | 351 | Q_pre_2 = hidden_Q_2.gather( 352 | 2, 353 | pre.view(1, batch_size, graph_size_plus1, 1).expand_as(hidden_Q_2), 354 | ) 355 | K_post_post = hidden_K_2.gather( 356 | 2, 357 | post_post.view(1, batch_size, graph_size_plus1, 1).expand_as( 358 | hidden_Q_2 359 | ), 360 | ) 361 | 362 | compatibility_2 = ( 363 | (Q_pre_2 * hidden_K_2).sum(-1) 364 | + (hidden_Q_2 * K_post_post).sum(-1) 365 | - (Q_pre_2 * K_post_post).sum(-1) 366 | )[ 367 | :, :, 1: 368 | ] # (n_heads, batch_size, graph_size) (12 pre-ppost) 369 | 370 | pre_pre = pre.gather(1, pre) 371 | 372 | hidden_Q_3 = torch.matmul(hflat, self.W_Q_3).view(shp) 373 | hidden_K_3 = torch.matmul(hflat, self.W_K_3).view(shp) 374 | 375 | Q_pre_pre = hidden_Q_3.gather( 376 | 2, 377 | pre_pre.view(1, batch_size, graph_size_plus1, 1).expand_as(hidden_Q_3), 378 | ) 379 | K_post_3 = hidden_K_3.gather( 380 | 2, 381 | post.view(1, batch_size, graph_size_plus1, 1).expand_as(hidden_Q_3), 382 | ) 383 | 384 | compatibility_3 = ( 385 | (Q_pre_pre * hidden_K_3).sum(-1) 386 | + (hidden_Q_3 * K_post_3).sum(-1) 387 | - (Q_pre_pre * K_post_3).sum(-1) 388 | )[ 389 | :, :, 1: 390 | ] # (n_heads, batch_size, graph_size) (12 ppre-post) 391 | 392 | compatibility_pairing = torch.cat( 393 | ( 394 | compatibility[:, :, : graph_size_plus1 // 2], 395 | compatibility[:, :, graph_size_plus1 // 2 :], 396 | compatibility_2[:, :, : graph_size_plus1 // 2], 397 | compatibility_2[:, :, graph_size_plus1 // 2 :], 398 | compatibility_3[:, :, : graph_size_plus1 // 2], 399 | compatibility_3[:, :, graph_size_plus1 // 2 :], 400 | ), 401 | 0, 402 | ) # (n_heads*6, batch_size, graph_size/2) 403 | 404 | else: 405 | compatibility_pairing = torch.cat( 406 | ( 407 | compatibility[:, :, : graph_size_plus1 // 2], 408 | compatibility[:, :, graph_size_plus1 // 2 :], 409 | ), 410 | 0, 411 | ) # (n_heads*2, batch_size, graph_size/2) 412 | 413 | compatibility_pairing = self.agg( 414 | torch.cat( 415 | ( 416 | compatibility_pairing.permute(1, 2, 0), 417 | selection_recent.permute(0, 2, 1), 418 | ), 419 | -1, 420 | ) 421 | ).squeeze() # (batch_size, graph_size/2) 422 | 423 | return compatibility_pairing 424 | 425 | 426 | class NodePairReinsertionDecoder(nn.Module): # (14) (15) 427 | def __init__(self, n_heads: int, input_dim: int) -> None: 428 | super().__init__() 429 | 430 | self.n_heads = n_heads 431 | 432 | self.compater_insert1 = MultiHeadAttention( 433 | n_heads, input_dim, input_dim, None, input_dim * n_heads 434 | ) 435 | 436 | self.compater_insert2 = MultiHeadAttention( 437 | n_heads, input_dim, input_dim, None, input_dim * n_heads 438 | ) 439 | 440 | self.agg = MLP(4 * n_heads, 32, 32, 1, 0) 441 | 442 | def init_parameters(self) -> None: 443 | 444 | for param in self.parameters(): 445 | stdv = 1.0 / math.sqrt(param.size(-1)) 446 | param.data.uniform_(-stdv, stdv) 447 | 448 | __call__: Callable[..., torch.Tensor] 449 | 450 | def forward( 451 | self, 452 | h_hat: torch.Tensor, 453 | pos_pickup: torch.Tensor, # (batch_size) 454 | pos_delivery: torch.Tensor, # (batch_size) 455 | solution: torch.Tensor, # (batch, graph_size+1) 456 | ) -> torch.Tensor: 457 | 458 | batch_size, graph_size_plus1, input_dim = h_hat.size() 459 | shp = (batch_size, graph_size_plus1, graph_size_plus1, self.n_heads) 460 | shp_p = (batch_size, -1, 1, self.n_heads) 461 | shp_d = (batch_size, 1, -1, self.n_heads) 462 | 463 | arange = torch.arange(batch_size, device=h_hat.device) 464 | h_pickup = h_hat[arange, pos_pickup].unsqueeze(1) # (batch_size, 1, input_dim) 465 | h_delivery = h_hat[arange, pos_delivery].unsqueeze( 466 | 1 467 | ) # (batch_size, 1, input_dim) 468 | h_K_neibour = h_hat.gather( 469 | 1, solution.view(batch_size, graph_size_plus1, 1).expand_as(h_hat) 470 | ) # (batch_size, graph_size+1, input_dim) 471 | 472 | compatibility_pickup_pre = ( 473 | self.compater_insert1( 474 | h_pickup, h_hat 475 | ) # (n_heads, batch_size, 1, graph_size+1) 476 | .permute(1, 2, 3, 0) # (batch_size, 1, graph_size+1, n_heads) 477 | .view(shp_p) # (batch_size, graph_size+1, 1, n_heads) 478 | .expand(shp) # (batch_size, graph_size+1, graph_size+1, n_heads) 479 | ) 480 | compatibility_pickup_post = ( 481 | self.compater_insert2(h_pickup, h_K_neibour) 482 | .permute(1, 2, 3, 0) 483 | .view(shp_p) 484 | .expand(shp) 485 | ) 486 | compatibility_delivery_pre = ( 487 | self.compater_insert1( 488 | h_delivery, h_hat 489 | ) # (n_heads, batch_size, 1, graph_size+1) 490 | .permute(1, 2, 3, 0) # (batch_size, 1, graph_size+1, n_heads) 491 | .view(shp_d) # (batch_size, 1, graph_size+1, n_heads) 492 | .expand(shp) # (batch_size, graph_size+1, graph_size+1, n_heads) 493 | ) 494 | compatibility_delivery_post = ( 495 | self.compater_insert2(h_delivery, h_K_neibour) 496 | .permute(1, 2, 3, 0) 497 | .view(shp_d) 498 | .expand(shp) 499 | ) 500 | 501 | compatibility = self.agg( 502 | torch.cat( 503 | ( 504 | compatibility_pickup_pre, 505 | compatibility_pickup_post, 506 | compatibility_delivery_pre, 507 | compatibility_delivery_post, 508 | ), 509 | -1, 510 | ) 511 | ).squeeze() 512 | return compatibility # (batch_size, graph_size+1, graph_size+1) 513 | 514 | 515 | class N2SDecoder(nn.Module): 516 | def __init__( 517 | self, n_heads: int, input_dim: int, v_range: float, removal_type: str 518 | ) -> None: 519 | super().__init__() 520 | self.input_dim = input_dim 521 | self.v_range = v_range 522 | 523 | if TYPE_REMOVAL == 'N2S': 524 | self.compater_removal = NodePairRemovalDecoder( 525 | n_heads, input_dim, removal_type 526 | ) 527 | if TYPE_REINSERTION == 'N2S': 528 | self.compater_reinsertion = NodePairReinsertionDecoder(n_heads, input_dim) 529 | 530 | self.project_graph = nn.Linear(self.input_dim, self.input_dim, bias=False) 531 | self.project_node = nn.Linear(self.input_dim, self.input_dim, bias=False) 532 | 533 | def init_parameters(self): 534 | 535 | for param in self.parameters(): 536 | stdv = 1.0 / math.sqrt(param.size(-1)) 537 | param.data.uniform_(-stdv, stdv) 538 | 539 | __call__: Callable[..., Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] 540 | 541 | def forward( 542 | self, 543 | problem: PDP, 544 | h_wave: torch.Tensor, 545 | solution: torch.Tensor, 546 | x_in: torch.Tensor, 547 | top2: torch.Tensor, 548 | visit_index: torch.Tensor, 549 | pre_action: Optional[torch.Tensor], 550 | selection_recent: torch.Tensor, 551 | fixed_action: Optional[torch.Tensor], 552 | require_entropy: bool, 553 | ): 554 | 555 | batch_size, graph_size_plus1, input_dim = h_wave.size() 556 | half_pos = (graph_size_plus1 - 1) // 2 557 | 558 | arange = torch.arange(batch_size) 559 | 560 | h_hat: torch.Tensor = self.project_node(h_wave) + self.project_graph( 561 | h_wave.max(1)[0] 562 | )[:, None, :].expand( 563 | batch_size, graph_size_plus1, input_dim 564 | ) # (11) 565 | 566 | ############# action1 removal 567 | if TYPE_REMOVAL == 'N2S': 568 | action_removal_table = ( 569 | torch.tanh( 570 | self.compater_removal(h_hat, solution, selection_recent).squeeze() 571 | ) 572 | * self.v_range 573 | ) 574 | if pre_action is not None and pre_action[0, 0] > 0: 575 | action_removal_table[arange, pre_action[:, 0]] = -1e20 576 | log_ll_removal = ( 577 | F.log_softmax(action_removal_table, dim=-1) if self.training else None 578 | ) # log-likelihood 579 | probs_removal = F.softmax(action_removal_table, dim=-1) 580 | elif TYPE_REMOVAL == 'random': 581 | probs_removal = torch.rand(batch_size, graph_size_plus1 // 2).to( 582 | h_wave.device 583 | ) 584 | elif TYPE_REMOVAL == 'greedy': 585 | # epi-greedy 586 | first_row = ( 587 | torch.arange(graph_size_plus1, device=solution.device) 588 | .long() 589 | .unsqueeze(0) 590 | .expand(batch_size, graph_size_plus1) 591 | ) 592 | d_i = x_in.gather( 593 | 1, first_row.unsqueeze(-1).expand(batch_size, graph_size_plus1, 2) 594 | ) 595 | d_i_next = x_in.gather( 596 | 1, solution.long().unsqueeze(-1).expand(batch_size, graph_size_plus1, 2) 597 | ) 598 | d_i_pre = x_in.gather( 599 | 1, 600 | solution.argsort() 601 | .long() 602 | .unsqueeze(-1) 603 | .expand(batch_size, graph_size_plus1, 2), 604 | ) 605 | cost_ = ( 606 | (d_i_pre - d_i).norm(p=2, dim=2) 607 | + (d_i - d_i_next).norm(p=2, dim=2) 608 | - (d_i_pre - d_i_next).norm(p=2, dim=2) 609 | )[:, 1:] 610 | probs_removal = ( 611 | cost_[:, : graph_size_plus1 // 2] + cost_[:, graph_size_plus1 // 2 :] 612 | ) 613 | probs_removal_random = torch.rand(batch_size, graph_size_plus1 // 2).to( 614 | h_wave.device 615 | ) 616 | else: 617 | assert False 618 | 619 | if fixed_action is not None: 620 | action_removal = fixed_action[:, :1] 621 | else: 622 | if TYPE_REMOVAL == 'greedy': 623 | action_removal_random = probs_removal_random.multinomial(1) 624 | action_removal_greedy = probs_removal.max(-1)[1].unsqueeze(1) 625 | action_removal = torch.where( 626 | torch.rand(batch_size, 1).to(h_wave.device) < 0.1, 627 | action_removal_random, 628 | action_removal_greedy, 629 | ) 630 | elif TYPE_REMOVAL == 'N2S' or TYPE_REMOVAL == 'random': 631 | action_removal = probs_removal.multinomial(1) 632 | else: 633 | assert False 634 | selected_log_ll_action1 = ( 635 | log_ll_removal.gather(1, action_removal) # type: ignore 636 | if self.training and TYPE_REMOVAL == 'N2S' 637 | else torch.tensor(0).to(h_hat.device) 638 | ) 639 | 640 | ############# action2 641 | pos_pickup = (1 + action_removal).view(-1) 642 | pos_delivery = pos_pickup + half_pos 643 | mask_table = ( 644 | problem.get_swap_mask(action_removal + 1, visit_index, top2) 645 | .expand(batch_size, graph_size_plus1, graph_size_plus1) 646 | .cpu() 647 | ) 648 | if TYPE_REINSERTION == 'N2S': 649 | action_reinsertion_table = ( 650 | torch.tanh( 651 | self.compater_reinsertion(h_hat, pos_pickup, pos_delivery, solution) 652 | ) 653 | * self.v_range 654 | ) 655 | elif TYPE_REINSERTION == 'random': 656 | action_reinsertion_table = torch.ones( 657 | batch_size, graph_size_plus1, graph_size_plus1 658 | ).to(h_wave.device) 659 | elif TYPE_REMOVAL == 'greedy': 660 | # epi-greedy 661 | pos_pickup = 1 + action_removal 662 | pos_delivery = pos_pickup + half_pos 663 | rec_new = solution.clone() 664 | argsort = rec_new.argsort() 665 | pre_pairfirst = argsort.gather(1, pos_pickup) 666 | post_pairfirst = rec_new.gather(1, pos_pickup) 667 | rec_new.scatter_(1, pre_pairfirst, post_pairfirst) 668 | rec_new.scatter_(1, pos_pickup, pos_pickup) 669 | argsort = rec_new.argsort() 670 | pre_pairsecond = argsort.gather(1, pos_delivery) 671 | post_pairsecond = rec_new.gather(1, pos_delivery) 672 | rec_new.scatter_(1, pre_pairsecond, post_pairsecond) 673 | # perform calc on new rec_new 674 | first_row = ( 675 | torch.arange(graph_size_plus1, device=solution.device) 676 | .long() 677 | .unsqueeze(0) 678 | .expand(batch_size, graph_size_plus1) 679 | ) 680 | d_i = x_in.gather( 681 | 1, first_row.unsqueeze(-1).expand(batch_size, graph_size_plus1, 2) 682 | ) 683 | d_i_next = x_in.gather( 684 | 1, rec_new.long().unsqueeze(-1).expand(batch_size, graph_size_plus1, 2) 685 | ) 686 | d_pick = x_in.gather( 687 | 1, pos_pickup.unsqueeze(1).expand(batch_size, graph_size_plus1, 2) 688 | ) 689 | d_deli = x_in.gather( 690 | 1, pos_delivery.unsqueeze(1).expand(batch_size, graph_size_plus1, 2) 691 | ) 692 | cost_insert_p = ( 693 | (d_pick - d_i).norm(p=2, dim=2) 694 | + (d_pick - d_i_next).norm(p=2, dim=2) 695 | - (d_i - d_i_next).norm(p=2, dim=2) 696 | ) 697 | cost_insert_d = ( 698 | (d_deli - d_i).norm(p=2, dim=2) 699 | + (d_deli - d_i_next).norm(p=2, dim=2) 700 | - (d_i - d_i_next).norm(p=2, dim=2) 701 | ) 702 | action_reinsertion_table = -( 703 | cost_insert_p.view(batch_size, graph_size_plus1, 1) 704 | + cost_insert_d.view(batch_size, 1, graph_size_plus1) 705 | ) 706 | action_reinsertion_table_random = torch.ones( 707 | batch_size, graph_size_plus1, graph_size_plus1 708 | ).to(h_wave.device) 709 | action_reinsertion_table_random[mask_table] = -1e20 710 | action_reinsertion_table_random = action_reinsertion_table_random.view( 711 | batch_size, -1 712 | ) 713 | probs_reinsertion_random = F.softmax( 714 | action_reinsertion_table_random, dim=-1 715 | ) 716 | else: 717 | assert False 718 | 719 | action_reinsertion_table[mask_table] = -1e20 720 | 721 | del visit_index, mask_table 722 | # reshape action_reinsertion_table 723 | action_reinsertion_table = action_reinsertion_table.view(batch_size, -1) 724 | log_ll_reinsertion = ( 725 | F.log_softmax(action_reinsertion_table, dim=-1) 726 | if self.training and TYPE_REINSERTION == 'N2S' 727 | else None 728 | ) 729 | probs_reinsertion = F.softmax(action_reinsertion_table, dim=-1) 730 | # fixed action 731 | if fixed_action is not None: 732 | p_selected = fixed_action[:, 1] 733 | d_selected = fixed_action[:, 2] 734 | pair_index = p_selected * graph_size_plus1 + d_selected 735 | pair_index = pair_index.view(-1, 1) 736 | action = fixed_action 737 | else: 738 | if TYPE_REINSERTION == 'greedy': 739 | action_reinsertion_random = probs_reinsertion_random.multinomial(1) 740 | action_reinsertion_greedy = probs_reinsertion.max(-1)[1].unsqueeze(1) 741 | pair_index = torch.where( 742 | torch.rand(batch_size, 1).to(h_wave.device) < 0.1, 743 | action_reinsertion_random, 744 | action_reinsertion_greedy, 745 | ) 746 | elif TYPE_REINSERTION == 'N2S' or TYPE_REINSERTION == 'random': 747 | # sample one action 748 | pair_index = probs_reinsertion.multinomial(1) 749 | else: 750 | assert False 751 | 752 | p_selected = pair_index // graph_size_plus1 753 | d_selected = pair_index % graph_size_plus1 754 | action = torch.cat( 755 | (action_removal.view(batch_size, -1), p_selected, d_selected), -1 756 | ) # batch_size, 3 757 | 758 | selected_log_ll_action2 = ( 759 | log_ll_reinsertion.gather(1, pair_index) # type: ignore 760 | if self.training and TYPE_REINSERTION == 'N2S' 761 | else torch.tensor(0).to(h_hat.device) 762 | ) 763 | 764 | log_ll = selected_log_ll_action1 + selected_log_ll_action2 765 | 766 | if require_entropy and self.training: 767 | dist = Categorical(probs_reinsertion, validate_args=False) 768 | entropy = dist.entropy() 769 | else: 770 | entropy = None 771 | 772 | return action, log_ll, entropy 773 | 774 | 775 | class Syn_Att(nn.Module): # (6) - (10) 776 | def __init__(self, n_heads: int, input_dim: int) -> None: 777 | super().__init__() 778 | 779 | hidden_dim = input_dim // n_heads 780 | 781 | self.n_heads = n_heads 782 | self.input_dim = input_dim 783 | self.hidden_dim = hidden_dim 784 | 785 | self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, hidden_dim)) 786 | self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, hidden_dim)) 787 | self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, hidden_dim)) 788 | 789 | self.score_aggr = nn.Sequential( 790 | nn.Linear(2 * n_heads, 2 * n_heads), 791 | nn.ReLU(inplace=True), 792 | nn.Linear(2 * n_heads, n_heads), 793 | ) 794 | 795 | self.W_out = nn.Parameter(torch.Tensor(n_heads, hidden_dim, input_dim)) 796 | 797 | self.init_parameters() 798 | 799 | def init_parameters(self) -> None: 800 | 801 | for param in self.parameters(): 802 | stdv = 1.0 / math.sqrt(param.size(-1)) 803 | param.data.uniform_(-stdv, stdv) 804 | 805 | __call__: Callable[..., torch.Tensor] 806 | 807 | def forward( 808 | self, h_fea: torch.Tensor, aux_att_score: torch.Tensor 809 | ) -> Tuple[torch.Tensor, torch.Tensor]: 810 | 811 | # h should be (batch_size, n_query, input_dim) 812 | batch_size, n_query, input_dim = h_fea.size() 813 | 814 | hflat = h_fea.contiguous().view(-1, input_dim) 815 | 816 | shp = (self.n_heads, batch_size, n_query, self.hidden_dim) 817 | 818 | # Calculate queries, (n_heads, batch_size, n_query, hidden_dim) 819 | Q = torch.matmul(hflat, self.W_query).view(shp) 820 | K = torch.matmul(hflat, self.W_key).view(shp) 821 | V = torch.matmul(hflat, self.W_val).view(shp) 822 | 823 | # Calculate compatibility (n_heads, batch_size, n_query, n_key) 824 | compatibility = torch.cat( 825 | (torch.matmul(Q, K.transpose(2, 3)), aux_att_score), 0 826 | ) 827 | 828 | attn_raw = compatibility.permute( 829 | 1, 2, 3, 0 830 | ) # (batch_size, n_query, n_key, n_heads) 831 | attn = self.score_aggr(attn_raw).permute( 832 | 3, 0, 1, 2 833 | ) # (n_heads, batch_size, n_query, n_key) 834 | heads = torch.matmul( 835 | F.softmax(attn, dim=-1), V 836 | ) # (n_heads, batch_size, n_query, hidden_dim) 837 | 838 | h_wave = torch.mm( 839 | heads.permute(1, 2, 0, 3) # (batch_size, n_query, n_heads, hidden_dim) 840 | .contiguous() 841 | .view( 842 | -1, self.n_heads * self.hidden_dim 843 | ), # (batch_size * n_query, n_heads * hidden_dim) 844 | self.W_out.view(-1, self.input_dim), # (n_heads * hidden_dim, input_dim) 845 | ).view(batch_size, n_query, self.input_dim) 846 | 847 | return h_wave, aux_att_score 848 | 849 | 850 | class Normalization(nn.Module): 851 | def __init__(self, input_dim: int, normalization: str) -> None: 852 | super().__init__() 853 | 854 | self.normalization = normalization 855 | 856 | if self.normalization != 'layer': 857 | normalizer_class = {'batch': nn.BatchNorm1d, 'instance': nn.InstanceNorm1d}[ 858 | normalization 859 | ] 860 | self.normalizer = normalizer_class(input_dim, affine=True) 861 | 862 | # Normalization by default initializes affine parameters with bias 0 and weight unif(0,1) which is too large! 863 | # self.init_parameters() 864 | 865 | def init_parameters(self) -> None: 866 | 867 | for name, param in self.named_parameters(): 868 | stdv = 1.0 / math.sqrt(param.size(-1)) 869 | param.data.uniform_(-stdv, stdv) 870 | 871 | __call__: Callable[..., torch.Tensor] 872 | 873 | def forward(self, input: torch.Tensor) -> torch.Tensor: 874 | if self.normalization == 'layer': 875 | return (input - input.mean((1, 2)).view(-1, 1, 1)) / torch.sqrt( 876 | input.var((1, 2)).view(-1, 1, 1) + 1e-05 877 | ) 878 | elif self.normalization == 'batch': 879 | return self.normalizer(input.view(-1, input.size(-1))).view(*input.size()) 880 | elif self.normalization == 'instance': 881 | return self.normalizer(input.permute(0, 2, 1)).permute(0, 2, 1) 882 | else: 883 | assert False, "Unknown normalizer type" 884 | 885 | 886 | class SynAttNormSubLayer(nn.Module): 887 | def __init__(self, n_heads: int, input_dim: int, normalization: str) -> None: 888 | super().__init__() 889 | 890 | self.SynAtt = Syn_Att(n_heads, input_dim) 891 | 892 | self.Norm = Normalization(input_dim, normalization) 893 | 894 | __call__: Callable[..., Tuple[torch.Tensor, torch.Tensor]] 895 | 896 | def forward( 897 | self, h_fea: torch.Tensor, aux_att_score: torch.Tensor 898 | ) -> Tuple[torch.Tensor, torch.Tensor]: 899 | # Attention and Residual connection 900 | h_wave, aux_att_score = self.SynAtt(h_fea, aux_att_score) 901 | 902 | # Normalization 903 | return self.Norm(h_wave + h_fea), aux_att_score 904 | 905 | 906 | class FFNormSubLayer(nn.Module): 907 | def __init__( 908 | self, input_dim: int, feed_forward_hidden: int, normalization: str 909 | ) -> None: 910 | super().__init__() 911 | 912 | self.FF = ( 913 | nn.Sequential( 914 | nn.Linear(input_dim, feed_forward_hidden, bias=False), 915 | nn.ReLU(inplace=True), 916 | nn.Linear(feed_forward_hidden, input_dim, bias=False), 917 | ) 918 | if feed_forward_hidden > 0 919 | else nn.Linear(input_dim, input_dim, bias=False) 920 | ) 921 | 922 | self.Norm = Normalization(input_dim, normalization) 923 | 924 | __call__: Callable[..., torch.Tensor] 925 | 926 | def forward(self, input: torch.Tensor) -> torch.Tensor: 927 | 928 | # FF and Residual connection 929 | out = self.FF(input) 930 | # Normalization 931 | return self.Norm(out + input) 932 | 933 | 934 | class N2SEncoder(nn.Module): 935 | def __init__( 936 | self, n_heads: int, input_dim: int, feed_forward_hidden: int, normalization: str 937 | ) -> None: 938 | super().__init__() 939 | 940 | self.SynAttNorm_sublayer = SynAttNormSubLayer(n_heads, input_dim, normalization) 941 | 942 | self.FFNorm_sublayer = FFNormSubLayer( 943 | input_dim, feed_forward_hidden, normalization 944 | ) 945 | 946 | __call__: Callable[..., Tuple[torch.Tensor, torch.Tensor]] 947 | 948 | def forward( 949 | self, h_fea: torch.Tensor, aux_att_score: torch.Tensor 950 | ) -> Tuple[torch.Tensor, torch.Tensor]: 951 | h_wave, aux_att_score = self.SynAttNorm_sublayer(h_fea, aux_att_score) 952 | return self.FFNorm_sublayer(h_wave), aux_att_score 953 | 954 | 955 | class EmbeddingNet(nn.Module): 956 | def __init__( 957 | self, node_dim: int, embedding_dim: int, seq_length: int, embedding_type: str 958 | ) -> None: 959 | super().__init__() 960 | self.node_dim = node_dim 961 | self.embedding_dim = embedding_dim 962 | 963 | if embedding_type == 'origin': 964 | self.feature_embedder = nn.Linear(node_dim, embedding_dim, bias=False) 965 | elif embedding_type == 'pair': 966 | self.feature_embedder = HeterEmbedding(node_dim, embedding_dim) # type: ignore 967 | elif embedding_type == 'share' or embedding_type == 'together': 968 | self.feature_embedder = None # type: ignore 969 | elif embedding_type == 'sep': 970 | self.feature_embedder = SepEmbedding(node_dim, embedding_dim) # type: ignore 971 | else: 972 | raise NotImplementedError 973 | 974 | self.pattern = self._cyclic_position_embedding_pattern( 975 | seq_length, embedding_dim 976 | ) 977 | 978 | self.init_parameters() 979 | 980 | def init_parameters(self) -> None: 981 | 982 | for param in self.parameters(): 983 | stdv = 1.0 / math.sqrt(param.size(-1)) 984 | param.data.uniform_(-stdv, stdv) 985 | 986 | def _base_sin(self, x: np.ndarray, omiga: float, fai: float = 0) -> np.ndarray: 987 | T = 2 * np.pi / omiga 988 | return np.sin(omiga * np.abs(np.mod(x, 2 * T) - T) + fai) 989 | 990 | def _base_cos(self, x: np.ndarray, omiga: float, fai: float = 0) -> np.ndarray: 991 | T = 2 * np.pi / omiga 992 | return np.cos(omiga * np.abs(np.mod(x, 2 * T) - T) + fai) 993 | 994 | def _cyclic_position_embedding_pattern( 995 | self, seq_length: int, embedding_dim: int, mean_pooling: bool = True 996 | ) -> torch.Tensor: 997 | 998 | Td_base = np.power(seq_length, 1 / (embedding_dim // 2)) 999 | Td_set = np.linspace(Td_base, seq_length, embedding_dim // 2, dtype='int') 1000 | g = np.zeros((seq_length, embedding_dim)) 1001 | 1002 | for d in range(embedding_dim): 1003 | Td = ( 1004 | Td_set[d // 3 * 3 + 1] 1005 | if (d // 3 * 3 + 1) < (embedding_dim // 2) 1006 | else Td_set[-1] 1007 | ) # (4) 1008 | 1009 | # get z(i) in the paper (via longer_pattern) 1010 | longer_pattern = np.arange(0, np.ceil(seq_length / Td) * Td, 0.01) 1011 | 1012 | num = len(longer_pattern) 1013 | omiga = 2 * np.pi / Td 1014 | fai = ( 1015 | 0 1016 | if d <= (embedding_dim // 2) 1017 | else 2 * np.pi * ((-d + (embedding_dim // 2)) / (embedding_dim // 2)) 1018 | ) 1019 | 1020 | # Eq. (2) in the paper 1021 | if d % 2 == 1: 1022 | g[:, d] = self._base_cos(longer_pattern, omiga, fai)[ 1023 | np.linspace(0, num, seq_length, dtype='int', endpoint=False) 1024 | ] 1025 | else: 1026 | g[:, d] = self._base_sin(longer_pattern, omiga, fai)[ 1027 | np.linspace(0, num, seq_length, dtype='int', endpoint=False) 1028 | ] 1029 | 1030 | pattern = torch.from_numpy(g).float() 1031 | pattern_sum = torch.zeros_like(pattern) 1032 | 1033 | # averaging the adjacient embeddings if needed (optional, almost the same performance) 1034 | arange = torch.arange(seq_length) 1035 | pooling = [0] if not mean_pooling else [-2, -1, 0, 1, 2] 1036 | time = 0 1037 | for d in pooling: 1038 | time += 1 1039 | index = (arange + d + seq_length) % seq_length 1040 | pattern_sum += pattern.gather(0, index.view(-1, 1).expand_as(pattern)) 1041 | pattern = 1.0 / time * pattern_sum - pattern.mean(0) 1042 | #### ---- 1043 | 1044 | return pattern # (seq_length, embedding_dim) 1045 | 1046 | def _position_embedding( 1047 | self, solution: torch.Tensor, embedding_dim: int, calc_stacks: bool 1048 | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: 1049 | batch_size, seq_length = solution.size() 1050 | half_size = seq_length // 2 1051 | 1052 | # expand for every batch 1053 | position_emb_new = ( 1054 | self.pattern.expand(batch_size, seq_length, embedding_dim) 1055 | .clone() 1056 | .to(solution.device) 1057 | ) 1058 | 1059 | # get index according to the solutions 1060 | visit_index = torch.zeros((batch_size, seq_length), device=solution.device) 1061 | 1062 | pre = torch.zeros((batch_size), device=solution.device).long() 1063 | 1064 | arange = torch.arange(batch_size) 1065 | if calc_stacks: 1066 | stacks = ( 1067 | torch.zeros(batch_size, half_size + 1, device=solution.device) - 0.01 1068 | ) # fix bug: topk is not stable sorting 1069 | top2 = torch.zeros(batch_size, seq_length, 2, device=solution.device).long() 1070 | stacks[arange, pre] = 0 # fix bug: topk is not stable sorting 1071 | 1072 | for i in range(seq_length): 1073 | current_nodes = solution[arange, pre] # (batch_size,) 1074 | visit_index[arange, current_nodes] = i + 1 1075 | pre = current_nodes 1076 | 1077 | if calc_stacks: 1078 | index1 = (current_nodes <= half_size) & (current_nodes > 0) 1079 | index2 = (current_nodes > half_size) & (current_nodes > 0) 1080 | if index1.any(): 1081 | stacks[index1, current_nodes[index1]] = i + 1 1082 | if index2.any(): 1083 | stacks[ 1084 | index2, current_nodes[index2] - half_size 1085 | ] = -0.01 # fix bug: topk is not stable sorting 1086 | top2[arange, current_nodes] = stacks.topk(2)[1] 1087 | # stack top after visit 1088 | # node+, (current_stack_top, last_stack_top_or_0) 1089 | # node-, (current_stack_top, last_stack_top_or_0) or (0, 1_meaningless) 1090 | 1091 | index = ( 1092 | (visit_index % seq_length) 1093 | .long() 1094 | .unsqueeze(-1) 1095 | .expand(batch_size, seq_length, embedding_dim) 1096 | ) 1097 | 1098 | return ( 1099 | torch.gather(position_emb_new, 1, index), 1100 | (visit_index % seq_length).long(), 1101 | top2 if calc_stacks else None, 1102 | ) 1103 | 1104 | __call__: Callable[ 1105 | ..., Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] 1106 | ] 1107 | 1108 | def forward( 1109 | self, x: torch.Tensor, solution: Optional[torch.Tensor], calc_stacks: bool 1110 | ): 1111 | if self.feature_embedder is None: 1112 | fea_emb = None 1113 | else: 1114 | fea_emb = self.feature_embedder(x) 1115 | 1116 | if solution is None: 1117 | return fea_emb, None, None, None 1118 | 1119 | pos_emb, visit_index, top2 = self._position_embedding( 1120 | solution, self.embedding_dim, calc_stacks 1121 | ) 1122 | return fea_emb, pos_emb, visit_index, top2 1123 | 1124 | 1125 | class CriticEncoder(nn.Sequential): 1126 | def __init__( 1127 | self, n_heads: int, input_dim: int, feed_forward_hidden: int, normalization: str 1128 | ) -> None: 1129 | super().__init__( 1130 | SkipConnection(MultiHeadSelfAttention(n_heads, input_dim)), 1131 | Normalization(input_dim, normalization), 1132 | SkipConnection( 1133 | nn.Sequential( 1134 | nn.Linear(input_dim, feed_forward_hidden), 1135 | nn.ReLU(inplace=True), 1136 | nn.Linear( 1137 | feed_forward_hidden, 1138 | input_dim, 1139 | ), 1140 | ) 1141 | if feed_forward_hidden > 0 1142 | else nn.Linear(input_dim, input_dim) 1143 | ), 1144 | Normalization(input_dim, normalization), 1145 | ) 1146 | 1147 | 1148 | class ConstructEncoder(nn.Module): 1149 | def __init__( 1150 | self, 1151 | n_heads: int, 1152 | input_dim: int, 1153 | normalization: str, 1154 | attn_type: str, 1155 | ) -> None: 1156 | super().__init__() 1157 | 1158 | if attn_type == 'typical': 1159 | self.MHA = SkipConnection(MultiHeadSelfAttention(n_heads, input_dim)) 1160 | elif attn_type == 'heter': 1161 | self.MHA = SkipConnection(HeterAttention(n_heads, input_dim)) 1162 | else: 1163 | raise NotImplementedError 1164 | self.norm = Normalization(input_dim, normalization) 1165 | self.FFnorm = FFNormSubLayer(input_dim, 512, normalization) 1166 | 1167 | __call__: Callable[..., torch.Tensor] 1168 | 1169 | def forward(self, h_fea: torch.Tensor) -> torch.Tensor: 1170 | return self.FFnorm(self.norm(self.MHA(h_fea))) 1171 | 1172 | 1173 | class ConstructDecoder(nn.Module): 1174 | def __init__( 1175 | self, n_heads: int, input_dim: int, stack_is_lifo: bool, type_select: str 1176 | ) -> None: 1177 | super().__init__() 1178 | 1179 | self.C = 10 1180 | self.stack_is_lifo = stack_is_lifo 1181 | self.type_select = type_select 1182 | 1183 | self.first_MHA = MultiHeadAttention( 1184 | n_heads, 2 * input_dim, input_dim, input_dim, input_dim 1185 | ) 1186 | self.second_SHA_score = MultiHeadAttention( 1187 | 1, input_dim, input_dim, None, input_dim 1188 | ) 1189 | 1190 | __call__: Callable[..., Tuple[torch.Tensor, torch.Tensor]] 1191 | 1192 | def forward( 1193 | self, 1194 | h_fea: torch.Tensor, 1195 | h_mean: torch.Tensor, 1196 | part_sol: torch.Tensor, 1197 | init_sol: torch.Tensor, 1198 | step: int, 1199 | stack: torch.Tensor, 1200 | direct_fixed_sol: Optional[torch.Tensor], 1201 | temperature: float, 1202 | ) -> Tuple[torch.Tensor, torch.Tensor]: 1203 | batch_size, graph_size_plus1, _ = h_fea.size() 1204 | half_size = graph_size_plus1 // 2 1205 | arange = torch.arange(batch_size) 1206 | 1207 | last_step = torch.argwhere(part_sol == 0)[:, 1] 1208 | context_emb = torch.cat((h_mean, h_fea[arange, last_step, :]), -1).unsqueeze(1) 1209 | 1210 | hc = self.first_MHA(context_emb, h_fea, h_fea) 1211 | uc = ( 1212 | torch.tanh(self.second_SHA_score(hc, h_fea, with_norm=True)) * self.C 1213 | ).view(batch_size, -1) 1214 | uc /= temperature 1215 | 1216 | mask = self._get_mask(part_sol, init_sol, stack) 1217 | uc[mask] = -1e20 1218 | 1219 | prob = F.softmax(uc, dim=-1) 1220 | log_p = F.log_softmax(uc, dim=-1) 1221 | 1222 | if direct_fixed_sol is None: 1223 | if self.type_select == 'greedy': 1224 | next_node = prob.max(1)[1] # (batch_size,) 1225 | elif self.type_select == 'sample': 1226 | next_node = prob.multinomial(1).view( 1227 | -1 1228 | ) # (batch_size, 1) -> (batch_size,) 1229 | else: 1230 | raise NotImplementedError 1231 | else: 1232 | next_node = direct_fixed_sol[:, step + 1] 1233 | 1234 | part_sol[arange, last_step] = next_node 1235 | part_sol[arange, next_node] = 0 1236 | 1237 | sel_log_p = log_p[arange, next_node] 1238 | 1239 | index_in = (next_node <= half_size) & (next_node > 0) 1240 | index_out = (next_node > half_size) & (next_node > 0) 1241 | if index_in.any(): 1242 | stack[index_in, next_node[index_in]] = step + 1 1243 | if index_out.any(): 1244 | stack[index_out, next_node[index_out] - half_size] = -1 1245 | 1246 | return part_sol, sel_log_p 1247 | 1248 | def _get_mask( 1249 | self, 1250 | part_sol: torch.Tensor, 1251 | init_sol: torch.Tensor, 1252 | stack: torch.Tensor, 1253 | ) -> torch.Tensor: 1254 | arange = torch.arange(stack.size(0)).to(stack.device) 1255 | half_size = stack.size(1) - 1 1256 | 1257 | mask = (part_sol == 0) | (part_sol != init_sol) 1258 | 1259 | if not self.stack_is_lifo: 1260 | wait_to_pick = torch.argwhere(stack < 0) 1261 | wait_to_pick[:, 1] += half_size 1262 | mask[wait_to_pick[:, 0], wait_to_pick[:, 1]] = True 1263 | else: 1264 | stack_top = stack.max(1)[1] 1265 | stack_top[stack_top != 0] += half_size 1266 | mask[:, half_size + 1 :] = True 1267 | mask[arange[stack_top != 0], stack_top[stack_top != 0]] = False 1268 | 1269 | return mask 1270 | 1271 | 1272 | class SepEmbedding(nn.Module): 1273 | def __init__(self, node_dim: int, embedding_dim: int) -> None: 1274 | super().__init__() 1275 | 1276 | self.embed_depot = nn.Linear(node_dim, embedding_dim) 1277 | self.embed_pickup = nn.Linear(node_dim, embedding_dim) 1278 | self.embed_delivery = nn.Linear(node_dim, embedding_dim) 1279 | 1280 | __call__: Callable[..., torch.Tensor] 1281 | 1282 | def forward(self, x_in: torch.Tensor) -> torch.Tensor: 1283 | graph_size_plus1 = x_in.size(1) 1284 | half_size = graph_size_plus1 // 2 1285 | depot_x = x_in[:, 0:1, :] 1286 | delivery_x = x_in[:, half_size + 1 :, :] 1287 | pickup_x = x_in[:, 1 : half_size + 1, :] 1288 | 1289 | return torch.cat( 1290 | [ 1291 | self.embed_depot(depot_x), 1292 | self.embed_pickup(pickup_x), 1293 | self.embed_delivery(delivery_x), 1294 | ], 1295 | 1, 1296 | ) 1297 | 1298 | 1299 | class HeterEmbedding(nn.Module): 1300 | def __init__(self, node_dim: int, embedding_dim: int) -> None: 1301 | super().__init__() 1302 | 1303 | self.embed_depot = nn.Linear(node_dim, embedding_dim) 1304 | self.embed_pickup = nn.Linear(node_dim * 2, embedding_dim) 1305 | self.embed_delivery = nn.Linear(node_dim, embedding_dim) 1306 | 1307 | __call__: Callable[..., torch.Tensor] 1308 | 1309 | def forward(self, x_in: torch.Tensor) -> torch.Tensor: 1310 | graph_size_plus1 = x_in.size(1) 1311 | half_size = graph_size_plus1 // 2 1312 | depot_x = x_in[:, 0:1, :] 1313 | delivery_x = x_in[:, half_size + 1 :, :] 1314 | pickup_pair = torch.cat([x_in[:, 1 : half_size + 1, :], delivery_x], 2) 1315 | 1316 | return torch.cat( 1317 | [ 1318 | self.embed_depot(depot_x), 1319 | self.embed_pickup(pickup_pair), 1320 | self.embed_delivery(delivery_x), 1321 | ], 1322 | 1, 1323 | ) 1324 | 1325 | 1326 | class HeterAttention(nn.Module): 1327 | # https://github.com/Demon0312/Heterogeneous-Attentions-PDP-DRL/blob/main/nets/graph_encoder.py 1328 | # without refactor 1329 | def __init__(self, n_heads: int, input_dim: int) -> None: 1330 | super().__init__() 1331 | 1332 | # start 1333 | embed_dim = input_dim 1334 | val_dim = None 1335 | key_dim = None 1336 | # end 1337 | 1338 | if val_dim is None: 1339 | assert embed_dim is not None, "Provide either embed_dim or val_dim" 1340 | val_dim = embed_dim // n_heads 1341 | if key_dim is None: 1342 | key_dim = val_dim 1343 | 1344 | self.n_heads = n_heads 1345 | self.input_dim = input_dim 1346 | self.embed_dim = embed_dim 1347 | self.val_dim = val_dim 1348 | self.key_dim = key_dim 1349 | 1350 | self.norm_factor = 1 / math.sqrt(key_dim) # See Attention is all you need 1351 | 1352 | self.W_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) 1353 | self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) 1354 | self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim)) 1355 | 1356 | # pickup 1357 | self.W1_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) 1358 | # self.W1_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) 1359 | # self.W1_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim)) 1360 | 1361 | self.W2_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) 1362 | # self.W2_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) 1363 | # self.W2_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim)) 1364 | 1365 | self.W3_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) 1366 | # self.W3_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) 1367 | # self.W3_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim)) 1368 | 1369 | # delivery 1370 | self.W4_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) 1371 | # self.W4_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) 1372 | # self.W4_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim)) 1373 | 1374 | self.W5_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) 1375 | # self.W5_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) 1376 | # self.W5_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim)) 1377 | 1378 | self.W6_query = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) 1379 | # self.W6_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim)) 1380 | # self.W6_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim)) 1381 | 1382 | if embed_dim is not None: 1383 | self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim)) 1384 | 1385 | self.init_parameters() 1386 | 1387 | def init_parameters(self) -> None: 1388 | 1389 | for param in self.parameters(): 1390 | stdv = 1.0 / math.sqrt(param.size(-1)) 1391 | param.data.uniform_(-stdv, stdv) 1392 | 1393 | __call__: Callable[..., torch.Tensor] 1394 | 1395 | def forward( 1396 | self, 1397 | q: torch.Tensor, 1398 | h: Optional[torch.Tensor] = None, 1399 | mask: Optional[torch.Tensor] = None, 1400 | ) -> torch.Tensor: 1401 | """ 1402 | :param q: queries (batch_size, n_query, input_dim) 1403 | :param h: data (batch_size, graph_size, input_dim) 1404 | :param mask: mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1) 1405 | Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency) 1406 | :return: 1407 | """ 1408 | if h is None: 1409 | h = q # compute self-attention 1410 | 1411 | # h should be (batch_size, graph_size, input_dim) 1412 | batch_size, graph_size, input_dim = h.size() 1413 | n_query = q.size(1) 1414 | assert q.size(0) == batch_size 1415 | assert q.size(2) == input_dim 1416 | assert input_dim == self.input_dim, "Wrong embedding dimension of input" 1417 | 1418 | hflat = h.contiguous().view( 1419 | -1, input_dim 1420 | ) # [batch_size * graph_size, embed_dim] 1421 | qflat = q.contiguous().view(-1, input_dim) # [batch_size * n_query, embed_dim] 1422 | 1423 | # last dimension can be different for keys and values 1424 | shp = (self.n_heads, batch_size, graph_size, -1) 1425 | shp_q = (self.n_heads, batch_size, n_query, -1) 1426 | 1427 | # pickup -> its delivery attention 1428 | n_pick = (graph_size - 1) // 2 1429 | shp_delivery = (self.n_heads, batch_size, n_pick, -1) 1430 | shp_q_pick = (self.n_heads, batch_size, n_pick, -1) 1431 | 1432 | # pickup -> all pickups attention 1433 | shp_allpick = (self.n_heads, batch_size, n_pick, -1) 1434 | shp_q_allpick = (self.n_heads, batch_size, n_pick, -1) 1435 | 1436 | # pickup -> all pickups attention 1437 | shp_alldelivery = (self.n_heads, batch_size, n_pick, -1) 1438 | shp_q_alldelivery = (self.n_heads, batch_size, n_pick, -1) 1439 | 1440 | # Calculate queries, (n_heads, n_query, graph_size, key/val_size) 1441 | Q = torch.matmul(qflat, self.W_query).view(shp_q) 1442 | # Calculate keys and values (n_heads, batch_size, graph_size, key/val_size) 1443 | K = torch.matmul(hflat, self.W_key).view(shp) 1444 | V = torch.matmul(hflat, self.W_val).view(shp) 1445 | 1446 | # pickup -> its delivery 1447 | pick_flat = ( 1448 | h[:, 1 : n_pick + 1, :].contiguous().view(-1, input_dim) 1449 | ) # [batch_size * n_pick, embed_dim] 1450 | delivery_flat = ( 1451 | h[:, n_pick + 1 :, :].contiguous().view(-1, input_dim) 1452 | ) # [batch_size * n_pick, embed_dim] 1453 | 1454 | # pickup -> its delivery attention 1455 | Q_pick = torch.matmul(pick_flat, self.W1_query).view( 1456 | shp_q_pick 1457 | ) # (self.n_heads, batch_size, n_pick, key_size) 1458 | K_delivery = torch.matmul(delivery_flat, self.W_key).view( 1459 | shp_delivery 1460 | ) # (self.n_heads, batch_size, n_pick, -1) 1461 | V_delivery = torch.matmul(delivery_flat, self.W_val).view( 1462 | shp_delivery 1463 | ) # (n_heads, batch_size, n_pick, key/val_size) 1464 | 1465 | # pickup -> all pickups attention 1466 | Q_pick_allpick = torch.matmul(pick_flat, self.W2_query).view( 1467 | shp_q_allpick 1468 | ) # (self.n_heads, batch_size, n_pick, -1) 1469 | K_allpick = torch.matmul(pick_flat, self.W_key).view( 1470 | shp_allpick 1471 | ) # [self.n_heads, batch_size, n_pick, key_size] 1472 | V_allpick = torch.matmul(pick_flat, self.W_val).view( 1473 | shp_allpick 1474 | ) # [self.n_heads, batch_size, n_pick, key_size] 1475 | 1476 | # pickup -> all delivery 1477 | Q_pick_alldelivery = torch.matmul(pick_flat, self.W3_query).view( 1478 | shp_q_alldelivery 1479 | ) # (self.n_heads, batch_size, n_pick, key_size) 1480 | K_alldelivery = torch.matmul(delivery_flat, self.W_key).view( 1481 | shp_alldelivery 1482 | ) # (self.n_heads, batch_size, n_pick, -1) 1483 | V_alldelivery = torch.matmul(delivery_flat, self.W_val).view( 1484 | shp_alldelivery 1485 | ) # (n_heads, batch_size, n_pick, key/val_size) 1486 | 1487 | # pickup -> its delivery 1488 | V_additional_delivery = torch.cat( 1489 | [ # [n_heads, batch_size, graph_size, key_size] 1490 | torch.zeros( 1491 | self.n_heads, 1492 | batch_size, 1493 | 1, 1494 | self.input_dim // self.n_heads, 1495 | dtype=V.dtype, 1496 | device=V.device, 1497 | ), 1498 | V_delivery, # [n_heads, batch_size, n_pick, key/val_size] 1499 | torch.zeros( 1500 | self.n_heads, 1501 | batch_size, 1502 | n_pick, 1503 | self.input_dim // self.n_heads, 1504 | dtype=V.dtype, 1505 | device=V.device, 1506 | ), 1507 | ], 1508 | 2, 1509 | ) 1510 | 1511 | # delivery -> its pickup attention 1512 | Q_delivery = torch.matmul(delivery_flat, self.W4_query).view( 1513 | shp_delivery 1514 | ) # (self.n_heads, batch_size, n_pick, key_size) 1515 | K_pick = torch.matmul(pick_flat, self.W_key).view( 1516 | shp_q_pick 1517 | ) # (self.n_heads, batch_size, n_pick, -1) 1518 | V_pick = torch.matmul(pick_flat, self.W_val).view( 1519 | shp_q_pick 1520 | ) # (n_heads, batch_size, n_pick, key/val_size) 1521 | 1522 | # delivery -> all delivery attention 1523 | Q_delivery_alldelivery = torch.matmul(delivery_flat, self.W5_query).view( 1524 | shp_alldelivery 1525 | ) # (self.n_heads, batch_size, n_pick, -1) 1526 | K_alldelivery2 = torch.matmul(delivery_flat, self.W_key).view( 1527 | shp_alldelivery 1528 | ) # [self.n_heads, batch_size, n_pick, key_size] 1529 | V_alldelivery2 = torch.matmul(delivery_flat, self.W_val).view( 1530 | shp_alldelivery 1531 | ) # [self.n_heads, batch_size, n_pick, key_size] 1532 | 1533 | # delivery -> all pickup 1534 | Q_delivery_allpickup = torch.matmul(delivery_flat, self.W6_query).view( 1535 | shp_alldelivery 1536 | ) # (self.n_heads, batch_size, n_pick, key_size) 1537 | K_allpickup2 = torch.matmul(pick_flat, self.W_key).view( 1538 | shp_q_alldelivery 1539 | ) # (self.n_heads, batch_size, n_pick, -1) 1540 | V_allpickup2 = torch.matmul(pick_flat, self.W_val).view( 1541 | shp_q_alldelivery 1542 | ) # (n_heads, batch_size, n_pick, key/val_size) 1543 | 1544 | # delivery -> its pick up 1545 | # V_additional_pick = torch.cat([ # [n_heads, batch_size, graph_size, key_size] 1546 | # torch.zeros(self.n_heads, batch_size, 1, self.input_dim // self.n_heads, dtype=V.dtype, device=V.device), 1547 | # V_delivery2, # [n_heads, batch_size, n_pick, key/val_size] 1548 | # torch.zeros(self.n_heads, batch_size, n_pick, self.input_dim // self.n_heads, dtype=V.dtype, device=V.device) 1549 | # ], 2) 1550 | V_additional_pick = torch.cat( 1551 | [ # [n_heads, batch_size, graph_size, key_size] 1552 | torch.zeros( 1553 | self.n_heads, 1554 | batch_size, 1555 | 1, 1556 | self.input_dim // self.n_heads, 1557 | dtype=V.dtype, 1558 | device=V.device, 1559 | ), 1560 | torch.zeros( 1561 | self.n_heads, 1562 | batch_size, 1563 | n_pick, 1564 | self.input_dim // self.n_heads, 1565 | dtype=V.dtype, 1566 | device=V.device, 1567 | ), 1568 | V_pick, # [n_heads, batch_size, n_pick, key/val_size] 1569 | ], 1570 | 2, 1571 | ) 1572 | 1573 | # Calculate compatibility (n_heads, batch_size, n_query, graph_size) 1574 | compatibility = self.norm_factor * torch.matmul(Q, K.transpose(2, 3)) 1575 | 1576 | ##Pick up 1577 | # ??pair???attention?? 1578 | compatibility_pick_delivery = self.norm_factor * torch.sum( 1579 | Q_pick * K_delivery, -1 1580 | ) # element_wise, [n_heads, batch_size, n_pick] 1581 | # [n_heads, batch_size, n_pick, n_pick] 1582 | compatibility_pick_allpick = self.norm_factor * torch.matmul( 1583 | Q_pick_allpick, K_allpick.transpose(2, 3) 1584 | ) # [n_heads, batch_size, n_pick, n_pick] 1585 | 1586 | compatibility_pick_alldelivery = self.norm_factor * torch.matmul( 1587 | Q_pick_alldelivery, K_alldelivery.transpose(2, 3) 1588 | ) # [n_heads, batch_size, n_pick, n_pick] 1589 | 1590 | ##Delivery 1591 | compatibility_delivery_pick = self.norm_factor * torch.sum( 1592 | Q_delivery * K_pick, -1 1593 | ) # element_wise, [n_heads, batch_size, n_pick] 1594 | 1595 | compatibility_delivery_alldelivery = self.norm_factor * torch.matmul( 1596 | Q_delivery_alldelivery, K_alldelivery2.transpose(2, 3) 1597 | ) # [n_heads, batch_size, n_pick, n_pick] 1598 | 1599 | compatibility_delivery_allpick = self.norm_factor * torch.matmul( 1600 | Q_delivery_allpickup, K_allpickup2.transpose(2, 3) 1601 | ) # [n_heads, batch_size, n_pick, n_pick] 1602 | 1603 | ##Pick up-> 1604 | # compatibility_additional?pickup????delivery????attention(size 1),1:n_pick+1??attention,depot?delivery?? 1605 | compatibility_additional_delivery = torch.cat( 1606 | [ # [n_heads, batch_size, graph_size, 1] 1607 | -np.inf 1608 | * torch.ones( 1609 | self.n_heads, 1610 | batch_size, 1611 | 1, 1612 | dtype=compatibility.dtype, 1613 | device=compatibility.device, 1614 | ), 1615 | compatibility_pick_delivery, # [n_heads, batch_size, n_pick] 1616 | -np.inf 1617 | * torch.ones( 1618 | self.n_heads, 1619 | batch_size, 1620 | n_pick, 1621 | dtype=compatibility.dtype, 1622 | device=compatibility.device, 1623 | ), 1624 | ], 1625 | -1, 1626 | ).view(self.n_heads, batch_size, graph_size, 1) 1627 | 1628 | compatibility_additional_allpick = torch.cat( 1629 | [ # [n_heads, batch_size, graph_size, n_pick] 1630 | -np.inf 1631 | * torch.ones( 1632 | self.n_heads, 1633 | batch_size, 1634 | 1, 1635 | n_pick, 1636 | dtype=compatibility.dtype, 1637 | device=compatibility.device, 1638 | ), 1639 | compatibility_pick_allpick, # [n_heads, batch_size, n_pick, n_pick] 1640 | -np.inf 1641 | * torch.ones( 1642 | self.n_heads, 1643 | batch_size, 1644 | n_pick, 1645 | n_pick, 1646 | dtype=compatibility.dtype, 1647 | device=compatibility.device, 1648 | ), 1649 | ], 1650 | 2, 1651 | ).view(self.n_heads, batch_size, graph_size, n_pick) 1652 | 1653 | compatibility_additional_alldelivery = torch.cat( 1654 | [ # [n_heads, batch_size, graph_size, n_pick] 1655 | -np.inf 1656 | * torch.ones( 1657 | self.n_heads, 1658 | batch_size, 1659 | 1, 1660 | n_pick, 1661 | dtype=compatibility.dtype, 1662 | device=compatibility.device, 1663 | ), 1664 | compatibility_pick_alldelivery, # [n_heads, batch_size, n_pick, n_pick] 1665 | -np.inf 1666 | * torch.ones( 1667 | self.n_heads, 1668 | batch_size, 1669 | n_pick, 1670 | n_pick, 1671 | dtype=compatibility.dtype, 1672 | device=compatibility.device, 1673 | ), 1674 | ], 1675 | 2, 1676 | ).view(self.n_heads, batch_size, graph_size, n_pick) 1677 | # [n_heads, batch_size, n_query, graph_size+1+n_pick+n_pick] 1678 | 1679 | ##Delivery-> 1680 | compatibility_additional_pick = torch.cat( 1681 | [ # [n_heads, batch_size, graph_size, 1] 1682 | -np.inf 1683 | * torch.ones( 1684 | self.n_heads, 1685 | batch_size, 1686 | 1, 1687 | dtype=compatibility.dtype, 1688 | device=compatibility.device, 1689 | ), 1690 | -np.inf 1691 | * torch.ones( 1692 | self.n_heads, 1693 | batch_size, 1694 | n_pick, 1695 | dtype=compatibility.dtype, 1696 | device=compatibility.device, 1697 | ), 1698 | compatibility_delivery_pick, # [n_heads, batch_size, n_pick] 1699 | ], 1700 | -1, 1701 | ).view(self.n_heads, batch_size, graph_size, 1) 1702 | 1703 | compatibility_additional_alldelivery2 = torch.cat( 1704 | [ # [n_heads, batch_size, graph_size, n_pick] 1705 | -np.inf 1706 | * torch.ones( 1707 | self.n_heads, 1708 | batch_size, 1709 | 1, 1710 | n_pick, 1711 | dtype=compatibility.dtype, 1712 | device=compatibility.device, 1713 | ), 1714 | -np.inf 1715 | * torch.ones( 1716 | self.n_heads, 1717 | batch_size, 1718 | n_pick, 1719 | n_pick, 1720 | dtype=compatibility.dtype, 1721 | device=compatibility.device, 1722 | ), 1723 | compatibility_delivery_alldelivery, # [n_heads, batch_size, n_pick, n_pick] 1724 | ], 1725 | 2, 1726 | ).view(self.n_heads, batch_size, graph_size, n_pick) 1727 | 1728 | compatibility_additional_allpick2 = torch.cat( 1729 | [ # [n_heads, batch_size, graph_size, n_pick] 1730 | -np.inf 1731 | * torch.ones( 1732 | self.n_heads, 1733 | batch_size, 1734 | 1, 1735 | n_pick, 1736 | dtype=compatibility.dtype, 1737 | device=compatibility.device, 1738 | ), 1739 | -np.inf 1740 | * torch.ones( 1741 | self.n_heads, 1742 | batch_size, 1743 | n_pick, 1744 | n_pick, 1745 | dtype=compatibility.dtype, 1746 | device=compatibility.device, 1747 | ), 1748 | compatibility_delivery_allpick, # [n_heads, batch_size, n_pick, n_pick] 1749 | ], 1750 | 2, 1751 | ).view(self.n_heads, batch_size, graph_size, n_pick) 1752 | 1753 | compatibility = torch.cat( 1754 | [ 1755 | compatibility, 1756 | compatibility_additional_delivery, 1757 | compatibility_additional_allpick, 1758 | compatibility_additional_alldelivery, 1759 | compatibility_additional_pick, 1760 | compatibility_additional_alldelivery2, 1761 | compatibility_additional_allpick2, 1762 | ], 1763 | dim=-1, 1764 | ) 1765 | 1766 | # Optionally apply mask to prevent attention 1767 | if mask is not None: 1768 | mask = mask.view(1, batch_size, n_query, graph_size).expand_as( 1769 | compatibility 1770 | ) 1771 | compatibility[mask] = -np.inf 1772 | 1773 | attn = torch.softmax( 1774 | compatibility, dim=-1 1775 | ) # [n_heads, batch_size, n_query, graph_size+1+n_pick*2] (graph_size include depot) 1776 | 1777 | # If there are nodes with no neighbours then softmax returns nan so we fix them to 0 1778 | if mask is not None: 1779 | attnc = attn.clone() 1780 | attnc[mask] = 0 1781 | attn = attnc 1782 | # heads: [n_heads, batrch_size, n_query, val_size], attn????pick?deliver?attn 1783 | heads = torch.matmul( 1784 | attn[:, :, :, :graph_size], V 1785 | ) # V: (self.n_heads, batch_size, graph_size, val_size) 1786 | 1787 | # heads??pick -> its delivery 1788 | heads = ( 1789 | heads 1790 | + attn[:, :, :, graph_size].view(self.n_heads, batch_size, graph_size, 1) 1791 | * V_additional_delivery 1792 | ) # V_addi:[n_heads, batch_size, graph_size, key_size] 1793 | 1794 | # heads??pick -> otherpick, V_allpick: # [n_heads, batch_size, n_pick, key_size] 1795 | # heads: [n_heads, batch_size, graph_size, key_size] 1796 | heads = heads + torch.matmul( 1797 | attn[:, :, :, graph_size + 1 : graph_size + 1 + n_pick].view( 1798 | self.n_heads, batch_size, graph_size, n_pick 1799 | ), 1800 | V_allpick, 1801 | ) 1802 | 1803 | # V_alldelivery: # (n_heads, batch_size, n_pick, key/val_size) 1804 | heads = heads + torch.matmul( 1805 | attn[:, :, :, graph_size + 1 + n_pick : graph_size + 1 + 2 * n_pick].view( 1806 | self.n_heads, batch_size, graph_size, n_pick 1807 | ), 1808 | V_alldelivery, 1809 | ) 1810 | 1811 | # delivery 1812 | heads = ( 1813 | heads 1814 | + attn[:, :, :, graph_size + 1 + 2 * n_pick].view( 1815 | self.n_heads, batch_size, graph_size, 1 1816 | ) 1817 | * V_additional_pick 1818 | ) 1819 | 1820 | heads = heads + torch.matmul( 1821 | attn[ 1822 | :, 1823 | :, 1824 | :, 1825 | graph_size + 1 + 2 * n_pick + 1 : graph_size + 1 + 3 * n_pick + 1, 1826 | ].view(self.n_heads, batch_size, graph_size, n_pick), 1827 | V_alldelivery2, 1828 | ) 1829 | 1830 | heads = heads + torch.matmul( 1831 | attn[:, :, :, graph_size + 1 + 3 * n_pick + 1 :].view( 1832 | self.n_heads, batch_size, graph_size, n_pick 1833 | ), 1834 | V_allpickup2, 1835 | ) 1836 | 1837 | out = torch.mm( 1838 | heads.permute(1, 2, 0, 3) 1839 | .contiguous() 1840 | .view(-1, self.n_heads * self.val_dim), 1841 | self.W_out.view(-1, self.embed_dim), 1842 | ).view(batch_size, n_query, self.embed_dim) 1843 | 1844 | return out 1845 | --------------------------------------------------------------------------------