├── .gitignore
├── 02_generate_dataset.py
├── 03_train_gnn.py
├── 04_evaluate.py
├── LICENSE
├── README.md
├── model
└── model.py
└── utilities.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/02_generate_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import gzip
4 | import argparse
5 | import pickle
6 | import queue
7 | import shutil
8 | import threading
9 | import numpy as np
10 | import ecole
11 | from collections import namedtuple
12 |
13 |
14 | class ExploreThenStrongBranch:
15 | def __init__(self, expert_probability):
16 | self.expert_probability = expert_probability
17 | self.pseudocosts_function = ecole.observation.Pseudocosts()
18 | self.strong_branching_function = ecole.observation.StrongBranchingScores()
19 |
20 | def before_reset(self, model):
21 | self.pseudocosts_function.before_reset(model)
22 | self.strong_branching_function.before_reset(model)
23 |
24 | def extract(self, model, done):
25 | probabilities = [1-self.expert_probability, self.expert_probability]
26 | expert_chosen = bool(np.random.choice(np.arange(2), p=probabilities))
27 | if expert_chosen:
28 | return (self.strong_branching_function.extract(model,done), True)
29 | else:
30 | return (self.pseudocosts_function.extract(model,done), False)
31 |
32 |
33 | def send_orders(orders_queue, instances, seed, query_expert_prob, time_limit, out_dir, stop_flag):
34 | """
35 | Continuously send sampling orders to workers (relies on limited
36 | queue capacity).
37 |
38 | Parameters
39 | ----------
40 | orders_queue : queue.Queue
41 | Queue to which to send orders.
42 | instances : list
43 | Instance file names from which to sample episodes.
44 | seed : int
45 | Random seed for reproducibility.
46 | query_expert_prob : float in [0, 1]
47 | Probability of running the expert strategy and collecting samples.
48 | time_limit : float in [0, 1e+20]
49 | Maximum running time for an episode, in seconds.
50 | out_dir: str
51 | Output directory in which to write samples.
52 | stop_flag: threading.Event
53 | A flag to tell the thread to stop.
54 | """
55 | rng = np.random.RandomState(seed)
56 |
57 | episode = 0
58 | while not stop_flag.is_set():
59 | instance = rng.choice(instances)
60 | seed = rng.randint(2**32)
61 | orders_queue.put([episode, instance, seed, query_expert_prob, time_limit, out_dir])
62 | episode += 1
63 |
64 |
65 | def make_samples(in_queue, out_queue, stop_flag):
66 | """
67 | Worker loop: fetch an instance, run an episode and record samples.
68 | Parameters
69 | ----------
70 | in_queue : queue.Queue
71 | Input queue from which orders are received.
72 | out_queue : queue.Queue
73 | Output queue in which to send samples.
74 | stop_flag: threading.Event
75 | A flag to tell the thread to stop.
76 | """
77 | sample_counter = 0
78 | while not stop_flag.is_set():
79 | episode, instance, seed, query_expert_prob, time_limit, out_dir = in_queue.get()
80 |
81 | scip_parameters = {'separating/maxrounds': 0, 'presolving/maxrestarts': 0,
82 | 'limits/time': time_limit, 'timing/clocktype': 2}
83 | observation_function = { "scores": ExploreThenStrongBranch(expert_probability=query_expert_prob),
84 | "node_observation": ecole.observation.NodeBipartite() }
85 | env = ecole.environment.Branching(observation_function=observation_function,
86 | scip_params=scip_parameters, pseudo_candidates=True)
87 |
88 | print(f"[w {threading.current_thread().name}] episode {episode}, seed {seed}, "
89 | f"processing instance '{instance}'...\n", end='')
90 | out_queue.put({
91 | 'type': 'start',
92 | 'episode': episode,
93 | 'instance': instance,
94 | 'seed': seed,
95 | })
96 |
97 | env.seed(seed)
98 | observation, action_set, _, done, _ = env.reset(instance)
99 | while not done:
100 | scores, scores_are_expert = observation["scores"]
101 | node_observation = observation["node_observation"]
102 | node_observation = (node_observation.row_features,
103 | (node_observation.edge_features.indices,
104 | node_observation.edge_features.values),
105 | node_observation.variable_features)
106 |
107 | action = action_set[scores[action_set].argmax()]
108 |
109 | if scores_are_expert and not stop_flag.is_set():
110 | data = [node_observation, action, action_set, scores]
111 | filename = f'{out_dir}/sample_{episode}_{sample_counter}.pkl'
112 |
113 | with gzip.open(filename, 'wb') as f:
114 | pickle.dump({
115 | 'episode': episode,
116 | 'instance': instance,
117 | 'seed': seed,
118 | 'data': data,
119 | }, f)
120 | out_queue.put({
121 | 'type': 'sample',
122 | 'episode': episode,
123 | 'instance': instance,
124 | 'seed': seed,
125 | 'filename': filename,
126 | })
127 | sample_counter += 1
128 |
129 | try:
130 | observation, action_set, _, done, _ = env.step(action)
131 | except Exception as e:
132 | done = True
133 | with open("error_log.txt","a") as f:
134 | f.write(f"Error occurred solving {instance} with seed {seed}\n")
135 | f.write(f"{e}\n")
136 |
137 | print(f"[w {threading.current_thread().name}] episode {episode} done, {sample_counter} samples\n", end='')
138 | out_queue.put({
139 | 'type': 'done',
140 | 'episode': episode,
141 | 'instance': instance,
142 | 'seed': seed,
143 | })
144 |
145 |
146 | def collect_samples(instances, out_dir, rng, n_samples, n_jobs,
147 | query_expert_prob, time_limit):
148 | """
149 | Runs branch-and-bound episodes on the given set of instances, and collects
150 | randomly (state, action) pairs from the 'vanilla-fullstrong' expert
151 | brancher.
152 | Parameters
153 | ----------
154 | instances : list
155 | Instance files from which to collect samples.
156 | out_dir : str
157 | Directory in which to write samples.
158 | rng : numpy.random.RandomState
159 | A random number generator for reproducibility.
160 | n_samples : int
161 | Number of samples to collect.
162 | n_jobs : int
163 | Number of jobs for parallel sampling.
164 | query_expert_prob : float in [0, 1]
165 | Probability of using the expert policy and recording a (state, action)
166 | pair.
167 | time_limit : float in [0, 1e+20]
168 | Maximum running time for an episode, in seconds.
169 | """
170 | os.makedirs(out_dir, exist_ok=True)
171 |
172 | # start workers
173 | orders_queue = queue.Queue(maxsize=2*n_jobs)
174 | answers_queue = queue.SimpleQueue()
175 |
176 | tmp_samples_dir = f'{out_dir}/tmp'
177 | os.makedirs(tmp_samples_dir, exist_ok=True)
178 |
179 | # start dispatcher
180 | dispatcher_stop_flag = threading.Event()
181 | dispatcher = threading.Thread(
182 | target=send_orders,
183 | args=(orders_queue, instances, rng.randint(2**32), query_expert_prob,
184 | time_limit, tmp_samples_dir, dispatcher_stop_flag),
185 | daemon=True)
186 | dispatcher.start()
187 |
188 | workers = []
189 | workers_stop_flag = threading.Event()
190 | for i in range(n_jobs):
191 | p = threading.Thread(
192 | target=make_samples,
193 | args=(orders_queue, answers_queue, workers_stop_flag),
194 | daemon=True)
195 | workers.append(p)
196 | p.start()
197 |
198 | # record answers and write samples
199 | buffer = {}
200 | current_episode = 0
201 | i = 0
202 | in_buffer = 0
203 | while i < n_samples:
204 | sample = answers_queue.get()
205 |
206 | # add received sample to buffer
207 | if sample['type'] == 'start':
208 | buffer[sample['episode']] = []
209 | else:
210 | buffer[sample['episode']].append(sample)
211 | if sample['type'] == 'sample':
212 | in_buffer += 1
213 |
214 | # if any, write samples from current episode
215 | while current_episode in buffer and buffer[current_episode]:
216 | samples_to_write = buffer[current_episode]
217 | buffer[current_episode] = []
218 |
219 | for sample in samples_to_write:
220 |
221 | # if no more samples here, move to next episode
222 | if sample['type'] == 'done':
223 | del buffer[current_episode]
224 | current_episode += 1
225 |
226 | # else write sample
227 | else:
228 | os.rename(sample['filename'], f'{out_dir}/sample_{i+1}.pkl')
229 | in_buffer -= 1
230 | i += 1
231 | print(f"[m {threading.current_thread().name}] {i} / {n_samples} samples written, "
232 | f"ep {sample['episode']} ({in_buffer} in buffer).\n", end='')
233 |
234 | # early stop dispatcher
235 | if in_buffer + i >= n_samples and dispatcher.is_alive():
236 | dispatcher_stop_flag.set()
237 | print(f"[m {threading.current_thread().name}] dispatcher stopped...\n", end='')
238 |
239 | # as soon as enough samples are collected, stop
240 | if i == n_samples:
241 | buffer = {}
242 | break
243 |
244 | # # stop all workers
245 | workers_stop_flag.set()
246 | for p in workers:
247 | p.join()
248 |
249 | print(f"Done collecting samples for {out_dir}")
250 | shutil.rmtree(tmp_samples_dir, ignore_errors=True)
251 |
252 |
253 | if __name__ == '__main__':
254 | parser = argparse.ArgumentParser()
255 | parser.add_argument(
256 | 'problem',
257 | help='MILP instance type to process.',
258 | choices=['setcover', 'cauctions', 'facilities', 'indset', 'mknapsack'],
259 | )
260 | parser.add_argument(
261 | '-s', '--seed',
262 | help='Random generator seed.',
263 | type=int,
264 | default=0,
265 | )
266 | parser.add_argument(
267 | '-j', '--njobs',
268 | help='Number of parallel jobs.',
269 | type=int,
270 | default=1,
271 | )
272 | args = parser.parse_args()
273 |
274 | print(f"seed {args.seed}")
275 |
276 | train_size = 100000
277 | valid_size = 20000
278 | test_size = 20000
279 | node_record_prob = 0.05
280 | time_limit = 3600
281 |
282 | if args.problem == 'setcover':
283 | instances_train = glob.glob('data/instances/setcover/train_500r_1000c_0.05d/*.lp')
284 | instances_valid = glob.glob('data/instances/setcover/valid_500r_1000c_0.05d/*.lp')
285 | instances_test = glob.glob('data/instances/setcover/test_500r_1000c_0.05d/*.lp')
286 | out_dir = 'data/samples/setcover/500r_1000c_0.05d'
287 |
288 | elif args.problem == 'cauctions':
289 | instances_train = glob.glob('data/instances/cauctions/train_100_500/*.lp')
290 | instances_valid = glob.glob('data/instances/cauctions/valid_100_500/*.lp')
291 | instances_test = glob.glob('data/instances/cauctions/test_100_500/*.lp')
292 | out_dir = 'data/samples/cauctions/100_500'
293 |
294 | elif args.problem == 'indset':
295 | instances_train = glob.glob('data/instances/indset/train_500_4/*.lp')
296 | instances_valid = glob.glob('data/instances/indset/valid_500_4/*.lp')
297 | instances_test = glob.glob('data/instances/indset/test_500_4/*.lp')
298 | out_dir = 'data/samples/indset/500_4'
299 |
300 | elif args.problem == 'facilities':
301 | instances_train = glob.glob('data/instances/facilities/train_100_100_5/*.lp')
302 | instances_valid = glob.glob('data/instances/facilities/valid_100_100_5/*.lp')
303 | instances_test = glob.glob('data/instances/facilities/test_100_100_5/*.lp')
304 | out_dir = 'data/samples/facilities/100_100_5'
305 | time_limit = 600
306 |
307 | elif args.problem == 'mknapsack':
308 | instances_train = glob.glob('data/instances/mknapsack/train_100_6/*.lp')
309 | instances_valid = glob.glob('data/instances/mknapsack/valid_100_6/*.lp')
310 | instances_test = glob.glob('data/instances/mknapsack/test_100_6/*.lp')
311 | out_dir = 'data/samples/mknapsack/100_6'
312 | time_limit = 60
313 |
314 | else:
315 | raise NotImplementedError
316 |
317 | print(f"{len(instances_train)} train instances for {train_size} samples")
318 | print(f"{len(instances_valid)} validation instances for {valid_size} samples")
319 | print(f"{len(instances_test)} test instances for {test_size} samples")
320 |
321 | # create output directory, throws an error if it already exists
322 | os.makedirs(out_dir, exist_ok=True)
323 |
324 | rng = np.random.RandomState(args.seed)
325 | collect_samples(instances_train, out_dir + '/train', rng, train_size,
326 | args.njobs, query_expert_prob=node_record_prob,
327 | time_limit=time_limit)
328 |
329 | rng = np.random.RandomState(args.seed + 1)
330 | collect_samples(instances_valid, out_dir + '/valid', rng, test_size,
331 | args.njobs, query_expert_prob=node_record_prob,
332 | time_limit=time_limit)
333 |
334 | rng = np.random.RandomState(args.seed + 2)
335 | collect_samples(instances_test, out_dir + '/test', rng, test_size,
336 | args.njobs, query_expert_prob=node_record_prob,
337 | time_limit=time_limit)
338 |
--------------------------------------------------------------------------------
/03_train_gnn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import pathlib
5 | import numpy as np
6 |
7 |
8 | def pretrain(policy, pretrain_loader):
9 | policy.pre_train_init()
10 | i = 0
11 | while True:
12 | for batch in pretrain_loader:
13 | batch.to(device)
14 | if not policy.pre_train(batch.constraint_features, batch.edge_index, batch.edge_attr, batch.variable_features):
15 | break
16 |
17 | if policy.pre_train_next() is None:
18 | break
19 | i += 1
20 | return i
21 |
22 |
23 | def process(policy, data_loader, top_k=[1, 3, 5, 10], optimizer=None):
24 | mean_loss = 0
25 | mean_kacc = np.zeros(len(top_k))
26 | mean_entropy = 0
27 |
28 | n_samples_processed = 0
29 | with torch.set_grad_enabled(optimizer is not None):
30 | for batch in data_loader:
31 | batch = batch.to(device)
32 | logits = policy(batch.constraint_features, batch.edge_index, batch.edge_attr, batch.variable_features)
33 | logits = pad_tensor(logits[batch.candidates], batch.nb_candidates)
34 | cross_entropy_loss = F.cross_entropy(logits, batch.candidate_choices, reduction='mean')
35 | entropy = (-F.softmax(logits, dim=-1)*F.log_softmax(logits, dim=-1)).sum(-1).mean()
36 | loss = cross_entropy_loss - entropy_bonus*entropy
37 |
38 | if optimizer is not None:
39 | optimizer.zero_grad()
40 | loss.backward()
41 | optimizer.step()
42 |
43 | true_scores = pad_tensor(batch.candidate_scores, batch.nb_candidates)
44 | true_bestscore = true_scores.max(dim=-1, keepdims=True).values
45 |
46 | kacc = []
47 | for k in top_k:
48 | if logits.size()[-1] < k:
49 | kacc.append(1.0)
50 | continue
51 | pred_top_k = logits.topk(k).indices
52 | pred_top_k_true_scores = true_scores.gather(-1, pred_top_k)
53 | accuracy = (pred_top_k_true_scores == true_bestscore).any(dim=-1).float().mean().item()
54 | kacc.append(accuracy)
55 | kacc = np.asarray(kacc)
56 | mean_loss += cross_entropy_loss.item() * batch.num_graphs
57 | mean_entropy += entropy.item() * batch.num_graphs
58 | mean_kacc += kacc * batch.num_graphs
59 | n_samples_processed += batch.num_graphs
60 |
61 | mean_loss /= n_samples_processed
62 | mean_kacc /= n_samples_processed
63 | mean_entropy /= n_samples_processed
64 | return mean_loss, mean_kacc, mean_entropy
65 |
66 |
67 | if __name__ == "__main__":
68 | parser = argparse.ArgumentParser()
69 | parser.add_argument(
70 | 'problem',
71 | help='MILP instance type to process.',
72 | choices=['setcover', 'cauctions', 'facilities', 'indset', 'mknapsack'],
73 | )
74 | parser.add_argument(
75 | '-s', '--seed',
76 | help='Random generator seed.',
77 | type=int,
78 | default=0,
79 | )
80 | parser.add_argument(
81 | '-g', '--gpu',
82 | help='CUDA GPU id (-1 for CPU).',
83 | type=int,
84 | default=0,
85 | )
86 | args = parser.parse_args()
87 |
88 | ### HYPER PARAMETERS ###
89 | max_epochs = 1000
90 | batch_size = 32
91 | pretrain_batch_size = 128
92 | valid_batch_size = 128
93 | lr = 1e-3
94 | entropy_bonus = 0.0
95 | top_k = [1, 3, 5, 10]
96 |
97 | problem_folders = {
98 | 'setcover': 'setcover/500r_1000c_0.05d',
99 | 'cauctions': 'cauctions/100_500',
100 | 'facilities': 'facilities/100_100_5',
101 | 'indset': 'indset/500_4',
102 | 'mknapsack': 'mknapsack/100_6',
103 | }
104 | problem_folder = problem_folders[args.problem]
105 | running_dir = f"model/{args.problem}/{args.seed}"
106 | os.makedirs(running_dir, exist_ok=True)
107 |
108 | ### PYTORCH SETUP ###
109 | if args.gpu == -1:
110 | os.environ['CUDA_VISIBLE_DEVICES'] = ''
111 | device = "cpu"
112 | else:
113 | os.environ['CUDA_VISIBLE_DEVICES'] = f'{args.gpu}'
114 | device = f"cuda:0"
115 | import torch
116 | import torch.nn.functional as F
117 | import torch_geometric
118 | from utilities import log, pad_tensor, GraphDataset, Scheduler
119 | sys.path.insert(0, os.path.abspath(f'model'))
120 | from model import GNNPolicy
121 |
122 | rng = np.random.RandomState(args.seed)
123 | torch.manual_seed(args.seed)
124 |
125 | ### LOG ###
126 | logfile = os.path.join(running_dir, 'train_log.txt')
127 | if os.path.exists(logfile):
128 | os.remove(logfile)
129 |
130 | log(f"max_epochs: {max_epochs}", logfile)
131 | log(f"batch_size: {batch_size}", logfile)
132 | log(f"pretrain_batch_size: {pretrain_batch_size}", logfile)
133 | log(f"valid_batch_size : {valid_batch_size }", logfile)
134 | log(f"lr: {lr}", logfile)
135 | log(f"entropy bonus: {entropy_bonus}", logfile)
136 | log(f"top_k: {top_k}", logfile)
137 | log(f"problem: {args.problem}", logfile)
138 | log(f"gpu: {args.gpu}", logfile)
139 | log(f"seed {args.seed}", logfile)
140 |
141 |
142 | policy = GNNPolicy().to(device)
143 | optimizer = torch.optim.Adam(policy.parameters(), lr=1e-3)
144 | scheduler = Scheduler(optimizer, mode='min', patience=10, factor=0.2, verbose=True)
145 |
146 | train_files = [str(file) for file in (pathlib.Path(f'data/samples')/problem_folder/'train').glob('sample_*.pkl')]
147 | pretrain_files = [f for i, f in enumerate(train_files) if i % 10 == 0]
148 | valid_files = [str(file) for file in (pathlib.Path(f'data/samples')/problem_folder/'valid').glob('sample_*.pkl')]
149 |
150 | pretrain_data = GraphDataset(pretrain_files)
151 | pretrain_loader = torch_geometric.loader.DataLoader(pretrain_data, pretrain_batch_size, shuffle=False)
152 | valid_data = GraphDataset(valid_files)
153 | valid_loader = torch_geometric.loader.DataLoader(valid_data, valid_batch_size, shuffle=False)
154 |
155 | for epoch in range(max_epochs + 1):
156 | log(f"EPOCH {epoch}...", logfile)
157 | if epoch == 0:
158 | n = pretrain(policy, pretrain_loader)
159 | log(f"PRETRAINED {n} LAYERS", logfile)
160 | else:
161 | epoch_train_files = rng.choice(train_files, int(np.floor(10000/batch_size))*batch_size, replace=True)
162 | train_data = GraphDataset(epoch_train_files)
163 | train_loader = torch_geometric.data.DataLoader(train_data, batch_size, shuffle=True)
164 | train_loss, train_kacc, entropy = process(policy, train_loader, top_k, optimizer)
165 | log(f"TRAIN LOSS: {train_loss:0.3f} " + "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip(top_k, train_kacc)]), logfile)
166 |
167 | # TEST
168 | valid_loss, valid_kacc, entropy = process(policy, valid_loader, top_k, None)
169 | log(f"VALID LOSS: {valid_loss:0.3f} " + "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip(top_k, valid_kacc)]), logfile)
170 |
171 | scheduler.step(valid_loss)
172 | if scheduler.num_bad_epochs == 0:
173 | torch.save(policy.state_dict(), pathlib.Path(running_dir)/'train_params.pkl')
174 | log(f" best model so far", logfile)
175 | elif scheduler.num_bad_epochs == 10:
176 | log(f" 10 epochs without improvement, decreasing learning rate", logfile)
177 | elif scheduler.num_bad_epochs == 20:
178 | log(f" 20 epochs without improvement, early stopping", logfile)
179 | break
180 |
181 | policy.load_state_dict(torch.load(pathlib.Path(running_dir)/'train_params.pkl'))
182 | valid_loss, valid_kacc, entropy = process(policy, valid_loader, top_k, None)
183 | log(f"BEST VALID LOSS: {valid_loss:0.3f} " + "".join([f" acc@{k}: {acc:0.3f}" for k, acc in zip(top_k, valid_kacc)]), logfile)
184 |
--------------------------------------------------------------------------------
/04_evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import importlib
4 | import argparse
5 | import csv
6 | import numpy as np
7 | import time
8 | import pickle
9 |
10 | import ecole
11 | import pyscipopt
12 |
13 |
14 | if __name__ == "__main__":
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument(
17 | 'problem',
18 | help='MILP instance type to process.',
19 | choices=['setcover', 'cauctions', 'facilities', 'indset'],
20 | )
21 | parser.add_argument(
22 | '-g', '--gpu',
23 | help='CUDA GPU id (-1 for CPU).',
24 | type=int,
25 | default=0,
26 | )
27 | args = parser.parse_args()
28 |
29 | result_file = f"{args.problem}_{time.strftime('%Y%m%d-%H%M%S')}.csv"
30 | instances = []
31 | seeds = [0, 1, 2, 3, 4]
32 | internal_branchers = ['relpscost']
33 | gnn_models = ['supervised'] # Can be supervised
34 | time_limit = 3600
35 |
36 | if args.problem == 'setcover':
37 | instances += [{'type': 'small', 'path': f"data/instances/setcover/transfer_500r_1000c_0.05d/instance_{i+1}.lp"} for i in range(20)]
38 | instances += [{'type': 'medium', 'path': f"data/instances/setcover/transfer_1000r_1000c_0.05d/instance_{i+1}.lp"} for i in range(20)]
39 | instances += [{'type': 'big', 'path': f"data/instances/setcover/transfer_2000r_1000c_0.05d/instance_{i+1}.lp"} for i in range(20)]
40 |
41 | elif args.problem == 'cauctions':
42 | instances += [{'type': 'small', 'path': f"data/instances/cauctions/transfer_100_500/instance_{i+1}.lp"} for i in range(20)]
43 | instances += [{'type': 'medium', 'path': f"data/instances/cauctions/transfer_200_1000/instance_{i+1}.lp"} for i in range(20)]
44 | instances += [{'type': 'big', 'path': f"data/instances/cauctions/transfer_300_1500/instance_{i+1}.lp"} for i in range(20)]
45 |
46 | elif args.problem == 'facilities':
47 | instances += [{'type': 'small', 'path': f"data/instances/facilities/transfer_100_100_5/instance_{i+1}.lp"} for i in range(20)]
48 | instances += [{'type': 'medium', 'path': f"data/instances/facilities/transfer_200_100_5/instance_{i+1}.lp"} for i in range(20)]
49 | instances += [{'type': 'big', 'path': f"data/instances/facilities/transfer_400_100_5/instance_{i+1}.lp"} for i in range(20)]
50 |
51 | elif args.problem == 'indset':
52 | instances += [{'type': 'small', 'path': f"data/instances/indset/transfer_500_4/instance_{i+1}.lp"} for i in range(20)]
53 | instances += [{'type': 'medium', 'path': f"data/instances/indset/transfer_1000_4/instance_{i+1}.lp"} for i in range(20)]
54 | instances += [{'type': 'big', 'path': f"data/instances/indset/transfer_1500_4/instance_{i+1}.lp"} for i in range(20)]
55 |
56 | else:
57 | raise NotImplementedError
58 |
59 | branching_policies = []
60 |
61 | # SCIP internal brancher baselines
62 | for brancher in internal_branchers:
63 | for seed in seeds:
64 | branching_policies.append({
65 | 'type': 'internal',
66 | 'name': brancher,
67 | 'seed': seed,
68 | })
69 | # GNN models
70 | for model in gnn_models:
71 | for seed in seeds:
72 | branching_policies.append({
73 | 'type': 'gnn',
74 | 'name': model,
75 | 'seed': seed,
76 | })
77 |
78 | print(f"problem: {args.problem}")
79 | print(f"gpu: {args.gpu}")
80 | print(f"time limit: {time_limit} s")
81 |
82 | ### PYTORCH SETUP ###
83 | if args.gpu == -1:
84 | os.environ['CUDA_VISIBLE_DEVICES'] = ''
85 | device = 'cpu'
86 | else:
87 | os.environ['CUDA_VISIBLE_DEVICES'] = f'{args.gpu}'
88 | device = f"cuda:0"
89 |
90 | import torch
91 | from model.model import GNNPolicy
92 |
93 | # load and assign tensorflow models to policies (share models and update parameters)
94 | loaded_models = {}
95 | loaded_calls = {}
96 | for policy in branching_policies:
97 | if policy['type'] == 'gnn':
98 | if policy['name'] not in loaded_models:
99 | ### MODEL LOADING ###
100 | model = GNNPolicy().to(device)
101 | if policy['name'] == 'supervised':
102 | model.load_state_dict(torch.load(f"model/{args.problem}/{policy['seed']}/train_params.pkl"))
103 | else:
104 | raise Exception(f"Unrecognized GNN policy {policy['name']}")
105 | loaded_models[policy['name']] = model
106 |
107 | policy['model'] = loaded_models[policy['name']]
108 |
109 | print("running SCIP...")
110 |
111 | fieldnames = [
112 | 'policy',
113 | 'seed',
114 | 'type',
115 | 'instance',
116 | 'nnodes',
117 | 'nlps',
118 | 'stime',
119 | 'gap',
120 | 'status',
121 | 'walltime',
122 | 'proctime',
123 | ]
124 | os.makedirs('results', exist_ok=True)
125 | scip_parameters = {'separating/maxrounds': 0, 'presolving/maxrestarts': 0, 'limits/time': time_limit,
126 | 'timing/clocktype': 1, 'branching/vanillafullstrong/idempotent': True}
127 |
128 | with open(f"results/{result_file}", 'w', newline='') as csvfile:
129 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
130 | writer.writeheader()
131 |
132 | for instance in instances:
133 | print(f"{instance['type']}: {instance['path']}...")
134 |
135 | for policy in branching_policies:
136 | if policy['type'] == 'internal':
137 | # Run SCIP's default brancher
138 | env = ecole.environment.Configuring(scip_params={**scip_parameters,
139 | f"branching/{policy['name']}/priority": 9999999})
140 | env.seed(policy['seed'])
141 |
142 | walltime = time.perf_counter()
143 | proctime = time.process_time()
144 |
145 | env.reset(instance['path'])
146 | _, _, _, _, _ = env.step({})
147 |
148 | walltime = time.perf_counter() - walltime
149 | proctime = time.process_time() - proctime
150 |
151 | elif policy['type'] == 'gnn':
152 | # Run the GNN policy
153 | env = ecole.environment.Branching(observation_function=ecole.observation.NodeBipartite(),
154 | scip_params=scip_parameters)
155 | env.seed(policy['seed'])
156 | torch.manual_seed(policy['seed'])
157 |
158 | walltime = time.perf_counter()
159 | proctime = time.process_time()
160 |
161 | observation, action_set, _, done, _ = env.reset(instance['path'])
162 | while not done:
163 | with torch.no_grad():
164 | observation = (torch.from_numpy(observation.row_features.astype(np.float32)).to(device),
165 | torch.from_numpy(observation.edge_features.indices.astype(np.int64)).to(device),
166 | torch.from_numpy(observation.edge_features.values.astype(np.float32)).view(-1, 1).to(device),
167 | torch.from_numpy(observation.variable_features.astype(np.float32)).to(device))
168 |
169 | logits = policy['model'](*observation)
170 | action = action_set[logits[action_set.astype(np.int64)].argmax()]
171 | observation, action_set, _, done, _ = env.step(action)
172 |
173 | walltime = time.perf_counter() - walltime
174 | proctime = time.process_time() - proctime
175 |
176 | scip_model = env.model.as_pyscipopt()
177 | stime = scip_model.getSolvingTime()
178 | nnodes = scip_model.getNNodes()
179 | nlps = scip_model.getNLPs()
180 | gap = scip_model.getGap()
181 | status = scip_model.getStatus()
182 |
183 | writer.writerow({
184 | 'policy': f"{policy['type']}:{policy['name']}",
185 | 'seed': policy['seed'],
186 | 'type': instance['type'],
187 | 'instance': instance['path'],
188 | 'nnodes': nnodes,
189 | 'nlps': nlps,
190 | 'stime': stime,
191 | 'gap': gap,
192 | 'status': status,
193 | 'walltime': walltime,
194 | 'proctime': proctime,
195 | })
196 | csvfile.flush()
197 |
198 | print(f" {policy['type']}:{policy['name']} {policy['seed']} - {nnodes} nodes {nlps} lps {stime:.2f} ({walltime:.2f} wall {proctime:.2f} proc) s. {status}")
199 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 CERC Data Science For Decision Making
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Exact Combinatorial Optimization with Graph Convolutional Neural Networks (Ecole+Pytorch+Pytorch Geometric reimplementation)
2 |
3 | This is the official reimplementation of the proposed GNN model from the paper "Exact Combinatorial Optimization with Graph Convolutional Neural Networks" [NeurIPS 2019 paper](https://arxiv.org/abs/1906.01629) using the [Ecole library](https://github.com/ds4dm/ecole). This reimplementation also makes use [Pytorch](https://github.com/pytorch/pytorch) instead of Tensorflow, and of [Pytorch Geometric](https://github.com/pyg-team/pytorch_geometric) for handling the GNN. As a consequence, much of the code is now simplified. Slight discrepancies in results from the original implementation is to be expected.
4 |
5 | As mentionned, this repo only implements the GNN model. For comparisons with the other ML competitors (ExtraTrees, LambdaMART and SVMRank), please see the original implementation [here](https://github.com/ds4dm/learn2branch).
6 |
7 |
8 |
9 |  |
10 |  |
11 |  |
12 |
13 |
14 |
15 | ## Authors
16 |
17 | Maxime Gasse, Didier Chételat, Nicola Ferroni, Laurent Charlin and Andrea Lodi.
18 |
19 | ## Installation
20 |
21 | Our recommended installation uses the [Conda package manager](https://docs.conda.io/en/latest/miniconda.html). The previous implementation required you to compile a patched version of SCIP and PySCIPOpt using Cython. This is not required anymore, as Conda packages are now available, which are dependencies of the Ecole conda package itself.
22 |
23 | __Instructions:__ Install Ecole, Pytorch and Pytorch Geometric using conda. At the time of writing these installation instructions, this can be accomplished by running:
24 |
25 | ```
26 | conda install ecole
27 | conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
28 | conda install pyg -c pyg -c conda-forge
29 | ```
30 |
31 | Please refer to the most up to date installation instructions for [Ecole](https://github.com/ds4dm/ecole#installation), [Pytorch](https://pytorch.org/get-started/locally) and [Pytorch Geometric](https://github.com/pyg-team/pytorch_geometric#installation) if you encounter any errors.
32 |
33 | ## Benchmarks
34 |
35 | For every benchmark in the paper, we describe the code for running the experiments, and the results compared to the original implementation.
36 |
37 | ### Set Covering
38 |
39 | ```
40 | # Generate MILP instances
41 | python 01_generate_instances.py setcover
42 | # Generate supervised learning datasets
43 | python 02_generate_dataset.py setcover -j 4 # number of available CPUs
44 | # Training
45 | for i in {0..4}
46 | do
47 | python 03_train_gnn.py setcover -s $i
48 | done
49 | # Evaluation
50 | python 04_evaluate.py setcover
51 | ```
52 |
53 |
54 |
55 | |
56 | Easy |
57 | Medium |
58 | Hard |
59 |
60 |
61 | |
62 | Time |
63 | Nodes |
64 | Time |
65 | Nodes |
66 | Time |
67 | Nodes |
68 |
69 |
70 | SCIP default |
71 | |
72 | |
73 | |
74 | |
75 | |
76 | |
77 |
78 |
79 | GNN (original) |
80 | |
81 | |
82 | |
83 | |
84 | |
85 | |
86 |
87 |
88 | GNN (reimplementation) |
89 | |
90 | |
91 | |
92 | |
93 | |
94 | |
95 |
96 |
97 |
98 | ### Combinatorial Auction
99 | ```
100 | # Generate MILP instances
101 | python 01_generate_instances.py cauctions
102 | # Generate supervised learning datasets
103 | python 02_generate_dataset.py cauctions -j 4 # number of available CPUs
104 | # Training
105 | for i in {0..4}
106 | do
107 | python 03_train_gnn.py cauctions -s $i
108 | done
109 | # Evaluation
110 | python 04_evaluate.py cauctions
111 | ```
112 |
113 | ### Capacitated Facility Location
114 | ```
115 | # Generate MILP instances
116 | python 01_generate_instances.py facilities
117 | # Generate supervised learning datasets
118 | python 02_generate_dataset.py facilities -j 4 # number of available CPUs
119 | # Training
120 | for i in {0..4}
121 | do
122 | python 03_train_gnn.py facilities -s $i
123 | done
124 | # Evaluation
125 | python 04_evaluate.py facilities
126 | ```
127 |
128 | ### Maximum Independent Set
129 | ```
130 | # Generate MILP instances
131 | python 01_generate_instances.py indset
132 | # Generate supervised learning datasets
133 | python 02_generate_dataset.py indset -j 4 # number of available CPUs
134 | # Training
135 | for i in {0..4}
136 | do
137 | python 03_train_gnn.py indset -s $i
138 | done
139 | # Evaluation
140 | python 04_evaluate.py indset
141 | ```
142 |
143 | ## Citation
144 | Please cite our paper if you use this code in your work.
145 | ```
146 | @inproceedings{conf/nips/GasseCFCL19,
147 | title={Exact Combinatorial Optimization with Graph Convolutional Neural Networks},
148 | author={Gasse, Maxime and Chételat, Didier and Ferroni, Nicola and Charlin, Laurent and Lodi, Andrea},
149 | booktitle={Advances in Neural Information Processing Systems 32},
150 | year={2019}
151 | }
152 | ```
153 |
154 | ## Questions / Bugs
155 | Please feel free to submit a Github issue if you have any questions or find any bugs. We do not guarantee any support, but will do our best if we can help.
156 |
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch_geometric
4 | import numpy as np
5 |
6 |
7 | class PreNormException(Exception):
8 | pass
9 |
10 |
11 | class PreNormLayer(torch.nn.Module):
12 | def __init__(self, n_units, shift=True, scale=True, name=None):
13 | super().__init__()
14 | assert shift or scale
15 | self.register_buffer('shift', torch.zeros(n_units) if shift else None)
16 | self.register_buffer('scale', torch.ones(n_units) if scale else None)
17 | self.n_units = n_units
18 | self.waiting_updates = False
19 | self.received_updates = False
20 |
21 | def forward(self, input_):
22 | if self.waiting_updates:
23 | self.update_stats(input_)
24 | self.received_updates = True
25 | raise PreNormException
26 |
27 | if self.shift is not None:
28 | input_ = input_ + self.shift
29 |
30 | if self.scale is not None:
31 | input_ = input_ * self.scale
32 |
33 | return input_
34 |
35 | def start_updates(self):
36 | self.avg = 0
37 | self.var = 0
38 | self.m2 = 0
39 | self.count = 0
40 | self.waiting_updates = True
41 | self.received_updates = False
42 |
43 | def update_stats(self, input_):
44 | """
45 | Online mean and variance estimation. See: Chan et al. (1979) Updating
46 | Formulae and a Pairwise Algorithm for Computing Sample Variances.
47 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
48 | """
49 | assert self.n_units == 1 or input_.shape[-1] == self.n_units, f"Expected input dimension of size {self.n_units}, got {input_.shape[-1]}."
50 |
51 | input_ = input_.reshape(-1, self.n_units)
52 | sample_avg = input_.mean(dim=0)
53 | sample_var = (input_ - sample_avg).pow(2).mean(dim=0)
54 | sample_count = np.prod(input_.size())/self.n_units
55 |
56 | delta = sample_avg - self.avg
57 |
58 | self.m2 = self.var * self.count + sample_var * sample_count + delta ** 2 * self.count * sample_count / (
59 | self.count + sample_count)
60 |
61 | self.count += sample_count
62 | self.avg += delta * sample_count / self.count
63 | self.var = self.m2 / self.count if self.count > 0 else 1
64 |
65 | def stop_updates(self):
66 | """
67 | Ends pre-training for that layer, and fixes the layers's parameters.
68 | """
69 | assert self.count > 0
70 | if self.shift is not None:
71 | self.shift = -self.avg
72 |
73 | if self.scale is not None:
74 | self.var[self.var < 1e-8] = 1
75 | self.scale = 1 / torch.sqrt(self.var)
76 |
77 | del self.avg, self.var, self.m2, self.count
78 | self.waiting_updates = False
79 | self.trainable = False
80 |
81 |
82 |
83 | class BipartiteGraphConvolution(torch_geometric.nn.MessagePassing):
84 | def __init__(self):
85 | super().__init__('add')
86 | emb_size = 64
87 |
88 | self.feature_module_left = torch.nn.Sequential(
89 | torch.nn.Linear(emb_size, emb_size)
90 | )
91 | self.feature_module_edge = torch.nn.Sequential(
92 | torch.nn.Linear(1, emb_size, bias=False)
93 | )
94 | self.feature_module_right = torch.nn.Sequential(
95 | torch.nn.Linear(emb_size, emb_size, bias=False)
96 | )
97 | self.feature_module_final = torch.nn.Sequential(
98 | PreNormLayer(1, shift=False),
99 | torch.nn.ReLU(),
100 | torch.nn.Linear(emb_size, emb_size)
101 | )
102 |
103 | self.post_conv_module = torch.nn.Sequential(
104 | PreNormLayer(1, shift=False)
105 | )
106 |
107 | # output_layers
108 | self.output_module = torch.nn.Sequential(
109 | torch.nn.Linear(2*emb_size, emb_size),
110 | torch.nn.ReLU(),
111 | torch.nn.Linear(emb_size, emb_size),
112 | )
113 |
114 | def forward(self, left_features, edge_indices, edge_features, right_features):
115 | output = self.propagate(edge_indices, size=(left_features.shape[0], right_features.shape[0]),
116 | node_features=(left_features, right_features), edge_features=edge_features)
117 | return self.output_module(torch.cat([self.post_conv_module(output), right_features], dim=-1))
118 |
119 | def message(self, node_features_i, node_features_j, edge_features):
120 | output = self.feature_module_final(self.feature_module_left(node_features_i)
121 | + self.feature_module_edge(edge_features)
122 | + self.feature_module_right(node_features_j))
123 | return output
124 |
125 |
126 | class BaseModel(torch.nn.Module):
127 | """
128 | Our base model class, which implements pre-training methods.
129 | """
130 |
131 | def pre_train_init(self):
132 | for module in self.modules():
133 | if isinstance(module, PreNormLayer):
134 | module.start_updates()
135 |
136 | def pre_train_next(self):
137 | for module in self.modules():
138 | if isinstance(module, PreNormLayer) and module.waiting_updates and module.received_updates:
139 | module.stop_updates()
140 | return module
141 | return None
142 |
143 | def pre_train(self, *args, **kwargs):
144 | try:
145 | with torch.no_grad():
146 | self.forward(*args, **kwargs)
147 | return False
148 | except PreNormException:
149 | return True
150 |
151 |
152 | class GNNPolicy(BaseModel):
153 | def __init__(self):
154 | super().__init__()
155 | emb_size = 64
156 | cons_nfeats = 5
157 | edge_nfeats = 1
158 | var_nfeats = 19
159 |
160 | # CONSTRAINT EMBEDDING
161 | self.cons_embedding = torch.nn.Sequential(
162 | PreNormLayer(cons_nfeats),
163 | torch.nn.Linear(cons_nfeats, emb_size),
164 | torch.nn.ReLU(),
165 | torch.nn.Linear(emb_size, emb_size),
166 | torch.nn.ReLU(),
167 | )
168 |
169 | # EDGE EMBEDDING
170 | self.edge_embedding = torch.nn.Sequential(
171 | PreNormLayer(edge_nfeats),
172 | )
173 |
174 | # VARIABLE EMBEDDING
175 | self.var_embedding = torch.nn.Sequential(
176 | PreNormLayer(var_nfeats),
177 | torch.nn.Linear(var_nfeats, emb_size),
178 | torch.nn.ReLU(),
179 | torch.nn.Linear(emb_size, emb_size),
180 | torch.nn.ReLU(),
181 | )
182 |
183 | self.conv_v_to_c = BipartiteGraphConvolution()
184 | self.conv_c_to_v = BipartiteGraphConvolution()
185 |
186 | self.output_module = torch.nn.Sequential(
187 | torch.nn.Linear(emb_size, emb_size),
188 | torch.nn.ReLU(),
189 | torch.nn.Linear(emb_size, 1, bias=False),
190 | )
191 |
192 | def forward(self, constraint_features, edge_indices, edge_features, variable_features):
193 | reversed_edge_indices = torch.stack([edge_indices[1], edge_indices[0]], dim=0)
194 |
195 | constraint_features = self.cons_embedding(constraint_features)
196 | edge_features = self.edge_embedding(edge_features)
197 | variable_features = self.var_embedding(variable_features)
198 |
199 | constraint_features = self.conv_v_to_c(variable_features, reversed_edge_indices, edge_features, constraint_features)
200 | variable_features = self.conv_c_to_v(constraint_features, edge_indices, edge_features, variable_features)
201 |
202 | output = self.output_module(variable_features).squeeze(-1)
203 | return output
204 |
--------------------------------------------------------------------------------
/utilities.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import pickle
3 | import datetime
4 | import numpy as np
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | import torch_geometric
9 |
10 | def log(str, logfile=None):
11 | str = f'[{datetime.datetime.now()}] {str}'
12 | print(str)
13 | if logfile is not None:
14 | with open(logfile, mode='a') as f:
15 | print(str, file=f)
16 |
17 |
18 | def pad_tensor(input_, pad_sizes, pad_value=-1e8):
19 | max_pad_size = pad_sizes.max()
20 | output = input_.split(pad_sizes.cpu().numpy().tolist())
21 | output = torch.stack([F.pad(slice_, (0, max_pad_size-slice_.size(0)), 'constant', pad_value)
22 | for slice_ in output], dim=0)
23 | return output
24 |
25 |
26 | class BipartiteNodeData(torch_geometric.data.Data):
27 | def __init__(self, constraint_features, edge_indices, edge_features, variable_features,
28 | candidates, nb_candidates, candidate_choice, candidate_scores):
29 | super().__init__()
30 | self.constraint_features = constraint_features
31 | self.edge_index = edge_indices
32 | self.edge_attr = edge_features
33 | self.variable_features = variable_features
34 | self.candidates = candidates
35 | self.nb_candidates = nb_candidates
36 | self.candidate_choices = candidate_choice
37 | self.candidate_scores = candidate_scores
38 |
39 | def __inc__(self, key, value, store, *args, **kwargs):
40 | if key == 'edge_index':
41 | return torch.tensor([[self.constraint_features.size(0)], [self.variable_features.size(0)]])
42 | elif key == 'candidates':
43 | return self.variable_features.size(0)
44 | else:
45 | return super().__inc__(key, value, *args, **kwargs)
46 |
47 |
48 | class GraphDataset(torch_geometric.data.Dataset):
49 | def __init__(self, sample_files):
50 | super().__init__(root=None, transform=None, pre_transform=None)
51 | self.sample_files = sample_files
52 |
53 | def len(self):
54 | return len(self.sample_files)
55 |
56 | def get(self, index):
57 | with gzip.open(self.sample_files[index], 'rb') as f:
58 | sample = pickle.load(f)
59 |
60 | sample_observation, sample_action, sample_action_set, sample_scores = sample['data']
61 |
62 | constraint_features, (edge_indices, edge_features), variable_features = sample_observation
63 | constraint_features = torch.FloatTensor(constraint_features)
64 | edge_indices = torch.LongTensor(edge_indices.astype(np.int32))
65 | edge_features = torch.FloatTensor(np.expand_dims(edge_features, axis=-1))
66 | variable_features = torch.FloatTensor(variable_features)
67 |
68 | candidates = torch.LongTensor(np.array(sample_action_set, dtype=np.int32))
69 | candidate_choice = torch.where(candidates == sample_action)[0][0] # action index relative to candidates
70 | candidate_scores = torch.FloatTensor([sample_scores[j] for j in candidates])
71 |
72 | graph = BipartiteNodeData(constraint_features, edge_indices, edge_features, variable_features,
73 | candidates, len(candidates), candidate_choice, candidate_scores)
74 | graph.num_nodes = constraint_features.shape[0]+variable_features.shape[0]
75 | return graph
76 |
77 |
78 | class Scheduler(torch.optim.lr_scheduler.ReduceLROnPlateau):
79 | def __init__(self, optimizer, **kwargs):
80 | super().__init__(optimizer, **kwargs)
81 |
82 | def step(self, metrics):
83 | # convert `metrics` to float, in case it's a zero-dim Tensor
84 | current = float(metrics)
85 | self.last_epoch =+1
86 |
87 | if self.is_better(current, self.best):
88 | self.best = current
89 | self.num_bad_epochs = 0
90 | else:
91 | self.num_bad_epochs += 1
92 |
93 | if self.num_bad_epochs == self.patience:
94 | self._reduce_lr(self.last_epoch)
95 |
96 | self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
97 |
--------------------------------------------------------------------------------