├── hydra_plugins └── list_sweeper_plugin │ ├── py.typed │ ├── __init__.py │ └── list_sweeper.py ├── hydra_list_sweeper.egg-info ├── dependency_links.txt ├── requires.txt ├── top_level.txt ├── SOURCES.txt └── PKG-INFO ├── .gitignore ├── example ├── conf │ ├── algorithm │ │ ├── sgd.yaml │ │ └── adam.yaml │ ├── env │ │ ├── 5_clubs_juggling.yaml │ │ └── balancing_stick.yaml │ └── config.yaml └── my_app.py ├── tests ├── __init__.py └── test_example_sweeper_plugin.py ├── MANIFEST.in ├── setup.py └── README.md /hydra_plugins/list_sweeper_plugin/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hydra_list_sweeper.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | example/outputs 3 | example/multirun -------------------------------------------------------------------------------- /hydra_list_sweeper.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | hydra-core 2 | -------------------------------------------------------------------------------- /example/conf/algorithm/sgd.yaml: -------------------------------------------------------------------------------- 1 | lr: 0.00001 2 | momentum: 0.9 -------------------------------------------------------------------------------- /hydra_list_sweeper.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | hydra_plugins 2 | -------------------------------------------------------------------------------- /example/conf/env/5_clubs_juggling.yaml: -------------------------------------------------------------------------------- 1 | num_props: 5 2 | num_jugglers: 3 3 | num_rounds: 3 -------------------------------------------------------------------------------- /example/conf/algorithm/adam.yaml: -------------------------------------------------------------------------------- 1 | lr: 0.001 2 | beta_1: 0.9 3 | beta_2: 0.999 4 | epsilon: 1e-8 -------------------------------------------------------------------------------- /example/conf/env/balancing_stick.yaml: -------------------------------------------------------------------------------- 1 | num_props: 1 2 | num_jugglers: 1 3 | num_rounds: 10 4 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | global-exclude *.pyc 2 | global-exclude __pycache__ 3 | recursive-include hydra_plugins/* *.yaml py.typed 4 | -------------------------------------------------------------------------------- /hydra_plugins/list_sweeper_plugin/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /example/my_app.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import hydra 3 | from omegaconf import OmegaConf, DictConfig 4 | 5 | 6 | @hydra.main(version_base=None, config_path="conf", config_name="config") 7 | def my_app(cfg: DictConfig) -> None: 8 | print(OmegaConf.to_yaml(cfg)) 9 | 10 | 11 | if __name__ == "__main__": 12 | my_app() 13 | -------------------------------------------------------------------------------- /hydra_list_sweeper.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | MANIFEST.in 2 | README.md 3 | setup.py 4 | hydra_list_sweeper.egg-info/PKG-INFO 5 | hydra_list_sweeper.egg-info/SOURCES.txt 6 | hydra_list_sweeper.egg-info/dependency_links.txt 7 | hydra_list_sweeper.egg-info/requires.txt 8 | hydra_list_sweeper.egg-info/top_level.txt 9 | hydra_plugins/list_sweeper_plugin/__init__.py 10 | hydra_plugins/list_sweeper_plugin/list_sweeper.py 11 | hydra_plugins/list_sweeper_plugin/py.typed 12 | tests/test_example_sweeper_plugin.py -------------------------------------------------------------------------------- /example/conf/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - algorithm: adam 4 | - env: balancing_stick 5 | # override this to use the new list sweeper: 6 | - override hydra/sweeper: list 7 | 8 | hydra: 9 | mode: MULTIRUN 10 | sweeper: 11 | # standard grid search 12 | grid_params: 13 | env: 5_clubs_juggling, balancing_stick 14 | # additional list sweeper 15 | list_params: 16 | algorithm.lr: 0.001, 0.0001 17 | algorithm.beta_1: [0.9, 0.99] # both notations work 18 | ablative_params: 19 | - algorithm.beta_2: 0.5 20 | algorithm.epsilon: 1e-4 21 | - algorithm.beta_1: 0.3 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # type: ignore 3 | from setuptools import find_namespace_packages, setup 4 | 5 | with open("README.md") as fh: 6 | LONG_DESC = fh.read() 7 | setup( 8 | name="hydra-list-sweeper", 9 | version="1.1.0", 10 | author="Philipp Dahlinger", 11 | author_email="philipp.dahlinger@kit.edu", 12 | description="List Hydra Sweeper plugin", 13 | long_description=LONG_DESC, 14 | long_description_content_type="text/markdown", 15 | url="https://github.com/facebookresearch/hydra/", 16 | packages=find_namespace_packages(include=["hydra_plugins.*"]), 17 | classifiers=[ 18 | # Feel free to use another license. 19 | "License :: OSI Approved :: MIT License", 20 | # Hydra uses Python version and Operating system to determine 21 | # In which environments to test this plugin 22 | "Programming Language :: Python :: 3.7", 23 | "Programming Language :: Python :: 3.8", 24 | "Programming Language :: Python :: 3.9", 25 | "Programming Language :: Python :: 3.10", 26 | "Programming Language :: Python :: 3.11", 27 | "Operating System :: OS Independent", 28 | ], 29 | install_requires=[ 30 | # consider pinning to a specific major version of Hydra to avoid unexpected problems 31 | # if a new major version of Hydra introduces breaking changes for plugins. 32 | # e.g: "hydra-core==1.0.*", 33 | "hydra-core", 34 | ], 35 | # If this plugin is providing configuration files, be sure to include them in the package. 36 | # See MANIFEST.in. 37 | # For configurations to be discoverable at runtime, they should also be added to the search path. 38 | include_package_data=True, 39 | ) 40 | -------------------------------------------------------------------------------- /tests/test_example_sweeper_plugin.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from pytest import mark 3 | 4 | from hydra.core.plugins import Plugins 5 | from hydra.plugins.sweeper import Sweeper 6 | from hydra.test_utils.launcher_common_tests import ( 7 | BatchedSweeperTestSuite, 8 | IntegrationTestSuite, 9 | LauncherTestSuite, 10 | ) 11 | from hydra.test_utils.test_utils import TSweepRunner 12 | 13 | from hydra_plugins.example_sweeper_plugin.example_sweeper import ExampleSweeper 14 | 15 | 16 | def test_discovery() -> None: 17 | # Tests that this plugin can be discovered via the plugins subsystem when looking at the Sweeper plugins 18 | assert ExampleSweeper.__name__ in [ 19 | x.__name__ for x in Plugins.instance().discover(Sweeper) 20 | ] 21 | 22 | 23 | def test_launched_jobs(hydra_sweep_runner: TSweepRunner) -> None: 24 | sweep = hydra_sweep_runner( 25 | calling_file=None, 26 | calling_module="hydra.test_utils.a_module", 27 | config_path="configs", 28 | config_name="compose.yaml", 29 | task_function=None, 30 | overrides=["hydra/sweeper=example", "hydra/launcher=basic", "foo=1,2"], 31 | ) 32 | with sweep: 33 | assert sweep.returns is not None 34 | job_ret = sweep.returns[0] 35 | assert len(job_ret) == 2 36 | assert job_ret[0].overrides == ["foo=1"] 37 | assert job_ret[0].cfg == {"foo": 1, "bar": 100} 38 | assert job_ret[1].overrides == ["foo=2"] 39 | assert job_ret[1].cfg == {"foo": 2, "bar": 100} 40 | 41 | 42 | # Run launcher test suite with the basic launcher and this sweeper 43 | @mark.parametrize( 44 | "launcher_name, overrides", 45 | [ 46 | ( 47 | "basic", 48 | [ 49 | # CHANGE THIS TO YOUR SWEEPER CONFIG NAME 50 | "hydra/sweeper=example" 51 | ], 52 | ) 53 | ], 54 | ) 55 | class TestExampleSweeper(LauncherTestSuite): 56 | ... 57 | 58 | 59 | # Many sweepers are batching jobs in groups. 60 | # This test suite verifies that the spawned jobs are not overstepping the directories of one another. 61 | @mark.parametrize( 62 | "launcher_name, overrides", 63 | [ 64 | ( 65 | "basic", 66 | [ 67 | # CHANGE THIS TO YOUR SWEEPER CONFIG NAME 68 | "hydra/sweeper=example", 69 | # This will cause the sweeper to split batches to at most 2 jobs each, which is what 70 | # the tests in BatchedSweeperTestSuite are expecting. 71 | "hydra.sweeper.max_batch_size=2", 72 | ], 73 | ) 74 | ], 75 | ) 76 | class TestExampleSweeperWithBatching(BatchedSweeperTestSuite): 77 | ... 78 | 79 | 80 | # Run integration test suite with the basic launcher and this sweeper 81 | @mark.parametrize( 82 | "task_launcher_cfg, extra_flags", 83 | [ 84 | ( 85 | {}, 86 | [ 87 | "-m", 88 | # CHANGE THIS TO YOUR SWEEPER CONFIG NAME 89 | "hydra/sweeper=example", 90 | "hydra/launcher=basic", 91 | ], 92 | ) 93 | ], 94 | ) 95 | class TestExampleSweeperIntegration(IntegrationTestSuite): 96 | pass 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # List Sweeper plugin for Hydra 2 | 3 | Sweeper plugin for Hydra which creates a list option additionally to the cartesian product ("grid"), 4 | which allows to sweep over the zipped list of parameters. 5 | This allows to test only a subset of the cartesian product and it is useful for small hyperparameter searches. 6 | 7 | ## Installation 8 | ```bash 9 | pip install hydra-list-sweeper 10 | ``` 11 | This will install the plugin in your current environment. 12 | 13 | You can check if the plugin is installed by adding `--info plugins` to your command line. 14 | The plugin should be listed in the output as `hydra_plugins.list_sweeper_plugin.list_sweeper`. 15 | 16 | In order to enable this plugin, you need to override the default sweeper in your configuration file: 17 | ```yaml 18 | defaults: 19 | - _self_ 20 | # override this to use the new list sweeper: 21 | - override hydra/sweeper: list 22 | ``` 23 | 24 | ## Usage 25 | 26 | `List sweeper` uses the a similar syntax as the standard sweeper, but instead of a `params` key, it uses `grid_params` and `list_params`: 27 | ```yaml 28 | hydra: 29 | mode: MULTIRUN 30 | sweeper: 31 | # standard grid search 32 | grid_params: 33 | env: 5_clubs_juggling, balancing_stick 34 | # additional list sweeper 35 | list_params: 36 | algorithm.lr: 0.001, 0.0001 37 | algorithm.beta_1: [0.9, 0.99] # both notations work 38 | ``` 39 | This configuration will create 4 jobs: 40 | ```text 41 | env=5_clubs_juggling, algorithm.lr=0.001, algorithm.beta_1=0.9 42 | env=5_clubs_juggling, algorithm.lr=0.0001, algorithm.beta_1=0.99 43 | env=balancing_stick, algorithm.lr=0.001, algorithm.beta_1=0.9 44 | env=balancing_stick, algorithm.lr=0.0001, algorithm.beta_1=0.99 45 | ``` 46 | 47 | Basically, it grids over all grid params, creating the standard cartesian product, 48 | and then for each of these combinations, it creates a job for each of the list params. 49 | You can additionally overwrite single values with command line arguments, and even define your grid_params in the command line: 50 | ```yaml 51 | hydra: 52 | mode: MULTIRUN 53 | sweeper: 54 | # additional list sweeper 55 | list_params: 56 | algorithm.lr: 0.001, 0.0001 57 | algorithm.beta_1: [0.9, 0.99] # both notations work 58 | ``` 59 | Combined with this command 60 | ```bash 61 | python my_app.py env=5_clubs_juggling,balancing_stick 62 | ``` 63 | will produce the same results as the first example. Also, you can override configs with the command line and the grid_params: 64 | 65 | ```yaml 66 | hydra: 67 | mode: MULTIRUN 68 | sweeper: 69 | # standard grid search 70 | grid_params: 71 | env: 5_clubs_juggling, balancing_stick 72 | # additional list sweeper 73 | list_params: 74 | algorithm.lr: 0.001, 0.0001 75 | algorithm.beta_1: [0.9, 0.99] # both notations work 76 | ``` 77 | Combined with this command: 78 | ```bash 79 | python my_app.py algorithm.epsilon=1.0e-4 80 | ``` 81 | 82 | will produce the same results as the first example, but epsilon will be set to 1.0e-4 for all jobs. 83 | 84 | If you remove the `list_params` section, it will behave exactly as the standard grid sweeper (at least it should do, if you find a bug, please report it). 85 | 86 | # Ablative params 87 | You can additionaly define a `ablative_params` section. This must be a list of dictionaries. For example 88 | ```yaml 89 | ablative_params: 90 | - algorithm.beta_2: 0.5 91 | algorithm.epsilon: 1e-4 92 | - algorithm.beta_1: 0.3 93 | ``` 94 | If the `ablative_params` are present, it will 95 | 1. sweep over all the jobs generated by list and grid ignoring the ablative params 96 | 2. For each dictionary in `ablative_params`, it will replace or add the key-value pairs in the dictionary to all the jobs generated in step 1. 97 | 98 | In the example above, it will generate 4 jobs from the list and grid params. Since 2 ablative dictionaries are present, it will in total generate $4 * (1 +1+ 1)$ jobs. 99 | 4 jobs from the list and grid params, and 4 jobs per dictionary in the `ablative_params` section. 100 | 101 | ## Limitations 102 | In the `ablative_parmas` section, you can only specify concrete parameters, changing a complete sub-config is not implemented currently. 103 | -------------------------------------------------------------------------------- /hydra_list_sweeper.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: hydra-list-sweeper 3 | Version: 1.1.0 4 | Summary: List Hydra Sweeper plugin 5 | Home-page: https://github.com/facebookresearch/hydra/ 6 | Author: Philipp Dahlinger 7 | Author-email: philipp.dahlinger@kit.edu 8 | Classifier: License :: OSI Approved :: MIT License 9 | Classifier: Programming Language :: Python :: 3.7 10 | Classifier: Programming Language :: Python :: 3.8 11 | Classifier: Programming Language :: Python :: 3.9 12 | Classifier: Programming Language :: Python :: 3.10 13 | Classifier: Programming Language :: Python :: 3.11 14 | Classifier: Operating System :: OS Independent 15 | Description-Content-Type: text/markdown 16 | Requires-Dist: hydra-core 17 | 18 | # List Sweeper plugin for Hydra 19 | 20 | Sweeper plugin for Hydra which creates a list option additionally to the cartesian product ("grid"), 21 | which allows to sweep over the zipped list of parameters. 22 | This allows to test only a subset of the cartesian product and it is useful for small hyperparameter searches. 23 | 24 | ## Installation 25 | ```bash 26 | pip install hydra-list-sweeper 27 | ``` 28 | This will install the plugin in your current environment. 29 | 30 | You can check if the plugin is installed by adding `--info plugins` to your command line. 31 | The plugin should be listed in the output as `hydra_plugins.list_sweeper_plugin.list_sweeper`. 32 | 33 | In order to enable this plugin, you need to override the default sweeper in your configuration file: 34 | ```yaml 35 | defaults: 36 | - _self_ 37 | # override this to use the new list sweeper: 38 | - override hydra/sweeper: list 39 | ``` 40 | 41 | ## Usage 42 | 43 | `List sweeper` uses the a similar syntax as the standard sweeper, but instead of a `params` key, it uses `grid_params` and `list_params`: 44 | ```yaml 45 | hydra: 46 | mode: MULTIRUN 47 | sweeper: 48 | # standard grid search 49 | grid_params: 50 | env: 5_clubs_juggling, balancing_stick 51 | # additional list sweeper 52 | list_params: 53 | algorithm.lr: 0.001, 0.0001 54 | algorithm.beta_1: [0.9, 0.99] # both notations work 55 | ``` 56 | This configuration will create 4 jobs: 57 | ```text 58 | env=5_clubs_juggling, algorithm.lr=0.001, algorithm.beta_1=0.9 59 | env=5_clubs_juggling, algorithm.lr=0.0001, algorithm.beta_1=0.99 60 | env=balancing_stick, algorithm.lr=0.001, algorithm.beta_1=0.9 61 | env=balancing_stick, algorithm.lr=0.0001, algorithm.beta_1=0.99 62 | ``` 63 | 64 | Basically, it grids over all grid params, creating the standard cartesian product, 65 | and then for each of these combinations, it creates a job for each of the list params. 66 | You can additionally overwrite single values with command line arguments, and even define your grid_params in the command line: 67 | ```yaml 68 | hydra: 69 | mode: MULTIRUN 70 | sweeper: 71 | # additional list sweeper 72 | list_params: 73 | algorithm.lr: 0.001, 0.0001 74 | algorithm.beta_1: [0.9, 0.99] # both notations work 75 | ``` 76 | Combined with this command 77 | ```bash 78 | python my_app.py env=5_clubs_juggling,balancing_stick 79 | ``` 80 | will produce the same results as the first example. Also, you can override configs with the command line and the grid_params: 81 | 82 | ```yaml 83 | hydra: 84 | mode: MULTIRUN 85 | sweeper: 86 | # standard grid search 87 | grid_params: 88 | env: 5_clubs_juggling, balancing_stick 89 | # additional list sweeper 90 | list_params: 91 | algorithm.lr: 0.001, 0.0001 92 | algorithm.beta_1: [0.9, 0.99] # both notations work 93 | ``` 94 | Combined with this command: 95 | ```bash 96 | python my_app.py algorithm.epsilon=1.0e-4 97 | ``` 98 | 99 | will produce the same results as the first example, but epsilon will be set to 1.0e-4 for all jobs. 100 | 101 | If you remove the `list_params` section, it will behave exactly as the standard grid sweeper (at least it should do, if you find a bug, please report it). 102 | 103 | # Ablative params 104 | You can additionaly define a `ablative_params` section. This must be a list of dictionaries. For example 105 | ```yaml 106 | ablative_params: 107 | - algorithm.beta_2: 0.5 108 | algorithm.epsilon: 1e-4 109 | - algorithm.beta_1: 0.3 110 | ``` 111 | If the `ablative_params` are present, it will 112 | 1. sweep over all the jobs generated by list and grid ignoring the ablative params 113 | 2. For each dictionary in `ablative_params`, it will replace or add the key-value pairs in the dictionary to all the jobs generated in step 1. 114 | 115 | In the example above, it will generate 4 jobs from the list and grid params. Since 2 ablative dictionaries are present, it will in total generate $4 * (1 +1+ 1)$ jobs. 116 | 4 jobs from the list and grid params, and 4 jobs per dictionary in the `ablative_params` section. 117 | 118 | ## Limitations 119 | In the `ablative_parmas` section, you can only specify concrete parameters, changing a complete sub-config is not implemented currently. 120 | -------------------------------------------------------------------------------- /hydra_plugins/list_sweeper_plugin/list_sweeper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import copy 3 | from dataclasses import dataclass 4 | 5 | import itertools 6 | import logging 7 | from pathlib import Path 8 | from typing import Any, Iterable, List, Optional, Sequence 9 | 10 | from hydra.types import HydraContext 11 | from hydra.core.config_store import ConfigStore 12 | from hydra.core.override_parser.overrides_parser import OverridesParser 13 | from hydra.core.plugins import Plugins 14 | from hydra.plugins.launcher import Launcher 15 | from hydra.plugins.sweeper import Sweeper 16 | from hydra.types import TaskFunction 17 | from omegaconf import DictConfig, OmegaConf, ListConfig 18 | 19 | # IMPORTANT: 20 | # If your plugin imports any module that takes more than a fraction of a second to import, 21 | # Import the module lazily (typically inside sweep()). 22 | # Installed plugins are imported during Hydra initialization and plugins that are slow to import plugins will slow 23 | # the startup of ALL hydra applications. 24 | # Another approach is to place heavy includes in a file prefixed by _, such as _core.py: 25 | # Hydra will not look for plugin in such files and will not import them during plugin discovery. 26 | 27 | log = logging.getLogger(__name__) 28 | 29 | 30 | @dataclass 31 | class LauncherConfig: 32 | _target_: str = ( 33 | "hydra_plugins.list_sweeper_plugin.list_sweeper.ListSweeper" 34 | ) 35 | list_params: DictConfig = DictConfig({}) 36 | grid_params: DictConfig = DictConfig({}) 37 | ablative_params: ListConfig = ListConfig([]) 38 | 39 | 40 | ConfigStore.instance().store(group="hydra/sweeper", name="list", node=LauncherConfig) 41 | 42 | 43 | def flatten_tuple(original_tuple): 44 | new_tuple = () 45 | for item in original_tuple: 46 | if isinstance(item, str): 47 | new_tuple += (item,) 48 | elif isinstance(item, list): 49 | new_tuple += tuple(item) 50 | return new_tuple 51 | 52 | class ListSweeper(Sweeper): 53 | def __init__(self, list_params: DictConfig, grid_params: DictConfig, ablative_params: ListConfig): 54 | self.config: Optional[DictConfig] = None 55 | self.launcher: Optional[Launcher] = None 56 | self.hydra_context: Optional[HydraContext] = None 57 | self.job_results = None 58 | self.list_params = list_params 59 | self.grid_params = grid_params 60 | self.ablative_params = ablative_params 61 | 62 | def setup( 63 | self, 64 | *, 65 | hydra_context: HydraContext, 66 | task_function: TaskFunction, 67 | config: DictConfig, 68 | ) -> None: 69 | self.config = config 70 | self.launcher = Plugins.instance().instantiate_launcher( 71 | hydra_context=hydra_context, task_function=task_function, config=config 72 | ) 73 | self.hydra_context = hydra_context 74 | 75 | def sweep(self, arguments: List[str]) -> Any: 76 | assert self.config is not None 77 | assert self.launcher is not None 78 | print(f"Sweep output dir : {self.config.hydra.sweep.dir}") 79 | 80 | # Save sweep run config in top level sweep working directory 81 | sweep_dir = Path(self.config.hydra.sweep.dir) 82 | sweep_dir.mkdir(parents=True, exist_ok=True) 83 | OmegaConf.save(self.config, sweep_dir / "multirun.yaml") 84 | 85 | parser = OverridesParser.create() 86 | parsed = parser.parse_overrides(arguments) 87 | grid_lists = [] 88 | grid_keys = [] 89 | # manage overrides 90 | for override in parsed: 91 | if override.is_sweep_override(): 92 | # Sweepers must manipulate only overrides that return true to is_sweep_override() 93 | # This syntax is shared across all sweepers, so it may limiting. 94 | # Sweeper must respect this though: failing to do so will cause all sorts of hard to debug issues. 95 | # If you would like to propose an extension to the grammar (enabling new types of sweep overrides) 96 | # Please file an issue and describe the use case and the proposed syntax. 97 | # Be aware that syntax extensions are potentially breaking compatibility for existing users and the 98 | # use case will be scrutinized heavily before the syntax is changed. 99 | sweep_choices = override.sweep_string_iterator() 100 | key = override.get_key_element() 101 | sweep = [f"{key}={val}" for val in sweep_choices] 102 | grid_lists.append(sweep) 103 | else: 104 | key = override.get_key_element() 105 | value = override.get_value_element_as_str() 106 | grid_lists.append([f"{key}={value}"]) 107 | grid_keys.append(key) 108 | 109 | # manage grid params 110 | for key in self.grid_params: 111 | values = self.grid_params[key] 112 | values = self.parse(key, values) 113 | grid_lists.append([f"{key}={value}" for value in values]) 114 | grid_keys.append(key) 115 | 116 | list_lists = [] 117 | values_length = None 118 | for key in self.list_params: 119 | if key in grid_keys: 120 | log.warning(f"List key {key} is also a grid key. The list key will be ignored.") 121 | continue 122 | values = self.list_params[key] 123 | values = self.parse(key, values) 124 | # check if all lists have the same length 125 | if values_length is None: 126 | values_length = len(values) 127 | elif len(values) != values_length: 128 | raise ValueError(f"List key {key} has different length than other list keys") 129 | for idx, value in enumerate(values): 130 | if len(list_lists) <= idx: 131 | list_lists.append([]) 132 | list_lists[idx].append(f"{key}={value}") 133 | if len(list_lists) == 0: 134 | batch = list(itertools.product(*grid_lists)) 135 | else: 136 | batch = list(itertools.product(*grid_lists, list_lists)) 137 | # the list params are flattened to be part of the tuple 138 | batch = [flatten_tuple(x) for x in batch] 139 | 140 | # copy with ablative params 141 | if len(self.ablative_params) > 0: 142 | complete_batch = copy.deepcopy(batch) # list which builds up with ablative, starting with the original batch 143 | for ablative_dict in self.ablative_params: 144 | new_batch = copy.deepcopy(batch) 145 | # create lists out of tuples 146 | new_batch = [list(x) for x in new_batch] 147 | # replace the overwritten keys 148 | for job in new_batch: 149 | for key, value in ablative_dict.items(): 150 | found_key = False 151 | for idx, key_param_str in enumerate(job): 152 | job_key = key_param_str.split("=")[0] 153 | if key == job_key: 154 | found_key = True 155 | # overwrite with new value 156 | job[idx] = f"{key}={ablative_dict[key]}" 157 | if not found_key: 158 | # add it to the job 159 | job.append(f"{key}={ablative_dict[key]}") 160 | # finished ablated job, can be added to the complete batch 161 | complete_batch.append(tuple(job)) 162 | # overwrite the batch with the ablative batch 163 | batch = complete_batch 164 | 165 | initial_job_idx = 0 166 | returns = [self.launcher.launch(batch, initial_job_idx)] 167 | return returns 168 | 169 | def parse(self, key, values): 170 | if isinstance(values, int) or isinstance(values, float) or isinstance(values, bool): 171 | values = [values] 172 | elif isinstance(values, str): 173 | if "," in values: 174 | # parse string 175 | values = values.replace(" ", "") 176 | values = values.replace("[", "") 177 | values = values.replace("]", "") 178 | values = values.split(",") 179 | else: 180 | # only single string value 181 | values = [values] 182 | elif isinstance(values, ListConfig): 183 | values = values._content 184 | else: 185 | raise ValueError(f"Cannot parse '{values}' for list key {key}") 186 | return values 187 | --------------------------------------------------------------------------------