├── mle_logging ├── _version.py ├── load │ ├── __init__.py │ ├── load_model.py │ └── load_log.py ├── merge │ ├── __init__.py │ ├── merge_logs.py │ ├── merge_hdf5.py │ └── aggregate.py ├── save │ ├── __init__.py │ ├── figure_log.py │ ├── extra_log.py │ ├── tboard_log.py │ ├── stats_log.py │ ├── wandb_log.py │ └── model_log.py ├── __init__.py ├── utils │ ├── __init__.py │ ├── comms.py │ └── helpers.py ├── meta_log.py └── mle_logger.py ├── requirements ├── requirements-examples.txt ├── requirements-test.txt └── requirements.txt ├── tests ├── fixtures │ ├── eval_0.yaml │ └── logs │ │ └── seed_aggregated.hdf5 ├── test_utils.py ├── test_logger.py ├── test_reload.py ├── test_model.py └── test_load.py ├── docs ├── logo_transparent.png ├── mle_logger_structure.png ├── config_1.json └── doc_snippet.py ├── examples ├── config_1.json └── config_2.json ├── .codecov.yml ├── .github └── workflows │ ├── pypi_publish.yaml │ └── run_tests.yaml ├── LICENSE ├── .gitignore ├── setup.py ├── CONTRIBUTING.md ├── CHANGELOG.md └── README.md /mle_logging/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.6" 2 | -------------------------------------------------------------------------------- /requirements/requirements-examples.txt: -------------------------------------------------------------------------------- 1 | torch 2 | matplotlib 3 | seaborn -------------------------------------------------------------------------------- /tests/fixtures/eval_0.yaml: -------------------------------------------------------------------------------- 1 | arch: mlp 2 | batch_size: 3 3 | lrate: 0.360379148648584 4 | -------------------------------------------------------------------------------- /docs/logo_transparent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mle-infrastructure/mle-logging/HEAD/docs/logo_transparent.png -------------------------------------------------------------------------------- /docs/mle_logger_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mle-infrastructure/mle-logging/HEAD/docs/mle_logger_structure.png -------------------------------------------------------------------------------- /requirements/requirements-test.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | seaborn 3 | torch 4 | tensorboard 5 | tensorflow 6 | dm-haiku 7 | scikit-learn 8 | jax -------------------------------------------------------------------------------- /tests/fixtures/logs/seed_aggregated.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mle-infrastructure/mle-logging/HEAD/tests/fixtures/logs/seed_aggregated.hdf5 -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | commentjson 2 | pyyaml>=5.1 3 | numpy 4 | h5py 5 | dotmap 6 | pickle5; python_version < '3.8' 7 | rich 8 | pandas 9 | wandb 10 | -------------------------------------------------------------------------------- /mle_logging/load/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_log import load_log, load_meta_log 2 | from .load_model import load_model 3 | 4 | __all__ = ["load_log", "load_meta_log", "load_model"] 5 | -------------------------------------------------------------------------------- /mle_logging/merge/__init__.py: -------------------------------------------------------------------------------- 1 | from .merge_hdf5 import merge_hdf5_files 2 | from .merge_logs import merge_seed_logs, merge_config_logs 3 | 4 | __all__ = ["merge_hdf5_files", "merge_seed_logs", "merge_config_logs"] 5 | -------------------------------------------------------------------------------- /docs/config_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_config": {"lrate": 0.1}, 3 | "model_config": {"num_layers": 5}, 4 | "log_config": {"time_to_track": ["step_counter"], 5 | "what_to_track": ["loss"], 6 | "time_to_print": ["step_counter"], 7 | "what_to_print": ["loss"], 8 | "print_every_k_updates": 10, 9 | "overwrite_experiment_dir": 1} 10 | } 11 | -------------------------------------------------------------------------------- /examples/config_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_config": {"lrate": 0.1}, 3 | "model_config": {"num_layers": 5}, 4 | "log_config": {"time_to_track": ["step_counter"], 5 | "what_to_track": ["loss"], 6 | "time_to_print": ["step_counter"], 7 | "what_to_print": ["loss"], 8 | "print_every_k_updates": 10, 9 | "overwrite_experiment_dir": 1} 10 | } 11 | -------------------------------------------------------------------------------- /examples/config_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_config": {"lrate": 0.1}, 3 | "model_config": {"num_layers": 5}, 4 | "log_config": {"time_to_track": ["step_counter"], 5 | "what_to_track": ["loss"], 6 | "time_to_print": ["step_counter"], 7 | "what_to_print": ["loss"], 8 | "print_every_k_updates": 10, 9 | "overwrite_experiment_dir": 1} 10 | } 11 | -------------------------------------------------------------------------------- /mle_logging/save/__init__.py: -------------------------------------------------------------------------------- 1 | from .stats_log import StatsLog 2 | from .tboard_log import TboardLog 3 | from .wandb_log import WandbLog 4 | from .model_log import ModelLog 5 | from .extra_log import ExtraLog 6 | from .figure_log import FigureLog 7 | 8 | 9 | __all__ = [ 10 | "StatsLog", 11 | "TboardLog", 12 | "WandbLog", 13 | "ModelLog", 14 | "ExtraLog", 15 | "FigureLog", 16 | ] 17 | -------------------------------------------------------------------------------- /mle_logging/__init__.py: -------------------------------------------------------------------------------- 1 | from ._version import __version__ 2 | from .mle_logger import MLELogger 3 | from .load import load_log, load_model 4 | from .utils import load_config 5 | from .merge import merge_config_logs, merge_seed_logs 6 | 7 | 8 | __all__ = [ 9 | "__version__", 10 | "MLELogger", 11 | "load_log", 12 | "load_model", 13 | "load_config", 14 | "merge_config_logs", 15 | "merge_seed_logs", 16 | ] 17 | -------------------------------------------------------------------------------- /.codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: yes 3 | 4 | coverage: 5 | precision: 2 6 | round: down 7 | range: "30...70" 8 | 9 | coverage: 10 | status: 11 | project: # settings affecting project coverage 12 | default: 13 | threshold: 5% # allow for 5% reduction of coverage without failing 14 | 15 | comment: 16 | layout: "reach, diff, files" 17 | behavior: default 18 | require_changes: true 19 | 20 | ignore: 21 | - "tests/.*" 22 | - "_version.py" 23 | - "setup.py" 24 | - "examples/*" 25 | - "docs/" 26 | -------------------------------------------------------------------------------- /mle_logging/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .helpers import ( 2 | load_config, 3 | write_to_hdf5, 4 | visualize_1D_lcurves, 5 | load_pkl_object, 6 | save_pkl_object, 7 | ) 8 | from .comms import ( 9 | print_welcome, 10 | print_startup, 11 | print_update, 12 | print_reload, 13 | print_storage, 14 | ) 15 | 16 | __all__ = [ 17 | "load_config", 18 | "write_to_hdf5", 19 | "visualize_1D_lcurves", 20 | "load_pkl_object", 21 | "save_pkl_object", 22 | "print_welcome", 23 | "print_startup", 24 | "print_update", 25 | "print_reload", 26 | "print_storage", 27 | ] 28 | -------------------------------------------------------------------------------- /.github/workflows/pypi_publish.yaml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v1 12 | - name: Set up Python 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: '3.x' 16 | - name: Install, Build and publish 17 | env: 18 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 19 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install setuptools wheel twine 23 | python setup.py sdist bdist_wheel 24 | twine upload dist/* 25 | -------------------------------------------------------------------------------- /docs/doc_snippet.py: -------------------------------------------------------------------------------- 1 | from mle_logging import MLELogger 2 | import torch.nn as nn 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | class DummyModel(nn.Module): 7 | def __init__(self): 8 | super(DummyModel, self).__init__() 9 | self.fc1 = nn.Linear(28*28, 10) 10 | 11 | def forward(self, x): 12 | x = self.fc1(x) 13 | return x 14 | 15 | 16 | model = DummyModel() 17 | fig, ax = plt.subplots() 18 | extra = {"hi": "there"} 19 | 20 | 21 | def run(): 22 | log = MLELogger(time_to_track=['num_updates', 'num_epochs'], 23 | what_to_track=['train_loss', 'test_loss'], 24 | experiment_dir="experiment_dir/", 25 | model_type="torch", 26 | config_fname="config_1.json", 27 | seed_id=1, 28 | verbose=True) 29 | log.update({'num_updates': 10, 'num_epochs': 1}, 30 | {'train_loss': 0.1234, 'test_loss': 0.1235}, 31 | model, fig, extra, save=False) 32 | 33 | 34 | 35 | if __name__ == "__main__": 36 | run() 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 The Python Packaging Authority 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /mle_logging/load/load_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ..utils import load_pkl_object 3 | 4 | 5 | def load_model(ckpt_path: str, model_type: str, model=None): 6 | """Helper to reload stored checkpoint/pkl & return trained model.""" 7 | if model_type == "torch": 8 | try: 9 | import torch 10 | except ModuleNotFoundError as err: 11 | raise ModuleNotFoundError( 12 | f"{err}. You need to install " 13 | "`torch` if you want to save a model " 14 | "checkpoint." 15 | ) 16 | 17 | checkpoint = torch.load(ckpt_path, map_location="cpu") 18 | if model is not None: 19 | # raise ValueError("Please provide a torch model instance.") 20 | model.load_state_dict(checkpoint) 21 | return model 22 | else: 23 | return checkpoint 24 | elif model_type == "tensorflow": 25 | model.load_weights(ckpt_path) 26 | elif model_type in ["jax", "sklearn"]: 27 | model = load_pkl_object(ckpt_path) 28 | return model 29 | elif model_type == "numpy": 30 | model = np.load(ckpt_path, allow_pickle=True) 31 | return model 32 | else: 33 | raise ValueError( 34 | "Please provide a valid model type ('torch', 'jax'," " 'sklearn', 'numpy')." 35 | ) 36 | -------------------------------------------------------------------------------- /.github/workflows/run_tests.yaml: -------------------------------------------------------------------------------- 1 | 2 | name: Python tests 3 | 4 | on: 5 | push: 6 | branches: [ main ] 7 | pull_request: 8 | branches: [ main ] 9 | 10 | jobs: 11 | test: 12 | name: Run tests on ${{ matrix.os }} with Python ${{ matrix.python }} 13 | strategy: 14 | matrix: 15 | os: [ubuntu-latest] 16 | python: ['3.7', '3.8', '3.9'] 17 | runs-on: ${{ matrix.os }} 18 | 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python ${{ matrix.python }} 22 | uses: actions/setup-python@v1 23 | with: 24 | python-version: ${{ matrix.python }} 25 | - name: Install testing and linting dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install pytest pytest-timeout pytest-cov 29 | pip install flake8 black 30 | pip install -r requirements/requirements-test.txt 31 | pip install -e . 32 | - name: Lint with flake8 33 | run: | 34 | # stop the build if there are Python syntax errors or undefined names 35 | flake8 ./mle_logging --count --select=E9,F63,F7,F82 --show-source --statistics 36 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 37 | flake8 ./mle_logging --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 38 | - name: Run unit/integration tests 39 | run: | 40 | pytest -vv --durations=0 --cov=./ --cov-report=term-missing --cov-report=xml 41 | - name: "Upload coverage to Codecov" 42 | uses: codecov/codecov-action@v2 43 | with: 44 | token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos 45 | fail_ci_if_error: true # optional (default = false) 46 | verbose: true # optional (default = false) 47 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import matplotlib.pyplot as plt 4 | from mle_logging import MLELogger, load_log, load_config 5 | from mle_logging.utils import visualize_1D_lcurves 6 | 7 | 8 | log_config = { 9 | "time_to_track": ["num_updates", "num_epochs"], 10 | "what_to_track": ["train_loss", "test_loss"], 11 | "experiment_dir": "experiment_dir/", 12 | "config_fname": None, 13 | "use_tboard": True, 14 | "model_type": "torch", 15 | } 16 | 17 | time_tic = {"num_updates": 10, "num_epochs": 1} 18 | stats_tic = {"train_loss": 0.1234, "test_loss": 0.1235} 19 | 20 | 21 | class DummyModel(nn.Module): 22 | def __init__(self): 23 | super(DummyModel, self).__init__() 24 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 25 | self.fc2 = nn.Linear(120, 84) 26 | self.fc3 = nn.Linear(84, 10) 27 | 28 | def forward(self, x): 29 | x = self.fc1(x) 30 | x = self.fc2(x) 31 | x = self.fc3(x) 32 | return x 33 | 34 | 35 | model = DummyModel() 36 | 37 | fig, ax = plt.subplots() 38 | ax.plot(np.random.normal(0, 1, 20)) 39 | 40 | some_dict = {"hi": "there"} 41 | 42 | 43 | def test_comms(): 44 | """Test functional verbose statements.""" 45 | log = MLELogger(**log_config, verbose=True) 46 | log.update(time_tic, stats_tic, model, fig, some_dict, save=True) 47 | 48 | 49 | def test_load_config(): 50 | config = load_config("tests/fixtures/eval_0.yaml", True) 51 | assert config.lrate == 0.360379148648584 52 | 53 | 54 | def test_plot_lcurves(): 55 | # Load the merged log - Individual seeds can be accessed via log.seed_1, etc. 56 | log = load_log("tests/fixtures") 57 | log.plot("train_loss", "num_updates") 58 | 59 | log = load_log("tests/fixtures", aggregate_seeds=True) 60 | log.plot("train_loss", "num_updates") 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.gif 2 | *.zip 3 | .vim-arsync 4 | .DS_Store 5 | __pycache__ 6 | .sync-config.cson 7 | .ipynb_checkpoints 8 | *.egg-info 9 | data/ 10 | *.key 11 | tboards/ 12 | *.ckpt 13 | .pytest_cache/ 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | experiment_dir 19 | # C extensions 20 | *.so 21 | examples/multi_seed_dir 22 | examples/multi_config_dir 23 | examples/every_k_dir 24 | examples/top_k_dir 25 | examples/post_plot_dir 26 | 27 | # Distribution / packaging 28 | .Python 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | MANIFEST 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # celery beat schedule file 97 | celerybeat-schedule 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | -------------------------------------------------------------------------------- /mle_logging/merge/merge_logs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from typing import Union 4 | from .merge_hdf5 import merge_hdf5_files 5 | 6 | 7 | def merge_seed_logs( 8 | merged_path: str, 9 | experiment_dir: str, 10 | num_logs: Union[int, None] = None, 11 | delete_files: bool = True, 12 | ) -> None: 13 | """Merge all .hdf5 files for different seeds into single log.""" 14 | # Collect paths in log dir until the num_logs is found 15 | log_dir = os.path.join(experiment_dir, "logs") 16 | while True: 17 | log_paths = [os.path.join(log_dir, log) for log in os.listdir(log_dir)] 18 | if num_logs is not None: 19 | if len(log_paths) == num_logs: 20 | # Delete joined log if at some point over-eagerly merged 21 | if merged_path in log_paths: 22 | os.remove(merged_path) 23 | break 24 | else: 25 | time.sleep(1) 26 | else: 27 | break 28 | merge_hdf5_files(merged_path, log_paths, delete_files=delete_files) 29 | 30 | 31 | def merge_config_logs(experiment_dir: str, all_run_ids: list) -> None: 32 | """Scavenge the experiment dictonaries & load in logs.""" 33 | all_folders = [x[0] for x in os.walk(experiment_dir)][1:] 34 | # Get rid of timestring in beginning & collect all folders/hdf5 files 35 | hyperp_results_folder = [] 36 | # Need to make sure that run_ids & experiment folder match! 37 | for run_id in all_run_ids: 38 | for f in all_folders: 39 | path, file = os.path.split(f) 40 | if file == run_id: 41 | hyperp_results_folder.append(f) 42 | continue 43 | # Collect all paths to the .hdf5 file 44 | log_paths = [] 45 | 46 | for i in range(len(hyperp_results_folder)): 47 | log_d_t = os.path.join(hyperp_results_folder[i], "logs/") 48 | for file in os.listdir(log_d_t): 49 | fname, fext = os.path.splitext(file) 50 | if file.endswith(".hdf5"): 51 | if fname in all_run_ids or fname == "log": 52 | log_paths.append(os.path.join(log_d_t, file)) 53 | 54 | # Merge individual run results into a single hdf5 file 55 | assert len(log_paths) == len(all_run_ids) 56 | 57 | meta_log_fname = os.path.join(experiment_dir, "meta_log.hdf5") 58 | merge_hdf5_files(meta_log_fname, log_paths, file_ids=all_run_ids) 59 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | try: 2 | from setuptools import setup, find_packages 3 | except ImportError: 4 | from distutils.core import setup, find_packages 5 | 6 | import re 7 | import os 8 | from typing import List 9 | 10 | CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) 11 | 12 | with open(os.path.join(CURRENT_DIR, "README.md"), encoding="utf-8") as f: 13 | long_description = f.read() 14 | 15 | 16 | def parse_requirements(path: str) -> List[str]: 17 | with open(os.path.join(CURRENT_DIR, path)) as f: 18 | return [ 19 | line.rstrip() 20 | for line in f 21 | if not (line.isspace() or line.startswith("#")) 22 | ] 23 | 24 | 25 | VERSIONFILE = "mle_logging/_version.py" 26 | verstrline = open(VERSIONFILE, "rt").read() 27 | VSRE = r"^__version__ = ['\"]([^'\"]*)['\"]" 28 | mo = re.search(VSRE, verstrline, re.M) 29 | if mo: 30 | verstr = mo.group(1) 31 | else: 32 | raise RuntimeError("Unable to find version string in %s." % (VERSIONFILE,)) 33 | git_tar = f"https://github.com/mle-infrastructure/mle-logging/archive/v{verstr}.tar.gz" 34 | 35 | 36 | setup( 37 | name="mle_logging", 38 | version=verstr, 39 | author="Robert Tjarko Lange", 40 | author_email="robertlange0@gmail.com", 41 | description="Machine Learning Experiment Logging", 42 | long_description=long_description, 43 | long_description_content_type="text/markdown", 44 | url="https://github.com/mle-infrastructure/mle-logging", 45 | download_url=git_tar, 46 | classifiers=[ 47 | "Programming Language :: Python :: 3.6", 48 | "Programming Language :: Python :: 3.7", 49 | "Programming Language :: Python :: 3.8", 50 | "Programming Language :: Python :: 3.9", 51 | "License :: OSI Approved :: MIT License", 52 | "Operating System :: OS Independent", 53 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 54 | ], 55 | packages=find_packages(), 56 | include_package_data=True, 57 | zip_safe=False, 58 | platforms="any", 59 | python_requires=">=3.6", 60 | install_requires=parse_requirements( 61 | os.path.join(CURRENT_DIR, "requirements", "requirements.txt") 62 | ), 63 | tests_require=parse_requirements( 64 | os.path.join(CURRENT_DIR, "requirements", "requirements-test.txt") 65 | ), 66 | extras_require={ 67 | "examples": parse_requirements( 68 | os.path.join( 69 | CURRENT_DIR, "requirements", "requirements-examples.txt" 70 | ) 71 | ) 72 | }, 73 | ) 74 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to `mle-logging` 2 | We love your input! We want to make contributing to this project as easy and transparent as possible, whether it's: 3 | 4 | - Reporting a bug 5 | - Discussing the current state of the code 6 | - Submitting a fix 7 | - Proposing new features 8 | - Becoming a maintainer 9 | 10 | ## We Develop with Github 11 | We use github to host code, to track issues and feature requests, as well as accept pull requests. 12 | 13 | ## We Use [Github Flow](https://guides.github.com/introduction/flow/index.html), So All Code Changes Happen Through Pull Requests 14 | Pull requests are the best way to propose changes to the codebase (we use [Github Flow](https://guides.github.com/introduction/flow/index.html)). We actively welcome your pull requests: 15 | 16 | 1. Fork the repo and create your branch from `master`. 17 | 2. If you've added code that should be tested, add tests. 18 | 3. If you've changed APIs, update the documentation. 19 | 4. Ensure the test suite passes. 20 | 5. Make sure your code lints. 21 | 6. Issue that pull request! 22 | 23 | ## Any contributions you make will be under the MIT Software License 24 | In short, when you submit code changes, your submissions are understood to be under the same [MIT License](http://choosealicense.com/licenses/mit/) that covers the project. Feel free to contact the maintainers if that's a concern. 25 | 26 | ## Report bugs using Github's [issues](https://github.com/mle-infrastructure/mle-logging/issues) 27 | We use GitHub issues to track public bugs. Report a bug by [opening a new issue](); it's that easy! 28 | 29 | ## Write bug reports with detail, background, and sample code 30 | 31 | **Great Bug Reports** tend to have: 32 | 33 | - A quick summary and/or background 34 | - Steps to reproduce 35 | - Be specific! 36 | - Give sample code if you can. 37 | - What you expected would happen 38 | - What actually happens 39 | - Notes (possibly including why you think this might be happening, or stuff you tried that didn't work) 40 | 41 | ## Use the Black Coding Style 42 | The codebase follows the [Black](https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html) coding style. Using a autoformatter can make your life easier! 43 | 44 | ## License 45 | By contributing, you agree that your contributions will be licensed under its MIT License. 46 | 47 | ## References 48 | This document was adapted from the open-source contribution guidelines for [Facebook's Draft](https://github.com/facebook/draft-js/blob/a9316a723f9e918afde44dea68b5f9f39b7d9b00/CONTRIBUTING.md) and from the [Transcriptase adapted version](https://gist.github.com/briandk/3d2e8b3ec8daf5a27a62). 49 | -------------------------------------------------------------------------------- /mle_logging/save/figure_log.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import isfile, join 3 | from typing import Union, List 4 | 5 | 6 | class FigureLog(object): 7 | """Figure Logger Class Instance.""" 8 | 9 | def __init__( 10 | self, 11 | experiment_dir: str = "/", 12 | seed_id: str = "no_seed_provided", 13 | reload: bool = False, 14 | ): 15 | # Setup figure logging directories 16 | self.experiment_dir = experiment_dir 17 | self.figures_dir = os.path.join(self.experiment_dir, "figures/") 18 | self.seed_id = seed_id 19 | 20 | # Reload filenames and counter from previous execution 21 | if reload: 22 | self.reload() 23 | else: 24 | self.fig_save_counter = 0 25 | self.fig_storage_paths: List[str] = [] 26 | 27 | def save(self, fig, fig_fname: Union[str, None] = None) -> None: 28 | """Store a matplotlib figure.""" 29 | # Create new directory to store figures - if it doesn't exist yet 30 | self.fig_save_counter += 1 31 | if self.fig_save_counter == 1: 32 | os.makedirs(self.figures_dir, exist_ok=True) 33 | 34 | # Tick up counter, save figure, store new path to figure 35 | if fig_fname is None: 36 | figure_fname = os.path.join( 37 | self.figures_dir, 38 | "fig_" 39 | + str(self.fig_save_counter) 40 | + "_" 41 | + str(self.seed_id) 42 | + ".png", 43 | ) 44 | else: 45 | self.fig_save_counter -= 1 46 | figure_fname = os.path.join( 47 | self.figures_dir, 48 | fig_fname, 49 | ) 50 | 51 | # Create all subfolders if needed! 52 | try: 53 | os.makedirs(os.path.dirname(figure_fname), exist_ok=True) 54 | except Exception: 55 | pass 56 | 57 | fig.savefig(figure_fname, dpi=300) 58 | self.fig_storage_paths.append(figure_fname) 59 | 60 | def reload(self): 61 | """Reload results from previous experiment run.""" 62 | # Go into figures directory, get list of figure files and set counter 63 | try: 64 | fig_paths = [ 65 | join(self.figures_dir, f) 66 | for f in os.listdir(self.figures_dir) 67 | if isfile(join(self.figures_dir, f)) 68 | ] 69 | self.fig_storage_paths = [ 70 | f for f in fig_paths if f.endswith(str(self.seed_id) + ".png") 71 | ] 72 | self.fig_save_counter = len(self.fig_storage_paths) 73 | except FileNotFoundError: 74 | self.fig_save_counter = 0 75 | self.fig_storage_paths: List[str] = [] 76 | -------------------------------------------------------------------------------- /mle_logging/merge/merge_hdf5.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Union 3 | import h5py 4 | import numpy as np 5 | 6 | 7 | def merge_hdf5_files( 8 | new_filename: str, 9 | log_paths: List[str], 10 | file_ids: Union[None, List[str]] = None, 11 | delete_files: bool = False, 12 | ) -> None: 13 | """Merges a set of hdf5 files into a new hdf5 file with more groups.""" 14 | file_to = h5py.File(new_filename, "w") 15 | for i, log_p in enumerate(log_paths): 16 | file_from = h5py.File(log_p, "r") 17 | datasets = get_datasets("/", file_from) 18 | if file_ids is None: 19 | write_data_to_file(file_to, file_from, datasets) 20 | else: 21 | # Maintain unique config id even if they have same random seed 22 | write_data_to_file(file_to, file_from, datasets, file_ids[i]) 23 | file_from.close() 24 | 25 | # Delete individual log file if desired 26 | if delete_files: 27 | os.remove(log_p) 28 | file_to.close() 29 | 30 | 31 | def get_datasets(key: str, archive: h5py.File): 32 | """Collects different paths to datasets in recursive fashion.""" 33 | if key[-1] != "/": 34 | key += "/" 35 | out = [] 36 | for name in archive[key]: 37 | path = key + name 38 | if isinstance(archive[path], h5py.Dataset): 39 | out += [path] 40 | else: 41 | out += get_datasets(path, archive) 42 | return out 43 | 44 | 45 | def write_data_to_file( 46 | file_to: h5py.File, 47 | file_from: h5py.File, 48 | datasets: List[str], 49 | file_id: Union[str, None] = None, 50 | ): 51 | """Writes the datasets from-to file.""" 52 | # get the group-names from the lists of datasets 53 | groups = list(set([i[::-1].split("/", 1)[1][::-1] for i in datasets])) 54 | if file_id is None: 55 | groups = [i for i in groups if len(i) > 0] 56 | else: 57 | groups = [i[0] + file_id + "/" + i[1:] for i in groups if len(i) > 0] 58 | 59 | # sort groups based on depth 60 | idx = np.argsort(np.array([len(i.split("/")) for i in groups])) 61 | groups = [groups[i] for i in idx] 62 | 63 | # create all groups that contain dataset that will be copied 64 | for group in groups: 65 | file_to.create_group(group) 66 | 67 | # copy datasets 68 | for path in datasets: 69 | # - get group name // - minimum group name // - copy data 70 | group = path[::-1].split("/", 1)[1][::-1] 71 | if len(group) == 0: 72 | group = "/" 73 | if file_id is not None: 74 | group_to_index = group[0] + file_id + "/" + group[1:] 75 | else: 76 | group_to_index = group 77 | file_from.copy(path, file_to[group_to_index]) 78 | 79 | file_from.close() 80 | -------------------------------------------------------------------------------- /mle_logging/save/extra_log.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import isfile, join 3 | from typing import Union, List 4 | from ..utils import save_pkl_object 5 | 6 | 7 | class ExtraLog(object): 8 | """Extra .pkl Object Logger Class Instance.""" 9 | 10 | def __init__( 11 | self, 12 | experiment_dir: str = "/", 13 | seed_id: str = "no_seed_provided", 14 | reload: bool = False, 15 | ): 16 | # Setup extra logging directories 17 | self.experiment_dir = experiment_dir 18 | self.extra_dir = os.path.join(self.experiment_dir, "extra/") 19 | self.seed_id = seed_id 20 | 21 | # Reload filenames and counter from previous execution 22 | if reload: 23 | self.reload() 24 | else: 25 | self.extra_save_counter = 0 26 | self.extra_storage_paths: List[str] = [] 27 | 28 | def save(self, obj, obj_fname: Union[str, None] = None): 29 | """Store a .pkl object.""" 30 | # Create new directory to store objects - if it doesn't exist yet 31 | self.extra_save_counter += 1 32 | if self.extra_save_counter == 1: 33 | os.makedirs(self.extra_dir, exist_ok=True) 34 | 35 | # Tick up counter, save figure, store new path to figure 36 | if obj_fname is None: 37 | obj_fname = os.path.join( 38 | self.extra_dir, 39 | "extra_" 40 | + str(self.extra_save_counter) 41 | + "_" 42 | + str(self.seed_id) 43 | + ".pkl", 44 | ) 45 | else: 46 | self.extra_save_counter -= 1 47 | obj_fname = os.path.join( 48 | self.extra_dir, 49 | obj_fname, 50 | ) 51 | 52 | # Create all subfolders if needed! 53 | try: 54 | os.makedirs(os.path.dirname(obj_fname), exist_ok=True) 55 | except Exception: 56 | pass 57 | 58 | save_pkl_object(obj, obj_fname) 59 | self.extra_storage_paths.append(obj_fname) 60 | 61 | def reload(self): 62 | """Reload results from previous experiment run.""" 63 | # Go into extra directory, get list of files and set counter 64 | try: 65 | extra_paths = [ 66 | join(self.extra_dir, f) 67 | for f in os.listdir(self.extra_dir) 68 | if isfile(join(self.extra_dir, f)) 69 | ] 70 | self.extra_storage_paths = [ 71 | f for f in extra_paths if f.endswith(str(self.seed_id) + ".pkl") 72 | ] 73 | self.extra_save_counter = len(self.extra_storage_paths) 74 | except FileNotFoundError: 75 | self.extra_save_counter = 0 76 | self.extra_storage_paths: List[str] = [] 77 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## [v0.0.6] - [08/2024] 2 | 3 | ### Fixed 4 | 5 | - Adds wandb requirement 6 | 7 | ## [v0.0.5] - [03/2023] 8 | 9 | ### Added 10 | 11 | - Adds new case to `MetaLog` loading: Single configuration with explicit seed. 12 | - Adds test coverage for `comms` and `utils`. 13 | - Adds support for storage of vector-valued stats in `log.update`. 14 | - Adds wandb backend. Based on options `use_wandb` and `wandb_config`. 15 | 16 | ### Changed 17 | 18 | - Changes and reduces requirements 19 | 20 | ### Fixed 21 | 22 | - Fixes different data types issues (#3) 23 | - Fixes string decoding and merging for `MetaLog` 24 | - Log aggregation for single seed/single configuration 25 | - Fixed package dependencies to include pandas 26 | 27 | ## [v0.0.4] - [12/07/2021] 28 | 29 | ### Added 30 | 31 | - Add plot detail options (title, labels) to `meta_log.plot()` 32 | 33 | ### Changed 34 | 35 | - Get rid of time string in sub directories 36 | 37 | ### Fixed 38 | 39 | - Makes log merging more robust 40 | - Small fixes for `mle-monitor` release 41 | - Fix `overwrite` and make verbose warning (delete `log.hdf5` for merged case) 42 | 43 | ## [v0.0.3] - [09/11/2021] 44 | 45 | ### Added 46 | 47 | - Adds function to store initial model checkpoint for post-processing via `log.save_init_model(model)`. 48 | - `MLELogger` got a new optional argument: `config_dict`, which allows you to provide a (nested) configuration of your experiment. It will be stored as a `.yaml` file if you don't provide a path to an alternative configuration file. The file can either be a `.json` or a `.yaml`: 49 | 50 | ```python 51 | log = MLELogger(time_to_track=['num_updates', 'num_epochs'], 52 | what_to_track=['train_loss', 'test_loss'], 53 | experiment_dir="experiment_dir/", 54 | config_dict={"train_config": {"lrate": 0.01}}, 55 | model_type='torch', 56 | verbose=True) 57 | ``` 58 | 59 | - The `config_dict`/ loaded `config_fname` data will be stored in the `meta` data of the loaded log and can be easily retrieved: 60 | 61 | ```python 62 | log = load_log("experiment_dir/") 63 | log.meta.config_dict 64 | ``` 65 | 66 | ### Fixed 67 | 68 | - Fix byte decoding for strings stored as arrays in `.hdf5` log file. Previously this only worked for multi seed/config settings. 69 | 70 | ## [v0.0.2] - [08/23/2021] 71 | 72 | ### Added 73 | 74 | - Enhances verbosity and nice rich layout printing. 75 | 76 | ## [v0.0.1] - [08/18/2021] 77 | 78 | ### Added 79 | 80 | - Basic `mle-logging` API: 81 | 82 | ```python 83 | from mle_logging import MLELogger 84 | 85 | # Instantiate logging to experiment_dir 86 | log = MLELogger(time_to_track=['num_updates', 'num_epochs'], 87 | what_to_track=['train_loss', 'test_loss'], 88 | experiment_dir="experiment_dir/", 89 | model_type='torch') 90 | 91 | time_tic = {'num_updates': 10, 'num_epochs': 1} 92 | stats_tic = {'train_loss': 0.1234, 'test_loss': 0.1235} 93 | 94 | # Update the log with collected data & save it to .hdf5 95 | log.update(time_tic, stats_tic) 96 | log.save() 97 | ``` 98 | -------------------------------------------------------------------------------- /mle_logging/save/tboard_log.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict 3 | 4 | 5 | class TboardLog(object): 6 | """Tensorboard Logger Class Instance.""" 7 | 8 | def __init__( 9 | self, 10 | experiment_dir: str, 11 | seed_id: str, 12 | ): 13 | # Setup figure logging directories 14 | try: 15 | from torch.utils.tensorboard import SummaryWriter 16 | except ModuleNotFoundError as err: 17 | raise ModuleNotFoundError( 18 | f"{err}. You need to install " 19 | "`torch` if you want that " 20 | "MLELogger logs to tensorboard." 21 | ) 22 | self.writer = SummaryWriter( 23 | experiment_dir + "/tboards/" + "tboard" + "_" + seed_id 24 | ) 25 | 26 | def update( 27 | self, 28 | time_to_track: list, 29 | clock_tick: Dict[str, int], 30 | stats_tick: Dict[str, float], 31 | model_type: str, 32 | model=None, 33 | grads=None, 34 | plot_to_tboard=None, 35 | ): 36 | """Update the tensorboard with the newest events""" 37 | # Set the x-axis time variable to first key provided in time key dict 38 | time_var_id = clock_tick[time_to_track[1]] 39 | 40 | # Add performance & step counters 41 | for k in stats_tick.keys(): 42 | self.writer.add_scalar( 43 | "performance/" + k, np.mean(stats_tick[k]), time_var_id 44 | ) 45 | 46 | # Log the model params & gradients 47 | if model is not None: 48 | if model_type == "torch": 49 | for name, param in model.named_parameters(): 50 | try: 51 | self.writer.add_histogram( 52 | "weights/" + name, 53 | param.clone().cpu().data.numpy(), 54 | time_var_id, 55 | ) 56 | except Exception: 57 | continue 58 | # Try getting gradients from torch model 59 | try: 60 | self.writer.add_histogram( 61 | "gradients/" + name, 62 | param.grad.clone().cpu().data.numpy(), 63 | time_var_id, 64 | ) 65 | except Exception: 66 | continue 67 | elif model_type == "jax": 68 | # Try to add parameters from nested dict first - then simple 69 | for layer in model.keys(): 70 | try: 71 | for w in model[layer].keys(): 72 | self.writer.add_histogram( 73 | "weights/" + layer + "/" + w, 74 | np.array(model[layer][w]), 75 | time_var_id, 76 | ) 77 | except Exception: 78 | try: 79 | self.writer.add_histogram( 80 | "weights/" + layer, 81 | np.array(model[layer]), 82 | time_var_id, 83 | ) 84 | except Exception: 85 | pass 86 | 87 | # Add the plot of interest to tboard 88 | if plot_to_tboard is not None: 89 | self.writer.add_figure("plot", plot_to_tboard, time_var_id) 90 | 91 | # Flush the log event 92 | self.writer.flush() 93 | -------------------------------------------------------------------------------- /mle_logging/save/stats_log.py: -------------------------------------------------------------------------------- 1 | import time 2 | from datetime import datetime 3 | from typing import List, Dict, Union 4 | from ..load import load_log 5 | 6 | 7 | class StatsLog(object): 8 | """Time-Series Statistics Logger Class Instance.""" 9 | 10 | def __init__( 11 | self, 12 | experiment_dir: str, 13 | seed_id: str, 14 | time_to_track: List[str] = [], 15 | what_to_track: List[str] = [], 16 | reload: bool = False, 17 | freeze_keys: bool = False, # Freeze keys that are stored in time/stats 18 | ): 19 | self.experiment_dir = experiment_dir 20 | self.seed_id = seed_id 21 | # Create empty dataframes to log statistics in 22 | self.time_to_track = ["time", "time_elapsed", "num_updates"] + time_to_track 23 | self.what_to_track = what_to_track 24 | self.clock_tracked = {k: [] for k in self.time_to_track} 25 | self.stats_tracked = {k: [] for k in self.what_to_track} 26 | self.freeze_keys = freeze_keys 27 | # Set update counter & start stop-watch/clock of experiment 28 | if reload: 29 | self.reload() 30 | else: 31 | self.stats_update_counter = 0 32 | # Regardless of reloading - start time counter at 0 33 | self.start_time = time.time() 34 | 35 | def extend_tracking( 36 | self, 37 | stats_keys: Union[List[str], None] = None, 38 | time_keys: Union[List[str], None] = None, 39 | ) -> None: 40 | """Add string names of variables to track.""" 41 | if stats_keys is not None: 42 | self.what_to_track += stats_keys 43 | for k in stats_keys: 44 | self.stats_tracked[k] = [] 45 | if time_keys is not None: 46 | self.time_to_track += time_keys 47 | for k in time_keys: 48 | self.clock_tracked[k] = [] 49 | 50 | def update(self, clock_tick: Dict[str, int], stats_tick: Dict[str, float]) -> None: 51 | # Check all keys do exist in data dicts to log [exclude time time_elapsed num_updates] 52 | if self.freeze_keys: 53 | for k in self.time_to_track[3:]: 54 | assert k in clock_tick.keys(), f"{k} not in clock_tick keys." 55 | for k in self.what_to_track: 56 | assert k in stats_tick.keys(), f"{k} not in stats_tick keys." 57 | else: 58 | # Update time logged first 59 | self.stats_update_counter += 1 60 | clock_tick["time"] = datetime.today().strftime("%y-%m-%d/%H:%M") 61 | clock_tick["time_elapsed"] = time.time() - self.start_time 62 | clock_tick["num_updates"] = self.stats_update_counter 63 | 64 | for k in clock_tick.keys(): 65 | if k in self.time_to_track: 66 | self.clock_tracked[k].append(clock_tick[k]) 67 | else: 68 | self.time_to_track.append(k) 69 | self.clock_tracked[k] = [clock_tick[k]] 70 | 71 | # Update stats logged next 72 | for k in stats_tick.keys(): 73 | if k in self.what_to_track: 74 | self.stats_tracked[k].append(stats_tick[k]) 75 | else: 76 | self.what_to_track.append(k) 77 | self.stats_tracked[k] = [stats_tick[k]] 78 | 79 | return clock_tick, stats_tick 80 | 81 | def reload(self): 82 | """Reload results from previous experiment run.""" 83 | reloaded_log = load_log(self.experiment_dir, 84 | aggregate_seeds=False, 85 | reload_log=True) 86 | self.clock_tracked, self.stats_tracked = {}, {} 87 | self.what_to_track, self.time_to_track = [], [] 88 | # Make sure to reload in results for correct seed 89 | if reloaded_log.eval_ids[0] == "no_seed_provided": 90 | for k in reloaded_log["no_seed_provided"].time.keys(): 91 | self.time_to_track.append(k) 92 | self.clock_tracked[k] = reloaded_log["no_seed_provided"].time[k].tolist() 93 | for k in reloaded_log["no_seed_provided"].stats.keys(): 94 | self.what_to_track.append(k) 95 | self.stats_tracked[k] = reloaded_log["no_seed_provided"].stats[k].tolist() 96 | else: 97 | for k in reloaded_log[self.seed_id].time.keys(): 98 | self.time_to_track.append(k) 99 | self.clock_tracked[k] = reloaded_log[self.seed_id].time[k].tolist() 100 | for k in reloaded_log[self.seed_id].stats.keys(): 101 | self.what_to_track.append(k) 102 | self.stats_tracked[k] = reloaded_log[self.seed_id].stats[k].tolist() 103 | self.stats_update_counter = self.clock_tracked["num_updates"][-1] 104 | -------------------------------------------------------------------------------- /mle_logging/load/load_log.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | from dotmap import DotMap 4 | import collections 5 | from ..meta_log import MetaLog 6 | 7 | 8 | def load_meta_log(log_fname: str, 9 | aggregate_seeds: bool = True, 10 | reload_log: bool = False) -> MetaLog: 11 | """Load in logging results & mean the results over different runs""" 12 | assert os.path.exists(log_fname), f"File {log_fname} does not exist." 13 | # Open File & Get array names to load in 14 | h5f = h5py.File(log_fname, mode="r", swmr=True) 15 | # Get all ids of all runs (b_1_eval_0, b_1_eval_1, ...) 16 | run_names = list(h5f.keys()) 17 | # Get all main data source keys (single vs multi-seed) 18 | data_sources = list(h5f[run_names[0]].keys()) 19 | data_types = ["meta", "stats", "time"] 20 | 21 | """ 22 | 5 Possible Cases: 23 | 1. Single config - single seed = no aggregation - 'no_seed_provided' 24 | 2. Single config - multi seed = aggregation - seed_id -> meta, stats, time 25 | 3. Multi config - multi seed = aggregation - config_id -> seed_id -> ... 26 | 4. Single config - single seed = aggregation - seed_id provided 27 | 5. Only multi seed -> no config_id key 28 | """ 29 | case_1 = len(run_names) == 1 and collections.Counter( 30 | h5f[run_names[0]].keys() 31 | ) == collections.Counter(data_types) 32 | case_2 = len(run_names) == 1 and collections.Counter( 33 | h5f[run_names[0]].keys() 34 | ) != collections.Counter(data_types) 35 | case_3 = len(run_names) > 1 and collections.Counter( 36 | h5f[run_names[0]].keys() 37 | ) != collections.Counter(data_types) 38 | case_4 = len(run_names) == 1 and collections.Counter( 39 | h5f[run_names[0]].keys() 40 | ) == collections.Counter(data_types) 41 | case_5 = len(run_names) > 1 and collections.Counter( 42 | h5f[run_names[0]].keys() 43 | ) == collections.Counter(data_types) 44 | 45 | result_dict = {key: {} for key in run_names} 46 | # Shallow versus deep aggregation 47 | if case_1 or case_5: 48 | data_items = { 49 | data_types[i]: list(h5f[run_names[0]][data_types[i]].keys()) 50 | for i in range(len(data_types)) 51 | } 52 | for rn in run_names: 53 | run = h5f[rn] 54 | source_to_store = {key: {} for key in data_types} 55 | for ds in data_items: 56 | data_to_store = {key: {} for key in data_items[ds]} 57 | for i, o_name in enumerate(data_items[ds]): 58 | data_to_store[o_name] = run[ds][o_name][:] 59 | source_to_store[ds] = data_to_store 60 | result_dict[rn] = source_to_store 61 | elif case_2 or case_3 or case_4: 62 | data_items = { 63 | data_types[i]: list( 64 | h5f[run_names[0]][data_sources[0]][data_types[i]].keys() 65 | ) 66 | for i in range(len(data_types)) 67 | } 68 | for rn in run_names: 69 | run = h5f[rn] 70 | result_dict[rn] = {} 71 | for seed_id in data_sources: 72 | source_to_store = {key: {} for key in data_types} 73 | for ds in data_items: 74 | data_to_store = {key: {} for key in data_items[ds]} 75 | for i, o_name in enumerate(data_items[ds]): 76 | try: 77 | data_to_store[o_name] = run[seed_id][ds][o_name][:] 78 | except KeyError: 79 | pass 80 | source_to_store[ds] = data_to_store 81 | result_dict[rn][seed_id] = source_to_store 82 | # Return as dot-callable dictionary 83 | if aggregate_seeds and (case_2 or case_3 or case_4 or case_5): 84 | # Important aggregation helper & compute mean/median/10p/50p/etc. 85 | from ..merge.aggregate import aggregate_over_seeds 86 | 87 | result_dict = aggregate_over_seeds( 88 | result_dict, batch_case=case_2 or case_3 or case_4 89 | ) 90 | meta_log = MetaLog( 91 | DotMap(result_dict, _dynamic=False), 92 | non_aggregated=(not aggregate_seeds and case_3), 93 | ) 94 | if meta_log.eval_ids is not None: 95 | if meta_log.eval_ids[0] == "no_seed_provided" and not reload_log: 96 | meta_log = meta_log.no_seed_provided 97 | return meta_log 98 | 99 | 100 | def load_log(experiment_dir: str, 101 | aggregate_seeds: bool = False, 102 | reload_log: bool = False) -> MetaLog: 103 | """Load a single .hdf5 log from /logs.""" 104 | if experiment_dir.endswith(".hdf5"): 105 | log_path = experiment_dir 106 | else: 107 | log_dir = os.path.join(experiment_dir, "logs/") 108 | log_paths = [] 109 | for file in os.listdir(log_dir): 110 | if file.endswith(".hdf5"): 111 | log_paths.append(os.path.join(log_dir, file)) 112 | if len(log_paths) > 1: 113 | print(f"Multiple .hdf5 files available: {log_paths}") 114 | print(f"Continue using: {log_paths[0]}") 115 | log_path = log_paths[0] 116 | run_log = load_meta_log(log_path, aggregate_seeds, reload_log) 117 | return run_log 118 | -------------------------------------------------------------------------------- /tests/test_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import numpy as np 4 | import torch.nn as nn 5 | import matplotlib.pyplot as plt 6 | from mle_logging import MLELogger 7 | 8 | 9 | log_config = { 10 | "time_to_track": ["num_updates", "num_epochs"], 11 | "what_to_track": ["train_loss", "test_loss"], 12 | "experiment_dir": "experiment_dir/", 13 | "config_fname": None, 14 | "use_tboard": True, 15 | "model_type": "torch", 16 | } 17 | 18 | time_tic = {"num_updates": 10, "num_epochs": 1} 19 | stats_tic = {"train_loss": 0.1234, "test_loss": 0.1235} 20 | 21 | 22 | class DummyModel(nn.Module): 23 | def __init__(self): 24 | super(DummyModel, self).__init__() 25 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 26 | self.fc2 = nn.Linear(120, 84) 27 | self.fc3 = nn.Linear(84, 10) 28 | 29 | def forward(self, x): 30 | x = self.fc1(x) 31 | x = self.fc2(x) 32 | x = self.fc3(x) 33 | return x 34 | 35 | 36 | model = DummyModel() 37 | 38 | fig, ax = plt.subplots() 39 | ax.plot(np.random.normal(0, 1, 20)) 40 | 41 | some_dict = {"hi": "there"} 42 | 43 | 44 | def test_update_log(): 45 | # Remove experiment dir at start of test 46 | if os.path.exists(log_config["experiment_dir"]) and os.path.isdir( 47 | log_config["experiment_dir"] 48 | ): 49 | shutil.rmtree(log_config["experiment_dir"]) 50 | 51 | # Instantiate logging to experiment_dir 52 | log = MLELogger(**log_config) 53 | # Update the log with collected data & save it to .hdf5 54 | log.update(time_tic, stats_tic) 55 | log.save() 56 | 57 | # Assert the existence of the files 58 | assert os.path.exists(os.path.join(log_config["experiment_dir"], "logs")) 59 | assert os.path.exists(os.path.join(log_config["experiment_dir"], "tboards")) 60 | file_to_check = os.path.join( 61 | log_config["experiment_dir"], "logs", "log_no_seed_provided.hdf5" 62 | ) 63 | assert os.path.exists(file_to_check) 64 | 65 | # Finally -- clean up 66 | shutil.rmtree(log_config["experiment_dir"]) 67 | 68 | 69 | def test_save_plot(): 70 | # Remove experiment dir at start of test 71 | if os.path.exists(log_config["experiment_dir"]) and os.path.isdir( 72 | log_config["experiment_dir"] 73 | ): 74 | shutil.rmtree(log_config["experiment_dir"]) 75 | 76 | # Instantiate logging to experiment_dir 77 | log = MLELogger(**log_config) 78 | 79 | # Save a matplotlib figure as .png 80 | log.save_plot(fig) 81 | 82 | # Assert the existence of the files 83 | file_to_check = os.path.join( 84 | log_config["experiment_dir"], "figures", "fig_1_no_seed_provided.png" 85 | ) 86 | assert os.path.exists(file_to_check) 87 | 88 | # Finally -- clean up 89 | shutil.rmtree(log_config["experiment_dir"]) 90 | 91 | 92 | def test_save_extra(): 93 | # Remove experiment dir at start of test 94 | if os.path.exists(log_config["experiment_dir"]) and os.path.isdir( 95 | log_config["experiment_dir"] 96 | ): 97 | shutil.rmtree(log_config["experiment_dir"]) 98 | 99 | # Instantiate logging to experiment_dir 100 | log = MLELogger(**log_config) 101 | 102 | # Save a dict as a .pkl object 103 | log.save_extra(some_dict) 104 | 105 | # Assert the existence of the files 106 | file_to_check = os.path.join( 107 | log_config["experiment_dir"], "extra", "extra_1_no_seed_provided.pkl" 108 | ) 109 | assert os.path.exists(file_to_check) 110 | 111 | # Finally -- clean up 112 | shutil.rmtree(log_config["experiment_dir"]) 113 | 114 | 115 | def test_all_in_one(): 116 | # Remove experiment dir at start of test 117 | if os.path.exists(log_config["experiment_dir"]) and os.path.isdir( 118 | log_config["experiment_dir"] 119 | ): 120 | shutil.rmtree(log_config["experiment_dir"]) 121 | 122 | # Instantiate logging to experiment_dir 123 | log = MLELogger(**log_config) 124 | 125 | # Save a dict as a .pkl object 126 | log.save_init_model(model) 127 | log.update(time_tic, stats_tic, model, fig, some_dict, save=True) 128 | 129 | # Assert the existence of the files 130 | assert os.path.exists(os.path.join(log_config["experiment_dir"], "logs")) 131 | assert os.path.exists(os.path.join(log_config["experiment_dir"], "tboards")) 132 | file_to_check = os.path.join( 133 | log_config["experiment_dir"], "logs", "log_no_seed_provided.hdf5" 134 | ) 135 | assert os.path.exists(file_to_check) 136 | 137 | file_to_check = os.path.join( 138 | log_config["experiment_dir"], "models/init", "init_no_seed_provided.pt" 139 | ) 140 | assert os.path.exists(file_to_check) 141 | 142 | file_to_check = os.path.join( 143 | log_config["experiment_dir"], "models/final", "final_no_seed_provided.pt" 144 | ) 145 | assert os.path.exists(file_to_check) 146 | 147 | file_to_check = os.path.join( 148 | log_config["experiment_dir"], "figures", "fig_1_no_seed_provided.png" 149 | ) 150 | assert os.path.exists(file_to_check) 151 | 152 | file_to_check = os.path.join( 153 | log_config["experiment_dir"], "extra", "extra_1_no_seed_provided.pkl" 154 | ) 155 | assert os.path.exists(file_to_check) 156 | 157 | # Finally -- clean up 158 | shutil.rmtree(log_config["experiment_dir"]) 159 | -------------------------------------------------------------------------------- /tests/test_reload.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import collections 3 | import os 4 | import numpy as np 5 | import torch.nn as nn 6 | import matplotlib.pyplot as plt 7 | from mle_logging import MLELogger 8 | 9 | time_tic1 = {"num_steps": 10, "num_epochs": 1} 10 | stats_tic1 = {"train_loss": 0.1234, "test_loss": 0.1235} 11 | time_tic2 = {"num_steps": 20, "num_epochs": 1} 12 | stats_tic2 = {"train_loss": 0.2, "test_loss": 0.1} 13 | time_tic3 = {"num_steps": 30, "num_epochs": 1} 14 | stats_tic3 = {"train_loss": 0.223, "test_loss": 0.097} 15 | time_tic4 = {"num_steps": 40, "num_epochs": 1} 16 | stats_tic4 = {"train_loss": 0.123, "test_loss": 0.085} 17 | 18 | 19 | class DummyModel(nn.Module): 20 | def __init__(self): 21 | super(DummyModel, self).__init__() 22 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 23 | self.fc2 = nn.Linear(120, 84) 24 | self.fc3 = nn.Linear(84, 10) 25 | 26 | def forward(self, x): 27 | x = self.fc1(x) 28 | x = self.fc2(x) 29 | x = self.fc3(x) 30 | return x 31 | 32 | 33 | model = DummyModel() 34 | 35 | fig, ax = plt.subplots() 36 | ax.plot(np.random.normal(0, 1, 20)) 37 | 38 | some_dict = {"hi": "there"} 39 | 40 | log_config = { 41 | "time_to_track": ["num_steps", "num_epochs"], 42 | "what_to_track": ["train_loss", "test_loss"], 43 | "experiment_dir": "reload_dir/", 44 | "model_type": "torch", 45 | "ckpt_time_to_track": "num_steps", 46 | "save_every_k_ckpt": 2, 47 | "save_top_k_ckpt": 2, 48 | "top_k_metric_name": "test_loss", 49 | "top_k_minimize_metric": True, 50 | } 51 | 52 | 53 | def test_reload(): 54 | """Test reloading/continuation of previous log with top/every k.""" 55 | if os.path.exists(log_config["experiment_dir"]) and os.path.isdir( 56 | log_config["experiment_dir"] 57 | ): 58 | shutil.rmtree(log_config["experiment_dir"]) 59 | 60 | log = MLELogger(**log_config) 61 | log.update(time_tic1, stats_tic1, model, fig, some_dict, save=True) 62 | log.update(time_tic2, stats_tic2, model, fig, some_dict, save=True) 63 | log.update(time_tic3, stats_tic3, model, fig, some_dict, save=True) 64 | 65 | # Reload the previously instantiated logger from the directory 66 | relog = MLELogger(**log_config, reload=True) 67 | # Check correctness of checkpoints 68 | assert collections.Counter(relog.model_log.top_k_ckpt_list) == collections.Counter( 69 | [ 70 | "reload_dir/models/top_k/top_k_no_seed_provided_top_0.pt", 71 | "reload_dir/models/top_k/top_k_no_seed_provided_top_1.pt", 72 | ] 73 | ) 74 | assert collections.Counter( 75 | relog.model_log.top_k_storage_time 76 | ) == collections.Counter([20, 30]) 77 | assert np.allclose(relog.model_log.top_k_performance, [0.097, 0.1]) 78 | assert collections.Counter( 79 | relog.model_log.every_k_storage_time 80 | ) == collections.Counter([20]) 81 | assert collections.Counter( 82 | relog.model_log.every_k_ckpt_list 83 | ) == collections.Counter( 84 | ["reload_dir/models/every_k/every_k_no_seed_provided_k_2.pt"] 85 | ) 86 | 87 | # Check correctness of figure paths 88 | assert collections.Counter( 89 | relog.figure_log.fig_storage_paths 90 | ) == collections.Counter( 91 | [ 92 | "reload_dir/figures/fig_1_no_seed_provided.png", 93 | "reload_dir/figures/fig_2_no_seed_provided.png", 94 | "reload_dir/figures/fig_3_no_seed_provided.png", 95 | ] 96 | ) 97 | # Check correctness of extra paths 98 | assert collections.Counter( 99 | relog.extra_log.extra_storage_paths 100 | ) == collections.Counter( 101 | [ 102 | "reload_dir/extra/extra_1_no_seed_provided.pkl", 103 | "reload_dir/extra/extra_2_no_seed_provided.pkl", 104 | "reload_dir/extra/extra_3_no_seed_provided.pkl", 105 | ] 106 | ) 107 | 108 | # Check correctness of reloaded statistics 109 | assert np.allclose( 110 | relog.stats_log.stats_tracked["test_loss"], np.array([0.1235, 0.1, 0.097]) 111 | ) 112 | assert np.allclose( 113 | relog.stats_log.clock_tracked["num_steps"], np.array([10, 20, 30]) 114 | ) 115 | 116 | # Add new result to log 117 | relog.update(time_tic4, stats_tic4, model, fig, some_dict, save=True) 118 | 119 | # Check correctness of figure paths 120 | assert collections.Counter( 121 | relog.figure_log.fig_storage_paths 122 | ) == collections.Counter( 123 | [ 124 | "reload_dir/figures/fig_1_no_seed_provided.png", 125 | "reload_dir/figures/fig_2_no_seed_provided.png", 126 | "reload_dir/figures/fig_3_no_seed_provided.png", 127 | "reload_dir/figures/fig_4_no_seed_provided.png", 128 | ] 129 | ) 130 | # Check correctness of extra paths 131 | assert collections.Counter( 132 | relog.extra_log.extra_storage_paths 133 | ) == collections.Counter( 134 | [ 135 | "reload_dir/extra/extra_1_no_seed_provided.pkl", 136 | "reload_dir/extra/extra_2_no_seed_provided.pkl", 137 | "reload_dir/extra/extra_3_no_seed_provided.pkl", 138 | "reload_dir/extra/extra_4_no_seed_provided.pkl", 139 | ] 140 | ) 141 | 142 | # Check correctness of reloaded statistics 143 | assert np.allclose( 144 | np.array(relog.stats_log.stats_tracked["test_loss"]), 145 | np.array([0.1235, 0.1, 0.097, 0.085]), 146 | ) 147 | assert np.allclose( 148 | np.array(relog.stats_log.clock_tracked["num_steps"]), 149 | np.array([10, 20, 30, 40]), 150 | ) 151 | 152 | # Clean up/delete files 153 | shutil.rmtree(log_config["experiment_dir"]) 154 | -------------------------------------------------------------------------------- /mle_logging/save/wandb_log.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Optional 3 | import numpy as np 4 | 5 | 6 | def setup_wandb_env(wandb_config: dict): 7 | """Set up environment variables for W&B logging.""" 8 | if "key" in wandb_config.keys(): 9 | os.environ["WANDB_API_KEY"] = wandb_config["key"] 10 | if "entity" in wandb_config.keys(): 11 | os.environ["WANDB_ENTITY"] = wandb_config["entity"] 12 | if "project" in wandb_config.keys(): 13 | os.environ["WANDB_PROJECT"] = wandb_config["project"] 14 | else: 15 | os.environ["WANDB_PROJECT"] = "prototyping" 16 | if "name" in wandb_config.keys(): 17 | os.environ["WANDB_NAME"] = wandb_config["name"] 18 | if "group" in wandb_config.keys(): 19 | if wandb_config["group"] is not None: 20 | os.environ["WANDB_RUN_GROUP"] = wandb_config["group"] 21 | if "job_type" in wandb_config.keys(): 22 | os.environ["WANDB_JOB_TYPE"] = wandb_config["job_type"] 23 | os.environ["WANDB_TAGS"] = "{}, {}".format( 24 | wandb_config["name"], wandb_config["job_type"] 25 | ) 26 | os.environ["WANDB_SILENT"] = "true" 27 | os.environ["WANDB_DISABLE_SERVICE"] = "true" 28 | 29 | 30 | class WandbLog(object): 31 | """Weights&Biases Logger Class Instance.""" 32 | 33 | def __init__( 34 | self, 35 | config_dict: Optional[dict], 36 | config_fname: Optional[str], 37 | seed_id: str, 38 | wandb_config: Optional[dict], 39 | ): 40 | # Setup figure logging directories 41 | try: 42 | import wandb 43 | 44 | global wandb 45 | except ModuleNotFoundError as err: 46 | raise ModuleNotFoundError( 47 | f"{err}. You need to install " 48 | "`wandb` if you want that " 49 | "MLELogger logs to Weights&Biases." 50 | ) 51 | self.wandb_config = wandb_config 52 | # config should contain - key, entity, project, group (experiment) 53 | for k in ["key", "entity", "project", "group"]: 54 | assert k in self.wandb_config.keys() 55 | 56 | # Setup the environment variables for W&B logging. 57 | if config_fname is None: 58 | config_fname = "pholder_config" 59 | else: 60 | path = os.path.normpath(config_fname) 61 | path_norm = path.split(os.sep) 62 | config_fname, _ = os.path.splitext(path_norm[-1]) 63 | 64 | if config_dict is None: 65 | config_dict = {} 66 | self.setup(config_dict, config_fname, seed_id) 67 | 68 | def setup(self, config_dict: dict, config_fname: str, seed_id: str): 69 | """Setup wandb process for logging.""" 70 | if self.wandb_config["group"] is None: 71 | self.wandb_config["job_type"] = seed_id 72 | else: 73 | self.wandb_config["job_type"] = config_fname 74 | # Replace name by seed if not otherwise specified 75 | if self.wandb_config["name"] == "seed0": 76 | self.wandb_config["name"] = seed_id 77 | setup_wandb_env(self.wandb_config) 78 | 79 | # Try opening port 10 times 80 | for _ in range(10): 81 | try: 82 | wandb.init( 83 | config=config_dict, 84 | group=self.wandb_config["group"], 85 | job_type=self.wandb_config["job_type"], 86 | ) 87 | self.correct_setup = True 88 | break 89 | except Exception: 90 | self.correct_setup = False 91 | pass 92 | self.step_counter = 0 93 | 94 | def update( 95 | self, 96 | clock_tick: Dict[str, int], 97 | stats_tick: Dict[str, float], 98 | model_type: Optional[str] = None, 99 | model=None, 100 | grads=None, 101 | plot_to_wandb: Optional[str] = None, 102 | ): 103 | """Update the wandb with the newest events""" 104 | if self.correct_setup: 105 | log_dict = {} 106 | for k, v in clock_tick.items(): 107 | log_dict["time/" + k] = v 108 | for k, v in stats_tick.items(): 109 | log_dict["stats/" + k] = v 110 | if plot_to_wandb is not None: 111 | log_dict["img"] = wandb.Image(plot_to_wandb) 112 | # Log stats to W&B log 113 | wandb.log( 114 | log_dict, 115 | step=self.step_counter, 116 | ) 117 | 118 | # Log model parameters and gradients if provided 119 | if model is not None and model_type == "jax": 120 | w_norm, w_hist = get_jax_norm_hist(model) 121 | wandb.log( 122 | {"params_norm/": w_norm, "params_hist/": w_hist}, 123 | step=self.step_counter, 124 | ) 125 | if grads is not None and model_type == "jax": 126 | g_norm, g_hist = get_jax_norm_hist(grads) 127 | wandb.log( 128 | {"grads_norm/": g_norm, "grads_hist/": g_hist}, 129 | step=self.step_counter, 130 | ) 131 | # Log model gradients if provided 132 | self.step_counter += 1 133 | 134 | def upload_gif(self, gif_path: str, video_name: Optional[str] = "video"): 135 | """Upload a gif file to W&B based on path""" 136 | wandb.log( 137 | {"video_name": wandb.Video(gif_path)}, 138 | step=self.step_counter, 139 | ) 140 | 141 | 142 | def get_jax_norm_hist(model): 143 | """Get norm of modules in jax model.""" 144 | import jax 145 | from flax.core import unfreeze 146 | 147 | def norm(val): 148 | return jax.tree_map(lambda x: np.linalg.norm(x), val) 149 | 150 | def histogram(val): 151 | return jax.tree_map(lambda x: np.histogram(x, density=True), val) 152 | 153 | w_norm = unfreeze(norm(model)) 154 | hist = histogram(model) 155 | hist = jax.tree_map(lambda x: jax.device_get(x), unfreeze(hist)) 156 | w_hist = jax.tree_map( 157 | lambda x: wandb.Histogram(np_histogram=x), 158 | hist, 159 | is_leaf=lambda x: isinstance(x, tuple), 160 | ) 161 | return w_norm, w_hist 162 | -------------------------------------------------------------------------------- /mle_logging/meta_log.py: -------------------------------------------------------------------------------- 1 | import ast 2 | from dotmap import DotMap 3 | from typing import Union, List 4 | from .utils import visualize_1D_lcurves 5 | 6 | 7 | class MetaLog(object): 8 | meta_vars: List[str] 9 | stats_vars: List[str] 10 | time_vars: List[str] 11 | num_configs: int 12 | 13 | def __init__(self, meta_log: DotMap, non_aggregated: bool = False): 14 | """Class wrapper for meta_log dictionary w. additional functionality. 15 | 16 | Args: 17 | meta_log (DotMap): Raw reloaded meta-log dotmap dictionary. 18 | non_aggregated (bool, optional): 19 | Whether the meta-log has previously been aggregated across 20 | seeds. Defaults to False. 21 | """ 22 | self.meta_log = meta_log 23 | 24 | # Return shallow log if there is only a single experiment stored 25 | self.num_configs = len(list(meta_log.keys())) 26 | ph_run = list(meta_log.keys())[0] 27 | ph_seed = list(meta_log[ph_run].keys())[0] 28 | 29 | # Extract different variable names from meta log 30 | if not non_aggregated and ph_seed in ["meta", "stats", "time"]: 31 | self.meta_vars = list(meta_log[ph_run].meta.keys()) 32 | self.stats_vars = list(meta_log[ph_run].stats.keys()) 33 | self.time_vars = list(meta_log[ph_run].time.keys()) 34 | else: 35 | self.meta_vars = list(meta_log[ph_run][ph_seed].meta.keys()) 36 | self.stats_vars = list(meta_log[ph_run][ph_seed].stats.keys()) 37 | self.time_vars = list(meta_log[ph_run][ph_seed].time.keys()) 38 | 39 | # Decode all byte strings in meta data 40 | for run_id in self.meta_log.keys(): 41 | if "meta" in self.meta_log[run_id].keys(): 42 | try: 43 | self.meta_log[run_id] = decode_meta_strings( 44 | self.meta_log[run_id] 45 | ) 46 | except Exception: 47 | pass 48 | else: 49 | for seed_id in self.meta_log[run_id].keys(): 50 | self.meta_log[run_id][seed_id] = decode_meta_strings( 51 | self.meta_log[run_id][seed_id] 52 | ) 53 | 54 | # Make possible that all runs are accessible via attribute as in pd 55 | for key in self.meta_log: 56 | setattr(self, key, self.meta_log[key]) 57 | 58 | def filter(self, run_ids: List[str]): 59 | """Subselect the meta log dict based on a list of run ids.""" 60 | sub_dict = subselect_meta_log(self.meta_log, run_ids) 61 | return MetaLog(sub_dict) 62 | 63 | def plot( 64 | self, 65 | target_to_plot: str, 66 | iter_to_plot: Union[str, None] = None, 67 | smooth_window: int = 1, 68 | plot_title: Union[str, None] = None, 69 | xy_labels: Union[list, None] = None, 70 | base_label: str = "{}", 71 | run_ids: Union[list, None] = None, 72 | curve_labels: list = [], 73 | every_nth_tick: Union[int, None] = None, 74 | plot_std_bar: bool = False, 75 | fname: Union[None, str] = None, 76 | num_legend_cols: Union[int, None] = 1, 77 | fig=None, 78 | ax=None, 79 | figsize: tuple = (9, 6), 80 | plot_labels: bool = True, 81 | legend_title: Union[None, str] = None, 82 | ax_lims: Union[None, list] = None, 83 | ): 84 | """Plot all runs in meta-log for variable 'target_to_plot'.""" 85 | if iter_to_plot is None: 86 | iter_to_plot = self.time_vars[0] 87 | assert iter_to_plot in self.time_vars 88 | if run_ids is None: 89 | run_ids = self.eval_ids 90 | fig, ax = visualize_1D_lcurves( 91 | self.meta_log, 92 | iter_to_plot, 93 | target_to_plot, 94 | smooth_window=smooth_window, 95 | every_nth_tick=every_nth_tick, 96 | num_legend_cols=num_legend_cols, 97 | run_ids=run_ids, 98 | plot_title=plot_title, 99 | xy_labels=xy_labels, 100 | base_label=base_label, 101 | curve_labels=curve_labels, 102 | plot_std_bar=plot_std_bar, 103 | fig=fig, 104 | ax=ax, 105 | figsize=figsize, 106 | plot_labels=plot_labels, 107 | legend_title=legend_title, 108 | ax_lims=ax_lims, 109 | ) 110 | # Save the figure if a filename was provided 111 | if fname is not None: 112 | fig.savefig(fname, dpi=300) 113 | else: 114 | return fig, ax 115 | 116 | @property 117 | def eval_ids(self) -> Union[int, None]: 118 | """Get ids of runs stored in meta_log instance.""" 119 | return list(self.meta_log.keys()) 120 | 121 | def __len__(self) -> int: 122 | """Return number of runs stored in meta_log.""" 123 | return len(self.eval_ids) 124 | 125 | def __getitem__(self, item): 126 | """Get run log via string subscription.""" 127 | return self.meta_log[item] 128 | 129 | 130 | def subselect_meta_log(meta_log: DotMap, run_ids: List[str]) -> DotMap: 131 | """Subselect the meta log dict based on a list of run ids.""" 132 | sub_log = DotMap() 133 | for run_id in run_ids: 134 | sub_log[run_id] = meta_log[run_id] 135 | return sub_log 136 | 137 | 138 | def decode_meta_strings(log: DotMap): 139 | """Decode all bytes encoded strings.""" 140 | for k in log.meta.keys(): 141 | temp_list = [] 142 | if type(log.meta[k]) != str and type(log.meta[k]) != dict: 143 | list_to_loop = ( 144 | log.meta[k].tolist() 145 | if type(log.meta[k]) != list 146 | else log.meta[k] 147 | ) 148 | 149 | if type(list_to_loop) in [str, bytes]: 150 | list_to_loop = [list_to_loop] 151 | for i in list_to_loop: 152 | if type(i) == bytes: 153 | if len(i) > 0: 154 | temp_list.append(i.decode()) 155 | else: 156 | temp_list.append(i) 157 | else: 158 | temp_list.append(log.meta[k]) 159 | 160 | if len(temp_list) == 1: 161 | if k == "config_dict": 162 | # Convert config into dict 163 | config_dict = ast.literal_eval(str(temp_list[0])) 164 | log.meta[k] = config_dict 165 | else: 166 | log.meta[k] = temp_list[0] 167 | else: 168 | log.meta[k] = temp_list 169 | 170 | return log 171 | -------------------------------------------------------------------------------- /mle_logging/merge/aggregate.py: -------------------------------------------------------------------------------- 1 | from dotmap import DotMap 2 | import numpy as np 3 | from typing import List, Tuple, Any 4 | 5 | 6 | def aggregate_over_seeds( 7 | result_dict: DotMap, batch_case: bool = False 8 | ) -> DotMap: 9 | """Mean all individual runs over their respective seeds. 10 | BATCH EVAL CASE: 11 | IN: {'b_1_eval_0': {'seed_0': {'meta': {}, 'stats': {}, 'time': {}} 12 | 'seed_1': {'meta': {}, 'stats': {}, 'time': {}}, 13 | ...} 14 | OUT: {'b_1_eval_0': {'meta': {}, 'stats': {}, 'time': {}, 15 | 'b_1_eval_1': {'meta': {}, 'stats': {}, 'time': {}} 16 | SINGLE EVAL CASE: 17 | IN: {'seed_0': {'meta': {}, 'stats': {}, 'time': {}}, 18 | 'seed_1': {'meta': {}, 'stats': {}, 'time': {}}, 19 | ...} 20 | OUT: {'eval': {'meta': {}, 'stats': {}, 'time': {}} 21 | """ 22 | all_runs = list(result_dict.keys()) 23 | if batch_case: 24 | # Perform seed aggregation for all evaluations 25 | new_results_dict = aggregate_batch_evals(result_dict, all_runs) 26 | else: 27 | new_results_dict = aggregate_single_eval(result_dict, all_runs, "eval") 28 | return DotMap(new_results_dict, _dynamic=False) 29 | 30 | 31 | def aggregate_single_eval( 32 | result_dict: dict, all_seeds_for_run: list, eval_name: str 33 | ) -> dict: 34 | """Mean over seeds of single config run.""" 35 | new_results_dict = {} 36 | data_temp = result_dict[all_seeds_for_run[0]] 37 | # Get all main data source keys ("meta", "stats", "time") 38 | data_sources = list(data_temp.keys()) 39 | # Get all variables within the data sources 40 | data_items = { 41 | data_sources[i]: list(data_temp[data_sources[i]].keys()) 42 | for i in range(len(data_sources)) 43 | } 44 | # Collect all runs together - data at this point is not modified 45 | source_to_store = {key: {} for key in data_sources} 46 | for ds in data_sources: 47 | data_to_store = {key: [] for key in data_items[ds]} 48 | for i, o_name in enumerate(data_items[ds]): 49 | for i, seed_id in enumerate(all_seeds_for_run): 50 | seed_run = result_dict[seed_id] 51 | try: 52 | data_to_store[o_name].append(seed_run[ds][o_name][:]) 53 | except TypeError: 54 | pass 55 | source_to_store[ds] = data_to_store 56 | new_results_dict[eval_name] = source_to_store 57 | 58 | # Aggregate over the collected runs 59 | aggregate_sources = {key: {} for key in data_sources} 60 | for ds in data_sources: 61 | if ds in ["time"]: 62 | aggregate_dict = {key: {} for key in data_items[ds]} 63 | for i, o_name in enumerate(data_items[ds]): 64 | aggregate_dict[o_name] = new_results_dict[eval_name][ds][ 65 | o_name 66 | ][0] 67 | # Mean over stats data 68 | elif ds in ["stats"]: 69 | aggregate_dict = {key: {} for key in data_items[ds]} 70 | for i, o_name in enumerate(data_items[ds]): 71 | if type(new_results_dict[eval_name][ds][o_name][0][0]) not in [ 72 | str, 73 | bytes, 74 | np.bytes_, 75 | np.str_, 76 | ]: 77 | # Compute mean and standard deviation over seeds 78 | mean_tol, std_tol = tolerant_mean( 79 | new_results_dict[eval_name][ds][o_name] 80 | ) 81 | aggregate_dict[o_name]["mean"] = mean_tol 82 | aggregate_dict[o_name]["std"] = std_tol 83 | 84 | # Compute 10, 25, 50, 75, 90 percentiles over seeds 85 | p50, p10, p25, p75, p90 = tolerant_median( 86 | new_results_dict[eval_name][ds][o_name] 87 | ) 88 | aggregate_dict[o_name]["p50"] = p50 89 | aggregate_dict[o_name]["p10"] = p10 90 | aggregate_dict[o_name]["p25"] = p25 91 | aggregate_dict[o_name]["p75"] = p75 92 | aggregate_dict[o_name]["p90"] = p90 93 | else: 94 | aggregate_dict[o_name] = new_results_dict[eval_name][ds][ 95 | o_name 96 | ] 97 | # Append over all meta data (strings, seeds nothing to mean) 98 | elif ds == "meta": 99 | aggregate_dict = {} 100 | for i, o_name in enumerate(data_items[ds]): 101 | temp = np.array( 102 | new_results_dict[eval_name][ds][o_name], 103 | dtype=object, 104 | ).squeeze() 105 | 106 | # Get rid of duplicate experiment dir strings 107 | if o_name in [ 108 | "experiment_dir", 109 | "eval_id", 110 | "config_fname", 111 | "model_type", 112 | "config_dict", 113 | ]: 114 | aggregate_dict[o_name] = np.unique(temp)[0].decode() 115 | else: 116 | aggregate_dict[o_name] = temp 117 | 118 | # Add seeds as clean array of integers to dict 119 | aggregate_dict["seeds"] = [ 120 | int(s.split("_")[1]) for s in all_seeds_for_run 121 | ] 122 | else: 123 | raise ValueError 124 | aggregate_sources[ds] = aggregate_dict 125 | new_results_dict[eval_name] = aggregate_sources 126 | return new_results_dict 127 | 128 | 129 | def aggregate_batch_evals(result_dict: dict, all_runs: list) -> dict: 130 | """Mean over seeds for all batches and evals.""" 131 | # Loop over all evals (e.g. b_1_eval_0) and merge + aggregate data 132 | new_results_dict = {} 133 | for eval in all_runs: 134 | all_seeds_for_run = list(result_dict[eval].keys()) 135 | eval_dict = aggregate_single_eval( 136 | result_dict[eval], all_seeds_for_run, eval 137 | ) 138 | new_results_dict[eval] = eval_dict[eval] 139 | return new_results_dict 140 | 141 | 142 | def tolerant_mean(arrs: List[Any]) -> Tuple[Any]: 143 | """Helper function for case where data to mean has different lengths.""" 144 | lens = [len(i) for i in arrs] 145 | if len(arrs[0].shape) == 1: 146 | arr = np.ma.empty((np.max(lens), len(arrs))) 147 | arr.mask = True 148 | for idx, l in enumerate(arrs): 149 | arr[: len(l), idx] = l 150 | else: 151 | arr = np.ma.empty((np.max(lens), arrs[0].shape[1], len(arrs))) 152 | arr.mask = True 153 | for idx, l in enumerate(arrs): 154 | arr[: len(l), :, idx] = l 155 | return arr.mean(axis=-1), arr.std(axis=-1) 156 | 157 | 158 | def tolerant_median(arrs: List[Any]) -> Tuple[Any]: 159 | """Helper function for case data to median has different lengths.""" 160 | lens = [len(i) for i in arrs] 161 | if len(arrs[0].shape) == 1: 162 | arr = np.ma.empty((np.max(lens), len(arrs))) 163 | arr.mask = True 164 | for idx, l in enumerate(arrs): 165 | arr[: len(l), idx] = l 166 | else: 167 | arr = np.ma.empty((np.max(lens), arrs[0].shape[1], len(arrs))) 168 | arr.mask = True 169 | for idx, l in enumerate(arrs): 170 | arr[: len(l), :, idx] = l 171 | return ( 172 | np.percentile(arr, 50, axis=-1), 173 | np.percentile(arr, 10, axis=-1), 174 | np.percentile(arr, 25, axis=-1), 175 | np.percentile(arr, 75, axis=-1), 176 | np.percentile(arr, 90, axis=-1), 177 | ) 178 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import numpy as np 4 | import torch.nn as nn 5 | from sklearn.svm import SVC 6 | from mle_logging import MLELogger, load_model, load_log 7 | 8 | 9 | log_config = { 10 | "time_to_track": ["num_updates", "num_epochs"], 11 | "what_to_track": ["train_loss", "test_loss"], 12 | "experiment_dir": "experiment_dir/", 13 | "config_fname": None, 14 | "use_tboard": True, 15 | "model_type": "torch", 16 | } 17 | 18 | time_tic = {"num_updates": 10, "num_epochs": 1} 19 | stats_tic = {"train_loss": 0.1234, "test_loss": 0.1235} 20 | 21 | 22 | class DummyModel(nn.Module): 23 | def __init__(self): 24 | super(DummyModel, self).__init__() 25 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 26 | self.fc2 = nn.Linear(120, 84) 27 | self.fc3 = nn.Linear(84, 10) 28 | 29 | def forward(self, x): 30 | x = self.fc1(x) 31 | x = self.fc2(x) 32 | x = self.fc3(x) 33 | return x 34 | 35 | 36 | def create_tensorflow_model(): 37 | import tensorflow as tf 38 | from tensorflow import keras 39 | 40 | model = tf.keras.models.Sequential( 41 | [ 42 | keras.layers.Dense(512, activation="relu", input_shape=(784,)), 43 | keras.layers.Dropout(0.2), 44 | keras.layers.Dense(10), 45 | ] 46 | ) 47 | model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]) 48 | 49 | return model 50 | 51 | 52 | def test_save_load_torch(): 53 | """Test saving and loading of torch model.""" 54 | # Remove experiment dir at start of test 55 | if os.path.exists(log_config["experiment_dir"]) and os.path.isdir( 56 | log_config["experiment_dir"] 57 | ): 58 | shutil.rmtree(log_config["experiment_dir"]) 59 | 60 | # Instantiate logging to experiment_dir 61 | log_config["model_type"] = "torch" 62 | log = MLELogger(**log_config) 63 | 64 | # Save a torch model 65 | model = DummyModel() 66 | log.update(time_tic, stats_tic, model, save=True) 67 | # Assert the existence of the files 68 | file_to_check = os.path.join( 69 | log_config["experiment_dir"], "models/final", "final_no_seed_provided.pt" 70 | ) 71 | assert os.path.exists(file_to_check) 72 | 73 | # Load log and afterwards the model 74 | relog = load_log(log_config["experiment_dir"]) 75 | remodel = load_model(relog.meta.model_ckpt, log_config["model_type"], model) 76 | assert type(remodel) == DummyModel 77 | # Finally -- clean up 78 | shutil.rmtree(log_config["experiment_dir"]) 79 | 80 | 81 | # def test_save_load_tf(): 82 | # """Test saving and loading of tensorflow model.""" 83 | # # Remove experiment dir at start of test 84 | # if os.path.exists(log_config["experiment_dir"]) and os.path.isdir( 85 | # log_config["experiment_dir"] 86 | # ): 87 | # shutil.rmtree(log_config["experiment_dir"]) 88 | 89 | # # Instantiate logging to experiment_dir 90 | # log_config["model_type"] = "tensorflow" 91 | # log = MLELogger(**log_config) 92 | 93 | # # Save a torch model 94 | # model = create_tensorflow_model() 95 | # log.update(time_tic, stats_tic, model, save=True) 96 | # # Assert the existence of the files 97 | # file_to_check = os.path.join( 98 | # log_config["experiment_dir"], 99 | # "models/final", 100 | # "final_no_seed_provided.pt" + ".data-00000-of-00001", 101 | # ) 102 | # assert os.path.exists(file_to_check) 103 | # file_to_check = os.path.join( 104 | # log_config["experiment_dir"], 105 | # "models/final", 106 | # "final_no_seed_provided.pt" + ".index", 107 | # ) 108 | # assert os.path.exists(file_to_check) 109 | # file_to_check = os.path.join( 110 | # log_config["experiment_dir"], "models/final", "checkpoint" 111 | # ) 112 | # assert os.path.exists(file_to_check) 113 | 114 | # # Load log and afterwards the model 115 | # relog = load_log(log_config["experiment_dir"]) 116 | # _ = load_model(relog.meta.model_ckpt, log_config["model_type"], model) 117 | 118 | # # Finally -- clean up 119 | # shutil.rmtree(log_config["experiment_dir"]) 120 | 121 | 122 | # def test_save_load_jax(): 123 | # """Test saving and loading of jax model.""" 124 | # # Remove experiment dir at start of test 125 | # if os.path.exists(log_config["experiment_dir"]) and os.path.isdir( 126 | # log_config["experiment_dir"] 127 | # ): 128 | # shutil.rmtree(log_config["experiment_dir"]) 129 | 130 | # # Instantiate logging to experiment_dir 131 | # log_config["model_type"] = "jax" 132 | # log = MLELogger(**log_config) 133 | 134 | # # Save a torch model 135 | # import jax 136 | # import haiku as hk 137 | 138 | # def lenet_fn(x): 139 | # """Standard LeNet-300-100 MLP network.""" 140 | # mlp = hk.Sequential( 141 | # [ 142 | # hk.Flatten(), 143 | # hk.Linear(300), 144 | # jax.nn.relu, 145 | # hk.Linear(100), 146 | # jax.nn.relu, 147 | # hk.Linear(10), 148 | # ] 149 | # ) 150 | # return mlp(x) 151 | 152 | # lenet = hk.without_apply_rng(hk.transform(lenet_fn)) 153 | # params = lenet.init(jax.random.PRNGKey(42), np.zeros((32, 784))) 154 | 155 | # log.update(time_tic, stats_tic, params, save=True) 156 | # # Assert the existence of the files 157 | # file_to_check = os.path.join( 158 | # log_config["experiment_dir"], "models/final", "final_no_seed_provided.pkl" 159 | # ) 160 | # assert os.path.exists(file_to_check) 161 | 162 | # # Load log and afterwards the model 163 | # relog = load_log(log_config["experiment_dir"]) 164 | # _ = load_model(relog.meta.model_ckpt, log_config["model_type"]) 165 | 166 | # # Finally -- clean up 167 | # shutil.rmtree(log_config["experiment_dir"]) 168 | 169 | 170 | def test_save_load_sklearn(): 171 | """Test saving and loading of sklearn model.""" 172 | # Remove experiment dir at start of test 173 | if os.path.exists(log_config["experiment_dir"]) and os.path.isdir( 174 | log_config["experiment_dir"] 175 | ): 176 | shutil.rmtree(log_config["experiment_dir"]) 177 | 178 | # Instantiate logging to experiment_dir 179 | log_config["model_type"] = "sklearn" 180 | log = MLELogger(**log_config) 181 | 182 | # Save a torch model 183 | model = SVC(gamma="auto") 184 | log.update(time_tic, stats_tic, model, save=True) 185 | 186 | # Assert the existence of the files 187 | file_to_check = os.path.join( 188 | log_config["experiment_dir"], "models/final", "final_no_seed_provided.pkl" 189 | ) 190 | assert os.path.exists(file_to_check) 191 | 192 | # Load log and afterwards the model 193 | relog = load_log(log_config["experiment_dir"]) 194 | remodel = load_model(relog.meta.model_ckpt, log_config["model_type"], model) 195 | assert type(remodel) == SVC 196 | # Finally -- clean up 197 | shutil.rmtree(log_config["experiment_dir"]) 198 | 199 | 200 | def test_save_load_numpy(): 201 | """Test saving and loading of numpy model/array.""" 202 | # Remove experiment dir at start of test 203 | if os.path.exists(log_config["experiment_dir"]) and os.path.isdir( 204 | log_config["experiment_dir"] 205 | ): 206 | shutil.rmtree(log_config["experiment_dir"]) 207 | 208 | # Instantiate logging to experiment_dir 209 | log_config["model_type"] = "numpy" 210 | log = MLELogger(**log_config) 211 | 212 | # Save a torch model 213 | model = np.array([1, 2, 3, 4]) 214 | log.update(time_tic, stats_tic, model, save=True) 215 | 216 | # Assert the existence of the files 217 | file_to_check = os.path.join( 218 | log_config["experiment_dir"], "models/final", "final_no_seed_provided.pkl" 219 | ) 220 | assert os.path.exists(file_to_check) 221 | 222 | # Load log and afterwards the model 223 | relog = load_log(log_config["experiment_dir"]) 224 | remodel = load_model(relog.meta.model_ckpt, log_config["model_type"], model) 225 | assert (remodel == model).all() 226 | # Finally -- clean up 227 | shutil.rmtree(log_config["experiment_dir"]) 228 | -------------------------------------------------------------------------------- /tests/test_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import collections 4 | import numpy as np 5 | import torch.nn as nn 6 | import matplotlib.pyplot as plt 7 | from mle_logging import MLELogger 8 | from mle_logging import merge_seed_logs, load_log 9 | from mle_logging import merge_config_logs 10 | 11 | 12 | log_config = { 13 | "time_to_track": ["num_updates", "num_epochs"], 14 | "what_to_track": ["train_loss", "test_loss"], 15 | "experiment_dir": "experiment_dir/", 16 | "model_type": "torch", 17 | } 18 | 19 | log_config1_seed1 = { 20 | "time_to_track": ["num_updates", "num_epochs"], 21 | "what_to_track": ["train_loss", "test_loss"], 22 | "experiment_dir": "experiment_dir/", 23 | "config_fname": "examples/config_1.json", 24 | "model_type": "torch", 25 | "seed_id": "seed_1", 26 | } 27 | 28 | log_config1_seed2 = { 29 | "time_to_track": ["num_updates", "num_epochs"], 30 | "what_to_track": ["train_loss", "test_loss"], 31 | "experiment_dir": "experiment_dir/", 32 | "config_fname": "examples/config_1.json", 33 | "model_type": "torch", 34 | "seed_id": "seed_2", 35 | } 36 | 37 | log_config2_seed1 = { 38 | "time_to_track": ["num_updates", "num_epochs"], 39 | "what_to_track": ["train_loss", "test_loss"], 40 | "experiment_dir": "experiment_dir/", 41 | "config_fname": "examples/config_2.json", 42 | "model_type": "torch", 43 | "seed_id": "seed_1", 44 | } 45 | 46 | log_config2_seed2 = { 47 | "time_to_track": ["num_steps", "num_epochs"], 48 | "what_to_track": ["train_loss", "test_loss"], 49 | "experiment_dir": "experiment_dir/", 50 | "config_fname": "examples/config_2.json", 51 | "model_type": "torch", 52 | "seed_id": "seed_2", 53 | } 54 | 55 | time_tic1 = {"num_steps": 10, "num_epochs": 1} 56 | stats_tic1 = {"train_loss": 0.1234, "test_loss": 0.1235} 57 | 58 | time_tic2 = {"num_steps": 10, "num_epochs": 1} 59 | stats_tic2 = {"train_loss": 0.2, "test_loss": 0.1} 60 | 61 | 62 | class DummyModel(nn.Module): 63 | def __init__(self): 64 | super(DummyModel, self).__init__() 65 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 66 | self.fc2 = nn.Linear(120, 84) 67 | self.fc3 = nn.Linear(84, 10) 68 | 69 | def forward(self, x): 70 | x = self.fc1(x) 71 | x = self.fc2(x) 72 | x = self.fc3(x) 73 | return x 74 | 75 | 76 | model = DummyModel() 77 | 78 | fig, ax = plt.subplots() 79 | ax.plot(np.random.normal(0, 1, 20)) 80 | 81 | some_dict = {"hi": "there"} 82 | 83 | 84 | def test_load_single(): 85 | """Test loading of single seed/config.""" 86 | # Remove experiment dir at start of test 87 | if os.path.exists(log_config["experiment_dir"]) and os.path.isdir( 88 | log_config["experiment_dir"] 89 | ): 90 | shutil.rmtree(log_config["experiment_dir"]) 91 | 92 | # Log some data 93 | log = MLELogger(**log_config) 94 | log.update(time_tic1, stats_tic1, model, 95 | plot_fig=fig, extra_obj=some_dict, save=True) 96 | 97 | # Reload log and check correctness of results 98 | relog = load_log(log_config["experiment_dir"]) 99 | 100 | meta_keys = [ 101 | "config_fname", 102 | "config_dict", 103 | "eval_id", 104 | "experiment_dir", 105 | "extra_storage_paths", 106 | "fig_storage_paths", 107 | "log_paths", 108 | "model_ckpt", 109 | "model_type", 110 | ] 111 | assert collections.Counter(list(relog.meta.keys())) == collections.Counter( 112 | meta_keys 113 | ) 114 | 115 | assert relog.stats.train_loss == 0.1234 116 | assert relog.time.num_steps == 10 117 | assert ( 118 | relog.meta.fig_storage_paths 119 | == "experiment_dir/figures/fig_1_no_seed_provided.png" 120 | ) 121 | assert ( 122 | relog.meta.extra_storage_paths 123 | == "experiment_dir/extra/extra_1_no_seed_provided.pkl" 124 | ) 125 | assert ( 126 | relog.meta.model_ckpt == "experiment_dir/models/final/final_no_seed_provided.pt" 127 | ) 128 | # Finally -- clean up 129 | shutil.rmtree(log_config["experiment_dir"]) 130 | 131 | 132 | def test_merge_load_seeds(): 133 | """Test merging of multiple seeds and loading.""" 134 | if os.path.exists(log_config1_seed1["experiment_dir"]) and os.path.isdir( 135 | log_config1_seed1["experiment_dir"] 136 | ): 137 | shutil.rmtree(log_config1_seed1["experiment_dir"]) 138 | 139 | # Log some data for both seeds 140 | log_seed1 = MLELogger(**log_config1_seed1) 141 | log_seed1.update(time_tic1, stats_tic1, model, 142 | plot_fig=fig, extra_obj=some_dict, save=True) 143 | 144 | log_seed2 = MLELogger(**log_config1_seed2) 145 | log_seed2.update(time_tic2, stats_tic2, model, 146 | plot_fig=fig, extra_obj=some_dict, save=True) 147 | 148 | experiment_dir = log_config["experiment_dir"] + "config_1/" 149 | merged_path = os.path.join(experiment_dir, "logs", "config_1.hdf5") 150 | 151 | # Merge different random seeds into one .hdf5 file 152 | merge_seed_logs(merged_path, experiment_dir) 153 | assert os.path.exists(os.path.join(experiment_dir, "config_1.json")) 154 | 155 | # Load the merged log - Individual seeds can be accessed via log.seed_1, etc. 156 | log = load_log(experiment_dir) 157 | assert log.seed_1.stats.train_loss == 0.1234 158 | assert log.seed_2.stats.train_loss == 0.2 159 | 160 | # Load the merged & aggregated log 161 | log = load_log(experiment_dir, aggregate_seeds=True).eval 162 | assert np.isclose(log.stats.train_loss.mean[0], np.mean([0.1234, 0.2])) 163 | assert np.isclose(log.stats.train_loss.std[0], np.std([0.1234, 0.2])) 164 | 165 | # Finally -- clean up 166 | shutil.rmtree(log_config1_seed1["experiment_dir"]) 167 | 168 | 169 | def test_merge_load_configs(): 170 | """Test merging of multiple configs and loading.""" 171 | if os.path.exists(log_config1_seed1["experiment_dir"]) and os.path.isdir( 172 | log_config1_seed1["experiment_dir"] 173 | ): 174 | shutil.rmtree(log_config1_seed1["experiment_dir"]) 175 | 176 | # Log some data for both seeds and both configs 177 | log_c1_s1 = MLELogger(**log_config1_seed1) 178 | log_c1_s2 = MLELogger(**log_config1_seed2) 179 | log_c2_s1 = MLELogger(**log_config2_seed1) 180 | log_c2_s2 = MLELogger(**log_config2_seed2) 181 | log_c1_s1.update(time_tic1, stats_tic1, model, fig, some_dict, save=True) 182 | log_c1_s2.update(time_tic2, stats_tic2, model, fig, some_dict, save=True) 183 | log_c2_s1.update(time_tic1, stats_tic1, model, fig, some_dict, save=True) 184 | log_c2_s2.update(time_tic2, stats_tic2, model, fig, some_dict, save=True) 185 | 186 | # Merge different random seeds for each config into separate .hdf5 file 187 | merge_seed_logs( 188 | f"{log_config1_seed1['experiment_dir']}/config_1/logs/config_1.hdf5", 189 | f"{log_config1_seed1['experiment_dir']}/config_1/", 190 | ) 191 | merge_seed_logs( 192 | f"{log_config1_seed1['experiment_dir']}/config_2/logs/config_2.hdf5", 193 | f"{log_config1_seed1['experiment_dir']}/config_2/", 194 | ) 195 | 196 | # Aggregate the different merged configuration .hdf5 files into single meta log 197 | eval_ids = ["config_1", "config_2"] 198 | seed_ids = ["seed_1", "seed_2"] 199 | merge_config_logs( 200 | experiment_dir=f"{log_config1_seed1['experiment_dir']}", all_run_ids=eval_ids 201 | ) 202 | meta_path = f"{log_config1_seed1['experiment_dir']}/meta_log.hdf5" 203 | meta_log = load_log(meta_path, aggregate_seeds=True) 204 | 205 | assert collections.Counter(meta_log.eval_ids) == collections.Counter(eval_ids) 206 | 207 | aggreg_keys = ["mean", "std", "p50", "p10", "p25", "p75", "p90"] 208 | assert collections.Counter( 209 | list(meta_log.config_1.stats.test_loss.keys()) 210 | ) == collections.Counter(aggreg_keys) 211 | 212 | meta_log = load_log(meta_path, aggregate_seeds=False) 213 | assert collections.Counter(meta_log.eval_ids) == collections.Counter(eval_ids) 214 | assert collections.Counter(meta_log.config_1.keys()) == collections.Counter( 215 | seed_ids 216 | ) 217 | # Finally -- clean up 218 | shutil.rmtree(log_config1_seed1["experiment_dir"]) 219 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Lightweight Logger for ML Experiments 📖 2 | [![Pyversions](https://img.shields.io/pypi/pyversions/mle-logging.svg?style=flat-square)](https://pypi.python.org/pypi/mle-logging) 3 | [![PyPI version](https://badge.fury.io/py/mle-logging.svg)](https://badge.fury.io/py/mle-logging) 4 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 5 | [![codecov](https://codecov.io/gh/mle-infrastructure/mle-logging/branch/main/graph/badge.svg)](https://codecov.io/gh/mle-infrastructure/mle-logging) 6 | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mle-infrastructure/mle-logging/blob/main/examples/getting_started.ipynb) 7 | 8 | 9 | Simple logging of statistics, model checkpoints, plots and other objects for your Machine Learning Experiments (MLE). Furthermore, the `MLELogger` comes with smooth multi-seed result aggregation and combination of multi-configuration runs. For a quickstart check out the [notebook blog](https://github.com/mle-infrastructure/mle-logging/blob/main/examples/getting_started.ipynb) 🚀 10 | 11 | ## The API 🎮 12 | 13 | ```python 14 | from mle_logging import MLELogger 15 | 16 | # Instantiate logging to experiment_dir 17 | log = MLELogger(time_to_track=['num_updates', 'num_epochs'], 18 | what_to_track=['train_loss', 'test_loss'], 19 | experiment_dir="experiment_dir/", 20 | model_type='torch') 21 | 22 | time_tic = {'num_updates': 10, 'num_epochs': 1} 23 | stats_tic = {'train_loss': 0.1234, 'test_loss': 0.1235} 24 | 25 | # Update the log with collected data & save it to .hdf5 26 | log.update(time_tic, stats_tic) 27 | log.save() 28 | ``` 29 | 30 | You can also log model checkpoints, matplotlib figures and other `.pkl` compatible objects. 31 | 32 | ```python 33 | # Save a model (torch, tensorflow, sklearn, jax, numpy) 34 | import torchvision.models as models 35 | model = models.resnet18() 36 | log.save_model(model) 37 | 38 | # Save a matplotlib figure as .png 39 | fig, ax = plt.subplots() 40 | log.save_plot(fig) 41 | 42 | # You can also save (somewhat) arbitrary objects .pkl 43 | some_dict = {"hi" : "there"} 44 | log.save_extra(some_dict) 45 | ``` 46 | 47 | 48 | Or do everything in a single line... 49 | ```python 50 | log.update(time_tic, stats_tic, model, fig, extra, save=True) 51 | ``` 52 | 53 | ### File Structure & Re-Loading 📚 54 | 55 | ![](https://github.com/mle-infrastructure/mle-logging/blob/main/docs/mle_logger_structure.png?raw=true) 56 | 57 | The `MLELogger` will create a nested directory, which looks as follows: 58 | 59 | ``` 60 | experiment_dir 61 | ├── extra: Stores saved .pkl object files 62 | ├── figures: Stores saved .png figures 63 | ├── logs: Stores .hdf5 log files (meta, stats, time) 64 | ├── models: Stores different model checkpoints 65 | ├── init: Stores initial checkpoint 66 | ├── final: Stores most recent checkpoint 67 | ├── every_k: Stores every k-th checkpoint provided in update 68 | ├── top_k: Stores portfolio of top-k checkpoints based on performance 69 | ├── tboards: Stores tensorboards for model checkpointing 70 | ├── .json: Copy of configuration file (if provided) 71 | ``` 72 | 73 | For visualization and post-processing load the results via 74 | ```python 75 | from mle_logging import load_log 76 | log_out = load_log("experiment_dir/") 77 | 78 | # The results can be accessed via meta, stats and time keys 79 | # >>> log_out.meta.keys() 80 | # odict_keys(['experiment_dir', 'extra_storage_paths', 'fig_storage_paths', 'log_paths', 'model_ckpt', 'model_type']) 81 | # >>> log_out.stats.keys() 82 | # odict_keys(['test_loss', 'train_loss']) 83 | # >>> log_out.time.keys() 84 | # odict_keys(['time', 'num_epochs', 'num_updates', 'time_elapsed']) 85 | ``` 86 | 87 | If an experiment was aborted, you can reload and continue the previous run via the `reload=True` option: 88 | 89 | ```python 90 | log = MLELogger(time_to_track=['num_updates', 'num_epochs'], 91 | what_to_track=['train_loss', 'test_loss'], 92 | experiment_dir="experiment_dir/", 93 | model_type='torch', 94 | reload=True) 95 | ``` 96 | 97 | ## Installation ⏳ 98 | 99 | A PyPI installation is available via: 100 | 101 | ``` 102 | pip install mle-logging 103 | ``` 104 | 105 | If you want to get the most recent commit, please install directly from the repository: 106 | 107 | ``` 108 | pip install git+https://github.com/mle-infrastructure/mle-logging.git@main 109 | ``` 110 | 111 | 112 | ## Advanced Options 🚴 113 | 114 | ### Merging Multiple Logs 👫 115 | 116 | **Merging Multiple Random Seeds** 🌱 + 🌱 117 | 118 | ```python 119 | from mle_logging import merge_seed_logs 120 | merge_seed_logs("multi_seed.hdf", "experiment_dir/") 121 | log_out = load_log("experiment_dir/") 122 | # >>> log.eval_ids 123 | # ['seed_1', 'seed_2'] 124 | ``` 125 | 126 | **Merging Multiple Configurations** 🔖 + 🔖 127 | 128 | ```python 129 | from mle_logging import merge_config_logs, load_meta_log 130 | merge_config_logs(experiment_dir="experiment_dir/", 131 | all_run_ids=["config_1", "config_2"]) 132 | meta_log = load_meta_log("multi_config_dir/meta_log.hdf5") 133 | # >>> log.eval_ids 134 | # ['config_2', 'config_1'] 135 | # >>> meta_log.config_1.stats.test_loss.keys() 136 | # odict_keys(['mean', 'std', 'p50', 'p10', 'p25', 'p75', 'p90'])) 137 | ``` 138 | 139 | 140 | ### Plotting of Logs 🧑‍🎨 141 | 142 | ```python 143 | meta_log = load_meta_log("multi_config_dir/meta_log.hdf5") 144 | meta_log.plot("train_loss", "num_updates") 145 | ``` 146 | 147 | ### Storing Checkpoint Portfolios 📂 148 | 149 | **Logging every k-th checkpoint update** ❗ ⏩ ... ⏩ ❗ 150 | 151 | ```python 152 | # Save every second checkpoint provided in log.update (stored in models/every_k) 153 | log = MLELogger(time_to_track=['num_updates', 'num_epochs'], 154 | what_to_track=['train_loss', 'test_loss'], 155 | experiment_dir='every_k_dir/', 156 | model_type='torch', 157 | ckpt_time_to_track='num_updates', 158 | save_every_k_ckpt=2) 159 | ``` 160 | 161 | **Logging top-k checkpoints based on metric** 🔱 162 | 163 | ```python 164 | # Save top-3 checkpoints provided in log.update (stored in models/top_k) 165 | # Based on minimizing the test_loss metric 166 | log = MLELogger(time_to_track=['num_updates', 'num_epochs'], 167 | what_to_track=['train_loss', 'test_loss'], 168 | experiment_dir="top_k_dir/", 169 | model_type='torch', 170 | ckpt_time_to_track='num_updates', 171 | save_top_k_ckpt=3, 172 | top_k_metric_name="test_loss", 173 | top_k_minimize_metric=True) 174 | ``` 175 | 176 | 177 | ### Weights&Biases Backend Integration 🧑‍🎨 178 | 179 | You can also use W&B as a backend for logging. All results are stored as before but additionally we report to the W&B server: 180 | 181 | ```python 182 | # Provide all configuration details as option 183 | log = MLELogger(time_to_track=['num_updates', 'num_epochs'], 184 | what_to_track=['train_loss', 'test_loss'], 185 | use_wandb=True, 186 | wandb_config={ 187 | "key": "sadfasd", # Only needed if not logged in 188 | "entity": "roberttlange", # Only needed if not logged in 189 | "project": "some-project-name", 190 | "group": "some-group-name" 191 | }) 192 | ``` 193 | 194 | ### Citing the MLE-Infrastructure ✏️ 195 | 196 | If you use `mle-logging` in your research, please cite it as follows: 197 | 198 | ``` 199 | @software{mle_infrastructure2021github, 200 | author = {Robert Tjarko Lange}, 201 | title = {{MLE-Infrastructure}: A Set of Lightweight Tools for Distributed Machine Learning Experimentation}, 202 | url = {http://github.com/mle-infrastructure}, 203 | year = {2021}, 204 | } 205 | ``` 206 | 207 | ## Development 👷 208 | 209 | You can run the test suite via `python -m pytest -vv tests/`. If you find a bug or are missing your favourite feature, feel free to create an issue and/or start [contributing](CONTRIBUTING.md) 🤗. 210 | -------------------------------------------------------------------------------- /mle_logging/save/model_log.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from typing import Union, List 4 | from ..utils import save_pkl_object 5 | from ..load import load_log 6 | 7 | 8 | class ModelLog(object): 9 | """Model Logger Class Instance.""" 10 | 11 | def __init__( 12 | self, 13 | experiment_dir: str = "/", 14 | seed_id: str = "no_seed_provided", 15 | model_type: str = "no-model-type-provided", 16 | ckpt_time_to_track: Union[str, None] = None, 17 | save_every_k_ckpt: Union[int, None] = None, 18 | save_top_k_ckpt: Union[int, None] = None, 19 | top_k_metric_name: Union[str, None] = None, 20 | top_k_minimize_metric: Union[bool, None] = None, 21 | reload: bool = False, 22 | ): 23 | # Setup model logging 24 | self.experiment_dir = experiment_dir 25 | assert model_type in [ 26 | "torch", 27 | "tensorflow", 28 | "jax", 29 | "sklearn", 30 | "numpy", 31 | "no-model-type", 32 | ] 33 | self.model_type = model_type 34 | self.save_every_k_ckpt = save_every_k_ckpt 35 | self.save_top_k_ckpt = save_top_k_ckpt 36 | self.ckpt_time_to_track = ckpt_time_to_track 37 | self.top_k_metric_name = top_k_metric_name 38 | self.top_k_minimize_metric = top_k_minimize_metric 39 | self.seed_id = seed_id 40 | 41 | if self.save_every_k_ckpt: 42 | assert self.ckpt_time_to_track is not None 43 | if self.save_top_k_ckpt: 44 | assert self.ckpt_time_to_track is not None 45 | assert self.top_k_metric_name is not None 46 | assert self.top_k_minimize_metric is not None 47 | 48 | # Create separate filenames for checkpoints & final trained model 49 | self.ckpt_dir = os.path.join(self.experiment_dir, "models/") 50 | self.final_model_save_fname = os.path.join( 51 | self.ckpt_dir, "final", "final_" + seed_id 52 | ) 53 | self.init_model_save_fname = os.path.join( 54 | self.ckpt_dir, "init", "init_" + seed_id 55 | ) 56 | self.init_model_saved = False 57 | if self.save_every_k_ckpt is not None: 58 | self.every_k_ckpt_list: List[str] = [] 59 | self.every_k_dir = os.path.join( 60 | self.experiment_dir, "models/every_k/" 61 | ) 62 | self.every_k_model_save_fname = os.path.join( 63 | self.every_k_dir, "every_k_" + seed_id + "_k_" 64 | ) 65 | if self.save_top_k_ckpt is not None: 66 | self.top_k_ckpt_list: List[str] = [] 67 | self.top_k_dir = os.path.join(self.experiment_dir, "models/top_k/") 68 | self.top_k_model_save_fname = os.path.join( 69 | self.top_k_dir, "top_k_" + seed_id + "_top_" 70 | ) 71 | 72 | # Different extensions to model checkpoints based on model type 73 | if self.model_type in [ 74 | "torch", 75 | "tensorflow", 76 | "jax", 77 | "sklearn", 78 | "numpy", 79 | ]: 80 | if self.model_type in ["torch", "tensorflow"]: 81 | self.model_fname_ext = ".pt" 82 | elif self.model_type in ["jax", "sklearn", "numpy"]: 83 | self.model_fname_ext = ".pkl" 84 | self.final_model_save_fname += self.model_fname_ext 85 | self.init_model_save_fname += self.model_fname_ext 86 | 87 | # Initialize counter & lists for top k scores and storage time to track 88 | if reload: 89 | self.reload() 90 | else: 91 | self.model_save_counter = 0 92 | if self.save_every_k_ckpt is not None: 93 | self.every_k_storage_time: List[int] = [] 94 | if self.save_top_k_ckpt is not None: 95 | self.top_k_performance: List[float] = [] 96 | self.top_k_storage_time: List[int] = [] 97 | 98 | def setup_model_ckpt_dir(self): 99 | """Create separate sub-dirs for checkpoints & final trained model.""" 100 | os.makedirs(self.ckpt_dir, exist_ok=True) 101 | if self.save_every_k_ckpt is not None: 102 | os.makedirs(self.every_k_dir, exist_ok=True) 103 | if self.save_top_k_ckpt is not None: 104 | os.makedirs(self.top_k_dir, exist_ok=True) 105 | 106 | def save( 107 | self, model, clock_tracked: dict, stats_tracked: dict 108 | ): # noqa: C901 109 | """Save current state of the model as a checkpoint.""" 110 | # If first model ckpt is saved - generate necessary directories 111 | self.model_save_counter += 1 112 | if self.model_save_counter == 1: 113 | os.makedirs(os.path.join(self.ckpt_dir, "final"), exist_ok=True) 114 | self.stored_every_k = False 115 | self.stored_top_k = False 116 | if self.model_save_counter == 1: 117 | self.setup_model_ckpt_dir() 118 | 119 | # CASE 1: SIMPLE STORAGE OF MOST RECENTLY LOGGED MODEL STATE 120 | self.save_final_model(model) 121 | 122 | # CASE 2: SEPARATE STORAGE OF EVERY K-TH LOGGED MODEL STATE 123 | if self.save_every_k_ckpt is not None: 124 | self.save_every_k_model(model, clock_tracked) 125 | 126 | # CASE 3: STORE TOP-K MODEL STATES BY SOME SCORE 127 | if self.save_top_k_ckpt is not None: 128 | self.save_top_k_model(model, clock_tracked, stats_tracked) 129 | 130 | def save_init_model(self, model): 131 | """Store the initial model checkpoint and replace old ckpt.""" 132 | os.makedirs(os.path.join(self.ckpt_dir, "init"), exist_ok=True) 133 | save_model_ckpt(model, self.init_model_save_fname, self.model_type) 134 | self.init_model_saved = True 135 | 136 | def save_final_model(self, model): 137 | """Store the most recent model checkpoint and replace old ckpt.""" 138 | save_model_ckpt(model, self.final_model_save_fname, self.model_type) 139 | 140 | def save_every_k_model(self, model, clock_tracked: dict): 141 | """Store every kth provided checkpoint.""" 142 | if self.model_save_counter % self.save_every_k_ckpt == 0: 143 | ckpt_path = ( 144 | self.every_k_model_save_fname 145 | + str(self.model_save_counter) 146 | + self.model_fname_ext 147 | ) 148 | save_model_ckpt(model, ckpt_path, self.model_type) 149 | # Use latest update performance for last checkpoint 150 | time = clock_tracked[self.ckpt_time_to_track][-1] 151 | self.every_k_storage_time.append(time) 152 | self.every_k_ckpt_list.append(ckpt_path) 153 | self.stored_every_k = True 154 | 155 | def save_top_k_model(self, model, clock_tracked: dict, stats_tracked: dict): 156 | """Store top-k checkpoints by performance.""" 157 | # Use latest update performance for last checkpoint 158 | score = stats_tracked[self.top_k_metric_name][-1] 159 | time = clock_tracked[self.ckpt_time_to_track][-1] 160 | 161 | # Fill up empty top k slots 162 | if len(self.top_k_performance) < self.save_top_k_ckpt: 163 | ckpt_path = ( 164 | self.top_k_model_save_fname 165 | + str(len(self.top_k_performance)) 166 | + self.model_fname_ext 167 | ) 168 | save_model_ckpt(model, ckpt_path, self.model_type) 169 | self.top_k_performance.append(score) 170 | self.top_k_storage_time.append(time) 171 | self.top_k_ckpt_list.append(ckpt_path) 172 | self.stored_top_k = True 173 | return 174 | 175 | # If minimize = replace worst performing model (max score) 176 | # Note: The archive of checkpoints is not sorted by performance! 177 | if not self.top_k_minimize_metric: 178 | top_k_scores = [-1 * s for s in self.top_k_performance] 179 | score_to_eval = -1 * score 180 | else: 181 | top_k_scores = [s for s in self.top_k_performance] 182 | score_to_eval = score 183 | if max(top_k_scores) > score_to_eval: 184 | id_to_replace = np.argmax(top_k_scores) 185 | self.top_k_performance[id_to_replace] = score 186 | self.top_k_storage_time[id_to_replace] = time 187 | ckpt_path = ( 188 | self.top_k_model_save_fname 189 | + str(id_to_replace) 190 | + self.model_fname_ext 191 | ) 192 | save_model_ckpt(model, ckpt_path, self.model_type) 193 | self.stored_top_k = True 194 | 195 | def reload(self): 196 | """Reload results from previous experiment run.""" 197 | reloaded_log = load_log(self.experiment_dir, 198 | aggregate_seeds=False, 199 | reload_log=True) 200 | # Make sure to reload in results for correct seed 201 | if reloaded_log.eval_ids[0] == "no_seed_provided": 202 | meta_data = reloaded_log["no_seed_provided"].meta 203 | else: 204 | meta_data = reloaded_log[self.seed_id].meta 205 | 206 | # Reload counter & lists for top k scores and storage time to track 207 | if self.save_every_k_ckpt is not None: 208 | if type(meta_data.every_k_ckpt_list) == list: 209 | self.every_k_ckpt_list = [ 210 | ck for ck in meta_data.every_k_ckpt_list 211 | ] 212 | self.every_k_storage_time = meta_data.every_k_storage_time 213 | else: 214 | self.every_k_ckpt_list = [meta_data.every_k_ckpt_list] 215 | self.every_k_storage_time = [meta_data.every_k_storage_time] 216 | 217 | self.model_save_counter = int( 218 | self.every_k_ckpt_list[-1].split(self.model_fname_ext)[0][-1] 219 | ) 220 | else: 221 | self.model_save_counter = 0 222 | if self.save_top_k_ckpt is not None: 223 | self.top_k_ckpt_list = [ck for ck in meta_data.top_k_ckpt_list] 224 | self.top_k_storage_time = meta_data.top_k_storage_time 225 | self.top_k_performance = meta_data.top_k_performance 226 | 227 | 228 | def save_model_ckpt(model, model_save_fname: str, model_type: str) -> None: 229 | """Save the most recent model checkpoint.""" 230 | if model_type == "torch": 231 | # Torch model case - save model state dict as .pt checkpoint 232 | save_torch_model(model_save_fname, model) 233 | elif model_type == "tensorflow": 234 | model.save_weights(model_save_fname) 235 | elif model_type in ["jax", "sklearn", "numpy"]: 236 | # JAX/sklearn save parameter dict/model as dictionary 237 | save_pkl_object(model, model_save_fname) 238 | else: 239 | raise ValueError( 240 | "Provide valid model_type [torch, jax, sklearn, numpy]." 241 | ) 242 | 243 | 244 | def save_torch_model(path_to_store: str, model) -> None: 245 | """Store a torch checkpoint for a model.""" 246 | try: 247 | import torch 248 | except ModuleNotFoundError as err: 249 | raise ModuleNotFoundError( 250 | f"{err}. You need to install " 251 | "`torch` if you want to save a model " 252 | "checkpoint." 253 | ) 254 | # Update the saved weights in a single file! 255 | torch.save(model.state_dict(), path_to_store) 256 | 257 | 258 | def save_tensorflow_model(path_to_store: str, model) -> None: 259 | """Store a tensorflow checkpoint for a model.""" 260 | model.save_weights(path_to_store) 261 | -------------------------------------------------------------------------------- /mle_logging/utils/comms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Union, List, Dict 3 | from rich.console import Console 4 | from rich.panel import Panel 5 | from rich.table import Table 6 | from rich import box 7 | import datetime 8 | from .._version import __version__ 9 | 10 | 11 | console_width = 80 12 | 13 | 14 | def print_welcome() -> None: 15 | """Display header with clock and general logging configurations.""" 16 | welcome_ascii = r""" 17 | __ __ __ ______ __ ______ ______ 18 | /\ "-./ \/\ \ /\ ___\/\ \ /\ __ \/\ ___\ 19 | \ \ \-./\ \ \ \___\ \ __\ \ \___\ \ \/\ \ \ \__ \ 20 | \ \_\ \ \_\ \_____\ \_____\ \_____\ \_____\ \_____\ 21 | \/_/ \/_/\/_____/\/_____/\/_____/\/_____/\/_____/ 22 | """.splitlines() 23 | grid = Table.grid(expand=True) 24 | grid.add_column(justify="left") 25 | grid.add_column(justify="right") 26 | grid.add_row( 27 | welcome_ascii[1], 28 | datetime.datetime.now().strftime("%d/%m/%y %H:%M:%S"), 29 | ) 30 | grid.add_row(welcome_ascii[2], f"Logger v{__version__} :lock_with_ink_pen:") 31 | grid.add_row( 32 | welcome_ascii[3], 33 | " [link=https://twitter.com/RobertTLange]@RobertTLange[/link] :bird:", 34 | ) 35 | grid.add_row( 36 | welcome_ascii[4], 37 | " [link=https://github.com/RobertTLange/mle-logging/blob/main/examples/getting_started.ipynb]MLE-Log" 38 | " Docs[/link] [not italic]:notebook:[/]", # noqa: E501 39 | ) 40 | grid.add_row( 41 | welcome_ascii[5], 42 | " [link=https://github.com/RobertTLange/mle-logging/]MLE-Log" 43 | " Repo[/link] [not italic]:pencil:[/]", # noqa: E501 44 | ) 45 | panel = Panel(grid, style="white on blue", expand=True) 46 | Console(width=console_width).print(panel) 47 | 48 | 49 | def print_startup( 50 | experiment_dir: str, 51 | config_fname: Union[str, None], 52 | time_to_track: Union[List[str], None], 53 | what_to_track: Union[List[str], None], 54 | model_type: str, 55 | seed_id: Union[str, int], 56 | use_tboard: bool, 57 | reload: bool, 58 | print_every_k_updates: Union[int, None], 59 | ckpt_time_to_track: Union[str, None], 60 | save_every_k_ckpt: Union[int, None], 61 | save_top_k_ckpt: Union[int, None], 62 | top_k_metric_name: Union[str, None], 63 | top_k_minimize_metric: Union[bool, None], 64 | ) -> None: 65 | """Rich print statement at logger startup. 66 | 67 | Args: 68 | experiment_dir (str): Base experiment directory. 69 | config_fname (Union[str, None]): Name of job configuration. 70 | time_to_track (Union[List[str], None]): Time variable names to store. 71 | what_to_track (Union[List[str], None]): Stats variable names to store. 72 | model_type (str): Model type to store ("jax", "torch", "tensorflow"). 73 | seed_id (Union[str, int]): Random seed used in experiment 74 | use_tboard (bool): Whether to also create tensorboard log. 75 | reload (bool): Whether to use reloaded previous log. 76 | print_every_k_updates (Union[int, None]): How often to print out log update. 77 | ckpt_time_to_track (Union[str, None]): Which time var to log with model ckpt. 78 | save_every_k_ckpt (Union[int, None]): How often to log model ckpt. 79 | save_top_k_ckpt (Union[int, None]): How many top performing ckpts to store. 80 | top_k_metric_name (Union[str, None]): Which stats var to use for performance. 81 | top_k_minimize_metric (Union[bool, None]): Whether to minimize the stats var. 82 | """ 83 | grid = Table.grid(expand=True) 84 | grid.add_column(justify="left") 85 | grid.add_column(justify="left") 86 | 87 | def format_content(title, value): 88 | if type(value) == list: 89 | base = f"[b]{title}[/b]: " 90 | for i, v in enumerate(value): 91 | base += f"{v}" 92 | if i < len(value) - 1: 93 | base += ", " 94 | return base 95 | else: 96 | return f"[b]{title}[/b]: {value}" 97 | 98 | time_to_print = [ 99 | t for t in time_to_track if t not in ["time", "time_elapsed"] 100 | ] 101 | renderables = [ 102 | Panel(format_content(":book: Log Dir", experiment_dir), expand=True), 103 | Panel( 104 | format_content(":page_facing_up: Config", config_fname), expand=True 105 | ), 106 | Panel(format_content(":watch: Time", time_to_print), expand=True), 107 | Panel( 108 | format_content(":chart_with_downwards_trend: Stats", what_to_track), 109 | expand=True, 110 | ), 111 | Panel(format_content(":seedling: Seed ID", seed_id), expand=True), 112 | Panel( 113 | format_content( 114 | ":chart_with_upwards_trend: Tensorboard", use_tboard 115 | ), 116 | expand=True, 117 | ), 118 | Panel(format_content(":rocket: Model", model_type), expand=True), 119 | Panel( 120 | format_content("Tracked ckpt Time", ckpt_time_to_track), expand=True 121 | ), 122 | Panel( 123 | format_content(":clock1130: Every k-th ckpt", save_every_k_ckpt), 124 | expand=True, 125 | ), 126 | Panel( 127 | format_content(":trident: Top k ckpt", save_top_k_ckpt), expand=True 128 | ), 129 | Panel( 130 | format_content("Top k-th metric", top_k_metric_name), expand=True 131 | ), 132 | Panel( 133 | format_content("Top k-th minimization", top_k_minimize_metric), 134 | expand=True, 135 | ), 136 | ] 137 | 138 | grid.add_row(renderables[0], renderables[1]) 139 | grid.add_row(renderables[2], renderables[3]) 140 | grid.add_row(renderables[4], renderables[6]) 141 | if save_every_k_ckpt is None and save_top_k_ckpt is not None: 142 | grid.add_row( 143 | renderables[9], 144 | ) 145 | elif save_every_k_ckpt is not None and save_top_k_ckpt is None: 146 | grid.add_row( 147 | renderables[8], 148 | ) 149 | elif save_every_k_ckpt is not None and save_top_k_ckpt is not None: 150 | grid.add_row(renderables[8], renderables[9]) 151 | # grid.add_row(renderables[10], renderables[11]) 152 | panel = Panel(grid, expand=True) 153 | Console(width=console_width).print(panel) 154 | 155 | 156 | def print_update( 157 | time_to_print: List[str], 158 | what_to_print: List[str], 159 | c_tick: Dict[str, Union[str, float]], 160 | s_tick: Dict[str, float], 161 | print_header: bool, 162 | ) -> None: 163 | """Rich print statement for logger update. 164 | 165 | Args: 166 | time_to_print (List[str]): List of time variable names to print. 167 | what_to_print (List[str]): List of stats variable names to print. 168 | c_tick (Dict[str, Union[str, float]]): Dict of time variable values. 169 | s_tick (Dict[str, float]): Dict of stats variable values. 170 | print_header (bool): Whether to print table header. 171 | """ 172 | table = Table( 173 | show_header=print_header, 174 | row_styles=["none"], 175 | border_style="white", 176 | box=box.SIMPLE, 177 | ) 178 | # Add watch and book emoji 179 | for i, c_label in enumerate(time_to_print): 180 | if i == 0: 181 | table.add_column( 182 | ":watch: [red]" + c_label + "[/red]", 183 | style="red", 184 | width=14, 185 | justify="left", 186 | ) 187 | else: 188 | table.add_column( 189 | "[red]" + c_label + "[/red]", 190 | style="red", 191 | width=12, 192 | justify="center", 193 | ) 194 | for i, c_label in enumerate(what_to_print): 195 | if i == 0: 196 | table.add_column( 197 | ":chart_with_downwards_trend: [blue]" + c_label + "[/blue]", 198 | width=14, 199 | justify="center", 200 | ) 201 | else: 202 | table.add_column( 203 | "[blue]" + c_label + "[/blue]", 204 | width=12, 205 | justify="center", 206 | ) 207 | row_list_time = [] 208 | for c in time_to_print: 209 | if c in c_tick.keys(): 210 | row_list_time.append(c_tick[c]) 211 | else: 212 | row_list_time.append("---") 213 | row_list_stats = [] 214 | for s in what_to_print: 215 | if s in s_tick.keys(): 216 | row_list_stats.append(np.round_(s_tick[s], 3)) 217 | else: 218 | row_list_stats.append("---") 219 | row_list = row_list_time + row_list_stats 220 | row_str_list = [str(v) for v in row_list] 221 | table.add_row(*row_str_list) 222 | 223 | # Print statistics update 224 | Console(width=console_width).print(table, justify="center") 225 | 226 | 227 | def print_reload(experiment_dir: str) -> None: 228 | """Rich print statement for logger reloading. 229 | 230 | Args: 231 | experiment_dir (str): Base experiment directory. 232 | """ 233 | Console().log(f"Reloaded log from {experiment_dir}") 234 | 235 | 236 | def print_storage( 237 | fig_path: Union[str, None] = None, 238 | extra_path: Union[str, None] = None, 239 | init_model_path: Union[str, None] = None, 240 | final_model_path: Union[str, None] = None, 241 | every_k_model_path: Union[str, None] = None, 242 | top_k_model_path: Union[str, None] = None, 243 | print_first: bool = False, 244 | ): 245 | """Rich print statement for object saving log. 246 | 247 | Args: 248 | fig_path (Union[str, None], optional): 249 | Path figure was stored at. Defaults to None. 250 | extra_path (Union[str, None], optional): 251 | Path extra object was stored at. Defaults to None. 252 | init_model_path (Union[str, None], optional): 253 | Path initial model ckpt was stored at. Defaults to None. 254 | final_model_path (Union[str, None], optional): 255 | Path most recent model ckpt was stored at. Defaults to None. 256 | every_k_model_path (Union[str, None], optional): 257 | Path last k-th update model ckpt was stored at. Defaults to None. 258 | top_k_model_path (Union[str, None], optional): 259 | Path top-k model ckpt was stored at. Defaults to None. 260 | print_first (bool, optional): 261 | Whether to always print init/final ckpt path. Defaults to False. 262 | """ 263 | table = Table( 264 | show_header=False, 265 | row_styles=["none"], 266 | border_style="white", 267 | box=box.SIMPLE, 268 | ) 269 | 270 | table.add_column( 271 | "---", 272 | style="red", 273 | width=16, 274 | justify="left", 275 | ) 276 | 277 | table.add_column( 278 | "---", 279 | style="red", 280 | width=64, 281 | justify="left", 282 | ) 283 | 284 | if fig_path is not None: 285 | table.add_row(":envelope_with_arrow: - Figure", f"{fig_path}") 286 | if extra_path is not None: 287 | table.add_row(":envelope_with_arrow: - Extra", f"{extra_path}") 288 | if init_model_path is not None and print_first: 289 | table.add_row(":envelope_with_arrow: - Model", f"{init_model_path}") 290 | if final_model_path is not None and print_first: 291 | table.add_row(":envelope_with_arrow: - Model", f"{final_model_path}") 292 | if every_k_model_path is not None: 293 | table.add_row( 294 | ":envelope_with_arrow: - Every-K", f"{every_k_model_path}" 295 | ) 296 | if top_k_model_path is not None: 297 | table.add_row(":envelope_with_arrow: - Top-K", f"{top_k_model_path}") 298 | 299 | to_print = ( 300 | (fig_path is not None) 301 | + (extra_path is not None) 302 | + (init_model_path is not None and print_first) 303 | + (final_model_path is not None and print_first) 304 | + (every_k_model_path is not None) 305 | + (top_k_model_path is not None) 306 | ) > 0 307 | # Print storage update 308 | if to_print: 309 | Console(width=console_width).print(table, justify="left") 310 | -------------------------------------------------------------------------------- /mle_logging/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Any, Union, List, Tuple 4 | import h5py 5 | import yaml 6 | import commentjson 7 | import numpy as np 8 | import pandas as pd 9 | import re 10 | from dotmap import DotMap 11 | import collections 12 | 13 | if sys.version_info < (3, 8): 14 | # Load with pickle5 for python version compatibility 15 | import pickle5 as pickle 16 | else: 17 | import pickle 18 | 19 | 20 | def save_pkl_object(obj: Any, filename: str) -> None: 21 | """Store objects as pickle files. 22 | 23 | Args: 24 | obj (Any): Object to pickle. 25 | filename (str): File path to store object in. 26 | """ 27 | with open(filename, "wb") as output: 28 | # Overwrites any existing file. 29 | pickle.dump(obj, output, pickle.HIGHEST_PROTOCOL) 30 | 31 | 32 | def load_pkl_object(filename: str) -> Any: 33 | """Reload pickle objects from path. 34 | 35 | Args: 36 | filename (str): File path to load object from. 37 | 38 | Returns: 39 | Any: Reloaded object. 40 | """ 41 | with open(filename, "rb") as input: 42 | obj = pickle.load(input) 43 | return obj 44 | 45 | 46 | def load_config( 47 | config_fname: str, return_dotmap: bool = False 48 | ) -> Union[dict, DotMap]: 49 | """Load JSON/YAML config depending on file ending. 50 | 51 | Args: 52 | config_fname (str): 53 | File path to YAML/JSON configuration file. 54 | return_dotmap (bool, optional): 55 | Option to return dot indexable dictionary. Defaults to False. 56 | 57 | Raises: 58 | ValueError: Only YAML/JSON files can be loaded. 59 | 60 | Returns: 61 | Union[dict, DotMap]: Loaded dictionary from file. 62 | """ 63 | fname, fext = os.path.splitext(config_fname) 64 | if fext == ".yaml": 65 | config = load_yaml_config(config_fname, return_dotmap) 66 | elif fext == ".json": 67 | config = load_json_config(config_fname, return_dotmap) 68 | else: 69 | raise ValueError("Only YAML & JSON configuration can be loaded.") 70 | return config 71 | 72 | 73 | def load_yaml_config( 74 | config_fname: str, return_dotmap: bool = False 75 | ) -> Union[dict, DotMap]: 76 | """Load in YAML config file. 77 | 78 | Args: 79 | config_fname (str): 80 | File path to YAML configuration file. 81 | return_dotmap (bool, optional): 82 | Option to return dot indexable dictionary. Defaults to False. 83 | 84 | Returns: 85 | Union[dict, DotMap]: Loaded dictionary from YAML file. 86 | """ 87 | loader = yaml.SafeLoader 88 | loader.add_implicit_resolver( 89 | "tag:yaml.org,2002:float", 90 | re.compile( 91 | """^(?: 92 | [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? 93 | |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) 94 | |\\.[0-9_]+(?:[eE][-+][0-9]+)? 95 | |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* 96 | |[-+]?\\.(?:inf|Inf|INF) 97 | |\\.(?:nan|NaN|NAN))$""", 98 | re.X, 99 | ), 100 | list("-+0123456789."), 101 | ) 102 | with open(config_fname) as file: 103 | yaml_config = yaml.load(file, Loader=loader) 104 | if not return_dotmap: 105 | return yaml_config 106 | else: 107 | return DotMap(yaml_config) 108 | 109 | 110 | def load_json_config( 111 | config_fname: str, return_dotmap: bool = False 112 | ) -> Union[dict, DotMap]: 113 | """Load in JSON config file. 114 | 115 | Args: 116 | config_fname (str): 117 | File path to JSON configuration file. 118 | return_dotmap (bool, optional): 119 | Option to return dot indexable dictionary. Defaults to False. 120 | 121 | Returns: 122 | Union[dict, DotMap]: Loaded dictionary from JSON file. 123 | """ 124 | json_config = commentjson.loads(open(config_fname, "r").read()) 125 | if not return_dotmap: 126 | return json_config 127 | else: 128 | return DotMap(json_config) 129 | 130 | 131 | def write_to_hdf5( 132 | log_fname: str, log_path: str, data_to_log: Any, dtype: str = "S5000" 133 | ) -> None: 134 | """Writes data to an hdf5 file and specified log path within. 135 | 136 | Args: 137 | log_fname (str): Path of hdf5 file. 138 | log_path (str): Path within hdf5 file to store data at. 139 | data_to_log (Any): Data (array, list, etc.) to store at `log_path` 140 | dtype (str, optional): Data type to store as. Defaults to "S5000". 141 | """ 142 | # Store figure paths if any where created 143 | if dtype == "S5000": 144 | try: 145 | data_to_store = [t.encode("ascii", "ignore") for t in data_to_log] 146 | except AttributeError: 147 | data_to_store = data_to_log 148 | else: 149 | data_to_store = np.array(data_to_log) 150 | 151 | h5f = h5py.File(log_fname, "a") 152 | if h5f.get(log_path): 153 | del h5f[log_path] 154 | h5f.create_dataset( 155 | name=log_path, 156 | data=data_to_store, 157 | compression="gzip", 158 | compression_opts=4, 159 | dtype=dtype, 160 | ) 161 | h5f.flush() 162 | h5f.close() 163 | 164 | 165 | def moving_smooth_ts( 166 | ts, window_size: int = 20 167 | ) -> Tuple[pd.core.series.Series, pd.core.series.Series]: 168 | """Smoothes a time series using a moving average filter. 169 | 170 | Args: 171 | ts: 172 | Time series to smooth. 173 | window_size (int, optional): 174 | Window size to apply for moving average. Defaults to 20. 175 | 176 | Returns: 177 | Tuple[pd.core.series.Series, pd.core.series.Series]: 178 | Smoothed mean and standard deviation of time series. 179 | """ 180 | smooth_df = pd.DataFrame(ts) 181 | mean_ts = smooth_df[0].rolling(window_size, min_periods=1).mean() 182 | std_ts = smooth_df[0].rolling(window_size, min_periods=1).std() 183 | return mean_ts, std_ts 184 | 185 | 186 | def visualize_1D_lcurves( # noqa: C901 187 | main_log: dict, 188 | iter_to_plot: str = "num_updates", 189 | target_to_plot: Union[List[str], str] = "loss", 190 | smooth_window: int = 1, 191 | plot_title: Union[str, None] = None, 192 | xy_labels: Union[List[str], None] = None, 193 | base_label: str = "{}", 194 | curve_labels: list = [], 195 | every_nth_tick: Union[int, None] = None, 196 | plot_std_bar: bool = False, 197 | run_ids: Union[None, List[str]] = None, 198 | rgb_tuples: Union[List[tuple], None] = None, 199 | num_legend_cols: Union[int, None] = 1, 200 | fig=None, 201 | ax=None, 202 | figsize: tuple = (9, 6), 203 | plot_labels: bool = True, 204 | legend_title: Union[None, str] = None, 205 | ax_lims: Union[None, list] = None, 206 | ) -> tuple: 207 | """Plot stats curves over time from meta_log. Select data and customize plot. 208 | 209 | Args: 210 | iter_to_plot (str, optional): 211 | Time variable to plot in log `time`. Defaults to "num_updates". 212 | target_to_plot (Union[List[str], str], optional): 213 | Stats variable to plot in log `stats`. Defaults to "loss". 214 | smooth_window (int, optional): 215 | Time series moving average smoothing window. Defaults to 1. 216 | plot_title (Union[str, None], optional): 217 | Title for plot. Defaults to None. 218 | xy_labels (Union[List[str], None], optional): 219 | List of x & y plot labels. Defaults to None. 220 | base_label (str, optional): 221 | Base start of line labels. Defaults to "{}". 222 | curve_labels (list, optional): 223 | Explicit labels for individual lines. Defaults to []. 224 | every_nth_tick (Union[int, None], optional): 225 | Only plot every nth tick. Leave others out. Defaults to None. 226 | plot_std_bar (bool, optional): 227 | Whether to also plot standard deviation. Defaults to False. 228 | run_ids (Union[None, List[str]], optional): 229 | Explicit string id of runs to plot from log. Defaults to None. 230 | rgb_tuples (Union[List[tuple], None], optional): 231 | Color tuple to use in color palette. Defaults to None. 232 | num_legend_cols (Union[int, None], optional): 233 | Number of columns to split legend in. Defaults to 1. 234 | fig (Union[matplotlib.figure.Figure, None], optional): 235 | Matplotlib figure to modify. Defaults to None. 236 | ax (Union[matplotlib.axes._subplots.AxesSubplot, None], optional): 237 | Matplotlib axis to modify. Defaults to None. 238 | figsize (tuple, optional): 239 | Desired figure size. Defaults to (9, 6). 240 | plot_labels (bool): 241 | Whether to plot curve labels 242 | legend_title (str, optional): 243 | Title of legend. Defaults to None. 244 | ax_lims (list, optional): 245 | Max/min axis range. Defaults to None. 246 | 247 | Returns: 248 | Tuple[matplotlib.figure.Figure, matplotlib.axes._subplots.AxesSubplot]: 249 | Modified matplotlib figure and axis. 250 | """ 251 | 252 | try: 253 | import matplotlib.pyplot as plt 254 | import seaborn as sns 255 | 256 | sns.set( 257 | context="poster", 258 | style="white", 259 | palette="Paired", 260 | font="sans-serif", 261 | font_scale=1.0, 262 | color_codes=True, 263 | rc=None, 264 | ) 265 | except ImportError: 266 | raise ImportError( 267 | "You need to install `matplotlib` & `seaborn` to use `mle-logging`" 268 | " visualization utilities." 269 | ) 270 | 271 | if fig is None or ax is None: 272 | fig, ax = plt.subplots(1, 1, figsize=figsize) 273 | 274 | # Make robust for list/str target variable name input 275 | if type(target_to_plot) is str: 276 | target_to_plot = [target_to_plot] 277 | multi_target = False 278 | else: 279 | multi_target = True 280 | 281 | # If single run - add placeholder key run_id 282 | if run_ids is None: 283 | run_ids = ["ph_run"] 284 | log_to_plot = {"ph_run": main_log} 285 | else: 286 | log_to_plot = main_log 287 | run_ids.sort(key=tokenize) 288 | 289 | # Plot all curves if not subselected 290 | single_level = collections.Counter( 291 | log_to_plot[run_ids[0]].keys() 292 | ) == collections.Counter(["stats", "time", "meta"]) 293 | 294 | # If single seed/aggregated - add placeholder key seed_id 295 | if single_level: 296 | for run_id in run_ids: 297 | log_to_plot[run_id] = {"ph_seed": log_to_plot[run_id]} 298 | seed_ids = ["ph_seed"] 299 | single_seed = True 300 | else: 301 | seed_ids = list(log_to_plot[run_ids[0]].keys()) 302 | single_seed = False 303 | 304 | if len(curve_labels) == 0: 305 | curve_labels = [] 306 | for r_id in run_ids: 307 | for s_id in seed_ids: 308 | for target in target_to_plot: 309 | c_label = f"{r_id}" 310 | if multi_target: 311 | c_label = f"{target}: " + c_label 312 | if not single_seed: 313 | c_label += f"/{s_id}" 314 | curve_labels.append(c_label) 315 | 316 | if rgb_tuples is None: 317 | # Default colormap is blue to red diverging seaborn palette 318 | color_by = sns.diverging_palette( 319 | 240, 10, sep=1, n=len(run_ids) * len(seed_ids) * len(target_to_plot) 320 | ) 321 | # color_by = sns.light_palette("navy", len(run_ids), reverse=False) 322 | else: 323 | color_by = rgb_tuples 324 | 325 | plot_counter = 0 326 | for i in range(len(run_ids)): 327 | run_id = run_ids[i] 328 | for j in range(len(seed_ids)): 329 | seed_id = seed_ids[j] 330 | for target in target_to_plot: 331 | label = curve_labels[plot_counter] 332 | if ( 333 | type(log_to_plot[run_id][seed_id].stats[target]) == dict 334 | or type(log_to_plot[run_id][seed_id].stats[target]) 335 | == DotMap 336 | ): 337 | plot_mean = True 338 | mean_to_plot = log_to_plot[run_id][seed_id].stats[target][ 339 | "mean" 340 | ] 341 | std_to_plot = log_to_plot[run_id][seed_id].stats[target][ 342 | "std" 343 | ] 344 | smooth_std, _ = moving_smooth_ts(std_to_plot, smooth_window) 345 | else: 346 | plot_mean = False 347 | mean_to_plot = log_to_plot[run_id][seed_id].stats[target] 348 | 349 | # Smooth the curve to plot for a specified window (1 = no smoothing) 350 | smooth_mean, _ = moving_smooth_ts(mean_to_plot, smooth_window) 351 | ax.plot( 352 | log_to_plot[run_id][seed_id].time[iter_to_plot], 353 | smooth_mean, 354 | color=color_by[plot_counter], 355 | label=base_label.format(label), 356 | alpha=0.85, 357 | ) 358 | 359 | if plot_std_bar and plot_mean: 360 | ax.fill_between( 361 | log_to_plot[run_id][seed_id].time[iter_to_plot], 362 | smooth_mean - smooth_std, 363 | smooth_mean + smooth_std, 364 | color=color_by[plot_counter], 365 | alpha=0.25, 366 | ) 367 | plot_counter += 1 368 | 369 | full_range_x = log_to_plot[run_id][seed_id].time[iter_to_plot] 370 | # Either plot every nth time tic or 5 equally spaced ones 371 | if every_nth_tick is not None: 372 | ax.set_xticks(full_range_x) 373 | ax.set_xticklabels([str(int(label)) for label in full_range_x]) 374 | for n, label in enumerate(ax.xaxis.get_ticklabels()): 375 | if n % every_nth_tick != 0: 376 | label.set_visible(False) 377 | else: 378 | idx = np.round(np.linspace(0, len(full_range_x) - 1, 5)).astype(int) 379 | range_x = full_range_x[idx] 380 | ax.set_xticks(range_x) 381 | ax.set_xticklabels([str(int(label)) for label in range_x]) 382 | 383 | if len(curve_labels) > 1 and plot_labels: 384 | if legend_title is None: 385 | ax.legend(fontsize=7, ncol=num_legend_cols) 386 | else: 387 | lg = ax.legend(fontsize=7, ncol=num_legend_cols, title=legend_title) 388 | title = lg.get_title() 389 | title.set_fontsize(10) 390 | 391 | ax.spines["top"].set_visible(False) 392 | ax.spines["right"].set_visible(False) 393 | 394 | if ax_lims is not None: 395 | ax.set_ylim(ax_lims) 396 | if plot_title is None: 397 | plot_title = ", ".join(target_to_plot) 398 | ax.set_title(plot_title) 399 | if xy_labels is None: 400 | xy_labels = [iter_to_plot, ", ".join(target_to_plot)] 401 | ax.set_xlabel(xy_labels[0]) 402 | ax.set_ylabel(xy_labels[1]) 403 | fig.tight_layout() 404 | return fig, ax 405 | 406 | 407 | def tokenize(filename: str): 408 | """Helper to sort the log files alphanumerically. 409 | 410 | Args: 411 | filename (str): Name of run. 412 | """ 413 | digits = re.compile(r"(\d+)") 414 | return tuple( 415 | int(token) if match else token 416 | for token, match in ( 417 | (fragment, digits.search(fragment)) 418 | for fragment in digits.split(filename) 419 | ) 420 | ) 421 | -------------------------------------------------------------------------------- /mle_logging/mle_logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import shutil 4 | import yaml 5 | from typing import Optional, Union, List, Dict 6 | from rich.console import Console 7 | from .utils import ( 8 | write_to_hdf5, 9 | load_config, 10 | print_welcome, 11 | print_startup, 12 | print_update, 13 | print_reload, 14 | print_storage, 15 | ) 16 | from .save import StatsLog, TboardLog, WandbLog, ModelLog, FigureLog, ExtraLog 17 | 18 | 19 | class MLELogger(object): 20 | """ 21 | Logging object for Machine Learning experiments 22 | 23 | Args: 24 | ======= TRACKING AND PRINTING VARIABLE NAMES 25 | time_to_track (List[str]): column names of pandas df - time 26 | what_to_track (List[str]): column names of pandas df - statistics 27 | time_to_print (List[str]): subset columns of time df to print out 28 | what_to_print (List[str]): subset columns of stats df to print out 29 | ======= TRACKING AND PRINTING VARIABLE NAMES 30 | config_fname (str): file path of configuration of experiment 31 | config_dict(dict): dictionary of experiment config to store in yaml 32 | experiment_dir (str): base experiment directory 33 | seed_id (str): seed id to distinguish logs with (e.g. seed_0) 34 | overwrite (bool): delete old log file/tboard dir 35 | ======= VERBOSITY/TBOARD LOGGING 36 | use_tboard (bool): whether to log to tensorboard 37 | use_wandb (bool): whether to log to wandb 38 | log_every_j_steps (int): steps between log updates 39 | print_every_k_updates (int): after how many log updates - verbose 40 | ======= MODEL STORAGE 41 | model_type (str): ["torch", "jax", "sklearn", "numpy"] 42 | ckpt_time_to_track (str): Variable name/score key to save 43 | save_every_k_ckpt (int): save every other checkpoint 44 | save_top_k_ckpt (int): save top k performing checkpoints 45 | top_k_metric_name (str): Variable name/score key to save 46 | top_k_minimize_metric (str): Boolean for min/max score in top k logging 47 | """ 48 | 49 | def __init__( 50 | self, 51 | experiment_dir: str = "/", 52 | time_to_track: List[str] = [], 53 | what_to_track: List[str] = [], 54 | time_to_print: Optional[List[str]] = None, 55 | what_to_print: Optional[List[str]] = None, 56 | config_fname: Optional[str] = None, 57 | config_dict: Optional[dict] = None, 58 | seed_id: Union[str, int] = "no_seed_provided", 59 | overwrite: bool = False, 60 | use_tboard: bool = False, 61 | use_wandb: bool = False, 62 | wandb_config: Optional[dict] = None, 63 | log_every_j_steps: Optional[int] = None, 64 | print_every_k_updates: Optional[int] = 1, 65 | model_type: str = "no-model-type", 66 | ckpt_time_to_track: Optional[str] = None, 67 | save_every_k_ckpt: Optional[int] = None, 68 | save_top_k_ckpt: Optional[int] = None, 69 | top_k_metric_name: Optional[str] = None, 70 | top_k_minimize_metric: Optional[bool] = None, 71 | reload: bool = False, 72 | verbose: bool = False, 73 | ): 74 | # Set os hdf file to non locking mode 75 | os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" 76 | # Set up tensorboard when/where to log and when to print 77 | self.use_tboard = use_tboard 78 | self.use_wandb = use_wandb 79 | self.log_every_j_steps = log_every_j_steps 80 | self.print_every_k_updates = print_every_k_updates 81 | self.log_save_counter = reload 82 | self.log_setup_counter = reload 83 | self.seed_id = "seed_" + str(seed_id) if type(seed_id) == int else seed_id 84 | self.config_fname = config_fname 85 | self.config_dict = config_dict 86 | 87 | self.get_configs_ready(self.config_fname, self.config_dict) 88 | 89 | # Set up the logging directories - copy timestamped config file 90 | self.setup_experiment( 91 | experiment_dir, 92 | config_fname, 93 | self.seed_id, 94 | overwrite, 95 | reload, 96 | ) 97 | 98 | # STATS & TENSORBOARD LOGGING SETUP 99 | self.stats_log = StatsLog( 100 | self.experiment_dir, 101 | self.seed_id, 102 | time_to_track, 103 | what_to_track, 104 | reload, 105 | ) 106 | if self.use_tboard: 107 | self.tboard_log = TboardLog( 108 | self.experiment_dir, 109 | self.seed_id, 110 | ) 111 | if self.use_wandb: 112 | self.wandb_log = WandbLog( 113 | self.config_dict, self.config_fname, self.seed_id, wandb_config 114 | ) 115 | 116 | # MODEL, FIGURE & EXTRA LOGGING SETUP 117 | self.model_log = ModelLog( 118 | self.experiment_dir, 119 | self.seed_id, 120 | model_type, 121 | ckpt_time_to_track, 122 | save_every_k_ckpt, 123 | save_top_k_ckpt, 124 | top_k_metric_name, 125 | top_k_minimize_metric, 126 | reload, 127 | ) 128 | 129 | self.figure_log = FigureLog( 130 | self.experiment_dir, 131 | self.seed_id, 132 | reload, 133 | ) 134 | self.extra_log = ExtraLog( 135 | self.experiment_dir, 136 | self.seed_id, 137 | reload, 138 | ) 139 | 140 | # VERBOSITY SETUP: Set up what to print 141 | self.verbose = verbose 142 | self.print_counter = 0 143 | self.time_to_print = time_to_print 144 | self.what_to_print = what_to_print 145 | 146 | if not reload and verbose: 147 | print_welcome() 148 | print_startup( 149 | self.experiment_dir, 150 | self.config_fname, 151 | time_to_track, 152 | what_to_track, 153 | model_type, 154 | seed_id, 155 | use_tboard, 156 | reload, 157 | print_every_k_updates, 158 | ckpt_time_to_track, 159 | save_every_k_ckpt, 160 | save_top_k_ckpt, 161 | top_k_metric_name, 162 | top_k_minimize_metric, 163 | ) 164 | elif reload and verbose: 165 | print_reload( 166 | self.experiment_dir, 167 | ) 168 | 169 | def setup_experiment( 170 | self, 171 | base_exp_dir: str, 172 | config_fname: Union[str, None], 173 | seed_id: str, 174 | overwrite_experiment_dir: bool = False, 175 | reload: bool = False, 176 | ) -> None: 177 | """Setup directory name and clean up previous logging data.""" 178 | # Get timestamp of experiment & create new directories 179 | if config_fname is not None: 180 | self.base_str = os.path.split(config_fname)[1].split(".")[0] 181 | if not reload: 182 | self.experiment_dir = os.path.join(base_exp_dir, self.base_str) 183 | else: 184 | # Don't redefine experiment directory but get already existing 185 | exp_dir = [ 186 | f for f in os.listdir(base_exp_dir) if f.endswith(self.base_str) 187 | ][0] 188 | self.experiment_dir = os.path.join(base_exp_dir, exp_dir) 189 | else: 190 | self.base_str = "" 191 | self.experiment_dir = base_exp_dir 192 | 193 | self.log_save_fname = os.path.join( 194 | self.experiment_dir, "logs/", "log_" + seed_id + ".hdf5" 195 | ) 196 | aggregated_log_save_fname = os.path.join( 197 | self.experiment_dir, "logs/", "log.hdf5" 198 | ) 199 | 200 | # Delete old experiment logging directory 201 | if overwrite_experiment_dir and not reload: 202 | if os.path.exists(self.log_save_fname): 203 | Console().log("Be careful - you are overwriting an existing log.") 204 | os.remove(self.log_save_fname) 205 | if os.path.exists(aggregated_log_save_fname): 206 | Console().log( 207 | "Be careful - you are overwriting an existing aggregated" " log." 208 | ) 209 | os.remove(aggregated_log_save_fname) 210 | if self.use_tboard: 211 | Console().log("Be careful - you are overwriting existing tboards.") 212 | if os.path.exists(os.path.join(self.experiment_dir, "tboards/")): 213 | shutil.rmtree(os.path.join(self.experiment_dir, "tboards/")) 214 | 215 | def get_configs_ready( 216 | self, config_fname: Union[str, None], config_dict: Union[dict, None] 217 | ): 218 | """Load configuration if provided and set config_dict.""" 219 | if config_fname is not None: 220 | self.config_dict = load_config(config_fname) 221 | elif config_dict is not None: 222 | self.config_dict = config_dict 223 | else: 224 | self.config_dict = {} 225 | 226 | def create_logging_dir( 227 | self, 228 | config_fname: Union[str, None], 229 | config_dict: Union[dict, None], 230 | ): 231 | """Create new empty dir for experiment (if not existing).""" 232 | os.makedirs(self.experiment_dir, exist_ok=True) 233 | 234 | # Copy over json configuration file if it exists 235 | if config_fname is not None: 236 | fname, fext = os.path.splitext(config_fname) 237 | else: 238 | fname, fext = "pholder", ".yaml" 239 | 240 | if config_fname is not None: 241 | config_copy = os.path.join(self.experiment_dir, self.base_str + fext) 242 | shutil.copy(config_fname, config_copy) 243 | self.config_copy = config_copy 244 | elif config_dict is not None: 245 | config_copy = os.path.join(self.experiment_dir, "config_dict" + fext) 246 | with open(config_copy, "w") as outfile: 247 | yaml.dump(config_dict, outfile, default_flow_style=False) 248 | self.config_copy = config_copy 249 | else: 250 | self.config_copy = "config-not-provided" 251 | 252 | # Create .hdf5 logging sub-directory 253 | os.makedirs(os.path.join(self.experiment_dir, "logs/"), exist_ok=True) 254 | 255 | def update( 256 | self, 257 | clock_tick: Dict[str, int], 258 | stats_tick: Dict[str, float], 259 | model=None, 260 | plot_fig=None, 261 | extra_obj=None, 262 | grads=None, 263 | save=False, 264 | ) -> None: 265 | """Update with the newest tick of performance stats, net weights""" 266 | # Make sure that timeseries data consists of floats 267 | stats_tick = { 268 | key: float(value) if type(value) != np.ndarray else value 269 | for (key, value) in stats_tick.items() 270 | } 271 | 272 | # Update the stats log with newest timeseries data 273 | c_tick, s_tick = self.stats_log.update(clock_tick, stats_tick) 274 | # Update the tensorboard log with the newest event 275 | if self.use_tboard: 276 | self.tboard_log.update( 277 | self.stats_log.time_to_track, 278 | clock_tick, 279 | stats_tick, 280 | self.model_log.model_type, 281 | model, 282 | grads, 283 | plot_fig, 284 | ) 285 | if self.use_wandb: 286 | self.wandb_log.update( 287 | clock_tick, 288 | stats_tick, 289 | self.model_log.model_type, 290 | model, 291 | grads, 292 | plot_fig, 293 | ) 294 | # Save the most recent model checkpoint 295 | if model is not None: 296 | self.save_model(model) 297 | # Save fig from matplotlib 298 | if plot_fig is not None: 299 | self.save_plot(plot_fig) 300 | # Save .pkl object 301 | if extra_obj is not None: 302 | self.save_extra(extra_obj) 303 | # Save the .hdf5 log if boolean says so 304 | if save: 305 | self.save() 306 | 307 | # Print the most current results 308 | if self.verbose and self.print_every_k_updates is not None: 309 | if ( 310 | self.stats_log.stats_update_counter % self.print_every_k_updates == 0 311 | ) or self.stats_log.stats_update_counter == 1: 312 | # Print storage paths generated/updated 313 | print_storage( 314 | fig_path=( 315 | self.figure_log.fig_storage_paths[-1] 316 | if plot_fig is not None 317 | else None 318 | ), 319 | extra_path=( 320 | self.extra_log.extra_storage_paths[-1] 321 | if extra_obj is not None 322 | else None 323 | ), 324 | init_model_path=( 325 | self.model_log.init_model_save_fname 326 | if model is not None and self.model_log.init_model_saved 327 | else None 328 | ), 329 | final_model_path=( 330 | self.model_log.final_model_save_fname 331 | if model is not None 332 | else None 333 | ), 334 | every_k_model_path=( 335 | self.model_log.every_k_ckpt_list[-1] 336 | if model is not None and self.model_log.stored_every_k 337 | else None 338 | ), 339 | top_k_model_path=( 340 | self.model_log.top_k_ckpt_list[-1] 341 | if model is not None and self.model_log.stored_top_k 342 | else None 343 | ), 344 | print_first=self.print_counter == 0, 345 | ) 346 | # Only print column name header at 1st print! 347 | if self.time_to_print is None: 348 | time_to_p = self.stats_log.time_to_track 349 | else: 350 | time_to_p = ["time", "time_elapsed", "num_updates"] 351 | if self.what_to_print is None: 352 | what_to_p = self.stats_log.what_to_track 353 | else: 354 | what_to_p = self.what_to_print 355 | print_update( 356 | time_to_p, 357 | what_to_p, 358 | c_tick, 359 | s_tick, 360 | self.print_counter == 0, 361 | ) 362 | self.print_counter += 1 363 | 364 | def save_init_model(self, model): 365 | """Save initial model checkpoint.""" 366 | self.model_log.save_init_model(model) 367 | 368 | def save_model(self, model): 369 | """Save a model checkpoint.""" 370 | self.model_log.save( 371 | model, self.stats_log.clock_tracked, self.stats_log.stats_tracked 372 | ) 373 | 374 | def save_plot(self, fig, fig_fname: Union[str, None] = None): 375 | """Store a figure in a experiment_id/figures directory.""" 376 | # Create main logging dir and .hdf5 sub-directory 377 | if not self.log_setup_counter: 378 | self.create_logging_dir(self.config_fname, self.config_dict) 379 | self.log_setup_counter += 1 380 | self.figure_log.save(fig, fig_fname) 381 | write_to_hdf5( 382 | self.log_save_fname, 383 | self.seed_id + "/meta/fig_storage_paths", 384 | self.figure_log.fig_storage_paths, 385 | ) 386 | 387 | def save_extra(self, obj, obj_fname: Union[str, None] = None): 388 | """Helper fct. to save object (dict/etc.) as .pkl in exp. subdir.""" 389 | # Create main logging dir and .hdf5 sub-directory 390 | if not self.log_setup_counter: 391 | self.create_logging_dir(self.config_fname, self.config_dict) 392 | self.log_setup_counter += 1 393 | self.extra_log.save(obj, obj_fname) 394 | write_to_hdf5( 395 | self.log_save_fname, 396 | self.seed_id + "/meta/extra_storage_paths", 397 | self.extra_log.extra_storage_paths, 398 | ) 399 | 400 | def save(self): 401 | """Create compressed .hdf5 file containing group """ 402 | # Create main logging dir and .hdf5 sub-directory 403 | if not self.log_setup_counter: 404 | self.create_logging_dir(self.config_fname, self.config_dict) 405 | self.log_setup_counter += 1 406 | 407 | # Create "datasets" to store in the hdf5 file [time, stats] 408 | # Store all relevant meta data (log filename, checkpoint filename) 409 | if self.log_save_counter == 0: 410 | data_paths = [ 411 | self.seed_id + "/meta/log_paths", 412 | self.seed_id + "/meta/experiment_dir", 413 | self.seed_id + "/meta/config_fname", 414 | self.seed_id + "/meta/eval_id", 415 | self.seed_id + "/meta/model_type", 416 | self.seed_id + "/meta/config_dict", 417 | ] 418 | 419 | data_to_log = [ 420 | [self.log_save_fname], 421 | [self.experiment_dir], 422 | [self.config_copy], 423 | [self.base_str], 424 | [self.model_log.model_type], 425 | [str(self.config_dict)], 426 | ] 427 | 428 | for i in range(len(data_paths)): 429 | write_to_hdf5(self.log_save_fname, data_paths[i], data_to_log[i]) 430 | 431 | if self.model_log.save_top_k_ckpt or self.model_log.save_every_k_ckpt: 432 | write_to_hdf5( 433 | self.log_save_fname, 434 | self.seed_id + "/meta/ckpt_time_to_track", 435 | [self.model_log.ckpt_time_to_track], 436 | ) 437 | 438 | if self.model_log.save_top_k_ckpt: 439 | write_to_hdf5( 440 | self.log_save_fname, 441 | self.seed_id + "/meta/top_k_metric_name", 442 | [self.model_log.top_k_metric_name], 443 | ) 444 | 445 | # Store final and initial checkpoint if provided 446 | if self.model_log.model_save_counter > 0: 447 | write_to_hdf5( 448 | self.log_save_fname, 449 | self.seed_id + "/meta/model_ckpt", 450 | [self.model_log.final_model_save_fname], 451 | ) 452 | 453 | if self.model_log.init_model_saved: 454 | write_to_hdf5( 455 | self.log_save_fname, 456 | self.seed_id + "/meta/init_ckpt", 457 | [self.model_log.init_model_save_fname], 458 | ) 459 | 460 | # Store all time_to_track variables 461 | for o_name in self.stats_log.time_to_track: 462 | if o_name != "time": 463 | write_to_hdf5( 464 | self.log_save_fname, 465 | self.seed_id + "/time/" + o_name, 466 | self.stats_log.clock_tracked[o_name], 467 | dtype="float32", 468 | ) 469 | else: 470 | write_to_hdf5( 471 | self.log_save_fname, 472 | self.seed_id + "/time/" + o_name, 473 | self.stats_log.clock_tracked[o_name], 474 | ) 475 | 476 | # Store all what_to_track variables 477 | for o_name in self.stats_log.what_to_track: 478 | data_to_store = self.stats_log.stats_tracked[o_name] 479 | data_to_store = np.array(data_to_store) 480 | if len(data_to_store) > 0: 481 | if type(data_to_store[0]) == np.ndarray: 482 | data_to_store = np.stack(data_to_store) 483 | dtype = np.dtype("float32") 484 | if type(data_to_store[0]) in [np.str_, str]: 485 | dtype = "S5000" 486 | if type(data_to_store[0]) in [bytes, np.str_]: 487 | dtype = np.dtype("S5000") 488 | elif type(data_to_store[0]) == int: 489 | dtype = np.dtype("int32") 490 | else: 491 | dtype = np.dtype("float32") 492 | write_to_hdf5( 493 | self.log_save_fname, 494 | self.seed_id + "/stats/" + o_name, 495 | data_to_store, 496 | dtype, 497 | ) 498 | 499 | # Store data on stored checkpoints - stored every k updates 500 | if self.model_log.save_every_k_ckpt is not None: 501 | data_paths = [ 502 | self.seed_id + "/meta/" + "every_k_storage_time", 503 | self.seed_id + "/meta/" + "every_k_ckpt_list", 504 | ] 505 | data_to_log = [ 506 | self.model_log.every_k_storage_time, 507 | self.model_log.every_k_ckpt_list, 508 | ] 509 | data_types = ["int32", "S5000"] 510 | for i in range(len(data_paths)): 511 | write_to_hdf5( 512 | self.log_save_fname, 513 | data_paths[i], 514 | data_to_log[i], 515 | data_types[i], 516 | ) 517 | 518 | # Store data on stored checkpoints - stored top k ckpt 519 | if self.model_log.save_top_k_ckpt is not None: 520 | data_paths = [ 521 | self.seed_id + "/meta/" + "top_k_storage_time", 522 | self.seed_id + "/meta/" + "top_k_ckpt_list", 523 | self.seed_id + "/meta/" + "top_k_performance", 524 | ] 525 | data_to_log = [ 526 | self.model_log.top_k_storage_time, 527 | self.model_log.top_k_ckpt_list, 528 | self.model_log.top_k_performance, 529 | ] 530 | data_types = ["int32", "S5000", "float32"] 531 | for i in range(len(data_paths)): 532 | write_to_hdf5( 533 | self.log_save_fname, 534 | data_paths[i], 535 | data_to_log[i], 536 | data_types[i], 537 | ) 538 | 539 | # Tick the log save counter 540 | self.log_save_counter += 1 541 | 542 | def extend_tracking(self, add_track_vars: List[str]) -> None: 543 | """Add string names of variables to track.""" 544 | self.stats_log.extend_tracking(add_track_vars) 545 | 546 | def ready_to_log(self, update_counter: int) -> bool: 547 | """Check whether update_counter is modulo of log_every_k_steps.""" 548 | assert ( 549 | self.log_every_j_steps is not None 550 | ), "Provide `log_every_j_steps` in your `log_config`" 551 | return (update_counter + 1) % self.log_every_j_steps == 0 or update_counter == 0 552 | --------------------------------------------------------------------------------