├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── quinine ├── __init__.py ├── common │ ├── argparse.py │ ├── cerberus.py │ ├── gin.py │ ├── sweep.py │ └── utils.py ├── examples │ ├── base.yaml │ ├── config.yaml │ ├── run.py │ └── simple.py ├── quinfig.py └── tests │ ├── base.yaml │ ├── bugs │ ├── test-2.yaml │ └── test.yaml │ ├── derived-1-1.yaml │ ├── derived-1-2.yaml │ ├── derived-1.yaml │ └── derived-2.yaml ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | build/ 4 | dist/ 5 | quinine.egg-info/ 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Karan Goel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Quinine 2 | Quinine is a no-nonsense, feature-rich library to create and manage configuration files (called _quinfigs_). 3 | 4 | It's especially well suited to machine learning projects (designed by an ML PhD student @ Stanford aka me) where, 5 | - the number of hyperparameters can be quite large and are naturally nested 6 | - projects are always expanding, so hyperparameters grow 7 | - complicated manual hyperparameter sweeps are the norm 8 | 9 | ## Installation 10 | Install using pip, 11 | ```shell script 12 | pip install quinine 13 | ``` 14 | For the latest version, 15 | ```shell script 16 | pip install git+https://github.com/krandiash/quinine.git --upgrade 17 | ``` 18 | 19 | ## Features 20 | Quinine is simple, powerful and extensible: let's go over all of the features with lots of examples. 21 | 22 | ### Configuration in YAML 23 | Configs are called Quinfigs. The most basic thing you can do is to create a _Quinfig_ using a yaml file. 24 | 25 | Here's an example where we use a `config.yaml` file to create a Quinfig. The only rule is you can't prefix any key 26 | with the `~` character because we'll use that for sweeps. 27 | 28 | #### **`config.yaml`** 29 | ```yaml 30 | general: 31 | seed: 2 32 | module: test.py 33 | 34 | model: 35 | pretrained: true 36 | ~architecture: resnet50 # <-- no ~ prefix allowed! 37 | architecture: resnet50 # <-- do this instead! 38 | 39 | dataset: 40 | - name: cifar10 41 | - name: imagenet 42 | ``` 43 | 44 | #### **`main.py`** 45 | ```python 46 | # Import the Quinfig class from quinine 47 | from quinine import Quinfig 48 | 49 | # Use the Quinfig class to create a quinfig 50 | # ...you can just pass in the path to the yaml 51 | quinfig = Quinfig(config_path='path/to/config.yaml') 52 | 53 | # Access parameters as keys 54 | assert quinfig['general']['seed'] == 2 55 | 56 | # or use dot access, making your code cleaner 57 | assert quinfig.general.seed == 2 58 | 59 | # dot access works to arbitrary levels of nesting, including through lists 60 | assert quinfig.dataset[0].name == 'cifar10' 61 | 62 | # you can also create a Quinfig directly from a dictionary 63 | quinfig = Quinfig(config={'key': 'value'}) 64 | ``` 65 | 66 | YAMLs are great for writing large, nested configs cleanly, and provide a nice separation from your code. This configuration workflow 67 | (feed yaml to python script) is pretty popular, and if all you wanted was that, Quinine has you covered. 68 | 69 | Read on to see more! 70 | 71 | ### Inheritance for configs 72 | A common use-case in machine learning is performing sweeps or variants on an experiment. 73 | It's often convenient to have to specify _only_ the parameters that need to be changed from some 'base' or template configs. 74 | 75 | 76 | Quinine provides support for inheritance, using the special `inherit` key in the config. 77 | 78 | Here's an example, where we 79 | - first specify a base config called `grandparent.yaml`, 80 | - inherit this config in `parent.yaml` and change a single parameter, 81 | - then inherit _that_ in `config.yaml`, changing another parameter. 82 | 83 | 84 | #### **`grandparent.yaml`** 85 | ```yaml 86 | general: 87 | seed: 2 88 | module: test.py 89 | 90 | model: 91 | pretrained: true 92 | architecture: resnet50 93 | 94 | dataset: 95 | - name: cifar10 96 | - name: imagenet 97 | ``` 98 | 99 | #### **`parent.yaml (how you write it)`** 100 | ```yaml 101 | inherit: path/to/grandparent.yaml 102 | 103 | # Overwrites the dataset configuration in grandparent.yaml to only train on CIFAR-10 104 | dataset: 105 | - name: cifar10 106 | 107 | # All other configuration options are inherited from grandparent.yaml 108 | ``` 109 | 110 | #### **`parent.yaml (how it actually is)`** 111 | ```yaml 112 | inherit: path/to/grandparent.yaml 113 | 114 | general: 115 | seed: 2 116 | module: test.py 117 | 118 | model: 119 | pretrained: true 120 | architecture: resnet50 121 | 122 | dataset: 123 | - name: cifar10 124 | ``` 125 | 126 | #### **`config.yaml (how you write it)`** 127 | ```yaml 128 | inherit: path/to/parent.yaml 129 | 130 | # Overwrites the model configuration in parent.yaml (which equals its value in grandparent.yaml) to set pretrained to False 131 | model: 132 | pretrained: false 133 | 134 | # All other configuration options are inherited from parent.yaml 135 | ``` 136 | 137 | #### **`config.yaml (how it actually is)`** 138 | ```yaml 139 | inherit: path/to/parent.yaml 140 | 141 | general: 142 | seed: 2 143 | module: test.py 144 | 145 | model: 146 | pretrained: false 147 | architecture: resnet50 148 | 149 | dataset: 150 | - name: cifar10 151 | ``` 152 | 153 | #### **`main.py`** 154 | ```python 155 | # Nothing special needed: just create a quinfig normally 156 | quinfig = Quinfig(config_path='path/to/config.yaml') 157 | 158 | # and things will be resolved correctly 159 | assert quinfig.model.pretrained == False 160 | assert quinfig.model.architecture == 'resnet50' 161 | ``` 162 | 163 | You can also inherit from multiple configs simultaneously (later configs take precedence). Here's an example, 164 | 165 | #### **`config.yaml`** 166 | ```yaml 167 | inherit: 168 | - path/to/parent_1.yaml 169 | - path/to/parent_2.yaml 170 | - path/to/parent_3.yaml # later parameters take precedence 171 | 172 | general: 173 | seed: 2 174 | module: test.py 175 | 176 | model: 177 | pretrained: false 178 | architecture: resnet50 179 | 180 | dataset: 181 | - name: cifar10 182 | ``` 183 | 184 | 185 | ### Cerberus schemas for validation 186 | A nice-to-have feature is the ability to validate your config file against a schema. 187 | 188 | If you've used `argparse` to ever configure your scripts, you've been doing this already. In a nutshell, 189 | the schema lets you specify what hyperparameters the program will accept and if you pass in something that's 190 | unexpected (e.g. architectur instead of architecture), it'll catch the error (that's called _schema validation_). 191 | 192 | Quinine uses an external library called `Cerberus` to support schema validation for your config files. 193 | Cerberus is great, but it has a bit of a learning curve and a lot of features you'll never actually use. 194 | So to make things easy, Quinine comes with syntactic sugar that will help you write schemas very quickly. 195 | All the functionality available in Cerberus is supported, 196 | but most scenarios are covered with the syntatic sugar provided. 197 | 198 | Another reason to use schemas: you can mark parameters as required, specify defaults or choices for the parameter's values. 199 | 200 | -- 201 | 202 | ```python 203 | from quinine import Quinfig, tstring, tboolean, tinteger, stdict, stlist, default, nullable, required 204 | from funcy import merge 205 | # You should write schemas in Python for reusability (recommended) 206 | 207 | # The model schema contains a single 'pretrained' bool parameter that is required 208 | model_schema = {'pretrained': merge(tboolean, required)} 209 | 210 | # The schema for a single dataset contains its name 211 | dataset_schema = {'name': tstring} 212 | 213 | # The general schema consists of the seed (defaults to 0) and a module name (defaults to None) 214 | general_schema = {'seed': merge(tinteger, default(0)), 215 | 'module': merge(tstring, nullable, default(None))} 216 | 217 | # The overall schema is composed of these three reusable schemas 218 | # Notice that you don't need to provide a schema for templating, Quinine will take care of that 219 | schema = {'general': stdict(general_schema), 220 | 'model': stdict(model_schema), 221 | 'dataset': stlist(dataset_schema)} 222 | 223 | # Just pass in the schema while instantiating the Quinfig: validation happens automatically 224 | quinfig = Quinfig(config_path='path/to/config.yaml', schema=schema) 225 | 226 | # You could also define schemas in YAML, but we recommend using Python to take advantage of the syntactic sugar 227 | quinfig = Quinfig(config_path='path/to/config.yaml', schema_path='path/to/schema') 228 | ``` 229 | 230 | ### QuinineArgumentParser: Override Command-Line Arguments 231 | Quinine also comes with an argument parser that can be used to perform command-line 232 | overrides on top of arguments specified in a config `.yaml` file. 233 | 234 | ```python 235 | from quinine import QuinineArgumentParser 236 | parser = QuinineArgumentParser(schema=your_schema) # a schema is necessary if you want to override command-line arguments 237 | quinfig = parser.parse_quinfig() 238 | # Do stuff 239 | ``` 240 | 241 | To use this, you can run 242 | ```shell script 243 | # Load config from `your_config.yaml` and override `nested_arg.nesting.parameter` with 244 | # a new value = 'abc' 245 | > python your_file.py --config your_config.yaml --nested_arg.nesting.parameter abc 246 | # ...and so on 247 | > python your_file.py --config your_config.yaml --arg1 2 --arg2 'abc' --nested.arg a 248 | ``` 249 | 250 | Note that `your_config.yaml` can inherit from an arbitrary number of configs. 251 | 252 | ### QuinSweeps: YAML Sweeping on Steroids 253 | Quinine has a _very_ powerful syntax for sweeps. One of the problems this aims to address is that 254 | it's often convenient to write sweeps in Python, because you can use operations such as products, zips and chains. 255 | But it's ugly and cumbersome to manage parameters in Python and I personally like having the separation that YAML provides. 256 | 257 | With Quinine, you can write complex sweeps with nested logic without leaving the comfort of your YAML file. 258 | 259 | Quinine will not actually run or manage your swept runs or do 'smart' hyperparameter optimization (hyperband-style). 260 | 261 | We'll go through a few examples to see how this works. 262 | 263 | Scenario: sweep over 4 learning rates 264 | ```yaml 265 | # This YAML specifies fixed values for all but one parameter: 266 | # optimizer.learning_rate takes on 4 values. 267 | model: 268 | pretrained: false 269 | architecture: resnet50 270 | 271 | optimizer: 272 | learning_rate: 273 | # Sweep over 4 separate learning rates 274 | ~disjoint: # you could also have used the ~product key here -- note the use of the special ~ character 275 | - 0.01 276 | - 0.001 277 | - 0.0001 278 | - 0.00001 279 | scheduler: cosine 280 | ``` 281 | 282 | ```python 283 | from quinine import QuinSweep 284 | 285 | # Generate a QuinSweep using this YAML 286 | quinsweep = QuinSweep(sweep_config_path='path/to/sweep_config.yaml') 287 | 288 | # Index into the quinsweep to get the i^th Quinfig 289 | i = 3 290 | quinfig_3 = quinsweep[3] # quinfig_i sets learning_rate to 0.00001 291 | 292 | # Iterate over the quinsweep 293 | for quinfig in quinsweep: 294 | # Do something with the quinfig (e.g. run a job) 295 | your_fn_that_does_something(quinfig) 296 | ``` 297 | 298 | Scenario: sweep over 4 distinct parameter settings that specify learning rate and architecture 299 | ```yaml 300 | model: 301 | pretrained: false 302 | architecture: 303 | # Sweep over 4 separate architectures 304 | ~disjoint: 305 | - resnet18 306 | - resnet50 307 | - vgg19 308 | - inceptionv3 309 | 310 | optimizer: 311 | learning_rate: 312 | # Sweep over 4 separate learning rates 313 | ~disjoint: 314 | - 0.01 315 | - 0.001 316 | - 0.0001 317 | - 0.00001 318 | scheduler: cosine 319 | ``` 320 | 321 | Scenario: sweep over all possible combinations of 4 learning rates and 4 architectures 322 | ```yaml 323 | model: 324 | pretrained: false 325 | architecture: 326 | # Sweep over 4 separate learning rates 327 | ~product: 328 | - resnet18 329 | - resnet50 330 | - vgg19 331 | - inceptionv3 332 | 333 | optimizer: 334 | learning_rate: 335 | # Sweep over 4 separate learning rates 336 | ~product: 337 | - 0.01 338 | - 0.001 339 | - 0.0001 340 | - 0.00001 341 | scheduler: cosine 342 | ``` 343 | 344 | Scenario: sweep over all possible combinations of 4 learning rates and 4 architectures and if architecture is resnet50, 345 | additionally sweep over 2 learning rate schedulers 346 | ```yaml 347 | model: 348 | pretrained: false 349 | architecture: 350 | # Sweep over 4 separate learning rates 351 | ~product: 352 | - resnet18 353 | - resnet50 354 | - vgg19 355 | - inceptionv3 356 | 357 | optimizer: 358 | learning_rate: 359 | # Sweep over 4 separate learning rates 360 | ~product: 361 | - 0.01 362 | - 0.001 363 | - 0.0001 364 | - 0.00001 365 | scheduler: 366 | # By default use the cosine scheduler 367 | ~default: cosine 368 | ~disjoint: 369 | # But, when architecture takes on index 1 (i.e. resnet50), sweep over 2 parameters 370 | architecture.1: 371 | - cosine 372 | - linear 373 | ``` 374 | 375 | ### Gin for sophisticated configuration 376 | `Gin` is a feature-rich configuration library that gives users the ability to directly force a function argument 377 | in their code to take on some value. 378 | 379 | This can be especially useful when configuration files have nested dependencies: 380 | e.g. consider a config with an `optimizer` key that dictates which optimizer is built and used. 381 | Each optimizer (e.g. SGD or Adam) has its own configuration options (e.g. momentum for SGD or beta_1, beta_2 for Adam). 382 | 383 | With gin, you avoid having to create a schema that specifies every parameter for every possible optimizer in your 384 | config file (and/or writing boilerplate code to parse all of this). 385 | 386 | Instead, you can mark functions as gin configurable (e.g. torch.optim.Adam and torch.optim.SGD) and 387 | simply set the arguments for the one you'll be using, directly in the config e.g. `torch.optim.Adam.beta_1 = 0.5`. 388 | When you need to use the optimizer, just use `torch.optim.Adam()` (and gin will take care of specifying the parameters). 389 | No need to parse this gin configuration manually! 390 | 391 | Quinine provides a thin wrapper on gin that allows users to perform gin configuration in YAML, 392 | without having to commit to gin completely (which can be cumbersome). 393 | 394 | With Quinine you can choose not to perform any gin configuration, use it a only a little or even use gin only, 395 | all from the convenience of YAML. 396 | 397 | Secondly, you can make your codebase gin configurable without having to manually decorate every function as `@gin.configurable`. 398 | This lets you switch to/away from gin without any hassles. 399 | 400 | 401 | ### About 402 | If you use `quinine` in a research paper, please use the following BibTeX entry 403 | ``` 404 | @misc{Goel2021, 405 | author = {Karan Goel}, 406 | title = {Quinine: Configuration for Machine Learning Projects}, 407 | year = {2021}, 408 | publisher = {GitHub}, 409 | journal = {GitHub repository}, 410 | howpublished = {\url{https://github.com/krandiash/quinine}}, 411 | } 412 | ``` 413 | 414 | ### Acknowledgments 415 | Thanks to Tri Dao and Albert Gu for initial discussions that led to the development 416 | of `quinine`, as well as Kabir Goel, Shreya Rajpal, Laurel Orr and Sidd 417 | Karamcheti for providing valuable feedback. 418 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .quinine.quinfig import Quinfig -------------------------------------------------------------------------------- /quinine/__init__.py: -------------------------------------------------------------------------------- 1 | from .quinfig import Quinfig 2 | from .common.sweep import QuinSweep 3 | from .common.argparse import QuinineArgumentParser 4 | from .common.cerberus import tstring, tinteger, tfloat, tboolean, tlist, tdict, required, nullable, \ 5 | default, excludes, schema, allowed, stlist, stdict -------------------------------------------------------------------------------- /quinine/common/argparse.py: -------------------------------------------------------------------------------- 1 | import sys as _sys 2 | from argparse import ArgumentParser 3 | from distutils.util import strtobool 4 | 5 | import cytoolz as tz 6 | from funcy import * 7 | 8 | from quinine.common.cerberus import tstring, tlist 9 | from quinine.common.utils import rmerge 10 | from quinine.quinfig import Quinfig 11 | from quinine.quinfig import get_all_leaf_paths 12 | 13 | 14 | class QuinineArgumentParser(ArgumentParser): 15 | """ 16 | Class for replacing the standard argparse.ArgumentParser. 17 | 18 | In addition to the standard functionality of ArgumentParser, 19 | - includes an argument for '--config', which passes in a YAML file containing the configuration parameters 20 | - automatically adds arguments to the ArgumentParser from a provided schema, 21 | allowing you to override arguments in your YAML directly from the command line 22 | """ 23 | types = {'string': str, 24 | 'integer': int, 25 | 'float': float, 26 | 'dict': dict, 27 | 'boolean': lambda x: bool(strtobool(x)), 28 | 'list': list, 29 | } 30 | 31 | def __init__(self, schema=None, **kwargs): 32 | super(QuinineArgumentParser, self).__init__(**kwargs) 33 | 34 | # Add a default argument for the path to the YAML configuration file that will be passed in 35 | self.add_argument('--config', type=str, required=True, help="YAML configuration file.") 36 | self.schema = schema 37 | 38 | if self.schema is not None: 39 | # Populate the argument parser with arguments from the schema 40 | paths_to_type = list(filter(lambda l: l[-1] == 'type', get_all_leaf_paths(self.schema))) 41 | type_lookup = dict([(tuple(filter(lambda e: e != 'schema', e[:-1])), tz.get_in(e, schema)) 42 | for e in paths_to_type]) 43 | 44 | valid_params = self.get_all_params(schema) 45 | for param in valid_params: 46 | self.add_argument(f'--{".".join(param)}', 47 | type=self.types[type_lookup[param]]) 48 | 49 | self.schema = merge(self.schema, {'config': tstring, 'inherit': tlist}) 50 | 51 | @staticmethod 52 | def get_all_params(schema): 53 | # Find all leaf paths in the schema, then truncate the last key from each path 54 | # and remove the 'schema' key if it occurs anywhere in the path 55 | # TODO: expand the list of criteria in the inner lambda 56 | candidate_parameters = list(set(map(lambda l: tuple(filter(lambda e: e != 'schema' and e != 'allowed', l[:-1])), 57 | get_all_leaf_paths(schema)) 58 | ) 59 | ) 60 | 61 | # Remove prefix paths from the candidate parameters, 62 | # e.g. when ['general', 'seed'] and ['general'] both occur, remove ['general'] 63 | valid_parameters = set() 64 | all_subpaths = set() 65 | for path in sorted(candidate_parameters, key=lambda l: len(l), reverse=True): 66 | # If the path isn't in the set of subpaths seen so far, it's a valid param (because paths are sorted) 67 | if path not in all_subpaths: 68 | valid_parameters.add(path) 69 | # Add all subpaths in for this 70 | for i in range(1, len(path)): 71 | all_subpaths.add(path[:i]) 72 | 73 | return valid_parameters 74 | 75 | def parse_quinfig(self): 76 | # Parse all the arguments from the command line, overriding defaults in the argparse 77 | args = self.parse_args() 78 | 79 | cli_keys = [] 80 | for cli_arg in _sys.argv[1:]: 81 | if cli_arg.startswith('--'): 82 | if str(cli_arg) == '--config': 83 | continue 84 | cli_keys.append(cli_arg[2:].replace("-", "_")) 85 | elif cli_arg.startswith('-'): 86 | raise NotImplementedError("QuinineArgumentParser doesn't support abbreviated arguments.") 87 | else: 88 | continue 89 | 90 | # Get parameters which need to be overridden from command line 91 | override_args = project(args.__dict__, cli_keys) 92 | 93 | # Trick: first load the config without a schema as a base config 94 | # quinfig = Quinfig(config_path=args.config) 95 | 96 | # Override all the defaults using the yaml config 97 | # quinfig = rmerge(Quinfig(config=args.__dict__), quinfig) 98 | # print(quinfig) 99 | 100 | # Use all the defaults in args to populate a dictionary 101 | quinfig = {} 102 | for param, val in args.__dict__.items(): 103 | param_path = param.split(".") 104 | quinfig = tz.assoc_in(quinfig, param_path, val) 105 | 106 | # Override all the defaults using the yaml config 107 | quinfig = rmerge(quinfig, Quinfig(config_path=args.config)) 108 | 109 | # Replace all the arguments passed into command line 110 | if len(override_args) > 0: 111 | print(f"\n(quinine) Overriding parameters in {args.config} from command line (___ is unspecified).") 112 | 113 | for param, val in override_args.items(): 114 | param_path = param.split(".") 115 | old_val = tz.get_in(param_path, quinfig) 116 | 117 | if old_val != val: 118 | print(f"> ({param}): {old_val} --> {val}") 119 | else: 120 | print(f"> ({param}): ___ --> {val}") 121 | quinfig = tz.assoc_in(quinfig, param_path, val) 122 | 123 | # Load the config again 124 | return Quinfig(config=quinfig, schema=self.schema) 125 | -------------------------------------------------------------------------------- /quinine/common/cerberus.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import cytoolz.curried as tz 5 | from cerberus import schema_registry, Validator 6 | from quinine.common.utils import * 7 | from quinine.common.gin import nested_scope_datagroup_gin_dict 8 | 9 | tstring = {'type': 'string'} 10 | tinteger = {'type': 'integer'} 11 | tfloat = {'type': 'float'} 12 | tboolean = {'type': 'boolean'} 13 | tlist = {'type': 'list'} 14 | tlistorstring = {'type': ['list', 'string']} 15 | tdict = {'type': 'dict'} 16 | 17 | required = {'required': True} 18 | nullable = {'nullable': True} 19 | default = lambda d: {'default': d} 20 | excludes = lambda e: {'excludes': e} 21 | schema = lambda s: {'schema': s} 22 | allowed = lambda a: {'allowed': a} 23 | 24 | stlist = lambda s: merge(tlist, schema(merge(tdict, schema(s)))) 25 | stlistorstring = lambda s: merge(tlistorstring, schema(merge(tdict, schema(s)))) 26 | stdict = lambda s: merge(tdict, schema(s)) 27 | 28 | 29 | def create_and_register_schemas(): 30 | # A schema for gin data 31 | gin = {'gin': tdict} 32 | 33 | # Loader schema for creating e.g. DataLoaders in torch 34 | loader = {'loader': stdict({ 35 | 'strategy': tstring, 36 | 'batch_size': tinteger, 37 | 'num_workers': tinteger, 38 | 'seed': merge(tinteger, default(0)), 39 | })} 40 | 41 | # A standard block for data loaders that is used to create a hierarchical schema for dataflows 42 | datablock = {'source': merge(tstring, allowed(['torchvision', 'imagefolder', 'tfds', 'tfrecord'])), 43 | 'name': tstring, 44 | 'shuffle_and_split': tboolean, 45 | 'load_path': tstring, 46 | 'seed': tinteger, 47 | 'heads': merge(tlist, nullable, default(None)), 48 | } 49 | datablock = merge(datablock, loader, gin) 50 | 51 | # Schema for a group: a group is like a slice of data from a particular dataset 52 | datagroup = merge(datablock, 53 | {'split_cmd': merge(tstring, required), 54 | 'alias': merge(tstring, required), 55 | }) 56 | datagroups = {'groups': stlist(datagroup)} 57 | 58 | # A datasets schema for building a list of datasets, with each dataset containing multiple groups 59 | datasets = {'datasets': stlist(merge(datablock, datagroups))} 60 | 61 | # A dataflow: complete description of all the datasets and groups, their loaders as well as where they are used 62 | dataflow = merge(datablock, 63 | {'train': stdict(merge(datablock, 64 | datasets, 65 | {'loader_level': merge(tstring, 66 | required, 67 | allowed(['all', 'dataset', 'group']))})), 68 | 'val': stdict(merge(datablock, datasets, 69 | {'loader_level': merge(tstring, 70 | default('group'), 71 | allowed(['all', 'dataset', 'group']))})), 72 | 'test': stdict(merge(datablock, datasets, 73 | {'loader_level': merge(tstring, 74 | default('group'), 75 | allowed(['all', 'dataset', 'group']))})) 76 | }) 77 | 78 | # Schema for the duration of training 79 | duration = {'epochs': merge(tinteger, required, excludes('steps')), 80 | 'steps': merge(tinteger, required, excludes('epochs'))} 81 | 82 | # Schema for the optimizer 83 | optimizer = {'optimizer': stdict(merge({'name': merge(tstring, required)}, gin))} 84 | 85 | # Schema for the learning rate scheduler 86 | lr = {'lr': stdict(merge({'scheduler': merge(tstring, required)}, gin))} 87 | 88 | # Trainer: complete description of how to update the model 89 | trainer = merge(duration, optimizer, lr, gin) 90 | 91 | # Schema for constructing the model 92 | model = {'source': merge(tstring, required, allowed(['torchvision', 'torch', 'cm', 'keras'])), 93 | 'architecture': merge(tstring, required), 94 | 'finetuning': merge(tboolean, required), 95 | 'pretrained': merge(tboolean, required), 96 | 'heads': merge(tlist, required), 97 | } 98 | model = merge(model, gin) 99 | 100 | # Schema for the checkpointer 101 | checkpointer = {'ckpt_path': merge(tstring, nullable, default(None)), 102 | 'save_freq': merge(tinteger, required)} 103 | checkpointer = merge(checkpointer, gin) 104 | 105 | # Schema for resuming model training from Weights and Biases 106 | wandb_resumer = {'resume': merge(tboolean, default(False)), 107 | 'run_id': merge(tstring, default(None), nullable), 108 | 'project': merge(tstring, default(None), nullable), 109 | 'entity': merge(tstring, default(None), nullable), 110 | } 111 | wandb_resumer = merge(wandb_resumer, gin) 112 | 113 | # Schema for Weights and Biases 114 | wandb = {'entity': merge(tstring, required), 115 | 'project': merge(tstring, required), 116 | 'group': merge(tstring, default('default')), 117 | 'job_type': merge(tstring, default('training')), 118 | 'ckpt_dir': merge(tstring, default('checkpoints')), 119 | 'dryrun': merge(tboolean, default(False)), 120 | } 121 | wandb = merge(wandb, gin) 122 | 123 | # Schema for general settings 124 | general = {'seed': merge(tinteger, default(0)), 125 | 'module': merge(tstring, required), 126 | } 127 | 128 | # Schema for Kubernetes 129 | kubernetes = {} 130 | 131 | # Schema for templating 132 | templating = {'parent': merge(tstring, nullable, default(None))} 133 | 134 | # Schema for inheritance 135 | inherit = stlistorstring(merge(tstring, nullable, default(None))) 136 | 137 | # Collect all the schemas that are reusable 138 | schemas = { 139 | 'general': general, 140 | 141 | 'duration': duration, 142 | 'optimizer': optimizer, 143 | 'lr': lr, 144 | 'trainer': trainer, 145 | 146 | 'checkpointer': checkpointer, 147 | 'wandb_resumer': wandb_resumer, 148 | 149 | 'model': model, 150 | 151 | 'kubernetes': kubernetes, 152 | 'templating': templating, 153 | 'inherit': inherit, 154 | 155 | 'loader': loader, 156 | 'datablock': datablock, 157 | 'datagroup': datagroup, 158 | 'datagroups': datagroups, 159 | 'datasets': datasets, 160 | 'dataflow': dataflow, 161 | 162 | 'gin': gin, 163 | 164 | 'wandb': wandb, 165 | } 166 | 167 | # Register the schemas 168 | register_schemas(*list(zip(*schemas.items())), verbose=False) 169 | 170 | return schemas 171 | 172 | 173 | def register_schemas(schema_names, schemas, verbose=True): 174 | """ 175 | Register a list of schemas, with corresponding names. 176 | """ 177 | # Register the schemas 178 | list(map(lambda n, s: schema_registry.add(n, s), schema_names, schemas)) 179 | 180 | if verbose: 181 | # Print 182 | print("Registered schemas in Cerberus: ") 183 | list(map(lambda n: print(f'- {n}'), schema_names)) 184 | 185 | 186 | def register_yaml_schemas(path): 187 | """ 188 | Register all schemas located in a directory. 189 | Schemas are assumed to be defined in yaml files at path. 190 | """ 191 | # Get the schema files 192 | schema_files = glob.glob(os.path.join(path, '*')) 193 | schema_names = list(map(lambda f: os.path.basename(f).replace(".yaml", ""), schema_files)) 194 | schemas = list(map(compose(autocurry(yaml.load)(Loader=yaml.FullLoader), open), schema_files)) 195 | 196 | # Register them 197 | register_schemas(schema_names, schemas) 198 | 199 | 200 | def normalize_config(config, schema=None, base_path=''): 201 | """ 202 | Execute a series of functions on the config that modify it. 203 | """ 204 | if schema: 205 | config = Validator(schema).normalized(config) 206 | config = resolve_templating(config) 207 | config = resolve_inheritance(config, base_path=base_path) 208 | if 'dataflow' in config: 209 | dataflow_config = propagate_parameters_to_datagroups(config.dataflow) 210 | config = set_in(config, ['dataflow'], dataflow_config) 211 | config = nested_scope_datagroup_gin_dict(config) 212 | return config 213 | 214 | 215 | def resolve_inheritance(config, base_path=''): 216 | """ 217 | Takes in a config and resolves any inheritance. 218 | If inheriting, the config will have information about one or more parent configs that should be overwritten 219 | (those configs may in turn inherit from others). 220 | This inheritance chain is resolved by recursively merging all the relevant configs. 221 | """ 222 | if 'inherit' not in config or ('inherit' in config and config['inherit'] is None): 223 | return config 224 | 225 | inherit_paths = [config['inherit']] if isinstance(config['inherit'], str) else config['inherit'] 226 | inherit_paths = [os.path.abspath(os.path.join(base_path, inherit_path)) for inherit_path in inherit_paths] 227 | config['inherit'] = inherit_paths 228 | 229 | inherit_configs = [ 230 | # Recurse to resolve inheritance for each inherited config 231 | resolve_inheritance( 232 | yaml.load(open(inherit_path), 233 | Loader=yaml.FullLoader), 234 | base_path=os.path.dirname(inherit_path), 235 | ) 236 | for inherit_path in inherit_paths 237 | ] 238 | 239 | config = rmerge(*inherit_configs, config) 240 | return config 241 | 242 | 243 | def resolve_templating(config, base_path=''): 244 | """ 245 | Takes in a config and resolves any templating. 246 | If templating, the config will have information about a parent config that it is overwriting 247 | (which in turn may itself be templating). 248 | This templating chain is resolved by recursively merging all the relevant configs. 249 | """ 250 | if 'templating' not in config or ('templating' in config and config['templating']['parent_yaml'] is None): 251 | return config 252 | 253 | append_parent = lambda l: [yaml.load(open(os.path.join(base_path, l[0]['templating']['parent_yaml'])), 254 | Loader=yaml.FullLoader)] + l 255 | construct_hierarchy = lambda l: ignore(errors=Exception, 256 | default=append_parent(l) 257 | )(construct_hierarchy)(append_parent(l)) 258 | config_hierarchy = construct_hierarchy([config]) 259 | config = rmerge(*config_hierarchy) 260 | return config 261 | 262 | 263 | def validate_config(config, schema): 264 | """ 265 | Check if a config file adheres to a schema. 266 | """ 267 | validator = Validator(schema) 268 | valid = validator.validate(config) 269 | if valid: 270 | return True 271 | else: 272 | print("CerberusError: config could not be validated against schema. The errors are,") 273 | print(validator.errors) 274 | exit() 275 | 276 | 277 | def expand_schema_for_gin_configuration(schema): 278 | """ 279 | Allows the schema to support gin configurability. 280 | The schema supports gin keys at any level of nesting. 281 | """ 282 | # Insert a 'gin' key into the dictionaries that don't contain a 'type' key 283 | predicate = lambda p: 'type' not in p 284 | 285 | # Merge a gin schema into the schema passed in, at every level of nesting 286 | return nested_dict_walker(iffy(predicate, 287 | lambda v: merge({'gin': tdict}, v)), 288 | schema) 289 | 290 | 291 | def expand_schema_for_inheritance(schema): 292 | """ 293 | Allows the schema to support inheritance. 294 | The schema supports configs that (optionally) point to zero or more parent YAML configs 295 | (using a path) that will be taken as base configurations to be overwritten. 296 | """ 297 | return merge({'inherit': stlistorstring(merge(tstring, 298 | nullable, 299 | default(None)) 300 | ) 301 | }, schema) 302 | 303 | 304 | def expand_schema_for_templating(schema): 305 | """ 306 | Allows the schema to support templating. 307 | The schema supports configs that (optionally) point to a parent YAML config (using a path) that is being overwritten. 308 | TODO: Deprecate. 309 | """ 310 | return merge({'templating': stdict({'parent_yaml': merge(tstring, 311 | nullable, 312 | default(None)) 313 | }) 314 | }, schema) 315 | 316 | 317 | def autoexpand_schema(schema): 318 | """ 319 | Automatically expands the schema to support 320 | - gin configuration 321 | - templating 322 | """ 323 | schema = expand_schema_for_gin_configuration(schema) 324 | schema = expand_schema_for_templating(schema) 325 | schema = expand_schema_for_inheritance(schema) 326 | return schema 327 | 328 | 329 | def propagate_parameters_to_datagroups(dataflow_config): 330 | """ 331 | Given a dataflow config, it is likely that parameters were defined 'globally' e.g. setting the dataset's source 332 | for all datasets and groups. This function propagates parameters from higher levels down to the lowest, group level. 333 | The propagation consolidates parameters at the following levels: 334 | - dataflow 335 | - train(/val/test) 336 | - datasets[i] 337 | - groups[j] 338 | into the parameters that are applicable to groups[j]. 339 | """ 340 | 341 | def construct_group_dict(group_path, config): 342 | """ 343 | Given a config and a path that points to a data group, compute the data group's updated parameters. 344 | The group_path is a list of keys and indices e.g. ['train', 'datasets', 1, 'groups', 0] 345 | that can be followed to reach a group's config. 346 | """ 347 | # Find (almost) all prefixes of the group path 348 | all_paths = list(map(compose(list, 349 | tz.take(seq=group_path)), 350 | range(1, len(group_path)) 351 | ) 352 | ) 353 | 354 | # Filter to exclude paths that point to lists 355 | paths_to_merge = list(filter(lambda p: isinstance(last(p[1]), str), 356 | pairwise(all_paths) 357 | ) 358 | ) 359 | # Find all the (mid-level) dicts that the filtered paths point to 360 | mid_level_dicts = list(map(lambda p: tz.keyfilter(lambda k: k != last(p[1]), 361 | tz.get_in(p[0], config)), 362 | paths_to_merge)) 363 | 364 | # Merge parameters at all levels to get a single parameter set for the group 365 | def dmerge(*args): 366 | if all(is_mapping, *args): 367 | return Munch(tz.merge(*args)) 368 | else: 369 | return tz.last(*args) 370 | 371 | group_dict = tz.merge_with( 372 | dmerge, 373 | tz.keyfilter(lambda k: k not in ['train', 'val', 'test'], config), # top-level dict 374 | *mid_level_dicts, # mid-level dicts 375 | tz.get_in(group_path, config) # bottom-level dict 376 | ) 377 | 378 | return group_dict 379 | 380 | def get_all_group_paths(config, following=()): 381 | """ 382 | Given a config, constructs paths to all the leaf nodes, truncating them one level below the 'groups' key. 383 | """ 384 | if isinstance(config, dict) or isinstance(config, Munch): 385 | if 'groups' in list(butlast(following)): 386 | return [[]] 387 | return list(cat(map(lambda t: list(map(lambda p: [t[0]] + p, 388 | get_all_group_paths(t[1], list(following) + [t[0]]) 389 | )), 390 | iteritems(config))) 391 | ) 392 | 393 | elif isinstance(config, list): 394 | return list(cat(map(lambda t: list(map(lambda p: [t[0]] + p, 395 | get_all_group_paths(t[1], list(following) + [t[0]]) 396 | )), 397 | enumerate(config))) 398 | ) 399 | else: 400 | return [[]] 401 | 402 | # Find all the group paths 403 | group_paths = list(filter(lambda p: 'groups' in p, 404 | get_all_group_paths(dataflow_config) 405 | ) 406 | ) 407 | 408 | # Construct the group dict for each group path 409 | group_dicts = list(map(lambda p: construct_group_dict(p, 410 | dataflow_config), 411 | group_paths) 412 | ) 413 | 414 | # Update the dataflow_config with all the group dicts 415 | updated_dataflow_config = compose(*list(map(autocurry(lambda p, d, c: set_in(c, p, d)), 416 | group_paths, 417 | group_dicts, 418 | ) 419 | ))(dataflow_config) 420 | 421 | return updated_dataflow_config 422 | 423 | 424 | if __name__ == '__main__': 425 | register_yaml_schemas(path='configs/schemas/base/') 426 | print() 427 | -------------------------------------------------------------------------------- /quinine/common/gin.py: -------------------------------------------------------------------------------- 1 | import gin 2 | from funcy import * 3 | from quinine.common.utils import nested_dict_walker, prefix 4 | 5 | 6 | def register_module_with_gin(module, module_name=None): 7 | """ 8 | Register all the callables in a single module with gin. 9 | 10 | A useful way to add gin configurability to a codebase without explicilty using the @gin.configurable decorator. 11 | """ 12 | module_name = module.__name__ if module_name is None else module_name 13 | 14 | for attr in dir(module): 15 | if callable(getattr(module, attr)): 16 | setattr(module, attr, gin.configurable(getattr(module, attr), module=module_name)) 17 | 18 | 19 | def scope_datagroup_gin_dict(coll): 20 | """ 21 | Rename the augmentations gin dict with the alias of the dataset. 22 | """ 23 | 24 | if 'alias' in coll and 'gin' in coll and is_mapping(coll['gin']): 25 | coll['gin'] = walk_keys(prefix(p=f"{coll['alias'].replace('.', '_')}/"), coll['gin']) 26 | return coll 27 | 28 | 29 | def nested_scope_datagroup_gin_dict(coll): 30 | """ 31 | Apply the renamer over a nested dict, e.g. derived from a yaml. 32 | """ 33 | return nested_dict_walker(scope_datagroup_gin_dict, coll) 34 | 35 | 36 | def gin_dict_parser(coll): 37 | """ 38 | Use for parsing collections that may contain a 'gin' key. 39 | The 'gin' key is assumed to map to either a dict or str value that contains gin bindings. 40 | e.g. 41 | {'gin': {'Classifier.n_layers': 2, 'Classifier.width': 3}} 42 | or 43 | {'gin': 'Classifier.n_layers = 2\nClassifier.width = 3'} 44 | """ 45 | if 'gin' in coll: 46 | if is_mapping(coll['gin']): 47 | gin.parse_config("".join(map(lambda t: f'{t[0]} = {t[1]}\n', iteritems(coll['gin'])))) 48 | elif isinstance(coll['gin'], str): 49 | gin.parse_config(coll['gin']) 50 | return coll 51 | 52 | 53 | def nested_gin_dict_parser(coll): 54 | """ 55 | Use for parsing nested collections that may contain a 'gin' key. 56 | The 'gin' key is assumed to map to a dict value that contains gin bindings (see gin_dict_parser). 57 | 58 | Enables support for gin keys in yaml files. 59 | """ 60 | return nested_dict_walker(gin_dict_parser, coll) 61 | -------------------------------------------------------------------------------- /quinine/common/sweep.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import re 3 | import os 4 | import typing as typ 5 | from collections import namedtuple 6 | from copy import deepcopy 7 | 8 | import cytoolz as tz 9 | import yaml 10 | from funcy import * 11 | from toposort import toposort 12 | 13 | from quinine.common.utils import get_only_paths, allequal 14 | from quinine.quinfig import Quinfig 15 | 16 | # Parameter = namedtuple('Parameter', 'path dotpath value') 17 | # SweptParameter = namedtuple('SweptParameter', 'path sweep') 18 | SweptDisjointParameter = namedtuple('SweptDisjointParameter', 'path disjoint') 19 | SweptProductParameter = namedtuple('SweptProductParameter', 'path product') 20 | SweptDefaultParameter = namedtuple('SweptDefaultParameter', 'path default') 21 | 22 | 23 | class Parameter: 24 | 25 | def __init__(self, path=None, dotpath=None, value=None): 26 | self.path = path 27 | self.dotpath = dotpath 28 | self.value = value 29 | 30 | def __eq__(self, other): 31 | if hasattr(other, 'dotpath'): 32 | return self.dotpath == other.dotpath 33 | else: 34 | return self.dotpath == other 35 | 36 | def __repr__(self): 37 | return f'Parameter(path={self.path}, dotpath={self.dotpath}, value={self.value})' 38 | 39 | 40 | class SweptParameter: 41 | 42 | def __init__(self, path=None, sweep=None): 43 | self.path = path 44 | self.sweep = sweep 45 | 46 | def __getitem__(self, item): 47 | if item == 0: 48 | return self.path 49 | elif item == 1: 50 | return self.sweep 51 | 52 | def __repr__(self): 53 | return f'SweptParameter(path={self.path}, sweep={self.sweep})' 54 | 55 | 56 | class QuinSweep: 57 | """ 58 | The QuinSweep class can be used to 59 | - process a parameter sweep, with support for complex conditional sweeps 60 | - generate Quinfigs for each setting of the parameter sweep 61 | 62 | e.g. 63 | 64 | """ 65 | # inside any parameter, specify a sweep with the special ~ character 66 | SWEEP_PREFIX = '~' 67 | 68 | # 'default' *must* be the last entry in this list 69 | SWEEP_TOKENS = ['disjoint', 'product'] + ['default'] 70 | 71 | @staticmethod 72 | def get_parameter_path(token_path: typ.List[str]) -> typ.List[str]: 73 | """ 74 | Given the path to a sweep token 75 | e.g. 76 | ['model', 'architecture', 'n_layers', '~product'] 77 | which points to the sweep token '~product' for the 'n_layers' parameter 78 | 79 | get_parameter_path returns the path to the parameter on which the sweep token was applied, 80 | i.e. 81 | ['model', 'architecture', 'n_layers']. 82 | """ 83 | return compose(tuple, butlast)(token_path) 84 | 85 | @staticmethod 86 | def parse_sweep_config(sweep_config) -> typ.List[typ.List[str]]: 87 | """ 88 | Takes in a sweep config and determines paths to the sweep tokens in the config. 89 | 90 | Uses the fact that all sweep tokens in the config are prefixed by the SWEEP_PREFIX (~) character. 91 | 92 | Each token path 93 | """ 94 | # To identify the swept parameters, 95 | # use the fact that any parameter sweep must use the special prefix (~ by default) 96 | token_paths = get_only_paths(sweep_config, 97 | pred=lambda p: any(lambda s: QuinSweep.SWEEP_PREFIX in str(s), p), 98 | stop_at=QuinSweep.SWEEP_PREFIX) 99 | 100 | # As output, we get a list of paths that point to locations of all the ~ tokens in the config 101 | token_paths = list(map(tuple, token_paths)) 102 | 103 | # Confirm that all the tokens followed by ~ are correctly specified 104 | all_tokens_ok = all(lambda s: s.startswith(QuinSweep.SWEEP_PREFIX) and 105 | s.strip(QuinSweep.SWEEP_PREFIX) in QuinSweep.SWEEP_TOKENS, 106 | map(last, token_paths) 107 | ) 108 | assert all_tokens_ok, f'Unknown token: sweep config failed parsing. ' \ 109 | f'Only tokens {QuinSweep.SWEEP_TOKENS} are allowed.' 110 | 111 | return token_paths 112 | 113 | @staticmethod 114 | def parse_fixed_parameters(sweep_config): 115 | # Extract the locations of the other non-swept, fixed parameters 116 | # use the fact that the other parameters must not be prefixed by the special prefix 117 | fixed_parameters = get_only_paths(sweep_config, 118 | pred=lambda p: all(lambda s: QuinSweep.SWEEP_PREFIX not in str(s), p)) 119 | 120 | # Make Parameter objects 121 | fixed_parameters = list(map(lambda tp: Parameter(tp, 122 | f'{".".join(map(str, tp))}.0', 123 | get_in(sweep_config, tp)), 124 | fixed_parameters) 125 | ) 126 | return fixed_parameters 127 | 128 | @staticmethod 129 | def path_to_dotpath(path: typ.List) -> str: 130 | """ 131 | Takes in any path 132 | e.g. 133 | ['model', 'architecture', 'n_layers'] 134 | and converts it to a dotpath 135 | i.e. 136 | 'model.architecture.n_layers'. 137 | """ 138 | return ".".join(path) 139 | 140 | @staticmethod 141 | def dotpath_to_path(dotpath): 142 | """ 143 | Takes in a dotpath 144 | e.g. 145 | 'model.architecture.n_layers' 146 | and converts it to a path 147 | e.g. 148 | ['model', 'architecture', 'n_layers']. 149 | 150 | Only gin parameters are allowed to have dots in their name, and 151 | this fn correctly handles conversion for those cases: 152 | e.g. 153 | taking 'model.gin.architecture.n_layers' and converting it into ['model', 'gin', 'architecture.n_layers']. 154 | """ 155 | # Split up the dotpath 156 | split = dotpath.split(".") 157 | 158 | # Handle the case where the gin parameter has a dot 159 | if len(split) > 2 and split[-3] == 'gin': 160 | split = split[:-2] + [".".join(split[-2:])] 161 | 162 | return split 163 | 164 | @staticmethod 165 | def parse_ref_dotpath(dotpath): 166 | """ 167 | Reference dotpaths consist of this.is.a.parameter.idx.this.is.another.parameter._.parameter.idx (_ indicates any idx) 168 | 169 | Returns a list of (dotpath, idx) tuples. 170 | """ 171 | return re.findall('((?:\w[^.]*)+(?:\.\w[^.]*)*?)\.(\d+|_)', dotpath) 172 | 173 | @staticmethod 174 | def param_comparator(p, q): 175 | """ 176 | A comparison operator that checks which of two sweeps to apply first. 177 | 178 | The comparison relies on the conditional dependencies expressed in both sweeps. 179 | """ 180 | # Create dotpaths for both sweeps 181 | p_dotpath = ".".join(p.path) 182 | q_dotpath = ".".join(q.path) 183 | 184 | # Get the keys that each sweep refers (these are the conditional dependencies that the sweep has) 185 | if isinstance(p, SweptDisjointParameter): 186 | p_ref_dotpaths = p.disjoint.keys() 187 | elif isinstance(p, SweptProductParameter): 188 | p_ref_dotpaths = p.product.keys() 189 | else: 190 | raise NotImplementedError 191 | if isinstance(q, SweptDisjointParameter): 192 | q_ref_dotpaths = q.disjoint.keys() 193 | elif isinstance(q, SweptProductParameter): 194 | q_ref_dotpaths = q.product.keys() 195 | else: 196 | raise NotImplementedError 197 | 198 | # Each key is a combination of dotpaths and index references: parse them to extract the dotpaths 199 | p_ref_parsed_dotpaths = list( 200 | mapcat(compose(list, autocurry(map)(0), QuinSweep.parse_ref_dotpath), p_ref_dotpaths)) 201 | q_ref_parsed_dotpaths = list( 202 | mapcat(compose(list, autocurry(map)(0), QuinSweep.parse_ref_dotpath), q_ref_dotpaths)) 203 | 204 | # Check if any reference dotpath in q matches p or vice-versa 205 | if any(lambda e: p_dotpath == e or p_dotpath.endswith(f'.{e}'), q_ref_parsed_dotpaths): 206 | return 1 207 | elif any(lambda e: q_dotpath == e or q_dotpath.endswith(f'.{e}'), p_ref_parsed_dotpaths): 208 | return -1 209 | else: 210 | return 0 211 | 212 | @staticmethod 213 | def is_product_subtype(subtype): 214 | return 'SweptProductParameter' in subtype.__name__ 215 | 216 | @staticmethod 217 | def is_disjoint_subtype(subtype): 218 | return 'SweptDisjointParameter' in subtype.__name__ 219 | 220 | def expand_partial_dotpath(self, partial_dotpath): 221 | """ 222 | Given a partial dotpath, expand the dotpath to yield a full dotpath from the root to the parameter. 223 | """ 224 | # This assertion checks if the reference made by the partial dotpath points to another parameter *uniquely*. 225 | # If not, it's impossible to know which parameter was being referenced. 226 | assert one(lambda k: k.endswith(partial_dotpath), 227 | self.swept_parameters_dict.keys()) # swept_parameter_dict.keys() -> all_dotpaths 228 | 229 | # Complete the dotpath, yielding the full dotpath that points to the location of the parameter 230 | return list(filter(lambda t: t[1] is True, 231 | map(lambda k: (k, k.endswith(partial_dotpath)), 232 | self.swept_parameters_dict.keys()) 233 | ) 234 | )[0][0] 235 | 236 | def expand_reference(self, reference): 237 | # Parse the reference dotpath 238 | parsed_ref_dotpath = self.parse_ref_dotpath(reference) 239 | 240 | # Map to expand the dotpaths in the reference 241 | parsed_ref_dotpath = list(map(lambda t: (self.expand_partial_dotpath(t[0]), t[1]), parsed_ref_dotpath)) 242 | 243 | # Join back 244 | return ".".join(list(cat(parsed_ref_dotpath))) 245 | 246 | def __init__(self, 247 | sweep_config_path=None, 248 | schema_path=None, 249 | sweep_config=None, 250 | schema=None 251 | ): 252 | 253 | # Load up the sweep config 254 | if sweep_config is None: 255 | assert sweep_config_path is not None, 'Please pass in either sweep_config or sweep_config_path.' 256 | assert sweep_config_path.endswith('.yaml'), 'Must use a YAML file for the sweep_config.' 257 | sweep_config = yaml.load(open(sweep_config_path), 258 | Loader=yaml.FullLoader) 259 | 260 | # First, extract the locations of the sweeps being performed 261 | self.sweep_paths = QuinSweep.parse_sweep_config(sweep_config) 262 | 263 | if len(self.sweep_paths) == 0: 264 | # No sweep: just return the single Quinfig 265 | self.quinfigs = [Quinfig(config=sweep_config)] 266 | print(f"Generated {len(self.quinfigs)} quinfig(s) successfully.") 267 | return 268 | 269 | # Create list of paths to all the parameters that are being swept 270 | self.swept_parameter_paths = list(distinct(map(QuinSweep.get_parameter_path, 271 | self.sweep_paths), 272 | key=tuple) 273 | ) 274 | 275 | # Next, extract the fixed parameters from the sweep config 276 | self.fixed_parameters = QuinSweep.parse_fixed_parameters(sweep_config) 277 | 278 | # Next, fetch the SweptParameter named tuples after creating them 279 | self.swept_parameters, \ 280 | self.swept_disjoint_parameters, \ 281 | self.swept_product_parameters, \ 282 | self.swept_default_parameters = \ 283 | self.fetch_swept_parameters(sweep_config, self.sweep_paths, self.swept_parameter_paths) 284 | 285 | # For convenience, create a lookup table for the swept parameters from their dotpaths 286 | self.swept_parameters_dict = dict(zip(map(lambda sp: ".".join(sp.path), 287 | self.swept_parameters), 288 | self.swept_parameters) 289 | ) 290 | 291 | # Expand all the dotpaths in any conditional 292 | self.expand_all_condition_dotpaths() 293 | 294 | # Filter out the unconditional sweeps and then process them 295 | uncond_disjoint_sweeps = list(filter(compose(is_seq, 1), 296 | self.swept_disjoint_parameters)) 297 | uncond_product_sweeps = list(filter(compose(is_seq, 1), 298 | self.swept_product_parameters)) 299 | 300 | self.all_combinations = self.process_unconditional_sweeps(uncond_disjoint_sweeps, 301 | uncond_product_sweeps) 302 | 303 | # Filter out the conditional sweeps and then process them 304 | cond_disjoint_sweeps = list(filter(compose(is_mapping, 1), self.swept_disjoint_parameters)) 305 | cond_product_sweeps = list(filter(compose(is_mapping, 1), self.swept_product_parameters)) 306 | 307 | self.process_conditional_sweeps(cond_disjoint_sweeps, 308 | cond_product_sweeps) 309 | 310 | # Generate all the Quinfigs 311 | self.quinfigs = [] 312 | for combination in self.all_combinations: 313 | coll = deepcopy(sweep_config) 314 | for parameter in combination: 315 | coll = tz.assoc_in(coll, parameter.path, parameter.value) 316 | self.quinfigs.append(Quinfig(config=coll, base_path=os.path.dirname(os.path.abspath(sweep_config_path)))) 317 | 318 | print(f"Generated {len(self.quinfigs)} quinfig(s) successfully.") 319 | 320 | def __getitem__(self, idx): 321 | """ 322 | Implement indexing into the QuinSweep to fetch particular Quinfigs. 323 | """ 324 | return self.quinfigs[idx] 325 | 326 | def expand_all_condition_dotpaths(self): 327 | """ 328 | Expands the paths 329 | """ 330 | 331 | # Function that expands condition dotpaths for SweptParameters 332 | expand_condition_dotpaths = lambda sp: \ 333 | SweptParameter(sp.path, 334 | walk_values(lambda sweep: iffy(is_mapping, 335 | autocurry(walk_keys)( 336 | self.expand_reference))(sweep), 337 | sp[1]) 338 | ) 339 | 340 | # Apply the function to the list of SweptParameters 341 | self.swept_parameters = \ 342 | list( 343 | map(expand_condition_dotpaths, 344 | self.swept_parameters) 345 | ) 346 | 347 | # Function that expands condition dotpaths for Swept___Parameters 348 | expand_condition_dotpaths = lambda subtype, sp: \ 349 | subtype(sp.path, 350 | iffy(is_mapping, 351 | autocurry(walk_keys)( 352 | self.expand_reference))(sp[1]), 353 | ) 354 | expand_condition_dotpaths = autocurry(expand_condition_dotpaths) 355 | 356 | self.swept_disjoint_parameters = \ 357 | list( 358 | 359 | map(expand_condition_dotpaths(SweptDisjointParameter), 360 | self.swept_disjoint_parameters) 361 | ) 362 | self.swept_product_parameters = \ 363 | list( 364 | map(expand_condition_dotpaths(SweptProductParameter), 365 | self.swept_product_parameters) 366 | 367 | ) 368 | 369 | # Recreate the lookup table 370 | self.swept_parameters_dict = dict(zip(map(lambda sp: ".".join(sp.path), 371 | self.swept_parameters), 372 | self.swept_parameters) 373 | ) 374 | 375 | def replace_underscores(self, swept_parameter): 376 | """ 377 | Replace all the underscore references in sweep of swept_parameter. 378 | """ 379 | 380 | # Find all the references (i.e. dependencies) made by the swept_parameter 381 | references = [] 382 | for token in QuinSweep.SWEEP_TOKENS[:-1]: # omit default since it's value is never a dict 383 | if f"~{token}" in swept_parameter.sweep and is_mapping(swept_parameter.sweep[f"~{token}"]): 384 | references.extend(list(swept_parameter.sweep[f"~{token}"].keys())) 385 | 386 | # Find all the referred parameters 387 | parsed_references = list(map(QuinSweep.parse_ref_dotpath, references)) 388 | dotpaths = list(cat(parsed_references)) 389 | ref_dict = merge_with(compose(list, cat), *list(map(lambda e: dict([e]), dotpaths))) 390 | 391 | # TODO: there's a bug here potentially 392 | assert all(map(lambda l: len(l) == len(set(l)), list(itervalues(ref_dict)))), \ 393 | 'All conditions must be distinct.' 394 | 395 | ref_dict_no_underscores = walk_values(compose(set, 396 | autocurry(map)(int), 397 | autocurry(filter)(lambda e: e != '_')), 398 | ref_dict) 399 | 400 | if not references: 401 | return swept_parameter 402 | 403 | def compute_possibilities(full_dotpath, reference): 404 | # Look up the parameter using the dotpath 405 | parameter = self.swept_parameters_dict[full_dotpath] 406 | 407 | # Use the reference to figure out how many possiblities exist for the underscore 408 | if len(reference) > 0: 409 | # Merge all the sweeps performed for this parameter 410 | merged_sweep = merge(*list(filter(is_mapping, itervalues(parameter.sweep)))) 411 | # Look up the reference 412 | return len(merged_sweep[reference]) 413 | 414 | assert len(parameter.sweep) == 1, 'If no reference, must be a single unconditional sweep.' 415 | # The number of possibilities is simply the number of values specified 416 | # in the (product/disjoint) unconditional sweep 417 | return len(list(parameter.sweep.values())[0]) 418 | 419 | # Update the sweep by replacing underscores 420 | updated_sweep = swept_parameter.sweep 421 | 422 | # Loop over all the parsed references 423 | for parsed_ref in parsed_references: 424 | 425 | # Expand all the partial dotpaths 426 | # TODO: remove? expanding all the partial dotpaths in the beginning? 427 | parsed_ref = list( 428 | map(lambda t: (self.expand_partial_dotpath(t[0]), t[1]), parsed_ref)) 429 | 430 | # For each parsed reference, there will be multiple (dotpath, idx) pairs 431 | for i, (full_dotpath, ref_idx) in enumerate(parsed_ref): 432 | 433 | # If the reference index is not an underscore, continue 434 | if not ref_idx == '_': 435 | continue 436 | 437 | # Compute the prefix reference 438 | prefix_reference = ".".join(list(cat(parsed_ref[:i]))) 439 | 440 | # Compute the number of possible ways to replace the underscore 441 | n_possibilities = compute_possibilities(full_dotpath, prefix_reference) 442 | replacements = set(range(n_possibilities)) - ref_dict_no_underscores[full_dotpath] 443 | 444 | # Find the path to the underscore condition 445 | path_to_condition = get_only_paths(updated_sweep, lambda p: any(lambda e: '_' in e, p), 446 | stop_at=full_dotpath)[0] 447 | 448 | # Find the value of the underscore condition 449 | value = tz.get_in( 450 | path_to_condition, 451 | updated_sweep) 452 | 453 | # Construct keys that are subtitutes for the underscore 454 | keys = list(map(lambda s: f'{full_dotpath}.{s}', replacements)) 455 | keys = list(map(lambda k: path_to_condition[:-1] + [k], keys)) 456 | 457 | # Update by adding those keys in 458 | for k in keys: 459 | updated_sweep = tz.assoc_in(updated_sweep, k, value) 460 | 461 | # Create a new swept parameter with the updated sweep 462 | swept_parameter = SweptParameter(swept_parameter.path, 463 | walk_values( 464 | iffy(is_mapping, 465 | autocurry(select_keys)(lambda k: '_' not in k) 466 | ), 467 | updated_sweep) 468 | ) 469 | return swept_parameter 470 | 471 | def process_conditional_sweeps(self, 472 | cond_disjoint_sweeps, 473 | cond_product_sweeps, 474 | ): 475 | """ 476 | Function to process conditional sweeps: these sweeps are applied when a conditional is satisfied. 477 | 478 | As an example, consider 479 | 480 | """ 481 | # The complete matrix of all pairwise comparisons over the set of conditional sweeps: 482 | # Note that our comparison op isn't transitive (e.g. A = B and B = C but A > C is possible) 483 | # so we want a topological sort over the set of sweeps: O(n^2) is incurred to generate the DAG 484 | # over which the toposort is applied 485 | all_comparisons = list(map(lambda t: (t, QuinSweep.param_comparator(*t)), 486 | itertools.product(cond_product_sweeps + cond_disjoint_sweeps, 487 | cond_product_sweeps + cond_disjoint_sweeps))) 488 | 489 | dependencies = merge( 490 | # first pretend that there are no dependencies 491 | dict(map(lambda t: ((t.path, type(t)), set()), cond_disjoint_sweeps)), 492 | dict(map(lambda t: ((t.path, type(t)), set()), cond_product_sweeps)), 493 | # then merge in the dependencies 494 | dict(map(lambda t: ((t[0][0].path, type(t[0][0])), set([(t[0][1].path, type(t[0][1])), ])), 495 | filter(lambda t: t[1] == -1, all_comparisons) 496 | ) 497 | ) 498 | ) 499 | 500 | # Topological sort to produce the partial ordering: a list of sets for the partial ordering 501 | sweep_posets = list(toposort(dependencies)) 502 | 503 | # Map to extract the dotpaths of the sweeps 504 | dotpath_posets = list(map(lambda s: set(map(0, s)), sweep_posets)) 505 | 506 | # A dotpath could occur in more than one poset 507 | # e.g. when it contains both a product sweep and a disjoint sweep with different conditionals 508 | # Ensure that the dotpath occurs exactly once, in its earliest location 509 | dotpath_poset_subs = list( 510 | reversed(list(accumulate([set()] + list(reversed(dotpath_posets))[:-1], lambda a, b: a.union(b))))) 511 | dotpath_posets = list(map(lambda t: t[0] - t[1], zip(dotpath_posets, dotpath_poset_subs))) 512 | 513 | # Loop over all the sets in the partial order 514 | for poset in dotpath_posets: 515 | # For each path, convert it to a dotpath and then replace any underscores in it 516 | for path in poset: 517 | dotpath = QuinSweep.path_to_dotpath(path) 518 | self.swept_parameters_dict[dotpath] = self.replace_underscores(self.swept_parameters_dict[dotpath]) 519 | 520 | for poset in sweep_posets: 521 | # Split the poset into the disjoint and product sweeps 522 | disjoint_poset = [p for p in poset if QuinSweep.is_disjoint_subtype(p[1])] 523 | product_poset = [p for p in poset if QuinSweep.is_product_subtype(p[1])] 524 | 525 | # Process the product sweeps first 526 | for path, subtype in product_poset: 527 | dotpath = QuinSweep.path_to_dotpath(path) 528 | # Each parameter combination needs to be expanded 529 | # print("Product Expansion") 530 | # for thing in self.all_combinations: 531 | # print([e.dotpath for e in thing]) 532 | self.all_combinations = self.product_expansion(self.all_combinations, 533 | dotpath, 534 | self.swept_parameters_dict[dotpath].sweep['~product']) 535 | 536 | # Process the poset 537 | # For disjoint, process the entire poset together (restricted to the disjoint sweeps) 538 | self.all_combinations = self.disjoint_expansion(self.all_combinations, disjoint_poset) 539 | 540 | pass 541 | 542 | # Apply the defaults 543 | for poset in dotpath_posets: 544 | # For each path, convert it to a dotpath and then replace any underscores in it 545 | for path in poset: 546 | for i, combo in enumerate(self.all_combinations): 547 | dotpath = QuinSweep.path_to_dotpath(path) 548 | if dotpath not in list(map(lambda c: self.path_to_dotpath(c.path), combo)): 549 | combo.append(Parameter(path, 550 | f'{dotpath}.0', 551 | self.swept_parameters_dict[dotpath].sweep['~default'])) 552 | self.all_combinations[i] = combo 553 | 554 | def product_expansion(self, 555 | combinations, 556 | dotpath, 557 | sweep): 558 | """ 559 | Takes as input 560 | 561 | """ 562 | new_combinations = [] 563 | # Iterate over the parameter combinations 564 | for i, combo in enumerate(combinations): 565 | # print(i, combo, sweep) 566 | flag = False 567 | # The parameter combination contains a setting of the previously configured parameters 568 | for key in sweep: 569 | # First parse the key, which is a reference dotpath, to a list of dotpaths 570 | # Check if all the ref dotpaths are satisfied by the combo 571 | if all(map(lambda t: ".".join(t) in combo, self.parse_ref_dotpath(key))): 572 | # Perform the product expansion 573 | new_combos = list(itertools.product(*list(map(lambda e: [e], combo)), 574 | list( 575 | map(lambda e: Parameter(self.dotpath_to_path(dotpath), 576 | f'{dotpath}.{e[0]}', 577 | e[1]), 578 | enumerate(sweep[key]) 579 | ) 580 | )) 581 | ) 582 | new_combos = list(map(compose(list, concat), new_combos)) 583 | new_combinations.extend(new_combos) 584 | flag = True 585 | break 586 | if not flag: 587 | new_combinations.append(combo) 588 | 589 | return new_combinations 590 | 591 | def disjoint_expansion(self, 592 | combinations, 593 | poset): 594 | """ 595 | 596 | """ 597 | 598 | ref_dotpath_to_dotpath = {} 599 | for path, _ in poset: 600 | # Look up the SweptParameter using the path 601 | swept_param = self.swept_parameters_dict[self.path_to_dotpath(path)] 602 | # Loop over the conditionals in the SweptParameter to create the lookup table 603 | for ref_dotpath in swept_param.sweep['~disjoint']: 604 | if ref_dotpath not in ref_dotpath_to_dotpath: 605 | ref_dotpath_to_dotpath[ref_dotpath] = set() 606 | ref_dotpath_to_dotpath[ref_dotpath].add(swept_param) 607 | 608 | new_combinations = [] 609 | 610 | # Iterate over the parameter combinations 611 | for i, combo in enumerate(combinations): 612 | # print(i, combo) 613 | flag = False 614 | # The parameter combination contains a setting of the previously configured parameters 615 | for ref_dotpath in ref_dotpath_to_dotpath: 616 | # First parse the key, which is a reference dotpath, to a list of dotpaths 617 | # Check if all the ref dotpaths are satisfied by the combo 618 | if all(map(lambda t: ".".join(t) in combo, self.parse_ref_dotpath(ref_dotpath))): 619 | disjoint_parameters = list(zip(*map(lambda sp: list(map(lambda v: Parameter(sp.path, 620 | f'{self.path_to_dotpath(sp.path)}.{v[0]}', 621 | v[1]), 622 | enumerate( 623 | sp.sweep['~disjoint'][ref_dotpath]) 624 | ) 625 | ), 626 | ref_dotpath_to_dotpath[ref_dotpath]) 627 | )) 628 | 629 | new_combos = list(map(compose(list, cat), zip([combo] * len(disjoint_parameters), 630 | disjoint_parameters) 631 | ) 632 | ) 633 | new_combinations.extend(new_combos) 634 | flag = True 635 | break 636 | 637 | if not flag: 638 | new_combinations.append(combo) 639 | 640 | return new_combinations 641 | 642 | def process_unconditional_sweeps(self, 643 | uncond_disjoint_sweeps, 644 | uncond_product_sweeps): 645 | """ 646 | Parameters with unconditional sweeps are always processed first, and unconditional sweeps specify how to vary 647 | these parameters. 648 | 649 | Sweeps should be thought of as trees, with the root node containing assignments to all unswept parameters. 650 | Each time one or a group of swept parameters are processed, the tree grows in depth. The path from the root 651 | of the tree to any leaf indicates a particular assignment to the swept parameters. 652 | 653 | Unconditional sweeps can be 654 | 655 | - disjoint over r parameters, each parameter taking exactly k parameter values 656 | 657 | Unconditional disjoint sweeps can be thought of as generating k disjoint sweeps, 658 | where each sweep contains assignments to the r parameters. 659 | The k parameter values specified for each of the r parameters are assumed to be in alignment. 660 | 661 | In terms of the tree structure, this operation adds k children to the root node, 662 | with each child containing a simultaneous assignment to the r parameters. 663 | 664 | - product over p parameters, with the parameters taking k_1, k_2, ..., k_p parameter values 665 | 666 | Unconditional product sweeps generate k_i new sweeps for the ith (of p) parameter. Each of these product sweeps 667 | add a new layer to the tree, leading to a total depth increase of p. The ith such layer contains assignments 668 | to the ith parameter. 669 | 670 | In total, unconditional sweeps generate a total of k * k_1 * k_2 ... * k_r total leaf nodes in the tree. 671 | If the sweep only contains unconditional sweeps, the leaf nodes will contain complete assignments to all 672 | parameters, and can be used to create a Quinfig. 673 | """ 674 | 675 | # All of the disjoint unconditional sweeeps must be identical length 676 | assert allequal( 677 | map(compose(len, 1), uncond_disjoint_sweeps)), 'All disjoint unconditional sweeps must have same length.' 678 | disjoint_uncond_combinations = list(zip( 679 | *map( 680 | lambda t: list( 681 | map(lambda e: Parameter(t.path, f'{".".join(t.path)}.{e[0]}', e[1]), enumerate(t.disjoint))), 682 | uncond_disjoint_sweeps)) 683 | ) 684 | all_combinations = disjoint_uncond_combinations 685 | 686 | # Next, process the sequential product sweeps 687 | product_uncond_combinations = list( 688 | itertools.product( 689 | *map(lambda t: list( 690 | map(lambda e: Parameter(t.path, 691 | f'{".".join(t.path)}.{e[0]}', 692 | e[1]), 693 | enumerate(t.product) 694 | ) 695 | ), 696 | uncond_product_sweeps) 697 | ) 698 | ) 699 | all_combinations = list( 700 | map(compose(list, cat), 701 | itertools.product(all_combinations, 702 | product_uncond_combinations) 703 | ) 704 | ) 705 | 706 | return all_combinations 707 | 708 | def fetch_swept_parameters(self, 709 | config, 710 | sweep_paths, 711 | swept_parameter_paths): 712 | """ 713 | Construct named tuples for all the parameters being swept. 714 | """ 715 | # Filter sweep paths to all 3 sweep types 716 | disjoint_sweep_paths = list(filter(lambda s: 'disjoint' in last(s), sweep_paths)) 717 | product_sweep_paths = list(filter(lambda s: 'product' in last(s), sweep_paths)) 718 | default_sweep_paths = list(filter(lambda s: 'default' in last(s), sweep_paths)) 719 | 720 | # Construct SweptParameter and Swept__Parameter namedtuples, 721 | # making it 722 | # consisting of the path to the parameter and its sweep configuration 723 | construct_swept_parameters = lambda subtype, paths, wrapper: list( 724 | map(lambda p: subtype(wrapper(p), 725 | tz.get_in(p, coll=config) 726 | ), 727 | paths) 728 | ) 729 | 730 | swept_parameters = construct_swept_parameters(SweptParameter, 731 | swept_parameter_paths, 732 | identity) 733 | swept_disjoint_parameters = construct_swept_parameters(SweptDisjointParameter, 734 | disjoint_sweep_paths, 735 | QuinSweep.get_parameter_path) 736 | swept_product_parameters = construct_swept_parameters(SweptProductParameter, 737 | product_sweep_paths, 738 | QuinSweep.get_parameter_path) 739 | swept_default_parameters = construct_swept_parameters(SweptDefaultParameter, 740 | default_sweep_paths, 741 | QuinSweep.get_parameter_path) 742 | 743 | return swept_parameters, swept_disjoint_parameters, swept_product_parameters, swept_default_parameters 744 | 745 | 746 | if __name__ == '__main__': 747 | import os 748 | 749 | print(os.getcwd()) 750 | # sweep_config = Quinfig( 751 | # config_path='quinine/tests/derived-1-2.yaml') 752 | sweep_config = Quinfig( 753 | config_path='tests/bugs/test-2.yaml') 754 | # sweep_config = Quinfig( 755 | # config_path='/Users/krandiash/Desktop/workspace/projects/quinine/tests/derived-2.yaml') 756 | 757 | quin_sweep = QuinSweep(sweep_config=sweep_config) 758 | print(quin_sweep.quinfigs) 759 | -------------------------------------------------------------------------------- /quinine/common/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some common utilities. 3 | """ 4 | 5 | import yaml 6 | from funcy import * 7 | from munch import Munch 8 | import cytoolz as tz 9 | 10 | 11 | def difference(*colls): 12 | """ 13 | Find the keys that have different values in an arbitrary number of (nested) collections. Any key 14 | that differs in at least 2 collections is considered to fit this criterion. 15 | """ 16 | 17 | # Get all the leaf paths for each collection: make each path a tuple 18 | leaf_paths_by_coll = list(map(lambda c: list(map(tuple, get_all_leaf_paths(c))), colls)) 19 | 20 | # Find the union of all leaf paths: merge all the paths and keep only the unique paths 21 | union_leaf_paths = list(distinct(concat(*leaf_paths_by_coll))) 22 | 23 | # Get the values corresponding to these leaf paths in every collection: if a leaf path doesn't exist, assumes None 24 | values_by_coll = list(map(lambda lp: list(map(lambda coll: tz.get_in(lp, coll), colls)), union_leaf_paths)) 25 | 26 | # Filter out the leaf paths that have identical values across the collections 27 | keep_leaf_paths = list(map(0, filter(lambda t: not allequal(t[1]), zip(union_leaf_paths, values_by_coll)))) 28 | keep_values = list(map(1, filter(lambda t: not allequal(t[1]), zip(union_leaf_paths, values_by_coll)))) 29 | 30 | # Rearrange to construct a list of dictionaries -- one per original collection. 31 | # Each of these dictionaries maps a 'kept' leaf path to its corresponding 32 | # value in the collection 33 | differences = list(map(lambda vals: dict(zip(keep_leaf_paths, vals)), list(zip(*keep_values)))) 34 | 35 | return differences 36 | 37 | 38 | 39 | def rmerge(*colls): 40 | """ 41 | Recursively merge an arbitrary number of collections. 42 | For conflicting values, later collections to the right are given priority. 43 | Note that this function treats sequences as a normal value and sequences are not merged. 44 | 45 | Uses: 46 | - merging config files 47 | """ 48 | if isinstance(colls, tuple) and len(colls) == 1: 49 | # A squeeze operation since merge_with generates tuple(list_of_objs,) 50 | colls = colls[0] 51 | if all(is_mapping, colls): 52 | # Merges all the collections, recursively applies merging to the combined values 53 | return merge_with(rmerge, *colls) 54 | else: 55 | # If colls does not contain mappings, simply pick the last one 56 | return last(colls) 57 | 58 | 59 | def prettyprint(s): 60 | if hasattr(s, '__dict__'): 61 | print(yaml.dump(s.__dict__)) 62 | elif isinstance(s, dict): 63 | print(yaml.dump(s)) 64 | else: 65 | print(s) 66 | 67 | 68 | def allequal(seq): 69 | return len(set(seq)) <= 1 70 | 71 | 72 | @autocurry 73 | def listmap(fn, seq): 74 | return list(map(fn, seq)) 75 | 76 | 77 | @autocurry 78 | def prefix(s, p): 79 | if isinstance(s, str): 80 | return f'{p}{s}' 81 | elif isinstance(s, list): 82 | return list(map(prefix(p=p), s)) 83 | else: 84 | raise NotImplementedError 85 | 86 | 87 | @autocurry 88 | def postfix(s, p): 89 | if isinstance(s, str): 90 | return f'{s}{p}' 91 | elif isinstance(s, list): 92 | return list(map(postfix(p=p), s)) 93 | else: 94 | raise NotImplementedError 95 | 96 | 97 | @autocurry 98 | def surround(s, pre, post): 99 | return postfix(prefix(s, pre), post) 100 | 101 | 102 | def nested_map(f, *args): 103 | """ Recursively transpose a nested structure of tuples, lists, and dicts """ 104 | assert len(args) > 0, 'Must have at least one argument.' 105 | 106 | arg = args[0] 107 | if isinstance(arg, tuple) or isinstance(arg, list): 108 | return [nested_map(f, *a) for a in zip(*args)] 109 | elif isinstance(arg, dict): 110 | return { 111 | k: nested_map(f, *[a[k] for a in args]) 112 | for k in arg 113 | } 114 | else: 115 | return f(*args) 116 | 117 | 118 | @autocurry 119 | def walk_values_rec(f, coll): 120 | """ 121 | Similar to funcy's walk_values, but does so recursively, including mapping f over lists. 122 | """ 123 | if is_mapping(coll): 124 | return f(walk_values(walk_values_rec(f), coll)) 125 | elif is_list(coll): 126 | return f(list(map(walk_values_rec(f), coll))) 127 | else: 128 | return f(coll) 129 | 130 | 131 | @autocurry 132 | def nested_dict_walker(fn, coll): 133 | """ 134 | Apply a function over the mappings contained in coll. 135 | """ 136 | return walk_values_rec(iffy(is_mapping, fn), coll) 137 | 138 | 139 | def get_all_leaf_paths(coll): 140 | """ 141 | Returns a list of paths to all leaf nodes in a nested dict. 142 | Paths can travel through lists and the index is inserted into the path. 143 | """ 144 | if isinstance(coll, dict) or isinstance(coll, Munch): 145 | return list(cat(map(lambda t: list(map(lambda p: [t[0]] + p, 146 | get_all_leaf_paths(t[1]) 147 | )), 148 | iteritems(coll))) 149 | ) 150 | 151 | elif isinstance(coll, list): 152 | return list(cat(map(lambda t: list(map(lambda p: [t[0]] + p, 153 | get_all_leaf_paths(t[1]) 154 | )), 155 | enumerate(coll))) 156 | ) 157 | else: 158 | return [[]] 159 | 160 | 161 | def get_all_paths(coll, prefix_path=(), stop_at=None, stop_below=None): 162 | """ 163 | Given a collection, by default returns paths to all the leaf nodes. 164 | Use stop_at to truncate paths at the given key. 165 | Use stop_below to truncate paths one level below the given key. 166 | """ 167 | assert stop_at is None or stop_below is None, 'Only one of stop_at or stop_below can be used.' 168 | if stop_below is not None and stop_below in str(last(butlast(prefix_path))): 169 | return [[]] 170 | if stop_at is not None and stop_at in str(last(prefix_path)): 171 | return [[]] 172 | if isinstance(coll, dict) or isinstance(coll, Munch) or isinstance(coll, list): 173 | if isinstance(coll, dict) or isinstance(coll, Munch): 174 | items = iteritems(coll) 175 | else: 176 | items = enumerate(coll) 177 | 178 | return list(cat(map(lambda t: list(map(lambda p: [t[0]] + p, 179 | get_all_paths(t[1], 180 | prefix_path=list(prefix_path) + [t[0]], 181 | stop_at=stop_at, 182 | stop_below=stop_below) 183 | )), 184 | items)) 185 | ) 186 | else: 187 | return [[]] 188 | 189 | 190 | def get_only_paths(coll, pred, prefix_path=(), stop_at=None, stop_below=None): 191 | """ 192 | Get all paths that satisfy the predicate fn pred. 193 | First gets all paths and then filters them based on pred. 194 | """ 195 | all_paths = get_all_paths(coll, prefix_path=prefix_path, stop_at=stop_at, stop_below=stop_below) 196 | return list(filter(pred, all_paths)) 197 | 198 | if __name__ == '__main__': 199 | coll1 = {'a': 1, 200 | 'b': 2, 201 | 'c': {'d': 12}} 202 | coll2 = {'a': 1, 203 | 'b': 2, 204 | 'c': {'d': 13}} 205 | coll3 = {'a': 1, 206 | 'b': 3, 207 | 'c': {'d': 14}, 208 | 'e': 4} 209 | 210 | difference(coll1, coll2, coll3) -------------------------------------------------------------------------------- /quinine/examples/base.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | architecture: resnet18 -------------------------------------------------------------------------------- /quinine/examples/config.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | seed: 2 3 | module: test.py 4 | 5 | model: 6 | pretrained: true 7 | 8 | dataset: 9 | - name: cifar10 10 | - name: imagenet 11 | 12 | 13 | gin: 14 | 'a_gin_configurable_fn.print_yes': true 15 | 16 | templating: 17 | parent_yaml: quinine/examples/base.yaml 18 | -------------------------------------------------------------------------------- /quinine/examples/run.py: -------------------------------------------------------------------------------- 1 | from quinine.common.cerberus import * 2 | from quinine.common.gin import * 3 | import quinine.examples.simple 4 | from quinine.quinfig import Quinfig 5 | from quinine.examples.simple import simple_program 6 | 7 | 8 | def simple_example(): 9 | # Create a simple schema using Cerberus: shortcuts make it easy to write complex schemas with reusable components 10 | schema = {'general': stdict({'seed': merge(tinteger, default(0)), 11 | 'module': merge(tstring, required), 12 | }), 13 | 'model': stdict({'architecture': merge(tstring, allowed(['resnet18', 'resnet50'])), 14 | 'pretrained': merge(tboolean, required) 15 | }), 16 | } 17 | 18 | # Internally, Quinfig will autoupdate this schema using the autoexpand_schema function to support gin 19 | prettyprint(autoexpand_schema(schema)) 20 | 21 | # Write out the config: you could also have written this in a yaml file 22 | config = {'general': {'seed': 2, 23 | 'module': 'test.py'}, 24 | 'model': {'pretrained': True}, 25 | 'gin': 26 | # below, we set the print_yes argument in a_gin_configurable_fn to True 27 | # you can use as much or as little gin configuration as you like 28 | # e.g. you could write your entire configuration in gin or not use gin at all 29 | {'a_gin_configurable_fn.print_yes': True}, 30 | 'templating': 31 | # first inherit all the configuration settings in tests/base.yaml and then overwrite them with this config 32 | {'parent_yaml': 'examples/base.yaml'} 33 | } 34 | 35 | # Register the module that we want to configure with gin 36 | register_module_with_gin(quinine.examples.simple, 'examples.simple') 37 | 38 | # Create the quinfig 39 | quinfig = Quinfig(config=config, 40 | schema=schema) 41 | # Voila! 42 | simple_program(quinfig) 43 | 44 | # Or you could have used the yaml 45 | quinfig = Quinfig(config_path='examples/config.yaml', 46 | schema=schema) 47 | simple_program(quinfig) 48 | 49 | 50 | if __name__ == '__main__': 51 | simple_example() 52 | -------------------------------------------------------------------------------- /quinine/examples/simple.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple example program that we want to control using a configuration file. 3 | """ 4 | 5 | 6 | def a_gin_configurable_fn(print_yes): 7 | if print_yes: 8 | print("Yes") 9 | else: 10 | print("No") 11 | 12 | 13 | def simple_program(quinfig): 14 | print("Running simple program.") 15 | print(quinfig) 16 | print("Pay attention to what's being printed below.") 17 | a_gin_configurable_fn() 18 | -------------------------------------------------------------------------------- /quinine/quinfig.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines the main Quinfig class. 3 | """ 4 | 5 | from quinine.common.cerberus import * 6 | from quinine.common.gin import * 7 | from quinine.common.utils import * 8 | 9 | # Register all schemas 10 | create_and_register_schemas() 11 | 12 | 13 | class Quinfig(Munch): 14 | """ 15 | The Quinfig class for creating configuration files. 16 | 17 | Quinfig provides a simple abstraction for configuration files with several useful features: 18 | - write (arbitrarily nested) configuration files in YAML or just specify them as dictionaries in Python 19 | - access keys with dot access 20 | - validate your configs against schemas using Cerberus 21 | - set fn arguments using gin 22 | """ 23 | 24 | def __init__(self, 25 | config_path=None, 26 | schema_path=None, 27 | config=None, 28 | schema=None, 29 | base_path=None, 30 | ): 31 | # Prepare the config 32 | config = prepare_config(config_path=config_path, 33 | schema_path=schema_path, 34 | config=config, 35 | schema=schema, 36 | base_path=base_path) 37 | 38 | # Create the Quinfig 39 | super(Quinfig, self).__init__(config) 40 | 41 | def __repr__(self): 42 | """ 43 | Use a yaml dump to create a formatted representation for the Quinfig. 44 | """ 45 | return f'Quinfig\n' \ 46 | f'-------\n' \ 47 | f'{yaml.dump(self.__dict__)}' 48 | 49 | 50 | def prepare_config(config_path=None, 51 | schema_path=None, 52 | config=None, 53 | schema=None, 54 | base_path=None) -> Munch: 55 | """ 56 | Takes in paths to config and schema files. 57 | Validates the config against the schema, normalizes the config, parses gin and converts the config to a Munch. 58 | """ 59 | # Load up the config 60 | if config is None: 61 | assert config_path is not None, 'Please pass in either config or config_path.' 62 | assert config_path.endswith('.yaml'), 'Must use a YAML file for the config.' 63 | config = yaml.load(open(config_path), 64 | Loader=yaml.FullLoader) 65 | 66 | # If the config is a Quinfig object, just grab the __dict__ for convenience 67 | if isinstance(config, Quinfig): 68 | config = config.__dict__ 69 | 70 | # Convert config to Munch: iffy ensures that the Munch fn is only applied to mappings 71 | config = walk_values_rec(iffy(is_mapping, lambda c: Munch(**c)), config) 72 | 73 | # Load up the schema 74 | if schema is None: 75 | if schema_path is not None: 76 | assert schema_path.endswith('.yaml'), 'Must use a YAML file for the config.' 77 | schema = yaml.load(open(schema_path), 78 | Loader=yaml.FullLoader) 79 | 80 | if schema is not None: 81 | # Allow gin configuration at any level of nesting: put a gin tag at every level of the schema 82 | schema = autoexpand_schema(schema) 83 | 84 | # Validate the config against the schema 85 | validate_config(config, schema) 86 | 87 | # Normalize the config 88 | if not base_path: 89 | base_path = os.path.dirname(os.path.abspath(config_path)) if config_path else '' 90 | else: 91 | base_path = os.path.abspath(base_path) 92 | config = normalize_config(config, schema, base_path=base_path) 93 | 94 | # Convert config to Munch: iffy ensures that the Munch fn is only applied to mappings 95 | config = walk_values_rec(iffy(is_mapping, lambda c: Munch(**c)), config) 96 | 97 | # Parse and load the gin configuration 98 | nested_gin_dict_parser(config) 99 | 100 | return config 101 | -------------------------------------------------------------------------------- /quinine/tests/base.yaml: -------------------------------------------------------------------------------- 1 | a: 2 | b: 3 | c: 4 | d: 5 5 | e: 6 | ~product: 7 | - 2 8 | - 4 9 | - 6 10 | f: 12 -------------------------------------------------------------------------------- /quinine/tests/bugs/test-2.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | module: train.augment 3 | 4 | dataset: 5 | name: imdb 6 | data_dir: /home/workspace/datasets/nlp/tensorflow/ 7 | version: null 8 | 9 | augmentation: 10 | name: 11 | ~disjoint: 12 | - backtranslation 13 | - substitution 14 | - eda 15 | variant: a 16 | model: 17 | ~disjoint: 18 | name.0: 19 | - en2de 20 | name.1: 21 | - glove 22 | - roberta-base 23 | - ppdb-s 24 | name.2: 25 | - null 26 | 27 | model_dir: 28 | ~disjoint: 29 | name.0: 30 | - /home/workspace/models/torchhub/ 31 | name.1: 32 | - /home/workspace/models/nlpaug/ 33 | - null 34 | - /home/workspace/models/nlpaug/ 35 | name.2: 36 | - null 37 | batch_size: 38 | ~product: 39 | name.0: 40 | - 64 41 | ~default: null 42 | num_aug: 43 | ~product: 44 | - 4 45 | - 9 46 | store_dir: /home/workspace/datasets/nlp/augmented/tensorflow/ 47 | 48 | wandb: 49 | group: augmentation-imdb -------------------------------------------------------------------------------- /quinine/tests/bugs/test.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | module: train.trainer 3 | 4 | dataset: 5 | name: cimdb 6 | version: orig 7 | data_dir: /home/workspace/datasets/nlp/from_source/counterfactually-augmented-data/sentiment/ 8 | train: 9 | batch_size: 16 10 | shuffle_buffer: 1000 11 | 12 | augmentation: 13 | name: 14 | ~disjoint: 15 | - null 16 | - backtranslation 17 | - substitution 18 | - eda 19 | variant: 20 | ~product: 21 | name.0: 22 | - null 23 | name._: 24 | - a 25 | - b 26 | - both 27 | model: 28 | ~product: 29 | name.1: 30 | - en2de 31 | name.2: 32 | - glove 33 | - roberta-base 34 | - ppdb-s 35 | name._: 36 | - null 37 | 38 | num_aug: 39 | ~product: 40 | name.0: 41 | - null 42 | name._: 43 | - 4 44 | - 9 45 | store_dir: /home/workspace/datasets/nlp/augmented/from_source/ 46 | 47 | 48 | trainer: 49 | epochs: 20 50 | 51 | optimizer: 52 | lr: 0.00001 53 | 54 | model: 55 | architecture: bert-base-uncased 56 | 57 | features: 58 | max_length: 350 59 | 60 | wandb: 61 | job_type: training 62 | -------------------------------------------------------------------------------- /quinine/tests/derived-1-1.yaml: -------------------------------------------------------------------------------- 1 | templating: 2 | parent_yaml: quinine/tests/derived-1.yaml 3 | a: 4 | b: 5 | c: 6 | d: 7 | ~disjoint: 8 | - 17 9 | - 23 10 | - 33 11 | - 39 12 | e: 13 | ~product: 14 | - 2 15 | - 4 16 | - 6 17 | f: 18 | ~disjoint: 19 | - 1 20 | - 2 21 | - 3 22 | - 4 23 | h: 24 | j: [-1, -2, -3] 25 | l: 26 | ~product: 27 | - 1 28 | - 2 29 | - 3 -------------------------------------------------------------------------------- /quinine/tests/derived-1-2.yaml: -------------------------------------------------------------------------------- 1 | templating: 2 | parent_yaml: quinine/tests/derived-1.yaml 3 | a: 4 | b: 5 | c: 6 | d: 7 | ~disjoint: 8 | - 17 9 | - 23 10 | - 33 11 | - 39 12 | e: 13 | ~product: 14 | d.0.m.1: 15 | - 2 16 | - 4 17 | - 6 18 | d._: 19 | - 1 20 | - 2 21 | - 3 22 | ~default: 1 23 | f: 24 | ~disjoint: 25 | - 1 26 | - 2 27 | - 3 28 | - 4 29 | h: 30 | j: [-1, -2, -3] 31 | l: 32 | ~product: 33 | - 1 34 | - 2 35 | - 3 36 | m: 37 | ~disjoint: 38 | d.0: 39 | - 0 40 | - 1 41 | d.1: 42 | - 1 43 | - 2 44 | ~default: 2 45 | n: 46 | ~disjoint: 47 | d.0: 48 | - 0 49 | - 1 50 | ~product: 51 | d._: 52 | - 1 53 | - 2 54 | o: 55 | ~product: 56 | - 10 57 | - 11 -------------------------------------------------------------------------------- /quinine/tests/derived-1.yaml: -------------------------------------------------------------------------------- 1 | templating: 2 | parent_yaml: quinine/tests/base.yaml 3 | a: 4 | b: 5 | c: 6 | d: 5 7 | e: 8 | ~product: 9 | - 2 10 | - 4 11 | - 6 12 | f: 7 13 | g: 1 -------------------------------------------------------------------------------- /quinine/tests/derived-2.yaml: -------------------------------------------------------------------------------- 1 | templating: 2 | parent_yaml: quinine/tests/derived-1.yaml 3 | a: 4 | b: 5 | c: 6 | d: 7 | ~disjoint: 8 | - 17 9 | - 23 10 | - 33 11 | - 39 12 | e: 13 | ~product: 14 | d.0: 15 | - 2 16 | - 4 17 | - 6 18 | d._: 19 | - 1 20 | - 2 21 | - 3 22 | ~default: 1 23 | f: 24 | ~disjoint: 25 | - 1 26 | - 2 27 | - 3 28 | - 4 29 | n: 30 | ~disjoint: 31 | d.0: 32 | - 0 33 | - 1 34 | ~product: 35 | d._: 36 | - 1 37 | - 2 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Automatically generated by https://github.com/damnever/pigar. 2 | Cerberus==1.3.2 3 | pyyaml==5.4 4 | cytoolz==0.11.0 5 | funcy==1.15 6 | gin_config==0.3.0 7 | munch==2.5.0 8 | toposort==1.5 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Setup script for the Quinine library. 3 | """ 4 | import os 5 | from setuptools import setup, find_packages 6 | 7 | 8 | # Utility function to read the README file. 9 | # Used for the long_description. It's nice, because now 1) we have a top level 10 | # README file and 2) it's easier to type in the README file than to put a raw 11 | # string in below ... 12 | def read(fname): 13 | return open(os.path.join(os.path.dirname(__file__), fname)).read() 14 | 15 | 16 | def req_file(filename): 17 | with open(filename) as f: 18 | content = f.readlines() 19 | return [x.strip() for x in content] 20 | 21 | 22 | install_requires = req_file("requirements.txt") 23 | 24 | setup( 25 | name="quinine", 26 | version="0.3.0", 27 | author="Karan Goel", 28 | author_email="kgoel93@gmail.com", 29 | license="MIT", 30 | description="quinine is a library for configuring machine learning projects.", 31 | keywords="configuration yaml machine learning ml ai nlp cv vision deep learning", 32 | # url="http://packages.python.org/an_example_pypi_project", 33 | packages=['quinine', 'quinine.common'], 34 | # long_description=read('README.md'), 35 | install_requires=install_requires, 36 | ) 37 | --------------------------------------------------------------------------------