├── assets
├── .gitignore
├── favicon.ico
├── logo-192x192.png
├── logo-512x512.png
└── logo-no-background.png
├── mandala
├── __init__.py
├── deps
│ ├── tracers
│ │ ├── __init__.py
│ │ ├── tracer_base.py
│ │ └── sys_impl.py
│ ├── crawler.py
│ ├── viz.py
│ ├── deep_versions.py
│ ├── utils.py
│ └── model.py
├── imports.py
├── config.py
├── common_imports.py
├── tests
│ ├── test_cfs.py
│ ├── test_configs.py
│ ├── test_memoization.py
│ └── test_versioning.py
└── tps.py
├── runtime.txt
├── requirements.txt
├── .github
├── FUNDING.yml
└── workflows
│ └── deploy-docs.yml
├── docs_source
├── readme.md
├── make_docs.py
└── topics
│ └── 01_storage_and_ops.ipynb
├── c.py
├── docs
└── docs
│ ├── stylesheets
│ └── extra.css
│ ├── index.md
│ ├── topics
│ ├── 03_cf_files
│ │ ├── 03_cf_6_1.svg
│ │ ├── 03_cf_20_1.svg
│ │ ├── 03_cf_22_0.svg
│ │ ├── 03_cf_28_0.svg
│ │ ├── 03_cf_14_1.svg
│ │ └── 03_cf_29_0.svg
│ ├── 06_advanced_cf.md
│ ├── 05_collections.md
│ ├── 01_storage_and_ops.md
│ ├── 05_collections_files
│ │ └── 05_collections_5_0.svg
│ └── 02_retracing.md
│ ├── tutorials
│ ├── 01_hello_files
│ │ ├── 01_hello_5_1.svg
│ │ └── 01_hello_5_3.svg
│ └── 02_ml_files
│ │ ├── 02_ml_22_3.svg
│ │ └── 02_ml_22_1.svg
│ └── blog
│ └── 01_cf_files
│ ├── 01_cf_2_0.svg
│ └── 01_cf_21_0.svg
├── console.py
├── .coveragerc
├── mkdocs.yml
├── .gitignore
├── setup.py
└── LICENSE
/assets/.gitignore:
--------------------------------------------------------------------------------
1 | !*.png
--------------------------------------------------------------------------------
/mandala/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/runtime.txt:
--------------------------------------------------------------------------------
1 | python-3.10
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # for binder
2 | numpy >= 1.18
3 | pandas >= 1.0
4 | joblib >= 1.0
--------------------------------------------------------------------------------
/assets/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amakelov/mandala/HEAD/assets/favicon.ico
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: [amakelov]
4 |
--------------------------------------------------------------------------------
/assets/logo-192x192.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amakelov/mandala/HEAD/assets/logo-192x192.png
--------------------------------------------------------------------------------
/assets/logo-512x512.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amakelov/mandala/HEAD/assets/logo-512x512.png
--------------------------------------------------------------------------------
/assets/logo-no-background.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amakelov/mandala/HEAD/assets/logo-no-background.png
--------------------------------------------------------------------------------
/mandala/deps/tracers/__init__.py:
--------------------------------------------------------------------------------
1 | from .tracer_base import TracerABC
2 | from .dec_impl import DecTracer
3 | from .sys_impl import SysTracer
4 |
--------------------------------------------------------------------------------
/docs_source/readme.md:
--------------------------------------------------------------------------------
1 | To convert these notebooks to markdown docs, run them in the `jupyter notebook`
2 | server, save them, and run
3 | ```bash
4 | python make_docs.py
5 | ```
--------------------------------------------------------------------------------
/c.py:
--------------------------------------------------------------------------------
1 | import subprocess, argparse
2 |
3 |
4 | def get_parser() -> argparse.ArgumentParser:
5 | parser = argparse.ArgumentParser()
6 | return parser
7 |
8 |
9 | if __name__ == "__main__":
10 | parser = get_parser()
11 | args = parser.parse_args()
12 | cmd = ["ipython", "-i", "console.py", "--"]
13 | subprocess.call(cmd)
14 |
--------------------------------------------------------------------------------
/mandala/imports.py:
--------------------------------------------------------------------------------
1 | from .storage import Storage, noop
2 | from .model import op, Ignore, NewArgDefault, wrap_atom, ValuePointer
3 | from .tps import MList, MDict
4 | from .deps.tracers.dec_impl import track
5 |
6 | from .common_imports import sess
7 |
8 |
9 | def pprint_dict(d) -> str:
10 | return '\n'.join([f" {k}: {v}" for k, v in d.items()])
--------------------------------------------------------------------------------
/.github/workflows/deploy-docs.yml:
--------------------------------------------------------------------------------
1 | name: Deploy MkDocs
2 | on:
3 | push:
4 | branches:
5 | - master
6 | jobs:
7 | deploy:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - uses: actions/checkout@v2
11 | - uses: actions/setup-python@v2
12 | with:
13 | python-version: 3.x
14 | - run: pip install mkdocs-material
15 | - run: mkdocs gh-deploy --force
--------------------------------------------------------------------------------
/docs/docs/stylesheets/extra.css:
--------------------------------------------------------------------------------
1 | :root {
2 | --md-primary-fg-color: #073642;
3 | --md-accent-fg-color: #268bd2;
4 |
5 | --md-default-bg-color: #fdf6e3;
6 | --md-default-fg-color: #657b83;
7 | --md-default-fg-color--light: #93a1a1;
8 | --md-default-fg-color--lighter: #eee8d5;
9 |
10 | --md-code-bg-color: #eee8d5;
11 | --md-code-fg-color: #657b83;
12 | }
13 |
14 | /*
15 | a {
16 | text-decoration: underline;
17 | }
18 | */
--------------------------------------------------------------------------------
/console.py:
--------------------------------------------------------------------------------
1 | # useful builtins
2 | import sys
3 | import argparse
4 |
5 |
6 | def get_parser() -> argparse.ArgumentParser:
7 | """
8 | Copy of `get_parser` in c.py
9 | """
10 | parser = argparse.ArgumentParser()
11 | return parser
12 |
13 |
14 | parser = get_parser()
15 | args = parser.parse_args()
16 |
17 | if __name__ == "__main__":
18 | from mandala.all import *
19 | from mandala.tests.test_stateful_slow import *
20 |
21 | # setup_logging(level='info')
22 |
--------------------------------------------------------------------------------
/docs/docs/index.md:
--------------------------------------------------------------------------------
1 | # Introduction
2 | This is the documentation for [mandala](https://github.com/amakelov/mandala), a
3 | simple & elegant experiment tracking framework for Python.
4 |
5 | Most methods in `mandala` are provided by the `Storage` and `ComputationFrame`
6 | classes. In general, you'll probably find yourself only interacting with 5-10
7 | methods on a regular basis, and their docstrings provide detailed explanations.
8 |
9 | To complement this, this documentation contains a few short walkthroughs
10 | illustrating the use of these methods.
--------------------------------------------------------------------------------
/.coveragerc:
--------------------------------------------------------------------------------
1 | # .coveragerc to control coverage.py
2 | [run]
3 | branch = True
4 | omit =
5 | mandala_lite/tests/*.py
6 | mandala_lite/demos/*
7 |
8 | [report]
9 | # Regexes for lines to exclude from consideration
10 | exclude_lines =
11 | # Have to re-enable the standard pragma
12 | pragma: no cover
13 |
14 | # Don't complain if tests don't hit defensive assertion code:
15 | raise AssertionError
16 | raise NotImplementedError
17 | # Don't complain if you don't run into internal errors
18 | raise InternalError
19 |
20 | # Don't complain about abstract methods, they aren't run:
21 | @(abc\.)?abstractmethod
22 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: Mandala Documentation
2 |
3 | plugins:
4 | - search
5 |
6 | theme:
7 | name: material
8 | font:
9 | text: Roboto
10 | code: Roboto Mono
11 | features:
12 | - content.code.copy # Adds a copy button to code blocks
13 | palette:
14 | - media: "(prefers-color-scheme: light)"
15 | scheme: default
16 | primary: indigo
17 | accent: indigo
18 | toggle:
19 | icon: material/brightness-7
20 | name: Switch to dark mode
21 | - media: "(prefers-color-scheme: dark)"
22 | scheme: slate
23 | primary: indigo
24 | accent: indigo
25 | toggle:
26 | icon: material/brightness-4
27 | name: Switch to light mode
28 |
29 | markdown_extensions:
30 | - pymdownx.highlight:
31 | anchor_linenums: true
32 | line_spans: __span
33 | pygments_lang_class: true
34 | - pymdownx.inlinehilite
35 | - pymdownx.snippets
36 | - pymdownx.superfences
37 |
38 | docs_dir: 'docs/docs'
39 |
40 | extra_css:
41 | - stylesheets/extra.css
42 |
--------------------------------------------------------------------------------
/docs/docs/topics/03_cf_files/03_cf_6_1.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
21 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # python
2 | *.pyc
3 |
4 | # logs
5 | *.log
6 |
7 | # pictures and other visualization-related
8 | *.jpg
9 | *.JPG
10 | *.png
11 | *.tif
12 | *.gif
13 | # *.svg
14 | *.dot
15 | *.gv
16 | # *.mp4
17 |
18 | # docs
19 | *.pdf
20 | *.djvu
21 | *.ps
22 |
23 | # vim
24 | *.swp
25 | *.swo
26 | *.swn
27 |
28 | # archives
29 | *.7z
30 | *.dmg
31 | *.gz
32 | *.iso
33 | *.jar
34 | *.rar
35 | *.tar
36 | *.zip
37 |
38 | # latex
39 | *.aux
40 | *.fdb_latexmk
41 | *.fls
42 | *.log
43 | *.out
44 |
45 | # web
46 | *.html
47 | # *.css
48 |
49 | # python objects
50 | *.pkl
51 | *.joblib
52 |
53 | ################################################################################
54 | ### path-based
55 | ################################################################################
56 | # eggs
57 | mandala.egg-info/*
58 | pymandala.egg-info/*
59 |
60 | scratchpad/
61 |
62 | # visualizations
63 | # Ignore files cached by Hypothesis
64 | **/.hypothesis/*
65 | # Ignore files cached by vscode
66 | **/.vscode/*
67 | # ignore files cached by pytest
68 | **/.pytest_cache/*
69 | # ignore files cached by jupyter
70 | **/.ipynb_checkpoints
71 | # ignore coverage things
72 | .coverage
73 | **/htmlcov/*
74 | # ignore persistent storages
75 | temp_dbs/*
76 | mandala/tests/output/*
77 | *.db
78 | *.parquet
79 |
80 | # ignore docs build
81 | site/
82 |
83 | # ignore PyPi build
84 | dist/
85 |
86 | # Ignore previous version
87 | mandala/_prev/
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 |
4 | install_requires = [
5 | "numpy >= 1.18",
6 | "pandas >= 1.0",
7 | "joblib >= 1.0",
8 | ]
9 |
10 | extras_require = {
11 | "base": ["prettytable", "graphviz"],
12 | "ui": [
13 | "rich",
14 | ],
15 | "test": [
16 | "pytest >= 6.0.0",
17 | "hypothesis >= 6.0.0",
18 | "ipython",
19 | ],
20 | "demos": [
21 | "scikit-learn",
22 | ],
23 | }
24 |
25 |
26 | extras_require["complete"] = sorted({v for req in extras_require.values() for v in req})
27 |
28 | packages = [
29 | "mandala",
30 | "mandala.deps",
31 | "mandala.deps.tracers",
32 | "mandala.tests",
33 | ]
34 |
35 | setup(
36 | name="pymandala",
37 | version="v0.2.0-alpha",
38 | author="Aleksandar (Alex) Makelov",
39 | author_email="aleksandar.makelov@gmail.com",
40 | description="A powerful and easy to use experiment tracking framework",
41 | long_description=open("README.md").read(),
42 | long_description_content_type="text/markdown",
43 | url="https://github.com/amakelov/mandala",
44 | license="Apache 2.0",
45 | keywords="computational-experiments data-management machine-learning data-science",
46 | classifiers=[
47 | "Development Status :: 3 - Alpha",
48 | "Intended Audience :: Science/Research",
49 | "Operating System :: OS Independent",
50 | "License :: OSI Approved :: Apache Software License",
51 | "Programming Language :: Python :: 3",
52 | "Programming Language :: Python :: 3.10",
53 | "Programming Language :: Python :: 3.9",
54 | "Programming Language :: Python :: 3.8",
55 | ],
56 | packages=packages,
57 | python_requires=">=3.8",
58 | install_requires=install_requires,
59 | extras_require=extras_require,
60 | )
61 |
--------------------------------------------------------------------------------
/mandala/config.py:
--------------------------------------------------------------------------------
1 | from .common_imports import *
2 |
3 | def get_mandala_path() -> Path:
4 | import mandala
5 |
6 | return Path(os.path.dirname(mandala.__file__))
7 |
8 | class Config:
9 | func_interface_cls_name = "Op"
10 | mandala_path = get_mandala_path()
11 | module_name = "mandala"
12 | tests_module_name = "mandala.tests"
13 |
14 | try:
15 | import PIL
16 |
17 | has_pil = True
18 | except ImportError:
19 | has_pil = False
20 |
21 | try:
22 | import torch
23 |
24 | has_torch = True
25 | except ImportError:
26 | has_torch = False
27 |
28 | try:
29 | import rich
30 |
31 | has_rich = True
32 | except ImportError:
33 | has_rich = False
34 |
35 | try:
36 | import prettytable
37 |
38 | has_prettytable = True
39 | except ImportError:
40 | has_prettytable = False
41 |
42 |
43 | if Config.has_torch:
44 | import torch
45 |
46 | def tensor_to_numpy(obj: Union[torch.Tensor, dict, list, tuple, Any]) -> Any:
47 | """
48 | Recursively convert PyTorch tensors in a data structure to numpy arrays.
49 |
50 | Parameters
51 | ----------
52 | obj : any
53 | The input data structure.
54 |
55 | Returns
56 | -------
57 | any
58 | The data structure with tensors converted to numpy arrays.
59 | """
60 | if isinstance(obj, torch.Tensor):
61 | return obj.detach().cpu().numpy()
62 | elif isinstance(obj, dict):
63 | return {k: tensor_to_numpy(v) for k, v in obj.items()}
64 | elif isinstance(obj, list):
65 | return [tensor_to_numpy(v) for v in obj]
66 | elif isinstance(obj, tuple):
67 | return tuple(tensor_to_numpy(v) for v in obj)
68 | else:
69 | return obj
70 |
--------------------------------------------------------------------------------
/docs/docs/topics/03_cf_files/03_cf_20_1.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
28 |
--------------------------------------------------------------------------------
/mandala/common_imports.py:
--------------------------------------------------------------------------------
1 | import time
2 | import traceback
3 | import random
4 | import logging
5 | import itertools
6 | import copy
7 | import hashlib
8 | import io
9 | import os
10 | import shutil
11 | import sys
12 | import joblib
13 | import inspect
14 | import binascii
15 | import asyncio
16 | import ast
17 | import types
18 | import tempfile
19 | from collections import defaultdict, OrderedDict
20 | from typing import (
21 | Any,
22 | Dict,
23 | List,
24 | Callable,
25 | Tuple,
26 | Iterable,
27 | Optional,
28 | Set,
29 | Union,
30 | TypeVar,
31 | Literal,
32 | )
33 | from pathlib import Path
34 |
35 | import pandas as pd
36 | import pyarrow as pa
37 | import numpy as np
38 |
39 | try:
40 | import rich
41 |
42 | has_rich = True
43 | except ImportError:
44 | has_rich = False
45 |
46 | if has_rich:
47 | from rich.logging import RichHandler
48 |
49 | logger = logging.getLogger("mandala")
50 | logging_handler = RichHandler(enable_link_path=False)
51 | FORMAT = "%(message)s"
52 | logging.basicConfig(
53 | level="INFO", format=FORMAT, datefmt="[%X]", handlers=[logging_handler]
54 | )
55 | else:
56 | logger = logging.getLogger("mandala")
57 | # logger.addHandler(logging.StreamHandler())
58 | FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
59 | logging.basicConfig(format=FORMAT)
60 | logger.setLevel(logging.INFO)
61 |
62 | class Session:
63 | # for debugging
64 |
65 | def __init__(self):
66 | self.items = []
67 | self._scope = None
68 |
69 | def d(self):
70 | scope = inspect.currentframe().f_back.f_locals
71 | self._scope = scope
72 |
73 | def dump(self):
74 | # put the scope into the current locals
75 | assert self._scope is not None
76 | scope = inspect.currentframe().f_back.f_locals
77 | print(f"Dumping {self._scope.keys()} into local scope")
78 | scope.update(self._scope)
79 |
80 | sess = Session()
--------------------------------------------------------------------------------
/mandala/tests/test_cfs.py:
--------------------------------------------------------------------------------
1 | from mandala.imports import *
2 |
3 |
4 | def test_single_func():
5 | storage = Storage()
6 |
7 | @op
8 | def inc(x: int) -> int:
9 | return x + 1
10 |
11 | with storage:
12 | for i in range(10):
13 | inc(i)
14 |
15 | cf = storage.cf(inc)
16 | df = cf.df()
17 | assert df.shape == (10, 3)
18 | assert (df['var_0'] == df['x'] + 1).all()
19 |
20 |
21 | def test_composition():
22 | storage = Storage()
23 |
24 | @op(output_names=['y'])
25 | def inc(x):
26 | return x + 1
27 |
28 | @op(output_names=['z'])
29 | def add(x, y):
30 | return x + y
31 |
32 | with storage:
33 | for x in range(5):
34 | y = inc(x)
35 | if x % 2 == 0:
36 | z = add(x, y)
37 |
38 | cf = storage.cf(add).expand_all()
39 | df = cf.df()
40 | assert df.shape[0] == 3
41 | assert (df['z'] == df['x'] + df['y']).all()
42 |
43 | cf = storage.cf(inc).expand_all()
44 | df = cf.df()
45 | assert df.shape[0] == 5
46 | assert (df['y'] == df['x'] + 1).all()
47 | assert (df['z'] == df['x'] + df['y'])[df['z'].notnull()].all()
48 |
49 | def test_merge():
50 | storage = Storage()
51 |
52 | @op(output_names=['y'])
53 | def inc(x):
54 | return x + 1
55 |
56 | @op(output_names=['z'])
57 | def add(x, y):
58 | return x + y
59 |
60 | @op(output_names=['w'])
61 | def mul(x, y):
62 | return x * y
63 |
64 | @op(output_names=['v'])
65 | def final(t):
66 | return t**2
67 |
68 | with storage:
69 | for x in range(10):
70 | y = inc(x)
71 | if x < 5:
72 | z = add(x, y)
73 | v = final(z)
74 | else:
75 | w = mul(x, y)
76 | v = final(w)
77 |
78 |
79 | cf = storage.cf(final).expand_all().merge_vars()
80 | df = cf.df()
81 | assert df.shape[0] == 10
--------------------------------------------------------------------------------
/docs_source/make_docs.py:
--------------------------------------------------------------------------------
1 | """
2 | convert .ipynb files in this directory to .md files and move them to the docs
3 | directory for mkdocs to use
4 | """
5 |
6 | import os
7 | import argparse
8 |
9 | # parse the command line arguments
10 | parser = argparse.ArgumentParser(description='Convert .ipynb files to .md files')
11 | parser.add_argument('--filenames', type=str, nargs='+', help='list of filenames to convert')
12 | args = parser.parse_args()
13 |
14 |
15 | if __name__ == '__main__':
16 | if args.filenames:
17 | ipynb_files = args.filenames
18 | # prepend "./" to the filenames
19 | ipynb_files = ['./' + f for f in ipynb_files if not f.startswith('./')]
20 | else:
21 | # find all .ipynb files recursively in the current directory and its subdirectories
22 | ipynb_files = []
23 | for root, dirs, files in os.walk('.'):
24 | for f in files:
25 | if f.endswith('.ipynb'):
26 | ipynb_files.append(os.path.join(root, f))
27 |
28 | for f in ipynb_files:
29 | os.system('jupyter nbconvert --to notebook --execute --inplace ' + f)
30 | os.system(f"jupyter nbconvert --to markdown {f}")
31 |
32 | DOCS_REL_PATH = '../docs/docs/'
33 | # now, move the .md files to the docs directory
34 | for f in ipynb_files:
35 | # find the relative directory and the filename
36 | relative_dir = os.path.dirname(f)
37 | fname = os.path.basename(f)
38 | # if the target dir doesn't exist, create it
39 | if not os.path.isdir(DOCS_REL_PATH + relative_dir):
40 | os.system("mkdir -p " + DOCS_REL_PATH + relative_dir)
41 | # move to the DOCS_REL_PATH, under the same directory structure
42 | mv_cmd = "mv " + f.replace('.ipynb', '.md') + " " + DOCS_REL_PATH + relative_dir + '/' + fname.replace('.ipynb', '.md')
43 | print(mv_cmd)
44 | os.system(mv_cmd)
45 |
46 | # also, move any directories named "{fname}_files" to the docs directory
47 | for f in ipynb_files:
48 | files_folder = f.replace('.ipynb', '_files')
49 | if os.path.isdir(files_folder):
50 | # first, remove the directory if it already exists
51 | target_files_path = DOCS_REL_PATH + files_folder
52 | if os.path.isdir(target_files_path):
53 | os.system(f"rm -r {DOCS_REL_PATH}" + files_folder)
54 | # then, move the directory
55 | os.system("mv " + f.replace('.ipynb', '_files') + " " + DOCS_REL_PATH + f.replace('.ipynb', '_files'))
--------------------------------------------------------------------------------
/docs/docs/topics/06_advanced_cf.md:
--------------------------------------------------------------------------------
1 | # Advanced `ComputationFrame` tools
2 |
3 |
4 |
5 | This section of the documentation contains some more advanced `ComputationFrame`
6 | topics.
7 |
8 | ## Set-like and graph-like operations on `ComputationFrame`s
9 | A CF is like a "graph of sets", where the elements of each set are either `Ref`s
10 | (for variables) or `Call`s (for functions). As such, it supports both natural
11 | set-like operations applied node/edge-wise, and natural operations using the
12 | graph's connectivity:
13 |
14 | - **union**: given by `cf_1 | cf_2`, this takes the union of the two computation
15 | graphs (merging nodes/edges with the same names), and in case of a merge, the
16 | resulting set at the node is the union of the two sets of `Ref`s or `Call`s.
17 | - **intersection**: given by `cf_1 & cf_2`, this takes the intersection of the
18 | two computation graphs (leaving only nodes/edges with the same name in both),
19 | and the set at each node is the intersection of the two corresponding sets.
20 | - **`.downstream(varnames)`**: restrict the CF to computations that are
21 | downstream of the `Ref`s in chosen variables
22 | - **`.upstream(varnames)`**: dual to `downstream`
23 |
24 | Consider the following example:
25 |
26 |
27 | ```python
28 | # for Google Colab
29 | try:
30 | import google.colab
31 | !pip install git+https://github.com/amakelov/mandala
32 | except:
33 | pass
34 | ```
35 |
36 |
37 | ```python
38 | from mandala.imports import *
39 | storage = Storage()
40 |
41 | @op
42 | def inc(x): return x + 1
43 |
44 | @op
45 | def add(y, z): return y + z
46 |
47 | @op
48 | def square(w): return w ** 2
49 |
50 | @op
51 | def divmod_(u, v): return divmod(u, v)
52 |
53 | with storage:
54 | xs = [inc(i) for i in range(5)]
55 | ys = [add(x, z=42) for x in xs] + [square(x) for x in range(5, 10)]
56 | zs = [divmod_(x, y) for x, y in zip(xs, ys[3:8])]
57 | ```
58 |
59 | We have a "middle layer" in the computation that uses both `add` and `square`.
60 | We can get a shared view of the entire computation by taking the union of the
61 | expanded CFs for these two ops:
62 |
63 |
64 | ```python
65 | cf = (storage.cf(add) | storage.cf(square)).expand_all()
66 | cf.draw(verbose=True)
67 | ```
68 |
69 |
70 |
71 | 
72 |
73 |
74 |
75 | ## Selection
76 | TODO
77 |
--------------------------------------------------------------------------------
/mandala/tests/test_configs.py:
--------------------------------------------------------------------------------
1 | from mandala.imports import *
2 |
3 |
4 | def test_nesting():
5 | storage = Storage()
6 | with storage(mode='noop'):
7 | assert storage.mode == 'noop'
8 | with storage():
9 | assert storage.mode == 'run'
10 | assert storage.mode == 'noop'
11 | with storage(mode='noop'):
12 | assert storage.mode == 'noop'
13 | assert storage.mode == 'noop'
14 | assert storage.mode == 'run'
15 |
16 |
17 | def test_noop_simple():
18 | @op
19 | def inc(x: int, *args, y: int = NewArgDefault(2), z: int = 23) -> int:
20 | return x + y + z + sum(args)
21 |
22 | storage = Storage()
23 |
24 | with storage(mode='noop'):
25 | # test various wrapped values
26 | res = inc(ValuePointer(id='one', obj=1), 1, 1, z=1)
27 | assert res == 6
28 |
29 |
30 | def test_noop_composition():
31 |
32 | @op
33 | def inc(x: int) -> int:
34 | return x + 1
35 |
36 | @op
37 | def add(x: int, y: int) -> int:
38 | return x + y
39 |
40 | storage = Storage()
41 |
42 |
43 | # test that wrapped values are unwrapped
44 | with storage:
45 | x = inc(20)
46 |
47 | with storage:
48 | x = inc(20)
49 | with storage(mode='noop'):
50 | z = add(x, 21)
51 | assert z == 42
52 |
53 |
54 | def test_noop_standalone():
55 | storage = Storage()
56 |
57 | @op
58 | def inc(x: int) -> int:
59 | return x + 1
60 |
61 | with storage:
62 | x = inc(20)
63 | assert storage.mode == 'run'
64 | with noop():
65 | assert storage.mode == 'noop'
66 | y = inc(x)
67 | z = inc(20)
68 | assert storage.mode == 'run'
69 | assert y == 22
70 | assert z == 21
71 |
72 | # now, test it without a context
73 | with noop():
74 | x = inc(20)
75 | assert x == 21
76 |
77 |
78 | def test_no_new_calls():
79 | @op
80 | def inc(x: int) -> int:
81 | return x + 1
82 |
83 | storage = Storage()
84 | with storage:
85 | inc(20)
86 |
87 | storage.allow_new_calls(False)
88 |
89 | # memoized calls should still work
90 | with storage:
91 | inc(20)
92 |
93 | try:
94 | with storage:
95 | inc(21)
96 | except RuntimeError as e:
97 | assert str(e) == "Call to inc does not exist and new calls are not allowed."
98 | except Exception as e:
99 | raise e
100 | finally:
101 | storage.allow_new_calls(True)
102 |
103 |
--------------------------------------------------------------------------------
/mandala/deps/tracers/tracer_base.py:
--------------------------------------------------------------------------------
1 | from ...common_imports import *
2 | from ...config import Config
3 | import importlib
4 | from ..model import DependencyGraph, CallableNode
5 | from abc import ABC, abstractmethod
6 |
7 |
8 | class TracerABC(ABC):
9 | def __init__(
10 | self,
11 | paths: List[Path],
12 | strict: bool = True,
13 | allow_methods: bool = False,
14 | track_globals: bool = True,
15 | ):
16 | self.call_stack: List[Optional[CallableNode]] = []
17 | self.graph = DependencyGraph()
18 | self.paths = paths
19 | self.strict = strict
20 | self.allow_methods = allow_methods
21 | self.track_globals = track_globals
22 |
23 | @abstractmethod
24 | def __enter__(self):
25 | raise NotImplementedError
26 |
27 | @abstractmethod
28 | def __exit__(self, exc_type, exc_val, exc_tb):
29 | raise NotImplementedError
30 |
31 | @staticmethod
32 | @abstractmethod
33 | def get_active_trace_obj() -> Optional[Any]:
34 | raise NotImplementedError
35 |
36 | @staticmethod
37 | @abstractmethod
38 | def set_active_trace_obj(trace_obj: Any):
39 | raise NotImplementedError
40 |
41 | @staticmethod
42 | @abstractmethod
43 | def register_leaf_event(trace_obj: Any, data: Any):
44 | raise NotImplementedError
45 |
46 |
47 | BREAK = "break" # stop tracing (currently doesn't really work b/c python)
48 | CONTINUE = "continue" # continue tracing, but don't add call to dependencies
49 | KEEP = "keep" # continue tracing and add call to dependencies
50 | MAIN = "__main__"
51 |
52 |
53 | def get_closure_names(code_obj: types.CodeType, func_qualname: str) -> Tuple[str]:
54 | closure_vars = code_obj.co_freevars
55 | if "." in func_qualname and "__class__" in closure_vars:
56 | closure_vars = tuple([var for var in closure_vars if var != "__class__"])
57 | return closure_vars
58 |
59 |
60 | def get_module_flow(module_name: Optional[str], paths: List[Path]) -> str:
61 | if module_name is None:
62 | return BREAK
63 | if module_name == MAIN:
64 | return KEEP
65 | try:
66 | module = importlib.import_module(module_name)
67 | is_importable = True
68 | except ModuleNotFoundError:
69 | is_importable = False
70 | if not is_importable:
71 | return BREAK
72 | try:
73 | module_path = Path(inspect.getfile(module))
74 | except TypeError:
75 | # this happens when the module is a built-in module
76 | return BREAK
77 | if (
78 | not any(root in module_path.parents for root in paths)
79 | and module_path not in paths
80 | ):
81 | # module is not in the paths we're inspecting; stop tracing
82 | logger.debug(f" Module {module_name} not in paths, BREAK")
83 | return BREAK
84 | elif module_name.startswith(Config.module_name) and not module_name.startswith(
85 | Config.tests_module_name
86 | ):
87 | # this function is part of `mandala` functionality. Continue tracing
88 | # but don't add it to the dependency state
89 | logger.debug(f" Module {module_name} is mandala, CONTINUE")
90 | return CONTINUE
91 | else:
92 | logger.debug(f" Module {module_name} is not mandala but in paths, KEEP")
93 | return KEEP
94 |
--------------------------------------------------------------------------------
/docs/docs/topics/05_collections.md:
--------------------------------------------------------------------------------
1 | # Natively handling Python collections
2 |
3 |
4 |
5 | A key benefit of `mandala` over straightforward memoization is that it can make
6 | Python collections (lists, dicts, ...) a native & transparent part of the
7 | memoization process:
8 |
9 | - `@op`s can return collections where each item is a separate `Ref`, so that
10 | later `@op` calls can work with individual elements;
11 | - `@op`s can also accept as input collections where each item is a separate
12 | `Ref`, to e.g. implement aggregation operations over them.
13 | - collections can reuse the storage of their items: if two collections share
14 | some elements, each shared element is stored only once in storage.
15 | - the relationship between a collection and each of its items is a native part
16 | of the computational graph of `@op` calls, and can be propagated automatically
17 | by `ComputationFrame`s. Indeed, **collections are implemented as `@op`s
18 | internally**.
19 |
20 | ## Input/output collections must be explicitly annotated
21 | By default, any collection passed as `@op` input or output will be stored as a
22 | single `Ref` with no structure; the object is **opaque** to the `Storage`
23 | instance. To make the collection **transparent** to the `Storage`, you must
24 | override this behavior explicitly by using a custom type annotation, such as
25 | `MList` for lists, `MDict` for dicts, ...:
26 |
27 |
28 | ```python
29 | # for Google Colab
30 | try:
31 | import google.colab
32 | !pip install git+https://github.com/amakelov/mandala
33 | except:
34 | pass
35 | ```
36 |
37 |
38 | ```python
39 | from mandala.imports import Storage, op, MList
40 |
41 | storage = Storage()
42 |
43 | @op
44 | def average(nums: MList[float]) -> float:
45 | return sum(nums) / len(nums)
46 |
47 | with storage:
48 | a = average([1, 2, 3, 4, 5])
49 | ```
50 |
51 | We can understand how the list was made transparent to the storage by inspecting
52 | the computation frame:
53 |
54 |
55 | ```python
56 | cf = storage.cf(average).expand_all(); cf.draw(verbose=True, orientation='LR')
57 | ```
58 |
59 |
60 |
61 | 
62 |
63 |
64 |
65 | We see that the internal `__make_list__` operation was automatically applied to
66 | create a list, which is then the `Ref` passed to `average`.
67 |
68 | ## How collections interact with `ComputationFrame`s
69 | In general, CFs are turned into dataframes that capture the joint history of the
70 | final `Ref`s in the CF. When there are collection `@op`s in the CF, a single
71 | `Ref` (such as the element of `nums` above) can depend on multiple `Ref`s in
72 | another variable (such as the `Ref`s in the `elts` variable).
73 |
74 | We can observe this by taking the dataframe of the above CF:
75 |
76 |
77 | ```python
78 | print(cf.df(values='objs').to_markdown())
79 | ```
80 |
81 | | | elts | __make_list__ | nums | average | var_0 |
82 | |---:|:---------------------------------|:--------------------------------|:----------------|:--------------------------|--------:|
83 | | 0 | ValueCollection([2, 4, 1, 3, 5]) | Call(__make_list__, hid=172...) | [1, 2, 3, 4, 5] | Call(average, hid=38e...) | 3 |
84 |
85 |
86 | There's only a single row, but in the `elts` column we see a `ValueCollection`
87 | object, indicating that there are multiple `Ref`s in `elts` that are
88 | dependencies of `output_0`.
89 |
--------------------------------------------------------------------------------
/docs/docs/tutorials/01_hello_files/01_hello_5_1.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
52 |
--------------------------------------------------------------------------------
/docs/docs/topics/03_cf_files/03_cf_22_0.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
52 |
--------------------------------------------------------------------------------
/mandala/deps/crawler.py:
--------------------------------------------------------------------------------
1 | import types
2 | from ..common_imports import *
3 | from ..utils import unwrap_decorators
4 | import importlib
5 | from .model import (
6 | DepKey,
7 | CallableNode,
8 | )
9 | from .utils import (
10 | is_callable_obj,
11 | extract_func_obj,
12 | unknown_function,
13 | )
14 |
15 |
16 | def crawl_obj(
17 | obj: Any,
18 | module_name: str,
19 | include_methods: bool,
20 | result: Dict[DepKey, CallableNode],
21 | strict: bool,
22 | objs_result: Dict[DepKey, Callable],
23 | ):
24 | """
25 | Find functions and optionally methods native to the module of this object.
26 | """
27 | if is_callable_obj(obj=obj, strict=strict):
28 | if isinstance(unwrap_decorators(obj, strict=False), types.BuiltinFunctionType):
29 | return
30 | v = extract_func_obj(obj=obj, strict=strict)
31 | if v is not unknown_function and v.__module__ != module_name:
32 | # exclude non-local functions
33 | return
34 | dep_key = (module_name, v.__qualname__)
35 | node = CallableNode.from_obj(obj=v, dep_key=dep_key)
36 | result[dep_key] = node
37 | objs_result[dep_key] = obj
38 | if isinstance(obj, type):
39 | if include_methods:
40 | if obj.__module__ != module_name:
41 | return
42 | for k in obj.__dict__.keys():
43 | v = obj.__dict__[k]
44 | crawl_obj(
45 | obj=v,
46 | module_name=module_name,
47 | include_methods=include_methods,
48 | result=result,
49 | strict=strict,
50 | objs_result=objs_result,
51 | )
52 |
53 |
54 | def crawl_static(
55 | root: Optional[Path],
56 | strict: bool,
57 | package_name: Optional[str] = None,
58 | include_methods: bool = False,
59 | ) -> Tuple[Dict[DepKey, CallableNode], Dict[DepKey, Callable]]:
60 | """
61 | Find all python files in the root directory, and use importlib to import
62 | them, look for callable objects, and create callable nodes from them.
63 | """
64 | result: Dict[DepKey, CallableNode] = {}
65 | objs_result: Dict[DepKey, Callable] = {}
66 | paths = []
67 | if root is not None:
68 | if root.is_file():
69 | assert package_name is not None # needs this to be able to import
70 | paths = [root]
71 | else:
72 | paths.extend(list(root.rglob("*.py")))
73 | paths.append("__main__")
74 | for path in paths:
75 | filename = path.name if path != "__main__" else "__main__"
76 | if filename in ("setup.py", "console.py"):
77 | continue
78 | if path != "__main__" and root is not None:
79 | if root.is_file():
80 | module_name = root.stem
81 | else:
82 | module_name = (
83 | path.with_suffix("").relative_to(root).as_posix().replace("/", ".")
84 | )
85 | if package_name is not None:
86 | module_name = ".".join([package_name, module_name])
87 | else:
88 | module_name = "__main__"
89 | try:
90 | module = importlib.import_module(module_name)
91 | except:
92 | msg = f"Failed to import {module_name}:"
93 | if strict:
94 | raise ValueError(msg)
95 | else:
96 | logger.warning(msg)
97 | continue
98 | keys = list(module.__dict__.keys())
99 | for k in keys:
100 | v = module.__dict__[k]
101 | crawl_obj(
102 | obj=v,
103 | module_name=module_name,
104 | strict=strict,
105 | include_methods=include_methods,
106 | result=result,
107 | objs_result=objs_result,
108 | )
109 | return result, objs_result
110 |
--------------------------------------------------------------------------------
/mandala/tps.py:
--------------------------------------------------------------------------------
1 | from .common_imports import *
2 | import typing
3 | from typing import Hashable
4 |
5 | ################################################################################
6 | ### types
7 | ################################################################################
8 | from typing import Generic
9 |
10 | T = TypeVar("T")
11 | # Subclassing List
12 | class MList(List[T], Generic[T]):
13 | def identify(self):
14 | return "Type annotation for `mandala` lists"
15 |
16 |
17 | _KT = TypeVar("_KT")
18 | _VT = TypeVar("_VT")
19 | class MDict(Dict[_KT, _VT], Generic[_KT, _VT]):
20 | def identify(self):
21 | return "Type annotation for `mandala` dictionaries"
22 |
23 |
24 | class MSet(Set[T], Generic[T]):
25 | def identify(self):
26 | return "Type annotation for `mandala` sets"
27 |
28 |
29 | class MTuple(Tuple, Generic[T]):
30 | def identify(self):
31 | return "Type annotation for `mandala` tuples"
32 |
33 |
34 | class Type:
35 | @staticmethod
36 | def from_annotation(annotation: Any) -> "Type":
37 | if (annotation is None) or (annotation is inspect._empty):
38 | return AtomType()
39 | elif annotation is typing.Any:
40 | return AtomType()
41 | elif hasattr(annotation, "__origin__"):
42 | if annotation.__origin__ is MList:
43 | elt_annotation = annotation.__args__[0]
44 | return ListType(elt=Type.from_annotation(annotation=elt_annotation))
45 | elif annotation.__origin__ is MDict:
46 | key_annotation = annotation.__args__[0]
47 | value_annotation = annotation.__args__[1]
48 | return DictType(
49 | key=Type.from_annotation(annotation=key_annotation),
50 | val=Type.from_annotation(annotation=value_annotation),
51 | )
52 | elif annotation.__origin__ is MSet:
53 | elt_annotation = annotation.__args__[0]
54 | return SetType(elt=Type.from_annotation(annotation=elt_annotation))
55 | elif annotation.__origin__ is MTuple:
56 | if len(annotation.__args__) == 2 and annotation.__args__[1] == Ellipsis:
57 | return TupleType(
58 | Type.from_annotation(annotation=annotation.__args__[0])
59 | )
60 | else:
61 | return TupleType(
62 | *(
63 | Type.from_annotation(annotation=elt_annotation)
64 | for elt_annotation in annotation.__args__
65 | )
66 | )
67 | else:
68 | return AtomType()
69 | elif isinstance(annotation, Type):
70 | return annotation
71 | else:
72 | return AtomType()
73 |
74 | def __eq__(self, other: Any) -> bool:
75 | if type(self) != type(other):
76 | return False
77 | elif isinstance(self, AtomType):
78 | return True
79 | else:
80 | raise NotImplementedError
81 |
82 |
83 | class AtomType(Type):
84 | def __repr__(self):
85 | return "AnyType()"
86 |
87 |
88 | class ListType(Type):
89 | struct_id = "__list__"
90 | model = list
91 |
92 | def __init__(self, elt: Type):
93 | self.elt = elt
94 |
95 | def __repr__(self):
96 | return f"ListType(elt_type={self.elt})"
97 |
98 |
99 | class DictType(Type):
100 | struct_id = "__dict__"
101 | model = dict
102 |
103 | def __init__(self, val: Type, key: Type = None):
104 | self.key = key
105 | self.val = val
106 |
107 | def __repr__(self):
108 | return f"DictType(val_type={self.val})"
109 |
110 |
111 | class SetType(Type):
112 | struct_id = "__set__"
113 | model = set
114 |
115 | def __init__(self, elt: Type):
116 | self.elt = elt
117 |
118 | def __repr__(self):
119 | return f"SetType(elt_type={self.elt})"
120 |
121 |
122 | class TupleType(Type):
123 | struct_id = "__tuple__"
124 | model = tuple
125 |
126 | def __init__(self, *elt_types: Type):
127 | self.elt_types = elt_types
128 |
129 | def __repr__(self):
130 | return f"TupleType(elt_types={self.elt_types})"
131 |
--------------------------------------------------------------------------------
/mandala/deps/viz.py:
--------------------------------------------------------------------------------
1 | import textwrap
2 | from ..common_imports import *
3 | from ..viz import (
4 | Node as DotNode,
5 | Edge as DotEdge,
6 | Group as DotGroup,
7 | to_dot_string,
8 | SOLARIZED_LIGHT,
9 | )
10 |
11 | from .utils import (
12 | DepKey,
13 | )
14 |
15 |
16 | def to_string(graph: "model.DependencyGraph") -> str:
17 | """
18 | Get a string for pretty-printing.
19 | """
20 | # group the nodes by module
21 | module_groups: Dict[str, List["model.Node"]] = {}
22 | for key, node in graph.nodes.items():
23 | module_name, _ = key
24 | module_groups.setdefault(module_name, []).append(node)
25 | lines = []
26 | for module_name, nodes in module_groups.items():
27 | global_nodes = [node for node in nodes if isinstance(node, model.GlobalVarNode)]
28 | callable_nodes = [
29 | node for node in nodes if isinstance(node, model.CallableNode)
30 | ]
31 | module_desc = f"MODULE: {module_name}"
32 | lines.append(module_desc)
33 | lines.append("-" * len(module_desc))
34 | lines.append("===Global Variables===")
35 | for node in global_nodes:
36 | desc = f"{node.obj_name} = {node.readable_content()}"
37 | lines.append(textwrap.indent(desc, 4 * " "))
38 | # lines.append(f" {node.diff_representation()}")
39 | lines.append("")
40 | lines.append("===Functions===")
41 | # group the methods by class
42 | method_nodes = [node for node in callable_nodes if node.is_method]
43 | func_nodes = [node for node in callable_nodes if not node.is_method]
44 | methods_by_class: Dict[str, List["model.CallableNode"]] = {}
45 | for method_node in method_nodes:
46 | methods_by_class.setdefault(method_node.class_name, []).append(method_node)
47 | for class_name, method_nodes in methods_by_class.items():
48 | lines.append(textwrap.indent(f"class {class_name}:", 4 * " "))
49 | for node in method_nodes:
50 | desc = node.readable_content()
51 | lines.append(textwrap.indent(textwrap.dedent(desc), 8 * " "))
52 | lines.append("")
53 | for node in func_nodes:
54 | desc = node.readable_content()
55 | lines.append(textwrap.indent(textwrap.dedent(desc), 4 * " "))
56 | lines.append("")
57 | return "\n".join(lines)
58 |
59 |
60 | def to_dot(graph: "model.DependencyGraph") -> str:
61 | nodes: Dict[DepKey, DotNode] = {}
62 | module_groups: Dict[str, DotGroup] = {} # module name -> Group
63 | class_groups: Dict[str, DotGroup] = {} # class name -> Group
64 | for key, node in graph.nodes.items():
65 | module_name, obj_addr = key
66 | if module_name not in module_groups:
67 | module_groups[module_name] = DotGroup(
68 | label=module_name, nodes=[], parent=None
69 | )
70 | if isinstance(node, model.GlobalVarNode):
71 | color = SOLARIZED_LIGHT["red"]
72 | elif isinstance(node, model.CallableNode):
73 | color = (
74 | SOLARIZED_LIGHT["blue"]
75 | if not node.is_method
76 | else SOLARIZED_LIGHT["violet"]
77 | )
78 | else:
79 | color = SOLARIZED_LIGHT["base03"]
80 | dot_node = DotNode(
81 | internal_name=".".join(key), label=node.obj_name, color=color
82 | )
83 | nodes[key] = dot_node
84 | module_groups[module_name].nodes.append(dot_node)
85 | if isinstance(node, model.CallableNode) and node.is_method:
86 | class_name = node.class_name
87 | class_groups.setdefault(
88 | class_name,
89 | DotGroup(
90 | label=class_name,
91 | nodes=[],
92 | parent=module_groups[module_name],
93 | ),
94 | ).nodes.append(dot_node)
95 | edges: Dict[Tuple[DotNode, DotNode], DotEdge] = {}
96 | for source, target in graph.edges:
97 | source_node = nodes[source]
98 | target_node = nodes[target]
99 | edge = DotEdge(source_node=source_node, target_node=target_node)
100 | edges[(source_node, target_node)] = edge
101 | dot_string = to_dot_string(
102 | nodes=list(nodes.values()),
103 | edges=list(edges.values()),
104 | groups=list(module_groups.values()) + list(class_groups.values()),
105 | rankdir="BT",
106 | )
107 | return dot_string
108 |
109 |
110 | from . import model
111 |
--------------------------------------------------------------------------------
/docs/docs/topics/01_storage_and_ops.md:
--------------------------------------------------------------------------------
1 | # `Storage` & the `@op` Decorator
2 |
3 |
4 |
5 | A `Storage` object holds all data (saved calls, code and dependencies) for a
6 | collection of memoized functions. In a given project, you should have just one
7 | `Storage` and many `@op`s connected to it. This way, the calls to memoized
8 | functions create a queriable web of interlinked objects.
9 |
10 |
11 | ```python
12 | # for Google Colab
13 | try:
14 | import google.colab
15 | !pip install git+https://github.com/amakelov/mandala
16 | except:
17 | pass
18 | ```
19 |
20 | ## Creating a `Storage`
21 |
22 | When creating a storage, you must decide if it will be in-memory or persisted on
23 | disk, and whether the storage will automatically version the `@op`s used with
24 | it:
25 |
26 |
27 | ```python
28 | from mandala.imports import Storage
29 | import os
30 |
31 | DB_PATH = 'my_persistent_storage.db'
32 | if os.path.exists(DB_PATH):
33 | os.remove(DB_PATH)
34 |
35 | storage = Storage(
36 | # omit for an in-memory storage
37 | db_path=DB_PATH,
38 | # omit to disable automatic dependency tracking & versioning
39 | # use "__main__" to only track functions defined in the current session
40 | deps_path='__main__',
41 | )
42 | ```
43 |
44 | ## Creating `@op`s and saving calls to them
45 | **Any Python function can be decorated with `@op`**:
46 |
47 |
48 | ```python
49 | from mandala.imports import op
50 |
51 | @op
52 | def sum_args(a, *args, b=1, **kwargs):
53 | return a + sum(args) + b + sum(kwargs.values())
54 | ```
55 |
56 | In general, calling `sum_args` will behave as if the `@op` decorator is not
57 | there. `@op`-decorated functions will interact with a `Storage` instance **only
58 | when** called inside a `with storage:` block:
59 |
60 |
61 | ```python
62 | with storage: # all `@op` calls inside this block use `storage`
63 | s = sum_args(6, 7, 8, 9, c=11,)
64 | print(s)
65 | ```
66 |
67 | AtomRef(42, hid=168...)
68 |
69 |
70 | This code runs the call to `sum_args`, and saves the inputs and outputs in the
71 | `storage` object, so that doing the same call later will directly load the saved
72 | outputs.
73 |
74 | ### When should something be an `@op`?
75 | As a general guide, you should make something an `@op` if you want to save its
76 | outputs, e.g. if they take a long time to compute but you need them for later
77 | analysis. Since `@op` [encourages
78 | composition](https://amakelov.github.io/mandala/02_retracing/#how-op-encourages-composition),
79 | you should aim to have `@op`s work on the outputs of other `@op`s, or on the
80 | [collections and/or items](https://amakelov.github.io/mandala/05_collections/)
81 | of outputs of other `@op`s.
82 |
83 | ## Working with `@op` outputs (`Ref`s)
84 | The objects (e.g. `s`) returned by `@op`s are always instances of a subclass of
85 | `Ref` (e.g., `AtomRef`), i.e. **references to objects in the storage**. Every
86 | `Ref` contains two metadata fields:
87 |
88 | - `cid`: a hash of the **content** of the object
89 | - `hid`: a hash of the **computational history** of the object, which is the precise
90 | composition of `@op`s that created this ref.
91 |
92 | Two `Ref`s with the same `cid` may have different `hid`s, and `hid` is the
93 | unique identifier of `Ref`s in the storage. However, only 1 copy per unique
94 | `cid` is stored to avoid duplication in the storage.
95 |
96 | ### `Ref`s can be in memory or not
97 | Additionally, `Ref`s have the `in_memory` property, which indicates if the
98 | underlying object is present in the `Ref` or if this is a "lazy" `Ref` which
99 | only contains metadata. **`Ref`s are only loaded in memory when needed for a new
100 | call to an `@op`**. For example, re-running the last code block:
101 |
102 |
103 | ```python
104 | with storage:
105 | s = sum_args(6, 7, 8, 9, c=11,)
106 | print(s)
107 | ```
108 |
109 | AtomRef(hid=168..., in_memory=False)
110 |
111 |
112 | To get the object wrapped by a `Ref`, call `storage.unwrap`:
113 |
114 |
115 | ```python
116 | storage.unwrap(s) # loads from storage only if necessary
117 | ```
118 |
119 |
120 |
121 |
122 | 42
123 |
124 |
125 |
126 | ### Other useful `Storage` methods
127 |
128 | - `Storage.attach(inplace: bool)`: like `unwrap`, but puts the objects in the
129 | `Ref`s if they are not in-memory.
130 | - `Storage.load_ref(hid: str, in_memory: bool)`: load a `Ref` by its history ID,
131 | optionally also loading the underlying object.
132 |
133 |
134 | ```python
135 | print(storage.attach(obj=s, inplace=False))
136 | print(storage.load_ref(s.hid))
137 | ```
138 |
139 | AtomRef(42, hid=168...)
140 | AtomRef(42, hid=168...)
141 |
142 |
143 | ## Working with `Call` objects
144 | Besides `Ref`s, the other kind of object in the storage is the `Call`, which
145 | stores references to the inputs and outputs of a call to an `@op`, together with
146 | metadata that mirrors the `Ref` metadata:
147 |
148 | - `Call.cid`: a content ID for the call, based on the `@op`'s identity, its
149 | version at the time of the call, and the `cid`s of the inputs
150 | - `Call.hid`: a history ID for the call, the same as `Call.cid`, but using the
151 | `hid`s of the inputs.
152 |
153 | **For every `Ref` history ID, there's at most one `Call` that has an output with
154 | this history ID**, and if it exists, this call can be found by calling
155 | `storage.get_ref_creator()`:
156 |
157 |
158 | ```python
159 | call = storage.get_ref_creator(ref=s)
160 | print(call)
161 | display(call.inputs)
162 | display(call.outputs)
163 | ```
164 |
165 | Call(sum_args, hid=f99...)
166 |
167 |
168 |
169 | {'a': AtomRef(hid=c6a..., in_memory=False),
170 | 'args_0': AtomRef(hid=e0f..., in_memory=False),
171 | 'args_1': AtomRef(hid=479..., in_memory=False),
172 | 'args_2': AtomRef(hid=c37..., in_memory=False),
173 | 'b': AtomRef(hid=610..., in_memory=False),
174 | 'c': AtomRef(hid=a33..., in_memory=False)}
175 |
176 |
177 |
178 | {'output_0': AtomRef(hid=168..., in_memory=False)}
179 |
180 |
--------------------------------------------------------------------------------
/docs/docs/topics/05_collections_files/05_collections_5_0.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
83 |
--------------------------------------------------------------------------------
/docs/docs/topics/03_cf_files/03_cf_28_0.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
82 |
--------------------------------------------------------------------------------
/mandala/tests/test_memoization.py:
--------------------------------------------------------------------------------
1 | from mandala.imports import *
2 | import numpy as np
3 |
4 |
5 | def test_storage():
6 | storage = Storage()
7 |
8 | @op
9 | def inc(x: int) -> int:
10 | return x + 1
11 |
12 | with storage:
13 | x = 1
14 | y = inc(x)
15 | z = inc(2)
16 | w = inc(y)
17 |
18 | assert w.cid == z.cid
19 | assert w.hid != y.hid
20 | assert w.cid != y.cid
21 | assert storage.unwrap(y) == 2
22 | assert storage.unwrap(z) == 3
23 | assert storage.unwrap(w) == 3
24 | for ref in (y, z, w):
25 | assert storage.attach(ref).in_memory
26 | assert storage.attach(ref).obj == storage.unwrap(ref)
27 |
28 |
29 | def test_signatures():
30 | storage = Storage()
31 |
32 | @op # a function with a wild input/output signature
33 | def add(x, *args, y: int = 1, **kwargs):
34 | # just sum everything
35 | res = x + sum(args) + y + sum(kwargs.values())
36 | if kwargs:
37 | return res, kwargs
38 | elif args:
39 | return None
40 | else:
41 | return res
42 |
43 | with storage:
44 | # call the func in all the ways
45 | sum_1 = add(1)
46 | sum_2 = add(1, 2, 3, 4, )
47 | sum_3 = add(1, 2, 3, 4, y=5)
48 | sum_4 = add(1, 2, 3, 4, y=5, z=6)
49 | sum_5 = add(1, 2, 3, 4, z=5, w=7)
50 |
51 | assert storage.unwrap(sum_1) == 2
52 | assert storage.unwrap(sum_2) == None
53 | assert storage.unwrap(sum_3) == None
54 | assert storage.unwrap(sum_4) == (21, {'z': 6})
55 | assert storage.unwrap(sum_5) == (23, {'z': 5, 'w': 7})
56 |
57 |
58 | def test_retracing():
59 | storage = Storage()
60 |
61 | @op
62 | def inc(x):
63 | return x + 1
64 |
65 | ### iterating a function
66 | with storage:
67 | start = 1
68 | for i in range(10):
69 | start = inc(start)
70 |
71 | with storage:
72 | start = 1
73 | for i in range(10):
74 | start = inc(start)
75 |
76 | ### composing functions
77 | @op
78 | def add(x, y):
79 | return x + y
80 |
81 | with storage:
82 | inp = [1, 2, 3, 4, 5]
83 | stage_1 = [inc(x) for x in inp]
84 | stage_2 = [add(x, y) for x, y in zip(stage_1, stage_1)]
85 |
86 | with storage:
87 | inp = [1, 2, 3, 4, 5]
88 | stage_1 = [inc(x) for x in inp]
89 | stage_2 = [add(x, y) for x, y in zip(stage_1, stage_1)]
90 |
91 |
92 | def test_lists():
93 | storage = Storage()
94 |
95 | @op
96 | def get_sum(elts: MList[int]) -> int:
97 | return sum(elts)
98 |
99 | @op
100 | def primes_below(n: int) -> MList[int]:
101 | primes = []
102 | for i in range(2, n):
103 | for p in primes:
104 | if i % p == 0:
105 | break
106 | else:
107 | primes.append(i)
108 | return primes
109 |
110 | @op
111 | def chunked_square(elts: MList[int]) -> MList[int]:
112 | # a model for an op that does something on chunks of a big thing
113 | # to prevent OOM errors
114 | return [x*x for x in elts]
115 |
116 | with storage:
117 | n = 10
118 | primes = primes_below(n)
119 | sum_primes = get_sum(primes)
120 | assert len(primes) == 4
121 | # check indexing
122 | assert storage.unwrap(primes[0]) == 2
123 | assert storage.unwrap(primes[:2]) == [2, 3]
124 |
125 | ### lists w/ overlapping elements
126 | with storage:
127 | n = 100
128 | primes = primes_below(n)
129 | for i in range(0, len(primes), 2):
130 | sum_primes = get_sum(primes[:i+1])
131 |
132 | with storage:
133 | elts = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
134 | squares = chunked_square(elts)
135 |
136 |
137 |
138 | def test_ignore():
139 |
140 | storage = Storage()
141 |
142 | @op(ignore_args=('irrelevant',))
143 | def inc(x, irrelevant):
144 | return x + 1
145 |
146 |
147 | with storage:
148 | inc(23, 0)
149 |
150 | df = storage.cf(inc).df()
151 | assert len(df) == 1
152 |
153 | with storage:
154 | inc(23, 1)
155 |
156 | df = storage.cf(inc).df()
157 | assert len(df) == 1
158 |
159 |
160 | def test_clear_uncommitted():
161 | storage = Storage()
162 |
163 | @op
164 | def inc(x):
165 | return x + 1
166 |
167 | with storage:
168 | for i in range(10):
169 | inc(i)
170 | # attempt to clear the atoms cache without having committed; this should
171 | # fail by default
172 | try:
173 | storage.atoms.clear()
174 | assert False
175 | except ValueError:
176 | pass
177 |
178 | # now clear the atoms cache after committing
179 | storage.commit()
180 | storage.atoms.clear()
181 |
182 |
183 |
184 | def test_newargdefault():
185 | storage = Storage()
186 |
187 | @op
188 | def add(x,):
189 | return x + 1
190 |
191 | with storage:
192 | add(1)
193 |
194 | @op
195 | def add(x, y=NewArgDefault(1)):
196 | return x + y
197 |
198 | with storage:
199 | add(1)
200 | # check that we didn't make a new call
201 | assert len(storage.cf(add).calls) == 1
202 |
203 | with storage:
204 | add(1, 1)
205 | # check that we didn't make a new call
206 | assert len(storage.cf(add).calls) == 1
207 |
208 | with storage:
209 | add(1, 2)
210 | # now this should have made a new call!
211 | assert len(storage.cf(add).calls) == 2
212 |
213 | def test_newargdefault_compound_types():
214 | storage = Storage()
215 |
216 | @op
217 | def add_array(x:np.ndarray):
218 | return x
219 | with storage:
220 | add_array(np.array([1, 2, 3]))
221 |
222 | @op
223 | def add_array(x:np.ndarray, y=NewArgDefault(None)):
224 | return x + y
225 | # test passing a raw value
226 | with storage:
227 | add_array(np.array([1, 2, 3]), y=np.array([4, 5, 6]))
228 |
229 | # now test passing a wrapped value
230 | with storage:
231 | add_array(np.array([1, 2, 3]), y=wrap_atom(np.array([7, 8, 9])))
232 |
233 |
234 |
235 |
236 | def test_value_pointer():
237 | storage = Storage()
238 |
239 | @op
240 | def get_mean(x: np.ndarray) -> float:
241 | return x.mean()
242 |
243 | with storage:
244 | X = np.array([1, 2, 3, 4, 5])
245 | X_pointer = ValuePointer("X", X)
246 | mean = get_mean(X_pointer)
247 |
248 | assert storage.unwrap(mean) == 3.0
249 | df = storage.cf(get_mean).df()
250 | assert len(df) == 1
251 | assert df['x'].item().id == "X"
252 |
--------------------------------------------------------------------------------
/docs/docs/topics/02_retracing.md:
--------------------------------------------------------------------------------
1 | # Patterns for Incremental Computation & Development
2 |
3 |
4 |
5 | **`@op`-decorated functions are designed to be composed** with one another. This
6 | enables the same piece of imperative code to adapt to multiple goals depending
7 | on the situation:
8 |
9 | - saving new `@op` calls and/or loading previous ones;
10 | - cheaply resuming an `@op` program after a failure;
11 | - incrementally adding more logic and computations to the same code without
12 | re-doing work.
13 |
14 | **This section of the documentation does not introduce new methods or classes**.
15 | Instead, it demonstrates the programming patterns needed to make effective use
16 | of `mandala`'s memoization capabilities.
17 |
18 | ## How `@op` encourages composition
19 | There are several ways in which the `@op` decorator encourages (and even
20 | enforces) composition of `@op`s:
21 |
22 | - **`@op`s return special objects**, `Ref`s, which prevents accidentally calling
23 | a non-`@op` on the output of an `@op`
24 | - If the inputs to an `@op` call are already `Ref`s, this **speeds up the cache
25 | lookups**.
26 | - If the call can be reused, the **input `Ref`s don't even need to be in memory**
27 | (because the lookup is based only on `Ref` metadata).
28 | - When `@op`s are composed, **computational history propagates** through this
29 | composition. This is automatically leveraged by `ComputationFrame`s when
30 | querying the storage.
31 | - Though not documented here, **`@op`s can natively handle Python
32 | collections** like lists and dicts. This
33 |
34 | When `@op`s are composed in this way, the entire computation becomes "end-to-end
35 | [memoized](https://en.wikipedia.org/wiki/Memoization)".
36 |
37 | ## Toy ML pipeline example
38 | Here's a small example of a machine learning pipeline:
39 |
40 |
41 | ```python
42 | # for Google Colab
43 | try:
44 | import google.colab
45 | !pip install git+https://github.com/amakelov/mandala
46 | except:
47 | pass
48 | ```
49 |
50 |
51 | ```python
52 | from mandala.imports import *
53 | from sklearn.datasets import load_digits
54 | from sklearn.ensemble import RandomForestClassifier
55 | from sklearn.metrics import accuracy_score
56 |
57 | @op
58 | def load_data(n_class=2):
59 | print("Loading data")
60 | return load_digits(n_class=n_class, return_X_y=True)
61 |
62 | @op
63 | def train_model(X, y, n_estimators=5):
64 | print("Training model")
65 | return RandomForestClassifier(n_estimators=n_estimators,
66 | max_depth=2).fit(X, y)
67 |
68 | @op
69 | def get_acc(model, X, y):
70 | print("Getting accuracy")
71 | return round(accuracy_score(y_pred=model.predict(X), y_true=y), 2)
72 |
73 | storage = Storage()
74 |
75 | with storage:
76 | X, y = load_data()
77 | model = train_model(X, y)
78 | acc = get_acc(model, X, y)
79 | print(acc)
80 | ```
81 |
82 | Loading data
83 | Training model
84 | Getting accuracy
85 | AtomRef(1.0, hid=d16...)
86 |
87 |
88 | ## Retracing your steps with memoization
89 | Running the computation again will not execute any calls, because it will
90 | exactly **retrace** calls that happened in the past. Moreover, the retracing is
91 | **lazy**: none of the values along the way are actually loaded from storage:
92 |
93 |
94 | ```python
95 | with storage:
96 | X, y = load_data()
97 | print(X, y)
98 | model = train_model(X, y)
99 | print(model)
100 | acc = get_acc(model, X, y)
101 | print(acc)
102 | ```
103 |
104 | AtomRef(hid=d0f..., in_memory=False) AtomRef(hid=f1a..., in_memory=False)
105 | AtomRef(hid=caf..., in_memory=False)
106 | AtomRef(hid=d16..., in_memory=False)
107 |
108 |
109 | This puts all the `Ref`s along the way in your local variables (as if you've
110 | just ran the computation), which lets you easily inspect any intermediate
111 | variables in this `@op` composition:
112 |
113 |
114 | ```python
115 | storage.unwrap(acc)
116 | ```
117 |
118 |
119 |
120 |
121 | 1.0
122 |
123 |
124 |
125 | ## Adding new calls "in-place" in `@op`-based programs
126 | With `mandala`, you don't need to think about what's already been computed and
127 | split up code based on that. All past results are automatically reused, so you can
128 | directly build upon the existing composition of `@op`s when you want to add new
129 | functions and/or run old ones with different parameters:
130 |
131 |
132 | ```python
133 | # reuse the previous code to loop over more values of n_class and n_estimators
134 | with storage:
135 | for n_class in (2, 5,):
136 | X, y = load_data(n_class)
137 | for n_estimators in (5, 10):
138 | model = train_model(X, y, n_estimators=n_estimators)
139 | acc = get_acc(model, X, y)
140 | print(acc)
141 | ```
142 |
143 | AtomRef(hid=d16..., in_memory=False)
144 | Training model
145 | Getting accuracy
146 | AtomRef(1.0, hid=6fd...)
147 | Loading data
148 | Training model
149 | Getting accuracy
150 | AtomRef(0.88, hid=158...)
151 | Training model
152 | Getting accuracy
153 | AtomRef(0.88, hid=214...)
154 |
155 |
156 | Note that the first value of `acc` from the nested loop is with
157 | `in_memory=False`, because it was reused from the call we did before; the other
158 | values are in memory, as they were freshly computed.
159 |
160 | This pattern lets you incrementally build towards the final computations you
161 | want without worrying about how results will be reused.
162 |
163 | ## Using control flow efficiently with `@op`s
164 | Because the unit of storage is the function call (as opposed to an entire script
165 | or notebook), you can transparently use Pythonic control flow. If the control
166 | flow depends on a `Ref`, you can explicitly load just this `Ref` in memory
167 | using `storage.unwrap`:
168 |
169 |
170 | ```python
171 | with storage:
172 | for n_class in (2, 5,):
173 | X, y = load_data(n_class)
174 | for n_estimators in (5, 10):
175 | model = train_model(X, y, n_estimators=n_estimators)
176 | acc = get_acc(model, X, y)
177 | if storage.unwrap(acc) > 0.9: # load only the `Ref`s needed for control flow
178 | print(n_class, n_estimators, storage.unwrap(acc))
179 | ```
180 |
181 | 2 5 1.0
182 | 2 10 1.0
183 |
184 |
185 | ## Memoized code as storage interface
186 | An end-to-end memoized composition of `@op`s is like an "imperative" storage
187 | interface. You can modify the code to only focus on particular results of
188 | interest:
189 |
190 |
191 | ```python
192 | with storage:
193 | for n_class in (5,):
194 | X, y = load_data(n_class)
195 | for n_estimators in (5,):
196 | model = train_model(X, y, n_estimators=n_estimators)
197 | acc = get_acc(model, X, y)
198 | print(storage.unwrap(acc), storage.unwrap(model))
199 | ```
200 |
201 | 0.88 RandomForestClassifier(max_depth=2, n_estimators=5)
202 |
203 |
--------------------------------------------------------------------------------
/docs/docs/tutorials/01_hello_files/01_hello_5_3.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
91 |
--------------------------------------------------------------------------------
/docs/docs/blog/01_cf_files/01_cf_2_0.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
91 |
--------------------------------------------------------------------------------
/mandala/deps/deep_versions.py:
--------------------------------------------------------------------------------
1 | from ..common_imports import *
2 | from .utils import DepKey, hash_dict
3 | from .model import (
4 | Node,
5 | CallableNode,
6 | GlobalVarNode,
7 | TerminalNode,
8 | )
9 |
10 |
11 | from .shallow_versions import DAG
12 |
13 |
14 | class Version:
15 | """
16 | Model of a "deep" version of a component that includes versions of its
17 | dependencies.
18 | """
19 |
20 | def __init__(
21 | self,
22 | component: DepKey,
23 | dynamic_deps_commits: Dict[DepKey, str],
24 | memoized_deps_content_versions: Dict[DepKey, Set[str]],
25 | ):
26 | ### raw data from the trace
27 | # the component whose dependencies are traced
28 | self.component = component
29 | # the content hashes of the direct dependencies
30 | self.direct_deps_commits = dynamic_deps_commits
31 | # pointers to content hashes of versions of memoized calls
32 | self.memoized_deps_content_versions = memoized_deps_content_versions
33 |
34 | ### cached data. These are set against a dependency state
35 | self._is_synced = False
36 | # the expanded set of dependencies, including all transitive
37 | # dependencies. Note this is a set of *content* hashes per dependency
38 | self._content_expansion: Dict[DepKey, Set[str]] = None
39 | # a hash uniquely identifying the content of dependencies of this version
40 | self._content_version: str = None
41 | # the semantic hashes of all dependencies for this version;
42 | # the system enforces that the semantic hash of a dependency is the same
43 | # for all commits of a component referenced by this version
44 | self._semantic_expansion: Dict[DepKey, str] = None
45 | # overall semantic hash of this version
46 | self._semantic_version: str = None
47 |
48 | @property
49 | def presentation(self) -> str:
50 | return f'Version of "{self.component[1]}" from module "{self.component[0]}" (content: {self.content_version}, semantic: {self.semantic_version})'
51 |
52 | @staticmethod
53 | def from_trace(
54 | component: DepKey, nodes: Dict[DepKey, Node], strict: bool = True
55 | ) -> "Version":
56 | dynamic_deps_commits = {}
57 | memoized_deps_content_versions = defaultdict(set)
58 | for dep_key, node in nodes.items():
59 | if isinstance(node, (CallableNode, GlobalVarNode)):
60 | dynamic_deps_commits[dep_key] = node.content_hash
61 | elif isinstance(node, TerminalNode):
62 | terminal_data = node.representation
63 | pointer_dep_key = terminal_data.dep_key
64 | version_content_hash = terminal_data.call_content_version
65 | memoized_deps_content_versions[pointer_dep_key].add(
66 | version_content_hash
67 | )
68 | else:
69 | raise ValueError(f"Unexpected node type {type(node)}")
70 | return Version(
71 | component=component,
72 | dynamic_deps_commits=dynamic_deps_commits,
73 | memoized_deps_content_versions=dict(memoized_deps_content_versions),
74 | )
75 |
76 | ############################################################################
77 | ### methods for setting cached data from a versioning state
78 | ############################################################################
79 | def _set_content_expansion(self, all_versions: Dict[DepKey, Dict[str, "Version"]]):
80 | result = defaultdict(set)
81 | for dep_key, content_hash in self.direct_deps_commits.items():
82 | result[dep_key].add(content_hash)
83 | for (
84 | dep_key,
85 | memoized_content_versions,
86 | ) in self.memoized_deps_content_versions.items():
87 | for memoized_content_version in memoized_content_versions:
88 | referenced_version = all_versions[dep_key][memoized_content_version]
89 | for (
90 | referenced_dep_key,
91 | referenced_content_hashes,
92 | ) in referenced_version.content_expansion.items():
93 | result[referenced_dep_key].update(referenced_content_hashes)
94 | self._content_expansion = dict(result)
95 |
96 | def _set_content_version(self):
97 | self._content_version = hash_dict(
98 | {
99 | dep_key: tuple(sorted(self.content_expansion[dep_key]))
100 | for dep_key in self.content_expansion
101 | }
102 | )
103 |
104 | def _set_semantic_expansion(
105 | self,
106 | component_dags: Dict[DepKey, DAG],
107 | all_versions: Dict[DepKey, Dict[str, "Version"]],
108 | ):
109 | result = {}
110 | # from own deps
111 | for dep_key, dep_content_hash in self.direct_deps_commits.items():
112 | dag = component_dags[dep_key]
113 | semantic_hash = dag.commits[dep_content_hash].semantic_hash
114 | result[dep_key] = semantic_hash
115 | # from pointers
116 | for (
117 | dep_key,
118 | memoized_content_versions,
119 | ) in self.memoized_deps_content_versions.items():
120 | for memoized_content_version in memoized_content_versions:
121 | dep_version_semantic_hashes = all_versions[dep_key][
122 | memoized_content_version
123 | ].semantic_expansion
124 | overlap = set(result.keys()).intersection(
125 | dep_version_semantic_hashes.keys()
126 | )
127 | if any(result[k] != dep_version_semantic_hashes[k] for k in overlap):
128 | raise ValueError(
129 | f"Version {self} has conflicting semantic hashes for {overlap}"
130 | )
131 | result.update(dep_version_semantic_hashes)
132 | self._semantic_expansion = result
133 | self._semantic_version = hash_dict(result)
134 |
135 | def sync(
136 | self,
137 | component_dags: Dict[DepKey, DAG],
138 | all_versions: Dict[DepKey, Dict[str, "Version"]],
139 | ):
140 | """
141 | Set all the cached things in the correct order
142 | """
143 | self._set_content_expansion(all_versions=all_versions)
144 | self._set_content_version()
145 | self._set_semantic_expansion(
146 | component_dags=component_dags, all_versions=all_versions
147 | )
148 | self.set_synced()
149 |
150 | @property
151 | def content_version(self) -> str:
152 | assert self._content_version is not None
153 | return self._content_version
154 |
155 | @property
156 | def semantic_version(self) -> str:
157 | assert self._semantic_version is not None
158 | return self._semantic_version
159 |
160 | @property
161 | def semantic_expansion(self) -> Dict[DepKey, str]:
162 | assert self._semantic_expansion is not None
163 | return self._semantic_expansion
164 |
165 | @property
166 | def content_expansion(self) -> Dict[DepKey, Set[str]]:
167 | assert self._content_expansion is not None
168 | return self._content_expansion
169 |
170 | @property
171 | def support(self) -> Iterable[DepKey]:
172 | return self.content_expansion.keys()
173 |
174 | @property
175 | def is_synced(self) -> bool:
176 | return self._is_synced
177 |
178 | def set_synced(self):
179 | # it can only go from unsynced to synced
180 | if self._is_synced:
181 | raise ValueError("Version is already synced")
182 | self._is_synced = True
183 |
184 | def __repr__(self) -> str:
185 | return f"""
186 | Version(
187 | dependencies={['.'.join(elt) for elt in self.support]},
188 | )"""
189 |
--------------------------------------------------------------------------------
/docs/docs/topics/03_cf_files/03_cf_14_1.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
97 |
--------------------------------------------------------------------------------
/mandala/deps/utils.py:
--------------------------------------------------------------------------------
1 | import types
2 | import dis
3 | import importlib
4 | import gc
5 | from typing import Literal
6 |
7 | from ..common_imports import *
8 | from ..model import Ref
9 | from ..utils import get_content_hash, unwrap_decorators
10 | from ..config import Config
11 |
12 | DepKey = Tuple[str, str] # (module name, object address in module)
13 |
14 |
15 | class GlobalClassifier:
16 | """
17 | Try to bucket Python objects into categories for the sake of tracking
18 | global state.
19 | """
20 | SCALARS = "scalars"
21 | DATA = "data"
22 | ALL = "all"
23 |
24 | @staticmethod
25 | def is_excluded(obj: Any) -> bool:
26 | return (
27 | inspect.ismodule(obj) # exclude modules
28 | or isinstance(obj, type) # exclude classes
29 | or inspect.isfunction(obj) # exclude functions
30 | # or callable(obj) # exclude callables... this is very questionable
31 | or type(obj).__name__ == Config.func_interface_cls_name #! a hack to exclude memoized functions
32 | )
33 |
34 | @staticmethod
35 | def is_scalar(obj: Any) -> bool:
36 | result = isinstance(obj, (int, float, str, bool, type(None)))
37 | return result
38 |
39 | @staticmethod
40 | def is_data(obj: Any) -> bool:
41 | if GlobalClassifier.is_scalar(obj):
42 | result = True
43 | elif type(obj) in (tuple, list):
44 | result = all(GlobalClassifier.is_data(x) for x in obj)
45 | elif type(obj) is dict:
46 | result = all(GlobalClassifier.is_data((x, y)) for (x, y) in obj.items())
47 | elif type(obj) in (np.ndarray, pd.DataFrame, pd.Series, pd.Index):
48 | result = True
49 | else:
50 | result = False
51 | # if not result and not GlobalsStrictness.is_callable(obj):
52 | # logger.warning(f'Access to global variable "{obj}" is not tracked because it is not a scalar or a data structure')
53 | return result
54 |
55 |
56 | def is_global_val(obj: Any, allow_only: str = "all") -> bool:
57 | """
58 | Determine whether the given Python object should be treated as a global
59 | variable whose *content* should be tracked.
60 |
61 | The alternative is that this is a callable object whose *dependencies*
62 | should be tracked.
63 |
64 | However, the distinction is not always clear, making this method somewhat
65 | heuristic. For example, a callable object could be either a function or a
66 | global variable we want to track.
67 | """
68 | if isinstance(obj, Ref): # easy case; we always track globals that we explicitly wrapped
69 | return True
70 | if allow_only == GlobalClassifier.SCALARS:
71 | return GlobalClassifier.is_scalar(obj=obj)
72 | elif allow_only == GlobalClassifier.DATA:
73 | return GlobalClassifier.is_data(obj=obj)
74 | elif allow_only == GlobalClassifier.ALL:
75 | return not (
76 | inspect.ismodule(obj) # exclude modules
77 | or isinstance(obj, type) # exclude classes
78 | or inspect.isfunction(obj) # exclude functions
79 | # or callable(obj) # exclude callables ### this is very questionable
80 | or type(obj).__name__
81 | == Config.func_interface_cls_name #! a hack to exclude memoized functions
82 | )
83 | else:
84 | raise ValueError(
85 | f"Unknown strictness level for tracking global variables: {allow_only}"
86 | )
87 |
88 |
89 | def is_callable_obj(obj: Any, strict: bool) -> bool:
90 | if type(obj).__name__ == Config.func_interface_cls_name:
91 | return True
92 | if isinstance(obj, types.FunctionType):
93 | return True
94 | if not strict and callable(obj): # quite permissive
95 | return True
96 | return False
97 |
98 |
99 | def extract_func_obj(obj: Any, strict: bool) -> types.FunctionType:
100 | if type(obj).__name__ == Config.func_interface_cls_name:
101 | return obj.f
102 | obj = unwrap_decorators(obj, strict=strict)
103 | if isinstance(obj, types.BuiltinFunctionType):
104 | raise ValueError(f"Expected a non-built-in function, but got {obj}")
105 | if not isinstance(obj, types.FunctionType):
106 | if not strict:
107 | if (
108 | isinstance(obj, type)
109 | and hasattr(obj, "__init__")
110 | and isinstance(obj.__init__, types.FunctionType)
111 | ):
112 | return obj.__init__
113 | else:
114 | return unknown_function
115 | else:
116 | raise ValueError(f"Expected a function, but got {obj} of type {type(obj)}")
117 | return obj
118 |
119 |
120 | def extract_code(obj: Callable) -> types.CodeType:
121 | if type(obj).__name__ == Config.func_interface_cls_name:
122 | obj = obj.f
123 | if isinstance(obj, property):
124 | obj = obj.fget
125 | obj = unwrap_decorators(obj, strict=True)
126 | if not isinstance(obj, (types.FunctionType, types.MethodType)):
127 | logger.debug(f"Expected a function or method, but got {type(obj)}")
128 | # raise ValueError(f"Expected a function or method, but got {obj}")
129 | return obj.__code__
130 |
131 |
132 | def get_runtime_description(code: types.CodeType) -> Any:
133 | assert isinstance(code, types.CodeType)
134 | return get_sanitized_bytecode_representation(code=code)
135 |
136 |
137 | def get_global_names_candidates(code: types.CodeType) -> Set[str]:
138 | result = set()
139 | instructions = list(dis.get_instructions(code))
140 | for instr in instructions:
141 | if instr.opname == "LOAD_GLOBAL":
142 | result.add(instr.argval)
143 | if isinstance(instr.argval, types.CodeType):
144 | result.update(get_global_names_candidates(instr.argval))
145 | return result
146 |
147 |
148 | def get_sanitized_bytecode_representation(
149 | code: types.CodeType,
150 | ) -> List[dis.Instruction]:
151 | instructions = list(dis.get_instructions(code))
152 | result = []
153 | for instr in instructions:
154 | if isinstance(instr.argval, types.CodeType):
155 | result.append(
156 | dis.Instruction(
157 | instr.opname,
158 | instr.opcode,
159 | instr.arg,
160 | get_sanitized_bytecode_representation(instr.argval),
161 | "",
162 | instr.offset,
163 | instr.starts_line,
164 | is_jump_target=instr.is_jump_target,
165 | )
166 | )
167 | else:
168 | result.append(instr)
169 | return result
170 |
171 |
172 | def unknown_function():
173 | # this is a placeholder function that we use to get the source of
174 | # functions that we can't get the source of
175 | pass
176 |
177 |
178 | UNKNOWN_GLOBAL_VAR = "UNKNOWN_GLOBAL_VAR"
179 |
180 |
181 | def get_bytecode(f: Union[types.FunctionType, types.CodeType, str]) -> str:
182 | if isinstance(f, str):
183 | f = compile(f, "", "exec")
184 | instructions = dis.get_instructions(f)
185 | return "\n".join([str(i) for i in instructions])
186 |
187 |
188 | def hash_dict(d: dict) -> str:
189 | return get_content_hash(obj=[(k, d[k]) for k in sorted(d.keys())])
190 |
191 |
192 | def load_obj(module_name: str, obj_name: str) -> Tuple[Any, bool]:
193 | module = importlib.import_module(module_name)
194 | parts = obj_name.split(".")
195 | current = module
196 | found = True
197 | for part in parts:
198 | if not hasattr(current, part):
199 | found = False
200 | break
201 | else:
202 | current = getattr(current, part)
203 | return current, found
204 |
205 |
206 | def get_dep_key_from_func(func: types.FunctionType) -> DepKey:
207 | module_name = func.__module__
208 | qualname = func.__qualname__
209 | return module_name, qualname
210 |
211 |
212 | def get_func_qualname(
213 | func_name: str,
214 | code: types.CodeType,
215 | frame: types.FrameType,
216 | ) -> str:
217 | # this is evil
218 | referrers = gc.get_referrers(code)
219 | func_referrers = [r for r in referrers if isinstance(r, types.FunctionType)]
220 | matching_name = [r for r in func_referrers if r.__name__ == func_name]
221 | if len(matching_name) != 1:
222 | return get_func_qualname_fallback(func_name=func_name, code=code, frame=frame)
223 | else:
224 | return matching_name[0].__qualname__
225 |
226 |
227 | def get_func_qualname_fallback(
228 | func_name: str, code: types.CodeType, frame: types.FrameType
229 | ) -> str:
230 | # get the argument names to *try* to tell if the function is a method
231 | arg_names = code.co_varnames[: code.co_argcount]
232 | # a necessary but not sufficient condition for this to
233 | # be a method
234 | is_probably_method = (
235 | len(arg_names) > 0
236 | and arg_names[0] == "self"
237 | and hasattr(frame.f_locals["self"].__class__, func_name)
238 | )
239 | if is_probably_method:
240 | # handle nested classes via __qualname__
241 | cls_qualname = frame.f_locals["self"].__class__.__qualname__
242 | func_qualname = f"{cls_qualname}.{func_name}"
243 | else:
244 | func_qualname = func_name
245 | return func_qualname
246 |
--------------------------------------------------------------------------------
/mandala/tests/test_versioning.py:
--------------------------------------------------------------------------------
1 | from mandala.imports import *
2 | from mandala.deps.shallow_versions import DAG
3 | from mandala.deps.tracers import DecTracer, SysTracer
4 | from pathlib import Path
5 | import os
6 | import uuid
7 | import pytest
8 |
9 | def test_dag():
10 | d = DAG(content_type="code")
11 | try:
12 | d.commit("something")
13 | except AssertionError:
14 | pass
15 |
16 | content_hash_1 = d.init(initial_content="something")
17 | assert len(d.commits) == 1
18 | assert d.head == content_hash_1
19 | content_hash_2 = d.commit(content="something else", is_semantic_change=True)
20 | assert (
21 | d.commits[content_hash_2].semantic_hash
22 | != d.commits[content_hash_1].semantic_hash
23 | )
24 | assert d.head == content_hash_2
25 | content_hash_3 = d.commit(content="something else #2", is_semantic_change=False)
26 | assert (
27 | d.commits[content_hash_3].semantic_hash
28 | == d.commits[content_hash_2].semantic_hash
29 | )
30 |
31 | content_hash_4 = d.sync(content="something else")
32 | assert content_hash_4 == content_hash_2
33 | assert d.head == content_hash_2
34 |
35 | d.show()
36 |
37 | def generate_path(ext: str) -> Path:
38 | output_dir = Path(os.path.dirname(os.path.abspath(__file__)) + "/output")
39 | fname = str(uuid.uuid4()) + ext
40 | return output_dir / fname
41 |
42 | # MODULE_NAME = "mandala.tests.test_versioning"
43 | # DEPS_PACKAGE = "mandala.tests"
44 | DEPS_PATH = Path(__file__).parent.absolute().resolve()
45 | MODULE_NAME = "test_versioning"
46 | # MODULE_NAME = '__main__'
47 |
48 |
49 | def _test_version_reprs(storage: Storage):
50 | for dag in storage.get_versioner().component_dags.values():
51 | for compact in [True, False]:
52 | dag.show(compact=compact)
53 | for version in storage.get_versioner().get_flat_versions().values():
54 | storage.get_versioner().present_dependencies(commits=version.semantic_expansion)
55 | storage.get_versioner().global_topology.show(path=generate_path(ext=".png"))
56 | repr(storage.get_versioner().global_topology)
57 |
58 |
59 | @pytest.mark.parametrize("tracer_impl", [DecTracer])
60 | def test_unit(tracer_impl):
61 | storage = Storage(
62 | deps_path=DEPS_PATH, tracer_impl=tracer_impl
63 | )
64 |
65 | # to be able to import this name
66 | global f_1, A
67 |
68 | A = 42
69 |
70 | @op
71 | def f_1(x) -> int:
72 | return 23 + A
73 |
74 | with storage:
75 | f_1(1)
76 |
77 | vs = storage.get_versioner()
78 | print(vs.versions)
79 | f_1_versions = vs.versions[MODULE_NAME, "f_1"]
80 | assert len(f_1_versions) == 1
81 | version = f_1_versions[list(f_1_versions.keys())[0]]
82 | assert set(version.support) == {(MODULE_NAME, "f_1"), (MODULE_NAME, "A")}
83 | _test_version_reprs(storage=storage )
84 |
85 | @pytest.mark.parametrize("tracer_impl", [DecTracer])
86 | def test_deps(tracer_impl):
87 | storage = Storage(
88 | deps_path=DEPS_PATH, tracer_impl=tracer_impl
89 | )
90 | global dep_1, f_2, A
91 | if tracer_impl is SysTracer:
92 | track = lambda x: x
93 | else:
94 | from mandala.deps.tracers.dec_impl import track
95 |
96 | A = 42
97 |
98 | @track
99 | def dep_1(x) -> int:
100 | return 23
101 |
102 | @track
103 | @op
104 | def f_2(x) -> int:
105 | return dep_1(x) + A
106 |
107 | with storage:
108 | f_2(1)
109 |
110 | vs = storage.get_versioner()
111 | f_2_versions = vs.versions[MODULE_NAME, "f_2"]
112 | assert len(f_2_versions) == 1
113 | version = f_2_versions[list(f_2_versions.keys())[0]]
114 | assert set(version.support) == {
115 | (MODULE_NAME, "f_2"),
116 | (MODULE_NAME, "dep_1"),
117 | (MODULE_NAME, "A"),
118 | }
119 | _test_version_reprs(storage=storage)
120 |
121 |
122 | @pytest.mark.parametrize("tracer_impl", [DecTracer])
123 | def test_changes(tracer_impl):
124 | storage = Storage(
125 | deps_path=DEPS_PATH, tracer_impl=tracer_impl
126 | )
127 |
128 | global f
129 |
130 | @op
131 | def f(x) -> int:
132 | return x + 1
133 |
134 | with storage:
135 | f(1)
136 | commit_1 = storage.sync_component(
137 | component=f,
138 | is_semantic_change=None,
139 | )
140 |
141 | @op
142 | def f(x) -> int:
143 | return x + 2
144 |
145 | commit_2 = storage.sync_component(component=f, is_semantic_change=True)
146 | assert commit_1 != commit_2
147 | with storage:
148 | f(1)
149 |
150 | @op
151 | def f(x) -> int:
152 | return x + 1
153 |
154 | # confirm we reverted to the previous version
155 | commit_3 = storage.sync_component(
156 | component=f,
157 | is_semantic_change=None,
158 | )
159 | assert commit_3 == commit_1
160 | with storage:
161 | f(1)
162 |
163 | # create a new branch
164 | @op
165 | def f(x) -> int:
166 | return x + 3
167 |
168 | commit_4 = storage.sync_component(
169 | component=f,
170 | is_semantic_change=True,
171 | )
172 | assert commit_4 not in (commit_1, commit_2)
173 | with storage:
174 | f(1)
175 |
176 | f_versions = storage.get_versioner().versions[MODULE_NAME, "f"]
177 | assert len(f_versions) == 3
178 | semantic_versions = [v.semantic_version for v in f_versions.values()]
179 | assert len(set(semantic_versions)) == 3
180 | _test_version_reprs(storage=storage)
181 |
182 |
183 | @pytest.mark.parametrize("tracer_impl", [DecTracer])
184 | def _test_dependency_patterns(tracer_impl):
185 | # this test is borked currently
186 | storage = Storage(
187 | deps_path=DEPS_PATH, tracer_impl=tracer_impl
188 | )
189 | global A, B, f_1, f_2, f_3, f_4, f_5, f_6
190 | if tracer_impl is SysTracer:
191 | track = lambda x: x
192 | else:
193 | from mandala.deps.tracers.dec_impl import track
194 |
195 | # global vars
196 | A = 23
197 | B = [1, 2, 3]
198 |
199 | # using a global var
200 | @track
201 | def f_1(x) -> int:
202 | return x + A
203 |
204 | # calling another function
205 | @track
206 | def f_2(x) -> int:
207 | return f_1(x) + B[0]
208 |
209 | # different dependencies per call
210 | @track
211 | @op
212 | def f_3(x) -> int:
213 | if x % 2 == 0:
214 | return f_2(2 * x)
215 | else:
216 | return f_1(x + 1)
217 |
218 | with storage:
219 | x = f_3(0)
220 |
221 | call = storage.get_ref_creator(x)
222 | version = storage.get_versioner().get_flat_versions()[call.content_version]
223 | assert version.support == {
224 | (MODULE_NAME, "f_3"),
225 | (MODULE_NAME, "f_2"),
226 | (MODULE_NAME, "A"),
227 | (MODULE_NAME, "B"),
228 | (MODULE_NAME, "f_1"),
229 | }
230 | with storage:
231 | x = f_3(1)
232 | call = storage.get_ref_creator(x)
233 | version = storage.get_versioner().get_flat_versions()[call.content_version]
234 | assert version.support == {
235 | (MODULE_NAME, "f_3"),
236 | (MODULE_NAME, "f_1"),
237 | (MODULE_NAME, "A"),
238 | }
239 |
240 | # using a lambda
241 | @track
242 | @op
243 | def f_4(x) -> int:
244 | f = lambda y: f_1(y) + B[0]
245 | return f(x)
246 |
247 | # make sure the call in the lambda is detected
248 | with storage:
249 | x = f_4(10)
250 | call = storage.get_ref_creator(x)
251 | version = storage.get_versioner().get_flat_versions()[call.content_version]
252 | assert version.support == {
253 | (MODULE_NAME, "f_4"),
254 | (MODULE_NAME, "f_1"),
255 | (MODULE_NAME, "A"),
256 | (MODULE_NAME, "B"),
257 | }
258 |
259 | # using comprehensions and generators
260 | @op
261 | def f_5(x) -> int:
262 | x = storage.unwrap(x)
263 | a = [f_1.f(y) for y in range(x)]
264 | b = {f_2.f(y) for y in range(x)}
265 | c = {y: f_3.f(y) for y in range(x)}
266 | return sum(storage.unwrap(f_4.f(y)) for y in range(x))
267 |
268 | with storage:
269 | f_5(10)
270 |
271 | f_5_versions = storage.get_versioner().versions[MODULE_NAME, "f_5"]
272 | assert len(f_5_versions) == 1
273 | version = f_5_versions[list(f_5_versions.keys())[0]]
274 | assert set(version.support) == {
275 | (MODULE_NAME, "f_5"),
276 | (MODULE_NAME, "f_4"),
277 | (MODULE_NAME, "f_3"),
278 | (MODULE_NAME, "f_2"),
279 | (MODULE_NAME, "f_1"),
280 | (MODULE_NAME, "A"),
281 | (MODULE_NAME, "B"),
282 | }
283 |
284 | # nested comprehensions and generators
285 | @op
286 | def f_6(x) -> int:
287 | x = storage.unwrap(x)
288 | # nested list comprehension
289 | a = sum([sum([f_1.f(y) for y in range(x)]) for z in range(x)])
290 | # nested comprehension with generator
291 | b = sum(sum(f_2.f(y) for y in range(x)) for z in range(storage.unwrap(f_3.f(x))))
292 | return a + b
293 |
294 | with storage:
295 | f_6(2)
296 |
297 | f_6_versions = storage.get_versioner().versions[MODULE_NAME, "f_6"]
298 | assert len(f_6_versions) == 1
299 | version = f_6_versions[list(f_6_versions.keys())[0]]
300 | assert set(version.support) == {
301 | (MODULE_NAME, "f_6"),
302 | (MODULE_NAME, "f_3"),
303 | (MODULE_NAME, "f_2"),
304 | (MODULE_NAME, "f_1"),
305 | (MODULE_NAME, "A"),
306 | (MODULE_NAME, "B"),
307 | }
308 | _test_version_reprs(storage=storage)
309 | storage.versions(f_6)
310 | storage.get_code(version_id=version.content_version)
311 |
312 |
--------------------------------------------------------------------------------
/docs/docs/topics/03_cf_files/03_cf_29_0.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
113 |
--------------------------------------------------------------------------------
/mandala/deps/tracers/sys_impl.py:
--------------------------------------------------------------------------------
1 | import types
2 | from ...common_imports import *
3 | from ..utils import (
4 | get_func_qualname,
5 | is_global_val,
6 | get_global_names_candidates,
7 | )
8 | from ..model import (
9 | DependencyGraph,
10 | CallableNode,
11 | TerminalData,
12 | TerminalNode,
13 | GlobalVarNode,
14 | )
15 | import sys
16 | import importlib
17 | from .tracer_base import TracerABC, get_closure_names
18 |
19 | ################################################################################
20 | ### tracer
21 | ################################################################################
22 | # control flow constants
23 | from .tracer_base import BREAK, CONTINUE, KEEP, MAIN, get_module_flow
24 |
25 | LEAF_SIGNAL = "leaf_signal"
26 |
27 | # constants for special Python function names
28 | LAMBDA = ""
29 | COMPREHENSIONS = ("", "", "", "")
30 | SKIP_FRAMES = tuple(list(COMPREHENSIONS) + [LAMBDA])
31 |
32 |
33 | class SysTracer(TracerABC):
34 | def __init__(
35 | self,
36 | paths: List[Path],
37 | graph: Optional[DependencyGraph] = None,
38 | strict: bool = True,
39 | allow_methods: bool = False,
40 | ):
41 | self.call_stack: List[Optional[CallableNode]] = []
42 | self.graph = DependencyGraph() if graph is None else graph
43 | self.paths = paths
44 | self.path_strs = [str(path) for path in paths]
45 | self.strict = strict
46 | self.allow_methods = allow_methods
47 |
48 | @staticmethod
49 | def leaf_signal(data):
50 | # a way to detect the end of a trace
51 | pass
52 |
53 | @staticmethod
54 | def register_leaf_event(trace_obj: types.FunctionType, data: Any):
55 | SysTracer.leaf_signal(data)
56 |
57 | @staticmethod
58 | def get_active_trace_obj() -> Optional[Any]:
59 | return sys.gettrace()
60 |
61 | @staticmethod
62 | def set_active_trace_obj(trace_obj: Any):
63 | sys.settrace(trace_obj)
64 |
65 | def _process_failure(self, msg: str):
66 | if self.strict:
67 | raise RuntimeError(msg)
68 | else:
69 | logger.warning(msg)
70 |
71 | def find_most_recent_call(self) -> Optional[CallableNode]:
72 | if len(self.call_stack) == 0:
73 | return None
74 | else:
75 | # return the most recent non-None obj on the stack
76 | for i in range(len(self.call_stack) - 1, -1, -1):
77 | call = self.call_stack[i]
78 | if isinstance(call, CallableNode):
79 | return call
80 | return None
81 |
82 | def __enter__(self):
83 | if sys.gettrace() is not None:
84 | # pre-check this is used correctly
85 | raise RuntimeError("Another tracer is already active")
86 |
87 | def tracer(frame: types.FrameType, event: str, arg: Any):
88 | if event not in ("call", "return"):
89 | return
90 | module_name = frame.f_globals.get("__name__")
91 | # fast check to rule out non-user code
92 | if event == "call":
93 | try:
94 | module = importlib.import_module(module_name)
95 | if not any(
96 | [
97 | module.__file__.startswith(path_str)
98 | for path_str in self.path_strs
99 | ]
100 | ):
101 | return
102 | except:
103 | if module_name != MAIN:
104 | return
105 | code_obj = frame.f_code
106 | func_name = code_obj.co_name
107 | if event == "return":
108 | logging.debug(f"Returning from {func_name}")
109 | if len(self.call_stack) > 0:
110 | popped = self.call_stack.pop()
111 | logging.debug(f"Popped {popped} from call stack")
112 | # some sanity checks
113 | if func_name in SKIP_FRAMES:
114 | if popped != func_name:
115 | self._process_failure(
116 | f"Expected to pop {func_name} from call stack, but popped {popped}"
117 | )
118 | else:
119 | if popped.obj_name.split(".")[-1] != func_name:
120 | self._process_failure(
121 | f"Expected to pop {func_name} from call stack, but popped {popped.obj_name}"
122 | )
123 | else:
124 | # something went wrong
125 | raise RuntimeError("Call stack is empty")
126 | return
127 |
128 | if func_name == LEAF_SIGNAL:
129 | data: TerminalData = frame.f_locals["data"]
130 | unique_id = "_".join(
131 | [
132 | data.op_internal_name,
133 | str(data.op_version),
134 | data.call_content_version,
135 | data.call_semantic_version,
136 | ]
137 | )
138 | node = TerminalNode(
139 | module_name=module_name, obj_name=unique_id, representation=data
140 | )
141 | most_recent_option = self.find_most_recent_call()
142 | if most_recent_option is not None:
143 | self.graph.add_edge(source=most_recent_option, target=node)
144 | # self.call_stack.append(None)
145 | return
146 |
147 | module_control_flow = get_module_flow(
148 | paths=self.paths, module_name=module_name
149 | )
150 | if module_control_flow in (BREAK, CONTINUE):
151 | frame.f_trace = None
152 | return
153 |
154 | logger.debug(f"Tracing call to {module_name}.{func_name}")
155 |
156 | ### get the qualified name of the function/method
157 | func_qualname = get_func_qualname(
158 | func_name=func_name, code=code_obj, frame=frame
159 | )
160 | if "." in func_qualname:
161 | if not self.allow_methods:
162 | raise RuntimeError(
163 | f"Methods are currently not supported: {func_qualname} from {module_name}"
164 | )
165 |
166 | ### detect use of closure variables
167 | closure_names = get_closure_names(
168 | code_obj=code_obj, func_qualname=func_qualname
169 | )
170 | if len(closure_names) > 0 and func_name not in SKIP_FRAMES:
171 | closure_values = {
172 | var: frame.f_locals.get(var, frame.f_globals.get(var, None))
173 | for var in closure_names
174 | }
175 | msg = f"Found closure variables accessed by function {module_name}.{func_name}:\n{closure_values}"
176 | self._process_failure(msg=msg)
177 |
178 | ### get the global variables used by the function
179 | globals_nodes = []
180 | for name in get_global_names_candidates(code=code_obj):
181 | # names used by the function; not all of them are global variables
182 | if name in frame.f_globals:
183 | global_val = frame.f_globals[name]
184 | if not is_global_val(global_val):
185 | continue
186 | node = GlobalVarNode.from_obj(
187 | obj=global_val, dep_key=(module_name, name)
188 | )
189 | globals_nodes.append(node)
190 |
191 | ### if this is a comprehension call, add the globals to the most
192 | ### recent tracked call
193 | if func_name in SKIP_FRAMES:
194 | most_recent_tracked_call = self.find_most_recent_call()
195 | assert most_recent_tracked_call is not None
196 | for global_node in globals_nodes:
197 | self.graph.add_edge(
198 | source=most_recent_tracked_call, target=global_node
199 | )
200 | self.call_stack.append(func_name)
201 | return tracer
202 |
203 | ### manage the call stack
204 | call_node = CallableNode.from_runtime(
205 | module_name=module_name, obj_name=func_qualname, code_obj=code_obj
206 | )
207 | self.graph.add_node(node=call_node)
208 | ### global variable edges from this function always exist
209 | for global_node in globals_nodes:
210 | self.graph.add_edge(source=call_node, target=global_node)
211 | ### call edges exist only if there is a caller on the stack
212 | if len(self.call_stack) > 0:
213 | # find the most recent tracked call
214 | most_recent_tracked_call = self.find_most_recent_call()
215 | if most_recent_tracked_call is not None:
216 | self.graph.add_edge(
217 | source=most_recent_tracked_call, target=call_node
218 | )
219 | self.call_stack.append(call_node)
220 | if len(self.call_stack) == 1:
221 | self.graph.roots.add(call_node.key)
222 | return tracer
223 |
224 | sys.settrace(tracer)
225 |
226 | def __exit__(self, *exc_info):
227 | sys.settrace(None) # Stop tracing
228 |
229 |
230 | class SuspendSysTraceContext:
231 | def __init__(self):
232 | self.suspended_trace = None
233 |
234 | def __enter__(self) -> "SuspendSysTraceContext":
235 | if sys.gettrace() is not None:
236 | self.suspended_trace = sys.gettrace()
237 | sys.settrace(None)
238 | return self
239 |
240 | def __exit__(self, *exc_info):
241 | if self.suspended_trace is not None:
242 | sys.settrace(self.suspended_trace)
243 | self.suspended_trace = None
244 |
--------------------------------------------------------------------------------
/docs/docs/tutorials/02_ml_files/02_ml_22_3.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
128 |
--------------------------------------------------------------------------------
/docs/docs/blog/01_cf_files/01_cf_21_0.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
129 |
--------------------------------------------------------------------------------
/mandala/deps/model.py:
--------------------------------------------------------------------------------
1 | import textwrap
2 | from abc import abstractmethod, ABC
3 | import types
4 |
5 | from ..common_imports import *
6 | from ..utils import get_content_hash
7 | from ..viz import (
8 | write_output,
9 | )
10 | from ..model import Ref
11 |
12 | from .utils import (
13 | DepKey,
14 | load_obj,
15 | get_runtime_description,
16 | extract_code,
17 | unknown_function,
18 | UNKNOWN_GLOBAL_VAR,
19 | )
20 |
21 |
22 | class Node(ABC):
23 | def __init__(self, module_name: str, obj_name: str, representation: Any):
24 | self.module_name = module_name
25 | self.obj_name = obj_name
26 | self.representation = representation
27 |
28 | @property
29 | def key(self) -> DepKey:
30 | return (self.module_name, self.obj_name)
31 |
32 | def present_key(self) -> str:
33 | raise NotImplementedError()
34 |
35 | @staticmethod
36 | @abstractmethod
37 | def represent(obj: Any) -> Any:
38 | raise NotImplementedError
39 |
40 | @abstractmethod
41 | def content(self) -> Any:
42 | raise NotImplementedError
43 |
44 | @abstractmethod
45 | def readable_content(self) -> str:
46 | raise NotImplementedError
47 |
48 | @property
49 | @abstractmethod
50 | def content_hash(self) -> str:
51 | raise NotImplementedError
52 |
53 | def load_obj(self, skip_missing: bool, skip_silently: bool) -> Any:
54 | obj, found = load_obj(module_name=self.module_name, obj_name=self.obj_name)
55 | if not found:
56 | msg = f"{self.present_key()} not found"
57 | if skip_missing:
58 | if skip_silently:
59 | logger.debug(msg)
60 | else:
61 | logger.warning(msg)
62 | if hasattr(self, "FALLBACK_OBJ"):
63 | return self.FALLBACK_OBJ
64 | else:
65 | raise ValueError(f"No fallback object defined for {self.__class__}")
66 | else:
67 | raise ValueError(msg)
68 | return obj
69 |
70 |
71 | class CallableNode(Node):
72 | FALLBACK_OBJ = unknown_function
73 |
74 | def __init__(
75 | self,
76 | module_name: str,
77 | obj_name: str,
78 | representation: Optional[str],
79 | runtime_description: str,
80 | ):
81 | self.module_name = module_name
82 | self.obj_name = obj_name
83 | self.runtime_description = runtime_description
84 | if representation is not None:
85 | self._set_representation(value=representation)
86 | else:
87 | self._representation = None
88 | self._content_hash = None
89 |
90 | @staticmethod
91 | def from_obj(obj: Any, dep_key: DepKey) -> "CallableNode":
92 | representation = CallableNode.represent(obj=obj)
93 | code_obj = extract_code(obj)
94 | runtime_description = get_runtime_description(code=code_obj)
95 | return CallableNode(
96 | module_name=dep_key[0],
97 | obj_name=dep_key[1],
98 | representation=representation,
99 | runtime_description=runtime_description,
100 | )
101 |
102 | @staticmethod
103 | def from_runtime(
104 | module_name: str,
105 | obj_name: str,
106 | code_obj: types.CodeType,
107 | ) -> "CallableNode":
108 | return CallableNode(
109 | module_name=module_name,
110 | obj_name=obj_name,
111 | representation=None,
112 | runtime_description=get_runtime_description(code=code_obj),
113 | )
114 |
115 | @property
116 | def representation(self) -> str:
117 | return self._representation
118 |
119 | def _set_representation(self, value: str):
120 | assert isinstance(value, str)
121 | self._representation = value
122 | self._content_hash = get_content_hash(value)
123 |
124 | @representation.setter
125 | def representation(self, value: str):
126 | self._set_representation(value)
127 |
128 | @property
129 | def is_method(self) -> bool:
130 | return "." in self.obj_name
131 |
132 | def present_key(self) -> str:
133 | return f"function {self.obj_name} from module {self.module_name}"
134 |
135 | @property
136 | def class_name(self) -> str:
137 | assert self.is_method
138 | return ".".join(self.obj_name.split(".")[:-1])
139 |
140 | @staticmethod
141 | def represent(
142 | obj: Union[types.FunctionType, types.CodeType, Callable],
143 | allow_fallback: bool = False,
144 | ) -> str:
145 | if type(obj).__name__ == "Op":
146 | obj = obj.f
147 | if not isinstance(obj, (types.FunctionType, types.MethodType, types.CodeType)):
148 | logger.warning(f"Found {obj} of type {type(obj)}")
149 | try:
150 | source = inspect.getsource(obj)
151 | except Exception as e:
152 | msg = f"Could not get source for {obj} because {e}"
153 | if allow_fallback:
154 | source = inspect.getsource(CallableNode.FALLBACK_OBJ)
155 | logger.warning(msg)
156 | else:
157 | raise RuntimeError(msg)
158 | # strip whitespace to prevent different sources looking the same in the
159 | # ui
160 | lines = source.splitlines()
161 | lines = [line.rstrip() for line in lines]
162 | source = "\n".join(lines)
163 | return source
164 |
165 | def content(self) -> str:
166 | return self.representation
167 |
168 | def readable_content(self) -> str:
169 | return self.representation
170 |
171 | @property
172 | def content_hash(self) -> str:
173 | assert isinstance(self._content_hash, str)
174 | return self._content_hash
175 |
176 |
177 | class GlobalVarNode(Node):
178 | FALLBACK_OBJ = UNKNOWN_GLOBAL_VAR
179 |
180 | def __init__(
181 | self,
182 | module_name: str,
183 | obj_name: str,
184 | # (content hash, truncated repr)
185 | representation: Tuple[str, str],
186 | ):
187 | self.module_name = module_name
188 | self.obj_name = obj_name
189 | self._representation = representation
190 |
191 | @staticmethod
192 | def from_obj(obj: Any, dep_key: DepKey,
193 | skip_unhashable: bool = False,
194 | skip_silently: bool = False,) -> "GlobalVarNode":
195 | representation = GlobalVarNode.represent(obj=obj, skip_unhashable=skip_unhashable, skip_silently=skip_silently)
196 | return GlobalVarNode(
197 | module_name=dep_key[0],
198 | obj_name=dep_key[1],
199 | representation=representation,
200 | )
201 |
202 | @property
203 | def representation(self) -> Tuple[str, str]:
204 | return self._representation
205 |
206 | @staticmethod
207 | def represent(obj: Any, skip_unhashable: bool = True,
208 | skip_silently: bool = False,
209 | ) -> Tuple[str, str]:
210 | """
211 | Return a hash of this global variable's value + a truncated
212 | representation useful for debugging/printing.
213 |
214 | If `obj` is a `Ref`, the content hash is reused from the `Ref` object.
215 | This is so that you can avoid repeatedly hashing the same (potentially
216 | large) object any time the code state needs to be synced.
217 | """
218 | truncated_repr = textwrap.shorten(text=repr(obj), width=80)
219 | if isinstance(obj, Ref):
220 | content_hash = obj.cid
221 | else:
222 | try:
223 | content_hash = get_content_hash(obj=obj)
224 | except Exception as e:
225 | shortened_exception = textwrap.shorten(text=str(e), width=80)
226 | msg = f"Failed to hash global variable {truncated_repr} of type {type(obj)}, because {shortened_exception}"
227 | if skip_unhashable:
228 | content_hash = UNKNOWN_GLOBAL_VAR
229 | if skip_silently:
230 | logger.debug(msg)
231 | else:
232 | logger.warning(msg)
233 | else:
234 | raise RuntimeError(msg)
235 | return content_hash, truncated_repr
236 |
237 | def present_key(self) -> str:
238 | return f"global variable {self.obj_name} from module {self.module_name}"
239 |
240 | def content(self) -> str:
241 | return self.representation
242 |
243 | def readable_content(self) -> str:
244 | return self.representation[1]
245 |
246 | @property
247 | def content_hash(self) -> str:
248 | assert isinstance(self.representation, tuple)
249 | return self.representation[0]
250 |
251 |
252 | class TerminalData:
253 | def __init__(
254 | self,
255 | op_internal_name: str,
256 | op_version: int,
257 | call_content_version: str,
258 | call_semantic_version: str,
259 | # data: Tuple[Tuple[str, int], Tuple[str, str]],
260 | dep_key: DepKey,
261 | ):
262 | # ((internal name, version), (content_version, semantic_version))
263 | self.op_internal_name = op_internal_name
264 | self.op_version = op_version
265 | self.call_content_version = call_content_version
266 | self.call_semantic_version = call_semantic_version
267 | self.dep_key = dep_key
268 |
269 |
270 | class TerminalNode(Node):
271 | def __init__(self, module_name: str, obj_name: str, representation: TerminalData):
272 | self.module_name = module_name
273 | self.obj_name = obj_name
274 | self.representation = representation
275 |
276 | @property
277 | def key(self) -> DepKey:
278 | return self.module_name, self.obj_name
279 |
280 | def present_key(self) -> str:
281 | raise NotImplementedError
282 |
283 | @property
284 | def content_hash(self) -> str:
285 | raise NotImplementedError
286 |
287 | def content(self) -> Any:
288 | raise NotImplementedError
289 |
290 | def readable_content(self) -> str:
291 | raise NotImplementedError
292 |
293 | @staticmethod
294 | def represent(obj: Any) -> Any:
295 | raise NotImplementedError
296 |
297 |
298 | class DependencyGraph:
299 | def __init__(self):
300 | self.nodes: Dict[DepKey, Node] = {}
301 | self.roots: Set[DepKey] = set()
302 | self.edges: Set[Tuple[DepKey, DepKey]] = set()
303 |
304 | def get_trace_state(self) -> Tuple[DepKey, Dict[DepKey, Node]]:
305 | if len(self.roots) != 1:
306 | raise ValueError(f"Expected exactly one root, got {len(self.roots)}")
307 | component = list(self.roots)[0]
308 | return component, self.nodes
309 |
310 | def show(self, path: Optional[Path] = None, how: str = "none"):
311 | dot = to_dot(self)
312 | output_ext = "svg" if how in ["browser"] else "png"
313 | return write_output(
314 | dot_string=dot, output_path=path, output_ext=output_ext, show_how=how
315 | )
316 |
317 | def __repr__(self) -> str:
318 | if len(self.nodes) == 0:
319 | return "DependencyGraph()"
320 | return to_string(self)
321 |
322 | def add_node(self, node: Node):
323 | self.nodes[node.key] = node
324 |
325 | def add_edge(self, source: Node, target: Node):
326 | if source.key not in self.nodes:
327 | self.add_node(source)
328 | if target.key not in self.nodes:
329 | self.add_node(target)
330 | self.edges.add((source.key, target.key))
331 |
332 |
333 | from .viz import to_dot, to_string
334 |
--------------------------------------------------------------------------------
/docs/docs/tutorials/02_ml_files/02_ml_22_1.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
6 |
7 |
136 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2022 mandala authors
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/docs_source/topics/01_storage_and_ops.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# `Storage` & the `@op` Decorator\n",
8 | " \n",
9 | "
\n",
10 | "\n",
11 | "A `Storage` object holds all data (saved calls, code and dependencies) for a\n",
12 | "collection of memoized functions. In a given project, you should have just one\n",
13 | "`Storage` and many `@op`s connected to it. This way, the calls to memoized\n",
14 | "functions create a queriable web of interlinked objects. "
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {
21 | "execution": {
22 | "iopub.execute_input": "2024-07-11T14:31:35.163619Z",
23 | "iopub.status.busy": "2024-07-11T14:31:35.162794Z",
24 | "iopub.status.idle": "2024-07-11T14:31:35.173665Z",
25 | "shell.execute_reply": "2024-07-11T14:31:35.172980Z"
26 | }
27 | },
28 | "outputs": [],
29 | "source": [
30 | "# for Google Colab\n",
31 | "try:\n",
32 | " import google.colab\n",
33 | " !pip install git+https://github.com/amakelov/mandala\n",
34 | "except:\n",
35 | " pass"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "metadata": {},
41 | "source": [
42 | "## Creating a `Storage`\n",
43 | "\n",
44 | "When creating a storage, you must decide if it will be in-memory or persisted on\n",
45 | "disk, and whether the storage will automatically version the `@op`s used with\n",
46 | "it:"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 2,
52 | "metadata": {
53 | "execution": {
54 | "iopub.execute_input": "2024-07-11T14:31:35.176299Z",
55 | "iopub.status.busy": "2024-07-11T14:31:35.176085Z",
56 | "iopub.status.idle": "2024-07-11T14:31:37.876054Z",
57 | "shell.execute_reply": "2024-07-11T14:31:37.875495Z"
58 | }
59 | },
60 | "outputs": [],
61 | "source": [
62 | "from mandala.imports import Storage\n",
63 | "import os\n",
64 | "\n",
65 | "DB_PATH = 'my_persistent_storage.db'\n",
66 | "if os.path.exists(DB_PATH):\n",
67 | " os.remove(DB_PATH)\n",
68 | "\n",
69 | "storage = Storage(\n",
70 | " # omit for an in-memory storage\n",
71 | " db_path=DB_PATH,\n",
72 | " # omit to disable automatic dependency tracking & versioning\n",
73 | " # use \"__main__\" to only track functions defined in the current session\n",
74 | " deps_path='__main__', \n",
75 | ")"
76 | ]
77 | },
78 | {
79 | "cell_type": "markdown",
80 | "metadata": {},
81 | "source": [
82 | "## Creating `@op`s and saving calls to them\n",
83 | "**Any Python function can be decorated with `@op`**:"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 3,
89 | "metadata": {
90 | "execution": {
91 | "iopub.execute_input": "2024-07-11T14:31:37.880287Z",
92 | "iopub.status.busy": "2024-07-11T14:31:37.879681Z",
93 | "iopub.status.idle": "2024-07-11T14:31:37.914152Z",
94 | "shell.execute_reply": "2024-07-11T14:31:37.913319Z"
95 | }
96 | },
97 | "outputs": [],
98 | "source": [
99 | "from mandala.imports import op\n",
100 | "\n",
101 | "@op \n",
102 | "def sum_args(a, *args, b=1, **kwargs):\n",
103 | " return a + sum(args) + b + sum(kwargs.values())"
104 | ]
105 | },
106 | {
107 | "cell_type": "markdown",
108 | "metadata": {},
109 | "source": [
110 | "In general, calling `sum_args` will behave as if the `@op` decorator is not\n",
111 | "there. `@op`-decorated functions will interact with a `Storage` instance **only\n",
112 | "when** called inside a `with storage:` block:"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": 4,
118 | "metadata": {
119 | "execution": {
120 | "iopub.execute_input": "2024-07-11T14:31:37.918828Z",
121 | "iopub.status.busy": "2024-07-11T14:31:37.918430Z",
122 | "iopub.status.idle": "2024-07-11T14:31:38.022162Z",
123 | "shell.execute_reply": "2024-07-11T14:31:38.021195Z"
124 | }
125 | },
126 | "outputs": [
127 | {
128 | "name": "stdout",
129 | "output_type": "stream",
130 | "text": [
131 | "AtomRef(42, hid=168...)\n"
132 | ]
133 | }
134 | ],
135 | "source": [
136 | "with storage: # all `@op` calls inside this block use `storage`\n",
137 | " s = sum_args(6, 7, 8, 9, c=11,)\n",
138 | " print(s)"
139 | ]
140 | },
141 | {
142 | "cell_type": "markdown",
143 | "metadata": {},
144 | "source": [
145 | "This code runs the call to `sum_args`, and saves the inputs and outputs in the\n",
146 | "`storage` object, so that doing the same call later will directly load the saved\n",
147 | "outputs."
148 | ]
149 | },
150 | {
151 | "cell_type": "markdown",
152 | "metadata": {},
153 | "source": [
154 | "### When should something be an `@op`?\n",
155 | "As a general guide, you should make something an `@op` if you want to save its\n",
156 | "outputs, e.g. if they take a long time to compute but you need them for later\n",
157 | "analysis. Since `@op` [encourages\n",
158 | "composition](https://amakelov.github.io/mandala/02_retracing/#how-op-encourages-composition),\n",
159 | "you should aim to have `@op`s work on the outputs of other `@op`s, or on the\n",
160 | "[collections and/or items](https://amakelov.github.io/mandala/05_collections/)\n",
161 | "of outputs of other `@op`s."
162 | ]
163 | },
164 | {
165 | "cell_type": "markdown",
166 | "metadata": {},
167 | "source": [
168 | "## Working with `@op` outputs (`Ref`s)\n",
169 | "The objects (e.g. `s`) returned by `@op`s are always instances of a subclass of\n",
170 | "`Ref` (e.g., `AtomRef`), i.e. **references to objects in the storage**. Every\n",
171 | "`Ref` contains two metadata fields:\n",
172 | "\n",
173 | "- `cid`: a hash of the **content** of the object\n",
174 | "- `hid`: a hash of the **computational history** of the object, which is the precise\n",
175 | "composition of `@op`s that created this ref. \n",
176 | "\n",
177 | "Two `Ref`s with the same `cid` may have different `hid`s, and `hid` is the\n",
178 | "unique identifier of `Ref`s in the storage. However, only 1 copy per unique\n",
179 | "`cid` is stored to avoid duplication in the storage.\n",
180 | "\n",
181 | "### `Ref`s can be in memory or not\n",
182 | "Additionally, `Ref`s have the `in_memory` property, which indicates if the\n",
183 | "underlying object is present in the `Ref` or if this is a \"lazy\" `Ref` which\n",
184 | "only contains metadata. **`Ref`s are only loaded in memory when needed for a new\n",
185 | "call to an `@op`**. For example, re-running the last code block:"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": 5,
191 | "metadata": {
192 | "execution": {
193 | "iopub.execute_input": "2024-07-11T14:31:38.032711Z",
194 | "iopub.status.busy": "2024-07-11T14:31:38.032185Z",
195 | "iopub.status.idle": "2024-07-11T14:31:38.077234Z",
196 | "shell.execute_reply": "2024-07-11T14:31:38.076082Z"
197 | }
198 | },
199 | "outputs": [
200 | {
201 | "name": "stdout",
202 | "output_type": "stream",
203 | "text": [
204 | "AtomRef(hid=168..., in_memory=False)\n"
205 | ]
206 | }
207 | ],
208 | "source": [
209 | "with storage: \n",
210 | " s = sum_args(6, 7, 8, 9, c=11,)\n",
211 | " print(s)"
212 | ]
213 | },
214 | {
215 | "cell_type": "markdown",
216 | "metadata": {},
217 | "source": [
218 | "To get the object wrapped by a `Ref`, call `storage.unwrap`:"
219 | ]
220 | },
221 | {
222 | "cell_type": "code",
223 | "execution_count": 6,
224 | "metadata": {
225 | "execution": {
226 | "iopub.execute_input": "2024-07-11T14:31:38.081995Z",
227 | "iopub.status.busy": "2024-07-11T14:31:38.081258Z",
228 | "iopub.status.idle": "2024-07-11T14:31:38.115821Z",
229 | "shell.execute_reply": "2024-07-11T14:31:38.114805Z"
230 | }
231 | },
232 | "outputs": [
233 | {
234 | "data": {
235 | "text/plain": [
236 | "42"
237 | ]
238 | },
239 | "execution_count": 6,
240 | "metadata": {},
241 | "output_type": "execute_result"
242 | }
243 | ],
244 | "source": [
245 | "storage.unwrap(s) # loads from storage only if necessary"
246 | ]
247 | },
248 | {
249 | "cell_type": "markdown",
250 | "metadata": {},
251 | "source": [
252 | "### Other useful `Storage` methods\n",
253 | "\n",
254 | "- `Storage.attach(inplace: bool)`: like `unwrap`, but puts the objects in the\n",
255 | "`Ref`s if they are not in-memory.\n",
256 | "- `Storage.load_ref(hid: str, in_memory: bool)`: load a `Ref` by its history ID,\n",
257 | "optionally also loading the underlying object."
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": 7,
263 | "metadata": {
264 | "execution": {
265 | "iopub.execute_input": "2024-07-11T14:31:38.120583Z",
266 | "iopub.status.busy": "2024-07-11T14:31:38.119825Z",
267 | "iopub.status.idle": "2024-07-11T14:31:38.150984Z",
268 | "shell.execute_reply": "2024-07-11T14:31:38.150207Z"
269 | }
270 | },
271 | "outputs": [
272 | {
273 | "name": "stdout",
274 | "output_type": "stream",
275 | "text": [
276 | "AtomRef(42, hid=168...)\n",
277 | "AtomRef(42, hid=168...)\n"
278 | ]
279 | }
280 | ],
281 | "source": [
282 | "print(storage.attach(obj=s, inplace=False))\n",
283 | "print(storage.load_ref(s.hid))"
284 | ]
285 | },
286 | {
287 | "cell_type": "markdown",
288 | "metadata": {},
289 | "source": [
290 | "## Working with `Call` objects\n",
291 | "Besides `Ref`s, the other kind of object in the storage is the `Call`, which\n",
292 | "stores references to the inputs and outputs of a call to an `@op`, together with\n",
293 | "metadata that mirrors the `Ref` metadata:\n",
294 | "\n",
295 | "- `Call.cid`: a content ID for the call, based on the `@op`'s identity, its\n",
296 | "version at the time of the call, and the `cid`s of the inputs\n",
297 | "- `Call.hid`: a history ID for the call, the same as `Call.cid`, but using the \n",
298 | "`hid`s of the inputs.\n",
299 | "\n",
300 | "**For every `Ref` history ID, there's at most one `Call` that has an output with\n",
301 | "this history ID**, and if it exists, this call can be found by calling\n",
302 | "`storage.get_ref_creator()`: "
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": 8,
308 | "metadata": {
309 | "execution": {
310 | "iopub.execute_input": "2024-07-11T14:31:38.154659Z",
311 | "iopub.status.busy": "2024-07-11T14:31:38.154326Z",
312 | "iopub.status.idle": "2024-07-11T14:31:38.205594Z",
313 | "shell.execute_reply": "2024-07-11T14:31:38.203966Z"
314 | }
315 | },
316 | "outputs": [
317 | {
318 | "name": "stdout",
319 | "output_type": "stream",
320 | "text": [
321 | "Call(sum_args, hid=f99...)\n"
322 | ]
323 | },
324 | {
325 | "data": {
326 | "text/plain": [
327 | "{'a': AtomRef(hid=c6a..., in_memory=False),\n",
328 | " 'args_0': AtomRef(hid=e0f..., in_memory=False),\n",
329 | " 'args_1': AtomRef(hid=479..., in_memory=False),\n",
330 | " 'args_2': AtomRef(hid=c37..., in_memory=False),\n",
331 | " 'b': AtomRef(hid=610..., in_memory=False),\n",
332 | " 'c': AtomRef(hid=a33..., in_memory=False)}"
333 | ]
334 | },
335 | "metadata": {},
336 | "output_type": "display_data"
337 | },
338 | {
339 | "data": {
340 | "text/plain": [
341 | "{'output_0': AtomRef(hid=168..., in_memory=False)}"
342 | ]
343 | },
344 | "metadata": {},
345 | "output_type": "display_data"
346 | }
347 | ],
348 | "source": [
349 | "call = storage.get_ref_creator(ref=s)\n",
350 | "print(call)\n",
351 | "display(call.inputs)\n",
352 | "display(call.outputs)"
353 | ]
354 | }
355 | ],
356 | "metadata": {
357 | "language_info": {
358 | "codemirror_mode": {
359 | "name": "ipython",
360 | "version": 3
361 | },
362 | "file_extension": ".py",
363 | "mimetype": "text/x-python",
364 | "name": "python",
365 | "nbconvert_exporter": "python",
366 | "pygments_lexer": "ipython3",
367 | "version": "3.10.8"
368 | }
369 | },
370 | "nbformat": 4,
371 | "nbformat_minor": 2
372 | }
373 |
--------------------------------------------------------------------------------