├── assets ├── .gitignore ├── favicon.ico ├── logo-192x192.png ├── logo-512x512.png └── logo-no-background.png ├── mandala ├── __init__.py ├── deps │ ├── tracers │ │ ├── __init__.py │ │ ├── tracer_base.py │ │ └── sys_impl.py │ ├── crawler.py │ ├── viz.py │ ├── deep_versions.py │ ├── utils.py │ └── model.py ├── imports.py ├── config.py ├── common_imports.py ├── tests │ ├── test_cfs.py │ ├── test_configs.py │ ├── test_memoization.py │ └── test_versioning.py └── tps.py ├── runtime.txt ├── requirements.txt ├── .github ├── FUNDING.yml └── workflows │ └── deploy-docs.yml ├── docs_source ├── readme.md ├── make_docs.py └── topics │ └── 01_storage_and_ops.ipynb ├── c.py ├── docs └── docs │ ├── stylesheets │ └── extra.css │ ├── index.md │ ├── topics │ ├── 03_cf_files │ │ ├── 03_cf_6_1.svg │ │ ├── 03_cf_20_1.svg │ │ ├── 03_cf_22_0.svg │ │ ├── 03_cf_28_0.svg │ │ ├── 03_cf_14_1.svg │ │ └── 03_cf_29_0.svg │ ├── 06_advanced_cf.md │ ├── 05_collections.md │ ├── 01_storage_and_ops.md │ ├── 05_collections_files │ │ └── 05_collections_5_0.svg │ └── 02_retracing.md │ ├── tutorials │ ├── 01_hello_files │ │ ├── 01_hello_5_1.svg │ │ └── 01_hello_5_3.svg │ └── 02_ml_files │ │ ├── 02_ml_22_3.svg │ │ └── 02_ml_22_1.svg │ └── blog │ └── 01_cf_files │ ├── 01_cf_2_0.svg │ └── 01_cf_21_0.svg ├── console.py ├── .coveragerc ├── mkdocs.yml ├── .gitignore ├── setup.py └── LICENSE /assets/.gitignore: -------------------------------------------------------------------------------- 1 | !*.png -------------------------------------------------------------------------------- /mandala/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /runtime.txt: -------------------------------------------------------------------------------- 1 | python-3.10 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # for binder 2 | numpy >= 1.18 3 | pandas >= 1.0 4 | joblib >= 1.0 -------------------------------------------------------------------------------- /assets/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amakelov/mandala/HEAD/assets/favicon.ico -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [amakelov] 4 | -------------------------------------------------------------------------------- /assets/logo-192x192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amakelov/mandala/HEAD/assets/logo-192x192.png -------------------------------------------------------------------------------- /assets/logo-512x512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amakelov/mandala/HEAD/assets/logo-512x512.png -------------------------------------------------------------------------------- /assets/logo-no-background.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amakelov/mandala/HEAD/assets/logo-no-background.png -------------------------------------------------------------------------------- /mandala/deps/tracers/__init__.py: -------------------------------------------------------------------------------- 1 | from .tracer_base import TracerABC 2 | from .dec_impl import DecTracer 3 | from .sys_impl import SysTracer 4 | -------------------------------------------------------------------------------- /docs_source/readme.md: -------------------------------------------------------------------------------- 1 | To convert these notebooks to markdown docs, run them in the `jupyter notebook` 2 | server, save them, and run 3 | ```bash 4 | python make_docs.py 5 | ``` -------------------------------------------------------------------------------- /c.py: -------------------------------------------------------------------------------- 1 | import subprocess, argparse 2 | 3 | 4 | def get_parser() -> argparse.ArgumentParser: 5 | parser = argparse.ArgumentParser() 6 | return parser 7 | 8 | 9 | if __name__ == "__main__": 10 | parser = get_parser() 11 | args = parser.parse_args() 12 | cmd = ["ipython", "-i", "console.py", "--"] 13 | subprocess.call(cmd) 14 | -------------------------------------------------------------------------------- /mandala/imports.py: -------------------------------------------------------------------------------- 1 | from .storage import Storage, noop 2 | from .model import op, Ignore, NewArgDefault, wrap_atom, ValuePointer 3 | from .tps import MList, MDict 4 | from .deps.tracers.dec_impl import track 5 | 6 | from .common_imports import sess 7 | 8 | 9 | def pprint_dict(d) -> str: 10 | return '\n'.join([f" {k}: {v}" for k, v in d.items()]) -------------------------------------------------------------------------------- /.github/workflows/deploy-docs.yml: -------------------------------------------------------------------------------- 1 | name: Deploy MkDocs 2 | on: 3 | push: 4 | branches: 5 | - master 6 | jobs: 7 | deploy: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v2 11 | - uses: actions/setup-python@v2 12 | with: 13 | python-version: 3.x 14 | - run: pip install mkdocs-material 15 | - run: mkdocs gh-deploy --force -------------------------------------------------------------------------------- /docs/docs/stylesheets/extra.css: -------------------------------------------------------------------------------- 1 | :root { 2 | --md-primary-fg-color: #073642; 3 | --md-accent-fg-color: #268bd2; 4 | 5 | --md-default-bg-color: #fdf6e3; 6 | --md-default-fg-color: #657b83; 7 | --md-default-fg-color--light: #93a1a1; 8 | --md-default-fg-color--lighter: #eee8d5; 9 | 10 | --md-code-bg-color: #eee8d5; 11 | --md-code-fg-color: #657b83; 12 | } 13 | 14 | /* 15 | a { 16 | text-decoration: underline; 17 | } 18 | */ -------------------------------------------------------------------------------- /console.py: -------------------------------------------------------------------------------- 1 | # useful builtins 2 | import sys 3 | import argparse 4 | 5 | 6 | def get_parser() -> argparse.ArgumentParser: 7 | """ 8 | Copy of `get_parser` in c.py 9 | """ 10 | parser = argparse.ArgumentParser() 11 | return parser 12 | 13 | 14 | parser = get_parser() 15 | args = parser.parse_args() 16 | 17 | if __name__ == "__main__": 18 | from mandala.all import * 19 | from mandala.tests.test_stateful_slow import * 20 | 21 | # setup_logging(level='info') 22 | -------------------------------------------------------------------------------- /docs/docs/index.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | This is the documentation for [mandala](https://github.com/amakelov/mandala), a 3 | simple & elegant experiment tracking framework for Python. 4 | 5 | Most methods in `mandala` are provided by the `Storage` and `ComputationFrame` 6 | classes. In general, you'll probably find yourself only interacting with 5-10 7 | methods on a regular basis, and their docstrings provide detailed explanations. 8 | 9 | To complement this, this documentation contains a few short walkthroughs 10 | illustrating the use of these methods. -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | [run] 3 | branch = True 4 | omit = 5 | mandala_lite/tests/*.py 6 | mandala_lite/demos/* 7 | 8 | [report] 9 | # Regexes for lines to exclude from consideration 10 | exclude_lines = 11 | # Have to re-enable the standard pragma 12 | pragma: no cover 13 | 14 | # Don't complain if tests don't hit defensive assertion code: 15 | raise AssertionError 16 | raise NotImplementedError 17 | # Don't complain if you don't run into internal errors 18 | raise InternalError 19 | 20 | # Don't complain about abstract methods, they aren't run: 21 | @(abc\.)?abstractmethod 22 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Mandala Documentation 2 | 3 | plugins: 4 | - search 5 | 6 | theme: 7 | name: material 8 | font: 9 | text: Roboto 10 | code: Roboto Mono 11 | features: 12 | - content.code.copy # Adds a copy button to code blocks 13 | palette: 14 | - media: "(prefers-color-scheme: light)" 15 | scheme: default 16 | primary: indigo 17 | accent: indigo 18 | toggle: 19 | icon: material/brightness-7 20 | name: Switch to dark mode 21 | - media: "(prefers-color-scheme: dark)" 22 | scheme: slate 23 | primary: indigo 24 | accent: indigo 25 | toggle: 26 | icon: material/brightness-4 27 | name: Switch to light mode 28 | 29 | markdown_extensions: 30 | - pymdownx.highlight: 31 | anchor_linenums: true 32 | line_spans: __span 33 | pygments_lang_class: true 34 | - pymdownx.inlinehilite 35 | - pymdownx.snippets 36 | - pymdownx.superfences 37 | 38 | docs_dir: 'docs/docs' 39 | 40 | extra_css: 41 | - stylesheets/extra.css 42 | -------------------------------------------------------------------------------- /docs/docs/topics/03_cf_files/03_cf_6_1.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 14 | v 15 | 16 | v 17 | 1 values (1 sources/1 sinks) 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # python 2 | *.pyc 3 | 4 | # logs 5 | *.log 6 | 7 | # pictures and other visualization-related 8 | *.jpg 9 | *.JPG 10 | *.png 11 | *.tif 12 | *.gif 13 | # *.svg 14 | *.dot 15 | *.gv 16 | # *.mp4 17 | 18 | # docs 19 | *.pdf 20 | *.djvu 21 | *.ps 22 | 23 | # vim 24 | *.swp 25 | *.swo 26 | *.swn 27 | 28 | # archives 29 | *.7z 30 | *.dmg 31 | *.gz 32 | *.iso 33 | *.jar 34 | *.rar 35 | *.tar 36 | *.zip 37 | 38 | # latex 39 | *.aux 40 | *.fdb_latexmk 41 | *.fls 42 | *.log 43 | *.out 44 | 45 | # web 46 | *.html 47 | # *.css 48 | 49 | # python objects 50 | *.pkl 51 | *.joblib 52 | 53 | ################################################################################ 54 | ### path-based 55 | ################################################################################ 56 | # eggs 57 | mandala.egg-info/* 58 | pymandala.egg-info/* 59 | 60 | scratchpad/ 61 | 62 | # visualizations 63 | # Ignore files cached by Hypothesis 64 | **/.hypothesis/* 65 | # Ignore files cached by vscode 66 | **/.vscode/* 67 | # ignore files cached by pytest 68 | **/.pytest_cache/* 69 | # ignore files cached by jupyter 70 | **/.ipynb_checkpoints 71 | # ignore coverage things 72 | .coverage 73 | **/htmlcov/* 74 | # ignore persistent storages 75 | temp_dbs/* 76 | mandala/tests/output/* 77 | *.db 78 | *.parquet 79 | 80 | # ignore docs build 81 | site/ 82 | 83 | # ignore PyPi build 84 | dist/ 85 | 86 | # Ignore previous version 87 | mandala/_prev/ -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | 4 | install_requires = [ 5 | "numpy >= 1.18", 6 | "pandas >= 1.0", 7 | "joblib >= 1.0", 8 | ] 9 | 10 | extras_require = { 11 | "base": ["prettytable", "graphviz"], 12 | "ui": [ 13 | "rich", 14 | ], 15 | "test": [ 16 | "pytest >= 6.0.0", 17 | "hypothesis >= 6.0.0", 18 | "ipython", 19 | ], 20 | "demos": [ 21 | "scikit-learn", 22 | ], 23 | } 24 | 25 | 26 | extras_require["complete"] = sorted({v for req in extras_require.values() for v in req}) 27 | 28 | packages = [ 29 | "mandala", 30 | "mandala.deps", 31 | "mandala.deps.tracers", 32 | "mandala.tests", 33 | ] 34 | 35 | setup( 36 | name="pymandala", 37 | version="v0.2.0-alpha", 38 | author="Aleksandar (Alex) Makelov", 39 | author_email="aleksandar.makelov@gmail.com", 40 | description="A powerful and easy to use experiment tracking framework", 41 | long_description=open("README.md").read(), 42 | long_description_content_type="text/markdown", 43 | url="https://github.com/amakelov/mandala", 44 | license="Apache 2.0", 45 | keywords="computational-experiments data-management machine-learning data-science", 46 | classifiers=[ 47 | "Development Status :: 3 - Alpha", 48 | "Intended Audience :: Science/Research", 49 | "Operating System :: OS Independent", 50 | "License :: OSI Approved :: Apache Software License", 51 | "Programming Language :: Python :: 3", 52 | "Programming Language :: Python :: 3.10", 53 | "Programming Language :: Python :: 3.9", 54 | "Programming Language :: Python :: 3.8", 55 | ], 56 | packages=packages, 57 | python_requires=">=3.8", 58 | install_requires=install_requires, 59 | extras_require=extras_require, 60 | ) 61 | -------------------------------------------------------------------------------- /mandala/config.py: -------------------------------------------------------------------------------- 1 | from .common_imports import * 2 | 3 | def get_mandala_path() -> Path: 4 | import mandala 5 | 6 | return Path(os.path.dirname(mandala.__file__)) 7 | 8 | class Config: 9 | func_interface_cls_name = "Op" 10 | mandala_path = get_mandala_path() 11 | module_name = "mandala" 12 | tests_module_name = "mandala.tests" 13 | 14 | try: 15 | import PIL 16 | 17 | has_pil = True 18 | except ImportError: 19 | has_pil = False 20 | 21 | try: 22 | import torch 23 | 24 | has_torch = True 25 | except ImportError: 26 | has_torch = False 27 | 28 | try: 29 | import rich 30 | 31 | has_rich = True 32 | except ImportError: 33 | has_rich = False 34 | 35 | try: 36 | import prettytable 37 | 38 | has_prettytable = True 39 | except ImportError: 40 | has_prettytable = False 41 | 42 | 43 | if Config.has_torch: 44 | import torch 45 | 46 | def tensor_to_numpy(obj: Union[torch.Tensor, dict, list, tuple, Any]) -> Any: 47 | """ 48 | Recursively convert PyTorch tensors in a data structure to numpy arrays. 49 | 50 | Parameters 51 | ---------- 52 | obj : any 53 | The input data structure. 54 | 55 | Returns 56 | ------- 57 | any 58 | The data structure with tensors converted to numpy arrays. 59 | """ 60 | if isinstance(obj, torch.Tensor): 61 | return obj.detach().cpu().numpy() 62 | elif isinstance(obj, dict): 63 | return {k: tensor_to_numpy(v) for k, v in obj.items()} 64 | elif isinstance(obj, list): 65 | return [tensor_to_numpy(v) for v in obj] 66 | elif isinstance(obj, tuple): 67 | return tuple(tensor_to_numpy(v) for v in obj) 68 | else: 69 | return obj 70 | -------------------------------------------------------------------------------- /docs/docs/topics/03_cf_files/03_cf_20_1.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 14 | test_acc 15 | 16 | test_acc 17 | 2 values (2 sources/2 sinks) 18 | 19 | 20 | 21 | model 22 | 23 | model 24 | 4 values (4 sources/4 sinks) 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /mandala/common_imports.py: -------------------------------------------------------------------------------- 1 | import time 2 | import traceback 3 | import random 4 | import logging 5 | import itertools 6 | import copy 7 | import hashlib 8 | import io 9 | import os 10 | import shutil 11 | import sys 12 | import joblib 13 | import inspect 14 | import binascii 15 | import asyncio 16 | import ast 17 | import types 18 | import tempfile 19 | from collections import defaultdict, OrderedDict 20 | from typing import ( 21 | Any, 22 | Dict, 23 | List, 24 | Callable, 25 | Tuple, 26 | Iterable, 27 | Optional, 28 | Set, 29 | Union, 30 | TypeVar, 31 | Literal, 32 | ) 33 | from pathlib import Path 34 | 35 | import pandas as pd 36 | import pyarrow as pa 37 | import numpy as np 38 | 39 | try: 40 | import rich 41 | 42 | has_rich = True 43 | except ImportError: 44 | has_rich = False 45 | 46 | if has_rich: 47 | from rich.logging import RichHandler 48 | 49 | logger = logging.getLogger("mandala") 50 | logging_handler = RichHandler(enable_link_path=False) 51 | FORMAT = "%(message)s" 52 | logging.basicConfig( 53 | level="INFO", format=FORMAT, datefmt="[%X]", handlers=[logging_handler] 54 | ) 55 | else: 56 | logger = logging.getLogger("mandala") 57 | # logger.addHandler(logging.StreamHandler()) 58 | FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s" 59 | logging.basicConfig(format=FORMAT) 60 | logger.setLevel(logging.INFO) 61 | 62 | class Session: 63 | # for debugging 64 | 65 | def __init__(self): 66 | self.items = [] 67 | self._scope = None 68 | 69 | def d(self): 70 | scope = inspect.currentframe().f_back.f_locals 71 | self._scope = scope 72 | 73 | def dump(self): 74 | # put the scope into the current locals 75 | assert self._scope is not None 76 | scope = inspect.currentframe().f_back.f_locals 77 | print(f"Dumping {self._scope.keys()} into local scope") 78 | scope.update(self._scope) 79 | 80 | sess = Session() -------------------------------------------------------------------------------- /mandala/tests/test_cfs.py: -------------------------------------------------------------------------------- 1 | from mandala.imports import * 2 | 3 | 4 | def test_single_func(): 5 | storage = Storage() 6 | 7 | @op 8 | def inc(x: int) -> int: 9 | return x + 1 10 | 11 | with storage: 12 | for i in range(10): 13 | inc(i) 14 | 15 | cf = storage.cf(inc) 16 | df = cf.df() 17 | assert df.shape == (10, 3) 18 | assert (df['var_0'] == df['x'] + 1).all() 19 | 20 | 21 | def test_composition(): 22 | storage = Storage() 23 | 24 | @op(output_names=['y']) 25 | def inc(x): 26 | return x + 1 27 | 28 | @op(output_names=['z']) 29 | def add(x, y): 30 | return x + y 31 | 32 | with storage: 33 | for x in range(5): 34 | y = inc(x) 35 | if x % 2 == 0: 36 | z = add(x, y) 37 | 38 | cf = storage.cf(add).expand_all() 39 | df = cf.df() 40 | assert df.shape[0] == 3 41 | assert (df['z'] == df['x'] + df['y']).all() 42 | 43 | cf = storage.cf(inc).expand_all() 44 | df = cf.df() 45 | assert df.shape[0] == 5 46 | assert (df['y'] == df['x'] + 1).all() 47 | assert (df['z'] == df['x'] + df['y'])[df['z'].notnull()].all() 48 | 49 | def test_merge(): 50 | storage = Storage() 51 | 52 | @op(output_names=['y']) 53 | def inc(x): 54 | return x + 1 55 | 56 | @op(output_names=['z']) 57 | def add(x, y): 58 | return x + y 59 | 60 | @op(output_names=['w']) 61 | def mul(x, y): 62 | return x * y 63 | 64 | @op(output_names=['v']) 65 | def final(t): 66 | return t**2 67 | 68 | with storage: 69 | for x in range(10): 70 | y = inc(x) 71 | if x < 5: 72 | z = add(x, y) 73 | v = final(z) 74 | else: 75 | w = mul(x, y) 76 | v = final(w) 77 | 78 | 79 | cf = storage.cf(final).expand_all().merge_vars() 80 | df = cf.df() 81 | assert df.shape[0] == 10 -------------------------------------------------------------------------------- /docs_source/make_docs.py: -------------------------------------------------------------------------------- 1 | """ 2 | convert .ipynb files in this directory to .md files and move them to the docs 3 | directory for mkdocs to use 4 | """ 5 | 6 | import os 7 | import argparse 8 | 9 | # parse the command line arguments 10 | parser = argparse.ArgumentParser(description='Convert .ipynb files to .md files') 11 | parser.add_argument('--filenames', type=str, nargs='+', help='list of filenames to convert') 12 | args = parser.parse_args() 13 | 14 | 15 | if __name__ == '__main__': 16 | if args.filenames: 17 | ipynb_files = args.filenames 18 | # prepend "./" to the filenames 19 | ipynb_files = ['./' + f for f in ipynb_files if not f.startswith('./')] 20 | else: 21 | # find all .ipynb files recursively in the current directory and its subdirectories 22 | ipynb_files = [] 23 | for root, dirs, files in os.walk('.'): 24 | for f in files: 25 | if f.endswith('.ipynb'): 26 | ipynb_files.append(os.path.join(root, f)) 27 | 28 | for f in ipynb_files: 29 | os.system('jupyter nbconvert --to notebook --execute --inplace ' + f) 30 | os.system(f"jupyter nbconvert --to markdown {f}") 31 | 32 | DOCS_REL_PATH = '../docs/docs/' 33 | # now, move the .md files to the docs directory 34 | for f in ipynb_files: 35 | # find the relative directory and the filename 36 | relative_dir = os.path.dirname(f) 37 | fname = os.path.basename(f) 38 | # if the target dir doesn't exist, create it 39 | if not os.path.isdir(DOCS_REL_PATH + relative_dir): 40 | os.system("mkdir -p " + DOCS_REL_PATH + relative_dir) 41 | # move to the DOCS_REL_PATH, under the same directory structure 42 | mv_cmd = "mv " + f.replace('.ipynb', '.md') + " " + DOCS_REL_PATH + relative_dir + '/' + fname.replace('.ipynb', '.md') 43 | print(mv_cmd) 44 | os.system(mv_cmd) 45 | 46 | # also, move any directories named "{fname}_files" to the docs directory 47 | for f in ipynb_files: 48 | files_folder = f.replace('.ipynb', '_files') 49 | if os.path.isdir(files_folder): 50 | # first, remove the directory if it already exists 51 | target_files_path = DOCS_REL_PATH + files_folder 52 | if os.path.isdir(target_files_path): 53 | os.system(f"rm -r {DOCS_REL_PATH}" + files_folder) 54 | # then, move the directory 55 | os.system("mv " + f.replace('.ipynb', '_files') + " " + DOCS_REL_PATH + f.replace('.ipynb', '_files')) -------------------------------------------------------------------------------- /docs/docs/topics/06_advanced_cf.md: -------------------------------------------------------------------------------- 1 | # Advanced `ComputationFrame` tools 2 | 3 | Open In Colab 4 | 5 | This section of the documentation contains some more advanced `ComputationFrame` 6 | topics. 7 | 8 | ## Set-like and graph-like operations on `ComputationFrame`s 9 | A CF is like a "graph of sets", where the elements of each set are either `Ref`s 10 | (for variables) or `Call`s (for functions). As such, it supports both natural 11 | set-like operations applied node/edge-wise, and natural operations using the 12 | graph's connectivity: 13 | 14 | - **union**: given by `cf_1 | cf_2`, this takes the union of the two computation 15 | graphs (merging nodes/edges with the same names), and in case of a merge, the 16 | resulting set at the node is the union of the two sets of `Ref`s or `Call`s. 17 | - **intersection**: given by `cf_1 & cf_2`, this takes the intersection of the 18 | two computation graphs (leaving only nodes/edges with the same name in both), 19 | and the set at each node is the intersection of the two corresponding sets. 20 | - **`.downstream(varnames)`**: restrict the CF to computations that are 21 | downstream of the `Ref`s in chosen variables 22 | - **`.upstream(varnames)`**: dual to `downstream` 23 | 24 | Consider the following example: 25 | 26 | 27 | ```python 28 | # for Google Colab 29 | try: 30 | import google.colab 31 | !pip install git+https://github.com/amakelov/mandala 32 | except: 33 | pass 34 | ``` 35 | 36 | 37 | ```python 38 | from mandala.imports import * 39 | storage = Storage() 40 | 41 | @op 42 | def inc(x): return x + 1 43 | 44 | @op 45 | def add(y, z): return y + z 46 | 47 | @op 48 | def square(w): return w ** 2 49 | 50 | @op 51 | def divmod_(u, v): return divmod(u, v) 52 | 53 | with storage: 54 | xs = [inc(i) for i in range(5)] 55 | ys = [add(x, z=42) for x in xs] + [square(x) for x in range(5, 10)] 56 | zs = [divmod_(x, y) for x, y in zip(xs, ys[3:8])] 57 | ``` 58 | 59 | We have a "middle layer" in the computation that uses both `add` and `square`. 60 | We can get a shared view of the entire computation by taking the union of the 61 | expanded CFs for these two ops: 62 | 63 | 64 | ```python 65 | cf = (storage.cf(add) | storage.cf(square)).expand_all() 66 | cf.draw(verbose=True) 67 | ``` 68 | 69 | 70 | 71 | ![svg](06_advanced_cf_files/06_advanced_cf_5_0.svg) 72 | 73 | 74 | 75 | ## Selection 76 | TODO 77 | -------------------------------------------------------------------------------- /mandala/tests/test_configs.py: -------------------------------------------------------------------------------- 1 | from mandala.imports import * 2 | 3 | 4 | def test_nesting(): 5 | storage = Storage() 6 | with storage(mode='noop'): 7 | assert storage.mode == 'noop' 8 | with storage(): 9 | assert storage.mode == 'run' 10 | assert storage.mode == 'noop' 11 | with storage(mode='noop'): 12 | assert storage.mode == 'noop' 13 | assert storage.mode == 'noop' 14 | assert storage.mode == 'run' 15 | 16 | 17 | def test_noop_simple(): 18 | @op 19 | def inc(x: int, *args, y: int = NewArgDefault(2), z: int = 23) -> int: 20 | return x + y + z + sum(args) 21 | 22 | storage = Storage() 23 | 24 | with storage(mode='noop'): 25 | # test various wrapped values 26 | res = inc(ValuePointer(id='one', obj=1), 1, 1, z=1) 27 | assert res == 6 28 | 29 | 30 | def test_noop_composition(): 31 | 32 | @op 33 | def inc(x: int) -> int: 34 | return x + 1 35 | 36 | @op 37 | def add(x: int, y: int) -> int: 38 | return x + y 39 | 40 | storage = Storage() 41 | 42 | 43 | # test that wrapped values are unwrapped 44 | with storage: 45 | x = inc(20) 46 | 47 | with storage: 48 | x = inc(20) 49 | with storage(mode='noop'): 50 | z = add(x, 21) 51 | assert z == 42 52 | 53 | 54 | def test_noop_standalone(): 55 | storage = Storage() 56 | 57 | @op 58 | def inc(x: int) -> int: 59 | return x + 1 60 | 61 | with storage: 62 | x = inc(20) 63 | assert storage.mode == 'run' 64 | with noop(): 65 | assert storage.mode == 'noop' 66 | y = inc(x) 67 | z = inc(20) 68 | assert storage.mode == 'run' 69 | assert y == 22 70 | assert z == 21 71 | 72 | # now, test it without a context 73 | with noop(): 74 | x = inc(20) 75 | assert x == 21 76 | 77 | 78 | def test_no_new_calls(): 79 | @op 80 | def inc(x: int) -> int: 81 | return x + 1 82 | 83 | storage = Storage() 84 | with storage: 85 | inc(20) 86 | 87 | storage.allow_new_calls(False) 88 | 89 | # memoized calls should still work 90 | with storage: 91 | inc(20) 92 | 93 | try: 94 | with storage: 95 | inc(21) 96 | except RuntimeError as e: 97 | assert str(e) == "Call to inc does not exist and new calls are not allowed." 98 | except Exception as e: 99 | raise e 100 | finally: 101 | storage.allow_new_calls(True) 102 | 103 | -------------------------------------------------------------------------------- /mandala/deps/tracers/tracer_base.py: -------------------------------------------------------------------------------- 1 | from ...common_imports import * 2 | from ...config import Config 3 | import importlib 4 | from ..model import DependencyGraph, CallableNode 5 | from abc import ABC, abstractmethod 6 | 7 | 8 | class TracerABC(ABC): 9 | def __init__( 10 | self, 11 | paths: List[Path], 12 | strict: bool = True, 13 | allow_methods: bool = False, 14 | track_globals: bool = True, 15 | ): 16 | self.call_stack: List[Optional[CallableNode]] = [] 17 | self.graph = DependencyGraph() 18 | self.paths = paths 19 | self.strict = strict 20 | self.allow_methods = allow_methods 21 | self.track_globals = track_globals 22 | 23 | @abstractmethod 24 | def __enter__(self): 25 | raise NotImplementedError 26 | 27 | @abstractmethod 28 | def __exit__(self, exc_type, exc_val, exc_tb): 29 | raise NotImplementedError 30 | 31 | @staticmethod 32 | @abstractmethod 33 | def get_active_trace_obj() -> Optional[Any]: 34 | raise NotImplementedError 35 | 36 | @staticmethod 37 | @abstractmethod 38 | def set_active_trace_obj(trace_obj: Any): 39 | raise NotImplementedError 40 | 41 | @staticmethod 42 | @abstractmethod 43 | def register_leaf_event(trace_obj: Any, data: Any): 44 | raise NotImplementedError 45 | 46 | 47 | BREAK = "break" # stop tracing (currently doesn't really work b/c python) 48 | CONTINUE = "continue" # continue tracing, but don't add call to dependencies 49 | KEEP = "keep" # continue tracing and add call to dependencies 50 | MAIN = "__main__" 51 | 52 | 53 | def get_closure_names(code_obj: types.CodeType, func_qualname: str) -> Tuple[str]: 54 | closure_vars = code_obj.co_freevars 55 | if "." in func_qualname and "__class__" in closure_vars: 56 | closure_vars = tuple([var for var in closure_vars if var != "__class__"]) 57 | return closure_vars 58 | 59 | 60 | def get_module_flow(module_name: Optional[str], paths: List[Path]) -> str: 61 | if module_name is None: 62 | return BREAK 63 | if module_name == MAIN: 64 | return KEEP 65 | try: 66 | module = importlib.import_module(module_name) 67 | is_importable = True 68 | except ModuleNotFoundError: 69 | is_importable = False 70 | if not is_importable: 71 | return BREAK 72 | try: 73 | module_path = Path(inspect.getfile(module)) 74 | except TypeError: 75 | # this happens when the module is a built-in module 76 | return BREAK 77 | if ( 78 | not any(root in module_path.parents for root in paths) 79 | and module_path not in paths 80 | ): 81 | # module is not in the paths we're inspecting; stop tracing 82 | logger.debug(f" Module {module_name} not in paths, BREAK") 83 | return BREAK 84 | elif module_name.startswith(Config.module_name) and not module_name.startswith( 85 | Config.tests_module_name 86 | ): 87 | # this function is part of `mandala` functionality. Continue tracing 88 | # but don't add it to the dependency state 89 | logger.debug(f" Module {module_name} is mandala, CONTINUE") 90 | return CONTINUE 91 | else: 92 | logger.debug(f" Module {module_name} is not mandala but in paths, KEEP") 93 | return KEEP 94 | -------------------------------------------------------------------------------- /docs/docs/topics/05_collections.md: -------------------------------------------------------------------------------- 1 | # Natively handling Python collections 2 | 3 | Open In Colab 4 | 5 | A key benefit of `mandala` over straightforward memoization is that it can make 6 | Python collections (lists, dicts, ...) a native & transparent part of the 7 | memoization process: 8 | 9 | - `@op`s can return collections where each item is a separate `Ref`, so that 10 | later `@op` calls can work with individual elements; 11 | - `@op`s can also accept as input collections where each item is a separate 12 | `Ref`, to e.g. implement aggregation operations over them. 13 | - collections can reuse the storage of their items: if two collections share 14 | some elements, each shared element is stored only once in storage. 15 | - the relationship between a collection and each of its items is a native part 16 | of the computational graph of `@op` calls, and can be propagated automatically 17 | by `ComputationFrame`s. Indeed, **collections are implemented as `@op`s 18 | internally**. 19 | 20 | ## Input/output collections must be explicitly annotated 21 | By default, any collection passed as `@op` input or output will be stored as a 22 | single `Ref` with no structure; the object is **opaque** to the `Storage` 23 | instance. To make the collection **transparent** to the `Storage`, you must 24 | override this behavior explicitly by using a custom type annotation, such as 25 | `MList` for lists, `MDict` for dicts, ...: 26 | 27 | 28 | ```python 29 | # for Google Colab 30 | try: 31 | import google.colab 32 | !pip install git+https://github.com/amakelov/mandala 33 | except: 34 | pass 35 | ``` 36 | 37 | 38 | ```python 39 | from mandala.imports import Storage, op, MList 40 | 41 | storage = Storage() 42 | 43 | @op 44 | def average(nums: MList[float]) -> float: 45 | return sum(nums) / len(nums) 46 | 47 | with storage: 48 | a = average([1, 2, 3, 4, 5]) 49 | ``` 50 | 51 | We can understand how the list was made transparent to the storage by inspecting 52 | the computation frame: 53 | 54 | 55 | ```python 56 | cf = storage.cf(average).expand_all(); cf.draw(verbose=True, orientation='LR') 57 | ``` 58 | 59 | 60 | 61 | ![svg](05_collections_files/05_collections_5_0.svg) 62 | 63 | 64 | 65 | We see that the internal `__make_list__` operation was automatically applied to 66 | create a list, which is then the `Ref` passed to `average`. 67 | 68 | ## How collections interact with `ComputationFrame`s 69 | In general, CFs are turned into dataframes that capture the joint history of the 70 | final `Ref`s in the CF. When there are collection `@op`s in the CF, a single 71 | `Ref` (such as the element of `nums` above) can depend on multiple `Ref`s in 72 | another variable (such as the `Ref`s in the `elts` variable). 73 | 74 | We can observe this by taking the dataframe of the above CF: 75 | 76 | 77 | ```python 78 | print(cf.df(values='objs').to_markdown()) 79 | ``` 80 | 81 | | | elts | __make_list__ | nums | average | var_0 | 82 | |---:|:---------------------------------|:--------------------------------|:----------------|:--------------------------|--------:| 83 | | 0 | ValueCollection([2, 4, 1, 3, 5]) | Call(__make_list__, hid=172...) | [1, 2, 3, 4, 5] | Call(average, hid=38e...) | 3 | 84 | 85 | 86 | There's only a single row, but in the `elts` column we see a `ValueCollection` 87 | object, indicating that there are multiple `Ref`s in `elts` that are 88 | dependencies of `output_0`. 89 | -------------------------------------------------------------------------------- /docs/docs/tutorials/01_hello_files/01_hello_5_1.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 14 | var_0 15 | 16 | var_0 17 | 5 values (5 sinks) 18 | 19 | 20 | 21 | x 22 | 23 | x 24 | 5 values (5 sources) 25 | 26 | 27 | 28 | inc 29 | 30 | inc 31 | @op:inc 32 | 5 calls 33 | 34 | 35 | 36 | x->inc 37 | 38 | 39 | x 40 | (5 values) 41 | 42 | 43 | 44 | inc->var_0 45 | 46 | 47 | output_0 48 | (5 values) 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /docs/docs/topics/03_cf_files/03_cf_22_0.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 14 | test_acc 15 | 16 | test_acc 17 | 2 values (2 sinks) 18 | 19 | 20 | 21 | model 22 | 23 | model 24 | 4 values (4 sources/2 sinks) 25 | 26 | 27 | 28 | eval_model 29 | 30 | eval_model 31 | @op:eval_model 32 | 2 calls 33 | 34 | 35 | 36 | model->eval_model 37 | 38 | 39 | model 40 | (2 values) 41 | 42 | 43 | 44 | eval_model->test_acc 45 | 46 | 47 | output_0 48 | (2 values) 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /mandala/deps/crawler.py: -------------------------------------------------------------------------------- 1 | import types 2 | from ..common_imports import * 3 | from ..utils import unwrap_decorators 4 | import importlib 5 | from .model import ( 6 | DepKey, 7 | CallableNode, 8 | ) 9 | from .utils import ( 10 | is_callable_obj, 11 | extract_func_obj, 12 | unknown_function, 13 | ) 14 | 15 | 16 | def crawl_obj( 17 | obj: Any, 18 | module_name: str, 19 | include_methods: bool, 20 | result: Dict[DepKey, CallableNode], 21 | strict: bool, 22 | objs_result: Dict[DepKey, Callable], 23 | ): 24 | """ 25 | Find functions and optionally methods native to the module of this object. 26 | """ 27 | if is_callable_obj(obj=obj, strict=strict): 28 | if isinstance(unwrap_decorators(obj, strict=False), types.BuiltinFunctionType): 29 | return 30 | v = extract_func_obj(obj=obj, strict=strict) 31 | if v is not unknown_function and v.__module__ != module_name: 32 | # exclude non-local functions 33 | return 34 | dep_key = (module_name, v.__qualname__) 35 | node = CallableNode.from_obj(obj=v, dep_key=dep_key) 36 | result[dep_key] = node 37 | objs_result[dep_key] = obj 38 | if isinstance(obj, type): 39 | if include_methods: 40 | if obj.__module__ != module_name: 41 | return 42 | for k in obj.__dict__.keys(): 43 | v = obj.__dict__[k] 44 | crawl_obj( 45 | obj=v, 46 | module_name=module_name, 47 | include_methods=include_methods, 48 | result=result, 49 | strict=strict, 50 | objs_result=objs_result, 51 | ) 52 | 53 | 54 | def crawl_static( 55 | root: Optional[Path], 56 | strict: bool, 57 | package_name: Optional[str] = None, 58 | include_methods: bool = False, 59 | ) -> Tuple[Dict[DepKey, CallableNode], Dict[DepKey, Callable]]: 60 | """ 61 | Find all python files in the root directory, and use importlib to import 62 | them, look for callable objects, and create callable nodes from them. 63 | """ 64 | result: Dict[DepKey, CallableNode] = {} 65 | objs_result: Dict[DepKey, Callable] = {} 66 | paths = [] 67 | if root is not None: 68 | if root.is_file(): 69 | assert package_name is not None # needs this to be able to import 70 | paths = [root] 71 | else: 72 | paths.extend(list(root.rglob("*.py"))) 73 | paths.append("__main__") 74 | for path in paths: 75 | filename = path.name if path != "__main__" else "__main__" 76 | if filename in ("setup.py", "console.py"): 77 | continue 78 | if path != "__main__" and root is not None: 79 | if root.is_file(): 80 | module_name = root.stem 81 | else: 82 | module_name = ( 83 | path.with_suffix("").relative_to(root).as_posix().replace("/", ".") 84 | ) 85 | if package_name is not None: 86 | module_name = ".".join([package_name, module_name]) 87 | else: 88 | module_name = "__main__" 89 | try: 90 | module = importlib.import_module(module_name) 91 | except: 92 | msg = f"Failed to import {module_name}:" 93 | if strict: 94 | raise ValueError(msg) 95 | else: 96 | logger.warning(msg) 97 | continue 98 | keys = list(module.__dict__.keys()) 99 | for k in keys: 100 | v = module.__dict__[k] 101 | crawl_obj( 102 | obj=v, 103 | module_name=module_name, 104 | strict=strict, 105 | include_methods=include_methods, 106 | result=result, 107 | objs_result=objs_result, 108 | ) 109 | return result, objs_result 110 | -------------------------------------------------------------------------------- /mandala/tps.py: -------------------------------------------------------------------------------- 1 | from .common_imports import * 2 | import typing 3 | from typing import Hashable 4 | 5 | ################################################################################ 6 | ### types 7 | ################################################################################ 8 | from typing import Generic 9 | 10 | T = TypeVar("T") 11 | # Subclassing List 12 | class MList(List[T], Generic[T]): 13 | def identify(self): 14 | return "Type annotation for `mandala` lists" 15 | 16 | 17 | _KT = TypeVar("_KT") 18 | _VT = TypeVar("_VT") 19 | class MDict(Dict[_KT, _VT], Generic[_KT, _VT]): 20 | def identify(self): 21 | return "Type annotation for `mandala` dictionaries" 22 | 23 | 24 | class MSet(Set[T], Generic[T]): 25 | def identify(self): 26 | return "Type annotation for `mandala` sets" 27 | 28 | 29 | class MTuple(Tuple, Generic[T]): 30 | def identify(self): 31 | return "Type annotation for `mandala` tuples" 32 | 33 | 34 | class Type: 35 | @staticmethod 36 | def from_annotation(annotation: Any) -> "Type": 37 | if (annotation is None) or (annotation is inspect._empty): 38 | return AtomType() 39 | elif annotation is typing.Any: 40 | return AtomType() 41 | elif hasattr(annotation, "__origin__"): 42 | if annotation.__origin__ is MList: 43 | elt_annotation = annotation.__args__[0] 44 | return ListType(elt=Type.from_annotation(annotation=elt_annotation)) 45 | elif annotation.__origin__ is MDict: 46 | key_annotation = annotation.__args__[0] 47 | value_annotation = annotation.__args__[1] 48 | return DictType( 49 | key=Type.from_annotation(annotation=key_annotation), 50 | val=Type.from_annotation(annotation=value_annotation), 51 | ) 52 | elif annotation.__origin__ is MSet: 53 | elt_annotation = annotation.__args__[0] 54 | return SetType(elt=Type.from_annotation(annotation=elt_annotation)) 55 | elif annotation.__origin__ is MTuple: 56 | if len(annotation.__args__) == 2 and annotation.__args__[1] == Ellipsis: 57 | return TupleType( 58 | Type.from_annotation(annotation=annotation.__args__[0]) 59 | ) 60 | else: 61 | return TupleType( 62 | *( 63 | Type.from_annotation(annotation=elt_annotation) 64 | for elt_annotation in annotation.__args__ 65 | ) 66 | ) 67 | else: 68 | return AtomType() 69 | elif isinstance(annotation, Type): 70 | return annotation 71 | else: 72 | return AtomType() 73 | 74 | def __eq__(self, other: Any) -> bool: 75 | if type(self) != type(other): 76 | return False 77 | elif isinstance(self, AtomType): 78 | return True 79 | else: 80 | raise NotImplementedError 81 | 82 | 83 | class AtomType(Type): 84 | def __repr__(self): 85 | return "AnyType()" 86 | 87 | 88 | class ListType(Type): 89 | struct_id = "__list__" 90 | model = list 91 | 92 | def __init__(self, elt: Type): 93 | self.elt = elt 94 | 95 | def __repr__(self): 96 | return f"ListType(elt_type={self.elt})" 97 | 98 | 99 | class DictType(Type): 100 | struct_id = "__dict__" 101 | model = dict 102 | 103 | def __init__(self, val: Type, key: Type = None): 104 | self.key = key 105 | self.val = val 106 | 107 | def __repr__(self): 108 | return f"DictType(val_type={self.val})" 109 | 110 | 111 | class SetType(Type): 112 | struct_id = "__set__" 113 | model = set 114 | 115 | def __init__(self, elt: Type): 116 | self.elt = elt 117 | 118 | def __repr__(self): 119 | return f"SetType(elt_type={self.elt})" 120 | 121 | 122 | class TupleType(Type): 123 | struct_id = "__tuple__" 124 | model = tuple 125 | 126 | def __init__(self, *elt_types: Type): 127 | self.elt_types = elt_types 128 | 129 | def __repr__(self): 130 | return f"TupleType(elt_types={self.elt_types})" 131 | -------------------------------------------------------------------------------- /mandala/deps/viz.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | from ..common_imports import * 3 | from ..viz import ( 4 | Node as DotNode, 5 | Edge as DotEdge, 6 | Group as DotGroup, 7 | to_dot_string, 8 | SOLARIZED_LIGHT, 9 | ) 10 | 11 | from .utils import ( 12 | DepKey, 13 | ) 14 | 15 | 16 | def to_string(graph: "model.DependencyGraph") -> str: 17 | """ 18 | Get a string for pretty-printing. 19 | """ 20 | # group the nodes by module 21 | module_groups: Dict[str, List["model.Node"]] = {} 22 | for key, node in graph.nodes.items(): 23 | module_name, _ = key 24 | module_groups.setdefault(module_name, []).append(node) 25 | lines = [] 26 | for module_name, nodes in module_groups.items(): 27 | global_nodes = [node for node in nodes if isinstance(node, model.GlobalVarNode)] 28 | callable_nodes = [ 29 | node for node in nodes if isinstance(node, model.CallableNode) 30 | ] 31 | module_desc = f"MODULE: {module_name}" 32 | lines.append(module_desc) 33 | lines.append("-" * len(module_desc)) 34 | lines.append("===Global Variables===") 35 | for node in global_nodes: 36 | desc = f"{node.obj_name} = {node.readable_content()}" 37 | lines.append(textwrap.indent(desc, 4 * " ")) 38 | # lines.append(f" {node.diff_representation()}") 39 | lines.append("") 40 | lines.append("===Functions===") 41 | # group the methods by class 42 | method_nodes = [node for node in callable_nodes if node.is_method] 43 | func_nodes = [node for node in callable_nodes if not node.is_method] 44 | methods_by_class: Dict[str, List["model.CallableNode"]] = {} 45 | for method_node in method_nodes: 46 | methods_by_class.setdefault(method_node.class_name, []).append(method_node) 47 | for class_name, method_nodes in methods_by_class.items(): 48 | lines.append(textwrap.indent(f"class {class_name}:", 4 * " ")) 49 | for node in method_nodes: 50 | desc = node.readable_content() 51 | lines.append(textwrap.indent(textwrap.dedent(desc), 8 * " ")) 52 | lines.append("") 53 | for node in func_nodes: 54 | desc = node.readable_content() 55 | lines.append(textwrap.indent(textwrap.dedent(desc), 4 * " ")) 56 | lines.append("") 57 | return "\n".join(lines) 58 | 59 | 60 | def to_dot(graph: "model.DependencyGraph") -> str: 61 | nodes: Dict[DepKey, DotNode] = {} 62 | module_groups: Dict[str, DotGroup] = {} # module name -> Group 63 | class_groups: Dict[str, DotGroup] = {} # class name -> Group 64 | for key, node in graph.nodes.items(): 65 | module_name, obj_addr = key 66 | if module_name not in module_groups: 67 | module_groups[module_name] = DotGroup( 68 | label=module_name, nodes=[], parent=None 69 | ) 70 | if isinstance(node, model.GlobalVarNode): 71 | color = SOLARIZED_LIGHT["red"] 72 | elif isinstance(node, model.CallableNode): 73 | color = ( 74 | SOLARIZED_LIGHT["blue"] 75 | if not node.is_method 76 | else SOLARIZED_LIGHT["violet"] 77 | ) 78 | else: 79 | color = SOLARIZED_LIGHT["base03"] 80 | dot_node = DotNode( 81 | internal_name=".".join(key), label=node.obj_name, color=color 82 | ) 83 | nodes[key] = dot_node 84 | module_groups[module_name].nodes.append(dot_node) 85 | if isinstance(node, model.CallableNode) and node.is_method: 86 | class_name = node.class_name 87 | class_groups.setdefault( 88 | class_name, 89 | DotGroup( 90 | label=class_name, 91 | nodes=[], 92 | parent=module_groups[module_name], 93 | ), 94 | ).nodes.append(dot_node) 95 | edges: Dict[Tuple[DotNode, DotNode], DotEdge] = {} 96 | for source, target in graph.edges: 97 | source_node = nodes[source] 98 | target_node = nodes[target] 99 | edge = DotEdge(source_node=source_node, target_node=target_node) 100 | edges[(source_node, target_node)] = edge 101 | dot_string = to_dot_string( 102 | nodes=list(nodes.values()), 103 | edges=list(edges.values()), 104 | groups=list(module_groups.values()) + list(class_groups.values()), 105 | rankdir="BT", 106 | ) 107 | return dot_string 108 | 109 | 110 | from . import model 111 | -------------------------------------------------------------------------------- /docs/docs/topics/01_storage_and_ops.md: -------------------------------------------------------------------------------- 1 | # `Storage` & the `@op` Decorator 2 | 3 | Open In Colab 4 | 5 | A `Storage` object holds all data (saved calls, code and dependencies) for a 6 | collection of memoized functions. In a given project, you should have just one 7 | `Storage` and many `@op`s connected to it. This way, the calls to memoized 8 | functions create a queriable web of interlinked objects. 9 | 10 | 11 | ```python 12 | # for Google Colab 13 | try: 14 | import google.colab 15 | !pip install git+https://github.com/amakelov/mandala 16 | except: 17 | pass 18 | ``` 19 | 20 | ## Creating a `Storage` 21 | 22 | When creating a storage, you must decide if it will be in-memory or persisted on 23 | disk, and whether the storage will automatically version the `@op`s used with 24 | it: 25 | 26 | 27 | ```python 28 | from mandala.imports import Storage 29 | import os 30 | 31 | DB_PATH = 'my_persistent_storage.db' 32 | if os.path.exists(DB_PATH): 33 | os.remove(DB_PATH) 34 | 35 | storage = Storage( 36 | # omit for an in-memory storage 37 | db_path=DB_PATH, 38 | # omit to disable automatic dependency tracking & versioning 39 | # use "__main__" to only track functions defined in the current session 40 | deps_path='__main__', 41 | ) 42 | ``` 43 | 44 | ## Creating `@op`s and saving calls to them 45 | **Any Python function can be decorated with `@op`**: 46 | 47 | 48 | ```python 49 | from mandala.imports import op 50 | 51 | @op 52 | def sum_args(a, *args, b=1, **kwargs): 53 | return a + sum(args) + b + sum(kwargs.values()) 54 | ``` 55 | 56 | In general, calling `sum_args` will behave as if the `@op` decorator is not 57 | there. `@op`-decorated functions will interact with a `Storage` instance **only 58 | when** called inside a `with storage:` block: 59 | 60 | 61 | ```python 62 | with storage: # all `@op` calls inside this block use `storage` 63 | s = sum_args(6, 7, 8, 9, c=11,) 64 | print(s) 65 | ``` 66 | 67 | AtomRef(42, hid=168...) 68 | 69 | 70 | This code runs the call to `sum_args`, and saves the inputs and outputs in the 71 | `storage` object, so that doing the same call later will directly load the saved 72 | outputs. 73 | 74 | ### When should something be an `@op`? 75 | As a general guide, you should make something an `@op` if you want to save its 76 | outputs, e.g. if they take a long time to compute but you need them for later 77 | analysis. Since `@op` [encourages 78 | composition](https://amakelov.github.io/mandala/02_retracing/#how-op-encourages-composition), 79 | you should aim to have `@op`s work on the outputs of other `@op`s, or on the 80 | [collections and/or items](https://amakelov.github.io/mandala/05_collections/) 81 | of outputs of other `@op`s. 82 | 83 | ## Working with `@op` outputs (`Ref`s) 84 | The objects (e.g. `s`) returned by `@op`s are always instances of a subclass of 85 | `Ref` (e.g., `AtomRef`), i.e. **references to objects in the storage**. Every 86 | `Ref` contains two metadata fields: 87 | 88 | - `cid`: a hash of the **content** of the object 89 | - `hid`: a hash of the **computational history** of the object, which is the precise 90 | composition of `@op`s that created this ref. 91 | 92 | Two `Ref`s with the same `cid` may have different `hid`s, and `hid` is the 93 | unique identifier of `Ref`s in the storage. However, only 1 copy per unique 94 | `cid` is stored to avoid duplication in the storage. 95 | 96 | ### `Ref`s can be in memory or not 97 | Additionally, `Ref`s have the `in_memory` property, which indicates if the 98 | underlying object is present in the `Ref` or if this is a "lazy" `Ref` which 99 | only contains metadata. **`Ref`s are only loaded in memory when needed for a new 100 | call to an `@op`**. For example, re-running the last code block: 101 | 102 | 103 | ```python 104 | with storage: 105 | s = sum_args(6, 7, 8, 9, c=11,) 106 | print(s) 107 | ``` 108 | 109 | AtomRef(hid=168..., in_memory=False) 110 | 111 | 112 | To get the object wrapped by a `Ref`, call `storage.unwrap`: 113 | 114 | 115 | ```python 116 | storage.unwrap(s) # loads from storage only if necessary 117 | ``` 118 | 119 | 120 | 121 | 122 | 42 123 | 124 | 125 | 126 | ### Other useful `Storage` methods 127 | 128 | - `Storage.attach(inplace: bool)`: like `unwrap`, but puts the objects in the 129 | `Ref`s if they are not in-memory. 130 | - `Storage.load_ref(hid: str, in_memory: bool)`: load a `Ref` by its history ID, 131 | optionally also loading the underlying object. 132 | 133 | 134 | ```python 135 | print(storage.attach(obj=s, inplace=False)) 136 | print(storage.load_ref(s.hid)) 137 | ``` 138 | 139 | AtomRef(42, hid=168...) 140 | AtomRef(42, hid=168...) 141 | 142 | 143 | ## Working with `Call` objects 144 | Besides `Ref`s, the other kind of object in the storage is the `Call`, which 145 | stores references to the inputs and outputs of a call to an `@op`, together with 146 | metadata that mirrors the `Ref` metadata: 147 | 148 | - `Call.cid`: a content ID for the call, based on the `@op`'s identity, its 149 | version at the time of the call, and the `cid`s of the inputs 150 | - `Call.hid`: a history ID for the call, the same as `Call.cid`, but using the 151 | `hid`s of the inputs. 152 | 153 | **For every `Ref` history ID, there's at most one `Call` that has an output with 154 | this history ID**, and if it exists, this call can be found by calling 155 | `storage.get_ref_creator()`: 156 | 157 | 158 | ```python 159 | call = storage.get_ref_creator(ref=s) 160 | print(call) 161 | display(call.inputs) 162 | display(call.outputs) 163 | ``` 164 | 165 | Call(sum_args, hid=f99...) 166 | 167 | 168 | 169 | {'a': AtomRef(hid=c6a..., in_memory=False), 170 | 'args_0': AtomRef(hid=e0f..., in_memory=False), 171 | 'args_1': AtomRef(hid=479..., in_memory=False), 172 | 'args_2': AtomRef(hid=c37..., in_memory=False), 173 | 'b': AtomRef(hid=610..., in_memory=False), 174 | 'c': AtomRef(hid=a33..., in_memory=False)} 175 | 176 | 177 | 178 | {'output_0': AtomRef(hid=168..., in_memory=False)} 179 | 180 | -------------------------------------------------------------------------------- /docs/docs/topics/05_collections_files/05_collections_5_0.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 14 | var_0 15 | 16 | var_0 17 | 1 values (1 sinks) 18 | 19 | 20 | 21 | nums 22 | 23 | nums 24 | 1 values 25 | 26 | 27 | 28 | average 29 | 30 | average 31 | @op:average 32 | 1 calls 33 | 34 | 35 | 36 | nums->average 37 | 38 | 39 | nums 40 | (1 values) 41 | 42 | 43 | 44 | elts 45 | 46 | elts 47 | 5 values (5 sources) 48 | 49 | 50 | 51 | __make_list__ 52 | 53 | __make_list__ 54 | @op:__make_list__ 55 | 1 calls 56 | 57 | 58 | 59 | elts->__make_list__ 60 | 61 | 62 | *elts 63 | (5 values) 64 | 65 | 66 | 67 | __make_list__->nums 68 | 69 | 70 | list 71 | (1 values) 72 | 73 | 74 | 75 | average->var_0 76 | 77 | 78 | output_0 79 | (1 values) 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /docs/docs/topics/03_cf_files/03_cf_28_0.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 14 | y_test 15 | 16 | y_test 17 | 1 values (1 sources) 18 | 19 | 20 | 21 | eval_model 22 | 23 | eval_model 24 | @op:eval_model 25 | 1 calls 26 | 27 | 28 | 29 | y_test->eval_model 30 | 31 | 32 | y_test 33 | (1 values) 34 | 35 | 36 | 37 | v 38 | 39 | v 40 | 1 values (1 sinks) 41 | 42 | 43 | 44 | model 45 | 46 | model 47 | 1 values (1 sources) 48 | 49 | 50 | 51 | model->eval_model 52 | 53 | 54 | model 55 | (1 values) 56 | 57 | 58 | 59 | X_test 60 | 61 | X_test 62 | 1 values (1 sources) 63 | 64 | 65 | 66 | X_test->eval_model 67 | 68 | 69 | X_test 70 | (1 values) 71 | 72 | 73 | 74 | eval_model->v 75 | 76 | 77 | output_0 78 | (1 values) 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /mandala/tests/test_memoization.py: -------------------------------------------------------------------------------- 1 | from mandala.imports import * 2 | import numpy as np 3 | 4 | 5 | def test_storage(): 6 | storage = Storage() 7 | 8 | @op 9 | def inc(x: int) -> int: 10 | return x + 1 11 | 12 | with storage: 13 | x = 1 14 | y = inc(x) 15 | z = inc(2) 16 | w = inc(y) 17 | 18 | assert w.cid == z.cid 19 | assert w.hid != y.hid 20 | assert w.cid != y.cid 21 | assert storage.unwrap(y) == 2 22 | assert storage.unwrap(z) == 3 23 | assert storage.unwrap(w) == 3 24 | for ref in (y, z, w): 25 | assert storage.attach(ref).in_memory 26 | assert storage.attach(ref).obj == storage.unwrap(ref) 27 | 28 | 29 | def test_signatures(): 30 | storage = Storage() 31 | 32 | @op # a function with a wild input/output signature 33 | def add(x, *args, y: int = 1, **kwargs): 34 | # just sum everything 35 | res = x + sum(args) + y + sum(kwargs.values()) 36 | if kwargs: 37 | return res, kwargs 38 | elif args: 39 | return None 40 | else: 41 | return res 42 | 43 | with storage: 44 | # call the func in all the ways 45 | sum_1 = add(1) 46 | sum_2 = add(1, 2, 3, 4, ) 47 | sum_3 = add(1, 2, 3, 4, y=5) 48 | sum_4 = add(1, 2, 3, 4, y=5, z=6) 49 | sum_5 = add(1, 2, 3, 4, z=5, w=7) 50 | 51 | assert storage.unwrap(sum_1) == 2 52 | assert storage.unwrap(sum_2) == None 53 | assert storage.unwrap(sum_3) == None 54 | assert storage.unwrap(sum_4) == (21, {'z': 6}) 55 | assert storage.unwrap(sum_5) == (23, {'z': 5, 'w': 7}) 56 | 57 | 58 | def test_retracing(): 59 | storage = Storage() 60 | 61 | @op 62 | def inc(x): 63 | return x + 1 64 | 65 | ### iterating a function 66 | with storage: 67 | start = 1 68 | for i in range(10): 69 | start = inc(start) 70 | 71 | with storage: 72 | start = 1 73 | for i in range(10): 74 | start = inc(start) 75 | 76 | ### composing functions 77 | @op 78 | def add(x, y): 79 | return x + y 80 | 81 | with storage: 82 | inp = [1, 2, 3, 4, 5] 83 | stage_1 = [inc(x) for x in inp] 84 | stage_2 = [add(x, y) for x, y in zip(stage_1, stage_1)] 85 | 86 | with storage: 87 | inp = [1, 2, 3, 4, 5] 88 | stage_1 = [inc(x) for x in inp] 89 | stage_2 = [add(x, y) for x, y in zip(stage_1, stage_1)] 90 | 91 | 92 | def test_lists(): 93 | storage = Storage() 94 | 95 | @op 96 | def get_sum(elts: MList[int]) -> int: 97 | return sum(elts) 98 | 99 | @op 100 | def primes_below(n: int) -> MList[int]: 101 | primes = [] 102 | for i in range(2, n): 103 | for p in primes: 104 | if i % p == 0: 105 | break 106 | else: 107 | primes.append(i) 108 | return primes 109 | 110 | @op 111 | def chunked_square(elts: MList[int]) -> MList[int]: 112 | # a model for an op that does something on chunks of a big thing 113 | # to prevent OOM errors 114 | return [x*x for x in elts] 115 | 116 | with storage: 117 | n = 10 118 | primes = primes_below(n) 119 | sum_primes = get_sum(primes) 120 | assert len(primes) == 4 121 | # check indexing 122 | assert storage.unwrap(primes[0]) == 2 123 | assert storage.unwrap(primes[:2]) == [2, 3] 124 | 125 | ### lists w/ overlapping elements 126 | with storage: 127 | n = 100 128 | primes = primes_below(n) 129 | for i in range(0, len(primes), 2): 130 | sum_primes = get_sum(primes[:i+1]) 131 | 132 | with storage: 133 | elts = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 134 | squares = chunked_square(elts) 135 | 136 | 137 | 138 | def test_ignore(): 139 | 140 | storage = Storage() 141 | 142 | @op(ignore_args=('irrelevant',)) 143 | def inc(x, irrelevant): 144 | return x + 1 145 | 146 | 147 | with storage: 148 | inc(23, 0) 149 | 150 | df = storage.cf(inc).df() 151 | assert len(df) == 1 152 | 153 | with storage: 154 | inc(23, 1) 155 | 156 | df = storage.cf(inc).df() 157 | assert len(df) == 1 158 | 159 | 160 | def test_clear_uncommitted(): 161 | storage = Storage() 162 | 163 | @op 164 | def inc(x): 165 | return x + 1 166 | 167 | with storage: 168 | for i in range(10): 169 | inc(i) 170 | # attempt to clear the atoms cache without having committed; this should 171 | # fail by default 172 | try: 173 | storage.atoms.clear() 174 | assert False 175 | except ValueError: 176 | pass 177 | 178 | # now clear the atoms cache after committing 179 | storage.commit() 180 | storage.atoms.clear() 181 | 182 | 183 | 184 | def test_newargdefault(): 185 | storage = Storage() 186 | 187 | @op 188 | def add(x,): 189 | return x + 1 190 | 191 | with storage: 192 | add(1) 193 | 194 | @op 195 | def add(x, y=NewArgDefault(1)): 196 | return x + y 197 | 198 | with storage: 199 | add(1) 200 | # check that we didn't make a new call 201 | assert len(storage.cf(add).calls) == 1 202 | 203 | with storage: 204 | add(1, 1) 205 | # check that we didn't make a new call 206 | assert len(storage.cf(add).calls) == 1 207 | 208 | with storage: 209 | add(1, 2) 210 | # now this should have made a new call! 211 | assert len(storage.cf(add).calls) == 2 212 | 213 | def test_newargdefault_compound_types(): 214 | storage = Storage() 215 | 216 | @op 217 | def add_array(x:np.ndarray): 218 | return x 219 | with storage: 220 | add_array(np.array([1, 2, 3])) 221 | 222 | @op 223 | def add_array(x:np.ndarray, y=NewArgDefault(None)): 224 | return x + y 225 | # test passing a raw value 226 | with storage: 227 | add_array(np.array([1, 2, 3]), y=np.array([4, 5, 6])) 228 | 229 | # now test passing a wrapped value 230 | with storage: 231 | add_array(np.array([1, 2, 3]), y=wrap_atom(np.array([7, 8, 9]))) 232 | 233 | 234 | 235 | 236 | def test_value_pointer(): 237 | storage = Storage() 238 | 239 | @op 240 | def get_mean(x: np.ndarray) -> float: 241 | return x.mean() 242 | 243 | with storage: 244 | X = np.array([1, 2, 3, 4, 5]) 245 | X_pointer = ValuePointer("X", X) 246 | mean = get_mean(X_pointer) 247 | 248 | assert storage.unwrap(mean) == 3.0 249 | df = storage.cf(get_mean).df() 250 | assert len(df) == 1 251 | assert df['x'].item().id == "X" 252 | -------------------------------------------------------------------------------- /docs/docs/topics/02_retracing.md: -------------------------------------------------------------------------------- 1 | # Patterns for Incremental Computation & Development 2 | 3 | Open In Colab 4 | 5 | **`@op`-decorated functions are designed to be composed** with one another. This 6 | enables the same piece of imperative code to adapt to multiple goals depending 7 | on the situation: 8 | 9 | - saving new `@op` calls and/or loading previous ones; 10 | - cheaply resuming an `@op` program after a failure; 11 | - incrementally adding more logic and computations to the same code without 12 | re-doing work. 13 | 14 | **This section of the documentation does not introduce new methods or classes**. 15 | Instead, it demonstrates the programming patterns needed to make effective use 16 | of `mandala`'s memoization capabilities. 17 | 18 | ## How `@op` encourages composition 19 | There are several ways in which the `@op` decorator encourages (and even 20 | enforces) composition of `@op`s: 21 | 22 | - **`@op`s return special objects**, `Ref`s, which prevents accidentally calling 23 | a non-`@op` on the output of an `@op` 24 | - If the inputs to an `@op` call are already `Ref`s, this **speeds up the cache 25 | lookups**. 26 | - If the call can be reused, the **input `Ref`s don't even need to be in memory** 27 | (because the lookup is based only on `Ref` metadata). 28 | - When `@op`s are composed, **computational history propagates** through this 29 | composition. This is automatically leveraged by `ComputationFrame`s when 30 | querying the storage. 31 | - Though not documented here, **`@op`s can natively handle Python 32 | collections** like lists and dicts. This 33 | 34 | When `@op`s are composed in this way, the entire computation becomes "end-to-end 35 | [memoized](https://en.wikipedia.org/wiki/Memoization)". 36 | 37 | ## Toy ML pipeline example 38 | Here's a small example of a machine learning pipeline: 39 | 40 | 41 | ```python 42 | # for Google Colab 43 | try: 44 | import google.colab 45 | !pip install git+https://github.com/amakelov/mandala 46 | except: 47 | pass 48 | ``` 49 | 50 | 51 | ```python 52 | from mandala.imports import * 53 | from sklearn.datasets import load_digits 54 | from sklearn.ensemble import RandomForestClassifier 55 | from sklearn.metrics import accuracy_score 56 | 57 | @op 58 | def load_data(n_class=2): 59 | print("Loading data") 60 | return load_digits(n_class=n_class, return_X_y=True) 61 | 62 | @op 63 | def train_model(X, y, n_estimators=5): 64 | print("Training model") 65 | return RandomForestClassifier(n_estimators=n_estimators, 66 | max_depth=2).fit(X, y) 67 | 68 | @op 69 | def get_acc(model, X, y): 70 | print("Getting accuracy") 71 | return round(accuracy_score(y_pred=model.predict(X), y_true=y), 2) 72 | 73 | storage = Storage() 74 | 75 | with storage: 76 | X, y = load_data() 77 | model = train_model(X, y) 78 | acc = get_acc(model, X, y) 79 | print(acc) 80 | ``` 81 | 82 | Loading data 83 | Training model 84 | Getting accuracy 85 | AtomRef(1.0, hid=d16...) 86 | 87 | 88 | ## Retracing your steps with memoization 89 | Running the computation again will not execute any calls, because it will 90 | exactly **retrace** calls that happened in the past. Moreover, the retracing is 91 | **lazy**: none of the values along the way are actually loaded from storage: 92 | 93 | 94 | ```python 95 | with storage: 96 | X, y = load_data() 97 | print(X, y) 98 | model = train_model(X, y) 99 | print(model) 100 | acc = get_acc(model, X, y) 101 | print(acc) 102 | ``` 103 | 104 | AtomRef(hid=d0f..., in_memory=False) AtomRef(hid=f1a..., in_memory=False) 105 | AtomRef(hid=caf..., in_memory=False) 106 | AtomRef(hid=d16..., in_memory=False) 107 | 108 | 109 | This puts all the `Ref`s along the way in your local variables (as if you've 110 | just ran the computation), which lets you easily inspect any intermediate 111 | variables in this `@op` composition: 112 | 113 | 114 | ```python 115 | storage.unwrap(acc) 116 | ``` 117 | 118 | 119 | 120 | 121 | 1.0 122 | 123 | 124 | 125 | ## Adding new calls "in-place" in `@op`-based programs 126 | With `mandala`, you don't need to think about what's already been computed and 127 | split up code based on that. All past results are automatically reused, so you can 128 | directly build upon the existing composition of `@op`s when you want to add new 129 | functions and/or run old ones with different parameters: 130 | 131 | 132 | ```python 133 | # reuse the previous code to loop over more values of n_class and n_estimators 134 | with storage: 135 | for n_class in (2, 5,): 136 | X, y = load_data(n_class) 137 | for n_estimators in (5, 10): 138 | model = train_model(X, y, n_estimators=n_estimators) 139 | acc = get_acc(model, X, y) 140 | print(acc) 141 | ``` 142 | 143 | AtomRef(hid=d16..., in_memory=False) 144 | Training model 145 | Getting accuracy 146 | AtomRef(1.0, hid=6fd...) 147 | Loading data 148 | Training model 149 | Getting accuracy 150 | AtomRef(0.88, hid=158...) 151 | Training model 152 | Getting accuracy 153 | AtomRef(0.88, hid=214...) 154 | 155 | 156 | Note that the first value of `acc` from the nested loop is with 157 | `in_memory=False`, because it was reused from the call we did before; the other 158 | values are in memory, as they were freshly computed. 159 | 160 | This pattern lets you incrementally build towards the final computations you 161 | want without worrying about how results will be reused. 162 | 163 | ## Using control flow efficiently with `@op`s 164 | Because the unit of storage is the function call (as opposed to an entire script 165 | or notebook), you can transparently use Pythonic control flow. If the control 166 | flow depends on a `Ref`, you can explicitly load just this `Ref` in memory 167 | using `storage.unwrap`: 168 | 169 | 170 | ```python 171 | with storage: 172 | for n_class in (2, 5,): 173 | X, y = load_data(n_class) 174 | for n_estimators in (5, 10): 175 | model = train_model(X, y, n_estimators=n_estimators) 176 | acc = get_acc(model, X, y) 177 | if storage.unwrap(acc) > 0.9: # load only the `Ref`s needed for control flow 178 | print(n_class, n_estimators, storage.unwrap(acc)) 179 | ``` 180 | 181 | 2 5 1.0 182 | 2 10 1.0 183 | 184 | 185 | ## Memoized code as storage interface 186 | An end-to-end memoized composition of `@op`s is like an "imperative" storage 187 | interface. You can modify the code to only focus on particular results of 188 | interest: 189 | 190 | 191 | ```python 192 | with storage: 193 | for n_class in (5,): 194 | X, y = load_data(n_class) 195 | for n_estimators in (5,): 196 | model = train_model(X, y, n_estimators=n_estimators) 197 | acc = get_acc(model, X, y) 198 | print(storage.unwrap(acc), storage.unwrap(model)) 199 | ``` 200 | 201 | 0.88 RandomForestClassifier(max_depth=2, n_estimators=5) 202 | 203 | -------------------------------------------------------------------------------- /docs/docs/tutorials/01_hello_files/01_hello_5_3.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 14 | var_0 15 | 16 | var_0 17 | 5 values (2 sinks) 18 | 19 | 20 | 21 | add 22 | 23 | add 24 | @op:add 25 | 3 calls 26 | 27 | 28 | 29 | var_0->add 30 | 31 | 32 | y 33 | (3 values) 34 | 35 | 36 | 37 | x 38 | 39 | x 40 | 5 values (5 sources) 41 | 42 | 43 | 44 | inc 45 | 46 | inc 47 | @op:inc 48 | 5 calls 49 | 50 | 51 | 52 | x->inc 53 | 54 | 55 | x 56 | (5 values) 57 | 58 | 59 | 60 | x->add 61 | 62 | 63 | x 64 | (3 values) 65 | 66 | 67 | 68 | var_1 69 | 70 | var_1 71 | 3 values (3 sinks) 72 | 73 | 74 | 75 | inc->var_0 76 | 77 | 78 | output_0 79 | (5 values) 80 | 81 | 82 | 83 | add->var_1 84 | 85 | 86 | output_0 87 | (3 values) 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /docs/docs/blog/01_cf_files/01_cf_2_0.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 14 | x 15 | 16 | x 17 | 5 values (5 sources) 18 | 19 | 20 | 21 | add 22 | 23 | add 24 | @op:add 25 | 3 calls 26 | 27 | 28 | 29 | x->add 30 | 31 | 32 | z 33 | (3 values) 34 | 35 | 36 | 37 | increment 38 | 39 | increment 40 | @op:increment 41 | 5 calls 42 | 43 | 44 | 45 | x->increment 46 | 47 | 48 | x 49 | (5 values) 50 | 51 | 52 | 53 | w 54 | 55 | w 56 | 3 values (3 sinks) 57 | 58 | 59 | 60 | y 61 | 62 | y 63 | 5 values (2 sinks) 64 | 65 | 66 | 67 | y->add 68 | 69 | 70 | y 71 | (3 values) 72 | 73 | 74 | 75 | add->w 76 | 77 | 78 | w 79 | (3 values) 80 | 81 | 82 | 83 | increment->y 84 | 85 | 86 | y 87 | (5 values) 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /mandala/deps/deep_versions.py: -------------------------------------------------------------------------------- 1 | from ..common_imports import * 2 | from .utils import DepKey, hash_dict 3 | from .model import ( 4 | Node, 5 | CallableNode, 6 | GlobalVarNode, 7 | TerminalNode, 8 | ) 9 | 10 | 11 | from .shallow_versions import DAG 12 | 13 | 14 | class Version: 15 | """ 16 | Model of a "deep" version of a component that includes versions of its 17 | dependencies. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | component: DepKey, 23 | dynamic_deps_commits: Dict[DepKey, str], 24 | memoized_deps_content_versions: Dict[DepKey, Set[str]], 25 | ): 26 | ### raw data from the trace 27 | # the component whose dependencies are traced 28 | self.component = component 29 | # the content hashes of the direct dependencies 30 | self.direct_deps_commits = dynamic_deps_commits 31 | # pointers to content hashes of versions of memoized calls 32 | self.memoized_deps_content_versions = memoized_deps_content_versions 33 | 34 | ### cached data. These are set against a dependency state 35 | self._is_synced = False 36 | # the expanded set of dependencies, including all transitive 37 | # dependencies. Note this is a set of *content* hashes per dependency 38 | self._content_expansion: Dict[DepKey, Set[str]] = None 39 | # a hash uniquely identifying the content of dependencies of this version 40 | self._content_version: str = None 41 | # the semantic hashes of all dependencies for this version; 42 | # the system enforces that the semantic hash of a dependency is the same 43 | # for all commits of a component referenced by this version 44 | self._semantic_expansion: Dict[DepKey, str] = None 45 | # overall semantic hash of this version 46 | self._semantic_version: str = None 47 | 48 | @property 49 | def presentation(self) -> str: 50 | return f'Version of "{self.component[1]}" from module "{self.component[0]}" (content: {self.content_version}, semantic: {self.semantic_version})' 51 | 52 | @staticmethod 53 | def from_trace( 54 | component: DepKey, nodes: Dict[DepKey, Node], strict: bool = True 55 | ) -> "Version": 56 | dynamic_deps_commits = {} 57 | memoized_deps_content_versions = defaultdict(set) 58 | for dep_key, node in nodes.items(): 59 | if isinstance(node, (CallableNode, GlobalVarNode)): 60 | dynamic_deps_commits[dep_key] = node.content_hash 61 | elif isinstance(node, TerminalNode): 62 | terminal_data = node.representation 63 | pointer_dep_key = terminal_data.dep_key 64 | version_content_hash = terminal_data.call_content_version 65 | memoized_deps_content_versions[pointer_dep_key].add( 66 | version_content_hash 67 | ) 68 | else: 69 | raise ValueError(f"Unexpected node type {type(node)}") 70 | return Version( 71 | component=component, 72 | dynamic_deps_commits=dynamic_deps_commits, 73 | memoized_deps_content_versions=dict(memoized_deps_content_versions), 74 | ) 75 | 76 | ############################################################################ 77 | ### methods for setting cached data from a versioning state 78 | ############################################################################ 79 | def _set_content_expansion(self, all_versions: Dict[DepKey, Dict[str, "Version"]]): 80 | result = defaultdict(set) 81 | for dep_key, content_hash in self.direct_deps_commits.items(): 82 | result[dep_key].add(content_hash) 83 | for ( 84 | dep_key, 85 | memoized_content_versions, 86 | ) in self.memoized_deps_content_versions.items(): 87 | for memoized_content_version in memoized_content_versions: 88 | referenced_version = all_versions[dep_key][memoized_content_version] 89 | for ( 90 | referenced_dep_key, 91 | referenced_content_hashes, 92 | ) in referenced_version.content_expansion.items(): 93 | result[referenced_dep_key].update(referenced_content_hashes) 94 | self._content_expansion = dict(result) 95 | 96 | def _set_content_version(self): 97 | self._content_version = hash_dict( 98 | { 99 | dep_key: tuple(sorted(self.content_expansion[dep_key])) 100 | for dep_key in self.content_expansion 101 | } 102 | ) 103 | 104 | def _set_semantic_expansion( 105 | self, 106 | component_dags: Dict[DepKey, DAG], 107 | all_versions: Dict[DepKey, Dict[str, "Version"]], 108 | ): 109 | result = {} 110 | # from own deps 111 | for dep_key, dep_content_hash in self.direct_deps_commits.items(): 112 | dag = component_dags[dep_key] 113 | semantic_hash = dag.commits[dep_content_hash].semantic_hash 114 | result[dep_key] = semantic_hash 115 | # from pointers 116 | for ( 117 | dep_key, 118 | memoized_content_versions, 119 | ) in self.memoized_deps_content_versions.items(): 120 | for memoized_content_version in memoized_content_versions: 121 | dep_version_semantic_hashes = all_versions[dep_key][ 122 | memoized_content_version 123 | ].semantic_expansion 124 | overlap = set(result.keys()).intersection( 125 | dep_version_semantic_hashes.keys() 126 | ) 127 | if any(result[k] != dep_version_semantic_hashes[k] for k in overlap): 128 | raise ValueError( 129 | f"Version {self} has conflicting semantic hashes for {overlap}" 130 | ) 131 | result.update(dep_version_semantic_hashes) 132 | self._semantic_expansion = result 133 | self._semantic_version = hash_dict(result) 134 | 135 | def sync( 136 | self, 137 | component_dags: Dict[DepKey, DAG], 138 | all_versions: Dict[DepKey, Dict[str, "Version"]], 139 | ): 140 | """ 141 | Set all the cached things in the correct order 142 | """ 143 | self._set_content_expansion(all_versions=all_versions) 144 | self._set_content_version() 145 | self._set_semantic_expansion( 146 | component_dags=component_dags, all_versions=all_versions 147 | ) 148 | self.set_synced() 149 | 150 | @property 151 | def content_version(self) -> str: 152 | assert self._content_version is not None 153 | return self._content_version 154 | 155 | @property 156 | def semantic_version(self) -> str: 157 | assert self._semantic_version is not None 158 | return self._semantic_version 159 | 160 | @property 161 | def semantic_expansion(self) -> Dict[DepKey, str]: 162 | assert self._semantic_expansion is not None 163 | return self._semantic_expansion 164 | 165 | @property 166 | def content_expansion(self) -> Dict[DepKey, Set[str]]: 167 | assert self._content_expansion is not None 168 | return self._content_expansion 169 | 170 | @property 171 | def support(self) -> Iterable[DepKey]: 172 | return self.content_expansion.keys() 173 | 174 | @property 175 | def is_synced(self) -> bool: 176 | return self._is_synced 177 | 178 | def set_synced(self): 179 | # it can only go from unsynced to synced 180 | if self._is_synced: 181 | raise ValueError("Version is already synced") 182 | self._is_synced = True 183 | 184 | def __repr__(self) -> str: 185 | return f""" 186 | Version( 187 | dependencies={['.'.join(elt) for elt in self.support]}, 188 | )""" 189 | -------------------------------------------------------------------------------- /docs/docs/topics/03_cf_files/03_cf_14_1.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 14 | var_0 15 | 16 | var_0 17 | 4 values (4 sinks) 18 | 19 | 20 | 21 | n_estimators 22 | 23 | n_estimators 24 | 4 values (4 sources) 25 | 26 | 27 | 28 | train_model 29 | 30 | train_model 31 | @op:train_model 32 | 4 calls 33 | 34 | 35 | 36 | n_estimators->train_model 37 | 38 | 39 | n_estimators 40 | (4 values) 41 | 42 | 43 | 44 | X_train 45 | 46 | X_train 47 | 1 values (1 sources) 48 | 49 | 50 | 51 | X_train->train_model 52 | 53 | 54 | X_train 55 | (1 values) 56 | 57 | 58 | 59 | y_train 60 | 61 | y_train 62 | 1 values (1 sources) 63 | 64 | 65 | 66 | y_train->train_model 67 | 68 | 69 | y_train 70 | (1 values) 71 | 72 | 73 | 74 | var_1 75 | 76 | var_1 77 | 4 values (4 sinks) 78 | 79 | 80 | 81 | train_model->var_0 82 | 83 | 84 | output_0 85 | (4 values) 86 | 87 | 88 | 89 | train_model->var_1 90 | 91 | 92 | output_1 93 | (4 values) 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /mandala/deps/utils.py: -------------------------------------------------------------------------------- 1 | import types 2 | import dis 3 | import importlib 4 | import gc 5 | from typing import Literal 6 | 7 | from ..common_imports import * 8 | from ..model import Ref 9 | from ..utils import get_content_hash, unwrap_decorators 10 | from ..config import Config 11 | 12 | DepKey = Tuple[str, str] # (module name, object address in module) 13 | 14 | 15 | class GlobalClassifier: 16 | """ 17 | Try to bucket Python objects into categories for the sake of tracking 18 | global state. 19 | """ 20 | SCALARS = "scalars" 21 | DATA = "data" 22 | ALL = "all" 23 | 24 | @staticmethod 25 | def is_excluded(obj: Any) -> bool: 26 | return ( 27 | inspect.ismodule(obj) # exclude modules 28 | or isinstance(obj, type) # exclude classes 29 | or inspect.isfunction(obj) # exclude functions 30 | # or callable(obj) # exclude callables... this is very questionable 31 | or type(obj).__name__ == Config.func_interface_cls_name #! a hack to exclude memoized functions 32 | ) 33 | 34 | @staticmethod 35 | def is_scalar(obj: Any) -> bool: 36 | result = isinstance(obj, (int, float, str, bool, type(None))) 37 | return result 38 | 39 | @staticmethod 40 | def is_data(obj: Any) -> bool: 41 | if GlobalClassifier.is_scalar(obj): 42 | result = True 43 | elif type(obj) in (tuple, list): 44 | result = all(GlobalClassifier.is_data(x) for x in obj) 45 | elif type(obj) is dict: 46 | result = all(GlobalClassifier.is_data((x, y)) for (x, y) in obj.items()) 47 | elif type(obj) in (np.ndarray, pd.DataFrame, pd.Series, pd.Index): 48 | result = True 49 | else: 50 | result = False 51 | # if not result and not GlobalsStrictness.is_callable(obj): 52 | # logger.warning(f'Access to global variable "{obj}" is not tracked because it is not a scalar or a data structure') 53 | return result 54 | 55 | 56 | def is_global_val(obj: Any, allow_only: str = "all") -> bool: 57 | """ 58 | Determine whether the given Python object should be treated as a global 59 | variable whose *content* should be tracked. 60 | 61 | The alternative is that this is a callable object whose *dependencies* 62 | should be tracked. 63 | 64 | However, the distinction is not always clear, making this method somewhat 65 | heuristic. For example, a callable object could be either a function or a 66 | global variable we want to track. 67 | """ 68 | if isinstance(obj, Ref): # easy case; we always track globals that we explicitly wrapped 69 | return True 70 | if allow_only == GlobalClassifier.SCALARS: 71 | return GlobalClassifier.is_scalar(obj=obj) 72 | elif allow_only == GlobalClassifier.DATA: 73 | return GlobalClassifier.is_data(obj=obj) 74 | elif allow_only == GlobalClassifier.ALL: 75 | return not ( 76 | inspect.ismodule(obj) # exclude modules 77 | or isinstance(obj, type) # exclude classes 78 | or inspect.isfunction(obj) # exclude functions 79 | # or callable(obj) # exclude callables ### this is very questionable 80 | or type(obj).__name__ 81 | == Config.func_interface_cls_name #! a hack to exclude memoized functions 82 | ) 83 | else: 84 | raise ValueError( 85 | f"Unknown strictness level for tracking global variables: {allow_only}" 86 | ) 87 | 88 | 89 | def is_callable_obj(obj: Any, strict: bool) -> bool: 90 | if type(obj).__name__ == Config.func_interface_cls_name: 91 | return True 92 | if isinstance(obj, types.FunctionType): 93 | return True 94 | if not strict and callable(obj): # quite permissive 95 | return True 96 | return False 97 | 98 | 99 | def extract_func_obj(obj: Any, strict: bool) -> types.FunctionType: 100 | if type(obj).__name__ == Config.func_interface_cls_name: 101 | return obj.f 102 | obj = unwrap_decorators(obj, strict=strict) 103 | if isinstance(obj, types.BuiltinFunctionType): 104 | raise ValueError(f"Expected a non-built-in function, but got {obj}") 105 | if not isinstance(obj, types.FunctionType): 106 | if not strict: 107 | if ( 108 | isinstance(obj, type) 109 | and hasattr(obj, "__init__") 110 | and isinstance(obj.__init__, types.FunctionType) 111 | ): 112 | return obj.__init__ 113 | else: 114 | return unknown_function 115 | else: 116 | raise ValueError(f"Expected a function, but got {obj} of type {type(obj)}") 117 | return obj 118 | 119 | 120 | def extract_code(obj: Callable) -> types.CodeType: 121 | if type(obj).__name__ == Config.func_interface_cls_name: 122 | obj = obj.f 123 | if isinstance(obj, property): 124 | obj = obj.fget 125 | obj = unwrap_decorators(obj, strict=True) 126 | if not isinstance(obj, (types.FunctionType, types.MethodType)): 127 | logger.debug(f"Expected a function or method, but got {type(obj)}") 128 | # raise ValueError(f"Expected a function or method, but got {obj}") 129 | return obj.__code__ 130 | 131 | 132 | def get_runtime_description(code: types.CodeType) -> Any: 133 | assert isinstance(code, types.CodeType) 134 | return get_sanitized_bytecode_representation(code=code) 135 | 136 | 137 | def get_global_names_candidates(code: types.CodeType) -> Set[str]: 138 | result = set() 139 | instructions = list(dis.get_instructions(code)) 140 | for instr in instructions: 141 | if instr.opname == "LOAD_GLOBAL": 142 | result.add(instr.argval) 143 | if isinstance(instr.argval, types.CodeType): 144 | result.update(get_global_names_candidates(instr.argval)) 145 | return result 146 | 147 | 148 | def get_sanitized_bytecode_representation( 149 | code: types.CodeType, 150 | ) -> List[dis.Instruction]: 151 | instructions = list(dis.get_instructions(code)) 152 | result = [] 153 | for instr in instructions: 154 | if isinstance(instr.argval, types.CodeType): 155 | result.append( 156 | dis.Instruction( 157 | instr.opname, 158 | instr.opcode, 159 | instr.arg, 160 | get_sanitized_bytecode_representation(instr.argval), 161 | "", 162 | instr.offset, 163 | instr.starts_line, 164 | is_jump_target=instr.is_jump_target, 165 | ) 166 | ) 167 | else: 168 | result.append(instr) 169 | return result 170 | 171 | 172 | def unknown_function(): 173 | # this is a placeholder function that we use to get the source of 174 | # functions that we can't get the source of 175 | pass 176 | 177 | 178 | UNKNOWN_GLOBAL_VAR = "UNKNOWN_GLOBAL_VAR" 179 | 180 | 181 | def get_bytecode(f: Union[types.FunctionType, types.CodeType, str]) -> str: 182 | if isinstance(f, str): 183 | f = compile(f, "", "exec") 184 | instructions = dis.get_instructions(f) 185 | return "\n".join([str(i) for i in instructions]) 186 | 187 | 188 | def hash_dict(d: dict) -> str: 189 | return get_content_hash(obj=[(k, d[k]) for k in sorted(d.keys())]) 190 | 191 | 192 | def load_obj(module_name: str, obj_name: str) -> Tuple[Any, bool]: 193 | module = importlib.import_module(module_name) 194 | parts = obj_name.split(".") 195 | current = module 196 | found = True 197 | for part in parts: 198 | if not hasattr(current, part): 199 | found = False 200 | break 201 | else: 202 | current = getattr(current, part) 203 | return current, found 204 | 205 | 206 | def get_dep_key_from_func(func: types.FunctionType) -> DepKey: 207 | module_name = func.__module__ 208 | qualname = func.__qualname__ 209 | return module_name, qualname 210 | 211 | 212 | def get_func_qualname( 213 | func_name: str, 214 | code: types.CodeType, 215 | frame: types.FrameType, 216 | ) -> str: 217 | # this is evil 218 | referrers = gc.get_referrers(code) 219 | func_referrers = [r for r in referrers if isinstance(r, types.FunctionType)] 220 | matching_name = [r for r in func_referrers if r.__name__ == func_name] 221 | if len(matching_name) != 1: 222 | return get_func_qualname_fallback(func_name=func_name, code=code, frame=frame) 223 | else: 224 | return matching_name[0].__qualname__ 225 | 226 | 227 | def get_func_qualname_fallback( 228 | func_name: str, code: types.CodeType, frame: types.FrameType 229 | ) -> str: 230 | # get the argument names to *try* to tell if the function is a method 231 | arg_names = code.co_varnames[: code.co_argcount] 232 | # a necessary but not sufficient condition for this to 233 | # be a method 234 | is_probably_method = ( 235 | len(arg_names) > 0 236 | and arg_names[0] == "self" 237 | and hasattr(frame.f_locals["self"].__class__, func_name) 238 | ) 239 | if is_probably_method: 240 | # handle nested classes via __qualname__ 241 | cls_qualname = frame.f_locals["self"].__class__.__qualname__ 242 | func_qualname = f"{cls_qualname}.{func_name}" 243 | else: 244 | func_qualname = func_name 245 | return func_qualname 246 | -------------------------------------------------------------------------------- /mandala/tests/test_versioning.py: -------------------------------------------------------------------------------- 1 | from mandala.imports import * 2 | from mandala.deps.shallow_versions import DAG 3 | from mandala.deps.tracers import DecTracer, SysTracer 4 | from pathlib import Path 5 | import os 6 | import uuid 7 | import pytest 8 | 9 | def test_dag(): 10 | d = DAG(content_type="code") 11 | try: 12 | d.commit("something") 13 | except AssertionError: 14 | pass 15 | 16 | content_hash_1 = d.init(initial_content="something") 17 | assert len(d.commits) == 1 18 | assert d.head == content_hash_1 19 | content_hash_2 = d.commit(content="something else", is_semantic_change=True) 20 | assert ( 21 | d.commits[content_hash_2].semantic_hash 22 | != d.commits[content_hash_1].semantic_hash 23 | ) 24 | assert d.head == content_hash_2 25 | content_hash_3 = d.commit(content="something else #2", is_semantic_change=False) 26 | assert ( 27 | d.commits[content_hash_3].semantic_hash 28 | == d.commits[content_hash_2].semantic_hash 29 | ) 30 | 31 | content_hash_4 = d.sync(content="something else") 32 | assert content_hash_4 == content_hash_2 33 | assert d.head == content_hash_2 34 | 35 | d.show() 36 | 37 | def generate_path(ext: str) -> Path: 38 | output_dir = Path(os.path.dirname(os.path.abspath(__file__)) + "/output") 39 | fname = str(uuid.uuid4()) + ext 40 | return output_dir / fname 41 | 42 | # MODULE_NAME = "mandala.tests.test_versioning" 43 | # DEPS_PACKAGE = "mandala.tests" 44 | DEPS_PATH = Path(__file__).parent.absolute().resolve() 45 | MODULE_NAME = "test_versioning" 46 | # MODULE_NAME = '__main__' 47 | 48 | 49 | def _test_version_reprs(storage: Storage): 50 | for dag in storage.get_versioner().component_dags.values(): 51 | for compact in [True, False]: 52 | dag.show(compact=compact) 53 | for version in storage.get_versioner().get_flat_versions().values(): 54 | storage.get_versioner().present_dependencies(commits=version.semantic_expansion) 55 | storage.get_versioner().global_topology.show(path=generate_path(ext=".png")) 56 | repr(storage.get_versioner().global_topology) 57 | 58 | 59 | @pytest.mark.parametrize("tracer_impl", [DecTracer]) 60 | def test_unit(tracer_impl): 61 | storage = Storage( 62 | deps_path=DEPS_PATH, tracer_impl=tracer_impl 63 | ) 64 | 65 | # to be able to import this name 66 | global f_1, A 67 | 68 | A = 42 69 | 70 | @op 71 | def f_1(x) -> int: 72 | return 23 + A 73 | 74 | with storage: 75 | f_1(1) 76 | 77 | vs = storage.get_versioner() 78 | print(vs.versions) 79 | f_1_versions = vs.versions[MODULE_NAME, "f_1"] 80 | assert len(f_1_versions) == 1 81 | version = f_1_versions[list(f_1_versions.keys())[0]] 82 | assert set(version.support) == {(MODULE_NAME, "f_1"), (MODULE_NAME, "A")} 83 | _test_version_reprs(storage=storage ) 84 | 85 | @pytest.mark.parametrize("tracer_impl", [DecTracer]) 86 | def test_deps(tracer_impl): 87 | storage = Storage( 88 | deps_path=DEPS_PATH, tracer_impl=tracer_impl 89 | ) 90 | global dep_1, f_2, A 91 | if tracer_impl is SysTracer: 92 | track = lambda x: x 93 | else: 94 | from mandala.deps.tracers.dec_impl import track 95 | 96 | A = 42 97 | 98 | @track 99 | def dep_1(x) -> int: 100 | return 23 101 | 102 | @track 103 | @op 104 | def f_2(x) -> int: 105 | return dep_1(x) + A 106 | 107 | with storage: 108 | f_2(1) 109 | 110 | vs = storage.get_versioner() 111 | f_2_versions = vs.versions[MODULE_NAME, "f_2"] 112 | assert len(f_2_versions) == 1 113 | version = f_2_versions[list(f_2_versions.keys())[0]] 114 | assert set(version.support) == { 115 | (MODULE_NAME, "f_2"), 116 | (MODULE_NAME, "dep_1"), 117 | (MODULE_NAME, "A"), 118 | } 119 | _test_version_reprs(storage=storage) 120 | 121 | 122 | @pytest.mark.parametrize("tracer_impl", [DecTracer]) 123 | def test_changes(tracer_impl): 124 | storage = Storage( 125 | deps_path=DEPS_PATH, tracer_impl=tracer_impl 126 | ) 127 | 128 | global f 129 | 130 | @op 131 | def f(x) -> int: 132 | return x + 1 133 | 134 | with storage: 135 | f(1) 136 | commit_1 = storage.sync_component( 137 | component=f, 138 | is_semantic_change=None, 139 | ) 140 | 141 | @op 142 | def f(x) -> int: 143 | return x + 2 144 | 145 | commit_2 = storage.sync_component(component=f, is_semantic_change=True) 146 | assert commit_1 != commit_2 147 | with storage: 148 | f(1) 149 | 150 | @op 151 | def f(x) -> int: 152 | return x + 1 153 | 154 | # confirm we reverted to the previous version 155 | commit_3 = storage.sync_component( 156 | component=f, 157 | is_semantic_change=None, 158 | ) 159 | assert commit_3 == commit_1 160 | with storage: 161 | f(1) 162 | 163 | # create a new branch 164 | @op 165 | def f(x) -> int: 166 | return x + 3 167 | 168 | commit_4 = storage.sync_component( 169 | component=f, 170 | is_semantic_change=True, 171 | ) 172 | assert commit_4 not in (commit_1, commit_2) 173 | with storage: 174 | f(1) 175 | 176 | f_versions = storage.get_versioner().versions[MODULE_NAME, "f"] 177 | assert len(f_versions) == 3 178 | semantic_versions = [v.semantic_version for v in f_versions.values()] 179 | assert len(set(semantic_versions)) == 3 180 | _test_version_reprs(storage=storage) 181 | 182 | 183 | @pytest.mark.parametrize("tracer_impl", [DecTracer]) 184 | def _test_dependency_patterns(tracer_impl): 185 | # this test is borked currently 186 | storage = Storage( 187 | deps_path=DEPS_PATH, tracer_impl=tracer_impl 188 | ) 189 | global A, B, f_1, f_2, f_3, f_4, f_5, f_6 190 | if tracer_impl is SysTracer: 191 | track = lambda x: x 192 | else: 193 | from mandala.deps.tracers.dec_impl import track 194 | 195 | # global vars 196 | A = 23 197 | B = [1, 2, 3] 198 | 199 | # using a global var 200 | @track 201 | def f_1(x) -> int: 202 | return x + A 203 | 204 | # calling another function 205 | @track 206 | def f_2(x) -> int: 207 | return f_1(x) + B[0] 208 | 209 | # different dependencies per call 210 | @track 211 | @op 212 | def f_3(x) -> int: 213 | if x % 2 == 0: 214 | return f_2(2 * x) 215 | else: 216 | return f_1(x + 1) 217 | 218 | with storage: 219 | x = f_3(0) 220 | 221 | call = storage.get_ref_creator(x) 222 | version = storage.get_versioner().get_flat_versions()[call.content_version] 223 | assert version.support == { 224 | (MODULE_NAME, "f_3"), 225 | (MODULE_NAME, "f_2"), 226 | (MODULE_NAME, "A"), 227 | (MODULE_NAME, "B"), 228 | (MODULE_NAME, "f_1"), 229 | } 230 | with storage: 231 | x = f_3(1) 232 | call = storage.get_ref_creator(x) 233 | version = storage.get_versioner().get_flat_versions()[call.content_version] 234 | assert version.support == { 235 | (MODULE_NAME, "f_3"), 236 | (MODULE_NAME, "f_1"), 237 | (MODULE_NAME, "A"), 238 | } 239 | 240 | # using a lambda 241 | @track 242 | @op 243 | def f_4(x) -> int: 244 | f = lambda y: f_1(y) + B[0] 245 | return f(x) 246 | 247 | # make sure the call in the lambda is detected 248 | with storage: 249 | x = f_4(10) 250 | call = storage.get_ref_creator(x) 251 | version = storage.get_versioner().get_flat_versions()[call.content_version] 252 | assert version.support == { 253 | (MODULE_NAME, "f_4"), 254 | (MODULE_NAME, "f_1"), 255 | (MODULE_NAME, "A"), 256 | (MODULE_NAME, "B"), 257 | } 258 | 259 | # using comprehensions and generators 260 | @op 261 | def f_5(x) -> int: 262 | x = storage.unwrap(x) 263 | a = [f_1.f(y) for y in range(x)] 264 | b = {f_2.f(y) for y in range(x)} 265 | c = {y: f_3.f(y) for y in range(x)} 266 | return sum(storage.unwrap(f_4.f(y)) for y in range(x)) 267 | 268 | with storage: 269 | f_5(10) 270 | 271 | f_5_versions = storage.get_versioner().versions[MODULE_NAME, "f_5"] 272 | assert len(f_5_versions) == 1 273 | version = f_5_versions[list(f_5_versions.keys())[0]] 274 | assert set(version.support) == { 275 | (MODULE_NAME, "f_5"), 276 | (MODULE_NAME, "f_4"), 277 | (MODULE_NAME, "f_3"), 278 | (MODULE_NAME, "f_2"), 279 | (MODULE_NAME, "f_1"), 280 | (MODULE_NAME, "A"), 281 | (MODULE_NAME, "B"), 282 | } 283 | 284 | # nested comprehensions and generators 285 | @op 286 | def f_6(x) -> int: 287 | x = storage.unwrap(x) 288 | # nested list comprehension 289 | a = sum([sum([f_1.f(y) for y in range(x)]) for z in range(x)]) 290 | # nested comprehension with generator 291 | b = sum(sum(f_2.f(y) for y in range(x)) for z in range(storage.unwrap(f_3.f(x)))) 292 | return a + b 293 | 294 | with storage: 295 | f_6(2) 296 | 297 | f_6_versions = storage.get_versioner().versions[MODULE_NAME, "f_6"] 298 | assert len(f_6_versions) == 1 299 | version = f_6_versions[list(f_6_versions.keys())[0]] 300 | assert set(version.support) == { 301 | (MODULE_NAME, "f_6"), 302 | (MODULE_NAME, "f_3"), 303 | (MODULE_NAME, "f_2"), 304 | (MODULE_NAME, "f_1"), 305 | (MODULE_NAME, "A"), 306 | (MODULE_NAME, "B"), 307 | } 308 | _test_version_reprs(storage=storage) 309 | storage.versions(f_6) 310 | storage.get_code(version_id=version.content_version) 311 | 312 | -------------------------------------------------------------------------------- /docs/docs/topics/03_cf_files/03_cf_29_0.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 14 | v 15 | 16 | v 17 | 1 values (1 sinks) 18 | 19 | 20 | 21 | y_test 22 | 23 | y_test 24 | 1 values (1 sources) 25 | 26 | 27 | 28 | eval_model 29 | 30 | eval_model 31 | @op:eval_model 32 | 1 calls 33 | 34 | 35 | 36 | y_test->eval_model 37 | 38 | 39 | y_test 40 | (1 values) 41 | 42 | 43 | 44 | random_seed 45 | 46 | random_seed 47 | 1 values (1 sources) 48 | 49 | 50 | 51 | generate_dataset 52 | 53 | generate_dataset 54 | @op:generate_dataset 55 | 1 calls 56 | 57 | 58 | 59 | random_seed->generate_dataset 60 | 61 | 62 | random_seed 63 | (1 values) 64 | 65 | 66 | 67 | model 68 | 69 | model 70 | 1 values (1 sources) 71 | 72 | 73 | 74 | model->eval_model 75 | 76 | 77 | model 78 | (1 values) 79 | 80 | 81 | 82 | X_test 83 | 84 | X_test 85 | 1 values 86 | 87 | 88 | 89 | X_test->eval_model 90 | 91 | 92 | X_test 93 | (1 values) 94 | 95 | 96 | 97 | eval_model->v 98 | 99 | 100 | output_0 101 | (1 values) 102 | 103 | 104 | 105 | generate_dataset->X_test 106 | 107 | 108 | output_1 109 | (1 values) 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /mandala/deps/tracers/sys_impl.py: -------------------------------------------------------------------------------- 1 | import types 2 | from ...common_imports import * 3 | from ..utils import ( 4 | get_func_qualname, 5 | is_global_val, 6 | get_global_names_candidates, 7 | ) 8 | from ..model import ( 9 | DependencyGraph, 10 | CallableNode, 11 | TerminalData, 12 | TerminalNode, 13 | GlobalVarNode, 14 | ) 15 | import sys 16 | import importlib 17 | from .tracer_base import TracerABC, get_closure_names 18 | 19 | ################################################################################ 20 | ### tracer 21 | ################################################################################ 22 | # control flow constants 23 | from .tracer_base import BREAK, CONTINUE, KEEP, MAIN, get_module_flow 24 | 25 | LEAF_SIGNAL = "leaf_signal" 26 | 27 | # constants for special Python function names 28 | LAMBDA = "" 29 | COMPREHENSIONS = ("", "", "", "") 30 | SKIP_FRAMES = tuple(list(COMPREHENSIONS) + [LAMBDA]) 31 | 32 | 33 | class SysTracer(TracerABC): 34 | def __init__( 35 | self, 36 | paths: List[Path], 37 | graph: Optional[DependencyGraph] = None, 38 | strict: bool = True, 39 | allow_methods: bool = False, 40 | ): 41 | self.call_stack: List[Optional[CallableNode]] = [] 42 | self.graph = DependencyGraph() if graph is None else graph 43 | self.paths = paths 44 | self.path_strs = [str(path) for path in paths] 45 | self.strict = strict 46 | self.allow_methods = allow_methods 47 | 48 | @staticmethod 49 | def leaf_signal(data): 50 | # a way to detect the end of a trace 51 | pass 52 | 53 | @staticmethod 54 | def register_leaf_event(trace_obj: types.FunctionType, data: Any): 55 | SysTracer.leaf_signal(data) 56 | 57 | @staticmethod 58 | def get_active_trace_obj() -> Optional[Any]: 59 | return sys.gettrace() 60 | 61 | @staticmethod 62 | def set_active_trace_obj(trace_obj: Any): 63 | sys.settrace(trace_obj) 64 | 65 | def _process_failure(self, msg: str): 66 | if self.strict: 67 | raise RuntimeError(msg) 68 | else: 69 | logger.warning(msg) 70 | 71 | def find_most_recent_call(self) -> Optional[CallableNode]: 72 | if len(self.call_stack) == 0: 73 | return None 74 | else: 75 | # return the most recent non-None obj on the stack 76 | for i in range(len(self.call_stack) - 1, -1, -1): 77 | call = self.call_stack[i] 78 | if isinstance(call, CallableNode): 79 | return call 80 | return None 81 | 82 | def __enter__(self): 83 | if sys.gettrace() is not None: 84 | # pre-check this is used correctly 85 | raise RuntimeError("Another tracer is already active") 86 | 87 | def tracer(frame: types.FrameType, event: str, arg: Any): 88 | if event not in ("call", "return"): 89 | return 90 | module_name = frame.f_globals.get("__name__") 91 | # fast check to rule out non-user code 92 | if event == "call": 93 | try: 94 | module = importlib.import_module(module_name) 95 | if not any( 96 | [ 97 | module.__file__.startswith(path_str) 98 | for path_str in self.path_strs 99 | ] 100 | ): 101 | return 102 | except: 103 | if module_name != MAIN: 104 | return 105 | code_obj = frame.f_code 106 | func_name = code_obj.co_name 107 | if event == "return": 108 | logging.debug(f"Returning from {func_name}") 109 | if len(self.call_stack) > 0: 110 | popped = self.call_stack.pop() 111 | logging.debug(f"Popped {popped} from call stack") 112 | # some sanity checks 113 | if func_name in SKIP_FRAMES: 114 | if popped != func_name: 115 | self._process_failure( 116 | f"Expected to pop {func_name} from call stack, but popped {popped}" 117 | ) 118 | else: 119 | if popped.obj_name.split(".")[-1] != func_name: 120 | self._process_failure( 121 | f"Expected to pop {func_name} from call stack, but popped {popped.obj_name}" 122 | ) 123 | else: 124 | # something went wrong 125 | raise RuntimeError("Call stack is empty") 126 | return 127 | 128 | if func_name == LEAF_SIGNAL: 129 | data: TerminalData = frame.f_locals["data"] 130 | unique_id = "_".join( 131 | [ 132 | data.op_internal_name, 133 | str(data.op_version), 134 | data.call_content_version, 135 | data.call_semantic_version, 136 | ] 137 | ) 138 | node = TerminalNode( 139 | module_name=module_name, obj_name=unique_id, representation=data 140 | ) 141 | most_recent_option = self.find_most_recent_call() 142 | if most_recent_option is not None: 143 | self.graph.add_edge(source=most_recent_option, target=node) 144 | # self.call_stack.append(None) 145 | return 146 | 147 | module_control_flow = get_module_flow( 148 | paths=self.paths, module_name=module_name 149 | ) 150 | if module_control_flow in (BREAK, CONTINUE): 151 | frame.f_trace = None 152 | return 153 | 154 | logger.debug(f"Tracing call to {module_name}.{func_name}") 155 | 156 | ### get the qualified name of the function/method 157 | func_qualname = get_func_qualname( 158 | func_name=func_name, code=code_obj, frame=frame 159 | ) 160 | if "." in func_qualname: 161 | if not self.allow_methods: 162 | raise RuntimeError( 163 | f"Methods are currently not supported: {func_qualname} from {module_name}" 164 | ) 165 | 166 | ### detect use of closure variables 167 | closure_names = get_closure_names( 168 | code_obj=code_obj, func_qualname=func_qualname 169 | ) 170 | if len(closure_names) > 0 and func_name not in SKIP_FRAMES: 171 | closure_values = { 172 | var: frame.f_locals.get(var, frame.f_globals.get(var, None)) 173 | for var in closure_names 174 | } 175 | msg = f"Found closure variables accessed by function {module_name}.{func_name}:\n{closure_values}" 176 | self._process_failure(msg=msg) 177 | 178 | ### get the global variables used by the function 179 | globals_nodes = [] 180 | for name in get_global_names_candidates(code=code_obj): 181 | # names used by the function; not all of them are global variables 182 | if name in frame.f_globals: 183 | global_val = frame.f_globals[name] 184 | if not is_global_val(global_val): 185 | continue 186 | node = GlobalVarNode.from_obj( 187 | obj=global_val, dep_key=(module_name, name) 188 | ) 189 | globals_nodes.append(node) 190 | 191 | ### if this is a comprehension call, add the globals to the most 192 | ### recent tracked call 193 | if func_name in SKIP_FRAMES: 194 | most_recent_tracked_call = self.find_most_recent_call() 195 | assert most_recent_tracked_call is not None 196 | for global_node in globals_nodes: 197 | self.graph.add_edge( 198 | source=most_recent_tracked_call, target=global_node 199 | ) 200 | self.call_stack.append(func_name) 201 | return tracer 202 | 203 | ### manage the call stack 204 | call_node = CallableNode.from_runtime( 205 | module_name=module_name, obj_name=func_qualname, code_obj=code_obj 206 | ) 207 | self.graph.add_node(node=call_node) 208 | ### global variable edges from this function always exist 209 | for global_node in globals_nodes: 210 | self.graph.add_edge(source=call_node, target=global_node) 211 | ### call edges exist only if there is a caller on the stack 212 | if len(self.call_stack) > 0: 213 | # find the most recent tracked call 214 | most_recent_tracked_call = self.find_most_recent_call() 215 | if most_recent_tracked_call is not None: 216 | self.graph.add_edge( 217 | source=most_recent_tracked_call, target=call_node 218 | ) 219 | self.call_stack.append(call_node) 220 | if len(self.call_stack) == 1: 221 | self.graph.roots.add(call_node.key) 222 | return tracer 223 | 224 | sys.settrace(tracer) 225 | 226 | def __exit__(self, *exc_info): 227 | sys.settrace(None) # Stop tracing 228 | 229 | 230 | class SuspendSysTraceContext: 231 | def __init__(self): 232 | self.suspended_trace = None 233 | 234 | def __enter__(self) -> "SuspendSysTraceContext": 235 | if sys.gettrace() is not None: 236 | self.suspended_trace = sys.gettrace() 237 | sys.settrace(None) 238 | return self 239 | 240 | def __exit__(self, *exc_info): 241 | if self.suspended_trace is not None: 242 | sys.settrace(self.suspended_trace) 243 | self.suspended_trace = None 244 | -------------------------------------------------------------------------------- /docs/docs/tutorials/02_ml_files/02_ml_22_3.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 14 | var_0 15 | 16 | var_0 17 | 3 values (3 sinks) 18 | 19 | 20 | 21 | n_estimators 22 | 23 | n_estimators 24 | 2 values (2 sources) 25 | 26 | 27 | 28 | train_model 29 | 30 | train_model 31 | @op:train_model 32 | 3 calls 33 | 34 | 35 | 36 | n_estimators->train_model 37 | 38 | 39 | n_estimators 40 | (2 values) 41 | 42 | 43 | 44 | X_train 45 | 46 | X_train 47 | 1 values (1 sources) 48 | 49 | 50 | 51 | X_train->train_model 52 | 53 | 54 | X_train 55 | (1 values) 56 | 57 | 58 | 59 | train_acc 60 | 61 | train_acc 62 | 3 values 63 | 64 | 65 | 66 | eval_model 67 | 68 | eval_model 69 | @op:eval_model 70 | 3 calls 71 | 72 | 73 | 74 | train_acc->eval_model 75 | 76 | 77 | model 78 | (3 values) 79 | 80 | 81 | 82 | model 83 | 84 | model 85 | 3 values (3 sinks) 86 | 87 | 88 | 89 | y_train 90 | 91 | y_train 92 | 1 values (1 sources) 93 | 94 | 95 | 96 | y_train->train_model 97 | 98 | 99 | y_train 100 | (1 values) 101 | 102 | 103 | 104 | eval_model->var_0 105 | 106 | 107 | output_0 108 | (3 values) 109 | 110 | 111 | 112 | train_model->train_acc 113 | 114 | 115 | output_0 116 | (3 values) 117 | 118 | 119 | 120 | train_model->model 121 | 122 | 123 | output_1 124 | (3 values) 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /docs/docs/blog/01_cf_files/01_cf_21_0.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 14 | y_test 15 | 16 | y_test 17 | 2 values (2 sources) 18 | 19 | 20 | 21 | eval_ensemble 22 | 23 | eval_ensemble 24 | @op:eval_ensemble 25 | 2 calls 26 | 27 | 28 | 29 | y_test->eval_ensemble 30 | 31 | 32 | y_test 33 | (2 values) 34 | 35 | 36 | 37 | eval_model 38 | 39 | eval_model 40 | @op:eval_model 41 | 12 calls 42 | 43 | 44 | 45 | y_test->eval_model 46 | 47 | 48 | y_test 49 | (2 values) 50 | 51 | 52 | 53 | X_test 54 | 55 | X_test 56 | 2 values (2 sources) 57 | 58 | 59 | 60 | X_test->eval_ensemble 61 | 62 | 63 | X_test 64 | (2 values) 65 | 66 | 67 | 68 | X_test->eval_model 69 | 70 | 71 | X_test 72 | (2 values) 73 | 74 | 75 | 76 | accuracy 77 | 78 | accuracy 79 | 14 values (14 sinks) 80 | 81 | 82 | 83 | model 84 | 85 | model 86 | 12 values (12 sources) 87 | 88 | 89 | 90 | model->eval_model 91 | 92 | 93 | model 94 | (12 values) 95 | 96 | 97 | 98 | models 99 | 100 | models 101 | 2 values (2 sources) 102 | 103 | 104 | 105 | models->eval_ensemble 106 | 107 | 108 | models 109 | (2 values) 110 | 111 | 112 | 113 | eval_ensemble->accuracy 114 | 115 | 116 | accuracy 117 | (2 values) 118 | 119 | 120 | 121 | eval_model->accuracy 122 | 123 | 124 | accuracy 125 | (12 values) 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /mandala/deps/model.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | from abc import abstractmethod, ABC 3 | import types 4 | 5 | from ..common_imports import * 6 | from ..utils import get_content_hash 7 | from ..viz import ( 8 | write_output, 9 | ) 10 | from ..model import Ref 11 | 12 | from .utils import ( 13 | DepKey, 14 | load_obj, 15 | get_runtime_description, 16 | extract_code, 17 | unknown_function, 18 | UNKNOWN_GLOBAL_VAR, 19 | ) 20 | 21 | 22 | class Node(ABC): 23 | def __init__(self, module_name: str, obj_name: str, representation: Any): 24 | self.module_name = module_name 25 | self.obj_name = obj_name 26 | self.representation = representation 27 | 28 | @property 29 | def key(self) -> DepKey: 30 | return (self.module_name, self.obj_name) 31 | 32 | def present_key(self) -> str: 33 | raise NotImplementedError() 34 | 35 | @staticmethod 36 | @abstractmethod 37 | def represent(obj: Any) -> Any: 38 | raise NotImplementedError 39 | 40 | @abstractmethod 41 | def content(self) -> Any: 42 | raise NotImplementedError 43 | 44 | @abstractmethod 45 | def readable_content(self) -> str: 46 | raise NotImplementedError 47 | 48 | @property 49 | @abstractmethod 50 | def content_hash(self) -> str: 51 | raise NotImplementedError 52 | 53 | def load_obj(self, skip_missing: bool, skip_silently: bool) -> Any: 54 | obj, found = load_obj(module_name=self.module_name, obj_name=self.obj_name) 55 | if not found: 56 | msg = f"{self.present_key()} not found" 57 | if skip_missing: 58 | if skip_silently: 59 | logger.debug(msg) 60 | else: 61 | logger.warning(msg) 62 | if hasattr(self, "FALLBACK_OBJ"): 63 | return self.FALLBACK_OBJ 64 | else: 65 | raise ValueError(f"No fallback object defined for {self.__class__}") 66 | else: 67 | raise ValueError(msg) 68 | return obj 69 | 70 | 71 | class CallableNode(Node): 72 | FALLBACK_OBJ = unknown_function 73 | 74 | def __init__( 75 | self, 76 | module_name: str, 77 | obj_name: str, 78 | representation: Optional[str], 79 | runtime_description: str, 80 | ): 81 | self.module_name = module_name 82 | self.obj_name = obj_name 83 | self.runtime_description = runtime_description 84 | if representation is not None: 85 | self._set_representation(value=representation) 86 | else: 87 | self._representation = None 88 | self._content_hash = None 89 | 90 | @staticmethod 91 | def from_obj(obj: Any, dep_key: DepKey) -> "CallableNode": 92 | representation = CallableNode.represent(obj=obj) 93 | code_obj = extract_code(obj) 94 | runtime_description = get_runtime_description(code=code_obj) 95 | return CallableNode( 96 | module_name=dep_key[0], 97 | obj_name=dep_key[1], 98 | representation=representation, 99 | runtime_description=runtime_description, 100 | ) 101 | 102 | @staticmethod 103 | def from_runtime( 104 | module_name: str, 105 | obj_name: str, 106 | code_obj: types.CodeType, 107 | ) -> "CallableNode": 108 | return CallableNode( 109 | module_name=module_name, 110 | obj_name=obj_name, 111 | representation=None, 112 | runtime_description=get_runtime_description(code=code_obj), 113 | ) 114 | 115 | @property 116 | def representation(self) -> str: 117 | return self._representation 118 | 119 | def _set_representation(self, value: str): 120 | assert isinstance(value, str) 121 | self._representation = value 122 | self._content_hash = get_content_hash(value) 123 | 124 | @representation.setter 125 | def representation(self, value: str): 126 | self._set_representation(value) 127 | 128 | @property 129 | def is_method(self) -> bool: 130 | return "." in self.obj_name 131 | 132 | def present_key(self) -> str: 133 | return f"function {self.obj_name} from module {self.module_name}" 134 | 135 | @property 136 | def class_name(self) -> str: 137 | assert self.is_method 138 | return ".".join(self.obj_name.split(".")[:-1]) 139 | 140 | @staticmethod 141 | def represent( 142 | obj: Union[types.FunctionType, types.CodeType, Callable], 143 | allow_fallback: bool = False, 144 | ) -> str: 145 | if type(obj).__name__ == "Op": 146 | obj = obj.f 147 | if not isinstance(obj, (types.FunctionType, types.MethodType, types.CodeType)): 148 | logger.warning(f"Found {obj} of type {type(obj)}") 149 | try: 150 | source = inspect.getsource(obj) 151 | except Exception as e: 152 | msg = f"Could not get source for {obj} because {e}" 153 | if allow_fallback: 154 | source = inspect.getsource(CallableNode.FALLBACK_OBJ) 155 | logger.warning(msg) 156 | else: 157 | raise RuntimeError(msg) 158 | # strip whitespace to prevent different sources looking the same in the 159 | # ui 160 | lines = source.splitlines() 161 | lines = [line.rstrip() for line in lines] 162 | source = "\n".join(lines) 163 | return source 164 | 165 | def content(self) -> str: 166 | return self.representation 167 | 168 | def readable_content(self) -> str: 169 | return self.representation 170 | 171 | @property 172 | def content_hash(self) -> str: 173 | assert isinstance(self._content_hash, str) 174 | return self._content_hash 175 | 176 | 177 | class GlobalVarNode(Node): 178 | FALLBACK_OBJ = UNKNOWN_GLOBAL_VAR 179 | 180 | def __init__( 181 | self, 182 | module_name: str, 183 | obj_name: str, 184 | # (content hash, truncated repr) 185 | representation: Tuple[str, str], 186 | ): 187 | self.module_name = module_name 188 | self.obj_name = obj_name 189 | self._representation = representation 190 | 191 | @staticmethod 192 | def from_obj(obj: Any, dep_key: DepKey, 193 | skip_unhashable: bool = False, 194 | skip_silently: bool = False,) -> "GlobalVarNode": 195 | representation = GlobalVarNode.represent(obj=obj, skip_unhashable=skip_unhashable, skip_silently=skip_silently) 196 | return GlobalVarNode( 197 | module_name=dep_key[0], 198 | obj_name=dep_key[1], 199 | representation=representation, 200 | ) 201 | 202 | @property 203 | def representation(self) -> Tuple[str, str]: 204 | return self._representation 205 | 206 | @staticmethod 207 | def represent(obj: Any, skip_unhashable: bool = True, 208 | skip_silently: bool = False, 209 | ) -> Tuple[str, str]: 210 | """ 211 | Return a hash of this global variable's value + a truncated 212 | representation useful for debugging/printing. 213 | 214 | If `obj` is a `Ref`, the content hash is reused from the `Ref` object. 215 | This is so that you can avoid repeatedly hashing the same (potentially 216 | large) object any time the code state needs to be synced. 217 | """ 218 | truncated_repr = textwrap.shorten(text=repr(obj), width=80) 219 | if isinstance(obj, Ref): 220 | content_hash = obj.cid 221 | else: 222 | try: 223 | content_hash = get_content_hash(obj=obj) 224 | except Exception as e: 225 | shortened_exception = textwrap.shorten(text=str(e), width=80) 226 | msg = f"Failed to hash global variable {truncated_repr} of type {type(obj)}, because {shortened_exception}" 227 | if skip_unhashable: 228 | content_hash = UNKNOWN_GLOBAL_VAR 229 | if skip_silently: 230 | logger.debug(msg) 231 | else: 232 | logger.warning(msg) 233 | else: 234 | raise RuntimeError(msg) 235 | return content_hash, truncated_repr 236 | 237 | def present_key(self) -> str: 238 | return f"global variable {self.obj_name} from module {self.module_name}" 239 | 240 | def content(self) -> str: 241 | return self.representation 242 | 243 | def readable_content(self) -> str: 244 | return self.representation[1] 245 | 246 | @property 247 | def content_hash(self) -> str: 248 | assert isinstance(self.representation, tuple) 249 | return self.representation[0] 250 | 251 | 252 | class TerminalData: 253 | def __init__( 254 | self, 255 | op_internal_name: str, 256 | op_version: int, 257 | call_content_version: str, 258 | call_semantic_version: str, 259 | # data: Tuple[Tuple[str, int], Tuple[str, str]], 260 | dep_key: DepKey, 261 | ): 262 | # ((internal name, version), (content_version, semantic_version)) 263 | self.op_internal_name = op_internal_name 264 | self.op_version = op_version 265 | self.call_content_version = call_content_version 266 | self.call_semantic_version = call_semantic_version 267 | self.dep_key = dep_key 268 | 269 | 270 | class TerminalNode(Node): 271 | def __init__(self, module_name: str, obj_name: str, representation: TerminalData): 272 | self.module_name = module_name 273 | self.obj_name = obj_name 274 | self.representation = representation 275 | 276 | @property 277 | def key(self) -> DepKey: 278 | return self.module_name, self.obj_name 279 | 280 | def present_key(self) -> str: 281 | raise NotImplementedError 282 | 283 | @property 284 | def content_hash(self) -> str: 285 | raise NotImplementedError 286 | 287 | def content(self) -> Any: 288 | raise NotImplementedError 289 | 290 | def readable_content(self) -> str: 291 | raise NotImplementedError 292 | 293 | @staticmethod 294 | def represent(obj: Any) -> Any: 295 | raise NotImplementedError 296 | 297 | 298 | class DependencyGraph: 299 | def __init__(self): 300 | self.nodes: Dict[DepKey, Node] = {} 301 | self.roots: Set[DepKey] = set() 302 | self.edges: Set[Tuple[DepKey, DepKey]] = set() 303 | 304 | def get_trace_state(self) -> Tuple[DepKey, Dict[DepKey, Node]]: 305 | if len(self.roots) != 1: 306 | raise ValueError(f"Expected exactly one root, got {len(self.roots)}") 307 | component = list(self.roots)[0] 308 | return component, self.nodes 309 | 310 | def show(self, path: Optional[Path] = None, how: str = "none"): 311 | dot = to_dot(self) 312 | output_ext = "svg" if how in ["browser"] else "png" 313 | return write_output( 314 | dot_string=dot, output_path=path, output_ext=output_ext, show_how=how 315 | ) 316 | 317 | def __repr__(self) -> str: 318 | if len(self.nodes) == 0: 319 | return "DependencyGraph()" 320 | return to_string(self) 321 | 322 | def add_node(self, node: Node): 323 | self.nodes[node.key] = node 324 | 325 | def add_edge(self, source: Node, target: Node): 326 | if source.key not in self.nodes: 327 | self.add_node(source) 328 | if target.key not in self.nodes: 329 | self.add_node(target) 330 | self.edges.add((source.key, target.key)) 331 | 332 | 333 | from .viz import to_dot, to_string 334 | -------------------------------------------------------------------------------- /docs/docs/tutorials/02_ml_files/02_ml_22_1.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 9 | 10 | G 11 | 12 | 13 | 14 | random_seed 15 | 16 | random_seed 17 | 1 values (1 sources) 18 | 19 | 20 | 21 | generate_dataset 22 | 23 | generate_dataset 24 | @op:generate_dataset 25 | 1 calls 26 | 27 | 28 | 29 | random_seed->generate_dataset 30 | 31 | 32 | random_seed 33 | (1 values) 34 | 35 | 36 | 37 | n_estimators 38 | 39 | n_estimators 40 | 2 values (2 sources) 41 | 42 | 43 | 44 | train_model 45 | 46 | train_model 47 | @op:train_model 48 | 3 calls 49 | 50 | 51 | 52 | n_estimators->train_model 53 | 54 | 55 | n_estimators 56 | (2 values) 57 | 58 | 59 | 60 | X_train 61 | 62 | X_train 63 | 1 values 64 | 65 | 66 | 67 | X_train->train_model 68 | 69 | 70 | X_train 71 | (1 values) 72 | 73 | 74 | 75 | train_acc 76 | 77 | train_acc 78 | 3 values (3 sinks) 79 | 80 | 81 | 82 | model 83 | 84 | model 85 | 3 values (3 sinks) 86 | 87 | 88 | 89 | y_train 90 | 91 | y_train 92 | 1 values 93 | 94 | 95 | 96 | y_train->train_model 97 | 98 | 99 | y_train 100 | (1 values) 101 | 102 | 103 | 104 | generate_dataset->X_train 105 | 106 | 107 | output_0 108 | (1 values) 109 | 110 | 111 | 112 | generate_dataset->y_train 113 | 114 | 115 | output_2 116 | (1 values) 117 | 118 | 119 | 120 | train_model->train_acc 121 | 122 | 123 | output_0 124 | (3 values) 125 | 126 | 127 | 128 | train_model->model 129 | 130 | 131 | output_1 132 | (3 values) 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 mandala authors 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /docs_source/topics/01_storage_and_ops.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# `Storage` & the `@op` Decorator\n", 8 | " \n", 9 | " \"Open \n", 10 | "\n", 11 | "A `Storage` object holds all data (saved calls, code and dependencies) for a\n", 12 | "collection of memoized functions. In a given project, you should have just one\n", 13 | "`Storage` and many `@op`s connected to it. This way, the calls to memoized\n", 14 | "functions create a queriable web of interlinked objects. " 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": { 21 | "execution": { 22 | "iopub.execute_input": "2024-07-11T14:31:35.163619Z", 23 | "iopub.status.busy": "2024-07-11T14:31:35.162794Z", 24 | "iopub.status.idle": "2024-07-11T14:31:35.173665Z", 25 | "shell.execute_reply": "2024-07-11T14:31:35.172980Z" 26 | } 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "# for Google Colab\n", 31 | "try:\n", 32 | " import google.colab\n", 33 | " !pip install git+https://github.com/amakelov/mandala\n", 34 | "except:\n", 35 | " pass" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "## Creating a `Storage`\n", 43 | "\n", 44 | "When creating a storage, you must decide if it will be in-memory or persisted on\n", 45 | "disk, and whether the storage will automatically version the `@op`s used with\n", 46 | "it:" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "metadata": { 53 | "execution": { 54 | "iopub.execute_input": "2024-07-11T14:31:35.176299Z", 55 | "iopub.status.busy": "2024-07-11T14:31:35.176085Z", 56 | "iopub.status.idle": "2024-07-11T14:31:37.876054Z", 57 | "shell.execute_reply": "2024-07-11T14:31:37.875495Z" 58 | } 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "from mandala.imports import Storage\n", 63 | "import os\n", 64 | "\n", 65 | "DB_PATH = 'my_persistent_storage.db'\n", 66 | "if os.path.exists(DB_PATH):\n", 67 | " os.remove(DB_PATH)\n", 68 | "\n", 69 | "storage = Storage(\n", 70 | " # omit for an in-memory storage\n", 71 | " db_path=DB_PATH,\n", 72 | " # omit to disable automatic dependency tracking & versioning\n", 73 | " # use \"__main__\" to only track functions defined in the current session\n", 74 | " deps_path='__main__', \n", 75 | ")" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "## Creating `@op`s and saving calls to them\n", 83 | "**Any Python function can be decorated with `@op`**:" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 3, 89 | "metadata": { 90 | "execution": { 91 | "iopub.execute_input": "2024-07-11T14:31:37.880287Z", 92 | "iopub.status.busy": "2024-07-11T14:31:37.879681Z", 93 | "iopub.status.idle": "2024-07-11T14:31:37.914152Z", 94 | "shell.execute_reply": "2024-07-11T14:31:37.913319Z" 95 | } 96 | }, 97 | "outputs": [], 98 | "source": [ 99 | "from mandala.imports import op\n", 100 | "\n", 101 | "@op \n", 102 | "def sum_args(a, *args, b=1, **kwargs):\n", 103 | " return a + sum(args) + b + sum(kwargs.values())" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "In general, calling `sum_args` will behave as if the `@op` decorator is not\n", 111 | "there. `@op`-decorated functions will interact with a `Storage` instance **only\n", 112 | "when** called inside a `with storage:` block:" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 4, 118 | "metadata": { 119 | "execution": { 120 | "iopub.execute_input": "2024-07-11T14:31:37.918828Z", 121 | "iopub.status.busy": "2024-07-11T14:31:37.918430Z", 122 | "iopub.status.idle": "2024-07-11T14:31:38.022162Z", 123 | "shell.execute_reply": "2024-07-11T14:31:38.021195Z" 124 | } 125 | }, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "AtomRef(42, hid=168...)\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "with storage: # all `@op` calls inside this block use `storage`\n", 137 | " s = sum_args(6, 7, 8, 9, c=11,)\n", 138 | " print(s)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "This code runs the call to `sum_args`, and saves the inputs and outputs in the\n", 146 | "`storage` object, so that doing the same call later will directly load the saved\n", 147 | "outputs." 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "### When should something be an `@op`?\n", 155 | "As a general guide, you should make something an `@op` if you want to save its\n", 156 | "outputs, e.g. if they take a long time to compute but you need them for later\n", 157 | "analysis. Since `@op` [encourages\n", 158 | "composition](https://amakelov.github.io/mandala/02_retracing/#how-op-encourages-composition),\n", 159 | "you should aim to have `@op`s work on the outputs of other `@op`s, or on the\n", 160 | "[collections and/or items](https://amakelov.github.io/mandala/05_collections/)\n", 161 | "of outputs of other `@op`s." 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "## Working with `@op` outputs (`Ref`s)\n", 169 | "The objects (e.g. `s`) returned by `@op`s are always instances of a subclass of\n", 170 | "`Ref` (e.g., `AtomRef`), i.e. **references to objects in the storage**. Every\n", 171 | "`Ref` contains two metadata fields:\n", 172 | "\n", 173 | "- `cid`: a hash of the **content** of the object\n", 174 | "- `hid`: a hash of the **computational history** of the object, which is the precise\n", 175 | "composition of `@op`s that created this ref. \n", 176 | "\n", 177 | "Two `Ref`s with the same `cid` may have different `hid`s, and `hid` is the\n", 178 | "unique identifier of `Ref`s in the storage. However, only 1 copy per unique\n", 179 | "`cid` is stored to avoid duplication in the storage.\n", 180 | "\n", 181 | "### `Ref`s can be in memory or not\n", 182 | "Additionally, `Ref`s have the `in_memory` property, which indicates if the\n", 183 | "underlying object is present in the `Ref` or if this is a \"lazy\" `Ref` which\n", 184 | "only contains metadata. **`Ref`s are only loaded in memory when needed for a new\n", 185 | "call to an `@op`**. For example, re-running the last code block:" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 5, 191 | "metadata": { 192 | "execution": { 193 | "iopub.execute_input": "2024-07-11T14:31:38.032711Z", 194 | "iopub.status.busy": "2024-07-11T14:31:38.032185Z", 195 | "iopub.status.idle": "2024-07-11T14:31:38.077234Z", 196 | "shell.execute_reply": "2024-07-11T14:31:38.076082Z" 197 | } 198 | }, 199 | "outputs": [ 200 | { 201 | "name": "stdout", 202 | "output_type": "stream", 203 | "text": [ 204 | "AtomRef(hid=168..., in_memory=False)\n" 205 | ] 206 | } 207 | ], 208 | "source": [ 209 | "with storage: \n", 210 | " s = sum_args(6, 7, 8, 9, c=11,)\n", 211 | " print(s)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": {}, 217 | "source": [ 218 | "To get the object wrapped by a `Ref`, call `storage.unwrap`:" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 6, 224 | "metadata": { 225 | "execution": { 226 | "iopub.execute_input": "2024-07-11T14:31:38.081995Z", 227 | "iopub.status.busy": "2024-07-11T14:31:38.081258Z", 228 | "iopub.status.idle": "2024-07-11T14:31:38.115821Z", 229 | "shell.execute_reply": "2024-07-11T14:31:38.114805Z" 230 | } 231 | }, 232 | "outputs": [ 233 | { 234 | "data": { 235 | "text/plain": [ 236 | "42" 237 | ] 238 | }, 239 | "execution_count": 6, 240 | "metadata": {}, 241 | "output_type": "execute_result" 242 | } 243 | ], 244 | "source": [ 245 | "storage.unwrap(s) # loads from storage only if necessary" 246 | ] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "metadata": {}, 251 | "source": [ 252 | "### Other useful `Storage` methods\n", 253 | "\n", 254 | "- `Storage.attach(inplace: bool)`: like `unwrap`, but puts the objects in the\n", 255 | "`Ref`s if they are not in-memory.\n", 256 | "- `Storage.load_ref(hid: str, in_memory: bool)`: load a `Ref` by its history ID,\n", 257 | "optionally also loading the underlying object." 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 7, 263 | "metadata": { 264 | "execution": { 265 | "iopub.execute_input": "2024-07-11T14:31:38.120583Z", 266 | "iopub.status.busy": "2024-07-11T14:31:38.119825Z", 267 | "iopub.status.idle": "2024-07-11T14:31:38.150984Z", 268 | "shell.execute_reply": "2024-07-11T14:31:38.150207Z" 269 | } 270 | }, 271 | "outputs": [ 272 | { 273 | "name": "stdout", 274 | "output_type": "stream", 275 | "text": [ 276 | "AtomRef(42, hid=168...)\n", 277 | "AtomRef(42, hid=168...)\n" 278 | ] 279 | } 280 | ], 281 | "source": [ 282 | "print(storage.attach(obj=s, inplace=False))\n", 283 | "print(storage.load_ref(s.hid))" 284 | ] 285 | }, 286 | { 287 | "cell_type": "markdown", 288 | "metadata": {}, 289 | "source": [ 290 | "## Working with `Call` objects\n", 291 | "Besides `Ref`s, the other kind of object in the storage is the `Call`, which\n", 292 | "stores references to the inputs and outputs of a call to an `@op`, together with\n", 293 | "metadata that mirrors the `Ref` metadata:\n", 294 | "\n", 295 | "- `Call.cid`: a content ID for the call, based on the `@op`'s identity, its\n", 296 | "version at the time of the call, and the `cid`s of the inputs\n", 297 | "- `Call.hid`: a history ID for the call, the same as `Call.cid`, but using the \n", 298 | "`hid`s of the inputs.\n", 299 | "\n", 300 | "**For every `Ref` history ID, there's at most one `Call` that has an output with\n", 301 | "this history ID**, and if it exists, this call can be found by calling\n", 302 | "`storage.get_ref_creator()`: " 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 8, 308 | "metadata": { 309 | "execution": { 310 | "iopub.execute_input": "2024-07-11T14:31:38.154659Z", 311 | "iopub.status.busy": "2024-07-11T14:31:38.154326Z", 312 | "iopub.status.idle": "2024-07-11T14:31:38.205594Z", 313 | "shell.execute_reply": "2024-07-11T14:31:38.203966Z" 314 | } 315 | }, 316 | "outputs": [ 317 | { 318 | "name": "stdout", 319 | "output_type": "stream", 320 | "text": [ 321 | "Call(sum_args, hid=f99...)\n" 322 | ] 323 | }, 324 | { 325 | "data": { 326 | "text/plain": [ 327 | "{'a': AtomRef(hid=c6a..., in_memory=False),\n", 328 | " 'args_0': AtomRef(hid=e0f..., in_memory=False),\n", 329 | " 'args_1': AtomRef(hid=479..., in_memory=False),\n", 330 | " 'args_2': AtomRef(hid=c37..., in_memory=False),\n", 331 | " 'b': AtomRef(hid=610..., in_memory=False),\n", 332 | " 'c': AtomRef(hid=a33..., in_memory=False)}" 333 | ] 334 | }, 335 | "metadata": {}, 336 | "output_type": "display_data" 337 | }, 338 | { 339 | "data": { 340 | "text/plain": [ 341 | "{'output_0': AtomRef(hid=168..., in_memory=False)}" 342 | ] 343 | }, 344 | "metadata": {}, 345 | "output_type": "display_data" 346 | } 347 | ], 348 | "source": [ 349 | "call = storage.get_ref_creator(ref=s)\n", 350 | "print(call)\n", 351 | "display(call.inputs)\n", 352 | "display(call.outputs)" 353 | ] 354 | } 355 | ], 356 | "metadata": { 357 | "language_info": { 358 | "codemirror_mode": { 359 | "name": "ipython", 360 | "version": 3 361 | }, 362 | "file_extension": ".py", 363 | "mimetype": "text/x-python", 364 | "name": "python", 365 | "nbconvert_exporter": "python", 366 | "pygments_lexer": "ipython3", 367 | "version": "3.10.8" 368 | } 369 | }, 370 | "nbformat": 4, 371 | "nbformat_minor": 2 372 | } 373 | --------------------------------------------------------------------------------