├── .gitignore ├── comb_modules ├── losses.py ├── utils.py └── dijkstra.py ├── Pipfile ├── constants.py ├── settings └── warcraft_shortest_path │ ├── 12x12_baseline.json │ └── 12x12_combresnet.json ├── warcraft_shortest_path ├── metrics.py ├── data_utils.py ├── visualization.py └── trainers.py ├── requirements.txt ├── README.md ├── decorators.py ├── main.py ├── logger.py ├── models.py ├── data └── utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data/globe_tsp/ 2 | /data/mnist_matching/ 3 | /data/warcraft_maps/ 4 | -------------------------------------------------------------------------------- /comb_modules/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class HammingLoss(torch.nn.Module): 5 | def forward(self, suggested, target): 6 | errors = suggested * (1.0 - target) + (1.0 - suggested) * target 7 | return errors.mean(dim=0).sum() 8 | # return (torch.mean(suggested*(1.0-target)) + torch.mean((1.0-suggested)*target)) * 25.0 9 | 10 | 11 | # 12 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | jupyter = "*" 8 | 9 | [packages] 10 | numpy = "==1.16.1" 11 | torch = "==1.2.0" 12 | matplotlib = "==3.0.2" 13 | pipenv = "*" 14 | gitpython = "==2.1.11" 15 | sklearn = "*" 16 | scipy = "*" 17 | tensorboardx = "*" 18 | torchvision = "==0.4.0" 19 | pathlib2 = "*" 20 | tqdm = "*" 21 | ray = "*" 22 | psutil = "*" 23 | setproctitle = "*" 24 | ipympl = "*" 25 | ipyvolume = "*" 26 | pillow = "*" 27 | basemap = {git = "https://github.com/matplotlib/basemap.git"} 28 | 29 | [requires] 30 | python_version = "3.6" 31 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | CLUSTER_PARAM_FILE = 'param_choice.csv' 2 | CLUSTER_METRIC_FILE = 'metrics.csv' 3 | JSON_SETTINGS_FILE = 'settings.json' 4 | JOB_INFO_FILE = 'job_info.csv' 5 | 6 | STATUS_PICKLE_FILE = 'status.pickle' 7 | FULL_DF_FILE = 'all_data.csv' 8 | REDUCED_DF_FILE = 'reduced_data.csv' 9 | STD_ENDING = '__std' 10 | RESTART_PARAM_NAME = 'restarts' 11 | 12 | JSON_FILE_KEY = 'default_json' 13 | OBJECT_SEPARATOR = '.' 14 | 15 | PARAM_TYPES = (bool, str, int, float, tuple, dict) 16 | 17 | RESERVED_PARAMS = ('model_dir', 'id', 'iteration', RESTART_PARAM_NAME, 'cluster_job_id') 18 | 19 | DISTR_BASE_COLORS = [(0.99, 0.7, 0.18), (0.7, 0.7, 0.9), (0.56, 0.692, 0.195), (0.923, 0.386, 0.209)] 20 | -------------------------------------------------------------------------------- /settings/warcraft_shortest_path/12x12_baseline.json: -------------------------------------------------------------------------------- 1 | { 2 | "evaluate_every": 10, 3 | "loader_params": { 4 | "data_dir": "data/warcraft_shortest_path/12x12", 5 | "evaluate_with_extra": false, 6 | "normalize": true, 7 | "use_local_path": false, 8 | "use_test_set": true 9 | }, 10 | "model_dir": "results/warcraft_shortest_path_baseline", 11 | "num_cpus": 1, 12 | "num_epochs": 150, 13 | "problem_type": "warcraft_shortest_path", 14 | "ray_params": {}, 15 | "save_visualizations": false, 16 | "seed": 1, 17 | "trainer_name": "Baseline", 18 | "trainer_params": { 19 | "batch_size": 70, 20 | "lr_milestone_1": 100, 21 | "lr_milestone_2": 130, 22 | "model_params": { 23 | "arch_params": {}, 24 | "model_name": "ResNet18" 25 | }, 26 | "neighbourhood_fn": "8-grid", 27 | "optimizer_name": "Adam", 28 | "optimizer_params": { 29 | "lr": 0.0005 30 | }, 31 | "preload_batch": true, 32 | "use_cuda": true, 33 | "use_lr_scheduling": false 34 | }, 35 | "use_ray": true 36 | } 37 | -------------------------------------------------------------------------------- /settings/warcraft_shortest_path/12x12_combresnet.json: -------------------------------------------------------------------------------- 1 | { 2 | "evaluate_every": 5, 3 | "loader_params": { 4 | "data_dir": "data/warcraft_shortest_path/12x12", 5 | "evaluate_with_extra": false, 6 | "normalize": true, 7 | "use_local_path": false, 8 | "use_test_set": true 9 | }, 10 | "model_dir": "results/warcraft_shortest_path_combresnet", 11 | "num_cpus": 4, 12 | "num_epochs": 50, 13 | "problem_type": "warcraft_shortest_path", 14 | "ray_params": {}, 15 | "save_visualizations": false, 16 | "seed": 1, 17 | "trainer_name": "DijkstraOnFull", 18 | "trainer_params": { 19 | "batch_size": 70, 20 | "l1_regconst": 0.0, 21 | "lambda_val": 20.0, 22 | "lr_milestone_1": 30, 23 | "lr_milestone_2": 40, 24 | "model_params": { 25 | "arch_params": {}, 26 | "model_name": "CombResnet18" 27 | }, 28 | "neighbourhood_fn": "8-grid", 29 | "optimizer_name": "Adam", 30 | "optimizer_params": { 31 | "lr": 0.0005 32 | }, 33 | "preload_batch": true, 34 | "use_cuda": true, 35 | "use_lr_scheduling": true 36 | }, 37 | "use_ray": true 38 | } 39 | -------------------------------------------------------------------------------- /warcraft_shortest_path/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from decorators import input_to_numpy, none_if_missing_arg 3 | from utils import all_accuracies 4 | from comb_modules.utils import edges_from_grid 5 | from comb_modules.dijkstra import dijkstra 6 | import itertools 7 | 8 | @none_if_missing_arg 9 | def perfect_match_accuracy(true_paths, suggested_paths): 10 | matching_correct = np.sum(np.abs(true_paths - suggested_paths), axis=-1) 11 | avg_matching_correct = (matching_correct < 0.5).mean() 12 | return avg_matching_correct 13 | 14 | 15 | @none_if_missing_arg 16 | def cost_ratio(vertex_costs, true_paths, suggested_paths): 17 | suggested_paths_costs = suggested_paths * vertex_costs 18 | true_paths_costs = true_paths * vertex_costs 19 | return (np.sum(suggested_paths_costs, axis=1) / np.sum(true_paths_costs, axis=1)).mean() 20 | 21 | 22 | @input_to_numpy 23 | def compute_metrics(true_paths, suggested_paths, true_vertex_costs): 24 | batch_size = true_vertex_costs.shape[0] 25 | metrics = { 26 | "perfect_match_accuracy": perfect_match_accuracy(true_paths.reshape(batch_size,-1), suggested_paths.reshape(batch_size,-1)), 27 | "cost_ratio_suggested_true": cost_ratio(true_vertex_costs, true_paths, suggested_paths), 28 | **all_accuracies(true_paths, suggested_paths, true_vertex_costs, is_valid_label_fn,6) 29 | } 30 | return metrics 31 | 32 | 33 | def is_valid_label_fn(suggested_path): 34 | inverted_path = 1.-suggested_path 35 | shortest_path, _, _ = dijkstra(inverted_path) 36 | is_valid = (shortest_path * inverted_path).sum() == 0 37 | return is_valid 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | atomicwrites==1.3.0 2 | attrs==19.1.0 3 | backcall==0.1.0 4 | bleach==3.1.0 5 | certifi==2019.9.11 6 | chardet==3.0.4 7 | Click==7.0 8 | colorama==0.4.1 9 | cycler==0.10.0 10 | decorator==4.4.0 11 | defusedxml==0.6.0 12 | entrypoints==0.3 13 | ffmpeg==1.4 14 | filelock==3.0.12 15 | funcsigs==1.0.2 16 | gitdb2==2.0.5 17 | GitPython==2.1.11 18 | idna==2.8 19 | imageio==2.5.0 20 | importlib-metadata==0.23 21 | ipydatawidgets==4.0.1 22 | ipykernel==5.1.2 23 | ipympl==0.3.3 24 | ipython==7.8.0 25 | ipython-genutils==0.2.0 26 | ipyvolume==0.5.2 27 | ipywebrtc==0.5.0 28 | ipywidgets==7.5.1 29 | jedi==0.15.1 30 | Jinja2==2.10.1 31 | joblib==0.13.2 32 | jsonschema==3.0.2 33 | jupyter==1.0.0 34 | jupyter-client==5.3.3 35 | jupyter-console==6.0.0 36 | jupyter-core==4.5.0 37 | kiwisolver==1.1.0 38 | MarkupSafe==1.1.1 39 | matplotlib==3.0.2 40 | mistune==0.8.4 41 | more-itertools==7.2.0 42 | nbconvert==5.6.0 43 | nbformat==4.4.0 44 | notebook==6.0.1 45 | numpy==1.16.1 46 | opencv-python==4.1.1.26 47 | packaging==19.2 48 | pandas==0.24.1 49 | pandocfilters==1.4.2 50 | parso==0.5.1 51 | pathlib2==2.3.4 52 | pexpect==4.7.0 53 | pickleshare==0.7.5 54 | Pillow==6.1.0 55 | pipenv==2018.11.26 56 | pluggy==0.13.0 57 | prometheus-client==0.7.1 58 | prompt-toolkit==2.0.9 59 | protobuf==3.9.2 60 | psutil==5.6.3 61 | ptyprocess==0.6.0 62 | py==1.8.0 63 | Pygments==2.4.2 64 | pyparsing==2.4.2 65 | pyrsistent==0.15.4 66 | pytest==5.1.3 67 | python-dateutil==2.8.0 68 | pythreejs==2.1.1 69 | pytz==2019.2 70 | PyYAML==5.1.2 71 | pyzmq==18.1.0 72 | qtconsole==4.5.5 73 | ray==0.7.5 74 | redis==3.3.8 75 | requests==2.22.0 76 | scikit-learn==0.21.3 77 | scipy==1.3.1 78 | seaborn==0.9.0 79 | Send2Trash==1.5.0 80 | setproctitle==1.1.10 81 | six==1.12.0 82 | sklearn==0.0 83 | smmap2==2.0.5 84 | tensorboardX==1.8 85 | terminado==0.8.2 86 | testpath==0.4.2 87 | torch==1.2.0 88 | torchvision==0.4.0 89 | tornado==6.0.3 90 | tqdm==4.36.1 91 | traitlets==4.3.2 92 | traittypes==0.2.1 93 | urllib3==1.25.6 94 | virtualenv==16.7.5 95 | virtualenv-clone==0.5.3 96 | wcwidth==0.1.7 97 | webencodings==0.5.1 98 | widgetsnbextension==3.5.1 99 | zipp==0.6.0 100 | git+git://github.com/matplotlib/basemap.git -------------------------------------------------------------------------------- /comb_modules/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import functools 3 | import numpy as np 4 | 5 | 6 | def neighbours_8(x, y, x_max, y_max): 7 | deltas_x = (-1, 0, 1) 8 | deltas_y = (-1, 0, 1) 9 | for (dx, dy) in itertools.product(deltas_x, deltas_y): 10 | x_new, y_new = x + dx, y + dy 11 | if 0 <= x_new < x_max and 0 <= y_new < y_max and (dx, dy) != (0, 0): 12 | yield x_new, y_new 13 | 14 | 15 | def neighbours_4(x, y, x_max, y_max): 16 | for (dx, dy) in [(1, 0), (0, 1), (0, -1), (-1, 0)]: 17 | x_new, y_new = x + dx, y + dy 18 | if 0 <= x_new < x_max and 0 <= y_new < y_max and (dx, dy) != (0, 0): 19 | yield x_new, y_new 20 | 21 | 22 | def get_neighbourhood_func(neighbourhood_fn): 23 | if neighbourhood_fn == "4-grid": 24 | return neighbours_4 25 | elif neighbourhood_fn == "8-grid": 26 | return neighbours_8 27 | else: 28 | raise Exception(f"neighbourhood_fn of {neighbourhood_fn} not possible") 29 | 30 | 31 | def edges_from_vertex(x, y, N, neighbourhood_fn): 32 | v = (x, y) 33 | neighbours = get_neighbourhood_func(neighbourhood_fn)(*v, x_max=N, y_max=N) 34 | v_edges = [ 35 | (*v, *vn) for vn in neighbours if vertex_index(v, N) < vertex_index(vn, N) 36 | ] # Enforce ordering on vertices 37 | return v_edges 38 | 39 | 40 | def vertex_index(v, dim): 41 | x, y = v 42 | return x * dim + y 43 | 44 | 45 | @functools.lru_cache(32) 46 | def edges_from_grid(N, neighbourhood_fn): 47 | all_vertices = itertools.product(range(N), range(N)) 48 | all_edges = [edges_from_vertex(x, y, N, neighbourhood_fn=neighbourhood_fn) for x, y in all_vertices] 49 | all_edges_flat = sum(all_edges, []) 50 | all_edges_flat_unique = list(set(all_edges_flat)) 51 | return np.asarray(all_edges_flat_unique) 52 | 53 | 54 | @functools.lru_cache(32) 55 | def cached_vertex_grid_to_edges_grid_coords(grid_dim: int): 56 | edges_grid_idxs = edges_from_grid(grid_dim, neighbourhood_fn="4-grid") 57 | return edges_grid_idxs[:, 0], edges_grid_idxs[:, 1], edges_grid_idxs[:, 2], edges_grid_idxs[:, 3] 58 | 59 | 60 | @functools.lru_cache(32) 61 | def cached_vertex_grid_to_edges(grid_dim: int): 62 | x, y, xn, yn = cached_vertex_grid_to_edges_grid_coords(grid_dim) 63 | return np.vstack([vertex_index((x, y), grid_dim), vertex_index((xn, yn), grid_dim)]).T 64 | -------------------------------------------------------------------------------- /warcraft_shortest_path/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import os 5 | 6 | from comb_modules.dijkstra import dijkstra 7 | from decorators import input_to_numpy 8 | from utils import TrainingIterator 9 | 10 | def load_dataset(data_dir, use_test_set, evaluate_with_extra, normalize, use_local_path): 11 | train_prefix = "train" 12 | data_suffix = "maps" 13 | true_weights_suffix = "" 14 | 15 | val_prefix = ("test" if use_test_set else "val") + ("_extra" if evaluate_with_extra else "") 16 | train_data_path = os.path.join(data_dir, train_prefix + "_" + data_suffix + ".npy") 17 | 18 | if os.path.exists(train_data_path): 19 | train_inputs = np.load(os.path.join(data_dir, train_prefix + "_" + data_suffix + ".npy")).astype(np.float32) 20 | # noinspection PyTypeChecker 21 | train_inputs = train_inputs.transpose(0, 3, 1, 2) # channel first 22 | 23 | train_labels = np.load(os.path.join(data_dir, train_prefix + "_shortest_paths.npy")) 24 | train_true_weights = np.load(os.path.join(data_dir, train_prefix + "_vertex_weights.npy")) 25 | if normalize: 26 | mean, std = ( 27 | np.mean(train_inputs, axis=(0, 2, 3), keepdims=True), 28 | np.std(train_inputs, axis=(0, 2, 3), keepdims=True), 29 | ) 30 | train_inputs -= mean 31 | train_inputs /= std 32 | train_iterator = TrainingIterator(dict(images=train_inputs, labels=train_labels, true_weights=train_true_weights)) 33 | else: 34 | raise Exception(f"Cannot find {train_data_path}") 35 | 36 | val_inputs = np.load(os.path.join(data_dir, val_prefix + "_" + data_suffix + ".npy")).astype(np.float32) 37 | # noinspection PyTypeChecker 38 | val_inputs = val_inputs.transpose(0, 3, 1, 2) # channel first 39 | 40 | if normalize: 41 | # noinspection PyUnboundLocalVariable 42 | val_inputs -= mean 43 | # noinspection PyUnboundLocalVariable 44 | val_inputs /= std 45 | 46 | val_labels = np.load(os.path.join(data_dir, val_prefix + "_shortest_paths.npy")) 47 | val_true_weights = np.load(os.path.join(data_dir, val_prefix + "_vertex_weights.npy")) 48 | val_full_images = np.load(os.path.join(data_dir, val_prefix + "_maps.npy")) 49 | eval_iterator = TrainingIterator( 50 | dict(images=val_inputs, labels=val_labels, true_weights=val_true_weights) 51 | ) 52 | 53 | @input_to_numpy 54 | def denormalize(x): 55 | return (x * std) + mean 56 | 57 | metadata = { 58 | "input_image_size": val_full_images[0].shape[1], 59 | "output_features": val_true_weights[0].shape[0] * val_true_weights[0].shape[1], 60 | "num_channels": val_full_images[0].shape[-1], 61 | "denormalize": denormalize 62 | } 63 | 64 | return train_iterator, eval_iterator, metadata 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Differentiation of Blackbox Combinatorial Solvers 2 | 3 | This repository provides code for the paper [Differentiation of Blackbox Combinatorial Solvers](http://arxiv.org/abs/1912.02175). 4 | 5 | By Marin Vlastelica*, Anselm Paulus*, Vít Musil, [Georg Martius](http://georg.playfulmachines.com/) and [Michal Rolínek](https://scholar.google.de/citations?user=DVdSTFQAAAAJ&hl=en). 6 | 7 | [Autonomous Learning Group](https://al.is.tuebingen.mpg.de/), [Max Planck Institute for Intelligent Systems](https://is.tuebingen.mpg.de/). 8 | 9 | For a condensed version containing only the modules (along with additional solvers) see [this repository](https://github.com/martius-lab/blackbox-backprop). 10 | 11 | ## Table of Contents 12 | 0. [Introduction](#introduction) 13 | 1. [Installation](#installation) 14 | 2. [Usage](#usage) 15 | 3. [Notes](#notes) 16 | 17 | 18 | 19 | ## Introduction 20 | 21 | This repository provides a visualization for all the datasets used in 22 | [Differentiation of Blackbox Combinatorial Solvers](http://arxiv.org/abs/1912.02175). 23 | Additionally, the training code for the Warcraft Shortest Path experiment is provided. 24 | 25 | *Disclaimer*: This code is a PROTOTYPE. It should work fine but use at your own risk. 26 | 27 | ## Installation 28 | 29 | First install the requirements with via one of the following options: 30 | 31 | - Option 1 (requires pipenv): 32 | 33 | ``pipenv install`` (use --skip-lock flag for speedup at your own risk) 34 | 35 | ``pipenv shell`` 36 | 37 | If the installation causes problems, the following could be a fix: 38 | ``sudo apt install python3.6-dev`` 39 | 40 | - Option 2 (requires python 3.6): 41 | 42 | ``pip3 install -r requirements.txt`` 43 | 44 | 45 | Next, download our [datasets](https://edmond.mpdl.mpg.de/imeji/collection/tGU9ok0_m2CVfHI8?q=) and extract each dataset to the data/ directory in this repository. 46 | 47 | 48 | ## Usage 49 | 50 | - Dataset visualization: 51 | 52 | For an easy overview over the datasets use the data/data_visualization.ipynb jupyter notebook. 53 | 54 | - Warcraft shortest path experiment: 55 | 56 | Change directory to project root folder. 57 | 58 | Run Warcraft shortest path experiment with gradients through Dijkstra: 59 | 60 | ``python main.py settings/warcraft_shortest_path/12x12_combresnet.json`` 61 | 62 | Run Warcraft shortest path baseline ResNet18 experiment: 63 | 64 | ``python main.py settings/warcraft_shortest_path/12x12_baseline.json`` 65 | 66 | The results are stored in the results directory in the project root folder. 67 | 68 | 69 | ## Notes 70 | 71 | *Contribute*: If you spot a bug or some incompatibility, raise an issue or contribute via a pull request! Thank you! 72 | -------------------------------------------------------------------------------- /decorators.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | 3 | import numpy as np 4 | import torch 5 | from abc import ABC, abstractmethod 6 | 7 | from functools import update_wrapper, partial 8 | 9 | 10 | class Decorator(ABC): 11 | def __init__(self, f): 12 | self.func = f 13 | update_wrapper(self, f, updated=[]) # updated=[] so that 'self' attributes are not overwritten 14 | 15 | @abstractmethod 16 | def __call__(self, *args, **kwargs): 17 | pass 18 | 19 | def __get__(self, instance, owner): 20 | new_f = partial(self.__call__, instance) 21 | update_wrapper(new_f, self.func) 22 | return new_f 23 | 24 | 25 | def to_tensor(x): 26 | if isinstance(x, np.ndarray) or np.isscalar(x): 27 | return torch.from_numpy(np.array(x)).float() 28 | else: 29 | return x 30 | 31 | 32 | def to_numpy(x): 33 | if isinstance(x, torch.Tensor): 34 | return x.cpu().detach().numpy() 35 | else: 36 | return x 37 | 38 | 39 | # noinspection PyPep8Naming 40 | class input_to_tensors(Decorator): 41 | def __call__(self, *args, **kwargs): 42 | new_args = [to_tensor(arg) for arg in args] 43 | new_kwargs = {key: to_tensor(value) for key, value in kwargs.items()} 44 | return self.func(*new_args, **new_kwargs) 45 | 46 | 47 | # noinspection PyPep8Naming 48 | class output_to_tensors(Decorator): 49 | def __call__(self, *args, **kwargs): 50 | outputs = self.func(*args, **kwargs) 51 | if isinstance(outputs, np.ndarray): 52 | return to_tensor(outputs) 53 | if isinstance(outputs, tuple): 54 | new_outputs = tuple([to_tensor(item) for item in outputs]) 55 | return new_outputs 56 | return outputs 57 | 58 | 59 | # noinspection PyPep8Naming 60 | class input_to_numpy(Decorator): 61 | def __call__(self, *args, **kwargs): 62 | new_args = [to_numpy(arg) for arg in args] 63 | new_kwargs = {key: to_numpy(value) for key, value in kwargs.items()} 64 | return self.func(*new_args, **new_kwargs) 65 | 66 | 67 | # noinspection PyPep8Naming 68 | class output_to_numpy(Decorator): 69 | def __call__(self, *args, **kwargs): 70 | outputs = self.func(*args, **kwargs) 71 | if isinstance(outputs, torch.Tensor): 72 | return to_numpy(outputs) 73 | if isinstance(outputs, tuple): 74 | new_outputs = tuple([to_numpy(item) for item in outputs]) 75 | return new_outputs 76 | return outputs 77 | 78 | 79 | # noinspection PyPep8Naming 80 | class none_if_missing_arg(Decorator): 81 | def __call__(self, *args, **kwargs): 82 | for arg in chain(args, kwargs.values()): 83 | if arg is None: 84 | return None 85 | 86 | return self.func(*args, **kwargs) 87 | -------------------------------------------------------------------------------- /comb_modules/dijkstra.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import heapq 3 | import torch 4 | from functools import partial 5 | from comb_modules.utils import get_neighbourhood_func 6 | from collections import namedtuple 7 | from utils import maybe_parallelize 8 | 9 | DijkstraOutput = namedtuple("DijkstraOutput", ["shortest_path", "is_unique", "transitions"]) 10 | 11 | 12 | def dijkstra(matrix, neighbourhood_fn="8-grid", request_transitions=False): 13 | 14 | x_max, y_max = matrix.shape 15 | neighbors_func = partial(get_neighbourhood_func(neighbourhood_fn), x_max=x_max, y_max=y_max) 16 | 17 | costs = np.full_like(matrix, 1.0e10) 18 | costs[0][0] = matrix[0][0] 19 | num_path = np.zeros_like(matrix) 20 | num_path[0][0] = 1 21 | priority_queue = [(matrix[0][0], (0, 0))] 22 | certain = set() 23 | transitions = dict() 24 | 25 | while priority_queue: 26 | cur_cost, (cur_x, cur_y) = heapq.heappop(priority_queue) 27 | if (cur_x, cur_y) in certain: 28 | pass 29 | 30 | for x, y in neighbors_func(cur_x, cur_y): 31 | if (x, y) not in certain: 32 | if matrix[x][y] + costs[cur_x][cur_y] < costs[x][y]: 33 | costs[x][y] = matrix[x][y] + costs[cur_x][cur_y] 34 | heapq.heappush(priority_queue, (costs[x][y], (x, y))) 35 | transitions[(x, y)] = (cur_x, cur_y) 36 | num_path[x, y] = num_path[cur_x, cur_y] 37 | elif matrix[x][y] + costs[cur_x][cur_y] == costs[x][y]: 38 | num_path[x, y] += 1 39 | 40 | certain.add((cur_x, cur_y)) 41 | # retrieve the path 42 | cur_x, cur_y = x_max - 1, y_max - 1 43 | on_path = np.zeros_like(matrix) 44 | on_path[-1][-1] = 1 45 | while (cur_x, cur_y) != (0, 0): 46 | cur_x, cur_y = transitions[(cur_x, cur_y)] 47 | on_path[cur_x, cur_y] = 1.0 48 | 49 | is_unique = num_path[-1, -1] == 1 50 | 51 | if request_transitions: 52 | return DijkstraOutput(shortest_path=on_path, is_unique=is_unique, transitions=transitions) 53 | else: 54 | return DijkstraOutput(shortest_path=on_path, is_unique=is_unique, transitions=None) 55 | 56 | 57 | def get_solver(neighbourhood_fn): 58 | def solver(matrix): 59 | return dijkstra(matrix, neighbourhood_fn).shortest_path 60 | 61 | return solver 62 | 63 | 64 | class ShortestPath(torch.autograd.Function): 65 | def __init__(self, lambda_val, neighbourhood_fn="8-grid"): 66 | self.lambda_val = lambda_val 67 | self.neighbourhood_fn = neighbourhood_fn 68 | self.solver = get_solver(neighbourhood_fn) 69 | 70 | def forward(self, weights): 71 | self.weights = weights.detach().cpu().numpy() 72 | self.suggested_tours = np.asarray(maybe_parallelize(self.solver, arg_list=list(self.weights))) 73 | return torch.from_numpy(self.suggested_tours).float().to(weights.device) 74 | 75 | def backward(self, grad_output): 76 | assert grad_output.shape == self.suggested_tours.shape 77 | grad_output_numpy = grad_output.detach().cpu().numpy() 78 | weights_prime = np.maximum(self.weights + self.lambda_val * grad_output_numpy, 0.0) 79 | better_paths = np.asarray(maybe_parallelize(self.solver, arg_list=list(weights_prime))) 80 | gradient = -(self.suggested_tours - better_paths) / self.lambda_val 81 | return torch.from_numpy(gradient).to(grad_output.device) 82 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from logging import WARNING 3 | import warnings 4 | warnings.filterwarnings("ignore") 5 | 6 | import psutil 7 | import ray 8 | 9 | from logger import Logger 10 | from utils import set_seed, save_metrics_params, update_params_from_cmdline, save_settings_to_json 11 | 12 | import warcraft_shortest_path.data_utils as warcraft_shortest_path_data 13 | import warcraft_shortest_path.trainers as warcraft_shortest_path_trainers 14 | 15 | dataset_loaders = { 16 | "warcraft_shortest_path": warcraft_shortest_path_data.load_dataset 17 | } 18 | 19 | trainer_loaders = { 20 | "warcraft_shortest_path": warcraft_shortest_path_trainers.get_trainer 21 | } 22 | 23 | required_top_level_params = [ 24 | "model_dir", 25 | "seed", 26 | "loader_params", 27 | "problem_type", 28 | "trainer_name", 29 | "trainer_params", 30 | "num_epochs", 31 | "evaluate_every", 32 | "save_visualizations" 33 | ] 34 | optional_top_level_params = ["num_cpus", "use_ray", "default_json", "id", "fast_mode", "fast_forward_training"] 35 | 36 | def verify_top_level_params(**kwargs): 37 | for kwarg in kwargs: 38 | if kwarg not in required_top_level_params and kwarg not in optional_top_level_params: 39 | raise ValueError("Unknown top_level argument: {}".format(kwarg)) 40 | 41 | for required in required_top_level_params: 42 | if required not in kwargs.keys(): 43 | raise ValueError("Missing required argument: {}".format(required)) 44 | 45 | def main(): 46 | params = update_params_from_cmdline(verbose=True) 47 | os.makedirs(params.model_dir, exist_ok=True) 48 | save_settings_to_json(params, params.model_dir) 49 | 50 | num_cpus = params.get("num_cpus", psutil.cpu_count(logical=True)) 51 | use_ray = params.get("use_ray", False) 52 | fast_forward_training = params.get("fast_forward_training", False) 53 | if use_ray: 54 | ray.init( 55 | num_cpus=num_cpus, 56 | logging_level=WARNING, 57 | ignore_reinit_error=True, 58 | redis_max_memory=10 ** 9, 59 | log_to_driver=False, 60 | **params.get("ray_params", {}) 61 | ) 62 | 63 | set_seed(params.seed) 64 | 65 | Logger.configure(params.model_dir, "tensorboard") 66 | 67 | dataset_loader = dataset_loaders[params.problem_type] 68 | train_iterator, test_iterator, metadata = dataset_loader(**params.loader_params) 69 | 70 | trainer_class = trainer_loaders[params.problem_type](params.trainer_name) 71 | 72 | fast_mode = params.get("fast_mode", False) 73 | trainer = trainer_class( 74 | train_iterator=train_iterator, 75 | test_iterator=test_iterator, 76 | metadata=metadata, 77 | fast_mode=fast_mode, 78 | **params.trainer_params 79 | ) 80 | train_results = {} 81 | for i in range(params.num_epochs): 82 | if i % params.evaluate_every == 0: 83 | eval_results = trainer.evaluate() 84 | print(eval_results) 85 | 86 | train_results = trainer.train_epoch() 87 | if train_results["train_accuracy"] > 0.999 and fast_forward_training: 88 | print(f'Reached train accuracy of {train_results["train_accuracy"]}. Fast forwarding.') 89 | break 90 | 91 | 92 | eval_results = trainer.evaluate() 93 | print(eval_results) 94 | train_results = train_results or {} 95 | save_metrics_params(params=params, metrics={**eval_results, **train_results}) 96 | 97 | if params.save_visualizations: 98 | print("Saving visualization images") 99 | trainer.log_visualization() 100 | 101 | if use_ray: 102 | ray.shutdown() 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /warcraft_shortest_path/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageDraw 3 | from decorators import input_to_numpy 4 | from comb_modules.dijkstra import dijkstra 5 | from utils import concat_2d 6 | 7 | @input_to_numpy 8 | def draw_paths_on_image(image, true_path, suggested_path, scaling_factor): 9 | transpose = True 10 | if len(image.shape) == 3 and image.shape[2] == 4: 11 | transpose = False 12 | print(transpose) 13 | image = preprocess_image(image, scaling_factor, transpose=transpose) 14 | 15 | true_transitions, is_valid_shortest_path_true = get_transitions_from_path(true_path) 16 | sug_transitions, is_valid_shortest_path_sug = get_transitions_from_path(suggested_path) 17 | 18 | visualized = draw_paths_on_image_as_line(image=image, 19 | transitions=true_transitions, 20 | grid_shape=true_path.shape, sf=scaling_factor, 21 | color="#8fb032") 22 | visualized = draw_paths_on_image_as_dots(image=visualized, 23 | path=suggested_path, 24 | sf=scaling_factor, color="#e19c24") 25 | if transpose: 26 | visualized = postprocess_image(visualized) 27 | return visualized 28 | 29 | def get_transitions_from_path(path): 30 | inverted_path = 1.-path 31 | shortest_path, _, transitions = dijkstra(inverted_path, request_transitions=True) 32 | is_valid_shortest_path = np.min(shortest_path == path) 33 | return transitions, is_valid_shortest_path 34 | 35 | 36 | def draw_paths_on_image_as_line(image, transitions, grid_shape, sf, color): 37 | im_width, im_height = image.size 38 | draw = ImageDraw.Draw(image) 39 | grid_x_max, grid_y_max = grid_shape 40 | 41 | cur_x, cur_y = grid_x_max - 1, grid_y_max - 1 42 | while (cur_x, cur_y) != (0, 0): 43 | next_y, next_x = transitions[(cur_y, cur_x)] 44 | cur_x_im, cur_y_im, _, _ = grid_to_im_coordinate( 45 | cur_x, cur_y, grid_x_max, grid_y_max, im_width, im_height 46 | ) 47 | next_x_im, next_y_im, _, _ = grid_to_im_coordinate( 48 | next_x, next_y, grid_x_max, grid_y_max, im_width, im_height 49 | ) 50 | draw.line([(cur_x_im, cur_y_im), (next_x_im, next_y_im)], 51 | fill=color, width=sf) 52 | cur_x, cur_y = next_x, next_y 53 | return image 54 | 55 | 56 | def draw_paths_on_image_as_dots(image, path, sf, color): 57 | im_width, im_height = image.size 58 | draw = ImageDraw.Draw(image) 59 | grid_x_max, grid_y_max = path.shape 60 | 61 | for x, y in np.ndindex(path.shape): 62 | if path[y][x]: 63 | x_im, y_im, x_spacing, y_spacing = grid_to_im_coordinate( 64 | x, y, grid_x_max, grid_y_max, im_width, im_height 65 | ) 66 | disp = min(x_spacing, y_spacing) // 8 67 | draw.ellipse([(x_im - disp, y_im - disp), (x_im + disp, y_im + disp)], 68 | outline=color, width=sf) 69 | return image 70 | 71 | def grid_to_im_coordinate(grid_x, grid_y, grid_x_max, grid_y_max, im_width, im_height): 72 | x_spacing = im_width // (grid_x_max) 73 | im_x = x_spacing * grid_x + x_spacing // 2 74 | y_spacing = im_height // (grid_y_max) 75 | im_y = y_spacing * grid_y + y_spacing // 2 76 | return im_x, im_y, x_spacing, y_spacing 77 | 78 | 79 | def preprocess_image(image, scaling_factor, transpose): 80 | if len(image.shape) == 5: # grid of images 81 | image = concat_2d(image) 82 | if transpose: 83 | image =np.moveaxis(image, 0, 2) 84 | im = Image.fromarray(image, mode=None) 85 | # im = im.resize(tuple(scaling_factor * x for x in im.size), resample=Image.NEAREST, box=None) 86 | return im 87 | 88 | def postprocess_image(image): 89 | image = np.moveaxis(np.array(image), 2, 0) 90 | return image 91 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tensorboardX import SummaryWriter 3 | import numpy as np 4 | from collections import defaultdict 5 | 6 | from decorators import input_to_numpy 7 | 8 | 9 | class Logger: 10 | 11 | valid_outputs = ["tensorboard", "stdout", "csv"] 12 | 13 | logging_dir = None 14 | summary_writer = None 15 | step_for_key = defaultdict(lambda: 0) # initialize step to 1 for every key 16 | default_output = None 17 | 18 | def __init__(self, scope, subdir=False, default_output=None): 19 | self.scope = scope 20 | self.subdir = subdir 21 | self.default_output = default_output or Logger.default_output 22 | 23 | if self.subdir: 24 | self.local_logging_dir = os.path.join(self.logging_dir, self.scope) 25 | self.local_summary_writer = SummaryWriter(os.path.join(self.local_logging_dir, "events")) 26 | 27 | @classmethod 28 | def configure(cls, logging_dir, default_output): 29 | cls.logging_dir = logging_dir 30 | cls.summary_writer = SummaryWriter(os.path.join(cls.logging_dir, "events"), flush_secs=30) 31 | if default_output not in cls.valid_outputs: 32 | raise NotImplementedError(f"{default_output} is not a valid output") 33 | else: 34 | cls.default_output = default_output 35 | 36 | def infer_datatype(self, data): 37 | if np.isscalar(data): 38 | return "scalar" 39 | elif isinstance(data, np.ndarray): 40 | if data.ndim == 0: 41 | return "scalar" 42 | elif data.ndim == 1: 43 | if data.size == 1: 44 | return "scalar" 45 | if data.size > 1: 46 | return "histogram" 47 | elif data.ndim == 2: 48 | return "image" 49 | elif data.ndim == 3: 50 | return "image" 51 | else: 52 | raise NotImplementedError("Numpy arrays with more than 2 dimensions are not supported") 53 | else: 54 | raise NotImplementedError(f"Data type {type(data)} not understood.") 55 | 56 | @input_to_numpy 57 | def log(self, data, key=None, data_type=None, to_tensorboard=None, to_stdout=None, to_csv=None): 58 | if data_type is None: 59 | data_type = self.infer_datatype(data) 60 | 61 | output_callables = [] 62 | if to_tensorboard or (to_tensorboard is None and self.default_output == "tensorboard"): 63 | output_callables.append(self.to_tensorboard) 64 | if to_stdout or (to_stdout is None and self.default_output == "stdout"): 65 | output_callables.append(self.to_stdout) 66 | if to_csv or (to_csv is None and self.default_output == "csv"): 67 | output_callables.append(self.to_csv) 68 | 69 | for output_callable in output_callables: 70 | output_callable(key, data_type, data) 71 | 72 | def to_tensorboard(self, key, data_type, data): 73 | if key is None: 74 | raise ValueError("Logging to tensorboard requires a valid key") 75 | 76 | if self.subdir: 77 | summary_writer = self.local_summary_writer 78 | else: 79 | summary_writer = self.summary_writer 80 | 81 | step = self.step_for_key[key] 82 | 83 | self.step_for_key[key] += 1 84 | 85 | if data_type == "scalar": 86 | data_specific_writer_callable = summary_writer.add_scalar 87 | elif data_type == "histogram": 88 | data_specific_writer_callable = summary_writer.add_histogram 89 | elif data_type == "image": 90 | data_specific_writer_callable = summary_writer.add_image 91 | else: 92 | raise NotImplementedError(f"Summary writer does not support type {data_type}") 93 | 94 | data_specific_writer_callable(self.scope + "/" + key, data, step) 95 | 96 | def to_stdout(self, key, data_type, data): 97 | # if not data_type == "scalar": 98 | # raise NotImplementedError("Only data type 'scalar' supported for stdout output") 99 | 100 | print(f"[{self.scope}] {key}: {data}") 101 | 102 | def to_csv(self, key, data_type, data): 103 | raise NotImplementedError("CSV output is not implemented, yet") 104 | 105 | def __del__(self, *args, **kwargs): 106 | if self.summary_writer is not None: 107 | self.summary_writer.close() 108 | self.summary_writer = None 109 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision 7 | 8 | 9 | def get_model(model_name, out_features, in_channels, arch_params): 10 | preloaded_models = {"ResNet18": torchvision.models.resnet18} 11 | 12 | own_models = {"ConvNet": ConvNet, "MLP": MLP, "PureConvNet": PureConvNet, "CombResnet18": CombRenset18} 13 | 14 | if model_name in preloaded_models: 15 | model = preloaded_models[model_name](pretrained=False, num_classes=out_features, **arch_params) 16 | 17 | # Hacking ResNets to expect 'in_channels' input channel (and not three) 18 | del model.conv1 19 | model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 20 | return model 21 | elif model_name in own_models: 22 | return own_models[model_name](out_features=out_features, in_channels=in_channels, **arch_params) 23 | else: 24 | raise ValueError(f"Model name {model_name} not recognized!") 25 | 26 | 27 | def dim_after_conv2D(input_dim, stride, kernel_size): 28 | return (input_dim - kernel_size + 2) // stride 29 | 30 | 31 | class CombRenset18(nn.Module): 32 | 33 | def __init__(self, out_features, in_channels): 34 | super().__init__() 35 | self.resnet_model = torchvision.models.resnet18(pretrained=False, num_classes=out_features) 36 | del self.resnet_model.conv1 37 | self.resnet_model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 38 | output_shape = (int(sqrt(out_features)), int(sqrt(out_features))) 39 | self.pool = nn.AdaptiveMaxPool2d(output_shape) 40 | #self.last_conv = nn.Conv2d(128, 1, kernel_size=1, stride=1) 41 | 42 | 43 | def forward(self, x): 44 | x = self.resnet_model.conv1(x) 45 | x = self.resnet_model.bn1(x) 46 | x = self.resnet_model.relu(x) 47 | x = self.resnet_model.maxpool(x) 48 | x = self.resnet_model.layer1(x) 49 | #x = self.resnet_model.layer2(x) 50 | #x = self.resnet_model.layer3(x) 51 | #x = self.last_conv(x) 52 | x = self.pool(x) 53 | x = x.mean(dim=1) 54 | return x 55 | 56 | 57 | class ConvNet(torch.nn.Module): 58 | def __init__(self, out_features, in_channels, kernel_size, stride, linear_layer_size, channels_1, channels_2): 59 | super().__init__() 60 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=channels_1, kernel_size=kernel_size, stride=stride) 61 | self.conv2 = nn.Conv2d(in_channels=channels_1, out_channels=channels_2, kernel_size=kernel_size, stride=stride) 62 | 63 | output_shape = (4, 4) 64 | self.pool = nn.AdaptiveAvgPool2d(output_shape) 65 | 66 | self.fc1 = nn.Linear(in_features=output_shape[0] * output_shape[1] * channels_2, out_features=linear_layer_size) 67 | self.fc2 = nn.Linear(in_features=linear_layer_size, out_features=out_features) 68 | 69 | def forward(self, x): 70 | batch_size = x.shape[0] 71 | x = F.relu(self.conv1(x)) 72 | x = F.max_pool2d(x, 2, 2) 73 | x = F.relu(self.conv2(x)) 74 | x = self.pool(x) 75 | x = x.view(batch_size, -1) 76 | x = F.relu(self.fc1(x)) 77 | x = self.fc2(x) 78 | return x 79 | 80 | 81 | class MLP(torch.nn.Module): 82 | def __init__(self, out_features, in_channels, hidden_layer_size): 83 | super().__init__() 84 | input_dim = in_channels * 40 * 20 85 | self.fc1 = nn.Linear(in_features=input_dim, out_features=hidden_layer_size) 86 | self.fc2 = nn.Linear(in_features=hidden_layer_size, out_features=out_features) 87 | 88 | def forward(self, x): 89 | batch_size = x.shape[0] 90 | x = x.view(batch_size, -1) 91 | x = torch.tanh(self.fc1(x)) 92 | x = self.fc2(x) 93 | return x 94 | 95 | 96 | class PureConvNet(torch.nn.Module): 97 | 98 | act_funcs = {"relu": F.relu, "tanh": F.tanh, "identity": lambda x: x} 99 | 100 | def __init__(self, out_features, pooling, use_second_conv, kernel_size, in_channels, 101 | channels_1=20, channels_2=20, act_func="relu"): 102 | super().__init__() 103 | self.use_second_conv = use_second_conv 104 | 105 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=channels_1, kernel_size=kernel_size, stride=1) 106 | self.conv2 = nn.Conv2d(in_channels=channels_1, out_channels=channels_2, kernel_size=kernel_size, stride=1) 107 | 108 | output_shape = (int(sqrt(out_features)), int(sqrt(out_features))) 109 | if pooling == "average": 110 | self.pool = nn.AdaptiveAvgPool2d(output_shape) 111 | elif pooling == "max": 112 | self.pool = nn.AdaptiveMaxPool2d(output_shape) 113 | 114 | self.conv3 = nn.Conv2d(in_channels=channels_2 if use_second_conv else channels_1, 115 | out_channels=1, kernel_size=1, stride=1) 116 | self.act_func = PureConvNet.act_funcs[act_func] 117 | 118 | def forward(self, x): 119 | x = self.act_func(self.conv1(x)) 120 | if self.use_second_conv: 121 | x = self.act_func(self.conv2(x)) 122 | x = self.pool(x) 123 | x = self.conv3(x) 124 | return x 125 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import itertools 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from PIL import Image, ImageDraw 6 | from mpl_toolkits.basemap import Basemap 7 | from matplotlib.offsetbox import OffsetImage, AnnotationBbox 8 | 9 | 10 | def neighbours_8(x, y, x_max, y_max): 11 | deltas_x = (-1, 0, 1) 12 | deltas_y = (-1, 0, 1) 13 | for (dx, dy) in itertools.product(deltas_x, deltas_y): 14 | x_new, y_new = x + dx, y + dy 15 | if 0 <= x_new < x_max and 0 <= y_new < y_max and (dx, dy) != (0, 0): 16 | yield x_new, y_new 17 | 18 | 19 | def neighbours_4(x, y, x_max, y_max): 20 | for (dx, dy) in [(1, 0), (0, 1), (0, -1), (-1, 0)]: 21 | x_new, y_new = x + dx, y + dy 22 | if 0 <= x_new < x_max and 0 <= y_new < y_max and (dx, dy) != (0, 0): 23 | yield x_new, y_new 24 | 25 | 26 | def get_neighbourhood_func(neighbourhood_fn): 27 | if neighbourhood_fn == "4-grid": 28 | return neighbours_4 29 | elif neighbourhood_fn == "8-grid": 30 | return neighbours_8 31 | else: 32 | raise Exception(f"neighbourhood_fn of {neighbourhood_fn} not possible") 33 | 34 | 35 | def edges_from_vertex(x, y, N, neighbourhood_fn): 36 | v = (x, y) 37 | neighbours = get_neighbourhood_func(neighbourhood_fn)(*v, x_max=N, y_max=N) 38 | v_edges = [ 39 | (*v, *vn) for vn in neighbours if vertex_index(v, N) < vertex_index(vn, N) 40 | ] # Enforce ordering on vertices 41 | return v_edges 42 | 43 | 44 | def vertex_index(v, dim): 45 | x, y = v 46 | return x * dim + y 47 | 48 | 49 | @functools.lru_cache(32) 50 | def edges_from_grid(N, neighbourhood_fn): 51 | all_vertices = itertools.product(range(N), range(N)) 52 | all_edges = [edges_from_vertex(x, y, N, neighbourhood_fn=neighbourhood_fn) for x, y in all_vertices] 53 | all_edges_flat = sum(all_edges, []) 54 | all_edges_flat_unique = list(set(all_edges_flat)) 55 | return np.asarray(all_edges_flat_unique) 56 | 57 | 58 | def perfect_matching_vis(grid_img, grid_dim, labels, color=(0, 255, 255), width=2, offset=0): 59 | edges = edges_from_grid(grid_dim, neighbourhood_fn='4-grid') 60 | pixels_per_cell = int(grid_img.shape[0] / grid_dim) 61 | 62 | img = Image.fromarray(np.uint8(grid_img.squeeze())).convert("RGB") 63 | for i, (y1, x1, y2, x2) in enumerate(edges): 64 | if labels[i]: 65 | draw = ImageDraw.Draw(img) 66 | if x1 == x2: 67 | draw.line( 68 | (x1 * pixels_per_cell + pixels_per_cell / 2, y1 * pixels_per_cell + pixels_per_cell / 2 + offset, 69 | x2 * pixels_per_cell + pixels_per_cell / 2, y2 * pixels_per_cell + pixels_per_cell / 2 - offset), 70 | fill=color, width=width) 71 | else: 72 | draw.line( 73 | (x1 * pixels_per_cell + pixels_per_cell / 2 + offset, y1 * pixels_per_cell + pixels_per_cell / 2, 74 | x2 * pixels_per_cell + pixels_per_cell / 2 - offset, y2 * pixels_per_cell + pixels_per_cell / 2), 75 | fill=color, width=width) 76 | del draw 77 | 78 | return np.asarray(img, dtype=np.uint8) 79 | 80 | 81 | def plot_tsp_path(gps, flags, tsp_tour): 82 | plt.figure(figsize=(24, 12)) 83 | m = Basemap(projection='ortho', lon_0=20.0, lat_0=20.0, resolution=None) 84 | m.shadedrelief() 85 | 86 | for (lon, lat), flag in zip(gps, flags): 87 | x, y = m(lon, lat) 88 | im = OffsetImage(flag[..., ::-1], zoom=0.8) 89 | ab = AnnotationBbox(im, (x, y), xycoords='data', frameon=False) 90 | m._check_ax().add_artist(ab) 91 | 92 | num_countries = len(tsp_tour) 93 | last_country = 0 94 | current_country = 0 95 | path_indices = [0] 96 | for _ in range(num_countries): 97 | for j in range(num_countries): 98 | if tsp_tour[current_country][j] and j != last_country: 99 | last_country = current_country 100 | current_country = j 101 | path_indices.append(current_country) 102 | break 103 | 104 | lat = [gps[i][1] for i in path_indices] 105 | lon = [gps[i][0] for i in path_indices] 106 | 107 | x, y = m(lon, lat) 108 | m.plot(x, y, 'o-', markersize=5, linewidth=3) 109 | 110 | plt.title("Country locations with TSP solution") 111 | plt.show() 112 | 113 | 114 | # helper functions, you need to install tqdm for progress bar feature 115 | import urllib.request 116 | import numpy as np 117 | from matplotlib import pyplot as plt 118 | 119 | try: 120 | 121 | from tqdm import tqdm 122 | class DownloadProgressBar(tqdm): 123 | def update_to(self, b=1, bsize=1, tsize=None): 124 | if tsize is not None: 125 | self.total = tsize 126 | self.update(b * bsize - self.n) 127 | 128 | 129 | def download_url(url, output_path): 130 | with DownloadProgressBar(unit='B', unit_scale=True, 131 | miniters=1, desc=url.split('/')[-1]) as t: 132 | urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to) 133 | except ModuleNotFoundError as e: 134 | print("Not using progress bar") 135 | def download_url(url, output_path): 136 | urllib.request.urlretrieve(url, filename=output_path) -------------------------------------------------------------------------------- /warcraft_shortest_path/trainers.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from comb_modules.utils import cached_vertex_grid_to_edges, cached_vertex_grid_to_edges_grid_coords 4 | import time 5 | from abc import ABC, abstractmethod 6 | 7 | import torch 8 | from comb_modules.losses import HammingLoss 9 | from comb_modules.dijkstra import ShortestPath 10 | from logger import Logger 11 | from models import get_model 12 | from utils import AverageMeter, optimizer_from_string, customdefaultdict 13 | from decorators import to_tensor, to_numpy 14 | from . import metrics 15 | from .metrics import compute_metrics 16 | import numpy as np 17 | from collections import defaultdict 18 | def get_trainer(trainer_name): 19 | trainers = {"Baseline": BaselineTrainer, "DijkstraOnFull": DijkstraOnFull} 20 | return trainers[trainer_name] 21 | from torch.optim.lr_scheduler import ReduceLROnPlateau, MultiStepLR 22 | from .visualization import draw_paths_on_image 23 | 24 | class ShortestPathAbstractTrainer(ABC): 25 | def __init__( 26 | self, 27 | *, 28 | train_iterator, 29 | test_iterator, 30 | metadata, 31 | use_cuda, 32 | batch_size, 33 | optimizer_name, 34 | optimizer_params, 35 | model_params, 36 | fast_mode, 37 | neighbourhood_fn, 38 | preload_batch, 39 | lr_milestone_1, 40 | lr_milestone_2, 41 | use_lr_scheduling 42 | ): 43 | 44 | self.fast_mode = fast_mode 45 | self.use_cuda = use_cuda 46 | self.optimizer_params = optimizer_params 47 | self.batch_size = batch_size 48 | self.test_iterator = test_iterator 49 | self.train_iterator = train_iterator 50 | self.metadata = metadata 51 | self.grid_dim = int(np.sqrt(self.metadata["output_features"])) 52 | self.neighbourhood_fn = neighbourhood_fn 53 | self.preload_batch = preload_batch 54 | 55 | self.model = None 56 | self.build_model(**model_params) 57 | 58 | if self.use_cuda: 59 | self.model.to("cuda") 60 | self.optimizer = optimizer_from_string(optimizer_name)(self.model.parameters(), **optimizer_params) 61 | self.use_lr_scheduling = use_lr_scheduling 62 | if use_lr_scheduling: 63 | self.scheduler = MultiStepLR(self.optimizer, milestones=[lr_milestone_1, lr_milestone_2], gamma=0.1) 64 | self.epochs = 0 65 | self.train_logger = Logger(scope="training", default_output="tensorboard") 66 | self.val_logger = Logger(scope="validation", default_output="tensorboard") 67 | 68 | def train_epoch(self): 69 | self.epochs += 1 70 | batch_time = AverageMeter("Batch time") 71 | data_time = AverageMeter("Data time") 72 | cuda_time = AverageMeter("Cuda time") 73 | avg_loss = AverageMeter("Loss") 74 | avg_accuracy = AverageMeter("Accuracy") 75 | avg_perfect_accuracy = AverageMeter("Perfect Accuracy") 76 | 77 | avg_metrics = customdefaultdict(lambda k: AverageMeter("train_"+k)) 78 | 79 | 80 | self.model.train() 81 | 82 | end = time.time() 83 | 84 | iterator = self.train_iterator.get_epoch_iterator(batch_size=self.batch_size, number_of_epochs=1, device='cuda' if self.use_cuda else 'cpu', preload=self.preload_batch) 85 | for i, data in enumerate(iterator): 86 | input, true_path, true_weights = data["images"], data["labels"], data["true_weights"] 87 | 88 | if i == 0: 89 | self.log(data, train=True) 90 | cuda_begin = time.time() 91 | cuda_time.update(time.time()-cuda_begin) 92 | 93 | # measure data loading time 94 | data_time.update(time.time() - end) 95 | 96 | loss, accuracy, last_suggestion = self.forward_pass(input, true_path, train=True, i=i) 97 | 98 | suggested_path = last_suggestion["suggested_path"] 99 | 100 | batch_metrics = metrics.compute_metrics(true_paths=true_path, 101 | suggested_paths=suggested_path, true_vertex_costs=true_weights) 102 | # update batch metrics 103 | {avg_metrics[k].update(v, input.size(0)) for k, v in batch_metrics.items()} 104 | assert len(avg_metrics.keys()) > 0 105 | 106 | avg_loss.update(loss.item(), input.size(0)) 107 | avg_accuracy.update(accuracy.item(), input.size(0)) 108 | 109 | # compute gradient and do SGD step 110 | self.optimizer.zero_grad() 111 | loss.backward() 112 | self.optimizer.step() 113 | 114 | # measure elapsed time 115 | batch_time.update(time.time() - end) 116 | end = time.time() 117 | 118 | if self.fast_mode: 119 | break 120 | 121 | meters = [batch_time, data_time, cuda_time, avg_loss, avg_accuracy] 122 | meter_str = "\t".join([str(meter) for meter in meters]) 123 | print(f"Epoch: {self.epochs}\t{meter_str}") 124 | 125 | if self.use_lr_scheduling: 126 | self.scheduler.step() 127 | self.train_logger.log(avg_loss.avg, "loss") 128 | self.train_logger.log(avg_accuracy.avg, "accuracy") 129 | for key, avg_metric in avg_metrics.items(): 130 | self.train_logger.log(avg_metric.avg, key=key) 131 | 132 | return { 133 | "train_loss": avg_loss.avg, 134 | "train_accuracy": avg_accuracy.avg, 135 | **{"train_"+k: avg_metrics[k].avg for k in avg_metrics.keys()} 136 | } 137 | 138 | def evaluate(self): 139 | avg_metrics = defaultdict(AverageMeter) 140 | 141 | self.model.eval() 142 | 143 | iterator = self.test_iterator.get_epoch_iterator(batch_size=self.batch_size, number_of_epochs=1, shuffle=False, device='cuda' if self.use_cuda else 'cpu', preload=self.preload_batch) 144 | 145 | for i, data in enumerate(iterator): 146 | input, true_path, true_weights = ( 147 | data["images"].contiguous(), 148 | data["labels"].contiguous(), 149 | data["true_weights"].contiguous(), 150 | ) 151 | 152 | if self.use_cuda: 153 | input = input.cuda(async=True) 154 | true_path = true_path.cuda(async=True) 155 | 156 | loss, accuracy, last_suggestion = self.forward_pass(input, true_path, train=False, i=i) 157 | suggested_path = last_suggestion["suggested_path"] 158 | data.update(last_suggestion) 159 | if i == 0: 160 | indices_in_batch = random.sample(range(self.batch_size), 4) 161 | for num, k in enumerate(indices_in_batch): 162 | self.log(data, train=False, k=k, num=num) 163 | 164 | evaluated_metrics = metrics.compute_metrics(true_paths=true_path, 165 | suggested_paths=suggested_path, true_vertex_costs=true_weights) 166 | avg_metrics["loss"].update(loss.item(), input.size(0)) 167 | avg_metrics["accuracy"].update(accuracy.item(), input.size(0)) 168 | for key, value in evaluated_metrics.items(): 169 | avg_metrics[key].update(value, input.size(0)) 170 | 171 | if self.fast_mode: 172 | break 173 | 174 | for key, avg_metric in avg_metrics.items(): 175 | self.val_logger.log(avg_metric.avg, key=key) 176 | avg_metrics_values = dict([(key, avg_metric.avg) for key, avg_metric in avg_metrics.items()]) 177 | return avg_metrics_values 178 | 179 | @abstractmethod 180 | def build_model(self, **kwargs): 181 | pass 182 | 183 | @abstractmethod 184 | def forward_pass(self, input, true_shortest_paths, train, i): 185 | pass 186 | 187 | def log(self, data, train, k=None, num=None): 188 | logger = self.train_logger if train else self.val_logger 189 | if not train: 190 | image = self.metadata['denormalize'](data["images"][k]).squeeze().astype(np.uint8) 191 | suggested_path = data["suggested_path"][k].squeeze() 192 | labels = data["labels"][k].squeeze() 193 | 194 | suggested_path_im = torch.ones((3, *suggested_path.shape))*255*suggested_path.cpu() 195 | labels_im = torch.ones((3, *labels.shape))*255*labels.cpu() 196 | image_with_path = draw_paths_on_image(image=image, true_path=labels, suggested_path=suggested_path, scaling_factor=10) 197 | 198 | logger.log(labels_im.data.numpy().astype(np.uint8), key=f"shortest_path_{num}", data_type="image") 199 | logger.log(suggested_path_im.data.numpy().astype(np.uint8), key=f"suggested_path_{num}", data_type="image") 200 | logger.log(image_with_path, key=f"full_input_with_path{num}", data_type="image") 201 | 202 | 203 | 204 | class BaselineTrainer(ShortestPathAbstractTrainer): 205 | def build_model(self, model_name, arch_params): 206 | grid_dim = int(np.sqrt(self.metadata["output_features"])) 207 | self.model = get_model( 208 | model_name, out_features=self.metadata["output_features"], in_channels=self.metadata["num_channels"], arch_params=arch_params 209 | ) 210 | 211 | def forward_pass(self, input, label, train, i): 212 | output = self.model(input) 213 | output = torch.sigmoid(output) 214 | flat_target = label.view(label.size()[0], -1) 215 | 216 | criterion = torch.nn.BCELoss() 217 | loss = criterion(output, flat_target).mean() 218 | accuracy = (output.round() * flat_target).sum() / flat_target.sum() 219 | 220 | suggested_path = output.view(label.shape).round() 221 | last_suggestion = {"vertex_costs": None, "suggested_path": suggested_path} 222 | 223 | return loss, accuracy, last_suggestion 224 | 225 | 226 | class DijkstraOnFull(ShortestPathAbstractTrainer): 227 | def __init__(self, *, l1_regconst, lambda_val, **kwargs): 228 | super().__init__(**kwargs) 229 | self.l1_regconst = l1_regconst 230 | self.lambda_val = lambda_val 231 | self.solver = ShortestPath(lambda_val=lambda_val, neighbourhood_fn=self.neighbourhood_fn) 232 | self.loss_fn = HammingLoss() 233 | 234 | print("META:", self.metadata) 235 | def build_model(self, model_name, arch_params): 236 | self.model = get_model( 237 | model_name, out_features=self.metadata["output_features"], in_channels=self.metadata["num_channels"], arch_params=arch_params 238 | ) 239 | 240 | def forward_pass(self, input, true_shortest_paths, train, i): 241 | output = self.model(input) 242 | # make grid weights positive 243 | output = torch.abs(output) 244 | weights = output.reshape(-1, output.shape[-1], output.shape[-1]) 245 | 246 | if i == 0 and not train: 247 | print(output[0]) 248 | assert len(weights.shape) == 3, f"{str(weights.shape)}" 249 | shortest_paths = self.solver(weights) 250 | 251 | loss = self.loss_fn(shortest_paths, true_shortest_paths) 252 | 253 | logger = self.train_logger if train else self.val_logger 254 | 255 | last_suggestion = { 256 | "suggested_weights": weights, 257 | "suggested_path": shortest_paths 258 | } 259 | 260 | accuracy = (torch.abs(shortest_paths - true_shortest_paths) < 0.5).to(torch.float32).mean() 261 | extra_loss = self.l1_regconst * torch.mean(output) 262 | loss += extra_loss 263 | 264 | return loss, accuracy, last_suggestion 265 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | import random 5 | import torch 6 | import csv 7 | from PIL import Image 8 | import ray 9 | import itertools 10 | from collections import defaultdict, deque 11 | import time 12 | from functools import lru_cache 13 | from constants import * 14 | 15 | import ast 16 | import collections 17 | import json 18 | from copy import deepcopy 19 | from warnings import warn 20 | import numpy as np 21 | 22 | import inspect 23 | import re 24 | import shutil 25 | import tempfile 26 | from time import sleep 27 | 28 | 29 | 30 | class customdefaultdict(defaultdict): 31 | def __missing__(self, key): 32 | if self.default_factory: 33 | dict.__setitem__(self, key, self.default_factory(key)) 34 | return self[key] 35 | else: 36 | defaultdict.__missing__(self, key) 37 | 38 | @lru_cache(maxsize=128) 39 | def cached_np_load(path, **kwargs): 40 | return np.load(path, **kwargs) 41 | 42 | 43 | def efficient_from_numpy(x, device): 44 | if device == 'cpu': 45 | return torch.from_numpy(x).cpu() 46 | else: 47 | return torch.from_numpy(x).contiguous().pin_memory().to(device=device, non_blocking=True) 48 | 49 | 50 | class AverageMeter(object): 51 | """Computes and stores the average and current value""" 52 | 53 | def __init__(self, name=None, fmt=":f"): 54 | self.name = name 55 | self.fmt = fmt 56 | self.reset() 57 | 58 | def reset(self): 59 | self.val = 0 60 | self.avg = 0 61 | self.sum = 0 62 | self.count = 0 63 | 64 | def update(self, val, n=1): 65 | self.val = val 66 | self.sum += val * n 67 | self.count += n 68 | self.avg = self.sum / self.count 69 | 70 | def __str__(self): 71 | fmtstr = "{name} ({avg" + self.fmt + "})" 72 | return fmtstr.format(**self.__dict__) 73 | 74 | 75 | def set_seed(seed): 76 | if seed is not None: 77 | random.seed(seed) 78 | torch.manual_seed(seed) 79 | torch.cuda.manual_seed(seed) 80 | np.random.seed(seed) 81 | 82 | 83 | def save_pickle(data, path): 84 | with open(path, "wb") as fh: 85 | pickle.dump(data, fh) 86 | 87 | 88 | def load_pickle(path): 89 | with open(path, "rb") as fh: 90 | return pickle.load(fh) 91 | 92 | 93 | def load_pngs(path): 94 | def get_im(): 95 | for file in os.listdir(path): 96 | if file.endswith(".png"): 97 | im = Image.open(os.path.join(path, file)) 98 | np_im = np.array(im) 99 | yield np_im 100 | 101 | return np.stack(list(get_im())) 102 | 103 | 104 | def concat_2d(arr): 105 | rows, columns, channels, height, width = arr.shape 106 | return np.rollaxis(arr, 2, 0).swapaxes(2, 3).reshape(channels, height * rows, width * columns) 107 | 108 | 109 | def sample_image_grid(height, width, images, labels): 110 | num_images = images.shape[0] 111 | num_labels = labels.shape[0] 112 | assert num_images == num_labels 113 | 114 | indices = np.random.choice(num_images, size=(height, width)) 115 | return images[indices, ...], labels[indices] 116 | 117 | class TrainingIterator(object): 118 | def __init__(self, data_dict): 119 | zipped_data = list(zip(*data_dict.values())) 120 | 121 | self.dtype = [(key, "f4", value[0].shape) for key, value in data_dict.items()] 122 | # PyTorch works with 32-bit floats by default 123 | 124 | self.array = np.array(zipped_data, dtype=self.dtype) 125 | 126 | def get_epoch_iterator(self, batch_size, number_of_epochs, device='cpu', preload=False, shuffle=True): 127 | def iterator(): 128 | if preload: 129 | preload_deque = deque(maxlen=2) 130 | for i in range(number_of_epochs): 131 | if shuffle: 132 | np.random.shuffle(self.array) 133 | for j in range(1 + len(self.array) // batch_size): 134 | numpy_batch = self.array[j * batch_size : (j + 1) * batch_size] 135 | torch_batch = {key: efficient_from_numpy(numpy_batch[key], device=device) for key in numpy_batch.dtype.names} 136 | 137 | if numpy_batch.size: 138 | if j == 0 and preload : 139 | preload_deque.appendleft(torch_batch) 140 | continue 141 | if preload: 142 | preload_deque.appendleft(torch_batch) 143 | yield preload_deque.pop() 144 | else: 145 | yield torch_batch 146 | if preload: 147 | while len(preload_deque) > 0: 148 | yield preload_deque.pop() 149 | 150 | return iterator() 151 | 152 | def detach_to_cpu_np(arrs): 153 | detached = [arr.cpu().detach().numpy() for arr in arrs] 154 | return detached 155 | 156 | 157 | def grid_to_im_coordinate(grid_x, grid_y, grid_x_max, grid_y_max, im_width, im_height): 158 | x_spacing = im_width / grid_x_max 159 | im_x = x_spacing * (0.5 + grid_x) 160 | y_spacing = im_height / grid_y_max 161 | im_y = y_spacing * (0.5 + grid_y) 162 | return im_x, im_y, x_spacing, y_spacing 163 | 164 | 165 | def maybe_parallelize(function, arg_list): 166 | if ray.is_initialized(): 167 | ray_fn = ray.remote(function) 168 | return ray.get([ray_fn.remote(arg) for arg in arg_list]) 169 | else: 170 | return [function(arg) for arg in arg_list] 171 | 172 | 173 | def optimizer_from_string(optimizer_name): 174 | dct = {"Adam": torch.optim.Adam, "SGD": torch.optim.SGD} 175 | return dct[optimizer_name] 176 | 177 | 178 | def all_accuracies(true_labels, suggested_labels, true_costs, is_valid_label_fn, num_thresholds, minimize=True): 179 | num_examples = len(true_labels) 180 | valid = 0 181 | meets_threshold = [0] * num_thresholds 182 | for true_label, suggested_label, true_cost in zip(true_labels, suggested_labels, true_costs): 183 | if not is_valid_label_fn(suggested_label): 184 | continue 185 | valid += 1 186 | cost_ratio = np.sum(suggested_label * true_cost) / np.sum(true_label * true_cost) 187 | if not minimize: 188 | cost_ratio = 1.0 / cost_ratio 189 | 190 | assert cost_ratio > 0.99 # cost is not better than optimal... 191 | 192 | for i in range(len(meets_threshold)): 193 | if cost_ratio - 1.0 < 10.0 ** (-i-1): 194 | meets_threshold[i] += 1 195 | 196 | threshold_dict = {f"below_{10. ** (1-i)}_percent_acc": val / num_examples for i, val in enumerate(meets_threshold)} 197 | threshold_dict['valid_acc'] = valid / num_examples 198 | return threshold_dict 199 | 200 | def shorten_string(string, max_len): 201 | if len(string) > max_len - 3: 202 | return '...' + string[-max_len + 3:] 203 | return string 204 | 205 | 206 | def get_caller_file(depth=2): 207 | _, filename, _, _, _, _ = inspect.stack()[depth] 208 | return filename 209 | 210 | 211 | def check_valid_name(string): 212 | pat = '[A-Za-z0-9_.-]*$' 213 | if type(string) is not str: 214 | raise TypeError(('Parameter \'{}\' not valid. String expected.'.format(string))) 215 | if string in RESERVED_PARAMS: 216 | raise ValueError('Parameter name {} is reserved'.format(string)) 217 | if string.endswith(STD_ENDING): 218 | raise ValueError('Parameter name \'{}\' not valid.' 219 | 'Ends with \'{}\' (may cause collisions)'.format(string, STD_ENDING)) 220 | if not bool(re.compile(pat).match(string)): 221 | raise ValueError('Parameter name \'{}\' not valid. Only \'[0-9][a-z][A-Z]_-.\' allowed.'.format(string)) 222 | if string.endswith('.') or string.startswith('.'): 223 | raise ValueError('Parameter name \'{}\' not valid. \'.\' not allowed at start/end'.format(string)) 224 | 225 | 226 | def rm_dir_full(dir_name): 227 | sleep(0.5) 228 | if os.path.exists(dir_name): 229 | shutil.rmtree(dir_name, ignore_errors=True) 230 | 231 | # filesystem is sometimes slow to response 232 | if os.path.exists(dir_name): 233 | sleep(1.0) 234 | shutil.rmtree(dir_name, ignore_errors=True) 235 | 236 | if os.path.exists(dir_name): 237 | warn(f'Removing of dir {dir_name} failed') 238 | 239 | 240 | def create_dir(dir_name): 241 | if not os.path.exists(dir_name): 242 | os.makedirs(dir_name) 243 | 244 | 245 | def flatten_nested_string_dict(nested_dict, prepend=''): 246 | for key, value in nested_dict.items(): 247 | if type(key) is not str: 248 | raise TypeError('Only strings as keys expected') 249 | if isinstance(value, dict): 250 | for sub in flatten_nested_string_dict(value, prepend=prepend + str(key) + OBJECT_SEPARATOR): 251 | yield sub 252 | else: 253 | yield prepend + str(key), value 254 | 255 | 256 | def save_dict_as_one_line_csv(dct, filename): 257 | with open(filename, 'w') as f: 258 | writer = csv.DictWriter(f, fieldnames=dct.keys()) 259 | writer.writeheader() 260 | writer.writerow(dct) 261 | 262 | 263 | def get_sample_generator(samples, hyperparam_dict, distribution_list, extra_settings=None): 264 | if bool(hyperparam_dict) == bool(distribution_list): 265 | raise TypeError('Exactly one of hyperparam_dict and distribution list must be provided') 266 | if distribution_list and not samples: 267 | raise TypeError('Number of samples not specified') 268 | if distribution_list: 269 | ans = distribution_list_sampler(distribution_list, samples) 270 | elif samples: 271 | assert hyperparam_dict 272 | ans = hyperparam_dict_samples(hyperparam_dict, samples) 273 | else: 274 | ans = hyperparam_dict_product(hyperparam_dict) 275 | if extra_settings is not None: 276 | return itertools.chain(extra_settings, ans) 277 | else: 278 | return ans 279 | 280 | 281 | 282 | 283 | 284 | def process_other_params(other_params, hyperparam_dict, distribution_list): 285 | if hyperparam_dict: 286 | name_list = hyperparam_dict.keys() 287 | else: 288 | name_list = [distr.param_name for distr in distribution_list] 289 | for name, value in other_params.items(): 290 | check_valid_name(name) 291 | if name in name_list: 292 | raise ValueError('Duplicate setting \'{}\' in other params!'.format(name)) 293 | if not any([isinstance(value, allowed_type) for allowed_type in PARAM_TYPES]): 294 | raise TypeError('Settings must from the following types: {}, not {}'.format(PARAM_TYPES, type(value))) 295 | nested_items = [(name.split('.'), value) for name, value in other_params.items()] 296 | return nested_to_dict(nested_items) 297 | 298 | 299 | def validate_hyperparam_dict(hyperparam_dict): 300 | for name, option_list in hyperparam_dict.items(): 301 | check_valid_name(name) 302 | if type(option_list) is not list: 303 | raise TypeError('Entries in hyperparam dict must be type list (not {}: {})'.format(name, type(option_list))) 304 | for item in option_list: 305 | if not any([isinstance(item, allowed_type) for allowed_type in PARAM_TYPES]): 306 | raise TypeError('Settings must from the following types: {}, not {}'.format(PARAM_TYPES, type(item))) 307 | 308 | 309 | def hyperparam_dict_samples(hyperparam_dict, num_samples): 310 | validate_hyperparam_dict(hyperparam_dict) 311 | nested_items = [(name.split(OBJECT_SEPARATOR), options) for name, options in hyperparam_dict.items()] 312 | 313 | for i in range(num_samples): 314 | nested_samples = [(nested_path, random.choice(options)) for nested_path, options in nested_items] 315 | yield nested_to_dict(nested_samples) 316 | 317 | 318 | def hyperparam_dict_product(hyperparam_dict): 319 | validate_hyperparam_dict(hyperparam_dict) 320 | 321 | nested_items = [(name.split(OBJECT_SEPARATOR), options) for name, options in hyperparam_dict.items()] 322 | nested_names, option_lists = zip(*nested_items) 323 | 324 | for sample_from_product in itertools.product(*list(option_lists)): 325 | yield nested_to_dict(zip(nested_names, sample_from_product)) 326 | 327 | 328 | def default_to_regular(d): 329 | if isinstance(d, defaultdict): 330 | d = {k: default_to_regular(v) for k, v in d.items()} 331 | return d 332 | 333 | 334 | def nested_to_dict(nested_items): 335 | nested_dict = lambda: defaultdict(nested_dict) 336 | result = nested_dict() 337 | for nested_key, value in nested_items: 338 | ptr = result 339 | for key in nested_key[:-1]: 340 | ptr = ptr[key] 341 | ptr[nested_key[-1]] = value 342 | return default_to_regular(result) 343 | 344 | 345 | def distribution_list_sampler(distribution_list, num_samples): 346 | for distr in distribution_list: 347 | distr.prepare_samples(howmany=num_samples) 348 | for i in range(num_samples): 349 | nested_items = [(distr.param_name.split(OBJECT_SEPARATOR), distr.sample()) for distr in distribution_list] 350 | yield nested_to_dict(nested_items) 351 | 352 | from pathlib2 import Path 353 | home = str(Path.home()) 354 | 355 | def mkdtemp(prefix='cluster_utils', suffix=''): 356 | new_prefix = prefix + ('' if not suffix else '-' + suffix + '-') 357 | return tempfile.mkdtemp(prefix=new_prefix, dir=os.path.join(home, '.cache')) 358 | 359 | 360 | def temp_directory(prefix='cluster_utils', suffix=''): 361 | new_prefix = prefix + ('' if not suffix else '-' + suffix + '-') 362 | return tempfile.TemporaryDirectory(prefix=new_prefix, dir=os.path.join(home, '.cache')) 363 | 364 | 365 | class ParamDict(dict): 366 | """ An immutable dict where elements can be accessed with a dot""" 367 | __getattr__ = dict.__getitem__ 368 | 369 | def __delattr__(self, item): 370 | raise TypeError("Setting object not mutable after settings are fixed!") 371 | 372 | def __setattr__(self, key, value): 373 | raise TypeError("Setting object not mutable after settings are fixed!") 374 | 375 | def __setitem__(self, key, value): 376 | raise TypeError("Setting object not mutable after settings are fixed!") 377 | 378 | def __deepcopy__(self, memo): 379 | """ In order to support deepcopy""" 380 | return ParamDict([(deepcopy(k, memo), deepcopy(v, memo)) for k, v in self.items()]) 381 | 382 | def __repr__(self): 383 | return json.dumps(self, indent=4, sort_keys=True) 384 | 385 | 386 | def recursive_objectify(nested_dict): 387 | "Turns a nested_dict into a nested ParamDict" 388 | result = deepcopy(nested_dict) 389 | for k, v in result.items(): 390 | if isinstance(v, collections.Mapping): 391 | result[k] = recursive_objectify(v) 392 | return ParamDict(result) 393 | 394 | 395 | class SafeDict(dict): 396 | """ A dict with prohibiting init from a list of pairs containing duplicates""" 397 | def __init__(self, *args, **kwargs): 398 | if args and args[0] and not isinstance(args[0], dict): 399 | keys, _ = zip(*args[0]) 400 | duplicates =[item for item, count in collections.Counter(keys).items() if count > 1] 401 | if duplicates: 402 | raise TypeError("Keys {} repeated in json parsing".format(duplicates)) 403 | super().__init__(*args, **kwargs) 404 | 405 | 406 | def load_json(file): 407 | """ Safe load of a json file (doubled entries raise exception)""" 408 | with open(file, 'r') as f: 409 | data = json.load(f, object_pairs_hook=SafeDict) 410 | return data 411 | 412 | 413 | def update_recursive(d, u, defensive=False): 414 | for k, v in u.items(): 415 | if defensive and k not in d: 416 | raise KeyError("Updating a non-existing key") 417 | if isinstance(v, collections.Mapping): 418 | d[k] = update_recursive(d.get(k, {}), v) 419 | else: 420 | d[k] = v 421 | return d 422 | 423 | 424 | def save_settings_to_json(setting_dict, model_dir): 425 | filename = os.path.join(model_dir, JSON_SETTINGS_FILE) 426 | with open(filename, 'w') as file: 427 | file.write(json.dumps(setting_dict, sort_keys=True, indent=4)) 428 | 429 | 430 | def save_metrics_params(metrics, params, save_dir=None): 431 | if save_dir is None: 432 | save_dir = params.model_dir 433 | create_dir(save_dir) 434 | save_settings_to_json(params, save_dir) 435 | 436 | param_file = os.path.join(save_dir, CLUSTER_PARAM_FILE) 437 | flattened_params = dict(flatten_nested_string_dict(params)) 438 | save_dict_as_one_line_csv(flattened_params, param_file) 439 | 440 | time_elapsed = time.time() - update_params_from_cmdline.start_time 441 | if 'time_elapsed' not in metrics.keys(): 442 | metrics['time_elapsed'] = time_elapsed 443 | else: 444 | warn('\'time_elapsed\' metric already taken. Automatic time saving failed.') 445 | metric_file = os.path.join(save_dir, CLUSTER_METRIC_FILE) 446 | save_dict_as_one_line_csv(metrics, metric_file) 447 | 448 | 449 | def is_json_file(cmd_line): 450 | try: 451 | return os.path.isfile(cmd_line) 452 | except Exception as e: 453 | warn('JSON parsing suppressed exception: ', e) 454 | return False 455 | 456 | 457 | def is_parseable_dict(cmd_line): 458 | try: 459 | res = ast.literal_eval(cmd_line) 460 | return isinstance(res, dict) 461 | except Exception as e: 462 | warn('Dict literal eval suppressed exception: ', e) 463 | return False 464 | 465 | 466 | def update_params_from_cmdline(cmd_line=None, default_params=None, custom_parser=None, verbose=True): 467 | """ Updates default settings based on command line input. 468 | 469 | :param cmd_line: Expecting (same format as) sys.argv 470 | :param default_params: Dictionary of default params 471 | :param custom_parser: callable that returns a dict of params on success 472 | and None on failure (suppress exceptions!) 473 | :param verbose: Boolean to determine if final settings are pretty printed 474 | :return: Immutable nested dict with (deep) dot access. Priority: default_params < default_json < cmd_line 475 | """ 476 | if not cmd_line: 477 | cmd_line = sys.argv 478 | 479 | if default_params is None: 480 | default_params = {} 481 | 482 | if len(cmd_line) < 2: 483 | cmd_params = {} 484 | elif custom_parser and custom_parser(cmd_line): # Custom parsing, typically for flags 485 | cmd_params = custom_parser(cmd_line) 486 | elif len(cmd_line) == 2 and is_json_file(cmd_line[1]): 487 | cmd_params = load_json(cmd_line[1]) 488 | elif len(cmd_line) == 2 and is_parseable_dict(cmd_line[1]): 489 | cmd_params = ast.literal_eval(cmd_line[1]) 490 | else: 491 | raise ValueError('Failed to parse command line') 492 | 493 | update_recursive(default_params, cmd_params) 494 | 495 | if JSON_FILE_KEY in default_params: 496 | json_params = load_json(default_params[JSON_FILE_KEY]) 497 | if 'default_json' in json_params: 498 | json_base = load_json(json_params[JSON_FILE_KEY]) 499 | else: 500 | json_base = {} 501 | update_recursive(json_base, json_params) 502 | update_recursive(default_params, json_base) 503 | 504 | update_recursive(default_params, cmd_params) 505 | final_params = recursive_objectify(default_params) 506 | if verbose: 507 | print(final_params) 508 | 509 | update_params_from_cmdline.start_time = time.time() 510 | return final_params 511 | 512 | update_params_from_cmdline.start_time = None 513 | --------------------------------------------------------------------------------