├── .gitignore ├── 02_generate_dataset.py ├── 03_train_gnn.py ├── 04_evaluate.py ├── LICENSE ├── README.md ├── model └── model.py └── utilities.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /02_generate_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import gzip 4 | import argparse 5 | import pickle 6 | import queue 7 | import shutil 8 | import threading 9 | import numpy as np 10 | import ecole 11 | from collections import namedtuple 12 | 13 | 14 | class ExploreThenStrongBranch: 15 | def __init__(self, expert_probability): 16 | self.expert_probability = expert_probability 17 | self.pseudocosts_function = ecole.observation.Pseudocosts() 18 | self.strong_branching_function = ecole.observation.StrongBranchingScores() 19 | 20 | def before_reset(self, model): 21 | self.pseudocosts_function.before_reset(model) 22 | self.strong_branching_function.before_reset(model) 23 | 24 | def extract(self, model, done): 25 | probabilities = [1-self.expert_probability, self.expert_probability] 26 | expert_chosen = bool(np.random.choice(np.arange(2), p=probabilities)) 27 | if expert_chosen: 28 | return (self.strong_branching_function.extract(model,done), True) 29 | else: 30 | return (self.pseudocosts_function.extract(model,done), False) 31 | 32 | 33 | def send_orders(orders_queue, instances, seed, query_expert_prob, time_limit, out_dir, stop_flag): 34 | """ 35 | Continuously send sampling orders to workers (relies on limited 36 | queue capacity). 37 | 38 | Parameters 39 | ---------- 40 | orders_queue : queue.Queue 41 | Queue to which to send orders. 42 | instances : list 43 | Instance file names from which to sample episodes. 44 | seed : int 45 | Random seed for reproducibility. 46 | query_expert_prob : float in [0, 1] 47 | Probability of running the expert strategy and collecting samples. 48 | time_limit : float in [0, 1e+20] 49 | Maximum running time for an episode, in seconds. 50 | out_dir: str 51 | Output directory in which to write samples. 52 | stop_flag: threading.Event 53 | A flag to tell the thread to stop. 54 | """ 55 | rng = np.random.RandomState(seed) 56 | 57 | episode = 0 58 | while not stop_flag.is_set(): 59 | instance = rng.choice(instances) 60 | seed = rng.randint(2**32) 61 | orders_queue.put([episode, instance, seed, query_expert_prob, time_limit, out_dir]) 62 | episode += 1 63 | 64 | 65 | def make_samples(in_queue, out_queue, stop_flag): 66 | """ 67 | Worker loop: fetch an instance, run an episode and record samples. 68 | Parameters 69 | ---------- 70 | in_queue : queue.Queue 71 | Input queue from which orders are received. 72 | out_queue : queue.Queue 73 | Output queue in which to send samples. 74 | stop_flag: threading.Event 75 | A flag to tell the thread to stop. 76 | """ 77 | sample_counter = 0 78 | while not stop_flag.is_set(): 79 | episode, instance, seed, query_expert_prob, time_limit, out_dir = in_queue.get() 80 | 81 | scip_parameters = {'separating/maxrounds': 0, 'presolving/maxrestarts': 0, 82 | 'limits/time': time_limit, 'timing/clocktype': 2} 83 | observation_function = { "scores": ExploreThenStrongBranch(expert_probability=query_expert_prob), 84 | "node_observation": ecole.observation.NodeBipartite() } 85 | env = ecole.environment.Branching(observation_function=observation_function, 86 | scip_params=scip_parameters, pseudo_candidates=True) 87 | 88 | print(f"[w {threading.current_thread().name}] episode {episode}, seed {seed}, " 89 | f"processing instance '{instance}'...\n", end='') 90 | out_queue.put({ 91 | 'type': 'start', 92 | 'episode': episode, 93 | 'instance': instance, 94 | 'seed': seed, 95 | }) 96 | 97 | env.seed(seed) 98 | observation, action_set, _, done, _ = env.reset(instance) 99 | while not done: 100 | scores, scores_are_expert = observation["scores"] 101 | node_observation = observation["node_observation"] 102 | node_observation = (node_observation.row_features, 103 | (node_observation.edge_features.indices, 104 | node_observation.edge_features.values), 105 | node_observation.variable_features) 106 | 107 | action = action_set[scores[action_set].argmax()] 108 | 109 | if scores_are_expert and not stop_flag.is_set(): 110 | data = [node_observation, action, action_set, scores] 111 | filename = f'{out_dir}/sample_{episode}_{sample_counter}.pkl' 112 | 113 | with gzip.open(filename, 'wb') as f: 114 | pickle.dump({ 115 | 'episode': episode, 116 | 'instance': instance, 117 | 'seed': seed, 118 | 'data': data, 119 | }, f) 120 | out_queue.put({ 121 | 'type': 'sample', 122 | 'episode': episode, 123 | 'instance': instance, 124 | 'seed': seed, 125 | 'filename': filename, 126 | }) 127 | sample_counter += 1 128 | 129 | try: 130 | observation, action_set, _, done, _ = env.step(action) 131 | except Exception as e: 132 | done = True 133 | with open("error_log.txt","a") as f: 134 | f.write(f"Error occurred solving {instance} with seed {seed}\n") 135 | f.write(f"{e}\n") 136 | 137 | print(f"[w {threading.current_thread().name}] episode {episode} done, {sample_counter} samples\n", end='') 138 | out_queue.put({ 139 | 'type': 'done', 140 | 'episode': episode, 141 | 'instance': instance, 142 | 'seed': seed, 143 | }) 144 | 145 | 146 | def collect_samples(instances, out_dir, rng, n_samples, n_jobs, 147 | query_expert_prob, time_limit): 148 | """ 149 | Runs branch-and-bound episodes on the given set of instances, and collects 150 | randomly (state, action) pairs from the 'vanilla-fullstrong' expert 151 | brancher. 152 | Parameters 153 | ---------- 154 | instances : list 155 | Instance files from which to collect samples. 156 | out_dir : str 157 | Directory in which to write samples. 158 | rng : numpy.random.RandomState 159 | A random number generator for reproducibility. 160 | n_samples : int 161 | Number of samples to collect. 162 | n_jobs : int 163 | Number of jobs for parallel sampling. 164 | query_expert_prob : float in [0, 1] 165 | Probability of using the expert policy and recording a (state, action) 166 | pair. 167 | time_limit : float in [0, 1e+20] 168 | Maximum running time for an episode, in seconds. 169 | """ 170 | os.makedirs(out_dir, exist_ok=True) 171 | 172 | # start workers 173 | orders_queue = queue.Queue(maxsize=2*n_jobs) 174 | answers_queue = queue.SimpleQueue() 175 | 176 | tmp_samples_dir = f'{out_dir}/tmp' 177 | os.makedirs(tmp_samples_dir, exist_ok=True) 178 | 179 | # start dispatcher 180 | dispatcher_stop_flag = threading.Event() 181 | dispatcher = threading.Thread( 182 | target=send_orders, 183 | args=(orders_queue, instances, rng.randint(2**32), query_expert_prob, 184 | time_limit, tmp_samples_dir, dispatcher_stop_flag), 185 | daemon=True) 186 | dispatcher.start() 187 | 188 | workers = [] 189 | workers_stop_flag = threading.Event() 190 | for i in range(n_jobs): 191 | p = threading.Thread( 192 | target=make_samples, 193 | args=(orders_queue, answers_queue, workers_stop_flag), 194 | daemon=True) 195 | workers.append(p) 196 | p.start() 197 | 198 | # record answers and write samples 199 | buffer = {} 200 | current_episode = 0 201 | i = 0 202 | in_buffer = 0 203 | while i < n_samples: 204 | sample = answers_queue.get() 205 | 206 | # add received sample to buffer 207 | if sample['type'] == 'start': 208 | buffer[sample['episode']] = [] 209 | else: 210 | buffer[sample['episode']].append(sample) 211 | if sample['type'] == 'sample': 212 | in_buffer += 1 213 | 214 | # if any, write samples from current episode 215 | while current_episode in buffer and buffer[current_episode]: 216 | samples_to_write = buffer[current_episode] 217 | buffer[current_episode] = [] 218 | 219 | for sample in samples_to_write: 220 | 221 | # if no more samples here, move to next episode 222 | if sample['type'] == 'done': 223 | del buffer[current_episode] 224 | current_episode += 1 225 | 226 | # else write sample 227 | else: 228 | os.rename(sample['filename'], f'{out_dir}/sample_{i+1}.pkl') 229 | in_buffer -= 1 230 | i += 1 231 | print(f"[m {threading.current_thread().name}] {i} / {n_samples} samples written, " 232 | f"ep {sample['episode']} ({in_buffer} in buffer).\n", end='') 233 | 234 | # early stop dispatcher 235 | if in_buffer + i >= n_samples and dispatcher.is_alive(): 236 | dispatcher_stop_flag.set() 237 | print(f"[m {threading.current_thread().name}] dispatcher stopped...\n", end='') 238 | 239 | # as soon as enough samples are collected, stop 240 | if i == n_samples: 241 | buffer = {} 242 | break 243 | 244 | # # stop all workers 245 | workers_stop_flag.set() 246 | for p in workers: 247 | p.join() 248 | 249 | print(f"Done collecting samples for {out_dir}") 250 | shutil.rmtree(tmp_samples_dir, ignore_errors=True) 251 | 252 | 253 | if __name__ == '__main__': 254 | parser = argparse.ArgumentParser() 255 | parser.add_argument( 256 | 'problem', 257 | help='MILP instance type to process.', 258 | choices=['setcover', 'cauctions', 'facilities', 'indset', 'mknapsack'], 259 | ) 260 | parser.add_argument( 261 | '-s', '--seed', 262 | help='Random generator seed.', 263 | type=int, 264 | default=0, 265 | ) 266 | parser.add_argument( 267 | '-j', '--njobs', 268 | help='Number of parallel jobs.', 269 | type=int, 270 | default=1, 271 | ) 272 | args = parser.parse_args() 273 | 274 | print(f"seed {args.seed}") 275 | 276 | train_size = 100000 277 | valid_size = 20000 278 | test_size = 20000 279 | node_record_prob = 0.05 280 | time_limit = 3600 281 | 282 | if args.problem == 'setcover': 283 | instances_train = glob.glob('data/instances/setcover/train_500r_1000c_0.05d/*.lp') 284 | instances_valid = glob.glob('data/instances/setcover/valid_500r_1000c_0.05d/*.lp') 285 | instances_test = glob.glob('data/instances/setcover/test_500r_1000c_0.05d/*.lp') 286 | out_dir = 'data/samples/setcover/500r_1000c_0.05d' 287 | 288 | elif args.problem == 'cauctions': 289 | instances_train = glob.glob('data/instances/cauctions/train_100_500/*.lp') 290 | instances_valid = glob.glob('data/instances/cauctions/valid_100_500/*.lp') 291 | instances_test = glob.glob('data/instances/cauctions/test_100_500/*.lp') 292 | out_dir = 'data/samples/cauctions/100_500' 293 | 294 | elif args.problem == 'indset': 295 | instances_train = glob.glob('data/instances/indset/train_500_4/*.lp') 296 | instances_valid = glob.glob('data/instances/indset/valid_500_4/*.lp') 297 | instances_test = glob.glob('data/instances/indset/test_500_4/*.lp') 298 | out_dir = 'data/samples/indset/500_4' 299 | 300 | elif args.problem == 'facilities': 301 | instances_train = glob.glob('data/instances/facilities/train_100_100_5/*.lp') 302 | instances_valid = glob.glob('data/instances/facilities/valid_100_100_5/*.lp') 303 | instances_test = glob.glob('data/instances/facilities/test_100_100_5/*.lp') 304 | out_dir = 'data/samples/facilities/100_100_5' 305 | time_limit = 600 306 | 307 | elif args.problem == 'mknapsack': 308 | instances_train = glob.glob('data/instances/mknapsack/train_100_6/*.lp') 309 | instances_valid = glob.glob('data/instances/mknapsack/valid_100_6/*.lp') 310 | instances_test = glob.glob('data/instances/mknapsack/test_100_6/*.lp') 311 | out_dir = 'data/samples/mknapsack/100_6' 312 | time_limit = 60 313 | 314 | else: 315 | raise NotImplementedError 316 | 317 | print(f"{len(instances_train)} train instances for {train_size} samples") 318 | print(f"{len(instances_valid)} validation instances for {valid_size} samples") 319 | print(f"{len(instances_test)} test instances for {test_size} samples") 320 | 321 | # create output directory, throws an error if it already exists 322 | os.makedirs(out_dir, exist_ok=True) 323 | 324 | rng = np.random.RandomState(args.seed) 325 | collect_samples(instances_train, out_dir + '/train', rng, train_size, 326 | args.njobs, query_expert_prob=node_record_prob, 327 | time_limit=time_limit) 328 | 329 | rng = np.random.RandomState(args.seed + 1) 330 | collect_samples(instances_valid, out_dir + '/valid', rng, test_size, 331 | args.njobs, query_expert_prob=node_record_prob, 332 | time_limit=time_limit) 333 | 334 | rng = np.random.RandomState(args.seed + 2) 335 | collect_samples(instances_test, out_dir + '/test', rng, test_size, 336 | args.njobs, query_expert_prob=node_record_prob, 337 | time_limit=time_limit) 338 | -------------------------------------------------------------------------------- /03_train_gnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import pathlib 5 | import numpy as np 6 | 7 | 8 | def pretrain(policy, pretrain_loader): 9 | policy.pre_train_init() 10 | i = 0 11 | while True: 12 | for batch in pretrain_loader: 13 | batch.to(device) 14 | if not policy.pre_train(batch.constraint_features, batch.edge_index, batch.edge_attr, batch.variable_features): 15 | break 16 | 17 | if policy.pre_train_next() is None: 18 | break 19 | i += 1 20 | return i 21 | 22 | 23 | def process(policy, data_loader, top_k=[1, 3, 5, 10], optimizer=None): 24 | mean_loss = 0 25 | mean_kacc = np.zeros(len(top_k)) 26 | mean_entropy = 0 27 | 28 | n_samples_processed = 0 29 | with torch.set_grad_enabled(optimizer is not None): 30 | for batch in data_loader: 31 | batch = batch.to(device) 32 | logits = policy(batch.constraint_features, batch.edge_index, batch.edge_attr, batch.variable_features) 33 | logits = pad_tensor(logits[batch.candidates], batch.nb_candidates) 34 | cross_entropy_loss = F.cross_entropy(logits, batch.candidate_choices, reduction='mean') 35 | entropy = (-F.softmax(logits, dim=-1)*F.log_softmax(logits, dim=-1)).sum(-1).mean() 36 | loss = cross_entropy_loss - entropy_bonus*entropy 37 | 38 | if optimizer is not None: 39 | optimizer.zero_grad() 40 | loss.backward() 41 | optimizer.step() 42 | 43 | true_scores = pad_tensor(batch.candidate_scores, batch.nb_candidates) 44 | true_bestscore = true_scores.max(dim=-1, keepdims=True).values 45 | 46 | kacc = [] 47 | for k in top_k: 48 | if logits.size()[-1] < k: 49 | kacc.append(1.0) 50 | continue 51 | pred_top_k = logits.topk(k).indices 52 | pred_top_k_true_scores = true_scores.gather(-1, pred_top_k) 53 | accuracy = (pred_top_k_true_scores == true_bestscore).any(dim=-1).float().mean().item() 54 | kacc.append(accuracy) 55 | kacc = np.asarray(kacc) 56 | mean_loss += cross_entropy_loss.item() * batch.num_graphs 57 | mean_entropy += entropy.item() * batch.num_graphs 58 | mean_kacc += kacc * batch.num_graphs 59 | n_samples_processed += batch.num_graphs 60 | 61 | mean_loss /= n_samples_processed 62 | mean_kacc /= n_samples_processed 63 | mean_entropy /= n_samples_processed 64 | return mean_loss, mean_kacc, mean_entropy 65 | 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument( 70 | 'problem', 71 | help='MILP instance type to process.', 72 | choices=['setcover', 'cauctions', 'facilities', 'indset', 'mknapsack'], 73 | ) 74 | parser.add_argument( 75 | '-s', '--seed', 76 | help='Random generator seed.', 77 | type=int, 78 | default=0, 79 | ) 80 | parser.add_argument( 81 | '-g', '--gpu', 82 | help='CUDA GPU id (-1 for CPU).', 83 | type=int, 84 | default=0, 85 | ) 86 | args = parser.parse_args() 87 | 88 | ### HYPER PARAMETERS ### 89 | max_epochs = 1000 90 | batch_size = 32 91 | pretrain_batch_size = 128 92 | valid_batch_size = 128 93 | lr = 1e-3 94 | entropy_bonus = 0.0 95 | top_k = [1, 3, 5, 10] 96 | 97 | problem_folders = { 98 | 'setcover': 'setcover/500r_1000c_0.05d', 99 | 'cauctions': 'cauctions/100_500', 100 | 'facilities': 'facilities/100_100_5', 101 | 'indset': 'indset/500_4', 102 | 'mknapsack': 'mknapsack/100_6', 103 | } 104 | problem_folder = problem_folders[args.problem] 105 | running_dir = f"model/{args.problem}/{args.seed}" 106 | os.makedirs(running_dir, exist_ok=True) 107 | 108 | ### PYTORCH SETUP ### 109 | if args.gpu == -1: 110 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 111 | device = "cpu" 112 | else: 113 | os.environ['CUDA_VISIBLE_DEVICES'] = f'{args.gpu}' 114 | device = f"cuda:0" 115 | import torch 116 | import torch.nn.functional as F 117 | import torch_geometric 118 | from utilities import log, pad_tensor, GraphDataset, Scheduler 119 | sys.path.insert(0, os.path.abspath(f'model')) 120 | from model import GNNPolicy 121 | 122 | rng = np.random.RandomState(args.seed) 123 | torch.manual_seed(args.seed) 124 | 125 | ### LOG ### 126 | logfile = os.path.join(running_dir, 'train_log.txt') 127 | if os.path.exists(logfile): 128 | os.remove(logfile) 129 | 130 | log(f"max_epochs: {max_epochs}", logfile) 131 | log(f"batch_size: {batch_size}", logfile) 132 | log(f"pretrain_batch_size: {pretrain_batch_size}", logfile) 133 | log(f"valid_batch_size : {valid_batch_size }", logfile) 134 | log(f"lr: {lr}", logfile) 135 | log(f"entropy bonus: {entropy_bonus}", logfile) 136 | log(f"top_k: {top_k}", logfile) 137 | log(f"problem: {args.problem}", logfile) 138 | log(f"gpu: {args.gpu}", logfile) 139 | log(f"seed {args.seed}", logfile) 140 | 141 | 142 | policy = GNNPolicy().to(device) 143 | optimizer = torch.optim.Adam(policy.parameters(), lr=1e-3) 144 | scheduler = Scheduler(optimizer, mode='min', patience=10, factor=0.2, verbose=True) 145 | 146 | train_files = [str(file) for file in (pathlib.Path(f'data/samples')/problem_folder/'train').glob('sample_*.pkl')] 147 | pretrain_files = [f for i, f in enumerate(train_files) if i % 10 == 0] 148 | valid_files = [str(file) for file in (pathlib.Path(f'data/samples')/problem_folder/'valid').glob('sample_*.pkl')] 149 | 150 | pretrain_data = GraphDataset(pretrain_files) 151 | pretrain_loader = torch_geometric.loader.DataLoader(pretrain_data, pretrain_batch_size, shuffle=False) 152 | valid_data = GraphDataset(valid_files) 153 | valid_loader = torch_geometric.loader.DataLoader(valid_data, valid_batch_size, shuffle=False) 154 | 155 | for epoch in range(max_epochs + 1): 156 | log(f"EPOCH {epoch}...", logfile) 157 | if epoch == 0: 158 | n = pretrain(policy, pretrain_loader) 159 | log(f"PRETRAINED {n} LAYERS", logfile) 160 | else: 161 | epoch_train_files = rng.choice(train_files, int(np.floor(10000/batch_size))*batch_size, replace=True) 162 | train_data = GraphDataset(epoch_train_files) 163 | train_loader = torch_geometric.data.DataLoader(train_data, batch_size, shuffle=True) 164 | train_loss, train_kacc, entropy = process(policy, train_loader, top_k, optimizer) 165 | log(f"TRAIN LOSS: {train_loss:0.3f} " + "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip(top_k, train_kacc)]), logfile) 166 | 167 | # TEST 168 | valid_loss, valid_kacc, entropy = process(policy, valid_loader, top_k, None) 169 | log(f"VALID LOSS: {valid_loss:0.3f} " + "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip(top_k, valid_kacc)]), logfile) 170 | 171 | scheduler.step(valid_loss) 172 | if scheduler.num_bad_epochs == 0: 173 | torch.save(policy.state_dict(), pathlib.Path(running_dir)/'train_params.pkl') 174 | log(f" best model so far", logfile) 175 | elif scheduler.num_bad_epochs == 10: 176 | log(f" 10 epochs without improvement, decreasing learning rate", logfile) 177 | elif scheduler.num_bad_epochs == 20: 178 | log(f" 20 epochs without improvement, early stopping", logfile) 179 | break 180 | 181 | policy.load_state_dict(torch.load(pathlib.Path(running_dir)/'train_params.pkl')) 182 | valid_loss, valid_kacc, entropy = process(policy, valid_loader, top_k, None) 183 | log(f"BEST VALID LOSS: {valid_loss:0.3f} " + "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip(top_k, valid_kacc)]), logfile) 184 | -------------------------------------------------------------------------------- /04_evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import importlib 4 | import argparse 5 | import csv 6 | import numpy as np 7 | import time 8 | import pickle 9 | 10 | import ecole 11 | import pyscipopt 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | 'problem', 18 | help='MILP instance type to process.', 19 | choices=['setcover', 'cauctions', 'facilities', 'indset'], 20 | ) 21 | parser.add_argument( 22 | '-g', '--gpu', 23 | help='CUDA GPU id (-1 for CPU).', 24 | type=int, 25 | default=0, 26 | ) 27 | args = parser.parse_args() 28 | 29 | result_file = f"{args.problem}_{time.strftime('%Y%m%d-%H%M%S')}.csv" 30 | instances = [] 31 | seeds = [0, 1, 2, 3, 4] 32 | internal_branchers = ['relpscost'] 33 | gnn_models = ['supervised'] # Can be supervised 34 | time_limit = 3600 35 | 36 | if args.problem == 'setcover': 37 | instances += [{'type': 'small', 'path': f"data/instances/setcover/transfer_500r_1000c_0.05d/instance_{i+1}.lp"} for i in range(20)] 38 | instances += [{'type': 'medium', 'path': f"data/instances/setcover/transfer_1000r_1000c_0.05d/instance_{i+1}.lp"} for i in range(20)] 39 | instances += [{'type': 'big', 'path': f"data/instances/setcover/transfer_2000r_1000c_0.05d/instance_{i+1}.lp"} for i in range(20)] 40 | 41 | elif args.problem == 'cauctions': 42 | instances += [{'type': 'small', 'path': f"data/instances/cauctions/transfer_100_500/instance_{i+1}.lp"} for i in range(20)] 43 | instances += [{'type': 'medium', 'path': f"data/instances/cauctions/transfer_200_1000/instance_{i+1}.lp"} for i in range(20)] 44 | instances += [{'type': 'big', 'path': f"data/instances/cauctions/transfer_300_1500/instance_{i+1}.lp"} for i in range(20)] 45 | 46 | elif args.problem == 'facilities': 47 | instances += [{'type': 'small', 'path': f"data/instances/facilities/transfer_100_100_5/instance_{i+1}.lp"} for i in range(20)] 48 | instances += [{'type': 'medium', 'path': f"data/instances/facilities/transfer_200_100_5/instance_{i+1}.lp"} for i in range(20)] 49 | instances += [{'type': 'big', 'path': f"data/instances/facilities/transfer_400_100_5/instance_{i+1}.lp"} for i in range(20)] 50 | 51 | elif args.problem == 'indset': 52 | instances += [{'type': 'small', 'path': f"data/instances/indset/transfer_500_4/instance_{i+1}.lp"} for i in range(20)] 53 | instances += [{'type': 'medium', 'path': f"data/instances/indset/transfer_1000_4/instance_{i+1}.lp"} for i in range(20)] 54 | instances += [{'type': 'big', 'path': f"data/instances/indset/transfer_1500_4/instance_{i+1}.lp"} for i in range(20)] 55 | 56 | else: 57 | raise NotImplementedError 58 | 59 | branching_policies = [] 60 | 61 | # SCIP internal brancher baselines 62 | for brancher in internal_branchers: 63 | for seed in seeds: 64 | branching_policies.append({ 65 | 'type': 'internal', 66 | 'name': brancher, 67 | 'seed': seed, 68 | }) 69 | # GNN models 70 | for model in gnn_models: 71 | for seed in seeds: 72 | branching_policies.append({ 73 | 'type': 'gnn', 74 | 'name': model, 75 | 'seed': seed, 76 | }) 77 | 78 | print(f"problem: {args.problem}") 79 | print(f"gpu: {args.gpu}") 80 | print(f"time limit: {time_limit} s") 81 | 82 | ### PYTORCH SETUP ### 83 | if args.gpu == -1: 84 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 85 | device = 'cpu' 86 | else: 87 | os.environ['CUDA_VISIBLE_DEVICES'] = f'{args.gpu}' 88 | device = f"cuda:0" 89 | 90 | import torch 91 | from model.model import GNNPolicy 92 | 93 | # load and assign tensorflow models to policies (share models and update parameters) 94 | loaded_models = {} 95 | loaded_calls = {} 96 | for policy in branching_policies: 97 | if policy['type'] == 'gnn': 98 | if policy['name'] not in loaded_models: 99 | ### MODEL LOADING ### 100 | model = GNNPolicy().to(device) 101 | if policy['name'] == 'supervised': 102 | model.load_state_dict(torch.load(f"model/{args.problem}/{policy['seed']}/train_params.pkl")) 103 | else: 104 | raise Exception(f"Unrecognized GNN policy {policy['name']}") 105 | loaded_models[policy['name']] = model 106 | 107 | policy['model'] = loaded_models[policy['name']] 108 | 109 | print("running SCIP...") 110 | 111 | fieldnames = [ 112 | 'policy', 113 | 'seed', 114 | 'type', 115 | 'instance', 116 | 'nnodes', 117 | 'nlps', 118 | 'stime', 119 | 'gap', 120 | 'status', 121 | 'walltime', 122 | 'proctime', 123 | ] 124 | os.makedirs('results', exist_ok=True) 125 | scip_parameters = {'separating/maxrounds': 0, 'presolving/maxrestarts': 0, 'limits/time': time_limit, 126 | 'timing/clocktype': 1, 'branching/vanillafullstrong/idempotent': True} 127 | 128 | with open(f"results/{result_file}", 'w', newline='') as csvfile: 129 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 130 | writer.writeheader() 131 | 132 | for instance in instances: 133 | print(f"{instance['type']}: {instance['path']}...") 134 | 135 | for policy in branching_policies: 136 | if policy['type'] == 'internal': 137 | # Run SCIP's default brancher 138 | env = ecole.environment.Configuring(scip_params={**scip_parameters, 139 | f"branching/{policy['name']}/priority": 9999999}) 140 | env.seed(policy['seed']) 141 | 142 | walltime = time.perf_counter() 143 | proctime = time.process_time() 144 | 145 | env.reset(instance['path']) 146 | _, _, _, _, _ = env.step({}) 147 | 148 | walltime = time.perf_counter() - walltime 149 | proctime = time.process_time() - proctime 150 | 151 | elif policy['type'] == 'gnn': 152 | # Run the GNN policy 153 | env = ecole.environment.Branching(observation_function=ecole.observation.NodeBipartite(), 154 | scip_params=scip_parameters) 155 | env.seed(policy['seed']) 156 | torch.manual_seed(policy['seed']) 157 | 158 | walltime = time.perf_counter() 159 | proctime = time.process_time() 160 | 161 | observation, action_set, _, done, _ = env.reset(instance['path']) 162 | while not done: 163 | with torch.no_grad(): 164 | observation = (torch.from_numpy(observation.row_features.astype(np.float32)).to(device), 165 | torch.from_numpy(observation.edge_features.indices.astype(np.int64)).to(device), 166 | torch.from_numpy(observation.edge_features.values.astype(np.float32)).view(-1, 1).to(device), 167 | torch.from_numpy(observation.variable_features.astype(np.float32)).to(device)) 168 | 169 | logits = policy['model'](*observation) 170 | action = action_set[logits[action_set.astype(np.int64)].argmax()] 171 | observation, action_set, _, done, _ = env.step(action) 172 | 173 | walltime = time.perf_counter() - walltime 174 | proctime = time.process_time() - proctime 175 | 176 | scip_model = env.model.as_pyscipopt() 177 | stime = scip_model.getSolvingTime() 178 | nnodes = scip_model.getNNodes() 179 | nlps = scip_model.getNLPs() 180 | gap = scip_model.getGap() 181 | status = scip_model.getStatus() 182 | 183 | writer.writerow({ 184 | 'policy': f"{policy['type']}:{policy['name']}", 185 | 'seed': policy['seed'], 186 | 'type': instance['type'], 187 | 'instance': instance['path'], 188 | 'nnodes': nnodes, 189 | 'nlps': nlps, 190 | 'stime': stime, 191 | 'gap': gap, 192 | 'status': status, 193 | 'walltime': walltime, 194 | 'proctime': proctime, 195 | }) 196 | csvfile.flush() 197 | 198 | print(f" {policy['type']}:{policy['name']} {policy['seed']} - {nnodes} nodes {nlps} lps {stime:.2f} ({walltime:.2f} wall {proctime:.2f} proc) s. {status}") 199 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 CERC Data Science For Decision Making 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Exact Combinatorial Optimization with Graph Convolutional Neural Networks (Ecole+Pytorch+Pytorch Geometric reimplementation) 2 | 3 | This is the official reimplementation of the proposed GNN model from the paper "Exact Combinatorial Optimization with Graph Convolutional Neural Networks" [NeurIPS 2019 paper](https://arxiv.org/abs/1906.01629) using the [Ecole library](https://github.com/ds4dm/ecole). This reimplementation also makes use [Pytorch](https://github.com/pytorch/pytorch) instead of Tensorflow, and of [Pytorch Geometric](https://github.com/pyg-team/pytorch_geometric) for handling the GNN. As a consequence, much of the code is now simplified. Slight discrepancies in results from the original implementation is to be expected. 4 | 5 | As mentionned, this repo only implements the GNN model. For comparisons with the other ML competitors (ExtraTrees, LambdaMART and SVMRank), please see the original implementation [here](https://github.com/ds4dm/learn2branch). 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 |
14 | 15 | ## Authors 16 | 17 | Maxime Gasse, Didier Chételat, Nicola Ferroni, Laurent Charlin and Andrea Lodi. 18 | 19 | ## Installation 20 | 21 | Our recommended installation uses the [Conda package manager](https://docs.conda.io/en/latest/miniconda.html). The previous implementation required you to compile a patched version of SCIP and PySCIPOpt using Cython. This is not required anymore, as Conda packages are now available, which are dependencies of the Ecole conda package itself. 22 | 23 | __Instructions:__ Install Ecole, Pytorch and Pytorch Geometric using conda. At the time of writing these installation instructions, this can be accomplished by running: 24 | 25 | ``` 26 | conda install ecole 27 | conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch 28 | conda install pyg -c pyg -c conda-forge 29 | ``` 30 | 31 | Please refer to the most up to date installation instructions for [Ecole](https://github.com/ds4dm/ecole#installation), [Pytorch](https://pytorch.org/get-started/locally) and [Pytorch Geometric](https://github.com/pyg-team/pytorch_geometric#installation) if you encounter any errors. 32 | 33 | ## Benchmarks 34 | 35 | For every benchmark in the paper, we describe the code for running the experiments, and the results compared to the original implementation. 36 | 37 | ### Set Covering 38 | 39 | ``` 40 | # Generate MILP instances 41 | python 01_generate_instances.py setcover 42 | # Generate supervised learning datasets 43 | python 02_generate_dataset.py setcover -j 4 # number of available CPUs 44 | # Training 45 | for i in {0..4} 46 | do 47 | python 03_train_gnn.py setcover -s $i 48 | done 49 | # Evaluation 50 | python 04_evaluate.py setcover 51 | ``` 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 |
EasyMediumHard
TimeNodesTimeNodesTimeNodes
SCIP default
GNN (original)
GNN (reimplementation)
97 | 98 | ### Combinatorial Auction 99 | ``` 100 | # Generate MILP instances 101 | python 01_generate_instances.py cauctions 102 | # Generate supervised learning datasets 103 | python 02_generate_dataset.py cauctions -j 4 # number of available CPUs 104 | # Training 105 | for i in {0..4} 106 | do 107 | python 03_train_gnn.py cauctions -s $i 108 | done 109 | # Evaluation 110 | python 04_evaluate.py cauctions 111 | ``` 112 | 113 | ### Capacitated Facility Location 114 | ``` 115 | # Generate MILP instances 116 | python 01_generate_instances.py facilities 117 | # Generate supervised learning datasets 118 | python 02_generate_dataset.py facilities -j 4 # number of available CPUs 119 | # Training 120 | for i in {0..4} 121 | do 122 | python 03_train_gnn.py facilities -s $i 123 | done 124 | # Evaluation 125 | python 04_evaluate.py facilities 126 | ``` 127 | 128 | ### Maximum Independent Set 129 | ``` 130 | # Generate MILP instances 131 | python 01_generate_instances.py indset 132 | # Generate supervised learning datasets 133 | python 02_generate_dataset.py indset -j 4 # number of available CPUs 134 | # Training 135 | for i in {0..4} 136 | do 137 | python 03_train_gnn.py indset -s $i 138 | done 139 | # Evaluation 140 | python 04_evaluate.py indset 141 | ``` 142 | 143 | ## Citation 144 | Please cite our paper if you use this code in your work. 145 | ``` 146 | @inproceedings{conf/nips/GasseCFCL19, 147 | title={Exact Combinatorial Optimization with Graph Convolutional Neural Networks}, 148 | author={Gasse, Maxime and Chételat, Didier and Ferroni, Nicola and Charlin, Laurent and Lodi, Andrea}, 149 | booktitle={Advances in Neural Information Processing Systems 32}, 150 | year={2019} 151 | } 152 | ``` 153 | 154 | ## Questions / Bugs 155 | Please feel free to submit a Github issue if you have any questions or find any bugs. We do not guarantee any support, but will do our best if we can help. 156 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch_geometric 4 | import numpy as np 5 | 6 | 7 | class PreNormException(Exception): 8 | pass 9 | 10 | 11 | class PreNormLayer(torch.nn.Module): 12 | def __init__(self, n_units, shift=True, scale=True, name=None): 13 | super().__init__() 14 | assert shift or scale 15 | self.register_buffer('shift', torch.zeros(n_units) if shift else None) 16 | self.register_buffer('scale', torch.ones(n_units) if scale else None) 17 | self.n_units = n_units 18 | self.waiting_updates = False 19 | self.received_updates = False 20 | 21 | def forward(self, input_): 22 | if self.waiting_updates: 23 | self.update_stats(input_) 24 | self.received_updates = True 25 | raise PreNormException 26 | 27 | if self.shift is not None: 28 | input_ = input_ + self.shift 29 | 30 | if self.scale is not None: 31 | input_ = input_ * self.scale 32 | 33 | return input_ 34 | 35 | def start_updates(self): 36 | self.avg = 0 37 | self.var = 0 38 | self.m2 = 0 39 | self.count = 0 40 | self.waiting_updates = True 41 | self.received_updates = False 42 | 43 | def update_stats(self, input_): 44 | """ 45 | Online mean and variance estimation. See: Chan et al. (1979) Updating 46 | Formulae and a Pairwise Algorithm for Computing Sample Variances. 47 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm 48 | """ 49 | assert self.n_units == 1 or input_.shape[-1] == self.n_units, f"Expected input dimension of size {self.n_units}, got {input_.shape[-1]}." 50 | 51 | input_ = input_.reshape(-1, self.n_units) 52 | sample_avg = input_.mean(dim=0) 53 | sample_var = (input_ - sample_avg).pow(2).mean(dim=0) 54 | sample_count = np.prod(input_.size())/self.n_units 55 | 56 | delta = sample_avg - self.avg 57 | 58 | self.m2 = self.var * self.count + sample_var * sample_count + delta ** 2 * self.count * sample_count / ( 59 | self.count + sample_count) 60 | 61 | self.count += sample_count 62 | self.avg += delta * sample_count / self.count 63 | self.var = self.m2 / self.count if self.count > 0 else 1 64 | 65 | def stop_updates(self): 66 | """ 67 | Ends pre-training for that layer, and fixes the layers's parameters. 68 | """ 69 | assert self.count > 0 70 | if self.shift is not None: 71 | self.shift = -self.avg 72 | 73 | if self.scale is not None: 74 | self.var[self.var < 1e-8] = 1 75 | self.scale = 1 / torch.sqrt(self.var) 76 | 77 | del self.avg, self.var, self.m2, self.count 78 | self.waiting_updates = False 79 | self.trainable = False 80 | 81 | 82 | 83 | class BipartiteGraphConvolution(torch_geometric.nn.MessagePassing): 84 | def __init__(self): 85 | super().__init__('add') 86 | emb_size = 64 87 | 88 | self.feature_module_left = torch.nn.Sequential( 89 | torch.nn.Linear(emb_size, emb_size) 90 | ) 91 | self.feature_module_edge = torch.nn.Sequential( 92 | torch.nn.Linear(1, emb_size, bias=False) 93 | ) 94 | self.feature_module_right = torch.nn.Sequential( 95 | torch.nn.Linear(emb_size, emb_size, bias=False) 96 | ) 97 | self.feature_module_final = torch.nn.Sequential( 98 | PreNormLayer(1, shift=False), 99 | torch.nn.ReLU(), 100 | torch.nn.Linear(emb_size, emb_size) 101 | ) 102 | 103 | self.post_conv_module = torch.nn.Sequential( 104 | PreNormLayer(1, shift=False) 105 | ) 106 | 107 | # output_layers 108 | self.output_module = torch.nn.Sequential( 109 | torch.nn.Linear(2*emb_size, emb_size), 110 | torch.nn.ReLU(), 111 | torch.nn.Linear(emb_size, emb_size), 112 | ) 113 | 114 | def forward(self, left_features, edge_indices, edge_features, right_features): 115 | output = self.propagate(edge_indices, size=(left_features.shape[0], right_features.shape[0]), 116 | node_features=(left_features, right_features), edge_features=edge_features) 117 | return self.output_module(torch.cat([self.post_conv_module(output), right_features], dim=-1)) 118 | 119 | def message(self, node_features_i, node_features_j, edge_features): 120 | output = self.feature_module_final(self.feature_module_left(node_features_i) 121 | + self.feature_module_edge(edge_features) 122 | + self.feature_module_right(node_features_j)) 123 | return output 124 | 125 | 126 | class BaseModel(torch.nn.Module): 127 | """ 128 | Our base model class, which implements pre-training methods. 129 | """ 130 | 131 | def pre_train_init(self): 132 | for module in self.modules(): 133 | if isinstance(module, PreNormLayer): 134 | module.start_updates() 135 | 136 | def pre_train_next(self): 137 | for module in self.modules(): 138 | if isinstance(module, PreNormLayer) and module.waiting_updates and module.received_updates: 139 | module.stop_updates() 140 | return module 141 | return None 142 | 143 | def pre_train(self, *args, **kwargs): 144 | try: 145 | with torch.no_grad(): 146 | self.forward(*args, **kwargs) 147 | return False 148 | except PreNormException: 149 | return True 150 | 151 | 152 | class GNNPolicy(BaseModel): 153 | def __init__(self): 154 | super().__init__() 155 | emb_size = 64 156 | cons_nfeats = 5 157 | edge_nfeats = 1 158 | var_nfeats = 19 159 | 160 | # CONSTRAINT EMBEDDING 161 | self.cons_embedding = torch.nn.Sequential( 162 | PreNormLayer(cons_nfeats), 163 | torch.nn.Linear(cons_nfeats, emb_size), 164 | torch.nn.ReLU(), 165 | torch.nn.Linear(emb_size, emb_size), 166 | torch.nn.ReLU(), 167 | ) 168 | 169 | # EDGE EMBEDDING 170 | self.edge_embedding = torch.nn.Sequential( 171 | PreNormLayer(edge_nfeats), 172 | ) 173 | 174 | # VARIABLE EMBEDDING 175 | self.var_embedding = torch.nn.Sequential( 176 | PreNormLayer(var_nfeats), 177 | torch.nn.Linear(var_nfeats, emb_size), 178 | torch.nn.ReLU(), 179 | torch.nn.Linear(emb_size, emb_size), 180 | torch.nn.ReLU(), 181 | ) 182 | 183 | self.conv_v_to_c = BipartiteGraphConvolution() 184 | self.conv_c_to_v = BipartiteGraphConvolution() 185 | 186 | self.output_module = torch.nn.Sequential( 187 | torch.nn.Linear(emb_size, emb_size), 188 | torch.nn.ReLU(), 189 | torch.nn.Linear(emb_size, 1, bias=False), 190 | ) 191 | 192 | def forward(self, constraint_features, edge_indices, edge_features, variable_features): 193 | reversed_edge_indices = torch.stack([edge_indices[1], edge_indices[0]], dim=0) 194 | 195 | constraint_features = self.cons_embedding(constraint_features) 196 | edge_features = self.edge_embedding(edge_features) 197 | variable_features = self.var_embedding(variable_features) 198 | 199 | constraint_features = self.conv_v_to_c(variable_features, reversed_edge_indices, edge_features, constraint_features) 200 | variable_features = self.conv_c_to_v(constraint_features, edge_indices, edge_features, variable_features) 201 | 202 | output = self.output_module(variable_features).squeeze(-1) 203 | return output 204 | -------------------------------------------------------------------------------- /utilities.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import pickle 3 | import datetime 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import torch_geometric 9 | 10 | def log(str, logfile=None): 11 | str = f'[{datetime.datetime.now()}] {str}' 12 | print(str) 13 | if logfile is not None: 14 | with open(logfile, mode='a') as f: 15 | print(str, file=f) 16 | 17 | 18 | def pad_tensor(input_, pad_sizes, pad_value=-1e8): 19 | max_pad_size = pad_sizes.max() 20 | output = input_.split(pad_sizes.cpu().numpy().tolist()) 21 | output = torch.stack([F.pad(slice_, (0, max_pad_size-slice_.size(0)), 'constant', pad_value) 22 | for slice_ in output], dim=0) 23 | return output 24 | 25 | 26 | class BipartiteNodeData(torch_geometric.data.Data): 27 | def __init__(self, constraint_features, edge_indices, edge_features, variable_features, 28 | candidates, nb_candidates, candidate_choice, candidate_scores): 29 | super().__init__() 30 | self.constraint_features = constraint_features 31 | self.edge_index = edge_indices 32 | self.edge_attr = edge_features 33 | self.variable_features = variable_features 34 | self.candidates = candidates 35 | self.nb_candidates = nb_candidates 36 | self.candidate_choices = candidate_choice 37 | self.candidate_scores = candidate_scores 38 | 39 | def __inc__(self, key, value, store, *args, **kwargs): 40 | if key == 'edge_index': 41 | return torch.tensor([[self.constraint_features.size(0)], [self.variable_features.size(0)]]) 42 | elif key == 'candidates': 43 | return self.variable_features.size(0) 44 | else: 45 | return super().__inc__(key, value, *args, **kwargs) 46 | 47 | 48 | class GraphDataset(torch_geometric.data.Dataset): 49 | def __init__(self, sample_files): 50 | super().__init__(root=None, transform=None, pre_transform=None) 51 | self.sample_files = sample_files 52 | 53 | def len(self): 54 | return len(self.sample_files) 55 | 56 | def get(self, index): 57 | with gzip.open(self.sample_files[index], 'rb') as f: 58 | sample = pickle.load(f) 59 | 60 | sample_observation, sample_action, sample_action_set, sample_scores = sample['data'] 61 | 62 | constraint_features, (edge_indices, edge_features), variable_features = sample_observation 63 | constraint_features = torch.FloatTensor(constraint_features) 64 | edge_indices = torch.LongTensor(edge_indices.astype(np.int32)) 65 | edge_features = torch.FloatTensor(np.expand_dims(edge_features, axis=-1)) 66 | variable_features = torch.FloatTensor(variable_features) 67 | 68 | candidates = torch.LongTensor(np.array(sample_action_set, dtype=np.int32)) 69 | candidate_choice = torch.where(candidates == sample_action)[0][0] # action index relative to candidates 70 | candidate_scores = torch.FloatTensor([sample_scores[j] for j in candidates]) 71 | 72 | graph = BipartiteNodeData(constraint_features, edge_indices, edge_features, variable_features, 73 | candidates, len(candidates), candidate_choice, candidate_scores) 74 | graph.num_nodes = constraint_features.shape[0]+variable_features.shape[0] 75 | return graph 76 | 77 | 78 | class Scheduler(torch.optim.lr_scheduler.ReduceLROnPlateau): 79 | def __init__(self, optimizer, **kwargs): 80 | super().__init__(optimizer, **kwargs) 81 | 82 | def step(self, metrics): 83 | # convert `metrics` to float, in case it's a zero-dim Tensor 84 | current = float(metrics) 85 | self.last_epoch =+1 86 | 87 | if self.is_better(current, self.best): 88 | self.best = current 89 | self.num_bad_epochs = 0 90 | else: 91 | self.num_bad_epochs += 1 92 | 93 | if self.num_bad_epochs == self.patience: 94 | self._reduce_lr(self.last_epoch) 95 | 96 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups] 97 | --------------------------------------------------------------------------------