├── data
├── __init__.py
├── gMission
│ ├── reduced_workers.txt
│ ├── reduced_tasks.txt
│ └── workers.txt
├── MovieLense
│ ├── movies.txt
│ └── users.txt
└── data_utils.py
├── figures
├── table1.png
└── histogram.png
├── .gitignore
├── requirements.txt
├── LICENSE
├── INSTALL.md
├── utils
├── visualize.py
├── log_utils.py
├── functions.py
└── reinforce_baselines.py
├── policy
├── balance.py
├── msvv.py
├── greedy_theshold.py
├── greedy_sc.py
├── greedy.py
├── simple_greedy.py
├── greedy_rt.py
├── inv_ff_history.py
├── greedy_matching.py
├── ff_model_invariant.py
├── ff_model_hist.py
├── ff_model.py
├── ff_supervised.py
├── gnn_simp_hist.py
├── gnn.py
├── inv_ff_history_switch.py
└── gnn_hist.py
├── problem_state
├── osbm_dataset.py
├── adwords_dataset.py
├── edge_obm_dataset.py
├── obm_dataset.py
├── obm_env.py
├── edge_obm_env.py
└── adwords_env.py
├── encoder
└── graph_encoder.py
├── README.md
├── TUTORIAL.md
├── IPsolvers
└── IPsolver.py
└── run_lite.py
/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/figures/table1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ren-Research/LOMAR/HEAD/figures/table1.png
--------------------------------------------------------------------------------
/figures/histogram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ren-Research/LOMAR/HEAD/figures/histogram.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | dataset/**
3 | dataset
4 | data/__pycache__/**
5 | policy/__pycache__/**
6 |
7 | # Output folders
8 | logs_dataset/**
9 | saved_models/**
10 | saved_models
11 |
12 |
13 | **.pyc
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib>=3.2.1
2 | networkx>=2.4
3 | numpy>=1.19.0
4 | pylint>=2.5.3
5 | scipy>=1.4.0
6 | tensorboard>=2.2.2
7 | tensorboard-plugin-wit>=1.7.0
8 | torch>=1.11.0
9 | torchvision>=0.8.1
10 | tqdm>=4.47.0
11 | seaborn
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Ren-Research
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installation Guide
2 |
3 | ## Install Online Bipartite Matching code base
4 | ```bash
5 |
6 | conda create --name renlab_obm python==3.7.0
7 | conda activate renlab_obm
8 | # Install pytorch
9 | pip install torch==1.11.0 torchvision
10 | # Install Pytorch Geometric
11 | pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.11.0+cu102.html
12 | # CPU-only version
13 | # pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.11.0+cpu.html
14 |
15 | # Install gurobi
16 | pip install gurobipy
17 | # Download seperate license verification tool.
18 | wget https://packages.gurobi.com/lictools/licensetools9.5.1_linux64.tar.gz
19 | # Please obtain a gurobi key before proceeding, replace the `your_gurobi_key` with your key
20 | grbgetkey your_gurobi_key
21 |
22 | # Install other requirements
23 | pip install -r requirements.txt
24 |
25 | pip install wandb
26 |
27 | pip install -v -e .
28 | ```
29 |
30 | ## Request an GPU node in Slurm Cluster
31 |
32 | ```shell
33 | srun -p gpu --gres=gpu:1 --mem=100g --time=48:00:00 --pty bash -l
34 | ```
35 |
36 |
37 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | from train import rollout
5 | import torch
6 |
7 | def visualize_reward_process(reward_array, fig_path, ylim = None):
8 | fig, ax = plt.subplots(1)
9 | x_array = np.arange(reward_array.shape[1])
10 | for k in range(reward_array.shape[0]):
11 | ax.plot(x_array, reward_array[k, :], label = "sample_{}".format(k))
12 |
13 | if ylim is not None:
14 | ax.set_ylim(ylim)
15 | plt.legend()
16 | fig.show()
17 | plt.savefig(fig_path)
18 | plt.close()
19 |
20 | def validate_histogram(models, model_names, dataset, opts):
21 | assert len(models) == len(model_names)
22 | num_models = len(models)
23 |
24 | cost_list = []
25 | cr_list = []
26 | for i in range(num_models):
27 | model = models[i]
28 | model_name = model_names[i]
29 | cost, cr, _ = rollout(model, dataset, opts)
30 | print("{} Validation overall avg_cost: {} +- {}".format(model_name,
31 | cost.mean(), torch.std(cost)))
32 |
33 | cost_list.append(cost)
34 | cr_list.append(cr)
35 |
36 | bins = 30
37 | fig, ax = plt.subplots(1)
38 | for i in range(num_models):
39 | model_name = model_names[i]
40 | cost = cost_list[i]
41 | cost = -cost.view(-1).cpu().numpy()
42 | ax.hist(cost, bins = bins, density=True, alpha = 0.4, label = model_name)
43 |
44 | plt.legend()
45 | fig.show()
46 | plt.savefig("temp_hist_cost.png")
47 | plt.close()
48 |
49 |
--------------------------------------------------------------------------------
/utils/log_utils.py:
--------------------------------------------------------------------------------
1 | def log_values(
2 | cost,
3 | epoch,
4 | batch_id,
5 | step,
6 | log_likelihood,
7 | tb_logger,
8 | opts,
9 | grad_norms=None,
10 | batch_loss=None,
11 | reinforce_loss=None,
12 | bl_loss=None,
13 | ):
14 | avg_cost = cost.mean().item()
15 | if grad_norms is not None:
16 | grad_norms, grad_norms_clipped = grad_norms
17 | print("grad_norm: {}, clipped: {}".format(grad_norms[0], grad_norms_clipped[0]))
18 | # Log values to screen
19 | print(
20 | "epoch: {}, train_batch_id: {}, avg_cost: {}".format(epoch, batch_id, avg_cost)
21 | )
22 |
23 | # Log values to tensorboard
24 | if not opts.no_tensorboard:
25 | tb_logger.add_scalar("avg_cost", avg_cost, step)
26 |
27 | if opts.model == "ff-supervised" and batch_loss is not None:
28 | tb_logger.add_scalar("batch loss", batch_loss, step)
29 | else:
30 | if reinforce_loss is not None:
31 | tb_logger.add_scalar("actor_loss", reinforce_loss.item(), step)
32 | # tb_logger.add_scalar("nll", -sum(log_likelihood)/len(log_likelihood).item(), step)
33 |
34 | if grad_norms is not None:
35 | tb_logger.add_scalar("grad_norm", grad_norms[0], step)
36 | tb_logger.add_scalar("grad_norm_clipped", grad_norms_clipped[0], step)
37 |
38 | if opts.baseline == "critic":
39 | tb_logger.add_scalar("critic_loss", bl_loss.item(), step)
40 | tb_logger.add_scalar("critic_grad_norm", grad_norms[1], step)
41 | tb_logger.add_scalar(
42 | "critic_grad_norm_clipped", grad_norms_clipped[1], step
43 | )
44 |
--------------------------------------------------------------------------------
/policy/balance.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from utils.functions import random_max
5 |
6 |
7 | class Balance(nn.Module):
8 | def __init__(
9 | self,
10 | embedding_dim,
11 | hidden_dim,
12 | problem,
13 | opts,
14 | tanh_clipping=None,
15 | mask_inner=None,
16 | mask_logits=None,
17 | n_encode_layers=None,
18 | normalization=None,
19 | checkpoint_encoder=False,
20 | shrink_size=None,
21 | num_actions=None,
22 | n_heads=None,
23 | encoder=None,
24 | ):
25 | super(Balance, self).__init__()
26 | self.decode_type = None
27 | self.problem = problem
28 | self.model_name = "balance"
29 |
30 | def forward(self, x, opts, optimizer, baseline, return_pi=False):
31 | assert opts.problem == "adwords"
32 | state = self.problem.make_state(x, opts.u_size, opts.v_size, opts)
33 | sequences = []
34 | while not (state.all_finished()):
35 | mask = state.get_mask()
36 | frac_budget = state.curr_budget / state.orig_budget
37 | frac_budget[mask.bool()] = -1e6
38 | frac_budget[:, 0] = -1e5
39 | selected = random_max(frac_budget)
40 |
41 | state = state.update(selected)
42 | sequences.append(selected.squeeze(1))
43 | if return_pi:
44 | return -state.size, None, torch.stack(sequences, 1), None
45 | return -state.size, torch.stack(sequences, 1), None
46 |
47 | def set_decode_type(self, decode_type, temp=None):
48 | self.decode_type = decode_type
49 | if temp is not None: # Do not change temperature if not provided
50 | self.temp = temp
51 |
--------------------------------------------------------------------------------
/policy/msvv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from utils.functions import random_max
5 |
6 |
7 | class MSVV(nn.Module):
8 | def __init__(
9 | self,
10 | embedding_dim,
11 | hidden_dim,
12 | problem,
13 | opts,
14 | tanh_clipping=None,
15 | mask_inner=None,
16 | mask_logits=None,
17 | n_encode_layers=None,
18 | normalization=None,
19 | checkpoint_encoder=False,
20 | shrink_size=None,
21 | num_actions=None,
22 | n_heads=None,
23 | encoder=None,
24 | ):
25 | super(MSVV, self).__init__()
26 | self.decode_type = None
27 | self.problem = problem
28 | self.model_name = "msvv"
29 |
30 | def forward(self, x, opts, optimizer, baseline, return_pi=False):
31 | assert opts.problem == "adwords"
32 | state = self.problem.make_state(x, opts.u_size, opts.v_size, opts)
33 | sequences = []
34 | while not (state.all_finished()):
35 | mask = state.get_mask()
36 | w = state.get_current_weights(mask).clone()
37 | scaled_w = w * (1 - torch.exp(-(state.curr_budget / state.orig_budget)))
38 | scaled_w[mask.bool()] = -1e6
39 | scaled_w[:, 0] = -1e5
40 | selected = random_max(scaled_w)
41 |
42 | state = state.update(selected)
43 | sequences.append(selected.squeeze(1))
44 | if return_pi:
45 | return -state.size, None, torch.stack(sequences, 1), None
46 | return -state.size, torch.stack(sequences, 1), None
47 |
48 | def set_decode_type(self, decode_type, temp=None):
49 | self.decode_type = decode_type
50 | if temp is not None: # Do not change temperature if not provided
51 | self.temp = temp
52 |
--------------------------------------------------------------------------------
/problem_state/osbm_dataset.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.data import Dataset
2 | import torch
3 | import os
4 | import pickle
5 | from problem_state.osbm_env import StateOSBM
6 | from data.generate_data import generate_osbm_data_geometric
7 |
8 |
9 | class OSBM(object):
10 |
11 | NAME = "osbm"
12 |
13 | @staticmethod
14 | def make_dataset(*args, **kwargs):
15 | return OSBMDataset(*args, **kwargs)
16 |
17 | @staticmethod
18 | def make_state(*args, **kwargs):
19 | return StateOSBM.initialize(*args, **kwargs)
20 |
21 |
22 | class OSBMDataset(Dataset):
23 | def __init__(
24 | self, dataset, size, problem, seed, opts, transform=None, pre_transform=None
25 | ):
26 | super(OSBMDataset, self).__init__(None, transform, pre_transform)
27 | # self.data_set = dataset
28 | # self.optimal_size = torch.load("{}/optimal_match.pt".format(self.data_set))
29 | self.problem = problem
30 | if dataset is not None:
31 | # self.optimal_size = torch.load("{}/optimal_match.pt".format(dataset))
32 | self.data_set = dataset
33 | else:
34 | # If no filename is specified generated data for edge obm probelm
35 | D, optimal_size = generate_osbm_data_geometric(
36 | opts.u_size,
37 | opts.v_size,
38 | opts.weight_distribution,
39 | opts.weight_distribution_param,
40 | opts.graph_family_parameter,
41 | seed,
42 | opts.graph_family,
43 | None,
44 | size,
45 | False,
46 | )
47 | self.optimal_size = optimal_size
48 | self.data_set = D
49 |
50 | self.size = size
51 |
52 | def len(self):
53 | return self.size
54 |
55 | def get(self, idx):
56 | if type(self.data_set) == str:
57 | data = torch.load(self.data_set + "/data_{}.pt".format(idx))
58 | else:
59 | data = self.data_set[idx]
60 | return data
61 |
--------------------------------------------------------------------------------
/problem_state/adwords_dataset.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.data import Dataset
2 | import torch
3 | from problem_state.adwords_env import StateAdwordsBipartite
4 | from data.generate_data import generate_adwords_data_geometric
5 |
6 |
7 | class AdwordsBipartite(object):
8 |
9 | NAME = "adwords"
10 |
11 | @staticmethod
12 | def make_dataset(*args, **kwargs):
13 | return AdwordsBipartiteDataset(*args, **kwargs)
14 |
15 | @staticmethod
16 | def make_state(*args, **kwargs):
17 | return StateAdwordsBipartite.initialize(*args, **kwargs)
18 |
19 |
20 | class AdwordsBipartiteDataset(Dataset):
21 | def __init__(
22 | self, dataset, size, problem, seed, opts, transform=None, pre_transform=None
23 | ):
24 | super(AdwordsBipartiteDataset, self).__init__(None, transform, pre_transform)
25 | # self.data_set = dataset
26 | # self.optimal_size = torch.load("{}/optimal_match.pt".format(self.data_set))
27 | self.problem = problem
28 | if dataset is not None:
29 | # self.optimal_size = torch.load("{}/optimal_match.pt".format(dataset))
30 | self.data_set = dataset
31 | else:
32 | # If no filename is specified generated data for edge obm probelm
33 | D, optimal_size = generate_adwords_data_geometric(
34 | opts.u_size,
35 | opts.v_size,
36 | opts.weight_distribution,
37 | opts.weight_distribution_param,
38 | opts.graph_family_parameter,
39 | seed,
40 | opts.graph_family,
41 | None,
42 | size,
43 | False,
44 | )
45 | self.optimal_size = optimal_size
46 | self.data_set = D
47 |
48 | self.size = size
49 |
50 | def len(self):
51 | return self.size
52 |
53 | def get(self, idx):
54 | if type(self.data_set) == str:
55 | data = torch.load(self.data_set + "/data_{}.pt".format(idx))
56 | else:
57 | data = self.data_set[idx]
58 | return data
59 |
--------------------------------------------------------------------------------
/problem_state/edge_obm_dataset.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.data import Dataset
2 | import torch
3 | import os
4 | import pickle
5 | from problem_state.edge_obm_env import StateEdgeBipartite
6 | from data.generate_data import generate_edge_obm_data_geometric
7 |
8 |
9 | class EdgeBipartite(object):
10 |
11 | NAME = "e-obm"
12 |
13 | @staticmethod
14 | def make_dataset(*args, **kwargs):
15 | return EdgeBipartiteDataset(*args, **kwargs)
16 |
17 | @staticmethod
18 | def make_state(*args, **kwargs):
19 | return StateEdgeBipartite.initialize(*args, **kwargs)
20 |
21 |
22 | class EdgeBipartiteDataset(Dataset):
23 | def __init__(
24 | self, dataset, size, problem, seed, opts, transform=None, pre_transform=None
25 | ):
26 | super(EdgeBipartiteDataset, self).__init__(None, transform, pre_transform)
27 | # self.data_set = dataset
28 | # self.optimal_size = torch.load("{}/optimal_match.pt".format(self.data_set))
29 | self.problem = problem
30 | if dataset is not None:
31 | # self.optimal_size = torch.load("{}/optimal_match.pt".format(dataset))
32 | self.data_set = dataset
33 | else:
34 | # If no filename is specified generated data for edge obm probelm
35 | D, optimal_size = generate_edge_obm_data_geometric(
36 | opts.u_size,
37 | opts.v_size,
38 | opts.weight_distribution,
39 | opts.weight_distribution_param,
40 | opts.graph_family_parameter,
41 | seed,
42 | opts.graph_family,
43 | None,
44 | size,
45 | False,
46 | )
47 | self.optimal_size = optimal_size
48 | self.data_set = D
49 |
50 | self.size = size
51 |
52 | def len(self):
53 | return self.size
54 |
55 | def get(self, idx):
56 | if type(self.data_set) == str:
57 | data = torch.load(self.data_set + "/data_{}.pt".format(idx))
58 | else:
59 | data = self.data_set[idx]
60 | return data
61 |
--------------------------------------------------------------------------------
/policy/greedy_theshold.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from utils.functions import get_best_t
4 |
5 |
6 | class GreedyThresh(nn.Module):
7 | def __init__(
8 | self,
9 | embedding_dim,
10 | hidden_dim,
11 | problem,
12 | opts,
13 | tanh_clipping=None,
14 | mask_inner=None,
15 | mask_logits=None,
16 | n_encode_layers=None,
17 | normalization=None,
18 | checkpoint_encoder=False,
19 | shrink_size=None,
20 | num_actions=None,
21 | n_heads=None,
22 | encoder=None,
23 | ):
24 | super(GreedyThresh, self).__init__()
25 | self.decode_type = None
26 | self.problem = problem
27 | self.model_name = "greedy-t"
28 | self.best_threshold = get_best_t(self.model_name, opts)
29 |
30 | def forward(self, x, opts, optimizer, baseline, return_pi=False):
31 | state = self.problem.make_state(x, opts.u_size, opts.v_size, opts)
32 | t = self.best_threshold
33 | sequences = []
34 | while not (state.all_finished()):
35 | mask = state.get_mask()
36 | w = state.get_current_weights(mask).clone()
37 | mask = state.get_mask()
38 | w[mask.bool()] = 0.0
39 | temp = w.clone()
40 | # w[temp >= t] = 1.0
41 | w[temp < t] = 0.0
42 | w[w.sum(1) == 0, 0] = 1.0
43 | # if self.decode_type == "greedy":
44 | selected = torch.argmax(w, dim=1)
45 | # elif self.decode_type == "sampling":
46 | # selected = (w / torch.sum(w, dim=1)[:, None]).multinomial(1)
47 | state = state.update(selected[:, None])
48 |
49 | sequences.append(selected)
50 | if return_pi:
51 | return -state.size, None, torch.stack(sequences, 1), None
52 | return -state.size, torch.stack(sequences, 1), None
53 |
54 | def set_decode_type(self, decode_type, temp=None):
55 | self.decode_type = decode_type
56 | if temp is not None: # Do not change temperature if not provided
57 | self.temp = temp
58 |
--------------------------------------------------------------------------------
/utils/functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from problem_state.obm_dataset import Bipartite
3 | from problem_state.edge_obm_dataset import EdgeBipartite
4 | from problem_state.osbm_dataset import OSBM
5 | from problem_state.adwords_dataset import AdwordsBipartite
6 | import csv
7 |
8 |
9 | def load_problem(name):
10 |
11 | problem = {
12 | "obm": Bipartite,
13 | "e-obm": EdgeBipartite,
14 | "osbm": OSBM,
15 | "adwords": AdwordsBipartite,
16 | }.get(name, None)
17 | assert problem is not None, "Currently unsupported problem: {}!".format(name)
18 | return problem
19 |
20 |
21 | def torch_load_cpu(load_path):
22 | return torch.load(
23 | load_path, map_location=lambda storage, loc: storage
24 | ) # Load on CPU
25 |
26 |
27 | def move_to(var, device):
28 | if isinstance(var, dict):
29 | return {k: move_to(v, device) for k, v in var.items()}
30 | elif isinstance(var, list):
31 | return list(move_to(v, device) for v in var)
32 | return var.to(device)
33 |
34 |
35 | def random_max(input):
36 | """
37 | Return max element with random tie breaking
38 | """
39 | max_w, _ = torch.max(input, dim=1)
40 | max_filter = (input == max_w[:, None]).float()
41 | selected = (max_filter / (max_filter.sum(dim=1)[:, None])).multinomial(1)
42 |
43 | return selected
44 |
45 |
46 | def random_min(input):
47 | """
48 | Return min element with random tie breaking
49 | """
50 | min_w, _ = torch.min(input, dim=1)
51 | max_filter = (input == min_w[:, None]).float()
52 | selected = (max_filter / (max_filter.sum(dim=1)[:, None])).multinomial(1)
53 |
54 | return selected
55 |
56 |
57 | def get_best_t(model, opts):
58 | best_params = None
59 | best_r = 0
60 | graph_family = (
61 | opts.graph_family if opts.graph_family != "gmission-perm" else "gmission"
62 | )
63 | with open(
64 | f"val_rewards_{model}_{opts.u_size}_{opts.v_size}_{graph_family}_{opts.graph_family_parameter}.csv"
65 | ) as csv_file:
66 | csv_reader = csv.reader(csv_file, delimiter=",")
67 | for line in csv_reader:
68 | if abs(float(line[-1])) > best_r:
69 | best_params = float(line[0])
70 | best_r = abs(float(line[-1]))
71 | return best_params
72 |
--------------------------------------------------------------------------------
/data/gMission/reduced_workers.txt:
--------------------------------------------------------------------------------
1 | 321.0
2 | 81.0
3 | 105.0
4 | 329.0
5 | 497.0
6 | 358.0
7 | 3.0
8 | 300.0
9 | 357.0
10 | 381.0
11 | 161.0
12 | 510.0
13 | 155.0
14 | 239.0
15 | 511.0
16 | 72.0
17 | 143.0
18 | 172.0
19 | 370.0
20 | 447.0
21 | 17.0
22 | 106.0
23 | 113.0
24 | 129.0
25 | 353.0
26 | 387.0
27 | 407.0
28 | 412.0
29 | 490.0
30 | 515.0
31 | 523.0
32 | 12.0
33 | 37.0
34 | 193.0
35 | 284.0
36 | 323.0
37 | 343.0
38 | 344.0
39 | 404.0
40 | 438.0
41 | 460.0
42 | 479.0
43 | 18.0
44 | 29.0
45 | 49.0
46 | 71.0
47 | 229.0
48 | 301.0
49 | 339.0
50 | 342.0
51 | 413.0
52 | 445.0
53 | 526.0
54 | 5.0
55 | 8.0
56 | 30.0
57 | 38.0
58 | 145.0
59 | 195.0
60 | 214.0
61 | 236.0
62 | 309.0
63 | 315.0
64 | 324.0
65 | 327.0
66 | 385.0
67 | 394.0
68 | 494.0
69 | 74.0
70 | 82.0
71 | 111.0
72 | 142.0
73 | 182.0
74 | 203.0
75 | 207.0
76 | 217.0
77 | 293.0
78 | 430.0
79 | 441.0
80 | 459.0
81 | 487.0
82 | 493.0
83 | 1.0
84 | 9.0
85 | 39.0
86 | 75.0
87 | 93.0
88 | 158.0
89 | 160.0
90 | 222.0
91 | 227.0
92 | 276.0
93 | 294.0
94 | 308.0
95 | 328.0
96 | 365.0
97 | 406.0
98 | 429.0
99 | 462.0
100 | 468.0
101 | 480.0
102 | 484.0
103 | 486.0
104 | 491.0
105 | 518.0
106 | 529.0
107 | 60.0
108 | 126.0
109 | 151.0
110 | 167.0
111 | 256.0
112 | 291.0
113 | 337.0
114 | 372.0
115 | 488.0
116 | 525.0
117 | 62.0
118 | 80.0
119 | 148.0
120 | 149.0
121 | 157.0
122 | 173.0
123 | 212.0
124 | 254.0
125 | 318.0
126 | 422.0
127 | 435.0
128 | 496.0
129 | 52.0
130 | 61.0
131 | 76.0
132 | 204.0
133 | 221.0
134 | 275.0
135 | 283.0
136 | 297.0
137 | 319.0
138 | 347.0
139 | 351.0
140 | 359.0
141 | 366.0
142 | 410.0
143 | 419.0
144 | 442.0
145 | 453.0
146 | 456.0
147 | 465.0
148 | 475.0
149 | 489.0
150 | 507.0
151 | 6.0
152 | 90.0
153 | 115.0
154 | 136.0
155 | 166.0
156 | 271.0
157 | 331.0
158 | 332.0
159 | 388.0
160 | 405.0
161 | 416.0
162 | 471.0
163 | 512.0
164 | 516.0
165 | 4.0
166 | 56.0
167 | 122.0
168 | 188.0
169 | 216.0
170 | 224.0
171 | 245.0
172 | 296.0
173 | 322.0
174 | 401.0
175 | 451.0
176 | 492.0
177 | 19.0
178 | 23.0
179 | 50.0
180 | 175.0
181 | 230.0
182 | 257.0
183 | 260.0
184 | 278.0
185 | 362.0
186 | 398.0
187 | 400.0
188 | 426.0
189 | 482.0
190 | 48.0
191 | 92.0
192 | 132.0
193 | 152.0
194 | 210.0
195 | 306.0
196 | 314.0
197 | 355.0
198 | 377.0
199 | 378.0
200 | 402.0
201 |
--------------------------------------------------------------------------------
/policy/greedy_sc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from utils.functions import random_max
4 | import math
5 |
6 | class GreedySC(nn.Module):
7 | def __init__(
8 | self,
9 | embedding_dim,
10 | hidden_dim,
11 | problem,
12 | opts,
13 | tanh_clipping=None,
14 | mask_inner=None,
15 | mask_logits=None,
16 | n_encode_layers=None,
17 | normalization=None,
18 | checkpoint_encoder=False,
19 | shrink_size=None,
20 | num_actions=None,
21 | n_heads=None,
22 | encoder=None,
23 | ):
24 | super(GreedySC, self).__init__()
25 | self.decode_type = None
26 | self.problem = problem
27 | self.model_name = "greedy-sc"
28 |
29 | def forward(self, x, opts, optimizer, baseline, return_pi=False):
30 | state = self.problem.make_state(x, opts.u_size, opts.v_size, opts)
31 | sequences = []
32 |
33 | # Array for single step rewards
34 | i = 0
35 | reward_array = torch.zeros([opts.v_size, opts.batch_size])
36 | prev_reward = 0
37 |
38 | graph_v_size = opts.v_size
39 | i = 0
40 |
41 | while not (state.all_finished()):
42 | mask = state.get_mask()
43 | w = state.get_current_weights(mask).clone()
44 | w[mask.bool()] = -1.0
45 |
46 | # Select action based on current info
47 | if(i > graph_v_size/math.e - 1):
48 | selected = random_max(w)
49 | else:
50 | selected = torch.zeros([w.shape[0], 1], device=w.device, dtype=torch.int64)
51 |
52 | state = state.update(selected)
53 | sequences.append(selected.squeeze(1))
54 |
55 | # Collect the signle step reward
56 | reward_array[i,:] = state.size.view(-1).cpu() - prev_reward
57 | prev_reward = state.size.view(-1).cpu()
58 | i+= 1
59 |
60 | if return_pi:
61 | return -state.size, None, torch.stack(sequences, 1), None
62 |
63 | return -state.size, torch.stack(sequences, 1), None
64 |
65 | def set_decode_type(self, decode_type, temp=None):
66 | self.decode_type = decode_type
67 | if temp is not None: # Do not change temperature if not provided
68 | self.temp = temp
69 |
--------------------------------------------------------------------------------
/policy/greedy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from utils.functions import random_max
4 |
5 |
6 | class Greedy(nn.Module):
7 | def __init__(
8 | self,
9 | embedding_dim,
10 | hidden_dim,
11 | problem,
12 | opts,
13 | tanh_clipping=None,
14 | mask_inner=None,
15 | mask_logits=None,
16 | n_encode_layers=None,
17 | normalization=None,
18 | checkpoint_encoder=False,
19 | shrink_size=None,
20 | num_actions=None,
21 | n_heads=None,
22 | encoder=None,
23 | ):
24 | super(Greedy, self).__init__()
25 | self.decode_type = None
26 | self.problem = problem
27 | self.model_name = "greedy"
28 |
29 | def forward(self, x, opts, optimizer, baseline, return_pi=False):
30 | state = self.problem.make_state(x, opts.u_size, opts.v_size, opts)
31 | sequences = []
32 |
33 | # Array for single step rewards
34 | i = 0
35 | reward_array = torch.zeros([opts.v_size, opts.batch_size])
36 | prev_reward = 0
37 |
38 | while not (state.all_finished()):
39 | mask = state.get_mask()
40 | w = state.get_current_weights(mask).clone()
41 | w[mask.bool()] = -1.0
42 | selected = random_max(w)
43 | state = state.update(selected)
44 | sequences.append(selected.squeeze(1))
45 |
46 | # Collect the signle step reward
47 | reward_array[i,:] = state.size.view(-1).cpu() - prev_reward
48 | prev_reward = state.size.view(-1).cpu()
49 | i+= 1
50 |
51 | if return_pi:
52 | return -state.size, None, torch.stack(sequences, 1), None
53 |
54 | # import os
55 | # base_dir = "reward_save"
56 | # file_list = os.listdir(base_dir)
57 | # if file_list.__len__() != 0:
58 | # file_list.sort()
59 | # curr_index = int(file_list[-1][-6:-3]) + 1
60 | # else:
61 | # curr_index = 0
62 | # torch.save(reward_array, base_dir+"/reward_{:0>3d}.pt".format(curr_index))
63 |
64 | return -state.size, torch.stack(sequences, 1), None
65 |
66 | def set_decode_type(self, decode_type, temp=None):
67 | self.decode_type = decode_type
68 | if temp is not None: # Do not change temperature if not provided
69 | self.temp = temp
70 |
--------------------------------------------------------------------------------
/encoder/graph_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch_geometric.nn import NNConv
4 | import torch.nn.functional as F
5 |
6 |
7 | class MPNN(nn.Module):
8 | def __init__(
9 | self,
10 | n_heads,
11 | embed_dim,
12 | n_layers,
13 | problem,
14 | opts,
15 | dropout=0.1,
16 | alpha=0.01,
17 | node_dim_u=1,
18 | node_dim_v=1,
19 | normalization="batch",
20 | feed_forward_hidden=512,
21 | ):
22 | super(MPNN, self).__init__()
23 | self.l1 = nn.Linear(1, embed_dim * embed_dim) # TODO: Change back
24 | self.node_embed_u = nn.Linear(node_dim_u, embed_dim)
25 | if node_dim_u != node_dim_v:
26 | self.node_embed_v = nn.Linear(node_dim_v, embed_dim)
27 | else:
28 | self.node_embed_v = self.node_embed_u
29 |
30 | self.conv1 = NNConv(
31 | embed_dim, embed_dim, self.l1, aggr="mean"
32 | ) # TODO: Change back
33 | if n_layers > 1:
34 | self.l2 = nn.Linear(1, embed_dim ** 2)
35 | self.conv2 = NNConv(embed_dim, embed_dim, self.l2, aggr="mean")
36 | self.n_layers = n_layers
37 | self.problem = opts.problem
38 | self.u_size = opts.u_size
39 | self.node_dim_u = node_dim_u
40 | self.node_dim_v = node_dim_v
41 | self.batch_size = opts.batch_size
42 |
43 | def forward(self, x, edge_index, edge_attribute, i, dummy, opts):
44 | i = i.item()
45 | graph_size = opts.u_size + 1 + i
46 | if i < self.n_layers:
47 | n_encode_layers = i + 1
48 | else:
49 | n_encode_layers = self.n_layers
50 | x_u = x[:, : self.node_dim_u * (opts.u_size + 1)].reshape(
51 | self.batch_size, opts.u_size + 1, self.node_dim_u
52 | )
53 | x_v = x[:, self.node_dim_u * (opts.u_size + 1) :].reshape(
54 | self.batch_size, i, self.node_dim_v
55 | )
56 | x_u = self.node_embed_u(x_u)
57 | x_v = self.node_embed_v(x_v)
58 | x = torch.cat((x_u, x_v), dim=1).reshape(self.batch_size * graph_size, -1)
59 |
60 | for j in range(n_encode_layers):
61 | # x = F.relu(x) # TODO: Change back
62 | if j == 0:
63 | x = self.conv1(x, edge_index, edge_attribute.float())
64 | if n_encode_layers > 1:
65 | x = F.relu(x)
66 | else:
67 | x = self.conv2(x, edge_index, edge_attribute.float())
68 |
69 | return x
70 |
--------------------------------------------------------------------------------
/policy/simple_greedy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 |
5 |
6 | class SimpleGreedy(nn.Module):
7 | def __init__(
8 | self,
9 | embedding_dim,
10 | hidden_dim,
11 | problem,
12 | opts,
13 | tanh_clipping=None,
14 | mask_inner=None,
15 | mask_logits=None,
16 | n_encode_layers=None,
17 | normalization=None,
18 | checkpoint_encoder=False,
19 | shrink_size=None,
20 | num_actions=None,
21 | n_heads=None,
22 | encoder=None,
23 | ):
24 | super(SimpleGreedy, self).__init__()
25 | self.decode_type = None
26 | self.allow_partial = problem.NAME == "sdvrp"
27 | self.is_vrp = problem.NAME == "cvrp" or problem.NAME == "sdvrp"
28 | self.is_orienteering = problem.NAME == "op"
29 | self.is_pctsp = problem.NAME == "pctsp"
30 | self.is_bipartite = problem.NAME == "bipartite"
31 | self.is_tsp = problem.NAME == "tsp"
32 | self.problem = problem
33 | self.rank = 0
34 |
35 | def forward(self, x, opts):
36 | state = self.problem.make_state(x, opts.u_size, opts.v_size, opts)
37 |
38 | self.rank = self.permute_uniform(
39 | torch.arange(1, state.u_size.item() + 2, device=x.device)
40 | .unsqueeze(0)
41 | .expand(state.batch_size.item(), state.u_size.item() + 1)
42 | )
43 | sequences = []
44 | self.rank[:, 0] = state.u_size.item() * 2
45 | while not (state.all_finished()):
46 | mask = state.get_mask().bool()
47 | r = self.rank.clone()
48 | r[mask] = 10e6
49 | selected = torch.argmin(r, dim=1)
50 |
51 | state = state.update(selected[:, None])
52 |
53 | sequences.append(selected)
54 |
55 | return -state.size, torch.stack(sequences, 1)
56 |
57 | def set_decode_type(self, decode_type, temp=None):
58 | self.decode_type = decode_type
59 | if temp is not None: # Do not change temperature if not provided
60 | self.temp = temp
61 |
62 | def permute_uniform(self, x):
63 | """
64 | Permutes a batch of lists uniformly at random using the Fisher Yates algorithm.
65 | """
66 | y = x.clone()
67 | n = x.size(1)
68 | batch_size = x.size(0)
69 | for i in range(0, n - 1):
70 | j = torch.tensor(np.random.randint(i, n, (batch_size, 1)), device=x.device)
71 | temp = y[:, i].clone()
72 | y[:, i] = torch.gather(y, 1, j).squeeze(1)
73 | y = y.scatter_(1, j, temp.unsqueeze(1))
74 | return y
75 |
--------------------------------------------------------------------------------
/policy/greedy_rt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 |
5 |
6 | class GreedyRt(nn.Module):
7 | def __init__(
8 | self,
9 | embedding_dim,
10 | hidden_dim,
11 | problem,
12 | opts,
13 | tanh_clipping=None,
14 | mask_inner=None,
15 | mask_logits=None,
16 | n_encode_layers=None,
17 | normalization=None,
18 | checkpoint_encoder=False,
19 | shrink_size=None,
20 | num_actions=None,
21 | n_heads=None,
22 | encoder=None,
23 | ):
24 | super(GreedyRt, self).__init__()
25 | self.decode_type = None
26 | self.problem = problem
27 | self.model_name = "greedy-rt"
28 | max_weight_dict = {
29 | "gmission-var": 18.8736,
30 | "gmission": 18.8736,
31 | "er": 10 ** 8,
32 | "ba": float(
33 | opts.graph_family_parameter
34 | ) # Make sure to set this properly before running!
35 | + float(opts.weight_distribution_param[1]),
36 | }
37 | norm_weight = {
38 | "gmission-var": 18.8736,
39 | "gmission": 18.8736,
40 | "er": 10 ** 8,
41 | "ba": 100.0,
42 | }
43 | if opts.graph_family == "gmission-perm":
44 | graph_family = "gmission"
45 | else:
46 | graph_family = opts.graph_family
47 | self.max_weight = max_weight_dict[graph_family]
48 | self.norm_weights = norm_weight[graph_family]
49 |
50 | def forward(self, x, opts, optimizer, baseline, return_pi=False):
51 | state = self.problem.make_state(x, opts.u_size, opts.v_size, opts)
52 | t = torch.tensor(
53 | np.e
54 | ** np.random.randint(
55 | 1, np.ceil(np.log(1 + self.max_weight)), (opts.batch_size, 1)
56 | ),
57 | device=opts.device,
58 | )
59 | sequences = []
60 | while not (state.all_finished()):
61 | mask = state.get_mask()
62 | w = state.get_current_weights(mask).clone()
63 | mask = state.get_mask()
64 | w[mask.bool()] = 0.0
65 | temp = (
66 | w * self.norm_weights
67 | ) # re-normalize the weights since they are mostly between 0 and 1.
68 | temp[
69 | temp > 0.0
70 | ] += 1.0 # To make sure all weights are at least 1 (needed for greedy-rt to work).
71 | w[temp >= t] = 1.0
72 | w[temp < t] = 0.0
73 | w[w.sum(1) == 0, 0] = 1.0
74 | selected = (w / torch.sum(w, dim=1)[:, None]).multinomial(1)
75 | state = state.update(selected)
76 |
77 | sequences.append(selected.squeeze(1))
78 | if return_pi:
79 | return -state.size, None, torch.stack(sequences, 1), None
80 | return -state.size, torch.stack(sequences, 1), None
81 |
82 | def set_decode_type(self, decode_type, temp=None):
83 | self.decode_type = decode_type
84 | if temp is not None: # Do not change temperature if not provided
85 | self.temp = temp
86 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning for Edge-Weighted Online Bipartite Matching with Robustness Guarantees
2 |
3 | [](LICENSE)
4 |
5 | [Pengfei Li](https://www.cs.ucr.edu/~pli081/), [Jianyi Yang](https://jyang-ai.github.io/) and [Shaolei Ren](https://intra.ece.ucr.edu/~sren/)
6 |
7 | **Note**
8 |
9 | This is the official implementation of the ICML 2023 paper
10 |
11 | ## Requirements
12 |
13 | * python>=3.6
14 |
15 | ## Installation
16 | * Clone this repo:
17 | ```bash
18 | git clone git@github.com:Ren-Research/LOMAR.git
19 | cd LOMAR
20 | ```
21 | Then please refer to [the install guide](INSTALL.md) for more details about installation
22 |
23 | ## Usage
24 | To apply our algorithm (LOMAR) in online bipartite matching, you need three main steps
25 |
26 | 1. Generate graph dataset
27 | 2. Train the RL model
28 | 3. Evaluate the policy
29 |
30 | A script example for each step can be found in our brief [tutorial](TUTORIAL.md).
31 |
32 | ## Evaluation
33 |
34 | In our experiment, we set $u_0 = 10$ and $v_0 = 60$ to generate the training and testing datasets. The number of graph instances in the training and testing datasets are 20000 and 1000, respectively. For the sake of reproducibility and fair comparision, our settings follows the same setup of our [baseline](https://github.com/lyeskhalil/CORL).
35 |
36 | |  |
37 | |:--:|
38 | | Table 1: Comparison under different $\rho$. In the top, LOMAR ($\rho = x$) means LOMAR is trained with the value of $\rho = x$. The average reward and competitive ratio are represented by AVG and CR, respectively — the higher, the better. The highest value in each testing setup is highlighted in bold. The AVG and CR for DRL are 12.909 and 0.544 respectively. The average reward for OPT is 13.209 .|
39 |
40 | The histogram of the bi-competitive ratios are visualized below. When $\rho = 0$, the ratio of DRL-OS / DRL is always 1 unsurprisingly. With a large $\rho$ (e.g. 0.8) for testing, the reward ratios of DRL-OS/Greedy for most samples are around 1, but the flexibility of DRL-OS is limited and can less exploit the good average performance of DRL.
41 |
42 | |  |
43 | |:--:|
44 | | Figure 1: Histogram of bi-competitive reward ratios of DRL-OS against Greedy and DRL under different $\rho$. The DRL-OS has the same online switching algorithm as LOMAR, while the RL model is trained with $\rho=0$. |
45 |
46 | ## Citation
47 | ```BibTex
48 | @inproceedings{Li2023LOMAR,
49 | title={Learning for Edge-Weighted Online Bipartite Matching with Robustness Guarantees},
50 | author={Li, Pengfei and Yang, Jianyi and Ren, Shaolei},
51 | booktitle={International Conference on Machine Learning},
52 | year={2023},
53 | organization={PMLR}
54 | }
55 | ```
56 |
57 |
58 | ## Codebase
59 | Thanks for the code base from Mohammad Ali Alomrani, Reza Moravej, Elias B. Khalil. The public repository of their code is available at [https://github.com/lyeskhalil/CORL](https://github.com/lyeskhalil/CORL)
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
--------------------------------------------------------------------------------
/data/gMission/reduced_tasks.txt:
--------------------------------------------------------------------------------
1 | 347.0
2 | 414.0
3 | 40.0
4 | 134.0
5 | 462.0
6 | 467.0
7 | 72.0
8 | 88.0
9 | 344.0
10 | 489.0
11 | 2.0
12 | 68.0
13 | 196.0
14 | 360.0
15 | 507.0
16 | 646.0
17 | 52.0
18 | 113.0
19 | 201.0
20 | 171.0
21 | 326.0
22 | 442.0
23 | 563.0
24 | 395.0
25 | 434.0
26 | 91.0
27 | 384.0
28 | 401.0
29 | 520.0
30 | 433.0
31 | 427.0
32 | 553.0
33 | 47.0
34 | 133.0
35 | 693.0
36 | 105.0
37 | 198.0
38 | 516.0
39 | 673.0
40 | 37.0
41 | 612.0
42 | 164.0
43 | 237.0
44 | 464.0
45 | 545.0
46 | 131.0
47 | 316.0
48 | 614.0
49 | 66.0
50 | 160.0
51 | 43.0
52 | 387.0
53 | 4.0
54 | 223.0
55 | 296.0
56 | 340.0
57 | 499.0
58 | 681.0
59 | 106.0
60 | 32.0
61 | 38.0
62 | 366.0
63 | 295.0
64 | 528.0
65 | 576.0
66 | 116.0
67 | 455.0
68 | 573.0
69 | 92.0
70 | 431.0
71 | 99.0
72 | 307.0
73 | 588.0
74 | 638.0
75 | 583.0
76 | 103.0
77 | 204.0
78 | 532.0
79 | 44.0
80 | 45.0
81 | 330.0
82 | 365.0
83 | 8.0
84 | 146.0
85 | 212.0
86 | 415.0
87 | 490.0
88 | 707.0
89 | 621.0
90 | 582.0
91 | 680.0
92 | 64.0
93 | 154.0
94 | 193.0
95 | 221.0
96 | 256.0
97 | 290.0
98 | 374.0
99 | 3.0
100 | 18.0
101 | 468.0
102 | 593.0
103 | 598.0
104 | 61.0
105 | 175.0
106 | 623.0
107 | 39.0
108 | 342.0
109 | 376.0
110 | 544.0
111 | 54.0
112 | 451.0
113 | 225.0
114 | 606.0
115 | 27.0
116 | 112.0
117 | 127.0
118 | 487.0
119 | 617.0
120 | 666.0
121 | 71.0
122 | 130.0
123 | 277.0
124 | 495.0
125 | 603.0
126 | 644.0
127 | 85.0
128 | 187.0
129 | 341.0
130 | 444.0
131 | 209.0
132 | 692.0
133 | 310.0
134 | 354.0
135 | 525.0
136 | 280.0
137 | 373.0
138 | 328.0
139 | 677.0
140 | 189.0
141 | 278.0
142 | 299.0
143 | 314.0
144 | 579.0
145 | 383.0
146 | 244.0
147 | 70.0
148 | 177.0
149 | 411.0
150 | 428.0
151 | 479.0
152 | 712.0
153 | 417.0
154 | 230.0
155 | 599.0
156 | 126.0
157 | 357.0
158 | 641.0
159 | 167.0
160 | 447.0
161 | 708.0
162 | 46.0
163 | 370.0
164 | 385.0
165 | 643.0
166 | 97.0
167 | 561.0
168 | 60.0
169 | 79.0
170 | 122.0
171 | 363.0
172 | 697.0
173 | 482.0
174 | 31.0
175 | 125.0
176 | 626.0
177 | 152.0
178 | 233.0
179 | 378.0
180 | 409.0
181 | 422.0
182 | 477.0
183 | 560.0
184 | 580.0
185 | 660.0
186 | 683.0
187 | 309.0
188 | 343.0
189 | 597.0
190 | 498.0
191 | 524.0
192 | 194.0
193 | 217.0
194 | 250.0
195 | 264.0
196 | 289.0
197 | 453.0
198 | 595.0
199 | 636.0
200 | 197.0
201 | 293.0
202 | 124.0
203 | 195.0
204 | 75.0
205 | 120.0
206 | 279.0
207 | 425.0
208 | 459.0
209 | 486.0
210 | 577.0
211 | 100.0
212 | 29.0
213 | 627.0
214 | 671.0
215 | 210.0
216 | 234.0
217 | 624.0
218 | 674.0
219 | 220.0
220 | 543.0
221 | 111.0
222 | 143.0
223 | 153.0
224 | 320.0
225 | 179.0
226 | 191.0
227 | 581.0
228 | 709.0
229 | 36.0
230 | 206.0
231 | 405.0
232 | 461.0
233 | 33.0
234 | 213.0
235 | 254.0
236 | 261.0
237 | 322.0
238 | 372.0
239 | 616.0
240 | 701.0
241 | 156.0
242 | 379.0
243 | 424.0
244 | 559.0
245 | 630.0
246 | 655.0
247 | 679.0
248 | 240.0
249 | 438.0
250 | 555.0
251 | 98.0
252 | 205.0
253 | 485.0
254 | 227.0
255 | 496.0
256 | 668.0
257 | 190.0
258 | 445.0
259 | 199.0
260 | 180.0
261 | 421.0
262 | 550.0
263 | 440.0
264 | 172.0
265 | 276.0
266 | 403.0
267 | 432.0
268 | 637.0
269 | 665.0
270 | 128.0
271 | 186.0
272 | 255.0
273 | 284.0
274 | 327.0
275 | 361.0
276 | 408.0
277 | 556.0
278 | 219.0
279 | 446.0
280 | 699.0
281 | 248.0
282 | 539.0
283 | 169.0
284 | 62.0
285 | 215.0
286 | 457.0
287 | 94.0
288 | 339.0
289 | 349.0
290 | 22.0
291 | 548.0
292 | 590.0
293 | 656.0
294 | 17.0
295 | 19.0
296 | 123.0
297 | 139.0
298 | 200.0
299 | 241.0
300 | 491.0
301 |
--------------------------------------------------------------------------------
/problem_state/obm_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import torch
3 | import os
4 | import pickle
5 | from problem_state.obm_env import StateBipartite
6 |
7 |
8 | class Bipartite(object):
9 |
10 | NAME = "obm"
11 |
12 | # @staticmethod
13 | # def get_costs(dataset, pi):
14 | # # TODO: MODIFY CODE SO IT WORKS WITH BIPARTITE INSTEAD OF TSP
15 | # # Check that tours are valid, i.e. contain 0 to n -1
16 | # assert (
17 | # torch.arange(pi.size(1), out=pi.data.new()).view(1, -1).expand_as(pi)
18 | # == pi.data.sort(1)[0]
19 | # ).all(), "Invalid tour"
20 | # # Gather dataset in order of tour
21 | # d = dataset.gather(1, pi.unsqueeze(-1).expand_as(dataset))
22 | # # Length is distance (L2-norm of difference) from each next location from its prev and of last from first
23 | # return (
24 | # (d[:, 1:] - d[:, :-1]).norm(p=2, dim=2).sum(1)
25 | # + (d[:, 0] - d[:, -1]).norm(p=2, dim=1),
26 | # None,
27 | # )
28 |
29 | @staticmethod
30 | def make_dataset(*args, **kwargs):
31 | return BipartiteDataset(*args, **kwargs)
32 |
33 | @staticmethod
34 | def make_state(*args, **kwargs):
35 | return StateBipartite.initialize(*args, **kwargs)
36 |
37 | # @staticmethod
38 | # def beam_search(
39 | # input,
40 | # beam_size,
41 | # expand_size=None,
42 | # compress_mask=False,
43 | # model=None,
44 | # max_calc_batch_size=4096,
45 | # ):
46 |
47 | # assert model is not None, "Provide model"
48 |
49 | # fixed = model.precompute_fixed(input)
50 |
51 | # def propose_expansions(beam):
52 | # return model.propose_expansions(
53 | # beam,
54 | # fixed,
55 | # expand_size,
56 | # normalize=True,
57 | # max_calc_batch_size=max_calc_batch_size,
58 | # )
59 |
60 | # state = Bipartite.make_state(
61 | # input, visited_dtype=torch.int64 if compress_mask else torch.uint8
62 | # )
63 |
64 | # return beam_search(state, beam_size, propose_expansions)
65 |
66 |
67 | class BipartiteDataset(Dataset):
68 | def __init__(self, dataset, size, problem, opts):
69 | super(BipartiteDataset, self).__init__()
70 |
71 | self.data_set = dataset
72 | self.optimal_size = torch.load("{}/optimal_match.pt".format(self.data_set))
73 | self.problem = problem
74 | # if opts.train_dataset is not None:
75 | # assert os.path.splitext(opts.train_dataset)[1] == ".pkl"
76 |
77 | # with open(opts.train_dataset, "rb") as f:
78 | # data = pickle.load(f)
79 | # self.data = data
80 | # else:
81 | # ### TODO: Should use generate function in generate_data.py
82 | # # If no filename is specified generated data for normal obm probelm
83 | # self.data = generate_obm_data(opts)
84 |
85 | self.size = size
86 |
87 | def __getitem__(self, index):
88 |
89 | # Load data and get label
90 | X = torch.load("{}/graphs/{}.pt".format(self.data_set, index))
91 | Y = self.optimal_size[index]
92 | return X, Y
93 |
94 | def __len__(self):
95 | return self.size
96 |
97 | # def __getitem__(self, idx):
98 | # return tuple(d[idx] for d in self.data)
99 |
100 |
101 | # train_loader = torch.utils.data.DataLoader(
102 | # ConcatDataset(
103 | # datasets.ImageFolder(traindir_A),
104 | # datasets.ImageFolder(traindir_B)
105 | # ),
106 | # batch_size=args.batch_size, shuffle=True,
107 | # num_workers=args.workers, pin_memory=True)
108 |
--------------------------------------------------------------------------------
/data/MovieLense/movies.txt:
--------------------------------------------------------------------------------
1 | 36::Dead Man Walking (1995)::Drama
2 | 68::French Twist (Gazon maudit) (1995)::Comedy|Romance
3 | 124::Star Maker, The (Uomo delle stelle, L') (1995)::Drama
4 | 170::Hackers (1995)::Action|Crime|Thriller
5 | 180::Mallrats (1995)::Comedy
6 | 192::Show, The (1995)::Documentary
7 | 251::Hunted, The (1995)::Action
8 | 252::I.Q. (1994)::Comedy|Romance
9 | 273::Mary Shelley's Frankenstein (1994)::Drama|Horror
10 | 398::Frank and Ollie (1995)::Documentary
11 | 404::Brother Minister: The Assassination of Malcolm X (1994)::Documentary
12 | 410::Addams Family Values (1993)::Comedy
13 | 523::Ruby in Paradise (1993)::Drama
14 | 538::Six Degrees of Separation (1993)::Drama
15 | 550::Threesome (1994)::Comedy|Romance
16 | 582::Metisse (Caf� au Lait) (1993)::Comedy
17 | 587::Ghost (1990)::Comedy|Romance|Thriller
18 | 608::Fargo (1996)::Crime|Drama|Thriller
19 | 635::Family Thing, A (1996)::Comedy|Drama
20 | 651::Superweib, Das (1996)::Comedy
21 | 657::Yankee Zulu (1994)::Comedy|Drama
22 | 725::Great White Hype, The (1996)::Comedy
23 | 757::Ashes of Time (1994)::Drama
24 | 778::Trainspotting (1996)::Drama
25 | 829::Joe's Apartment (1996)::Comedy|Musical
26 | 843::Lotto Land (1995)::Drama
27 | 848::Spitfire Grill, The (1996)::Drama
28 | 852::Tin Cup (1996)::Comedy|Romance
29 | 1127::Abyss, The (1989)::Action|Adventure|Sci-Fi|Thriller
30 | 1128::Fog, The (1980)::Horror
31 | 1131::Jean de Florette (1986)::Drama
32 | 1206::Clockwork Orange, A (1971)::Sci-Fi
33 | 1208::Apocalypse Now (1979)::Drama|War
34 | 1276::Cool Hand Luke (1967)::Comedy|Drama
35 | 1303::Man Who Would Be King, The (1975)::Adventure
36 | 1358::Sling Blade (1996)::Drama|Thriller
37 | 1362::Garden of Finzi-Contini, The (Giardino dei Finzi-Contini, Il) (1970)::Drama
38 | 1414::Mother (1996)::Comedy
39 | 1419::Walkabout (1971)::Drama
40 | 1434::Stranger, The (1994)::Action
41 | 1436::Falling in Love Again (1980)::Comedy
42 | 1487::Selena (1997)::Drama|Musical
43 | 1489::Cats Don't Dance (1997)::Animation|Children's|Musical
44 | 1567::Last Time I Committed Suicide, The (1997)::Drama
45 | 1582::Wild America (1997)::Adventure|Children's
46 | 1622::Kicked in the Head (1997)::Comedy|Drama
47 | 1681::Mortal Kombat: Annihilation (1997)::Action|Adventure
48 | 1780::Ayn Rand: A Sense of Life (1997)::Documentary
49 | 1801::Man in the Iron Mask, The (1998)::Action|Drama|Romance
50 | 1811::Niagara, Niagara (1997)::Drama
51 | 1829::Chinese Box (1997)::Drama|Romance
52 | 1835::City of Angels (1998)::Romance
53 | 1883::Bulworth (1998)::Comedy
54 | 1911::Doctor Dolittle (1998)::Comedy
55 | 1960::Last Emperor, The (1987)::Drama|War
56 | 1992::Child's Play 2 (1990)::Horror
57 | 2011::Back to the Future Part II (1989)::Comedy|Sci-Fi
58 | 2019::Seven Samurai (The Magnificent Seven) (Shichinin no samurai) (1954)::Action|Drama
59 | 2035::Blackbeard's Ghost (1968)::Children's|Comedy
60 | 2042::D2: The Mighty Ducks (1994)::Children's|Comedy
61 | 2214::Number Seventeen (1932)::Thriller
62 | 2236::Simon Birch (1998)::Drama
63 | 2318::Happiness (1998)::Comedy
64 | 2344::Runaway Train (1985)::Action|Adventure|Drama|Thriller
65 | 2415::Violets Are Blue... (1986)::Drama|Romance
66 | 2448::Virus (1999)::Horror|Sci-Fi
67 | 2473::Soul Man (1986)::Comedy
68 | 2565::King and I, The (1956)::Musical
69 | 2594::Open Your Eyes (Abre los ojos) (1997)::Drama|Romance|Sci-Fi
70 | 2833::Lucie Aubrac (1997)::Romance|War
71 | 2850::Public Access (1993)::Drama|Thriller
72 | 2863::Hard Day's Night, A (1964)::Comedy|Musical
73 | 2885::Guinevere (1999)::Drama|Romance
74 | 2906::Random Hearts (1999)::Drama|Romance
75 | 2918::Ferris Bueller's Day Off (1986)::Comedy
76 | 2940::Gilda (1946)::Film-Noir
77 | 2970::Fitzcarraldo (1982)::Adventure|Drama
78 | 3051::Anywhere But Here (1999)::Drama
79 | 3197::Presidio, The (1988)::Action
80 | 3221::Draughtsman's Contract, The (1982)::Drama
81 | 3247::Sister Act (1992)::Comedy|Crime
82 | 3270::Cutting Edge, The (1992)::Drama
83 | 3299::Hanging Up (2000)::Comedy|Drama
84 | 3413::Impact (1949)::Crime|Drama
85 | 3429::Creature Comforts (1990)::Animation|Comedy
86 | 3442::Band of the Hand (1986)::Action
87 | 3464::Solar Crisis (1993)::Sci-Fi|Thriller
88 | 3484::Skulls, The (2000)::Thriller
89 | 3500::Mr. Saturday Night (1992)::Comedy|Drama
90 | 3643::Fighting Seabees, The (1944)::Action|Drama|War
91 | 3665::Curse of the Puppet Master (1998)::Horror|Sci-Fi|Thriller
92 | 3691::Private School (1983)::Comedy
93 | 3786::But I'm a Cheerleader (1999)::Comedy
94 | 3943::Bamboozled (2000)::Comedy
95 |
--------------------------------------------------------------------------------
/policy/inv_ff_history.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class InvariantFFHist(nn.Module):
6 | def __init__(
7 | self,
8 | embedding_dim,
9 | hidden_dim,
10 | problem,
11 | opts,
12 | tanh_clipping=None,
13 | mask_inner=None,
14 | mask_logits=None,
15 | n_encode_layers=None,
16 | normalization="batch",
17 | checkpoint_encoder=False,
18 | shrink_size=None,
19 | num_actions=4,
20 | n_heads=None,
21 | encoder=None,
22 | ):
23 | super(InvariantFFHist, self).__init__()
24 |
25 | self.embedding_dim = embedding_dim
26 | self.decode_type = None
27 | self.num_actions = 16 if opts.problem != "adwords" else 19
28 | self.problem = problem
29 | self.model_name = "inv-ff-hist"
30 | self.ff = nn.Sequential(
31 | nn.Linear(self.num_actions, 100),
32 | nn.ReLU(),
33 | nn.Linear(100, 100),
34 | nn.ReLU(),
35 | nn.Linear(100, 1),
36 | )
37 |
38 | def forward(self, x, opts, optimizer, baseline, return_pi=True):
39 |
40 | _log_p, pi, cost = self._inner(x, opts)
41 |
42 | ll, e = self._calc_log_likelihood(_log_p, pi, None)
43 | if return_pi:
44 | return -cost, ll, pi, e
45 |
46 | return -cost, ll, e
47 |
48 | def _calc_log_likelihood(self, _log_p, a, mask):
49 |
50 | # Get log_p corresponding to selected actions
51 | entropy = -(_log_p * _log_p.exp()).sum(2).sum(1).mean()
52 | log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1)
53 |
54 | # Optional: mask out actions irrelevant to objective so they do not get reinforced
55 | if mask is not None:
56 | log_p[mask] = 0
57 | if not (log_p > -1e8).data.all():
58 | print(log_p)
59 | assert (
60 | log_p > -1e8
61 | ).data.all(), "Logprobs should not be -inf, check sampling procedure!"
62 | # Calculate log_likelihood
63 | return log_p.sum(1), entropy
64 |
65 | def _inner(self, input, opts):
66 |
67 | outputs = []
68 | sequences = []
69 | state = self.problem.make_state(input, opts.u_size, opts.v_size, opts)
70 |
71 | # Perform decoding steps
72 | i = 1
73 | while not (state.all_finished()):
74 | mask = state.get_mask()
75 | state.get_current_weights(mask)
76 | s, mask = state.get_curr_state(self.model_name)
77 | pi = self.ff(s).reshape(state.batch_size, state.u_size + 1)
78 | # Select the indices of the next nodes in the sequences, result (batch_size) long
79 | selected, p = self._select_node(
80 | pi, mask.bool()
81 | ) # Squeeze out steps dimension
82 | # entropy += torch.sum(p * (p.log()), dim=1)
83 | state = state.update((selected)[:, None])
84 | outputs.append(p)
85 | sequences.append(selected)
86 | i += 1
87 | # Collected lists, return Tensor
88 | return (
89 | torch.stack(outputs, 1),
90 | torch.stack(sequences, 1),
91 | state.size,
92 | )
93 |
94 | def _select_node(self, probs, mask):
95 | assert (probs == probs).all(), "Probs should not contain any nans"
96 | probs[mask] = -1e8
97 | p = torch.log_softmax(probs, dim=1)
98 | # print(p)
99 | if self.decode_type == "greedy":
100 | _, selected = p.max(1)
101 | # assert not mask.gather(
102 | # 1, selected.unsqueeze(-1)
103 | # ).data.any(), "Decode greedy: infeasible action has maximum probability"
104 |
105 | elif self.decode_type == "sampling":
106 | selected = p.exp().multinomial(1).squeeze(1)
107 | # Check if sampling went OK, can go wrong due to bug on GPU
108 | # See https://discuss.pytorch.org/t/bad-behavior-of-multinomial-function/10232
109 | # while mask.gather(1, selected.unsqueeze(-1)).data.any():
110 | # print("Sampled bad values, resampling!")
111 | # selected = p.exp().multinomial(1).squeeze(1)
112 |
113 | else:
114 | assert False, "Unknown decode type"
115 | return selected, p
116 |
117 | def set_decode_type(self, decode_type, temp=None):
118 | self.decode_type = decode_type
119 | if temp is not None: # Do not change temperature if not provided
120 | self.temp = temp
121 |
--------------------------------------------------------------------------------
/data/MovieLense/users.txt:
--------------------------------------------------------------------------------
1 | 98,F,35,7,33547
2 | 4628,F,56,6,23225
3 | 250,M,35,16,11229
4 | 3530,M,25,0,77058
5 | 4749,M,45,0,23462
6 | 3021,F,45,15,78750
7 | 4192,M,35,17,77521
8 | 2530,M,18,0,42420
9 | 932,F,35,6,97838
10 | 5380,M,18,4,01125
11 | 2884,F,50,3,53222
12 | 3273,M,35,7,11201
13 | 217,M,18,4,22903
14 | 5012,M,50,15,48103
15 | 4419,F,45,12,15801
16 | 5828,M,25,17,90048
17 | 311,M,18,4,31201
18 | 2292,M,50,15,13039
19 | 4068,M,56,3,89185
20 | 1664,F,18,4,76707
21 | 1406,M,25,4,91911
22 | 572,M,18,4,61801
23 | 3222,M,35,7,10583
24 | 4398,F,18,4,43612
25 | 4393,M,18,4,43402
26 | 5533,F,50,2,10025
27 | 3291,F,25,7,27516
28 | 1993,M,35,17,99508
29 | 2061,M,56,1,44107
30 | 5725,M,18,2,10018
31 | 1351,F,25,17,98034
32 | 2111,M,56,13,32806
33 | 665,M,18,2,13317
34 | 2584,F,56,1,94563
35 | 761,M,18,7,99945
36 | 4332,F,25,3,60606
37 | 5168,M,35,7,01720
38 | 6038,F,56,1,14706
39 | 1452,F,56,13,90732
40 | 708,M,25,0,37042
41 | 703,F,56,1,44074
42 | 5314,F,25,0,94107
43 | 5145,M,35,7,77565-2332
44 | 3376,M,35,16,10803
45 | 2673,M,35,7,77040
46 | 2037,M,56,14,02468
47 | 2696,M,25,7,24210
48 | 160,M,35,7,13021
49 | 2160,F,18,19,46556
50 | 4775,M,25,17,23462
51 | 2714,M,18,4,96815
52 | 5258,M,35,4,58369
53 | 2133,F,1,10,01607
54 | 2759,M,25,7,36102
55 | 5216,M,18,2,81291
56 | 5215,F,56,6,91941
57 | 5027,F,25,14,05345
58 | 3969,M,56,7,12345
59 | 3633,M,35,18,60441
60 | 4463,M,50,7,07405
61 | 947,M,18,0,21015
62 | 2128,F,45,1,46737
63 | 3407,F,45,5,60050
64 | 2819,M,45,14,44122
65 | 4244,M,35,17,48146
66 | 5525,F,1,10,55311
67 | 341,F,56,13,92119
68 | 4383,F,18,4,93940
69 | 4527,M,18,0,38111
70 | 4525,M,25,14,10021
71 | 421,F,45,3,55125
72 | 3787,F,25,0,94703
73 | 4991,F,56,14,91791
74 | 2488,F,45,2,31804
75 | 4178,F,50,16,08033
76 | 5904,F,45,12,954025
77 | 4365,F,50,16,46350
78 | 2381,F,35,9,13066
79 | 1844,F,18,4,53711
80 | 1493,F,25,14,60657
81 | 4547,M,18,12,02115
82 | 3552,F,35,2,02879
83 | 5309,M,25,0,02478
84 | 3883,M,50,16,78411
85 | 2930,M,50,17,45420
86 | 5159,M,35,7,37027
87 | 2686,M,25,0,85283
88 | 94,M,25,17,28601
89 | 1549,M,35,6,20901
90 | 5052,M,35,16,02536
91 | 5863,F,25,14,89511
92 | 5804,F,35,7,94306
93 | 4880,F,45,9,95030
94 | 3642,F,25,7,98188
95 | 3068,F,35,9,66204
96 | 1305,M,45,20,94121
97 | 1307,M,56,13,75703
98 | 653,M,56,13,92660
99 | 4230,M,56,13,34235
100 | 318,F,56,13,55104
101 | 317,M,35,7,38555
102 | 4065,F,45,5,48341
103 | 951,M,45,2,10009
104 | 3388,F,25,1,92646
105 | 833,M,35,7,46825
106 | 574,M,25,19,90214
107 | 356,M,56,20,55101
108 | 3225,M,18,0,15237
109 | 559,F,25,7,60422
110 | 4259,M,35,17,45380
111 | 5899,M,35,17,30024
112 | 2422,F,50,9,02865
113 | 3324,M,35,17,48083
114 | 535,M,35,6,95370
115 | 1431,M,56,6,91011
116 | 1488,M,25,7,94538
117 | 1226,M,18,20,07087
118 | 2234,M,56,1,85718
119 | 3495,M,18,4,94608
120 | 4558,M,35,14,48217
121 | 4554,F,25,9,48306
122 | 2373,M,45,7,94109
123 | 4943,F,50,0,55401
124 | 3133,M,18,4,33720
125 | 3897,M,56,17,92069
126 | 6034,M,25,14,94117
127 | 2908,M,18,2,74403
128 | 1459,F,25,0,80027
129 | 2923,M,35,12,44256
130 | 3176,M,25,17,95129
131 | 1713,M,25,0,47421
132 | 1751,M,35,20,77024
133 | 1282,M,50,1,98541
134 | 1363,M,50,7,01915
135 | 383,F,25,7,78757
136 | 226,M,35,1,94518
137 | 2502,M,56,13,49423
138 | 1250,M,50,0,10573
139 | 1126,M,56,16,13031
140 | 5810,M,50,16,95070
141 | 4614,F,45,1,10469
142 | 821,M,18,2,14456
143 | 2527,M,35,15,36109
144 | 2893,M,25,1,94025
145 | 4,M,45,7,02460
146 | 906,M,1,10,71106
147 | 907,F,45,6,92117
148 | 1318,M,18,16,94014
149 | 3616,M,45,0,33759
150 | 1534,M,50,12,12184
151 | 3234,F,1,0,90012
152 | 1615,F,45,7,10022
153 | 640,M,18,4,47406
154 | 4464,F,50,1,62052
155 | 5409,M,45,1,60091
156 | 4264,M,45,1,60102
157 | 320,M,35,6,99516
158 | 4871,M,18,12,04096
159 | 3188,M,56,13,55344
160 | 5529,M,18,6,05273
161 | 2207,M,18,4,07508
162 | 2336,M,35,18,13146
163 | 4992,F,25,2,55106
164 | 3275,M,35,7,10028
165 | 3488,M,45,17,80228
166 | 1801,F,25,6,54729
167 | 4548,M,18,2,27514
168 | 1070,F,45,2,20036
169 | 2977,F,18,4,48104
170 | 5117,F,25,0,70118
171 | 5116,M,35,6,95521
172 | 1783,M,18,4,10027
173 | 4110,F,45,7,41722
174 | 2640,M,35,12,78230
175 | 2644,M,18,15,45215
176 | 5174,M,45,1,45208
177 | 89,F,56,9,85749
178 | 277,F,35,1,98126
179 | 4665,M,35,2,03431
180 | 2516,F,45,6,37909
181 | 1157,M,45,7,30253
182 | 4741,M,35,7,15203
183 | 158,M,56,7,28754
184 | 1708,M,25,0,64106
185 | 1927,M,35,7,94806
186 | 4881,M,45,16,55406
187 | 1582,M,18,0,10705
188 | 213,F,18,4,01609
189 | 197,M,18,14,10023
190 | 315,F,56,1,55105
191 | 2295,M,45,11,72316
192 | 5207,M,25,0,98102
193 | 3623,M,35,11,43209
194 | 2467,M,1,10,44224
195 | 5439,M,56,13,19119
196 | 1621,M,25,4,11423
197 | 417,F,25,0,50613
198 | 1384,M,18,10,98058
199 | 1385,F,25,5,90262
200 | 5465,F,25,0,10019
201 |
--------------------------------------------------------------------------------
/policy/greedy_matching.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch_geometric.utils import subgraph, to_networkx
4 | from networkx.algorithms.matching import max_weight_matching
5 | from torch_geometric.data import Data
6 |
7 |
8 | class GreedyMatching(nn.Module):
9 | def __init__(
10 | self,
11 | embedding_dim,
12 | hidden_dim,
13 | problem,
14 | opts,
15 | tanh_clipping=None,
16 | mask_inner=None,
17 | mask_logits=None,
18 | n_encode_layers=None,
19 | normalization=None,
20 | checkpoint_encoder=False,
21 | shrink_size=None,
22 | num_actions=None,
23 | n_heads=None,
24 | encoder=None,
25 | ):
26 | super(GreedyMatching, self).__init__()
27 | self.decode_type = None
28 | self.problem = problem
29 | self.model_name = "greedy-m"
30 |
31 | def forward(self, x, opts, optimizer, baseline, return_pi=False):
32 | state = self.problem.make_state(x, opts.u_size, opts.v_size, opts)
33 | t = opts.threshold
34 | sequences = []
35 | batch_size = opts.batch_size
36 | graph_size = opts.u_size + opts.v_size + 1
37 | i = 1
38 | while not (state.all_finished()):
39 | step_size = state.i + 1
40 | mask = state.get_mask()
41 | state.get_current_weights(mask).clone()
42 | mask = state.get_mask()
43 | if i <= int(
44 | t * opts.v_size
45 | ): # Skip matching if less than t fraction of V nodes arrived
46 | selected = torch.zeros(
47 | batch_size, dtype=torch.int64, device=opts.device
48 | )
49 | state = state.update(selected[:, None])
50 | sequences.append(selected)
51 | i += 1
52 | continue
53 | nodes = torch.cat(
54 | (
55 | torch.arange(1, opts.u_size + 1, device=opts.device),
56 | state.idx[:i] + opts.u_size + 1,
57 | )
58 | )
59 | subgraphs = (
60 | (nodes.unsqueeze(0).expand(batch_size, step_size - 1))
61 | + torch.arange(
62 | 0, batch_size * graph_size, graph_size, device=opts.device
63 | ).unsqueeze(1)
64 | ).flatten() # The nodes of the current subgraphs
65 | graph_weights = state.get_graph_weights()
66 | edge_i, weights = subgraph(
67 | subgraphs,
68 | state.graphs.edge_index,
69 | graph_weights.unsqueeze(1),
70 | relabel_nodes=True,
71 | )
72 | match_sol = torch.tensor(
73 | list(
74 | max_weight_matching(
75 | to_networkx(
76 | Data(subgraphs, edge_index=edge_i, edge_attr=weights),
77 | to_undirected=True,
78 | )
79 | )
80 | ),
81 | device=opts.device,
82 | )
83 |
84 | edges_sol = torch.cat(
85 | (
86 | torch.arange(0, opts.u_size, device=opts.device)
87 | .unsqueeze(0)
88 | .expand(batch_size, -1)
89 | .unsqueeze(2),
90 | torch.ones(batch_size, opts.u_size, 1) * (opts.u_size + i - 1),
91 | ),
92 | dim=2,
93 | )
94 | match_sol = match_sol.sort(dim=-1)[0]
95 | offset = torch.arange(0, batch_size * (opts.u_size + i), opts.u_size + i)[
96 | :, None, None
97 | ]
98 | edges_sol = edges_sol + offset
99 | in_sol = edges_sol.unsqueeze(1) == match_sol[None, :, :, None].expand(
100 | batch_size, -1, -1, opts.u_size
101 | ).transpose(-1, -2)
102 | in_sol = in_sol * (1 - mask[:, None, 1:, None])
103 | in_sol = (
104 | in_sol.prod(-1).sum(-1).float().unsqueeze(1)
105 | * match_sol[None, :, :].transpose(1, 2)
106 | ).sum(-1)
107 | skip_sol = in_sol.sum(-1)
108 | in_sol = in_sol[:, 0] - offset.reshape(-1)
109 | in_sol[skip_sol == 0] = -1
110 | selected = (in_sol + 1).type(torch.int64)
111 | # print(selected, match_sol)
112 | state = state.update(selected[:, None])
113 |
114 | sequences.append(selected)
115 | i += 1
116 | if return_pi:
117 | return -state.size, None, torch.stack(sequences, 1), None
118 | return -state.size, torch.stack(sequences, 1), None
119 |
120 | def set_decode_type(self, decode_type, temp=None):
121 | self.decode_type = decode_type
122 | if temp is not None: # Do not change temperature if not provided
123 | self.temp = temp
124 |
--------------------------------------------------------------------------------
/policy/ff_model_invariant.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class InvariantFF(nn.Module):
6 | def __init__(
7 | self,
8 | embedding_dim,
9 | hidden_dim,
10 | problem,
11 | opts,
12 | tanh_clipping=None,
13 | mask_inner=None,
14 | mask_logits=None,
15 | n_encode_layers=None,
16 | normalization="batch",
17 | checkpoint_encoder=False,
18 | shrink_size=None,
19 | num_actions=4,
20 | n_heads=None,
21 | encoder=None,
22 | ):
23 |
24 | super(InvariantFF, self).__init__()
25 |
26 | self.embedding_dim = embedding_dim
27 | self.decode_type = None
28 | self.num_actions = 3 if opts.problem != "adwords" else 5
29 | self.is_bipartite = problem.NAME == "bipartite"
30 | self.problem = problem
31 | self.shrink_size = None
32 | self.ff = nn.Sequential(
33 | nn.Linear(self.num_actions, 100),
34 | nn.ReLU(),
35 | nn.Linear(100, 100),
36 | nn.ReLU(),
37 | nn.Linear(100, 1),
38 | )
39 | self.model_name = "inv-ff"
40 |
41 | # def init_weights(m):
42 | # if type(m) == nn.Linear:
43 | # torch.nn.init.xavier_uniform_(m.weight)
44 | # m.bias.data.fill_(0.0001)
45 |
46 | # self.ff.apply(init_weights)
47 |
48 | def forward(self, x, opts, optimizer, baseline, return_pi=False):
49 |
50 | _log_p, pi, cost = self._inner(x, opts)
51 |
52 | # cost, mask = self.problem.get_costs(input, pi)
53 | # Log likelyhood is calculated within the model since returning it per action does not work well with
54 | # DataParallel since sequences can be of different lengths
55 | ll, e = self._calc_log_likelihood(_log_p, pi, None)
56 | if return_pi:
57 | return -cost, ll, pi, e
58 | # print(ll)
59 | return -cost, ll, e
60 |
61 | def _calc_log_likelihood(self, _log_p, a, mask):
62 |
63 | # Get log_p corresponding to selected actions
64 | log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1)
65 | entropy = -(_log_p * _log_p.exp()).sum(2).sum(1).mean()
66 | # Optional: mask out actions irrelevant to objective so they do not get reinforced
67 | if mask is not None:
68 | log_p[mask] = 0
69 | if not (log_p > -10000).data.all():
70 | print(log_p)
71 | assert (
72 | log_p > -10000
73 | ).data.all(), "Logprobs should not be -inf, check sampling procedure!"
74 |
75 | # Calculate log_likelihood
76 | return log_p.sum(1), entropy
77 |
78 | def _inner(self, input, opts):
79 |
80 | outputs = []
81 | sequences = []
82 | state = self.problem.make_state(input, opts.u_size, opts.v_size, opts)
83 |
84 | i = 1
85 | while not (state.all_finished()):
86 | mask = state.get_mask()
87 | state.get_current_weights(mask)
88 | s, mask = state.get_curr_state(self.model_name)
89 |
90 | pi = self.ff(s).reshape(state.batch_size, state.u_size + 1)
91 | # Select the indices of the next nodes in the sequences, result (batch_size) long
92 | selected, p = self._select_node(pi, mask.bool())
93 |
94 | state = state.update((selected)[:, None])
95 | outputs.append(p)
96 | sequences.append(selected)
97 | i += 1
98 | # Collected lists, return Tensor
99 | return (
100 | torch.stack(outputs, 1),
101 | torch.stack(sequences, 1),
102 | state.size,
103 | )
104 |
105 | def _select_node(self, probs, mask):
106 | assert (probs == probs).all(), "Probs should not contain any nans"
107 | probs[mask] = -1e6
108 | p = torch.log_softmax(probs, dim=1)
109 | # print(p)
110 | if self.decode_type == "greedy":
111 | _, selected = p.max(1)
112 | # assert not mask.gather(
113 | # 1, selected.unsqueeze(-1)
114 | # ).data.any(), "Decode greedy: infeasible action has maximum probability"
115 |
116 | elif self.decode_type == "sampling":
117 | selected = p.exp().multinomial(1).squeeze(1)
118 | # Check if sampling went OK, can go wrong due to bug on GPU
119 | # See https://discuss.pytorch.org/t/bad-behavior-of-multinomial-function/10232
120 | # while mask.gather(1, selected.unsqueeze(-1)).data.any():
121 | # print("Sampled bad values, resampling!")
122 | # selected = probs.multinomial(1).squeeze(1)
123 |
124 | else:
125 | assert False, "Unknown decode type"
126 | return selected, p
127 |
128 | def set_decode_type(self, decode_type, temp=None):
129 | self.decode_type = decode_type
130 | if temp is not None: # Do not change temperature if not provided
131 | self.temp = temp
132 |
--------------------------------------------------------------------------------
/TUTORIAL.md:
--------------------------------------------------------------------------------
1 | # Tutorial
2 |
3 | Here we provide a brief tutorial about our [codebase](https://github.com/lyeskhalil/CORL). If you are already familiar with the codes and the framework, please skip this tutorial.
4 |
5 | ## 1. Bipartite Data Generation Code
6 |
7 |
8 | ### Supported Args
9 | ```
10 | problem: "adwords", "e-obm", "osbm"
11 | graph_family: "er", "ba", "triangular", "thick-z", "movielense", "gmission"
12 | weight_distribution: "normal", "power", "degree", "node-normal", "fixed-normal"
13 | weight_distribution_param: ...
14 | graph_family_parameter: ...
15 | ```
16 | Caution: We only list all the feasible parameters here, some combination may not be allowed (adwords + gmission), please check before using these arguments.
17 |
18 | ## 2. Networkx Guide
19 | ### Print information for a graph
20 | ```python
21 | print(G)
22 | #Graph named 'triangular_graph(3,9,-1.0)' with 12 nodes and 18 edges
23 | print(G.graph)
24 | #{'name': 'triangular_graph(3,9,-1.0)'}
25 | print(G.nodes)
26 | #[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
27 | print(G.nodes[1]) #{'bipartite': 0}
28 | print(G.edges)
29 | #[(0, 3), (0, 4), (0, 5), (1, 3), ..., (2, 6), (2, 7), (2, 8), (2, 9), (2, 10), (2, 11)]
30 | print(G.edges[0,3])
31 | #{'weight': 0.3958830486584797}
32 | ```
33 |
34 |
35 | ### Convert data type
36 | Convert the networkx graph to torch
37 | ```python
38 | def from_networkx(G):
39 | r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a
40 | :class:`torch_geometric.data.Data` instance"""
41 | pass
42 | ```
43 | Reveal useful information from the `TORCH_GEOMETRIC.DATA` class
44 | ```python
45 | print(data.edge_index) # Print the edge between nodes, the graph is directed (two edges needed for one pair of nodes)
46 | print(data.weight) # Print the weight
47 | print(data.bipartite) # Print the portion
48 | print(data.x) # Capacity of each resource
49 | print(data.y) # Optimal solution, (total_gain, action_array)
50 | print(data.edge_attr) # None for this project
51 | ```
52 |
53 | ## 3. Important Arguments for the OBM codebase
54 | ```shell
55 | 'load_path': 'path_to_pretrained_model', # Model path
56 | 'checkpoint_epochs': 0, # The frequency of routine checkpoint saving
57 | 'model': 'inv-ff-hist', # Model type
58 | ```
59 |
60 |
61 | ## 4. RL Model For OBM
62 | ### Apply for a new GPU node in SLURM cluster (skip this step on your own server)
63 | ``` shell
64 | srun -p gpu --gres=gpu:1 --mem=100g --nodelist=gpu02 --time=48:00:00 --pty bash -l
65 | ```
66 |
67 | ### Quick start - Data Generation
68 | ```shell
69 | python data/generate_data.py --problem adwords --dataset_size 1000 \
70 | --dataset_folder dataset/train/adwords_triangular_uniform_0.10.4_10by100/parameter_-1 \
71 | --u_size 10 --v_size 60 --graph_family triangular --weight_distribution uniform \
72 | --weight_distribution_param 0.1 0.4 --graph_family_parameter -1
73 |
74 | python data/generate_data.py --problem adwords --dataset_size 1 \
75 | --dataset_folder dataset/train/adwords_triangular_triangular_0.10.4_10by100/parameter_-1 \
76 | --u_size 3 --v_size 9 --graph_family ba --weight_distribution uniform \
77 | --weight_distribution_param 0 1 --graph_family_parameter 2
78 | ```
79 |
80 | ### Quick Start - Training
81 | ```shell
82 | python run.py --encoder mpnn --model inv-ff-hist --problem adwords --batch_size 100 --embedding_dim 30 --n_heads 1 --u_size 10 --v_size 60 \
83 | --n_epochs 20 --train_dataset dataset/train/adwords_triangular_uniform_0.10.4_10by60/parameter_-1 \
84 | --val_dataset dataset/val/adwords_triangular_uniform_0.10.4_10by60/parameter_-1 \
85 | --dataset_size 1000 --val_size 100 --checkpoint_epochs 0 --baseline exponential --lr_model 0.006 --lr_decay 0.97 \
86 | --output_dir saved_models --log_dir logs_dataset --n_encode_layers 1 \
87 | --save_dir saved_models/adwords_triangular_uniform_0.10.4_10by60/parameter_-1 \
88 | --graph_family_parameter -1 --exp_beta 0.8 --ent_rate 0.0006
89 | ```
90 |
91 | ### Quick start - Evaluation
92 | * Note: Please first train the model, if you already have the model, you can use the following code to validate the performance. Don't forget to replace the load_path and other arguments accordingly.
93 | ```shell
94 | python3 run_lite.py --encoder mpnn --model inv-ff-hist-switch --problem osbm --batch_size 100 --embedding_dim 30 --n_heads 1 --u_size 10 --v_size 60 --n_epochs 300 --train_dataset dataset/train/osbm_movielense_default_-1_10by60/parameter_-1 --val_dataset dataset/val/osbm_movielense_default_-1_10by60/parameter_-1 --dataset_size 20000 --val_size 1000 --checkpoint_epochs 0 --baseline exponential --lr_model 0.006 --lr_decay 0.97 --output_dir saved_models --log_dir logs_dataset --n_encode_layers 1 --save_dir saved_models/osbm_movielense_default_-1_10by60/parameter_-1 --graph_family_parameter -1 --exp_beta 0.8 --ent_rate 0.0006 --eval_only --no_tensorboard --load_path saved_models/inv-ff-hist/run_20220501T195416/best-model.pt --switch_lambda 0.0 --slackness 0.0 --max_reward 5.0
95 | ```
96 |
97 |
98 |
99 |
--------------------------------------------------------------------------------
/policy/ff_model_hist.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import math
4 |
5 |
6 | class FeedForwardModelHist(nn.Module):
7 | def __init__(
8 | self,
9 | embedding_dim,
10 | hidden_dim,
11 | problem,
12 | opts,
13 | tanh_clipping=None,
14 | mask_inner=None,
15 | mask_logits=None,
16 | n_encode_layers=None,
17 | normalization="batch",
18 | checkpoint_encoder=False,
19 | shrink_size=None,
20 | num_actions=4,
21 | n_heads=None,
22 | encoder=None,
23 | ):
24 |
25 | super(FeedForwardModelHist, self).__init__()
26 |
27 | self.embedding_dim = embedding_dim
28 | self.decode_type = None
29 | self.num_actions = (
30 | 5 * (opts.u_size + 1) + 8
31 | if opts.problem != "adwords"
32 | else 7 * (opts.u_size + 1) + 8
33 | )
34 | self.problem = problem
35 | self.model_name = "ff-hist"
36 | hidden_size = 100
37 | self.ff = nn.Sequential(
38 | nn.Linear(self.num_actions, hidden_size),
39 | nn.ReLU(),
40 | nn.Linear(hidden_size, hidden_size),
41 | nn.ReLU(),
42 | nn.Linear(hidden_size, hidden_size),
43 | nn.ReLU(),
44 | nn.Linear(hidden_size, opts.u_size + 1),
45 | )
46 |
47 | # self.ff.apply(init_weights)
48 | # self.init_parameters()
49 |
50 | def init_parameters(self):
51 | for name, param in self.named_parameters():
52 | stdv = 1.0 / math.sqrt(param.size(-1))
53 | param.data.uniform_(-stdv, stdv)
54 |
55 | def forward(self, x, opts, optimizer, baseline, return_pi=False):
56 |
57 | _log_p, pi, cost = self._inner(x, opts)
58 |
59 | # cost, mask = self.problem.get_costs(input, pi)
60 | # Log likelyhood is calculated within the model since returning it per action does not work well with
61 | # DataParallel since sequences can be of different lengths
62 | ll, e = self._calc_log_likelihood(_log_p, pi, None)
63 | if return_pi:
64 | return -cost, ll, pi, e
65 | # print(ll)
66 | return -cost, ll, e
67 |
68 | def _calc_log_likelihood(self, _log_p, a, mask):
69 |
70 | # Get log_p corresponding to selected actions
71 | # print(a[0, :])
72 | entropy = -(_log_p * _log_p.exp()).sum(2).sum(1).mean()
73 | log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1)
74 |
75 | # Optional: mask out actions irrelevant to objective so they do not get reinforced
76 | if mask is not None:
77 | log_p[mask] = 0
78 | if not (log_p > -10000).data.all():
79 | print(log_p)
80 | assert (
81 | log_p > -10000
82 | ).data.all(), "Logprobs should not be -inf, check sampling procedure!"
83 |
84 | # Calculate log_likelihood
85 | # print(_log_p)
86 | return log_p.sum(1), entropy
87 |
88 | def _inner(self, input, opts):
89 |
90 | outputs = []
91 | sequences = []
92 | state = self.problem.make_state(input, opts.u_size, opts.v_size, opts)
93 |
94 | i = 1.0
95 | while not (state.all_finished()):
96 | mask = state.get_mask()
97 | state.get_current_weights(mask)
98 |
99 | s, mask = state.get_curr_state(self.model_name)
100 | pi = self.ff(s)
101 | # Select the indices of the next nodes in the sequences, result (batch_size) long
102 | selected, p = self._select_node(pi, mask.bool())
103 | state = state.update((selected)[:, None])
104 | outputs.append(p)
105 | sequences.append(selected)
106 | i += 1.0
107 | # Collected lists, return Tensor
108 | return (
109 | torch.stack(outputs, 1),
110 | torch.stack(sequences, 1),
111 | state.size,
112 | )
113 |
114 | def _select_node(self, probs, mask):
115 | assert (probs == probs).all(), "Probs should not contain any nans"
116 |
117 | probs[mask] = -1e8
118 | p = torch.log_softmax(probs, dim=1)
119 | if self.decode_type == "greedy":
120 | _, selected = p.max(1)
121 | # assert not mask.gather(
122 | # 1, selected.unsqueeze(-1)
123 | # ).data.any(), "Decode greedy: infeasible action has maximum probability"
124 |
125 | elif self.decode_type == "sampling":
126 | selected = p.exp().multinomial(1).squeeze(1)
127 | # Check if sampling went OK, can go wrong due to bug on GPU
128 | # See https://discuss.pytorch.org/t/bad-behavior-of-multinomial-function/10232
129 | while mask.gather(1, selected.unsqueeze(-1)).data.any():
130 | print("Sampled bad values, resampling!")
131 | selected = p.exp().multinomial(1).squeeze(1)
132 |
133 | else:
134 | assert False, "Unknown decode type"
135 | return selected, p
136 |
137 | def set_decode_type(self, decode_type, temp=None):
138 | self.decode_type = decode_type
139 | if temp is not None: # Do not change temperature if not provided
140 | self.temp = temp
141 |
--------------------------------------------------------------------------------
/policy/ff_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class FeedForwardModel(nn.Module):
6 | def __init__(
7 | self,
8 | embedding_dim,
9 | hidden_dim,
10 | problem,
11 | opts,
12 | tanh_clipping=None,
13 | mask_inner=None,
14 | mask_logits=None,
15 | n_encode_layers=None,
16 | normalization="batch",
17 | checkpoint_encoder=False,
18 | shrink_size=None,
19 | num_actions=4,
20 | n_heads=None,
21 | encoder=None,
22 | ):
23 |
24 | super(FeedForwardModel, self).__init__()
25 |
26 | self.embedding_dim = embedding_dim
27 | self.decode_type = None
28 | self.num_actions = (
29 | 2 * (opts.u_size + 1)
30 | if opts.problem != "adwords"
31 | else 3 * (opts.u_size + 1)
32 | )
33 | self.is_bipartite = problem.NAME == "bipartite"
34 | self.problem = problem
35 | self.shrink_size = None
36 | hidden_size = 100
37 | self.ff = nn.Sequential(
38 | nn.Linear(self.num_actions, hidden_size),
39 | nn.ReLU(),
40 | nn.Linear(hidden_size, hidden_size),
41 | nn.ReLU(),
42 | nn.Linear(hidden_size, hidden_size),
43 | nn.ReLU(),
44 | nn.Linear(hidden_size, opts.u_size + 1),
45 | )
46 | self.model_name = "ff"
47 |
48 | def init_weights(m):
49 | if type(m) == nn.Linear:
50 | torch.nn.init.xavier_uniform_(m.weight)
51 | m.bias.data.fill_(0.0001)
52 |
53 | # self.ff.apply(init_weights)
54 |
55 | def forward(self, x, opts, optimizer, baseline, return_pi=False):
56 |
57 | _log_p, pi, cost = self._inner(x, opts)
58 |
59 | # cost, mask = self.problem.get_costs(input, pi)
60 | # Log likelyhood is calculated within the model since returning it per action does not work well with
61 | # DataParallel since sequences can be of different lengths
62 | ll, e = self._calc_log_likelihood(_log_p, pi, None)
63 | if return_pi:
64 | return -cost, ll, pi, e
65 | # print(ll)
66 | return -cost, ll, e
67 |
68 | def _calc_log_likelihood(self, _log_p, a, mask):
69 |
70 | entropy = -(_log_p * _log_p.exp()).sum(2).sum(1).mean()
71 | # Get log_p corresponding to selected actions
72 | log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1)
73 |
74 | # Optional: mask out actions irrelevant to objective so they do not get reinforced
75 | if mask is not None:
76 | log_p[mask] = 0
77 | if not (log_p > -10000).data.all():
78 | print(log_p.nonzero())
79 | assert (
80 | log_p > -10000
81 | ).data.all(), "Logprobs should not be -inf, check sampling procedure!"
82 |
83 | # Calculate log_likelihood
84 | # print(log_p.sum(1))
85 | return log_p.sum(1), entropy
86 |
87 | def _inner(self, input, opts):
88 |
89 | outputs = []
90 | sequences = []
91 | state = self.problem.make_state(input, opts.u_size, opts.v_size, opts)
92 |
93 | # step_context = 0
94 | # batch_size = state.ids.size(0)
95 | # Perform decoding steps
96 | i = 1
97 | # entropy = 0
98 | while not (state.all_finished()):
99 | mask = state.get_mask()
100 | state.get_current_weights(mask)
101 | s, mask = state.get_curr_state(self.model_name)
102 | # s = w
103 | pi = self.ff(s)
104 | # Select the indices of the next nodes in the sequences, result (batch_size) long
105 | selected, p = self._select_node(
106 | pi, mask.bool()
107 | ) # Squeeze out steps dimension
108 | # entropy += torch.sum(p * (p.log()), dim=1)
109 | state = state.update((selected)[:, None])
110 | outputs.append(p)
111 | sequences.append(selected)
112 | i += 1
113 | # Collected lists, return Tensor
114 | return (
115 | torch.stack(outputs, 1),
116 | torch.stack(sequences, 1),
117 | state.size,
118 | )
119 |
120 | def _select_node(self, probs, mask):
121 | assert (probs == probs).all(), "Probs should not contain any nans"
122 | probs[mask] = -1e6
123 | p = torch.log_softmax(probs, dim=1)
124 | if self.decode_type == "greedy":
125 | _, selected = p.max(1)
126 | # assert not mask.gather(
127 | # 1, selected.unsqueeze(-1)
128 | # ).data.any(), "Decode greedy: infeasible action has maximum probability"
129 |
130 | elif self.decode_type == "sampling":
131 | selected = p.exp().multinomial(1).squeeze(1)
132 | # Check if sampling went OK, can go wrong due to bug on GPU
133 | # See https://discuss.pytorch.org/t/bad-behavior-of-multinomial-function/10232
134 | # while mask.gather(1, selected.unsqueeze(-1)).data.any():
135 | # print("Sampled bad values, resampling!")
136 | # selected = probs.multinomial(1).squeeze(1)
137 |
138 | else:
139 | assert False, "Unknown decode type"
140 | return selected, p
141 |
142 | def set_decode_type(self, decode_type, temp=None):
143 | self.decode_type = decode_type
144 | if temp is not None: # Do not change temperature if not provided
145 | self.temp = temp
146 |
--------------------------------------------------------------------------------
/problem_state/obm_env.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import NamedTuple
3 |
4 | # from utils.boolmask import mask_long2bool, mask_long_scatter
5 |
6 |
7 | class StateBipartite(NamedTuple):
8 | # Fixed input
9 | graphs: torch.Tensor # full adjacency matrix of all graphs in a batch
10 | # adj: torch.Tensor # full adjacency matrix of all graphs in a batch
11 | weights: torch.Tensor
12 | u_size: torch.Tensor
13 | v_size: torch.Tensor
14 | batch_size: torch.Tensor
15 | # If this state contains multiple copies (i.e. beam search) for the same instance, then for memory efficiency
16 | # the loc and dist tensors are not kept multiple times, so we need to use the ids to index the correct rows.
17 | ids: torch.Tensor # Keeps track of original fixed data index of rows
18 |
19 | # State
20 | # curr_edge: torch.Tensor # current edge number
21 | matched_nodes: torch.Tensor # Keeps track of nodes that have been matched
22 | # picked_edges: torch.Tensor
23 | size: torch.Tensor # size of current matching
24 | i: torch.Tensor # Keeps track of step
25 | # mask: torch.Tensor # mask for each step
26 |
27 | @property
28 | def visited(self):
29 | if self.visited_.dtype == torch.uint8:
30 | return self.matched_nodes
31 |
32 | def __getitem__(self, key):
33 | if torch.is_tensor(key) or isinstance(
34 | key, slice
35 | ): # If tensor, idx all tensors by this tensor:
36 | return self._replace(
37 | ids=self.ids[key],
38 | graphs=self.graphs[key],
39 | matched_nodes=self.matched_nodes[key],
40 | size=self.size[key],
41 | u_size=self.u_size[key],
42 | v_size=self.v_size[key],
43 | )
44 | # return super(StateBipartite, self).__getitem__(key)
45 | return self[key]
46 |
47 | @staticmethod
48 | def initialize(
49 | input, u_size, v_size, num_edges, visited_dtype=torch.uint8,
50 | ):
51 |
52 | batch_size = len(input)
53 | # size = torch.zeros(batch_size, 1, dtype=torch.long, device=graphs.device)
54 | return StateBipartite(
55 | graphs=input,
56 | u_size=torch.tensor([u_size], device=input.device),
57 | v_size=torch.tensor([v_size], device=input.device),
58 | weights=None,
59 | batch_size=torch.tensor([batch_size], device=input.device),
60 | ids=torch.arange(batch_size, dtype=torch.int64, device=input.device)[
61 | :, None
62 | ], # Add steps dimension
63 | # Keep visited with depot so we can scatter efficiently (if there is an action for depot)
64 | matched_nodes=( # Visited as mask is easier to understand, as long more memory efficient
65 | torch.zeros(
66 | batch_size, 1, u_size + 1, dtype=torch.uint8, device=input.device
67 | )
68 | ),
69 | size=torch.zeros(batch_size, 1, device=input.device),
70 | i=torch.ones(1, dtype=torch.int64, device=input.device)
71 | * (u_size + 1), # Vector with length num_steps
72 | )
73 |
74 | def get_final_cost(self):
75 |
76 | assert self.all_finished()
77 | # assert self.visited_.
78 |
79 | return self.size
80 |
81 | def update(self, selected):
82 | # Update the state
83 | nodes = self.matched_nodes.squeeze(1).scatter_(-1, selected, 1)
84 |
85 | total_weights = self.size + (selected != 0).long()
86 |
87 | return self._replace(matched_nodes=nodes, size=total_weights, i=self.i + 1,)
88 |
89 | def all_finished(self):
90 | # Exactly n steps
91 | return (self.i.item() - (self.u_size.item() + 1)) >= self.v_size
92 |
93 | def get_current_node(self):
94 | return self.i.item()
95 |
96 | def get_mask(self):
97 | """
98 | Returns a mask vector which includes only nodes in U that can matched.
99 | That is, neighbors of the incoming node that have not been matched already.
100 | """
101 | mask = self.graphs[:, self.i.item(), : self.u_size.item() + 1]
102 |
103 | self.matched_nodes[
104 | :, 0
105 | ] = 0 # node that represents not being matched to anything can be matched to more than once
106 | return (
107 | self.matched_nodes.squeeze(1) + mask > 0
108 | ).long() # Hacky way to return bool or uint8 depending on pytorch version
109 |
110 | # def get_nn(self, k=None):
111 | # # Insert step dimension
112 | # # Nodes already visited get inf so they do not make it
113 | # if k is None:
114 | # k = self.loc.size(-2) - self.i.item() # Number of remaining
115 | # return (
116 | # self.dist[self.ids, :, :] + self.visited.float()[:, :, None, :] * 1e6
117 | # ).topk(k, dim=-1, largest=False)[1]
118 |
119 | # def get_nn_current(self, k=None):
120 | # assert (
121 | # False
122 | # ), "Currently not implemented, look into which neighbours to use in step 0?"
123 | # # Note: if this is called in step 0, it will have k nearest neighbours to node 0, which may not be desired
124 | # # so it is probably better to use k = None in the first iteration
125 | # if k is None:
126 | # k = self.loc.size(-2)
127 | # k = min(k, self.loc.size(-2) - self.i.item()) # Number of remaining
128 | # return (self.dist[self.ids, self.prev_a] + self.visited.float() * 1e6).topk(
129 | # k, dim=-1, largest=False
130 | # )[1]
131 |
132 | # def construct_solutions(self, actions):
133 | # return actions
134 |
--------------------------------------------------------------------------------
/IPsolvers/IPsolver.py:
--------------------------------------------------------------------------------
1 | import gurobipy as gp
2 | from gurobipy import GRB
3 | import itertools
4 | import numpy as np
5 |
6 |
7 | def get_data_adwords(u_size, v_size, adjacency_matrix):
8 | """
9 | pre-process the data for groubi for the adwords problem
10 | Reads data from the specfied file and writes the graph tensor into multu dict of the following form:
11 | combinations, ms= gp.multidict({
12 | ('u1','v1'):10,
13 | ('u1','v2'):13,
14 | ('u2','v1'):9,
15 | ('u2','v2'):3
16 | })
17 | """
18 |
19 | adj_dic = {}
20 |
21 | for v, u in itertools.product(range(v_size), range(u_size)):
22 | adj_dic[(v, u)] = adjacency_matrix[u, v]
23 |
24 | return gp.multidict(adj_dic)
25 |
26 |
27 | def get_data_osbm(u_size, v_size, adjacency_matrix, prefrences):
28 | """
29 | pre-process the data for groubi for the osbm problem
30 | """
31 |
32 | adj_dic = {}
33 | w = {}
34 |
35 | for v, u in itertools.product(range(v_size), range(u_size)):
36 | adj_dic[(v, u)] = adjacency_matrix[v][u]
37 | _, dic = gp.multidict(adj_dic)
38 |
39 | for i, j in itertools.product(
40 | range(prefrences.shape[0]), range(prefrences.shape[1])
41 | ):
42 | w[(j, i)] = prefrences[i][j]
43 | return dic, w
44 |
45 |
46 | def solve_adwords(u_size, v_size, adjacency_matrix, budgets):
47 | try:
48 | m = gp.Model("adwords")
49 | # m.Params.LogToConsole = 0
50 | m.Params.timeLimit = 30
51 |
52 | _, dic = get_data_adwords(u_size, v_size, adjacency_matrix)
53 |
54 | # add variable
55 | x = m.addVars(v_size, u_size, vtype="B", name="(u,v) pairs")
56 |
57 | # set constraints
58 | m.addConstrs((x.sum(v, "*") <= 1 for v in range(v_size)), "V")
59 | m.addConstrs((x.prod(dic, "*", u) <= budgets[u] for u in range(u_size)), "U")
60 |
61 | # set the objective
62 | m.setObjective(x.prod(dic), GRB.MAXIMIZE)
63 | m.optimize()
64 |
65 | solution = np.zeros(v_size).tolist()
66 | for v in range(v_size):
67 | u = 0
68 | for nghbr_v in x.select(v, "*"):
69 | if nghbr_v.getAttr("x") == 1:
70 | solution[v] = u + 1
71 | break
72 | u += 1
73 | return m.objVal, solution
74 |
75 | except gp.GurobiError as e:
76 | print("Error code " + str(e.errno) + ": " + str(e))
77 | except AttributeError:
78 | print("Encountered an attribute error")
79 |
80 |
81 | def solve_submodular_matching(
82 | u_size, v_size, adjacency_matrix, r_v, movie_features, preferences, num_incoming
83 | ):
84 | try:
85 | m = gp.Model("submatching")
86 | m.Params.LogToConsole = 0
87 | # 15 is the fixed number of genres from the movielens dataset
88 | genres = 15
89 | dic, weight_dic = get_data_osbm(u_size, v_size, adjacency_matrix, preferences)
90 |
91 | # add variable for each edge (u,v), where v is the user and u is the movie
92 | x = m.addVars(v_size, u_size, vtype="B", name="(u,v) pairs")
93 |
94 | # set the variable to zero for edges that do not exist in the graph
95 | for key in x:
96 | x[key] = 0 if dic[key] == 0 else x[key]
97 |
98 | # create variable for each (genre, user) pair
99 | gamma = m.addVars(genres, v_size, vtype="B", name="gamma")
100 |
101 | # A is |genres| by |V| matrix containg the total number of edges going from u to genere g at index (g,u)
102 | A = m.addVars(genres, v_size)
103 |
104 | # first set all variables in A to zero
105 | for key in A:
106 | A[key] = 0
107 |
108 | for z, v in itertools.product(range(genres), range(v_size)):
109 | for u in range(u_size):
110 | if movie_features[u][z] == 1.0: # if u belongs to genre z
111 | A[(z, v)] += x[(v, u)]
112 |
113 | # set constraints
114 | r_v1 = {}
115 | for i, n in enumerate(r_v):
116 | r_v1[i] = r_v[n]
117 | m.addConstrs((x.sum(v, "*") <= len(r_v1[v]) for v in r_v1), "const1")
118 | m.addConstrs((x.sum("*", u) <= 1 for u in range(u_size)), "const2")
119 | m.addConstrs(
120 | (
121 | gamma[(z, v)] <= A[(z, v)]
122 | for z, v in itertools.product(range(genres), range(v_size))
123 | ),
124 | "const3",
125 | )
126 | m.addConstrs(
127 | (
128 | gamma[(z, v)] <= 1
129 | for z, v in itertools.product(range(genres), range(v_size))
130 | ),
131 | "const4",
132 | )
133 |
134 | # give each gamma variable a weight based on the user preferences and optimiza the sum
135 | m.setObjective(gamma.prod(weight_dic), GRB.MAXIMIZE)
136 | m.optimize()
137 | solution = np.zeros(num_incoming).tolist()
138 | sol_dict = dict(x)
139 | s = sorted(sol_dict.keys(), key=(lambda k: k[0]))
140 | matched = 0
141 | for i, p in enumerate(r_v1):
142 | matched = 0
143 | inds = r_v1[p]
144 | idx = i * u_size
145 | nodes = s[idx : idx + u_size]
146 | for j, u in enumerate(nodes):
147 | if (type(x[u]) is not int) and x[u].x == 1:
148 | solution[inds[matched]] = u[1] + 1
149 | matched += 1
150 |
151 | for j in range(len(inds) - matched):
152 | solution[inds[j + matched]] = 0
153 |
154 | return m.objVal, solution
155 |
156 | except gp.GurobiError as e:
157 | print("Error code " + str(e.errno) + ": " + str(e))
158 | except AttributeError:
159 | print("Encountered an attribute error")
160 |
161 |
162 | if __name__ == "__main__":
163 |
164 | # osbm exmaple:
165 |
166 | # {v_id : freq}
167 | # r_v = {0: [1, 2], 1: [0]}
168 |
169 | # # 3 genres (each column), 3 movies (each row)
170 | # movie_features = [
171 | # [0.0, 0.0, 1.0],
172 | # [0.0, 0.0, 1.0],
173 | # [0.0, 1.0, 1.0]
174 | # ]
175 |
176 | # # user preferences V by |genres|
177 | # preferences = np.array([
178 | # [0.999, 0.4, 0.222],
179 | # [1, 1, 1]
180 | # ])
181 |
182 | # adjacency_matrix = np.array([
183 | # [1, 2, 0],
184 | # [0, 1, 0],
185 | # [4, 0, 0]
186 | # ])
187 |
188 | # print(solve_submodular_matching(3, 2, adjacency_matrix, r_v, movie_features, preferences, 3))
189 |
190 | # adwords exmaple:
191 |
192 | # V by U matrix
193 | adjacency_matrix = np.array([[1, 2, 0], [0, 1, 0], [4, 0, 0]])
194 |
195 | budgets = [3, 1, 4]
196 | print(solve_adwords(3, 3, adjacency_matrix, budgets))
197 |
--------------------------------------------------------------------------------
/policy/ff_supervised.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch.nn import DataParallel
5 |
6 |
7 | def set_decode_type(model, decode_type):
8 | if isinstance(model, DataParallel):
9 | model = model.module
10 | model.set_decode_type(decode_type)
11 |
12 |
13 | def get_loss(log_p, y, optimizers, w, opts):
14 | # The cross entrophy loss of v_1, ..., v_{t-1} (this is a batch_size by 1 vector)
15 | # total_loss = torch.zeros(y.shape)
16 |
17 | # Calculate loss of v_t
18 | # print(log_p, y)
19 | loss_t = F.cross_entropy(log_p, y.long(), weight=w)
20 |
21 | # Update the loss for the whole graph
22 | # total_loss += loss_t
23 | loss = loss_t
24 |
25 | # Perform backward pass and optimization step
26 | # optimizers[0].zero_grad()
27 | # loss.backward()
28 | # optimizers[0].step()
29 | return loss
30 |
31 |
32 | class SupervisedFFModel(nn.Module):
33 | def __init__(
34 | self,
35 | embedding_dim,
36 | hidden_dim,
37 | problem,
38 | opts,
39 | tanh_clipping=None,
40 | mask_inner=None,
41 | mask_logits=None,
42 | n_encode_layers=None,
43 | normalization="batch",
44 | checkpoint_encoder=False,
45 | shrink_size=None,
46 | num_actions=4,
47 | n_heads=None,
48 | encoder=None,
49 | ):
50 | super(SupervisedFFModel, self).__init__()
51 |
52 | self.embedding_dim = embedding_dim
53 | self.decode_type = None
54 | self.num_actions = self.num_actions = (
55 | 5 * (opts.u_size + 1) + 8
56 | if opts.problem != "adwords"
57 | else 7 * (opts.u_size + 1) + 8
58 | )
59 | self.is_bipartite = problem.NAME == "bipartite"
60 | self.problem = problem
61 | self.shrink_size = None
62 | self.model_name = "ff-supervised"
63 | self.ff = nn.Sequential(
64 | nn.Linear(self.num_actions, 100),
65 | nn.ReLU(),
66 | nn.Linear(100, 100),
67 | nn.ReLU(),
68 | nn.Linear(100, 100),
69 | nn.ReLU(),
70 | nn.Linear(100, opts.u_size + 1),
71 | )
72 |
73 | def forward(
74 | self,
75 | input,
76 | opt_match,
77 | opts,
78 | optimizer,
79 | training=False,
80 | baseline=None,
81 | return_pi=False,
82 | ):
83 | """
84 | :param input: (batch_size, graph_size, node_dim) input node features or dictionary with multiple tensors
85 | :param return_pi: whether to return the output sequences, this is optional as it is not compatible with
86 | :param opt_match: (batch_size, U_size, V_size), the optimal matching of the graphs in the batch
87 | using DataParallel as the results may be of different lengths on different GPUs
88 | :return:
89 | """
90 | _log_p, pi, cost, batch_loss = self._inner(
91 | input, opt_match, opts, optimizer, training
92 | )
93 |
94 | ll = self._calc_log_likelihood(_log_p, pi, None)
95 |
96 | return -cost, ll, pi, batch_loss
97 |
98 | def _calc_log_likelihood(self, _log_p, a, mask):
99 |
100 | # Get log_p corresponding to selected actions
101 | # print(a[0, :])
102 | entropy = -(_log_p * _log_p.exp()).sum(2).sum(1).mean()
103 | log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1)
104 |
105 | # Optional: mask out actions irrelevant to objective so they do not get reinforced
106 | # if mask is not None:
107 | # log_p[mask] = 0
108 | # if not (log_p > -10000).data.all():
109 | # print(log_p)
110 | # assert (
111 | # log_p > -10000
112 | # ).data.all(), "Logprobs should not be -inf, check sampling procedure!"
113 |
114 | # Calculate log_likelihood
115 | # print(_log_p)
116 | return log_p.sum(1), entropy
117 |
118 | def _inner(self, input, opt_match, opts, optimizer, training):
119 |
120 | outputs = []
121 | sequences = []
122 | # losses = []
123 |
124 | state = self.problem.make_state(input, opts.u_size, opts.v_size, opts)
125 | i = 1
126 | total_loss = 0
127 | while not (state.all_finished()):
128 | mask = state.get_mask()
129 | w = state.get_current_weights(mask)
130 | s, mask = state.get_curr_state(self.model_name)
131 | # s = w
132 | pi = self.ff(s)
133 | # Select the indices of the next nodes in the sequences, result (batch_size) long
134 | # if training:
135 | # mask = torch.zeros(mask.shape, device=opts.device)
136 | selected, p = self._select_node(
137 | pi,
138 | mask.bool(),
139 | ) # Squeeze out steps dimension
140 | # entropy += torch.sum(p * (p.log()), dim=1)
141 | state = state.update((selected)[:, None])
142 | outputs.append(p)
143 | sequences.append(selected)
144 |
145 | # do backprop if in training mode
146 | if opts.problem != "adwords":
147 | none_node_w = torch.tensor(
148 | [1.0 / (opts.v_size / opts.u_size)], device=opts.device
149 | ).float()
150 | else:
151 | none_node_w = torch.tensor([1.0], device=opts.device).float()
152 | w = torch.cat(
153 | [none_node_w, torch.ones(opts.u_size, device=opts.device).float()],
154 | dim=0,
155 | )
156 | # supervised learning
157 | y = opt_match[:, i - 1]
158 | # print('y: ', y)
159 | # print('selected: ', selected)
160 | loss = get_loss(pi, y, optimizer, w, opts)
161 | # print("Loss: ", loss)
162 | # keep track for logging
163 | total_loss += loss
164 | i += 1
165 | # Collected lists, return Tensor
166 | batch_loss = total_loss / state.v_size
167 | # print(batch_loss)
168 | if optimizer is not None and training:
169 | # print('epoch {} batch loss {}'.format(i, batch_loss))
170 | # print('outputs: ', outputs)
171 | # print('sequences: ', sequences)
172 | # print('optimal solution: ', opt_match)
173 | optimizer[0].zero_grad()
174 | batch_loss.backward()
175 | optimizer[0].step()
176 | return (
177 | torch.stack(outputs, 1),
178 | torch.stack(sequences, 1),
179 | state.size,
180 | batch_loss,
181 | )
182 |
183 | def _select_node(self, probs, mask):
184 | assert (probs == probs).all(), "Probs should not contain any nans"
185 | mask[:, 0] = False
186 | p = probs.clone()
187 | p[
188 | mask
189 | ] = (
190 | -1e6
191 | ) # Masking doesn't really make sense with supervised since input samples are independent, should only masking during testing.
192 | _, selected = p.max(1)
193 | return selected, p
194 |
195 | def set_decode_type(self, decode_type, temp=None):
196 | self.decode_type = decode_type
197 | if temp is not None: # Do not change temperature if not provided
198 | self.temp = temp
199 |
--------------------------------------------------------------------------------
/policy/gnn_simp_hist.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.utils.checkpoint import checkpoint
4 | import math
5 | from typing import NamedTuple
6 |
7 | import torch.nn.functional as F
8 |
9 | # from utils.tensor_functions import compute_in_batches
10 |
11 | from encoder.graph_encoder_v2 import GraphAttentionEncoder
12 | from train import clip_grad_norms
13 |
14 | from encoder.graph_encoder import MPNN
15 | from torch.nn import DataParallel
16 | from torch_geometric.utils import subgraph
17 |
18 | # from utils.functions import sample_many
19 |
20 | import time
21 |
22 |
23 | def set_decode_type(model, decode_type):
24 | if isinstance(model, DataParallel):
25 | model = model.module
26 | model.set_decode_type(decode_type)
27 |
28 |
29 | class GNNSimpHist(nn.Module):
30 | def __init__(
31 | self,
32 | embedding_dim,
33 | hidden_dim,
34 | problem,
35 | opts,
36 | n_encode_layers=1,
37 | tanh_clipping=10.0,
38 | mask_inner=True,
39 | mask_logits=True,
40 | normalization="batch",
41 | n_heads=8,
42 | checkpoint_encoder=False,
43 | shrink_size=None,
44 | num_actions=None,
45 | encoder="mpnn",
46 | ):
47 | super(GNNSimpHist, self).__init__()
48 |
49 | self.embedding_dim = embedding_dim
50 | self.hidden_dim = hidden_dim
51 | self.n_encode_layers = n_encode_layers
52 | self.decode_type = None
53 | self.temp = 1.0
54 | self.problem = problem
55 | self.opts = opts
56 | # Problem specific context parameters (placeholder and step context dimension)
57 |
58 | encoder_class = {"attention": GraphAttentionEncoder, "mpnn": MPNN}.get(
59 | encoder, None
60 | )
61 | if opts.problem == "osbm":
62 | node_dim_u = 16
63 | node_dim_v = 18
64 | else:
65 | node_dim_u, node_dim_v = 1, 1
66 |
67 | self.embedder = encoder_class(
68 | n_heads=n_heads,
69 | embed_dim=embedding_dim,
70 | n_layers=self.n_encode_layers,
71 | normalization=normalization,
72 | problem=self.problem,
73 | opts=self.opts,
74 | node_dim_v=node_dim_v,
75 | node_dim_u=node_dim_u,
76 | )
77 |
78 | self.ff = nn.Sequential(
79 | nn.Linear(16 + opts.embedding_dim, 200),
80 | nn.ReLU(),
81 | nn.Linear(200, 200),
82 | nn.ReLU(),
83 | nn.Linear(200, 1),
84 | )
85 |
86 | assert embedding_dim % n_heads == 0
87 | self.step_context_transf = nn.Linear(2 * opts.embedding_dim, opts.embedding_dim)
88 | self.initial_stepcontext = nn.Parameter(torch.Tensor(1, 1, embedding_dim))
89 | self.initial_stepcontext.data.uniform_(-1, 1)
90 | self.dummy = torch.ones(1, dtype=torch.float32, requires_grad=True)
91 | self.model_name = "gnn-simp-hist"
92 |
93 | def init_parameters(self):
94 | for name, param in self.named_parameters():
95 | stdv = 1.0 / math.sqrt(param.size(-1))
96 | param.data.uniform_(-stdv, stdv)
97 |
98 | def set_decode_type(self, decode_type, temp=None):
99 | self.decode_type = decode_type
100 | if temp is not None: # Do not change temperature if not provided
101 | self.temp = temp
102 |
103 | def forward(self, x, opts, optimizer, baseline, return_pi=False):
104 |
105 | _log_p, pi, cost = self._inner(x, opts)
106 |
107 | # cost, mask = self.problem.get_costs(input, pi)
108 | # Log likelyhood is calculated within the model since returning it per action does not work well with
109 | # DataParallel since sequences can be of different lengths
110 | ll, e = self._calc_log_likelihood(_log_p, pi, None)
111 | if return_pi:
112 | return -cost, ll, pi, e
113 | # print(ll)
114 | return -cost, ll, e
115 |
116 | def _calc_log_likelihood(self, _log_p, a, mask):
117 |
118 | entropy = -(_log_p * _log_p.exp()).sum(2).sum(1).mean()
119 | # Get log_p corresponding to selected actions
120 | log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1)
121 |
122 | # Optional: mask out actions irrelevant to objective so they do not get reinforced
123 | if mask is not None:
124 | log_p[mask] = 0
125 | if not (log_p > -10000).data.all():
126 | print(log_p.nonzero())
127 | assert (
128 | log_p > -10000
129 | ).data.all(), "Logprobs should not be -inf, check sampling procedure!"
130 |
131 | # Calculate log_likelihood
132 | # print(log_p.sum(1))
133 |
134 | return log_p.sum(1), entropy
135 |
136 | def _inner(self, input, opts):
137 |
138 | outputs = []
139 | sequences = []
140 |
141 | state = self.problem.make_state(input, opts.u_size, opts.v_size, opts)
142 |
143 | batch_size = state.batch_size
144 | graph_size = state.u_size + state.v_size + 1
145 | i = 1
146 | while not (state.all_finished()):
147 | step_size = state.i + 1
148 | mask = state.get_mask()
149 | w = state.get_current_weights(mask)
150 | s, mask = state.get_curr_state(self.model_name)
151 | # Pass the graph to the Encoder
152 | node_features = state.get_node_features()
153 | nodes = torch.cat(
154 | (
155 | torch.arange(0, opts.u_size + 1, device=opts.device),
156 | state.idx[:i] + opts.u_size + 1,
157 | )
158 | )
159 | subgraphs = (
160 | (nodes.unsqueeze(0).expand(batch_size, step_size))
161 | + torch.arange(
162 | 0, batch_size * graph_size, graph_size, device=opts.device
163 | ).unsqueeze(1)
164 | ).flatten() # The nodes of the current subgraphs
165 | graph_weights = state.get_graph_weights()
166 | edge_i, weights = subgraph(
167 | subgraphs,
168 | state.graphs.edge_index,
169 | graph_weights.unsqueeze(1),
170 | relabel_nodes=True,
171 | )
172 | embeddings = checkpoint(
173 | self.embedder,
174 | node_features,
175 | edge_i,
176 | weights.float(),
177 | torch.tensor(i),
178 | self.dummy,
179 | ).reshape(batch_size, step_size, -1)
180 | pos = torch.argsort(state.idx[:i])[-1]
181 | incoming_node_embeddings = embeddings[
182 | :, pos + state.u_size + 1, :
183 | ].unsqueeze(1)
184 | # print(incoming_node_embeddings)
185 | w = (state.adj[:, state.get_current_node(), :]).float()
186 | # mean_w = w.mean(1)[:, None, None].repeat(1, state.u_size + 1, 1)
187 | w = w.reshape(state.batch_size, state.u_size + 1, 1)
188 | s = torch.cat(
189 | (s, incoming_node_embeddings.repeat(1, state.u_size + 1, 1),), dim=2,
190 | )
191 | pi = self.ff(s).reshape(state.batch_size, state.u_size + 1)
192 | # Select the indices of the next nodes in the sequences, result (batch_size) long
193 | selected, p = self._select_node(
194 | pi, mask.bool()
195 | ) # Squeeze out steps dimension
196 | state = state.update((selected)[:, None])
197 | outputs.append(p)
198 | sequences.append(selected)
199 | i += 1
200 | # Collected lists, return Tensor
201 | return (
202 | torch.stack(outputs, 1),
203 | torch.stack(sequences, 1),
204 | state.size,
205 | )
206 |
207 | def _select_node(self, probs, mask):
208 | assert (probs == probs).all(), "Probs should not contain any nans"
209 | probs[mask] = -1e6
210 | p = torch.log_softmax(probs, dim=1)
211 | # print(p)
212 | if self.decode_type == "greedy":
213 | _, selected = p.max(1)
214 | # assert not mask.gather(
215 | # 1, selected.unsqueeze(-1)
216 | # ).data.any(), "Decode greedy: infeasible action has maximum probability"
217 |
218 | elif self.decode_type == "sampling":
219 | selected = p.exp().multinomial(1).squeeze(1)
220 | # Check if sampling went OK, can go wrong due to bug on GPU
221 | # See https://discuss.pytorch.org/t/bad-behavior-of-multinomial-function/10232
222 | # while mask.gather(1, selected.unsqueeze(-1)).data.any():
223 | # print("Sampled bad values, resampling!")
224 | # selected = probs.multinomial(1).squeeze(1)
225 |
226 | else:
227 | assert False, "Unknown decode type"
228 | return selected, p
229 |
--------------------------------------------------------------------------------
/policy/gnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.utils.checkpoint import checkpoint
4 | import math
5 | from typing import NamedTuple
6 |
7 | import torch.nn.functional as F
8 |
9 | # from utils.tensor_functions import compute_in_batches
10 |
11 | from encoder.graph_encoder_v2 import GraphAttentionEncoder
12 | from train import clip_grad_norms
13 |
14 | from encoder.graph_encoder import MPNN
15 | from torch.nn import DataParallel
16 | from torch_geometric.utils import subgraph
17 |
18 | # from utils.functions import sample_many
19 |
20 | import time
21 |
22 |
23 | def set_decode_type(model, decode_type):
24 | if isinstance(model, DataParallel):
25 | model = model.module
26 | model.set_decode_type(decode_type)
27 |
28 |
29 | class GNN(nn.Module):
30 | def __init__(
31 | self,
32 | embedding_dim,
33 | hidden_dim,
34 | problem,
35 | opts,
36 | n_encode_layers=1,
37 | tanh_clipping=10.0,
38 | mask_inner=True,
39 | mask_logits=True,
40 | normalization="batch",
41 | n_heads=8,
42 | checkpoint_encoder=False,
43 | shrink_size=None,
44 | num_actions=None,
45 | encoder="mpnn",
46 | ):
47 | super(GNN, self).__init__()
48 |
49 | self.embedding_dim = embedding_dim
50 | self.hidden_dim = hidden_dim
51 | self.n_encode_layers = n_encode_layers
52 | self.decode_type = None
53 | self.temp = 1.0
54 | self.problem = problem
55 | self.opts = opts
56 | # Problem specific context parameters (placeholder and step context dimension)
57 |
58 | encoder_class = {"attention": GraphAttentionEncoder, "mpnn": MPNN}.get(
59 | encoder, None
60 | )
61 | if opts.problem == "osbm":
62 | node_dim_u = 16
63 | node_dim_v = 18
64 | else:
65 | node_dim_u, node_dim_v = 1, 1
66 |
67 | self.embedder = encoder_class(
68 | n_heads=n_heads,
69 | embed_dim=embedding_dim,
70 | n_layers=self.n_encode_layers,
71 | normalization=normalization,
72 | problem=self.problem,
73 | opts=self.opts,
74 | node_dim_v=node_dim_v,
75 | node_dim_u=node_dim_u,
76 | )
77 |
78 | self.ff = nn.Sequential(
79 | nn.Linear(2 + opts.embedding_dim, 200), nn.ReLU(), nn.Linear(200, 1),
80 | )
81 |
82 | assert embedding_dim % n_heads == 0
83 | self.step_context_transf = nn.Linear(2 * opts.embedding_dim, opts.embedding_dim)
84 | self.initial_stepcontext = nn.Parameter(torch.Tensor(1, 1, embedding_dim))
85 | self.initial_stepcontext.data.uniform_(-1, 1)
86 | self.dummy = torch.ones(1, dtype=torch.float32, requires_grad=True)
87 | self.model_name = "gnn"
88 |
89 | def init_parameters(self):
90 | for name, param in self.named_parameters():
91 | stdv = 1.0 / math.sqrt(param.size(-1))
92 | param.data.uniform_(-stdv, stdv)
93 |
94 | def set_decode_type(self, decode_type, temp=None):
95 | self.decode_type = decode_type
96 | if temp is not None: # Do not change temperature if not provided
97 | self.temp = temp
98 |
99 | def forward(self, x, opts, optimizer, baseline, return_pi=False):
100 |
101 | _log_p, pi, cost = self._inner(x, opts)
102 |
103 | # cost, mask = self.problem.get_costs(input, pi)
104 | # Log likelyhood is calculated within the model since returning it per action does not work well with
105 | # DataParallel since sequences can be of different lengths
106 | ll, e = self._calc_log_likelihood(_log_p, pi, None)
107 | if return_pi:
108 | return -cost, ll, pi, e
109 | # print(ll)
110 | return -cost, ll, e
111 |
112 | def _calc_log_likelihood(self, _log_p, a, mask):
113 |
114 | entropy = -(_log_p * _log_p.exp()).sum(2).sum(1).mean()
115 | # Get log_p corresponding to selected actions
116 | log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1)
117 |
118 | # Optional: mask out actions irrelevant to objective so they do not get reinforced
119 | if mask is not None:
120 | log_p[mask] = 0
121 | if not (log_p > -10000).data.all():
122 | print(log_p.nonzero())
123 | assert (
124 | log_p > -10000
125 | ).data.all(), "Logprobs should not be -inf, check sampling procedure!"
126 |
127 | # Calculate log_likelihood
128 | # print(log_p.sum(1))
129 |
130 | return log_p.sum(1), entropy
131 |
132 | def _inner(self, input, opts):
133 |
134 | outputs = []
135 | sequences = []
136 |
137 | state = self.problem.make_state(input, opts.u_size, opts.v_size, opts)
138 |
139 | batch_size = state.batch_size
140 | graph_size = state.u_size + state.v_size + 1
141 | i = 1
142 | while not (state.all_finished()):
143 | step_size = state.i + 1
144 | mask = state.get_mask()
145 | w = state.get_current_weights(mask)
146 | # Pass the graph to the Encoder
147 | node_features = state.get_node_features()
148 | nodes = torch.cat(
149 | (
150 | torch.arange(0, opts.u_size + 1, device=opts.device),
151 | state.idx[:i] + opts.u_size + 1,
152 | )
153 | )
154 | subgraphs = (
155 | (nodes.unsqueeze(0).expand(batch_size, step_size))
156 | + torch.arange(
157 | 0, batch_size * graph_size, graph_size, device=opts.device
158 | ).unsqueeze(1)
159 | ).flatten() # The nodes of the current subgraphs
160 | graph_weights = state.get_graph_weights()
161 | edge_i, weights = subgraph(
162 | subgraphs,
163 | state.graphs.edge_index,
164 | graph_weights.unsqueeze(1),
165 | relabel_nodes=True,
166 | )
167 | embeddings = checkpoint(
168 | self.embedder,
169 | node_features,
170 | edge_i,
171 | weights.float(),
172 | torch.tensor(i),
173 | self.dummy,
174 | ).reshape(batch_size, step_size, -1)
175 | pos = torch.argsort(state.idx[:i])[-1]
176 | incoming_node_embeddings = embeddings[
177 | :, pos + state.u_size + 1, :
178 | ].unsqueeze(1)
179 | # print(incoming_node_embeddings)
180 | w = (state.adj[:, state.get_current_node(), :]).float()
181 | # mean_w = w.mean(1)[:, None, None].repeat(1, state.u_size + 1, 1)
182 | s = w.reshape(state.batch_size, state.u_size + 1, 1)
183 | idx = (
184 | torch.ones(state.batch_size, 1, 1, device=opts.device)
185 | * i
186 | / state.v_size
187 | )
188 | s = torch.cat(
189 | (
190 | s,
191 | idx.repeat(1, state.u_size + 1, 1),
192 | incoming_node_embeddings.repeat(1, state.u_size + 1, 1),
193 | ),
194 | dim=2,
195 | )
196 | pi = self.ff(s).reshape(state.batch_size, state.u_size + 1)
197 | # Select the indices of the next nodes in the sequences, result (batch_size) long
198 | selected, p = self._select_node(
199 | pi, mask.bool()
200 | ) # Squeeze out steps dimension
201 | # entropy += torch.sum(p * (p.log()), dim=1)
202 | state = state.update((selected)[:, None])
203 | outputs.append(p)
204 | sequences.append(selected)
205 | i += 1
206 | # Collected lists, return Tensor
207 | return (
208 | torch.stack(outputs, 1),
209 | torch.stack(sequences, 1),
210 | state.size,
211 | )
212 |
213 | def _select_node(self, probs, mask):
214 | assert (probs == probs).all(), "Probs should not contain any nans"
215 | probs[mask] = -1e6
216 | p = torch.log_softmax(probs, dim=1)
217 | # print(p)
218 | if self.decode_type == "greedy":
219 | _, selected = p.max(1)
220 | # assert not mask.gather(
221 | # 1, selected.unsqueeze(-1)
222 | # ).data.any(), "Decode greedy: infeasible action has maximum probability"
223 |
224 | elif self.decode_type == "sampling":
225 | selected = p.exp().multinomial(1).squeeze(1)
226 | # Check if sampling went OK, can go wrong due to bug on GPU
227 | # See https://discuss.pytorch.org/t/bad-behavior-of-multinomial-function/10232
228 | # while mask.gather(1, selected.unsqueeze(-1)).data.any():
229 | # print("Sampled bad values, resampling!")
230 | # selected = probs.multinomial(1).squeeze(1)
231 |
232 | else:
233 | assert False, "Unknown decode type"
234 | return selected, p
235 |
--------------------------------------------------------------------------------
/policy/inv_ff_history_switch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from policy.inv_ff_history import InvariantFFHist
4 | from utils.functions import random_max
5 | import math
6 |
7 | # from utils.visualize import visualize_reward_process
8 | from copy import deepcopy
9 |
10 | class InvariantFFHistSwitch(InvariantFFHist):
11 | def __init__(
12 | self,
13 | embedding_dim,
14 | hidden_dim,
15 | problem,
16 | opts,
17 | tanh_clipping=None,
18 | mask_inner=None,
19 | mask_logits=None,
20 | n_encode_layers=None,
21 | normalization="batch",
22 | checkpoint_encoder=False,
23 | shrink_size=None,
24 | num_actions=4,
25 | n_heads=None,
26 | encoder=None,
27 | ):
28 | super(InvariantFFHistSwitch, self).__init__(embedding_dim,
29 | hidden_dim,
30 | problem,
31 | opts,
32 | tanh_clipping=tanh_clipping,
33 | mask_inner=mask_inner,
34 | mask_logits=mask_logits,
35 | n_encode_layers=n_encode_layers,
36 | normalization=normalization,
37 | checkpoint_encoder=checkpoint_encoder,
38 | shrink_size=shrink_size,
39 | num_actions=num_actions,
40 | n_heads=n_heads,
41 | encoder=encoder)
42 |
43 | self.switch_lambda = opts.switch_lambda
44 | self.slackness = opts.slackness
45 | self.max_reward = opts.max_reward
46 | self.softmax_temp = opts.softmax_temp
47 |
48 |
49 | def _inner_pure_ml(self, input, opts, show_rewards = False):
50 |
51 | outputs = []
52 | sequences = []
53 | state = self.problem.make_state(deepcopy(input), opts.u_size, opts.v_size, opts)
54 |
55 | # Perform decoding steps
56 | rewards = []
57 | i = 1
58 | while not (state.all_finished()):
59 | mask = state.get_mask()
60 | state.get_current_weights(mask)
61 | s, mask = state.get_curr_state(self.model_name)
62 | pi = self.ff(s).reshape(state.batch_size, state.u_size + 1)
63 | # Select the indices of the next nodes in the sequences, result (batch_size) long
64 | selected, p = self._select_node(
65 | pi, mask.bool()
66 | ) # Squeeze out steps dimension
67 | # entropy += torch.sum(p * (p.log()), dim=1)
68 | state = state.update((selected)[:, None])
69 | outputs.append(p)
70 | sequences.append(selected)
71 | i += 1
72 |
73 | rewards.append(state.size.view(-1))
74 |
75 | res_list = [torch.stack(outputs, 1),
76 | torch.stack(sequences, 1),
77 | state.size,]
78 | if show_rewards:
79 | # visualize_reward_process(torch.stack(rewards, 1).cpu().numpy(), "temp_ml_movie.png",ylim = [-1, 14])
80 | res_list.append(torch.stack(rewards, 1))
81 |
82 | # Collected lists, return Tensor
83 | return res_list
84 |
85 | def _inner_switch(self, input, opts, show_rewards = False, osm_baseline = False):
86 |
87 | # torch.autograd.set_detect_anomaly(True)
88 |
89 | if opts.problem not in ["osbm", "e-obm"]:
90 | raise NotImplementedError
91 |
92 | batch_size = int(input.batch.size(0) / (opts.u_size + opts.v_size + 1) )
93 |
94 | ## Greedy algorithm state
95 | state_greedy = self.problem.make_state(deepcopy(input), opts.u_size, opts.v_size, opts)
96 |
97 | ## ML algorithm state
98 | state_ml = self.problem.make_state(deepcopy(input), opts.u_size, opts.v_size, opts)
99 | sequences_ml, outputs_ml, rewards_ml = [], [], []
100 |
101 | graph_v_size = opts.v_size
102 | i = 1
103 |
104 | while not (state_ml.all_finished()):
105 | #Extract Current Info
106 | mask_gd = state_greedy.get_mask()
107 | weight_gd = state_greedy.get_current_weights(mask_gd).clone()
108 | weight_gd[mask_gd.bool()] = -1.0
109 |
110 | if osm_baseline:
111 | # Select action based on current info
112 | if(i > graph_v_size/math.e):
113 | selected_gd = random_max(weight_gd)
114 | else:
115 | selected_gd = torch.zeros([weight_gd.shape[0], 1], device=weight_gd.device, dtype=torch.int64)
116 | else:
117 | selected_gd = random_max(weight_gd)
118 |
119 | state_greedy = state_greedy.update(selected_gd)
120 |
121 | # Generate Probabality from Greedy algorithm
122 | probs_gd = -1e8 * torch.ones(mask_gd.shape, device = mask_gd.device)
123 | probs_gd[(torch.arange(batch_size) , selected_gd.view(-1) )] = 1
124 | p_gd = torch.log_softmax(probs_gd, dim=1)
125 |
126 | reward_gd = state_greedy.size
127 |
128 | #Extract Current Info
129 | mask_ml = state_ml.get_mask()
130 | state_ml.get_current_weights(mask_ml)
131 | info_s, mask_ml = state_ml.get_curr_state(self.model_name)
132 | pi = self.ff(info_s).reshape(state_ml.batch_size, state_ml.u_size + 1)
133 |
134 | # Select the indices of the next nodes in the sequences, result (batch_size) long
135 | selected_ml, p_ml = self._select_node(pi, mask_ml.bool())
136 | selected_ml = (selected_ml)[:, None]
137 |
138 | # Collect reward based on these hypothesis actions
139 | reward_hyp, matched_node_hyp = state_ml.try_hypothesis_action(selected_ml)
140 |
141 | # Compare the ML with Greedy actions
142 | matched_node_gd = state_greedy.matched_nodes
143 | action_diff = torch.relu(matched_node_hyp - matched_node_gd)[:,1:]
144 | action_diff = action_diff.sum(1, keepdim=True)
145 |
146 | # Calculate reward
147 | reward_gd_reserve = reward_gd + action_diff * self.max_reward
148 | reward_diff = reward_hyp - (self.switch_lambda * reward_gd_reserve - self.slackness)
149 |
150 | # Determin the probability of the final action
151 | probs_policy = torch.zeros([batch_size, 2], device=mask_gd.device)
152 | probs_policy[:,0] = reward_diff.view(-1)/self.softmax_temp # Normalize the selection cost
153 | p_policy_exp = torch.log_softmax(probs_policy, dim=1).exp()
154 |
155 | p_total_exp = p_policy_exp[:,0].view(-1,1) * (p_ml.exp()) + p_policy_exp[:,1].view(-1,1) * (p_gd.exp()) + 1e-8
156 | p_total = p_total_exp.log()
157 |
158 | # Determine the action based on the reward_diff (expert v.s. ML)
159 | if (self.ff.training):
160 | # During training, use probability based method
161 | projected_ml, p_final = self._select_node(p_total, mask_ml.bool())
162 | projected_ml = (projected_ml)[:, None]
163 | else:
164 | # During training, use hard-switch
165 | reward_sign = (reward_diff >= 0)
166 | avaliable_sign = (matched_node_hyp.gather(1, selected_gd) < 1e-10)
167 | expert_sign = torch.logical_not(reward_sign) * avaliable_sign
168 | projected_ml = selected_ml*reward_sign + selected_gd * expert_sign
169 | p_final = torch.log_softmax(p_total, dim=1)
170 |
171 | # Update the current state
172 | state_ml = state_ml.update(projected_ml)
173 | outputs_ml.append(p_final)
174 | sequences_ml.append(projected_ml.squeeze(1))
175 | i += 1
176 |
177 | rewards_ml.append(state_ml.size.view(-1))
178 |
179 | res_list = [torch.stack(outputs_ml, 1),
180 | torch.stack(sequences_ml, 1),
181 | state_ml.size,]
182 |
183 | if show_rewards:
184 | # visualize_reward_process(torch.stack(rewards, 1).cpu().numpy(), "temp_ml_movie.png",ylim = [-1, 14])
185 | res_list.append(torch.stack(rewards_gd, 1))
186 |
187 | # Collected lists, return Tensor
188 | return res_list
189 |
190 | def _inner_greedy(self, input, opts, show_rewards = False):
191 | state = self.problem.make_state(deepcopy(input), opts.u_size, opts.v_size, opts)
192 | sequences = []
193 | outputs = []
194 | batch_size = int(input.batch.size(0) / (opts.u_size + opts.v_size + 1) )
195 |
196 | rewards = []
197 | i = 1
198 | while not (state.all_finished()):
199 | mask = state.get_mask()
200 | w = state.get_current_weights(mask).clone()
201 | w[mask.bool()] = -1.0
202 | selected = random_max(w)
203 | state = state.update(selected)
204 | sequences.append(selected.squeeze(1))
205 |
206 | # Generate Probabality from Greedy algorithm
207 | probs = -1e8 * torch.ones(mask.shape, device = mask.device)
208 | probs[(torch.arange(batch_size) , selected.view(-1) )] = 1
209 | p = torch.log_softmax(probs, dim=1)
210 | outputs.append(p)
211 | i += 1
212 |
213 | rewards.append(state.size.view(-1))
214 |
215 | res_list = [torch.stack(outputs, 1),
216 | torch.stack(sequences, 1),
217 | state.size,]
218 |
219 | if show_rewards:
220 | # visualize_reward_process(torch.stack(rewards, 1).cpu().numpy(), "temp_ml_movie.png",ylim = [-1, 14])
221 | res_list.append(torch.stack(rewards, 1))
222 |
223 | # Collected lists, return Tensor
224 | return res_list
225 |
226 | def _inner(self, input, opts):
227 |
228 | # visualize_reward_process(torch.stack(rewards, 1).cpu().numpy(), "temp_ml_movie.png",ylim = [-1, 14])
229 |
230 | # return self._inner_greedy(input, opts)
231 | # return self._inner_pure_ml(input, opts)
232 | return self._inner_switch(input, opts)
233 |
234 |
--------------------------------------------------------------------------------
/run_lite.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import json
4 | import pprint as pp
5 |
6 | import torch
7 | import torch.optim as optim
8 | from itertools import product
9 | import wandb
10 |
11 | # from tensorboard_logger import Logger as TbLogger
12 | from torch.utils.tensorboard import SummaryWriter
13 | from torch_geometric.loader import DataLoader as geoDataloader
14 |
15 | # from nets.critic_network import CriticNetwork
16 | from options import get_options
17 | from train import train_epoch, validate, get_inner_model
18 | from utils.reinforce_baselines import (
19 | NoBaseline,
20 | ExponentialBaseline,
21 | RolloutBaseline,
22 | WarmupBaseline,
23 | GreedyBaseline,
24 | )
25 |
26 | from policy.attention_model import AttentionModel as AttentionModelgeo
27 | from policy.ff_model import FeedForwardModel
28 | from policy.ff_model_invariant import InvariantFF
29 | from policy.ff_model_hist import FeedForwardModelHist
30 | from policy.inv_ff_history import InvariantFFHist
31 | from policy.inv_ff_history_switch import InvariantFFHistSwitch
32 | from policy.greedy import Greedy
33 | from policy.greedy_rt import GreedyRt
34 | from policy.greedy_sc import GreedySC
35 | from policy.greedy_theshold import GreedyThresh
36 | from policy.greedy_matching import GreedyMatching
37 | from policy.simple_greedy import SimpleGreedy
38 | from policy.supervised import SupervisedModel
39 | from policy.ff_supervised import SupervisedFFModel
40 | from policy.gnn_hist import GNNHist
41 | from policy.gnn_simp_hist import GNNSimpHist
42 | from policy.gnn import GNN
43 |
44 | # from nets.pointer_network import PointerNetwork, CriticNetworkLSTM
45 | from utils.functions import torch_load_cpu, load_problem
46 | from utils.visualize import validate_histogram
47 |
48 |
49 | def run(opts):
50 |
51 | # Pretty print the run args
52 | # pp.pprint(vars(opts))
53 |
54 | # Set the random seed
55 | torch.manual_seed(opts.seed)
56 | # torch.backends.cudnn.benchmark = True
57 | # torch.autograd.set_detect_anomaly(True)
58 | # Optionally configure tensorboard
59 | tb_logger = None
60 | if not opts.no_tensorboard:
61 | tb_logger = SummaryWriter(
62 | os.path.join(
63 | opts.log_dir,
64 | opts.model,
65 | opts.run_name,
66 | )
67 | )
68 | if not opts.eval_only and not os.path.exists(opts.save_dir):
69 | os.makedirs(opts.save_dir)
70 | # Save arguments so exact configuration can always be found
71 | with open(os.path.join(opts.save_dir, "args.json"), "w") as f:
72 | json.dump(vars(opts), f, indent=True)
73 |
74 | # Set the device
75 | opts.device = torch.device("cuda:0" if opts.use_cuda else "cpu")
76 |
77 | # Figure out what's the problem
78 | problem = load_problem(opts.problem)
79 |
80 | # Load data from load_path
81 | load_data = {}
82 | assert (
83 | opts.load_path is None or opts.resume is None
84 | ), "Only one of load path and resume can be given"
85 | load_path = opts.load_path if opts.load_path is not None else opts.resume
86 | if load_path is not None:
87 | print(" [*] Loading data from {}".format(load_path))
88 | load_data = torch_load_cpu(load_path)
89 |
90 | assert (opts.tune_wandb == False
91 | and opts.tune == False
92 | and opts.tune_baseline == False), "Unsupported Mode, please use \"run.py\" instead of \"run_lite.py\" "
93 |
94 | # assert(opts.model == "inv-ff-hist"), "We only support inv-ff-hist model up to now..."
95 | # Initialize model
96 | model_class = {
97 | "attention": AttentionModelgeo,
98 | "ff": FeedForwardModel,
99 | "greedy": Greedy,
100 | "greedy-rt": GreedyRt,
101 | "greedy-sc": GreedySC,
102 | "greedy-t": GreedyThresh,
103 | "greedy-m": GreedyMatching,
104 | "simple-greedy": SimpleGreedy,
105 | "inv-ff": InvariantFF,
106 | "inv-ff-hist": InvariantFFHist,
107 | "ff-hist": FeedForwardModelHist,
108 | "supervised": SupervisedModel,
109 | "ff-supervised": SupervisedFFModel,
110 | "gnn-hist": GNNHist,
111 | "gnn-simp-hist": GNNSimpHist,
112 | "gnn": GNN,
113 | "inv-ff-hist-switch":InvariantFFHistSwitch
114 | }.get(opts.model, None)
115 | assert model_class is not None, "Unknown model: {}".format(model_class)
116 | # if not opts.tune:
117 | model, lr_schedulers, optimizers, val_dataloader, baseline = setup_training_env(
118 | opts, model_class, problem, load_data, tb_logger
119 | )
120 |
121 | training_dataset = problem.make_dataset(
122 | opts.train_dataset, opts.dataset_size, opts.problem, seed=None, opts=opts
123 | )
124 | if opts.eval_only:
125 |
126 | print("Evaluate Greedy algorithm as Baseline algorithm")
127 | greedy_model = Greedy(opts.embedding_dim, opts.hidden_dim, problem=problem, opts=opts).to(opts.device)
128 | validate(greedy_model, val_dataloader, opts)
129 |
130 | print("Evaluate ML model from checkpoint")
131 | validate(model, val_dataloader, opts)
132 | switch_lambda_array = np.arange(0.0, 1.01, 0.1)
133 |
134 |
135 |
136 |
137 | else:
138 | best_avg_cr = 0.0
139 | for epoch in range(opts.epoch_start, opts.epoch_start + opts.n_epochs):
140 | training_dataloader = geoDataloader(
141 | baseline.wrap_dataset(training_dataset),
142 | batch_size=opts.batch_size,
143 | num_workers=0,
144 | shuffle=True,
145 | )
146 | avg_reward, min_cr, avg_cr, loss = train_epoch(
147 | model,
148 | optimizers,
149 | baseline,
150 | lr_schedulers,
151 | epoch,
152 | val_dataloader,
153 | training_dataloader,
154 | problem,
155 | tb_logger,
156 | opts,
157 | best_avg_cr,
158 | )
159 | # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
160 | best_avg_cr = max(best_avg_cr, avg_cr)
161 |
162 |
163 | def setup_training_env(opts, model_class, problem, load_data, tb_logger):
164 | model = model_class(
165 | opts.embedding_dim,
166 | opts.hidden_dim,
167 | problem=problem,
168 | n_encode_layers=opts.n_encode_layers,
169 | mask_inner=True,
170 | mask_logits=True,
171 | normalization=opts.normalization,
172 | tanh_clipping=opts.tanh_clipping,
173 | checkpoint_encoder=opts.checkpoint_encoder,
174 | shrink_size=opts.shrink_size,
175 | num_actions=opts.u_size + 1,
176 | n_heads=opts.n_heads,
177 | encoder=opts.encoder,
178 | opts=opts,
179 | ).to(opts.device)
180 |
181 | if opts.use_cuda and torch.cuda.device_count() > 1:
182 | model = torch.nn.DataParallel(model)
183 |
184 | # Overwrite model parameters by parameters to load
185 | model_ = get_inner_model(model)
186 | model_.load_state_dict({**model_.state_dict(), **load_data.get("model", {})})
187 |
188 | # Initialize baseline
189 | if opts.baseline == "exponential":
190 | baseline = ExponentialBaseline(opts.exp_beta)
191 | elif opts.baseline == "greedy":
192 | baseline_class = {"e-obm": Greedy, "obm": SimpleGreedy}.get(opts.problem, None)
193 |
194 | greedybaseline = baseline_class(
195 | opts.embedding_dim,
196 | opts.hidden_dim,
197 | problem=problem,
198 | n_encode_layers=opts.n_encode_layers,
199 | mask_inner=True,
200 | mask_logits=True,
201 | normalization=opts.normalization,
202 | tanh_clipping=opts.tanh_clipping,
203 | checkpoint_encoder=opts.checkpoint_encoder,
204 | shrink_size=opts.shrink_size,
205 | num_actions=opts.u_size + 1,
206 | # n_heads=opts.n_heads,
207 | )
208 | baseline = GreedyBaseline(greedybaseline, opts)
209 | elif opts.baseline == "rollout":
210 | baseline = RolloutBaseline(model, problem, opts)
211 | else:
212 | assert opts.baseline is None, "Unknown baseline: {}".format(opts.baseline)
213 | baseline = NoBaseline()
214 |
215 | if opts.bl_warmup_epochs > 0:
216 | baseline = WarmupBaseline(
217 | baseline, opts.bl_warmup_epochs, warmup_exp_beta=opts.exp_beta
218 | )
219 |
220 | # Load baseline from data, make sure script is called with same type of baseline
221 | if "baseline" in load_data:
222 | baseline.load_state_dict(load_data["baseline"])
223 |
224 | # Initialize optimizer
225 | optimizer = optim.Adam(
226 | [{"params": model.parameters(), "lr": opts.lr_model}]
227 | + (
228 | [{"params": baseline.get_learnable_parameters(), "lr": opts.lr_critic}]
229 | if len(baseline.get_learnable_parameters()) > 0
230 | else []
231 | )
232 | )
233 |
234 | # Load optimizer state
235 | if "optimizer" in load_data:
236 | optimizer.load_state_dict(load_data["optimizer"])
237 | for state in optimizer.state.values():
238 | for k, v in state.items():
239 | # if isinstance(v, torch.Tensor):
240 | if torch.is_tensor(v):
241 | state[k] = v.to(opts.device)
242 |
243 | # Initialize learning rate scheduler, decay by lr_decay once per epoch!
244 | lr_scheduler = optim.lr_scheduler.LambdaLR(
245 | optimizer, lambda epoch: opts.lr_decay ** epoch
246 | )
247 |
248 | # Start the actual training loop
249 | val_dataset = problem.make_dataset(
250 | opts.val_dataset, opts.val_size, opts.problem, seed=None, opts=opts
251 | )
252 | val_dataloader = geoDataloader(
253 | val_dataset, batch_size=opts.batch_size, num_workers=1
254 | )
255 | assert (opts.resume is None), "Resume not supported, please use \"run.py\" instead of \"run_lite.py\""
256 |
257 | return (
258 | model,
259 | [lr_scheduler],
260 | [optimizer],
261 | val_dataloader,
262 | baseline,
263 | )
264 |
265 |
266 | if __name__ == "__main__":
267 | run(get_options())
268 |
--------------------------------------------------------------------------------
/utils/reinforce_baselines.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.utils.data import Dataset
4 | from scipy.stats import ttest_rel
5 | import copy
6 | from train import rollout, get_inner_model
7 | from torch_geometric.data import DataLoader
8 |
9 |
10 | class Baseline(object):
11 | def wrap_dataset(self, dataset):
12 | return dataset
13 |
14 | def unwrap_batch(self, batch):
15 | return batch, None
16 |
17 | def eval(self, x, c):
18 | raise NotImplementedError("Override this method")
19 |
20 | def get_learnable_parameters(self):
21 | return []
22 |
23 | def epoch_callback(self, model, epoch):
24 | pass
25 |
26 | def state_dict(self):
27 | return {}
28 |
29 | def load_state_dict(self, state_dict):
30 | pass
31 |
32 |
33 | class WarmupBaseline(Baseline):
34 | def __init__(
35 | self, baseline, n_epochs=1, warmup_exp_beta=0.8,
36 | ):
37 | super(Baseline, self).__init__()
38 |
39 | self.baseline = baseline
40 | assert n_epochs > 0, "n_epochs to warmup must be positive"
41 | self.warmup_baseline = ExponentialBaseline(warmup_exp_beta)
42 | self.alpha = 0
43 | self.n_epochs = n_epochs
44 |
45 | def wrap_dataset(self, dataset):
46 | if self.alpha > 0:
47 | return self.baseline.wrap_dataset(dataset)
48 | return self.warmup_baseline.wrap_dataset(dataset)
49 |
50 | def unwrap_batch(self, batch):
51 | if self.alpha > 0:
52 | return self.baseline.unwrap_batch(batch)
53 | # return batch
54 | return self.warmup_baseline.unwrap_batch(batch)
55 |
56 | def eval(self, x, c):
57 |
58 | if self.alpha == 1:
59 | return self.baseline.eval(x, c)
60 | if self.alpha == 0:
61 | return self.warmup_baseline.eval(x, c)
62 | v, t = self.baseline.eval(x, c)
63 | vw, lw = self.warmup_baseline.eval(x, c)
64 | # Return convex combination of baseline and of loss
65 | return (
66 | self.alpha * v + (1 - self.alpha) * vw,
67 | self.alpha * t + (1 - self.alpha * lw),
68 | )
69 |
70 | def epoch_callback(self, model, epoch):
71 | # Need to call epoch callback of inner model (also after first epoch if we have not used it)
72 | self.baseline.epoch_callback(model, epoch)
73 | self.alpha = (epoch + 1) / float(self.n_epochs)
74 | if epoch < self.n_epochs:
75 | print("Set warmup alpha = {}".format(self.alpha))
76 |
77 | def state_dict(self):
78 | # Checkpointing within warmup stage makes no sense, only save inner baseline
79 | return self.baseline.state_dict()
80 |
81 | def load_state_dict(self, state_dict):
82 | # Checkpointing within warmup stage makes no sense, only load inner baseline
83 | self.baseline.load_state_dict(state_dict)
84 |
85 |
86 | class NoBaseline(Baseline):
87 | def eval(self, x, c):
88 | return 0, 0 # No baseline, no loss
89 |
90 |
91 | class ExponentialBaseline(Baseline):
92 | def __init__(self, beta):
93 | super(Baseline, self).__init__()
94 |
95 | self.beta = beta
96 | self.v = None
97 |
98 | def eval(self, x, c):
99 |
100 | if self.v is None:
101 | v = c.mean()
102 | else:
103 | v = self.beta * self.v + (1.0 - self.beta) * c.mean()
104 |
105 | self.v = v.detach() # Detach since we never want to backprop
106 | return self.v, 0 # No loss
107 |
108 | def state_dict(self):
109 | return {"v": self.v}
110 |
111 | def load_state_dict(self, state_dict):
112 | self.v = state_dict["v"]
113 |
114 |
115 | class CriticBaseline(Baseline):
116 | def __init__(self, critic):
117 | super(Baseline, self).__init__()
118 |
119 | self.critic = critic
120 |
121 | def eval(self, x, c):
122 | v = self.critic(x)
123 | # Detach v since actor should not backprop through baseline, only for loss
124 | return v.detach(), F.mse_loss(v, c.detach())
125 |
126 | def get_learnable_parameters(self):
127 | return list(self.critic.parameters())
128 |
129 | def epoch_callback(self, model, epoch):
130 | pass
131 |
132 | def state_dict(self):
133 | return {"critic": self.critic.state_dict()}
134 |
135 | def load_state_dict(self, state_dict):
136 | critic_state_dict = state_dict.get("critic", {})
137 | if not isinstance(critic_state_dict, dict): # backwards compatibility
138 | critic_state_dict = critic_state_dict.state_dict()
139 | self.critic.load_state_dict({**self.critic.state_dict(), **critic_state_dict})
140 |
141 |
142 | class GreedyBaseline(Baseline):
143 | def __init__(self, greedymodel, opts):
144 | super(Baseline, self).__init__()
145 |
146 | self.baseline = greedymodel
147 | self.opts = opts
148 |
149 | def eval(self, x, c):
150 |
151 | return self.baseline(x, opts=self.opts)[0].detach(), 0 # No loss
152 |
153 |
154 | class RolloutBaseline(Baseline):
155 | def __init__(self, model, problem, opts, epoch=0):
156 | super(Baseline, self).__init__()
157 |
158 | self.problem = problem
159 | self.opts = opts
160 |
161 | self._update_model(model, epoch)
162 |
163 | def _update_model(self, model, epoch, dataset=None):
164 | self.model = copy.deepcopy(model)
165 | # Always generate baseline dataset when updating model to prevent overfitting to the baseline dataset
166 |
167 | if dataset is not None:
168 | if len(dataset) != self.opts.val_size:
169 | print(
170 | "Warning: not using saved baseline dataset since val_size does not match"
171 | )
172 | dataset = None
173 | elif (dataset[0] if self.problem.NAME == "tsp" else dataset[0]["loc"]).size(
174 | 0
175 | ) != self.opts.graph_size:
176 | print(
177 | "Warning: not using saved baseline dataset since graph_size does not match"
178 | )
179 | dataset = None
180 |
181 | if dataset is None:
182 | self.dataset = DataLoader(
183 | self.problem.make_dataset(
184 | None,
185 | self.opts.val_size,
186 | self.opts.problem,
187 | seed=epoch * 1000,
188 | opts=self.opts,
189 | ),
190 | batch_size=self.opts.eval_batch_size,
191 | num_workers=1,
192 | )
193 | else:
194 | self.dataset = dataset
195 | print("Evaluating baseline model on evaluation dataset")
196 | self.bl_vals = (
197 | rollout(self.model, self.dataset, self.opts)[0].cpu().numpy() / 100.0
198 | )
199 | self.mean = self.bl_vals.mean()
200 | self.epoch = epoch
201 |
202 | def wrap_dataset(self, dataset):
203 | print("Evaluating baseline on dataset...")
204 | # Need to convert baseline to 2D to prevent converting to double, see
205 | # https://discuss.pytorch.org/t/dataloader-gives-double-instead-of-float/717/3
206 | dataloader = DataLoader(
207 | dataset, batch_size=self.opts.eval_batch_size, num_workers=1
208 | )
209 | return BaselineDataset(
210 | dataset, rollout(self.model, dataloader, self.opts)[0].view(-1, 1)
211 | )
212 | # return dataset
213 |
214 | def unwrap_batch(self, batch):
215 | return (
216 | batch["data"],
217 | None,
218 | ) # Flatten result to undo wrapping as 2D
219 |
220 | def eval(self, x, c):
221 | # Use volatile mode for efficient inference (single batch so we do not use rollout function)
222 | with torch.no_grad():
223 | v, _ = self.model(x, self.opts, None, None)
224 |
225 | # There is no loss
226 | return v, 0
227 |
228 | def epoch_callback(self, model, epoch):
229 | """
230 | Challenges the current baseline with the model and replaces the baseline model if it is improved.
231 | :param model: The model to challenge the baseline by
232 | :param epoch: The current epoch
233 | """
234 | print("Evaluating candidate model on evaluation dataset")
235 | candidate_vals = (
236 | rollout(model, self.dataset, self.opts)[0].cpu().numpy() / 100.0
237 | )
238 |
239 | candidate_mean = candidate_vals.mean()
240 |
241 | print(
242 | "Epoch {} candidate mean {}, baseline epoch {} mean {}, difference {}".format(
243 | epoch, candidate_mean, self.epoch, self.mean, candidate_mean - self.mean
244 | )
245 | )
246 | if candidate_mean - self.mean < 0:
247 | # Calc p value
248 | t, p = ttest_rel(candidate_vals, self.bl_vals)
249 |
250 | p_val = p / 2 # one-sided
251 | assert t < 0, "T-statistic should be negative"
252 | print("p-value: {}".format(p_val))
253 | if p_val < self.opts.bl_alpha:
254 | print("Update baseline")
255 | self._update_model(model, epoch)
256 |
257 | def state_dict(self):
258 | return {"model": self.model, "dataset": self.dataset, "epoch": self.epoch}
259 |
260 | def load_state_dict(self, state_dict):
261 | # We make it such that it works whether model was saved as data parallel or not
262 | load_model = copy.deepcopy(self.model)
263 | get_inner_model(load_model).load_state_dict(
264 | get_inner_model(state_dict["model"]).state_dict()
265 | )
266 | self._update_model(load_model, state_dict["epoch"], state_dict["dataset"])
267 |
268 |
269 | class BaselineDataset(Dataset):
270 | def __init__(self, dataset=None, baseline=None):
271 | super(BaselineDataset, self).__init__()
272 |
273 | self.dataset = dataset
274 | self.baseline = baseline
275 | assert len(self.dataset) == len(self.baseline)
276 |
277 | def __getitem__(self, item):
278 | return {"data": self.dataset[item], "baseline": self.baseline[item]}
279 |
280 | def __len__(self):
281 | return len(self.dataset)
282 |
--------------------------------------------------------------------------------
/policy/gnn_hist.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.utils.checkpoint import checkpoint
4 | import math
5 |
6 | from encoder.graph_encoder_v2 import GraphAttentionEncoder
7 |
8 | from encoder.graph_encoder import MPNN
9 | from torch.nn import DataParallel
10 | from torch_geometric.utils import subgraph
11 |
12 |
13 | def set_decode_type(model, decode_type):
14 | if isinstance(model, DataParallel):
15 | model = model.module
16 | model.set_decode_type(decode_type)
17 |
18 |
19 | class GNNHist(nn.Module):
20 | def __init__(
21 | self,
22 | embedding_dim,
23 | hidden_dim,
24 | problem,
25 | opts,
26 | n_encode_layers=1,
27 | tanh_clipping=10.0,
28 | mask_inner=True,
29 | mask_logits=True,
30 | normalization="batch",
31 | n_heads=8,
32 | checkpoint_encoder=False,
33 | shrink_size=None,
34 | num_actions=None,
35 | encoder="mpnn",
36 | ):
37 | super(GNNHist, self).__init__()
38 |
39 | self.embedding_dim = embedding_dim
40 | self.hidden_dim = hidden_dim
41 | self.n_encode_layers = n_encode_layers
42 | self.decode_type = None
43 | self.temp = 1.0
44 | self.problem = problem
45 | self.opts = opts
46 |
47 | encoder_class = {"attention": GraphAttentionEncoder, "mpnn": MPNN}.get(
48 | encoder, None
49 | )
50 | if opts.problem == "osbm":
51 | node_dim_u = 16
52 | node_dim_v = 3
53 | else:
54 | node_dim_u, node_dim_v = 1, 1
55 |
56 | self.embedder = encoder_class(
57 | n_heads=n_heads,
58 | embed_dim=embedding_dim,
59 | n_layers=self.n_encode_layers,
60 | normalization=normalization,
61 | problem=self.problem,
62 | opts=self.opts,
63 | node_dim_v=node_dim_v,
64 | node_dim_u=node_dim_u,
65 | )
66 |
67 | self.ff = nn.Sequential(
68 | nn.Linear(5 + 4 * opts.embedding_dim, 200),
69 | nn.ReLU(),
70 | nn.Linear(200, 200),
71 | nn.ReLU(),
72 | nn.Linear(200, 1),
73 | )
74 |
75 | assert embedding_dim % n_heads == 0
76 | self.step_context_transf = nn.Linear(2 * opts.embedding_dim, opts.embedding_dim)
77 | self.initial_stepcontext = nn.Parameter(torch.Tensor(1, 1, embedding_dim))
78 | self.initial_stepcontext.data.uniform_(-1, 1)
79 | self.dummy = torch.ones(1, dtype=torch.float32, requires_grad=True)
80 | self.model_name = "gnn-hist"
81 |
82 | def init_parameters(self):
83 | for name, param in self.named_parameters():
84 | stdv = 1.0 / math.sqrt(param.size(-1))
85 | param.data.uniform_(-stdv, stdv)
86 |
87 | def set_decode_type(self, decode_type, temp=None):
88 | self.decode_type = decode_type
89 | if temp is not None: # Do not change temperature if not provided
90 | self.temp = temp
91 |
92 | def forward(self, x, opts, optimizer, baseline, return_pi=False):
93 |
94 | _log_p, pi, cost = self._inner(x, opts)
95 |
96 | # cost, mask = self.problem.get_costs(input, pi)
97 | # Log likelyhood is calculated within the model since returning it per action does not work well with
98 | # DataParallel since sequences can be of different lengths
99 | ll, e = self._calc_log_likelihood(_log_p, pi, None)
100 | if return_pi:
101 | return -cost, ll, pi, e
102 | # print(ll)
103 | return -cost, ll, e
104 |
105 | def _calc_log_likelihood(self, _log_p, a, mask):
106 |
107 | entropy = -(_log_p * _log_p.exp()).sum(2).sum(1).mean()
108 | # Get log_p corresponding to selected actions
109 | log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1)
110 |
111 | # Optional: mask out actions irrelevant to objective so they do not get reinforced
112 | if mask is not None:
113 | log_p[mask] = 0
114 | if not (log_p > -10000).data.all():
115 | print(log_p.nonzero())
116 | assert (
117 | log_p > -10000
118 | ).data.all(), "Logprobs should not be -inf, check sampling procedure!"
119 |
120 | # Calculate log_likelihood
121 | # print(log_p.sum(1))
122 |
123 | return log_p.sum(1), entropy
124 |
125 | def _inner(self, input, opts):
126 |
127 | outputs = []
128 | sequences = []
129 |
130 | state = self.problem.make_state(input, opts.u_size, opts.v_size, opts)
131 |
132 | batch_size = state.batch_size
133 | graph_size = state.u_size + state.v_size + 1
134 | i = 1
135 | step_context = 0.0
136 | while not (state.all_finished()):
137 | step_size = state.i + 1
138 | mask = state.get_mask()
139 | w = state.get_current_weights(mask)
140 | # Pass the graph to the Encoder
141 | node_features = state.get_node_features()
142 | nodes = torch.cat(
143 | (
144 | torch.arange(0, opts.u_size + 1, device=opts.device),
145 | state.idx[:i] + opts.u_size + 1,
146 | )
147 | )
148 | subgraphs = (
149 | (nodes.unsqueeze(0).expand(batch_size, step_size))
150 | + torch.arange(
151 | 0, batch_size * graph_size, graph_size, device=opts.device
152 | ).unsqueeze(1)
153 | ).flatten() # The nodes of the current subgraphs
154 | graph_weights = state.get_graph_weights()
155 | edge_i, weights = subgraph(
156 | subgraphs,
157 | state.graphs.edge_index,
158 | graph_weights.unsqueeze(1),
159 | relabel_nodes=True,
160 | )
161 | embeddings = checkpoint(
162 | self.embedder,
163 | node_features,
164 | edge_i,
165 | weights.float(),
166 | torch.tensor(i),
167 | self.dummy,
168 | opts,
169 | ).reshape(batch_size, step_size, -1)
170 | pos = torch.argsort(state.idx[:i])[-1]
171 | incoming_node_embeddings = embeddings[
172 | :, pos + state.u_size + 1, :
173 | ].unsqueeze(1)
174 | # print(incoming_node_embeddings)
175 | w = (state.adj[:, state.get_current_node(), :]).float()
176 | # mean_w = w.mean(1)[:, None, None].repeat(1, state.u_size + 1, 1)
177 | s = w.reshape(state.batch_size, state.u_size + 1, 1)
178 | idx = (
179 | torch.ones(state.batch_size, 1, 1, device=opts.device)
180 | * (i - 1.0)
181 | / state.v_size
182 | )
183 |
184 | if i != 1:
185 | past_sol = (
186 | torch.stack(sequences, 1)
187 | + torch.arange(
188 | 0, batch_size * (i - 1), i - 1, device=opts.device
189 | ).unsqueeze(1)
190 | ).flatten()
191 |
192 | selected_nodes = torch.index_select(
193 | embeddings.reshape(-1, opts.embedding_dim),
194 | 0,
195 | past_sol.to(opts.device),
196 | ).reshape(batch_size, i - 1, opts.embedding_dim)
197 | step_context = (
198 | self.step_context_transf(
199 | torch.cat(
200 | (
201 | selected_nodes,
202 | embeddings[:, state.u_size + 1 : state.u_size + i, :],
203 | ),
204 | dim=2,
205 | )
206 | )
207 | .mean(1)
208 | .unsqueeze(1)
209 | )
210 | else:
211 | step_context = self.initial_stepcontext.repeat(batch_size, 1, 1)
212 | u_embeddings = embeddings[:, : opts.u_size + 1, :]
213 | fixed_node_identity = torch.zeros(
214 | state.batch_size, state.u_size + 1, 1, device=opts.device
215 | )
216 | fixed_node_identity[:, 0, :] = 1.0
217 | norm_size = (
218 | state.orig_budget.sum(-1)[:, None, None]
219 | if opts.problem == "adwords"
220 | else state.u_size
221 | )
222 | s = torch.cat(
223 | (
224 | s,
225 | idx.repeat(1, state.u_size + 1, 1),
226 | state.size.unsqueeze(2).repeat(1, state.u_size + 1, 1) / norm_size,
227 | fixed_node_identity,
228 | mask.unsqueeze(2),
229 | incoming_node_embeddings.repeat(1, state.u_size + 1, 1),
230 | embeddings[:, : opts.u_size + 1, :],
231 | step_context.repeat(1, state.u_size + 1, 1),
232 | u_embeddings.mean(1).unsqueeze(1).repeat(1, state.u_size + 1, 1),
233 | ),
234 | dim=2,
235 | ).float()
236 | pi = self.ff(s).reshape(state.batch_size, state.u_size + 1)
237 | # Select the indices of the next nodes in the sequences, result (batch_size) long
238 | selected, p = self._select_node(
239 | pi, mask.bool()
240 | ) # Squeeze out steps dimension
241 | # entropy += torch.sum(p * (p.log()), dim=1)
242 | state = state.update((selected)[:, None])
243 | outputs.append(p)
244 | sequences.append(selected)
245 | i += 1
246 | # Collected lists, return Tensor
247 | return (
248 | torch.stack(outputs, 1),
249 | torch.stack(sequences, 1),
250 | state.size,
251 | )
252 |
253 | def _select_node(self, probs, mask):
254 | assert (probs == probs).all(), "Probs should not contain any nans"
255 | probs[mask] = -1e6
256 | p = torch.log_softmax(probs, dim=1)
257 | # print(p)
258 | if self.decode_type == "greedy":
259 | _, selected = p.max(1)
260 | # assert not mask.gather(
261 | # 1, selected.unsqueeze(-1)
262 | # ).data.any(), "Decode greedy: infeasible action has maximum probability"
263 |
264 | elif self.decode_type == "sampling":
265 | selected = p.exp().multinomial(1).squeeze(1)
266 | # Check if sampling went OK, can go wrong due to bug on GPU
267 | # See https://discuss.pytorch.org/t/bad-behavior-of-multinomial-function/10232
268 | # while mask.gather(1, selected.unsqueeze(-1)).data.any():
269 | # print("Sampled bad values, resampling!")
270 | # selected = probs.multinomial(1).squeeze(1)
271 |
272 | else:
273 | assert False, "Unknown decode type"
274 | return selected, p
275 |
--------------------------------------------------------------------------------
/data/data_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import numpy as np
4 | import networkx as nx
5 | import torch
6 | from scipy.stats import powerlaw
7 | import torch_geometric
8 |
9 |
10 | # gMission files
11 | gMission_edges = "data/gMission/edges.txt"
12 | gMission_tasks = "data/gMission/tasks.txt"
13 | gMission_reduced_tasks = "data/gMission/reduced_tasks.txt"
14 | gMission_reduced_workers = "data/gMission/reduced_workers.txt"
15 |
16 | # MovieLense files
17 | movie_lense_movies = "data/MovieLense/movies.txt"
18 | movie_lense_users = "data/MovieLense/users.txt"
19 | movie_lense_edges = "data/MovieLense/edges.txt"
20 | movie_lense_ratings = "data/MovieLense/ratings.txt"
21 | movie_lense_feature_weights = "data/MovieLense/feature_weights.txt"
22 |
23 |
24 | def add_nodes_with_bipartite_label(G, lena, lenb):
25 | """
26 | Helper for generate_ba_graph that initializes the initial empty graph with nodes
27 | """
28 | G.add_nodes_from(range(0, lena + lenb))
29 | b = dict(zip(range(0, lena), [0] * lena))
30 | b.update(dict(zip(range(lena, lena + lenb), [1] * lenb)))
31 | nx.set_node_attributes(G, b, "bipartite")
32 | return G
33 |
34 |
35 | def get_solution(row_ind, col_in, weights, v_size):
36 | """
37 | returns a np vector where the index at i is the the node in u that v_i connect to. If index is zero, then v[i]
38 | is connected to no node in U.
39 | """
40 | new_col_in = []
41 | # row_ind.sort()
42 | col_in = col_in + 1
43 | new_col_in += [0] * (row_ind[0])
44 |
45 | for i in range(0, len(row_ind) - 1):
46 | if weights[row_ind[i], col_in[i] - 1] != 0.0:
47 | new_col_in.append(col_in[i])
48 | else:
49 | new_col_in.append(0.0)
50 | new_col_in += [0.0] * (row_ind[i + 1] - row_ind[i] - 1)
51 |
52 | if weights[row_ind[-1], col_in[-1] - 1] != 0.0:
53 | new_col_in.append(col_in[-1])
54 | else:
55 | new_col_in.append(0.0)
56 | new_col_in += [0] * (v_size - row_ind[-1] - 1)
57 | return new_col_in
58 |
59 |
60 | def check_extension(filename):
61 | if os.path.splitext(filename)[1] != ".pkl":
62 | return filename + ".pkl"
63 | return filename
64 |
65 |
66 | def save_dataset(dataset, filename):
67 |
68 | with open(check_extension(filename), "wb") as f:
69 | pickle.dump(dataset, f, pickle.HIGHEST_PROTOCOL)
70 |
71 |
72 | def load_dataset(filename):
73 |
74 | with open(check_extension(filename), "rb") as f:
75 | return pickle.load(f)
76 |
77 |
78 | def from_networkx(G):
79 | r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a
80 | :class:`torch_geometric.data.Data` instance.
81 |
82 | Args:
83 | G (networkx.Graph or networkx.DiGraph): A networkx graph.
84 | """
85 |
86 | G = nx.convert_node_labels_to_integers(G, ordering="sorted")
87 | G = G.to_directed() if not nx.is_directed(G) else G
88 | edge_index = torch.LongTensor(list(G.edges)).t().contiguous()
89 |
90 | data = {}
91 |
92 | for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
93 | for key, value in feat_dict.items():
94 | data[str(key)] = [value] if i == 0 else data[str(key)] + [value]
95 |
96 | for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
97 | for key, value in feat_dict.items():
98 | data[str(key)] = [value] if i == 0 else data[str(key)] + [value]
99 |
100 | for key, item in data.items():
101 | try:
102 | data[key] = torch.tensor(item)
103 | except ValueError:
104 | pass
105 |
106 | data["edge_index"] = edge_index.view(2, -1)
107 | data = torch_geometric.data.Data.from_dict(data)
108 | data.num_nodes = G.number_of_nodes()
109 |
110 | return data
111 |
112 |
113 | def parse_gmission_dataset():
114 | f_edges = open(gMission_edges, "r")
115 | f_tasks = open(gMission_tasks, "r")
116 | f_reduced_tasks = open(gMission_reduced_tasks, "r")
117 | f_reduced_workers = open(gMission_reduced_workers, "r")
118 | edgeWeights = dict()
119 | edgeNumber = dict()
120 | count = 0
121 | for line in f_edges:
122 | vals = line.split(",")
123 | edgeWeights[vals[0]] = vals[1].split("\n")[0]
124 | edgeNumber[vals[0]] = count
125 | count += 1
126 |
127 | tasks = list()
128 | reduced_tasks = []
129 | reduced_workers = []
130 | tasks_x = dict()
131 | tasks_y = dict()
132 |
133 | for line in f_tasks:
134 | vals = line.split(",")
135 | tasks.append(vals[0])
136 | tasks_x[vals[0]] = float(vals[2])
137 | tasks_y[vals[0]] = float(vals[3])
138 |
139 | for t in f_reduced_tasks:
140 | reduced_tasks.append(t)
141 |
142 | for w in f_reduced_workers:
143 | reduced_workers.append(w)
144 |
145 | return edgeWeights, tasks, reduced_tasks, reduced_workers
146 |
147 |
148 | def parse_movie_lense_dataset():
149 | f_edges = open(movie_lense_edges, "r")
150 | f_movies = open(movie_lense_movies, "r")
151 | f_users = open(movie_lense_users, "r")
152 | f_feature_weights = open(movie_lense_feature_weights, "r")
153 | num_genres = 15
154 | gender_map = {"M": 0, "F": 1}
155 | age_map = {"1": 0, "18": 1, "25": 2, "35": 3, "45": 4, "50": 5, "56": 6}
156 | genre_map = {
157 | "Action": 0,
158 | "Adventure": 1,
159 | "Animation": 2,
160 | "Children's": 3,
161 | "Comedy": 4,
162 | "Crime": 5,
163 | "Documentary": 6,
164 | "Drama": 7,
165 | "Film-Noir": 8,
166 | "Horror": 9,
167 | "Musical": 10,
168 | "Romance": 11,
169 | "Sci-Fi": 12,
170 | "Thriller": 13,
171 | "War": 14,
172 | }
173 | users = {}
174 | movies = {}
175 | edges = {}
176 | feature_weights = {}
177 | user_ids = []
178 | popularity = {}
179 | for u in f_users:
180 | info = u.split(",")[:4]
181 | info[1] = float(gender_map[info[1]])
182 | info[2] = float(age_map[info[2]]) / 6.0
183 | info[3] = float(info[3]) / 21.0
184 | users[info[0]] = info[1:4]
185 | user_ids.append(int(info[0]))
186 | user_ids.sort()
187 | for i, u in enumerate(user_ids):
188 | users[str(u)].append(i)
189 |
190 | for m in f_movies:
191 | info = m.split("::")
192 | genres = info[2].split("|")
193 | genres[-1] = genres[-1].split("\n")[0] # remove "\n" character
194 | genres_id = np.array(list(map(lambda g: genre_map[g], genres)))
195 | one_hot_encoding = np.zeros(num_genres)
196 | one_hot_encoding[genres_id] = 1.0
197 | movies[info[0]] = list(one_hot_encoding)
198 | popularity[info[0]] = 0
199 |
200 | for e in f_edges:
201 | info = e.split(",")
202 | genres = info[3].split("|")
203 | genres[-1] = genres[-1].split("\n")[0] # remove "\n" character
204 | edges[(info[2], info[1])] = list(map(lambda g: genre_map[g], genres))
205 | popularity[info[2]] += 1
206 |
207 | for w in f_feature_weights:
208 | feature = w.split(",")
209 | feature[-1] = feature[-1].split("\n")[0] # remove "\n" character
210 | if feature[1] not in feature_weights:
211 | feature_weights[feature[1]] = [0.0] * num_genres
212 | feature_weights[feature[1]][genre_map[feature[0]]] = float(feature[2]) / 5.0
213 | return users, movies, edges, feature_weights, popularity
214 |
215 |
216 | def find_best_tasks(tasks, edges):
217 | task_total = {}
218 | f_open = open("data/gMission/reduced_tasks.txt", "a")
219 | for e in edges.items():
220 | task = e[0].split(";")[1]
221 | if task not in task_total:
222 | task_total[task] = 1
223 | else:
224 | task_total[task] += 1
225 | top_tasks = sorted(task_total.items(), key=lambda k: k[1], reverse=True)[:300]
226 | for t in top_tasks:
227 | f_open.write(t[0] + "\n")
228 | return top_tasks
229 |
230 |
231 | def find_best_workers(tasks, edges):
232 | task_total = {}
233 | f_open = open("data/gMission/reduced_workers.txt", "a")
234 | for e in edges.items():
235 | task = e[0].split(";")[0]
236 | if task not in task_total:
237 | task_total[task] = 1
238 | else:
239 | task_total[task] += 1
240 | top_tasks = sorted(task_total.items(), key=lambda k: k[1], reverse=True)[:200]
241 | for t in top_tasks:
242 | f_open.write(t[0] + "\n")
243 | return top_tasks
244 |
245 |
246 | def generate_weights_geometric(distribution, u_size, v_size, parameters, g1, seed):
247 | weights, w = 0, 0
248 | np.random.seed(seed)
249 | if distribution == "uniform":
250 | weights = nx.bipartite.biadjacency_matrix(
251 | g1, range(0, u_size), range(u_size, u_size + v_size)
252 | ).toarray() * np.random.uniform(
253 | int(parameters[0]), int(parameters[1]), (u_size, v_size)
254 | )
255 | w = torch.cat(
256 | (torch.zeros(v_size, 1).float(), torch.tensor(weights).T.float()), 1
257 | )
258 | elif distribution == "normal":
259 | weights = nx.bipartite.biadjacency_matrix(
260 | g1, range(0, u_size), range(u_size, u_size + v_size)
261 | ).toarray() * (
262 | np.abs(
263 | np.random.normal(
264 | int(parameters[0]), int(parameters[1]), (u_size, v_size)
265 | )
266 | )
267 | + 5
268 | ) # to make sure no edge has weight zero
269 | w = torch.cat(
270 | (torch.zeros(v_size, 1).float(), torch.tensor(weights).T.float()), 1
271 | )
272 | elif distribution == "power":
273 | weights = nx.bipartite.biadjacency_matrix(
274 | g1, range(0, u_size), range(u_size, u_size + v_size)
275 | ).toarray() * (
276 | powerlaw.rvs(
277 | int(parameters[0]),
278 | int(parameters[1]),
279 | int(parameters[2]),
280 | (u_size, v_size),
281 | )
282 | + 5
283 | ) # to make sure no edge has weight zero
284 | w = torch.cat(
285 | (torch.zeros(v_size, 1).float(), torch.tensor(weights).T.float()), 1
286 | )
287 | elif distribution == "degree":
288 | weights = nx.bipartite.biadjacency_matrix(
289 | g1, range(0, u_size), range(u_size, u_size + v_size)
290 | ).toarray()
291 | graph = weights * weights.sum(axis=1).reshape(-1, 1)
292 | noise = np.abs(
293 | np.random.normal(
294 | float(parameters[0]), float(parameters[1]), (u_size, v_size)
295 | )
296 | )
297 | weights = np.where(graph, (graph + noise) / v_size, graph)
298 | w = torch.cat(
299 | (torch.zeros(v_size, 1).float(), torch.tensor(weights).T.float()), 1
300 | )
301 | elif distribution == "node-normal":
302 | adj = nx.bipartite.biadjacency_matrix(
303 | g1, range(0, u_size), range(u_size, u_size + v_size)
304 | ).toarray()
305 | mean = np.random.randint(
306 | float(parameters[0]), float(parameters[1]), (u_size, 1)
307 | )
308 | variance = np.sqrt(
309 | np.random.randint(float(parameters[0]), float(parameters[1]), (u_size, 1))
310 | )
311 | weights = (
312 | np.abs(np.random.normal(0.0, 1.0, (u_size, v_size)) * variance + mean) + 5
313 | ) * adj
314 | elif distribution == "fixed-normal":
315 | adj = nx.bipartite.biadjacency_matrix(
316 | g1, range(0, u_size), range(u_size, u_size + v_size)
317 | ).toarray()
318 | mean = np.random.choice(np.arange(0, 100, 15), size=(u_size, 1))
319 | variance = np.sqrt(np.random.choice(np.arange(0, 100, 20), (u_size, 1)))
320 | weights = (
321 | np.abs(np.random.normal(0.0, 1.0, (u_size, v_size)) * variance + mean) + 5
322 | ) * adj
323 |
324 | w = np.delete(weights.flatten(), weights.flatten() == 0)
325 | return weights, w
326 |
327 | def visualize_biparite(G, u, v, fig_path = "temp.png"):
328 |
329 | import matplotlib
330 | import matplotlib.pyplot as plt
331 | # nx.draw(G)
332 | G_left = G.subgraph([i for i in range(u)])
333 | G_right = G.subgraph([i for i in range(u,u+v)])
334 | edges = G.edges()
335 |
336 | pos = dict()
337 | pos.update( (n, (1, i)) for i, n in enumerate(G_left) ) # put nodes from X at x=1
338 | pos.update( (n, (2, i)) for i, n in enumerate(G_right) ) # put nodes from Y at x=2
339 | nx.draw_networkx(G, pos=pos)
340 |
341 | plt.savefig(fig_path)
342 |
--------------------------------------------------------------------------------
/problem_state/edge_obm_env.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import NamedTuple
3 | from torch_geometric.utils import to_dense_adj
4 |
5 | # from utils.boolmask import mask_long2bool, mask_long_scatter
6 |
7 |
8 | class StateEdgeBipartite(NamedTuple):
9 | # Fixed input
10 | graphs: torch.Tensor # graphs objects in a batch
11 | u_size: int
12 | v_size: int
13 | batch_size: torch.Tensor
14 | hist_sum: torch.Tensor
15 | hist_sum_sq: torch.Tensor
16 | hist_deg: torch.Tensor
17 | min_sol: torch.Tensor
18 | max_sol: torch.Tensor
19 | sum_sol_sq: torch.Tensor
20 | num_skip: torch.Tensor
21 |
22 | adj: torch.Tensor
23 | # State
24 | matched_nodes: torch.Tensor # Keeps track of nodes that have been matched
25 | size: torch.Tensor # size of current matching
26 | i: int # Keeps track of step
27 | opts: dict
28 | idx: torch.Tensor
29 |
30 | @staticmethod
31 | def initialize(
32 | input,
33 | u_size,
34 | v_size,
35 | opts,
36 | ):
37 | graph_size = u_size + v_size + 1
38 | batch_size = int(input.batch.size(0) / graph_size)
39 | # print(batch_size, input.batch.size(0), graph_size)
40 | adj = to_dense_adj(
41 | input.edge_index, input.batch, input.weight.unsqueeze(1)
42 | ).squeeze(-1)
43 | adj = adj[:, u_size + 1 :, : u_size + 1]
44 |
45 | # permute the nodes for data
46 | idx = torch.arange(adj.shape[1], device=opts.device)
47 | # if "supervised" not in opts.model and not opts.eval_only:
48 | # idx = torch.randperm(adj.shape[1], device=opts.device)
49 | # adj = adj[:, idx, :].view(adj.size())
50 |
51 | return StateEdgeBipartite(
52 | graphs=input,
53 | adj=adj,
54 | u_size=u_size,
55 | v_size=v_size,
56 | batch_size=batch_size,
57 | # Keep visited with depot so we can scatter efficiently (if there is an action for depot)
58 | matched_nodes=( # Visited as mask is easier to understand, as long more memory efficient
59 | torch.zeros(
60 | batch_size,
61 | u_size + 1,
62 | device=input.batch.device,
63 | )
64 | ),
65 | hist_sum=( # Visited as mask is easier to understand, as long more memory efficient
66 | torch.zeros(
67 | batch_size,
68 | 1,
69 | u_size + 1,
70 | device=input.batch.device,
71 | )
72 | ),
73 | hist_deg=( # Visited as mask is easier to understand, as long more memory efficient
74 | torch.zeros(
75 | batch_size,
76 | 1,
77 | u_size + 1,
78 | device=input.batch.device,
79 | )
80 | ),
81 | hist_sum_sq=( # Visited as mask is easier to understand, as long more memory efficient
82 | torch.zeros(
83 | batch_size,
84 | 1,
85 | u_size + 1,
86 | device=input.batch.device,
87 | )
88 | ),
89 | min_sol=torch.zeros(batch_size, 1, device=input.batch.device),
90 | max_sol=torch.zeros(batch_size, 1, device=input.batch.device),
91 | sum_sol_sq=torch.zeros(batch_size, 1, device=input.batch.device),
92 | num_skip=torch.zeros(batch_size, 1, device=input.batch.device),
93 | size=torch.zeros(batch_size, 1, device=input.batch.device),
94 | i=u_size + 1,
95 | opts=opts,
96 | idx=idx,
97 | )
98 |
99 | def get_final_cost(self):
100 |
101 | assert self.all_finished()
102 | # assert self.visited_.
103 |
104 | return self.size
105 |
106 | def get_current_weights(self, mask):
107 | return self.adj[:, 0, :].float()
108 |
109 | def get_graph_weights(self):
110 | return self.graphs.weight
111 |
112 | def update(self, selected):
113 | # Update the state
114 | nodes = self.matched_nodes.scatter_(-1, selected, 1)
115 | nodes[
116 | :, 0
117 | ] = 0 # node that represents not being matched to anything can be matched to more than once
118 | w = self.adj[:, 0, :].clone()
119 | selected_weights = w.gather(1, selected).to(self.adj.device)
120 | skip = (selected == 0).float()
121 | num_skip = self.num_skip + skip
122 | if self.i == self.u_size + 1:
123 | min_sol = selected_weights
124 | else:
125 | m = self.min_sol.clone()
126 | m[m == 0.0] = 2.0
127 | selected_weights[skip.bool()] = 2.0
128 | min_sol = torch.minimum(m, selected_weights)
129 | selected_weights[selected_weights == 2.0] = 0.0
130 | min_sol[min_sol == 2.0] = 0.0
131 |
132 | max_sol = torch.maximum(self.max_sol, selected_weights)
133 | total_weights = self.size + selected_weights
134 | sum_sol_sq = self.sum_sol_sq + selected_weights ** 2
135 |
136 | hist_sum = self.hist_sum + w.unsqueeze(1)
137 | hist_sum_sq = self.hist_sum_sq + w.unsqueeze(1) ** 2
138 | hist_deg = self.hist_deg + (w.unsqueeze(1) != 0).float()
139 | hist_deg[:, :, 0] = float(self.i - self.u_size)
140 | return self._replace(
141 | matched_nodes=nodes,
142 | size=total_weights,
143 | i=self.i + 1,
144 | adj=self.adj[:, 1:, :],
145 | hist_sum=hist_sum,
146 | hist_sum_sq=hist_sum_sq,
147 | hist_deg=hist_deg,
148 | num_skip=num_skip,
149 | max_sol=max_sol,
150 | min_sol=min_sol,
151 | sum_sol_sq=sum_sol_sq,
152 | )
153 |
154 | def all_finished(self):
155 | # Exactly v_size steps
156 | return (self.i - (self.u_size + 1)) >= self.v_size
157 |
158 | def get_current_node(self):
159 | return 0
160 |
161 | def get_curr_state(self, model):
162 | mask = self.get_mask()
163 | opts = self.opts
164 | w = self.adj[:, 0, :].float().clone()
165 | s = None
166 | if model == "ff":
167 | s = torch.cat((w, mask.float()), dim=1)
168 |
169 | elif model == "inv-ff":
170 | deg = (w != 0).float().sum(1)
171 | deg[deg == 0.0] = 1.0
172 | mean_w = w.sum(1) / deg
173 | mean_w = mean_w[:, None, None].repeat(1, self.u_size + 1, 1)
174 | fixed_node_identity = torch.zeros(
175 | self.batch_size, self.u_size + 1, 1, device=opts.device
176 | ).float()
177 | fixed_node_identity[:, 0, :] = 1.0
178 | s = w.reshape(self.batch_size, self.u_size + 1, 1)
179 | s = torch.cat(
180 | (
181 | fixed_node_identity,
182 | s,
183 | mean_w,
184 | ),
185 | dim=2,
186 | )
187 |
188 | elif model == "ff-hist" or model == "ff-supervised":
189 | (
190 | h_mean,
191 | h_var,
192 | h_mean_degree,
193 | ind,
194 | matched_ratio,
195 | var_sol,
196 | mean_sol,
197 | n_skip,
198 | ) = self.get_hist_features()
199 | s = torch.cat(
200 | (
201 | w,
202 | mask.float(),
203 | h_mean.squeeze(1),
204 | h_var.squeeze(1),
205 | h_mean_degree.squeeze(1),
206 | self.size / self.u_size,
207 | ind.float(),
208 | mean_sol,
209 | var_sol,
210 | n_skip,
211 | self.max_sol,
212 | self.min_sol,
213 | matched_ratio,
214 | ),
215 | dim=1,
216 | ).float()
217 |
218 | elif model == "inv-ff-hist" or model == "gnn-simp-hist":
219 | deg = (w != 0).float().sum(1)
220 | deg[deg == 0.0] = 1.0
221 | mean_w = w.sum(1) / deg
222 | mean_w = mean_w[:, None, None].repeat(1, self.u_size + 1, 1)
223 | s = w.reshape(self.batch_size, self.u_size + 1, 1)
224 | (
225 | h_mean,
226 | h_var,
227 | h_mean_degree,
228 | ind,
229 | matched_ratio,
230 | var_sol,
231 | mean_sol,
232 | n_skip,
233 | ) = self.get_hist_features()
234 | available_ratio = (deg.unsqueeze(1)) / (self.u_size)
235 | fixed_node_identity = torch.zeros(
236 | self.batch_size, self.u_size + 1, 1, device=opts.device
237 | ).float()
238 | fixed_node_identity[:, 0, :] = 1.0
239 | s = torch.cat(
240 | (
241 | s,
242 | mask.reshape(-1, self.u_size + 1, 1).float(),
243 | mean_w,
244 | h_mean.transpose(1, 2),
245 | h_var.transpose(1, 2),
246 | h_mean_degree.transpose(1, 2),
247 | ind.unsqueeze(2).repeat(1, self.u_size + 1, 1),
248 | self.size.unsqueeze(2).repeat(1, self.u_size + 1, 1) / self.u_size,
249 | mean_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1),
250 | var_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1),
251 | n_skip.unsqueeze(2).repeat(1, self.u_size + 1, 1),
252 | self.max_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1),
253 | self.min_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1),
254 | matched_ratio.unsqueeze(2).repeat(1, self.u_size + 1, 1),
255 | available_ratio.unsqueeze(2).repeat(1, self.u_size + 1, 1),
256 | fixed_node_identity,
257 | ),
258 | dim=2,
259 | ).float()
260 |
261 | return s, mask
262 |
263 | def get_node_features(self):
264 | step_size = self.i + 1
265 | batch_size = self.batch_size
266 | incoming_node_features = (
267 | torch.cat(
268 | (torch.ones(step_size - self.u_size - 1, device=self.adj.device) * 2,)
269 | )
270 | .unsqueeze(0)
271 | .expand(batch_size, step_size - self.u_size - 1)
272 | ).float() # Collecting node features up until the ith incoming node
273 |
274 | future_node_feature = torch.ones(batch_size, 1, device=self.adj.device) * -1.0
275 | fixed_node_feature = self.matched_nodes[:, 1:]
276 | node_features = torch.cat(
277 | (future_node_feature, fixed_node_feature, incoming_node_features), dim=1
278 | ).reshape(batch_size, step_size)
279 |
280 | return node_features
281 |
282 | def get_hist_features(self):
283 | i = self.i - (self.u_size + 1)
284 | if i != 0:
285 | deg = self.hist_deg.clone()
286 | deg[deg == 0] = 1.0
287 | h_mean = self.hist_sum / deg
288 | h_var = (self.hist_sum_sq - ((self.hist_sum ** 2) / deg)) / deg
289 | h_mean_degree = self.hist_deg / i
290 | ind = (
291 | torch.ones(self.batch_size, 1, device=self.opts.device)
292 | * i
293 | / self.v_size
294 | )
295 | curr_sol_size = i - self.num_skip
296 | var_sol = (
297 | self.sum_sol_sq - ((self.size ** 2) / curr_sol_size)
298 | ) / curr_sol_size
299 | mean_sol = self.size / curr_sol_size
300 | var_sol[curr_sol_size == 0.0] = 0.0
301 | mean_sol[curr_sol_size == 0.0] = 0.0
302 | matched_ratio = self.matched_nodes.sum(1).unsqueeze(1) / self.u_size
303 | n_skip = self.num_skip / i
304 | else:
305 | (
306 | h_mean,
307 | h_var,
308 | h_mean_degree,
309 | ind,
310 | matched_ratio,
311 | var_sol,
312 | mean_sol,
313 | n_skip,
314 | ) = (
315 | self.hist_sum * 0.0,
316 | self.hist_sum * 0.0,
317 | self.hist_sum * 0.0,
318 | self.size * 0.0,
319 | self.num_skip * 0.0,
320 | self.size * 0.0,
321 | self.size * 0.0,
322 | self.size * 0.0,
323 | )
324 |
325 | return (
326 | h_mean,
327 | h_var,
328 | h_mean_degree,
329 | ind,
330 | matched_ratio,
331 | var_sol,
332 | mean_sol,
333 | n_skip,
334 | )
335 |
336 | def get_mask(self):
337 | """
338 | Returns a mask vector which includes only nodes in U that can matched.
339 | That is, neighbors of the incoming node that have not been matched already.
340 | """
341 |
342 | mask = (self.adj[:, 0, :] == 0).float()
343 | mask[:, 0] = 0
344 | self.matched_nodes[
345 | :, 0
346 | ] = 0 # node that represents not being matched to anything can be matched to more than once
347 | return (
348 | self.matched_nodes + mask > 0
349 | ).long() # Hacky way to return bool or uint8 depending on pytorch version
350 |
351 | ##### Works fine in most cases, but didn't unitest this module, please pay
352 | ##### extra attention to the rewards reservation....
353 |
354 | def try_hypothesis_action(self, selected):
355 |
356 | w = self.adj[:, 0, :].clone()
357 | selected_weights = w.gather(1, selected).to(self.adj.device)
358 | skip = (selected == 0).float()
359 | if self.i != self.u_size + 1:
360 | selected_weights[skip.bool()] = 2.0
361 | selected_weights[selected_weights == 2.0] = 0.0
362 |
363 | total_weights = self.size + selected_weights
364 |
365 | matched_nodes = self.matched_nodes.clone().scatter_(-1, selected, 1)
366 | matched_nodes[:, 0] = 0 # node that represents not being matched to anything can be matched to more than once
367 |
368 | return total_weights, matched_nodes
369 |
--------------------------------------------------------------------------------
/data/gMission/workers.txt:
--------------------------------------------------------------------------------
1 | 1,1.6,1.36,1,0.971
2 | 2,0.36,3.32,1,0.641
3 | 3,2.2,3.41,1,0.904
4 | 4,2.45,4.32,1,0.749
5 | 5,3.75,1.1,1,0.944
6 | 6,2.56,2.67,1,0.914
7 | 7,4.38,0.07,1,0.642
8 | 8,2.04,3.25,1,0.926
9 | 9,2.27,1.42,1,0.665
10 | 10,0.23,0.04,1,0.677
11 | 11,4.68,3.38,1,0.916
12 | 12,3.77,0.94,1,0.947
13 | 13,0.48,1.37,1,0.787
14 | 14,1.97,2.45,1,0.733
15 | 15,0.87,3.77,1,0.909
16 | 16,0.85,3.86,1,0.872
17 | 17,2.53,3.52,1,0.911
18 | 18,4.05,3.16,1,0.846
19 | 19,2.13,1.77,1,0.801
20 | 20,0.12,4.23,1,0.763
21 | 21,0.91,0.3,1,0.926
22 | 22,4.87,1.99,1,0.772
23 | 23,1.5,1.6,1,0.753
24 | 24,0.88,3.67,1,0.95
25 | 25,0.97,0.5,1,0.886
26 | 26,3.71,0.15,1,0.791
27 | 27,2.5,4.83,1,0.659
28 | 28,4.9,1.89,1,0.904
29 | 29,3.56,2.76,1,0.764
30 | 30,1.6,3.79,1,0.978
31 | 31,0.45,2.11,1,0.868
32 | 32,4.51,4.79,1,0.909
33 | 33,4.59,4.94,1,0.758
34 | 34,2.51,4.33,1,0.863
35 | 35,2.93,4.66,1,0.981
36 | 36,3.65,4.35,1,0.698
37 | 37,2.74,3.13,1,0.64
38 | 38,4.17,1.63,1,0.928
39 | 39,2.54,2.96,1,0.904
40 | 40,4.46,0.87,1,0.717
41 | 41,4.72,4.69,1,0.946
42 | 42,0.5,3.0,1,0.76
43 | 43,3.92,0.38,1,0.951
44 | 44,0.74,0.62,1,0.976
45 | 45,4.45,1.26,1,0.972
46 | 46,0.2,0.87,1,0.263
47 | 47,1.55,0.1,1,0.687
48 | 48,1.53,2.63,1,0.764
49 | 49,1.62,1.33,1,0.862
50 | 50,3.31,3.18,1,0.775
51 | 51,4.78,4.31,1,0.853
52 | 52,3.52,1.86,1,0.676
53 | 53,0.2,0.55,1,0.82
54 | 54,3.96,0.27,1,0.862
55 | 55,0.99,3.18,1,0.929
56 | 56,2.16,2.32,1,0.979
57 | 57,0.9,3.14,1,0.718
58 | 58,4.73,4.38,1,0.717
59 | 59,0.28,4.66,1,0.894
60 | 60,1.03,1.66,1,0.632
61 | 61,0.93,1.65,1,0.834
62 | 62,2.19,0.82,1,0.775
63 | 63,4.81,0.55,1,0.804
64 | 64,4.6,3.36,1,0.933
65 | 65,1.25,3.89,1,0.71
66 | 66,4.73,4.73,1,0.83
67 | 67,4.58,3.17,1,0.699
68 | 68,4.58,0.24,1,0.989
69 | 69,4.44,1.92,1,0.849
70 | 70,1.48,2.43,1,0.95
71 | 71,2.42,3.01,1,0.693
72 | 72,3.93,2.63,1,0.942
73 | 73,1.85,4.95,1,0.979
74 | 74,3.78,1.09,1,0.871
75 | 75,2.76,2.82,1,0.875
76 | 76,1.84,3.39,1,0.73
77 | 77,4.59,0.94,1,0.825
78 | 78,4.44,3.39,1,0.807
79 | 79,4.92,2.78,1,0.857
80 | 80,1.54,1.42,1,0.877
81 | 81,4.04,1.0,1,0.658
82 | 82,1.61,3.45,1,0.712
83 | 83,1.38,2.3,1,0.757
84 | 84,4.34,4.73,1,0.813
85 | 85,3.77,0.01,1,0.875
86 | 86,0.01,2.22,1,0.849
87 | 87,1.46,4.51,1,0.783
88 | 88,1.89,0.12,1,0.957
89 | 89,0.24,1.28,1,0.774
90 | 90,4.3,1.67,1,0.96
91 | 91,1.28,0.37,1,0.772
92 | 92,4.38,1.61,1,0.657
93 | 93,2.82,1.5,1,0.765
94 | 94,2.81,2.24,1,0.713
95 | 95,4.69,0.19,1,0.695
96 | 96,0.5,1.88,1,0.961
97 | 97,4.54,3.21,1,0.81
98 | 98,0.01,3.24,1,0.951
99 | 99,0.09,2.67,1,0.765
100 | 100,3.46,0.24,1,0.932
101 | 101,0.28,3.26,1,0.68
102 | 102,4.64,4.44,1,0.644
103 | 103,2.76,1.97,1,0.815
104 | 104,4.68,4.61,1,0.83
105 | 105,1.85,3.87,1,0.692
106 | 106,2.37,1.39,1,0.869
107 | 107,0.94,3.29,1,0.771
108 | 108,3.99,4.98,1,0.715
109 | 109,3.85,4.44,1,0.893
110 | 110,4.3,0.03,1,0.78
111 | 111,4.16,2.04,1,0.93
112 | 112,3.45,0.26,1,0.88
113 | 113,1.81,1.11,1,0.96
114 | 114,3.6,4.66,1,0.733
115 | 115,2.61,0.81,1,0.961
116 | 116,0.14,3.22,1,0.666
117 | 117,4.01,4.43,1,0.884
118 | 118,2.11,0.03,1,0.903
119 | 119,0.71,0.4,1,0.801
120 | 120,0.53,0.16,1,0.639
121 | 121,0.78,3.43,1,0.638
122 | 122,0.96,4.02,1,0.911
123 | 123,3.44,4.96,1,0.798
124 | 124,0.7,4.24,1,0.828
125 | 125,4.92,0.28,1,0.39
126 | 126,4.11,0.82,1,0.706
127 | 127,4.69,4.26,1,0.676
128 | 128,4.38,3.57,1,0.738
129 | 129,2.59,3.34,1,0.84
130 | 130,3.2,1.98,1,0.735
131 | 131,0.55,3.34,1,0.768
132 | 132,4.32,1.77,1,0.816
133 | 133,4.53,3.53,1,0.857
134 | 134,3.45,0.79,1,0.833
135 | 135,4.55,0.74,1,0.918
136 | 136,3.25,2.86,1,0.808
137 | 137,3.35,4.37,1,0.769
138 | 138,3.19,2.21,1,0.915
139 | 139,0.56,4.86,1,0.72
140 | 140,4.6,2.58,1,0.685
141 | 141,0.7,4.14,1,0.825
142 | 142,3.08,1.2,1,0.836
143 | 143,4.17,1.05,1,0.706
144 | 144,1.84,2.73,1,0.284
145 | 145,1.44,3.46,1,0.927
146 | 146,0.02,2.58,1,0.297
147 | 147,0.64,0.91,1,0.873
148 | 148,3.02,0.77,1,0.915
149 | 149,3.86,1.49,1,0.716
150 | 150,2.07,0.09,1,0.794
151 | 151,3.78,3.57,1,0.91
152 | 152,3.3,2.2,1,0.929
153 | 153,0.72,3.61,1,0.978
154 | 154,0.39,0.83,1,0.778
155 | 155,3.07,4.11,1,0.788
156 | 156,0.92,3.49,1,0.858
157 | 157,1.07,3.08,1,0.658
158 | 158,3.69,1.06,1,0.811
159 | 159,4.08,3.73,1,0.939
160 | 160,2.47,1.63,1,0.971
161 | 161,2.65,3.39,1,0.761
162 | 162,0.49,0.97,1,0.891
163 | 163,2.83,0.5,1,0.846
164 | 164,1.01,4.84,1,0.834
165 | 165,1.16,4.73,1,0.732
166 | 166,1.55,3.2,1,0.671
167 | 167,2.7,2.88,1,0.917
168 | 168,0.81,3.46,1,0.816
169 | 169,1.01,3.43,1,0.958
170 | 170,4.95,2.63,1,0.947
171 | 171,0.23,3.85,1,0.872
172 | 172,3.95,2.15,1,0.731
173 | 173,2.35,1.75,1,0.635
174 | 174,0.22,2.65,1,0.781
175 | 175,4.36,0.95,1,0.752
176 | 176,1.9,4.79,1,0.894
177 | 177,3.39,4.82,1,0.907
178 | 178,2.52,4.94,1,0.988
179 | 179,3.93,0.75,1,0.937
180 | 180,0.06,3.72,1,0.67
181 | 181,3.95,4.39,1,0.228
182 | 182,4.04,1.89,1,0.809
183 | 183,0.03,3.69,1,0.951
184 | 184,4.82,4.35,1,0.867
185 | 185,4.51,2.25,1,0.906
186 | 186,2.08,0.34,1,0.664
187 | 187,0.55,4.08,1,0.828
188 | 188,3.3,2.34,1,0.979
189 | 189,4.79,1.92,1,0.633
190 | 190,2.66,1.96,1,0.657
191 | 191,4.31,3.51,1,0.932
192 | 192,0.1,1.52,1,0.923
193 | 193,1.29,1.43,1,0.808
194 | 194,0.46,1.95,1,0.88
195 | 195,2.8,0.95,1,0.979
196 | 196,1.56,4.59,1,0.653
197 | 197,4.56,1.97,1,0.786
198 | 198,1.07,3.59,1,0.645
199 | 199,0.79,2.23,1,0.986
200 | 200,0.48,1.39,1,0.865
201 | 201,4.65,2.43,1,0.933
202 | 202,4.58,4.89,1,0.668
203 | 203,2.01,4.24,1,0.841
204 | 204,4.2,3.1,1,0.762
205 | 205,3.18,0.78,1,0.633
206 | 206,4.76,4.46,1,0.972
207 | 207,1.21,2.83,1,0.652
208 | 208,2.45,4.48,1,0.832
209 | 209,3.65,0.55,1,0.318
210 | 210,1.39,2.57,1,0.71
211 | 211,4.52,2.94,1,0.806
212 | 212,3.77,3.03,1,0.903
213 | 213,0.53,2.4,1,0.916
214 | 214,3.03,1.06,1,0.983
215 | 215,0.83,3.95,1,0.887
216 | 216,2.24,2.15,1,0.917
217 | 217,3.46,3.96,1,0.973
218 | 218,3.76,4.13,1,0.946
219 | 219,4.64,3.31,1,0.762
220 | 220,1.6,2.45,1,0.837
221 | 221,1.6,0.87,1,0.219
222 | 222,2.81,0.88,1,0.823
223 | 223,4.95,3.38,1,0.93
224 | 224,0.91,3.1,1,0.685
225 | 225,0.08,1.17,1,0.782
226 | 226,4.9,4.36,1,0.762
227 | 227,1.28,0.99,1,0.883
228 | 228,4.53,1.92,1,0.824
229 | 229,2.44,4.15,1,0.787
230 | 230,1.14,1.2,1,0.844
231 | 231,4.23,0.79,1,0.678
232 | 232,0.02,4.2,1,0.321
233 | 233,1.29,4.25,1,0.916
234 | 234,0.56,4.06,1,0.893
235 | 235,0.83,4.08,1,0.648
236 | 236,2.96,0.98,1,0.731
237 | 237,3.73,0.43,1,0.77
238 | 238,4.84,0.77,1,0.694
239 | 239,1.82,1.15,1,0.392
240 | 240,2.61,0.12,1,0.645
241 | 241,1.07,0.59,1,0.983
242 | 242,3.04,4.4,1,0.978
243 | 243,0.81,0.48,1,0.784
244 | 244,2.42,0.33,1,0.657
245 | 245,0.93,2.21,1,0.89
246 | 246,0.64,0.48,1,0.931
247 | 247,4.89,4.1,1,0.821
248 | 248,0.21,2.17,1,0.744
249 | 249,2.64,4.36,1,0.803
250 | 250,0.72,1.69,1,0.835
251 | 251,1.45,4.15,1,0.404
252 | 252,2.32,4.5,1,0.65
253 | 253,0.39,3.32,1,0.928
254 | 254,0.98,1.67,1,0.931
255 | 255,0.99,3.37,1,0.89
256 | 256,3.76,2.18,1,0.872
257 | 257,1.19,1.0,1,0.908
258 | 258,3.59,4.56,1,0.831
259 | 259,0.61,3.1,1,0.637
260 | 260,1.02,2.64,1,0.974
261 | 261,0.35,3.97,1,0.953
262 | 262,0.23,4.49,1,0.858
263 | 263,4.62,2.16,1,0.844
264 | 264,5.0,2.38,1,0.813
265 | 265,0.07,2.84,1,0.635
266 | 266,2.22,4.42,1,0.883
267 | 267,3.94,0.01,1,0.92
268 | 268,1.02,3.38,1,0.858
269 | 269,0.88,0.33,1,0.963
270 | 270,1.11,0.82,1,0.806
271 | 271,3.06,2.5,1,0.827
272 | 272,0.51,4.13,1,0.756
273 | 273,1.94,2.41,1,0.632
274 | 274,0.4,2.25,1,0.778
275 | 275,3.08,3.22,1,0.721
276 | 276,3.36,2.71,1,0.87
277 | 277,4.54,2.61,1,0.723
278 | 278,2.71,2.44,1,0.952
279 | 279,0.32,0.78,1,0.702
280 | 280,2.59,2.19,1,0.933
281 | 281,2.76,4.42,1,0.726
282 | 282,1.33,4.01,1,0.745
283 | 283,1.35,3.66,1,0.911
284 | 284,3.97,2.05,1,0.839
285 | 285,0.95,0.42,1,0.75
286 | 286,4.69,2.73,1,0.669
287 | 287,2.15,0.51,1,0.334
288 | 288,4.56,3.53,1,0.989
289 | 289,0.57,2.84,1,0.852
290 | 290,1.34,4.68,1,0.809
291 | 291,1.86,4.38,1,0.637
292 | 292,3.12,4.38,1,0.654
293 | 293,2.21,3.09,1,0.79
294 | 294,1.27,2.01,1,0.882
295 | 295,3.24,0.98,1,0.907
296 | 296,1.19,3.28,1,0.659
297 | 297,1.15,1.61,1,0.695
298 | 298,0.16,1.8,1,0.758
299 | 299,3.5,4.56,1,0.724
300 | 300,3.97,2.79,1,0.639
301 | 301,4.13,2.09,1,0.84
302 | 302,4.74,3.71,1,0.705
303 | 303,0.58,4.19,1,0.84
304 | 304,3.34,0.76,1,0.861
305 | 305,4.95,1.01,1,0.643
306 | 306,3.04,4.31,1,0.679
307 | 307,4.89,3.43,1,0.868
308 | 308,4.37,1.31,1,0.989
309 | 309,1.76,3.61,1,0.727
310 | 310,1.25,4.81,1,0.882
311 | 311,0.13,2.76,1,0.842
312 | 312,1.55,2.49,1,0.673
313 | 313,0.24,3.0,1,0.853
314 | 314,3.07,0.85,1,0.918
315 | 315,4.04,1.54,1,0.814
316 | 316,0.06,3.68,1,0.963
317 | 317,1.93,0.35,1,0.96
318 | 318,2.38,1.76,1,0.662
319 | 319,1.11,3.09,1,0.661
320 | 320,2.17,4.74,1,0.714
321 | 321,2.05,3.88,1,0.845
322 | 322,1.33,2.54,1,0.961
323 | 323,3.58,2.6,1,0.703
324 | 324,2.6,1.27,1,0.796
325 | 325,0.5,3.44,1,0.638
326 | 326,1.08,4.43,1,0.635
327 | 327,4.09,2.15,1,0.755
328 | 328,3.85,3.64,1,0.747
329 | 329,2.11,3.89,1,0.634
330 | 330,3.88,0.24,1,0.788
331 | 331,1.47,3.26,1,0.778
332 | 332,2.68,2.23,1,0.831
333 | 333,3.01,4.34,1,0.92
334 | 334,0.65,4.99,1,0.765
335 | 335,0.43,2.64,1,0.741
336 | 336,0.51,3.04,1,0.912
337 | 337,2.96,0.93,1,0.848
338 | 338,1.09,2.72,1,0.841
339 | 339,4.17,2.36,1,0.867
340 | 340,2.92,2.24,1,0.945
341 | 341,0.74,3.42,1,0.918
342 | 342,1.75,1.05,1,0.832
343 | 343,4.2,1.38,1,0.807
344 | 344,1.78,1.33,1,0.98
345 | 345,2.99,4.55,1,0.723
346 | 346,1.94,0.39,1,0.73
347 | 347,3.66,3.55,1,0.673
348 | 348,1.13,3.74,1,0.875
349 | 349,4.37,3.41,1,0.81
350 | 350,2.17,1.52,1,0.93
351 | 351,3.76,3.99,1,0.773
352 | 352,2.85,2.48,1,0.722
353 | 353,3.4,2.59,1,0.652
354 | 354,2.3,4.82,1,0.74
355 | 355,3.18,1.49,1,0.656
356 | 356,4.68,2.36,1,0.724
357 | 357,3.12,3.55,1,0.967
358 | 358,4.03,1.33,1,0.815
359 | 359,2.21,4.2,1,0.756
360 | 360,4.45,3.11,1,0.821
361 | 361,0.09,2.69,1,0.843
362 | 362,3.24,1.1,1,0.774
363 | 363,1.76,2.4,1,0.896
364 | 364,4.18,3.3,1,0.877
365 | 365,2.81,2.65,1,0.695
366 | 366,3.17,2.79,1,0.972
367 | 367,3.4,0.29,1,0.986
368 | 368,0.52,4.51,1,0.631
369 | 369,2.54,0.2,1,0.95
370 | 370,2.75,3.68,1,0.924
371 | 371,0.19,0.76,1,0.81
372 | 372,3.99,3.2,1,0.788
373 | 373,4.02,4.81,1,0.778
374 | 374,0.44,1.57,1,0.335
375 | 375,0.86,1.4,1,0.967
376 | 376,1.46,4.67,1,0.834
377 | 377,2.74,0.82,1,0.902
378 | 378,2.18,4.36,1,0.851
379 | 379,4.89,0.34,1,0.728
380 | 380,4.5,3.27,1,0.862
381 | 381,1.74,3.87,1,0.78
382 | 382,4.58,0.19,1,0.884
383 | 383,3.49,0.31,1,0.714
384 | 384,4.76,3.79,1,0.965
385 | 385,3.48,3.75,1,0.942
386 | 386,2.4,4.65,1,0.95
387 | 387,1.63,1.21,1,0.642
388 | 388,1.32,2.25,1,0.959
389 | 389,1.49,1.88,1,0.856
390 | 390,0.78,3.36,1,0.945
391 | 391,2.49,0.03,1,0.693
392 | 392,1.49,4.11,1,0.981
393 | 393,0.26,3.6,1,0.964
394 | 394,3.42,2.65,1,0.852
395 | 395,2.18,0.02,1,0.874
396 | 396,0.5,2.14,1,0.654
397 | 397,4.3,3.76,1,0.815
398 | 398,1.21,1.84,1,0.711
399 | 399,3.32,4.61,1,0.637
400 | 400,2.34,2.81,1,0.8
401 | 401,1.92,3.23,1,0.885
402 | 402,3.31,1.86,1,0.789
403 | 403,0.03,4.11,1,0.929
404 | 404,4.23,2.72,1,0.711
405 | 405,1.05,2.19,1,0.742
406 | 406,3.12,3.04,1,0.987
407 | 407,4.06,2.09,1,0.762
408 | 408,3.98,0.56,1,0.956
409 | 409,0.5,3.15,1,0.769
410 | 410,1.12,3.22,1,0.711
411 | 411,3.51,4.82,1,0.728
412 | 412,1.63,3.95,1,0.956
413 | 413,1.35,1.04,1,0.946
414 | 414,4.51,1.16,1,0.387
415 | 415,3.64,4.39,1,0.747
416 | 416,2.5,4.27,1,0.883
417 | 417,2.31,4.59,1,0.841
418 | 418,2.11,1.79,1,0.695
419 | 419,2.55,2.68,1,0.98
420 | 420,2.14,4.69,1,0.925
421 | 421,0.14,4.77,1,0.969
422 | 422,2.59,2.72,1,0.852
423 | 423,2.74,4.7,1,0.984
424 | 424,4.2,0.16,1,0.966
425 | 425,4.83,0.31,1,0.754
426 | 426,0.83,1.68,1,0.904
427 | 427,1.84,2.26,1,0.873
428 | 428,4.52,3.15,1,0.764
429 | 429,3.41,3.99,1,0.851
430 | 430,3.64,3.69,1,0.652
431 | 431,3.59,4.16,1,0.906
432 | 432,1.17,0.54,1,0.643
433 | 433,0.54,0.99,1,0.864
434 | 434,4.24,0.36,1,0.933
435 | 435,2.56,2.71,1,0.836
436 | 436,4.84,4.96,1,0.703
437 | 437,2.28,4.76,1,0.901
438 | 438,2.11,1.19,1,0.934
439 | 439,0.64,2.75,1,0.856
440 | 440,4.26,4.16,1,0.87
441 | 441,3.45,3.92,1,0.645
442 | 442,3.7,3.22,1,0.936
443 | 443,4.54,4.31,1,0.982
444 | 444,4.58,3.52,1,0.739
445 | 445,1.58,1.18,1,0.949
446 | 446,4.65,2.24,1,0.838
447 | 447,2.75,3.97,1,0.9
448 | 448,0.78,3.0,1,0.889
449 | 449,4.7,4.25,1,0.756
450 | 450,4.36,1.93,1,0.876
451 | 451,1.4,3.68,1,0.663
452 | 452,3.2,4.84,1,0.835
453 | 453,1.18,2.08,1,0.739
454 | 454,0.74,3.55,1,0.813
455 | 455,2.56,2.15,1,0.37
456 | 456,3.59,3.41,1,0.883
457 | 457,2.98,4.94,1,0.73
458 | 458,0.12,4.85,1,0.824
459 | 459,3.52,3.83,1,0.967
460 | 460,2.72,1.31,1,0.874
461 | 461,4.33,4.18,1,0.237
462 | 462,3.04,1.15,1,0.796
463 | 463,3.8,4.79,1,0.8
464 | 464,2.71,1.98,1,0.633
465 | 465,3.39,1.49,1,0.909
466 | 466,3.32,0.26,1,0.691
467 | 467,3.8,0.74,1,0.839
468 | 468,1.5,1.33,1,0.654
469 | 469,2.04,4.56,1,0.969
470 | 470,3.15,0.14,1,0.688
471 | 471,1.07,2.0,1,0.802
472 | 472,3.25,1.95,1,0.984
473 | 473,3.95,4.34,1,0.892
474 | 474,1.4,3.76,1,0.917
475 | 475,2.08,1.74,1,0.648
476 | 476,0.85,4.82,1,0.7
477 | 477,0.4,1.58,1,0.67
478 | 478,3.57,0.53,1,0.777
479 | 479,4.03,1.46,1,0.321
480 | 480,4.09,3.06,1,0.976
481 | 481,4.93,4.05,1,0.813
482 | 482,1.18,2.31,1,0.827
483 | 483,4.43,2.05,1,0.731
484 | 484,2.87,1.01,1,0.694
485 | 485,2.49,0.4,1,0.856
486 | 486,2.84,1.24,1,0.69
487 | 487,3.94,1.52,1,0.647
488 | 488,4.28,2.37,1,0.796
489 | 489,1.24,2.65,1,0.981
490 | 490,1.35,1.05,1,0.742
491 | 491,3.33,2.46,1,0.983
492 | 492,2.2,2.25,1,0.711
493 | 493,3.71,2.04,1,0.81
494 | 494,2.48,3.01,1,0.74
495 | 495,4.45,0.01,1,0.739
496 | 496,3.0,1.35,1,0.822
497 | 497,1.92,3.89,1,0.823
498 | 498,1.46,2.31,1,0.658
499 | 499,4.46,2.56,1,0.943
500 | 500,0.38,3.51,1,0.736
501 | 501,1.77,2.08,1,0.638
502 | 502,4.11,0.58,1,0.796
503 | 503,0.91,3.24,1,0.984
504 | 504,3.29,0.27,1,0.881
505 | 505,2.93,4.44,1,0.68
506 | 506,4.99,0.98,1,0.834
507 | 507,1.19,1.49,1,0.657
508 | 508,0.65,4.84,1,0.922
509 | 509,4.37,3.66,1,0.248
510 | 510,2.28,3.29,1,0.696
511 | 511,3.07,4.13,1,0.734
512 | 512,2.8,1.64,1,0.893
513 | 513,3.55,4.69,1,0.835
514 | 514,4.75,2.87,1,0.952
515 | 515,3.88,1.21,1,0.858
516 | 516,2.46,1.66,1,0.856
517 | 517,2.96,2.22,1,0.808
518 | 518,2.8,2.86,1,0.849
519 | 519,0.62,4.39,1,0.679
520 | 520,1.98,0.18,1,0.85
521 | 521,2.13,2.36,1,0.831
522 | 522,0.78,0.97,1,0.855
523 | 523,1.63,3.61,1,0.689
524 | 524,0.56,3.98,1,0.959
525 | 525,3.92,1.55,1,0.83
526 | 526,2.4,1.1,1,0.746
527 | 527,0.12,4.63,1,0.89
528 | 528,0.6,2.4,1,0.909
529 | 529,3.99,1.64,1,0.95
530 | 530,1.58,1.84,1,0.713
531 | 531,1.1,0.9,1,0.905
532 | 532,0.46,2.42,1,0.769
533 |
--------------------------------------------------------------------------------
/problem_state/adwords_env.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import NamedTuple
3 | from torch_geometric.utils import to_dense_adj
4 | import torch.nn.functional as F
5 |
6 | # from utils.boolmask import mask_long2bool, mask_long_scatter
7 |
8 |
9 | class StateAdwordsBipartite(NamedTuple):
10 | # Fixed input
11 | graphs: torch.Tensor # graphs objects in a batch
12 | u_size: int
13 | v_size: int
14 | batch_size: torch.Tensor
15 | hist_sum: torch.Tensor
16 | hist_sum_sq: torch.Tensor
17 | hist_deg: torch.Tensor
18 | min_sol: torch.Tensor
19 | max_sol: torch.Tensor
20 | sum_sol_sq: torch.Tensor
21 | num_skip: torch.Tensor
22 | orig_budget: torch.Tensor
23 | curr_budget: torch.Tensor
24 |
25 | adj: torch.Tensor
26 | # State
27 | size: torch.Tensor # size of current matching
28 | i: int # Keeps track of step
29 | opts: dict
30 | idx: torch.Tensor
31 |
32 | @staticmethod
33 | def initialize(
34 | input,
35 | u_size,
36 | v_size,
37 | opts,
38 | ):
39 | graph_size = u_size + v_size + 1
40 | batch_size = int(input.batch.size(0) / graph_size)
41 | # print(batch_size, input.batch.size(0), graph_size)
42 | adj = to_dense_adj(
43 | input.edge_index,
44 | input.batch,
45 | input.weight.unsqueeze(1),
46 | ).squeeze(-1)
47 | adj = adj[:, u_size + 1 :, : u_size + 1]
48 | budgets = torch.cat(
49 | (torch.zeros(batch_size, 1).to(opts.device), input.x.reshape(batch_size, -1)), dim=1
50 | )
51 | # print(adj)
52 | # print(budgets)
53 | # permute the nodes for data
54 | idx = torch.arange(adj.shape[1], device=opts.device)
55 | # if "supervised" not in opts.model and not opts.eval_only:
56 | # idx = torch.randperm(adj.shape[1], device=opts.device)
57 | # adj = adj[:, idx, :].view(adj.size())
58 |
59 | return StateAdwordsBipartite(
60 | graphs=input,
61 | adj=adj.float(),
62 | u_size=u_size,
63 | v_size=v_size,
64 | batch_size=batch_size,
65 | # Keep visited with depot so we can scatter efficiently (if there is an action for depot)
66 | orig_budget=budgets,
67 | curr_budget=budgets.clone(),
68 | hist_sum=(
69 | torch.zeros(
70 | batch_size,
71 | 1,
72 | u_size + 1,
73 | device=input.batch.device,
74 | )
75 | ),
76 | hist_deg=(
77 | torch.zeros(
78 | batch_size,
79 | 1,
80 | u_size + 1,
81 | device=input.batch.device,
82 | )
83 | ),
84 | hist_sum_sq=( # Visited as mask is easier to understand, as long more memory efficient
85 | torch.zeros(
86 | batch_size,
87 | 1,
88 | u_size + 1,
89 | device=input.batch.device,
90 | )
91 | ),
92 | min_sol=torch.zeros(batch_size, 1, device=input.batch.device),
93 | max_sol=torch.zeros(batch_size, 1, device=input.batch.device),
94 | sum_sol_sq=torch.zeros(batch_size, 1, device=input.batch.device),
95 | num_skip=torch.zeros(batch_size, 1, device=input.batch.device),
96 | size=torch.zeros(batch_size, 1, device=input.batch.device),
97 | i=u_size + 1,
98 | opts=opts,
99 | idx=idx,
100 | )
101 |
102 | def get_final_cost(self):
103 |
104 | assert self.all_finished()
105 | # assert self.visited_.
106 |
107 | return self.size
108 |
109 | def get_current_weights(self, mask):
110 | return self.adj[:, 0, :].float()
111 |
112 | def get_graph_weights(self):
113 | return self.graphs.weight
114 |
115 | def update(self, selected):
116 | # Update the state
117 | w = self.adj[:, 0, :].clone()
118 | selected_weights = w.gather(1, selected).to(self.adj.device)
119 | one_hot_w = (
120 | F.one_hot(selected, num_classes=self.u_size + 1)
121 | .to(self.adj.device)
122 | .squeeze(1)
123 | .float()
124 | * selected_weights
125 | )
126 | curr_budget = self.curr_budget - one_hot_w
127 | curr_budget[curr_budget < 0.0] = 0.0
128 | skip = (selected == 0).float()
129 | num_skip = self.num_skip + skip
130 | if self.i == self.u_size + 1:
131 | min_sol = selected_weights
132 | else:
133 | m = self.min_sol.clone()
134 | m[m == 0.0] = 2.0
135 | selected_weights[skip.bool()] = 2.0
136 | min_sol = torch.minimum(m, selected_weights)
137 | selected_weights[selected_weights == 2.0] = 0.0
138 | min_sol[min_sol == 2.0] = 0.0
139 |
140 | max_sol = torch.maximum(self.max_sol, selected_weights)
141 | total_weights = self.size + selected_weights
142 | sum_sol_sq = self.sum_sol_sq + selected_weights ** 2
143 |
144 | hist_sum = self.hist_sum + w.unsqueeze(1)
145 | hist_sum_sq = self.hist_sum_sq + w.unsqueeze(1) ** 2
146 | hist_deg = self.hist_deg + (w.unsqueeze(1) != 0).float()
147 | hist_deg[:, :, 0] = float(self.i - self.u_size)
148 | return self._replace(
149 | curr_budget=curr_budget,
150 | size=total_weights,
151 | i=self.i + 1,
152 | adj=self.adj[:, 1:, :],
153 | hist_sum=hist_sum,
154 | hist_sum_sq=hist_sum_sq,
155 | hist_deg=hist_deg,
156 | num_skip=num_skip,
157 | max_sol=max_sol,
158 | min_sol=min_sol,
159 | sum_sol_sq=sum_sol_sq,
160 | )
161 |
162 | def all_finished(self):
163 | # Exactly v_size steps
164 | return (self.i - (self.u_size + 1)) >= self.v_size
165 |
166 | def get_current_node(self):
167 | return 0
168 |
169 | def get_curr_state(self, model):
170 | mask = self.get_mask()
171 | opts = self.opts
172 | w = self.adj[:, 0, :].float().clone()
173 | s = None
174 | if model == "ff":
175 | s = torch.cat((w, self.curr_budget, mask.float()), dim=1).float()
176 | elif model == "inv-ff":
177 | deg = (w != 0).float().sum(1)
178 | deg[deg == 0.0] = 1.0
179 | mean_w = w.sum(1) / deg
180 | mean_budget = self.curr_budget.sum(1) / self.u_size
181 | mean_budget = mean_budget[:, None, None].repeat(1, self.u_size + 1, 1)
182 | mean_w = mean_w[:, None, None].repeat(1, self.u_size + 1, 1)
183 | fixed_node_identity = torch.zeros(
184 | self.batch_size, self.u_size + 1, 1, device=opts.device
185 | ).float()
186 | fixed_node_identity[:, 0, :] = 1.0
187 | s = w.reshape(self.batch_size, self.u_size + 1, 1)
188 | s = torch.cat(
189 | (
190 | fixed_node_identity,
191 | self.curr_budget.reshape(self.batch_size, self.u_size + 1, 1),
192 | s,
193 | mean_w,
194 | mean_budget,
195 | ),
196 | dim=2,
197 | ).float()
198 |
199 | elif model == "ff-hist" or model == "ff-supervised":
200 | (
201 | h_mean,
202 | h_var,
203 | h_mean_degree,
204 | ind,
205 | matched_ratio,
206 | var_sol,
207 | mean_sol,
208 | n_skip,
209 | ) = self.get_hist_features()
210 | s = torch.cat(
211 | (
212 | w,
213 | mask.float(),
214 | self.orig_budget,
215 | self.curr_budget,
216 | h_mean.squeeze(1),
217 | h_var.squeeze(1),
218 | h_mean_degree.squeeze(1),
219 | self.size / self.orig_budget.sum(-1).unsqueeze(1),
220 | ind.float(),
221 | mean_sol,
222 | var_sol,
223 | n_skip,
224 | self.max_sol,
225 | self.min_sol,
226 | matched_ratio,
227 | ),
228 | dim=1,
229 | ).float()
230 |
231 | elif model == "inv-ff-hist" or model == "gnn-simp-hist":
232 | deg = (w != 0).float().sum(1)
233 | deg[deg == 0.0] = 1.0
234 | mean_w = w.sum(1) / deg
235 | mean_budget = self.curr_budget.sum(1) / self.u_size
236 | mean_w = mean_w[:, None, None].repeat(1, self.u_size + 1, 1)
237 | mean_budget = mean_budget[:, None, None].repeat(1, self.u_size + 1, 1)
238 | s = w.reshape(self.batch_size, self.u_size + 1, 1)
239 | (
240 | h_mean,
241 | h_var,
242 | h_mean_degree,
243 | ind,
244 | matched_ratio,
245 | var_sol,
246 | mean_sol,
247 | n_skip,
248 | ) = self.get_hist_features()
249 | available_ratio = (deg.unsqueeze(1)) / (self.u_size)
250 | fixed_node_identity = torch.zeros(
251 | self.batch_size, self.u_size + 1, 1, device=opts.device
252 | ).float()
253 | fixed_node_identity[:, 0, :] = 1.0
254 | s = torch.cat(
255 | (
256 | s,
257 | mask.reshape(-1, self.u_size + 1, 1).float(),
258 | self.orig_budget.reshape(self.batch_size, self.u_size + 1, 1),
259 | self.curr_budget.reshape(self.batch_size, self.u_size + 1, 1),
260 | mean_w,
261 | mean_budget,
262 | h_mean.transpose(1, 2),
263 | h_var.transpose(1, 2),
264 | h_mean_degree.transpose(1, 2),
265 | ind.unsqueeze(2).repeat(1, self.u_size + 1, 1),
266 | self.size.unsqueeze(2).repeat(1, self.u_size + 1, 1)
267 | / self.orig_budget.sum(-1)[:, None, None],
268 | mean_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1),
269 | var_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1),
270 | n_skip.unsqueeze(2).repeat(1, self.u_size + 1, 1),
271 | self.max_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1),
272 | self.min_sol.unsqueeze(2).repeat(1, self.u_size + 1, 1),
273 | matched_ratio.unsqueeze(2).repeat(1, self.u_size + 1, 1),
274 | available_ratio.unsqueeze(2).repeat(1, self.u_size + 1, 1),
275 | fixed_node_identity,
276 | ),
277 | dim=2,
278 | ).float()
279 |
280 | return s, mask
281 |
282 | def get_node_features(self):
283 | step_size = self.i + 1
284 | batch_size = self.batch_size
285 | incoming_node_features = (
286 | torch.cat(
287 | (
288 | torch.ones(step_size - self.u_size - 1, device=self.adj.device)
289 | * -2.0,
290 | )
291 | )
292 | .unsqueeze(0)
293 | .expand(batch_size, step_size - self.u_size - 1)
294 | ).float() # Collecting node features up until the ith incoming node
295 |
296 | future_node_feature = torch.ones(batch_size, 1, device=self.adj.device) * -1.0
297 | fixed_node_feature = self.curr_budget[:, 1:]
298 | node_features = torch.cat(
299 | (future_node_feature, fixed_node_feature, incoming_node_features), dim=1
300 | ).reshape(batch_size, step_size)
301 |
302 | return node_features.float()
303 |
304 | def get_hist_features(self):
305 | i = self.i - (self.u_size + 1)
306 | if i != 0:
307 | deg = self.hist_deg.clone()
308 | deg[deg == 0] = 1.0
309 | h_mean = self.hist_sum / deg
310 | h_var = (self.hist_sum_sq - ((self.hist_sum ** 2) / deg)) / deg
311 | h_mean_degree = self.hist_deg / i
312 | ind = (
313 | torch.ones(self.batch_size, 1, device=self.opts.device)
314 | * i
315 | / self.v_size
316 | )
317 | curr_sol_size = i - self.num_skip
318 | var_sol = (
319 | self.sum_sol_sq - ((self.size ** 2) / curr_sol_size)
320 | ) / curr_sol_size
321 | mean_sol = self.size / curr_sol_size
322 | var_sol[curr_sol_size == 0.0] = 0.0
323 | mean_sol[curr_sol_size == 0.0] = 0.0
324 | avg_budget = self.curr_budget.sum(1).unsqueeze(1) / self.u_size
325 | n_skip = self.num_skip / i
326 | else:
327 | (
328 | h_mean,
329 | h_var,
330 | h_mean_degree,
331 | ind,
332 | avg_budget,
333 | var_sol,
334 | mean_sol,
335 | n_skip,
336 | ) = (
337 | self.hist_sum * 0.0,
338 | self.hist_sum * 0.0,
339 | self.hist_sum * 0.0,
340 | self.size * 0.0,
341 | self.curr_budget.sum(1).unsqueeze(1) / self.u_size,
342 | self.size * 0.0,
343 | self.size * 0.0,
344 | self.size * 0.0,
345 | )
346 |
347 | return (
348 | h_mean,
349 | h_var,
350 | h_mean_degree,
351 | ind,
352 | avg_budget,
353 | var_sol,
354 | mean_sol,
355 | n_skip,
356 | )
357 |
358 | def get_mask(self):
359 | """
360 | Returns a mask vector which includes only nodes in U that can matched.
361 | That is, neighbors of the incoming node that have not been matched already.
362 | """
363 |
364 | mask = (self.adj[:, 0, :] == 0.0).float()
365 | mask[:, 0] = 0.0
366 | budget_mask = ((self.adj[:, 0, :] - self.curr_budget) > 1e-5).float()
367 | return (
368 | budget_mask + mask > 0.0
369 | ).long() # Hacky way to return bool or uint8 depending on pytorch version
370 |
--------------------------------------------------------------------------------