├── .github ├── dependabot.yml └── workflows │ ├── publish.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── PyExpUtils ├── FileSystemContext.py ├── __init__.py ├── collection │ ├── Collector.py │ ├── Sampler.py │ ├── __init__.py │ └── utils.py ├── models │ ├── Config.py │ ├── ExperimentDescription.py │ └── __init__.py ├── parallel_runner.py ├── py.typed ├── results │ ├── Collection.py │ ├── LazyCollection.py │ ├── __init__.py │ ├── _utils │ │ └── shared.py │ ├── indices.py │ ├── migrations.py │ ├── pandas.py │ ├── sqlite.py │ ├── sqlite_utils.py │ ├── tools.py │ └── voting.py ├── runner │ ├── Slurm.py │ ├── __init__.py │ ├── parallel.py │ ├── parallel_exec.py │ └── utils.py └── utils │ ├── NestedDict.py │ ├── __init__.py │ ├── arrays.py │ ├── asyncio.py │ ├── cache.py │ ├── cmdline.py │ ├── csv.py │ ├── dict.py │ ├── fp.py │ ├── generator.py │ ├── iterable.py │ ├── jit.py │ ├── pandas.py │ ├── path.py │ ├── permute.py │ ├── random.py │ ├── str.py │ └── types.py ├── README.md ├── config.json ├── dev-setup.sh ├── docs └── OrganizationPatterns.md ├── mock_repo └── experiments │ └── overfit │ ├── best │ ├── ann.json │ └── sdl.json │ └── sweeps │ ├── ann.json │ └── sdl.json ├── pyproject.toml ├── requirements.txt ├── scripts ├── generate_docs.py ├── publish.sh └── run_tests.sh ├── tests ├── __init__.py ├── _utils │ └── pandas.py ├── models │ ├── __init__.py │ └── test_ExperimentDescription.py ├── results │ ├── __init__.py │ ├── test_indices.py │ ├── test_tools.py │ └── test_voting.py ├── runner │ ├── __init__.py │ ├── test_parallel.py │ └── test_slurm.py ├── test_FileSystemContext.py └── utils │ ├── __init__.py │ ├── test_Collector.py │ ├── test_arrays.py │ ├── test_cmdline.py │ ├── test_csv.py │ ├── test_dict.py │ ├── test_generator.py │ ├── test_path.py │ ├── test_permute.py │ ├── test_random.py │ └── test_str.py └── typings ├── h5py └── __init__.pyi └── numba ├── __init__.pyi ├── experimental └── __init__.pyi └── typed └── __init__.pyi /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | workflow_run: 5 | workflows: ['Test'] 6 | branches: [master] 7 | types: 8 | - completed 9 | 10 | jobs: 11 | build: 12 | if: ${{ github.event.workflow_run.conclusion == 'success' }} 13 | 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - run: git fetch --prune --unshallow 20 | 21 | - name: Set up Python 3.11 22 | uses: actions/setup-python@v4 23 | with: 24 | python-version: "3.11" 25 | cache: 'pip' 26 | 27 | - name: Install dependencies 28 | run: | 29 | python -m venv .venv 30 | source .venv/bin/activate 31 | python -m pip install --upgrade pip 32 | pip install -r requirements.txt 33 | echo PATH=$PATH >> $GITHUB_ENV 34 | 35 | - name: Publish 36 | env: 37 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 38 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 39 | run: ./scripts/publish.sh 40 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [ '*' ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.10", "3.11"] 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - uses: chartboost/ruff-action@v1 20 | 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v4 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | cache: 'pip' 26 | 27 | - name: Install dependencies 28 | run: | 29 | python -m venv .venv 30 | source .venv/bin/activate 31 | python -m pip install --upgrade pip 32 | pip install -r requirements.txt 33 | echo PATH=$PATH >> $GITHUB_ENV 34 | 35 | - name: Test 36 | env: 37 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 38 | run: ./scripts/run_tests.sh 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .vscode/ 3 | *.pyc 4 | .mypy_cache 5 | .token 6 | 7 | .venv/ 8 | env/ 9 | dist/ 10 | PyExpUtils_andnp* 11 | .pdm-python 12 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/commitizen-tools/commitizen 3 | rev: v2.21.2 4 | hooks: 5 | - id: commitizen 6 | stages: 7 | - commit-msg 8 | - repo: https://github.com/astral-sh/ruff-pre-commit 9 | # Ruff version. 10 | rev: v0.2.2 11 | hooks: 12 | # Run the linter. 13 | - id: ruff 14 | -------------------------------------------------------------------------------- /PyExpUtils/FileSystemContext.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import PyExpUtils.utils.path as Path 4 | 5 | class FileSystemContext: 6 | def __init__(self, path: str, base: str = ''): 7 | self._path = path 8 | self._base = base 9 | 10 | def getBase(self): 11 | return self._base 12 | 13 | def resolve(self, path: str = ''): 14 | base = Path.join(self._base, self._path) 15 | 16 | path = path.replace(base + '/', '') 17 | 18 | while path.startswith('../'): 19 | path = path[3:] 20 | base = Path.up(base) 21 | 22 | if path == '': 23 | return base 24 | 25 | return Path.join(base, path) 26 | 27 | def exists(self, path: str = ''): 28 | path = self.resolve(path) 29 | return os.path.exists(path) 30 | 31 | def ensureExists(self, path: str = '', is_file: bool = False): 32 | path = self.resolve(path) 33 | di = path 34 | if is_file: 35 | di = os.path.dirname(path) 36 | 37 | os.makedirs(di, exist_ok=True) 38 | return path 39 | 40 | def remove(self, path: str = ''): 41 | files = self.resolve(path) 42 | shutil.rmtree(files) 43 | -------------------------------------------------------------------------------- /PyExpUtils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andnp/PyExpUtils/5d076ff1196368a936b18998afd00c80d4699857/PyExpUtils/__init__.py -------------------------------------------------------------------------------- /PyExpUtils/collection/Collector.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List 2 | from PyExpUtils.collection.Sampler import Sampler, Ignore, Identity 3 | 4 | """doc 5 | A frame-based data collection utility. 6 | The collector stores some context---which index is currently being run, what is the current timestep, etc.--- 7 | and associates collected data with this context. 8 | 9 | Example usage: 10 | ```python 11 | collector = Collector( 12 | config={ 13 | # a dictionary mapping keys -> data preprocessors 14 | # for instance performing fixed-window averaging 15 | 'return': Window(100), 16 | # or subsampling 1 of every 100 values 17 | 'reward': Subsample(100), 18 | # or moving averages 19 | 'error': MovingAverage(0.99), 20 | # or ignored entirely 21 | 'special': Ignore(), 22 | }, 23 | # by default, if a key is not mentioned above it is stored as-is 24 | # however this can be changed by passing a default preprocessor 25 | default=Identity() 26 | ) 27 | 28 | # tell the collector what idx of the experiment we are currently processing 29 | collector.setIdx(0) 30 | 31 | for step in range(exp.max_steps): 32 | # tell the collector to increment the frame 33 | collector.next_frame() 34 | 35 | # these values will be associated with the current idx and frame 36 | collector.collect('reward', r) 37 | collector.collect('error', delta) 38 | 39 | # not all values need to be stored at each frame 40 | if step % 100 == 0: 41 | collector.collect('special', 'test value') 42 | ``` 43 | """ 44 | class Collector: 45 | def __init__(self, config: Dict[str, Sampler | Ignore] = {}, idx: int | None = None, default: Identity | Ignore = Identity()): 46 | self._d: List[Dict[str, Any]] = [] 47 | self._c = config 48 | 49 | self._ignore = set(k for k, sampler in config.items() if isinstance(sampler, Ignore)) 50 | self._sampler: Dict[str, Sampler] = { 51 | k: sampler for k, sampler in config.items() if not isinstance(sampler, Ignore) 52 | } 53 | 54 | self._idx: int | None = idx 55 | self._frame: int = -1 56 | self._cur: Dict[str, Any] = {} 57 | self._con: Dict[str, Any] = {} 58 | 59 | # create this once and cache it since it is stateless 60 | # avoid recreating on every step 61 | self._def = default 62 | 63 | # cache some useful metadata 64 | self._idxs = set[int]() 65 | self._keys = set[str]() 66 | 67 | # ------------- 68 | # -- Context -- 69 | # ------------- 70 | def setContext(self, context: Dict[str, Any]): 71 | self._con |= context 72 | self._keys |= set(context.keys()) 73 | 74 | def addContext(self, key: str, val: Any): 75 | self._con[key] = val 76 | self._keys.add(key) 77 | 78 | def setIdx(self, idx: int): 79 | if self._idx is not None: 80 | self.reset() 81 | 82 | self._idxs.add(idx) 83 | self._idx = idx 84 | self._frame = -1 85 | self._cur = {} 86 | 87 | def getIdx(self): 88 | assert self._idx is not None 89 | return self._idx 90 | 91 | def next_frame(self): 92 | self._frame += 1 93 | 94 | if self._cur: 95 | self._cur['idx'] = self.getIdx() 96 | self._cur['frame'] = self._frame - 1 97 | self._d.append(self._cur | self._con) 98 | self._cur = {} 99 | 100 | def reset(self): 101 | self.next_frame() 102 | for k in self._sampler: 103 | v = self._sampler[k].end() 104 | if v is None: continue 105 | 106 | self._cur[k] = v 107 | 108 | self.next_frame() 109 | self._frame = -1 110 | self._cur = {} 111 | 112 | # ------------- 113 | # -- Storing -- 114 | # ------------- 115 | def collect(self, name: str, value: Any): 116 | if name in self._ignore: 117 | return 118 | 119 | v = self._sampler.get(name, self._def).next(value) 120 | if v is None: 121 | return 122 | 123 | self._keys.add(name) 124 | self._cur[name] = v 125 | 126 | def evaluate(self, name: str, lmbda: Callable[[], Any]): 127 | if name in self._ignore: 128 | return 129 | 130 | v = self._sampler.get(name, self._def).next_eval(lmbda) 131 | if v is None: 132 | return 133 | 134 | self._keys.add(name) 135 | self._cur[name] = v 136 | 137 | # --------------- 138 | # -- Accessing -- 139 | # --------------- 140 | def get(self, name: str, idx: int): 141 | out = [ 142 | d[name] for d in self._d if d['idx'] == idx 143 | ] 144 | return out 145 | 146 | def get_frames(self, idx: int): 147 | return [ d for d in self._d if d['idx'] == idx ] 148 | 149 | def get_last(self, name: str): 150 | arr = self.get(name, self.getIdx()) 151 | return arr[-1] 152 | 153 | def keys(self): 154 | return self._keys 155 | 156 | def indices(self): 157 | return self._idxs 158 | -------------------------------------------------------------------------------- /PyExpUtils/collection/Sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from abc import abstractmethod 4 | from typing import Callable, Generator 5 | 6 | class Sampler: 7 | def next(self, v: float) -> float | None: ... 8 | def next_eval(self, v: Callable[[], float]) -> float | None: ... 9 | 10 | @abstractmethod 11 | def repeat(self, v: float, times: int) -> Generator[float, None, None]: ... 12 | def end(self) -> float | None: ... 13 | 14 | class Ignore: 15 | def __init__(self): ... 16 | def next(self, v): return None 17 | def next_eval(self, v): return None 18 | def repeat(self, v, times): yield None 19 | def end(self): return None 20 | 21 | class Identity(Sampler): 22 | def next(self, v: float): 23 | return v 24 | 25 | def next_eval(self, c: Callable[[], float]): 26 | return c() 27 | 28 | def repeat(self, v: float, times: int): 29 | for _ in range(times): yield v 30 | 31 | def end(self): 32 | return None 33 | 34 | 35 | # -------------- 36 | # -- Samplers -- 37 | # -------------- 38 | class Window(Sampler): 39 | def __init__(self, size: int): 40 | self._b = np.empty(size, dtype=np.float64) 41 | self._clock = 0 42 | self._size = size 43 | 44 | def next(self, v: float): 45 | self._b[self._clock] = v 46 | self._clock += 1 47 | 48 | if self._clock == self._size: 49 | m = self._b.mean() 50 | self._clock = 0 51 | return m 52 | 53 | def next_eval(self, c: Callable[[], float]): 54 | return self.next(c()) 55 | 56 | def repeat(self, v: float, times: int): 57 | while times > 0: 58 | r = self._size - self._clock 59 | r = min(times, r) 60 | 61 | # I can save a good chunk of compute if I know the whole window 62 | # is filled with v. Then the mean is clearly also v. 63 | if self._clock == 0 and r == self._size: 64 | times -= r 65 | yield v 66 | continue 67 | 68 | e = self._clock + r 69 | self._b[self._clock:e] = v 70 | self._clock = (self._clock + r) % self._size 71 | 72 | times -= r 73 | 74 | if self._clock == 0: 75 | yield self._b.mean() 76 | 77 | def end(self): 78 | out = None 79 | if self._clock > 0: 80 | out = self._b[:self._clock].mean() 81 | 82 | self._clock = 0 83 | return out 84 | 85 | class Subsample(Sampler): 86 | def __init__(self, freq: int, trailing_edge: bool = False, first: bool = True): 87 | self._clock = 0 88 | self._freq = freq 89 | 90 | self._first = first 91 | self._target = 0 92 | if trailing_edge: 93 | self._target = freq - 1 94 | 95 | def next(self, v: float): 96 | tick = self._clock % self._freq == self._target or (self._first and self._clock == 0) 97 | self._clock += 1 98 | 99 | if tick: 100 | return v 101 | 102 | def next_eval(self, c: Callable[[], float]): 103 | tick = self._clock % self._freq == self._target or (self._first and self._clock == 0) 104 | self._clock += 1 105 | 106 | if tick: 107 | return c() 108 | 109 | def repeat(self, v: float, times: int): 110 | if self._clock % self._freq == self._target or (self._first and self._clock == 0): 111 | yield v 112 | 113 | r = self._clock + times 114 | reps = int(r // self._freq) 115 | for _ in range(reps): 116 | yield v 117 | 118 | self._clock += times 119 | 120 | def end(self): 121 | self._clock = 0 122 | return None 123 | 124 | class MovingAverage(Sampler): 125 | def __init__(self, decay: float): 126 | self._decay = decay 127 | self.z = 0. 128 | 129 | def next(self, v: float): 130 | self.z = self._decay * self.z + (1. - self._decay) * v 131 | return self.z 132 | 133 | def next_eval(self, c: Callable[[], float]): 134 | v = c() 135 | return self.next(v) 136 | 137 | def repeat(self, v: float, times: int): 138 | for _ in range(times): 139 | self.next(v) 140 | 141 | def end(self): 142 | return None 143 | -------------------------------------------------------------------------------- /PyExpUtils/collection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andnp/PyExpUtils/5d076ff1196368a936b18998afd00c80d4699857/PyExpUtils/collection/__init__.py -------------------------------------------------------------------------------- /PyExpUtils/collection/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | from PyExpUtils.collection.Sampler import Sampler 3 | 4 | class Pipe(Sampler): 5 | def __init__(self, *args: Sampler) -> None: 6 | self._subs = args 7 | 8 | def next(self, v: float) -> float | None: 9 | out: float | None = v 10 | for sub in self._subs: 11 | if out is None: return None 12 | out = sub.next(out) 13 | 14 | return out 15 | 16 | def next_eval(self, v: Callable[[], float]) -> float | None: 17 | subs = iter(self._subs) 18 | first = next(subs) 19 | out = first.next_eval(v) 20 | 21 | for sub in subs: 22 | if out is None: return None 23 | out = sub.next(out) 24 | 25 | return out 26 | 27 | def repeat(self, v: float, times: int): 28 | for _ in range(times): 29 | self.next(v) 30 | 31 | def end(self): 32 | return None 33 | -------------------------------------------------------------------------------- /PyExpUtils/models/Config.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass 3 | from PyExpUtils.utils.fp import once 4 | 5 | """doc 6 | Experiment utility configuration file. 7 | Specifies global configuration settings: 8 | - *save_path*: directory format where experimental results will be stored 9 | - *log_path*: directory where log files will be saved (e.g. stacktraces during experiments) 10 | - *experiment_directory*: root directory where all of the experiment description files are located 11 | 12 | The config file should be at the root level of the repository and should be named `config.json`. 13 | ``` 14 | .git 15 | .gitignore 16 | tests/ 17 | scripts/ 18 | src/ 19 | config.json 20 | ``` 21 | 22 | An example configuration file: 23 | ```json 24 | { 25 | "save_path": "results/{name}/{environment}/{agent}/{params}", 26 | "log_path": "~/scratch/.logs", 27 | "experiment_directory": "experiments" 28 | } 29 | ``` 30 | """ 31 | @dataclass 32 | class Config: 33 | save_path: str 34 | log_path: str = '.logs' 35 | experiment_directory: str = 'experiments' 36 | 37 | """doc 38 | Memoized global configuration loader. 39 | Will read `config.json` (only once) and return a Config object. 40 | ```python 41 | config = getConfig() 42 | print(config.save_path) # -> 'results' 43 | ``` 44 | """ 45 | @once 46 | def getConfig(): 47 | with open('config.json', 'r') as f: 48 | d = json.load(f) 49 | 50 | return Config(**d) 51 | -------------------------------------------------------------------------------- /PyExpUtils/models/ExperimentDescription.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import copy 4 | import PyExpUtils.utils.path as Path 5 | from PyExpUtils.utils.arrays import unwrap 6 | from PyExpUtils.utils.permute import KVPair, Record, getCountFromPairs, _flattenToKeyValues, getPermutationFromPairs 7 | from PyExpUtils.utils.dict import merge, hyphenatedStringify, pick 8 | from PyExpUtils.utils.str import interpolate 9 | from PyExpUtils.models.Config import getConfig 10 | from PyExpUtils.FileSystemContext import FileSystemContext 11 | 12 | # type checking 13 | from typing import Optional, Union, List, Dict, Any, Type, TypeVar 14 | Keys = Union[str, List[str]] 15 | 16 | """doc 17 | Main workhorse class of the library. 18 | Takes a dictionary desribing all configurable options of an experiment and serializes that dictionary. 19 | Provides a set of utility methods to run parameter sweeps in parallel and for storing data during experiments. 20 | ```python 21 | exp_dict = { 22 | 'algorithm': 'SARSA', 23 | 'environment': 'MountainCar', 24 | 'metaParameters': { 25 | 'alpha': [1.0, 0.5, 0.25, 0.125], 26 | 'lambda': [1.0, 0.99, 0.98, 0.96] 27 | } 28 | } 29 | exp = ExperimentDescription(d) 30 | ``` 31 | """ 32 | class ExperimentDescription: 33 | def __init__(self, d: Dict[str, Any], path: Optional[str] = None, keys: Keys = 'metaParameters', save_key: Optional[str] = None) -> None: 34 | # the raw serialized json 35 | self._d = d 36 | # a collection of keys to permute over 37 | self.keys = keys 38 | # path to the experiment description file 39 | self.path = path 40 | # interpolation key for saving 41 | self.save_key = save_key 42 | 43 | # cached data 44 | self._num_perms: Optional[int] = None 45 | self._pairs: Optional[List[KVPair]] = None 46 | 47 | # get the keys to permute over 48 | def getKeys(self, keys: Optional[Keys] = None): 49 | keys = keys if keys is not None else self.keys 50 | return keys if isinstance(keys, list) else [keys] 51 | 52 | def _getSaveKey(self, save_key: Optional[str] = None): 53 | if save_key is not None: 54 | return save_key 55 | 56 | if self.save_key is not None: 57 | return self.save_key 58 | 59 | config = getConfig() 60 | return config.save_path 61 | 62 | """doc 63 | Gives a list of parameters that can be swept over. 64 | 65 | Using above example dictionary: 66 | ```python 67 | params = exp.permutable() 68 | print(params) # -> { 'alpha': [1.0, 0.5, 0.25, 0.125], 'lambda': [1.0, 0.99, 0.98, 0.96] } 69 | ``` 70 | """ 71 | def permutable(self) -> Dict[str, Any]: 72 | keys = self.getKeys() 73 | 74 | sweeps: Record = {} 75 | for key in keys: 76 | sweeps[key] = self._d[key] 77 | 78 | return sweeps 79 | 80 | """doc 81 | Gives the `i`'th permutation of sweepable parameters. 82 | Handles wrapping indices, so can perform multiple runs of the same parameter setting by setting `i` large. 83 | 84 | In the above dictionary, there are 16 total parameter permutations. 85 | ```python 86 | params = exp.getPermutation(0) 87 | print(params) # -> { 'alpha': 1.0, 'lambda': 1.0 } 88 | 89 | params = exp.getPermutation(1) 90 | print(params) # -> { 'alpha': 1.0, 'lambda': 0.99 } 91 | 92 | params = exp.getPermutation(15) 93 | print(params) # -> { 'alpha': 0.125, 'lambda': 0.96 } 94 | 95 | params = exp.getPermutation(16) 96 | print(params) # -> { 'alpha': 1.0, 'lambda': 1.0 } 97 | ``` 98 | """ 99 | def getPermutation(self, idx: int) -> Record: 100 | if self._pairs is None: 101 | sweeps = self.permutable() 102 | self._pairs = _flattenToKeyValues(sweeps) 103 | 104 | permutation = getPermutationFromPairs(self._pairs, idx) 105 | d = merge(self._d, permutation) 106 | 107 | # since we are caching, we need to guarantee modifications to the returned dict 108 | # are not propagated to the cached dict 109 | return copy.deepcopy(d) 110 | 111 | def get_hypers(self, idx: int): 112 | keys = self.getKeys() 113 | assert len(keys) == 1 114 | 115 | params = self.getPermutation(idx) 116 | return params[keys[0]] 117 | 118 | """doc 119 | Gives the total number of parameter permutations. 120 | 121 | ```python 122 | num_params = exp.numPermutations() 123 | print(num_params) # -> 16 124 | ``` 125 | """ 126 | def numPermutations(self): 127 | if self._num_perms is not None: 128 | return self._num_perms 129 | 130 | if self._pairs is None: 131 | sweeps = self.permutable() 132 | self._pairs = _flattenToKeyValues(sweeps) 133 | 134 | self._num_perms = getCountFromPairs(self._pairs) 135 | return self._num_perms 136 | 137 | """doc 138 | Get the run number based on wrapping the index. 139 | This is a count of how many times we've wrapped back around to the same parameter setting. 140 | 141 | ```python 142 | num = exp.getRun(0) 143 | print(num) # -> 0 144 | 145 | num = exp.getRun(12) 146 | print(num) # -> 0 147 | 148 | num = exp.getRun(16) 149 | print(num) # -> 1 150 | 151 | num = exp.getRun(32) 152 | print(num) # -> 2 153 | ``` 154 | """ 155 | def getRun(self, idx: int): 156 | count = self.numPermutations() 157 | return idx // count 158 | 159 | """doc 160 | Returns the name of the experiment if stated in the dictionary: `{ 'name': 'MountainCar-v0', ... }`. 161 | If not stated, will try to determine the name of the experiment based on the path to the JSON it is stored in (assuming experiments are stored in JSON files). 162 | 163 | ```python 164 | path = 'experiments/MountainCar-v0/sarsa.json' 165 | with open(path, 'r') as f: 166 | d = json.load(path) 167 | 168 | exp = ExperimentDescription(d, path) 169 | 170 | name = exp.getExperimentName() 171 | print(name) # -> d['name'] if available, or 'MountainCar-v0' if not. 172 | ``` 173 | """ 174 | def getExperimentName(self): 175 | cwd = os.getcwd() 176 | exp_dir = getConfig().experiment_directory 177 | 178 | if exp_dir is None: 179 | exp_dir = '' 180 | 181 | if self.path is None: 182 | return str(self._d.get('name', 'unnamed')) 183 | 184 | path = self.path \ 185 | .replace(cwd + '/', '') \ 186 | .replace(exp_dir + '/', '') \ 187 | .replace('./', '') 188 | 189 | return Path.up(path) 190 | 191 | """doc 192 | Takes a parameter index and generates a path for saving results. 193 | The path depends on the configuration settings of the library (i.e. `config.json`). 194 | 195 | Note this uses an opinionated formatting for save paths and parameter string representations. 196 | The configuration file can specify ordering and high-level control over paths, but for more fine-tuned control over how these are saved, inherit from this class and overload this method. 197 | 198 | `config.json`: 199 | ```json 200 | { 201 | "save_path": "results/{name}/{environment}/{agent}/{params}" 202 | } 203 | ``` 204 | 205 | ```python 206 | path = exp.interpolateSavePath(0) 207 | print(path) # -> 'results/MountainCar-v0/SARSA/alpha-1.0_lambda-1.0' 208 | ``` 209 | """ 210 | def interpolateSavePath(self, idx: int, key: Optional[str] = None): 211 | key = self._getSaveKey(key) 212 | 213 | permute = unwrap(self.getKeys()) 214 | params = pick(self.getPermutation(idx), permute) 215 | param_string = hyphenatedStringify(params) 216 | 217 | run = self.getRun(idx) 218 | 219 | special_keys = { 220 | 'params': param_string, 221 | 'run': str(run), 222 | 'name': self.getExperimentName() 223 | } 224 | d = merge(self.__dict__, special_keys) 225 | 226 | return interpolate(str(key), d) 227 | 228 | """doc 229 | Builds a `FileSystemContext` utility object that contains the save path for experimental results. 230 | 231 | ```python 232 | file_context = exp.buildSaveContext(0) 233 | 234 | # make sure folder structure is built 235 | file_context.ensureExists() 236 | 237 | # get the path where results should be saved 238 | path = file_context.resolve('returns.npy') 239 | print(path) # -> '/results/MountainCar-v0/SARSA/alpha-1.0_lambda-1.0/returns.npy' 240 | 241 | # save results 242 | np.save(path, returns) 243 | ``` 244 | """ 245 | def buildSaveContext(self, idx: int, base: str = '', key: Optional[str] = None): 246 | path = self.interpolateSavePath(idx, key) 247 | return FileSystemContext(path, base) 248 | 249 | 250 | Exp = TypeVar('Exp', bound=ExperimentDescription) 251 | 252 | """doc 253 | Loads an ExperimentDescription from a JSON file (preferred way to make ExperimentDescriptions). 254 | 255 | ```python 256 | exp = loadExperiment('experiments/MountainCar-v0/sarsa.json') 257 | ``` 258 | """ 259 | def loadExperiment(path: str, Model: Type[Exp] = ExperimentDescription): 260 | with open(path, 'r') as f: 261 | d = json.load(f) 262 | 263 | return Model(d, path=path) 264 | -------------------------------------------------------------------------------- /PyExpUtils/models/__init__.py: -------------------------------------------------------------------------------- 1 | """doc 2 | A collection of JSON serialization classes with associated utility methods. 3 | """ 4 | -------------------------------------------------------------------------------- /PyExpUtils/parallel_runner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | from PyExpUtils.runner.parallel_exec import ParallelConfig, execute 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--parallel', type=int, required=True) 9 | parser.add_argument('--exec', type=str, required=True) 10 | parser.add_argument('--seq', type=int, required=False, default=1) 11 | parser.add_argument('--tasks', nargs='+', type=int, required=True) 12 | 13 | args = parser.parse_args() 14 | 15 | config = ParallelConfig( 16 | executable=args.exec, 17 | parallel=args.parallel, 18 | sequential=args.seq, 19 | tasks=args.tasks, 20 | ) 21 | 22 | execute(config) 23 | -------------------------------------------------------------------------------- /PyExpUtils/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andnp/PyExpUtils/5d076ff1196368a936b18998afd00c80d4699857/PyExpUtils/py.typed -------------------------------------------------------------------------------- /PyExpUtils/results/Collection.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import os 3 | import glob 4 | import importlib 5 | import dataclasses 6 | import pandas as pd 7 | 8 | from typing import Any, Callable, Dict, Generic, Optional, Sequence, Tuple, Type, TypeVar 9 | from multiprocessing.dummy import Pool 10 | 11 | from PyExpUtils.models.ExperimentDescription import ExperimentDescription, loadExperiment 12 | from PyExpUtils.results.sqlite import loadAllResults 13 | from PyExpUtils.results.tools import getHeader 14 | 15 | 16 | Exp = TypeVar('Exp', bound=ExperimentDescription) 17 | CExp = TypeVar('CExp', bound=ExperimentDescription) 18 | 19 | 20 | @dataclasses.dataclass 21 | class Result(Generic[Exp]): 22 | exp: Exp 23 | df: pd.DataFrame 24 | path: str 25 | 26 | class ResultCollection(Generic[Exp]): 27 | def __init__(self, Model: Optional[Type[Exp]] = None): 28 | self._data: Dict[str, Result[Exp]] = {} 29 | self._Model = Model 30 | 31 | def apply(self, f: Callable[[pd.DataFrame], pd.DataFrame | None]): 32 | for item in self._data.values(): 33 | out = f(item.df) 34 | 35 | if out is not None: 36 | item.df = out 37 | 38 | return self 39 | 40 | def map(self, f: Callable[[pd.DataFrame], pd.DataFrame]): 41 | out = ResultCollection(self._Model) 42 | 43 | for key, item in self._data.items(): 44 | out._data[key] = Result( 45 | exp=item.exp, 46 | df=f(item.df), 47 | path=item.path, 48 | ) 49 | 50 | return out 51 | 52 | def combine(self, folder_columns: Sequence[str | None], file_col: str | None): 53 | out: pd.DataFrame | None = None 54 | for path in self._data.keys(): 55 | parts = path.split('/') 56 | assert len(parts) == len(folder_columns) + 1 57 | 58 | df = self._data[path].df 59 | 60 | for fcol, part in zip(folder_columns, parts): 61 | if fcol is None: continue 62 | 63 | df[fcol] = part 64 | 65 | if file_col is not None: 66 | df[file_col] = parts[-1].replace('.json', '') 67 | 68 | if out is None: 69 | out = df 70 | else: 71 | out = pd.concat((out, df), axis=0, ignore_index=True) 72 | 73 | if out is not None: 74 | out.reset_index(drop=True, inplace=True) 75 | 76 | return out 77 | 78 | def get_hyperparameter_columns(self): 79 | hypers = set[str]() 80 | 81 | for res in self._data.values(): 82 | sub = getHeader(res.exp) 83 | hypers |= set(sub) 84 | 85 | return list(sorted(hypers)) 86 | 87 | def get_any_exp(self): 88 | k = next(iter(self._data)) 89 | return self._data[k].exp 90 | 91 | def __iter__(self): 92 | return iter(self._data.values()) 93 | 94 | def __getitem__(self, key: str | Tuple[str, ...]) -> Result: 95 | if isinstance(key, str): 96 | return self._data[key] 97 | 98 | matches = [] 99 | for k in self._data: 100 | parts = k.split('/') 101 | parts[-1] = parts[-1].replace('.json', '') 102 | if all(any(query == part for part in parts) for query in key): 103 | matches.append(self._data[k]) 104 | 105 | if len(matches) == 0: 106 | raise KeyError('Could not find an experiment matching query') 107 | 108 | if len(matches) > 1: 109 | raise KeyError(f'Found too many experiments for query: {len(matches)}') 110 | 111 | return matches[0] 112 | 113 | @classmethod 114 | def fromExperiments(cls, metrics: Sequence[str] | None = None, path: Optional[str] = None, Model: Type[CExp] = ExperimentDescription) -> ResultCollection[CExp]: 115 | pool = Pool() 116 | paths = findExperiments(path) 117 | out: Any = cls(Model=Model) 118 | 119 | def load_path(p: str): 120 | exp = loadExperiment(p, Model) 121 | df = loadAllResults(exp, metrics=metrics) 122 | 123 | if df is not None: 124 | out._data[p] = Result( 125 | exp=exp, 126 | df=df, 127 | path=p, 128 | ) 129 | 130 | pool.map(load_path, paths) 131 | 132 | return out 133 | 134 | 135 | def findExperiments(path: Optional[str] = None): 136 | if path is None: 137 | main_file = importlib.import_module('__main__').__file__ 138 | assert main_file is not None 139 | path = os.path.dirname(main_file) 140 | 141 | paths = glob.glob(f'{path}/**/*.json', recursive=True) 142 | 143 | project = os.getcwd() 144 | return [ p.replace(f'{project}/', '') for p in paths ] 145 | -------------------------------------------------------------------------------- /PyExpUtils/results/LazyCollection.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import os 3 | import glob 4 | import importlib 5 | import pandas as pd 6 | 7 | from typing import Any, Generic, Sequence, Tuple, Type, TypeVar 8 | 9 | from PyExpUtils.models.ExperimentDescription import ExperimentDescription, loadExperiment 10 | from PyExpUtils.results.sqlite import loadAllResults, loadHypersOnly, loadResultsOnly 11 | from PyExpUtils.results.tools import getHeader 12 | 13 | 14 | Exp = TypeVar('Exp', bound=ExperimentDescription) 15 | CExp = TypeVar('CExp', bound=ExperimentDescription) 16 | 17 | 18 | class LazyResult(Generic[Exp]): 19 | def __init__(self, exp: Exp, path: str, metrics: Sequence[str] | None): 20 | self.metrics = metrics 21 | self.exp = exp 22 | self.path = path 23 | 24 | def load(self) -> pd.DataFrame | None: 25 | return loadAllResults(self.exp, metrics=self.metrics) 26 | 27 | def load_metrics(self) -> pd.DataFrame | None: 28 | return loadResultsOnly(self.exp, metrics=self.metrics) 29 | 30 | def load_hypers(self) -> pd.DataFrame | None: 31 | return loadHypersOnly(self.exp) 32 | 33 | 34 | class GroupbyResult(LazyResult): 35 | def __init__(self, exp: Any, path: str, sub_path: str, metrics: Sequence[str] | None): 36 | super().__init__(exp, path, metrics) 37 | self.sub_path = sub_path 38 | 39 | class LazyResultCollection(Generic[Exp]): 40 | def __init__(self, path: str | None = None, metrics: Sequence[str] | None = None, Model: Type[Exp] | None = None): 41 | self._Model = Model or ExperimentDescription 42 | self._path = path 43 | self._metrics = metrics 44 | 45 | if self._path is None: 46 | main_file = importlib.import_module('__main__').__file__ 47 | assert main_file is not None 48 | self._path = os.path.dirname(main_file) 49 | 50 | project = os.getcwd() 51 | paths = glob.glob(f'{self._path}/**/*.json', recursive=True) 52 | paths = [ p.replace(f'{project}/', '') for p in paths ] 53 | self._paths = paths 54 | 55 | def result(self, path: str) -> LazyResult[Exp]: 56 | exp: Any = loadExperiment(path, self._Model) 57 | 58 | return LazyResult[Exp]( 59 | exp=exp, 60 | path=path, 61 | metrics=self._metrics, 62 | ) 63 | 64 | def groupby_directory(self, level: int): 65 | uniques = set([ 66 | p.split('/')[level] for p in self._paths 67 | ]) 68 | 69 | for group in uniques: 70 | group_paths = [p for p in self._paths if p.split('/')[level] == group] 71 | results = map(self.result, group_paths) 72 | 73 | yield group, [_sub_path(r, level) for r in results] 74 | 75 | def get_hyperparameter_columns(self): 76 | hypers = set[str]() 77 | 78 | for path in self._paths: 79 | exp = loadExperiment(path, Model=self._Model) 80 | sub = getHeader(exp) 81 | hypers |= set(sub) 82 | 83 | return list(sorted(hypers)) 84 | 85 | def get_hyperparameter_values(self, hyper: str): 86 | values = set() 87 | 88 | for path in self._paths: 89 | exp = loadExperiment(path, Model=self._Model) 90 | v = set(exp._d['metaParameters'].get(hyper, [])) 91 | 92 | values |= v 93 | 94 | return values 95 | 96 | def __iter__(self): 97 | return map(self.result, self._paths) 98 | 99 | def __getitem__(self, key: str | Tuple[str, ...]) -> LazyResult[Exp]: 100 | if isinstance(key, str): 101 | key = (key, ) 102 | 103 | matches = [] 104 | for k in self._paths: 105 | parts = k.split('/') 106 | parts[-1] = parts[-1].replace('.json', '') 107 | if all(any(query == part for part in parts) for query in key): 108 | matches.append(k) 109 | 110 | if len(matches) == 0: 111 | raise KeyError('Could not find an experiment matching query') 112 | 113 | if len(matches) > 1: 114 | raise KeyError(f'Found too many experiments for query: {len(matches)}') 115 | 116 | return self.result(matches[0]) 117 | 118 | 119 | def _sub_path(r: LazyResult, level: int): 120 | sub_parts = r.path.split('/')[level + 1:] 121 | sub = '/'.join(sub_parts) 122 | sub = sub.replace('.json', '') 123 | return GroupbyResult( 124 | exp=r.exp, 125 | path=r.path, 126 | sub_path=sub, 127 | metrics=r.metrics, 128 | ) 129 | -------------------------------------------------------------------------------- /PyExpUtils/results/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andnp/PyExpUtils/5d076ff1196368a936b18998afd00c80d4699857/PyExpUtils/results/__init__.py -------------------------------------------------------------------------------- /PyExpUtils/results/_utils/shared.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable 2 | 3 | def hash_values(vals: Iterable[Any]): 4 | return hash(','.join(map(str, vals))) 5 | -------------------------------------------------------------------------------- /PyExpUtils/results/indices.py: -------------------------------------------------------------------------------- 1 | from PyExpUtils.models.ExperimentDescription import ExperimentDescription 2 | 3 | """doc 4 | Returns an iterator over indices for each parameter permutation. 5 | Can specify a number of runs and will cycle over the permutations `runs` number of times. 6 | 7 | ```python 8 | for i in listIndices(exp, runs=2): 9 | print(i, exp.getRun(i)) # -> "0 0", "1 0", "2 0", ... "0 1", "1 1", ... 10 | ``` 11 | """ 12 | def listIndices(exp: ExperimentDescription, runs: int = 1): 13 | perms = exp.numPermutations() 14 | tasks = perms * runs 15 | return range(tasks) 16 | -------------------------------------------------------------------------------- /PyExpUtils/results/migrations.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import sqlite3 3 | import logging 4 | import pandas as pd 5 | import connectorx as cx 6 | import PyExpUtils.results.sqlite_utils as sqlu 7 | 8 | from glob import glob 9 | from typing import Iterable 10 | 11 | from PyExpUtils.models.ExperimentDescription import ExperimentDescription 12 | from PyExpUtils.results.tools import getHeader 13 | from PyExpUtils.results._utils.shared import hash_values 14 | 15 | logger = logging.getLogger('PyExpUtils') 16 | 17 | def detect_version(cur: sqlite3.Cursor) -> str: 18 | tables = sqlu.get_tables(cur) 19 | 20 | if 'metadata' not in tables: 21 | return 'v1' 22 | 23 | res = cur.execute('SELECT version FROM metadata') 24 | version = res.fetchone() 25 | 26 | if version is None: 27 | return 'v1' 28 | 29 | return version[0] 30 | 31 | def maybe_migrate(db_name: str, exp: ExperimentDescription): 32 | con = sqlite3.connect(db_name) 33 | cur = con.cursor() 34 | 35 | version = detect_version(cur) 36 | 37 | if version == 'v1': 38 | logger.warning('Migrating from v1->v2 of data version') 39 | make_backup(db_name) 40 | 41 | try: 42 | v1_to_v2_migration(db_name, cur, exp) 43 | except Exception as e: 44 | restore_backup(db_name) 45 | raise e 46 | 47 | elif version == 'v2': 48 | ... 49 | 50 | else: 51 | raise Exception('Cannot figure out how to migrate to latest data version') 52 | 53 | cur.close() 54 | con.close() 55 | 56 | def restore_backup(db_name: str): 57 | backups = glob(db_name + '.*.backup') 58 | latest = backups[-1] 59 | 60 | logger.warning(f'Attempting to restore {db_name} from backup {latest}') 61 | shutil.copyfile(latest, db_name) 62 | 63 | def make_backup(db_name: str): 64 | backups = glob(db_name + '.*.backup') 65 | num = len(backups) 66 | dst = f'{db_name}.{num}.backup' 67 | 68 | logger.warning(f'Making a backup of {db_name} at {dst}') 69 | shutil.copyfile(db_name, dst) 70 | 71 | def get_values(df: pd.DataFrame | pd.Series, hypers: Iterable[str]): 72 | hypers = sorted(hypers) 73 | 74 | vals = [] 75 | for h in hypers: 76 | vals.append(df[h]) 77 | 78 | return vals 79 | 80 | def v1_to_v2_migration(path: str, cur: sqlite3.Cursor, exp: ExperimentDescription): 81 | sqlu.make_table(cur, 'metadata', ['version']) 82 | cur.execute('INSERT INTO metadata(version) VALUES("v2")') 83 | 84 | hypers = getHeader(exp) 85 | hypers = sorted(hypers) 86 | sqlu.make_table(cur, 'hyperparameters', ['config_id'] + hypers) 87 | 88 | hypers_str = ','.join(map(sqlu.quote, hypers)) 89 | df = cx.read_sql(f'sqlite://{path}', f'SELECT DISTINCT {hypers_str} FROM results') 90 | 91 | cur.execute('ALTER TABLE results ADD COLUMN config_id') 92 | 93 | logger.warning('Updating existing rows') 94 | for _, row in df.iterrows(): 95 | values = get_values(row, hypers) 96 | cid = hash_values(values) 97 | 98 | q = ' AND '.join(f'"{k}"={sqlu.maybe_quote(v)}' for k, v in zip(hypers, values)) 99 | cur.execute(f'UPDATE results SET config_id={cid} WHERE {q}') 100 | 101 | cols = ','.join(map(sqlu.quote, ['config_id'] + hypers)) 102 | vals = ','.join(map(str, map(sqlu.maybe_quote, [cid] + values))) 103 | cur.execute(f'INSERT INTO hyperparameters({cols}) VALUES({vals})') 104 | 105 | cur.connection.commit() 106 | 107 | logger.warning('Creating new table') 108 | # can't just drop columns in sqlite 109 | # have to recreate table 110 | all_cols = sqlu.get_cols(cur, 'results') 111 | 112 | desired_cols = set(all_cols) - set(hypers) 113 | desired_cols |= { 'config_id' } 114 | 115 | cols = ','.join(map(sqlu.quote, desired_cols)) 116 | cur.execute(f'CREATE TEMPORARY TABLE results_backup({cols})') 117 | cur.execute(f'INSERT INTO results_backup SELECT {cols} FROM results') 118 | cur.execute('DROP TABLE results') 119 | cur.execute(f'CREATE TABLE results({cols})') 120 | cur.execute(f'INSERT INTO results SELECT {cols} FROM results_backup') 121 | cur.execute('DROP TABLE results_backup') 122 | cur.connection.commit() 123 | -------------------------------------------------------------------------------- /PyExpUtils/results/pandas.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import os 3 | import glob 4 | import pandas as pd 5 | import PyExpUtils.utils.pandas as pdu 6 | 7 | from filelock import FileLock 8 | from typing import Any, Dict, Iterable, Optional, Sequence, Union 9 | from PyExpUtils.FileSystemContext import FileSystemContext 10 | from PyExpUtils.models.ExperimentDescription import ExperimentDescription 11 | from PyExpUtils.results.indices import listIndices 12 | from PyExpUtils.results.tools import subsetDF 13 | from PyExpUtils.collection.Collector import Collector 14 | from PyExpUtils.utils.dict import flatKeys, get 15 | from PyExpUtils.utils.types import NpList 16 | from PyExpUtils.utils.asyncio import threadMap 17 | from PyExpUtils.utils.iterable import filter_none 18 | 19 | class NoResultException(Exception): 20 | ... 21 | 22 | def saveResults(exp: ExperimentDescription, idx: int, filename: str, data: NpList, base: str = './', batch_size: Optional[int] = 20000): 23 | context = exp.buildSaveContext(idx, base=base) 24 | context.ensureExists() 25 | 26 | params = exp.getPermutation(idx)['metaParameters'] 27 | header = getHeader(exp) 28 | pvalues = [get(params, k) for k in header] 29 | 30 | run = exp.getRun(idx) 31 | 32 | df = pd.DataFrame([pvalues + [run] + list(data)]) 33 | 34 | # -------------- 35 | # -- batching -- 36 | # -------------- 37 | data_file = _batchFile(context, filename, idx, batch_size) 38 | 39 | with FileLock(data_file + '.lock'): 40 | df.to_csv(data_file, mode='a+', header=False, index=False) 41 | 42 | return data_file 43 | 44 | def saveSequentialRuns(exp: ExperimentDescription, idx: int, filename: str, data: Any, base: str = './', batch_size: Optional[int] = 20000): 45 | context = exp.buildSaveContext(idx, base=base) 46 | context.ensureExists() 47 | 48 | params = exp.getPermutation(idx)['metaParameters'] 49 | header = getHeader(exp) 50 | pvalues = [get(params, k) for k in header] 51 | 52 | run = exp.getRun(idx) 53 | rows = [] 54 | for i in range(len(data)): 55 | if data[i] is None: 56 | continue 57 | 58 | rows.append(pvalues + [run + i] + list(data[i])) 59 | 60 | df = pd.DataFrame(rows) 61 | 62 | # -------------- 63 | # -- batching -- 64 | # -------------- 65 | data_file = _batchFile(context, filename, idx, batch_size) 66 | 67 | with FileLock(data_file + '.lock'): 68 | df.to_csv(data_file, mode='a+', header=False, index=False) 69 | 70 | return data_file 71 | 72 | def saveCollector(exp: ExperimentDescription, collector: Collector, base: str = './', batch_size: Optional[int] = 20000, keys: Optional[Sequence[str]] = None): 73 | context = exp.buildSaveContext(0, base=base) 74 | context.ensureExists() 75 | 76 | header = getHeader(exp) 77 | 78 | to_write = defaultdict(list) 79 | 80 | if keys is None: 81 | keys = list(collector.keys()) 82 | 83 | for filename in keys: 84 | for idx in collector.indices(): 85 | data = collector.get(filename, idx) 86 | 87 | params = exp.getPermutation(idx)['metaParameters'] 88 | run = exp.getRun(idx) 89 | pvalues = [get(params, k) for k in header] 90 | 91 | row = pvalues + [run] + list(data) 92 | data_file = _batchFile(context, filename, idx, batch_size) 93 | 94 | to_write[data_file].append(row) 95 | 96 | for path in to_write: 97 | df = pd.DataFrame(to_write[path]) 98 | 99 | with FileLock(path + '.lock'): 100 | df.to_csv(path, mode='a+', header=False, index=False) 101 | 102 | def loadAllResults(exp: ExperimentDescription, metrics: Optional[Iterable[str]] = None, base: str = './', use_cache: bool = True) -> Union[pd.DataFrame, None]: 103 | if metrics is None: 104 | metrics = get_result_filenames(exp, base) 105 | 106 | parts = (loadResults(exp, f, base, col=f, use_cache=use_cache) for f in metrics) 107 | dfs: Iterable[pd.DataFrame] = filter_none(parts) 108 | dfs = list(dfs) 109 | 110 | if len(dfs) == 0: 111 | return None 112 | 113 | header = getHeader(exp) + ['run'] 114 | df = pdu.outer(dfs, on=header) 115 | 116 | return df 117 | 118 | def loadResults(exp: ExperimentDescription, filename: str, base: str = './', col: Optional[str] = None, use_cache: bool = True) -> Union[pd.DataFrame, None]: 119 | context = exp.buildSaveContext(0, base=base) 120 | 121 | files = glob.glob(context.resolve(f'{filename}.*.csv')) 122 | 123 | # this could be because we did not use batching 124 | # try again without batching 125 | if len(files) == 0: 126 | files = glob.glob(context.resolve(f'{filename}.csv')) 127 | 128 | # if still no files, then no results exist 129 | if len(files) == 0: 130 | return None 131 | 132 | # get latest modification time 133 | times = (os.path.getmtime(f) for f in files) 134 | latest = max(*times, 0, 0) 135 | 136 | cache_file = context.resolve(filename + '.pkl') 137 | if use_cache and os.path.exists(cache_file) and os.path.getmtime(cache_file) > latest: 138 | df = pd.read_pickle(cache_file) 139 | return _subsetDFbyExp(df, exp) 140 | 141 | partials = threadMap(_readUnevenCsv, files) 142 | df = pd.concat(partials, ignore_index=True) 143 | 144 | header = getHeader(exp) 145 | nparams = len(header) + 1 146 | new_df = df.iloc[:, :nparams] 147 | new_df.columns = header + ['run'] 148 | 149 | # figure out where to put the data 150 | if col is None: 151 | col = 'data' 152 | 153 | data_cols = df.iloc[:, nparams:].values 154 | if data_cols.shape[1] == 1: 155 | new_df[col] = data_cols[:, 0] 156 | else: 157 | new_df[col] = df.iloc[:, nparams:].values.tolist() 158 | 159 | if use_cache: 160 | new_df.to_pickle(cache_file) 161 | 162 | return _subsetDFbyExp(new_df, exp) 163 | 164 | 165 | def detectMissingIndices(exp: ExperimentDescription, runs: int, filename: Optional[str] = None, base: str = './'): # noqa: C901 166 | indices = listIndices(exp) 167 | nperms = exp.numPermutations() 168 | header = getHeader(exp) 169 | 170 | r_files = get_result_filenames(exp, base) 171 | if len(r_files) == 0: 172 | for idx in indices: 173 | for run in range(runs): 174 | yield idx + run * nperms 175 | return 176 | 177 | filename = list(r_files)[0] 178 | assert isinstance(filename, str) 179 | df = loadResults(exp, filename, base=base) 180 | # ---------------------------------- 181 | # -- first case: no existing data -- 182 | # ---------------------------------- 183 | if df is None: 184 | for idx in indices: 185 | for run in range(runs): 186 | yield idx + run * nperms 187 | 188 | return 189 | 190 | grouped = df.groupby(header) 191 | for idx in indices: 192 | params = exp.getPermutation(idx)['metaParameters'] 193 | pvals = tuple(get(params, k) for k in header) 194 | 195 | # get_group cannot handle singular tuples 196 | if len(pvals) == 1: 197 | pvals = pvals[0] 198 | 199 | # ------------------------------------ 200 | # -- second case: no existing group -- 201 | # ------------------------------------ 202 | assert grouped is not None 203 | try: 204 | group = grouped.get_group(pvals) 205 | except KeyError: 206 | for run in range(runs): 207 | yield idx + run * nperms 208 | 209 | continue 210 | 211 | # ------------------------------------------------- 212 | # -- final case: have data and group. check runs -- 213 | # ------------------------------------------------- 214 | for run in range(runs): 215 | if not (group['run'] == run).any(): 216 | yield idx + run * nperms 217 | 218 | def getHeader(exp: ExperimentDescription): 219 | params = exp.getPermutation(0)['metaParameters'] 220 | keys = flatKeys(params) 221 | return sorted(keys) 222 | 223 | def getParamValues(exp: ExperimentDescription, idx: int, header: Optional[Sequence[str]] = None): 224 | if header is None: 225 | header = getHeader(exp) 226 | 227 | params = exp.getPermutation(idx)['metaParameters'] 228 | return [get(params, k) for k in header] 229 | 230 | def get_result_filenames(exp: ExperimentDescription, base: str = './'): 231 | context = exp.buildSaveContext(0, base=base) 232 | files = glob.glob(context.resolve('*.*.csv')) + glob.glob(context.resolve('*.csv')) 233 | 234 | if len(files) == 0: 235 | return set[str]() 236 | 237 | return set(map(lambda x: os.path.basename(x).split('.')[0], files)) 238 | 239 | # --------------- 240 | # -- Utilities -- 241 | # --------------- 242 | 243 | # makes sure the dataframe only contains the data for a given experiment description 244 | def _subsetDFbyExp(df: pd.DataFrame, exp: ExperimentDescription): 245 | params = exp._d['metaParameters'] 246 | return subsetDF(df, _flattenKeys(params)) 247 | 248 | def _flattenKeys(d: Dict[str, Any]): 249 | out = {} 250 | for k, v in d.items(): 251 | if isinstance(v, dict): 252 | for sk, sv in _flattenKeys(v).items(): 253 | out[f'{k}.{sk}'] = sv 254 | 255 | # if we have a list of lists, add top-level list as key 256 | elif isinstance(v, list) and isinstance(v[0], list): 257 | for i, sv in enumerate(v): 258 | out[f'{k}.[{i}]'] = sv 259 | 260 | # if we have a list of objects, keep digging 261 | elif isinstance(v, list) and isinstance(v[0], dict): 262 | for i, sv in enumerate(v): 263 | for sk, ssv in _flattenKeys(sv).items(): 264 | out[f'{k}.[{i}].{sk}'] = ssv 265 | 266 | else: 267 | out[k] = v 268 | 269 | return out 270 | 271 | 272 | # if the csv contains ragged rows (i.e. rows have different numbers of columns) 273 | # then the native csv reader needs to know the max number of columns. 274 | # the resulting df will have NaNs for the shorter rows 275 | def _readUnevenCsv(f: str): 276 | with open(f, 'r') as temp_f: 277 | col_count = ( len(line.split(",")) for line in temp_f.readlines() ) 278 | m_cols = max(col_count) 279 | names = list(map(str, range(0, m_cols))) 280 | 281 | return pd.read_csv(f, header=None, names=names) 282 | 283 | def _batchFile(context: FileSystemContext, filename: str, idx: int, batch_size: Optional[int]): 284 | if batch_size is None: 285 | return context.resolve(f'{filename}.csv') 286 | 287 | batch_idx = int(idx // batch_size) 288 | return context.resolve(f'{filename}.{batch_idx}.csv') 289 | -------------------------------------------------------------------------------- /PyExpUtils/results/sqlite.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sqlite3 3 | import logging 4 | import pandas as pd 5 | import PyExpUtils.results.sqlite_utils as sqlu 6 | 7 | from filelock import FileLock 8 | from typing import Iterable, Sequence 9 | 10 | from PyExpUtils.collection.Collector import Collector 11 | from PyExpUtils.models.ExperimentDescription import ExperimentDescription 12 | from PyExpUtils.results.indices import listIndices 13 | from PyExpUtils.results.migrations import maybe_migrate 14 | from PyExpUtils.results.tools import getHeader, getParamValues 15 | from PyExpUtils.results._utils.shared import hash_values 16 | 17 | logger = logging.getLogger('PyExpUtils') 18 | 19 | 20 | # ------------ 21 | # -- Saving -- 22 | # ------------ 23 | def saveCollector(exp: ExperimentDescription, collector: Collector, base: str = './', keys: Iterable[str] | None = None): 24 | context = exp.buildSaveContext(0, base=base) 25 | context.ensureExists() 26 | 27 | hypers = getHeader(exp) 28 | metrics = list(collector.keys()) 29 | 30 | res_cols = list(set(['config_id', 'seed', 'frame'] + metrics)) 31 | hyp_cols = list(set(hypers + ['config_id'])) 32 | 33 | db_file = context.resolve('results.db') 34 | with FileLock(db_file + '.lock'): 35 | if os.path.exists(db_file): 36 | maybe_migrate(db_file, exp) 37 | 38 | con = sqlite3.connect(db_file, timeout=30) 39 | cur = con.cursor() 40 | 41 | set_version(cur, 'v2') 42 | 43 | sqlu.maybe_make_table(cur, 'hyperparameters', hyp_cols) 44 | sqlu.ensure_table_compatible(cur, 'hyperparameters', hyp_cols) 45 | 46 | sqlu.maybe_make_table(cur, 'results', res_cols) 47 | sqlu.ensure_table_compatible(cur, 'results', res_cols) 48 | 49 | rows = [] 50 | for idx in collector.indices(): 51 | cid = get_cid(cur, hypers, exp, idx) 52 | seed = exp.getRun(idx) 53 | frames = collector.get_frames(idx) 54 | for frame in frames: 55 | row_dict = frame | {'seed': seed, 'config_id': cid} 56 | vals = tuple(row_dict.get(k, None) for k in res_cols) 57 | rows.append(vals) 58 | 59 | cols_str = ', '.join(map(sqlu.quote, res_cols)) 60 | v_inserter = ', '.join('?' * len(res_cols)) 61 | cur.executemany(f'INSERT INTO results({cols_str}) VALUES({v_inserter})', rows) 62 | 63 | con.commit() 64 | con.close() 65 | 66 | # ------------- 67 | # -- Loading -- 68 | # ------------- 69 | def loadResultsOnly(exp: ExperimentDescription, base: str = './', metrics: Sequence[str] | None = None): 70 | context = exp.buildSaveContext(0, base=base) 71 | if not context.exists('results.db'): 72 | return None 73 | 74 | path = context.resolve('results.db') 75 | maybe_migrate(path, exp) 76 | 77 | con = sqlite3.connect(path) 78 | cur = con.cursor() 79 | 80 | header = getHeader(exp) 81 | valid_cids = [ 82 | get_cid(cur, header, exp, i) for i in listIndices(exp) 83 | ] 84 | 85 | constraints = ','.join(map(str, valid_cids)) 86 | constraints = f'config_id IN ({constraints})' 87 | if metrics is None: 88 | df = sqlu.read_to_df(path, f'SELECT * FROM results WHERE {constraints}', part='config_id') 89 | else: 90 | cols = set(metrics) | { 'frame', 'seed', 'config_id' } 91 | col_str = ','.join(map(sqlu.quote, cols)) 92 | 93 | non_null = ' AND '.join(f'{m} IS NOT NULL' for m in metrics) 94 | df = sqlu.read_to_df(path, f'SELECT {col_str} FROM results WHERE {non_null} AND {constraints}', part='config_id') 95 | 96 | return df 97 | 98 | def loadHypersOnly(exp: ExperimentDescription, base: str = './') -> pd.DataFrame | None: 99 | context = exp.buildSaveContext(0, base=base) 100 | if not context.exists('results.db'): 101 | return None 102 | 103 | path = context.resolve('results.db') 104 | config_df = sqlu.read_to_df(path, 'SELECT * FROM hyperparameters') 105 | 106 | return config_df 107 | 108 | def loadAllResults(exp: ExperimentDescription, base: str = './', metrics: Sequence[str] | None = None) -> pd.DataFrame | None: 109 | context = exp.buildSaveContext(0, base=base) 110 | if not context.exists('results.db'): 111 | return None 112 | 113 | path = context.resolve('results.db') 114 | 115 | result_df = loadResultsOnly(exp, base, metrics) 116 | config_df = sqlu.read_to_df(path, 'SELECT * FROM hyperparameters') 117 | 118 | assert result_df is not None 119 | df = result_df.merge(config_df, on='config_id') 120 | 121 | return df 122 | 123 | def detectMissingIndices(exp: ExperimentDescription, runs: int, base: str = './'): # noqa: C901 124 | context = exp.buildSaveContext(0, base=base) 125 | nperms = exp.numPermutations() 126 | 127 | header = getHeader(exp) 128 | 129 | # first case: no data 130 | if not context.exists('results.db'): 131 | yield from listIndices(exp, runs) 132 | return 133 | 134 | db_file = context.resolve('results.db') 135 | maybe_migrate(db_file, exp) 136 | con = sqlite3.connect(db_file, timeout=30) 137 | cur = con.cursor() 138 | 139 | tables = sqlu.get_tables(cur) 140 | if 'results' not in tables: 141 | yield from listIndices(exp, runs) 142 | con.close() 143 | return 144 | 145 | expected_seeds = set(range(runs)) 146 | for idx in listIndices(exp): 147 | cid = get_cid(cur, header, exp, idx) 148 | 149 | rows = cur.execute(f'SELECT DISTINCT seed FROM results WHERE config_id={cid}').fetchall() 150 | seeds = set(d[0] for d in rows) 151 | 152 | needed = expected_seeds - seeds 153 | for seed in needed: 154 | yield idx + seed * nperms 155 | 156 | con.close() 157 | 158 | # --------------- 159 | # -- Utilities -- 160 | # --------------- 161 | def get_cid(cur: sqlite3.Cursor, header: Sequence[str], exp: ExperimentDescription, idx: int) -> int: 162 | values = getParamValues(exp, idx, header) 163 | 164 | # first see if a cid already exists 165 | if len(header) > 0: 166 | c = sqlu.constraints_from_lists(header, values) 167 | res = cur.execute(f'SELECT config_id FROM hyperparameters WHERE {c}') 168 | else: 169 | res = cur.execute('SELECT config_id FROM hyperparameters') 170 | 171 | cids = res.fetchall() 172 | if len(cids) > 0: 173 | return cids[0][0] 174 | 175 | # otherwise create and store a cid 176 | cid = hash_values(values) 177 | 178 | if len(header) > 0: 179 | c_str = ','.join(map(sqlu.maybe_quote, header)) 180 | v_str = ','.join(map(str, map(sqlu.maybe_quote, values))) 181 | cur.execute(f'INSERT INTO hyperparameters({c_str},config_id) VALUES({v_str},{cid})') 182 | else: 183 | cur.execute(f'INSERT INTO hyperparameters(config_id) VALUES({cid})') 184 | 185 | return cid 186 | 187 | 188 | def set_version(cur: sqlite3.Cursor, version: str): 189 | sqlu.maybe_make_table(cur, 'metadata', ['version']) 190 | 191 | res = cur.execute('SELECT version FROM metadata') 192 | v = res.fetchall() 193 | 194 | if len(v) == 0: 195 | cur.execute(f'INSERT INTO metadata(version) VALUES("{version}")') 196 | else: 197 | cur.execute(f'UPDATE metadata SET version="{version}" WHERE version="{v[0]}"') 198 | -------------------------------------------------------------------------------- /PyExpUtils/results/sqlite_utils.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import pandas as pd 3 | import connectorx as cx 4 | from typing import Any, Dict, Iterable, List 5 | 6 | 7 | def get_tables(cur: sqlite3.Cursor) -> List[str]: 8 | res = cur.execute("SELECT name FROM sqlite_master") 9 | return [r[0] for r in res.fetchall()] 10 | 11 | def make_table(cur: sqlite3.Cursor, name: str, columns: Iterable[str]): 12 | cols = ', '.join(map(quote, columns)) 13 | cur.execute(f'CREATE TABLE {name}({cols})') 14 | 15 | def maybe_make_table(cur: sqlite3.Cursor, name: str, columns: Iterable[str]): 16 | tables = get_tables(cur) 17 | 18 | if name not in tables: 19 | make_table(cur, name, columns) 20 | 21 | def get_cols(cur: sqlite3.Cursor, name: str): 22 | res = cur.execute(f'PRAGMA table_info({name})') 23 | rows = res.fetchall() 24 | 25 | return [r[1] for r in rows] 26 | 27 | def add_cols(cur: sqlite3.Cursor, columns: Iterable[str]): 28 | columns = map(quote, columns) 29 | for col in columns: 30 | cur.execute(f'ALTER TABLE results ADD COLUMN {col}') 31 | 32 | def ensure_table_compatible(cur: sqlite3.Cursor, name: str, columns: Iterable[str]): 33 | columns = set(columns) 34 | current_cols = set(get_cols(cur, name)) 35 | needed_cols = columns - current_cols 36 | 37 | if needed_cols: 38 | add_cols(cur, needed_cols) 39 | 40 | 41 | def query(cur: sqlite3.Cursor, what: str, where: Dict[str, Any]): 42 | constraints = ' and '.join([f'"{k}"={maybe_quote(v)}' for k, v in where.items()]) 43 | res = cur.execute(f'SELECT {what} FROM results WHERE {constraints}') 44 | rows = res.fetchall() 45 | return rows 46 | 47 | def constraints_from_lists(cols: Iterable[str], vals: Iterable[Any]): 48 | c = ' AND '.join(f'"{k}"={maybe_quote(v)}' for k, v in zip(cols, vals)) 49 | return c 50 | 51 | 52 | def maybe_quote(v: Any): 53 | if isinstance(v, str): 54 | return quote(v) 55 | return v 56 | 57 | def quote(s: str): 58 | return f'"{s}"' 59 | 60 | 61 | def read_to_df(db_name: str, query: str, part: str | None = None) -> pd.DataFrame: 62 | # it appears that connectorx has some bugs that cause it to periodically fail. 63 | # but when it _does_ work, it is 10x faster. So let's try connectorx first, then 64 | # fall back to the slower pandas for now. 65 | try: 66 | n = 4 if part is not None else None 67 | df: Any = cx.read_sql(f'sqlite://{db_name}', query, partition_on=part, partition_num=n) 68 | return df 69 | except BaseException as e: 70 | print(db_name) 71 | print(query) 72 | print(e) 73 | con = sqlite3.connect(db_name) 74 | df = pd.read_sql_query(query, con) 75 | con.close() 76 | return df 77 | -------------------------------------------------------------------------------- /PyExpUtils/results/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from typing import Any, Dict, Optional, Sequence 5 | 6 | from PyExpUtils.models.ExperimentDescription import ExperimentDescription 7 | from PyExpUtils.utils.dict import flatKeys, get 8 | 9 | 10 | def collapseRuns(df: pd.DataFrame): 11 | cols = list(df.columns) 12 | header = list(filter(lambda c: c not in ['data', 'run'], cols)) 13 | 14 | df = df.groupby(header)['data'].apply(lambda x: np.array(list(x))).reset_index() 15 | return df 16 | 17 | def subsetDF(df: pd.DataFrame, conds: Dict[str, Any]): 18 | mask = _buildMask(df, conds) 19 | return df[mask].reset_index(drop=True) 20 | 21 | def splitByValue(df: pd.DataFrame, col: str): 22 | values = df[col].unique() 23 | values.sort() 24 | 25 | for v in values: 26 | sub = df[df[col] == v] 27 | yield v, sub 28 | 29 | def getHeader(exp: ExperimentDescription): 30 | params = exp.getPermutation(0)['metaParameters'] 31 | keys = flatKeys(params) 32 | return sorted(keys) 33 | 34 | def getParamValues(exp: ExperimentDescription, idx: int, header: Optional[Sequence[str]] = None): 35 | if header is None: 36 | header = getHeader(exp) 37 | 38 | params = exp.getPermutation(idx)['metaParameters'] 39 | return [get(params, k) for k in header] 40 | 41 | def getParamsAsDict(exp: ExperimentDescription, idx: int, header: Optional[Sequence[str]] = None): 42 | if header is None: 43 | header = getHeader(exp) 44 | 45 | params = exp.getPermutation(idx)['metaParameters'] 46 | return { 47 | k: get(params, k) for k in header 48 | } 49 | 50 | # ------------------------ 51 | # -- Internal Utilities -- 52 | # ------------------------ 53 | def _buildMask(df: pd.DataFrame, conds: Dict[str, Any]): 54 | mask = np.ones(len(df), dtype=bool) 55 | for key, cond in conds.items(): 56 | if isinstance(cond, dict): 57 | mask = mask | _buildMask(df, cond) 58 | 59 | elif isinstance(cond, list): 60 | mask = mask & (df[key].isin(cond)) 61 | 62 | elif key in df: 63 | mask = mask & (df[key] == cond) 64 | 65 | return mask 66 | -------------------------------------------------------------------------------- /PyExpUtils/results/voting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | from PyExpUtils.utils.types import T 4 | from PyExpUtils.utils.jit import try2jit 5 | from typing import Dict, List, NamedTuple, Tuple, Union, cast 6 | 7 | Name = Union[int, str] 8 | 9 | class ScoredCandidate(NamedTuple): 10 | # meta-parameter permutation index for experiment description 11 | name: Name 12 | # the "score" of this meta-parameter 13 | score: float 14 | # the stderr of the point value 15 | stderr: float 16 | 17 | class RankedCandidate(NamedTuple): 18 | name: Name 19 | rank: int 20 | score: float = 0 21 | 22 | def confidenceInterval(scored: ScoredCandidate, c: float): 23 | lo = scored.score - c * scored.stderr 24 | hi = scored.score + c * scored.stderr 25 | 26 | return (lo, hi) 27 | 28 | def inRange(a: Tuple[float, float], b: Tuple[float, float]): 29 | if a[0] <= b[1] and a[0] >= b[0]: 30 | return True 31 | 32 | if a[1] <= b[1] and a[1] >= b[0]: 33 | return True 34 | 35 | return False 36 | 37 | def filterNans(scores: List[ScoredCandidate]): 38 | return list(filter(lambda s: not np.isnan(s.score), scores)) 39 | 40 | def confidenceRanking(scores: List[ScoredCandidate], stderrs: float = 2.0, prefer: str = 'big'): 41 | # this method just ignores null results 42 | scores = filterNans(scores) 43 | 44 | if prefer == 'big': 45 | ordered = sorted(scores, key=lambda x: x.score, reverse=True) 46 | else: 47 | ordered = sorted(scores, key=lambda x: x.score, reverse=False) 48 | 49 | rank = 0 50 | last_range = confidenceInterval(ordered[0], stderrs) 51 | 52 | ranks: List[RankedCandidate] = [] 53 | for score in ordered: 54 | rang = confidenceInterval(score, stderrs) 55 | if not inRange(rang, last_range): 56 | rank += 1 57 | last_range = rang 58 | 59 | ranks.append(RankedCandidate(score.name, rank, score.score)) 60 | 61 | return ranks 62 | 63 | def scoreRanking(scores: List[ScoredCandidate], prefer: str = 'big'): 64 | # this method just ignores null results 65 | scores = filterNans(scores) 66 | 67 | if prefer == 'big': 68 | ordered = sorted(scores, key=lambda x: x.score, reverse=True) 69 | else: 70 | ordered = sorted(scores, key=lambda x: x.score, reverse=False) 71 | 72 | rank = 0 73 | ranks: List[RankedCandidate] = [] 74 | for score in ordered: 75 | ranks.append(RankedCandidate(score.name, rank, score.score)) 76 | rank += 1 77 | 78 | return ranks 79 | 80 | RankedBallot = Dict[Name, RankedCandidate] 81 | def buildBallot(candidates: List[RankedCandidate]) -> RankedBallot: 82 | ballot: RankedBallot = {} 83 | for candidate in candidates: 84 | ballot[candidate.name] = candidate 85 | 86 | return ballot 87 | 88 | def countVotes(ballots: List[RankedBallot]): 89 | votes: Dict[Name, int] = {} 90 | 91 | for ballot in ballots: 92 | for name in ballot: 93 | v = votes.get(name, 0) 94 | 95 | if ballot[name].rank == 0: 96 | v += 1 97 | 98 | votes[name] = v 99 | 100 | return votes 101 | 102 | def highScore(ballots: List[RankedBallot], prefer: str = 'big') -> Name: 103 | names = list(ballots[0].keys()) 104 | 105 | scores = np.zeros(len(names)) 106 | 107 | for i, name in enumerate(names): 108 | for ballot in ballots: 109 | scores[i] += ballot[name].score 110 | 111 | if prefer == 'big': 112 | idx = np.argmax(scores) 113 | else: 114 | idx = np.argmin(scores) 115 | 116 | # numpy types are getting worse 117 | i = cast(int, idx) 118 | 119 | return names[i] 120 | 121 | def firstPastPost(ballots: List[RankedBallot]) -> Name: 122 | votes = countVotes(ballots) 123 | 124 | return dictMax(votes)[1] 125 | 126 | def instantRunoff(ballots: List[RankedBallot]) -> Name: 127 | # the code is simpler if we modify in place 128 | # so create a copy so that we don't mess with the sender's object 129 | ballots = deepcopy(ballots) 130 | 131 | votes = countVotes(ballots) 132 | 133 | # check if we have a majority leader 134 | vals = list(votes.values()) 135 | ma: int = np.max(vals) 136 | 137 | # if we have a majority leader, return that candidate 138 | if ma > np.ceil(len(ballots) / 2): 139 | return findKey(votes, ma) 140 | 141 | # if everyone is equal, then we have a tie. 142 | # in this case, return the first candidate 143 | if np.sum(vals == ma) == len(vals): 144 | return findKey(votes, ma) 145 | 146 | # otherwise, redistribute the ballots from the last place candidate 147 | mi: int = np.min(vals) 148 | 149 | # if there's only one loser, this is easy 150 | if np.sum(vals == mi) == 1: 151 | loser = findKey(votes, mi) 152 | 153 | # otherwise, let's grab the lowest total ranked candidate among the losers 154 | else: 155 | mi_votes: Dict[Name, int] = {} 156 | for name in findAllKeys(votes, mi): 157 | s = 0 158 | for ballot in ballots: 159 | s += ballot[name].rank 160 | 161 | mi_votes[name] = s 162 | 163 | mi_vals = list(mi_votes.values()) 164 | max_val: int = np.max(mi_vals) # highest rank here means worst candidate 165 | 166 | # technically there can again be a tie, but at this point we can safely 167 | # ignore that and break the tie arbitrarily 168 | loser = findKey(mi_votes, max_val) 169 | 170 | # now remove the loser from all ballots 171 | for ballot in ballots: 172 | # if this ballot voted for the loser *and* there are no other rank 0 votes 173 | # then bump all ranks up one 174 | if ballot[loser].rank == 0 and len(getCandidatesByRank(ballot, 0)) == 1: 175 | for name in ballot: 176 | ballot[name] = RankedCandidate(name, max(0, ballot[name].rank - 1), ballot[name].score) 177 | 178 | del ballot[loser] 179 | 180 | # run the vote again with the modified ballots 181 | return instantRunoff(ballots) 182 | 183 | @try2jit 184 | def computeVoteMatrix(ranks: np.ndarray): 185 | n = len(ranks) 186 | matrix = np.zeros((n, n)) 187 | 188 | for i in range(n): 189 | for j in range(n): 190 | if i == j: 191 | continue 192 | 193 | # a loses to b 194 | if ranks[i] > ranks[j]: 195 | matrix[j, i] = 1 196 | 197 | # b loses to a 198 | elif ranks[j] > ranks[i]: 199 | matrix[i, j] = 1 200 | 201 | return matrix 202 | 203 | @try2jit 204 | def copelandScore(sum_matrix: np.ndarray): 205 | scores = np.zeros(sum_matrix.shape[0]) 206 | 207 | for i in range(len(scores)): 208 | for j in range(len(scores)): 209 | if i == j: 210 | continue 211 | 212 | if sum_matrix[i, j] > sum_matrix[j, i]: 213 | scores[i] += 1 214 | 215 | elif sum_matrix[i, j] == sum_matrix[j, i]: 216 | scores[i] += 0.5 217 | 218 | return scores 219 | 220 | def sumMatrix(ballots: List[RankedBallot], names: List[Name]) -> np.ndarray: 221 | n = len(names) 222 | 223 | sum_matrix = np.zeros((n, n)) 224 | for ballot in ballots: 225 | # ensure we pass ranks in order *by name* 226 | ranks = np.array([ballot[name].rank for name in names]) 227 | sum_matrix += computeVoteMatrix(ranks) 228 | 229 | return sum_matrix 230 | 231 | def small(ballots: List[RankedBallot], prefer: str = 'big') -> Name: 232 | # the code is simpler if we modify in place 233 | # so create a copy so that we don't mess with the sender's object 234 | ballots = deepcopy(ballots) 235 | 236 | # order is arbitrary, but *must* be consistent 237 | names = list(ballots[0].keys()) 238 | sum_matrix = sumMatrix(ballots, names) 239 | copeland_scores = copelandScore(sum_matrix) 240 | 241 | winners = argsMax(copeland_scores) 242 | winner_names = [names[idx] for idx in winners] 243 | 244 | # if we have a singular winner, we are done 245 | if len(winners) == 1: 246 | return winner_names[0] 247 | 248 | # we could end up with a tie, which needs to broken arbitrarily 249 | # we pick the highest (by default) score 250 | rest = [name for name in names if name not in winner_names] 251 | 252 | # there was a tie which could not be resolved 253 | if len(rest) == 0: 254 | return highScore(ballots, prefer=prefer) 255 | 256 | # otherwise, iterate over all of the worst performers and delete them 257 | # then try again 258 | for loser in rest: 259 | for ballot in ballots: 260 | del ballot[loser] 261 | 262 | return small(ballots) 263 | 264 | def raynaud(ballots: List[RankedBallot]) -> Name: 265 | ballots = deepcopy(ballots) 266 | 267 | names = list(ballots[0].keys()) 268 | 269 | if len(names) == 1: 270 | return names[0] 271 | 272 | sum_matrix = sumMatrix(ballots, names) 273 | 274 | ma = np.max(sum_matrix) 275 | _, col = np.where(sum_matrix == ma) 276 | 277 | loser = names[col[0]] 278 | for ballot in ballots: 279 | del ballot[loser] 280 | 281 | return raynaud(ballots) 282 | 283 | # --------------------- 284 | # Local utility methods 285 | # --------------------- 286 | def getCandidatesByRank(ballot: RankedBallot, rank: int): 287 | out: List[RankedCandidate] = [] 288 | 289 | for name in ballot: 290 | if ballot[name].rank == rank: 291 | out.append(ballot[name]) 292 | 293 | return out 294 | 295 | def findKey(obj: Dict[Name, T], val: T) -> Name: 296 | for key in obj: 297 | if obj[key] == val: 298 | return key 299 | 300 | raise Exception('uh-oh') 301 | 302 | def findAllKeys(obj: Dict[Name, T], val: T) -> List[Name]: 303 | ret: List[Name] = [] 304 | for key in obj: 305 | if obj[key] == val: 306 | ret.append(key) 307 | 308 | return ret 309 | 310 | def dictMax(d: Dict[Name, int]): 311 | vals = list(d.values()) 312 | ma: int = np.max(vals) 313 | 314 | amax = findKey(d, ma) 315 | 316 | return ma, amax 317 | 318 | def argsMax(arr: np.ndarray): 319 | ties: List[int] = [] 320 | ma: float = -np.inf 321 | 322 | for i, a in enumerate(arr): 323 | if a > ma: 324 | ties = [i] 325 | ma = a 326 | elif a == ma: 327 | ties.append(i) 328 | 329 | return ties 330 | -------------------------------------------------------------------------------- /PyExpUtils/runner/Slurm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | from dataclasses import dataclass 5 | from typing import Any, Dict, Iterable, Optional 6 | 7 | from PyExpUtils.utils.cmdline import flagString 8 | 9 | """doc 10 | Takes an integer number of hours and returns a well-formatted time string. 11 | ```python 12 | time = hours(3) 13 | print(time) # -> '2:59:59 14 | ``` 15 | """ 16 | def hours(n: int): 17 | return f'{n-1}:59:59' 18 | 19 | """doc 20 | Takes an integer number of gigabytes and returns a well-formatted memory string. 21 | ```python 22 | memory = gb(4) 23 | print(memory) # -> '4G' 24 | ``` 25 | """ 26 | def gb(n: int): 27 | return f'{n}G' 28 | 29 | 30 | @dataclass 31 | class SingleNodeOptions: 32 | account: str 33 | time: str 34 | cores: int 35 | mem_per_core: str | float 36 | 37 | # task management args 38 | sequential: int = 1 39 | threads_per_task: int = 1 40 | 41 | # job reporting args 42 | log_path: str = "$SCRATCH/job_output_%j.txt" 43 | 44 | @dataclass 45 | class MultiNodeOptions: 46 | account: str 47 | time: str 48 | cores: int 49 | mem_per_core: str | float 50 | 51 | # task management args 52 | sequential: int = 1 53 | 54 | # job reporting args 55 | log_path: str = "$SCRATCH/job_output_%j.txt" 56 | 57 | # ---------------- 58 | # -- Validation -- 59 | # ---------------- 60 | def check_account(account: str): 61 | assert account.startswith('rrg-') or account.startswith('def-') 62 | assert not account.endswith('_cpu') and not account.endswith('_gpu') 63 | 64 | def check_time(time: str): 65 | assert isinstance(time, str) 66 | 67 | # while technically slurm is more permissive, I find being more explicit removes 68 | # some common footguns. Example the "int:int" format is oft misunderstood as "hours:minutes" 69 | 70 | # "hour:minute:second" 71 | h_m_s = re.match(r'^\d+:\d+:\d+$', time) 72 | 73 | # "days-hours" 74 | d_h = re.match(r'^\d+-\d+$', time) 75 | 76 | # "days-hours:minutes:seconds" 77 | d_h_m_s = re.match(r'^\d+-\d+:\d+:\d+$', time) 78 | 79 | assert h_m_s or d_h or d_h_m_s 80 | 81 | def normalize_memory(memory: float | str) -> str: 82 | if isinstance(memory, (float, int)): 83 | mbs = int(memory * 1024) 84 | memory = f'{mbs}M' 85 | 86 | assert isinstance(memory, str) 87 | assert re.match(r'^\d+[G|M|K]$', memory) 88 | return memory 89 | 90 | def shared_validation(options: SingleNodeOptions | MultiNodeOptions): 91 | check_account(options.account) 92 | check_time(options.time) 93 | options.mem_per_core = normalize_memory(options.mem_per_core) 94 | 95 | 96 | def single_validation(options: SingleNodeOptions): 97 | shared_validation(options) 98 | # TODO: validate that the current cluster has nodes that can handle the specified request 99 | 100 | def multi_validation(options: MultiNodeOptions): 101 | shared_validation(options) 102 | 103 | def validate(options: SingleNodeOptions | MultiNodeOptions): 104 | if isinstance(options, SingleNodeOptions): single_validation(options) 105 | elif isinstance(options, MultiNodeOptions): multi_validation(options) 106 | 107 | # ------------------ 108 | # -- External API -- 109 | # ------------------ 110 | 111 | def memory_in_mb(memory: str | float) -> float: 112 | memory = normalize_memory(memory) 113 | 114 | if memory.endswith('M'): 115 | return int(memory[:-1]) 116 | 117 | if memory.endswith('G'): 118 | return int(memory[:-1]) * 1024 119 | 120 | if memory.endswith('K'): 121 | return int(memory[:-1]) / 1024 122 | 123 | raise Exception('Unknown memory unit') 124 | 125 | def to_cmdline_flags( 126 | options: SingleNodeOptions | MultiNodeOptions, 127 | skip_validation: bool = False, 128 | ) -> str: 129 | if not skip_validation: 130 | validate(options) 131 | 132 | args = [ 133 | ('--account', options.account), 134 | ('--time', options.time), 135 | ('--mem-per-cpu', options.mem_per_core), 136 | ('--output', options.log_path), 137 | ] 138 | 139 | if isinstance(options, SingleNodeOptions): 140 | args += [ 141 | ('--ntasks', options.cores), 142 | ('--nodes', 1), 143 | ('--cpus-per-task', 1), 144 | ] 145 | 146 | elif isinstance(options, MultiNodeOptions): 147 | args += [ 148 | ('--ntasks', options.cores), 149 | ('--cpus-per-task', 1), 150 | ] 151 | 152 | return flagString(args) 153 | 154 | def fromFile(path: str): 155 | with open(path, 'r') as f: 156 | d = json.load(f) 157 | 158 | assert 'type' in d, 'Need to specify scheduling strategy.' 159 | t = d['type'] 160 | del d['type'] 161 | 162 | if t == 'single_node': 163 | return SingleNodeOptions(**d) 164 | 165 | elif t == 'multi_node': 166 | return MultiNodeOptions(**d) 167 | 168 | raise Exception('Unknown scheduling strategy') 169 | 170 | def buildParallel(executable: str, tasks: Iterable[Any], opts: SingleNodeOptions | MultiNodeOptions, parallelOpts: Dict[str, Any] = {}): 171 | threads = 1 172 | if isinstance(opts, SingleNodeOptions): 173 | threads = opts.threads_per_task 174 | 175 | cores = int(opts.cores / threads) 176 | 177 | parallel_exec = f'srun -N1 -n{threads} --exclusive {executable}' 178 | if isinstance(opts, SingleNodeOptions): 179 | parallel_exec = executable 180 | 181 | task_str = ' '.join(map(str, tasks)) 182 | return f'run-parallel --parallel {cores} --exec "{parallel_exec}" --tasks {task_str}' 183 | 184 | def schedule( 185 | script: str, 186 | opts: Optional[SingleNodeOptions | MultiNodeOptions] = None, 187 | script_name: str = 'auto_slurm.sh', 188 | cleanup: bool = True, 189 | skip_validation: bool = False, 190 | ) -> None: 191 | with open(script_name, 'w') as f: 192 | f.write(script) 193 | 194 | cmdArgs = '' 195 | if opts is not None: 196 | cmdArgs = to_cmdline_flags(opts, skip_validation=skip_validation) 197 | 198 | os.system(f'sbatch {cmdArgs} {script_name}') 199 | 200 | if cleanup: 201 | os.remove(script_name) 202 | -------------------------------------------------------------------------------- /PyExpUtils/runner/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andnp/PyExpUtils/5d076ff1196368a936b18998afd00c80d4699857/PyExpUtils/runner/__init__.py -------------------------------------------------------------------------------- /PyExpUtils/runner/parallel.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | from PyExpUtils.utils.cmdline import flagString 3 | 4 | def build(d: Dict[str, Any]): 5 | # required 6 | ex = d['executable'] 7 | cores = d['cores'] 8 | tasks = d['tasks'] 9 | 10 | # make sure tasks is a string 11 | tasks = tasks if isinstance(tasks, str) else ' '.join(map(str, tasks)) 12 | 13 | # optional 14 | delay = d.get('delay') 15 | sshloginfile = d.get('sshloginfile') 16 | batch = d.get('batch') 17 | 18 | # build parameter pairs 19 | pairs = [ 20 | ('-j', cores), 21 | ('--delay', delay), 22 | ('--sshloginfile', sshloginfile), 23 | ('-n', batch), 24 | ] 25 | 26 | # build parallel options 27 | ops = flagString(pairs, joiner=' ') 28 | 29 | if len(tasks) == 0: 30 | return None 31 | 32 | return f'parallel {ops} {ex} ::: {tasks}' 33 | -------------------------------------------------------------------------------- /PyExpUtils/runner/parallel_exec.py: -------------------------------------------------------------------------------- 1 | import shlex 2 | import subprocess 3 | import signal 4 | 5 | from dataclasses import dataclass 6 | from typing import Dict, Sequence 7 | from multiprocessing.dummy import Pool 8 | from functools import partial 9 | 10 | from PyExpUtils.utils.generator import group 11 | 12 | @dataclass 13 | class ParallelConfig: 14 | executable: str 15 | parallel: int 16 | tasks: Sequence[int] 17 | sequential: int = 1 18 | 19 | 20 | def execute(c: ParallelConfig): 21 | task_seq = group(c.tasks, c.sequential) 22 | task_strs = map(_stringify_group, task_seq) 23 | execs = map(lambda t: f'{c.executable} {t}', task_strs) 24 | 25 | procs: Dict[int, subprocess.Popen] = {} 26 | 27 | def _handler(sig, frame): 28 | for p in procs.values(): 29 | p.send_signal(signal.SIGUSR1) 30 | 31 | signal.signal(signal.SIGUSR1, _handler) 32 | 33 | with Pool(c.parallel) as p: 34 | p.map(partial(_exec, procs=procs), execs) 35 | 36 | def _exec(cmd: str, procs): 37 | parts = shlex.split(cmd) 38 | process = subprocess.Popen(parts) 39 | procs[process.pid] = process 40 | process.wait() 41 | del procs[process.pid] 42 | 43 | def _stringify_group(g: Sequence[int]) -> str: 44 | return ', '.join(map(str, g)) 45 | -------------------------------------------------------------------------------- /PyExpUtils/runner/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Generator, Iterable, List, TypeVar 2 | from PyExpUtils.models.ExperimentDescription import ExperimentDescription, loadExperiment 3 | from PyExpUtils.results.sqlite import detectMissingIndices 4 | 5 | T = TypeVar('T') 6 | def print_progress(size: int, it: Iterable[T]) -> Generator[T, Any, None]: 7 | m_width = 0 8 | for i, v in enumerate(it): 9 | msg = f'{i + 1}/{size}' 10 | m_width = max(m_width, len(msg)) 11 | print(' ' * m_width, end='\r') 12 | print(msg, end='\r') 13 | if i - 1 == size: 14 | print() 15 | yield v 16 | 17 | def approximate_cost(jobs: int, cores_per_job: int, mem_per_core: float, hours: float): 18 | total_cores = jobs * cores_per_job 19 | mem_in_gb = mem_per_core / 1024 20 | core_equivalents = total_cores * max(mem_in_gb / 4, 1) 21 | 22 | core_hours = core_equivalents * hours 23 | core_years = core_hours / (24 * 365) 24 | 25 | return core_years 26 | 27 | def gather_missing_indices(experiment_paths: Iterable[str], runs: int, loader: Callable[[str], ExperimentDescription] = loadExperiment, base: str = './'): 28 | path_to_indices: Dict[str, List[int]] = {} 29 | 30 | for path in experiment_paths: 31 | exp = loader(path) 32 | 33 | indices = detectMissingIndices(exp, runs, base=base) 34 | indices = sorted(indices) 35 | path_to_indices[path] = indices 36 | 37 | size = exp.numPermutations() * runs 38 | print(path, f'{len(indices)} / {size}') 39 | 40 | return path_to_indices 41 | -------------------------------------------------------------------------------- /PyExpUtils/utils/NestedDict.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Any, Callable, Dict, Generator, Generic, Optional, Tuple, TypeVar, Union 3 | 4 | K = TypeVar('K', bound=Union[int, str]) 5 | V = TypeVar('V') 6 | R = TypeVar('R') 7 | 8 | Key = Union[K, Tuple[K, ...]] 9 | class NestedDict(Generic[K, V]): 10 | def __init__(self, depth: int, default: Optional[Callable[[], V]] = None): 11 | self._depth = depth 12 | self._default = default 13 | self._data: Dict[K, Any] = {} 14 | 15 | def __getitem__(self, keys: Key[K]) -> Any: 16 | keys = self._normalize(keys) 17 | assert len(keys) <= self._depth 18 | 19 | has_ellipse = ... in keys 20 | 21 | # easy case: we directly access a single element 22 | if not has_ellipse: 23 | level: Any = self._data 24 | for key in keys[:-1]: 25 | if key not in level: 26 | level[key] = {} 27 | 28 | level = level[key] 29 | 30 | if keys[-1] not in level and self._default is not None: 31 | level[keys[-1]] = self._default() 32 | 33 | return level[keys[-1]] 34 | 35 | out = {} 36 | idx = keys.index(...) 37 | level = self._data 38 | for i in range(idx): 39 | key = keys[i] 40 | level = level[key] 41 | 42 | start = level 43 | for dkey in start: 44 | level = start[dkey] 45 | for i in range(idx + 1, len(keys)): 46 | key = keys[i] 47 | level = level[key] 48 | 49 | out[dkey] = level 50 | 51 | return out 52 | 53 | def __setitem__(self, keys: Key[K], val: V): 54 | keys = self._normalize(keys) 55 | assert len(keys) == self._depth 56 | 57 | level = self._data 58 | for key in keys[:-1]: 59 | nlevel = level.get(key, {}) 60 | level[key] = nlevel 61 | 62 | level = nlevel 63 | 64 | level[keys[-1]] = val 65 | 66 | def __iter__(self) -> Generator[Tuple[K, ...], None, None]: 67 | return _walkKeys(self._data) 68 | 69 | def __contains__(self, keys: Key[K]): 70 | keys = self._normalize(keys) 71 | assert len(keys) <= self._depth 72 | 73 | level = self._data 74 | for key in keys: 75 | if key not in level: return False 76 | 77 | level = level[key] 78 | 79 | return True 80 | 81 | def map(self, f: Callable[[V], R]) -> NestedDict[K, R]: 82 | out = NestedDict[K, R](self._depth) 83 | 84 | for key in self: 85 | out[key] = f(self[key]) 86 | 87 | return out 88 | 89 | def keys(self): 90 | return self._data.keys() 91 | 92 | def _normalize(self, key: Key[K]) -> Tuple[K, ...]: 93 | if isinstance(key, tuple): 94 | return key 95 | 96 | return (key, ) 97 | 98 | @classmethod 99 | def fromDict(cls, d: Dict[K, Any]): 100 | depth = 0 101 | node = d 102 | while isinstance(node, dict): 103 | k = next(iter(node.keys())) 104 | node = node[k] 105 | depth += 1 106 | 107 | out = cls(depth) 108 | out._data = d 109 | 110 | return out 111 | 112 | def _walkKeys(d: Dict[K, Any]) -> Generator[Tuple[K, ...], None, None]: 113 | for k in d: 114 | if not isinstance(d[k], dict): 115 | yield (k, ) 116 | continue 117 | 118 | for tup in _walkKeys(d[k]): 119 | yield (k, ) + tup 120 | -------------------------------------------------------------------------------- /PyExpUtils/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andnp/PyExpUtils/5d076ff1196368a936b18998afd00c80d4699857/PyExpUtils/utils/__init__.py -------------------------------------------------------------------------------- /PyExpUtils/utils/arrays.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import tee, filterfalse 3 | from typing import Any, Callable, List, Sequence, Union, Iterator, Optional 4 | from PyExpUtils.utils.jit import try2jit 5 | from PyExpUtils.utils.generator import windowAverage 6 | from PyExpUtils.utils.types import AnyNumber, ForAble, T 7 | 8 | def npPadUneven(arr: Sequence[np.ndarray], val: float) -> np.ndarray: 9 | longest = len(arr[0]) 10 | for sub in arr: 11 | a = sub.shape[0] 12 | if a > longest: 13 | longest = a 14 | 15 | out = np.empty((len(arr), longest)) 16 | for i, sub in enumerate(arr): 17 | out[i] = np.pad(sub, (0, longest - sub.shape[0]), constant_values=val) 18 | 19 | return out 20 | 21 | def padUneven(arr: List[List[T]], val: T) -> List[List[T]]: 22 | longest = len(arr[0]) 23 | for sub in arr: 24 | a = len(sub) 25 | if a > longest: 26 | longest = a 27 | 28 | out: List[List[T]] = [] 29 | for sub in arr: 30 | out.append(fillRest(sub, val, longest)) 31 | 32 | return out 33 | 34 | def fillRest(arr: List[T], val: T, length: int) -> List[T]: 35 | if len(arr) >= length: 36 | return arr 37 | 38 | rem = length - len(arr) 39 | pad = [val] * rem 40 | 41 | return arr + pad 42 | 43 | def fillRest_(arr: List[T], val: T, length: int) -> List[T]: 44 | for _ in range(len(arr), length): 45 | arr.append(val) 46 | 47 | return arr 48 | 49 | def first(listOrGen: Union[Sequence[T], Iterator[T]]): 50 | if isinstance(listOrGen, Sequence): 51 | return listOrGen[0] 52 | 53 | return next(listOrGen) 54 | 55 | def last(a: Sequence[T]) -> T: 56 | return a[len(a) - 1] 57 | 58 | def partition(gen: ForAble[T], pred: Callable[[T], bool]): 59 | t1, t2 = tee(gen) 60 | 61 | return filter(pred, t2), filterfalse(pred, t1) 62 | 63 | def deduplicate(arr: Sequence[T]) -> List[T]: 64 | return list(set(arr)) 65 | 66 | def unwrap(arr: List[T]) -> Union[T, List[T]]: 67 | if len(arr) == 1: 68 | return arr[0] 69 | 70 | return arr 71 | 72 | def sampleFrequency(arr: Sequence[Any], percent: Optional[float] = None, num: Optional[int] = None): 73 | if percent is None and num is None: 74 | raise Exception() 75 | 76 | if percent is not None: 77 | num = int(len(arr) * percent) 78 | 79 | if num is None: 80 | raise Exception('impossible to reach') 81 | 82 | every = int(len(arr) // num) 83 | 84 | return every 85 | 86 | def downsample(arr: Sequence[AnyNumber], percent: Optional[float] = None, num: Optional[int] = None, method: str = 'window'): 87 | every = sampleFrequency(arr, percent, num) 88 | 89 | if every <= 1: 90 | return arr 91 | 92 | if method == 'subsample': 93 | return [arr[i] for i in range(0, len(arr), every)] 94 | 95 | elif method == 'window': 96 | out = list(windowAverage(arr, every)) 97 | 98 | # this case might occur if the array is not evenly divisible by num 99 | # then we should end up with exactly one additional element in out 100 | # which does not have a complete window average. Just toss it 101 | if num is not None and len(out) > num: 102 | return out[:num] 103 | 104 | return out 105 | 106 | else: 107 | raise Exception() 108 | 109 | @try2jit 110 | def argsmax(arr: np.ndarray): 111 | ties: List[int] = [0 for _ in range(0)] # <-- trick njit into knowing the type of this empty list 112 | top: float = arr[0] 113 | 114 | for i in range(len(arr)): 115 | if arr[i] > top: 116 | ties = [i] 117 | top = arr[i] 118 | 119 | elif arr[i] == top: 120 | ties.append(i) 121 | 122 | if len(ties) == 0: 123 | ties = list(range(len(arr))) 124 | 125 | return ties 126 | 127 | @try2jit 128 | def argsmax2(arr: np.ndarray): 129 | ties: List[List[int]] = [] 130 | for i in range(arr.shape[0]): 131 | ties.append(argsmax(arr[i])) 132 | 133 | return ties 134 | -------------------------------------------------------------------------------- /PyExpUtils/utils/asyncio.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures as cf 2 | from typing import Callable, Generator, Iterable, TypeVar 3 | 4 | T = TypeVar('T') 5 | R = TypeVar('R') 6 | def threadMap(f: Callable[[T], R], arr: Iterable[T]) -> Generator[R, None, None]: 7 | with cf.ThreadPoolExecutor(max_workers=8) as executor: 8 | futures = (executor.submit(f, x) for x in arr) 9 | for future in cf.as_completed(futures): 10 | yield future.result() 11 | -------------------------------------------------------------------------------- /PyExpUtils/utils/cache.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Generic 2 | from PyExpUtils.utils.types import T 3 | 4 | Builder = Callable[[str], T] 5 | 6 | class Cache(Generic[T]): 7 | def __init__(self) -> None: 8 | self.cache: Dict[str, T] = {} 9 | 10 | def get(self, key: str, builder: Builder[T]) -> T: 11 | got = self.cache.get(key) 12 | 13 | if got is not None: 14 | return got 15 | 16 | d = builder(key) 17 | self.cache[key] = d 18 | 19 | return d 20 | 21 | def set(self, key: str, val: T): 22 | self.cache[key] = val 23 | 24 | def delete(self, key: str): 25 | del self.cache[key] 26 | 27 | def empty(self): 28 | self.cache = {} 29 | -------------------------------------------------------------------------------- /PyExpUtils/utils/cmdline.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Iterable, Tuple 2 | 3 | def flagString(pairs: Iterable[Tuple[str, Optional[Any]]], joiner: str = '='): 4 | pairs = filter(lambda p: p[1] is not None, pairs) 5 | pairs = sorted(pairs) 6 | 7 | s = '' 8 | for i, pair in enumerate(pairs): 9 | key, value = pair 10 | if i > 0: s += ' ' 11 | s += f'{key}{joiner}{str(value)}' 12 | 13 | return s 14 | -------------------------------------------------------------------------------- /PyExpUtils/utils/csv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.typing as npt 3 | from PyExpUtils.utils.arrays import unwrap 4 | from typing import Any, Callable, Iterable, List, Optional 5 | from PyExpUtils.utils.dict import flatKeys, get, pick 6 | from PyExpUtils.models.ExperimentDescription import ExperimentDescription 7 | 8 | def buildCsvParams(exp: ExperimentDescription, idx: int): 9 | params = pick(exp.getPermutation(idx), unwrap(exp.getKeys())) 10 | keys = flatKeys(params) 11 | keys = sorted(keys) 12 | 13 | values: List[str] = [] 14 | for key in keys: 15 | values.append(str(get(params, key))) 16 | 17 | return ','.join(values) 18 | 19 | def buildCsvHeader(exp: ExperimentDescription): 20 | params = pick(exp.getPermutation(0), unwrap(exp.getKeys())) 21 | keys = flatKeys(params) 22 | keys = sorted(keys) 23 | 24 | return ','.join(keys) 25 | 26 | def buildPrecisionStr(p: float): 27 | return ('{:.' + str(p) + 'f}').format 28 | 29 | def arrayToCsv(data: npt.ArrayLike, precision: Optional[int] = None): 30 | if precision is None: 31 | toStr: Callable[[Any], str] = str 32 | elif precision == 0: 33 | toStr = lambda x: str(int(x)) 34 | else: 35 | toStr = buildPrecisionStr(precision) 36 | 37 | if np.ndim(data) == 0: 38 | return toStr(data) 39 | 40 | assert isinstance(data, Iterable) 41 | return ','.join(map(toStr, data)) 42 | -------------------------------------------------------------------------------- /PyExpUtils/utils/dict.py: -------------------------------------------------------------------------------- 1 | from PyExpUtils.utils.types import T 2 | import re 3 | from typing import Any, Dict, List, Sequence, overload, Union 4 | 5 | # making a type alias here just for readability 6 | # using NewType('DictPath', str) is a huge pita for all consumers 7 | DictPath = str 8 | 9 | def merge(d1: Dict[Any, T], d2: Dict[Any, T]) -> Dict[Any, T]: 10 | ret = d2.copy() 11 | for key in d1: 12 | ret[key] = d2.get(key, d1[key]) 13 | 14 | return ret 15 | 16 | def flatKeys(d: Dict[Any, Any]) -> List[DictPath]: 17 | keys = d.keys() 18 | out: List[DictPath] = [] 19 | for key in keys: 20 | sub_keys: List[str] = [] 21 | 22 | if isinstance(d[key], dict): 23 | sub_keys = flatKeys(d[key]) 24 | out += [f'{key}.{subkey}' for subkey in sub_keys] 25 | 26 | elif isinstance(d[key], list): 27 | sub_keys = [] 28 | for i in range(len(d[key])): 29 | sub_keys.append(f'[{i}]') 30 | 31 | if isinstance(d[key][0], dict): 32 | sub_keys = [] 33 | for i, sub in enumerate(d[key]): 34 | sub_keys += [ f'[{i}].{subkey}' for subkey in flatKeys(sub) ] 35 | 36 | out += [f'{key}.{subkey}' for subkey in sub_keys] 37 | 38 | else: 39 | out.append(key) 40 | 41 | return out 42 | 43 | def flatDict(d: Dict[Any, Any]) -> Dict[DictPath, Any]: 44 | out: Dict[DictPath, Any] = {} 45 | 46 | for key in flatKeys(d): 47 | v = get(d, key) 48 | out[key] = v 49 | 50 | return out 51 | 52 | 53 | def hyphenatedStringify(d: Dict[Any, Any]): 54 | sorted_keys = sorted(flatKeys(d)) 55 | parts = [f'{key}-{get(d, key)}' for key in sorted_keys] 56 | 57 | return '_'.join(parts) 58 | 59 | @overload 60 | def pick(d: Dict[Any, T], keys: DictPath) -> T: 61 | ... 62 | @overload 63 | def pick(d: Dict[Any, T], keys: List[DictPath]) -> Dict[Any, T]: 64 | ... 65 | @overload 66 | def pick(d: Dict[Any, T], keys: Union[DictPath, List[DictPath]]) -> Union[T, Dict[Any, T]]: 67 | ... 68 | def pick(d: Dict[Any, T], keys: Union[DictPath, List[DictPath]]) -> Union[T, Dict[Any, T]]: 69 | if not isinstance(keys, list): 70 | return d[keys] 71 | 72 | if len(keys) == 1: 73 | return d[keys[0]] 74 | 75 | r: Dict[Any, T] = {} 76 | for key in keys: 77 | r[key] = d[key] 78 | 79 | return r 80 | 81 | def get(d: Dict[str, Union[Dict[Any, Any], List[Any], Any]], key: DictPath, default: Any = None) -> Any: 82 | if key == '': 83 | return d 84 | 85 | parts = key.split('.') 86 | 87 | el = d.get(parts[0]) 88 | if el is None: 89 | return default 90 | 91 | if isinstance(el, list) and len(parts) > 1: 92 | idx = re.findall(r'\[(\d+)\]', parts[1])[0] 93 | idx = int(idx) 94 | if len(el) <= idx: 95 | return default 96 | 97 | return get(el[idx], '.'.join(parts[2:]), default) 98 | 99 | if isinstance(el, dict): 100 | return get(el, '.'.join(parts[1:]), default) 101 | 102 | return el 103 | 104 | def equal(d1: Dict[Any, Any], d2: Dict[Any, Any], ignore: Sequence[Any] = []): 105 | for k in list(d1.keys()) + list(d2.keys()): 106 | if k in ignore: 107 | continue 108 | 109 | if k not in d2 or k not in d1: 110 | return False 111 | 112 | if d1[k] != d2[k]: 113 | return False 114 | 115 | return True 116 | 117 | def subset(d1: Dict[Any, Any], d2: Dict[Any, Any], ignore: Sequence[Any] = []): 118 | for k in d1: 119 | if k in ignore: 120 | continue 121 | 122 | if k not in d2: 123 | return False 124 | 125 | if isinstance(d1[k], dict) and isinstance(d2[k], dict): 126 | is_subsubset = subset(d1[k], d2[k]) 127 | if not is_subsubset: 128 | return False 129 | 130 | elif d1[k] != d2[k]: 131 | return False 132 | 133 | return True 134 | 135 | def partialEqual(d1: Dict[Any, Any], d2: Dict[Any, Any]): 136 | for k in d1: 137 | if k not in d2: 138 | continue 139 | 140 | if isinstance(d1[k], dict) and isinstance(d2[k], dict): 141 | is_subpartialEqual = partialEqual(d1[k], d2[k]) 142 | if not is_subpartialEqual: 143 | return False 144 | 145 | if d1[k] != d2[k]: 146 | return False 147 | 148 | return True 149 | -------------------------------------------------------------------------------- /PyExpUtils/utils/fp.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Any, Callable, Dict, TypeVar, cast 3 | 4 | F = TypeVar('F', bound=Callable[..., Any]) 5 | T = TypeVar('T') 6 | R = TypeVar('R') 7 | 8 | def memoize(f: F, cache: Dict[str, Any] = {}) -> F: 9 | def cacheKey(*args: Any, **kwargs: Any): 10 | s = '' 11 | for arg in args: 12 | s = s + '__' + str(arg) 13 | for arg in kwargs: 14 | s = s + '__' + str(arg) + '-' + str(kwargs[arg]) 15 | return s 16 | 17 | @functools.wraps(f) 18 | def wrapped(*args: Any, **kwargs: Any): 19 | nonlocal cache 20 | nonlocal cacheKey 21 | key = cacheKey(*args, **kwargs) 22 | if key in cache: 23 | return cache[key] 24 | ret = f(*args, **kwargs) 25 | cache[key] = ret 26 | return ret 27 | 28 | return cast(F, wrapped) 29 | 30 | def once(f: Callable[[], R]) -> Callable[[], R]: 31 | called = False 32 | ret = None 33 | 34 | def wrapped() -> R: 35 | nonlocal called 36 | nonlocal ret 37 | if not called: 38 | ret = f() 39 | called = True 40 | 41 | # have to cast this because control flow analysis thinks this is type R | None 42 | # since R _could_ contain None, there is no clean way to signal to the type-checker 43 | # that what we are doing is okay. So instead we override. 44 | return cast(R, ret) 45 | 46 | return wrapped 47 | -------------------------------------------------------------------------------- /PyExpUtils/utils/generator.py: -------------------------------------------------------------------------------- 1 | from PyExpUtils.utils.types import AnyNumber, ForAble, T 2 | import numpy as np 3 | from typing import Generator, List, cast 4 | 5 | # takes a generator and a number of items to group together 6 | # returns a generator that yields `num` items in groups 7 | # example: 8 | # grouped = group(range(10), 3) 9 | # grouped == [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 10 | def group(gen: ForAble[T], num: int) -> Generator[List[T], None, None]: 11 | coll: List[T] = [] 12 | for x in gen: 13 | coll.append(x) 14 | if len(coll) == num: 15 | yield coll 16 | coll = [] 17 | 18 | # if there was anything left over (generator was not perfectly divisible) 19 | # then go ahead an yield what was left and make sure to release it from memory 20 | if len(coll) > 0: 21 | yield coll 22 | coll = [] 23 | 24 | def windowAverage(arr: ForAble[AnyNumber], window: int) -> Generator[float, None, None]: 25 | for g in group(arr, window): 26 | yield cast(float, np.mean(g)) 27 | -------------------------------------------------------------------------------- /PyExpUtils/utils/iterable.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, TypeVar, cast 2 | 3 | T = TypeVar('T') 4 | def filter_none(it: Iterable[T | None]) -> Iterable[T]: 5 | out = filter(lambda x: x is not None, it) 6 | return cast(Iterable[T], out) 7 | -------------------------------------------------------------------------------- /PyExpUtils/utils/jit.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Callable, TypeVar 3 | 4 | _has_warned = False 5 | T = TypeVar('T', bound=Callable[..., Any]) 6 | 7 | def try2jit(f: T) -> T: 8 | try: 9 | from numba import njit 10 | return njit(f, cache=True, nogil=True, fastmath=True) 11 | except Exception: 12 | global _has_warned 13 | if not _has_warned: 14 | _has_warned = True 15 | logging.getLogger('PyExpUtils').warn('Could not jit compile --- expect slow performance') 16 | 17 | return f 18 | -------------------------------------------------------------------------------- /PyExpUtils/utils/pandas.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from functools import reduce 3 | from typing import Any, Dict, Iterable, Sequence 4 | 5 | def inner(dfs: Iterable[pd.DataFrame], on: str | Sequence[str]): 6 | return reduce(lambda l, r: pd.merge(l, r, how='inner', on=on), dfs) 7 | 8 | def outer(dfs: Iterable[pd.DataFrame], on: str | Sequence[str]): 9 | return reduce(lambda l, r: pd.merge(l, r, how='outer', on=on), dfs) 10 | 11 | 12 | def query(df: pd.DataFrame, d: Dict[str, Any]): 13 | if len(d) == 0: 14 | return df 15 | 16 | keys = d.keys() 17 | for k in keys: 18 | assert k in df, f"Can't query df. Unknown key {k} in {df.columns}" 19 | 20 | q = ' & '.join(f'`{k}`=={_maybe_quote(v)}' for k, v in d.items()) 21 | return df.query(q) 22 | 23 | def _maybe_quote(v: Any): 24 | if isinstance(v, str): 25 | return _quote(v) 26 | return v 27 | 28 | def _quote(s: str): 29 | return f'"{s}"' 30 | -------------------------------------------------------------------------------- /PyExpUtils/utils/path.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Iterable 3 | from PyExpUtils.utils.arrays import last 4 | 5 | def split(path: str): 6 | parts = path.split('/') 7 | 8 | # if path starts with leading slash, then make sure that doesn't go away from split('/') 9 | if parts[0] == '': 10 | parts[0] = '/' 11 | 12 | return parts 13 | 14 | def rest(path: str): 15 | parts = split(path) 16 | return join(*parts[1:]) 17 | 18 | def up(path: str): 19 | return join(*split(path)[:-1]) 20 | 21 | def fileName(path: str): 22 | parts = split(path) 23 | f = last(parts) 24 | return f 25 | 26 | def removeFirstAndLastSlash(s: str): 27 | if s.startswith('/'): 28 | s = s[1:] 29 | 30 | if s.endswith('/'): 31 | s = s[:-1] 32 | 33 | return s 34 | 35 | def remoteDuplicatedSlashes(s: str): 36 | return re.sub(r'/+', '/', s) 37 | 38 | def join(*argv: str): 39 | # remote empty strings 40 | gen: Iterable[str] = filter(lambda s: s != '', argv) 41 | # remote any duplicated slashes, e.g. this//is/a/path 42 | gen = map(remoteDuplicatedSlashes, gen) 43 | # get rid of leading/trailing slashes 44 | gen = map(removeFirstAndLastSlash, gen) 45 | # make sure there are no empty strings after cleaning up slashes 46 | gen = filter(lambda s: s != '', gen) 47 | 48 | path = '/'.join(gen) 49 | 50 | if argv[0].startswith('/'): 51 | path = '/' + path 52 | 53 | return path 54 | -------------------------------------------------------------------------------- /PyExpUtils/utils/permute.py: -------------------------------------------------------------------------------- 1 | import re 2 | from PyExpUtils.utils.dict import DictPath, flatKeys, get 3 | from PyExpUtils.utils.arrays import deduplicate, last 4 | from typing import Dict, Any, List, Tuple 5 | 6 | Record = Dict[str, Any] 7 | PathDict = Dict[DictPath, Any] 8 | KVPair = Tuple[DictPath, List[Any]] 9 | 10 | # ----------------------------------------------------------------------------- 11 | # clean public api 12 | 13 | def getParameterPermutation(sweeps: Record, index: int): 14 | pairs = _flattenToKeyValues(sweeps) 15 | return getPermutationFromPairs(pairs, index) 16 | 17 | 18 | def getNumberOfPermutations(sweeps: Record): 19 | pairs = _flattenToKeyValues(sweeps) 20 | return getCountFromPairs(pairs) 21 | 22 | # ----------------------------------------------------------------------------- 23 | 24 | def getPermutationFromPairs(pairs: List[KVPair], index: int): 25 | perm: PathDict = {} 26 | accum = 1 27 | 28 | for key, values in pairs: 29 | num = len(values) 30 | 31 | # if we have an empty array for a parameter, add that parameter back as an empty array 32 | if num == 0: 33 | perm[key] = [] 34 | continue 35 | 36 | perm[key] = values[(index // accum) % num] 37 | accum *= num 38 | 39 | return reconstructParameters(perm) 40 | 41 | def getCountFromPairs(pairs: List[KVPair]): 42 | accum = 1 43 | for pair in pairs: 44 | _, values = pair 45 | num = len(values) if len(values) > 0 else 1 46 | accum *= num 47 | 48 | return accum 49 | 50 | def dropLastArray(key: str): 51 | parts = key.split('.') 52 | 53 | # if the last part of the dict path is an element in an array 54 | # then just drop that part 55 | if re.match(r'\[\d+\]', last(parts)): 56 | return '.'.join(parts[:-1]) 57 | 58 | return '.'.join(parts) 59 | 60 | def reconstructParameters(perm: PathDict): 61 | res: Record = {} 62 | for key in perm: 63 | set_at_path(res, key, perm[key]) 64 | 65 | return res 66 | 67 | def _flattenToKeyValues(sweeps: Record): 68 | keys = flatKeys(sweeps) 69 | keys = list(map(dropLastArray, keys)) 70 | keys = deduplicate(keys) 71 | keys = sorted(keys) 72 | 73 | out: List[Tuple[DictPath, List[Any]]] = [] 74 | for key in keys: 75 | values = get(sweeps, key) 76 | 77 | # allow parameters to be set like "alpha": 0.1 as a shortcut 78 | if not isinstance(values, list): 79 | values = [values] 80 | 81 | out.append((key, values)) 82 | 83 | return out 84 | 85 | # TODO: move this to the utils.dict folder and try to compress/simplify it 86 | # then add unit tests 87 | def set_at_path(d: Record, path: DictPath, val: Any): 88 | def inner(d: Record, path: DictPath, val: Any, last: str) -> Record: 89 | if len(path) == 0: return d 90 | split = path.split('.', maxsplit=1) 91 | 92 | part, rest = split if len(split) > 1 else [split[0], ''] 93 | nxt = rest.split('.')[0] 94 | 95 | # lists 96 | if part.startswith('['): 97 | num = int(re.sub(r'[\[,\]]', '', part)) 98 | 99 | if len(d[last]) > num: 100 | piece = inner(d[last][num], rest, val, '') if len(rest) > 0 else val 101 | d[last][num] = piece 102 | else: 103 | piece = inner({}, rest, val, '') if len(rest) > 0 else val 104 | d[last].append(piece) 105 | return d 106 | 107 | # objects 108 | elif len(rest) > 0: 109 | if nxt.startswith('['): 110 | piece = d.setdefault(part, []) 111 | return inner(d, rest, val, part) 112 | else: 113 | piece = d.setdefault(part, {}) 114 | return inner(piece, rest, val, part) 115 | 116 | # everything else 117 | else: 118 | d.setdefault(part, val) 119 | return d 120 | 121 | inner(d, path, val, '') 122 | return d 123 | -------------------------------------------------------------------------------- /PyExpUtils/utils/random.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Sequence, TypeVar 3 | from PyExpUtils.utils.arrays import argsmax 4 | from PyExpUtils.utils.jit import try2jit 5 | 6 | T = TypeVar('T') 7 | 8 | # way faster than np.random.choice 9 | # arr is an array of probabilities, should sum to 1 10 | @try2jit 11 | def sample(arr: np.ndarray, rng: np.random.Generator): 12 | r = rng.random() 13 | s = 0 14 | for i, p in enumerate(arr): 15 | s += p 16 | if s > r or s == 1: 17 | return i 18 | 19 | # worst case if we run into floating point error, just return the last element 20 | # we should never get here 21 | return len(arr) - 1 22 | 23 | # also much faster than np.random.choice 24 | # choose an element from a list with uniform random probability 25 | @try2jit 26 | def choice(arr: Sequence[T], rng: np.random.Generator) -> T: 27 | idxs = rng.permutation(len(arr)) 28 | return arr[idxs[0]] 29 | 30 | # argmax that breaks ties randomly 31 | @try2jit 32 | def argmax(vals: np.ndarray, rng: np.random.Generator): 33 | ties = argsmax(vals) 34 | return choice(ties, rng) 35 | -------------------------------------------------------------------------------- /PyExpUtils/utils/str.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Dict, Any 3 | 4 | def interpolate(s: str, d: Dict[str, Any]): 5 | keys = re.findall('{.*?}', s) 6 | 7 | final = s 8 | for key in keys: 9 | unwrapped = key[1:-1] 10 | value = str(d[unwrapped]) 11 | final = final.replace(key, value) 12 | 13 | return final 14 | -------------------------------------------------------------------------------- /PyExpUtils/utils/types.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Any, Callable, Iterable, Iterator, Optional, Sequence, TypeVar, Union 3 | 4 | # the most generic of generics 5 | T = TypeVar('T') 6 | 7 | ForAble = Union[Sequence[T], Iterable[T], Iterator[T]] 8 | 9 | AnyNumber = Union[float, int] 10 | NpList = Union[np.ndarray, Sequence[AnyNumber]] 11 | 12 | def optionalCast(typ: Callable[[Any], T], thing: Optional[Any]) -> Optional[T]: 13 | if thing is None: 14 | return thing 15 | 16 | return typ(thing) 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyExpUtils 2 | 3 | [![Test](https://github.com/andnp/PyExpUtils/actions/workflows/test.yml/badge.svg?branch=master)](https://github.com/andnp/PyExpUtils/actions/workflows/test.yml) 4 | 5 | Short for python experiment utilities. 6 | This is a collection of scripts and machine learning experiment management tools that I use whenever I have to use python. 7 | 8 | For a more complete discussion on my organization patterns for research codebases, [look in the docs](docs/OrganizationPatterns.md). 9 | 10 | ## This lib 11 | Maintaining a rigorous experiment structure can be labor intensive. 12 | As such, I've automated out many of the common pieces that I use in my research. 13 | 14 | ### Parameter Permutations 15 | Experiments are encoded within JSON files. 16 | The JSON files should contain all of the information necessary to reproduce an experiment, including all parameters swept. 17 | Each of the parameter sweep specifications leads to a set of parameter permutations. 18 | Imagine the case where you are sweeping over 2 meta-parameters: 19 | ```json 20 | { 21 | "metaParameters": { 22 | "alpha": [0.01, 0.02, 0.04], 23 | "epsilon": [0.1, 0.2, 0.3] 24 | } 25 | } 26 | ``` 27 | Here there are 9 total possible permutations: `{alpha: 0.01, epsilon: 0.1}`, `{alpha: 0.01, epsilon: 0.2}`, ... 28 | 29 | These are indexed by a single numeric value. 30 | To run each permutation once, simply execute indices `i \in [0..8]`. 31 | To run each permutation twice, multiply by 2: `i \in [0..17]`. 32 | In general for `n` runs and `p` permutations: `i \in [0..(n*p - 1)]`. 33 | 34 | 35 | ## models 36 | A collection of JSON serialization classes with associated utility methods. 37 | ### PyExpUtils/models/Config.py 38 | **Config**: 39 | 40 | Experiment utility configuration file. 41 | Specifies global configuration settings: 42 | - *save_path*: directory format where experimental results will be stored 43 | - *log_path*: directory where log files will be saved (e.g. stacktraces during experiments) 44 | - *experiment_directory*: root directory where all of the experiment description files are located 45 | 46 | The config file should be at the root level of the repository and should be named `config.json`. 47 | ``` 48 | .git 49 | .gitignore 50 | tests/ 51 | scripts/ 52 | src/ 53 | config.json 54 | ``` 55 | 56 | An example configuration file: 57 | ```json 58 | { 59 | "save_path": "results/{name}/{environment}/{agent}/{params}", 60 | "log_path": "~/scratch/.logs", 61 | "experiment_directory": "experiments" 62 | } 63 | ``` 64 | 65 | 66 | **getConfig**: 67 | 68 | Memoized global configuration loader. 69 | Will read `config.json` (only once) and return a Config object. 70 | ```python 71 | config = getConfig() 72 | print(config.save_path) # -> 'results' 73 | ``` 74 | 75 | 76 | ### PyExpUtils/models/ExperimentDescription.py 77 | **ExperimentDescription**: 78 | 79 | Main workhorse class of the library. 80 | Takes a dictionary desribing all configurable options of an experiment and serializes that dictionary. 81 | Provides a set of utility methods to run parameter sweeps in parallel and for storing data during experiments. 82 | ```python 83 | exp_dict = { 84 | 'algorithm': 'SARSA', 85 | 'environment': 'MountainCar', 86 | 'metaParameters': { 87 | 'alpha': [1.0, 0.5, 0.25, 0.125], 88 | 'lambda': [1.0, 0.99, 0.98, 0.96] 89 | } 90 | } 91 | exp = ExperimentDescription(d) 92 | ``` 93 | 94 | 95 | **permutable**: 96 | 97 | Gives a list of parameters that can be swept over. 98 | Using above example dictionary: 99 | ```python 100 | params = exp.permutable() 101 | print(params) # -> { 'alpha': [1.0, 0.5, 0.25, 0.125], 'lambda': [1.0, 0.99, 0.98, 0.96] } 102 | ``` 103 | 104 | 105 | **getPermutation**: 106 | 107 | Gives the `i`'th permutation of sweepable parameters. 108 | Handles wrapping indices, so can perform multiple runs of the same parameter setting by setting `i` large. 109 | In the above dictionary, there are 16 total parameter permutations. 110 | ```python 111 | params = exp.getPermutation(0) 112 | print(params) # -> { 'alpha': 1.0, 'lambda': 1.0 } 113 | params = exp.getPermutation(1) 114 | print(params) # -> { 'alpha': 1.0, 'lambda': 0.99 } 115 | params = exp.getPermutation(15) 116 | print(params) # -> { 'alpha': 0.125, 'lambda': 0.96 } 117 | params = exp.getPermutation(16) 118 | print(params) # -> { 'alpha': 1.0, 'lambda': 1.0 } 119 | ``` 120 | 121 | 122 | **numPermutations**: 123 | 124 | Gives the total number of parameter permutations. 125 | ```python 126 | num_params = exp.numPermutations() 127 | print(num_params) # -> 16 128 | ``` 129 | 130 | 131 | **getRun**: 132 | 133 | Get the run number based on wrapping the index. 134 | This is a count of how many times we've wrapped back around to the same parameter setting. 135 | ```python 136 | num = exp.getRun(0) 137 | print(num) # -> 0 138 | num = exp.getRun(12) 139 | print(num) # -> 0 140 | num = exp.getRun(16) 141 | print(num) # -> 1 142 | num = exp.getRun(32) 143 | print(num) # -> 2 144 | ``` 145 | 146 | 147 | **getExperimentName**: 148 | 149 | Returns the name of the experiment if stated in the dictionary: `{ 'name': 'MountainCar-v0', ... }`. 150 | If not stated, will try to determine the name of the experiment based on the path to the JSON it is stored in (assuming experiments are stored in JSON files). 151 | ```python 152 | path = 'experiments/MountainCar-v0/sarsa.json' 153 | with open(path, 'r') as f: 154 | d = json.load(path) 155 | exp = ExperimentDescription(d, path) 156 | name = exp.getExperimentName() 157 | print(name) # -> d['name'] if available, or 'MountainCar-v0' if not. 158 | ``` 159 | 160 | 161 | **interpolateSavePath**: 162 | 163 | Takes a parameter index and generates a path for saving results. 164 | The path depends on the configuration settings of the library (i.e. `config.json`). 165 | Note this uses an opinionated formatting for save paths and parameter string representations. 166 | The configuration file can specify ordering and high-level control over paths, but for more fine-tuned control over how these are saved, inherit from this class and overload this method. 167 | `config.json`: 168 | ```json 169 | { 170 | "save_path": "results/{name}/{environment}/{agent}/{params}" 171 | } 172 | ``` 173 | ```python 174 | path = exp.interpolateSavePath(0) 175 | print(path) # -> 'results/MountainCar-v0/SARSA/alpha-1.0_lambda-1.0' 176 | ``` 177 | 178 | 179 | **buildSaveContext**: 180 | 181 | Builds a `FileSystemContext` utility object that contains the save path for experimental results. 182 | ```python 183 | file_context = exp.buildSaveContext(0) 184 | # make sure folder structure is built 185 | file_context.ensureExists() 186 | # get the path where results should be saved 187 | path = file_context.resolve('returns.npy') 188 | print(path) # -> '/results/MountainCar-v0/SARSA/alpha-1.0_lambda-1.0/returns.npy' 189 | # save results 190 | np.save(path, returns) 191 | ``` 192 | 193 | 194 | **loadExperiment**: 195 | 196 | Loads an ExperimentDescription from a JSON file (preferred way to make ExperimentDescriptions). 197 | 198 | ```python 199 | exp = loadExperiment('experiments/MountainCar-v0/sarsa.json') 200 | ``` 201 | 202 | 203 | ## collection 204 | ### PyExpUtils/collection/Collector.py 205 | **Collector**: 206 | 207 | A frame-based data collection utility. 208 | The collector stores some context---which index is currently being run, what is the current timestep, etc.--- 209 | and associates collected data with this context. 210 | 211 | Example usage: 212 | ```python 213 | collector = Collector( 214 | config={ 215 | # a dictionary mapping keys -> data preprocessors 216 | # for instance performing fixed-window averaging 217 | 'return': Window(100), 218 | # or subsampling 1 of every 100 values 219 | 'reward': Subsample(100), 220 | # or moving averages 221 | 'error': MovingAverage(0.99), 222 | # or ignored entirely 223 | 'special': Ignore(), 224 | }, 225 | # by default, if a key is not mentioned above it is stored as-is 226 | # however this can be changed by passing a default preprocessor 227 | default=Identity() 228 | ) 229 | 230 | # tell the collector what idx of the experiment we are currently processing 231 | collector.setIdx(0) 232 | 233 | for step in range(exp.max_steps): 234 | # tell the collector to increment the frame 235 | collector.next_frame() 236 | 237 | # these values will be associated with the current idx and frame 238 | collector.collect('reward', r) 239 | collector.collect('error', delta) 240 | 241 | # not all values need to be stored at each frame 242 | if step % 100 == 0: 243 | collector.collect('special', 'test value') 244 | ``` 245 | 246 | 247 | ## runner 248 | ### PyExpUtils/runner/Slurm.py 249 | **hours**: 250 | 251 | Takes an integer number of hours and returns a well-formatted time string. 252 | ```python 253 | time = hours(3) 254 | print(time) # -> '2:59:59 255 | ``` 256 | 257 | 258 | **gb**: 259 | 260 | Takes an integer number of gigabytes and returns a well-formatted memory string. 261 | ```python 262 | memory = gb(4) 263 | print(memory) # -> '4G' 264 | ``` 265 | 266 | 267 | ## results 268 | ### PyExpUtils/results/indices.py 269 | **listIndices**: 270 | 271 | Returns an iterator over indices for each parameter permutation. 272 | Can specify a number of runs and will cycle over the permutations `runs` number of times. 273 | 274 | ```python 275 | for i in listIndices(exp, runs=2): 276 | print(i, exp.getRun(i)) # -> "0 0", "1 0", "2 0", ... "0 1", "1 1", ... 277 | ``` 278 | 279 | 280 | ## utils 281 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "save_path": ".tmp/{agent}/{environment}/{params}/{run}", 3 | "experiment_directory": "mock_repo/experiments" 4 | } 5 | -------------------------------------------------------------------------------- /dev-setup.sh: -------------------------------------------------------------------------------- 1 | pip install -r requirements.txt 2 | pre-commit install -t pre-commit -t commit-msg 3 | -------------------------------------------------------------------------------- /docs/OrganizationPatterns.md: -------------------------------------------------------------------------------- 1 | ## Structure 2 | I structure all of my machine learning experiments in the same way. 3 | This common structure lowers the cognitive cost of switching between projects. 4 | The structure specification has been a work in progress over the past few years, but has finally converged to a reasonably stable point. 5 | 6 | My experiment codebase always looks like: 7 | ``` 8 | src/ 9 | tests/ 10 | scripts/ 11 | config.json 12 | 13 | experiments/ 14 | results/ 15 | ``` 16 | 17 | The `src/`, `tests/`, and `scripts/` folders hold all of the common source code. 18 | This includes things like dataset loaders, RL environments, and ML model code. 19 | `scripts/` will hold things like job scheduling and exploratory data analysis scripts (though those sometimes go in another `analysis/` folder depending on the project). 20 | 21 | ### Experiments 22 | The `experiments/` folder contains all of the experiment source code and experiment description datafiles. 23 | It additionally contains brief experiment write-ups, usually in markdown. 24 | 25 | #### Experiment Folders 26 | The structure of this directory is generally flat. 27 | It contains a list of folders with experiment short names. 28 | For example: 29 | ``` 30 | experiments/ 31 | experiments/overfit 32 | experiments/covariateShift 33 | ... 34 | ``` 35 | 36 | I do **not** have version numbers in these folders. 37 | Inevitably, the experiments will change with time. 38 | These changes are checked into the version control software (git), and recorded in a log (more on that momentarily). 39 | 40 | Inside each of the individual experiment folders, I have the description datafiles and the experiment entry script. 41 | 42 | #### Entry Script 43 | Every experiment contains its own entry script. 44 | This is the one part of the codebase where copy/paste is accepted and encouraged. 45 | These scripts should be minimal, and should only "plug-in" the pieces developed in `src/`. 46 | An example: 47 | ```python 48 | from src.dataloader import dataloader 49 | from src.ANN import ANN 50 | 51 | data = dataloader.MNIST() 52 | ann = ANN() 53 | 54 | ann.train(data) 55 | 56 | accuracy_train = ann.evaluate(data.train) 57 | accuracy_test = ann.evaluate(data.test) 58 | 59 | saveFile('test.csv', accuracy_test) 60 | saveFile('train.csv', accuracy_train) 61 | ``` 62 | This is the file that will be called at the command-line: `python experiments/overfit/overfit.py`. 63 | 64 | #### Experiment Descriptions 65 | All experiment description files should be stored in JSON format. 66 | These files describe the parameter sweeps that are inherent with most, if not all, machine learning experiments. 67 | They also describe manipulations of experiment level parameters (as opposed to only algorithm meta-parameters). 68 | 69 | An example: 70 | ```JSON 71 | { 72 | "algorithm": "ANN", 73 | "dataset": "MNIST", 74 | "metaParameters": { 75 | "layers": [ 76 | { "type": "dense", "units": 256, "transfer": ["linear", "relu", "sigmoid"] } 77 | ], 78 | "useDictionary": false, 79 | "dictionaryWeight": [0.01] 80 | }, 81 | "limitSamples": 1000, 82 | "optimizer": { 83 | "type": "rmsprop", 84 | "stepsize": 0.001 85 | } 86 | } 87 | ``` 88 | 89 | In the above example, `metaParameters` specifies algorithm meta-parameters. 90 | The values inside an array at the lowest level of the object will be swept over. 91 | In this case, `layers` and `transfer` are both array values. 92 | `layers` is not at the bottom of the object, but `transfer` is; only `transfer` will be swept over. 93 | This means that this experiment description describes 3 different models, one with a `linear` transfer, one with a `relu` transfer, and one with a `sigmoid` transfer. 94 | 95 | Bottom level parameters can be singleton values or length 1 arrays if they are not to be swept. 96 | That is 97 | ```JSON 98 | { "units": 256 } 99 | ``` 100 | and 101 | ```JSON 102 | { "units": [256] } 103 | ``` 104 | are equivalent. 105 | The first is just syntactic sugar for the second, as it is a common case. 106 | 107 | #### Analysis 108 | Each of the experiments folders will contain a set of analysis scripts. 109 | These analysis files will read in the generated experiment data and produce the final plots and tables. 110 | 111 | They should take **no** command-line arguments. 112 | If I have a set of results, running one of these scripts should produce a near-publish-ready plot or table. 113 | The no arguments restriction reduces cognitive load when running these scripts months later to reproduce results. 114 | 115 | These are another case where copy/paste is okay, as many of these scripts will be quite similar. 116 | As with the entry files, these should be minimal and high-level. 117 | Low-level functions should be (intelligently) abstracted out in the `src/` directory. 118 | 119 | #### Experiment Write-ups 120 | In every experiment directory, there should be a living write-up document. 121 | This document keeps a _complete_ history of all experiment trials, both failures and successes. 122 | In my experience, each of these write-ups will contain a dozen failures and a single success (because once you have a successfully experiment, generally you stop running it again!) 123 | 124 | The structure of these files is still in flux for me, however there are consistent key details that should be included. 125 | 1) The trial number - how many times have you attempted this experiment? 126 | 2) A textual description of the experiment that should have enough details to reproduce without the code. 127 | 3) The hypothesis being tested. 128 | 4) The list of open questions, and open issues. 129 | 5) The list of follow-up questions after running a trial - often times when running an experiment, auxiliary questions become clear. This is a great place to look when trying to come up with future research directions! 130 | 6) A textual description of the outcome for each trial - usually an explanation of what went wrong. 131 | 7) The outputs of the `analysis` scripts for each trial. 132 | 8) The commit hash for the repo (and any important sub-repos) when each trial was run. 133 | 134 | An example: 135 | ```markdown 136 | # Overfit 137 | A long description of this experiment. 138 | Why am I measuring overfitting, and how? 139 | What are my goals? 140 | How does this tie into the larger project? 141 | 142 | ## Hypothesis 143 | What do I expect will happen? 144 | What does success look like, and how do I measure that? 145 | 146 | ## Open Questions 147 | 1) How does my choice of optimizer affect the results? 148 | 2) Am I running enough epochs, or will the ANN stop overfitting after a certain point? 149 | 150 | ## Follow-up Questions 151 | 1) Do these effects hold across other datasets? 152 | 2) Would a model with skip connections overfit as badly? 153 | 154 | ## Trials 155 | ### Trial 00 156 | This trial tests with 300 epochs. 157 | It failed to run long enough, as the test accuracy continued to decrease sharply at epoch 300. 158 | I should run again with many more epochs, perhaps 600. 159 | #### Results 160 | [Learning curves](./trials/trial-00_test-train.svg) 161 | ``` 162 | 163 | #### Experiment Log 164 | The final piece to the experiment folder is a table-of-contents or experiment log. 165 | This is a top level markdown file in the `experiments/` folder that specifies what each of the subfolders contains. 166 | After a month or two, there will be old experiments that you forget about. 167 | This file should be a reminder of what has happened in the past. 168 | 169 | It usually looks like: 170 | ```markdown 171 | # Experiment Log 172 | 173 | ## Experiments 174 | ### Overfit 3/1/19 175 | **status**: on-going 176 | **path**: `overfit/` 177 | 178 | ### Covariate Shift 9/15/18 179 | **status**: complete 180 | **path**: `covariateShift/` 181 | ``` 182 | Note that entries should be in reverse order of age (newest first, oldest last). 183 | 184 | #### Putting it all together 185 | The `experiments` directory should look like: 186 | ``` 187 | experiments/ 188 | experiments/toc.md 189 | experiments/overfit 190 | experiments/overfit/results.md 191 | experiments/overfit/overfit.py 192 | experiments/overfit/learning_curve.py 193 | experiments/overfit/ann.json 194 | experiments/overfit/logistic_regression.json 195 | experiments/overfit/trials/trial-00_test-train.svg 196 | ``` 197 | 198 | ### Results 199 | The results directory contains all of the **raw** results from experiments. 200 | The subfolders should be the experiment short names (e.g. `results/overfit` using the above example). 201 | This is usually a git submodule. 202 | 203 | This folder can become unwieldy with many many files. 204 | As such a git hosting service like github will not work. 205 | Usually I host a git repo on my personal server. 206 | I am actively investigating better solutions here. 207 | 208 | -------------------------------------------------------------------------------- /mock_repo/experiments/overfit/best/ann.json: -------------------------------------------------------------------------------- 1 | { 2 | "algorithm": "ann", 3 | "dataset": "cifar10", 4 | "samples": 1000, 5 | "metaParameters": { 6 | "layers": [ 7 | { "units": 10, "type": "dense", "transfer": "relu" }, 8 | { "units": 32, "type": "dense", "transfer": "softmax" } 9 | ], 10 | "useDictionary": false 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /mock_repo/experiments/overfit/best/sdl.json: -------------------------------------------------------------------------------- 1 | { 2 | "algorithm": "sdl", 3 | "dataset": "cifar10", 4 | "samples": 1000, 5 | "metaParameters": { 6 | "hiddenUnits": [100] 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /mock_repo/experiments/overfit/sweeps/ann.json: -------------------------------------------------------------------------------- 1 | { 2 | "algorithm": "ann", 3 | "dataset": "cifar10", 4 | "samples": 1000, 5 | "metaParameters": { 6 | "layers": [ 7 | { "units": [10, 100, 1000], "type": "dense", "transfer": "relu" }, 8 | { "units": 32, "type": "dense", "transfer": "softmax" } 9 | ], 10 | "useDictionary": false 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /mock_repo/experiments/overfit/sweeps/sdl.json: -------------------------------------------------------------------------------- 1 | { 2 | "algorithm": "sdl", 3 | "dataset": "cifar10", 4 | "samples": 1000, 5 | "metaParameters": { 6 | "hiddenUnits": [10, 100, 1000] 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool] 2 | [tool.commitizen] 3 | name = "cz_conventional_commits" 4 | version = "8.1.2" 5 | tag_format = "$version" 6 | version_files = ["pyproject.toml"] 7 | 8 | [tool.ruff.lint] 9 | ignore = ['E701', 'E731', 'E741'] 10 | 11 | [project] 12 | name = "PyExpUtils-andnp" 13 | version = "8.1.2" 14 | description = "A small set of utilities for RL and ML experiments" 15 | authors = [ 16 | {name = "Andy Patterson", email = "andnpatterson@gmail.com"}, 17 | ] 18 | dependencies = [ 19 | "numba>=0.57.0", 20 | "numpy>=1.21.5", 21 | "filelock>=3.0.0", 22 | "pandas", 23 | "connectorx", 24 | ] 25 | requires-python = ">=3.10" 26 | readme = "README.md" 27 | license = {text = "MIT"} 28 | 29 | [project.optional-dependencies] 30 | dev = [ 31 | "ruff", 32 | "commitizen", 33 | "pre-commit", 34 | "types-filelock", 35 | "build", 36 | "twine", 37 | "pyright>=1.1.324", 38 | ] 39 | 40 | [project.scripts] 41 | run-parallel = "PyExpUtils.parallel_runner:main" 42 | 43 | [build-system] 44 | requires = ["pdm-pep517>=1.0.0"] 45 | build-backend = "pdm.pep517.api" 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by uv via the following command: 2 | # uv pip compile --extra=dev pyproject.toml -o requirements.txt 3 | argcomplete==3.5.1 4 | # via commitizen 5 | backports-tarfile==1.2.0 6 | # via jaraco-context 7 | build==1.2.1 8 | # via pyexputils-andnp (pyproject.toml) 9 | certifi==2024.7.4 10 | # via requests 11 | cffi==1.16.0 12 | # via cryptography 13 | cfgv==3.4.0 14 | # via pre-commit 15 | charset-normalizer==3.3.2 16 | # via 17 | # commitizen 18 | # requests 19 | colorama==0.4.6 20 | # via commitizen 21 | commitizen==3.29.0 22 | # via pyexputils-andnp (pyproject.toml) 23 | connectorx==0.3.2 24 | # via pyexputils-andnp (pyproject.toml) 25 | cryptography==43.0.3 26 | # via secretstorage 27 | decli==0.6.1 28 | # via commitizen 29 | distlib==0.3.8 30 | # via virtualenv 31 | docutils==0.21.2 32 | # via readme-renderer 33 | filelock==3.13.3 34 | # via 35 | # pyexputils-andnp (pyproject.toml) 36 | # virtualenv 37 | identify==2.5.36 38 | # via pre-commit 39 | idna==3.7 40 | # via requests 41 | importlib-metadata==7.1.0 42 | # via 43 | # keyring 44 | # twine 45 | jaraco-classes==3.4.0 46 | # via keyring 47 | jaraco-context==6.0.1 48 | # via keyring 49 | jaraco-functools==4.1.0 50 | # via keyring 51 | jeepney==0.8.0 52 | # via 53 | # keyring 54 | # secretstorage 55 | jinja2==3.1.3 56 | # via commitizen 57 | keyring==25.5.0 58 | # via twine 59 | llvmlite==0.42.0 60 | # via numba 61 | markdown-it-py==3.0.0 62 | # via rich 63 | markupsafe==2.1.5 64 | # via jinja2 65 | mdurl==0.1.2 66 | # via markdown-it-py 67 | more-itertools==10.2.0 68 | # via 69 | # jaraco-classes 70 | # jaraco-functools 71 | nh3==0.2.18 72 | # via readme-renderer 73 | nodeenv==1.8.0 74 | # via 75 | # pre-commit 76 | # pyright 77 | numba==0.59.1 78 | # via pyexputils-andnp (pyproject.toml) 79 | numpy==1.26.4 80 | # via 81 | # pyexputils-andnp (pyproject.toml) 82 | # numba 83 | # pandas 84 | packaging==24.0 85 | # via 86 | # build 87 | # commitizen 88 | pandas==2.2.2 89 | # via pyexputils-andnp (pyproject.toml) 90 | pkginfo==1.11.1 91 | # via twine 92 | platformdirs==4.2.1 93 | # via virtualenv 94 | pre-commit==3.7.0 95 | # via pyexputils-andnp (pyproject.toml) 96 | prompt-toolkit==3.0.36 97 | # via questionary 98 | pycparser==2.22 99 | # via cffi 100 | pygments==2.18.0 101 | # via 102 | # readme-renderer 103 | # rich 104 | pyproject-hooks==1.1.0 105 | # via build 106 | pyright==1.1.365 107 | # via pyexputils-andnp (pyproject.toml) 108 | python-dateutil==2.9.0.post0 109 | # via pandas 110 | pytz==2024.1 111 | # via pandas 112 | pyyaml==6.0.1 113 | # via 114 | # commitizen 115 | # pre-commit 116 | questionary==2.0.1 117 | # via commitizen 118 | readme-renderer==43.0 119 | # via twine 120 | requests==2.32.3 121 | # via 122 | # requests-toolbelt 123 | # twine 124 | requests-toolbelt==1.0.0 125 | # via twine 126 | rfc3986==2.0.0 127 | # via twine 128 | rich==13.7.1 129 | # via twine 130 | ruff==0.4.4 131 | # via pyexputils-andnp (pyproject.toml) 132 | secretstorage==3.3.3 133 | # via keyring 134 | setuptools==70.0.0 135 | # via nodeenv 136 | six==1.16.0 137 | # via python-dateutil 138 | termcolor==2.4.0 139 | # via commitizen 140 | tomlkit==0.12.5 141 | # via commitizen 142 | twine==5.0.0 143 | # via pyexputils-andnp (pyproject.toml) 144 | types-filelock==3.2.7 145 | # via pyexputils-andnp (pyproject.toml) 146 | tzdata==2024.1 147 | # via pandas 148 | urllib3==2.2.3 149 | # via 150 | # requests 151 | # twine 152 | virtualenv==20.26.0 153 | # via pre-commit 154 | wcwidth==0.2.13 155 | # via prompt-toolkit 156 | zipp==3.20.0 157 | # via importlib-metadata 158 | -------------------------------------------------------------------------------- /scripts/generate_docs.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | 4 | MODULES = ['models', 'collection', 'runner', 'results', 'utils'] 5 | 6 | doc_str = """# PyExpUtils 7 | 8 | [![Test](https://github.com/andnp/PyExpUtils/actions/workflows/test.yml/badge.svg?branch=master)](https://github.com/andnp/PyExpUtils/actions/workflows/test.yml) 9 | 10 | Short for python experiment utilities. 11 | This is a collection of scripts and machine learning experiment management tools that I use whenever I have to use python. 12 | 13 | For a more complete discussion on my organization patterns for research codebases, [look in the docs](docs/OrganizationPatterns.md). 14 | 15 | ## This lib 16 | Maintaining a rigorous experiment structure can be labor intensive. 17 | As such, I've automated out many of the common pieces that I use in my research. 18 | 19 | ### Parameter Permutations 20 | Experiments are encoded within JSON files. 21 | The JSON files should contain all of the information necessary to reproduce an experiment, including all parameters swept. 22 | Each of the parameter sweep specifications leads to a set of parameter permutations. 23 | Imagine the case where you are sweeping over 2 meta-parameters: 24 | ```json 25 | { 26 | "metaParameters": { 27 | "alpha": [0.01, 0.02, 0.04], 28 | "epsilon": [0.1, 0.2, 0.3] 29 | } 30 | } 31 | ``` 32 | Here there are 9 total possible permutations: `{alpha: 0.01, epsilon: 0.1}`, `{alpha: 0.01, epsilon: 0.2}`, ... 33 | 34 | These are indexed by a single numeric value. 35 | To run each permutation once, simply execute indices `i \\in [0..8]`. 36 | To run each permutation twice, multiply by 2: `i \\in [0..17]`. 37 | In general for `n` runs and `p` permutations: `i \\in [0..(n*p - 1)]`. 38 | 39 | 40 | """ 41 | 42 | def getName(line): 43 | line = line.strip() 44 | line = line.replace('def ', '') 45 | line = line.replace('class ', '') 46 | line = re.sub(r'\W*\(.*\).*:*', '', line) 47 | line = line.replace(':', '') 48 | return line 49 | 50 | def scanFile(f): 51 | in_doc = False 52 | tabs = 0 53 | get_method = False 54 | buffer = [] 55 | total = {} 56 | for line in f.readlines(): 57 | if get_method: 58 | if '@' in line: 59 | continue 60 | get_method = False 61 | name = getName(line) 62 | total[name] = buffer 63 | tabs = 0 64 | buffer = [] 65 | continue 66 | 67 | if not in_doc and '"""doc' in line: 68 | in_doc = True 69 | # count the number of whitespaces to offset all lines in docs by 70 | m = re.match(r'\W*', line) or [''] 71 | tabs = len(m[0]) - 3 72 | continue 73 | 74 | if not in_doc: 75 | continue 76 | 77 | if '"""' in line: 78 | in_doc = False 79 | get_method = True 80 | continue 81 | 82 | line = line[tabs:] 83 | buffer.append(line) 84 | 85 | return total 86 | 87 | 88 | py_paths = glob.glob('PyExpUtils/**/*.py', recursive=True) 89 | py_paths = filter(lambda path: '__init__.py' not in path, py_paths) 90 | 91 | split_paths = {} 92 | for module in MODULES: 93 | split_paths[module] = [] 94 | 95 | for path in py_paths: 96 | parts = path.split('/') 97 | module = parts[1] 98 | arr = split_paths.get(module, []) 99 | arr.append(path) 100 | 101 | for module in MODULES: 102 | doc_str += f"## {module}\n" 103 | 104 | init = open(f'PyExpUtils/{module}/__init__.py', 'r') 105 | init_str = '' 106 | start_read = False 107 | for line in init.readlines(): 108 | if '"""doc' in line: 109 | start_read = True 110 | continue 111 | 112 | if '"""' in line: 113 | start_read = False 114 | continue 115 | 116 | if start_read: 117 | init_str += line 118 | 119 | doc_str += init_str 120 | init.close() 121 | 122 | for path in split_paths[module]: 123 | f = open(path, 'r') 124 | docs = scanFile(f) 125 | if len(docs): 126 | doc_str += f"### {path}\n" 127 | for method in docs: 128 | doc_str += f"**{method}**:\n\n" 129 | doc_str += ''.join(docs[method]) + '\n\n' 130 | f.close() 131 | 132 | with open('README.md', 'w') as f: 133 | f.write(doc_str) 134 | -------------------------------------------------------------------------------- /scripts/publish.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | source .venv/bin/activate 4 | 5 | git config credential.helper "store --file=.git/credentials" 6 | echo "https://${GH_TOKEN}:@github.com" > .git/credentials 7 | 8 | git config user.email "andnpatterson@gmail.com" 9 | git config user.name "github-action" 10 | 11 | git fetch --all --tags 12 | 13 | git checkout -f master 14 | 15 | # bump the version 16 | cz bump --no-verify --yes --check-consistency 17 | 18 | # push to pypi repository 19 | python -m build 20 | python -m twine upload -u __token__ -p ${PYPI_TOKEN} --non-interactive dist/* 21 | 22 | pip install uv 23 | uv pip compile --extra=dev pyproject.toml -o requirements.txt 24 | git add requirements.txt 25 | git commit -m "ci: update requirements" || echo "No changes to commit" 26 | 27 | git push 28 | git push --tags 29 | -------------------------------------------------------------------------------- /scripts/run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | source .venv/bin/activate 5 | 6 | # pyright --stats 7 | 8 | export PYTHONPATH=PyExpUtils 9 | python3 -m unittest discover -p "*test_*.py" 10 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andnp/PyExpUtils/5d076ff1196368a936b18998afd00c80d4699857/tests/__init__.py -------------------------------------------------------------------------------- /tests/_utils/pandas.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | def check_equal(df1: pd.DataFrame, df2: pd.DataFrame): 4 | pd.testing.assert_frame_equal( 5 | df1.reset_index(drop=True), 6 | df2.reset_index(drop=True), 7 | check_like=True, 8 | ) 9 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andnp/PyExpUtils/5d076ff1196368a936b18998afd00c80d4699857/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/test_ExperimentDescription.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | from PyExpUtils.models.ExperimentDescription import ExperimentDescription, loadExperiment 4 | 5 | class TestSavingPath(unittest.TestCase): 6 | def test_path(self): 7 | desc = { 8 | 'name': 'test', 9 | 'algorithm': 'q', 10 | 'environment': 'mountaincar', 11 | 'metaParameters': { 12 | 'alpha': [0.01, 0.02], 13 | 'epsilon': 0.05, 14 | } 15 | } 16 | 17 | class MLExpDesc(ExperimentDescription): 18 | def __init__(self, d): 19 | super().__init__(d) 20 | self.algorithm = d['algorithm'] 21 | self.env = d['environment'] 22 | 23 | exp = MLExpDesc(desc) 24 | key = '{name}/{algorithm}/{env}/{params}/{run}' 25 | 26 | got = exp.interpolateSavePath(0, key=key) 27 | expected = 'test/q/mountaincar/alpha-0.01_epsilon-0.05/0' 28 | self.assertEqual(got, expected) 29 | 30 | def test_interpolateSavePath(self): 31 | desc = { 32 | 'metaParameters': { 33 | 'optimizer': { 34 | 'alpha': [0.1, 0.2], 35 | 'beta': [0.99, 0.999], 36 | }, 37 | 'epsilon': 0.05, 38 | }, 39 | } 40 | 41 | exp = ExperimentDescription(desc) 42 | key = '{params}' 43 | 44 | got = exp.interpolateSavePath(0, key=key) 45 | expected = 'epsilon-0.05_optimizer.alpha-0.1_optimizer.beta-0.99' 46 | self.assertEqual(got, expected) 47 | 48 | class TestPermutations(unittest.TestCase): 49 | def fakeDescription(self): 50 | return { 51 | 'metaParameters': { 52 | 'alpha': [0.01, 0.02, 0.04], 53 | 'epsilon': 0.05, 54 | 'gamma': [0.9] 55 | }, 56 | 'envParameters': { 57 | 'size': 30, 58 | 'noise': [0.01, 0.02], 59 | } 60 | } 61 | 62 | def test_permutable(self): 63 | desc = self.fakeDescription() 64 | # can specify a list of parameters to permute over 65 | exp = ExperimentDescription(desc, keys=['metaParameters']) 66 | 67 | # permutable defaults to 'metaParameters' being the only permutable key 68 | got = exp.permutable() 69 | self.assertDictEqual( 70 | got, 71 | { 72 | 'metaParameters': desc['metaParameters'] 73 | }, 74 | ) 75 | 76 | # can specify a list of parameters to permute over 77 | exp = ExperimentDescription(desc, keys=['metaParameters', 'envParameters']) 78 | got = exp.permutable() 79 | self.assertDictEqual( 80 | got, 81 | desc 82 | ) 83 | 84 | def test_getPermutations(self): 85 | desc = self.fakeDescription() 86 | exp = ExperimentDescription(desc) 87 | 88 | # can get permutation of metaParameters by default 89 | got = exp.getPermutation(0) 90 | expected = desc.copy() 91 | expected['metaParameters'] = { 92 | 'alpha': 0.01, 93 | 'epsilon': 0.05, 94 | 'gamma': 0.9, 95 | } 96 | self.assertDictEqual( 97 | got, 98 | expected, 99 | ) 100 | 101 | # can get permutation of multiple parameters 102 | exp = ExperimentDescription(desc, keys=['metaParameters', 'envParameters']) 103 | got = exp.getPermutation(0) 104 | expected = { 105 | 'metaParameters': { 106 | 'alpha': 0.01, 107 | 'epsilon': 0.05, 108 | 'gamma': 0.9, 109 | }, 110 | 'envParameters': { 111 | 'size': 30, 112 | 'noise': 0.01, 113 | }, 114 | } 115 | self.assertDictEqual(got, expected) 116 | 117 | def test_permutations(self): 118 | desc = self.fakeDescription() 119 | exp = ExperimentDescription(desc) 120 | 121 | got = exp.numPermutations() 122 | expected = 3 123 | self.assertEqual(got, expected) 124 | 125 | exp = ExperimentDescription(desc, keys=['envParameters']) 126 | got = exp.numPermutations() 127 | expected = 2 128 | self.assertEqual(got, expected) 129 | 130 | exp = ExperimentDescription(desc, keys=['metaParameters', 'envParameters']) 131 | got = exp.numPermutations() 132 | expected = 6 133 | self.assertEqual(got, expected) 134 | 135 | 136 | class TestExperimentName(unittest.TestCase): 137 | def test_fromFile(self): 138 | exp = loadExperiment('mock_repo/experiments/overfit/best/ann.json') 139 | 140 | got = exp.getExperimentName() 141 | expected = 'overfit/best' 142 | 143 | self.assertEqual(got, expected) 144 | 145 | def test_withCWD(self): 146 | exp = loadExperiment(f'{os.getcwd()}/mock_repo/experiments/overfit/best/ann.json') 147 | 148 | got = exp.getExperimentName() 149 | expected = 'overfit/best' 150 | 151 | self.assertEqual(got, expected) 152 | 153 | def test_withDotSlash(self): 154 | exp = loadExperiment('./mock_repo/experiments/overfit/best/ann.json') 155 | 156 | got = exp.getExperimentName() 157 | expected = 'overfit/best' 158 | 159 | self.assertEqual(got, expected) 160 | 161 | class TestRegressions(unittest.TestCase): 162 | def fakeDescription(self): 163 | return { 164 | 'metaParameters': { 165 | 'alpha': [0.01, 0.02, 0.04], 166 | 'epsilon': 0.05, 167 | 'gamma': [0.9] 168 | }, 169 | } 170 | 171 | def test_mutatingPermutationsDoesNotInvalidateCache(self): 172 | exp = ExperimentDescription(self.fakeDescription()) 173 | 174 | params = exp.getPermutation(0)['metaParameters'] 175 | 176 | # mutate params 177 | params['new_param'] = 1 178 | 179 | # check again 180 | params2 = exp.getPermutation(0)['metaParameters'] 181 | 182 | # 'new_param' should not appear on dict 183 | self.assertDictEqual(params2, { 184 | 'alpha': 0.01, 185 | 'epsilon': 0.05, 186 | 'gamma': 0.9, 187 | }) 188 | -------------------------------------------------------------------------------- /tests/results/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andnp/PyExpUtils/5d076ff1196368a936b18998afd00c80d4699857/tests/results/__init__.py -------------------------------------------------------------------------------- /tests/results/test_indices.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import shutil 4 | from PyExpUtils.results.indices import listIndices 5 | from PyExpUtils.models.ExperimentDescription import ExperimentDescription 6 | 7 | class RLExperiment(ExperimentDescription): 8 | def __init__(self, d): 9 | super().__init__(d) 10 | self.agent = d['agent'] 11 | self.environment = d['environment'] 12 | 13 | class TestIndices(unittest.TestCase): 14 | @classmethod 15 | def tearDownClass(cls): 16 | try: 17 | shutil.rmtree('.tmp') 18 | os.remove('.tmp.tar') 19 | except Exception: 20 | pass 21 | 22 | def test_listIndices(self): 23 | exp = ExperimentDescription({ 24 | 'metaParameters': { 25 | 'alpha': [0.01, 0.02, 0.04, 0.08, 0.16], 26 | 'lambda': [1.0, 0.99, 0.98, 0.96, 0.92], 27 | } 28 | }) 29 | 30 | expected = list(range(25)) 31 | got = list(listIndices(exp)) 32 | 33 | self.assertListEqual(got, expected) 34 | -------------------------------------------------------------------------------- /tests/results/test_tools.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pandas as pd 3 | import tests._utils.pandas as pdu 4 | from PyExpUtils.results.tools import subsetDF, splitByValue 5 | 6 | class TestTools(unittest.TestCase): 7 | def test_subsetDF(self): 8 | df = pd.DataFrame({ 9 | 'a': [1, 1, 1, 2, 2], 10 | 'b': [2, 4, 6, 8, 10], 11 | }) 12 | 13 | # check base case 14 | sub = subsetDF(df, {}) 15 | self.assertTrue(df.equals(sub)) 16 | 17 | # check simple case where a single value is specified 18 | sub = subsetDF(df, { 'a': 2 }) 19 | expect = pd.DataFrame({ 20 | 'a': [2, 2], 21 | 'b': [8, 10], 22 | }) 23 | self.assertTrue(sub.equals(expect)) 24 | 25 | # check can specify multiple columns 26 | sub = subsetDF(df, {'a': 1, 'b': 4}) 27 | expect = pd.DataFrame({ 28 | 'a': [1], 29 | 'b': [4], 30 | }) 31 | self.assertTrue(sub.equals(expect)) 32 | 33 | # check can specify a list of values for a column 34 | sub = subsetDF(df, {'b': [4, 6, 8]}) 35 | expect = pd.DataFrame({ 36 | 'a': [1, 1, 2], 37 | 'b': [4, 6, 8], 38 | }) 39 | self.assertTrue(sub.equals(expect)) 40 | 41 | # check can specify non-existent columns 42 | sub = subsetDF(df, {'a': 1, 'c': 22}) 43 | expect = pd.DataFrame({ 44 | 'a': [1, 1, 1], 45 | 'b': [2, 4, 6], 46 | }) 47 | self.assertTrue(sub.equals(expect)) 48 | 49 | def test_splitByValue(self): 50 | df = pd.DataFrame({ 51 | 'a': [1, 2, 3, 4, 5, 6], 52 | 'b': [0, 0, 1, 1, 1, 2], 53 | }) 54 | 55 | parts = list(splitByValue(df, 'b')) 56 | 57 | self.assertEqual(parts[0][0], 0) 58 | pdu.check_equal(parts[0][1], pd.DataFrame({ 59 | 'a': [1, 2], 60 | 'b': [0, 0], 61 | })) 62 | 63 | self.assertEqual(parts[1][0], 1) 64 | pdu.check_equal(parts[1][1], pd.DataFrame({ 65 | 'a': [3, 4, 5], 66 | 'b': [1, 1, 1], 67 | })) 68 | 69 | self.assertEqual(parts[2][0], 2) 70 | pdu.check_equal(parts[2][1], pd.DataFrame({ 71 | 'a': [6], 72 | 'b': [2], 73 | })) 74 | -------------------------------------------------------------------------------- /tests/results/test_voting.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import unittest 3 | import numpy as np 4 | import copy 5 | from PyExpUtils.results.voting import RankedBallot, RankedCandidate, ScoredCandidate, buildBallot, raynaud, small, confidenceRanking, firstPastPost, instantRunoff, scoreRanking 6 | 7 | def fakeElection1(): 8 | return [ 9 | buildBallot([ 10 | # name, rank, score 11 | RankedCandidate(3, 0, 63), 12 | RankedCandidate(4, 1, 44), 13 | RankedCandidate(5, 2, 32), 14 | RankedCandidate(0, 2, 20), 15 | RankedCandidate(8, 3, np.nan), 16 | ]), 17 | buildBallot([ 18 | RankedCandidate(4, 0, 63), 19 | RankedCandidate(5, 0, 59), 20 | RankedCandidate(0, 1, 20), 21 | RankedCandidate(3, 2, 18), 22 | RankedCandidate(8, 3, np.nan), 23 | ]), 24 | buildBallot([ 25 | RankedCandidate(5, 0, 32), 26 | RankedCandidate(4, 1, 28), 27 | RankedCandidate(3, 2, 25), 28 | RankedCandidate(0, 2, 20), 29 | RankedCandidate(8, 3, np.nan), 30 | ]), 31 | buildBallot([ 32 | RankedCandidate(0, 0, 66), 33 | RankedCandidate(3, 1, 34), 34 | RankedCandidate(4, 1, 33), 35 | RankedCandidate(5, 2, 32), 36 | RankedCandidate(8, 3, np.nan), 37 | ]), 38 | ] 39 | 40 | def fakeElection2(): 41 | return [ 42 | buildBallot([ 43 | RankedCandidate(3, 0, 63), 44 | RankedCandidate(4, 1, 44), 45 | RankedCandidate(5, 2, 32), 46 | RankedCandidate(0, 3, 20), 47 | RankedCandidate(8, 4, np.nan), 48 | ]), 49 | buildBallot([ 50 | RankedCandidate(4, 0, 63), 51 | RankedCandidate(5, 0, 59), 52 | RankedCandidate(0, 1, 20), 53 | RankedCandidate(3, 2, 18), 54 | RankedCandidate(8, 3, np.nan), 55 | ]), 56 | buildBallot([ 57 | RankedCandidate(5, 0, 32), 58 | RankedCandidate(4, 1, 28), 59 | RankedCandidate(3, 2, 25), 60 | RankedCandidate(0, 2, 20), 61 | RankedCandidate(8, 3, np.nan), 62 | ]), 63 | buildBallot([ 64 | RankedCandidate(0, 0, 66), 65 | RankedCandidate(3, 1, 34), 66 | RankedCandidate(4, 2, 33), 67 | RankedCandidate(5, 2, 32), 68 | RankedCandidate(8, 3, np.nan), 69 | ]), 70 | ] 71 | 72 | def buildByProportion(ballotPairs: List[Tuple[float, RankedBallot]], total: int): 73 | ballots: List[RankedBallot] = [] 74 | 75 | for pair in ballotPairs: 76 | proportion, ballot = pair 77 | 78 | num = int(proportion * total) 79 | for _ in range(num): 80 | ballots.append(copy.deepcopy(ballot)) 81 | 82 | return ballots 83 | 84 | # taken from http://www.cs.angelo.edu/~rlegrand/rbvote/desc.html 85 | def fakeElection3(): 86 | return buildByProportion([ 87 | 88 | (0.098, buildBallot([ 89 | RankedCandidate('Abby', 0), 90 | RankedCandidate('Cora', 1), 91 | RankedCandidate('Erin', 2), 92 | RankedCandidate('Dave', 3), 93 | RankedCandidate('Brad', 4), 94 | ])), 95 | (0.064, buildBallot([ 96 | RankedCandidate('Brad', 0), 97 | RankedCandidate('Abby', 1), 98 | RankedCandidate('Erin', 2), 99 | RankedCandidate('Cora', 3), 100 | RankedCandidate('Dave', 4), 101 | ])), 102 | (0.012, buildBallot([ 103 | RankedCandidate('Brad', 0), 104 | RankedCandidate('Abby', 1), 105 | RankedCandidate('Erin', 2), 106 | RankedCandidate('Dave', 3), 107 | RankedCandidate('Cora', 4), 108 | ])), 109 | (0.098, buildBallot([ 110 | RankedCandidate('Brad', 0), 111 | RankedCandidate('Erin', 1), 112 | RankedCandidate('Abby', 2), 113 | RankedCandidate('Cora', 3), 114 | RankedCandidate('Dave', 4), 115 | ])), 116 | (0.013, buildBallot([ 117 | RankedCandidate('Brad', 0), 118 | RankedCandidate('Erin', 1), 119 | RankedCandidate('Abby', 2), 120 | RankedCandidate('Dave', 3), 121 | RankedCandidate('Cora', 4), 122 | ])), 123 | (0.125, buildBallot([ 124 | RankedCandidate('Brad', 0), 125 | RankedCandidate('Erin', 1), 126 | RankedCandidate('Dave', 2), 127 | RankedCandidate('Abby', 3), 128 | RankedCandidate('Cora', 4), 129 | ])), 130 | (0.124, buildBallot([ 131 | RankedCandidate('Cora', 0), 132 | RankedCandidate('Abby', 1), 133 | RankedCandidate('Erin', 2), 134 | RankedCandidate('Dave', 3), 135 | RankedCandidate('Brad', 4), 136 | ])), 137 | (0.076, buildBallot([ 138 | RankedCandidate('Cora', 0), 139 | RankedCandidate('Erin', 1), 140 | RankedCandidate('Abby', 2), 141 | RankedCandidate('Dave', 3), 142 | RankedCandidate('Brad', 4), 143 | ])), 144 | (0.021, buildBallot([ 145 | RankedCandidate('Dave', 0), 146 | RankedCandidate('Abby', 1), 147 | RankedCandidate('Brad', 2), 148 | RankedCandidate('Erin', 3), 149 | RankedCandidate('Cora', 4), 150 | ])), 151 | (0.030, buildBallot([ 152 | RankedCandidate('Dave', 0), 153 | RankedCandidate('Brad', 1), 154 | RankedCandidate('Abby', 2), 155 | RankedCandidate('Erin', 3), 156 | RankedCandidate('Cora', 4), 157 | ])), 158 | (0.098, buildBallot([ 159 | RankedCandidate('Dave', 0), 160 | RankedCandidate('Brad', 1), 161 | RankedCandidate('Erin', 2), 162 | RankedCandidate('Cora', 3), 163 | RankedCandidate('Abby', 4), 164 | ])), 165 | (0.139, buildBallot([ 166 | RankedCandidate('Dave', 0), 167 | RankedCandidate('Cora', 1), 168 | RankedCandidate('Abby', 2), 169 | RankedCandidate('Brad', 3), 170 | RankedCandidate('Erin', 4), 171 | ])), 172 | (0.023, buildBallot([ 173 | RankedCandidate('Dave', 0), 174 | RankedCandidate('Cora', 1), 175 | RankedCandidate('Brad', 2), 176 | RankedCandidate('Abby', 3), 177 | RankedCandidate('Erin', 4), 178 | ])), 179 | 180 | ], 1000) 181 | 182 | class TestVoting(unittest.TestCase): 183 | def test_confidenceRanking(self): 184 | scores = [ 185 | ScoredCandidate(0, 20, 2), 186 | ScoredCandidate(3, 25, 1), 187 | ScoredCandidate(4, 63, 5), 188 | ScoredCandidate(5, 32, 8), 189 | ScoredCandidate(8, np.nan, np.nan), 190 | ] 191 | 192 | expected = [ 193 | RankedCandidate(4, 0, 63), 194 | RankedCandidate(5, 1, 32), 195 | RankedCandidate(3, 1, 25), 196 | RankedCandidate(0, 2, 20), 197 | ] 198 | 199 | got = confidenceRanking(scores, stderrs=1, prefer='big') 200 | self.assertEqual(expected, got) 201 | 202 | expected = [ 203 | RankedCandidate(0, 0, 20), 204 | RankedCandidate(3, 1, 25), 205 | RankedCandidate(5, 1, 32), 206 | RankedCandidate(4, 2, 63), 207 | ] 208 | 209 | got = confidenceRanking(scores, stderrs=1, prefer='small') 210 | self.assertEqual(expected, got) 211 | 212 | def test_scoreRanking(self): 213 | scores = [ 214 | ScoredCandidate(0, 20, 2), 215 | ScoredCandidate(3, 25, 1), 216 | ScoredCandidate(4, 63, 5), 217 | ScoredCandidate(5, 32, 8), 218 | ScoredCandidate(8, np.nan, np.nan), 219 | ] 220 | 221 | expected = [ 222 | RankedCandidate(4, 0, 63), 223 | RankedCandidate(5, 1, 32), 224 | RankedCandidate(3, 2, 25), 225 | RankedCandidate(0, 3, 20), 226 | ] 227 | 228 | got = scoreRanking(scores, prefer='big') 229 | self.assertEqual(expected, got) 230 | 231 | def test_instantRunoff(self): 232 | ballots = fakeElection1() 233 | winner = instantRunoff(ballots) 234 | self.assertEqual(winner, 4) 235 | 236 | ballots = fakeElection2() 237 | winner = instantRunoff(ballots) 238 | self.assertEqual(winner, 3) 239 | 240 | ballots = fakeElection3() 241 | winner = instantRunoff(ballots) 242 | self.assertEqual(winner, 'Brad') 243 | 244 | def test_firstPastPost(self): 245 | ballots = fakeElection1() 246 | winner = firstPastPost(ballots) 247 | self.assertEqual(winner, 5) 248 | 249 | ballots = fakeElection2() 250 | winner = firstPastPost(ballots) 251 | self.assertEqual(winner, 5) 252 | 253 | ballots = fakeElection3() 254 | winner = firstPastPost(ballots) 255 | self.assertEqual(winner, 'Brad') 256 | 257 | def test_small(self): 258 | ballots = fakeElection1() 259 | winner = small(ballots) 260 | self.assertEqual(winner, 4) 261 | 262 | ballots = fakeElection2() 263 | winner = small(ballots) 264 | self.assertEqual(winner, 4) 265 | 266 | ballots = fakeElection3() 267 | winner = small(ballots) 268 | self.assertEqual(winner, 'Brad') 269 | 270 | def test_raynaud(self): 271 | ballots = fakeElection1() 272 | winner = raynaud(ballots) 273 | self.assertEqual(winner, 4) 274 | 275 | ballots = fakeElection2() 276 | winner = raynaud(ballots) 277 | self.assertEqual(winner, 3) 278 | 279 | ballots = fakeElection3() 280 | winner = raynaud(ballots) 281 | self.assertEqual(winner, 'Abby') 282 | -------------------------------------------------------------------------------- /tests/runner/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andnp/PyExpUtils/5d076ff1196368a936b18998afd00c80d4699857/tests/runner/__init__.py -------------------------------------------------------------------------------- /tests/runner/test_parallel.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from PyExpUtils.runner.parallel import build 3 | 4 | class TestParallel(unittest.TestCase): 5 | def test_build(self): 6 | d = { 7 | 'executable': 'thingDoer.exe', 8 | 'cores': 22, 9 | 'tasks': [1, 2, 3, 4, 5], 10 | } 11 | 12 | got = build(d) 13 | expected = 'parallel -j 22 thingDoer.exe ::: 1 2 3 4 5' 14 | self.assertEqual(got, expected) 15 | -------------------------------------------------------------------------------- /tests/runner/test_slurm.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from PyExpUtils.runner.Slurm import MultiNodeOptions, SingleNodeOptions, to_cmdline_flags 3 | 4 | class TestSlurm(unittest.TestCase): 5 | def test_Options(self): 6 | opts = SingleNodeOptions( 7 | account='def-whitem', 8 | time='2:59:59', 9 | cores=2, 10 | mem_per_core=4, 11 | sequential=30, 12 | ) 13 | 14 | got = to_cmdline_flags(opts) 15 | expected = '--account=def-whitem --cpus-per-task=1 --mem-per-cpu=4096M --nodes=1 --ntasks=2 --output=$SCRATCH/job_output_%j.txt --time=2:59:59' 16 | self.assertEqual(got, expected) 17 | 18 | opts = MultiNodeOptions( 19 | account='def-whitem', 20 | time='2:59:59', 21 | cores=8, 22 | mem_per_core=4, 23 | sequential=30, 24 | ) 25 | 26 | got = to_cmdline_flags(opts) 27 | expected = '--account=def-whitem --cpus-per-task=1 --mem-per-cpu=4096M --ntasks=8 --output=$SCRATCH/job_output_%j.txt --time=2:59:59' 28 | self.assertEqual(got, expected) 29 | -------------------------------------------------------------------------------- /tests/test_FileSystemContext.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import shutil 4 | from PyExpUtils.FileSystemContext import FileSystemContext 5 | 6 | class TestFileSystemContext(unittest.TestCase): 7 | @classmethod 8 | def tearDownClass(cls): 9 | try: 10 | shutil.rmtree('.tmp') 11 | os.remove('path.tar') 12 | except Exception: 13 | pass 14 | 15 | def test_getBase(self): 16 | ctx = FileSystemContext('path/to/results', 'scratch') 17 | 18 | got = ctx.getBase() 19 | expected = 'scratch' 20 | 21 | self.assertEqual(got, expected) 22 | 23 | def test_resolveEmpty(self): 24 | ctx = FileSystemContext('path/to/results', 'scratch') 25 | 26 | got = ctx.resolve() 27 | expected = 'scratch/path/to/results' 28 | 29 | self.assertEqual(got, expected) 30 | 31 | def test_resolveFile(self): 32 | ctx = FileSystemContext('path/to/results', 'scratch') 33 | 34 | got = ctx.resolve('rmsve.npy') 35 | expected = 'scratch/path/to/results/rmsve.npy' 36 | 37 | self.assertEqual(got, expected) 38 | 39 | def test_resolveFull(self): 40 | ctx = FileSystemContext('path/to/results', 'scratch') 41 | 42 | got = ctx.resolve('scratch/path/to/results/rmsve.npy') 43 | expected = 'scratch/path/to/results/rmsve.npy' 44 | 45 | self.assertEqual(got, expected) 46 | 47 | def test_resolveParentDir(self): 48 | ctx = FileSystemContext('path/to/results', 'scratch') 49 | 50 | got = ctx.resolve('../experiment.json') 51 | expected = 'scratch/path/to/experiment.json' 52 | 53 | self.assertEqual(got, expected) 54 | 55 | def test_remove(self): 56 | ctx = FileSystemContext('path/to/results', '.tmp/test_remove') 57 | 58 | ctx.ensureExists() 59 | path = ctx.resolve('test.txt') 60 | with open(path, 'w') as f: 61 | f.write('hey there') 62 | 63 | self.assertTrue(os.path.isfile(f'{ctx.getBase()}/path/to/results/test.txt')) 64 | 65 | ctx.remove() 66 | self.assertFalse(os.path.isfile(f'{ctx.getBase()}/path/to/results/test.txt')) 67 | 68 | class TestRegressions(unittest.TestCase): 69 | @classmethod 70 | def tearDownClass(cls): 71 | try: 72 | shutil.rmtree('.tmp') 73 | os.remove('path.tar') 74 | except Exception: 75 | pass 76 | 77 | def test_resolveNoBase(self): 78 | ctx = FileSystemContext('path/to/results') 79 | 80 | got = ctx.resolve('test.txt') 81 | expected = 'path/to/results/test.txt' 82 | 83 | self.assertEqual(got, expected) 84 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andnp/PyExpUtils/5d076ff1196368a936b18998afd00c80d4699857/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/test_Collector.py: -------------------------------------------------------------------------------- 1 | from PyExpUtils.collection.Collector import Collector 2 | from PyExpUtils.collection.Sampler import Window, Subsample 3 | import unittest 4 | 5 | class TestCollector(unittest.TestCase): 6 | def test_collect(self): 7 | collector = Collector() 8 | collector.setIdx(0) 9 | 10 | for i in range(10): 11 | collector.collect('data', i * 2) 12 | collector.next_frame() 13 | 14 | got = collector.get('data', 0) 15 | expected = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] 16 | self.assertEqual(got, expected) 17 | 18 | def test_evaluate(self): 19 | collector = Collector(idx=0) 20 | 21 | for i in range(5): 22 | ev = lambda: i 23 | collector.evaluate('data', ev) 24 | collector.next_frame() 25 | 26 | expected = [0, 1, 2, 3, 4] 27 | got = collector.get('data', 0) 28 | 29 | self.assertEqual(got, expected) 30 | 31 | def test_window(self): 32 | collector = Collector( 33 | config={ 34 | 'a': Window(3), 35 | }, 36 | idx=0, 37 | ) 38 | 39 | collector.collect('a', 0) 40 | collector.next_frame() 41 | collector.collect('a', 1) 42 | collector.next_frame() 43 | collector.collect('a', 5) 44 | collector.next_frame() 45 | collector.collect('a', 3) 46 | collector.next_frame() 47 | 48 | self.assertEqual(collector.get('a', 0), [2.0]) 49 | 50 | collector.collect('a', 4) 51 | collector.next_frame() 52 | collector.collect('a', 5) 53 | collector.next_frame() 54 | 55 | self.assertEqual(collector.get('a', 0), [2.0, 4.0]) 56 | 57 | def test_subsample(self): 58 | collector = Collector( 59 | config={ 60 | 'a': Subsample(3), 61 | }, 62 | idx=0, 63 | ) 64 | 65 | collector.collect('a', 0) 66 | collector.next_frame() 67 | collector.collect('a', 1) 68 | collector.next_frame() 69 | collector.collect('a', 2) 70 | collector.next_frame() 71 | 72 | self.assertEqual(collector.get('a', 0), [0]) 73 | 74 | collector.collect('a', 3) 75 | collector.next_frame() 76 | 77 | self.assertEqual(collector.get('a', 0), [0, 3]) 78 | -------------------------------------------------------------------------------- /tests/utils/test_arrays.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from PyExpUtils.utils.arrays import argsmax, argsmax2, deduplicate, downsample, fillRest, first, last, npPadUneven, padUneven, partition, sampleFrequency 4 | 5 | class TestArrays(unittest.TestCase): 6 | def test_fillRest(self): 7 | # base functionality 8 | arr = [1, 2, 3, 4] 9 | 10 | got = fillRest(arr, 5, 10) 11 | expected = [1, 2, 3, 4, 5, 5, 5, 5, 5, 5] # length 10 12 | 13 | self.assertEqual(got, expected) 14 | 15 | # degenerate length 16 | arr = [1, 2, 3, 4] 17 | 18 | got = fillRest(arr, 5, 2) 19 | expected = [1, 2, 3, 4] 20 | 21 | self.assertEqual(got, expected) 22 | 23 | def test_first(self): 24 | # lists 25 | arr = [1, 2, 3, 4] 26 | 27 | got = first(arr) 28 | expected = 1 29 | 30 | self.assertEqual(got, expected) 31 | 32 | # iterators 33 | arr = ['a', 'b', 'c'] 34 | it = arr.__iter__() 35 | 36 | got = first(it) 37 | expected = 'a' 38 | 39 | self.assertEqual(got, expected) 40 | 41 | def test_last(self): 42 | # base functionality 43 | arr = [1, 2, 3, 4] 44 | 45 | got = last(arr) 46 | expected = 4 47 | 48 | self.assertEqual(got, expected) 49 | 50 | def test_partition(self): 51 | # lists 52 | arr = [1, 2, 3, 4, 5, 6] 53 | 54 | l, r = partition(arr, lambda a: a > 3) 55 | self.assertEqual(list(l), [4, 5, 6]) 56 | self.assertEqual(list(r), [1, 2, 3]) 57 | 58 | # iterators 59 | arr = [1, 2, 3, 4, 5, 6] 60 | it = arr.__iter__() 61 | 62 | l, r = partition(it, lambda a: a > 3) 63 | self.assertEqual(list(l), [4, 5, 6]) 64 | self.assertEqual(list(r), [1, 2, 3]) 65 | 66 | def test_deduplicate(self): 67 | arr = [1, 2, 3, 2, 5, 7, 1, 2, 3] 68 | 69 | got = deduplicate(arr) 70 | expected = [1, 2, 3, 5, 7] 71 | 72 | # note that order isn't guaranteed so this is a brittle test... 73 | # maybe should fix this later 74 | self.assertListEqual(got, expected) 75 | 76 | def test_sampleFrequency(self): 77 | arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 78 | 79 | got = sampleFrequency(arr, percent=.1) 80 | expected = 10 81 | 82 | self.assertEqual(expected, got) 83 | 84 | got = sampleFrequency(arr, percent=.23) 85 | expected = 5 86 | 87 | self.assertEqual(expected, got) 88 | 89 | got = sampleFrequency(arr, num=4) 90 | expected = 2 91 | 92 | self.assertEqual(expected, got) 93 | 94 | def test_downsample(self): 95 | arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 96 | 97 | got = downsample(arr, percent=0.23, method='window') 98 | expected = [3, 8] 99 | 100 | self.assertEqual(got, expected) 101 | 102 | got = downsample(arr, num=3, method='window') 103 | expected = [2, 5, 8] 104 | 105 | self.assertEqual(got, expected) 106 | 107 | got = downsample(arr, percent=0.23, method='subsample') 108 | expected = [1, 6] 109 | 110 | self.assertEqual(got, expected) 111 | 112 | got = downsample(arr, num=4, method='subsample') 113 | expected = [2, 4, 6, 8] 114 | 115 | def test_argsmax(self): 116 | arr = np.array([0, 0, 1, 2]) 117 | 118 | got = argsmax(arr) 119 | expected = [3] 120 | 121 | self.assertEqual(got, expected) 122 | 123 | arr = np.array([0, 2, 1, 2]) 124 | 125 | got = argsmax(arr) 126 | expected = [1, 3] 127 | 128 | self.assertEqual(got, expected) 129 | 130 | arr = np.array([ 131 | [0, 1, 1, 0, 1], 132 | [2, 0, 0, 1, 2], 133 | ]) 134 | 135 | got = argsmax2(arr) 136 | expected = [ 137 | [1, 2, 4], 138 | [0, 4], 139 | ] 140 | 141 | self.assertEqual(got, expected) 142 | 143 | def test_padUneven(self): 144 | arr = [ 145 | [1., 2.], 146 | [2., 3., 4.], 147 | [1.], 148 | ] 149 | 150 | res = padUneven(arr, np.nan) 151 | 152 | e = [ 153 | [1., 2., np.nan], 154 | [2., 3., 4.], 155 | [1., np.nan, np.nan], 156 | ] 157 | 158 | self.assertEqual(res, e) 159 | 160 | def test_npPadUneven(self): 161 | arr = [ 162 | np.array([1., 2]), 163 | np.array([2., 3, 4]), 164 | np.array([1.]), 165 | ] 166 | 167 | res = npPadUneven(arr, np.nan) 168 | 169 | e = np.array([ 170 | [1., 2, np.nan], 171 | [2, 3, 4], 172 | [1, np.nan, np.nan], 173 | ]) 174 | 175 | self.assertTrue(np.allclose(res, e, equal_nan=True)) 176 | -------------------------------------------------------------------------------- /tests/utils/test_cmdline.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from PyExpUtils.utils.cmdline import flagString 3 | 4 | class TestCmdline(unittest.TestCase): 5 | def test_flagString(self): 6 | # base functionality 7 | pairs = [ 8 | ('--test', 'a'), 9 | ('--trial', 'b'), 10 | ('--exam', 'c'), 11 | ] 12 | 13 | got = flagString(pairs) 14 | expected = '--exam=c --test=a --trial=b' 15 | 16 | self.assertEqual(got, expected) 17 | 18 | # removes None entries 19 | pairs = [ 20 | ('--test', 'a'), 21 | ('--exam', None), 22 | ] 23 | 24 | got = flagString(pairs) 25 | expected = '--test=a' 26 | 27 | self.assertEqual(got, expected) 28 | 29 | # can join arguments with arbitrary string 30 | pairs = [ 31 | ('--test', 'a'), 32 | ('--exam', 'b'), 33 | ] 34 | 35 | got = flagString(pairs, ' ') 36 | expected = '--exam b --test a' 37 | 38 | self.assertEqual(got, expected) 39 | -------------------------------------------------------------------------------- /tests/utils/test_csv.py: -------------------------------------------------------------------------------- 1 | from PyExpUtils.utils.csv import arrayToCsv, buildCsvHeader, buildCsvParams 2 | from PyExpUtils.models.ExperimentDescription import ExperimentDescription 3 | import unittest 4 | import numpy as np 5 | 6 | class TestCsv(unittest.TestCase): 7 | def fakeDoubleDescription(self): 8 | return { 9 | 'metaParameters': { 10 | 'alpha': [0.01, 0.02, 0.04], 11 | 'epsilon': 0.05, 12 | 'gamma': [0.9] 13 | }, 14 | 'envParameters': { 15 | 'size': 30, 16 | 'noise': [0.01, 0.02], 17 | } 18 | } 19 | 20 | def fakeDescription(self): 21 | return { 22 | 'metaParameters': { 23 | 'alpha': [0.01, 0.02, 0.04], 24 | 'epsilon': 0.05, 25 | 'gamma': [0.9] 26 | } 27 | } 28 | 29 | def test_buildCsvParams(self): 30 | exp = ExperimentDescription(self.fakeDescription()) 31 | 32 | got = buildCsvParams(exp, 0) 33 | expected = '0.01,0.05,0.9' 34 | 35 | self.assertEqual(got, expected) 36 | 37 | exp = ExperimentDescription(self.fakeDoubleDescription(), keys=['metaParameters', 'envParameters']) 38 | 39 | got = buildCsvParams(exp, 1) 40 | expected = '0.02,30,0.01,0.05,0.9' 41 | 42 | self.assertEqual(got, expected) 43 | 44 | def test_buildCsvHeader(self): 45 | exp = ExperimentDescription(self.fakeDescription()) 46 | 47 | got = buildCsvHeader(exp) 48 | expected = 'alpha,epsilon,gamma' 49 | 50 | self.assertEqual(got, expected) 51 | 52 | exp = ExperimentDescription(self.fakeDoubleDescription(), keys=['metaParameters', 'envParameters']) 53 | 54 | got = buildCsvHeader(exp) 55 | expected = 'envParameters.noise,envParameters.size,metaParameters.alpha,metaParameters.epsilon,metaParameters.gamma' 56 | 57 | self.assertEqual(got, expected) 58 | 59 | def test_arrayToCsv(self): 60 | data = np.arange(5) / 4 61 | 62 | got = arrayToCsv(data, 1) 63 | expected = '0.0,0.2,0.5,0.8,1.0' 64 | 65 | self.assertEqual(got, expected) 66 | -------------------------------------------------------------------------------- /tests/utils/test_dict.py: -------------------------------------------------------------------------------- 1 | from PyExpUtils.utils.dict import equal, flatKeys, get, hyphenatedStringify, merge, partialEqual, pick, subset 2 | import unittest 3 | 4 | class TestDict(unittest.TestCase): 5 | def test_merge(self): 6 | # base functionality 7 | d1 = { 8 | 'a': [1, 2, 3], 9 | 'b': False, 10 | 'c': { 11 | 'aa': [4, 5, 6], 12 | }, 13 | } 14 | 15 | d2 = { 16 | 'b': True, 17 | 'd': 22, 18 | } 19 | 20 | got = merge(d1, d2) 21 | expected = { 22 | 'a': [1, 2, 3], 23 | 'b': True, 24 | 'c': { 25 | 'aa': [4, 5, 6], 26 | }, 27 | 'd': 22, 28 | } 29 | 30 | self.assertDictEqual(got, expected) 31 | 32 | def test_hyphenateStringify(self): 33 | # base functionality 34 | d = { 35 | 'alpha': 0.1, 36 | 'beta': 0.99, 37 | 'gamma': 1, 38 | 'optimizer': 'SuperFast', 39 | } 40 | 41 | got = hyphenatedStringify(d) 42 | expected = 'alpha-0.1_beta-0.99_gamma-1_optimizer-SuperFast' 43 | 44 | self.assertEqual(got, expected) 45 | 46 | # keys are sorted alphabetically 47 | d = { 48 | 'beta': 0.95, 49 | 'alpha': -0.5, 50 | } 51 | 52 | got = hyphenatedStringify(d) 53 | expected = 'alpha--0.5_beta-0.95' 54 | 55 | self.assertEqual(got, expected) 56 | 57 | # nested dictionaries are iterated 58 | d = { 59 | 'epsilon': 0.1, 60 | 'optimizer': { 61 | 'alpha': 0.9, 62 | 'beta': 0.99, 63 | } 64 | } 65 | 66 | got = hyphenatedStringify(d) 67 | expected = 'epsilon-0.1_optimizer.alpha-0.9_optimizer.beta-0.99' 68 | self.assertEqual(got, expected) 69 | 70 | def test_pick(self): 71 | # base functionality 72 | d = { 73 | 'a': 1, 74 | 'b': 22, 75 | 'c': 333, 76 | } 77 | 78 | got = pick(d, 'a') 79 | expected = 1 80 | 81 | self.assertEqual(got, expected) 82 | 83 | # multiple keys 84 | d = { 85 | 'a': 1, 86 | 'b': 22, 87 | 'c': 333, 88 | } 89 | 90 | got = pick(d, ['a', 'b']) 91 | expected = { 92 | 'a': 1, 93 | 'b': 22, 94 | } 95 | 96 | self.assertDictEqual(got, expected) 97 | 98 | def test_get(self): 99 | # base functionality 100 | d = { 101 | 'a': { 102 | 'b': 2, 103 | }, 104 | 'b': 3, 105 | 'c': [{ 'd': 4 }], 106 | 'd': { 107 | 'e': [5, 4, 3, 2, 1, 0], 108 | }, 109 | } 110 | 111 | got = get(d, 'a') 112 | expected = { 'b': 2 } 113 | self.assertDictEqual(got, expected) 114 | 115 | got = get(d, 'a.b') 116 | expected = 2 117 | self.assertEqual(got, expected) 118 | 119 | got = get(d, 'b') 120 | expected = 3 121 | self.assertEqual(got, expected) 122 | 123 | got = get(d, 'c.[0].d') 124 | expected = 4 125 | self.assertEqual(got, expected) 126 | 127 | got = get(d, 'd.e.[3]') 128 | expected = 2 129 | self.assertEqual(got, expected) 130 | 131 | got = get(d, 'd.f.[3]', 'merp') 132 | expected = 'merp' 133 | self.assertEqual(got, expected) 134 | 135 | got = get(d, 'd.e.[10]', 'merp') 136 | expected = 'merp' 137 | self.assertEqual(got, expected) 138 | 139 | got = get(d, 'd.e') 140 | expected = [5, 4, 3, 2, 1, 0] 141 | self.assertListEqual(got, expected) 142 | 143 | def test_equal(self): 144 | # base functionality 145 | d1 = { 146 | 'a': [1, 2, 3], 147 | 'b': 'a', 148 | 'c': { 149 | 'aa': 22, 150 | }, 151 | } 152 | 153 | d2 = { 154 | 'a': [1, 2, 3], 155 | 'b': 'a', 156 | 'c': { 157 | 'aa': 22, 158 | }, 159 | } 160 | 161 | got = equal(d1, d2) 162 | self.assertTrue(got) 163 | 164 | # missing keys 165 | d1 = { 166 | 'a': 22, 167 | 'b': 'a', 168 | } 169 | 170 | d2 = { 171 | 'a': 22, 172 | } 173 | 174 | got = equal(d1, d2) 175 | self.assertFalse(got) 176 | 177 | # ignore keys 178 | d1 = { 179 | 'a': 22, 180 | 'b': 'a', 181 | } 182 | 183 | d2 = { 184 | 'a': 22, 185 | } 186 | 187 | got = equal(d1, d2, ['b']) 188 | self.assertTrue(got) 189 | 190 | # d1 missing keys 191 | d1 = { 192 | 'a': 22, 193 | } 194 | 195 | d2 = { 196 | 'a': 22, 197 | 'b': 'a', 198 | } 199 | 200 | got = equal(d1, d2) 201 | self.assertFalse(got) 202 | 203 | def test_flatKeys(self): 204 | d = { 205 | 'a': 22, 206 | 'b': { 207 | 'a': 'hey', 208 | 'b': { 209 | 'a': 2, 210 | }, 211 | }, 212 | } 213 | 214 | got = flatKeys(d) 215 | expected = [ 'a', 'b.a', 'b.b.a' ] 216 | 217 | self.assertListEqual(got, expected) 218 | 219 | d = { 220 | 'a': 22, 221 | 'b': [ 222 | { 'a': 'hi' }, 223 | { 'a': 'there' }, 224 | { 'b': 'friend' }, 225 | ], 226 | } 227 | 228 | got = flatKeys(d) 229 | expected = [ 'a', 'b.[0].a', 'b.[1].a', 'b.[2].b' ] 230 | 231 | self.assertListEqual(got, expected) 232 | 233 | def test_subset(self): 234 | d1 = { 235 | 'a': 22, 236 | } 237 | 238 | d2 = { 239 | 'a': 22, 240 | 'b': 23, 241 | } 242 | 243 | got = subset(d1, d2) 244 | self.assertTrue(got) 245 | 246 | got = subset(d2, d1) 247 | self.assertFalse(got) 248 | 249 | d1 = { 250 | 'a': 22, 251 | 'b': { 252 | 'a': 21, 253 | }, 254 | } 255 | 256 | d2 = { 257 | 'a': 22, 258 | 'b': { 259 | 'a': 21, 260 | 'b': 'hey', 261 | }, 262 | } 263 | 264 | got = subset(d1, d2) 265 | self.assertTrue(got) 266 | 267 | got = subset(d2, d1) 268 | self.assertFalse(got) 269 | 270 | def test_partialEqual(self): 271 | d1 = {} 272 | d2 = { 'a': 22 } 273 | 274 | got = partialEqual(d1, d2) 275 | self.assertTrue(got) 276 | 277 | got = partialEqual(d2, d1) 278 | self.assertTrue(got) 279 | 280 | d1 = { 281 | 'a': 22, 282 | 'c': 31, 283 | } 284 | 285 | d2 = { 286 | 'a': 22, 287 | 'b': 31, 288 | } 289 | 290 | got = partialEqual(d1, d2) 291 | self.assertTrue(got) 292 | 293 | got = partialEqual(d2, d1) 294 | self.assertTrue(got) 295 | -------------------------------------------------------------------------------- /tests/utils/test_generator.py: -------------------------------------------------------------------------------- 1 | from PyExpUtils.utils.generator import group, windowAverage 2 | import unittest 3 | 4 | class TestGenerator(unittest.TestCase): 5 | def test_group(self): 6 | arr = [1, 2, 3, 4, 5, 6, 7, 8] 7 | 8 | got = list(group(arr, 3)) 9 | expected = [ 10 | [1, 2, 3], 11 | [4, 5, 6], 12 | [7, 8], # last group may not be same size as rest 13 | ] 14 | 15 | self.assertEqual(got, expected) 16 | 17 | def test_windowAverage(self): 18 | arr = [1, 2, 3, 4, 5, 6, 7, 8] 19 | 20 | got = list(windowAverage(arr, 3)) 21 | expected = [2, 5, 7.5] 22 | 23 | self.assertEqual(got, expected) 24 | -------------------------------------------------------------------------------- /tests/utils/test_path.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from PyExpUtils.utils.path import fileName, join, rest, up 3 | 4 | class TestPath(unittest.TestCase): 5 | def test_rest(self): 6 | test_path = 'this/is/a/test' 7 | 8 | got = rest(test_path) 9 | expected = 'is/a/test' 10 | 11 | self.assertEqual(got, expected) 12 | 13 | test_path = '/this/is/a/test' 14 | 15 | got = rest(test_path) 16 | expected = 'this/is/a/test' 17 | 18 | self.assertEqual(got, expected) 19 | 20 | def test_up(self): 21 | test_path = 'this/is/a/test' 22 | 23 | got = up(test_path) 24 | expected = 'this/is/a' 25 | self.assertEqual(got, expected) 26 | 27 | test_path = '/this/is/a/test' 28 | 29 | got = up(test_path) 30 | expected = '/this/is/a' 31 | self.assertEqual(got, expected) 32 | 33 | def test_fileName(self): 34 | test_path = 'this/is/a/test' 35 | 36 | got = fileName(test_path) 37 | expected = 'test' 38 | self.assertEqual(got, expected) 39 | 40 | test_path = '/this/is/a/test.txt' 41 | 42 | got = fileName(test_path) 43 | expected = 'test.txt' 44 | self.assertEqual(got, expected) 45 | 46 | test_path = 'test.txt' 47 | 48 | got = fileName(test_path) 49 | expected = 'test.txt' 50 | self.assertEqual(got, expected) 51 | 52 | def test_join(self): 53 | test_parts = ['/this', 'is', 'a/', 'test'] 54 | 55 | got = join(*test_parts) 56 | expected = '/this/is/a/test' 57 | self.assertEqual(got, expected) 58 | 59 | test_parts = ['this', '//is/', '/a/', 'test/'] 60 | 61 | got = join(*test_parts) 62 | expected = 'this/is/a/test' 63 | self.assertEqual(got, expected) 64 | 65 | test_parts = ['this/is', 'a', 'test'] 66 | 67 | got = join(*test_parts) 68 | expected = 'this/is/a/test' 69 | self.assertEqual(got, expected) 70 | -------------------------------------------------------------------------------- /tests/utils/test_permute.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | import unittest 3 | from PyExpUtils.utils.permute import PathDict, getNumberOfPermutations, getParameterPermutation, reconstructParameters 4 | 5 | class TestPermute(unittest.TestCase): 6 | def test_getParameterPermutation(self): 7 | # base functionality 8 | d = { 9 | 'alpha': [1.0, 0.5, 0.25, 0.125], 10 | 'beta': [0.2, 0.4, 0.6], 11 | } 12 | 13 | got = getParameterPermutation(d, 0) 14 | expected = { 15 | 'alpha': 1.0, 16 | 'beta': 0.2, 17 | } 18 | self.assertDictEqual(got, expected) 19 | 20 | got = getParameterPermutation(d, 1) 21 | expected = { 22 | 'alpha': 0.5, 23 | 'beta': 0.2, 24 | } 25 | self.assertDictEqual(got, expected) 26 | 27 | got = getParameterPermutation(d, 4) 28 | expected = { 29 | 'alpha': 1.0, 30 | 'beta': 0.4, 31 | } 32 | self.assertDictEqual(got, expected) 33 | 34 | # nested objects 35 | d = { 36 | 'alpha': [1.0, 0.5], 37 | 'optimizer': { 38 | 'type': ['SGD', 'SuperGood'], 39 | 'beta': 0.1, 40 | } 41 | } 42 | 43 | got = getParameterPermutation(d, 0) 44 | expected = { 45 | 'alpha': 1.0, 46 | 'optimizer': { 47 | 'type': 'SGD', 48 | 'beta': 0.1, 49 | }, 50 | } 51 | self.assertDictEqual(got, expected) 52 | 53 | got = getParameterPermutation(d, 2) 54 | expected = { 55 | 'alpha': 1.0, 56 | 'optimizer': { 57 | 'type': 'SuperGood', 58 | 'beta': 0.1, 59 | }, 60 | } 61 | self.assertDictEqual(got, expected) 62 | 63 | # array of objects 64 | d = { 65 | 'alpha': [1.0, 0.5], 66 | 'layers': [ 67 | { 'type': 'Linear', 'units': [2, 4, 8] }, 68 | { 'type': 'Tanh', 'units': [2, 3, 4] }, 69 | ], 70 | } 71 | 72 | got = getParameterPermutation(d, 0) 73 | expected = { 74 | 'alpha': 1.0, 75 | 'layers': [ 76 | { 'type': 'Linear', 'units': 2 }, 77 | { 'type': 'Tanh', 'units': 2 }, 78 | ], 79 | } 80 | self.assertDictEqual(got, expected) 81 | 82 | got = getParameterPermutation(d, 2) 83 | expected = { 84 | 'alpha': 1.0, 85 | 'layers': [ 86 | { 'type': 'Linear', 'units': 4 }, 87 | { 'type': 'Tanh', 'units': 2 }, 88 | ], 89 | } 90 | self.assertDictEqual(got, expected) 91 | 92 | def test_reconstructParameters(self): 93 | # can reconstruct nested objects from paths 94 | path_dict = cast(PathDict, { 95 | 'metaParameters.alpha': 0.4, 96 | 'metaParameters.beta': 0.2, 97 | 'metaParameters.optimizer.type': 'SGD', 98 | }) 99 | 100 | got = reconstructParameters(path_dict) 101 | expected = { 102 | 'metaParameters': { 103 | 'alpha': 0.4, 104 | 'beta': 0.2, 105 | 'optimizer': { 106 | 'type': 'SGD', 107 | }, 108 | }, 109 | } 110 | self.assertDictEqual(got, expected) 111 | 112 | # can reconstruct lists from paths 113 | path_dict = cast(PathDict, { 114 | 'alpha': 0.1, 115 | 'layers.[0].type': 'SGD', 116 | }) 117 | 118 | got = reconstructParameters(path_dict) 119 | expected = { 120 | 'alpha': 0.1, 121 | 'layers': [ 122 | { 'type': 'SGD' }, 123 | ], 124 | } 125 | self.assertDictEqual(got, expected) 126 | 127 | def test_getNumberOfPermutations(self): 128 | d = { 129 | 'alpha': [1, 2, 3], 130 | 'beta': [4, 3, 2], 131 | 'optimizers': { 132 | 'type': 'momentum', 133 | 'beta': [0.99, 0.98, 0.975], 134 | }, 135 | } 136 | 137 | got = getNumberOfPermutations(d) 138 | expected = 27 139 | self.assertEqual(got, expected) 140 | -------------------------------------------------------------------------------- /tests/utils/test_random.py: -------------------------------------------------------------------------------- 1 | import numba.typed 2 | import unittest 3 | import numpy as np 4 | from PyExpUtils.utils.random import argmax, choice, sample 5 | 6 | class TestRandom(unittest.TestCase): 7 | def test_sample(self): 8 | rng = np.random.default_rng(0) 9 | # base functionality 10 | arr = np.array([.50, .20, .10, .10, .10]) 11 | 12 | got = sample(arr, rng) 13 | expected = 1 # an index from 0-4 14 | 15 | self.assertEqual(got, expected) 16 | 17 | arr = np.array([.01, .01, .08, .9]) 18 | 19 | got = sample(arr, rng) 20 | expected = 3 # an index from 0-3 21 | 22 | self.assertEqual(got, expected) 23 | 24 | def test_choice(self): 25 | rng = np.random.default_rng(0) 26 | 27 | arr = numba.typed.List(['a', 'b', 'c']) 28 | 29 | got = choice(arr, rng) 30 | expected = 'c' # one of the three elements 31 | 32 | self.assertEqual(got, expected) 33 | 34 | counts = {'a': 0, 'b': 0, 'c': 0} 35 | for _ in range(10000): 36 | element = choice(arr, rng) 37 | counts[element] += 1 38 | 39 | # super fragile 40 | # TODO: make this a statistical test for uniformity 41 | self.assertDictEqual(counts, { 'a': 3309, 'b': 3389, 'c': 3302 }) 42 | 43 | def test_argmax(self): 44 | rng = np.random.default_rng(0) 45 | 46 | arr = np.array([3, 2, 3]) 47 | 48 | got = argmax(arr, rng) 49 | expected = 0 # either 0 or 2 50 | 51 | self.assertEqual(got, expected) 52 | 53 | counts = [0, 0, 0] 54 | for _ in range(10000): 55 | got = argmax(arr, rng) 56 | counts[got] += 1 57 | 58 | # TODO: make this a statistical test for uniformity 59 | self.assertEqual(counts, [4971, 0, 5029]) 60 | -------------------------------------------------------------------------------- /tests/utils/test_str.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from PyExpUtils.utils.str import interpolate 3 | 4 | class TestStr(unittest.TestCase): 5 | def test_interpolate(self): 6 | key = 'results/{name}/{run}/data' 7 | d = { 8 | 'name': 'johnny', 9 | 'run': 0, 10 | } 11 | 12 | got = interpolate(key, d) 13 | expected = 'results/johnny/0/data' 14 | 15 | self.assertEqual(got, expected) 16 | -------------------------------------------------------------------------------- /typings/h5py/__init__.pyi: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | class File: 4 | def __init__(self, path: str, mode: str) -> None: ... 5 | 6 | # context manager 7 | def __enter__(self) -> File: ... # noqa: F821 8 | def __exit__(self, type: Any, value: Any, traceback: Any) -> None: ... 9 | 10 | # groups 11 | def create_group(self, name: str) -> Group: ... # noqa: F821 12 | 13 | # act like dict 14 | def __getitem__(self, name: str) -> Any: ... # noqa: F821 15 | def __contains__(self, name: str) -> bool: ... 16 | def keys(self) -> List[str]: ... 17 | 18 | # act like file 19 | def close(self) -> None: ... 20 | 21 | class Group: 22 | # datasets 23 | def create_dataset(self, name: str, data: Any = ..., compression: str = ...) -> None: ... 24 | 25 | class Dataset: 26 | pass 27 | -------------------------------------------------------------------------------- /typings/numba/__init__.pyi: -------------------------------------------------------------------------------- 1 | # from typing import Callable 2 | from typing import Callable, overload 3 | from PyExpUtils.utils.types import T 4 | 5 | @overload 6 | def njit(cache: bool = False, parallel: bool = False, nogil: bool = False, fastmath: bool = False, inline: str = 'always') -> Callable[[T], T]: ... 7 | @overload 8 | def njit(f: T, cache: bool = False, parallel: bool = False, nogil: bool = False, fastmath: bool = False) -> T: ... 9 | 10 | def jit(cache: bool, forceobj: bool) -> Callable[[T], T]: ... 11 | -------------------------------------------------------------------------------- /typings/numba/experimental/__init__.pyi: -------------------------------------------------------------------------------- 1 | from PyExpUtils.utils.types import T 2 | 3 | def jitclass(f: T) -> T: ... 4 | -------------------------------------------------------------------------------- /typings/numba/typed/__init__.pyi: -------------------------------------------------------------------------------- 1 | from typing import List as PyList 2 | 3 | class List(PyList): 4 | def __init__(self, l: PyList): ... 5 | --------------------------------------------------------------------------------