├── data ├── __init__.py ├── gMission │ ├── reduced_workers.txt │ ├── reduced_tasks.txt │ └── workers.txt ├── MovieLense │ ├── movies.txt │ └── users.txt └── data_utils.py ├── figures ├── table1.png └── histogram.png ├── .gitignore ├── requirements.txt ├── LICENSE ├── INSTALL.md ├── utils ├── visualize.py ├── log_utils.py ├── functions.py └── reinforce_baselines.py ├── policy ├── balance.py ├── msvv.py ├── greedy_theshold.py ├── greedy_sc.py ├── greedy.py ├── simple_greedy.py ├── greedy_rt.py ├── inv_ff_history.py ├── greedy_matching.py ├── ff_model_invariant.py ├── ff_model_hist.py ├── ff_model.py ├── ff_supervised.py ├── gnn_simp_hist.py ├── gnn.py ├── inv_ff_history_switch.py └── gnn_hist.py ├── problem_state ├── osbm_dataset.py ├── adwords_dataset.py ├── edge_obm_dataset.py ├── obm_dataset.py ├── obm_env.py ├── edge_obm_env.py └── adwords_env.py ├── encoder └── graph_encoder.py ├── README.md ├── TUTORIAL.md ├── IPsolvers └── IPsolver.py └── run_lite.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/table1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ren-Research/LOMAR/HEAD/figures/table1.png -------------------------------------------------------------------------------- /figures/histogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ren-Research/LOMAR/HEAD/figures/histogram.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | dataset/** 3 | dataset 4 | data/__pycache__/** 5 | policy/__pycache__/** 6 | 7 | # Output folders 8 | logs_dataset/** 9 | saved_models/** 10 | saved_models 11 | 12 | 13 | **.pyc 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.2.1 2 | networkx>=2.4 3 | numpy>=1.19.0 4 | pylint>=2.5.3 5 | scipy>=1.4.0 6 | tensorboard>=2.2.2 7 | tensorboard-plugin-wit>=1.7.0 8 | torch>=1.11.0 9 | torchvision>=0.8.1 10 | tqdm>=4.47.0 11 | seaborn -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ren-Research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation Guide 2 | 3 | ## Install Online Bipartite Matching code base 4 | ```bash 5 | 6 | conda create --name renlab_obm python==3.7.0 7 | conda activate renlab_obm 8 | # Install pytorch 9 | pip install torch==1.11.0 torchvision 10 | # Install Pytorch Geometric 11 | pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.11.0+cu102.html 12 | # CPU-only version 13 | # pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.11.0+cpu.html 14 | 15 | # Install gurobi 16 | pip install gurobipy 17 | # Download seperate license verification tool. 18 | wget https://packages.gurobi.com/lictools/licensetools9.5.1_linux64.tar.gz 19 | # Please obtain a gurobi key before proceeding, replace the `your_gurobi_key` with your key 20 | grbgetkey your_gurobi_key 21 | 22 | # Install other requirements 23 | pip install -r requirements.txt 24 | 25 | pip install wandb 26 | 27 | pip install -v -e . 28 | ``` 29 | 30 | ## Request an GPU node in Slurm Cluster 31 | 32 | ```shell 33 | srun -p gpu --gres=gpu:1 --mem=100g --time=48:00:00 --pty bash -l 34 | ``` 35 | 36 | 37 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from train import rollout 5 | import torch 6 | 7 | def visualize_reward_process(reward_array, fig_path, ylim = None): 8 | fig, ax = plt.subplots(1) 9 | x_array = np.arange(reward_array.shape[1]) 10 | for k in range(reward_array.shape[0]): 11 | ax.plot(x_array, reward_array[k, :], label = "sample_{}".format(k)) 12 | 13 | if ylim is not None: 14 | ax.set_ylim(ylim) 15 | plt.legend() 16 | fig.show() 17 | plt.savefig(fig_path) 18 | plt.close() 19 | 20 | def validate_histogram(models, model_names, dataset, opts): 21 | assert len(models) == len(model_names) 22 | num_models = len(models) 23 | 24 | cost_list = [] 25 | cr_list = [] 26 | for i in range(num_models): 27 | model = models[i] 28 | model_name = model_names[i] 29 | cost, cr, _ = rollout(model, dataset, opts) 30 | print("{} Validation overall avg_cost: {} +- {}".format(model_name, 31 | cost.mean(), torch.std(cost))) 32 | 33 | cost_list.append(cost) 34 | cr_list.append(cr) 35 | 36 | bins = 30 37 | fig, ax = plt.subplots(1) 38 | for i in range(num_models): 39 | model_name = model_names[i] 40 | cost = cost_list[i] 41 | cost = -cost.view(-1).cpu().numpy() 42 | ax.hist(cost, bins = bins, density=True, alpha = 0.4, label = model_name) 43 | 44 | plt.legend() 45 | fig.show() 46 | plt.savefig("temp_hist_cost.png") 47 | plt.close() 48 | 49 | -------------------------------------------------------------------------------- /utils/log_utils.py: -------------------------------------------------------------------------------- 1 | def log_values( 2 | cost, 3 | epoch, 4 | batch_id, 5 | step, 6 | log_likelihood, 7 | tb_logger, 8 | opts, 9 | grad_norms=None, 10 | batch_loss=None, 11 | reinforce_loss=None, 12 | bl_loss=None, 13 | ): 14 | avg_cost = cost.mean().item() 15 | if grad_norms is not None: 16 | grad_norms, grad_norms_clipped = grad_norms 17 | print("grad_norm: {}, clipped: {}".format(grad_norms[0], grad_norms_clipped[0])) 18 | # Log values to screen 19 | print( 20 | "epoch: {}, train_batch_id: {}, avg_cost: {}".format(epoch, batch_id, avg_cost) 21 | ) 22 | 23 | # Log values to tensorboard 24 | if not opts.no_tensorboard: 25 | tb_logger.add_scalar("avg_cost", avg_cost, step) 26 | 27 | if opts.model == "ff-supervised" and batch_loss is not None: 28 | tb_logger.add_scalar("batch loss", batch_loss, step) 29 | else: 30 | if reinforce_loss is not None: 31 | tb_logger.add_scalar("actor_loss", reinforce_loss.item(), step) 32 | # tb_logger.add_scalar("nll", -sum(log_likelihood)/len(log_likelihood).item(), step) 33 | 34 | if grad_norms is not None: 35 | tb_logger.add_scalar("grad_norm", grad_norms[0], step) 36 | tb_logger.add_scalar("grad_norm_clipped", grad_norms_clipped[0], step) 37 | 38 | if opts.baseline == "critic": 39 | tb_logger.add_scalar("critic_loss", bl_loss.item(), step) 40 | tb_logger.add_scalar("critic_grad_norm", grad_norms[1], step) 41 | tb_logger.add_scalar( 42 | "critic_grad_norm_clipped", grad_norms_clipped[1], step 43 | ) 44 | -------------------------------------------------------------------------------- /policy/balance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from utils.functions import random_max 5 | 6 | 7 | class Balance(nn.Module): 8 | def __init__( 9 | self, 10 | embedding_dim, 11 | hidden_dim, 12 | problem, 13 | opts, 14 | tanh_clipping=None, 15 | mask_inner=None, 16 | mask_logits=None, 17 | n_encode_layers=None, 18 | normalization=None, 19 | checkpoint_encoder=False, 20 | shrink_size=None, 21 | num_actions=None, 22 | n_heads=None, 23 | encoder=None, 24 | ): 25 | super(Balance, self).__init__() 26 | self.decode_type = None 27 | self.problem = problem 28 | self.model_name = "balance" 29 | 30 | def forward(self, x, opts, optimizer, baseline, return_pi=False): 31 | assert opts.problem == "adwords" 32 | state = self.problem.make_state(x, opts.u_size, opts.v_size, opts) 33 | sequences = [] 34 | while not (state.all_finished()): 35 | mask = state.get_mask() 36 | frac_budget = state.curr_budget / state.orig_budget 37 | frac_budget[mask.bool()] = -1e6 38 | frac_budget[:, 0] = -1e5 39 | selected = random_max(frac_budget) 40 | 41 | state = state.update(selected) 42 | sequences.append(selected.squeeze(1)) 43 | if return_pi: 44 | return -state.size, None, torch.stack(sequences, 1), None 45 | return -state.size, torch.stack(sequences, 1), None 46 | 47 | def set_decode_type(self, decode_type, temp=None): 48 | self.decode_type = decode_type 49 | if temp is not None: # Do not change temperature if not provided 50 | self.temp = temp 51 | -------------------------------------------------------------------------------- /policy/msvv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from utils.functions import random_max 5 | 6 | 7 | class MSVV(nn.Module): 8 | def __init__( 9 | self, 10 | embedding_dim, 11 | hidden_dim, 12 | problem, 13 | opts, 14 | tanh_clipping=None, 15 | mask_inner=None, 16 | mask_logits=None, 17 | n_encode_layers=None, 18 | normalization=None, 19 | checkpoint_encoder=False, 20 | shrink_size=None, 21 | num_actions=None, 22 | n_heads=None, 23 | encoder=None, 24 | ): 25 | super(MSVV, self).__init__() 26 | self.decode_type = None 27 | self.problem = problem 28 | self.model_name = "msvv" 29 | 30 | def forward(self, x, opts, optimizer, baseline, return_pi=False): 31 | assert opts.problem == "adwords" 32 | state = self.problem.make_state(x, opts.u_size, opts.v_size, opts) 33 | sequences = [] 34 | while not (state.all_finished()): 35 | mask = state.get_mask() 36 | w = state.get_current_weights(mask).clone() 37 | scaled_w = w * (1 - torch.exp(-(state.curr_budget / state.orig_budget))) 38 | scaled_w[mask.bool()] = -1e6 39 | scaled_w[:, 0] = -1e5 40 | selected = random_max(scaled_w) 41 | 42 | state = state.update(selected) 43 | sequences.append(selected.squeeze(1)) 44 | if return_pi: 45 | return -state.size, None, torch.stack(sequences, 1), None 46 | return -state.size, torch.stack(sequences, 1), None 47 | 48 | def set_decode_type(self, decode_type, temp=None): 49 | self.decode_type = decode_type 50 | if temp is not None: # Do not change temperature if not provided 51 | self.temp = temp 52 | -------------------------------------------------------------------------------- /problem_state/osbm_dataset.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import Dataset 2 | import torch 3 | import os 4 | import pickle 5 | from problem_state.osbm_env import StateOSBM 6 | from data.generate_data import generate_osbm_data_geometric 7 | 8 | 9 | class OSBM(object): 10 | 11 | NAME = "osbm" 12 | 13 | @staticmethod 14 | def make_dataset(*args, **kwargs): 15 | return OSBMDataset(*args, **kwargs) 16 | 17 | @staticmethod 18 | def make_state(*args, **kwargs): 19 | return StateOSBM.initialize(*args, **kwargs) 20 | 21 | 22 | class OSBMDataset(Dataset): 23 | def __init__( 24 | self, dataset, size, problem, seed, opts, transform=None, pre_transform=None 25 | ): 26 | super(OSBMDataset, self).__init__(None, transform, pre_transform) 27 | # self.data_set = dataset 28 | # self.optimal_size = torch.load("{}/optimal_match.pt".format(self.data_set)) 29 | self.problem = problem 30 | if dataset is not None: 31 | # self.optimal_size = torch.load("{}/optimal_match.pt".format(dataset)) 32 | self.data_set = dataset 33 | else: 34 | # If no filename is specified generated data for edge obm probelm 35 | D, optimal_size = generate_osbm_data_geometric( 36 | opts.u_size, 37 | opts.v_size, 38 | opts.weight_distribution, 39 | opts.weight_distribution_param, 40 | opts.graph_family_parameter, 41 | seed, 42 | opts.graph_family, 43 | None, 44 | size, 45 | False, 46 | ) 47 | self.optimal_size = optimal_size 48 | self.data_set = D 49 | 50 | self.size = size 51 | 52 | def len(self): 53 | return self.size 54 | 55 | def get(self, idx): 56 | if type(self.data_set) == str: 57 | data = torch.load(self.data_set + "/data_{}.pt".format(idx)) 58 | else: 59 | data = self.data_set[idx] 60 | return data 61 | -------------------------------------------------------------------------------- /problem_state/adwords_dataset.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import Dataset 2 | import torch 3 | from problem_state.adwords_env import StateAdwordsBipartite 4 | from data.generate_data import generate_adwords_data_geometric 5 | 6 | 7 | class AdwordsBipartite(object): 8 | 9 | NAME = "adwords" 10 | 11 | @staticmethod 12 | def make_dataset(*args, **kwargs): 13 | return AdwordsBipartiteDataset(*args, **kwargs) 14 | 15 | @staticmethod 16 | def make_state(*args, **kwargs): 17 | return StateAdwordsBipartite.initialize(*args, **kwargs) 18 | 19 | 20 | class AdwordsBipartiteDataset(Dataset): 21 | def __init__( 22 | self, dataset, size, problem, seed, opts, transform=None, pre_transform=None 23 | ): 24 | super(AdwordsBipartiteDataset, self).__init__(None, transform, pre_transform) 25 | # self.data_set = dataset 26 | # self.optimal_size = torch.load("{}/optimal_match.pt".format(self.data_set)) 27 | self.problem = problem 28 | if dataset is not None: 29 | # self.optimal_size = torch.load("{}/optimal_match.pt".format(dataset)) 30 | self.data_set = dataset 31 | else: 32 | # If no filename is specified generated data for edge obm probelm 33 | D, optimal_size = generate_adwords_data_geometric( 34 | opts.u_size, 35 | opts.v_size, 36 | opts.weight_distribution, 37 | opts.weight_distribution_param, 38 | opts.graph_family_parameter, 39 | seed, 40 | opts.graph_family, 41 | None, 42 | size, 43 | False, 44 | ) 45 | self.optimal_size = optimal_size 46 | self.data_set = D 47 | 48 | self.size = size 49 | 50 | def len(self): 51 | return self.size 52 | 53 | def get(self, idx): 54 | if type(self.data_set) == str: 55 | data = torch.load(self.data_set + "/data_{}.pt".format(idx)) 56 | else: 57 | data = self.data_set[idx] 58 | return data 59 | -------------------------------------------------------------------------------- /problem_state/edge_obm_dataset.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import Dataset 2 | import torch 3 | import os 4 | import pickle 5 | from problem_state.edge_obm_env import StateEdgeBipartite 6 | from data.generate_data import generate_edge_obm_data_geometric 7 | 8 | 9 | class EdgeBipartite(object): 10 | 11 | NAME = "e-obm" 12 | 13 | @staticmethod 14 | def make_dataset(*args, **kwargs): 15 | return EdgeBipartiteDataset(*args, **kwargs) 16 | 17 | @staticmethod 18 | def make_state(*args, **kwargs): 19 | return StateEdgeBipartite.initialize(*args, **kwargs) 20 | 21 | 22 | class EdgeBipartiteDataset(Dataset): 23 | def __init__( 24 | self, dataset, size, problem, seed, opts, transform=None, pre_transform=None 25 | ): 26 | super(EdgeBipartiteDataset, self).__init__(None, transform, pre_transform) 27 | # self.data_set = dataset 28 | # self.optimal_size = torch.load("{}/optimal_match.pt".format(self.data_set)) 29 | self.problem = problem 30 | if dataset is not None: 31 | # self.optimal_size = torch.load("{}/optimal_match.pt".format(dataset)) 32 | self.data_set = dataset 33 | else: 34 | # If no filename is specified generated data for edge obm probelm 35 | D, optimal_size = generate_edge_obm_data_geometric( 36 | opts.u_size, 37 | opts.v_size, 38 | opts.weight_distribution, 39 | opts.weight_distribution_param, 40 | opts.graph_family_parameter, 41 | seed, 42 | opts.graph_family, 43 | None, 44 | size, 45 | False, 46 | ) 47 | self.optimal_size = optimal_size 48 | self.data_set = D 49 | 50 | self.size = size 51 | 52 | def len(self): 53 | return self.size 54 | 55 | def get(self, idx): 56 | if type(self.data_set) == str: 57 | data = torch.load(self.data_set + "/data_{}.pt".format(idx)) 58 | else: 59 | data = self.data_set[idx] 60 | return data 61 | -------------------------------------------------------------------------------- /policy/greedy_theshold.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from utils.functions import get_best_t 4 | 5 | 6 | class GreedyThresh(nn.Module): 7 | def __init__( 8 | self, 9 | embedding_dim, 10 | hidden_dim, 11 | problem, 12 | opts, 13 | tanh_clipping=None, 14 | mask_inner=None, 15 | mask_logits=None, 16 | n_encode_layers=None, 17 | normalization=None, 18 | checkpoint_encoder=False, 19 | shrink_size=None, 20 | num_actions=None, 21 | n_heads=None, 22 | encoder=None, 23 | ): 24 | super(GreedyThresh, self).__init__() 25 | self.decode_type = None 26 | self.problem = problem 27 | self.model_name = "greedy-t" 28 | self.best_threshold = get_best_t(self.model_name, opts) 29 | 30 | def forward(self, x, opts, optimizer, baseline, return_pi=False): 31 | state = self.problem.make_state(x, opts.u_size, opts.v_size, opts) 32 | t = self.best_threshold 33 | sequences = [] 34 | while not (state.all_finished()): 35 | mask = state.get_mask() 36 | w = state.get_current_weights(mask).clone() 37 | mask = state.get_mask() 38 | w[mask.bool()] = 0.0 39 | temp = w.clone() 40 | # w[temp >= t] = 1.0 41 | w[temp < t] = 0.0 42 | w[w.sum(1) == 0, 0] = 1.0 43 | # if self.decode_type == "greedy": 44 | selected = torch.argmax(w, dim=1) 45 | # elif self.decode_type == "sampling": 46 | # selected = (w / torch.sum(w, dim=1)[:, None]).multinomial(1) 47 | state = state.update(selected[:, None]) 48 | 49 | sequences.append(selected) 50 | if return_pi: 51 | return -state.size, None, torch.stack(sequences, 1), None 52 | return -state.size, torch.stack(sequences, 1), None 53 | 54 | def set_decode_type(self, decode_type, temp=None): 55 | self.decode_type = decode_type 56 | if temp is not None: # Do not change temperature if not provided 57 | self.temp = temp 58 | -------------------------------------------------------------------------------- /utils/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from problem_state.obm_dataset import Bipartite 3 | from problem_state.edge_obm_dataset import EdgeBipartite 4 | from problem_state.osbm_dataset import OSBM 5 | from problem_state.adwords_dataset import AdwordsBipartite 6 | import csv 7 | 8 | 9 | def load_problem(name): 10 | 11 | problem = { 12 | "obm": Bipartite, 13 | "e-obm": EdgeBipartite, 14 | "osbm": OSBM, 15 | "adwords": AdwordsBipartite, 16 | }.get(name, None) 17 | assert problem is not None, "Currently unsupported problem: {}!".format(name) 18 | return problem 19 | 20 | 21 | def torch_load_cpu(load_path): 22 | return torch.load( 23 | load_path, map_location=lambda storage, loc: storage 24 | ) # Load on CPU 25 | 26 | 27 | def move_to(var, device): 28 | if isinstance(var, dict): 29 | return {k: move_to(v, device) for k, v in var.items()} 30 | elif isinstance(var, list): 31 | return list(move_to(v, device) for v in var) 32 | return var.to(device) 33 | 34 | 35 | def random_max(input): 36 | """ 37 | Return max element with random tie breaking 38 | """ 39 | max_w, _ = torch.max(input, dim=1) 40 | max_filter = (input == max_w[:, None]).float() 41 | selected = (max_filter / (max_filter.sum(dim=1)[:, None])).multinomial(1) 42 | 43 | return selected 44 | 45 | 46 | def random_min(input): 47 | """ 48 | Return min element with random tie breaking 49 | """ 50 | min_w, _ = torch.min(input, dim=1) 51 | max_filter = (input == min_w[:, None]).float() 52 | selected = (max_filter / (max_filter.sum(dim=1)[:, None])).multinomial(1) 53 | 54 | return selected 55 | 56 | 57 | def get_best_t(model, opts): 58 | best_params = None 59 | best_r = 0 60 | graph_family = ( 61 | opts.graph_family if opts.graph_family != "gmission-perm" else "gmission" 62 | ) 63 | with open( 64 | f"val_rewards_{model}_{opts.u_size}_{opts.v_size}_{graph_family}_{opts.graph_family_parameter}.csv" 65 | ) as csv_file: 66 | csv_reader = csv.reader(csv_file, delimiter=",") 67 | for line in csv_reader: 68 | if abs(float(line[-1])) > best_r: 69 | best_params = float(line[0]) 70 | best_r = abs(float(line[-1])) 71 | return best_params 72 | -------------------------------------------------------------------------------- /data/gMission/reduced_workers.txt: -------------------------------------------------------------------------------- 1 | 321.0 2 | 81.0 3 | 105.0 4 | 329.0 5 | 497.0 6 | 358.0 7 | 3.0 8 | 300.0 9 | 357.0 10 | 381.0 11 | 161.0 12 | 510.0 13 | 155.0 14 | 239.0 15 | 511.0 16 | 72.0 17 | 143.0 18 | 172.0 19 | 370.0 20 | 447.0 21 | 17.0 22 | 106.0 23 | 113.0 24 | 129.0 25 | 353.0 26 | 387.0 27 | 407.0 28 | 412.0 29 | 490.0 30 | 515.0 31 | 523.0 32 | 12.0 33 | 37.0 34 | 193.0 35 | 284.0 36 | 323.0 37 | 343.0 38 | 344.0 39 | 404.0 40 | 438.0 41 | 460.0 42 | 479.0 43 | 18.0 44 | 29.0 45 | 49.0 46 | 71.0 47 | 229.0 48 | 301.0 49 | 339.0 50 | 342.0 51 | 413.0 52 | 445.0 53 | 526.0 54 | 5.0 55 | 8.0 56 | 30.0 57 | 38.0 58 | 145.0 59 | 195.0 60 | 214.0 61 | 236.0 62 | 309.0 63 | 315.0 64 | 324.0 65 | 327.0 66 | 385.0 67 | 394.0 68 | 494.0 69 | 74.0 70 | 82.0 71 | 111.0 72 | 142.0 73 | 182.0 74 | 203.0 75 | 207.0 76 | 217.0 77 | 293.0 78 | 430.0 79 | 441.0 80 | 459.0 81 | 487.0 82 | 493.0 83 | 1.0 84 | 9.0 85 | 39.0 86 | 75.0 87 | 93.0 88 | 158.0 89 | 160.0 90 | 222.0 91 | 227.0 92 | 276.0 93 | 294.0 94 | 308.0 95 | 328.0 96 | 365.0 97 | 406.0 98 | 429.0 99 | 462.0 100 | 468.0 101 | 480.0 102 | 484.0 103 | 486.0 104 | 491.0 105 | 518.0 106 | 529.0 107 | 60.0 108 | 126.0 109 | 151.0 110 | 167.0 111 | 256.0 112 | 291.0 113 | 337.0 114 | 372.0 115 | 488.0 116 | 525.0 117 | 62.0 118 | 80.0 119 | 148.0 120 | 149.0 121 | 157.0 122 | 173.0 123 | 212.0 124 | 254.0 125 | 318.0 126 | 422.0 127 | 435.0 128 | 496.0 129 | 52.0 130 | 61.0 131 | 76.0 132 | 204.0 133 | 221.0 134 | 275.0 135 | 283.0 136 | 297.0 137 | 319.0 138 | 347.0 139 | 351.0 140 | 359.0 141 | 366.0 142 | 410.0 143 | 419.0 144 | 442.0 145 | 453.0 146 | 456.0 147 | 465.0 148 | 475.0 149 | 489.0 150 | 507.0 151 | 6.0 152 | 90.0 153 | 115.0 154 | 136.0 155 | 166.0 156 | 271.0 157 | 331.0 158 | 332.0 159 | 388.0 160 | 405.0 161 | 416.0 162 | 471.0 163 | 512.0 164 | 516.0 165 | 4.0 166 | 56.0 167 | 122.0 168 | 188.0 169 | 216.0 170 | 224.0 171 | 245.0 172 | 296.0 173 | 322.0 174 | 401.0 175 | 451.0 176 | 492.0 177 | 19.0 178 | 23.0 179 | 50.0 180 | 175.0 181 | 230.0 182 | 257.0 183 | 260.0 184 | 278.0 185 | 362.0 186 | 398.0 187 | 400.0 188 | 426.0 189 | 482.0 190 | 48.0 191 | 92.0 192 | 132.0 193 | 152.0 194 | 210.0 195 | 306.0 196 | 314.0 197 | 355.0 198 | 377.0 199 | 378.0 200 | 402.0 201 | -------------------------------------------------------------------------------- /policy/greedy_sc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from utils.functions import random_max 4 | import math 5 | 6 | class GreedySC(nn.Module): 7 | def __init__( 8 | self, 9 | embedding_dim, 10 | hidden_dim, 11 | problem, 12 | opts, 13 | tanh_clipping=None, 14 | mask_inner=None, 15 | mask_logits=None, 16 | n_encode_layers=None, 17 | normalization=None, 18 | checkpoint_encoder=False, 19 | shrink_size=None, 20 | num_actions=None, 21 | n_heads=None, 22 | encoder=None, 23 | ): 24 | super(GreedySC, self).__init__() 25 | self.decode_type = None 26 | self.problem = problem 27 | self.model_name = "greedy-sc" 28 | 29 | def forward(self, x, opts, optimizer, baseline, return_pi=False): 30 | state = self.problem.make_state(x, opts.u_size, opts.v_size, opts) 31 | sequences = [] 32 | 33 | # Array for single step rewards 34 | i = 0 35 | reward_array = torch.zeros([opts.v_size, opts.batch_size]) 36 | prev_reward = 0 37 | 38 | graph_v_size = opts.v_size 39 | i = 0 40 | 41 | while not (state.all_finished()): 42 | mask = state.get_mask() 43 | w = state.get_current_weights(mask).clone() 44 | w[mask.bool()] = -1.0 45 | 46 | # Select action based on current info 47 | if(i > graph_v_size/math.e - 1): 48 | selected = random_max(w) 49 | else: 50 | selected = torch.zeros([w.shape[0], 1], device=w.device, dtype=torch.int64) 51 | 52 | state = state.update(selected) 53 | sequences.append(selected.squeeze(1)) 54 | 55 | # Collect the signle step reward 56 | reward_array[i,:] = state.size.view(-1).cpu() - prev_reward 57 | prev_reward = state.size.view(-1).cpu() 58 | i+= 1 59 | 60 | if return_pi: 61 | return -state.size, None, torch.stack(sequences, 1), None 62 | 63 | return -state.size, torch.stack(sequences, 1), None 64 | 65 | def set_decode_type(self, decode_type, temp=None): 66 | self.decode_type = decode_type 67 | if temp is not None: # Do not change temperature if not provided 68 | self.temp = temp 69 | -------------------------------------------------------------------------------- /policy/greedy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from utils.functions import random_max 4 | 5 | 6 | class Greedy(nn.Module): 7 | def __init__( 8 | self, 9 | embedding_dim, 10 | hidden_dim, 11 | problem, 12 | opts, 13 | tanh_clipping=None, 14 | mask_inner=None, 15 | mask_logits=None, 16 | n_encode_layers=None, 17 | normalization=None, 18 | checkpoint_encoder=False, 19 | shrink_size=None, 20 | num_actions=None, 21 | n_heads=None, 22 | encoder=None, 23 | ): 24 | super(Greedy, self).__init__() 25 | self.decode_type = None 26 | self.problem = problem 27 | self.model_name = "greedy" 28 | 29 | def forward(self, x, opts, optimizer, baseline, return_pi=False): 30 | state = self.problem.make_state(x, opts.u_size, opts.v_size, opts) 31 | sequences = [] 32 | 33 | # Array for single step rewards 34 | i = 0 35 | reward_array = torch.zeros([opts.v_size, opts.batch_size]) 36 | prev_reward = 0 37 | 38 | while not (state.all_finished()): 39 | mask = state.get_mask() 40 | w = state.get_current_weights(mask).clone() 41 | w[mask.bool()] = -1.0 42 | selected = random_max(w) 43 | state = state.update(selected) 44 | sequences.append(selected.squeeze(1)) 45 | 46 | # Collect the signle step reward 47 | reward_array[i,:] = state.size.view(-1).cpu() - prev_reward 48 | prev_reward = state.size.view(-1).cpu() 49 | i+= 1 50 | 51 | if return_pi: 52 | return -state.size, None, torch.stack(sequences, 1), None 53 | 54 | # import os 55 | # base_dir = "reward_save" 56 | # file_list = os.listdir(base_dir) 57 | # if file_list.__len__() != 0: 58 | # file_list.sort() 59 | # curr_index = int(file_list[-1][-6:-3]) + 1 60 | # else: 61 | # curr_index = 0 62 | # torch.save(reward_array, base_dir+"/reward_{:0>3d}.pt".format(curr_index)) 63 | 64 | return -state.size, torch.stack(sequences, 1), None 65 | 66 | def set_decode_type(self, decode_type, temp=None): 67 | self.decode_type = decode_type 68 | if temp is not None: # Do not change temperature if not provided 69 | self.temp = temp 70 | -------------------------------------------------------------------------------- /encoder/graph_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.nn import NNConv 4 | import torch.nn.functional as F 5 | 6 | 7 | class MPNN(nn.Module): 8 | def __init__( 9 | self, 10 | n_heads, 11 | embed_dim, 12 | n_layers, 13 | problem, 14 | opts, 15 | dropout=0.1, 16 | alpha=0.01, 17 | node_dim_u=1, 18 | node_dim_v=1, 19 | normalization="batch", 20 | feed_forward_hidden=512, 21 | ): 22 | super(MPNN, self).__init__() 23 | self.l1 = nn.Linear(1, embed_dim * embed_dim) # TODO: Change back 24 | self.node_embed_u = nn.Linear(node_dim_u, embed_dim) 25 | if node_dim_u != node_dim_v: 26 | self.node_embed_v = nn.Linear(node_dim_v, embed_dim) 27 | else: 28 | self.node_embed_v = self.node_embed_u 29 | 30 | self.conv1 = NNConv( 31 | embed_dim, embed_dim, self.l1, aggr="mean" 32 | ) # TODO: Change back 33 | if n_layers > 1: 34 | self.l2 = nn.Linear(1, embed_dim ** 2) 35 | self.conv2 = NNConv(embed_dim, embed_dim, self.l2, aggr="mean") 36 | self.n_layers = n_layers 37 | self.problem = opts.problem 38 | self.u_size = opts.u_size 39 | self.node_dim_u = node_dim_u 40 | self.node_dim_v = node_dim_v 41 | self.batch_size = opts.batch_size 42 | 43 | def forward(self, x, edge_index, edge_attribute, i, dummy, opts): 44 | i = i.item() 45 | graph_size = opts.u_size + 1 + i 46 | if i < self.n_layers: 47 | n_encode_layers = i + 1 48 | else: 49 | n_encode_layers = self.n_layers 50 | x_u = x[:, : self.node_dim_u * (opts.u_size + 1)].reshape( 51 | self.batch_size, opts.u_size + 1, self.node_dim_u 52 | ) 53 | x_v = x[:, self.node_dim_u * (opts.u_size + 1) :].reshape( 54 | self.batch_size, i, self.node_dim_v 55 | ) 56 | x_u = self.node_embed_u(x_u) 57 | x_v = self.node_embed_v(x_v) 58 | x = torch.cat((x_u, x_v), dim=1).reshape(self.batch_size * graph_size, -1) 59 | 60 | for j in range(n_encode_layers): 61 | # x = F.relu(x) # TODO: Change back 62 | if j == 0: 63 | x = self.conv1(x, edge_index, edge_attribute.float()) 64 | if n_encode_layers > 1: 65 | x = F.relu(x) 66 | else: 67 | x = self.conv2(x, edge_index, edge_attribute.float()) 68 | 69 | return x 70 | -------------------------------------------------------------------------------- /policy/simple_greedy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | 6 | class SimpleGreedy(nn.Module): 7 | def __init__( 8 | self, 9 | embedding_dim, 10 | hidden_dim, 11 | problem, 12 | opts, 13 | tanh_clipping=None, 14 | mask_inner=None, 15 | mask_logits=None, 16 | n_encode_layers=None, 17 | normalization=None, 18 | checkpoint_encoder=False, 19 | shrink_size=None, 20 | num_actions=None, 21 | n_heads=None, 22 | encoder=None, 23 | ): 24 | super(SimpleGreedy, self).__init__() 25 | self.decode_type = None 26 | self.allow_partial = problem.NAME == "sdvrp" 27 | self.is_vrp = problem.NAME == "cvrp" or problem.NAME == "sdvrp" 28 | self.is_orienteering = problem.NAME == "op" 29 | self.is_pctsp = problem.NAME == "pctsp" 30 | self.is_bipartite = problem.NAME == "bipartite" 31 | self.is_tsp = problem.NAME == "tsp" 32 | self.problem = problem 33 | self.rank = 0 34 | 35 | def forward(self, x, opts): 36 | state = self.problem.make_state(x, opts.u_size, opts.v_size, opts) 37 | 38 | self.rank = self.permute_uniform( 39 | torch.arange(1, state.u_size.item() + 2, device=x.device) 40 | .unsqueeze(0) 41 | .expand(state.batch_size.item(), state.u_size.item() + 1) 42 | ) 43 | sequences = [] 44 | self.rank[:, 0] = state.u_size.item() * 2 45 | while not (state.all_finished()): 46 | mask = state.get_mask().bool() 47 | r = self.rank.clone() 48 | r[mask] = 10e6 49 | selected = torch.argmin(r, dim=1) 50 | 51 | state = state.update(selected[:, None]) 52 | 53 | sequences.append(selected) 54 | 55 | return -state.size, torch.stack(sequences, 1) 56 | 57 | def set_decode_type(self, decode_type, temp=None): 58 | self.decode_type = decode_type 59 | if temp is not None: # Do not change temperature if not provided 60 | self.temp = temp 61 | 62 | def permute_uniform(self, x): 63 | """ 64 | Permutes a batch of lists uniformly at random using the Fisher Yates algorithm. 65 | """ 66 | y = x.clone() 67 | n = x.size(1) 68 | batch_size = x.size(0) 69 | for i in range(0, n - 1): 70 | j = torch.tensor(np.random.randint(i, n, (batch_size, 1)), device=x.device) 71 | temp = y[:, i].clone() 72 | y[:, i] = torch.gather(y, 1, j).squeeze(1) 73 | y = y.scatter_(1, j, temp.unsqueeze(1)) 74 | return y 75 | -------------------------------------------------------------------------------- /policy/greedy_rt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | 6 | class GreedyRt(nn.Module): 7 | def __init__( 8 | self, 9 | embedding_dim, 10 | hidden_dim, 11 | problem, 12 | opts, 13 | tanh_clipping=None, 14 | mask_inner=None, 15 | mask_logits=None, 16 | n_encode_layers=None, 17 | normalization=None, 18 | checkpoint_encoder=False, 19 | shrink_size=None, 20 | num_actions=None, 21 | n_heads=None, 22 | encoder=None, 23 | ): 24 | super(GreedyRt, self).__init__() 25 | self.decode_type = None 26 | self.problem = problem 27 | self.model_name = "greedy-rt" 28 | max_weight_dict = { 29 | "gmission-var": 18.8736, 30 | "gmission": 18.8736, 31 | "er": 10 ** 8, 32 | "ba": float( 33 | opts.graph_family_parameter 34 | ) # Make sure to set this properly before running! 35 | + float(opts.weight_distribution_param[1]), 36 | } 37 | norm_weight = { 38 | "gmission-var": 18.8736, 39 | "gmission": 18.8736, 40 | "er": 10 ** 8, 41 | "ba": 100.0, 42 | } 43 | if opts.graph_family == "gmission-perm": 44 | graph_family = "gmission" 45 | else: 46 | graph_family = opts.graph_family 47 | self.max_weight = max_weight_dict[graph_family] 48 | self.norm_weights = norm_weight[graph_family] 49 | 50 | def forward(self, x, opts, optimizer, baseline, return_pi=False): 51 | state = self.problem.make_state(x, opts.u_size, opts.v_size, opts) 52 | t = torch.tensor( 53 | np.e 54 | ** np.random.randint( 55 | 1, np.ceil(np.log(1 + self.max_weight)), (opts.batch_size, 1) 56 | ), 57 | device=opts.device, 58 | ) 59 | sequences = [] 60 | while not (state.all_finished()): 61 | mask = state.get_mask() 62 | w = state.get_current_weights(mask).clone() 63 | mask = state.get_mask() 64 | w[mask.bool()] = 0.0 65 | temp = ( 66 | w * self.norm_weights 67 | ) # re-normalize the weights since they are mostly between 0 and 1. 68 | temp[ 69 | temp > 0.0 70 | ] += 1.0 # To make sure all weights are at least 1 (needed for greedy-rt to work). 71 | w[temp >= t] = 1.0 72 | w[temp < t] = 0.0 73 | w[w.sum(1) == 0, 0] = 1.0 74 | selected = (w / torch.sum(w, dim=1)[:, None]).multinomial(1) 75 | state = state.update(selected) 76 | 77 | sequences.append(selected.squeeze(1)) 78 | if return_pi: 79 | return -state.size, None, torch.stack(sequences, 1), None 80 | return -state.size, torch.stack(sequences, 1), None 81 | 82 | def set_decode_type(self, decode_type, temp=None): 83 | self.decode_type = decode_type 84 | if temp is not None: # Do not change temperature if not provided 85 | self.temp = temp 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning for Edge-Weighted Online Bipartite Matching with Robustness Guarantees 2 | 3 | [![MIT licensed](https://img.shields.io/badge/license-MIT-brightgreen.svg)](LICENSE) 4 | 5 | [Pengfei Li](https://www.cs.ucr.edu/~pli081/), [Jianyi Yang](https://jyang-ai.github.io/) and [Shaolei Ren](https://intra.ece.ucr.edu/~sren/) 6 | 7 | **Note** 8 | 9 | This is the official implementation of the ICML 2023 paper 10 | 11 | ## Requirements 12 | 13 | * python>=3.6 14 | 15 | ## Installation 16 | * Clone this repo: 17 | ```bash 18 | git clone git@github.com:Ren-Research/LOMAR.git 19 | cd LOMAR 20 | ``` 21 | Then please refer to [the install guide](INSTALL.md) for more details about installation 22 | 23 | ## Usage 24 | To apply our algorithm (LOMAR) in online bipartite matching, you need three main steps 25 | 26 | 1. Generate graph dataset 27 | 2. Train the RL model 28 | 3. Evaluate the policy 29 | 30 | A script example for each step can be found in our brief [tutorial](TUTORIAL.md). 31 | 32 | ## Evaluation 33 | 34 | In our experiment, we set $u_0 = 10$ and $v_0 = 60$ to generate the training and testing datasets. The number of graph instances in the training and testing datasets are 20000 and 1000, respectively. For the sake of reproducibility and fair comparision, our settings follows the same setup of our [baseline](https://github.com/lyeskhalil/CORL). 35 | 36 | | ![space-1.jpg](https://github.com/Ren-Research/LOMAR/blob/main/figures/table1.png) | 37 | |:--:| 38 | | Table 1: Comparison under different $\rho$. In the top, LOMAR ($\rho = x$) means LOMAR is trained with the value of $\rho = x$. The average reward and competitive ratio are represented by AVG and CR, respectively — the higher, the better. The highest value in each testing setup is highlighted in bold. The AVG and CR for DRL are 12.909 and 0.544 respectively. The average reward for OPT is 13.209 .| 39 | 40 | The histogram of the bi-competitive ratios are visualized below. When $\rho = 0$, the ratio of DRL-OS / DRL is always 1 unsurprisingly. With a large $\rho$ (e.g. 0.8) for testing, the reward ratios of DRL-OS/Greedy for most samples are around 1, but the flexibility of DRL-OS is limited and can less exploit the good average performance of DRL. 41 | 42 | | ![space-1.jpg](https://github.com/Ren-Research/LOMAR/blob/main/figures/histogram.png) | 43 | |:--:| 44 | | Figure 1: Histogram of bi-competitive reward ratios of DRL-OS against Greedy and DRL under different $\rho$. The DRL-OS has the same online switching algorithm as LOMAR, while the RL model is trained with $\rho=0$. | 45 | 46 | ## Citation 47 | ```BibTex 48 | @inproceedings{Li2023LOMAR, 49 | title={Learning for Edge-Weighted Online Bipartite Matching with Robustness Guarantees}, 50 | author={Li, Pengfei and Yang, Jianyi and Ren, Shaolei}, 51 | booktitle={International Conference on Machine Learning}, 52 | year={2023}, 53 | organization={PMLR} 54 | } 55 | ``` 56 | 57 | 58 | ## Codebase 59 | Thanks for the code base from Mohammad Ali Alomrani, Reza Moravej, Elias B. Khalil. The public repository of their code is available at [https://github.com/lyeskhalil/CORL](https://github.com/lyeskhalil/CORL) 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /data/gMission/reduced_tasks.txt: -------------------------------------------------------------------------------- 1 | 347.0 2 | 414.0 3 | 40.0 4 | 134.0 5 | 462.0 6 | 467.0 7 | 72.0 8 | 88.0 9 | 344.0 10 | 489.0 11 | 2.0 12 | 68.0 13 | 196.0 14 | 360.0 15 | 507.0 16 | 646.0 17 | 52.0 18 | 113.0 19 | 201.0 20 | 171.0 21 | 326.0 22 | 442.0 23 | 563.0 24 | 395.0 25 | 434.0 26 | 91.0 27 | 384.0 28 | 401.0 29 | 520.0 30 | 433.0 31 | 427.0 32 | 553.0 33 | 47.0 34 | 133.0 35 | 693.0 36 | 105.0 37 | 198.0 38 | 516.0 39 | 673.0 40 | 37.0 41 | 612.0 42 | 164.0 43 | 237.0 44 | 464.0 45 | 545.0 46 | 131.0 47 | 316.0 48 | 614.0 49 | 66.0 50 | 160.0 51 | 43.0 52 | 387.0 53 | 4.0 54 | 223.0 55 | 296.0 56 | 340.0 57 | 499.0 58 | 681.0 59 | 106.0 60 | 32.0 61 | 38.0 62 | 366.0 63 | 295.0 64 | 528.0 65 | 576.0 66 | 116.0 67 | 455.0 68 | 573.0 69 | 92.0 70 | 431.0 71 | 99.0 72 | 307.0 73 | 588.0 74 | 638.0 75 | 583.0 76 | 103.0 77 | 204.0 78 | 532.0 79 | 44.0 80 | 45.0 81 | 330.0 82 | 365.0 83 | 8.0 84 | 146.0 85 | 212.0 86 | 415.0 87 | 490.0 88 | 707.0 89 | 621.0 90 | 582.0 91 | 680.0 92 | 64.0 93 | 154.0 94 | 193.0 95 | 221.0 96 | 256.0 97 | 290.0 98 | 374.0 99 | 3.0 100 | 18.0 101 | 468.0 102 | 593.0 103 | 598.0 104 | 61.0 105 | 175.0 106 | 623.0 107 | 39.0 108 | 342.0 109 | 376.0 110 | 544.0 111 | 54.0 112 | 451.0 113 | 225.0 114 | 606.0 115 | 27.0 116 | 112.0 117 | 127.0 118 | 487.0 119 | 617.0 120 | 666.0 121 | 71.0 122 | 130.0 123 | 277.0 124 | 495.0 125 | 603.0 126 | 644.0 127 | 85.0 128 | 187.0 129 | 341.0 130 | 444.0 131 | 209.0 132 | 692.0 133 | 310.0 134 | 354.0 135 | 525.0 136 | 280.0 137 | 373.0 138 | 328.0 139 | 677.0 140 | 189.0 141 | 278.0 142 | 299.0 143 | 314.0 144 | 579.0 145 | 383.0 146 | 244.0 147 | 70.0 148 | 177.0 149 | 411.0 150 | 428.0 151 | 479.0 152 | 712.0 153 | 417.0 154 | 230.0 155 | 599.0 156 | 126.0 157 | 357.0 158 | 641.0 159 | 167.0 160 | 447.0 161 | 708.0 162 | 46.0 163 | 370.0 164 | 385.0 165 | 643.0 166 | 97.0 167 | 561.0 168 | 60.0 169 | 79.0 170 | 122.0 171 | 363.0 172 | 697.0 173 | 482.0 174 | 31.0 175 | 125.0 176 | 626.0 177 | 152.0 178 | 233.0 179 | 378.0 180 | 409.0 181 | 422.0 182 | 477.0 183 | 560.0 184 | 580.0 185 | 660.0 186 | 683.0 187 | 309.0 188 | 343.0 189 | 597.0 190 | 498.0 191 | 524.0 192 | 194.0 193 | 217.0 194 | 250.0 195 | 264.0 196 | 289.0 197 | 453.0 198 | 595.0 199 | 636.0 200 | 197.0 201 | 293.0 202 | 124.0 203 | 195.0 204 | 75.0 205 | 120.0 206 | 279.0 207 | 425.0 208 | 459.0 209 | 486.0 210 | 577.0 211 | 100.0 212 | 29.0 213 | 627.0 214 | 671.0 215 | 210.0 216 | 234.0 217 | 624.0 218 | 674.0 219 | 220.0 220 | 543.0 221 | 111.0 222 | 143.0 223 | 153.0 224 | 320.0 225 | 179.0 226 | 191.0 227 | 581.0 228 | 709.0 229 | 36.0 230 | 206.0 231 | 405.0 232 | 461.0 233 | 33.0 234 | 213.0 235 | 254.0 236 | 261.0 237 | 322.0 238 | 372.0 239 | 616.0 240 | 701.0 241 | 156.0 242 | 379.0 243 | 424.0 244 | 559.0 245 | 630.0 246 | 655.0 247 | 679.0 248 | 240.0 249 | 438.0 250 | 555.0 251 | 98.0 252 | 205.0 253 | 485.0 254 | 227.0 255 | 496.0 256 | 668.0 257 | 190.0 258 | 445.0 259 | 199.0 260 | 180.0 261 | 421.0 262 | 550.0 263 | 440.0 264 | 172.0 265 | 276.0 266 | 403.0 267 | 432.0 268 | 637.0 269 | 665.0 270 | 128.0 271 | 186.0 272 | 255.0 273 | 284.0 274 | 327.0 275 | 361.0 276 | 408.0 277 | 556.0 278 | 219.0 279 | 446.0 280 | 699.0 281 | 248.0 282 | 539.0 283 | 169.0 284 | 62.0 285 | 215.0 286 | 457.0 287 | 94.0 288 | 339.0 289 | 349.0 290 | 22.0 291 | 548.0 292 | 590.0 293 | 656.0 294 | 17.0 295 | 19.0 296 | 123.0 297 | 139.0 298 | 200.0 299 | 241.0 300 | 491.0 301 | -------------------------------------------------------------------------------- /problem_state/obm_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | import os 4 | import pickle 5 | from problem_state.obm_env import StateBipartite 6 | 7 | 8 | class Bipartite(object): 9 | 10 | NAME = "obm" 11 | 12 | # @staticmethod 13 | # def get_costs(dataset, pi): 14 | # # TODO: MODIFY CODE SO IT WORKS WITH BIPARTITE INSTEAD OF TSP 15 | # # Check that tours are valid, i.e. contain 0 to n -1 16 | # assert ( 17 | # torch.arange(pi.size(1), out=pi.data.new()).view(1, -1).expand_as(pi) 18 | # == pi.data.sort(1)[0] 19 | # ).all(), "Invalid tour" 20 | # # Gather dataset in order of tour 21 | # d = dataset.gather(1, pi.unsqueeze(-1).expand_as(dataset)) 22 | # # Length is distance (L2-norm of difference) from each next location from its prev and of last from first 23 | # return ( 24 | # (d[:, 1:] - d[:, :-1]).norm(p=2, dim=2).sum(1) 25 | # + (d[:, 0] - d[:, -1]).norm(p=2, dim=1), 26 | # None, 27 | # ) 28 | 29 | @staticmethod 30 | def make_dataset(*args, **kwargs): 31 | return BipartiteDataset(*args, **kwargs) 32 | 33 | @staticmethod 34 | def make_state(*args, **kwargs): 35 | return StateBipartite.initialize(*args, **kwargs) 36 | 37 | # @staticmethod 38 | # def beam_search( 39 | # input, 40 | # beam_size, 41 | # expand_size=None, 42 | # compress_mask=False, 43 | # model=None, 44 | # max_calc_batch_size=4096, 45 | # ): 46 | 47 | # assert model is not None, "Provide model" 48 | 49 | # fixed = model.precompute_fixed(input) 50 | 51 | # def propose_expansions(beam): 52 | # return model.propose_expansions( 53 | # beam, 54 | # fixed, 55 | # expand_size, 56 | # normalize=True, 57 | # max_calc_batch_size=max_calc_batch_size, 58 | # ) 59 | 60 | # state = Bipartite.make_state( 61 | # input, visited_dtype=torch.int64 if compress_mask else torch.uint8 62 | # ) 63 | 64 | # return beam_search(state, beam_size, propose_expansions) 65 | 66 | 67 | class BipartiteDataset(Dataset): 68 | def __init__(self, dataset, size, problem, opts): 69 | super(BipartiteDataset, self).__init__() 70 | 71 | self.data_set = dataset 72 | self.optimal_size = torch.load("{}/optimal_match.pt".format(self.data_set)) 73 | self.problem = problem 74 | # if opts.train_dataset is not None: 75 | # assert os.path.splitext(opts.train_dataset)[1] == ".pkl" 76 | 77 | # with open(opts.train_dataset, "rb") as f: 78 | # data = pickle.load(f) 79 | # self.data = data 80 | # else: 81 | # ### TODO: Should use generate function in generate_data.py 82 | # # If no filename is specified generated data for normal obm probelm 83 | # self.data = generate_obm_data(opts) 84 | 85 | self.size = size 86 | 87 | def __getitem__(self, index): 88 | 89 | # Load data and get label 90 | X = torch.load("{}/graphs/{}.pt".format(self.data_set, index)) 91 | Y = self.optimal_size[index] 92 | return X, Y 93 | 94 | def __len__(self): 95 | return self.size 96 | 97 | # def __getitem__(self, idx): 98 | # return tuple(d[idx] for d in self.data) 99 | 100 | 101 | # train_loader = torch.utils.data.DataLoader( 102 | # ConcatDataset( 103 | # datasets.ImageFolder(traindir_A), 104 | # datasets.ImageFolder(traindir_B) 105 | # ), 106 | # batch_size=args.batch_size, shuffle=True, 107 | # num_workers=args.workers, pin_memory=True) 108 | -------------------------------------------------------------------------------- /data/MovieLense/movies.txt: -------------------------------------------------------------------------------- 1 | 36::Dead Man Walking (1995)::Drama 2 | 68::French Twist (Gazon maudit) (1995)::Comedy|Romance 3 | 124::Star Maker, The (Uomo delle stelle, L') (1995)::Drama 4 | 170::Hackers (1995)::Action|Crime|Thriller 5 | 180::Mallrats (1995)::Comedy 6 | 192::Show, The (1995)::Documentary 7 | 251::Hunted, The (1995)::Action 8 | 252::I.Q. (1994)::Comedy|Romance 9 | 273::Mary Shelley's Frankenstein (1994)::Drama|Horror 10 | 398::Frank and Ollie (1995)::Documentary 11 | 404::Brother Minister: The Assassination of Malcolm X (1994)::Documentary 12 | 410::Addams Family Values (1993)::Comedy 13 | 523::Ruby in Paradise (1993)::Drama 14 | 538::Six Degrees of Separation (1993)::Drama 15 | 550::Threesome (1994)::Comedy|Romance 16 | 582::Metisse (Caf� au Lait) (1993)::Comedy 17 | 587::Ghost (1990)::Comedy|Romance|Thriller 18 | 608::Fargo (1996)::Crime|Drama|Thriller 19 | 635::Family Thing, A (1996)::Comedy|Drama 20 | 651::Superweib, Das (1996)::Comedy 21 | 657::Yankee Zulu (1994)::Comedy|Drama 22 | 725::Great White Hype, The (1996)::Comedy 23 | 757::Ashes of Time (1994)::Drama 24 | 778::Trainspotting (1996)::Drama 25 | 829::Joe's Apartment (1996)::Comedy|Musical 26 | 843::Lotto Land (1995)::Drama 27 | 848::Spitfire Grill, The (1996)::Drama 28 | 852::Tin Cup (1996)::Comedy|Romance 29 | 1127::Abyss, The (1989)::Action|Adventure|Sci-Fi|Thriller 30 | 1128::Fog, The (1980)::Horror 31 | 1131::Jean de Florette (1986)::Drama 32 | 1206::Clockwork Orange, A (1971)::Sci-Fi 33 | 1208::Apocalypse Now (1979)::Drama|War 34 | 1276::Cool Hand Luke (1967)::Comedy|Drama 35 | 1303::Man Who Would Be King, The (1975)::Adventure 36 | 1358::Sling Blade (1996)::Drama|Thriller 37 | 1362::Garden of Finzi-Contini, The (Giardino dei Finzi-Contini, Il) (1970)::Drama 38 | 1414::Mother (1996)::Comedy 39 | 1419::Walkabout (1971)::Drama 40 | 1434::Stranger, The (1994)::Action 41 | 1436::Falling in Love Again (1980)::Comedy 42 | 1487::Selena (1997)::Drama|Musical 43 | 1489::Cats Don't Dance (1997)::Animation|Children's|Musical 44 | 1567::Last Time I Committed Suicide, The (1997)::Drama 45 | 1582::Wild America (1997)::Adventure|Children's 46 | 1622::Kicked in the Head (1997)::Comedy|Drama 47 | 1681::Mortal Kombat: Annihilation (1997)::Action|Adventure 48 | 1780::Ayn Rand: A Sense of Life (1997)::Documentary 49 | 1801::Man in the Iron Mask, The (1998)::Action|Drama|Romance 50 | 1811::Niagara, Niagara (1997)::Drama 51 | 1829::Chinese Box (1997)::Drama|Romance 52 | 1835::City of Angels (1998)::Romance 53 | 1883::Bulworth (1998)::Comedy 54 | 1911::Doctor Dolittle (1998)::Comedy 55 | 1960::Last Emperor, The (1987)::Drama|War 56 | 1992::Child's Play 2 (1990)::Horror 57 | 2011::Back to the Future Part II (1989)::Comedy|Sci-Fi 58 | 2019::Seven Samurai (The Magnificent Seven) (Shichinin no samurai) (1954)::Action|Drama 59 | 2035::Blackbeard's Ghost (1968)::Children's|Comedy 60 | 2042::D2: The Mighty Ducks (1994)::Children's|Comedy 61 | 2214::Number Seventeen (1932)::Thriller 62 | 2236::Simon Birch (1998)::Drama 63 | 2318::Happiness (1998)::Comedy 64 | 2344::Runaway Train (1985)::Action|Adventure|Drama|Thriller 65 | 2415::Violets Are Blue... (1986)::Drama|Romance 66 | 2448::Virus (1999)::Horror|Sci-Fi 67 | 2473::Soul Man (1986)::Comedy 68 | 2565::King and I, The (1956)::Musical 69 | 2594::Open Your Eyes (Abre los ojos) (1997)::Drama|Romance|Sci-Fi 70 | 2833::Lucie Aubrac (1997)::Romance|War 71 | 2850::Public Access (1993)::Drama|Thriller 72 | 2863::Hard Day's Night, A (1964)::Comedy|Musical 73 | 2885::Guinevere (1999)::Drama|Romance 74 | 2906::Random Hearts (1999)::Drama|Romance 75 | 2918::Ferris Bueller's Day Off (1986)::Comedy 76 | 2940::Gilda (1946)::Film-Noir 77 | 2970::Fitzcarraldo (1982)::Adventure|Drama 78 | 3051::Anywhere But Here (1999)::Drama 79 | 3197::Presidio, The (1988)::Action 80 | 3221::Draughtsman's Contract, The (1982)::Drama 81 | 3247::Sister Act (1992)::Comedy|Crime 82 | 3270::Cutting Edge, The (1992)::Drama 83 | 3299::Hanging Up (2000)::Comedy|Drama 84 | 3413::Impact (1949)::Crime|Drama 85 | 3429::Creature Comforts (1990)::Animation|Comedy 86 | 3442::Band of the Hand (1986)::Action 87 | 3464::Solar Crisis (1993)::Sci-Fi|Thriller 88 | 3484::Skulls, The (2000)::Thriller 89 | 3500::Mr. Saturday Night (1992)::Comedy|Drama 90 | 3643::Fighting Seabees, The (1944)::Action|Drama|War 91 | 3665::Curse of the Puppet Master (1998)::Horror|Sci-Fi|Thriller 92 | 3691::Private School (1983)::Comedy 93 | 3786::But I'm a Cheerleader (1999)::Comedy 94 | 3943::Bamboozled (2000)::Comedy 95 | -------------------------------------------------------------------------------- /policy/inv_ff_history.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class InvariantFFHist(nn.Module): 6 | def __init__( 7 | self, 8 | embedding_dim, 9 | hidden_dim, 10 | problem, 11 | opts, 12 | tanh_clipping=None, 13 | mask_inner=None, 14 | mask_logits=None, 15 | n_encode_layers=None, 16 | normalization="batch", 17 | checkpoint_encoder=False, 18 | shrink_size=None, 19 | num_actions=4, 20 | n_heads=None, 21 | encoder=None, 22 | ): 23 | super(InvariantFFHist, self).__init__() 24 | 25 | self.embedding_dim = embedding_dim 26 | self.decode_type = None 27 | self.num_actions = 16 if opts.problem != "adwords" else 19 28 | self.problem = problem 29 | self.model_name = "inv-ff-hist" 30 | self.ff = nn.Sequential( 31 | nn.Linear(self.num_actions, 100), 32 | nn.ReLU(), 33 | nn.Linear(100, 100), 34 | nn.ReLU(), 35 | nn.Linear(100, 1), 36 | ) 37 | 38 | def forward(self, x, opts, optimizer, baseline, return_pi=True): 39 | 40 | _log_p, pi, cost = self._inner(x, opts) 41 | 42 | ll, e = self._calc_log_likelihood(_log_p, pi, None) 43 | if return_pi: 44 | return -cost, ll, pi, e 45 | 46 | return -cost, ll, e 47 | 48 | def _calc_log_likelihood(self, _log_p, a, mask): 49 | 50 | # Get log_p corresponding to selected actions 51 | entropy = -(_log_p * _log_p.exp()).sum(2).sum(1).mean() 52 | log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1) 53 | 54 | # Optional: mask out actions irrelevant to objective so they do not get reinforced 55 | if mask is not None: 56 | log_p[mask] = 0 57 | if not (log_p > -1e8).data.all(): 58 | print(log_p) 59 | assert ( 60 | log_p > -1e8 61 | ).data.all(), "Logprobs should not be -inf, check sampling procedure!" 62 | # Calculate log_likelihood 63 | return log_p.sum(1), entropy 64 | 65 | def _inner(self, input, opts): 66 | 67 | outputs = [] 68 | sequences = [] 69 | state = self.problem.make_state(input, opts.u_size, opts.v_size, opts) 70 | 71 | # Perform decoding steps 72 | i = 1 73 | while not (state.all_finished()): 74 | mask = state.get_mask() 75 | state.get_current_weights(mask) 76 | s, mask = state.get_curr_state(self.model_name) 77 | pi = self.ff(s).reshape(state.batch_size, state.u_size + 1) 78 | # Select the indices of the next nodes in the sequences, result (batch_size) long 79 | selected, p = self._select_node( 80 | pi, mask.bool() 81 | ) # Squeeze out steps dimension 82 | # entropy += torch.sum(p * (p.log()), dim=1) 83 | state = state.update((selected)[:, None]) 84 | outputs.append(p) 85 | sequences.append(selected) 86 | i += 1 87 | # Collected lists, return Tensor 88 | return ( 89 | torch.stack(outputs, 1), 90 | torch.stack(sequences, 1), 91 | state.size, 92 | ) 93 | 94 | def _select_node(self, probs, mask): 95 | assert (probs == probs).all(), "Probs should not contain any nans" 96 | probs[mask] = -1e8 97 | p = torch.log_softmax(probs, dim=1) 98 | # print(p) 99 | if self.decode_type == "greedy": 100 | _, selected = p.max(1) 101 | # assert not mask.gather( 102 | # 1, selected.unsqueeze(-1) 103 | # ).data.any(), "Decode greedy: infeasible action has maximum probability" 104 | 105 | elif self.decode_type == "sampling": 106 | selected = p.exp().multinomial(1).squeeze(1) 107 | # Check if sampling went OK, can go wrong due to bug on GPU 108 | # See https://discuss.pytorch.org/t/bad-behavior-of-multinomial-function/10232 109 | # while mask.gather(1, selected.unsqueeze(-1)).data.any(): 110 | # print("Sampled bad values, resampling!") 111 | # selected = p.exp().multinomial(1).squeeze(1) 112 | 113 | else: 114 | assert False, "Unknown decode type" 115 | return selected, p 116 | 117 | def set_decode_type(self, decode_type, temp=None): 118 | self.decode_type = decode_type 119 | if temp is not None: # Do not change temperature if not provided 120 | self.temp = temp 121 | -------------------------------------------------------------------------------- /data/MovieLense/users.txt: -------------------------------------------------------------------------------- 1 | 98,F,35,7,33547 2 | 4628,F,56,6,23225 3 | 250,M,35,16,11229 4 | 3530,M,25,0,77058 5 | 4749,M,45,0,23462 6 | 3021,F,45,15,78750 7 | 4192,M,35,17,77521 8 | 2530,M,18,0,42420 9 | 932,F,35,6,97838 10 | 5380,M,18,4,01125 11 | 2884,F,50,3,53222 12 | 3273,M,35,7,11201 13 | 217,M,18,4,22903 14 | 5012,M,50,15,48103 15 | 4419,F,45,12,15801 16 | 5828,M,25,17,90048 17 | 311,M,18,4,31201 18 | 2292,M,50,15,13039 19 | 4068,M,56,3,89185 20 | 1664,F,18,4,76707 21 | 1406,M,25,4,91911 22 | 572,M,18,4,61801 23 | 3222,M,35,7,10583 24 | 4398,F,18,4,43612 25 | 4393,M,18,4,43402 26 | 5533,F,50,2,10025 27 | 3291,F,25,7,27516 28 | 1993,M,35,17,99508 29 | 2061,M,56,1,44107 30 | 5725,M,18,2,10018 31 | 1351,F,25,17,98034 32 | 2111,M,56,13,32806 33 | 665,M,18,2,13317 34 | 2584,F,56,1,94563 35 | 761,M,18,7,99945 36 | 4332,F,25,3,60606 37 | 5168,M,35,7,01720 38 | 6038,F,56,1,14706 39 | 1452,F,56,13,90732 40 | 708,M,25,0,37042 41 | 703,F,56,1,44074 42 | 5314,F,25,0,94107 43 | 5145,M,35,7,77565-2332 44 | 3376,M,35,16,10803 45 | 2673,M,35,7,77040 46 | 2037,M,56,14,02468 47 | 2696,M,25,7,24210 48 | 160,M,35,7,13021 49 | 2160,F,18,19,46556 50 | 4775,M,25,17,23462 51 | 2714,M,18,4,96815 52 | 5258,M,35,4,58369 53 | 2133,F,1,10,01607 54 | 2759,M,25,7,36102 55 | 5216,M,18,2,81291 56 | 5215,F,56,6,91941 57 | 5027,F,25,14,05345 58 | 3969,M,56,7,12345 59 | 3633,M,35,18,60441 60 | 4463,M,50,7,07405 61 | 947,M,18,0,21015 62 | 2128,F,45,1,46737 63 | 3407,F,45,5,60050 64 | 2819,M,45,14,44122 65 | 4244,M,35,17,48146 66 | 5525,F,1,10,55311 67 | 341,F,56,13,92119 68 | 4383,F,18,4,93940 69 | 4527,M,18,0,38111 70 | 4525,M,25,14,10021 71 | 421,F,45,3,55125 72 | 3787,F,25,0,94703 73 | 4991,F,56,14,91791 74 | 2488,F,45,2,31804 75 | 4178,F,50,16,08033 76 | 5904,F,45,12,954025 77 | 4365,F,50,16,46350 78 | 2381,F,35,9,13066 79 | 1844,F,18,4,53711 80 | 1493,F,25,14,60657 81 | 4547,M,18,12,02115 82 | 3552,F,35,2,02879 83 | 5309,M,25,0,02478 84 | 3883,M,50,16,78411 85 | 2930,M,50,17,45420 86 | 5159,M,35,7,37027 87 | 2686,M,25,0,85283 88 | 94,M,25,17,28601 89 | 1549,M,35,6,20901 90 | 5052,M,35,16,02536 91 | 5863,F,25,14,89511 92 | 5804,F,35,7,94306 93 | 4880,F,45,9,95030 94 | 3642,F,25,7,98188 95 | 3068,F,35,9,66204 96 | 1305,M,45,20,94121 97 | 1307,M,56,13,75703 98 | 653,M,56,13,92660 99 | 4230,M,56,13,34235 100 | 318,F,56,13,55104 101 | 317,M,35,7,38555 102 | 4065,F,45,5,48341 103 | 951,M,45,2,10009 104 | 3388,F,25,1,92646 105 | 833,M,35,7,46825 106 | 574,M,25,19,90214 107 | 356,M,56,20,55101 108 | 3225,M,18,0,15237 109 | 559,F,25,7,60422 110 | 4259,M,35,17,45380 111 | 5899,M,35,17,30024 112 | 2422,F,50,9,02865 113 | 3324,M,35,17,48083 114 | 535,M,35,6,95370 115 | 1431,M,56,6,91011 116 | 1488,M,25,7,94538 117 | 1226,M,18,20,07087 118 | 2234,M,56,1,85718 119 | 3495,M,18,4,94608 120 | 4558,M,35,14,48217 121 | 4554,F,25,9,48306 122 | 2373,M,45,7,94109 123 | 4943,F,50,0,55401 124 | 3133,M,18,4,33720 125 | 3897,M,56,17,92069 126 | 6034,M,25,14,94117 127 | 2908,M,18,2,74403 128 | 1459,F,25,0,80027 129 | 2923,M,35,12,44256 130 | 3176,M,25,17,95129 131 | 1713,M,25,0,47421 132 | 1751,M,35,20,77024 133 | 1282,M,50,1,98541 134 | 1363,M,50,7,01915 135 | 383,F,25,7,78757 136 | 226,M,35,1,94518 137 | 2502,M,56,13,49423 138 | 1250,M,50,0,10573 139 | 1126,M,56,16,13031 140 | 5810,M,50,16,95070 141 | 4614,F,45,1,10469 142 | 821,M,18,2,14456 143 | 2527,M,35,15,36109 144 | 2893,M,25,1,94025 145 | 4,M,45,7,02460 146 | 906,M,1,10,71106 147 | 907,F,45,6,92117 148 | 1318,M,18,16,94014 149 | 3616,M,45,0,33759 150 | 1534,M,50,12,12184 151 | 3234,F,1,0,90012 152 | 1615,F,45,7,10022 153 | 640,M,18,4,47406 154 | 4464,F,50,1,62052 155 | 5409,M,45,1,60091 156 | 4264,M,45,1,60102 157 | 320,M,35,6,99516 158 | 4871,M,18,12,04096 159 | 3188,M,56,13,55344 160 | 5529,M,18,6,05273 161 | 2207,M,18,4,07508 162 | 2336,M,35,18,13146 163 | 4992,F,25,2,55106 164 | 3275,M,35,7,10028 165 | 3488,M,45,17,80228 166 | 1801,F,25,6,54729 167 | 4548,M,18,2,27514 168 | 1070,F,45,2,20036 169 | 2977,F,18,4,48104 170 | 5117,F,25,0,70118 171 | 5116,M,35,6,95521 172 | 1783,M,18,4,10027 173 | 4110,F,45,7,41722 174 | 2640,M,35,12,78230 175 | 2644,M,18,15,45215 176 | 5174,M,45,1,45208 177 | 89,F,56,9,85749 178 | 277,F,35,1,98126 179 | 4665,M,35,2,03431 180 | 2516,F,45,6,37909 181 | 1157,M,45,7,30253 182 | 4741,M,35,7,15203 183 | 158,M,56,7,28754 184 | 1708,M,25,0,64106 185 | 1927,M,35,7,94806 186 | 4881,M,45,16,55406 187 | 1582,M,18,0,10705 188 | 213,F,18,4,01609 189 | 197,M,18,14,10023 190 | 315,F,56,1,55105 191 | 2295,M,45,11,72316 192 | 5207,M,25,0,98102 193 | 3623,M,35,11,43209 194 | 2467,M,1,10,44224 195 | 5439,M,56,13,19119 196 | 1621,M,25,4,11423 197 | 417,F,25,0,50613 198 | 1384,M,18,10,98058 199 | 1385,F,25,5,90262 200 | 5465,F,25,0,10019 201 | -------------------------------------------------------------------------------- /policy/greedy_matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.utils import subgraph, to_networkx 4 | from networkx.algorithms.matching import max_weight_matching 5 | from torch_geometric.data import Data 6 | 7 | 8 | class GreedyMatching(nn.Module): 9 | def __init__( 10 | self, 11 | embedding_dim, 12 | hidden_dim, 13 | problem, 14 | opts, 15 | tanh_clipping=None, 16 | mask_inner=None, 17 | mask_logits=None, 18 | n_encode_layers=None, 19 | normalization=None, 20 | checkpoint_encoder=False, 21 | shrink_size=None, 22 | num_actions=None, 23 | n_heads=None, 24 | encoder=None, 25 | ): 26 | super(GreedyMatching, self).__init__() 27 | self.decode_type = None 28 | self.problem = problem 29 | self.model_name = "greedy-m" 30 | 31 | def forward(self, x, opts, optimizer, baseline, return_pi=False): 32 | state = self.problem.make_state(x, opts.u_size, opts.v_size, opts) 33 | t = opts.threshold 34 | sequences = [] 35 | batch_size = opts.batch_size 36 | graph_size = opts.u_size + opts.v_size + 1 37 | i = 1 38 | while not (state.all_finished()): 39 | step_size = state.i + 1 40 | mask = state.get_mask() 41 | state.get_current_weights(mask).clone() 42 | mask = state.get_mask() 43 | if i <= int( 44 | t * opts.v_size 45 | ): # Skip matching if less than t fraction of V nodes arrived 46 | selected = torch.zeros( 47 | batch_size, dtype=torch.int64, device=opts.device 48 | ) 49 | state = state.update(selected[:, None]) 50 | sequences.append(selected) 51 | i += 1 52 | continue 53 | nodes = torch.cat( 54 | ( 55 | torch.arange(1, opts.u_size + 1, device=opts.device), 56 | state.idx[:i] + opts.u_size + 1, 57 | ) 58 | ) 59 | subgraphs = ( 60 | (nodes.unsqueeze(0).expand(batch_size, step_size - 1)) 61 | + torch.arange( 62 | 0, batch_size * graph_size, graph_size, device=opts.device 63 | ).unsqueeze(1) 64 | ).flatten() # The nodes of the current subgraphs 65 | graph_weights = state.get_graph_weights() 66 | edge_i, weights = subgraph( 67 | subgraphs, 68 | state.graphs.edge_index, 69 | graph_weights.unsqueeze(1), 70 | relabel_nodes=True, 71 | ) 72 | match_sol = torch.tensor( 73 | list( 74 | max_weight_matching( 75 | to_networkx( 76 | Data(subgraphs, edge_index=edge_i, edge_attr=weights), 77 | to_undirected=True, 78 | ) 79 | ) 80 | ), 81 | device=opts.device, 82 | ) 83 | 84 | edges_sol = torch.cat( 85 | ( 86 | torch.arange(0, opts.u_size, device=opts.device) 87 | .unsqueeze(0) 88 | .expand(batch_size, -1) 89 | .unsqueeze(2), 90 | torch.ones(batch_size, opts.u_size, 1) * (opts.u_size + i - 1), 91 | ), 92 | dim=2, 93 | ) 94 | match_sol = match_sol.sort(dim=-1)[0] 95 | offset = torch.arange(0, batch_size * (opts.u_size + i), opts.u_size + i)[ 96 | :, None, None 97 | ] 98 | edges_sol = edges_sol + offset 99 | in_sol = edges_sol.unsqueeze(1) == match_sol[None, :, :, None].expand( 100 | batch_size, -1, -1, opts.u_size 101 | ).transpose(-1, -2) 102 | in_sol = in_sol * (1 - mask[:, None, 1:, None]) 103 | in_sol = ( 104 | in_sol.prod(-1).sum(-1).float().unsqueeze(1) 105 | * match_sol[None, :, :].transpose(1, 2) 106 | ).sum(-1) 107 | skip_sol = in_sol.sum(-1) 108 | in_sol = in_sol[:, 0] - offset.reshape(-1) 109 | in_sol[skip_sol == 0] = -1 110 | selected = (in_sol + 1).type(torch.int64) 111 | # print(selected, match_sol) 112 | state = state.update(selected[:, None]) 113 | 114 | sequences.append(selected) 115 | i += 1 116 | if return_pi: 117 | return -state.size, None, torch.stack(sequences, 1), None 118 | return -state.size, torch.stack(sequences, 1), None 119 | 120 | def set_decode_type(self, decode_type, temp=None): 121 | self.decode_type = decode_type 122 | if temp is not None: # Do not change temperature if not provided 123 | self.temp = temp 124 | -------------------------------------------------------------------------------- /policy/ff_model_invariant.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class InvariantFF(nn.Module): 6 | def __init__( 7 | self, 8 | embedding_dim, 9 | hidden_dim, 10 | problem, 11 | opts, 12 | tanh_clipping=None, 13 | mask_inner=None, 14 | mask_logits=None, 15 | n_encode_layers=None, 16 | normalization="batch", 17 | checkpoint_encoder=False, 18 | shrink_size=None, 19 | num_actions=4, 20 | n_heads=None, 21 | encoder=None, 22 | ): 23 | 24 | super(InvariantFF, self).__init__() 25 | 26 | self.embedding_dim = embedding_dim 27 | self.decode_type = None 28 | self.num_actions = 3 if opts.problem != "adwords" else 5 29 | self.is_bipartite = problem.NAME == "bipartite" 30 | self.problem = problem 31 | self.shrink_size = None 32 | self.ff = nn.Sequential( 33 | nn.Linear(self.num_actions, 100), 34 | nn.ReLU(), 35 | nn.Linear(100, 100), 36 | nn.ReLU(), 37 | nn.Linear(100, 1), 38 | ) 39 | self.model_name = "inv-ff" 40 | 41 | # def init_weights(m): 42 | # if type(m) == nn.Linear: 43 | # torch.nn.init.xavier_uniform_(m.weight) 44 | # m.bias.data.fill_(0.0001) 45 | 46 | # self.ff.apply(init_weights) 47 | 48 | def forward(self, x, opts, optimizer, baseline, return_pi=False): 49 | 50 | _log_p, pi, cost = self._inner(x, opts) 51 | 52 | # cost, mask = self.problem.get_costs(input, pi) 53 | # Log likelyhood is calculated within the model since returning it per action does not work well with 54 | # DataParallel since sequences can be of different lengths 55 | ll, e = self._calc_log_likelihood(_log_p, pi, None) 56 | if return_pi: 57 | return -cost, ll, pi, e 58 | # print(ll) 59 | return -cost, ll, e 60 | 61 | def _calc_log_likelihood(self, _log_p, a, mask): 62 | 63 | # Get log_p corresponding to selected actions 64 | log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1) 65 | entropy = -(_log_p * _log_p.exp()).sum(2).sum(1).mean() 66 | # Optional: mask out actions irrelevant to objective so they do not get reinforced 67 | if mask is not None: 68 | log_p[mask] = 0 69 | if not (log_p > -10000).data.all(): 70 | print(log_p) 71 | assert ( 72 | log_p > -10000 73 | ).data.all(), "Logprobs should not be -inf, check sampling procedure!" 74 | 75 | # Calculate log_likelihood 76 | return log_p.sum(1), entropy 77 | 78 | def _inner(self, input, opts): 79 | 80 | outputs = [] 81 | sequences = [] 82 | state = self.problem.make_state(input, opts.u_size, opts.v_size, opts) 83 | 84 | i = 1 85 | while not (state.all_finished()): 86 | mask = state.get_mask() 87 | state.get_current_weights(mask) 88 | s, mask = state.get_curr_state(self.model_name) 89 | 90 | pi = self.ff(s).reshape(state.batch_size, state.u_size + 1) 91 | # Select the indices of the next nodes in the sequences, result (batch_size) long 92 | selected, p = self._select_node(pi, mask.bool()) 93 | 94 | state = state.update((selected)[:, None]) 95 | outputs.append(p) 96 | sequences.append(selected) 97 | i += 1 98 | # Collected lists, return Tensor 99 | return ( 100 | torch.stack(outputs, 1), 101 | torch.stack(sequences, 1), 102 | state.size, 103 | ) 104 | 105 | def _select_node(self, probs, mask): 106 | assert (probs == probs).all(), "Probs should not contain any nans" 107 | probs[mask] = -1e6 108 | p = torch.log_softmax(probs, dim=1) 109 | # print(p) 110 | if self.decode_type == "greedy": 111 | _, selected = p.max(1) 112 | # assert not mask.gather( 113 | # 1, selected.unsqueeze(-1) 114 | # ).data.any(), "Decode greedy: infeasible action has maximum probability" 115 | 116 | elif self.decode_type == "sampling": 117 | selected = p.exp().multinomial(1).squeeze(1) 118 | # Check if sampling went OK, can go wrong due to bug on GPU 119 | # See https://discuss.pytorch.org/t/bad-behavior-of-multinomial-function/10232 120 | # while mask.gather(1, selected.unsqueeze(-1)).data.any(): 121 | # print("Sampled bad values, resampling!") 122 | # selected = probs.multinomial(1).squeeze(1) 123 | 124 | else: 125 | assert False, "Unknown decode type" 126 | return selected, p 127 | 128 | def set_decode_type(self, decode_type, temp=None): 129 | self.decode_type = decode_type 130 | if temp is not None: # Do not change temperature if not provided 131 | self.temp = temp 132 | -------------------------------------------------------------------------------- /TUTORIAL.md: -------------------------------------------------------------------------------- 1 | # Tutorial 2 | 3 | Here we provide a brief tutorial about our [codebase](https://github.com/lyeskhalil/CORL). If you are already familiar with the codes and the framework, please skip this tutorial. 4 | 5 | ## 1. Bipartite Data Generation Code 6 | 7 | 8 | ### Supported Args 9 | ``` 10 | problem: "adwords", "e-obm", "osbm" 11 | graph_family: "er", "ba", "triangular", "thick-z", "movielense", "gmission" 12 | weight_distribution: "normal", "power", "degree", "node-normal", "fixed-normal" 13 | weight_distribution_param: ... 14 | graph_family_parameter: ... 15 | ``` 16 | Caution: We only list all the feasible parameters here, some combination may not be allowed (adwords + gmission), please check before using these arguments. 17 | 18 | ## 2. Networkx Guide 19 | ### Print information for a graph 20 | ```python 21 | print(G) 22 | #Graph named 'triangular_graph(3,9,-1.0)' with 12 nodes and 18 edges 23 | print(G.graph) 24 | #{'name': 'triangular_graph(3,9,-1.0)'} 25 | print(G.nodes) 26 | #[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] 27 | print(G.nodes[1]) #{'bipartite': 0} 28 | print(G.edges) 29 | #[(0, 3), (0, 4), (0, 5), (1, 3), ..., (2, 6), (2, 7), (2, 8), (2, 9), (2, 10), (2, 11)] 30 | print(G.edges[0,3]) 31 | #{'weight': 0.3958830486584797} 32 | ``` 33 | 34 | 35 | ### Convert data type 36 | Convert the networkx graph to torch 37 | ```python 38 | def from_networkx(G): 39 | r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a 40 | :class:`torch_geometric.data.Data` instance""" 41 | pass 42 | ``` 43 | Reveal useful information from the `TORCH_GEOMETRIC.DATA` class 44 | ```python 45 | print(data.edge_index) # Print the edge between nodes, the graph is directed (two edges needed for one pair of nodes) 46 | print(data.weight) # Print the weight 47 | print(data.bipartite) # Print the portion 48 | print(data.x) # Capacity of each resource 49 | print(data.y) # Optimal solution, (total_gain, action_array) 50 | print(data.edge_attr) # None for this project 51 | ``` 52 | 53 | ## 3. Important Arguments for the OBM codebase 54 | ```shell 55 | 'load_path': 'path_to_pretrained_model', # Model path 56 | 'checkpoint_epochs': 0, # The frequency of routine checkpoint saving 57 | 'model': 'inv-ff-hist', # Model type 58 | ``` 59 | 60 | 61 | ## 4. RL Model For OBM 62 | ### Apply for a new GPU node in SLURM cluster (skip this step on your own server) 63 | ``` shell 64 | srun -p gpu --gres=gpu:1 --mem=100g --nodelist=gpu02 --time=48:00:00 --pty bash -l 65 | ``` 66 | 67 | ### Quick start - Data Generation 68 | ```shell 69 | python data/generate_data.py --problem adwords --dataset_size 1000 \ 70 | --dataset_folder dataset/train/adwords_triangular_uniform_0.10.4_10by100/parameter_-1 \ 71 | --u_size 10 --v_size 60 --graph_family triangular --weight_distribution uniform \ 72 | --weight_distribution_param 0.1 0.4 --graph_family_parameter -1 73 | 74 | python data/generate_data.py --problem adwords --dataset_size 1 \ 75 | --dataset_folder dataset/train/adwords_triangular_triangular_0.10.4_10by100/parameter_-1 \ 76 | --u_size 3 --v_size 9 --graph_family ba --weight_distribution uniform \ 77 | --weight_distribution_param 0 1 --graph_family_parameter 2 78 | ``` 79 | 80 | ### Quick Start - Training 81 | ```shell 82 | python run.py --encoder mpnn --model inv-ff-hist --problem adwords --batch_size 100 --embedding_dim 30 --n_heads 1 --u_size 10 --v_size 60 \ 83 | --n_epochs 20 --train_dataset dataset/train/adwords_triangular_uniform_0.10.4_10by60/parameter_-1 \ 84 | --val_dataset dataset/val/adwords_triangular_uniform_0.10.4_10by60/parameter_-1 \ 85 | --dataset_size 1000 --val_size 100 --checkpoint_epochs 0 --baseline exponential --lr_model 0.006 --lr_decay 0.97 \ 86 | --output_dir saved_models --log_dir logs_dataset --n_encode_layers 1 \ 87 | --save_dir saved_models/adwords_triangular_uniform_0.10.4_10by60/parameter_-1 \ 88 | --graph_family_parameter -1 --exp_beta 0.8 --ent_rate 0.0006 89 | ``` 90 | 91 | ### Quick start - Evaluation 92 | * Note: Please first train the model, if you already have the model, you can use the following code to validate the performance. Don't forget to replace the load_path and other arguments accordingly. 93 | ```shell 94 | python3 run_lite.py --encoder mpnn --model inv-ff-hist-switch --problem osbm --batch_size 100 --embedding_dim 30 --n_heads 1 --u_size 10 --v_size 60 --n_epochs 300 --train_dataset dataset/train/osbm_movielense_default_-1_10by60/parameter_-1 --val_dataset dataset/val/osbm_movielense_default_-1_10by60/parameter_-1 --dataset_size 20000 --val_size 1000 --checkpoint_epochs 0 --baseline exponential --lr_model 0.006 --lr_decay 0.97 --output_dir saved_models --log_dir logs_dataset --n_encode_layers 1 --save_dir saved_models/osbm_movielense_default_-1_10by60/parameter_-1 --graph_family_parameter -1 --exp_beta 0.8 --ent_rate 0.0006 --eval_only --no_tensorboard --load_path saved_models/inv-ff-hist/run_20220501T195416/best-model.pt --switch_lambda 0.0 --slackness 0.0 --max_reward 5.0 95 | ``` 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /policy/ff_model_hist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | 5 | 6 | class FeedForwardModelHist(nn.Module): 7 | def __init__( 8 | self, 9 | embedding_dim, 10 | hidden_dim, 11 | problem, 12 | opts, 13 | tanh_clipping=None, 14 | mask_inner=None, 15 | mask_logits=None, 16 | n_encode_layers=None, 17 | normalization="batch", 18 | checkpoint_encoder=False, 19 | shrink_size=None, 20 | num_actions=4, 21 | n_heads=None, 22 | encoder=None, 23 | ): 24 | 25 | super(FeedForwardModelHist, self).__init__() 26 | 27 | self.embedding_dim = embedding_dim 28 | self.decode_type = None 29 | self.num_actions = ( 30 | 5 * (opts.u_size + 1) + 8 31 | if opts.problem != "adwords" 32 | else 7 * (opts.u_size + 1) + 8 33 | ) 34 | self.problem = problem 35 | self.model_name = "ff-hist" 36 | hidden_size = 100 37 | self.ff = nn.Sequential( 38 | nn.Linear(self.num_actions, hidden_size), 39 | nn.ReLU(), 40 | nn.Linear(hidden_size, hidden_size), 41 | nn.ReLU(), 42 | nn.Linear(hidden_size, hidden_size), 43 | nn.ReLU(), 44 | nn.Linear(hidden_size, opts.u_size + 1), 45 | ) 46 | 47 | # self.ff.apply(init_weights) 48 | # self.init_parameters() 49 | 50 | def init_parameters(self): 51 | for name, param in self.named_parameters(): 52 | stdv = 1.0 / math.sqrt(param.size(-1)) 53 | param.data.uniform_(-stdv, stdv) 54 | 55 | def forward(self, x, opts, optimizer, baseline, return_pi=False): 56 | 57 | _log_p, pi, cost = self._inner(x, opts) 58 | 59 | # cost, mask = self.problem.get_costs(input, pi) 60 | # Log likelyhood is calculated within the model since returning it per action does not work well with 61 | # DataParallel since sequences can be of different lengths 62 | ll, e = self._calc_log_likelihood(_log_p, pi, None) 63 | if return_pi: 64 | return -cost, ll, pi, e 65 | # print(ll) 66 | return -cost, ll, e 67 | 68 | def _calc_log_likelihood(self, _log_p, a, mask): 69 | 70 | # Get log_p corresponding to selected actions 71 | # print(a[0, :]) 72 | entropy = -(_log_p * _log_p.exp()).sum(2).sum(1).mean() 73 | log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1) 74 | 75 | # Optional: mask out actions irrelevant to objective so they do not get reinforced 76 | if mask is not None: 77 | log_p[mask] = 0 78 | if not (log_p > -10000).data.all(): 79 | print(log_p) 80 | assert ( 81 | log_p > -10000 82 | ).data.all(), "Logprobs should not be -inf, check sampling procedure!" 83 | 84 | # Calculate log_likelihood 85 | # print(_log_p) 86 | return log_p.sum(1), entropy 87 | 88 | def _inner(self, input, opts): 89 | 90 | outputs = [] 91 | sequences = [] 92 | state = self.problem.make_state(input, opts.u_size, opts.v_size, opts) 93 | 94 | i = 1.0 95 | while not (state.all_finished()): 96 | mask = state.get_mask() 97 | state.get_current_weights(mask) 98 | 99 | s, mask = state.get_curr_state(self.model_name) 100 | pi = self.ff(s) 101 | # Select the indices of the next nodes in the sequences, result (batch_size) long 102 | selected, p = self._select_node(pi, mask.bool()) 103 | state = state.update((selected)[:, None]) 104 | outputs.append(p) 105 | sequences.append(selected) 106 | i += 1.0 107 | # Collected lists, return Tensor 108 | return ( 109 | torch.stack(outputs, 1), 110 | torch.stack(sequences, 1), 111 | state.size, 112 | ) 113 | 114 | def _select_node(self, probs, mask): 115 | assert (probs == probs).all(), "Probs should not contain any nans" 116 | 117 | probs[mask] = -1e8 118 | p = torch.log_softmax(probs, dim=1) 119 | if self.decode_type == "greedy": 120 | _, selected = p.max(1) 121 | # assert not mask.gather( 122 | # 1, selected.unsqueeze(-1) 123 | # ).data.any(), "Decode greedy: infeasible action has maximum probability" 124 | 125 | elif self.decode_type == "sampling": 126 | selected = p.exp().multinomial(1).squeeze(1) 127 | # Check if sampling went OK, can go wrong due to bug on GPU 128 | # See https://discuss.pytorch.org/t/bad-behavior-of-multinomial-function/10232 129 | while mask.gather(1, selected.unsqueeze(-1)).data.any(): 130 | print("Sampled bad values, resampling!") 131 | selected = p.exp().multinomial(1).squeeze(1) 132 | 133 | else: 134 | assert False, "Unknown decode type" 135 | return selected, p 136 | 137 | def set_decode_type(self, decode_type, temp=None): 138 | self.decode_type = decode_type 139 | if temp is not None: # Do not change temperature if not provided 140 | self.temp = temp 141 | -------------------------------------------------------------------------------- /policy/ff_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class FeedForwardModel(nn.Module): 6 | def __init__( 7 | self, 8 | embedding_dim, 9 | hidden_dim, 10 | problem, 11 | opts, 12 | tanh_clipping=None, 13 | mask_inner=None, 14 | mask_logits=None, 15 | n_encode_layers=None, 16 | normalization="batch", 17 | checkpoint_encoder=False, 18 | shrink_size=None, 19 | num_actions=4, 20 | n_heads=None, 21 | encoder=None, 22 | ): 23 | 24 | super(FeedForwardModel, self).__init__() 25 | 26 | self.embedding_dim = embedding_dim 27 | self.decode_type = None 28 | self.num_actions = ( 29 | 2 * (opts.u_size + 1) 30 | if opts.problem != "adwords" 31 | else 3 * (opts.u_size + 1) 32 | ) 33 | self.is_bipartite = problem.NAME == "bipartite" 34 | self.problem = problem 35 | self.shrink_size = None 36 | hidden_size = 100 37 | self.ff = nn.Sequential( 38 | nn.Linear(self.num_actions, hidden_size), 39 | nn.ReLU(), 40 | nn.Linear(hidden_size, hidden_size), 41 | nn.ReLU(), 42 | nn.Linear(hidden_size, hidden_size), 43 | nn.ReLU(), 44 | nn.Linear(hidden_size, opts.u_size + 1), 45 | ) 46 | self.model_name = "ff" 47 | 48 | def init_weights(m): 49 | if type(m) == nn.Linear: 50 | torch.nn.init.xavier_uniform_(m.weight) 51 | m.bias.data.fill_(0.0001) 52 | 53 | # self.ff.apply(init_weights) 54 | 55 | def forward(self, x, opts, optimizer, baseline, return_pi=False): 56 | 57 | _log_p, pi, cost = self._inner(x, opts) 58 | 59 | # cost, mask = self.problem.get_costs(input, pi) 60 | # Log likelyhood is calculated within the model since returning it per action does not work well with 61 | # DataParallel since sequences can be of different lengths 62 | ll, e = self._calc_log_likelihood(_log_p, pi, None) 63 | if return_pi: 64 | return -cost, ll, pi, e 65 | # print(ll) 66 | return -cost, ll, e 67 | 68 | def _calc_log_likelihood(self, _log_p, a, mask): 69 | 70 | entropy = -(_log_p * _log_p.exp()).sum(2).sum(1).mean() 71 | # Get log_p corresponding to selected actions 72 | log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1) 73 | 74 | # Optional: mask out actions irrelevant to objective so they do not get reinforced 75 | if mask is not None: 76 | log_p[mask] = 0 77 | if not (log_p > -10000).data.all(): 78 | print(log_p.nonzero()) 79 | assert ( 80 | log_p > -10000 81 | ).data.all(), "Logprobs should not be -inf, check sampling procedure!" 82 | 83 | # Calculate log_likelihood 84 | # print(log_p.sum(1)) 85 | return log_p.sum(1), entropy 86 | 87 | def _inner(self, input, opts): 88 | 89 | outputs = [] 90 | sequences = [] 91 | state = self.problem.make_state(input, opts.u_size, opts.v_size, opts) 92 | 93 | # step_context = 0 94 | # batch_size = state.ids.size(0) 95 | # Perform decoding steps 96 | i = 1 97 | # entropy = 0 98 | while not (state.all_finished()): 99 | mask = state.get_mask() 100 | state.get_current_weights(mask) 101 | s, mask = state.get_curr_state(self.model_name) 102 | # s = w 103 | pi = self.ff(s) 104 | # Select the indices of the next nodes in the sequences, result (batch_size) long 105 | selected, p = self._select_node( 106 | pi, mask.bool() 107 | ) # Squeeze out steps dimension 108 | # entropy += torch.sum(p * (p.log()), dim=1) 109 | state = state.update((selected)[:, None]) 110 | outputs.append(p) 111 | sequences.append(selected) 112 | i += 1 113 | # Collected lists, return Tensor 114 | return ( 115 | torch.stack(outputs, 1), 116 | torch.stack(sequences, 1), 117 | state.size, 118 | ) 119 | 120 | def _select_node(self, probs, mask): 121 | assert (probs == probs).all(), "Probs should not contain any nans" 122 | probs[mask] = -1e6 123 | p = torch.log_softmax(probs, dim=1) 124 | if self.decode_type == "greedy": 125 | _, selected = p.max(1) 126 | # assert not mask.gather( 127 | # 1, selected.unsqueeze(-1) 128 | # ).data.any(), "Decode greedy: infeasible action has maximum probability" 129 | 130 | elif self.decode_type == "sampling": 131 | selected = p.exp().multinomial(1).squeeze(1) 132 | # Check if sampling went OK, can go wrong due to bug on GPU 133 | # See https://discuss.pytorch.org/t/bad-behavior-of-multinomial-function/10232 134 | # while mask.gather(1, selected.unsqueeze(-1)).data.any(): 135 | # print("Sampled bad values, resampling!") 136 | # selected = probs.multinomial(1).squeeze(1) 137 | 138 | else: 139 | assert False, "Unknown decode type" 140 | return selected, p 141 | 142 | def set_decode_type(self, decode_type, temp=None): 143 | self.decode_type = decode_type 144 | if temp is not None: # Do not change temperature if not provided 145 | self.temp = temp 146 | -------------------------------------------------------------------------------- /problem_state/obm_env.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import NamedTuple 3 | 4 | # from utils.boolmask import mask_long2bool, mask_long_scatter 5 | 6 | 7 | class StateBipartite(NamedTuple): 8 | # Fixed input 9 | graphs: torch.Tensor # full adjacency matrix of all graphs in a batch 10 | # adj: torch.Tensor # full adjacency matrix of all graphs in a batch 11 | weights: torch.Tensor 12 | u_size: torch.Tensor 13 | v_size: torch.Tensor 14 | batch_size: torch.Tensor 15 | # If this state contains multiple copies (i.e. beam search) for the same instance, then for memory efficiency 16 | # the loc and dist tensors are not kept multiple times, so we need to use the ids to index the correct rows. 17 | ids: torch.Tensor # Keeps track of original fixed data index of rows 18 | 19 | # State 20 | # curr_edge: torch.Tensor # current edge number 21 | matched_nodes: torch.Tensor # Keeps track of nodes that have been matched 22 | # picked_edges: torch.Tensor 23 | size: torch.Tensor # size of current matching 24 | i: torch.Tensor # Keeps track of step 25 | # mask: torch.Tensor # mask for each step 26 | 27 | @property 28 | def visited(self): 29 | if self.visited_.dtype == torch.uint8: 30 | return self.matched_nodes 31 | 32 | def __getitem__(self, key): 33 | if torch.is_tensor(key) or isinstance( 34 | key, slice 35 | ): # If tensor, idx all tensors by this tensor: 36 | return self._replace( 37 | ids=self.ids[key], 38 | graphs=self.graphs[key], 39 | matched_nodes=self.matched_nodes[key], 40 | size=self.size[key], 41 | u_size=self.u_size[key], 42 | v_size=self.v_size[key], 43 | ) 44 | # return super(StateBipartite, self).__getitem__(key) 45 | return self[key] 46 | 47 | @staticmethod 48 | def initialize( 49 | input, u_size, v_size, num_edges, visited_dtype=torch.uint8, 50 | ): 51 | 52 | batch_size = len(input) 53 | # size = torch.zeros(batch_size, 1, dtype=torch.long, device=graphs.device) 54 | return StateBipartite( 55 | graphs=input, 56 | u_size=torch.tensor([u_size], device=input.device), 57 | v_size=torch.tensor([v_size], device=input.device), 58 | weights=None, 59 | batch_size=torch.tensor([batch_size], device=input.device), 60 | ids=torch.arange(batch_size, dtype=torch.int64, device=input.device)[ 61 | :, None 62 | ], # Add steps dimension 63 | # Keep visited with depot so we can scatter efficiently (if there is an action for depot) 64 | matched_nodes=( # Visited as mask is easier to understand, as long more memory efficient 65 | torch.zeros( 66 | batch_size, 1, u_size + 1, dtype=torch.uint8, device=input.device 67 | ) 68 | ), 69 | size=torch.zeros(batch_size, 1, device=input.device), 70 | i=torch.ones(1, dtype=torch.int64, device=input.device) 71 | * (u_size + 1), # Vector with length num_steps 72 | ) 73 | 74 | def get_final_cost(self): 75 | 76 | assert self.all_finished() 77 | # assert self.visited_. 78 | 79 | return self.size 80 | 81 | def update(self, selected): 82 | # Update the state 83 | nodes = self.matched_nodes.squeeze(1).scatter_(-1, selected, 1) 84 | 85 | total_weights = self.size + (selected != 0).long() 86 | 87 | return self._replace(matched_nodes=nodes, size=total_weights, i=self.i + 1,) 88 | 89 | def all_finished(self): 90 | # Exactly n steps 91 | return (self.i.item() - (self.u_size.item() + 1)) >= self.v_size 92 | 93 | def get_current_node(self): 94 | return self.i.item() 95 | 96 | def get_mask(self): 97 | """ 98 | Returns a mask vector which includes only nodes in U that can matched. 99 | That is, neighbors of the incoming node that have not been matched already. 100 | """ 101 | mask = self.graphs[:, self.i.item(), : self.u_size.item() + 1] 102 | 103 | self.matched_nodes[ 104 | :, 0 105 | ] = 0 # node that represents not being matched to anything can be matched to more than once 106 | return ( 107 | self.matched_nodes.squeeze(1) + mask > 0 108 | ).long() # Hacky way to return bool or uint8 depending on pytorch version 109 | 110 | # def get_nn(self, k=None): 111 | # # Insert step dimension 112 | # # Nodes already visited get inf so they do not make it 113 | # if k is None: 114 | # k = self.loc.size(-2) - self.i.item() # Number of remaining 115 | # return ( 116 | # self.dist[self.ids, :, :] + self.visited.float()[:, :, None, :] * 1e6 117 | # ).topk(k, dim=-1, largest=False)[1] 118 | 119 | # def get_nn_current(self, k=None): 120 | # assert ( 121 | # False 122 | # ), "Currently not implemented, look into which neighbours to use in step 0?" 123 | # # Note: if this is called in step 0, it will have k nearest neighbours to node 0, which may not be desired 124 | # # so it is probably better to use k = None in the first iteration 125 | # if k is None: 126 | # k = self.loc.size(-2) 127 | # k = min(k, self.loc.size(-2) - self.i.item()) # Number of remaining 128 | # return (self.dist[self.ids, self.prev_a] + self.visited.float() * 1e6).topk( 129 | # k, dim=-1, largest=False 130 | # )[1] 131 | 132 | # def construct_solutions(self, actions): 133 | # return actions 134 | -------------------------------------------------------------------------------- /IPsolvers/IPsolver.py: -------------------------------------------------------------------------------- 1 | import gurobipy as gp 2 | from gurobipy import GRB 3 | import itertools 4 | import numpy as np 5 | 6 | 7 | def get_data_adwords(u_size, v_size, adjacency_matrix): 8 | """ 9 | pre-process the data for groubi for the adwords problem 10 | Reads data from the specfied file and writes the graph tensor into multu dict of the following form: 11 | combinations, ms= gp.multidict({ 12 | ('u1','v1'):10, 13 | ('u1','v2'):13, 14 | ('u2','v1'):9, 15 | ('u2','v2'):3 16 | }) 17 | """ 18 | 19 | adj_dic = {} 20 | 21 | for v, u in itertools.product(range(v_size), range(u_size)): 22 | adj_dic[(v, u)] = adjacency_matrix[u, v] 23 | 24 | return gp.multidict(adj_dic) 25 | 26 | 27 | def get_data_osbm(u_size, v_size, adjacency_matrix, prefrences): 28 | """ 29 | pre-process the data for groubi for the osbm problem 30 | """ 31 | 32 | adj_dic = {} 33 | w = {} 34 | 35 | for v, u in itertools.product(range(v_size), range(u_size)): 36 | adj_dic[(v, u)] = adjacency_matrix[v][u] 37 | _, dic = gp.multidict(adj_dic) 38 | 39 | for i, j in itertools.product( 40 | range(prefrences.shape[0]), range(prefrences.shape[1]) 41 | ): 42 | w[(j, i)] = prefrences[i][j] 43 | return dic, w 44 | 45 | 46 | def solve_adwords(u_size, v_size, adjacency_matrix, budgets): 47 | try: 48 | m = gp.Model("adwords") 49 | # m.Params.LogToConsole = 0 50 | m.Params.timeLimit = 30 51 | 52 | _, dic = get_data_adwords(u_size, v_size, adjacency_matrix) 53 | 54 | # add variable 55 | x = m.addVars(v_size, u_size, vtype="B", name="(u,v) pairs") 56 | 57 | # set constraints 58 | m.addConstrs((x.sum(v, "*") <= 1 for v in range(v_size)), "V") 59 | m.addConstrs((x.prod(dic, "*", u) <= budgets[u] for u in range(u_size)), "U") 60 | 61 | # set the objective 62 | m.setObjective(x.prod(dic), GRB.MAXIMIZE) 63 | m.optimize() 64 | 65 | solution = np.zeros(v_size).tolist() 66 | for v in range(v_size): 67 | u = 0 68 | for nghbr_v in x.select(v, "*"): 69 | if nghbr_v.getAttr("x") == 1: 70 | solution[v] = u + 1 71 | break 72 | u += 1 73 | return m.objVal, solution 74 | 75 | except gp.GurobiError as e: 76 | print("Error code " + str(e.errno) + ": " + str(e)) 77 | except AttributeError: 78 | print("Encountered an attribute error") 79 | 80 | 81 | def solve_submodular_matching( 82 | u_size, v_size, adjacency_matrix, r_v, movie_features, preferences, num_incoming 83 | ): 84 | try: 85 | m = gp.Model("submatching") 86 | m.Params.LogToConsole = 0 87 | # 15 is the fixed number of genres from the movielens dataset 88 | genres = 15 89 | dic, weight_dic = get_data_osbm(u_size, v_size, adjacency_matrix, preferences) 90 | 91 | # add variable for each edge (u,v), where v is the user and u is the movie 92 | x = m.addVars(v_size, u_size, vtype="B", name="(u,v) pairs") 93 | 94 | # set the variable to zero for edges that do not exist in the graph 95 | for key in x: 96 | x[key] = 0 if dic[key] == 0 else x[key] 97 | 98 | # create variable for each (genre, user) pair 99 | gamma = m.addVars(genres, v_size, vtype="B", name="gamma") 100 | 101 | # A is |genres| by |V| matrix containg the total number of edges going from u to genere g at index (g,u) 102 | A = m.addVars(genres, v_size) 103 | 104 | # first set all variables in A to zero 105 | for key in A: 106 | A[key] = 0 107 | 108 | for z, v in itertools.product(range(genres), range(v_size)): 109 | for u in range(u_size): 110 | if movie_features[u][z] == 1.0: # if u belongs to genre z 111 | A[(z, v)] += x[(v, u)] 112 | 113 | # set constraints 114 | r_v1 = {} 115 | for i, n in enumerate(r_v): 116 | r_v1[i] = r_v[n] 117 | m.addConstrs((x.sum(v, "*") <= len(r_v1[v]) for v in r_v1), "const1") 118 | m.addConstrs((x.sum("*", u) <= 1 for u in range(u_size)), "const2") 119 | m.addConstrs( 120 | ( 121 | gamma[(z, v)] <= A[(z, v)] 122 | for z, v in itertools.product(range(genres), range(v_size)) 123 | ), 124 | "const3", 125 | ) 126 | m.addConstrs( 127 | ( 128 | gamma[(z, v)] <= 1 129 | for z, v in itertools.product(range(genres), range(v_size)) 130 | ), 131 | "const4", 132 | ) 133 | 134 | # give each gamma variable a weight based on the user preferences and optimiza the sum 135 | m.setObjective(gamma.prod(weight_dic), GRB.MAXIMIZE) 136 | m.optimize() 137 | solution = np.zeros(num_incoming).tolist() 138 | sol_dict = dict(x) 139 | s = sorted(sol_dict.keys(), key=(lambda k: k[0])) 140 | matched = 0 141 | for i, p in enumerate(r_v1): 142 | matched = 0 143 | inds = r_v1[p] 144 | idx = i * u_size 145 | nodes = s[idx : idx + u_size] 146 | for j, u in enumerate(nodes): 147 | if (type(x[u]) is not int) and x[u].x == 1: 148 | solution[inds[matched]] = u[1] + 1 149 | matched += 1 150 | 151 | for j in range(len(inds) - matched): 152 | solution[inds[j + matched]] = 0 153 | 154 | return m.objVal, solution 155 | 156 | except gp.GurobiError as e: 157 | print("Error code " + str(e.errno) + ": " + str(e)) 158 | except AttributeError: 159 | print("Encountered an attribute error") 160 | 161 | 162 | if __name__ == "__main__": 163 | 164 | # osbm exmaple: 165 | 166 | # {v_id : freq} 167 | # r_v = {0: [1, 2], 1: [0]} 168 | 169 | # # 3 genres (each column), 3 movies (each row) 170 | # movie_features = [ 171 | # [0.0, 0.0, 1.0], 172 | # [0.0, 0.0, 1.0], 173 | # [0.0, 1.0, 1.0] 174 | # ] 175 | 176 | # # user preferences V by |genres| 177 | # preferences = np.array([ 178 | # [0.999, 0.4, 0.222], 179 | # [1, 1, 1] 180 | # ]) 181 | 182 | # adjacency_matrix = np.array([ 183 | # [1, 2, 0], 184 | # [0, 1, 0], 185 | # [4, 0, 0] 186 | # ]) 187 | 188 | # print(solve_submodular_matching(3, 2, adjacency_matrix, r_v, movie_features, preferences, 3)) 189 | 190 | # adwords exmaple: 191 | 192 | # V by U matrix 193 | adjacency_matrix = np.array([[1, 2, 0], [0, 1, 0], [4, 0, 0]]) 194 | 195 | budgets = [3, 1, 4] 196 | print(solve_adwords(3, 3, adjacency_matrix, budgets)) 197 | -------------------------------------------------------------------------------- /policy/ff_supervised.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn import DataParallel 5 | 6 | 7 | def set_decode_type(model, decode_type): 8 | if isinstance(model, DataParallel): 9 | model = model.module 10 | model.set_decode_type(decode_type) 11 | 12 | 13 | def get_loss(log_p, y, optimizers, w, opts): 14 | # The cross entrophy loss of v_1, ..., v_{t-1} (this is a batch_size by 1 vector) 15 | # total_loss = torch.zeros(y.shape) 16 | 17 | # Calculate loss of v_t 18 | # print(log_p, y) 19 | loss_t = F.cross_entropy(log_p, y.long(), weight=w) 20 | 21 | # Update the loss for the whole graph 22 | # total_loss += loss_t 23 | loss = loss_t 24 | 25 | # Perform backward pass and optimization step 26 | # optimizers[0].zero_grad() 27 | # loss.backward() 28 | # optimizers[0].step() 29 | return loss 30 | 31 | 32 | class SupervisedFFModel(nn.Module): 33 | def __init__( 34 | self, 35 | embedding_dim, 36 | hidden_dim, 37 | problem, 38 | opts, 39 | tanh_clipping=None, 40 | mask_inner=None, 41 | mask_logits=None, 42 | n_encode_layers=None, 43 | normalization="batch", 44 | checkpoint_encoder=False, 45 | shrink_size=None, 46 | num_actions=4, 47 | n_heads=None, 48 | encoder=None, 49 | ): 50 | super(SupervisedFFModel, self).__init__() 51 | 52 | self.embedding_dim = embedding_dim 53 | self.decode_type = None 54 | self.num_actions = self.num_actions = ( 55 | 5 * (opts.u_size + 1) + 8 56 | if opts.problem != "adwords" 57 | else 7 * (opts.u_size + 1) + 8 58 | ) 59 | self.is_bipartite = problem.NAME == "bipartite" 60 | self.problem = problem 61 | self.shrink_size = None 62 | self.model_name = "ff-supervised" 63 | self.ff = nn.Sequential( 64 | nn.Linear(self.num_actions, 100), 65 | nn.ReLU(), 66 | nn.Linear(100, 100), 67 | nn.ReLU(), 68 | nn.Linear(100, 100), 69 | nn.ReLU(), 70 | nn.Linear(100, opts.u_size + 1), 71 | ) 72 | 73 | def forward( 74 | self, 75 | input, 76 | opt_match, 77 | opts, 78 | optimizer, 79 | training=False, 80 | baseline=None, 81 | return_pi=False, 82 | ): 83 | """ 84 | :param input: (batch_size, graph_size, node_dim) input node features or dictionary with multiple tensors 85 | :param return_pi: whether to return the output sequences, this is optional as it is not compatible with 86 | :param opt_match: (batch_size, U_size, V_size), the optimal matching of the graphs in the batch 87 | using DataParallel as the results may be of different lengths on different GPUs 88 | :return: 89 | """ 90 | _log_p, pi, cost, batch_loss = self._inner( 91 | input, opt_match, opts, optimizer, training 92 | ) 93 | 94 | ll = self._calc_log_likelihood(_log_p, pi, None) 95 | 96 | return -cost, ll, pi, batch_loss 97 | 98 | def _calc_log_likelihood(self, _log_p, a, mask): 99 | 100 | # Get log_p corresponding to selected actions 101 | # print(a[0, :]) 102 | entropy = -(_log_p * _log_p.exp()).sum(2).sum(1).mean() 103 | log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1) 104 | 105 | # Optional: mask out actions irrelevant to objective so they do not get reinforced 106 | # if mask is not None: 107 | # log_p[mask] = 0 108 | # if not (log_p > -10000).data.all(): 109 | # print(log_p) 110 | # assert ( 111 | # log_p > -10000 112 | # ).data.all(), "Logprobs should not be -inf, check sampling procedure!" 113 | 114 | # Calculate log_likelihood 115 | # print(_log_p) 116 | return log_p.sum(1), entropy 117 | 118 | def _inner(self, input, opt_match, opts, optimizer, training): 119 | 120 | outputs = [] 121 | sequences = [] 122 | # losses = [] 123 | 124 | state = self.problem.make_state(input, opts.u_size, opts.v_size, opts) 125 | i = 1 126 | total_loss = 0 127 | while not (state.all_finished()): 128 | mask = state.get_mask() 129 | w = state.get_current_weights(mask) 130 | s, mask = state.get_curr_state(self.model_name) 131 | # s = w 132 | pi = self.ff(s) 133 | # Select the indices of the next nodes in the sequences, result (batch_size) long 134 | # if training: 135 | # mask = torch.zeros(mask.shape, device=opts.device) 136 | selected, p = self._select_node( 137 | pi, 138 | mask.bool(), 139 | ) # Squeeze out steps dimension 140 | # entropy += torch.sum(p * (p.log()), dim=1) 141 | state = state.update((selected)[:, None]) 142 | outputs.append(p) 143 | sequences.append(selected) 144 | 145 | # do backprop if in training mode 146 | if opts.problem != "adwords": 147 | none_node_w = torch.tensor( 148 | [1.0 / (opts.v_size / opts.u_size)], device=opts.device 149 | ).float() 150 | else: 151 | none_node_w = torch.tensor([1.0], device=opts.device).float() 152 | w = torch.cat( 153 | [none_node_w, torch.ones(opts.u_size, device=opts.device).float()], 154 | dim=0, 155 | ) 156 | # supervised learning 157 | y = opt_match[:, i - 1] 158 | # print('y: ', y) 159 | # print('selected: ', selected) 160 | loss = get_loss(pi, y, optimizer, w, opts) 161 | # print("Loss: ", loss) 162 | # keep track for logging 163 | total_loss += loss 164 | i += 1 165 | # Collected lists, return Tensor 166 | batch_loss = total_loss / state.v_size 167 | # print(batch_loss) 168 | if optimizer is not None and training: 169 | # print('epoch {} batch loss {}'.format(i, batch_loss)) 170 | # print('outputs: ', outputs) 171 | # print('sequences: ', sequences) 172 | # print('optimal solution: ', opt_match) 173 | optimizer[0].zero_grad() 174 | batch_loss.backward() 175 | optimizer[0].step() 176 | return ( 177 | torch.stack(outputs, 1), 178 | torch.stack(sequences, 1), 179 | state.size, 180 | batch_loss, 181 | ) 182 | 183 | def _select_node(self, probs, mask): 184 | assert (probs == probs).all(), "Probs should not contain any nans" 185 | mask[:, 0] = False 186 | p = probs.clone() 187 | p[ 188 | mask 189 | ] = ( 190 | -1e6 191 | ) # Masking doesn't really make sense with supervised since input samples are independent, should only masking during testing. 192 | _, selected = p.max(1) 193 | return selected, p 194 | 195 | def set_decode_type(self, decode_type, temp=None): 196 | self.decode_type = decode_type 197 | if temp is not None: # Do not change temperature if not provided 198 | self.temp = temp 199 | -------------------------------------------------------------------------------- /policy/gnn_simp_hist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.utils.checkpoint import checkpoint 4 | import math 5 | from typing import NamedTuple 6 | 7 | import torch.nn.functional as F 8 | 9 | # from utils.tensor_functions import compute_in_batches 10 | 11 | from encoder.graph_encoder_v2 import GraphAttentionEncoder 12 | from train import clip_grad_norms 13 | 14 | from encoder.graph_encoder import MPNN 15 | from torch.nn import DataParallel 16 | from torch_geometric.utils import subgraph 17 | 18 | # from utils.functions import sample_many 19 | 20 | import time 21 | 22 | 23 | def set_decode_type(model, decode_type): 24 | if isinstance(model, DataParallel): 25 | model = model.module 26 | model.set_decode_type(decode_type) 27 | 28 | 29 | class GNNSimpHist(nn.Module): 30 | def __init__( 31 | self, 32 | embedding_dim, 33 | hidden_dim, 34 | problem, 35 | opts, 36 | n_encode_layers=1, 37 | tanh_clipping=10.0, 38 | mask_inner=True, 39 | mask_logits=True, 40 | normalization="batch", 41 | n_heads=8, 42 | checkpoint_encoder=False, 43 | shrink_size=None, 44 | num_actions=None, 45 | encoder="mpnn", 46 | ): 47 | super(GNNSimpHist, self).__init__() 48 | 49 | self.embedding_dim = embedding_dim 50 | self.hidden_dim = hidden_dim 51 | self.n_encode_layers = n_encode_layers 52 | self.decode_type = None 53 | self.temp = 1.0 54 | self.problem = problem 55 | self.opts = opts 56 | # Problem specific context parameters (placeholder and step context dimension) 57 | 58 | encoder_class = {"attention": GraphAttentionEncoder, "mpnn": MPNN}.get( 59 | encoder, None 60 | ) 61 | if opts.problem == "osbm": 62 | node_dim_u = 16 63 | node_dim_v = 18 64 | else: 65 | node_dim_u, node_dim_v = 1, 1 66 | 67 | self.embedder = encoder_class( 68 | n_heads=n_heads, 69 | embed_dim=embedding_dim, 70 | n_layers=self.n_encode_layers, 71 | normalization=normalization, 72 | problem=self.problem, 73 | opts=self.opts, 74 | node_dim_v=node_dim_v, 75 | node_dim_u=node_dim_u, 76 | ) 77 | 78 | self.ff = nn.Sequential( 79 | nn.Linear(16 + opts.embedding_dim, 200), 80 | nn.ReLU(), 81 | nn.Linear(200, 200), 82 | nn.ReLU(), 83 | nn.Linear(200, 1), 84 | ) 85 | 86 | assert embedding_dim % n_heads == 0 87 | self.step_context_transf = nn.Linear(2 * opts.embedding_dim, opts.embedding_dim) 88 | self.initial_stepcontext = nn.Parameter(torch.Tensor(1, 1, embedding_dim)) 89 | self.initial_stepcontext.data.uniform_(-1, 1) 90 | self.dummy = torch.ones(1, dtype=torch.float32, requires_grad=True) 91 | self.model_name = "gnn-simp-hist" 92 | 93 | def init_parameters(self): 94 | for name, param in self.named_parameters(): 95 | stdv = 1.0 / math.sqrt(param.size(-1)) 96 | param.data.uniform_(-stdv, stdv) 97 | 98 | def set_decode_type(self, decode_type, temp=None): 99 | self.decode_type = decode_type 100 | if temp is not None: # Do not change temperature if not provided 101 | self.temp = temp 102 | 103 | def forward(self, x, opts, optimizer, baseline, return_pi=False): 104 | 105 | _log_p, pi, cost = self._inner(x, opts) 106 | 107 | # cost, mask = self.problem.get_costs(input, pi) 108 | # Log likelyhood is calculated within the model since returning it per action does not work well with 109 | # DataParallel since sequences can be of different lengths 110 | ll, e = self._calc_log_likelihood(_log_p, pi, None) 111 | if return_pi: 112 | return -cost, ll, pi, e 113 | # print(ll) 114 | return -cost, ll, e 115 | 116 | def _calc_log_likelihood(self, _log_p, a, mask): 117 | 118 | entropy = -(_log_p * _log_p.exp()).sum(2).sum(1).mean() 119 | # Get log_p corresponding to selected actions 120 | log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1) 121 | 122 | # Optional: mask out actions irrelevant to objective so they do not get reinforced 123 | if mask is not None: 124 | log_p[mask] = 0 125 | if not (log_p > -10000).data.all(): 126 | print(log_p.nonzero()) 127 | assert ( 128 | log_p > -10000 129 | ).data.all(), "Logprobs should not be -inf, check sampling procedure!" 130 | 131 | # Calculate log_likelihood 132 | # print(log_p.sum(1)) 133 | 134 | return log_p.sum(1), entropy 135 | 136 | def _inner(self, input, opts): 137 | 138 | outputs = [] 139 | sequences = [] 140 | 141 | state = self.problem.make_state(input, opts.u_size, opts.v_size, opts) 142 | 143 | batch_size = state.batch_size 144 | graph_size = state.u_size + state.v_size + 1 145 | i = 1 146 | while not (state.all_finished()): 147 | step_size = state.i + 1 148 | mask = state.get_mask() 149 | w = state.get_current_weights(mask) 150 | s, mask = state.get_curr_state(self.model_name) 151 | # Pass the graph to the Encoder 152 | node_features = state.get_node_features() 153 | nodes = torch.cat( 154 | ( 155 | torch.arange(0, opts.u_size + 1, device=opts.device), 156 | state.idx[:i] + opts.u_size + 1, 157 | ) 158 | ) 159 | subgraphs = ( 160 | (nodes.unsqueeze(0).expand(batch_size, step_size)) 161 | + torch.arange( 162 | 0, batch_size * graph_size, graph_size, device=opts.device 163 | ).unsqueeze(1) 164 | ).flatten() # The nodes of the current subgraphs 165 | graph_weights = state.get_graph_weights() 166 | edge_i, weights = subgraph( 167 | subgraphs, 168 | state.graphs.edge_index, 169 | graph_weights.unsqueeze(1), 170 | relabel_nodes=True, 171 | ) 172 | embeddings = checkpoint( 173 | self.embedder, 174 | node_features, 175 | edge_i, 176 | weights.float(), 177 | torch.tensor(i), 178 | self.dummy, 179 | ).reshape(batch_size, step_size, -1) 180 | pos = torch.argsort(state.idx[:i])[-1] 181 | incoming_node_embeddings = embeddings[ 182 | :, pos + state.u_size + 1, : 183 | ].unsqueeze(1) 184 | # print(incoming_node_embeddings) 185 | w = (state.adj[:, state.get_current_node(), :]).float() 186 | # mean_w = w.mean(1)[:, None, None].repeat(1, state.u_size + 1, 1) 187 | w = w.reshape(state.batch_size, state.u_size + 1, 1) 188 | s = torch.cat( 189 | (s, incoming_node_embeddings.repeat(1, state.u_size + 1, 1),), dim=2, 190 | ) 191 | pi = self.ff(s).reshape(state.batch_size, state.u_size + 1) 192 | # Select the indices of the next nodes in the sequences, result (batch_size) long 193 | selected, p = self._select_node( 194 | pi, mask.bool() 195 | ) # Squeeze out steps dimension 196 | state = state.update((selected)[:, None]) 197 | outputs.append(p) 198 | sequences.append(selected) 199 | i += 1 200 | # Collected lists, return Tensor 201 | return ( 202 | torch.stack(outputs, 1), 203 | torch.stack(sequences, 1), 204 | state.size, 205 | ) 206 | 207 | def _select_node(self, probs, mask): 208 | assert (probs == probs).all(), "Probs should not contain any nans" 209 | probs[mask] = -1e6 210 | p = torch.log_softmax(probs, dim=1) 211 | # print(p) 212 | if self.decode_type == "greedy": 213 | _, selected = p.max(1) 214 | # assert not mask.gather( 215 | # 1, selected.unsqueeze(-1) 216 | # ).data.any(), "Decode greedy: infeasible action has maximum probability" 217 | 218 | elif self.decode_type == "sampling": 219 | selected = p.exp().multinomial(1).squeeze(1) 220 | # Check if sampling went OK, can go wrong due to bug on GPU 221 | # See https://discuss.pytorch.org/t/bad-behavior-of-multinomial-function/10232 222 | # while mask.gather(1, selected.unsqueeze(-1)).data.any(): 223 | # print("Sampled bad values, resampling!") 224 | # selected = probs.multinomial(1).squeeze(1) 225 | 226 | else: 227 | assert False, "Unknown decode type" 228 | return selected, p 229 | -------------------------------------------------------------------------------- /policy/gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.utils.checkpoint import checkpoint 4 | import math 5 | from typing import NamedTuple 6 | 7 | import torch.nn.functional as F 8 | 9 | # from utils.tensor_functions import compute_in_batches 10 | 11 | from encoder.graph_encoder_v2 import GraphAttentionEncoder 12 | from train import clip_grad_norms 13 | 14 | from encoder.graph_encoder import MPNN 15 | from torch.nn import DataParallel 16 | from torch_geometric.utils import subgraph 17 | 18 | # from utils.functions import sample_many 19 | 20 | import time 21 | 22 | 23 | def set_decode_type(model, decode_type): 24 | if isinstance(model, DataParallel): 25 | model = model.module 26 | model.set_decode_type(decode_type) 27 | 28 | 29 | class GNN(nn.Module): 30 | def __init__( 31 | self, 32 | embedding_dim, 33 | hidden_dim, 34 | problem, 35 | opts, 36 | n_encode_layers=1, 37 | tanh_clipping=10.0, 38 | mask_inner=True, 39 | mask_logits=True, 40 | normalization="batch", 41 | n_heads=8, 42 | checkpoint_encoder=False, 43 | shrink_size=None, 44 | num_actions=None, 45 | encoder="mpnn", 46 | ): 47 | super(GNN, self).__init__() 48 | 49 | self.embedding_dim = embedding_dim 50 | self.hidden_dim = hidden_dim 51 | self.n_encode_layers = n_encode_layers 52 | self.decode_type = None 53 | self.temp = 1.0 54 | self.problem = problem 55 | self.opts = opts 56 | # Problem specific context parameters (placeholder and step context dimension) 57 | 58 | encoder_class = {"attention": GraphAttentionEncoder, "mpnn": MPNN}.get( 59 | encoder, None 60 | ) 61 | if opts.problem == "osbm": 62 | node_dim_u = 16 63 | node_dim_v = 18 64 | else: 65 | node_dim_u, node_dim_v = 1, 1 66 | 67 | self.embedder = encoder_class( 68 | n_heads=n_heads, 69 | embed_dim=embedding_dim, 70 | n_layers=self.n_encode_layers, 71 | normalization=normalization, 72 | problem=self.problem, 73 | opts=self.opts, 74 | node_dim_v=node_dim_v, 75 | node_dim_u=node_dim_u, 76 | ) 77 | 78 | self.ff = nn.Sequential( 79 | nn.Linear(2 + opts.embedding_dim, 200), nn.ReLU(), nn.Linear(200, 1), 80 | ) 81 | 82 | assert embedding_dim % n_heads == 0 83 | self.step_context_transf = nn.Linear(2 * opts.embedding_dim, opts.embedding_dim) 84 | self.initial_stepcontext = nn.Parameter(torch.Tensor(1, 1, embedding_dim)) 85 | self.initial_stepcontext.data.uniform_(-1, 1) 86 | self.dummy = torch.ones(1, dtype=torch.float32, requires_grad=True) 87 | self.model_name = "gnn" 88 | 89 | def init_parameters(self): 90 | for name, param in self.named_parameters(): 91 | stdv = 1.0 / math.sqrt(param.size(-1)) 92 | param.data.uniform_(-stdv, stdv) 93 | 94 | def set_decode_type(self, decode_type, temp=None): 95 | self.decode_type = decode_type 96 | if temp is not None: # Do not change temperature if not provided 97 | self.temp = temp 98 | 99 | def forward(self, x, opts, optimizer, baseline, return_pi=False): 100 | 101 | _log_p, pi, cost = self._inner(x, opts) 102 | 103 | # cost, mask = self.problem.get_costs(input, pi) 104 | # Log likelyhood is calculated within the model since returning it per action does not work well with 105 | # DataParallel since sequences can be of different lengths 106 | ll, e = self._calc_log_likelihood(_log_p, pi, None) 107 | if return_pi: 108 | return -cost, ll, pi, e 109 | # print(ll) 110 | return -cost, ll, e 111 | 112 | def _calc_log_likelihood(self, _log_p, a, mask): 113 | 114 | entropy = -(_log_p * _log_p.exp()).sum(2).sum(1).mean() 115 | # Get log_p corresponding to selected actions 116 | log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1) 117 | 118 | # Optional: mask out actions irrelevant to objective so they do not get reinforced 119 | if mask is not None: 120 | log_p[mask] = 0 121 | if not (log_p > -10000).data.all(): 122 | print(log_p.nonzero()) 123 | assert ( 124 | log_p > -10000 125 | ).data.all(), "Logprobs should not be -inf, check sampling procedure!" 126 | 127 | # Calculate log_likelihood 128 | # print(log_p.sum(1)) 129 | 130 | return log_p.sum(1), entropy 131 | 132 | def _inner(self, input, opts): 133 | 134 | outputs = [] 135 | sequences = [] 136 | 137 | state = self.problem.make_state(input, opts.u_size, opts.v_size, opts) 138 | 139 | batch_size = state.batch_size 140 | graph_size = state.u_size + state.v_size + 1 141 | i = 1 142 | while not (state.all_finished()): 143 | step_size = state.i + 1 144 | mask = state.get_mask() 145 | w = state.get_current_weights(mask) 146 | # Pass the graph to the Encoder 147 | node_features = state.get_node_features() 148 | nodes = torch.cat( 149 | ( 150 | torch.arange(0, opts.u_size + 1, device=opts.device), 151 | state.idx[:i] + opts.u_size + 1, 152 | ) 153 | ) 154 | subgraphs = ( 155 | (nodes.unsqueeze(0).expand(batch_size, step_size)) 156 | + torch.arange( 157 | 0, batch_size * graph_size, graph_size, device=opts.device 158 | ).unsqueeze(1) 159 | ).flatten() # The nodes of the current subgraphs 160 | graph_weights = state.get_graph_weights() 161 | edge_i, weights = subgraph( 162 | subgraphs, 163 | state.graphs.edge_index, 164 | graph_weights.unsqueeze(1), 165 | relabel_nodes=True, 166 | ) 167 | embeddings = checkpoint( 168 | self.embedder, 169 | node_features, 170 | edge_i, 171 | weights.float(), 172 | torch.tensor(i), 173 | self.dummy, 174 | ).reshape(batch_size, step_size, -1) 175 | pos = torch.argsort(state.idx[:i])[-1] 176 | incoming_node_embeddings = embeddings[ 177 | :, pos + state.u_size + 1, : 178 | ].unsqueeze(1) 179 | # print(incoming_node_embeddings) 180 | w = (state.adj[:, state.get_current_node(), :]).float() 181 | # mean_w = w.mean(1)[:, None, None].repeat(1, state.u_size + 1, 1) 182 | s = w.reshape(state.batch_size, state.u_size + 1, 1) 183 | idx = ( 184 | torch.ones(state.batch_size, 1, 1, device=opts.device) 185 | * i 186 | / state.v_size 187 | ) 188 | s = torch.cat( 189 | ( 190 | s, 191 | idx.repeat(1, state.u_size + 1, 1), 192 | incoming_node_embeddings.repeat(1, state.u_size + 1, 1), 193 | ), 194 | dim=2, 195 | ) 196 | pi = self.ff(s).reshape(state.batch_size, state.u_size + 1) 197 | # Select the indices of the next nodes in the sequences, result (batch_size) long 198 | selected, p = self._select_node( 199 | pi, mask.bool() 200 | ) # Squeeze out steps dimension 201 | # entropy += torch.sum(p * (p.log()), dim=1) 202 | state = state.update((selected)[:, None]) 203 | outputs.append(p) 204 | sequences.append(selected) 205 | i += 1 206 | # Collected lists, return Tensor 207 | return ( 208 | torch.stack(outputs, 1), 209 | torch.stack(sequences, 1), 210 | state.size, 211 | ) 212 | 213 | def _select_node(self, probs, mask): 214 | assert (probs == probs).all(), "Probs should not contain any nans" 215 | probs[mask] = -1e6 216 | p = torch.log_softmax(probs, dim=1) 217 | # print(p) 218 | if self.decode_type == "greedy": 219 | _, selected = p.max(1) 220 | # assert not mask.gather( 221 | # 1, selected.unsqueeze(-1) 222 | # ).data.any(), "Decode greedy: infeasible action has maximum probability" 223 | 224 | elif self.decode_type == "sampling": 225 | selected = p.exp().multinomial(1).squeeze(1) 226 | # Check if sampling went OK, can go wrong due to bug on GPU 227 | # See https://discuss.pytorch.org/t/bad-behavior-of-multinomial-function/10232 228 | # while mask.gather(1, selected.unsqueeze(-1)).data.any(): 229 | # print("Sampled bad values, resampling!") 230 | # selected = probs.multinomial(1).squeeze(1) 231 | 232 | else: 233 | assert False, "Unknown decode type" 234 | return selected, p 235 | -------------------------------------------------------------------------------- /policy/inv_ff_history_switch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from policy.inv_ff_history import InvariantFFHist 4 | from utils.functions import random_max 5 | import math 6 | 7 | # from utils.visualize import visualize_reward_process 8 | from copy import deepcopy 9 | 10 | class InvariantFFHistSwitch(InvariantFFHist): 11 | def __init__( 12 | self, 13 | embedding_dim, 14 | hidden_dim, 15 | problem, 16 | opts, 17 | tanh_clipping=None, 18 | mask_inner=None, 19 | mask_logits=None, 20 | n_encode_layers=None, 21 | normalization="batch", 22 | checkpoint_encoder=False, 23 | shrink_size=None, 24 | num_actions=4, 25 | n_heads=None, 26 | encoder=None, 27 | ): 28 | super(InvariantFFHistSwitch, self).__init__(embedding_dim, 29 | hidden_dim, 30 | problem, 31 | opts, 32 | tanh_clipping=tanh_clipping, 33 | mask_inner=mask_inner, 34 | mask_logits=mask_logits, 35 | n_encode_layers=n_encode_layers, 36 | normalization=normalization, 37 | checkpoint_encoder=checkpoint_encoder, 38 | shrink_size=shrink_size, 39 | num_actions=num_actions, 40 | n_heads=n_heads, 41 | encoder=encoder) 42 | 43 | self.switch_lambda = opts.switch_lambda 44 | self.slackness = opts.slackness 45 | self.max_reward = opts.max_reward 46 | self.softmax_temp = opts.softmax_temp 47 | 48 | 49 | def _inner_pure_ml(self, input, opts, show_rewards = False): 50 | 51 | outputs = [] 52 | sequences = [] 53 | state = self.problem.make_state(deepcopy(input), opts.u_size, opts.v_size, opts) 54 | 55 | # Perform decoding steps 56 | rewards = [] 57 | i = 1 58 | while not (state.all_finished()): 59 | mask = state.get_mask() 60 | state.get_current_weights(mask) 61 | s, mask = state.get_curr_state(self.model_name) 62 | pi = self.ff(s).reshape(state.batch_size, state.u_size + 1) 63 | # Select the indices of the next nodes in the sequences, result (batch_size) long 64 | selected, p = self._select_node( 65 | pi, mask.bool() 66 | ) # Squeeze out steps dimension 67 | # entropy += torch.sum(p * (p.log()), dim=1) 68 | state = state.update((selected)[:, None]) 69 | outputs.append(p) 70 | sequences.append(selected) 71 | i += 1 72 | 73 | rewards.append(state.size.view(-1)) 74 | 75 | res_list = [torch.stack(outputs, 1), 76 | torch.stack(sequences, 1), 77 | state.size,] 78 | if show_rewards: 79 | # visualize_reward_process(torch.stack(rewards, 1).cpu().numpy(), "temp_ml_movie.png",ylim = [-1, 14]) 80 | res_list.append(torch.stack(rewards, 1)) 81 | 82 | # Collected lists, return Tensor 83 | return res_list 84 | 85 | def _inner_switch(self, input, opts, show_rewards = False, osm_baseline = False): 86 | 87 | # torch.autograd.set_detect_anomaly(True) 88 | 89 | if opts.problem not in ["osbm", "e-obm"]: 90 | raise NotImplementedError 91 | 92 | batch_size = int(input.batch.size(0) / (opts.u_size + opts.v_size + 1) ) 93 | 94 | ## Greedy algorithm state 95 | state_greedy = self.problem.make_state(deepcopy(input), opts.u_size, opts.v_size, opts) 96 | 97 | ## ML algorithm state 98 | state_ml = self.problem.make_state(deepcopy(input), opts.u_size, opts.v_size, opts) 99 | sequences_ml, outputs_ml, rewards_ml = [], [], [] 100 | 101 | graph_v_size = opts.v_size 102 | i = 1 103 | 104 | while not (state_ml.all_finished()): 105 | #Extract Current Info 106 | mask_gd = state_greedy.get_mask() 107 | weight_gd = state_greedy.get_current_weights(mask_gd).clone() 108 | weight_gd[mask_gd.bool()] = -1.0 109 | 110 | if osm_baseline: 111 | # Select action based on current info 112 | if(i > graph_v_size/math.e): 113 | selected_gd = random_max(weight_gd) 114 | else: 115 | selected_gd = torch.zeros([weight_gd.shape[0], 1], device=weight_gd.device, dtype=torch.int64) 116 | else: 117 | selected_gd = random_max(weight_gd) 118 | 119 | state_greedy = state_greedy.update(selected_gd) 120 | 121 | # Generate Probabality from Greedy algorithm 122 | probs_gd = -1e8 * torch.ones(mask_gd.shape, device = mask_gd.device) 123 | probs_gd[(torch.arange(batch_size) , selected_gd.view(-1) )] = 1 124 | p_gd = torch.log_softmax(probs_gd, dim=1) 125 | 126 | reward_gd = state_greedy.size 127 | 128 | #Extract Current Info 129 | mask_ml = state_ml.get_mask() 130 | state_ml.get_current_weights(mask_ml) 131 | info_s, mask_ml = state_ml.get_curr_state(self.model_name) 132 | pi = self.ff(info_s).reshape(state_ml.batch_size, state_ml.u_size + 1) 133 | 134 | # Select the indices of the next nodes in the sequences, result (batch_size) long 135 | selected_ml, p_ml = self._select_node(pi, mask_ml.bool()) 136 | selected_ml = (selected_ml)[:, None] 137 | 138 | # Collect reward based on these hypothesis actions 139 | reward_hyp, matched_node_hyp = state_ml.try_hypothesis_action(selected_ml) 140 | 141 | # Compare the ML with Greedy actions 142 | matched_node_gd = state_greedy.matched_nodes 143 | action_diff = torch.relu(matched_node_hyp - matched_node_gd)[:,1:] 144 | action_diff = action_diff.sum(1, keepdim=True) 145 | 146 | # Calculate reward 147 | reward_gd_reserve = reward_gd + action_diff * self.max_reward 148 | reward_diff = reward_hyp - (self.switch_lambda * reward_gd_reserve - self.slackness) 149 | 150 | # Determin the probability of the final action 151 | probs_policy = torch.zeros([batch_size, 2], device=mask_gd.device) 152 | probs_policy[:,0] = reward_diff.view(-1)/self.softmax_temp # Normalize the selection cost 153 | p_policy_exp = torch.log_softmax(probs_policy, dim=1).exp() 154 | 155 | p_total_exp = p_policy_exp[:,0].view(-1,1) * (p_ml.exp()) + p_policy_exp[:,1].view(-1,1) * (p_gd.exp()) + 1e-8 156 | p_total = p_total_exp.log() 157 | 158 | # Determine the action based on the reward_diff (expert v.s. ML) 159 | if (self.ff.training): 160 | # During training, use probability based method 161 | projected_ml, p_final = self._select_node(p_total, mask_ml.bool()) 162 | projected_ml = (projected_ml)[:, None] 163 | else: 164 | # During training, use hard-switch 165 | reward_sign = (reward_diff >= 0) 166 | avaliable_sign = (matched_node_hyp.gather(1, selected_gd) < 1e-10) 167 | expert_sign = torch.logical_not(reward_sign) * avaliable_sign 168 | projected_ml = selected_ml*reward_sign + selected_gd * expert_sign 169 | p_final = torch.log_softmax(p_total, dim=1) 170 | 171 | # Update the current state 172 | state_ml = state_ml.update(projected_ml) 173 | outputs_ml.append(p_final) 174 | sequences_ml.append(projected_ml.squeeze(1)) 175 | i += 1 176 | 177 | rewards_ml.append(state_ml.size.view(-1)) 178 | 179 | res_list = [torch.stack(outputs_ml, 1), 180 | torch.stack(sequences_ml, 1), 181 | state_ml.size,] 182 | 183 | if show_rewards: 184 | # visualize_reward_process(torch.stack(rewards, 1).cpu().numpy(), "temp_ml_movie.png",ylim = [-1, 14]) 185 | res_list.append(torch.stack(rewards_gd, 1)) 186 | 187 | # Collected lists, return Tensor 188 | return res_list 189 | 190 | def _inner_greedy(self, input, opts, show_rewards = False): 191 | state = self.problem.make_state(deepcopy(input), opts.u_size, opts.v_size, opts) 192 | sequences = [] 193 | outputs = [] 194 | batch_size = int(input.batch.size(0) / (opts.u_size + opts.v_size + 1) ) 195 | 196 | rewards = [] 197 | i = 1 198 | while not (state.all_finished()): 199 | mask = state.get_mask() 200 | w = state.get_current_weights(mask).clone() 201 | w[mask.bool()] = -1.0 202 | selected = random_max(w) 203 | state = state.update(selected) 204 | sequences.append(selected.squeeze(1)) 205 | 206 | # Generate Probabality from Greedy algorithm 207 | probs = -1e8 * torch.ones(mask.shape, device = mask.device) 208 | probs[(torch.arange(batch_size) , selected.view(-1) )] = 1 209 | p = torch.log_softmax(probs, dim=1) 210 | outputs.append(p) 211 | i += 1 212 | 213 | rewards.append(state.size.view(-1)) 214 | 215 | res_list = [torch.stack(outputs, 1), 216 | torch.stack(sequences, 1), 217 | state.size,] 218 | 219 | if show_rewards: 220 | # visualize_reward_process(torch.stack(rewards, 1).cpu().numpy(), "temp_ml_movie.png",ylim = [-1, 14]) 221 | res_list.append(torch.stack(rewards, 1)) 222 | 223 | # Collected lists, return Tensor 224 | return res_list 225 | 226 | def _inner(self, input, opts): 227 | 228 | # visualize_reward_process(torch.stack(rewards, 1).cpu().numpy(), "temp_ml_movie.png",ylim = [-1, 14]) 229 | 230 | # return self._inner_greedy(input, opts) 231 | # return self._inner_pure_ml(input, opts) 232 | return self._inner_switch(input, opts) 233 | 234 | -------------------------------------------------------------------------------- /run_lite.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import pprint as pp 5 | 6 | import torch 7 | import torch.optim as optim 8 | from itertools import product 9 | import wandb 10 | 11 | # from tensorboard_logger import Logger as TbLogger 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torch_geometric.loader import DataLoader as geoDataloader 14 | 15 | # from nets.critic_network import CriticNetwork 16 | from options import get_options 17 | from train import train_epoch, validate, get_inner_model 18 | from utils.reinforce_baselines import ( 19 | NoBaseline, 20 | ExponentialBaseline, 21 | RolloutBaseline, 22 | WarmupBaseline, 23 | GreedyBaseline, 24 | ) 25 | 26 | from policy.attention_model import AttentionModel as AttentionModelgeo 27 | from policy.ff_model import FeedForwardModel 28 | from policy.ff_model_invariant import InvariantFF 29 | from policy.ff_model_hist import FeedForwardModelHist 30 | from policy.inv_ff_history import InvariantFFHist 31 | from policy.inv_ff_history_switch import InvariantFFHistSwitch 32 | from policy.greedy import Greedy 33 | from policy.greedy_rt import GreedyRt 34 | from policy.greedy_sc import GreedySC 35 | from policy.greedy_theshold import GreedyThresh 36 | from policy.greedy_matching import GreedyMatching 37 | from policy.simple_greedy import SimpleGreedy 38 | from policy.supervised import SupervisedModel 39 | from policy.ff_supervised import SupervisedFFModel 40 | from policy.gnn_hist import GNNHist 41 | from policy.gnn_simp_hist import GNNSimpHist 42 | from policy.gnn import GNN 43 | 44 | # from nets.pointer_network import PointerNetwork, CriticNetworkLSTM 45 | from utils.functions import torch_load_cpu, load_problem 46 | from utils.visualize import validate_histogram 47 | 48 | 49 | def run(opts): 50 | 51 | # Pretty print the run args 52 | # pp.pprint(vars(opts)) 53 | 54 | # Set the random seed 55 | torch.manual_seed(opts.seed) 56 | # torch.backends.cudnn.benchmark = True 57 | # torch.autograd.set_detect_anomaly(True) 58 | # Optionally configure tensorboard 59 | tb_logger = None 60 | if not opts.no_tensorboard: 61 | tb_logger = SummaryWriter( 62 | os.path.join( 63 | opts.log_dir, 64 | opts.model, 65 | opts.run_name, 66 | ) 67 | ) 68 | if not opts.eval_only and not os.path.exists(opts.save_dir): 69 | os.makedirs(opts.save_dir) 70 | # Save arguments so exact configuration can always be found 71 | with open(os.path.join(opts.save_dir, "args.json"), "w") as f: 72 | json.dump(vars(opts), f, indent=True) 73 | 74 | # Set the device 75 | opts.device = torch.device("cuda:0" if opts.use_cuda else "cpu") 76 | 77 | # Figure out what's the problem 78 | problem = load_problem(opts.problem) 79 | 80 | # Load data from load_path 81 | load_data = {} 82 | assert ( 83 | opts.load_path is None or opts.resume is None 84 | ), "Only one of load path and resume can be given" 85 | load_path = opts.load_path if opts.load_path is not None else opts.resume 86 | if load_path is not None: 87 | print(" [*] Loading data from {}".format(load_path)) 88 | load_data = torch_load_cpu(load_path) 89 | 90 | assert (opts.tune_wandb == False 91 | and opts.tune == False 92 | and opts.tune_baseline == False), "Unsupported Mode, please use \"run.py\" instead of \"run_lite.py\" " 93 | 94 | # assert(opts.model == "inv-ff-hist"), "We only support inv-ff-hist model up to now..." 95 | # Initialize model 96 | model_class = { 97 | "attention": AttentionModelgeo, 98 | "ff": FeedForwardModel, 99 | "greedy": Greedy, 100 | "greedy-rt": GreedyRt, 101 | "greedy-sc": GreedySC, 102 | "greedy-t": GreedyThresh, 103 | "greedy-m": GreedyMatching, 104 | "simple-greedy": SimpleGreedy, 105 | "inv-ff": InvariantFF, 106 | "inv-ff-hist": InvariantFFHist, 107 | "ff-hist": FeedForwardModelHist, 108 | "supervised": SupervisedModel, 109 | "ff-supervised": SupervisedFFModel, 110 | "gnn-hist": GNNHist, 111 | "gnn-simp-hist": GNNSimpHist, 112 | "gnn": GNN, 113 | "inv-ff-hist-switch":InvariantFFHistSwitch 114 | }.get(opts.model, None) 115 | assert model_class is not None, "Unknown model: {}".format(model_class) 116 | # if not opts.tune: 117 | model, lr_schedulers, optimizers, val_dataloader, baseline = setup_training_env( 118 | opts, model_class, problem, load_data, tb_logger 119 | ) 120 | 121 | training_dataset = problem.make_dataset( 122 | opts.train_dataset, opts.dataset_size, opts.problem, seed=None, opts=opts 123 | ) 124 | if opts.eval_only: 125 | 126 | print("Evaluate Greedy algorithm as Baseline algorithm") 127 | greedy_model = Greedy(opts.embedding_dim, opts.hidden_dim, problem=problem, opts=opts).to(opts.device) 128 | validate(greedy_model, val_dataloader, opts) 129 | 130 | print("Evaluate ML model from checkpoint") 131 | validate(model, val_dataloader, opts) 132 | switch_lambda_array = np.arange(0.0, 1.01, 0.1) 133 | 134 | 135 | 136 | 137 | else: 138 | best_avg_cr = 0.0 139 | for epoch in range(opts.epoch_start, opts.epoch_start + opts.n_epochs): 140 | training_dataloader = geoDataloader( 141 | baseline.wrap_dataset(training_dataset), 142 | batch_size=opts.batch_size, 143 | num_workers=0, 144 | shuffle=True, 145 | ) 146 | avg_reward, min_cr, avg_cr, loss = train_epoch( 147 | model, 148 | optimizers, 149 | baseline, 150 | lr_schedulers, 151 | epoch, 152 | val_dataloader, 153 | training_dataloader, 154 | problem, 155 | tb_logger, 156 | opts, 157 | best_avg_cr, 158 | ) 159 | # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) 160 | best_avg_cr = max(best_avg_cr, avg_cr) 161 | 162 | 163 | def setup_training_env(opts, model_class, problem, load_data, tb_logger): 164 | model = model_class( 165 | opts.embedding_dim, 166 | opts.hidden_dim, 167 | problem=problem, 168 | n_encode_layers=opts.n_encode_layers, 169 | mask_inner=True, 170 | mask_logits=True, 171 | normalization=opts.normalization, 172 | tanh_clipping=opts.tanh_clipping, 173 | checkpoint_encoder=opts.checkpoint_encoder, 174 | shrink_size=opts.shrink_size, 175 | num_actions=opts.u_size + 1, 176 | n_heads=opts.n_heads, 177 | encoder=opts.encoder, 178 | opts=opts, 179 | ).to(opts.device) 180 | 181 | if opts.use_cuda and torch.cuda.device_count() > 1: 182 | model = torch.nn.DataParallel(model) 183 | 184 | # Overwrite model parameters by parameters to load 185 | model_ = get_inner_model(model) 186 | model_.load_state_dict({**model_.state_dict(), **load_data.get("model", {})}) 187 | 188 | # Initialize baseline 189 | if opts.baseline == "exponential": 190 | baseline = ExponentialBaseline(opts.exp_beta) 191 | elif opts.baseline == "greedy": 192 | baseline_class = {"e-obm": Greedy, "obm": SimpleGreedy}.get(opts.problem, None) 193 | 194 | greedybaseline = baseline_class( 195 | opts.embedding_dim, 196 | opts.hidden_dim, 197 | problem=problem, 198 | n_encode_layers=opts.n_encode_layers, 199 | mask_inner=True, 200 | mask_logits=True, 201 | normalization=opts.normalization, 202 | tanh_clipping=opts.tanh_clipping, 203 | checkpoint_encoder=opts.checkpoint_encoder, 204 | shrink_size=opts.shrink_size, 205 | num_actions=opts.u_size + 1, 206 | # n_heads=opts.n_heads, 207 | ) 208 | baseline = GreedyBaseline(greedybaseline, opts) 209 | elif opts.baseline == "rollout": 210 | baseline = RolloutBaseline(model, problem, opts) 211 | else: 212 | assert opts.baseline is None, "Unknown baseline: {}".format(opts.baseline) 213 | baseline = NoBaseline() 214 | 215 | if opts.bl_warmup_epochs > 0: 216 | baseline = WarmupBaseline( 217 | baseline, opts.bl_warmup_epochs, warmup_exp_beta=opts.exp_beta 218 | ) 219 | 220 | # Load baseline from data, make sure script is called with same type of baseline 221 | if "baseline" in load_data: 222 | baseline.load_state_dict(load_data["baseline"]) 223 | 224 | # Initialize optimizer 225 | optimizer = optim.Adam( 226 | [{"params": model.parameters(), "lr": opts.lr_model}] 227 | + ( 228 | [{"params": baseline.get_learnable_parameters(), "lr": opts.lr_critic}] 229 | if len(baseline.get_learnable_parameters()) > 0 230 | else [] 231 | ) 232 | ) 233 | 234 | # Load optimizer state 235 | if "optimizer" in load_data: 236 | optimizer.load_state_dict(load_data["optimizer"]) 237 | for state in optimizer.state.values(): 238 | for k, v in state.items(): 239 | # if isinstance(v, torch.Tensor): 240 | if torch.is_tensor(v): 241 | state[k] = v.to(opts.device) 242 | 243 | # Initialize learning rate scheduler, decay by lr_decay once per epoch! 244 | lr_scheduler = optim.lr_scheduler.LambdaLR( 245 | optimizer, lambda epoch: opts.lr_decay ** epoch 246 | ) 247 | 248 | # Start the actual training loop 249 | val_dataset = problem.make_dataset( 250 | opts.val_dataset, opts.val_size, opts.problem, seed=None, opts=opts 251 | ) 252 | val_dataloader = geoDataloader( 253 | val_dataset, batch_size=opts.batch_size, num_workers=1 254 | ) 255 | assert (opts.resume is None), "Resume not supported, please use \"run.py\" instead of \"run_lite.py\"" 256 | 257 | return ( 258 | model, 259 | [lr_scheduler], 260 | [optimizer], 261 | val_dataloader, 262 | baseline, 263 | ) 264 | 265 | 266 | if __name__ == "__main__": 267 | run(get_options()) 268 | -------------------------------------------------------------------------------- /utils/reinforce_baselines.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.utils.data import Dataset 4 | from scipy.stats import ttest_rel 5 | import copy 6 | from train import rollout, get_inner_model 7 | from torch_geometric.data import DataLoader 8 | 9 | 10 | class Baseline(object): 11 | def wrap_dataset(self, dataset): 12 | return dataset 13 | 14 | def unwrap_batch(self, batch): 15 | return batch, None 16 | 17 | def eval(self, x, c): 18 | raise NotImplementedError("Override this method") 19 | 20 | def get_learnable_parameters(self): 21 | return [] 22 | 23 | def epoch_callback(self, model, epoch): 24 | pass 25 | 26 | def state_dict(self): 27 | return {} 28 | 29 | def load_state_dict(self, state_dict): 30 | pass 31 | 32 | 33 | class WarmupBaseline(Baseline): 34 | def __init__( 35 | self, baseline, n_epochs=1, warmup_exp_beta=0.8, 36 | ): 37 | super(Baseline, self).__init__() 38 | 39 | self.baseline = baseline 40 | assert n_epochs > 0, "n_epochs to warmup must be positive" 41 | self.warmup_baseline = ExponentialBaseline(warmup_exp_beta) 42 | self.alpha = 0 43 | self.n_epochs = n_epochs 44 | 45 | def wrap_dataset(self, dataset): 46 | if self.alpha > 0: 47 | return self.baseline.wrap_dataset(dataset) 48 | return self.warmup_baseline.wrap_dataset(dataset) 49 | 50 | def unwrap_batch(self, batch): 51 | if self.alpha > 0: 52 | return self.baseline.unwrap_batch(batch) 53 | # return batch 54 | return self.warmup_baseline.unwrap_batch(batch) 55 | 56 | def eval(self, x, c): 57 | 58 | if self.alpha == 1: 59 | return self.baseline.eval(x, c) 60 | if self.alpha == 0: 61 | return self.warmup_baseline.eval(x, c) 62 | v, t = self.baseline.eval(x, c) 63 | vw, lw = self.warmup_baseline.eval(x, c) 64 | # Return convex combination of baseline and of loss 65 | return ( 66 | self.alpha * v + (1 - self.alpha) * vw, 67 | self.alpha * t + (1 - self.alpha * lw), 68 | ) 69 | 70 | def epoch_callback(self, model, epoch): 71 | # Need to call epoch callback of inner model (also after first epoch if we have not used it) 72 | self.baseline.epoch_callback(model, epoch) 73 | self.alpha = (epoch + 1) / float(self.n_epochs) 74 | if epoch < self.n_epochs: 75 | print("Set warmup alpha = {}".format(self.alpha)) 76 | 77 | def state_dict(self): 78 | # Checkpointing within warmup stage makes no sense, only save inner baseline 79 | return self.baseline.state_dict() 80 | 81 | def load_state_dict(self, state_dict): 82 | # Checkpointing within warmup stage makes no sense, only load inner baseline 83 | self.baseline.load_state_dict(state_dict) 84 | 85 | 86 | class NoBaseline(Baseline): 87 | def eval(self, x, c): 88 | return 0, 0 # No baseline, no loss 89 | 90 | 91 | class ExponentialBaseline(Baseline): 92 | def __init__(self, beta): 93 | super(Baseline, self).__init__() 94 | 95 | self.beta = beta 96 | self.v = None 97 | 98 | def eval(self, x, c): 99 | 100 | if self.v is None: 101 | v = c.mean() 102 | else: 103 | v = self.beta * self.v + (1.0 - self.beta) * c.mean() 104 | 105 | self.v = v.detach() # Detach since we never want to backprop 106 | return self.v, 0 # No loss 107 | 108 | def state_dict(self): 109 | return {"v": self.v} 110 | 111 | def load_state_dict(self, state_dict): 112 | self.v = state_dict["v"] 113 | 114 | 115 | class CriticBaseline(Baseline): 116 | def __init__(self, critic): 117 | super(Baseline, self).__init__() 118 | 119 | self.critic = critic 120 | 121 | def eval(self, x, c): 122 | v = self.critic(x) 123 | # Detach v since actor should not backprop through baseline, only for loss 124 | return v.detach(), F.mse_loss(v, c.detach()) 125 | 126 | def get_learnable_parameters(self): 127 | return list(self.critic.parameters()) 128 | 129 | def epoch_callback(self, model, epoch): 130 | pass 131 | 132 | def state_dict(self): 133 | return {"critic": self.critic.state_dict()} 134 | 135 | def load_state_dict(self, state_dict): 136 | critic_state_dict = state_dict.get("critic", {}) 137 | if not isinstance(critic_state_dict, dict): # backwards compatibility 138 | critic_state_dict = critic_state_dict.state_dict() 139 | self.critic.load_state_dict({**self.critic.state_dict(), **critic_state_dict}) 140 | 141 | 142 | class GreedyBaseline(Baseline): 143 | def __init__(self, greedymodel, opts): 144 | super(Baseline, self).__init__() 145 | 146 | self.baseline = greedymodel 147 | self.opts = opts 148 | 149 | def eval(self, x, c): 150 | 151 | return self.baseline(x, opts=self.opts)[0].detach(), 0 # No loss 152 | 153 | 154 | class RolloutBaseline(Baseline): 155 | def __init__(self, model, problem, opts, epoch=0): 156 | super(Baseline, self).__init__() 157 | 158 | self.problem = problem 159 | self.opts = opts 160 | 161 | self._update_model(model, epoch) 162 | 163 | def _update_model(self, model, epoch, dataset=None): 164 | self.model = copy.deepcopy(model) 165 | # Always generate baseline dataset when updating model to prevent overfitting to the baseline dataset 166 | 167 | if dataset is not None: 168 | if len(dataset) != self.opts.val_size: 169 | print( 170 | "Warning: not using saved baseline dataset since val_size does not match" 171 | ) 172 | dataset = None 173 | elif (dataset[0] if self.problem.NAME == "tsp" else dataset[0]["loc"]).size( 174 | 0 175 | ) != self.opts.graph_size: 176 | print( 177 | "Warning: not using saved baseline dataset since graph_size does not match" 178 | ) 179 | dataset = None 180 | 181 | if dataset is None: 182 | self.dataset = DataLoader( 183 | self.problem.make_dataset( 184 | None, 185 | self.opts.val_size, 186 | self.opts.problem, 187 | seed=epoch * 1000, 188 | opts=self.opts, 189 | ), 190 | batch_size=self.opts.eval_batch_size, 191 | num_workers=1, 192 | ) 193 | else: 194 | self.dataset = dataset 195 | print("Evaluating baseline model on evaluation dataset") 196 | self.bl_vals = ( 197 | rollout(self.model, self.dataset, self.opts)[0].cpu().numpy() / 100.0 198 | ) 199 | self.mean = self.bl_vals.mean() 200 | self.epoch = epoch 201 | 202 | def wrap_dataset(self, dataset): 203 | print("Evaluating baseline on dataset...") 204 | # Need to convert baseline to 2D to prevent converting to double, see 205 | # https://discuss.pytorch.org/t/dataloader-gives-double-instead-of-float/717/3 206 | dataloader = DataLoader( 207 | dataset, batch_size=self.opts.eval_batch_size, num_workers=1 208 | ) 209 | return BaselineDataset( 210 | dataset, rollout(self.model, dataloader, self.opts)[0].view(-1, 1) 211 | ) 212 | # return dataset 213 | 214 | def unwrap_batch(self, batch): 215 | return ( 216 | batch["data"], 217 | None, 218 | ) # Flatten result to undo wrapping as 2D 219 | 220 | def eval(self, x, c): 221 | # Use volatile mode for efficient inference (single batch so we do not use rollout function) 222 | with torch.no_grad(): 223 | v, _ = self.model(x, self.opts, None, None) 224 | 225 | # There is no loss 226 | return v, 0 227 | 228 | def epoch_callback(self, model, epoch): 229 | """ 230 | Challenges the current baseline with the model and replaces the baseline model if it is improved. 231 | :param model: The model to challenge the baseline by 232 | :param epoch: The current epoch 233 | """ 234 | print("Evaluating candidate model on evaluation dataset") 235 | candidate_vals = ( 236 | rollout(model, self.dataset, self.opts)[0].cpu().numpy() / 100.0 237 | ) 238 | 239 | candidate_mean = candidate_vals.mean() 240 | 241 | print( 242 | "Epoch {} candidate mean {}, baseline epoch {} mean {}, difference {}".format( 243 | epoch, candidate_mean, self.epoch, self.mean, candidate_mean - self.mean 244 | ) 245 | ) 246 | if candidate_mean - self.mean < 0: 247 | # Calc p value 248 | t, p = ttest_rel(candidate_vals, self.bl_vals) 249 | 250 | p_val = p / 2 # one-sided 251 | assert t < 0, "T-statistic should be negative" 252 | print("p-value: {}".format(p_val)) 253 | if p_val < self.opts.bl_alpha: 254 | print("Update baseline") 255 | self._update_model(model, epoch) 256 | 257 | def state_dict(self): 258 | return {"model": self.model, "dataset": self.dataset, "epoch": self.epoch} 259 | 260 | def load_state_dict(self, state_dict): 261 | # We make it such that it works whether model was saved as data parallel or not 262 | load_model = copy.deepcopy(self.model) 263 | get_inner_model(load_model).load_state_dict( 264 | get_inner_model(state_dict["model"]).state_dict() 265 | ) 266 | self._update_model(load_model, state_dict["epoch"], state_dict["dataset"]) 267 | 268 | 269 | class BaselineDataset(Dataset): 270 | def __init__(self, dataset=None, baseline=None): 271 | super(BaselineDataset, self).__init__() 272 | 273 | self.dataset = dataset 274 | self.baseline = baseline 275 | assert len(self.dataset) == len(self.baseline) 276 | 277 | def __getitem__(self, item): 278 | return {"data": self.dataset[item], "baseline": self.baseline[item]} 279 | 280 | def __len__(self): 281 | return len(self.dataset) 282 | -------------------------------------------------------------------------------- /policy/gnn_hist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.utils.checkpoint import checkpoint 4 | import math 5 | 6 | from encoder.graph_encoder_v2 import GraphAttentionEncoder 7 | 8 | from encoder.graph_encoder import MPNN 9 | from torch.nn import DataParallel 10 | from torch_geometric.utils import subgraph 11 | 12 | 13 | def set_decode_type(model, decode_type): 14 | if isinstance(model, DataParallel): 15 | model = model.module 16 | model.set_decode_type(decode_type) 17 | 18 | 19 | class GNNHist(nn.Module): 20 | def __init__( 21 | self, 22 | embedding_dim, 23 | hidden_dim, 24 | problem, 25 | opts, 26 | n_encode_layers=1, 27 | tanh_clipping=10.0, 28 | mask_inner=True, 29 | mask_logits=True, 30 | normalization="batch", 31 | n_heads=8, 32 | checkpoint_encoder=False, 33 | shrink_size=None, 34 | num_actions=None, 35 | encoder="mpnn", 36 | ): 37 | super(GNNHist, self).__init__() 38 | 39 | self.embedding_dim = embedding_dim 40 | self.hidden_dim = hidden_dim 41 | self.n_encode_layers = n_encode_layers 42 | self.decode_type = None 43 | self.temp = 1.0 44 | self.problem = problem 45 | self.opts = opts 46 | 47 | encoder_class = {"attention": GraphAttentionEncoder, "mpnn": MPNN}.get( 48 | encoder, None 49 | ) 50 | if opts.problem == "osbm": 51 | node_dim_u = 16 52 | node_dim_v = 3 53 | else: 54 | node_dim_u, node_dim_v = 1, 1 55 | 56 | self.embedder = encoder_class( 57 | n_heads=n_heads, 58 | embed_dim=embedding_dim, 59 | n_layers=self.n_encode_layers, 60 | normalization=normalization, 61 | problem=self.problem, 62 | opts=self.opts, 63 | node_dim_v=node_dim_v, 64 | node_dim_u=node_dim_u, 65 | ) 66 | 67 | self.ff = nn.Sequential( 68 | nn.Linear(5 + 4 * opts.embedding_dim, 200), 69 | nn.ReLU(), 70 | nn.Linear(200, 200), 71 | nn.ReLU(), 72 | nn.Linear(200, 1), 73 | ) 74 | 75 | assert embedding_dim % n_heads == 0 76 | self.step_context_transf = nn.Linear(2 * opts.embedding_dim, opts.embedding_dim) 77 | self.initial_stepcontext = nn.Parameter(torch.Tensor(1, 1, embedding_dim)) 78 | self.initial_stepcontext.data.uniform_(-1, 1) 79 | self.dummy = torch.ones(1, dtype=torch.float32, requires_grad=True) 80 | self.model_name = "gnn-hist" 81 | 82 | def init_parameters(self): 83 | for name, param in self.named_parameters(): 84 | stdv = 1.0 / math.sqrt(param.size(-1)) 85 | param.data.uniform_(-stdv, stdv) 86 | 87 | def set_decode_type(self, decode_type, temp=None): 88 | self.decode_type = decode_type 89 | if temp is not None: # Do not change temperature if not provided 90 | self.temp = temp 91 | 92 | def forward(self, x, opts, optimizer, baseline, return_pi=False): 93 | 94 | _log_p, pi, cost = self._inner(x, opts) 95 | 96 | # cost, mask = self.problem.get_costs(input, pi) 97 | # Log likelyhood is calculated within the model since returning it per action does not work well with 98 | # DataParallel since sequences can be of different lengths 99 | ll, e = self._calc_log_likelihood(_log_p, pi, None) 100 | if return_pi: 101 | return -cost, ll, pi, e 102 | # print(ll) 103 | return -cost, ll, e 104 | 105 | def _calc_log_likelihood(self, _log_p, a, mask): 106 | 107 | entropy = -(_log_p * _log_p.exp()).sum(2).sum(1).mean() 108 | # Get log_p corresponding to selected actions 109 | log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1) 110 | 111 | # Optional: mask out actions irrelevant to objective so they do not get reinforced 112 | if mask is not None: 113 | log_p[mask] = 0 114 | if not (log_p > -10000).data.all(): 115 | print(log_p.nonzero()) 116 | assert ( 117 | log_p > -10000 118 | ).data.all(), "Logprobs should not be -inf, check sampling procedure!" 119 | 120 | # Calculate log_likelihood 121 | # print(log_p.sum(1)) 122 | 123 | return log_p.sum(1), entropy 124 | 125 | def _inner(self, input, opts): 126 | 127 | outputs = [] 128 | sequences = [] 129 | 130 | state = self.problem.make_state(input, opts.u_size, opts.v_size, opts) 131 | 132 | batch_size = state.batch_size 133 | graph_size = state.u_size + state.v_size + 1 134 | i = 1 135 | step_context = 0.0 136 | while not (state.all_finished()): 137 | step_size = state.i + 1 138 | mask = state.get_mask() 139 | w = state.get_current_weights(mask) 140 | # Pass the graph to the Encoder 141 | node_features = state.get_node_features() 142 | nodes = torch.cat( 143 | ( 144 | torch.arange(0, opts.u_size + 1, device=opts.device), 145 | state.idx[:i] + opts.u_size + 1, 146 | ) 147 | ) 148 | subgraphs = ( 149 | (nodes.unsqueeze(0).expand(batch_size, step_size)) 150 | + torch.arange( 151 | 0, batch_size * graph_size, graph_size, device=opts.device 152 | ).unsqueeze(1) 153 | ).flatten() # The nodes of the current subgraphs 154 | graph_weights = state.get_graph_weights() 155 | edge_i, weights = subgraph( 156 | subgraphs, 157 | state.graphs.edge_index, 158 | graph_weights.unsqueeze(1), 159 | relabel_nodes=True, 160 | ) 161 | embeddings = checkpoint( 162 | self.embedder, 163 | node_features, 164 | edge_i, 165 | weights.float(), 166 | torch.tensor(i), 167 | self.dummy, 168 | opts, 169 | ).reshape(batch_size, step_size, -1) 170 | pos = torch.argsort(state.idx[:i])[-1] 171 | incoming_node_embeddings = embeddings[ 172 | :, pos + state.u_size + 1, : 173 | ].unsqueeze(1) 174 | # print(incoming_node_embeddings) 175 | w = (state.adj[:, state.get_current_node(), :]).float() 176 | # mean_w = w.mean(1)[:, None, None].repeat(1, state.u_size + 1, 1) 177 | s = w.reshape(state.batch_size, state.u_size + 1, 1) 178 | idx = ( 179 | torch.ones(state.batch_size, 1, 1, device=opts.device) 180 | * (i - 1.0) 181 | / state.v_size 182 | ) 183 | 184 | if i != 1: 185 | past_sol = ( 186 | torch.stack(sequences, 1) 187 | + torch.arange( 188 | 0, batch_size * (i - 1), i - 1, device=opts.device 189 | ).unsqueeze(1) 190 | ).flatten() 191 | 192 | selected_nodes = torch.index_select( 193 | embeddings.reshape(-1, opts.embedding_dim), 194 | 0, 195 | past_sol.to(opts.device), 196 | ).reshape(batch_size, i - 1, opts.embedding_dim) 197 | step_context = ( 198 | self.step_context_transf( 199 | torch.cat( 200 | ( 201 | selected_nodes, 202 | embeddings[:, state.u_size + 1 : state.u_size + i, :], 203 | ), 204 | dim=2, 205 | ) 206 | ) 207 | .mean(1) 208 | .unsqueeze(1) 209 | ) 210 | else: 211 | step_context = self.initial_stepcontext.repeat(batch_size, 1, 1) 212 | u_embeddings = embeddings[:, : opts.u_size + 1, :] 213 | fixed_node_identity = torch.zeros( 214 | state.batch_size, state.u_size + 1, 1, device=opts.device 215 | ) 216 | fixed_node_identity[:, 0, :] = 1.0 217 | norm_size = ( 218 | state.orig_budget.sum(-1)[:, None, None] 219 | if opts.problem == "adwords" 220 | else state.u_size 221 | ) 222 | s = torch.cat( 223 | ( 224 | s, 225 | idx.repeat(1, state.u_size + 1, 1), 226 | state.size.unsqueeze(2).repeat(1, state.u_size + 1, 1) / norm_size, 227 | fixed_node_identity, 228 | mask.unsqueeze(2), 229 | incoming_node_embeddings.repeat(1, state.u_size + 1, 1), 230 | embeddings[:, : opts.u_size + 1, :], 231 | step_context.repeat(1, state.u_size + 1, 1), 232 | u_embeddings.mean(1).unsqueeze(1).repeat(1, state.u_size + 1, 1), 233 | ), 234 | dim=2, 235 | ).float() 236 | pi = self.ff(s).reshape(state.batch_size, state.u_size + 1) 237 | # Select the indices of the next nodes in the sequences, result (batch_size) long 238 | selected, p = self._select_node( 239 | pi, mask.bool() 240 | ) # Squeeze out steps dimension 241 | # entropy += torch.sum(p * (p.log()), dim=1) 242 | state = state.update((selected)[:, None]) 243 | outputs.append(p) 244 | sequences.append(selected) 245 | i += 1 246 | # Collected lists, return Tensor 247 | return ( 248 | torch.stack(outputs, 1), 249 | torch.stack(sequences, 1), 250 | state.size, 251 | ) 252 | 253 | def _select_node(self, probs, mask): 254 | assert (probs == probs).all(), "Probs should not contain any nans" 255 | probs[mask] = -1e6 256 | p = torch.log_softmax(probs, dim=1) 257 | # print(p) 258 | if self.decode_type == "greedy": 259 | _, selected = p.max(1) 260 | # assert not mask.gather( 261 | # 1, selected.unsqueeze(-1) 262 | # ).data.any(), "Decode greedy: infeasible action has maximum probability" 263 | 264 | elif self.decode_type == "sampling": 265 | selected = p.exp().multinomial(1).squeeze(1) 266 | # Check if sampling went OK, can go wrong due to bug on GPU 267 | # See https://discuss.pytorch.org/t/bad-behavior-of-multinomial-function/10232 268 | # while mask.gather(1, selected.unsqueeze(-1)).data.any(): 269 | # print("Sampled bad values, resampling!") 270 | # selected = probs.multinomial(1).squeeze(1) 271 | 272 | else: 273 | assert False, "Unknown decode type" 274 | return selected, p 275 | -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import networkx as nx 5 | import torch 6 | from scipy.stats import powerlaw 7 | import torch_geometric 8 | 9 | 10 | # gMission files 11 | gMission_edges = "data/gMission/edges.txt" 12 | gMission_tasks = "data/gMission/tasks.txt" 13 | gMission_reduced_tasks = "data/gMission/reduced_tasks.txt" 14 | gMission_reduced_workers = "data/gMission/reduced_workers.txt" 15 | 16 | # MovieLense files 17 | movie_lense_movies = "data/MovieLense/movies.txt" 18 | movie_lense_users = "data/MovieLense/users.txt" 19 | movie_lense_edges = "data/MovieLense/edges.txt" 20 | movie_lense_ratings = "data/MovieLense/ratings.txt" 21 | movie_lense_feature_weights = "data/MovieLense/feature_weights.txt" 22 | 23 | 24 | def add_nodes_with_bipartite_label(G, lena, lenb): 25 | """ 26 | Helper for generate_ba_graph that initializes the initial empty graph with nodes 27 | """ 28 | G.add_nodes_from(range(0, lena + lenb)) 29 | b = dict(zip(range(0, lena), [0] * lena)) 30 | b.update(dict(zip(range(lena, lena + lenb), [1] * lenb))) 31 | nx.set_node_attributes(G, b, "bipartite") 32 | return G 33 | 34 | 35 | def get_solution(row_ind, col_in, weights, v_size): 36 | """ 37 | returns a np vector where the index at i is the the node in u that v_i connect to. If index is zero, then v[i] 38 | is connected to no node in U. 39 | """ 40 | new_col_in = [] 41 | # row_ind.sort() 42 | col_in = col_in + 1 43 | new_col_in += [0] * (row_ind[0]) 44 | 45 | for i in range(0, len(row_ind) - 1): 46 | if weights[row_ind[i], col_in[i] - 1] != 0.0: 47 | new_col_in.append(col_in[i]) 48 | else: 49 | new_col_in.append(0.0) 50 | new_col_in += [0.0] * (row_ind[i + 1] - row_ind[i] - 1) 51 | 52 | if weights[row_ind[-1], col_in[-1] - 1] != 0.0: 53 | new_col_in.append(col_in[-1]) 54 | else: 55 | new_col_in.append(0.0) 56 | new_col_in += [0] * (v_size - row_ind[-1] - 1) 57 | return new_col_in 58 | 59 | 60 | def check_extension(filename): 61 | if os.path.splitext(filename)[1] != ".pkl": 62 | return filename + ".pkl" 63 | return filename 64 | 65 | 66 | def save_dataset(dataset, filename): 67 | 68 | with open(check_extension(filename), "wb") as f: 69 | pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL) 70 | 71 | 72 | def load_dataset(filename): 73 | 74 | with open(check_extension(filename), "rb") as f: 75 | return pickle.load(f) 76 | 77 | 78 | def from_networkx(G): 79 | r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a 80 | :class:`torch_geometric.data.Data` instance. 81 | 82 | Args: 83 | G (networkx.Graph or networkx.DiGraph): A networkx graph. 84 | """ 85 | 86 | G = nx.convert_node_labels_to_integers(G, ordering="sorted") 87 | G = G.to_directed() if not nx.is_directed(G) else G 88 | edge_index = torch.LongTensor(list(G.edges)).t().contiguous() 89 | 90 | data = {} 91 | 92 | for i, (_, feat_dict) in enumerate(G.nodes(data=True)): 93 | for key, value in feat_dict.items(): 94 | data[str(key)] = [value] if i == 0 else data[str(key)] + [value] 95 | 96 | for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): 97 | for key, value in feat_dict.items(): 98 | data[str(key)] = [value] if i == 0 else data[str(key)] + [value] 99 | 100 | for key, item in data.items(): 101 | try: 102 | data[key] = torch.tensor(item) 103 | except ValueError: 104 | pass 105 | 106 | data["edge_index"] = edge_index.view(2, -1) 107 | data = torch_geometric.data.Data.from_dict(data) 108 | data.num_nodes = G.number_of_nodes() 109 | 110 | return data 111 | 112 | 113 | def parse_gmission_dataset(): 114 | f_edges = open(gMission_edges, "r") 115 | f_tasks = open(gMission_tasks, "r") 116 | f_reduced_tasks = open(gMission_reduced_tasks, "r") 117 | f_reduced_workers = open(gMission_reduced_workers, "r") 118 | edgeWeights = dict() 119 | edgeNumber = dict() 120 | count = 0 121 | for line in f_edges: 122 | vals = line.split(",") 123 | edgeWeights[vals[0]] = vals[1].split("\n")[0] 124 | edgeNumber[vals[0]] = count 125 | count += 1 126 | 127 | tasks = list() 128 | reduced_tasks = [] 129 | reduced_workers = [] 130 | tasks_x = dict() 131 | tasks_y = dict() 132 | 133 | for line in f_tasks: 134 | vals = line.split(",") 135 | tasks.append(vals[0]) 136 | tasks_x[vals[0]] = float(vals[2]) 137 | tasks_y[vals[0]] = float(vals[3]) 138 | 139 | for t in f_reduced_tasks: 140 | reduced_tasks.append(t) 141 | 142 | for w in f_reduced_workers: 143 | reduced_workers.append(w) 144 | 145 | return edgeWeights, tasks, reduced_tasks, reduced_workers 146 | 147 | 148 | def parse_movie_lense_dataset(): 149 | f_edges = open(movie_lense_edges, "r") 150 | f_movies = open(movie_lense_movies, "r") 151 | f_users = open(movie_lense_users, "r") 152 | f_feature_weights = open(movie_lense_feature_weights, "r") 153 | num_genres = 15 154 | gender_map = {"M": 0, "F": 1} 155 | age_map = {"1": 0, "18": 1, "25": 2, "35": 3, "45": 4, "50": 5, "56": 6} 156 | genre_map = { 157 | "Action": 0, 158 | "Adventure": 1, 159 | "Animation": 2, 160 | "Children's": 3, 161 | "Comedy": 4, 162 | "Crime": 5, 163 | "Documentary": 6, 164 | "Drama": 7, 165 | "Film-Noir": 8, 166 | "Horror": 9, 167 | "Musical": 10, 168 | "Romance": 11, 169 | "Sci-Fi": 12, 170 | "Thriller": 13, 171 | "War": 14, 172 | } 173 | users = {} 174 | movies = {} 175 | edges = {} 176 | feature_weights = {} 177 | user_ids = [] 178 | popularity = {} 179 | for u in f_users: 180 | info = u.split(",")[:4] 181 | info[1] = float(gender_map[info[1]]) 182 | info[2] = float(age_map[info[2]]) / 6.0 183 | info[3] = float(info[3]) / 21.0 184 | users[info[0]] = info[1:4] 185 | user_ids.append(int(info[0])) 186 | user_ids.sort() 187 | for i, u in enumerate(user_ids): 188 | users[str(u)].append(i) 189 | 190 | for m in f_movies: 191 | info = m.split("::") 192 | genres = info[2].split("|") 193 | genres[-1] = genres[-1].split("\n")[0] # remove "\n" character 194 | genres_id = np.array(list(map(lambda g: genre_map[g], genres))) 195 | one_hot_encoding = np.zeros(num_genres) 196 | one_hot_encoding[genres_id] = 1.0 197 | movies[info[0]] = list(one_hot_encoding) 198 | popularity[info[0]] = 0 199 | 200 | for e in f_edges: 201 | info = e.split(",") 202 | genres = info[3].split("|") 203 | genres[-1] = genres[-1].split("\n")[0] # remove "\n" character 204 | edges[(info[2], info[1])] = list(map(lambda g: genre_map[g], genres)) 205 | popularity[info[2]] += 1 206 | 207 | for w in f_feature_weights: 208 | feature = w.split(",") 209 | feature[-1] = feature[-1].split("\n")[0] # remove "\n" character 210 | if feature[1] not in feature_weights: 211 | feature_weights[feature[1]] = [0.0] * num_genres 212 | feature_weights[feature[1]][genre_map[feature[0]]] = float(feature[2]) / 5.0 213 | return users, movies, edges, feature_weights, popularity 214 | 215 | 216 | def find_best_tasks(tasks, edges): 217 | task_total = {} 218 | f_open = open("data/gMission/reduced_tasks.txt", "a") 219 | for e in edges.items(): 220 | task = e[0].split(";")[1] 221 | if task not in task_total: 222 | task_total[task] = 1 223 | else: 224 | task_total[task] += 1 225 | top_tasks = sorted(task_total.items(), key=lambda k: k[1], reverse=True)[:300] 226 | for t in top_tasks: 227 | f_open.write(t[0] + "\n") 228 | return top_tasks 229 | 230 | 231 | def find_best_workers(tasks, edges): 232 | task_total = {} 233 | f_open = open("data/gMission/reduced_workers.txt", "a") 234 | for e in edges.items(): 235 | task = e[0].split(";")[0] 236 | if task not in task_total: 237 | task_total[task] = 1 238 | else: 239 | task_total[task] += 1 240 | top_tasks = sorted(task_total.items(), key=lambda k: k[1], reverse=True)[:200] 241 | for t in top_tasks: 242 | f_open.write(t[0] + "\n") 243 | return top_tasks 244 | 245 | 246 | def generate_weights_geometric(distribution, u_size, v_size, parameters, g1, seed): 247 | weights, w = 0, 0 248 | np.random.seed(seed) 249 | if distribution == "uniform": 250 | weights = nx.bipartite.biadjacency_matrix( 251 | g1, range(0, u_size), range(u_size, u_size + v_size) 252 | ).toarray() * np.random.uniform( 253 | int(parameters[0]), int(parameters[1]), (u_size, v_size) 254 | ) 255 | w = torch.cat( 256 | (torch.zeros(v_size, 1).float(), torch.tensor(weights).T.float()), 1 257 | ) 258 | elif distribution == "normal": 259 | weights = nx.bipartite.biadjacency_matrix( 260 | g1, range(0, u_size), range(u_size, u_size + v_size) 261 | ).toarray() * ( 262 | np.abs( 263 | np.random.normal( 264 | int(parameters[0]), int(parameters[1]), (u_size, v_size) 265 | ) 266 | ) 267 | + 5 268 | ) # to make sure no edge has weight zero 269 | w = torch.cat( 270 | (torch.zeros(v_size, 1).float(), torch.tensor(weights).T.float()), 1 271 | ) 272 | elif distribution == "power": 273 | weights = nx.bipartite.biadjacency_matrix( 274 | g1, range(0, u_size), range(u_size, u_size + v_size) 275 | ).toarray() * ( 276 | powerlaw.rvs( 277 | int(parameters[0]), 278 | int(parameters[1]), 279 | int(parameters[2]), 280 | (u_size, v_size), 281 | ) 282 | + 5 283 | ) # to make sure no edge has weight zero 284 | w = torch.cat( 285 | (torch.zeros(v_size, 1).float(), torch.tensor(weights).T.float()), 1 286 | ) 287 | elif distribution == "degree": 288 | weights = nx.bipartite.biadjacency_matrix( 289 | g1, range(0, u_size), range(u_size, u_size + v_size) 290 | ).toarray() 291 | graph = weights * weights.sum(axis=1).reshape(-1, 1) 292 | noise = np.abs( 293 | np.random.normal( 294 | float(parameters[0]), float(parameters[1]), (u_size, v_size) 295 | ) 296 | ) 297 | weights = np.where(graph, (graph + noise) / v_size, graph) 298 | w = torch.cat( 299 | (torch.zeros(v_size, 1).float(), torch.tensor(weights).T.float()), 1 300 | ) 301 | elif distribution == "node-normal": 302 | adj = nx.bipartite.biadjacency_matrix( 303 | g1, range(0, u_size), range(u_size, u_size + v_size) 304 | ).toarray() 305 | mean = np.random.randint( 306 | float(parameters[0]), float(parameters[1]), (u_size, 1) 307 | ) 308 | variance = np.sqrt( 309 | np.random.randint(float(parameters[0]), float(parameters[1]), (u_size, 1)) 310 | ) 311 | weights = ( 312 | np.abs(np.random.normal(0.0, 1.0, (u_size, v_size)) * variance + mean) + 5 313 | ) * adj 314 | elif distribution == "fixed-normal": 315 | adj = nx.bipartite.biadjacency_matrix( 316 | g1, range(0, u_size), range(u_size, u_size + v_size) 317 | ).toarray() 318 | mean = np.random.choice(np.arange(0, 100, 15), size=(u_size, 1)) 319 | variance = np.sqrt(np.random.choice(np.arange(0, 100, 20), (u_size, 1))) 320 | weights = ( 321 | np.abs(np.random.normal(0.0, 1.0, (u_size, v_size)) * variance + mean) + 5 322 | ) * adj 323 | 324 | w = np.delete(weights.flatten(), weights.flatten() == 0) 325 | return weights, w 326 | 327 | def visualize_biparite(G, u, v, fig_path = "temp.png"): 328 | 329 | import matplotlib 330 | import matplotlib.pyplot as plt 331 | # nx.draw(G) 332 | G_left = G.subgraph([i for i in range(u)]) 333 | G_right = G.subgraph([i for i in range(u,u+v)]) 334 | edges = G.edges() 335 | 336 | pos = dict() 337 | pos.update( (n, (1, i)) for i, n in enumerate(G_left) ) # put nodes from X at x=1 338 | pos.update( (n, (2, i)) for i, n in enumerate(G_right) ) # put nodes from Y at x=2 339 | nx.draw_networkx(G, pos=pos) 340 | 341 | plt.savefig(fig_path) 342 | -------------------------------------------------------------------------------- /problem_state/edge_obm_env.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import NamedTuple 3 | from torch_geometric.utils import to_dense_adj 4 | 5 | # from utils.boolmask import mask_long2bool, mask_long_scatter 6 | 7 | 8 | class StateEdgeBipartite(NamedTuple): 9 | # Fixed input 10 | graphs: torch.Tensor # graphs objects in a batch 11 | u_size: int 12 | v_size: int 13 | batch_size: torch.Tensor 14 | hist_sum: torch.Tensor 15 | hist_sum_sq: torch.Tensor 16 | hist_deg: torch.Tensor 17 | min_sol: torch.Tensor 18 | max_sol: torch.Tensor 19 | sum_sol_sq: torch.Tensor 20 | num_skip: torch.Tensor 21 | 22 | adj: torch.Tensor 23 | # State 24 | matched_nodes: torch.Tensor # Keeps track of nodes that have been matched 25 | size: torch.Tensor # size of current matching 26 | i: int # Keeps track of step 27 | opts: dict 28 | idx: torch.Tensor 29 | 30 | @staticmethod 31 | def initialize( 32 | input, 33 | u_size, 34 | v_size, 35 | opts, 36 | ): 37 | graph_size = u_size + v_size + 1 38 | batch_size = int(input.batch.size(0) / graph_size) 39 | # print(batch_size, input.batch.size(0), graph_size) 40 | adj = to_dense_adj( 41 | input.edge_index, input.batch, input.weight.unsqueeze(1) 42 | ).squeeze(-1) 43 | adj = adj[:, u_size + 1 :, : u_size + 1] 44 | 45 | # permute the nodes for data 46 | idx = torch.arange(adj.shape[1], device=opts.device) 47 | # if "supervised" not in opts.model and not opts.eval_only: 48 | # idx = torch.randperm(adj.shape[1], device=opts.device) 49 | # adj = adj[:, idx, :].view(adj.size()) 50 | 51 | return StateEdgeBipartite( 52 | graphs=input, 53 | adj=adj, 54 | u_size=u_size, 55 | v_size=v_size, 56 | batch_size=batch_size, 57 | # Keep visited with depot so we can scatter efficiently (if there is an action for depot) 58 | matched_nodes=( # Visited as mask is easier to understand, as long more memory efficient 59 | torch.zeros( 60 | batch_size, 61 | u_size + 1, 62 | device=input.batch.device, 63 | ) 64 | ), 65 | hist_sum=( # Visited as mask is easier to understand, as long more memory efficient 66 | torch.zeros( 67 | batch_size, 68 | 1, 69 | u_size + 1, 70 | device=input.batch.device, 71 | ) 72 | ), 73 | hist_deg=( # Visited as mask is easier to understand, as long more memory efficient 74 | torch.zeros( 75 | batch_size, 76 | 1, 77 | u_size + 1, 78 | device=input.batch.device, 79 | ) 80 | ), 81 | hist_sum_sq=( # Visited as mask is easier to understand, as long more memory efficient 82 | torch.zeros( 83 | batch_size, 84 | 1, 85 | u_size + 1, 86 | device=input.batch.device, 87 | ) 88 | ), 89 | min_sol=torch.zeros(batch_size, 1, device=input.batch.device), 90 | max_sol=torch.zeros(batch_size, 1, device=input.batch.device), 91 | sum_sol_sq=torch.zeros(batch_size, 1, device=input.batch.device), 92 | num_skip=torch.zeros(batch_size, 1, device=input.batch.device), 93 | size=torch.zeros(batch_size, 1, device=input.batch.device), 94 | i=u_size + 1, 95 | opts=opts, 96 | idx=idx, 97 | ) 98 | 99 | def get_final_cost(self): 100 | 101 | assert self.all_finished() 102 | # assert self.visited_. 103 | 104 | return self.size 105 | 106 | def get_current_weights(self, mask): 107 | return self.adj[:, 0, :].float() 108 | 109 | def get_graph_weights(self): 110 | return self.graphs.weight 111 | 112 | def update(self, selected): 113 | # Update the state 114 | nodes = self.matched_nodes.scatter_(-1, selected, 1) 115 | nodes[ 116 | :, 0 117 | ] = 0 # node that represents not being matched to anything can be matched to more than once 118 | w = self.adj[:, 0, :].clone() 119 | selected_weights = w.gather(1, selected).to(self.adj.device) 120 | skip = (selected == 0).float() 121 | num_skip = self.num_skip + skip 122 | if self.i == self.u_size + 1: 123 | min_sol = selected_weights 124 | else: 125 | m = self.min_sol.clone() 126 | m[m == 0.0] = 2.0 127 | selected_weights[skip.bool()] = 2.0 128 | min_sol = torch.minimum(m, selected_weights) 129 | selected_weights[selected_weights == 2.0] = 0.0 130 | min_sol[min_sol == 2.0] = 0.0 131 | 132 | max_sol = torch.maximum(self.max_sol, selected_weights) 133 | total_weights = self.size + selected_weights 134 | sum_sol_sq = self.sum_sol_sq + selected_weights ** 2 135 | 136 | hist_sum = self.hist_sum + w.unsqueeze(1) 137 | hist_sum_sq = self.hist_sum_sq + w.unsqueeze(1) ** 2 138 | hist_deg = self.hist_deg + (w.unsqueeze(1) != 0).float() 139 | hist_deg[:, :, 0] = float(self.i - self.u_size) 140 | return self._replace( 141 | matched_nodes=nodes, 142 | size=total_weights, 143 | i=self.i + 1, 144 | adj=self.adj[:, 1:, :], 145 | hist_sum=hist_sum, 146 | hist_sum_sq=hist_sum_sq, 147 | hist_deg=hist_deg, 148 | num_skip=num_skip, 149 | max_sol=max_sol, 150 | min_sol=min_sol, 151 | sum_sol_sq=sum_sol_sq, 152 | ) 153 | 154 | def all_finished(self): 155 | # Exactly v_size steps 156 | return (self.i - (self.u_size + 1)) >= self.v_size 157 | 158 | def get_current_node(self): 159 | return 0 160 | 161 | def get_curr_state(self, model): 162 | mask = self.get_mask() 163 | opts = self.opts 164 | w = self.adj[:, 0, :].float().clone() 165 | s = None 166 | if model == "ff": 167 | s = torch.cat((w, mask.float()), dim=1) 168 | 169 | elif model == "inv-ff": 170 | deg = (w != 0).float().sum(1) 171 | deg[deg == 0.0] = 1.0 172 | mean_w = w.sum(1) / deg 173 | mean_w = mean_w[:, None, None].repeat(1, self.u_size + 1, 1) 174 | fixed_node_identity = torch.zeros( 175 | self.batch_size, self.u_size + 1, 1, device=opts.device 176 | ).float() 177 | fixed_node_identity[:, 0, :] = 1.0 178 | s = w.reshape(self.batch_size, self.u_size + 1, 1) 179 | s = torch.cat( 180 | ( 181 | fixed_node_identity, 182 | s, 183 | mean_w, 184 | ), 185 | dim=2, 186 | ) 187 | 188 | elif model == "ff-hist" or model == "ff-supervised": 189 | ( 190 | h_mean, 191 | h_var, 192 | h_mean_degree, 193 | ind, 194 | matched_ratio, 195 | var_sol, 196 | mean_sol, 197 | n_skip, 198 | ) = self.get_hist_features() 199 | s = torch.cat( 200 | ( 201 | w, 202 | mask.float(), 203 | h_mean.squeeze(1), 204 | h_var.squeeze(1), 205 | h_mean_degree.squeeze(1), 206 | self.size / self.u_size, 207 | ind.float(), 208 | mean_sol, 209 | var_sol, 210 | n_skip, 211 | self.max_sol, 212 | self.min_sol, 213 | matched_ratio, 214 | ), 215 | dim=1, 216 | ).float() 217 | 218 | elif model == "inv-ff-hist" or model == "gnn-simp-hist": 219 | deg = (w != 0).float().sum(1) 220 | deg[deg == 0.0] = 1.0 221 | mean_w = w.sum(1) / deg 222 | mean_w = mean_w[:, None, None].repeat(1, self.u_size + 1, 1) 223 | s = w.reshape(self.batch_size, self.u_size + 1, 1) 224 | ( 225 | h_mean, 226 | h_var, 227 | h_mean_degree, 228 | ind, 229 | matched_ratio, 230 | var_sol, 231 | mean_sol, 232 | n_skip, 233 | ) = self.get_hist_features() 234 | available_ratio = (deg.unsqueeze(1)) / (self.u_size) 235 | fixed_node_identity = torch.zeros( 236 | self.batch_size, self.u_size + 1, 1, device=opts.device 237 | ).float() 238 | fixed_node_identity[:, 0, :] = 1.0 239 | s = torch.cat( 240 | ( 241 | s, 242 | mask.reshape(-1, self.u_size + 1, 1).float(), 243 | mean_w, 244 | h_mean.transpose(1, 2), 245 | h_var.transpose(1, 2), 246 | h_mean_degree.transpose(1, 2), 247 | ind.unsqueeze(2).repeat(1, self.u_size + 1, 1), 248 | self.size.unsqueeze(2).repeat(1, self.u_size + 1, 1) / self.u_size, 249 | mean_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1), 250 | var_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1), 251 | n_skip.unsqueeze(2).repeat(1, self.u_size + 1, 1), 252 | self.max_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1), 253 | self.min_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1), 254 | matched_ratio.unsqueeze(2).repeat(1, self.u_size + 1, 1), 255 | available_ratio.unsqueeze(2).repeat(1, self.u_size + 1, 1), 256 | fixed_node_identity, 257 | ), 258 | dim=2, 259 | ).float() 260 | 261 | return s, mask 262 | 263 | def get_node_features(self): 264 | step_size = self.i + 1 265 | batch_size = self.batch_size 266 | incoming_node_features = ( 267 | torch.cat( 268 | (torch.ones(step_size - self.u_size - 1, device=self.adj.device) * 2,) 269 | ) 270 | .unsqueeze(0) 271 | .expand(batch_size, step_size - self.u_size - 1) 272 | ).float() # Collecting node features up until the ith incoming node 273 | 274 | future_node_feature = torch.ones(batch_size, 1, device=self.adj.device) * -1.0 275 | fixed_node_feature = self.matched_nodes[:, 1:] 276 | node_features = torch.cat( 277 | (future_node_feature, fixed_node_feature, incoming_node_features), dim=1 278 | ).reshape(batch_size, step_size) 279 | 280 | return node_features 281 | 282 | def get_hist_features(self): 283 | i = self.i - (self.u_size + 1) 284 | if i != 0: 285 | deg = self.hist_deg.clone() 286 | deg[deg == 0] = 1.0 287 | h_mean = self.hist_sum / deg 288 | h_var = (self.hist_sum_sq - ((self.hist_sum ** 2) / deg)) / deg 289 | h_mean_degree = self.hist_deg / i 290 | ind = ( 291 | torch.ones(self.batch_size, 1, device=self.opts.device) 292 | * i 293 | / self.v_size 294 | ) 295 | curr_sol_size = i - self.num_skip 296 | var_sol = ( 297 | self.sum_sol_sq - ((self.size ** 2) / curr_sol_size) 298 | ) / curr_sol_size 299 | mean_sol = self.size / curr_sol_size 300 | var_sol[curr_sol_size == 0.0] = 0.0 301 | mean_sol[curr_sol_size == 0.0] = 0.0 302 | matched_ratio = self.matched_nodes.sum(1).unsqueeze(1) / self.u_size 303 | n_skip = self.num_skip / i 304 | else: 305 | ( 306 | h_mean, 307 | h_var, 308 | h_mean_degree, 309 | ind, 310 | matched_ratio, 311 | var_sol, 312 | mean_sol, 313 | n_skip, 314 | ) = ( 315 | self.hist_sum * 0.0, 316 | self.hist_sum * 0.0, 317 | self.hist_sum * 0.0, 318 | self.size * 0.0, 319 | self.num_skip * 0.0, 320 | self.size * 0.0, 321 | self.size * 0.0, 322 | self.size * 0.0, 323 | ) 324 | 325 | return ( 326 | h_mean, 327 | h_var, 328 | h_mean_degree, 329 | ind, 330 | matched_ratio, 331 | var_sol, 332 | mean_sol, 333 | n_skip, 334 | ) 335 | 336 | def get_mask(self): 337 | """ 338 | Returns a mask vector which includes only nodes in U that can matched. 339 | That is, neighbors of the incoming node that have not been matched already. 340 | """ 341 | 342 | mask = (self.adj[:, 0, :] == 0).float() 343 | mask[:, 0] = 0 344 | self.matched_nodes[ 345 | :, 0 346 | ] = 0 # node that represents not being matched to anything can be matched to more than once 347 | return ( 348 | self.matched_nodes + mask > 0 349 | ).long() # Hacky way to return bool or uint8 depending on pytorch version 350 | 351 | ##### Works fine in most cases, but didn't unitest this module, please pay 352 | ##### extra attention to the rewards reservation.... 353 | 354 | def try_hypothesis_action(self, selected): 355 | 356 | w = self.adj[:, 0, :].clone() 357 | selected_weights = w.gather(1, selected).to(self.adj.device) 358 | skip = (selected == 0).float() 359 | if self.i != self.u_size + 1: 360 | selected_weights[skip.bool()] = 2.0 361 | selected_weights[selected_weights == 2.0] = 0.0 362 | 363 | total_weights = self.size + selected_weights 364 | 365 | matched_nodes = self.matched_nodes.clone().scatter_(-1, selected, 1) 366 | matched_nodes[:, 0] = 0 # node that represents not being matched to anything can be matched to more than once 367 | 368 | return total_weights, matched_nodes 369 | -------------------------------------------------------------------------------- /data/gMission/workers.txt: -------------------------------------------------------------------------------- 1 | 1,1.6,1.36,1,0.971 2 | 2,0.36,3.32,1,0.641 3 | 3,2.2,3.41,1,0.904 4 | 4,2.45,4.32,1,0.749 5 | 5,3.75,1.1,1,0.944 6 | 6,2.56,2.67,1,0.914 7 | 7,4.38,0.07,1,0.642 8 | 8,2.04,3.25,1,0.926 9 | 9,2.27,1.42,1,0.665 10 | 10,0.23,0.04,1,0.677 11 | 11,4.68,3.38,1,0.916 12 | 12,3.77,0.94,1,0.947 13 | 13,0.48,1.37,1,0.787 14 | 14,1.97,2.45,1,0.733 15 | 15,0.87,3.77,1,0.909 16 | 16,0.85,3.86,1,0.872 17 | 17,2.53,3.52,1,0.911 18 | 18,4.05,3.16,1,0.846 19 | 19,2.13,1.77,1,0.801 20 | 20,0.12,4.23,1,0.763 21 | 21,0.91,0.3,1,0.926 22 | 22,4.87,1.99,1,0.772 23 | 23,1.5,1.6,1,0.753 24 | 24,0.88,3.67,1,0.95 25 | 25,0.97,0.5,1,0.886 26 | 26,3.71,0.15,1,0.791 27 | 27,2.5,4.83,1,0.659 28 | 28,4.9,1.89,1,0.904 29 | 29,3.56,2.76,1,0.764 30 | 30,1.6,3.79,1,0.978 31 | 31,0.45,2.11,1,0.868 32 | 32,4.51,4.79,1,0.909 33 | 33,4.59,4.94,1,0.758 34 | 34,2.51,4.33,1,0.863 35 | 35,2.93,4.66,1,0.981 36 | 36,3.65,4.35,1,0.698 37 | 37,2.74,3.13,1,0.64 38 | 38,4.17,1.63,1,0.928 39 | 39,2.54,2.96,1,0.904 40 | 40,4.46,0.87,1,0.717 41 | 41,4.72,4.69,1,0.946 42 | 42,0.5,3.0,1,0.76 43 | 43,3.92,0.38,1,0.951 44 | 44,0.74,0.62,1,0.976 45 | 45,4.45,1.26,1,0.972 46 | 46,0.2,0.87,1,0.263 47 | 47,1.55,0.1,1,0.687 48 | 48,1.53,2.63,1,0.764 49 | 49,1.62,1.33,1,0.862 50 | 50,3.31,3.18,1,0.775 51 | 51,4.78,4.31,1,0.853 52 | 52,3.52,1.86,1,0.676 53 | 53,0.2,0.55,1,0.82 54 | 54,3.96,0.27,1,0.862 55 | 55,0.99,3.18,1,0.929 56 | 56,2.16,2.32,1,0.979 57 | 57,0.9,3.14,1,0.718 58 | 58,4.73,4.38,1,0.717 59 | 59,0.28,4.66,1,0.894 60 | 60,1.03,1.66,1,0.632 61 | 61,0.93,1.65,1,0.834 62 | 62,2.19,0.82,1,0.775 63 | 63,4.81,0.55,1,0.804 64 | 64,4.6,3.36,1,0.933 65 | 65,1.25,3.89,1,0.71 66 | 66,4.73,4.73,1,0.83 67 | 67,4.58,3.17,1,0.699 68 | 68,4.58,0.24,1,0.989 69 | 69,4.44,1.92,1,0.849 70 | 70,1.48,2.43,1,0.95 71 | 71,2.42,3.01,1,0.693 72 | 72,3.93,2.63,1,0.942 73 | 73,1.85,4.95,1,0.979 74 | 74,3.78,1.09,1,0.871 75 | 75,2.76,2.82,1,0.875 76 | 76,1.84,3.39,1,0.73 77 | 77,4.59,0.94,1,0.825 78 | 78,4.44,3.39,1,0.807 79 | 79,4.92,2.78,1,0.857 80 | 80,1.54,1.42,1,0.877 81 | 81,4.04,1.0,1,0.658 82 | 82,1.61,3.45,1,0.712 83 | 83,1.38,2.3,1,0.757 84 | 84,4.34,4.73,1,0.813 85 | 85,3.77,0.01,1,0.875 86 | 86,0.01,2.22,1,0.849 87 | 87,1.46,4.51,1,0.783 88 | 88,1.89,0.12,1,0.957 89 | 89,0.24,1.28,1,0.774 90 | 90,4.3,1.67,1,0.96 91 | 91,1.28,0.37,1,0.772 92 | 92,4.38,1.61,1,0.657 93 | 93,2.82,1.5,1,0.765 94 | 94,2.81,2.24,1,0.713 95 | 95,4.69,0.19,1,0.695 96 | 96,0.5,1.88,1,0.961 97 | 97,4.54,3.21,1,0.81 98 | 98,0.01,3.24,1,0.951 99 | 99,0.09,2.67,1,0.765 100 | 100,3.46,0.24,1,0.932 101 | 101,0.28,3.26,1,0.68 102 | 102,4.64,4.44,1,0.644 103 | 103,2.76,1.97,1,0.815 104 | 104,4.68,4.61,1,0.83 105 | 105,1.85,3.87,1,0.692 106 | 106,2.37,1.39,1,0.869 107 | 107,0.94,3.29,1,0.771 108 | 108,3.99,4.98,1,0.715 109 | 109,3.85,4.44,1,0.893 110 | 110,4.3,0.03,1,0.78 111 | 111,4.16,2.04,1,0.93 112 | 112,3.45,0.26,1,0.88 113 | 113,1.81,1.11,1,0.96 114 | 114,3.6,4.66,1,0.733 115 | 115,2.61,0.81,1,0.961 116 | 116,0.14,3.22,1,0.666 117 | 117,4.01,4.43,1,0.884 118 | 118,2.11,0.03,1,0.903 119 | 119,0.71,0.4,1,0.801 120 | 120,0.53,0.16,1,0.639 121 | 121,0.78,3.43,1,0.638 122 | 122,0.96,4.02,1,0.911 123 | 123,3.44,4.96,1,0.798 124 | 124,0.7,4.24,1,0.828 125 | 125,4.92,0.28,1,0.39 126 | 126,4.11,0.82,1,0.706 127 | 127,4.69,4.26,1,0.676 128 | 128,4.38,3.57,1,0.738 129 | 129,2.59,3.34,1,0.84 130 | 130,3.2,1.98,1,0.735 131 | 131,0.55,3.34,1,0.768 132 | 132,4.32,1.77,1,0.816 133 | 133,4.53,3.53,1,0.857 134 | 134,3.45,0.79,1,0.833 135 | 135,4.55,0.74,1,0.918 136 | 136,3.25,2.86,1,0.808 137 | 137,3.35,4.37,1,0.769 138 | 138,3.19,2.21,1,0.915 139 | 139,0.56,4.86,1,0.72 140 | 140,4.6,2.58,1,0.685 141 | 141,0.7,4.14,1,0.825 142 | 142,3.08,1.2,1,0.836 143 | 143,4.17,1.05,1,0.706 144 | 144,1.84,2.73,1,0.284 145 | 145,1.44,3.46,1,0.927 146 | 146,0.02,2.58,1,0.297 147 | 147,0.64,0.91,1,0.873 148 | 148,3.02,0.77,1,0.915 149 | 149,3.86,1.49,1,0.716 150 | 150,2.07,0.09,1,0.794 151 | 151,3.78,3.57,1,0.91 152 | 152,3.3,2.2,1,0.929 153 | 153,0.72,3.61,1,0.978 154 | 154,0.39,0.83,1,0.778 155 | 155,3.07,4.11,1,0.788 156 | 156,0.92,3.49,1,0.858 157 | 157,1.07,3.08,1,0.658 158 | 158,3.69,1.06,1,0.811 159 | 159,4.08,3.73,1,0.939 160 | 160,2.47,1.63,1,0.971 161 | 161,2.65,3.39,1,0.761 162 | 162,0.49,0.97,1,0.891 163 | 163,2.83,0.5,1,0.846 164 | 164,1.01,4.84,1,0.834 165 | 165,1.16,4.73,1,0.732 166 | 166,1.55,3.2,1,0.671 167 | 167,2.7,2.88,1,0.917 168 | 168,0.81,3.46,1,0.816 169 | 169,1.01,3.43,1,0.958 170 | 170,4.95,2.63,1,0.947 171 | 171,0.23,3.85,1,0.872 172 | 172,3.95,2.15,1,0.731 173 | 173,2.35,1.75,1,0.635 174 | 174,0.22,2.65,1,0.781 175 | 175,4.36,0.95,1,0.752 176 | 176,1.9,4.79,1,0.894 177 | 177,3.39,4.82,1,0.907 178 | 178,2.52,4.94,1,0.988 179 | 179,3.93,0.75,1,0.937 180 | 180,0.06,3.72,1,0.67 181 | 181,3.95,4.39,1,0.228 182 | 182,4.04,1.89,1,0.809 183 | 183,0.03,3.69,1,0.951 184 | 184,4.82,4.35,1,0.867 185 | 185,4.51,2.25,1,0.906 186 | 186,2.08,0.34,1,0.664 187 | 187,0.55,4.08,1,0.828 188 | 188,3.3,2.34,1,0.979 189 | 189,4.79,1.92,1,0.633 190 | 190,2.66,1.96,1,0.657 191 | 191,4.31,3.51,1,0.932 192 | 192,0.1,1.52,1,0.923 193 | 193,1.29,1.43,1,0.808 194 | 194,0.46,1.95,1,0.88 195 | 195,2.8,0.95,1,0.979 196 | 196,1.56,4.59,1,0.653 197 | 197,4.56,1.97,1,0.786 198 | 198,1.07,3.59,1,0.645 199 | 199,0.79,2.23,1,0.986 200 | 200,0.48,1.39,1,0.865 201 | 201,4.65,2.43,1,0.933 202 | 202,4.58,4.89,1,0.668 203 | 203,2.01,4.24,1,0.841 204 | 204,4.2,3.1,1,0.762 205 | 205,3.18,0.78,1,0.633 206 | 206,4.76,4.46,1,0.972 207 | 207,1.21,2.83,1,0.652 208 | 208,2.45,4.48,1,0.832 209 | 209,3.65,0.55,1,0.318 210 | 210,1.39,2.57,1,0.71 211 | 211,4.52,2.94,1,0.806 212 | 212,3.77,3.03,1,0.903 213 | 213,0.53,2.4,1,0.916 214 | 214,3.03,1.06,1,0.983 215 | 215,0.83,3.95,1,0.887 216 | 216,2.24,2.15,1,0.917 217 | 217,3.46,3.96,1,0.973 218 | 218,3.76,4.13,1,0.946 219 | 219,4.64,3.31,1,0.762 220 | 220,1.6,2.45,1,0.837 221 | 221,1.6,0.87,1,0.219 222 | 222,2.81,0.88,1,0.823 223 | 223,4.95,3.38,1,0.93 224 | 224,0.91,3.1,1,0.685 225 | 225,0.08,1.17,1,0.782 226 | 226,4.9,4.36,1,0.762 227 | 227,1.28,0.99,1,0.883 228 | 228,4.53,1.92,1,0.824 229 | 229,2.44,4.15,1,0.787 230 | 230,1.14,1.2,1,0.844 231 | 231,4.23,0.79,1,0.678 232 | 232,0.02,4.2,1,0.321 233 | 233,1.29,4.25,1,0.916 234 | 234,0.56,4.06,1,0.893 235 | 235,0.83,4.08,1,0.648 236 | 236,2.96,0.98,1,0.731 237 | 237,3.73,0.43,1,0.77 238 | 238,4.84,0.77,1,0.694 239 | 239,1.82,1.15,1,0.392 240 | 240,2.61,0.12,1,0.645 241 | 241,1.07,0.59,1,0.983 242 | 242,3.04,4.4,1,0.978 243 | 243,0.81,0.48,1,0.784 244 | 244,2.42,0.33,1,0.657 245 | 245,0.93,2.21,1,0.89 246 | 246,0.64,0.48,1,0.931 247 | 247,4.89,4.1,1,0.821 248 | 248,0.21,2.17,1,0.744 249 | 249,2.64,4.36,1,0.803 250 | 250,0.72,1.69,1,0.835 251 | 251,1.45,4.15,1,0.404 252 | 252,2.32,4.5,1,0.65 253 | 253,0.39,3.32,1,0.928 254 | 254,0.98,1.67,1,0.931 255 | 255,0.99,3.37,1,0.89 256 | 256,3.76,2.18,1,0.872 257 | 257,1.19,1.0,1,0.908 258 | 258,3.59,4.56,1,0.831 259 | 259,0.61,3.1,1,0.637 260 | 260,1.02,2.64,1,0.974 261 | 261,0.35,3.97,1,0.953 262 | 262,0.23,4.49,1,0.858 263 | 263,4.62,2.16,1,0.844 264 | 264,5.0,2.38,1,0.813 265 | 265,0.07,2.84,1,0.635 266 | 266,2.22,4.42,1,0.883 267 | 267,3.94,0.01,1,0.92 268 | 268,1.02,3.38,1,0.858 269 | 269,0.88,0.33,1,0.963 270 | 270,1.11,0.82,1,0.806 271 | 271,3.06,2.5,1,0.827 272 | 272,0.51,4.13,1,0.756 273 | 273,1.94,2.41,1,0.632 274 | 274,0.4,2.25,1,0.778 275 | 275,3.08,3.22,1,0.721 276 | 276,3.36,2.71,1,0.87 277 | 277,4.54,2.61,1,0.723 278 | 278,2.71,2.44,1,0.952 279 | 279,0.32,0.78,1,0.702 280 | 280,2.59,2.19,1,0.933 281 | 281,2.76,4.42,1,0.726 282 | 282,1.33,4.01,1,0.745 283 | 283,1.35,3.66,1,0.911 284 | 284,3.97,2.05,1,0.839 285 | 285,0.95,0.42,1,0.75 286 | 286,4.69,2.73,1,0.669 287 | 287,2.15,0.51,1,0.334 288 | 288,4.56,3.53,1,0.989 289 | 289,0.57,2.84,1,0.852 290 | 290,1.34,4.68,1,0.809 291 | 291,1.86,4.38,1,0.637 292 | 292,3.12,4.38,1,0.654 293 | 293,2.21,3.09,1,0.79 294 | 294,1.27,2.01,1,0.882 295 | 295,3.24,0.98,1,0.907 296 | 296,1.19,3.28,1,0.659 297 | 297,1.15,1.61,1,0.695 298 | 298,0.16,1.8,1,0.758 299 | 299,3.5,4.56,1,0.724 300 | 300,3.97,2.79,1,0.639 301 | 301,4.13,2.09,1,0.84 302 | 302,4.74,3.71,1,0.705 303 | 303,0.58,4.19,1,0.84 304 | 304,3.34,0.76,1,0.861 305 | 305,4.95,1.01,1,0.643 306 | 306,3.04,4.31,1,0.679 307 | 307,4.89,3.43,1,0.868 308 | 308,4.37,1.31,1,0.989 309 | 309,1.76,3.61,1,0.727 310 | 310,1.25,4.81,1,0.882 311 | 311,0.13,2.76,1,0.842 312 | 312,1.55,2.49,1,0.673 313 | 313,0.24,3.0,1,0.853 314 | 314,3.07,0.85,1,0.918 315 | 315,4.04,1.54,1,0.814 316 | 316,0.06,3.68,1,0.963 317 | 317,1.93,0.35,1,0.96 318 | 318,2.38,1.76,1,0.662 319 | 319,1.11,3.09,1,0.661 320 | 320,2.17,4.74,1,0.714 321 | 321,2.05,3.88,1,0.845 322 | 322,1.33,2.54,1,0.961 323 | 323,3.58,2.6,1,0.703 324 | 324,2.6,1.27,1,0.796 325 | 325,0.5,3.44,1,0.638 326 | 326,1.08,4.43,1,0.635 327 | 327,4.09,2.15,1,0.755 328 | 328,3.85,3.64,1,0.747 329 | 329,2.11,3.89,1,0.634 330 | 330,3.88,0.24,1,0.788 331 | 331,1.47,3.26,1,0.778 332 | 332,2.68,2.23,1,0.831 333 | 333,3.01,4.34,1,0.92 334 | 334,0.65,4.99,1,0.765 335 | 335,0.43,2.64,1,0.741 336 | 336,0.51,3.04,1,0.912 337 | 337,2.96,0.93,1,0.848 338 | 338,1.09,2.72,1,0.841 339 | 339,4.17,2.36,1,0.867 340 | 340,2.92,2.24,1,0.945 341 | 341,0.74,3.42,1,0.918 342 | 342,1.75,1.05,1,0.832 343 | 343,4.2,1.38,1,0.807 344 | 344,1.78,1.33,1,0.98 345 | 345,2.99,4.55,1,0.723 346 | 346,1.94,0.39,1,0.73 347 | 347,3.66,3.55,1,0.673 348 | 348,1.13,3.74,1,0.875 349 | 349,4.37,3.41,1,0.81 350 | 350,2.17,1.52,1,0.93 351 | 351,3.76,3.99,1,0.773 352 | 352,2.85,2.48,1,0.722 353 | 353,3.4,2.59,1,0.652 354 | 354,2.3,4.82,1,0.74 355 | 355,3.18,1.49,1,0.656 356 | 356,4.68,2.36,1,0.724 357 | 357,3.12,3.55,1,0.967 358 | 358,4.03,1.33,1,0.815 359 | 359,2.21,4.2,1,0.756 360 | 360,4.45,3.11,1,0.821 361 | 361,0.09,2.69,1,0.843 362 | 362,3.24,1.1,1,0.774 363 | 363,1.76,2.4,1,0.896 364 | 364,4.18,3.3,1,0.877 365 | 365,2.81,2.65,1,0.695 366 | 366,3.17,2.79,1,0.972 367 | 367,3.4,0.29,1,0.986 368 | 368,0.52,4.51,1,0.631 369 | 369,2.54,0.2,1,0.95 370 | 370,2.75,3.68,1,0.924 371 | 371,0.19,0.76,1,0.81 372 | 372,3.99,3.2,1,0.788 373 | 373,4.02,4.81,1,0.778 374 | 374,0.44,1.57,1,0.335 375 | 375,0.86,1.4,1,0.967 376 | 376,1.46,4.67,1,0.834 377 | 377,2.74,0.82,1,0.902 378 | 378,2.18,4.36,1,0.851 379 | 379,4.89,0.34,1,0.728 380 | 380,4.5,3.27,1,0.862 381 | 381,1.74,3.87,1,0.78 382 | 382,4.58,0.19,1,0.884 383 | 383,3.49,0.31,1,0.714 384 | 384,4.76,3.79,1,0.965 385 | 385,3.48,3.75,1,0.942 386 | 386,2.4,4.65,1,0.95 387 | 387,1.63,1.21,1,0.642 388 | 388,1.32,2.25,1,0.959 389 | 389,1.49,1.88,1,0.856 390 | 390,0.78,3.36,1,0.945 391 | 391,2.49,0.03,1,0.693 392 | 392,1.49,4.11,1,0.981 393 | 393,0.26,3.6,1,0.964 394 | 394,3.42,2.65,1,0.852 395 | 395,2.18,0.02,1,0.874 396 | 396,0.5,2.14,1,0.654 397 | 397,4.3,3.76,1,0.815 398 | 398,1.21,1.84,1,0.711 399 | 399,3.32,4.61,1,0.637 400 | 400,2.34,2.81,1,0.8 401 | 401,1.92,3.23,1,0.885 402 | 402,3.31,1.86,1,0.789 403 | 403,0.03,4.11,1,0.929 404 | 404,4.23,2.72,1,0.711 405 | 405,1.05,2.19,1,0.742 406 | 406,3.12,3.04,1,0.987 407 | 407,4.06,2.09,1,0.762 408 | 408,3.98,0.56,1,0.956 409 | 409,0.5,3.15,1,0.769 410 | 410,1.12,3.22,1,0.711 411 | 411,3.51,4.82,1,0.728 412 | 412,1.63,3.95,1,0.956 413 | 413,1.35,1.04,1,0.946 414 | 414,4.51,1.16,1,0.387 415 | 415,3.64,4.39,1,0.747 416 | 416,2.5,4.27,1,0.883 417 | 417,2.31,4.59,1,0.841 418 | 418,2.11,1.79,1,0.695 419 | 419,2.55,2.68,1,0.98 420 | 420,2.14,4.69,1,0.925 421 | 421,0.14,4.77,1,0.969 422 | 422,2.59,2.72,1,0.852 423 | 423,2.74,4.7,1,0.984 424 | 424,4.2,0.16,1,0.966 425 | 425,4.83,0.31,1,0.754 426 | 426,0.83,1.68,1,0.904 427 | 427,1.84,2.26,1,0.873 428 | 428,4.52,3.15,1,0.764 429 | 429,3.41,3.99,1,0.851 430 | 430,3.64,3.69,1,0.652 431 | 431,3.59,4.16,1,0.906 432 | 432,1.17,0.54,1,0.643 433 | 433,0.54,0.99,1,0.864 434 | 434,4.24,0.36,1,0.933 435 | 435,2.56,2.71,1,0.836 436 | 436,4.84,4.96,1,0.703 437 | 437,2.28,4.76,1,0.901 438 | 438,2.11,1.19,1,0.934 439 | 439,0.64,2.75,1,0.856 440 | 440,4.26,4.16,1,0.87 441 | 441,3.45,3.92,1,0.645 442 | 442,3.7,3.22,1,0.936 443 | 443,4.54,4.31,1,0.982 444 | 444,4.58,3.52,1,0.739 445 | 445,1.58,1.18,1,0.949 446 | 446,4.65,2.24,1,0.838 447 | 447,2.75,3.97,1,0.9 448 | 448,0.78,3.0,1,0.889 449 | 449,4.7,4.25,1,0.756 450 | 450,4.36,1.93,1,0.876 451 | 451,1.4,3.68,1,0.663 452 | 452,3.2,4.84,1,0.835 453 | 453,1.18,2.08,1,0.739 454 | 454,0.74,3.55,1,0.813 455 | 455,2.56,2.15,1,0.37 456 | 456,3.59,3.41,1,0.883 457 | 457,2.98,4.94,1,0.73 458 | 458,0.12,4.85,1,0.824 459 | 459,3.52,3.83,1,0.967 460 | 460,2.72,1.31,1,0.874 461 | 461,4.33,4.18,1,0.237 462 | 462,3.04,1.15,1,0.796 463 | 463,3.8,4.79,1,0.8 464 | 464,2.71,1.98,1,0.633 465 | 465,3.39,1.49,1,0.909 466 | 466,3.32,0.26,1,0.691 467 | 467,3.8,0.74,1,0.839 468 | 468,1.5,1.33,1,0.654 469 | 469,2.04,4.56,1,0.969 470 | 470,3.15,0.14,1,0.688 471 | 471,1.07,2.0,1,0.802 472 | 472,3.25,1.95,1,0.984 473 | 473,3.95,4.34,1,0.892 474 | 474,1.4,3.76,1,0.917 475 | 475,2.08,1.74,1,0.648 476 | 476,0.85,4.82,1,0.7 477 | 477,0.4,1.58,1,0.67 478 | 478,3.57,0.53,1,0.777 479 | 479,4.03,1.46,1,0.321 480 | 480,4.09,3.06,1,0.976 481 | 481,4.93,4.05,1,0.813 482 | 482,1.18,2.31,1,0.827 483 | 483,4.43,2.05,1,0.731 484 | 484,2.87,1.01,1,0.694 485 | 485,2.49,0.4,1,0.856 486 | 486,2.84,1.24,1,0.69 487 | 487,3.94,1.52,1,0.647 488 | 488,4.28,2.37,1,0.796 489 | 489,1.24,2.65,1,0.981 490 | 490,1.35,1.05,1,0.742 491 | 491,3.33,2.46,1,0.983 492 | 492,2.2,2.25,1,0.711 493 | 493,3.71,2.04,1,0.81 494 | 494,2.48,3.01,1,0.74 495 | 495,4.45,0.01,1,0.739 496 | 496,3.0,1.35,1,0.822 497 | 497,1.92,3.89,1,0.823 498 | 498,1.46,2.31,1,0.658 499 | 499,4.46,2.56,1,0.943 500 | 500,0.38,3.51,1,0.736 501 | 501,1.77,2.08,1,0.638 502 | 502,4.11,0.58,1,0.796 503 | 503,0.91,3.24,1,0.984 504 | 504,3.29,0.27,1,0.881 505 | 505,2.93,4.44,1,0.68 506 | 506,4.99,0.98,1,0.834 507 | 507,1.19,1.49,1,0.657 508 | 508,0.65,4.84,1,0.922 509 | 509,4.37,3.66,1,0.248 510 | 510,2.28,3.29,1,0.696 511 | 511,3.07,4.13,1,0.734 512 | 512,2.8,1.64,1,0.893 513 | 513,3.55,4.69,1,0.835 514 | 514,4.75,2.87,1,0.952 515 | 515,3.88,1.21,1,0.858 516 | 516,2.46,1.66,1,0.856 517 | 517,2.96,2.22,1,0.808 518 | 518,2.8,2.86,1,0.849 519 | 519,0.62,4.39,1,0.679 520 | 520,1.98,0.18,1,0.85 521 | 521,2.13,2.36,1,0.831 522 | 522,0.78,0.97,1,0.855 523 | 523,1.63,3.61,1,0.689 524 | 524,0.56,3.98,1,0.959 525 | 525,3.92,1.55,1,0.83 526 | 526,2.4,1.1,1,0.746 527 | 527,0.12,4.63,1,0.89 528 | 528,0.6,2.4,1,0.909 529 | 529,3.99,1.64,1,0.95 530 | 530,1.58,1.84,1,0.713 531 | 531,1.1,0.9,1,0.905 532 | 532,0.46,2.42,1,0.769 533 | -------------------------------------------------------------------------------- /problem_state/adwords_env.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import NamedTuple 3 | from torch_geometric.utils import to_dense_adj 4 | import torch.nn.functional as F 5 | 6 | # from utils.boolmask import mask_long2bool, mask_long_scatter 7 | 8 | 9 | class StateAdwordsBipartite(NamedTuple): 10 | # Fixed input 11 | graphs: torch.Tensor # graphs objects in a batch 12 | u_size: int 13 | v_size: int 14 | batch_size: torch.Tensor 15 | hist_sum: torch.Tensor 16 | hist_sum_sq: torch.Tensor 17 | hist_deg: torch.Tensor 18 | min_sol: torch.Tensor 19 | max_sol: torch.Tensor 20 | sum_sol_sq: torch.Tensor 21 | num_skip: torch.Tensor 22 | orig_budget: torch.Tensor 23 | curr_budget: torch.Tensor 24 | 25 | adj: torch.Tensor 26 | # State 27 | size: torch.Tensor # size of current matching 28 | i: int # Keeps track of step 29 | opts: dict 30 | idx: torch.Tensor 31 | 32 | @staticmethod 33 | def initialize( 34 | input, 35 | u_size, 36 | v_size, 37 | opts, 38 | ): 39 | graph_size = u_size + v_size + 1 40 | batch_size = int(input.batch.size(0) / graph_size) 41 | # print(batch_size, input.batch.size(0), graph_size) 42 | adj = to_dense_adj( 43 | input.edge_index, 44 | input.batch, 45 | input.weight.unsqueeze(1), 46 | ).squeeze(-1) 47 | adj = adj[:, u_size + 1 :, : u_size + 1] 48 | budgets = torch.cat( 49 | (torch.zeros(batch_size, 1).to(opts.device), input.x.reshape(batch_size, -1)), dim=1 50 | ) 51 | # print(adj) 52 | # print(budgets) 53 | # permute the nodes for data 54 | idx = torch.arange(adj.shape[1], device=opts.device) 55 | # if "supervised" not in opts.model and not opts.eval_only: 56 | # idx = torch.randperm(adj.shape[1], device=opts.device) 57 | # adj = adj[:, idx, :].view(adj.size()) 58 | 59 | return StateAdwordsBipartite( 60 | graphs=input, 61 | adj=adj.float(), 62 | u_size=u_size, 63 | v_size=v_size, 64 | batch_size=batch_size, 65 | # Keep visited with depot so we can scatter efficiently (if there is an action for depot) 66 | orig_budget=budgets, 67 | curr_budget=budgets.clone(), 68 | hist_sum=( 69 | torch.zeros( 70 | batch_size, 71 | 1, 72 | u_size + 1, 73 | device=input.batch.device, 74 | ) 75 | ), 76 | hist_deg=( 77 | torch.zeros( 78 | batch_size, 79 | 1, 80 | u_size + 1, 81 | device=input.batch.device, 82 | ) 83 | ), 84 | hist_sum_sq=( # Visited as mask is easier to understand, as long more memory efficient 85 | torch.zeros( 86 | batch_size, 87 | 1, 88 | u_size + 1, 89 | device=input.batch.device, 90 | ) 91 | ), 92 | min_sol=torch.zeros(batch_size, 1, device=input.batch.device), 93 | max_sol=torch.zeros(batch_size, 1, device=input.batch.device), 94 | sum_sol_sq=torch.zeros(batch_size, 1, device=input.batch.device), 95 | num_skip=torch.zeros(batch_size, 1, device=input.batch.device), 96 | size=torch.zeros(batch_size, 1, device=input.batch.device), 97 | i=u_size + 1, 98 | opts=opts, 99 | idx=idx, 100 | ) 101 | 102 | def get_final_cost(self): 103 | 104 | assert self.all_finished() 105 | # assert self.visited_. 106 | 107 | return self.size 108 | 109 | def get_current_weights(self, mask): 110 | return self.adj[:, 0, :].float() 111 | 112 | def get_graph_weights(self): 113 | return self.graphs.weight 114 | 115 | def update(self, selected): 116 | # Update the state 117 | w = self.adj[:, 0, :].clone() 118 | selected_weights = w.gather(1, selected).to(self.adj.device) 119 | one_hot_w = ( 120 | F.one_hot(selected, num_classes=self.u_size + 1) 121 | .to(self.adj.device) 122 | .squeeze(1) 123 | .float() 124 | * selected_weights 125 | ) 126 | curr_budget = self.curr_budget - one_hot_w 127 | curr_budget[curr_budget < 0.0] = 0.0 128 | skip = (selected == 0).float() 129 | num_skip = self.num_skip + skip 130 | if self.i == self.u_size + 1: 131 | min_sol = selected_weights 132 | else: 133 | m = self.min_sol.clone() 134 | m[m == 0.0] = 2.0 135 | selected_weights[skip.bool()] = 2.0 136 | min_sol = torch.minimum(m, selected_weights) 137 | selected_weights[selected_weights == 2.0] = 0.0 138 | min_sol[min_sol == 2.0] = 0.0 139 | 140 | max_sol = torch.maximum(self.max_sol, selected_weights) 141 | total_weights = self.size + selected_weights 142 | sum_sol_sq = self.sum_sol_sq + selected_weights ** 2 143 | 144 | hist_sum = self.hist_sum + w.unsqueeze(1) 145 | hist_sum_sq = self.hist_sum_sq + w.unsqueeze(1) ** 2 146 | hist_deg = self.hist_deg + (w.unsqueeze(1) != 0).float() 147 | hist_deg[:, :, 0] = float(self.i - self.u_size) 148 | return self._replace( 149 | curr_budget=curr_budget, 150 | size=total_weights, 151 | i=self.i + 1, 152 | adj=self.adj[:, 1:, :], 153 | hist_sum=hist_sum, 154 | hist_sum_sq=hist_sum_sq, 155 | hist_deg=hist_deg, 156 | num_skip=num_skip, 157 | max_sol=max_sol, 158 | min_sol=min_sol, 159 | sum_sol_sq=sum_sol_sq, 160 | ) 161 | 162 | def all_finished(self): 163 | # Exactly v_size steps 164 | return (self.i - (self.u_size + 1)) >= self.v_size 165 | 166 | def get_current_node(self): 167 | return 0 168 | 169 | def get_curr_state(self, model): 170 | mask = self.get_mask() 171 | opts = self.opts 172 | w = self.adj[:, 0, :].float().clone() 173 | s = None 174 | if model == "ff": 175 | s = torch.cat((w, self.curr_budget, mask.float()), dim=1).float() 176 | elif model == "inv-ff": 177 | deg = (w != 0).float().sum(1) 178 | deg[deg == 0.0] = 1.0 179 | mean_w = w.sum(1) / deg 180 | mean_budget = self.curr_budget.sum(1) / self.u_size 181 | mean_budget = mean_budget[:, None, None].repeat(1, self.u_size + 1, 1) 182 | mean_w = mean_w[:, None, None].repeat(1, self.u_size + 1, 1) 183 | fixed_node_identity = torch.zeros( 184 | self.batch_size, self.u_size + 1, 1, device=opts.device 185 | ).float() 186 | fixed_node_identity[:, 0, :] = 1.0 187 | s = w.reshape(self.batch_size, self.u_size + 1, 1) 188 | s = torch.cat( 189 | ( 190 | fixed_node_identity, 191 | self.curr_budget.reshape(self.batch_size, self.u_size + 1, 1), 192 | s, 193 | mean_w, 194 | mean_budget, 195 | ), 196 | dim=2, 197 | ).float() 198 | 199 | elif model == "ff-hist" or model == "ff-supervised": 200 | ( 201 | h_mean, 202 | h_var, 203 | h_mean_degree, 204 | ind, 205 | matched_ratio, 206 | var_sol, 207 | mean_sol, 208 | n_skip, 209 | ) = self.get_hist_features() 210 | s = torch.cat( 211 | ( 212 | w, 213 | mask.float(), 214 | self.orig_budget, 215 | self.curr_budget, 216 | h_mean.squeeze(1), 217 | h_var.squeeze(1), 218 | h_mean_degree.squeeze(1), 219 | self.size / self.orig_budget.sum(-1).unsqueeze(1), 220 | ind.float(), 221 | mean_sol, 222 | var_sol, 223 | n_skip, 224 | self.max_sol, 225 | self.min_sol, 226 | matched_ratio, 227 | ), 228 | dim=1, 229 | ).float() 230 | 231 | elif model == "inv-ff-hist" or model == "gnn-simp-hist": 232 | deg = (w != 0).float().sum(1) 233 | deg[deg == 0.0] = 1.0 234 | mean_w = w.sum(1) / deg 235 | mean_budget = self.curr_budget.sum(1) / self.u_size 236 | mean_w = mean_w[:, None, None].repeat(1, self.u_size + 1, 1) 237 | mean_budget = mean_budget[:, None, None].repeat(1, self.u_size + 1, 1) 238 | s = w.reshape(self.batch_size, self.u_size + 1, 1) 239 | ( 240 | h_mean, 241 | h_var, 242 | h_mean_degree, 243 | ind, 244 | matched_ratio, 245 | var_sol, 246 | mean_sol, 247 | n_skip, 248 | ) = self.get_hist_features() 249 | available_ratio = (deg.unsqueeze(1)) / (self.u_size) 250 | fixed_node_identity = torch.zeros( 251 | self.batch_size, self.u_size + 1, 1, device=opts.device 252 | ).float() 253 | fixed_node_identity[:, 0, :] = 1.0 254 | s = torch.cat( 255 | ( 256 | s, 257 | mask.reshape(-1, self.u_size + 1, 1).float(), 258 | self.orig_budget.reshape(self.batch_size, self.u_size + 1, 1), 259 | self.curr_budget.reshape(self.batch_size, self.u_size + 1, 1), 260 | mean_w, 261 | mean_budget, 262 | h_mean.transpose(1, 2), 263 | h_var.transpose(1, 2), 264 | h_mean_degree.transpose(1, 2), 265 | ind.unsqueeze(2).repeat(1, self.u_size + 1, 1), 266 | self.size.unsqueeze(2).repeat(1, self.u_size + 1, 1) 267 | / self.orig_budget.sum(-1)[:, None, None], 268 | mean_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1), 269 | var_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1), 270 | n_skip.unsqueeze(2).repeat(1, self.u_size + 1, 1), 271 | self.max_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1), 272 | self.min_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1), 273 | matched_ratio.unsqueeze(2).repeat(1, self.u_size + 1, 1), 274 | available_ratio.unsqueeze(2).repeat(1, self.u_size + 1, 1), 275 | fixed_node_identity, 276 | ), 277 | dim=2, 278 | ).float() 279 | 280 | return s, mask 281 | 282 | def get_node_features(self): 283 | step_size = self.i + 1 284 | batch_size = self.batch_size 285 | incoming_node_features = ( 286 | torch.cat( 287 | ( 288 | torch.ones(step_size - self.u_size - 1, device=self.adj.device) 289 | * -2.0, 290 | ) 291 | ) 292 | .unsqueeze(0) 293 | .expand(batch_size, step_size - self.u_size - 1) 294 | ).float() # Collecting node features up until the ith incoming node 295 | 296 | future_node_feature = torch.ones(batch_size, 1, device=self.adj.device) * -1.0 297 | fixed_node_feature = self.curr_budget[:, 1:] 298 | node_features = torch.cat( 299 | (future_node_feature, fixed_node_feature, incoming_node_features), dim=1 300 | ).reshape(batch_size, step_size) 301 | 302 | return node_features.float() 303 | 304 | def get_hist_features(self): 305 | i = self.i - (self.u_size + 1) 306 | if i != 0: 307 | deg = self.hist_deg.clone() 308 | deg[deg == 0] = 1.0 309 | h_mean = self.hist_sum / deg 310 | h_var = (self.hist_sum_sq - ((self.hist_sum ** 2) / deg)) / deg 311 | h_mean_degree = self.hist_deg / i 312 | ind = ( 313 | torch.ones(self.batch_size, 1, device=self.opts.device) 314 | * i 315 | / self.v_size 316 | ) 317 | curr_sol_size = i - self.num_skip 318 | var_sol = ( 319 | self.sum_sol_sq - ((self.size ** 2) / curr_sol_size) 320 | ) / curr_sol_size 321 | mean_sol = self.size / curr_sol_size 322 | var_sol[curr_sol_size == 0.0] = 0.0 323 | mean_sol[curr_sol_size == 0.0] = 0.0 324 | avg_budget = self.curr_budget.sum(1).unsqueeze(1) / self.u_size 325 | n_skip = self.num_skip / i 326 | else: 327 | ( 328 | h_mean, 329 | h_var, 330 | h_mean_degree, 331 | ind, 332 | avg_budget, 333 | var_sol, 334 | mean_sol, 335 | n_skip, 336 | ) = ( 337 | self.hist_sum * 0.0, 338 | self.hist_sum * 0.0, 339 | self.hist_sum * 0.0, 340 | self.size * 0.0, 341 | self.curr_budget.sum(1).unsqueeze(1) / self.u_size, 342 | self.size * 0.0, 343 | self.size * 0.0, 344 | self.size * 0.0, 345 | ) 346 | 347 | return ( 348 | h_mean, 349 | h_var, 350 | h_mean_degree, 351 | ind, 352 | avg_budget, 353 | var_sol, 354 | mean_sol, 355 | n_skip, 356 | ) 357 | 358 | def get_mask(self): 359 | """ 360 | Returns a mask vector which includes only nodes in U that can matched. 361 | That is, neighbors of the incoming node that have not been matched already. 362 | """ 363 | 364 | mask = (self.adj[:, 0, :] == 0.0).float() 365 | mask[:, 0] = 0.0 366 | budget_mask = ((self.adj[:, 0, :] - self.curr_budget) > 1e-5).float() 367 | return ( 368 | budget_mask + mask > 0.0 369 | ).long() # Hacky way to return bool or uint8 depending on pytorch version 370 | --------------------------------------------------------------------------------