├── LICENSE ├── Pipfile ├── README.md ├── data.py ├── data └── dataset_params │ ├── binary_random.yaml │ ├── dense_random.yaml │ └── set_covering.yaml ├── experiments ├── knapsack │ ├── base.yaml │ ├── comboptnet.yaml │ ├── cvxpy.yaml │ └── mlp.yaml └── static_constraints │ ├── base.yaml │ ├── comboptnet.yaml │ ├── cvxpy.yaml │ └── mlp.yaml ├── main.py ├── media ├── arch_overview.png └── arch_overview_600x400.png ├── models ├── comboptnet.py ├── graph_matching.py ├── models.py └── modules.py ├── trainer.py └── utils ├── comboptnet_utils.py ├── constraint_generation.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Autonomous Learning Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | 8 | [packages] 9 | numpy = "*" 10 | torch = "==1.5.0" 11 | ray = "*" 12 | jax = ">=0.2.8" 13 | jaxlib = ">=0.1.58" 14 | cvxpy = "*" 15 | cvxpylayers = "*" 16 | 17 | 18 | [requires] 19 | python_version = "3.6" 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CombOptNet: Fit the Right NP-Hard Problem by Learning Integer Programming Constraints 2 | 3 | ![Architecture overview](media/arch_overview.png) 4 | 5 | 6 | This repository contains PyTorch implementation of the paper 7 | [CombOptNet: Fit the Right NP-Hard Problem by Learning Integer Programming Constraints](https://arxiv.org/abs/2105.02343) 8 | 9 | ## Installation 10 | 1) Run `pipenv install` (at your own risk with `--skip-lock` to save some time). 11 | 2) From within the pipenv environment run `python3 -m pip install -i https://pypi.gurobi.com gurobipy`. 12 | 3) Obtain a [license](https://www.gurobi.com/documentation/9.1/quickstart_mac/obtaining_a_grb_license.html) and download/set it. 13 | 4) Download and extract the [datasets](https://edmond.mpdl.mpg.de/imeji/collection/Z_abYaB4ggQTS_G0?q=). 14 | 15 | ## Usage 16 | For `[experiment] = knapsack` or `[experiment] = static_constraints`: 17 | 1) Set the `base_dataset_path` parameter in `experiments/[experiment]/base.yaml`. 18 | 2) In case of static constraints: set the `dataset_specification` parameter in `experiments/static_constraints/base.yaml` 19 | 3) Run `python3 main.py experiments/[experiment]/[method].yaml`. 20 | 21 | ## Citation 22 | 23 | ``` 24 | @misc{paulus2021comboptnet, 25 | title={CombOptNet: Fit the Right NP-Hard Problem by Learning Integer Programming Constraints}, 26 | author={Anselm Paulus and Michal Rolínek and Vít Musil and Brandon Amos and Georg Martius}, 27 | year={2021}, 28 | eprint={2105.02343}, 29 | archivePrefix={arXiv}, 30 | primaryClass={cs.LG} 31 | } 32 | ``` -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils import data 7 | 8 | from models.comboptnet import ilp_solver 9 | from utils.constraint_generation import sample_constraints 10 | from utils.utils import compute_normalized_solution, save_pickle, load_pickle, AvgMeters, check_equal_ys, \ 11 | solve_unconstrained, load_with_default_yaml, save_dict_as_one_line_csv 12 | 13 | 14 | def load_dataset(dataset_type, base_dataset_path, **dataset_params): 15 | dataset_path = os.path.join(base_dataset_path, dataset_type) 16 | dataset_loader_dict = dict(static_constraints=static_constraint_dataloader, knapsack=knapsack_dataloader) 17 | return dataset_loader_dict[dataset_type](dataset_path=dataset_path, **dataset_params) 18 | 19 | 20 | def static_constraint_dataloader(dataset_path, dataset_specification, num_gt_variables, num_gt_constraints, 21 | dataset_seed, train_dataset_size, loader_params): 22 | dataset_path = os.path.join(dataset_path, dataset_specification, str(num_gt_variables) + '_dim', 23 | str(num_gt_constraints) + '_const', str(dataset_seed), 'dataset.p') 24 | datasets = load_pickle(dataset_path) 25 | 26 | train_ys = [tuple(y) for c, y in datasets['train'][:train_dataset_size]] 27 | test_ys = [tuple(y) for c, y in datasets['test'][:train_dataset_size]] 28 | 29 | print(f'Successfully loaded Static Constraints dataset.\n' 30 | f'Number of distinct solutions in train set: {len(set(train_ys))}\n' 31 | f'Number of distinct solutions in test set: {len(set(test_ys))}') 32 | 33 | training_set = Dataset(datasets['train'][:train_dataset_size]) 34 | train_iterator = data.DataLoader(training_set, **loader_params) 35 | 36 | test_iterator = data.DataLoader(Dataset(datasets['test']), **loader_params) 37 | 38 | return (train_iterator, test_iterator), datasets['metadata'] 39 | 40 | 41 | def knapsack_dataloader(dataset_path, loader_params): 42 | variable_range = dict(lb=0, ub=1) 43 | num_variables = 10 44 | 45 | train_encodings = np.load(os.path.join(dataset_path, 'train_encodings.npy')) 46 | train_ys = compute_normalized_solution(np.load(os.path.join(dataset_path, 'train_sols.npy')), **variable_range) 47 | train_dataset = list(zip(train_encodings, train_ys)) 48 | training_set = Dataset(train_dataset) 49 | train_iterator = data.DataLoader(training_set, **loader_params) 50 | 51 | test_encodings = np.load(os.path.join(dataset_path, 'test_encodings.npy')) 52 | test_ys = compute_normalized_solution(np.load(os.path.join(dataset_path, 'test_sols.npy')), **variable_range) 53 | test_dataset = list(zip(test_encodings, test_ys)) 54 | test_set = Dataset(test_dataset) 55 | test_iterator = data.DataLoader(test_set, **loader_params) 56 | 57 | distinct_ys_train = len(set([tuple(y) for y in train_ys])) 58 | distinct_ys_test = len(set([tuple(y) for y in test_ys])) 59 | print(f'Successfully loaded Knapsack dataset.\n' 60 | f'Number of distinct solutions in train set: {distinct_ys_train},\n' 61 | f'Number of distinct solutions in test set: {distinct_ys_test}') 62 | 63 | metadata = {"variable_range": variable_range, 64 | "num_variables": num_variables} 65 | 66 | return (train_iterator, test_iterator), metadata 67 | 68 | 69 | class Dataset(data.Dataset): 70 | def __init__(self, dataset): 71 | self.dataset = dataset 72 | 73 | def __len__(self): 74 | return len(self.dataset) 75 | 76 | def __getitem__(self, index): 77 | x, y = [torch.from_numpy(_x) for _x in self.dataset[index]] 78 | return x, y 79 | 80 | 81 | def gen_constraints_dataset(train_dataset_size, test_dataset_size, seed, variable_range, num_variables, 82 | num_constraints, positive_costs, constraint_params): 83 | np.random.seed(seed) 84 | constraints = sample_constraints(variable_range=variable_range, 85 | num_variables=num_variables, 86 | num_constraints=num_constraints, 87 | seed=seed, **constraint_params) 88 | metadata = dict(true_constraints=constraints, num_variables=num_variables, num_constraints=num_constraints, 89 | variable_range=variable_range) 90 | 91 | c_l = [] 92 | y_l = [] 93 | dataset = [] 94 | for _ in range(test_dataset_size + train_dataset_size): 95 | cost_vector = 2 * (np.random.rand(constraints.shape[1] - 1) - 0.5) 96 | if positive_costs: 97 | cost_vector = np.abs(cost_vector) 98 | y = ilp_solver(cost_vector=cost_vector, constraints=constraints, **variable_range)[0] 99 | y_norm = compute_normalized_solution(y, **variable_range) 100 | dataset.append((cost_vector, y_norm)) 101 | c_l.append(cost_vector) 102 | y_l.append(y) 103 | cs, ys = np.stack(c_l, axis=0), np.stack(y_l, axis=0) 104 | 105 | num_distinct_ys = len(set([tuple(y) for _, y in dataset])) 106 | ys_uncon = solve_unconstrained(cs, **variable_range) 107 | match_boxconst_solution_acc = check_equal_ys(y_1=ys, y_2=ys_uncon)[1].mean() 108 | metrics = dict(num_distinct_ys=num_distinct_ys, match_boxconst_solution_acc=match_boxconst_solution_acc) 109 | print(f'Num distinct ys: {num_distinct_ys}, Match boxconst acc: {match_boxconst_solution_acc}') 110 | 111 | test_set = dataset[:test_dataset_size] 112 | train_set = dataset[test_dataset_size:] 113 | datasets = dict(metadata=metadata, train=train_set, test=test_set) 114 | return datasets, metrics 115 | 116 | 117 | def main(working_dir, num_seeds, num_constraints, num_variables, data_gen_params): 118 | avg_meter = AvgMeters() 119 | all_metrics = {} 120 | for num_const, num_var in zip(num_constraints, num_variables): 121 | print(f'Gnerating dataset with {num_var} variables and {num_const} constraints...') 122 | for seed in range(num_seeds): 123 | dir = os.path.join(working_dir, str(num_var) + "_dim", str(num_const) + "_const", str(seed)) 124 | os.makedirs(dir, exist_ok=True) 125 | datasets, metrics = gen_constraints_dataset(seed=seed, num_variables=num_var, 126 | num_constraints=num_const, **data_gen_params) 127 | save_pickle(datasets, os.path.join(dir, 'dataset.p')) 128 | avg_meter.update(metrics) 129 | all_metrics.update( 130 | avg_meter.get_averages(prefix=str(num_var) + "_dim_" + str(num_const) + "_const_")) 131 | avg_meter.reset() 132 | save_dict_as_one_line_csv(all_metrics, filename=os.path.join(working_dir, "metrics.csv")) 133 | return all_metrics 134 | 135 | 136 | if __name__ == "__main__": 137 | param_path = sys.argv[1] 138 | param_dict = load_with_default_yaml(path=param_path) 139 | main(**param_dict) 140 | -------------------------------------------------------------------------------- /data/dataset_params/binary_random.yaml: -------------------------------------------------------------------------------- 1 | "working_dir": "datasets/static_constraints/binary_random" 2 | "num_variables": [16, 16, 16, 16] 3 | "num_constraints": [1, 2, 4, 8] 4 | "num_seeds": 10 5 | "data_gen_params": 6 | "test_dataset_size": 1000 7 | "train_dataset_size": 1600 8 | "constraint_params": 9 | "constraint_type": "random_const" 10 | "feasible_point": "random_corner" 11 | "positive_costs": flase 12 | "variable_range": 13 | "lb": 0.0 14 | "ub": 1.0 -------------------------------------------------------------------------------- /data/dataset_params/dense_random.yaml: -------------------------------------------------------------------------------- 1 | "working_dir": "datasets/static_constraints/dense_random" 2 | "num_variables": [16, 16, 16, 16] 3 | "num_constraints": [1, 2, 4, 8] 4 | "num_seeds": 10 5 | "data_gen_params": 6 | "test_dataset_size": 1000 7 | "train_dataset_size": 1600 8 | "constraint_params": 9 | "constraint_type": "random_const" 10 | "feasible_point": "random_corner" 11 | "positive_costs": flase 12 | "variable_range": 13 | "lb": -5.0 14 | "ub": 5.0 -------------------------------------------------------------------------------- /data/dataset_params/set_covering.yaml: -------------------------------------------------------------------------------- 1 | "working_dir": "datasets/static_constraints/set_covering" 2 | "num_variables": [8, 12, 16, 20] 3 | "num_constraints": [4, 6, 8, 10] 4 | "num_seeds": 10 5 | "data_gen_params": 6 | "test_dataset_size": 1000 7 | "train_dataset_size": 1600 8 | "constraint_params": 9 | "max_subset_size": 3 # maximum size of a sampled subset 10 | "constraint_type": "set_covering" 11 | "positive_costs": flase 12 | "variable_range": 13 | "lb": 0.0 14 | "ub": 1.0 -------------------------------------------------------------------------------- /experiments/knapsack/base.yaml: -------------------------------------------------------------------------------- 1 | # base settings file for knapsack experiment 2 | 3 | "seed": null 4 | 5 | "working_dir": "results/knapsack/1" 6 | "use_ray": false 7 | "ray_params": 8 | "num_cpus": 20 9 | 10 | "data_params": 11 | "base_dataset_path": ".../datasets" # Add correct dataset path here ".../datasets" 12 | "dataset_type": "knapsack" 13 | "loader_params": 14 | "batch_size": 8 15 | "shuffle": true 16 | 17 | "train_epochs": 100 18 | "eval_every": 10 19 | "trainer_params": 20 | "use_cuda": false 21 | "loss_name": "MSE" 22 | "optimizer_name": "Adam" 23 | "optimizer_params": 24 | "lr": 0.00005 -------------------------------------------------------------------------------- /experiments/knapsack/comboptnet.yaml: -------------------------------------------------------------------------------- 1 | # settings file for knapsack with comboptnet 2 | 3 | "__import__": "experiments/knapsack/base.yaml" 4 | 5 | "trainer_params": 6 | "trainer_name": "KnapsackConstraintLearningTrainer" 7 | 8 | "model_params": 9 | "backbone_module_params": 10 | "hidden_layer_size": 512 11 | 12 | "solver_module_params": 13 | "solver_name": "CombOptNet" 14 | -------------------------------------------------------------------------------- /experiments/knapsack/cvxpy.yaml: -------------------------------------------------------------------------------- 1 | # settings file for knapsack with cvxpy 2 | 3 | "__import__": "experiments/knapsack/base.yaml" 4 | 5 | "trainer_params": 6 | "trainer_name": "KnapsackConstraintLearningTrainer" 7 | 8 | "model_params": 9 | "backbone_module_params": 10 | "hidden_layer_size": 512 11 | 12 | "solver_module_params": 13 | "solver_name": "Cvxpy" 14 | "use_entropy": true 15 | -------------------------------------------------------------------------------- /experiments/knapsack/mlp.yaml: -------------------------------------------------------------------------------- 1 | # settings file for knapsack with mlp 2 | 3 | "__import__": "experiments/knapsack/base.yaml" 4 | 5 | "trainer_params": 6 | "trainer_name": "MLPTrainer" 7 | 8 | "model_params": 9 | "model_name": "KnapsackMLP" 10 | "reduced_embed_dim": 100 11 | "hidden_layer_size": 100 12 | -------------------------------------------------------------------------------- /experiments/static_constraints/base.yaml: -------------------------------------------------------------------------------- 1 | # base settings file for static constraints experiments 2 | 3 | "seed": null 4 | 5 | "working_dir": "results/static_constraints/1" 6 | "use_ray": false 7 | "ray_params": 8 | "num_cpus": 20 9 | 10 | "data_params": 11 | "base_dataset_path": ".../datasets" # Add correct dataset path here ".../datasets" 12 | "dataset_type": "static_constraints" 13 | "dataset_specification": "set_covering" # "dense_random", "binary_random" or "set_covering" 14 | "dataset_seed": 0 15 | "num_gt_variables": 16 16 | "num_gt_constraints": 8 17 | "train_dataset_size": 1600 18 | "loader_params": 19 | "batch_size": 8 20 | "shuffle": true 21 | 22 | "train_epochs": 100 23 | "eval_every": 10 24 | "trainer_params": 25 | "use_cuda": false 26 | "loss_name": "MSE" 27 | "optimizer_name": "Adam" 28 | "optimizer_params": 29 | "lr": 0.0005 -------------------------------------------------------------------------------- /experiments/static_constraints/comboptnet.yaml: -------------------------------------------------------------------------------- 1 | # settings file for random constraints with comboptnet 2 | 3 | "__import__": "experiments/static_constraints/base.yaml" 4 | 5 | "trainer_params": 6 | "trainer_name": "RandomConstraintLearningTrainer" 7 | 8 | "model_params": 9 | "constraint_module_params": 10 | "num_constraints": 2 11 | "normalize_constraints": true 12 | 13 | "solver_module_params": 14 | "solver_name": "CombOptNet" 15 | "tau": 0.5 -------------------------------------------------------------------------------- /experiments/static_constraints/cvxpy.yaml: -------------------------------------------------------------------------------- 1 | # settings file for random constraints with cvxpy 2 | 3 | "__import__": "experiments/static_constraints/base.yaml" 4 | 5 | "trainer_params": 6 | "trainer_name": "RandomConstraintLearningTrainer" 7 | 8 | "model_params": 9 | "constraint_module_params": 10 | "num_constraints": 2 11 | "normalize_constraints": true 12 | "feasible_point": "random_corner" 13 | 14 | 15 | "solver_module_params": 16 | "solver_name": "Cvxpy" 17 | "use_entropy": false -------------------------------------------------------------------------------- /experiments/static_constraints/mlp.yaml: -------------------------------------------------------------------------------- 1 | # settings file for random constraints with mlp 2 | 3 | "__import__": "experiments/static_constraints/base.yaml" 4 | 5 | "trainer_params": 6 | "trainer_name": "MLPTrainer" 7 | 8 | "model_params": 9 | "model_name": "RandomConstraintsMLP" 10 | "hidden_layer_size": 512 11 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import ray 5 | 6 | from data import load_dataset 7 | from trainer import get_trainer 8 | from utils.utils import print_eval_acc, print_train_acc, load_with_default_yaml, save_dict_as_one_line_csv 9 | 10 | 11 | def main(working_dir, seed, train_epochs, eval_every, use_ray, ray_params, data_params, trainer_params): 12 | if use_ray: 13 | ray.init(**ray_params) 14 | 15 | (train_iterator, test_iterator), metadata = load_dataset(**data_params) 16 | trainer = get_trainer(seed=seed, train_iterator=train_iterator, test_iterator=test_iterator, metadata=metadata, 17 | **trainer_params) 18 | 19 | eval_metrics = trainer.evaluate() 20 | print_eval_acc(eval_metrics) 21 | 22 | for i in range(train_epochs): 23 | train_metrics = trainer.train_epoch() 24 | print_train_acc(train_metrics, epoch=i) 25 | if eval_every is not None and (i + 1) % eval_every == 0: 26 | eval_metrics = trainer.evaluate() 27 | print_eval_acc(eval_metrics) 28 | 29 | eval_metrics = trainer.evaluate() 30 | print_eval_acc(eval_metrics) 31 | 32 | if use_ray: 33 | ray.shutdown() 34 | metrics = dict(**train_metrics, **eval_metrics) 35 | save_dict_as_one_line_csv(metrics, filename=os.path.join(working_dir, "metrics.csv")) 36 | return metrics 37 | 38 | 39 | if __name__ == "__main__": 40 | param_path = sys.argv[1] 41 | param_dict = load_with_default_yaml(path=param_path) 42 | main(**param_dict) 43 | -------------------------------------------------------------------------------- /media/arch_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/CombOptNet/d563d31a95dce35a365d50b81f932c27531ae09b/media/arch_overview.png -------------------------------------------------------------------------------- /media/arch_overview_600x400.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/CombOptNet/d563d31a95dce35a365d50b81f932c27531ae09b/media/arch_overview_600x400.png -------------------------------------------------------------------------------- /models/comboptnet.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import gurobipy as gp 4 | import jax.numpy as jnp 5 | import numpy as np 6 | import torch 7 | from gurobipy import GRB, quicksum 8 | from jax import grad 9 | 10 | from utils.comboptnet_utils import compute_delta_y, check_point_feasibility, softmin, \ 11 | signed_euclidean_distance_constraint_point, tensor_to_jax 12 | from utils.utils import ParallelProcessing 13 | 14 | 15 | class CombOptNetModule(torch.nn.Module): 16 | def __init__(self, variable_range, tau=None, clip_gradients_to_box=True, use_canonical_basis=False): 17 | super().__init__() 18 | """ 19 | @param variable_range: dict(lb, ub), range of variables in the ILP 20 | @param tau: a float/np.float32/torch.float32, the value of tau for computing the constraint gradient 21 | @param clip_gradients_to_box: boolean flag, if true the gradients are projected into the feasible hypercube 22 | @param use_canonical_basis: boolean flag, if true the canonical basis is used instead of delta basis 23 | """ 24 | self.solver_params = dict(tau=tau, variable_range=variable_range, clip_gradients_to_box=clip_gradients_to_box, 25 | use_canonical_basis=use_canonical_basis, parallel_processing=ParallelProcessing()) 26 | self.solver = DifferentiableILPsolver 27 | 28 | def forward(self, cost_vector, constraints): 29 | """ 30 | Forward pass of CombOptNet running a differentiable ILP solver 31 | @param cost_vector: torch.Tensor of shape (bs, num_variables) with batch of ILP cost vectors 32 | @param constraints: torch.Tensor of shape (bs, num_const, num_variables + 1) or (num_const, num_variables + 1) 33 | with (potentially batch of) ILP constraints 34 | @return: torch.Tensor of shape (bs, num_variables) with integer values capturing the solution of the ILP 35 | """ 36 | if len(constraints.shape) == 2: 37 | bs = cost_vector.shape[0] 38 | constraints = torch.stack(bs * [constraints]) 39 | y, infeasibility_indicator = self.solver.apply(cost_vector, constraints, self.solver_params) 40 | return y 41 | 42 | 43 | class DifferentiableILPsolver(torch.autograd.Function): 44 | """ 45 | Differentiable ILP solver as a torch.Function 46 | """ 47 | 48 | @staticmethod 49 | def forward(ctx, cost_vector, constraints, params): 50 | """ 51 | Implementation of the forward pass of a batched (potentially parallelized) ILP solver. 52 | @param ctx: context for backpropagation 53 | @param cost_vector: torch.Tensor of shape (bs, num_variables) with batch of ILp cost vectors 54 | @param constraints: torch.Tensor of shape (bs, num_const, num_variables + 1) with batch of ILP constraints 55 | @param params: a dict of additional params. Must contain: 56 | tau: a float/np.float32/torch.float32, the value of tau for computing the constraint gradient 57 | clip_gradients_to_box: boolean flag, if true the gradients are projected into the feasible hypercube 58 | @return: torch.Tensor of shape (bs, num_variables) with integer values capturing the solution of the ILP, 59 | torch.Tensor of shape (bs) with 0/1 values, where 1 corresponds to an infeasible ILP instance 60 | """ 61 | device = constraints.device 62 | maybe_parallelize = params['parallel_processing'].maybe_parallelize 63 | 64 | dynamic_args = [{"cost_vector": cost_vector, "constraints": const} for cost_vector, const in 65 | zip(cost_vector.cpu().detach().numpy(), constraints.cpu().detach().numpy())] 66 | 67 | result = maybe_parallelize(ilp_solver, params['variable_range'], dynamic_args) 68 | y, infeasibility_indicator = [torch.from_numpy(np.array(res)).to(device) for res in zip(*result)] 69 | 70 | ctx.params = params 71 | ctx.save_for_backward(cost_vector, constraints, y, infeasibility_indicator) 72 | return y, infeasibility_indicator 73 | 74 | @staticmethod 75 | def backward(ctx, y_grad, _): 76 | """ 77 | Backward pass computation. 78 | @param ctx: context from the forward pass 79 | @param y_grad: torch.Tensor of shape (bs, num_variables) describing the incoming gradient dy for y 80 | @return: torch.Tensor of shape (bs, num_variables) gradient dL / cost_vector 81 | torch.Tensor of shape (bs, num_constraints, num_variables + 1) gradient dL / constraints 82 | """ 83 | cost_vector, constraints, y, infeasibility_indicator = ctx.saved_tensors 84 | assert y.shape == y_grad.shape 85 | 86 | grad_mismatch_function = grad(mismatch_function, argnums=[0, 1]) 87 | grad_cost_vector, grad_constraints = grad_mismatch_function(tensor_to_jax(cost_vector), 88 | tensor_to_jax(constraints), 89 | tensor_to_jax(y), 90 | tensor_to_jax(y_grad), 91 | tensor_to_jax(infeasibility_indicator), 92 | variable_range=ctx.params['variable_range'], 93 | clip_gradients_to_box=ctx.params[ 94 | 'clip_gradients_to_box'], 95 | use_canonical_basis=ctx.params[ 96 | 'use_canonical_basis'], 97 | tau=ctx.params['tau']) 98 | 99 | cost_vector_grad = torch.from_numpy(np.array(grad_cost_vector)).to(y_grad.device) 100 | constraints_gradient = torch.from_numpy(np.array(grad_constraints)).to(y_grad.device) 101 | return cost_vector_grad, constraints_gradient, None 102 | 103 | 104 | def ilp_solver(cost_vector, constraints, lb, ub): 105 | """ 106 | ILP solver using Gurobi. Computes the solution of a single integer linear program 107 | y* = argmin_y (c * y) subject to A @ y + b <= 0, y integer, lb <= y <= ub 108 | 109 | @param cost_vector: np.array of shape (num_variables) with cost vector of the ILP 110 | @param constraints: np.array of shape (num_const, num_variables + 1) with constraints of the ILP 111 | @param lb: float, lower bound of variables 112 | @param ub: float, upper bound of variables 113 | @return: np.array of shape (num_variables) with integer values capturing the solution of the ILP, 114 | boolean flag, where true corresponds to an infeasible ILP instance 115 | """ 116 | A, b = constraints[:, :-1], constraints[:, -1] 117 | num_constraints, num_variables = A.shape 118 | 119 | model = gp.Model("mip1") 120 | model.setParam('OutputFlag', 0) 121 | model.setParam("Threads", 1) 122 | 123 | variables = [model.addVar(lb=lb, ub=ub, vtype=GRB.INTEGER, name='v' + str(i)) for i in range(num_variables)] 124 | model.setObjective(quicksum(c * var for c, var in zip(cost_vector, variables)), GRB.MINIMIZE) 125 | for a, _b in zip(A, b): 126 | model.addConstr(quicksum(c * var for c, var in zip(a, variables)) + _b <= 0) 127 | model.optimize() 128 | try: 129 | y = np.array([v.x for v in model.getVars()]) 130 | infeasible = False 131 | except AttributeError: 132 | warnings.warn(f'Infeasible ILP encountered. Dummy solution should be handled as special case.') 133 | y = np.zeros_like(cost_vector) 134 | infeasible = True 135 | return y, infeasible 136 | 137 | 138 | def mismatch_function(cost_vector, constraints, y, y_grad, infeasibility_indicator, variable_range, tau, 139 | clip_gradients_to_box, use_canonical_basis, average_solution=True, use_cost_mismatch=True): 140 | """ 141 | Computes the combined mismatch function for cost vectors and constraints P_(dy)(A, b, c) = P_(dy)(A, b) + P_(dy)(c) 142 | P_(dy)(A, b) = sum_k(lambda_k * P_(delta_k)(A, b)), 143 | P_(dy)(c) = sum_k(lambda_k * P_(delta_k)(c)), 144 | where delta_k = y'_k - y 145 | 146 | @param cost_vector: jnp.array of shape (bs, num_variables) with batch of ILP cost vectors 147 | @param constraints: jnp.array of shape (bs, num_const, num_variables + 1) with batch of ILP constraints 148 | @param y: jnp.array of shape (bs, num_variables) with batch of ILP solutions 149 | @param y_grad: jnp.array of shape (bs, num_variables) with batch of incoming gradients for y 150 | @param infeasibility_indicator: jnp.array of shape (bs) with batch of indicators whether ILP has feasible solution 151 | @param variable_range: dict(lb, ub), range of variables in the ILP 152 | @param tau: a float/np.float32/torch.float32, the value of tau for computing the constraint gradient 153 | @param clip_gradients_to_box: boolean flag, if true the gradients are projected into the feasible hypercube 154 | @param use_canonical_basis: boolean flag, if true the canonical basis is used instead of delta basis 155 | 156 | @return: jnp.array scalar with value of mismatch function 157 | """ 158 | num_constraints = constraints.shape[1] 159 | if num_constraints > 1 and tau is None: 160 | raise ValueError('If more than one constraint is used the parameter tau needs to be specified.') 161 | if num_constraints == 1 and tau is not None: 162 | warnings.warn('The specified parameter tau has no influence as only a single constraint is used.') 163 | 164 | delta_y, lambdas = compute_delta_y(y=y, y_grad=y_grad, clip_gradients_to_box=clip_gradients_to_box, 165 | use_canonical_basis=use_canonical_basis, **variable_range) 166 | y = y[:, None, :] 167 | y_prime = y + delta_y 168 | cost_vector = cost_vector[:, None, :] 169 | constraints = constraints[:, None, :, :] 170 | 171 | y_prime_feasible_constraints, y_prime_inside_box = check_point_feasibility(point=y_prime, constraints=constraints, 172 | **variable_range) 173 | feasibility_indicator = 1.0 - infeasibility_indicator 174 | correct_solution_indicator = jnp.all(jnp.isclose(y, y_prime), axis=-1) 175 | # solution can only be correct if we also have a feasible problem 176 | # (fixes case in which infeasible solution dummy matches ground truth) 177 | correct_solution_indicator *= feasibility_indicator[:, None] 178 | incorrect_solution_indicator = 1.0 - correct_solution_indicator 179 | 180 | constraints_mismatch = compute_constraints_mismatch(constraints=constraints, y=y, y_prime=y_prime, 181 | y_prime_feasible_constraints=y_prime_feasible_constraints, 182 | y_prime_inside_box=y_prime_inside_box, tau=tau, 183 | incorrect_solution_indicator=incorrect_solution_indicator) 184 | cost_mismatch = compute_cost_mismatch(cost_vector=cost_vector, y=y, y_prime=y_prime, 185 | y_prime_feasible_constraints=y_prime_feasible_constraints, 186 | y_prime_inside_box=y_prime_inside_box) 187 | total_mismatch = constraints_mismatch 188 | if use_cost_mismatch: 189 | total_mismatch += cost_mismatch 190 | 191 | total_mismatch = jnp.mean(total_mismatch * lambdas, axis=-1) # scale mismatch functions of sparse y' with lambda 192 | if average_solution: 193 | total_mismatch = jnp.mean(total_mismatch) 194 | return total_mismatch 195 | 196 | 197 | def compute_cost_mismatch(cost_vector, y, y_prime, y_prime_feasible_constraints, y_prime_inside_box): 198 | """ 199 | Computes the mismatch function for cost vectors P_(delta_k)(c), where delta_k = y'_k - y 200 | """ 201 | c_diff = jnp.sum(cost_vector * y_prime, axis=-1) - jnp.sum(cost_vector * y, axis=-1) 202 | cost_mismatch = jnp.maximum(c_diff, 0.0) 203 | 204 | # case distinction in paper: if y' is (constraint-)infeasible or outside of hypercube cost mismatch function is zero 205 | cost_mismatch = cost_mismatch * y_prime_inside_box * y_prime_feasible_constraints 206 | return cost_mismatch 207 | 208 | 209 | def compute_constraints_mismatch(constraints, y, y_prime, y_prime_inside_box, y_prime_feasible_constraints, 210 | incorrect_solution_indicator, tau): 211 | """ 212 | Computes the mismatch function for constraints P_(delta_k)(A, b), where delta_k = y'_k - y 213 | """ 214 | # case 1 in paper: if y' is (constraint-)feasible, y' is inside the hypercube and y != y' 215 | constraints_mismatch_feasible = compute_constraints_mismatch_feasible(constraints=constraints, y=y, tau=tau) 216 | constraints_mismatch_feasible *= y_prime_feasible_constraints 217 | 218 | # case 2 in paper: if y' is (constraint-)infeasible, y' is inside the hypercube and y != y' 219 | constraints_mismatch_infeasible = compute_constraints_mismatch_infeasible(constraints=constraints, y_prime=y_prime) 220 | constraints_mismatch_infeasible *= (1.0 - y_prime_feasible_constraints) 221 | 222 | constraints_mismatch = constraints_mismatch_feasible + constraints_mismatch_infeasible 223 | 224 | # case 3 in paper: if y prime is outside the hypercube or y = y' constraint mismatch function is zero 225 | constraints_mismatch = constraints_mismatch * y_prime_inside_box * incorrect_solution_indicator 226 | return constraints_mismatch 227 | 228 | 229 | def compute_constraints_mismatch_feasible(constraints, y, tau): 230 | distance_y_const = signed_euclidean_distance_constraint_point(constraints=constraints, point=y) 231 | constraints_mismatch_feasible = jnp.maximum(-distance_y_const, 0.0) 232 | constraints_mismatch_feasible = softmin(constraints_mismatch_feasible, tau=tau, axis=-1) 233 | return constraints_mismatch_feasible 234 | 235 | 236 | def compute_constraints_mismatch_infeasible(constraints, y_prime): 237 | distance_y_prime_const = signed_euclidean_distance_constraint_point(constraints=constraints, point=y_prime) 238 | constraints_mismatch_infeasible = jnp.maximum(distance_y_prime_const, 0.0) 239 | constraints_mismatch_infeasible = jnp.sum(constraints_mismatch_infeasible, axis=-1) 240 | return constraints_mismatch_infeasible 241 | -------------------------------------------------------------------------------- /models/graph_matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from models.modules import StaticConstraintModule, CombOptNetModule 4 | 5 | 6 | class StaticConstraintLearnerModule(torch.nn.Module): 7 | """ 8 | Combined module for the graph matching demonstration, containing the constraint and solver module. 9 | Use with repo https://github.com/martius-lab/blackbox-deep-graph-matching. 10 | """ 11 | 12 | def __init__(self, keypoints_per_image): 13 | """ 14 | @param keypoints_per_image: int, fixed number of keypoints in src and tgt image 15 | """ 16 | super().__init__() 17 | 18 | variable_range = dict(lb=0.0, ub=1.0) 19 | num_variables = keypoints_per_image * keypoints_per_image 20 | num_learned_constraints = 2 * keypoints_per_image 21 | self.static_constraint_module = StaticConstraintModule(variable_range=variable_range, 22 | num_variables=num_variables, 23 | num_constraints=num_learned_constraints, 24 | normalize_constraints=True) 25 | self.solver_module = CombOptNetModule(variable_range=variable_range) 26 | 27 | def forward(self, unary_costs_list): 28 | """ 29 | @param unary_costs_list: list (length bs) of torch.tensors of shape [keypoints_per_image, keypoints_per_image] 30 | 31 | @return: torch.tensor of shape [bs, keypoints_per_image, keypoints_per_image] with 0/1 assignment 32 | """ 33 | constraints = self.static_constraint_module() 34 | 35 | cost = torch.stack(unary_costs_list, dim=0) 36 | old_shape = cost.shape 37 | cost_vector = cost.reshape(old_shape[0], -1) 38 | 39 | y_denorm_flat_batch = self.solver_module(cost_vector=cost_vector, constraints=constraints) 40 | y_denorm_batch = y_denorm_flat_batch.reshape(*old_shape) 41 | return y_denorm_batch 42 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MLP(torch.nn.Module): 6 | def __init__(self, out_features, in_features, hidden_layer_size, output_nonlinearity): 7 | super().__init__() 8 | self.fc1 = nn.Linear(in_features=in_features, out_features=hidden_layer_size) 9 | self.fc2 = nn.Linear(in_features=hidden_layer_size, out_features=out_features) 10 | self.output_nonlinearity_fn = nonlinearity_dict[output_nonlinearity] 11 | 12 | def forward(self, x): 13 | x = torch.relu(self.fc1(x.float())) 14 | x = self.fc2(x) 15 | x = self.output_nonlinearity_fn(x) 16 | return x 17 | 18 | 19 | class KnapsackMLP(MLP): 20 | """ 21 | Predicts normalized solution y (range [-0.5, 0.5]) 22 | """ 23 | 24 | def __init__(self, num_variables, reduced_embed_dim, embed_dim=4096, **kwargs): 25 | super().__init__(in_features=num_variables * reduced_embed_dim, out_features=num_variables, 26 | output_nonlinearity='sigmoid', **kwargs) 27 | self.reduce_embedding_layer = nn.Linear(in_features=embed_dim, out_features=reduced_embed_dim) 28 | 29 | def forward(self, x): 30 | bs = x.shape[0] 31 | x = self.reduce_embedding_layer(x.float()) 32 | x = x.reshape(shape=(bs, -1)) 33 | x = super().forward(x) 34 | y_norm = x - 0.5 35 | return y_norm 36 | 37 | 38 | class RandomConstraintsMLP(MLP): 39 | """ 40 | Predicts normalized solution y (range [-0.5, 0.5]) 41 | """ 42 | 43 | def __init__(self, num_variables, **kwargs): 44 | super().__init__(in_features=num_variables, out_features=num_variables, 45 | output_nonlinearity='sigmoid', **kwargs) 46 | 47 | def forward(self, x): 48 | x = super().forward(x) 49 | y_norm = x - 0.5 50 | return y_norm 51 | 52 | 53 | class KnapsackExtractWeightsCostFromEmbeddingMLP(MLP): 54 | """ 55 | Extracts weights and prices of vector-embedding of Knapsack instance 56 | 57 | @return: torch.Tensor of shape (bs, num_variables) with negative extracted prices, 58 | torch.Tensor of shape (bs, num_constraints, num_variables + 1) with extracted weights and negative knapsack capacity 59 | """ 60 | 61 | def __init__(self, num_constraints=1, embed_dim=4096, knapsack_capacity=1.0, weight_min=0.15, weight_max=0.35, 62 | cost_min=0.10, cost_max=0.45, output_nonlinearity='sigmoid', **kwargs): 63 | self.num_constraints = num_constraints 64 | 65 | self.knapsack_capacity = knapsack_capacity 66 | self.weight_min = weight_min 67 | self.weight_range = weight_max - weight_min 68 | self.cost_min = cost_min 69 | self.cost_range = cost_max - cost_min 70 | 71 | super().__init__(in_features=embed_dim, out_features=num_constraints + 1, 72 | output_nonlinearity=output_nonlinearity, **kwargs) 73 | 74 | def forward(self, x): 75 | x = super().forward(x) 76 | batch_size = x.shape[0] 77 | cost, As = x.split([1, self.num_constraints], dim=-1) 78 | cost = -(self.cost_min + self.cost_range * cost[..., 0]) 79 | As = As.transpose(1, 2) 80 | As = self.weight_min + self.weight_range * As 81 | bs = -torch.ones(batch_size, self.num_constraints).to(As.device) * self.knapsack_capacity 82 | constraints = torch.cat([As, bs[..., None]], dim=-1) 83 | return cost, constraints 84 | 85 | 86 | nonlinearity_dict = dict(tanh=torch.tanh, relu=torch.relu, sigmoid=torch.sigmoid, identity=lambda x: x) 87 | baseline_mlp_dict = dict(RandomConstraintsMLP=RandomConstraintsMLP, KnapsackMLP=KnapsackMLP) 88 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import cvxpy as cp 2 | import torch 3 | from cvxpylayers.torch import CvxpyLayer 4 | from diffcp import SolverError 5 | 6 | from models.comboptnet import CombOptNetModule 7 | from utils.constraint_generation import sample_offset_constraints_numpy, compute_constraints_in_base_coordinate_system, \ 8 | compute_normalized_constraints 9 | from utils.utils import torch_parameter_from_numpy 10 | 11 | 12 | class StaticConstraintModule(torch.nn.Module): 13 | """ 14 | Module wrapping parameters of a learnable constraint set. 15 | """ 16 | 17 | def __init__(self, num_constraints, num_variables, variable_range, normalize_constraints, learn_offsets=True, 18 | learn_bs=True, **constraint_sample_params): 19 | """ 20 | Initializes the learnable constraint ste 21 | @param num_constraints: int, cardinality of the learned constraint set 22 | @param num_variables: int, number of variables in the constraints 23 | @param variable_range: dict(lb, ub), range of variables in the ILP 24 | @param normalize_constraints: boolean flag, if true all constraints are normalized to unit norm 25 | @param learn_offsets: boolean flag, if false the initial origin offsets are not further learned 26 | @param learn_bs: boolean flag, if false the initial constraint bias terms are not further learned 27 | @param constraint_sample_params: dict, additional parameters for the sampling of initial constraint set 28 | """ 29 | super().__init__() 30 | constraints_in_offset_system, offsets = sample_offset_constraints_numpy(variable_range=variable_range, 31 | num_variables=num_variables, 32 | num_constraints=num_constraints, 33 | request_offset_const=True, 34 | **constraint_sample_params) 35 | 36 | offsets = torch_parameter_from_numpy(offsets) 37 | self.A = torch_parameter_from_numpy(constraints_in_offset_system[..., :-1]) 38 | b = torch_parameter_from_numpy(constraints_in_offset_system[..., -1]) 39 | self.b = b if learn_bs else b.detach() 40 | self.offsets = offsets if learn_offsets else offsets.detach() 41 | 42 | self.normalize_constraints = normalize_constraints 43 | 44 | def forward(self): 45 | """ 46 | @return: current set of learned constraints, with representation Ab such that A @ y + b <= 0 47 | """ 48 | constraints_in_offset_system = torch.cat([self.A, self.b[..., None]], dim=-1) 49 | constraints = compute_constraints_in_base_coordinate_system( 50 | constraints_in_offset_system=constraints_in_offset_system, offsets=self.offsets) 51 | 52 | if self.normalize_constraints: 53 | constraints = compute_normalized_constraints(constraints) 54 | return constraints 55 | 56 | 57 | class CvxpyModule(torch.nn.Module): 58 | def __init__(self, variable_range, use_entropy): 59 | """ 60 | @param variable_range: dict(lb, ub), range of variables in the LP 61 | """ 62 | super().__init__() 63 | self.variable_range = variable_range 64 | self.ilp_solver = CombOptNetModule(variable_range=self.variable_range) 65 | self.solver = None 66 | self.use_entropy = use_entropy 67 | 68 | def init_solver(self, num_variables, num_constraints, lb, ub): 69 | _p = cp.Parameter(num_variables) 70 | _A = cp.Parameter([num_constraints, num_variables]) 71 | _b = cp.Parameter(num_constraints) 72 | _x = cp.Variable(num_variables) 73 | if self.use_entropy: 74 | cons = [_A @ _x + _b <= 0] 75 | obj = cp.Minimize(_p @ _x - sum(cp.entr(_x - lb) + cp.entr(ub - _x))) 76 | else: 77 | cons = [_A @ _x + _b <= 0, lb <= _x, _x <= ub] 78 | obj = cp.Minimize(_p @ _x) 79 | prob = cp.Problem(obj, cons) 80 | solver = CvxpyLayer(prob, parameters=[_A, _b, _p], variables=[_x]) 81 | return solver 82 | 83 | def forward(self, cost_vector, constraints): 84 | """ 85 | Forward pass of the CVXPY module running a differentiable LP solver 86 | @param cost_vector: torch.Tensor of shape (bs, num_variables) with batch of ILP cost vectors 87 | @param constraints: torch.Tensor of shape (bs, num_const, num_variables + 1) or (num_const, num_variables + 1) 88 | with (potentially batch of) ILP constraints 89 | @return: torch.Tensor of shape (bs, num_variables) with integer values capturing the solution of the LP 90 | """ 91 | A = constraints[..., :-1] 92 | b = constraints[..., -1] 93 | if self.solver is None: 94 | num_constraints, num_variables = A.shape[-2:] 95 | 96 | self.solver = self.init_solver(num_variables=num_variables, num_constraints=num_constraints, 97 | **self.variable_range) 98 | 99 | try: 100 | y, = self.solver(A, b, cost_vector) 101 | except SolverError as e: 102 | print(f'Dummy zero solution should be handled as special case.') 103 | y = torch.zeros_like(cost_vector).to(cost_vector.device) 104 | return y 105 | 106 | 107 | def get_solver_module(**params): 108 | solver_name = params.pop('solver_name') 109 | return solver_module_dict[solver_name](**params) 110 | 111 | 112 | solver_module_dict = dict(CombOptNet=CombOptNetModule, Cvxpy=CvxpyModule) 113 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | 5 | from models.models import KnapsackExtractWeightsCostFromEmbeddingMLP, baseline_mlp_dict 6 | from models.modules import get_solver_module, StaticConstraintModule, CvxpyModule, CombOptNetModule 7 | from utils.utils import loss_from_string, optimizer_from_string, set_seed, AvgMeters, compute_metrics, \ 8 | knapsack_round, compute_normalized_solution, compute_denormalized_solution, solve_unconstrained 9 | 10 | 11 | def get_trainer(trainer_name, **trainer_params): 12 | trainer_dict = dict(MLPTrainer=MLPBaselineTrainer, 13 | KnapsackConstraintLearningTrainer=KnapsackConstraintLearningTrainer, 14 | RandomConstraintLearningTrainer=RandomConstraintLearningTrainer) 15 | return trainer_dict[trainer_name](**trainer_params) 16 | 17 | 18 | class BaseTrainer(ABC): 19 | def __init__(self, train_iterator, test_iterator, use_cuda, optimizer_name, loss_name, optimizer_params, metadata, 20 | model_params, seed): 21 | set_seed(seed) 22 | self.use_cuda = use_cuda 23 | self.device = 'cuda' if self.use_cuda else 'cpu' 24 | 25 | self.train_iterator = train_iterator 26 | self.test_iterator = test_iterator 27 | 28 | self.true_variable_range = metadata['variable_range'] 29 | self.num_variables = metadata['num_variables'] 30 | self.variable_range = self.true_variable_range 31 | 32 | model_parameters = self.build_model(**model_params) 33 | self.optimizer = optimizer_from_string(optimizer_name)(model_parameters, **optimizer_params) 34 | self.loss_fn = loss_from_string(loss_name) 35 | 36 | @abstractmethod 37 | def build_model(self, **model_params): 38 | pass 39 | 40 | @abstractmethod 41 | def calculate_loss_metrics(self, **data_params): 42 | pass 43 | 44 | def train_epoch(self): 45 | self.train = True 46 | metrics = AvgMeters() 47 | 48 | for i, data in enumerate(self.train_iterator): 49 | x, y_true_norm = [dat.to(self.device) for dat in data] 50 | loss, metric_dct = self.calculate_loss_metrics(x=x, y_true_norm=y_true_norm) 51 | metrics.update(metric_dct, n=x.size(0)) 52 | 53 | self.optimizer.zero_grad() 54 | loss.backward(retain_graph=True) 55 | self.optimizer.step() 56 | 57 | results = metrics.get_averages(prefix='train_') 58 | return results 59 | 60 | def evaluate(self): 61 | self.train = False 62 | metrics = AvgMeters() 63 | 64 | for i, data in enumerate(self.test_iterator): 65 | x, y_true_norm = [dat.to(self.device) for dat in data] 66 | loss, metric_dct = self.calculate_loss_metrics(x=x, y_true_norm=y_true_norm) 67 | metrics.update(metric_dct, n=x.size(0)) 68 | 69 | results = metrics.get_averages(prefix='eval_') 70 | return results 71 | 72 | 73 | class MLPBaselineTrainer(BaseTrainer): 74 | def build_model(self, model_name, **model_params): 75 | self.model = baseline_mlp_dict[model_name](num_variables=self.num_variables, **model_params).to( 76 | self.device) 77 | return self.model.parameters() 78 | 79 | def calculate_loss_metrics(self, x, y_true_norm): 80 | y_norm = self.model(x=x) 81 | loss = self.loss_fn(y_norm.double(), y_true_norm) 82 | 83 | metrics = dict(loss=loss.item()) 84 | y_denorm = compute_denormalized_solution(y_norm, **self.variable_range) 85 | y_denorm_rounded = torch.round(y_denorm) 86 | y_true_denorm = compute_denormalized_solution(y_true_norm, **self.true_variable_range) 87 | metrics.update(compute_metrics(y=y_denorm_rounded, y_true=y_true_denorm)) 88 | return loss, metrics 89 | 90 | 91 | class ConstraintLearningTrainerBase(BaseTrainer, ABC): 92 | @abstractmethod 93 | def forward(self, x): 94 | pass 95 | 96 | def calculate_loss_metrics(self, x, y_true_norm): 97 | y_denorm, y_denorm_roudned, solutions_denorm_dict, cost_vector = self.forward(x) 98 | y_norm = compute_normalized_solution(y_denorm, **self.variable_range) 99 | loss = self.loss_fn(y_norm.double(), y_true_norm) 100 | 101 | metrics = dict(loss=loss.item()) 102 | y_uncon_denorm = solve_unconstrained(cost_vector=cost_vector, **self.variable_range) 103 | y_true_denorm = compute_denormalized_solution(y_true_norm, **self.true_variable_range) 104 | metrics.update(compute_metrics(y=y_denorm_roudned, y_true=y_true_denorm, y_uncon=y_uncon_denorm)) 105 | for prefix, solution in solutions_denorm_dict.items(): 106 | metrics.update( 107 | compute_metrics(y=solution, y_true=y_true_denorm, y_uncon=y_uncon_denorm, prefix=prefix + "_")) 108 | return loss, metrics 109 | 110 | 111 | class RandomConstraintLearningTrainer(ConstraintLearningTrainerBase): 112 | def build_model(self, constraint_module_params, solver_module_params): 113 | self.static_constraint_module = StaticConstraintModule(variable_range=self.variable_range, 114 | num_variables=self.num_variables, 115 | **constraint_module_params).to(self.device) 116 | self.solver_module = get_solver_module(variable_range=self.variable_range, 117 | **solver_module_params).to(self.device) 118 | self.ilp_solver_module = CombOptNetModule(variable_range=self.variable_range).to(self.device) 119 | model_parameters = list(self.static_constraint_module.parameters()) + list(self.solver_module.parameters()) 120 | return model_parameters 121 | 122 | def forward(self, x): 123 | cost_vector = x 124 | cost_vector = cost_vector / torch.norm(cost_vector, p=2, dim=-1, keepdim=True) 125 | constraints = self.static_constraint_module() 126 | 127 | y_denorm = self.solver_module(cost_vector=cost_vector, constraints=constraints) 128 | y_denorm_rounded = torch.round(y_denorm) 129 | solutions_dict = {} 130 | 131 | if not self.train and isinstance(self.solver_module, CvxpyModule): 132 | y_denorm_ilp = self.ilp_solver_module(cost_vector=cost_vector, constraints=constraints) 133 | update_dict = dict(ilp_postprocess=y_denorm_ilp) 134 | solutions_dict.update(update_dict) 135 | 136 | return y_denorm, y_denorm_rounded, solutions_dict, cost_vector 137 | 138 | 139 | class KnapsackConstraintLearningTrainer(ConstraintLearningTrainerBase): 140 | def build_model(self, solver_module_params, backbone_module_params): 141 | self.backbone_module = KnapsackExtractWeightsCostFromEmbeddingMLP(**backbone_module_params).to(self.device) 142 | self.solver_module = get_solver_module(variable_range=self.variable_range, 143 | **solver_module_params).to(self.device) 144 | model_parameters = list(self.backbone_module.parameters()) + list(self.solver_module.parameters()) 145 | return model_parameters 146 | 147 | def forward(self, x): 148 | cost_vector, constraints = self.backbone_module(x) 149 | cost_vector = cost_vector / torch.norm(cost_vector, p=2, dim=-1, keepdim=True) 150 | 151 | y_denorm = self.solver_module(cost_vector=cost_vector, constraints=constraints) 152 | if isinstance(self.solver_module, CvxpyModule): 153 | y_denorm_rounded = knapsack_round(y_denorm=y_denorm, constraints=constraints, 154 | knapsack_capacity=self.backbone_module.knapsack_capacity) 155 | else: 156 | y_denorm_rounded = y_denorm 157 | return y_denorm, y_denorm_rounded, {}, cost_vector 158 | -------------------------------------------------------------------------------- /utils/comboptnet_utils.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | epsilon_constant = 1e-8 4 | 5 | 6 | def logsumexp(array, tau, axis): 7 | exp = jnp.exp(array / tau) 8 | sumexp = jnp.sum(exp, axis=axis) 9 | offset = jnp.log(array.shape[axis]) 10 | logsumexp = (jnp.log(sumexp) - offset) * tau 11 | return logsumexp 12 | 13 | 14 | def softmin(array, axis, tau): 15 | if tau is None: 16 | tau = 0.0 17 | assert tau >= 0.0 18 | if jnp.isclose(tau, 0.0): 19 | return jnp.min(array, axis=axis) 20 | else: 21 | return -logsumexp(-array, tau=tau, axis=axis) 22 | 23 | 24 | def compute_delta_y(y, y_grad, lb, ub, clip_gradients_to_box, use_canonical_basis): 25 | """ 26 | Computes integer basis delta_k from dy such that dy = sum_k(lambda_k * delta_k) 27 | with delta_k the signed indicator vector of the first k dominant directions of dy 28 | 29 | @param y: jnp.array of shape (bs, num_variables) with batch of ILP solutions 30 | @param y_grad: jnp.array of shape (bs, num_variables) with batch of incoming gradients dy for y 31 | @param lb: float, lower bound of variables in ILP 32 | @param ub: float, upper bound of variables in ILP 33 | @param clip_gradients_to_box: boolean flag, if true the gradients are projected into the feasible hypercube 34 | @param use_canonical_basis: boolean flag, if true canonical basis delta_k = e_k is used 35 | 36 | @return: jnp.array of shape (bs, num_variables, num_variables) with basis delta_k, 37 | jnp.array of shape (bs, num_variables) with (positive) weightings lambda_k of the delta_k 38 | """ 39 | bs, dim = y.shape 40 | y_grad = -y_grad 41 | 42 | if clip_gradients_to_box: 43 | y_test = y + jnp.sign(y_grad) 44 | inside_box_indicator = jnp.greater_equal(y_test, lb) * jnp.less_equal(y_test, ub) 45 | y_grad = y_grad * inside_box_indicator 46 | 47 | if use_canonical_basis: 48 | delta_y = jnp.sign(y_grad)[:, None, :] * jnp.eye(dim)[None, :, :] 49 | lambdas = jnp.abs(y_grad) 50 | else: 51 | sort_indices = jnp.argsort(jnp.abs(y_grad)) 52 | ranks_of_abs_y_grad = jnp.argsort(sort_indices) 53 | 54 | triangular_matrix = jnp.triu(jnp.ones((dim, dim))) 55 | permuted_triangular_matrix = jnp.take_along_axis(triangular_matrix[None, :, :], 56 | ranks_of_abs_y_grad[:, None, :], 57 | axis=-1) 58 | 59 | delta_y = permuted_triangular_matrix * jnp.sign(y_grad + epsilon_constant)[:, None, :] 60 | 61 | abs_grad_sorted = jnp.take_along_axis(jnp.abs(y_grad), sort_indices, axis=-1) 62 | sorted_grad_zero_append = jnp.concatenate((jnp.zeros((bs, 1)), abs_grad_sorted), axis=-1) 63 | lambdas = sorted_grad_zero_append[:, 1:] - sorted_grad_zero_append[:, :-1] 64 | return delta_y, lambdas 65 | 66 | 67 | def signed_euclidean_distance_constraint_point(constraints, point): 68 | constraint_lhs = jnp.sum(constraints[..., :-1] * point[..., None, :], axis=-1) + constraints[..., -1] 69 | distance_point_const = constraint_lhs / (jnp.linalg.norm(constraints[..., :-1], axis=-1) + epsilon_constant) 70 | return distance_point_const 71 | 72 | 73 | def check_point_feasibility(point, constraints, lb, ub): 74 | distance_point_const = signed_euclidean_distance_constraint_point(constraints=constraints, point=point) 75 | feasibility_indicator_constraints = jnp.all(jnp.less_equal(distance_point_const, 0.0), axis=-1) 76 | feasibility_indicator_lb = jnp.all(jnp.greater_equal(point, lb), axis=-1) 77 | feasibility_indicator_ub = jnp.all(jnp.less_equal(point, ub), axis=-1) 78 | y_prime_inside_box = feasibility_indicator_lb * feasibility_indicator_ub 79 | return feasibility_indicator_constraints, y_prime_inside_box 80 | 81 | 82 | def tensor_to_jax(tensor): 83 | return jnp.array(tensor.cpu().detach().numpy()) 84 | -------------------------------------------------------------------------------- /utils/constraint_generation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from utils.utils import general_sum, general_cat, general_l2norm, epsilon_constant, powerset, \ 4 | sample_at_least_one_of_each 5 | 6 | 7 | def sample_constraints(constraint_type, **params): 8 | if constraint_type == 'random_const': 9 | constraints = sample_offset_constraints_numpy(request_offset_const=False, **params) 10 | elif constraint_type == 'set_covering': 11 | constraints = get_set_covering_constraints_numpy(**params) 12 | else: 13 | raise NotImplementedError 14 | return constraints 15 | 16 | 17 | def get_set_covering_constraints_numpy(num_variables, num_constraints, variable_range, max_subset_size, seed): 18 | np.random.seed(seed) 19 | assert variable_range == dict(lb=0.0, ub=1.0) 20 | assert max_subset_size <= num_constraints 21 | 22 | universe = [i for i in range(num_constraints)] 23 | subsets = list(powerset(universe)) # remove empty and full set from powerset 24 | rel_subsets = list(filter(lambda subset: len(subset) <= max_subset_size, subsets[1:])) 25 | print(f'Num relevant subsets: {len(rel_subsets)}') 26 | assert num_variables <= len(rel_subsets) 27 | 28 | chosen_subsets = sample_at_least_one_of_each(subsets=rel_subsets, num_classes=num_constraints, 29 | num_samples=num_variables) 30 | 31 | A = np.zeros((num_constraints, num_variables)) 32 | for subset_idx, subset in enumerate(chosen_subsets): 33 | for elem in subset: 34 | A[elem, subset_idx] = -1 35 | b = np.ones(num_constraints) 36 | constraints = np.concatenate([A, b[:, None]], axis=1) 37 | return constraints 38 | 39 | 40 | def sample_offset_constraints_numpy(num_variables, num_constraints, variable_range, offset_sample_point="random_unif", 41 | request_offset_const=False, feasible_point=None, seed=None): 42 | """ 43 | Sample constraints represented as A and b, but with individual offset origins 44 | """ 45 | np.random.seed(seed) 46 | 47 | constraints_in_offset_system = sample_Ab_constraints_numpy(num_variables=num_variables, 48 | num_constraints=num_constraints) 49 | offsets = get_hypercube_point_numpy(shape=(num_constraints, num_variables), 50 | point=offset_sample_point, **variable_range) 51 | 52 | constraints_in_base_system = compute_constraints_in_base_coordinate_system( 53 | constraints_in_offset_system=constraints_in_offset_system, offsets=offsets) 54 | 55 | if feasible_point is not None: 56 | feasibility_indicator = check_feasibility_numpy(constraints=constraints_in_base_system, 57 | feasible_point=feasible_point, 58 | variable_range=variable_range) 59 | constraints_in_offset_system = constraints_in_offset_system * feasibility_indicator[:, None] 60 | 61 | if request_offset_const: 62 | return constraints_in_offset_system, offsets 63 | else: 64 | constraints = compute_constraints_in_base_coordinate_system( 65 | constraints_in_offset_system=constraints_in_offset_system, offsets=offsets) 66 | return constraints 67 | 68 | 69 | def sample_Ab_constraints_numpy(num_variables, num_constraints, b_value=0.2): 70 | b = np.ones(shape=[num_constraints]) * b_value 71 | A = np.random.rand(num_constraints, num_variables) - 0.5 72 | A_normalized = A / (np.linalg.norm(A, axis=1, keepdims=True) + epsilon_constant) 73 | constraints = general_cat([A_normalized, b[..., None]], axis=-1) 74 | return constraints 75 | 76 | 77 | def get_hypercube_point_numpy(point, shape, lb, ub, scale_sample_range=0.5): 78 | if point == 'center_hypercube': 79 | mean = (lb + ub) / 2 80 | return np.ones(shape) * mean 81 | elif point == 'origin': 82 | return np.zeros(shape) 83 | elif point == 'random_corner': 84 | rand_zero_ones = np.random.randint(2, size=shape) 85 | corner = lb * (1 - rand_zero_ones) + ub * rand_zero_ones 86 | return corner 87 | elif point == 'random_unif': 88 | mean = (lb + ub) / 2 89 | diff = ub - lb 90 | lb = mean - scale_sample_range * diff / 2 91 | ub = mean + scale_sample_range * diff / 2 92 | return np.random.rand(*shape) * (ub - lb) + lb 93 | else: 94 | raise NotImplementedError(f'Point {point} not implemented.') 95 | 96 | 97 | def check_feasibility_numpy(constraints, feasible_point, variable_range): 98 | num_variables = constraints.shape[-1] - 1 99 | decision_point = get_hypercube_point_numpy(point=feasible_point, shape=[num_variables], **variable_range) 100 | 101 | # return vector of 1s and -1s, depending on whether the constraint is feasible at the decision point or not 102 | constraint_eq_solution = np.sum(constraints[:, :-1] * decision_point[None, :], axis=1) + constraints[:, -1] 103 | is_feasible = (constraint_eq_solution <= 0.0).astype(float) 104 | feasibility_indicator = 2.0 * is_feasible - 1.0 105 | return feasibility_indicator 106 | 107 | 108 | def compute_normalized_constraints(constraints): 109 | A = constraints[..., :-1] 110 | norm = general_l2norm(A, axis=-1, keepdims=True) 111 | constraints = constraints / (norm + epsilon_constant) 112 | return constraints 113 | 114 | 115 | def compute_constraints_in_base_coordinate_system(constraints_in_offset_system, offsets): 116 | # offset is defined as the offset of the offset constraint coordinate system wrt to the base coordinate system 117 | # origin_offset_const = origin_base + offset 118 | A, b = constraints_in_offset_system[..., :-1], constraints_in_offset_system[..., -1] 119 | b_prime = b - general_sum(A * offsets, axis=-1) 120 | constraints_in_base_system = general_cat([A, b_prime[..., None]], axis=-1) 121 | return constraints_in_base_system 122 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import pickle 4 | import random 5 | import sys 6 | from itertools import chain, combinations 7 | 8 | import numpy as np 9 | import torch 10 | import yaml 11 | from torch import Tensor 12 | from torch.nn.parameter import Parameter 13 | 14 | epsilon_constant = 1e-8 15 | 16 | 17 | class AverageMeter(object): 18 | """Computes and stores the average and current value""" 19 | 20 | def __init__(self, name=None, fmt=":f"): 21 | self.name = name 22 | self.fmt = fmt 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / self.count 36 | 37 | def __str__(self): 38 | fmtstr = "{name} ({avg" + self.fmt + "})" 39 | return fmtstr.format(**self.__dict__) 40 | 41 | 42 | class AvgMeters(object): 43 | def __init__(self): 44 | self.metrics = {} 45 | 46 | def update_metric(self, name, value, n=1): 47 | if name not in self.metrics.keys(): 48 | self.metrics[name] = AverageMeter(name) 49 | self.metrics[name].update(value, n=n) 50 | 51 | def get_averages(self, prefix=''): 52 | return {prefix + key: avg_meter.avg for key, avg_meter in self.metrics.items()} 53 | 54 | def update(self, dct, n=1): 55 | for key, value in dct.items(): 56 | self.update_metric(name=key, value=value, n=n) 57 | 58 | def reset(self): 59 | self.metrics = {} 60 | 61 | 62 | def check_equal_ys(y_1, y_2, threshold=1e-5): 63 | equal_variable_wise = np.abs(y_1 - y_2) < threshold 64 | equal_instance_wise = equal_variable_wise.all(axis=1) 65 | return equal_variable_wise, equal_instance_wise 66 | 67 | 68 | def compute_metrics(y, y_true, y_uncon=None, prefix='', acc_threshold=1e-5): 69 | y = y.cpu().detach().numpy() 70 | metric_dict = {} 71 | 72 | y_true = y_true.cpu().detach().numpy() 73 | correct, correct_perfect = check_equal_ys(y, y_true, threshold=acc_threshold) 74 | metric_dict.update(dict(acc=correct, perfect_accuracy=correct_perfect)) 75 | 76 | if y_uncon is not None: 77 | y_uncon = y_uncon.cpu().detach().numpy() 78 | match_uncon, match_uncon_perfect = check_equal_ys(y, y_uncon, threshold=acc_threshold) 79 | metric_dict.update(dict(match_uncon_accuracy=match_uncon, match_uncon_perfect_accuracy=match_uncon_perfect)) 80 | 81 | metric_dict = {prefix + key: value.mean() for key, value in metric_dict.items()} 82 | return metric_dict 83 | 84 | 85 | def check_if_zero_or_one(y): 86 | return torch.all(torch.logical_or(torch.isclose(y, torch.ones_like(y)), torch.isclose(y, torch.zeros_like(y)))) 87 | 88 | 89 | class HammingLoss(torch.nn.Module): 90 | def forward(self, suggested, target): 91 | suggested += 0.5 # solutions are normalized to [-0.5, 0.5] 92 | target += 0.5 93 | if not (check_if_zero_or_one(suggested) and check_if_zero_or_one(target)): 94 | print(suggested, target) 95 | raise ValueError( 96 | f'Hamming loss only defined for zero/one predictions and targets. Instead received {suggested}, {target}.') 97 | errors = suggested * (1.0 - target) + (1.0 - suggested) * target 98 | return errors.mean(dim=0).sum() 99 | 100 | 101 | class IOULoss(torch.nn.Module): 102 | def forward(self, suggested, target): 103 | suggested += 0.5 # solutions are normalized to [-0.5, 0.5] 104 | target += 0.5 105 | if not (check_if_zero_or_one(suggested) and check_if_zero_or_one(target)): 106 | raise ValueError( 107 | f'Hamming loss only defined for zero/one predictions and targets. Instead received {suggested}, {target}.') 108 | matches = suggested * target # // true positives 109 | fps = torch.relu(suggested - target) # // false positives 110 | fns = torch.relu(target - suggested) # // false negatives 111 | iou = (matches.sum()) / (matches.sum() + fps.sum() + fns.sum()) 112 | return 1 - iou # // or (1-iou)*(1-iou) 113 | 114 | 115 | class L0Loss(torch.nn.Module): 116 | def forward(self, suggested, target): 117 | errors = (suggested - target).abs() 118 | return torch.max(errors, dim=-1)[0].mean() 119 | 120 | 121 | class HuberLoss(torch.nn.Module): 122 | def __init__(self, beta=0.3): 123 | self.beta = beta 124 | super(HuberLoss, self).__init__() 125 | 126 | def forward(self, suggested, target): 127 | errors = torch.abs(suggested - target) 128 | mask = errors < self.beta 129 | l2_errors = 0.5 * (errors ** 2) / self.beta 130 | l1_errors = errors - 0.5 * self.beta 131 | combined_errors = mask * l2_errors + ~mask * l1_errors 132 | return combined_errors.mean(dim=0).sum() 133 | 134 | 135 | def loss_from_string(loss_name): 136 | dct = {"Hamming": HammingLoss(), "MSE": torch.nn.MSELoss(), "L1": torch.nn.L1Loss(), "L0": L0Loss(), 137 | "IOU": IOULoss(), "Huber": HuberLoss(beta=0.3)} 138 | return dct[loss_name] 139 | 140 | 141 | def optimizer_from_string(optimizer_name): 142 | dct = {"Adam": torch.optim.Adam, "SGD": torch.optim.SGD, "RMSprop": torch.optim.RMSprop} 143 | return dct[optimizer_name] 144 | 145 | 146 | def save_pickle(data, path): 147 | with open(path, "wb") as fh: 148 | pickle.dump(data, fh) 149 | 150 | 151 | def load_pickle(path): 152 | with open(path, "rb") as fh: 153 | return pickle.load(fh) 154 | 155 | 156 | def set_seed(seed): 157 | if seed is not None: 158 | random.seed(seed) 159 | torch.manual_seed(seed) 160 | torch.cuda.manual_seed(seed) 161 | np.random.seed(seed) 162 | 163 | 164 | class ParallelProcessing: 165 | def __init__(self): 166 | self.remote_fns = dict() 167 | 168 | def ray_available(self): 169 | if 'ray' in sys.modules and sys.modules['ray'].is_initialized(): 170 | self.ray = sys.modules['ray'] 171 | return True 172 | else: 173 | return False 174 | 175 | def get_remote(self, function): 176 | if type(function) is self.ray.remote_function.RemoteFunction: 177 | return function.remote 178 | else: 179 | if function in self.remote_fns: 180 | remote_fn = self.remote_fns[function] 181 | else: 182 | remote_fn = self.ray.remote(function) 183 | self.remote_fns[function] = remote_fn 184 | assert isinstance(remote_fn, self.ray.remote_function.RemoteFunction) 185 | return remote_fn.remote 186 | 187 | def maybe_parallelize(self, function, static_kwargs, dynamic_kwargs): 188 | if self.ray_available(): 189 | remote_fn = self.get_remote(function) 190 | return self.ray.get([remote_fn(**static_kwargs, **kwargs) for kwargs in dynamic_kwargs]) 191 | else: 192 | return [function(**static_kwargs, **kwargs) for kwargs in dynamic_kwargs] 193 | 194 | 195 | def torch_parameter_from_numpy(array): 196 | param = Parameter(Tensor(*array.shape)) 197 | param.data = torch.from_numpy(array) 198 | return param 199 | 200 | 201 | def general_cat(list_of_arrays, axis): 202 | if isinstance(list_of_arrays[0], torch.Tensor): 203 | concatenated_array = torch.cat(list_of_arrays, dim=axis) 204 | elif isinstance(list_of_arrays[0], np.ndarray): 205 | concatenated_array = np.concatenate(list_of_arrays, axis=axis) 206 | else: 207 | raise NotImplementedError 208 | return concatenated_array 209 | 210 | 211 | def general_sum(array, axis): 212 | if isinstance(array, torch.Tensor): 213 | summed_array = torch.sum(array, dim=axis) 214 | elif isinstance(array, np.ndarray): 215 | summed_array = np.sum(array, axis=axis) 216 | else: 217 | raise NotImplementedError 218 | return summed_array 219 | 220 | 221 | def general_l2norm(array, axis, keepdims): 222 | if isinstance(array, torch.Tensor): 223 | norm = torch.norm(array, p=2, dim=axis, keepdim=keepdims) 224 | elif isinstance(array, np.ndarray): 225 | norm = np.linalg.norm(array, axis=axis, keepdims=keepdims) 226 | else: 227 | raise NotImplementedError 228 | return norm 229 | 230 | 231 | def compute_normalized_solution(y, lb, ub): 232 | mean = (lb + ub) / 2 233 | size = ub - lb 234 | y_normalized = (y - mean) / size 235 | return y_normalized 236 | 237 | 238 | def compute_denormalized_solution(y_normalized, lb, ub): 239 | mean = (ub + lb) / 2 240 | size = ub - lb 241 | y = y_normalized * size + mean 242 | return y 243 | 244 | 245 | def solve_unconstrained(cost_vector, lb, ub): 246 | mean = (ub + lb) / 2 247 | size = ub - lb 248 | # indicator is -1 if less than, 0 if equal to and 1 if greater than zero 249 | if isinstance(cost_vector, torch.Tensor): 250 | indicator = (cost_vector >= 0).to(torch.float) - (cost_vector <= 0).to(torch.float) 251 | else: 252 | indicator = (cost_vector >= 0).astype(float) - (cost_vector <= 0).astype(float) 253 | # minus because we minimize 254 | y = mean - indicator * size / 2 255 | return y 256 | 257 | 258 | def knapsack_round(y_denorm, constraints, knapsack_capacity): 259 | # for cvxpy knapsack 260 | weights = constraints[:, 0, :-1] 261 | n_batch, n_sol = y_denorm.shape 262 | rounded_sol = torch.zeros_like(y_denorm) 263 | for batch_i in range(n_batch): 264 | # Add indices until we hit capacity 265 | sol_i_sort = y_denorm[batch_i].sort(descending=True).indices 266 | for j in range(n_sol): 267 | candidate = rounded_sol[batch_i].clone() 268 | candidate[sol_i_sort[j]] = 1. 269 | if sum(weights[batch_i] * candidate) <= knapsack_capacity: 270 | rounded_sol[batch_i] = candidate 271 | else: 272 | break 273 | return rounded_sol 274 | 275 | 276 | def load_with_default_yaml(path): 277 | base_path = os.getcwd() 278 | with open(os.path.join(base_path, path), 'r') as stream: 279 | param_dict = yaml.safe_load(stream) 280 | default_yaml = param_dict.pop('__import__', None) 281 | if default_yaml is not None: 282 | with open(os.path.join(base_path, default_yaml), 'r') as stream: 283 | default_dict = yaml.safe_load(stream) 284 | param_dict = merge(param_dict, default_dict) 285 | return param_dict 286 | 287 | 288 | def merge(source, destination): 289 | for key, value in source.items(): 290 | if isinstance(value, dict): 291 | node = destination.setdefault(key, {}) 292 | merge(value, node) 293 | else: 294 | destination[key] = value 295 | 296 | return destination 297 | 298 | 299 | def powerset(iterable): 300 | "list(powerset([1,2,3])) --> [(), (1,), (2,), (3,), (1,2), (1,3), (2,3), (1,2,3)]" 301 | s = list(iterable) 302 | return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) 303 | 304 | 305 | def sample_at_least_one_of_each(subsets, num_classes, num_samples, num_iterations=1000): 306 | min_number_of_subsets_containing_idx = 1 307 | for _ in range(num_iterations): 308 | good_sample = True 309 | samples = random.sample(subsets, num_samples) 310 | # ensure each class is present in the drawn subsets 311 | for i in range(num_classes): 312 | num_subsets_containing_i = len([s for s in samples if i in s]) 313 | if num_subsets_containing_i <= min_number_of_subsets_containing_idx: 314 | good_sample = False 315 | break 316 | if good_sample: 317 | return samples 318 | raise ValueError( 319 | f'Could not draw {num_samples} samples from {len(subsets)} subsets that contain a representative of each of ' 320 | f'the {num_classes} classes in {num_iterations} iterations.') 321 | 322 | 323 | def print_eval_acc(metrics): 324 | print(f"Evaluation:: Loss: {metrics['eval_loss']:.4f}, " 325 | f"Perfect acc: {metrics['eval_perfect_accuracy']:.4f}") 326 | 327 | 328 | def print_train_acc(metrics, epoch): 329 | print(f"Epoch: {epoch + 1:>2}, Train loss: {metrics['train_loss']:.4f}, " 330 | f"Perfect acc: {metrics['train_perfect_accuracy']:.4f}") 331 | 332 | 333 | def save_dict_as_one_line_csv(dct, filename): 334 | os.makedirs(os.path.dirname(filename), exist_ok=True) 335 | with open(filename, 'w') as f: 336 | writer = csv.DictWriter(f, fieldnames=dct.keys()) 337 | writer.writeheader() 338 | writer.writerow(dct) 339 | --------------------------------------------------------------------------------