├── test
├── __init__.py
├── config
│ ├── __init__.py
│ └── test_meta.py
├── entry
│ ├── __init__.py
│ ├── cli
│ │ ├── __init__.py
│ │ └── test_version.py
│ └── script
│ │ └── __init__.py
├── tree
│ ├── __init__.py
│ ├── common
│ │ ├── __init__.py
│ │ ├── test_base.py
│ │ └── test_delay.py
│ ├── func
│ │ ├── __init__.py
│ │ ├── test_inner.py
│ │ ├── test_left.py
│ │ ├── test_strict.py
│ │ └── test_outer.py
│ ├── tree
│ │ ├── __init__.py
│ │ ├── test_tree.py
│ │ ├── test_io.py
│ │ └── test_graph.py
│ ├── general
│ │ ├── __init__.py
│ │ ├── test_meta.py
│ │ ├── test_fast.py
│ │ ├── test_general_benchmark.py
│ │ └── test_general.py
│ └── integration
│ │ ├── __init__.py
│ │ ├── test_init.py
│ │ └── test_jax.py
├── utils
│ ├── __init__.py
│ ├── test_random.py
│ └── test_tree.py
├── testings
│ ├── __init__.py
│ └── mapping.py
└── tests
│ ├── __init__.py
│ ├── test_benchmark.py
│ └── utils.py
├── benchmark
├── __init__.py
├── jax
│ └── __init__.py
├── deepmind
│ └── __init__.py
├── facebook
│ ├── __init__.py
│ └── test_nest.py
├── tianshou
│ └── __init__.py
└── base.py
├── docs
├── source
│ ├── _libs
│ │ ├── .keep
│ │ └── dm.py
│ ├── _static
│ │ ├── .keep
│ │ ├── wechat.png
│ │ └── title-banner.png
│ ├── _templates
│ │ ├── .keep
│ │ ├── page.html
│ │ └── versions.html
│ ├── tutorials
│ │ ├── cli_usage
│ │ │ ├── help_demo.demo.sh
│ │ │ ├── version_demo.demo.sh
│ │ │ ├── graph_help_demo.demo.sh
│ │ │ ├── export_help_demo.demo.sh
│ │ │ ├── graph_demo_1.demo.sh
│ │ │ ├── graph_demo_2.demo.sh
│ │ │ ├── graph_demo_3.demo.sh
│ │ │ ├── export_demo_1.demo.sh
│ │ │ ├── export_demo_4.demo.sh
│ │ │ ├── export_demo_3.demo.sh
│ │ │ ├── export_demo_2.demo.sh
│ │ │ ├── graph_demo_7.demo.sh
│ │ │ ├── graph_demo_8.demo.sh
│ │ │ ├── graph_demo_5.demo.sh
│ │ │ ├── graph_demo_6.demo.sh
│ │ │ ├── graph_demo_9.demo.sh
│ │ │ ├── graph_demo_4.demo.sh
│ │ │ ├── graph_demo_10.demo.sh
│ │ │ ├── export_demo_5.demo.sh
│ │ │ ├── tree_demo.py
│ │ │ ├── large_tree_demo.py
│ │ │ ├── node_share_demo.py
│ │ │ └── share_demo.py
│ │ ├── installation
│ │ │ ├── cli_demo.demo.sh
│ │ │ └── index.rst
│ │ ├── quick_start
│ │ │ ├── display_tree.demo.sh
│ │ │ ├── display_complex_tree.demo.sh
│ │ │ ├── create_a_tree.demo.py
│ │ │ ├── complex_demo.py
│ │ │ ├── simple_demo.py
│ │ │ ├── create_a_complex_tree.demo.py
│ │ │ └── index.rst
│ │ ├── advanced_usage
│ │ │ ├── inherit_numpy_demo.demo.py
│ │ │ ├── diy_class_x_tv.demo.py
│ │ │ ├── functional_python_demo.demo.py
│ │ │ ├── mapping_demo.demo.py
│ │ │ ├── unflatten_demo.demo.py
│ │ │ ├── flatten_demo.demo.py
│ │ │ ├── pickle_demo_1.demo.py
│ │ │ ├── walk_demo.demo.py
│ │ │ ├── jsonify_demo.demo.py
│ │ │ ├── dump_demo_2.demo.py
│ │ │ ├── reduce_demo_2.demo.py
│ │ │ ├── diy_class_x_demo_1.demo.py
│ │ │ ├── filter_demo.demo.py
│ │ │ ├── mask_demo.demo.py
│ │ │ ├── outer_demo.demo.py
│ │ │ ├── filter_eq_demo.demo.py
│ │ │ ├── reduce_demo_1.demo.py
│ │ │ ├── left_demo.demox.py
│ │ │ ├── inner_demo.demo.py
│ │ │ ├── strict_demo.demox.py
│ │ │ ├── reduce_demo_3.demo.py
│ │ │ ├── subside_demo.demo.py
│ │ │ ├── dump_demo_1.demo.py
│ │ │ ├── union_demo.demo.py
│ │ │ ├── inherit_demo.demox.py
│ │ │ ├── clone_demo.demo.py
│ │ │ ├── missing_demo.demo.py
│ │ │ ├── typetrans_demo.demo.py
│ │ │ ├── diy_class_self_demo.demo.py
│ │ │ ├── dump_compression_demo.demo.py
│ │ │ ├── strict_demo_show.demox.py
│ │ │ ├── diy_class_demo.demo.py
│ │ │ ├── rise_demo_1.demo.py
│ │ │ ├── diy_class_x_demo_2.demo.py
│ │ │ ├── left_demo_2.gv
│ │ │ ├── strict_demo_2.gv
│ │ │ ├── oo_demo.demo.py
│ │ │ ├── inherit_demo_2.gv
│ │ │ ├── rise_demo_2.demo.py
│ │ │ ├── diy_class_x_demo_3.demo.py
│ │ │ ├── inherit_demo_1.gv
│ │ │ ├── left_demo_1.gv
│ │ │ ├── inner_demo_1.gv
│ │ │ ├── missing_demo_2.gv
│ │ │ ├── strict_demo_1.gv
│ │ │ ├── missing_demo_1.gv
│ │ │ └── outer_demo_1.gv
│ │ ├── main_idea
│ │ │ ├── treevalue_demo_2.demo.py
│ │ │ ├── treevalue_demo_1.demo.py
│ │ │ ├── native_python_demo.demo.py
│ │ │ ├── treevalue_demo.gv
│ │ │ └── treelize_demo.gv
│ │ ├── basic_usage
│ │ │ ├── tree_support_primitive.demo.py
│ │ │ ├── index_and_slice.demo.py
│ │ │ ├── tree_support_2.gv
│ │ │ ├── calculation.demo.py
│ │ │ ├── tree_support.demo.py
│ │ │ ├── calculation_self.demo.py
│ │ │ ├── slice_index_operation.gv
│ │ │ ├── edit_tree.demo.py
│ │ │ ├── tree_support_1.gv
│ │ │ ├── index_operation.gv
│ │ │ ├── calculation_sub_and_xor.gv
│ │ │ ├── edit_tree_1.gv
│ │ │ ├── calculation_add.gv
│ │ │ └── edit_tree_2.gv
│ │ └── plugins
│ │ │ ├── potc_demo.demo.py
│ │ │ └── index.rst
│ ├── api_doc
│ │ ├── config
│ │ │ ├── index.rst
│ │ │ └── meta.rst
│ │ ├── utils
│ │ │ ├── index.rst
│ │ │ ├── formattree.rst
│ │ │ ├── build_graph.demo.py
│ │ │ ├── color.rst
│ │ │ ├── random.rst
│ │ │ ├── build_graph_complex.demo.py
│ │ │ └── tree.rst
│ │ └── tree
│ │ │ ├── index.rst
│ │ │ ├── graphics.demo.py
│ │ │ ├── func.rst
│ │ │ ├── general.rst
│ │ │ ├── graphics_dup_value.demo.py
│ │ │ ├── integration.rst
│ │ │ └── common.rst
│ ├── best_practice
│ │ ├── sklearn
│ │ │ ├── heading_of_pca.jpg
│ │ │ ├── sklearn.demo.py
│ │ │ └── index.rst
│ │ └── numpy
│ │ │ ├── with_treevalue.py
│ │ │ ├── numpy.demo.py
│ │ │ ├── without_treevalue.py
│ │ │ └── index.rst
│ ├── _shims
│ │ └── treevalue
│ ├── contribute
│ │ └── architecture
│ │ │ ├── get_tree_demo.demo.py
│ │ │ ├── tree_demo.demo.py
│ │ │ └── architecture.puml
│ ├── graphviz.mk
│ ├── diagrams.mk
│ ├── notebook.mk
│ ├── comparison
│ │ ├── generic.rst
│ │ └── environment.ipynb
│ ├── index.rst
│ └── demos.mk
├── main_page.html
└── Makefile
├── runs
├── artifacts
│ └── .keep
└── Makefile
├── treevalue
├── config
│ ├── __init__.py
│ └── meta.py
├── entry
│ ├── script
│ │ └── __init__.py
│ ├── __init__.py
│ └── cli
│ │ ├── __init__.py
│ │ ├── dispatch.py
│ │ ├── base.py
│ │ ├── io.py
│ │ └── utils.py
├── tree
│ ├── general
│ │ ├── __init__.py
│ │ └── fast.py
│ ├── func
│ │ ├── __init__.py
│ │ ├── cfunc.pxd
│ │ └── modes.pxd
│ ├── __init__.py
│ ├── integration
│ │ ├── base.pxd
│ │ ├── cjax.pxd
│ │ ├── ctorch.pxd
│ │ ├── base.pyx
│ │ ├── __init__.py
│ │ ├── jax.py
│ │ ├── general.pxd
│ │ ├── torch.py
│ │ ├── cjax.pyx
│ │ └── ctorch.pyx
│ ├── common
│ │ ├── __init__.py
│ │ ├── base.pxd
│ │ ├── delay.pxd
│ │ ├── storage.pxd
│ │ └── base.pyx
│ └── tree
│ │ ├── service.pxd
│ │ ├── __init__.py
│ │ ├── flatten.pxd
│ │ ├── structural.pxd
│ │ ├── functional.pxd
│ │ ├── tree.pxd
│ │ └── constraint.pxd
├── __init__.py
└── utils
│ ├── __init__.py
│ └── random.py
├── requirements-potc.txt
├── .coveragerc
├── requirements.txt
├── pytest.ini
├── requirements-test-extra.txt
├── requirements-benchmark.txt
├── MANIFEST.in
├── requirements-build.txt
├── codecov.yml
├── install_test.sh
├── CONTRIBUTING.md
├── .github
├── PULL_REQUEST_TEMPLATE.md
├── ISSUE_TEMPLATE
│ └── custom.md
└── workflows
│ ├── badge.yml
│ ├── doc.yml
│ └── run.yml
├── examples
├── Makefile
└── README.md
├── requirements-doc.txt
├── requirements-test.txt
├── pyproject.toml
├── bmtrans.py
├── cloc.sh
├── setup.py
└── Makefile
/test/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/benchmark/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/source/_libs/.keep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/runs/artifacts/.keep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/config/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/entry/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/tree/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/benchmark/jax/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/source/_static/.keep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/source/_templates/.keep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/entry/cli/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/tree/common/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/tree/func/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/tree/tree/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/treevalue/config/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/benchmark/deepmind/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/benchmark/facebook/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/benchmark/tianshou/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/entry/script/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/tree/general/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/tree/integration/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/treevalue/entry/script/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements-potc.txt:
--------------------------------------------------------------------------------
1 | potc-treevalue>=0.0.1
--------------------------------------------------------------------------------
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | plugins = Cython.Coverage
3 |
--------------------------------------------------------------------------------
/treevalue/entry/__init__.py:
--------------------------------------------------------------------------------
1 | from .script import *
2 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/help_demo.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue -h
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/version_demo.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue -v
--------------------------------------------------------------------------------
/test/testings/__init__.py:
--------------------------------------------------------------------------------
1 | from .mapping import CustomMapping
2 |
--------------------------------------------------------------------------------
/test/tests/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import float_eq, eq_extend
2 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/graph_help_demo.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue graph -h
--------------------------------------------------------------------------------
/treevalue/entry/cli/__init__.py:
--------------------------------------------------------------------------------
1 | from .dispatch import treevalue_cli
2 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/export_help_demo.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue export -h
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | enum_tools
2 | graphviz>=0.17
3 | dill>=0.3.4
4 | click>=7.1.0
5 | hbutils>=0.9.1
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/graph_demo_1.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue graph -t 'tree_demo.t1' -o 'only_t1.dat.svg'
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/graph_demo_2.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue graph -t 'tree_demo.*' -o 't1_t2_t3.dat.svg'
--------------------------------------------------------------------------------
/docs/source/tutorials/installation/cli_demo.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue -v &&
2 | echo '' &&
3 | treevalue -h
4 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | timeout = 60
3 | markers =
4 | unittest
5 | benchmark
6 | ignore
7 |
--------------------------------------------------------------------------------
/docs/source/_static/wechat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/treevalue/HEAD/docs/source/_static/wechat.png
--------------------------------------------------------------------------------
/treevalue/tree/general/__init__.py:
--------------------------------------------------------------------------------
1 | from .fast import FastTreeValue
2 | from .general import general_tree_value
3 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/graph_demo_3.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue graph -t 'node_share_demo.*' -o 'shared_nodes.dat.svg'
2 |
--------------------------------------------------------------------------------
/docs/source/_static/title-banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/treevalue/HEAD/docs/source/_static/title-banner.png
--------------------------------------------------------------------------------
/requirements-test-extra.txt:
--------------------------------------------------------------------------------
1 | jax[cpu]>=0.3.25; platform_system != 'Windows'
2 | torch>=1.1.0,<2.1.0; python_version < '3.12'
3 |
--------------------------------------------------------------------------------
/treevalue/__init__.py:
--------------------------------------------------------------------------------
1 | # noinspection PyPep8Naming
2 | from .config.meta import __VERSION__ as __version__
3 | from .tree import *
4 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/export_demo_1.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue export -t 'tree_demo.t1' -o export_t1.dat.btv
2 | ls -al export_t1.dat.btv
3 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/export_demo_4.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue export -t 'tree_demo.t1' -r -o raw_t1.dat.btv
2 | ls -al raw_t1.dat.btv
3 |
--------------------------------------------------------------------------------
/docs/source/api_doc/config/index.rst:
--------------------------------------------------------------------------------
1 | treevalue.config
2 | =====================
3 |
4 | .. toctree::
5 | :maxdepth: 3
6 |
7 | meta
8 |
--------------------------------------------------------------------------------
/docs/source/tutorials/quick_start/display_tree.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue graph -t 'simple_demo.*' -c 'bgcolor=#ffffff00' -d list -o simple_demo.dat.svg
2 |
--------------------------------------------------------------------------------
/requirements-benchmark.txt:
--------------------------------------------------------------------------------
1 | pytest-benchmark~=3.4.0
2 | dm-tree>=0.1.6
3 | tianshou>=0.4.5
4 | jax[cpu]>=0.2.17
5 | pandas
6 | click>=7.0.0
7 | numpy<2
--------------------------------------------------------------------------------
/treevalue/tree/func/__init__.py:
--------------------------------------------------------------------------------
1 | from .func import func_treelize, MISSING_NOT_ALLOW, AUTO_DETECT_RETURN_TYPE, method_treelize, classmethod_treelize
2 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/export_demo_3.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue export -t 'tree_demo.t1' -c gzip -o compress_t1.dat.btv
2 | ls -al compress_t1.dat.btv
3 |
--------------------------------------------------------------------------------
/docs/source/tutorials/quick_start/display_complex_tree.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue graph -t 'complex_demo.*' -c 'bgcolor=#ffffff00' -d list -o complex_demo.dat.svg
2 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
2 | include MANIFEST.in
3 | include requirements.txt
4 | include requirements-*.txt
5 | recursive-include treevalue *.pyx *.pxd
6 |
--------------------------------------------------------------------------------
/docs/source/best_practice/sklearn/heading_of_pca.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/treevalue/HEAD/docs/source/best_practice/sklearn/heading_of_pca.jpg
--------------------------------------------------------------------------------
/treevalue/tree/__init__.py:
--------------------------------------------------------------------------------
1 | from .common import raw
2 | from .func import *
3 | from .general import *
4 | from .integration import *
5 | from .tree import *
6 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/export_demo_2.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue export -t 'tree_demo.*' -o 'me_t1.dat.btv' -o 'me_t2.dat.btv' -o 'me_t3.dat.btv'
2 | ls -al me_*.btv
3 |
--------------------------------------------------------------------------------
/requirements-build.txt:
--------------------------------------------------------------------------------
1 | # cython>=0.29; platform_system != 'Windows'
2 | # cython>=0.29; platform_system == 'Windows'
3 | cython>=3
4 | build>=0.7.0
5 | auditwheel>=4
--------------------------------------------------------------------------------
/docs/source/_shims/treevalue:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from treevalue.entry.cli import treevalue_cli
4 |
5 | if __name__ == '__main__':
6 | treevalue_cli()
7 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/graph_demo_7.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue graph \
2 | -t 'node_share_demo.*' \
3 | -O \
4 | -T 'PNG Formatted Graph' \
5 | -c 'bgcolor=#ffffff00'
6 |
--------------------------------------------------------------------------------
/benchmark/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | CMP_N = int(os.environ.get('CMP_N', None) or '5')
5 | HAS_CUDA = bool(os.environ.get('HAS_CUDA', shutil.which('nvidia-smi')))
6 |
--------------------------------------------------------------------------------
/runs/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: all run
2 |
3 | all: run
4 | run:
5 | echo "hello world"
6 | echo ${PYTHONPATH}
7 |
8 | clean:
9 | rm -rf artifacts/*
10 | touch artifacts/.keep
11 |
--------------------------------------------------------------------------------
/treevalue/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .color import Color
2 | from .formattree import format_tree
3 | from .random import random_hex, random_hex_with_timestamp
4 | from .tree import build_graph
5 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/graph_demo_8.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue graph \
2 | -t 'share_demo.*' \
3 | -o 'no_shared_values.dat.svg' \
4 | -T 'No Shared Values' \
5 | -c 'bgcolor=#ffffff00'
6 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/graph_demo_5.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue graph \
2 | -t 'node_share_demo.*' \
3 | -o 'shared_nodes.dat.png' \
4 | -T 'PNG Formatted Graph' \
5 | -c 'bgcolor=#ffffff00'
6 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/graph_demo_6.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue graph \
2 | -t 'node_share_demo.*' \
3 | -o 'shared_nodes.dat.gv' \
4 | -T 'PNG Formatted Graph' \
5 | -c 'bgcolor=#ffffff00'
6 |
--------------------------------------------------------------------------------
/docs/source/api_doc/utils/index.rst:
--------------------------------------------------------------------------------
1 | treevalue.utils
2 | ===========================
3 |
4 | .. toctree::
5 | :maxdepth: 3
6 |
7 | color
8 | formattree
9 | random
10 | tree
11 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/graph_demo_9.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue graph \
2 | -t 'share_demo.*' \
3 | -o 'shared_all_values.dat.svg' \
4 | -T 'Shared All Values' \
5 | -c 'bgcolor=#ffffff00' \
6 | -D
7 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | coverage:
2 | status:
3 | project:
4 | default:
5 | # basic
6 | target: auto
7 | threshold: 1%
8 | if_ci_failed: success # success, failure, error, ignore
9 |
--------------------------------------------------------------------------------
/docs/source/api_doc/tree/index.rst:
--------------------------------------------------------------------------------
1 | treevalue.tree
2 | =====================
3 |
4 | .. toctree::
5 | :maxdepth: 3
6 |
7 | common
8 | tree
9 | func
10 | general
11 | integration
12 |
--------------------------------------------------------------------------------
/install_test.sh:
--------------------------------------------------------------------------------
1 | mkdir -p .installs
2 |
3 | git clone --depth=1 https://github.com/facebookresearch/torchbeast.git .installs/torchbeast
4 | cd .installs/torchbeast/nest
5 | CXX=c++ pip install . -vv
6 | cd ../../..
7 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/graph_demo_4.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue graph \
2 | -t 'node_share_demo.*' \
3 | -o 'shared_nodes_with_cfg.dat.svg' \
4 | -T 'Graph to Show the Shared Nodes' \
5 | -c 'bgcolor=#ffffff00'
6 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/graph_demo_10.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue graph \
2 | -t 'share_demo.*' \
3 | -o 'shared_values.dat.svg' \
4 | -T 'Shared Values' \
5 | -c 'bgcolor=#ffffff00' \
6 | -d list -d numpy.ndarray
7 |
--------------------------------------------------------------------------------
/treevalue/tree/integration/base.pxd:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | cdef tuple _c_flatten_for_integration(object tv)
5 | cdef object _c_unflatten_for_integration(object values, tuple spec)
6 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/inherit_numpy_demo.demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 | if __name__ == '__main__':
6 | ar1 = np.array([[1, 2], [3, 4]])
7 | print('ar1 + 9:', ar1 + 9, sep=os.linesep)
8 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/export_demo_5.demo.sh:
--------------------------------------------------------------------------------
1 | treevalue export -t 'large_tree_demo.t1' -r -o c_o_large_t1.dat.btv
2 | treevalue export -t 'large_tree_demo.t1' -o c_x_large_t1.dat.btv
3 | ls -al c_*_large_t1.dat.btv
4 |
--------------------------------------------------------------------------------
/treevalue/tree/common/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import raw, unraw, RawWrapper
2 | from .delay import DelayedProxy, delayed_partial, undelay, DelayedValueProxy, DelayedFuncProxy
3 | from .storage import TreeStorage, create_storage
4 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/tree_demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import FastTreeValue
2 |
3 | t1 = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
4 | t2 = FastTreeValue({'a': 11, 'b': 24, 'x': {'c': 30, 'd': 47}})
5 | t3 = t1 + t2
6 |
--------------------------------------------------------------------------------
/test/config/test_meta.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from treevalue.config.meta import __TITLE__
4 |
5 |
6 | @pytest.mark.unittest
7 | class TestConfigMeta:
8 | def test_title(self):
9 | assert __TITLE__ == 'treevalue'
10 |
--------------------------------------------------------------------------------
/test/tree/tree/test_tree.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from treevalue import TreeValue
4 | from .base import get_treevalue_test
5 |
6 |
7 | @pytest.mark.unittest
8 | class TestTreeTreeTree(get_treevalue_test(TreeValue)):
9 | pass
10 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guide
2 |
3 | * [Architecture of TreeValue](https://opendilab.github.io/treevalue/main/contribute/architecture/index.html)
4 |
5 | Other guide content is still under development, will be completed soon afterwards.
6 |
--------------------------------------------------------------------------------
/treevalue/tree/integration/cjax.pxd:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | cdef tuple _c_flatten_for_jax(object tv)
5 | cdef object _c_unflatten_for_jax(tuple aux, tuple values)
6 | cpdef void register_for_jax(object cls) except*
7 |
--------------------------------------------------------------------------------
/treevalue/tree/general/fast.py:
--------------------------------------------------------------------------------
1 | from .general import general_tree_value
2 |
3 |
4 | class FastTreeValue(general_tree_value()):
5 | """
6 | Overview:
7 | Fast tree value, can do almost anything with this.
8 | """
9 | pass
10 |
--------------------------------------------------------------------------------
/treevalue/tree/integration/ctorch.pxd:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | cdef tuple _c_flatten_for_torch(object tv)
5 | cdef object _c_unflatten_for_torch(list values, tuple context)
6 | cpdef void register_for_torch(object cls) except*
--------------------------------------------------------------------------------
/test/tests/test_benchmark.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | @pytest.mark.benchmark(group="system")
5 | class TestSystemBenchmark:
6 | def test_empty_func(self, benchmark):
7 | def empty_func():
8 | pass
9 |
10 | benchmark(empty_func)
11 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/diy_class_x_tv.demo.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import sys
3 |
4 | from treevalue import FastTreeValue
5 |
6 | if __name__ == '__main__':
7 | _module = sys.modules[FastTreeValue.__module__]
8 | print(pathlib.Path(_module.__file__).read_text())
9 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/large_tree_demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import FastTreeValue
2 |
3 | t1 = FastTreeValue({
4 | 'a': 1,
5 | 'b': [2] * 1000, # huge array
6 | 'x': {
7 | 'c': b'aklsdfj' * 2000, # huge bytes
8 | 'd': 4
9 | }
10 | })
11 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ## Description
2 |
3 |
4 | ## Related Issue
5 |
6 |
7 | ## TODO
8 |
9 |
10 | ## Check List
11 |
12 | - [ ] merge the latest version source branch/repo, and resolve all the conflicts
13 | - [ ] pass style check
14 | - [ ] pass all the tests
15 |
--------------------------------------------------------------------------------
/docs/main_page.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Redirecting to master branch
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/treevalue/tree/common/base.pxd:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | cdef class RawWrapper:
5 | cdef readonly object val
6 |
7 | cpdef object value(self)
8 |
9 | cpdef public object raw(object obj)
10 | cpdef public object unraw(object wrapped)
11 |
--------------------------------------------------------------------------------
/docs/source/api_doc/utils/formattree.rst:
--------------------------------------------------------------------------------
1 | treevalue.utils.formattree
2 | ============================
3 |
4 | .. py:currentmodule:: treevalue.utils.formattree
5 |
6 | .. automodule:: treevalue.utils.formattree
7 |
8 |
9 | format_tree
10 | ---------------
11 |
12 | .. autofunction:: format_tree
13 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/functional_python_demo.demo.py:
--------------------------------------------------------------------------------
1 | if __name__ == '__main__':
2 | # map function
3 | print("Result of map:", list(map(lambda x: x + 1, [2, 3, 5, 7])))
4 |
5 | # filter function
6 | print("Result of filter:", list(filter(lambda x: x % 4 == 3, [2, 3, 5, 7])))
7 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/mapping_demo.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import mapping, FastTreeValue
2 |
3 | if __name__ == '__main__':
4 | t = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
5 |
6 | print('mapping(t, lambda x: x ** x + 2):')
7 | print(mapping(t, lambda x: x ** x + 2))
8 |
--------------------------------------------------------------------------------
/docs/source/tutorials/main_idea/treevalue_demo_2.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import FastTreeValue
2 |
3 | if __name__ == "__main__":
4 | d1 = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
5 | d2 = FastTreeValue({'a': 11, 'b': 22, 'x': {'c': 30, 'd': 47}})
6 |
7 | print("d1 + d2:")
8 | print(d1 + d2)
9 |
--------------------------------------------------------------------------------
/docs/source/api_doc/utils/build_graph.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue.utils import build_graph
2 |
3 | if __name__ == '__main__':
4 | t = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}
5 | g = build_graph((t, 't'), graph_title="Demo of build_graph.")
6 |
7 | print(g.source)
8 | print(g.render('build_graph_demo.dat.gv', format='svg'))
9 |
--------------------------------------------------------------------------------
/docs/source/tutorials/quick_start/create_a_tree.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import FastTreeValue
2 |
3 | t = FastTreeValue({
4 | 'a': 1,
5 | 'b': 2.3,
6 | 'x': {
7 | 'c': 'str',
8 | 'd': [1, 2, None],
9 | 'e': b'bytes',
10 | }
11 | })
12 |
13 | if __name__ == '__main__':
14 | print(t)
15 |
--------------------------------------------------------------------------------
/treevalue/entry/cli/dispatch.py:
--------------------------------------------------------------------------------
1 | from .base import _base_treevalue_cli
2 | from .export import _export_cli
3 | from .graph import _graph_cli
4 | from .utils import _cli_builder
5 |
6 | treevalue_cli = _cli_builder(
7 | _base_treevalue_cli,
8 | _graph_cli, # treevalue graph
9 | _export_cli, # treevalue export
10 | )
11 |
--------------------------------------------------------------------------------
/docs/source/tutorials/quick_start/complex_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dm import get_module
4 |
5 | from treevalue import TreeValue
6 |
7 | _module = get_module(os.path.abspath('create_a_complex_tree.demo.py'))
8 | for key, value in _module.__dict__.items():
9 | if isinstance(value, TreeValue):
10 | locals()[key] = value
11 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/node_share_demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import FastTreeValue
2 |
3 | nt1 = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
4 | nt2 = FastTreeValue({'a': 11, 'b': 24, 'x': {'c': 30, 'd': 47}})
5 | nt3 = FastTreeValue({
6 | 'first': nt1,
7 | 'second': nt2,
8 | 'another': nt1.x,
9 | 'sum': nt1 + nt2,
10 | })
11 |
--------------------------------------------------------------------------------
/examples/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: clean
2 |
3 | JUPYTER ?= $(shell which jupyter)
4 | NBCONVERT ?= ${JUPYTER} nbconvert
5 |
6 | SOURCE ?= .
7 | IPYNBS := $(shell find ${SOURCE} -name *.ipynb | grep -v .ipynb_checkpoints)
8 |
9 | clean:
10 | for nb in ${IPYNBS}; do \
11 | if [ -f $$nb ]; then \
12 | $(NBCONVERT) --clear-output --inplace $$nb; \
13 | fi; \
14 | done;
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/unflatten_demo.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import unflatten
2 |
3 | if __name__ == '__main__':
4 | flatted = [
5 | (('a',), 1),
6 | (('b',), 2),
7 | (('c',), {'x': 3, 'y': 4}),
8 | (('d', 'x'), 3),
9 | (('d', 'y'), 4)
10 | ]
11 |
12 | print('unflatten(flatted):')
13 | print(unflatten(flatted))
14 |
--------------------------------------------------------------------------------
/docs/source/tutorials/cli_usage/share_demo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from treevalue import FastTreeValue
4 |
5 | st1 = FastTreeValue({
6 | 'a': 1,
7 | 'b': [1, 2],
8 | 'x': {
9 | 'c': 3,
10 | 'd': np.zeros((3, 2)),
11 | }
12 | })
13 | st2 = FastTreeValue({
14 | 'np': st1.x.d,
15 | 'ar': st1.b,
16 | 'a': st1.a,
17 | 'arx': [1, 2],
18 | })
19 |
--------------------------------------------------------------------------------
/docs/source/api_doc/utils/color.rst:
--------------------------------------------------------------------------------
1 | treevalue.utils.color
2 | ===========================
3 |
4 | .. py:currentmodule:: treevalue.utils.color
5 |
6 | .. automodule:: treevalue.utils.color
7 |
8 | Color
9 | ----------------
10 |
11 | .. autoclass:: Color
12 | :members: __init__, alpha, rgb, hsv, hls, __repr__, __str__, __getstate__, __setstate__, __hash__, __eq__, from_hsv, from_hls
13 |
14 |
15 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/flatten_demo.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import TreeValue, raw, flatten
2 |
3 | if __name__ == '__main__':
4 | t = TreeValue({
5 | 'a': 1,
6 | 'b': 2,
7 | 'c': raw({'x': 3, 'y': 4}),
8 | 'd': {
9 | 'x': 3,
10 | 'y': 4
11 | },
12 | })
13 |
14 | print('flatten(t):')
15 | print(flatten(t))
16 |
--------------------------------------------------------------------------------
/requirements-doc.txt:
--------------------------------------------------------------------------------
1 | Jinja2~=3.0.0
2 | sphinx~=3.2.0
3 | sphinx_rtd_theme~=0.4.3
4 | enum_tools~=0.9.0
5 | sphinx-toolbox
6 | plantumlcli>=0.0.4
7 | packaging
8 | sphinx-multiversion~=0.2.4
9 | where~=1.0.2
10 | numpy>=1.19,<2
11 | easydict>=1.7,<2
12 | scikit-learn>=0.24.2
13 | potc-treevalue>=0.0.1
14 | nbsphinx>=0.8.8
15 | ipython>=7.16.3
16 | psutil>=5.8.0
17 | ipykernel>=6.15
18 | py-cpuinfo>=8.0.0
--------------------------------------------------------------------------------
/test/testings/mapping.py:
--------------------------------------------------------------------------------
1 | import collections.abc
2 |
3 |
4 | class CustomMapping(collections.abc.Mapping):
5 | def __init__(self, **kwargs):
6 | self._kwargs = kwargs
7 |
8 | def __getitem__(self, __key):
9 | return self._kwargs[__key]
10 |
11 | def __len__(self):
12 | return len(self._kwargs)
13 |
14 | def __iter__(self):
15 | yield from self._kwargs
16 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/pickle_demo_1.demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | from treevalue import TreeValue
5 |
6 | if __name__ == '__main__':
7 | t = TreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': [3, 4]}})
8 | binary = pickle.dumps(t)
9 | print('t:', t, sep=os.linesep)
10 |
11 | tx = pickle.loads(binary)
12 | assert tx == t
13 | print('tx:', tx, sep=os.linesep)
14 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/walk_demo.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import TreeValue, raw, walk
2 |
3 | if __name__ == '__main__':
4 | t = TreeValue({
5 | 'a': 1,
6 | 'b': 2,
7 | 'c': raw({'x': 3, 'y': 4}),
8 | 'd': {
9 | 'x': 3,
10 | 'y': 4
11 | },
12 | })
13 |
14 | for path, node in walk(t):
15 | print(path, '-->', node)
16 |
--------------------------------------------------------------------------------
/test/tree/general/test_meta.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from .base import get_fasttreevalue_test
3 | from treevalue.tree import FastTreeValue
4 |
5 |
6 |
7 | class _MyMetaClass(type):
8 | pass
9 |
10 |
11 | class MyMetaTreeValue(FastTreeValue, metaclass=_MyMetaClass):
12 | pass
13 |
14 |
15 | @pytest.mark.unittest
16 | class TestTreeGeneralMeta(get_fasttreevalue_test(MyMetaTreeValue)):
17 | pass
18 |
--------------------------------------------------------------------------------
/docs/source/api_doc/utils/random.rst:
--------------------------------------------------------------------------------
1 | treevalue.utils.random
2 | ============================
3 |
4 | .. py:currentmodule:: treevalue.utils.random
5 |
6 | .. automodule:: treevalue.utils.random
7 |
8 | random_hex
9 | -------------------
10 |
11 | .. autofunction:: random_hex
12 |
13 |
14 | random_hex_with_timestamp
15 | ---------------------------
16 |
17 | .. autofunction:: random_hex_with_timestamp
18 |
19 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/jsonify_demo.demo.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from treevalue import TreeValue, raw, jsonify
4 |
5 | if __name__ == '__main__':
6 | t = TreeValue({'a': 1, 'b': [2, 3], 'x': {'c': raw({'x': 1, 'y': 2}), 'd': "this is a string"}})
7 |
8 | print("Tree t:")
9 | print(t)
10 |
11 | print("Json data of t:")
12 | print(json.dumps(jsonify(t), indent=4, sort_keys=True))
13 |
--------------------------------------------------------------------------------
/docs/source/tutorials/main_idea/treevalue_demo_1.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import TreeValue, func_treelize
2 |
3 |
4 | @func_treelize()
5 | def plus(a, b):
6 | return a + b
7 |
8 |
9 | if __name__ == "__main__":
10 | d1 = TreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
11 | d2 = TreeValue({'a': 11, 'b': 22, 'x': {'c': 30, 'd': 47}})
12 |
13 | print("plus(d1, d2):")
14 | print(plus(d1, d2))
15 |
--------------------------------------------------------------------------------
/docs/source/tutorials/quick_start/simple_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dm import get_module
4 |
5 | from treevalue import TreeValue
6 |
7 | _module = get_module(os.path.abspath('create_a_tree.demo.py'))
8 | for key, value in _module.__dict__.items():
9 | if isinstance(value, TreeValue):
10 | locals()[key] = value
11 |
12 | del locals()['key']
13 | del locals()['value']
14 | del locals()['_module']
15 |
--------------------------------------------------------------------------------
/docs/source/contribute/architecture/get_tree_demo.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import TreeValue
2 |
3 | if __name__ == '__main__':
4 | t1 = TreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
5 | storage = t1._detach() # tree is the data tree
6 | data = storage.detach()
7 |
8 | print('t1:')
9 | print(t1)
10 |
11 | print('tree storage:')
12 | print(storage)
13 |
14 | print('data:')
15 | print(data)
16 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/dump_demo_2.demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from treevalue import dumps, FastTreeValue, loads
4 |
5 | if __name__ == '__main__':
6 | t = FastTreeValue({'a': 'ab', 'b': 'Cd', 'x': {'c': 'eF', 'd': 'GH'}})
7 | print('t:', t, sep=os.linesep)
8 |
9 | binary = dumps(t)
10 | dt = loads(binary, type_=FastTreeValue)
11 |
12 | assert dt == t
13 | print('dt:', dt, sep=os.linesep)
14 |
--------------------------------------------------------------------------------
/docs/source/tutorials/basic_usage/tree_support_primitive.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import func_treelize
2 |
3 |
4 | @func_treelize()
5 | def gcd(a, b): # GCD calculation
6 | while True:
7 | r = a % b
8 | a, b = b, r
9 | if r == 0:
10 | break
11 |
12 | return a
13 |
14 |
15 | if __name__ == '__main__':
16 | print("gcd(6, 8):", gcd(6, 8))
17 | print("gcd(900, 768):", gcd(900, 768))
18 |
--------------------------------------------------------------------------------
/docs/source/api_doc/utils/build_graph_complex.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue.utils import build_graph
2 |
3 | if __name__ == '__main__':
4 | t1 = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}
5 | t2 = {'f': 4, 'y': t1['x'], 'z': {'e': [5, 7], 'f': "string"}}
6 | g = build_graph((t1, 't1'), (t2, 't2'), graph_title="Complex demo of build_graph.")
7 |
8 | print(g.source)
9 | print(g.render('build_graph_complex_demo.dat.gv', format='svg'))
10 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/reduce_demo_2.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import reduce_, FastTreeValue, mapping
2 |
3 | if __name__ == '__main__':
4 | t = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}, 'y': {'e': 6, 'f': 8}})
5 |
6 | weights = mapping(t, lambda v, p: v * len(p))
7 | print("Weight tree:", weights)
8 | print("Huffman weight sum of t:",
9 | reduce_(weights, lambda **kwargs: sum(kwargs.values())))
10 |
--------------------------------------------------------------------------------
/treevalue/tree/tree/service.pxd:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | # jsonify, clone, typetrans, walk
5 |
6 | from libcpp cimport bool
7 |
8 | from .tree cimport TreeValue
9 |
10 | cdef object _keep_object(object obj)
11 | cpdef object jsonify(TreeValue val)
12 | cpdef TreeValue clone(TreeValue t, object copy_value= *)
13 | cpdef TreeValue typetrans(TreeValue t, object return_type)
14 | cpdef walk(TreeValue tree)
15 |
--------------------------------------------------------------------------------
/requirements-test.txt:
--------------------------------------------------------------------------------
1 | coverage>=5
2 | mock>=4.0.3
3 | flake8~=3.5
4 | pytest~=6.2.5
5 | pytest-cov~=3.0.0
6 | pytest-mock~=3.6.1
7 | pytest-xdist>=1.34.0
8 | pytest-rerunfailures~=10.2
9 | pytest-timeout~=2.0.2
10 | pytest-benchmark~=3.4.0
11 | testtools>=2
12 | hbutils>=0.6.13
13 | setuptools
14 | # setuptools<=59.5.0
15 | # numpy2.x cannot be used with torch1.7.x
16 | numpy>=1.10,<2; python_version <= '3.9'
17 | numpy>=1.10; python_version > '3.9'
18 | easydict>=1.7,<2
--------------------------------------------------------------------------------
/docs/source/_libs/dm.py:
--------------------------------------------------------------------------------
1 | import importlib.util
2 | import os
3 | from types import ModuleType
4 |
5 |
6 | def get_module(filename) -> ModuleType:
7 | _, _simple_filename = os.path.split(filename)
8 | _module_name, _ = _simple_filename.split('.', maxsplit=1)
9 |
10 | spec = importlib.util.spec_from_file_location(_module_name, filename)
11 | module = importlib.util.module_from_spec(spec)
12 | spec.loader.exec_module(module)
13 |
14 | return module
15 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/diy_class_x_demo_1.demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from treevalue import general_tree_value
4 |
5 |
6 | class MyTreeValue(general_tree_value()):
7 | pass
8 |
9 |
10 | if __name__ == '__main__':
11 | t1 = MyTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
12 | t2 = MyTreeValue({'a': 11, 'b': 24, 'x': {'c': 30, 'd': 47}})
13 |
14 | # __add__ operator can be directly used
15 | print('t1 + t2:', t1 + t2, sep=os.linesep)
16 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/filter_demo.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import FastTreeValue, filter_
2 |
3 | if __name__ == '__main__':
4 | t = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}, 'y': {'e': 6, 'f': 8}})
5 |
6 | print('filter_(t, lambda x: x % 2 == 1):')
7 | print(filter_(t, lambda x: x % 2 == 1))
8 |
9 | print('filter_(t, lambda x: x % 2 == 1, remove_empty=False):')
10 | print(filter_(t, lambda x: x % 2 == 1, remove_empty=False))
11 |
--------------------------------------------------------------------------------
/docs/source/tutorials/quick_start/create_a_complex_tree.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import FastTreeValue
2 |
3 | t1 = FastTreeValue({'a': 's1', 'b': 2, 'x': {'c': 3, 'd': [1, 2]}})
4 | t2 = FastTreeValue({'a': 'str11', 'b': 2, 'x': {'c': 33, 'd': t1.x.d}})
5 | t3 = FastTreeValue({'t1': t1, 't2': t2, 'sum': t1 + t2})
6 |
7 | if __name__ == '__main__':
8 | print('t1:')
9 | print(t1)
10 |
11 | print('t2:')
12 | print(t2)
13 |
14 | print('t3:')
15 | print(t3)
16 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/mask_demo.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import FastTreeValue, mask
2 |
3 | if __name__ == '__main__':
4 | t = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}, 'y': {'e': 6, 'f': 8}})
5 | m = FastTreeValue({'a': True, 'b': False, 'x': {'c': True, 'd': True}, 'y': {'e': False, 'f': False}})
6 |
7 | print('mask(t, m):')
8 | print(mask(t, m))
9 |
10 | print('mask(t, m, remove_empty=False):')
11 | print(mask(t, m, remove_empty=False))
12 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/outer_demo.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import FastTreeValue, func_treelize
2 |
3 |
4 | # missing value is very important when use outer mode
5 | @func_treelize(mode='outer', missing=0)
6 | def plus(a, b):
7 | return a + b
8 |
9 |
10 | if __name__ == '__main__':
11 | t1 = FastTreeValue({'b': 2, 'x': {'c': 3, 'd': 4, 'e': 5, 'f': 6}})
12 | t2 = FastTreeValue({'a': 11, 'b': 22, 'x': {'c': 30, 'd': 48, 'e': 54}})
13 |
14 | print('plus(t1, t2):', plus(t1, t2))
15 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/filter_eq_demo.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import FastTreeValue, mapping, mask
2 |
3 | if __name__ == '__main__':
4 | t = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}, 'y': {'e': 6, 'f': 8}})
5 |
6 | print('mask(t, mapping(t, lambda x: x % 2 == 1)):')
7 | print(mask(t, mapping(t, lambda x: x % 2 == 1)))
8 |
9 | print('mask(t, mapping(t, lambda x: x % 2 == 1), remove_empty=False):')
10 | print(mask(t, mapping(t, lambda x: x % 2 == 1), remove_empty=False))
11 |
--------------------------------------------------------------------------------
/treevalue/tree/tree/__init__.py:
--------------------------------------------------------------------------------
1 | from .constraint import to_constraint, Constraint, NodeConstraint, ValueConstraint, cleaf, vval, vcheck, nval, ncheck
2 | from .flatten import flatten, unflatten, flatten_values, flatten_keys
3 | from .functional import mapping, filter_, mask, reduce_
4 | from .graph import graphics
5 | from .io import loads, load, dumps, dump
6 | from .service import jsonify, clone, typetrans, walk
7 | from .structural import subside, union, rise
8 | from .tree import TreeValue, delayed, ValidationError, register_dict_type
9 |
--------------------------------------------------------------------------------
/docs/source/contribute/architecture/tree_demo.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import TreeValue
2 |
3 | if __name__ == '__main__':
4 | t1 = TreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
5 | t2 = TreeValue(t1) # use the same memory with t1
6 |
7 | print("Initial t1:")
8 | print(t1)
9 | print("Initial t2:")
10 | print(t2)
11 | print()
12 |
13 | t1.a, t1.x.c = 7, 5 # only t1 is updated in code
14 | print("Updated t1:")
15 | print(t1)
16 | print("Updated t2:")
17 | print(t2)
18 | print()
19 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/reduce_demo_1.demo.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 | from operator import __mul__
3 |
4 | from treevalue import reduce_, FastTreeValue
5 |
6 |
7 | def multi(items):
8 | return reduce(__mul__, items, 1)
9 |
10 |
11 | if __name__ == '__main__':
12 | t = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}, 'y': {'e': 6, 'f': 8}})
13 |
14 | print("Sum of t:", reduce_(t, lambda **kwargs: sum(kwargs.values())))
15 | print("Multiply of t:", reduce_(t, lambda **kwargs: multi(kwargs.values())))
16 |
--------------------------------------------------------------------------------
/docs/source/tutorials/basic_usage/index_and_slice.demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from treevalue import FastTreeValue
4 |
5 | if __name__ == '__main__':
6 | t = FastTreeValue({
7 | 'a': [1, 2, 3],
8 | 'b': [4, 9, 16],
9 | 'x': {
10 | 'c': [11, 13, 17],
11 | 'd': [-2, -4, -8]
12 | }
13 | })
14 | print("Result of t[0]:", t[0], sep=os.linesep) # __getitem__ operator
15 | print("Result of t[::-1]:", t[::-1], sep=os.linesep)
16 | print("Result of t.x[1:]:", t.x[1:], sep=os.linesep)
17 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/left_demo.demox.py:
--------------------------------------------------------------------------------
1 | from treevalue import FastTreeValue, func_treelize
2 |
3 |
4 | @func_treelize(mode='left')
5 | def plus(a, b):
6 | return a + b
7 |
8 |
9 | if __name__ == '__main__':
10 | t1 = FastTreeValue({'b': 2, 'x': {'c': 3, 'd': 4, 'e': 5}})
11 | t2 = FastTreeValue({'a': 11, 'b': 22, 'x': {'c': 30, 'd': 48, 'e': 54}})
12 | t4 = FastTreeValue({'b': 2, 'x': {'c': 3, 'd': 4, 'e': 5, 'f': 6}})
13 |
14 | print('plus(t1, t2):', plus(t1, t2))
15 | print('plus(t4, t2):', plus(t4, t2))
16 |
--------------------------------------------------------------------------------
/docs/source/tutorials/plugins/potc_demo.demo.py:
--------------------------------------------------------------------------------
1 | from potc import transvars
2 |
3 | from treevalue import FastTreeValue, raw
4 |
5 | r = raw({'a': 1, 'b': 2, 'c': [3, 4]})
6 | t = FastTreeValue({
7 | 'a': 1, 'b': 'this is a string',
8 | 'c': [], 'd': {
9 | 'x': raw({'a': 1, 'b': (None, Ellipsis)}),
10 | 'y': {3, 4, 5}
11 | }
12 | })
13 | st = t._detach()
14 | if __name__ == '__main__':
15 | _code = transvars(
16 | {'t': t, 'st': t._detach(), 'r': r},
17 | reformat='pep8'
18 | )
19 | print(_code)
20 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Examples of TreeValue
2 |
3 | ## Examples
4 |
5 | Here are some TreeValue examples that can be viewed on Colab.
6 |
7 | * Visualization: [](https://colab.research.google.com/github/opendilab/treevalue/blob/main/examples/visualization.ipynb)
8 |
9 | ## Makefile
10 |
11 | ```shell
12 | make clean # clean all output in notebook files
13 |
14 | ```
15 |
16 | **Please make sure the output of all notebook files are cleared with `make clean` before committing**.
17 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/inner_demo.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import FastTreeValue, func_treelize
2 |
3 |
4 | @func_treelize(mode='inner')
5 | def plus(a, b):
6 | return a + b
7 |
8 |
9 | if __name__ == '__main__':
10 | t1 = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4, 'e': 5}})
11 | t2 = FastTreeValue({'a': 11, 'b': 22, 'x': {'c': 30, 'd': 48, 'e': 54}})
12 | t4 = FastTreeValue({'b': 2, 'x': {'c': 3, 'd': 4, 'e': 5, 'f': 6}})
13 |
14 | print('plus(t1, t2):', plus(t1, t2))
15 | print('plus(t4, t2):', plus(t4, t2))
16 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/strict_demo.demox.py:
--------------------------------------------------------------------------------
1 | from treevalue import FastTreeValue, func_treelize
2 |
3 |
4 | @func_treelize(mode='strict')
5 | def plus(a, b):
6 | return a + b
7 |
8 |
9 | if __name__ == '__main__':
10 | t1 = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4, 'e': 5}})
11 | t2 = FastTreeValue({'a': 11, 'b': 22, 'x': {'c': 30, 'd': 48, 'e': 54}})
12 | t4 = FastTreeValue({'b': 2, 'x': {'c': 3, 'd': 4, 'e': 5, 'f': 6}})
13 |
14 | print('plus(t1, t2):', plus(t1, t2))
15 | print('plus(t4, t2):', plus(t4, t2))
16 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/reduce_demo_3.demo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from treevalue import reduce_, FastTreeValue
4 |
5 | if __name__ == '__main__':
6 | t = FastTreeValue({
7 | 'a': np.identity(3),
8 | 'b': np.array([[1, 2], [3, 4]]),
9 | 'x': {
10 | 'c': np.zeros(4),
11 | 'd': np.array([[5, 6, 7], [8, 9, 10]])
12 | },
13 | })
14 |
15 | print("Size tree:", t.nbytes)
16 | print("Total bytes of arrays in t:",
17 | reduce_(t.nbytes, lambda **kwargs: sum(kwargs.values())))
18 |
--------------------------------------------------------------------------------
/treevalue/config/meta.py:
--------------------------------------------------------------------------------
1 | """
2 | Overview:
3 | Meta information for treevalue package.
4 | """
5 |
6 | #: Title of this project (should be `treevalue`).
7 | __TITLE__ = "treevalue"
8 |
9 | #: Version of this project.
10 | __VERSION__ = "1.5.0"
11 |
12 | #: Short description of the project, will be included in ``setup.py``.
13 | __DESCRIPTION__ = 'A flexible, generalized tree-based data structure.'
14 |
15 | #: Author of this project.
16 | __AUTHOR__ = "HansBug, DI-engine's Contributors"
17 |
18 | #: Email of the authors'.
19 | __AUTHOR_EMAIL__ = "hansbug@buaa.edu.cn"
20 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/subside_demo.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import TreeValue, subside
2 |
3 | if __name__ == '__main__':
4 | t1 = TreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
5 | t2 = TreeValue({'a': -14, 'b': 9, 'x': {'c': 3, 'd': 8}})
6 | t3 = TreeValue({'a': 6, 'b': 0, 'x': {'c': -5, 'd': 17}})
7 | t4 = TreeValue({'a': 0, 'b': -17, 'x': {'c': -8, 'd': 15}})
8 | t5 = TreeValue({'a': 3, 'b': 9, 'x': {'c': 11, 'd': -17}})
9 |
10 | st = {'first': (t1, t2), 'second': [t3, {'x': t4, 'y': t5}]}
11 | print("subside(st):")
12 | print(subside(st))
13 |
--------------------------------------------------------------------------------
/docs/source/graphviz.mk:
--------------------------------------------------------------------------------
1 | DOT := $(shell which dot)
2 |
3 | SOURCE ?= .
4 | GVS := $(shell find ${SOURCE} -name *.gv)
5 | PNGS := $(addsuffix .gv.png, $(basename ${GVS}))
6 | SVGS := $(addsuffix .gv.svg, $(basename ${GVS}))
7 |
8 | %.gv.png: %.gv
9 | $(DOT) -Tpng -o"$(shell readlink -f $@)" "$(shell readlink -f $<)"
10 |
11 | %.gv.svg: %.gv
12 | $(DOT) -Tsvg -o"$(shell readlink -f $@)" "$(shell readlink -f $<)"
13 |
14 | build: ${SVGS} ${PNGS}
15 |
16 | all: build
17 |
18 | clean:
19 | rm -rf \
20 | $(shell find ${SOURCE} -name *.gv.svg) \
21 | $(shell find ${SOURCE} -name *.gv.png) \
22 |
--------------------------------------------------------------------------------
/test/utils/test_random.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import pytest
4 |
5 | from treevalue.utils import random_hex, random_hex_with_timestamp
6 |
7 |
8 | @pytest.mark.unittest
9 | class TestUtilsRandom:
10 | def test_random_hex(self):
11 | assert re.fullmatch(r'^[a-f0-9]{32}$', random_hex())
12 | assert re.fullmatch(r'^[a-f0-9]{48}$', random_hex(48))
13 |
14 | def test_random_hex_with_timestamp(self):
15 | assert re.fullmatch(r'^\d{8}_\d{12}_[a-f0-9]{12}$', random_hex_with_timestamp())
16 | assert re.fullmatch(r'^\d{8}_\d{12}_[a-f0-9]{48}$', random_hex_with_timestamp(48))
17 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/dump_demo_1.demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 |
4 | from treevalue import TreeValue, dump, load
5 |
6 | if __name__ == '__main__':
7 | t = TreeValue({'a': 'ab', 'b': 'Cd', 'x': {'c': 'eF', 'd': 'GH'}})
8 | print('t:', t, sep=os.linesep)
9 |
10 | with tempfile.NamedTemporaryFile() as tf:
11 | with open(tf.name, 'wb') as wf: # dump t to file
12 | dump(t, file=wf)
13 |
14 | with open(tf.name, 'rb') as rf: # load dt from file
15 | dt = load(file=rf)
16 |
17 | assert dt == t
18 | print('dt:', dt, sep=os.linesep)
19 |
--------------------------------------------------------------------------------
/docs/source/tutorials/main_idea/native_python_demo.demo.py:
--------------------------------------------------------------------------------
1 | # native plus between dictionaries
2 | def plus(a, b):
3 | _result = {}
4 | for key in set(a.keys()) | set(b.keys()):
5 | if isinstance(a[key], int) and isinstance(b[key], int):
6 | _result[key] = a[key] + b[key]
7 | else:
8 | _result[key] = plus(a[key], b[key])
9 |
10 | return _result
11 |
12 |
13 | if __name__ == "__main__":
14 | d1 = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}
15 | d2 = {'a': 11, 'b': 22, 'x': {'c': 30, 'd': 47}}
16 |
17 | print('plus(d1, d2):')
18 | print(plus(d1, d2))
19 |
--------------------------------------------------------------------------------
/docs/source/best_practice/sklearn/sklearn.demo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.decomposition import PCA
3 |
4 | from treevalue import FastTreeValue
5 |
6 | fit_transform = FastTreeValue.func()(lambda x: PCA(min(*x.shape)).fit_transform(x))
7 |
8 | if __name__ == '__main__':
9 | data = FastTreeValue({
10 | 'a': np.random.randint(-5, 15, (4, 3)),
11 | 'x': {
12 | 'c': np.random.randint(-15, 5, (5, 4)),
13 | }
14 | })
15 | print("Original int data:")
16 | print(data)
17 |
18 | pdata = fit_transform(data)
19 | print("Fit transformed data:")
20 | print(pdata)
21 |
--------------------------------------------------------------------------------
/docs/source/tutorials/basic_usage/tree_support_2.gv:
--------------------------------------------------------------------------------
1 | digraph tree_support {
2 | graph [bgcolor = "#ffffff00"];
3 |
4 | subgraph cluster_step2 {
5 | label = "Result of gcd(t1, t2)"
6 | root2 [label = "tr = gcd(t1, t2)"];
7 | n21 [label = "2 = gcd(2, 4)"];
8 | n22 [label = "6 = gcd(30, 48)"];
9 | n23 [label = "tr.x = gcd(t1.x, t2.x)"];
10 | n24 [label = "2 = gcd(4, 6)"];
11 | n25 [label = "9 = gcd(9, 54)"];
12 | root2 -> n21 [label = "a"];
13 | root2 -> n22 [label = "b"];
14 | root2 -> n23 [label = "x"];
15 | n23 -> n24 [label = "c"];
16 | n23 -> n25 [label = "d"];
17 | }
18 | }
--------------------------------------------------------------------------------
/docs/source/tutorials/basic_usage/calculation.demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from treevalue import FastTreeValue
4 |
5 | if __name__ == '__main__':
6 | t1 = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
7 | t2 = FastTreeValue({'a': 3, 'b': 7, 'x': {'c': 14, 'd': -5}})
8 |
9 | print("Result of t1 + t2:", t1 + t2, sep=os.linesep) # __add__ operator
10 | print("Result of t1 - t2:", t1 - t2, sep=os.linesep) # __sub__ operator
11 | print("Result of t1 ^ t2:", t1 ^ t2, sep=os.linesep) # __xor__ operator
12 | print("Result of t1 + t2 * (-4 + t1 ** t2)", t1 + t2 * (-4 + t1 ** -t2)) # mathematics calculation
13 |
--------------------------------------------------------------------------------
/docs/source/tutorials/basic_usage/tree_support.demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from treevalue import FastTreeValue, func_treelize
4 |
5 |
6 | @func_treelize()
7 | def gcd(a, b): # GCD calculation
8 | while True:
9 | r = a % b
10 | a, b = b, r
11 | if r == 0:
12 | break
13 |
14 | return a
15 |
16 |
17 | if __name__ == '__main__':
18 | t1 = FastTreeValue({'a': 2, 'b': 30, 'x': {'c': 4, 'd': 9}})
19 | t2 = FastTreeValue({'a': 4, 'b': 48, 'x': {'c': 6, 'd': 54}})
20 |
21 | print("Result of gcd(t1, t2):", gcd(t1, t2), sep=os.linesep)
22 | print("Result of gcd(12, 9):", gcd(12, 9))
23 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/union_demo.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import FastTreeValue, union
2 |
3 | if __name__ == '__main__':
4 | t1 = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4, 'e': 5}})
5 | t2 = FastTreeValue({'a': 11, 'b': 22, 'x': {'c': 30, 'd': 48, 'e': 54}})
6 | t3 = FastTreeValue({'a': -13, 'b': -7, 'x': {'c': -5, 'd': -3, 'e': -2}})
7 | t4 = FastTreeValue({'a': -13, 'b': -7, 'x': 8})
8 |
9 | print("union(t1, t2):")
10 | print(union(t1, t2))
11 |
12 | print("union(t1, t2, t3):")
13 | print(union(t1, t2, t3))
14 |
15 | print("union(t1, t2, t3, t4):")
16 | print(union(t1, t2, t3, t4))
17 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/inherit_demo.demox.py:
--------------------------------------------------------------------------------
1 | from treevalue import FastTreeValue, func_treelize
2 |
3 |
4 | @func_treelize(mode='strict')
5 | def plus(a, b):
6 | return a + b
7 |
8 |
9 | @func_treelize(mode='strict', inherit=False)
10 | def plusx(a, b):
11 | return a + b
12 |
13 |
14 | if __name__ == '__main__':
15 | t1 = FastTreeValue({'a': 1, 'b': 2, 'x': 9})
16 | t2 = FastTreeValue({'a': 11, 'b': 22, 'x': {'c': 30, 'd': 48, 'e': 54}})
17 |
18 | print('plus(t1, t2):', plus(t1, t2))
19 | print('plus(t2, 5):', plus(t2, 5))
20 | print()
21 |
22 | print('plusx(t1, t2):', plusx(t1, t2))
23 | print()
24 |
--------------------------------------------------------------------------------
/docs/source/diagrams.mk:
--------------------------------------------------------------------------------
1 | PLANTUMLCLI ?= $(shell which plantumlcli)
2 |
3 | SOURCE ?= .
4 | PUMLS := $(shell find ${SOURCE} -name *.puml)
5 | PNGS := $(addsuffix .puml.png, $(basename ${PUMLS}))
6 | SVGS := $(addsuffix .puml.svg, $(basename ${PUMLS}))
7 |
8 | %.puml.png: %.puml
9 | $(PLANTUMLCLI) -t png -o "$(shell readlink -f $@)" "$(shell readlink -f $<)"
10 |
11 | %.puml.svg: %.puml
12 | $(PLANTUMLCLI) -t svg -o "$(shell readlink -f $@)" "$(shell readlink -f $<)"
13 |
14 | build: ${SVGS} ${PNGS}
15 |
16 | all: build
17 |
18 | clean:
19 | rm -rf \
20 | $(shell find ${SOURCE} -name *.puml.svg) \
21 | $(shell find ${SOURCE} -name *.puml.png) \
22 |
--------------------------------------------------------------------------------
/docs/source/tutorials/basic_usage/calculation_self.demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from treevalue import FastTreeValue
4 |
5 | if __name__ == '__main__':
6 | t1 = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
7 | t2 = FastTreeValue({'a': 3, 'b': 7, 'x': {'c': 14, 'd': -5}})
8 |
9 | print('t1:', t1, sep=os.linesep)
10 | print('t2:', t2, sep=os.linesep)
11 | print('t1 + t2:', t1 + t2, sep=os.linesep)
12 | _original_ids = (id(t1), id(t2))
13 | print()
14 |
15 | t1 += t2
16 | print('After t1 += t2')
17 | print('t1:', t1, sep=os.linesep)
18 | print('t2:', t2, sep=os.linesep)
19 | assert (id(t1), id(t2)) == _original_ids
20 |
--------------------------------------------------------------------------------
/treevalue/tree/tree/flatten.pxd:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | # flatten, unflatten
5 |
6 | from .tree cimport TreeValue
7 | from ..common.storage cimport TreeStorage
8 |
9 | cdef void _c_flatten(TreeStorage st, tuple path, list res) except *
10 | cpdef list flatten(TreeValue tree)
11 |
12 | cdef void _c_flatten_values(TreeStorage st, list res) except *
13 | cpdef list flatten_values(TreeValue tree)
14 |
15 | cdef void _c_flatten_keys(TreeStorage st, tuple path, list res) except *
16 | cpdef list flatten_keys(TreeValue tree)
17 |
18 | cdef TreeStorage _c_unflatten(object pairs)
19 | cpdef TreeValue unflatten(object pairs, object return_type= *)
20 |
--------------------------------------------------------------------------------
/test/entry/cli/test_version.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from click.testing import CliRunner
3 |
4 | from treevalue.config.meta import __TITLE__, __VERSION__
5 | from treevalue.entry.cli import treevalue_cli
6 |
7 |
8 | @pytest.mark.unittest
9 | class TestEntryCliVersion:
10 | def test_version(self):
11 | runner = CliRunner()
12 | result = runner.invoke(treevalue_cli, args=['-v'])
13 |
14 | assert result.exit_code == 0, f'Runtime Error (exitcode {result.exit_code}), ' \
15 | f'The output is:\n{result.output}'
16 | assert __TITLE__.lower() in result.stdout.lower()
17 | assert __VERSION__.lower() in result.stdout.lower()
18 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/clone_demo.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import TreeValue, clone, raw
2 |
3 | if __name__ == '__main__':
4 | t = TreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4, 'y': raw({'e': 5, 'f': 6})}})
5 |
6 | print("Tree t:")
7 | print(t)
8 | print("Id of t.x.y: %x" % id(t.x.y))
9 | print()
10 | print()
11 |
12 | print("clone(t):")
13 | print(clone(t))
14 | print("Id of clone(t).x.y: %x" % id(clone(t).x.y))
15 | print()
16 | print()
17 |
18 | print('clone(t, copy_value=True):')
19 | print(clone(t, copy_value=True))
20 | print("Id of clone(t, copy_value=True).x.y: %x" % id(clone(t, copy_value=True)))
21 | print()
22 | print()
23 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/missing_demo.demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from treevalue import func_treelize, FastTreeValue
4 |
5 |
6 | @func_treelize(mode='outer', missing=lambda: [])
7 | def plus(a, b):
8 | return a + b
9 |
10 |
11 | if __name__ == '__main__':
12 | t1 = FastTreeValue({
13 | 'b': [2, 3],
14 | 'x': {
15 | 'c': [5],
16 | 'd': [7, 11, 13],
17 | 'e': [17, 19],
18 | }
19 | })
20 | t2 = FastTreeValue({
21 | 'a': [23],
22 | 'b': [29, 31],
23 | 'x': {
24 | 'c': [37],
25 | 'd': [41, 43],
26 | }
27 | })
28 |
29 | print('plus(t1, t2):', plus(t1, t2), sep=os.linesep)
30 |
--------------------------------------------------------------------------------
/treevalue/tree/integration/base.pyx:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | from ..tree.flatten cimport _c_flatten, _c_unflatten
5 |
6 | cdef inline tuple _c_flatten_for_integration(object tv):
7 | cdef list result = []
8 | _c_flatten(tv._detach(), (), result)
9 |
10 | cdef list paths = []
11 | cdef list values = []
12 | for path, value in result:
13 | paths.append(path)
14 | values.append(value)
15 |
16 | return values, (type(tv), paths)
17 |
18 | cdef inline object _c_unflatten_for_integration(object values, tuple spec):
19 | cdef object type_
20 | cdef list paths
21 | type_, paths = spec
22 | return type_(_c_unflatten(zip(paths, values)))
23 |
--------------------------------------------------------------------------------
/docs/source/best_practice/numpy/with_treevalue.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from treevalue import FastTreeValue
4 |
5 | T, B = 3, 4
6 | power = FastTreeValue.func()(np.power)
7 | stack = FastTreeValue.func(subside=True)(np.stack)
8 | split = FastTreeValue.func(rise=True)(np.split)
9 |
10 |
11 | def with_treevalue(batch_):
12 | batch_ = [FastTreeValue(b) for b in batch_]
13 | batch_ = stack(batch_)
14 | batch_ = batch_.astype(np.float32)
15 | batch_.b = power(batch_.b, 2) + 1.0
16 | batch_.c.noise = np.random.random(size=(B, 3, 4, 5))
17 | mean_b = batch_.b.mean()
18 | even_index_a = batch_.a[:, ::2]
19 | batch_ = split(batch_, indices_or_sections=B, axis=0)
20 | return batch_, mean_b, even_index_a
21 |
--------------------------------------------------------------------------------
/docs/source/api_doc/tree/graphics.demo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from treevalue import FastTreeValue, graphics, TreeValue
4 |
5 |
6 | class MyFastTreeValue(FastTreeValue):
7 | pass
8 |
9 |
10 | if __name__ == '__main__':
11 | t = MyFastTreeValue({
12 | 'a': 1,
13 | 'b': np.array([[5, 6], [7, 8]]),
14 | 'x': {
15 | 'c': 3,
16 | 'd': 4,
17 | 'e': np.array([[1, 2], [3, 4]])
18 | },
19 | })
20 | t2 = TreeValue({'ppp': t.x, 'x': {'t': t, 'y': t.x}})
21 |
22 | g = graphics(
23 | (t, 't'), (t2, 't2'),
24 | title="This is a demo of 2 trees.",
25 | cfg={'bgcolor': '#ffffff00'},
26 | )
27 | g.render('graphics.dat.gv', format='svg')
28 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/typetrans_demo.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import TreeValue, method_treelize, FastTreeValue, typetrans
2 |
3 |
4 | class MyTreeValue(TreeValue):
5 | @method_treelize()
6 | def pw(self):
7 | return (self + 1) * (self + 2)
8 |
9 |
10 | if __name__ == '__main__':
11 | t1 = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
12 | print('t1:')
13 | print(t1)
14 | print('t1 ** 2:')
15 | print(t1 ** 2)
16 | print()
17 |
18 | # Transform t1 to MyTreeValue,
19 | # __pow__ operator will be disabled and pw method will be enabled.
20 | t2 = typetrans(t1, MyTreeValue)
21 | print('t2:')
22 | print(t2)
23 | print('t2.pw():')
24 | print(t2.pw())
25 | print()
26 |
--------------------------------------------------------------------------------
/treevalue/tree/integration/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Type
2 |
3 | from .general import generic_flatten, generic_unflatten, register_integrate_container, generic_mapping
4 | from .jax import register_for_jax
5 | from .torch import register_for_torch
6 | from ..tree import TreeValue
7 |
8 |
9 | def register_treevalue_class(cls: Type[TreeValue], r_jax: bool = True, r_torch: bool = True):
10 | """
11 | Overview:
12 | Register treevalue class into all existing types.
13 |
14 | :param cls: TreeValue class.
15 | :param r_jax: Register for jax, default is `True`.
16 | :param r_torch: Register for torch, default is `True`.
17 | """
18 | if r_jax:
19 | register_for_jax(cls)
20 | if r_torch:
21 | register_for_torch(cls)
22 |
--------------------------------------------------------------------------------
/treevalue/tree/common/delay.pxd:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | from libcpp cimport bool
5 |
6 | cdef class DelayedProxy:
7 | cpdef object value(self)
8 | cpdef object fvalue(self)
9 |
10 | cdef class DelayedValueProxy(DelayedProxy):
11 | cdef readonly object func
12 | cdef readonly bool calculated
13 | cdef object val
14 |
15 | cpdef object value(self)
16 |
17 | cdef class DelayedFuncProxy(DelayedProxy):
18 | cdef readonly object func
19 | cdef readonly tuple args
20 | cdef readonly dict kwargs
21 | cdef readonly bool calculated
22 | cdef object val
23 |
24 | cpdef object value(self)
25 |
26 | cdef DelayedProxy _c_delayed_partial(func, args, kwargs)
27 | cpdef object undelay(object p, bool is_final= *)
28 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/diy_class_self_demo.demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from treevalue import TreeValue, method_treelize
4 |
5 |
6 | class MyTreeValue(TreeValue):
7 | # return type will be automatically detected as `MyTreeValue`
8 | @method_treelize(self_copy=True)
9 | def append(self, b):
10 | return self + b
11 |
12 |
13 | if __name__ == '__main__':
14 | t1 = MyTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
15 | t2 = TreeValue({'a': -14, 'b': 9, 'x': {'c': 3, 'd': 8}})
16 | t3 = TreeValue({'a': 6, 'b': 0, 'x': {'c': -5, 'd': 17}})
17 |
18 | print('t1:', t1, sep=os.linesep)
19 | _t1_id = id(t1)
20 |
21 | print('t1.append(t2).append(t3):',
22 | t1.append(t2).append(t3), sep=os.linesep)
23 | assert id(t1) == _t1_id
24 |
--------------------------------------------------------------------------------
/docs/source/api_doc/config/meta.rst:
--------------------------------------------------------------------------------
1 | treevalue.config.meta
2 | ==========================
3 |
4 | .. automodule:: treevalue.config.meta
5 |
6 | \_\_TITLE\_\_
7 | ------------------
8 |
9 | .. autodata:: treevalue.config.meta.__TITLE__
10 | :annotation:
11 |
12 |
13 | \_\_VERSION\_\_
14 | ------------------
15 |
16 | .. autodata:: treevalue.config.meta.__VERSION__
17 | :annotation:
18 |
19 |
20 | \_\_DESCRIPTION\_\_
21 | ----------------------
22 |
23 | .. autodata:: treevalue.config.meta.__DESCRIPTION__
24 | :annotation:
25 |
26 |
27 | \_\_AUTHOR\_\_
28 | ------------------
29 |
30 | .. autodata:: treevalue.config.meta.__AUTHOR__
31 | :annotation:
32 |
33 |
34 | \_\_AUTHOR_EMAIL\_\_
35 | ----------------------
36 |
37 | .. autodata:: treevalue.config.meta.__AUTHOR_EMAIL__
38 | :annotation:
39 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/dump_compression_demo.demo.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import os
3 |
4 | from treevalue import FastTreeValue, dumps, loads
5 |
6 | if __name__ == '__main__':
7 | t = FastTreeValue({'a': 'ab', 'b': 'Cd', 'x': {'c': 'eF', 'd': 'GH'}})
8 | st = t.upper
9 | print('st:', st, sep=os.linesep) # st is a function tree
10 | print('st():', st(), sep=os.linesep) # st() if an upper-case-string tree
11 | print()
12 |
13 | binary = dumps(st, compress=(gzip.compress, gzip.decompress))
14 | print('Length of binary:', len(binary))
15 |
16 | # compression function is not needed here
17 | dst = loads(binary, type_=FastTreeValue)
18 | print('dst:', dst, sep=os.linesep)
19 | assert st() == dst() # st() should be equal to dst()
20 | print('dst():', dst(), sep=os.linesep)
21 |
--------------------------------------------------------------------------------
/treevalue/tree/integration/jax.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from functools import wraps
3 |
4 | try:
5 | import jax
6 | from jax.tree_util import register_pytree_node
7 | except (ModuleNotFoundError, ImportError):
8 | from .cjax import register_for_jax as _original_register_for_jax
9 |
10 |
11 | @wraps(_original_register_for_jax)
12 | def register_for_jax(cls):
13 | warnings.warn(f'Jax doesn\'t have tree_util module due to either not installed '
14 | f'or the installed version is too low, '
15 | f'so the registration of {cls!r} will be ignored.')
16 | else:
17 | from .cjax import register_for_jax
18 | from ..tree import TreeValue
19 | from ..general import FastTreeValue
20 |
21 | register_for_jax(TreeValue)
22 | register_for_jax(FastTreeValue)
23 |
--------------------------------------------------------------------------------
/docs/source/notebook.mk:
--------------------------------------------------------------------------------
1 | JUPYTER ?= $(shell which jupyter)
2 | NBCONVERT ?= ${JUPYTER} nbconvert
3 |
4 | SOURCE ?= .
5 | IPYNBS := $(shell find ${SOURCE} -name *.ipynb -not -name *.result.ipynb)
6 | RESULTS := $(addsuffix .result.ipynb, $(basename ${IPYNBS}))
7 |
8 | %.result.ipynb: %.ipynb
9 | cp "$(shell readlink -f $<)" "$(shell readlink -f $@)" && \
10 | cd "$(shell dirname $(shell readlink -f $<))" && \
11 | PYTHONPATH="$(shell dirname $(shell readlink -f $<)):${PYTHONPATH}" \
12 | $(NBCONVERT) --to notebook --inplace --execute "$(shell readlink -f $@)"
13 |
14 | build: ${RESULTS}
15 |
16 | all: build
17 |
18 | clean:
19 | rm -rf \
20 | $(shell find ${SOURCE} -name *.result.ipynb)
21 | for nb in ${IPYNBS}; do \
22 | if [ -f $$nb ]; then \
23 | $(NBCONVERT) --clear-output --inplace $$nb; \
24 | fi; \
25 | done;
--------------------------------------------------------------------------------
/docs/source/api_doc/tree/func.rst:
--------------------------------------------------------------------------------
1 | treevalue.tree.func
2 | =============================
3 |
4 | .. py:currentmodule:: treevalue.tree.func
5 |
6 | .. _apidoc_tree_func_functreelize:
7 |
8 | func_treelize
9 | --------------------
10 |
11 | .. autofunction:: func_treelize
12 |
13 |
14 | .. _apidoc_tree_func_methodtreelize:
15 |
16 | method_treelize
17 | --------------------
18 |
19 | .. autofunction:: method_treelize
20 |
21 |
22 | .. _apidoc_tree_func_classmethodtreelize:
23 |
24 | classmethod_treelize
25 | --------------------
26 |
27 | .. autofunction:: classmethod_treelize
28 |
29 |
30 | MISSING_NOT_ALLOW
31 | -------------------------
32 |
33 | .. autodata:: MISSING_NOT_ALLOW
34 | :annotation:
35 |
36 |
37 | AUTO_DETECT_RETURN_TYPE
38 | ----------------------------
39 |
40 | .. autodata:: AUTO_DETECT_RETURN_TYPE
41 | :annotation:
42 |
--------------------------------------------------------------------------------
/docs/source/comparison/generic.rst:
--------------------------------------------------------------------------------
1 | What and Why to Compare
2 | ====================================
3 |
4 | In this part, TreeValue will be compared with several different similar libraries \
5 | in terms of functionality and performance. It contains the following sections:
6 |
7 | * Run Environment Information
8 | * Comparison to DM-Tree
9 |
10 | More comparison will be take out soon.
11 |
12 | .. note::
13 | Please note that **the core advantage of treevalue is actually \
14 | to model and simplify the writing process of programs, \
15 | not to provide higher computing performance**. \
16 | In fact, in large-scale operations, the performance advantage of treevalue \
17 | will tend to be marginal. Therefore, this part of the comparison will not only \
18 | focus on the performance differences, but will also briefly discuss \
19 | the capabilities and potential of the different libraries.
--------------------------------------------------------------------------------
/test/tree/common/test_base.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 | import pytest
4 |
5 | from treevalue.tree.common import raw, unraw, RawWrapper
6 |
7 |
8 | @pytest.mark.unittest
9 | class TestTreeBase:
10 | def test_raw(self):
11 | assert raw(1) == 1
12 | assert raw('sdklfgj') == 'sdklfgj'
13 |
14 | h = {'a': 1, 'b': 2}
15 | r = raw(h)
16 | assert isinstance(r, RawWrapper)
17 | assert r.value() is h
18 |
19 | def test_unraw(self):
20 | assert unraw(1) == 1
21 | assert unraw('sdklfgj') == 'sdklfgj'
22 |
23 | h = {'a': 1, 'b': 2}
24 | r = raw(h)
25 | u = unraw(r)
26 | assert u is h
27 |
28 | def test_pickle(self):
29 | h = {'a': 1, 'b': 2}
30 | r = raw(h)
31 |
32 | bt = pickle.dumps(r)
33 | nt = pickle.loads(bt)
34 |
35 | assert isinstance(nt, RawWrapper)
36 | assert nt.value() == h
37 |
--------------------------------------------------------------------------------
/docs/source/tutorials/basic_usage/slice_index_operation.gv:
--------------------------------------------------------------------------------
1 | digraph slice_index {
2 | graph [bgcolor = "#ffffff00"];
3 |
4 | subgraph cluster_t_xx_1 {
5 | label = "Result of t[::-1]"
6 | root2 [label = "t[::-1]"];
7 | n21 [label = "[3, 2, 1] = t.a[::-1"];
8 | n22 [label = "[16, 9, 4] = t.b[::-1]"];
9 | n23 [label = "t[::-1].x"];
10 | n24 [label = "[17, 13, 11] = t.x.c[::-1]"];
11 | n25 [label = "[-8, -4, -2] = t.x.d[::-1]"];
12 | root2 -> n21 [label = "a"];
13 | root2 -> n22 [label = "b"];
14 | root2 -> n23 [label = "x"];
15 | n23 -> n24 [label = "c"];
16 | n23 -> n25 [label = "d"];
17 | }
18 |
19 | subgraph cluster_t_x_1_0 {
20 | label = "Result of t.x[1:]"
21 | n33 [label = "t.x[1:]"];
22 | n34 [label = "[13, 17] = t.x.c[1:]"];
23 | n35 [label = "[-4, -8] = t.x.d[1:]"];
24 | n33 -> n34 [label = "c"];
25 | n33 -> n35 [label = "d"];
26 | }
27 | }
--------------------------------------------------------------------------------
/test/tree/general/test_fast.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from treevalue.tree import FastTreeValue, method_treelize
4 | from .base import get_fasttreevalue_test
5 |
6 |
7 | class MyFastTreeValue(FastTreeValue):
8 | @method_treelize(missing=0, mode='outer')
9 | def __add__(self, other):
10 | return self + other
11 |
12 | @method_treelize(missing=0, mode='outer')
13 | def __radd(self, other):
14 | return other + self
15 |
16 |
17 | @pytest.mark.unittest
18 | class TestTreeGeneralFast(get_fasttreevalue_test(FastTreeValue)):
19 | def test_my_fast_tree_value(self):
20 | t1 = MyFastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4, 'e': 7}})
21 | t2 = MyFastTreeValue({'a': 11, 'b': 22, 'c': 4, 'x': {'c': 33, 'd': 5}})
22 | assert (t1 + t2) == MyFastTreeValue({'a': 12, 'b': 24, 'c': 4, 'x': {'c': 36, 'd': 9, 'e': 7}})
23 |
24 | with pytest.raises(KeyError):
25 | _ = t1 - t2
26 |
--------------------------------------------------------------------------------
/docs/source/_templates/page.html:
--------------------------------------------------------------------------------
1 | {% extends "!page.html" %}
2 | {% block body %}
3 | {% if current_version and latest_version and current_version != latest_version %}
4 |
5 |
6 | {% if current_version.is_released %}
7 | You're reading an old version of this documentation.
8 | If you want up-to-date information, please have a look at
9 | {{ latest_version.name }} .
10 | {% else %}
11 | You're reading the documentation for a development version.
12 | For the latest released version, please have a look at
13 | {{ latest_version.name }} .
14 | {% endif %}
15 |
16 |
17 | {% endif %}
18 | {{ super() }}
19 | {% endblock %}%
--------------------------------------------------------------------------------
/docs/source/tutorials/basic_usage/edit_tree.demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from treevalue import FastTreeValue
4 |
5 | if __name__ == '__main__':
6 | t = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
7 | print("Original tree:", t, sep=os.linesep)
8 |
9 | # Get values
10 | print("Value of t.a: ", t.a)
11 | print("Value of t.x.c:", t.x.c)
12 | print("Value of t.x:", t.x, sep=os.linesep)
13 |
14 | # Set values
15 | t.a = 233
16 | print("Value after t.a = 233:", t, sep=os.linesep)
17 | t.x.d = -1
18 | print("Value after t.x.d = -1:", t, sep=os.linesep)
19 | t.x = FastTreeValue({'e': 5, 'f': 6})
20 | print("Value after t.x = FastTreeValue({'e': 5, 'f': 6}):", t, sep=os.linesep)
21 | t.x.g = {'e': 5, 'f': 6}
22 | print("Value after t.x.g = {'e': 5, 'f': 6}:", t, sep=os.linesep)
23 |
24 | # Delete values
25 | del t.x.g
26 | print("Value after del t.x.g:", t, sep=os.linesep)
27 |
--------------------------------------------------------------------------------
/docs/source/tutorials/basic_usage/tree_support_1.gv:
--------------------------------------------------------------------------------
1 | digraph tree_support {
2 | graph [bgcolor = "#ffffff00"];
3 |
4 | subgraph cluster_t1 {
5 | label = "Tree 't1'"
6 | root0 [label = "t1"];
7 | n01 [label = "2"];
8 | n02 [label = "30"];
9 | n03 [label = "t1.x"];
10 | n04 [label = "4"];
11 | n05 [label = "9"];
12 | root0 -> n01 [label = "a"];
13 | root0 -> n02 [label = "b"];
14 | root0 -> n03 [label = "x"];
15 | n03 -> n04 [label = "c"];
16 | n03 -> n05 [label = "d"];
17 | }
18 |
19 | subgraph cluster_t2 {
20 | label = "Tree 't2'"
21 | root1 [label = "t2"];
22 | n11 [label = "4"];
23 | n12 [label = "48"];
24 | n13 [label = "t2.x"];
25 | n14 [label = "6"];
26 | n15 [label = "54"];
27 | root1 -> n11 [label = "a"];
28 | root1 -> n12 [label = "b"];
29 | root1 -> n13 [label = "x"];
30 | n13 -> n14 [label = "c"];
31 | n13 -> n15 [label = "d"];
32 | }
33 | }
--------------------------------------------------------------------------------
/docs/source/tutorials/main_idea/treevalue_demo.gv:
--------------------------------------------------------------------------------
1 | digraph {
2 | label = "A simple example of TreeValue t";
3 | graph [bgcolor = "#ffffff00"]
4 |
5 | root [label = "t" shape = diamond];
6 | n1 [label = "array([[1, 2],\n [3, 4]])" shape = box];
7 | n2 [label = "2" shape = box];
8 | n3 [label = "t.x"];
9 | n4 [label = "[1, 2]" shape = box];
10 | n5 [label = "{3, 4}" shape = box];
11 | n6 [label = "t.x.e"];
12 | n7 [label = "{'a': 1, 'b': 2}" shape = box];
13 | n8 [label = "4" shape = box];
14 | n9 [label = "t.y"];
15 | n10 [label = "5" shape = box];
16 | n11 [label = "'string'" shape = box];
17 |
18 | root -> n1 [label = "a"];
19 | root -> n2 [label = "b"];
20 | root -> n3 [label = "x"];
21 | root -> n9 [label = "y"];
22 | n3 -> n4 [label = "c"];
23 | n3 -> n5 [label = "d"];
24 | n3 -> n6 [label = "e"];
25 | n6 -> n7 [label = "f"];
26 | n6 -> n8 [label = "g"];
27 | n9 -> n10 [label = "h"];
28 | n9 -> n11 [label = "a"];
29 | }
--------------------------------------------------------------------------------
/test/tree/integration/test_init.py:
--------------------------------------------------------------------------------
1 | from unittest import skipUnless
2 |
3 | import pytest
4 |
5 | from treevalue import register_treevalue_class, FastTreeValue
6 |
7 | try:
8 | import torch
9 | except (ImportError, ModuleNotFoundError):
10 | torch = None
11 |
12 | try:
13 | import jax
14 | except (ModuleNotFoundError, ImportError):
15 | jax = None
16 |
17 |
18 | @pytest.mark.unittest
19 | class TestTreeIntegrationInit:
20 | @skipUnless(torch and jax, 'Torch and jax required.')
21 | def test_register_custom_class_all(self):
22 | class MyTreeValue(FastTreeValue):
23 | pass
24 |
25 | with pytest.warns(None):
26 | register_treevalue_class(MyTreeValue)
27 |
28 | @skipUnless(not torch or not jax, 'Not all torch and jax required.')
29 | def test_register_custom_class_some(self):
30 | class MyTreeValue(FastTreeValue):
31 | pass
32 |
33 | with pytest.warns(UserWarning):
34 | register_treevalue_class(MyTreeValue)
35 |
--------------------------------------------------------------------------------
/docs/source/tutorials/basic_usage/index_operation.gv:
--------------------------------------------------------------------------------
1 | digraph index {
2 | graph [bgcolor = "#ffffff00"];
3 |
4 | subgraph cluster_t {
5 | label = "Tree 't'"
6 | root0 [label = "t"];
7 | n01 [label = "[1, 2, 3]"];
8 | n02 [label = "[4, 9, 16]"];
9 | n03 [label = "t.x"];
10 | n04 [label = "[11, 13, 17]"];
11 | n05 [label = "[-2, -4, -8]"];
12 | root0 -> n01 [label = "a"];
13 | root0 -> n02 [label = "b"];
14 | root0 -> n03 [label = "x"];
15 | n03 -> n04 [label = "c"];
16 | n03 -> n05 [label = "d"];
17 | }
18 |
19 | subgraph cluster_t_0 {
20 | label = "Result of t[0]"
21 | root1 [label = "t[0]"];
22 | n11 [label = "1 = t.a[0]"];
23 | n12 [label = "4 = t.b[0]"];
24 | n13 [label = "t[0].x"];
25 | n14 [label = "11 = t.x.c[0]"];
26 | n15 [label = "-2 = t.x.d[0]"];
27 | root1 -> n11 [label = "a"];
28 | root1 -> n12 [label = "b"];
29 | root1 -> n13 [label = "x"];
30 | n13 -> n14 [label = "c"];
31 | n13 -> n15 [label = "d"];
32 | }
33 |
34 | }
--------------------------------------------------------------------------------
/treevalue/tree/func/cfunc.pxd:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | from libcpp cimport bool
5 |
6 | from .modes cimport _e_tree_mode
7 |
8 | cdef object _c_wrap_func_treelize_run(object func, list args, dict kwargs, _e_tree_mode mode, bool inherit,
9 | bool allow_missing, object missing_func, bool delayed)
10 | cdef object _c_func_treelize_run(object func, list args, dict kwargs, _e_tree_mode mode, bool inherit,
11 | bool allow_missing, object missing_func, bool delayed)
12 |
13 | cpdef object _d_func_treelize(object func, object mode, object return_type, bool inherit, object missing,
14 | bool delayed, object subside, object rise)
15 | cdef object _c_common_value(object item)
16 | cdef tuple _c_missing_process(object missing)
17 | cpdef object func_treelize(object mode= *, object return_type= *, bool inherit= *, object missing= *,
18 | bool delayed= *, object subside= *, object rise= *)
19 |
--------------------------------------------------------------------------------
/docs/source/best_practice/numpy/numpy.demo.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import numpy as np
4 | from with_treevalue import with_treevalue
5 | from without_treevalue import without_treevalue
6 |
7 | T, B = 3, 4
8 |
9 |
10 | def get_data():
11 | return {
12 | 'a': np.random.random(size=(T, 8)),
13 | 'b': np.random.random(size=(6,)),
14 | 'c': {
15 | 'd': np.random.randint(0, 10, size=(1,))
16 | }
17 | }
18 |
19 |
20 | if __name__ == "__main__":
21 | batch = [get_data() for _ in range(B)]
22 | batch0, mean0, even_index_a0 = without_treevalue(copy.deepcopy(batch))
23 | batch1, mean1, even_index_a1 = with_treevalue(copy.deepcopy(batch))
24 |
25 | assert np.abs(mean0 - mean1) < 1e-6
26 | print('mean0 & mean1:', mean0, mean1)
27 | print('\n')
28 |
29 | assert np.abs((even_index_a0 - even_index_a1).max()) < 1e-6
30 | print('even_index_a0:', even_index_a0)
31 | print('even_index_a1:', even_index_a1)
32 |
33 | assert len(batch0) == B
34 | assert len(batch1) == B
35 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/strict_demo_show.demox.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from treevalue import FastTreeValue, func_treelize
4 |
5 |
6 | @func_treelize(mode='strict')
7 | def plus(a, b):
8 | print("Current a and b:", type(a), a, type(b), b)
9 | return a + b
10 |
11 |
12 | if __name__ == '__main__':
13 | t1 = FastTreeValue({'a': 11, 'b': 22, 'x': {'c': 30, 'd': 48, 'e': 54}})
14 | t2 = FastTreeValue({
15 | 'a': np.array([1, 3, 5, 7]),
16 | 'b': np.array([2, 5, 8, 11]),
17 | 'x': {
18 | 'c': 3,
19 | 'd': 4.8,
20 | 'e': np.array([[2, 3], [5, 6]]),
21 | }
22 | })
23 | t3 = FastTreeValue({
24 | 'a': np.array([1, 3, 5, 7]),
25 | 'b': [2, 5, 8, 11],
26 | 'x': {
27 | 'c': 3,
28 | 'd': 4.8,
29 | 'e': np.array([[2, 3], [5, 6]]),
30 | }
31 | })
32 |
33 | print('plus(t1, t2):', plus(t1, t2))
34 | print()
35 | print('plus(t1, t3):', plus(t1, t3))
36 | print()
37 |
--------------------------------------------------------------------------------
/treevalue/tree/integration/general.pxd:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | from libcpp cimport bool
5 |
6 | cdef tuple _dict_flatten(object d)
7 | cdef object _dict_unflatten(list values, tuple spec)
8 |
9 | cdef tuple _list_and_tuple_flatten(object l)
10 | cdef object _list_and_tuple_unflatten(list values, object spec)
11 |
12 | cdef tuple _namedtuple_flatten(object l)
13 | cdef object _namedtuple_unflatten(list values, object spec)
14 |
15 | cdef tuple _treevalue_flatten(object l)
16 | cdef object _treevalue_unflatten(list values, tuple spec)
17 |
18 | cdef bool _is_namedtuple_instance(pytree) except*
19 |
20 | cpdef void register_integrate_container(object type_, object flatten_func, object unflatten_func) except*
21 |
22 | cdef tuple _c_get_flatted_values_and_spec(object v)
23 | cdef object _c_get_object_from_flatted(object values, object type_, object spec)
24 |
25 | cpdef object generic_flatten(object v)
26 | cpdef object generic_unflatten(object v, tuple gspec)
27 | cpdef object generic_mapping(object v, object func)
28 |
--------------------------------------------------------------------------------
/test/tests/utils.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | from typing import Callable
3 |
4 |
5 | def eq_extend(func: Callable[..., bool]):
6 | @wraps(func)
7 | def _new_func(a, b, *args, **kwargs):
8 | if isinstance(a, dict) and isinstance(b, dict):
9 | aks, bks = set(a.keys()), set(b.keys())
10 | if aks != bks:
11 | return False
12 | else:
13 | return all([_new_func(a[key], b[key], *args, **kwargs) for key in aks])
14 | elif (isinstance(a, tuple) and isinstance(b, tuple)) \
15 | or (isinstance(a, list) and isinstance(b, list)):
16 | length_a, length_b = len(a), len(b)
17 | if length_a != length_b:
18 | return False
19 | else:
20 | return all([_new_func(ai, bi, *args, **kwargs) for ai, bi in zip(a, b)])
21 | else:
22 | return func(a, b, *args, **kwargs)
23 |
24 | return _new_func
25 |
26 |
27 | @eq_extend
28 | def float_eq(a, b, eps=1e-5):
29 | return abs(a - b) < abs(eps)
30 |
--------------------------------------------------------------------------------
/docs/source/tutorials/installation/index.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ===================
3 |
4 | Treevalue is currently hosted on PyPI. It required python >= 3.6.
5 |
6 | You can simply install Treevalue from PyPI with the following command:
7 |
8 | .. code:: shell
9 |
10 | pip install treevalue
11 |
12 | You can also install with the newest version through GitHub:
13 |
14 | .. code:: shell
15 |
16 | pip install -U git+https://github.com/opendilab/treevalue.git@main
17 |
18 | After installation, open your shell console and use \
19 | the cli like the script below.
20 |
21 | .. literalinclude:: cli_demo.demo.sh
22 | :language: shell
23 | :linenos:
24 |
25 | .. literalinclude:: cli_demo.demo.sh.txt
26 | :language: text
27 | :linenos:
28 |
29 | In newest version of treevalue, cli is supported to do some \
30 | data processing. Here is the version and help display.
31 |
32 | Treevalue is still under development, you can also check out the documents in stable version through `https://opendilab.github.io/treevalue/ `_.
33 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/diy_class_demo.demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from treevalue import classmethod_treelize, TreeValue, method_treelize
4 |
5 |
6 | class MyTreeValue(TreeValue):
7 | # return type will be automatically detected as `MyTreeValue`
8 | @method_treelize()
9 | def append(self, b):
10 | print("Append arguments:", self, b)
11 | return self + b
12 |
13 | # return type will be automatically detected as `MyTreeValue`
14 | @classmethod
15 | @classmethod_treelize()
16 | def sum(cls, *args):
17 | print("Sum arguments:", cls, *args)
18 | return sum(args)
19 |
20 |
21 | if __name__ == '__main__':
22 | t1 = MyTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
23 | t2 = TreeValue({'a': -14, 'b': 9, 'x': {'c': 3, 'd': 8}})
24 | t3 = TreeValue({'a': 6, 'b': 0, 'x': {'c': -5, 'd': 17}})
25 |
26 | print('t1.append(t2).append(t3):',
27 | t1.append(t2).append(t3), sep=os.linesep)
28 |
29 | print('MyTreeValue.sum(t1, t2, t3):',
30 | MyTreeValue.sum(t1, t2, t3), sep=os.linesep)
31 |
--------------------------------------------------------------------------------
/treevalue/tree/integration/torch.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from functools import wraps
3 |
4 | from ..tree import register_dict_type
5 |
6 | try:
7 | import torch
8 | from torch.utils._pytree import _register_pytree_node
9 | except (ModuleNotFoundError, ImportError):
10 | from .ctorch import register_for_torch as _original_register_for_torch
11 |
12 |
13 | @wraps(_original_register_for_torch)
14 | def register_for_torch(cls):
15 | warnings.warn(f'Pytree module is not included in the Torch installation '
16 | f'or the installed version is too low, '
17 | f'so the registration of {cls!r} will be ignored.')
18 | else:
19 | from .ctorch import register_for_torch
20 | from ..tree import TreeValue
21 | from ..general import FastTreeValue
22 |
23 | register_for_torch(TreeValue)
24 | register_for_torch(FastTreeValue)
25 |
26 | try:
27 | from torch.nn import ModuleDict
28 | except (ModuleNotFoundError, ImportError):
29 | pass
30 | else:
31 | register_dict_type(ModuleDict, ModuleDict.items)
32 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/rise_demo_1.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import TreeValue, subside, rise
2 |
3 | if __name__ == '__main__':
4 | # The same demo as the subside docs
5 | t1 = TreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
6 | t2 = TreeValue({'a': -14, 'b': 9, 'x': {'c': 3, 'd': 8}})
7 | t3 = TreeValue({'a': 6, 'b': 0, 'x': {'c': -5, 'd': 17}})
8 | t4 = TreeValue({'a': 0, 'b': -17, 'x': {'c': -8, 'd': 15}})
9 | t5 = TreeValue({'a': 3, 'b': 9, 'x': {'c': 11, 'd': -17}})
10 | st = {'first': (t1, t2), 'second': (t3, {'x': t4, 'y': t5})}
11 | tx = subside(st)
12 |
13 | # Rising process
14 | st2 = rise(tx)
15 | assert st2 == st
16 | print('st2:', st2)
17 |
18 | print("st2['first'][0]:")
19 | print(st2['first'][0])
20 |
21 | print("st2['first'][1]:")
22 | print(st2['first'][1])
23 |
24 | print("st2['second'][0]:")
25 | print(st2['second'][0])
26 |
27 | print("st2['second'][1]['x']:")
28 | print(st2['second'][1]['x'])
29 |
30 | print("st2['second'][1]['y']:")
31 | print(st2['second'][1]['y'])
32 |
--------------------------------------------------------------------------------
/docs/source/tutorials/basic_usage/calculation_sub_and_xor.gv:
--------------------------------------------------------------------------------
1 | digraph calculation_sub_and_xor {
2 | graph [bgcolor = "#ffffff00"];
3 |
4 | subgraph cluster_sub {
5 | label = "Result of t1 - t2"
6 | root1 [label = "t1 - t2"];
7 | n11 [label = "-2 = 1 - 3"];
8 | n12 [label = "-5 = 2 - 7"];
9 | n13 [label = "t1.x - t2.x"];
10 | n14 [label = "-11 = 3 - 14"];
11 | n15 [label = "9 = 4 - (-5)"];
12 | root1 -> n11 [label = "a"];
13 | root1 -> n12 [label = "b"];
14 | root1 -> n13 [label = "x"];
15 | n13 -> n14 [label = "c"];
16 | n13 -> n15 [label = "d"];
17 | }
18 |
19 | subgraph cluster_xor {
20 | label = "Result of t1 ^ t2"
21 | root2 [label = "t1 ^ t2"];
22 | n21 [label = "2 = 1 ^ 3"];
23 | n22 [label = "5 = 2 ^ 7"];
24 | n23 [label = "t1.x ^ t2.x"];
25 | n24 [label = "13 = 3 ^ 14"];
26 | n25 [label = "-1 = 4 ^ (-5)"];
27 | root2 -> n21 [label = "a"];
28 | root2 -> n22 [label = "b"];
29 | root2 -> n23 [label = "x"];
30 | n23 -> n24 [label = "c"];
31 | n23 -> n25 [label = "d"];
32 | }
33 | }
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/custom.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Custom issue template
3 | about: Describe this issue template's purpose here.
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | - [ ] I have marked all applicable categories:
11 | + [ ] installation bug
12 | + [ ] exception-raising bug
13 | + [ ] data model bug
14 | + [ ] tree utils bug
15 | + [ ] function treelize bug (including method and classmethod)
16 | + [ ] code design/refactor
17 | + [ ] documentation request
18 | + [ ] new feature request
19 | - [ ] I have visited the [readme](https://github.com/opendilab/treevalue/blob/main/README.md) and [doc](https://opendilab.github.io/treevalue/main/index.html)
20 | - [ ] I have searched through the [issue tracker](https://github.com/opendilab/treevalue/issues) and [pr tracker](https://github.com/opendilab/treevalue/pulls)
21 | - [ ] I have mentioned version numbers, operating system and environment, where applicable:
22 |
23 | ```python
24 | import sys
25 |
26 | import treevalue
27 |
28 | print(treevalue.__version__, sys.version, sys.platform)
29 | ```
30 |
--------------------------------------------------------------------------------
/treevalue/utils/random.py:
--------------------------------------------------------------------------------
1 | import random
2 | from datetime import datetime
3 |
4 |
5 | def random_hex(length: int = 32) -> str:
6 | """
7 | Overview:
8 | Generate random hex string.
9 |
10 | Arguments:
11 | - length (:obj:`int`): Length of hex string, default is `32`.
12 |
13 | Returns:
14 | - string (:obj:`str`): Generated string.
15 |
16 | Examples:
17 | >>> random_hex() # 'ca7f14b25aa4498efdacb54e9ff72784'
18 | """
19 | return ''.join([hex(random.randint(0, 15))[2:] for _ in range(length)])
20 |
21 |
22 | def random_hex_with_timestamp(length: int = 12) -> str:
23 | """
24 | Overview:
25 | Generate random hex string, with prefix of timestamp.
26 |
27 | Arguments:
28 | - length (:obj:`int`): Length of hex string, default is `12`.
29 |
30 | Returns:
31 | - string (:obj:`str`): Generated string.
32 |
33 | Examples:
34 | >>> random_hex_with_timestamp() # '20210729_202059576266_69603d64afad'
35 | """
36 | return datetime.now().strftime("%Y%m%d_%H%M%S%f") + "_" + random_hex(length)
37 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/diy_class_x_demo_2.demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from treevalue import general_tree_value, FastTreeValue
4 |
5 |
6 | class AddZeroTreeValue(general_tree_value(methods=dict(
7 | __add__=dict(missing=0, mode='outer'),
8 | __radd__=dict(missing=0, mode='outer'),
9 | __iadd__=dict(missing=0, mode='outer'),
10 | ))):
11 | pass
12 |
13 |
14 | class MulOneTreeValue(general_tree_value(methods=dict(
15 | __mul__=dict(missing=1, mode='outer'),
16 | __rmul__=dict(missing=1, mode='outer'),
17 | __imul__=dict(missing=1, mode='outer'),
18 | ))):
19 | pass
20 |
21 |
22 | if __name__ == '__main__':
23 | t1 = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3}})
24 | t2 = FastTreeValue({'a': 11, 'x': {'c': 30, 'd': 47}})
25 |
26 | # __add__ with default value 0
27 | at1 = t1.type(AddZeroTreeValue)
28 | at2 = t2.type(AddZeroTreeValue)
29 | print('at1 + at2:', at1 + at2, sep=os.linesep)
30 |
31 | # __mul__ with default value 1
32 | mt1 = t1.type(MulOneTreeValue)
33 | mt2 = t2.type(MulOneTreeValue)
34 | print('mt1 * mt2:', mt1 * mt2, sep=os.linesep)
35 |
--------------------------------------------------------------------------------
/docs/source/_templates/versions.html:
--------------------------------------------------------------------------------
1 | {%- if current_version %}
2 |
3 |
4 | Other Versions
5 | v: {{ current_version.name }}
6 |
7 |
8 |
9 | {%- if versions.tags %}
10 |
11 | Tags
12 | {%- for item in versions.tags %}
13 | {{ item.name }}
14 | {%- endfor %}
15 |
16 | {%- endif %}
17 | {%- if versions.branches %}
18 |
19 | Branches
20 | {%- for item in versions.branches %}
21 | {{ item.name }}
22 | {%- endfor %}
23 |
24 | {%- endif %}
25 |
26 |
27 | {%- endif %}
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/left_demo_2.gv:
--------------------------------------------------------------------------------
1 | digraph {
2 | label = "left_demo_2 \n(KeyError raised due to t2.x.f's non-existence)";
3 | graph [bgcolor = "#ffffff00"];
4 |
5 | subgraph cluster_t1 {
6 | label = "Tree t1";
7 | t1 [label = t1];
8 | t1_b [label = 2];
9 | t1_x [label = "t1.x"];
10 | t1_x_c [label = 3];
11 | t1_x_d [label = 4];
12 | t1_x_e [label = 5];
13 | t1_x_f [label = 6];
14 |
15 | t1 -> t1_b [label = "b"];
16 | t1 -> t1_x [label = "x"];
17 | t1_x -> t1_x_c [label = "c"];
18 | t1_x -> t1_x_d [label = "d"];
19 | t1_x -> t1_x_e [label = "e"];
20 | t1_x -> t1_x_f [label = "f"];
21 | }
22 |
23 | subgraph cluster_t2 {
24 | label = "Tree t2";
25 | t2 [label = t2];
26 | t2_a [label = 11];
27 | t2_b [label = 22];
28 | t2_x [label = "t2.x"];
29 | t2_x_c [label = 30];
30 | t2_x_d [label = 48];
31 | t2_x_e [label = 54];
32 |
33 | t2 -> t2_a [label = "a"];
34 | t2 -> t2_b [label = "b"];
35 | t2 -> t2_x [label = "x"];
36 | t2_x -> t2_x_c [label = "c"];
37 | t2_x -> t2_x_d [label = "d"];
38 | t2_x -> t2_x_e [label = "e"];
39 | }
40 | }
--------------------------------------------------------------------------------
/treevalue/tree/tree/structural.pxd:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | # subside, union, rise
5 |
6 | from libcpp cimport bool
7 |
8 | cdef class _SubsideCall:
9 | cdef object run
10 |
11 | cdef object _c_subside_process(tuple value, object it)
12 | cdef tuple _c_subside_build(object value, bool dict_, bool list_, bool tuple_)
13 | cdef void _c_subside_missing()
14 | cdef object _c_subside(object value, bool dict_, bool list_, bool tuple_, bool inherit,
15 | object mode, object missing, bool delayed)
16 | cdef object _c_subside_keep_type(object t)
17 | cpdef object subside(object value, bool dict_= *, bool list_= *, bool tuple_= *,
18 | object return_type= *, bool inherit= *, object mode= *, object missing= *, bool delayed= *)
19 |
20 | cdef object _c_rise_tree_builder(tuple p, object it)
21 | cdef tuple _c_rise_tree_process(object t)
22 | cdef object _c_rise_struct_builder(tuple p, object it)
23 | cdef tuple _c_rise_struct_process(list objs, object template)
24 | cdef object _c_rise_keep_type(object t)
25 | cdef object _c_rise(object tree, bool dict_, bool list_, bool tuple_, object template_)
26 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/strict_demo_2.gv:
--------------------------------------------------------------------------------
1 | digraph {
2 | label = "strict_demo_2 \n(KeyError raised due to t1.a and t2.x.f's non-existence)";
3 | graph [bgcolor = "#ffffff00"];
4 |
5 | subgraph cluster_t1 {
6 | label = "Tree t1";
7 | t1 [label = t1];
8 | t1_b [label = 2];
9 | t1_x [label = "t1.x"];
10 | t1_x_c [label = 3];
11 | t1_x_d [label = 4];
12 | t1_x_e [label = 5];
13 | t1_x_f [label = 6];
14 |
15 | t1 -> t1_b [label = "b"];
16 | t1 -> t1_x [label = "x"];
17 | t1_x -> t1_x_c [label = "c"];
18 | t1_x -> t1_x_d [label = "d"];
19 | t1_x -> t1_x_e [label = "e"];
20 | t1_x -> t1_x_f [label = "f"];
21 | }
22 |
23 | subgraph cluster_t2 {
24 | label = "Tree t2";
25 | t2 [label = t2];
26 | t2_a [label = 11];
27 | t2_b [label = 22];
28 | t2_x [label = "t2.x"];
29 | t2_x_c [label = 30];
30 | t2_x_d [label = 48];
31 | t2_x_e [label = 54];
32 |
33 | t2 -> t2_a [label = "a"];
34 | t2 -> t2_b [label = "b"];
35 | t2 -> t2_x [label = "x"];
36 | t2_x -> t2_x_c [label = "c"];
37 | t2_x -> t2_x_d [label = "d"];
38 | t2_x -> t2_x_e [label = "e"];
39 | }
40 | }
--------------------------------------------------------------------------------
/test/tree/general/test_general_benchmark.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from functools import lru_cache
3 | from typing import Optional
4 |
5 | from hbutils.testing import vpip
6 |
7 | try:
8 | import torch
9 | except ImportError:
10 | torch = None
11 |
12 | import pytest
13 |
14 | from treevalue import TreeValue, func_treelize, FastTreeValue
15 |
16 |
17 | @lru_cache()
18 | def _get_tree() -> Optional[FastTreeValue]:
19 | if torch is not None:
20 | _TREE_DATA_1 = {'a': torch.randn(2, 3), 'x': {'c': torch.randn(3, 4)}}
21 | return FastTreeValue(_TREE_DATA_1)
22 | else:
23 | return None
24 |
25 |
26 | @pytest.mark.benchmark(group='treevalue_dynamic')
27 | @unittest.skipUnless(vpip('torch') >= '1.1.0', 'Torch>=1.1.0 only')
28 | class TestTreeGeneralBenchmark:
29 | def test_dynamic_execute(self, benchmark):
30 | def sin(t):
31 | return t.sin()
32 |
33 | return benchmark(sin, _get_tree())
34 |
35 | def test_static_execute(self, benchmark):
36 | sinf = func_treelize(return_type=TreeValue)(torch.sin)
37 |
38 | def sin(t):
39 | return sinf(t)
40 |
41 | return benchmark(sin, _get_tree())
42 |
--------------------------------------------------------------------------------
/treevalue/tree/tree/functional.pxd:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | # mapping, filter_, mask, reduce_
5 |
6 | from libcpp cimport bool
7 |
8 | from .tree cimport TreeValue
9 | from ..common.storage cimport TreeStorage
10 |
11 | cdef object _c_no_arg(object func, object v, object p)
12 | cdef object _c_one_arg(object func, object v, object p)
13 | cdef object _c_two_args(object func, object v, object p)
14 | cdef object _c_wrap_mapping_func(object func)
15 | cdef object _c_delayed_mapping(object so, object func, tuple path, bool delayed)
16 | cdef TreeStorage _c_mapping(TreeStorage st, object func, tuple path, bool delayed)
17 | cpdef TreeValue mapping(TreeValue tree, object func, bool delayed= *)
18 | cdef TreeStorage _c_filter_(TreeStorage st, object func, tuple path, bool remove_empty)
19 | cpdef TreeValue filter_(TreeValue tree, object func, bool remove_empty= *)
20 | cdef object _c_mask(TreeStorage st, object sm, tuple path, bool remove_empty)
21 | cpdef TreeValue mask(TreeValue tree, object mask_, bool remove_empty= *)
22 | cdef object _c_reduce(TreeStorage st, object func, tuple path, object return_type)
23 | cpdef object reduce_(TreeValue tree, object func)
24 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/oo_demo.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import FastTreeValue
2 |
3 | if __name__ == '__main__':
4 | t1 = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
5 | t2 = FastTreeValue({'a': 5, 'b': 6, 'x': {'c': 7, 'd': 8}})
6 |
7 | # operator support
8 | print('t1 + t2:')
9 | print(t1 + t2)
10 | print('t2 ** t1:')
11 | print(t2 ** t1)
12 | print()
13 |
14 | # utilities support
15 | print('t1.map(lambda x: (x + 1) * (x + 2)):')
16 | print(t1.map(lambda x: (x + 1) * (x + 2)))
17 | print('t1.reduce(lambda **kwargs: sum(kwargs.values())):',
18 | t1.reduce(lambda **kwargs: sum(kwargs.values())))
19 | print()
20 | print()
21 |
22 | # linking usage
23 | print('t1.map(lambda x: (x + 1) * (x + 2)).filter(lambda x: x % 4 == 0):')
24 | print(t1.map(lambda x: (x + 1) * (x + 2)).filter(lambda x: x % 4 == 0))
25 | print()
26 |
27 | # structural support
28 | print("Union result:")
29 | print(FastTreeValue.union(
30 | t1.map(lambda x: (x + 1) * (x + 2)),
31 | t2.map(lambda x: (x - 2) ** (x - 1)),
32 | ).map(lambda x: 'first: %d, second: %d, which sum is %d' % (x[0], x[1], sum(x))))
33 | print()
34 |
--------------------------------------------------------------------------------
/docs/source/api_doc/tree/general.rst:
--------------------------------------------------------------------------------
1 | treevalue.tree.general
2 | =============================
3 |
4 | .. _apidoc_tree_general_fasttreevalue:
5 |
6 | FastTreeValue
7 | -------------------
8 |
9 | .. autoclass:: treevalue.tree.general.FastTreeValue
10 | :members: _attr_extern, json, clone, __add__, __radd__, __sub__, __rsub__, __mul__, __rmul__, __matmul__, __rmatmul__, __truediv__, __rtruediv__, __floordiv__, __rfloordiv__, __mod__, __rmod__, __pow__, __rpow__, __and__, __rand__, __or__, __ror__, __xor__, __rxor__, __lshift__, __rlshift__, __rshift__, __rrshift__, __pos__, __neg__, __invert__, __getitem__, __setitem__, __delitem__, __call__, __getattribute__, __setattr__, __delattr__, __repr__, __iter__, __hash__, __eq__, map, type, mask, filter, __str__, reduce, rise, union, subside, __getstate__, __setstate__, __iadd__, __isub__, __imul__, __imatmul__, __ifloordiv__, __itruediv__, __ipow__, __imod__, __iand__, __ior__, __ixor__, __ilshift__, __irshift__, graph, graphics, func, keys, values, items, walk, _getitem_extern, _setitem_extern, _delitem_extern
11 |
12 |
13 | .. _apidoc_tree_general_generaltreevalue:
14 |
15 | general_tree_value
16 | ---------------------
17 |
18 | .. autofunction:: treevalue.tree.general.general_tree_value
19 |
20 |
21 |
--------------------------------------------------------------------------------
/docs/source/tutorials/plugins/index.rst:
--------------------------------------------------------------------------------
1 | Plugins
2 | ===============
3 |
4 | Potc support
5 | ---------------------
6 |
7 | `Potc `_ is a package that can convert any object into executable source code.
8 | For ``treevalue``, potc can support the source code transformation of treevalue objects through
9 | the installation of additional plugins. So we can execute the following installation command
10 |
11 | .. code:: shell
12 |
13 | pip install treevalue[potc]
14 |
15 | After this installation, you will be able to directly convert treevalue to an object without any additional operations.
16 | Such as
17 |
18 | .. literalinclude:: ./potc_demo.demo.py
19 | :language: python
20 | :linenos:
21 |
22 | The output should be
23 |
24 | .. literalinclude:: ./potc_demo.demo.py.txt
25 | :language: text
26 | :linenos:
27 |
28 | Also, you can use the following CLI command to get the same output results as above.
29 |
30 | .. code:: shell
31 |
32 | potc export -v 'test_simple.t' -v 'test_simple.st' -v 'test_simple.r'
33 |
34 | For further information, you can refer to
35 |
36 | * `potc-dev/potc `_
37 | * `potc-dev/potc-treevalue `_
38 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/inherit_demo_2.gv:
--------------------------------------------------------------------------------
1 | digraph {
2 | label = "inherit_demo_2 \n(Runnable, t2 is sum of t1 and 5 due to enablement of inheriting)";
3 | graph [bgcolor = "#ffffff00"];
4 |
5 | subgraph cluster_t1 {
6 | label = "Tree t1";
7 | t1 [label = t1];
8 | t1_a [label = 11];
9 | t1_b [label = 22];
10 | t1_x [label = "t1.x"];
11 | t1_x_c [label = 30];
12 | t1_x_d [label = 48];
13 | t1_x_e [label = 54];
14 |
15 | t1 -> t1_a [label = "a"];
16 | t1 -> t1_b [label = "b"];
17 | t1 -> t1_x [label = "x"];
18 | t1_x -> t1_x_c [label = "c"];
19 | t1_x -> t1_x_d [label = "d"];
20 | t1_x -> t1_x_e [label = "e"];
21 | }
22 |
23 | subgraph cluster_t2 {
24 | label = "Tree t2";
25 | t2 [label = "t2 = t1 + 5"];
26 | t2_a [label = "16 = 11 + 5"];
27 | t2_b [label = "27 = 22 + 5"];
28 | t2_x [label = "t2.x = t1.x + 5"];
29 | t2_x_c [label = "35 = 30 + 5"];
30 | t2_x_d [label = "53 = 48 + 5"];
31 | t2_x_e [label = "59 = 54 + 5"];
32 |
33 | t2 -> t2_a [label = "a"];
34 | t2 -> t2_b [label = "b"];
35 | t2 -> t2_x [label = "x"];
36 | t2_x -> t2_x_c [label = "c"];
37 | t2_x -> t2_x_d [label = "d"];
38 | t2_x -> t2_x_e [label = "e"];
39 | }
40 | }
--------------------------------------------------------------------------------
/docs/source/best_practice/numpy/without_treevalue.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | T, B = 3, 4
4 |
5 |
6 | def without_treevalue(batch_):
7 | mean_b_list = []
8 | even_index_a_list = []
9 | for i in range(len(batch_)):
10 | for k, v in batch_[i].items():
11 | if k == 'a':
12 | v = v.astype(np.float32)
13 | even_index_a_list.append(v[::2])
14 | elif k == 'b':
15 | v = v.astype(np.float32)
16 | transformed_v = np.power(v, 2) + 1.0
17 | mean_b_list.append(transformed_v.mean())
18 | elif k == 'c':
19 | for k1, v1 in v.items():
20 | if k1 == 'd':
21 | v1 = v1.astype(np.float32)
22 | else:
23 | print('ignore keys: {}'.format(k1))
24 | else:
25 | print('ignore keys: {}'.format(k))
26 | for i in range(len(batch_)):
27 | for k in batch_[i].keys():
28 | if k == 'd':
29 | batch_[i][k]['noise'] = np.random.random(size=(3, 4, 5))
30 |
31 | mean_b = sum(mean_b_list) / len(mean_b_list)
32 | even_index_a = np.stack(even_index_a_list, axis=0)
33 | return batch_, mean_b, even_index_a
34 |
--------------------------------------------------------------------------------
/treevalue/tree/integration/cjax.pyx:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | import cython
5 |
6 | from .base cimport _c_flatten_for_integration, _c_unflatten_for_integration
7 | from ..tree.tree cimport TreeValue
8 |
9 | cdef inline tuple _c_flatten_for_jax(object tv):
10 | return _c_flatten_for_integration(tv)
11 |
12 | cdef inline object _c_unflatten_for_jax(tuple aux, tuple values):
13 | return _c_unflatten_for_integration(values, aux)
14 |
15 | @cython.binding(True)
16 | cpdef void register_for_jax(object cls) except*:
17 | """
18 | Overview:
19 | Register treevalue class for jax.
20 |
21 | :param cls: TreeValue class.
22 |
23 | Examples::
24 | >>> from treevalue import FastTreeValue, TreeValue, register_for_jax
25 | >>> register_for_jax(TreeValue)
26 | >>> register_for_jax(FastTreeValue)
27 |
28 | .. warning::
29 | This method will put a warning message and then do nothing when jax is not installed.
30 | """
31 | if isinstance(cls, type) and issubclass(cls, TreeValue):
32 | import jax
33 | jax.tree_util.register_pytree_node(cls, _c_flatten_for_jax, _c_unflatten_for_jax)
34 | else:
35 | raise TypeError(f'Registered class should be a subclass of TreeValue, but {cls!r} found.')
36 |
--------------------------------------------------------------------------------
/docs/source/api_doc/tree/graphics_dup_value.demo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from treevalue import FastTreeValue, graphics
4 |
5 |
6 | class MyFastTreeValue(FastTreeValue):
7 | pass
8 |
9 |
10 | if __name__ == '__main__':
11 | t = MyFastTreeValue({
12 | 'a': [4, 3, 2, 1],
13 | 'b': np.array([[5, 6], [7, 8]]),
14 | 'x': {
15 | 'c': np.array([[5, 7], [8, 6]]),
16 | 'd': {'a', 'b', 'c'},
17 | 'e': np.array([[1, 2], [3, 4]])
18 | },
19 | })
20 | t1 = MyFastTreeValue({
21 | 'aa': t.a,
22 | 'bb': np.array([[5, 6], [7, 8]]),
23 | 'xx': {
24 | 'cc': t.x.c,
25 | 'dd': t.x.d,
26 | 'ee': np.array([[1, 2], [3, 4]])
27 | },
28 | })
29 |
30 | g = graphics(
31 | (t, 't'), (t1, 't1'),
32 | (MyFastTreeValue({'a': t, 'b': t1, 'c': [1, 2], 'd': t1.xx}), 't2'),
33 | # Here is the dup value, with several types
34 | # np.ndarray and list type will use the same value node,
35 | # but set type is not in this tuple, so will not share the same node.
36 | dup_value=(np.ndarray, list),
37 | title="This is a demo of 2 trees with dup value.",
38 | cfg={'bgcolor': '#ffffff00'},
39 | )
40 | g.render('graphics_dup_value.dat.gv', format='svg')
41 |
--------------------------------------------------------------------------------
/treevalue/tree/integration/ctorch.pyx:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | import cython
5 |
6 | from .base cimport _c_flatten_for_integration, _c_unflatten_for_integration
7 | from ..tree.tree cimport TreeValue
8 |
9 | cdef inline tuple _c_flatten_for_torch(object tv):
10 | return _c_flatten_for_integration(tv)
11 |
12 | cdef inline object _c_unflatten_for_torch(list values, tuple context):
13 | return _c_unflatten_for_integration(values, context)
14 |
15 | @cython.binding(True)
16 | cpdef void register_for_torch(object cls) except*:
17 | """
18 | Overview:
19 | Register treevalue class for torch's pytree library.
20 |
21 | :param cls: TreeValue class.
22 |
23 | Examples::
24 | >>> from treevalue import FastTreeValue, TreeValue, register_for_torch
25 | >>> register_for_torch(TreeValue)
26 | >>> register_for_torch(FastTreeValue)
27 |
28 | .. warning::
29 | This method will put a warning message and then do nothing when torch is not installed.
30 | """
31 | if isinstance(cls, type) and issubclass(cls, TreeValue):
32 | import torch
33 | torch.utils._pytree._register_pytree_node(cls, _c_flatten_for_torch, _c_unflatten_for_torch)
34 | else:
35 | raise TypeError(f'Registered class should be a subclass of TreeValue, but {cls!r} found.')
36 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/rise_demo_2.demo.py:
--------------------------------------------------------------------------------
1 | from treevalue import TreeValue, subside, rise
2 |
3 | if __name__ == '__main__':
4 | # The same demo as the subside docs
5 | t11 = TreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
6 | t12 = TreeValue({'a': 3, 'b': 9, 'x': {'c': 15, 'd': 41}})
7 | t21 = TreeValue({'a': -14, 'b': 9, 'x': {'c': 3, 'd': 8}})
8 | t22 = TreeValue({'a': -31, 'b': 82, 'x': {'c': 47, 'd': 32}})
9 | t3 = TreeValue({'a': 6, 'b': 0, 'x': {'c': -5, 'd': 17}})
10 | t4 = TreeValue({'a': 0, 'b': -17, 'x': {'c': -8, 'd': 15}})
11 | t5 = TreeValue({'a': 3, 'b': 9, 'x': {'c': 11, 'd': -17}})
12 | st = {'first': [(t11, t21), (t12, t22)], 'second': (t3, {'x': t4, 'y': t5})}
13 | tx = subside(st)
14 |
15 | # Rising process, with template
16 | # only the top-leveled dict will be extracted,
17 | # neither tuple, list nor low-leveled dict will not be extracted
18 | # because they are not defined in `template` argument
19 | st2 = rise(tx, template={'first': [object, ...], 'second': (object, ...)})
20 | print('st2:', st2)
21 |
22 | print("st2['first'][0]:")
23 | print(st2['first'][0])
24 |
25 | print("st2['first'][1]:")
26 | print(st2['first'][1])
27 |
28 | print("st2['second'][0]:")
29 | print(st2['second'][0])
30 |
31 | print("st2['second'][1]:")
32 | print(st2['second'][1])
33 |
--------------------------------------------------------------------------------
/docs/source/api_doc/tree/integration.rst:
--------------------------------------------------------------------------------
1 | treevalue.tree.integration
2 | =======================================
3 |
4 | .. py:currentmodule:: treevalue.tree.integration
5 |
6 |
7 | .. _apidoc_tree_integration_register_for_jax:
8 |
9 | register_for_jax
10 | ---------------------------
11 |
12 | .. autofunction:: register_for_jax
13 |
14 |
15 | .. _apidoc_tree_integration_register_for_torch:
16 |
17 | register_for_torch
18 | ---------------------------
19 |
20 | .. autofunction:: register_for_torch
21 |
22 |
23 | .. _apidoc_tree_integration_register_treevalue_class:
24 |
25 | register_treevalue_class
26 | ---------------------------
27 |
28 | .. autofunction:: register_treevalue_class
29 |
30 |
31 | .. _apidoc_tree_integration_register_integrate_container:
32 |
33 | register_integrate_container
34 | --------------------------------
35 |
36 | .. autofunction:: register_integrate_container
37 |
38 |
39 | .. _apidoc_tree_integration_generic_flatten:
40 |
41 | generic_flatten
42 | --------------------------------
43 |
44 | .. autofunction:: generic_flatten
45 |
46 |
47 | .. _apidoc_tree_integration_generic_unflatten:
48 |
49 | generic_unflatten
50 | --------------------------------
51 |
52 | .. autofunction:: generic_unflatten
53 |
54 |
55 | .. _apidoc_tree_integration_generic_mapping:
56 |
57 | generic_mapping
58 | --------------------------------
59 |
60 | .. autofunction:: generic_mapping
61 |
62 |
63 |
--------------------------------------------------------------------------------
/docs/source/tutorials/basic_usage/edit_tree_1.gv:
--------------------------------------------------------------------------------
1 | digraph edit_tree {
2 | graph [bgcolor = "#ffffff00"];
3 |
4 |
5 | subgraph cluster_original {
6 | label = "(0) Original tree 't'"
7 | root0 [label = "t"];
8 | n01 [label = "1"];
9 | n02 [label = "2"];
10 | n03 [label = "t.x"];
11 | n04 [label = "3"];
12 | n05 [label = "4"];
13 | root0 -> n01 [label = "a"];
14 | root0 -> n02 [label = "b"];
15 | root0 -> n03 [label = "x"];
16 | n03 -> n04 [label = "c"];
17 | n03 -> n05 [label = "d"];
18 | }
19 |
20 | subgraph cluster_step1 {
21 | label = "(1) After t.a = 233"
22 | root1 [label = "t"];
23 | n11 [label = "233"];
24 | n12 [label = "2"];
25 | n13 [label = "t.x"];
26 | n14 [label = "3"];
27 | n15 [label = "4"];
28 | root1 -> n11 [label = "a"];
29 | root1 -> n12 [label = "b"];
30 | root1 -> n13 [label = "x"];
31 | n13 -> n14 [label = "c"];
32 | n13 -> n15 [label = "d"];
33 | }
34 |
35 | subgraph cluster_step2 {
36 | label = "(2) After t.x.d = -1"
37 | root2 [label = "t"];
38 | n21 [label = "233"];
39 | n22 [label = "2"];
40 | n23 [label = "t.x"];
41 | n24 [label = "3"];
42 | n25 [label = "-1"];
43 | root2 -> n21 [label = "a"];
44 | root2 -> n22 [label = "b"];
45 | root2 -> n23 [label = "x"];
46 | n23 -> n24 [label = "c"];
47 | n23 -> n25 [label = "d"];
48 | }
49 | }
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | Welcome to TreeValue's Documentation
2 | =====================================
3 |
4 | Overview
5 | -------------
6 |
7 | ``TreeValue`` is a generalized tree-based data structure.
8 | Almost all the operation can be supported \
9 | in form of trees in a convenient way to simplify the \
10 | structure processing when the calculation is tree-based.
11 |
12 | .. toctree::
13 | :maxdepth: 2
14 | :caption: Tutorials
15 |
16 | tutorials/installation/index
17 | tutorials/quick_start/index
18 | tutorials/main_idea/index
19 | tutorials/basic_usage/index
20 | tutorials/advanced_usage/index
21 | tutorials/cli_usage/index
22 | tutorials/plugins/index
23 |
24 |
25 | .. toctree::
26 | :maxdepth: 2
27 | :caption: Best Practice
28 |
29 | best_practice/numpy/index
30 | best_practice/sklearn/index
31 |
32 |
33 | .. toctree::
34 | :maxdepth: 2
35 | :caption: Comparison
36 |
37 | comparison/generic
38 | comparison/environment.result
39 | comparison/dmtree.result
40 | comparison/tianshou_batch.result
41 | comparison/jax_libtree.result
42 |
43 |
44 | .. toctree::
45 | :maxdepth: 2
46 | :caption: API Documentation
47 |
48 | api_doc/config/index
49 | api_doc/tree/index
50 | api_doc/utils/index
51 |
52 | .. toctree::
53 | :maxdepth: 2
54 | :caption: Contributor Guide
55 |
56 | contribute/architecture/index
57 |
--------------------------------------------------------------------------------
/docs/source/tutorials/basic_usage/calculation_add.gv:
--------------------------------------------------------------------------------
1 | digraph calculation_add {
2 | graph [bgcolor = "#ffffff00"];
3 |
4 | subgraph cluster_t1 {
5 | label = "Tree 't1'"
6 | root0 [label = "t1"];
7 | n01 [label = "1"];
8 | n02 [label = "2"];
9 | n03 [label = "t1.x"];
10 | n04 [label = "3"];
11 | n05 [label = "4"];
12 | root0 -> n01 [label = "a"];
13 | root0 -> n02 [label = "b"];
14 | root0 -> n03 [label = "x"];
15 | n03 -> n04 [label = "c"];
16 | n03 -> n05 [label = "d"];
17 | }
18 |
19 | subgraph cluster_t2 {
20 | label = "Tree 't2'"
21 | root1 [label = "t2"];
22 | n11 [label = "3"];
23 | n12 [label = "7"];
24 | n13 [label = "t2.x"];
25 | n14 [label = "14"];
26 | n15 [label = "-5"];
27 | root1 -> n11 [label = "a"];
28 | root1 -> n12 [label = "b"];
29 | root1 -> n13 [label = "x"];
30 | n13 -> n14 [label = "c"];
31 | n13 -> n15 [label = "d"];
32 | }
33 |
34 | subgraph cluster_step2 {
35 | label = "Result of t1 + t2"
36 | root2 [label = "t1 + t2"];
37 | n21 [label = "4 = 1 + 3"];
38 | n22 [label = "9 = 2 + 7"];
39 | n23 [label = "t1.x + t2.x"];
40 | n24 [label = "17 = 3 + 14"];
41 | n25 [label = "-1 = 4 + (-5)"];
42 | root2 -> n21 [label = "a"];
43 | root2 -> n22 [label = "b"];
44 | root2 -> n23 [label = "x"];
45 | n23 -> n24 [label = "c"];
46 | n23 -> n25 [label = "d"];
47 | }
48 | }
--------------------------------------------------------------------------------
/docs/source/tutorials/quick_start/index.rst:
--------------------------------------------------------------------------------
1 | Quick Start
2 | ==================
3 |
4 | Create a Simplest Tree
5 | -------------------------------------
6 |
7 | You can create a simplest tree like this
8 |
9 | .. literalinclude:: create_a_tree.demo.py
10 | :language: python
11 | :linenos:
12 |
13 | A tree value object with the following structure will be \
14 | created
15 |
16 | .. image:: simple_demo.dat.svg
17 | :align: center
18 |
19 | The output of the code above should be
20 |
21 | .. literalinclude:: create_a_tree.demo.py.txt
22 | :language: text
23 | :linenos:
24 |
25 |
26 | Create a Slightly Complex Tree
27 | -------------------------------------
28 |
29 | You can easily create a tree value object which is \
30 | slightly complex, based on ``FastTreeValue``.
31 |
32 | .. literalinclude:: create_a_complex_tree.demo.py
33 | :language: python
34 | :linenos:
35 |
36 | The result should be
37 |
38 | .. literalinclude:: create_a_complex_tree.demo.py.txt
39 | :language: text
40 | :linenos:
41 |
42 | Three simple treevalue structures are created. \
43 | Then save the code above to ``demo.py``, and then input \
44 | this shell command in your terminal.
45 |
46 | .. literalinclude:: display_complex_tree.demo.sh
47 | :language: shell
48 | :linenos:
49 |
50 | A graph named ``demo.dat.svg`` will be generated, like this.
51 |
52 | .. image:: complex_demo.dat.svg
53 | :align: center
54 |
55 | Now you are successfully started.
56 |
--------------------------------------------------------------------------------
/docs/source/tutorials/basic_usage/edit_tree_2.gv:
--------------------------------------------------------------------------------
1 | digraph edit_tree {
2 | graph [bgcolor = "#ffffff00"];
3 |
4 | subgraph cluster_step3 {
5 | label = "(3) After t.x = {'e': 5, 'f': 6}"
6 | root3 [label = "t"];
7 | n31 [label = "233"];
8 | n32 [label = "2"];
9 | n33 [label = "t.x"];
10 | n34 [label = "5"];
11 | n35 [label = "6"];
12 | root3 -> n31 [label = "a"];
13 | root3 -> n32 [label = "b"];
14 | root3 -> n33 [label = "x"];
15 | n33 -> n34 [label = "e"];
16 | n33 -> n35 [label = "f"];
17 | }
18 |
19 | subgraph cluster_step4 {
20 | label = "(4) After t.x.g = raw({'e': 5, 'f': 6})"
21 | root4 [label = "t"];
22 | n41 [label = "233"];
23 | n42 [label = "2"];
24 | n43 [label = "t.x"];
25 | n44 [label = "5"];
26 | n45 [label = "6"];
27 | n46 [label = "{'e': 5, 'f': 6}"];
28 | root4 -> n41 [label = "a"];
29 | root4 -> n42 [label = "b"];
30 | root4 -> n43 [label = "x"];
31 | n43 -> n44 [label = "e"];
32 | n43 -> n45 [label = "f"];
33 | n43 -> n46 [label = "g"];
34 | }
35 |
36 | subgraph cluster_step5 {
37 | label = "(5) After del t.x.g"
38 | root5 [label = "t"];
39 | n51 [label = "233"];
40 | n52 [label = "2"];
41 | n53 [label = "t.x"];
42 | n54 [label = "5"];
43 | n55 [label = "6"];
44 | root5 -> n51 [label = "a"];
45 | root5 -> n52 [label = "b"];
46 | root5 -> n53 [label = "x"];
47 | n53 -> n54 [label = "e"];
48 | n53 -> n55 [label = "f"];
49 | }
50 |
51 | }
--------------------------------------------------------------------------------
/treevalue/tree/func/modes.pxd:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | from libcpp cimport bool
5 |
6 | ctypedef enum _e_tree_mode:
7 | STRICT
8 | INNER
9 | OUTER
10 | LEFT
11 |
12 | cdef _e_tree_mode _c_load_mode(str mode) except *
13 | cdef void _c_base_check(_e_tree_mode mode, object return_type,
14 | bool inherit, bool allow_missing, object missing_func) except *
15 |
16 | cdef set _c_strict_keyset(list args, dict kwargs)
17 | cdef void _c_strict_check(_e_tree_mode mode, object return_type,
18 | bool inherit, bool allow_missing, object missing_func) except *
19 |
20 | cdef set _c_inner_keyset(list args, dict kwargs)
21 | cdef void _c_inner_check(_e_tree_mode mode, object return_type,
22 | bool inherit, bool allow_missing, object missing_func) except *
23 |
24 | cdef set _c_outer_keyset(list args, dict kwargs)
25 | cdef void _c_outer_check(_e_tree_mode mode, object return_type,
26 | bool inherit, bool allow_missing, object missing_func) except *
27 |
28 | cdef set _c_left_keyset(list args, dict kwargs)
29 | cdef void _c_left_check(_e_tree_mode mode, object return_type,
30 | bool inherit, bool allow_missing, object missing_func) except *
31 |
32 | cdef set _c_keyset(_e_tree_mode mode, list args, dict kwargs)
33 | cdef void _c_check(_e_tree_mode mode, object return_type,
34 | bool inherit, bool allow_missing, object missing_func) except *
35 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/diy_class_x_demo_3.demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from treevalue import general_tree_value
4 |
5 |
6 | class DemoTreeValue(general_tree_value(methods=dict(
7 | # - operator will be disabled
8 | __sub__=NotImplemented,
9 | __rsub__=NotImplemented,
10 | __isub__=NotImplemented,
11 |
12 | # / operator will raise ArithmeticError
13 | __truediv__=ArithmeticError('True div is not supported'),
14 | __rtruediv__=ArithmeticError('True div is not supported'),
15 | __itruediv__=ArithmeticError('True div is not supported'),
16 |
17 | # +t will be changed to t * 2
18 | __pos__=lambda x: x * 2,
19 | ))):
20 | pass
21 |
22 |
23 | if __name__ == '__main__':
24 | t1 = DemoTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
25 | t2 = DemoTreeValue({'a': 11, 'b': 24, 'x': {'c': 30, 'd': 47}})
26 |
27 | # __add__ can be used normally
28 | print('t1 + t2:', t1 + t2, sep=os.linesep)
29 |
30 | # __sub__ will cause TypeError due to the NotImplemented
31 | try:
32 | _ = t1 - t2
33 | except TypeError as err:
34 | print('t1 - t2:', err, sep=os.linesep)
35 | else:
36 | assert False, 'Should not reach here!'
37 |
38 | # __truediv__ will cause ArithmeticError
39 | try:
40 | _ = t1 / t2
41 | except ArithmeticError as err:
42 | print('t1 / t2:', err, sep=os.linesep)
43 | else:
44 | assert False, 'Should not reach here!'
45 |
46 | # __pos__ will be like t1 * 2
47 | print('+t1:', +t1, sep=os.linesep)
48 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/inherit_demo_1.gv:
--------------------------------------------------------------------------------
1 | digraph {
2 | label = "inherit_demo_1 \n(Runnable, t3 is sum of t1 and t2 due to enablement of inheriting)";
3 | graph [bgcolor = "#ffffff00"];
4 |
5 | subgraph cluster_t1 {
6 | label = "Tree t1";
7 | t1 [label = t1];
8 | t1_a [label = 1];
9 | t1_b [label = 2];
10 | t1_x [label = "9"];
11 |
12 | t1 -> t1_a [label = "a"];
13 | t1 -> t1_b [label = "b"];
14 | t1 -> t1_x [label = "x"];
15 | }
16 |
17 | subgraph cluster_t2 {
18 | label = "Tree t2";
19 | t2 [label = t2];
20 | t2_a [label = 11];
21 | t2_b [label = 22];
22 | t2_x [label = "t2.x"];
23 | t2_x_c [label = 30];
24 | t2_x_d [label = 48];
25 | t2_x_e [label = 54];
26 |
27 | t2 -> t2_a [label = "a"];
28 | t2 -> t2_b [label = "b"];
29 | t2 -> t2_x [label = "x"];
30 | t2_x -> t2_x_c [label = "c"];
31 | t2_x -> t2_x_d [label = "d"];
32 | t2_x -> t2_x_e [label = "e"];
33 | }
34 |
35 | subgraph cluster_t3 {
36 | label = "Tree t3";
37 | t3 [label = "t3 = t1 + t2"];
38 | t3_a [label = "12 = 1 + 11"];
39 | t3_b [label = "24 = 2 + 22"];
40 | t3_x [label = "t3.x = t1.x + t2.x"];
41 | t3_x_c [label = "39 = 9 + 30"];
42 | t3_x_d [label = "57 = 9 + 48"];
43 | t3_x_e [label = "63 = 9 + 54"];
44 |
45 | t3 -> t3_a [label = "a"];
46 | t3 -> t3_b [label = "b"];
47 | t3 -> t3_x [label = "x"];
48 | t3_x -> t3_x_c [label = "c"];
49 | t3_x -> t3_x_d [label = "d"];
50 | t3_x -> t3_x_e [label = "e"];
51 | }
52 | }
--------------------------------------------------------------------------------
/docs/source/api_doc/utils/tree.rst:
--------------------------------------------------------------------------------
1 | treevalue.utils.tree
2 | ==========================
3 |
4 | .. py:currentmodule:: treevalue.utils.tree
5 |
6 | .. automodule:: treevalue.utils.tree
7 |
8 | build_graph
9 | -------------------
10 |
11 | .. autofunction:: build_graph
12 |
13 | Here is an example of ``build_graph`` function. The source code is
14 |
15 | .. literalinclude:: build_graph.demo.py
16 | :language: python
17 | :linenos:
18 |
19 | The generated graphviz source code should be
20 |
21 | .. literalinclude:: build_graph_demo.dat.gv
22 | :language: text
23 | :linenos:
24 |
25 | The graph should be
26 |
27 | .. image:: build_graph_demo.dat.gv.svg
28 | :align: center
29 |
30 | Also, multiple rooted graph is supported, this function will detect
31 | the pointer of the objects. Just like another complex source code below.
32 |
33 | .. literalinclude:: build_graph_complex.demo.py
34 | :language: python
35 | :linenos:
36 |
37 | The exported graph should be
38 |
39 | .. image:: build_graph_complex_demo.dat.gv.svg
40 | :align: center
41 |
42 | The return value's type of function ``graphics`` is \
43 | class ``graphviz.dot.Digraph``, from the opensource \
44 | library ``graphviz``, for further information of \
45 | this project and ``graphviz.dot.Digraph``'s usage, \
46 | take a look at:
47 |
48 | * `Official site of Graphviz `_.
49 | * `User Guide of Graphviz `_.
50 | * `API Reference of Graphviz `_.
51 |
--------------------------------------------------------------------------------
/.github/workflows/badge.yml:
--------------------------------------------------------------------------------
1 | name: Badge Creation
2 |
3 | on:
4 | push:
5 | branches: [ main, 'badge/*', 'doc/*' ]
6 |
7 | jobs:
8 | update-badges:
9 | name: Update Badges
10 | runs-on: ubuntu-latest
11 | strategy:
12 | matrix:
13 | python-version:
14 | - '3.8'
15 |
16 | steps:
17 | - uses: actions/checkout@v3
18 | - name: Set up Python ${{ matrix.python-version }}
19 | uses: actions/setup-python@v4
20 | with:
21 | python-version: ${{ matrix.python-version }}
22 | - name: Download cloc
23 | run: |
24 | sudo apt-get update -y
25 | sudo apt-get install -y cloc
26 | - name: Get the Numbers
27 | run: |
28 | cloc .
29 | echo "CODE_LINES=$(./cloc.sh --loc)" >> $GITHUB_ENV
30 | echo "COMMENT_LINES=$(./cloc.sh --percentage)%" >> $GITHUB_ENV
31 | - name: Create Lines-of-Code-Badge
32 | uses: schneegans/dynamic-badges-action@v1.0.0
33 | with:
34 | auth: ${{ secrets.GIST_SECRET }}
35 | gistID: ${{ secrets.BADGE_GIST_ID }}
36 | filename: loc.json
37 | label: Lines of Code
38 | message: ${{ env.CODE_LINES }}
39 | color: lightgrey
40 | - name: Create Comments-Badge
41 | uses: schneegans/dynamic-badges-action@v1.0.0
42 | with:
43 | auth: ${{ secrets.GIST_SECRET }}
44 | gistID: ${{ secrets.BADGE_GIST_ID }}
45 | filename: comments.json
46 | label: Comments
47 | message: ${{ env.COMMENT_LINES }}
48 | color: green
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/left_demo_1.gv:
--------------------------------------------------------------------------------
1 | digraph {
2 | label = "left_demo_1 \n(Runnable, t3 is sum of t1 and t2, t2.a is ignored.)";
3 | graph [bgcolor = "#ffffff00"];
4 |
5 | subgraph cluster_t1 {
6 | label = "Tree t1";
7 | t1 [label = t1];
8 | t1_b [label = 2];
9 | t1_x [label = "t1.x"];
10 | t1_x_c [label = 3];
11 | t1_x_d [label = 4];
12 | t1_x_e [label = 5];
13 |
14 | t1 -> t1_b [label = "b"];
15 | t1 -> t1_x [label = "x"];
16 | t1_x -> t1_x_c [label = "c"];
17 | t1_x -> t1_x_d [label = "d"];
18 | t1_x -> t1_x_e [label = "e"];
19 | }
20 |
21 | subgraph cluster_t2 {
22 | label = "Tree t2";
23 | t2 [label = t2];
24 | t2_a [label = 11];
25 | t2_b [label = 22];
26 | t2_x [label = "t2.x"];
27 | t2_x_c [label = 30];
28 | t2_x_d [label = 48];
29 | t2_x_e [label = 54];
30 |
31 | t2 -> t2_a [label = "a"];
32 | t2 -> t2_b [label = "b"];
33 | t2 -> t2_x [label = "x"];
34 | t2_x -> t2_x_c [label = "c"];
35 | t2_x -> t2_x_d [label = "d"];
36 | t2_x -> t2_x_e [label = "e"];
37 | }
38 |
39 | subgraph cluster_t3 {
40 | label = "Tree t3";
41 | t3 [label = "t3 = t1 + t2"];
42 | t3_b [label = "24 = 2 + 22"];
43 | t3_x [label = "t3.x = t1.x + t2.x"];
44 | t3_x_c [label = "33 = 3 + 30"];
45 | t3_x_d [label = "52 = 4 + 48"];
46 | t3_x_e [label = "59 = 5 + 54"];
47 |
48 | t3 -> t3_b [label = "b"];
49 | t3 -> t3_x [label = "x"];
50 | t3_x -> t3_x_c [label = "c"];
51 | t3_x -> t3_x_d [label = "d"];
52 | t3_x -> t3_x_e [label = "e"];
53 | }
54 | }
--------------------------------------------------------------------------------
/test/tree/func/test_inner.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from treevalue.tree import func_treelize, TreeValue
4 |
5 |
6 | # noinspection DuplicatedCode
7 | @pytest.mark.unittest
8 | class TestTreeFuncInner:
9 | def test_inner_raw(self):
10 | @func_treelize('inner', inherit=False)
11 | def ssum(*args):
12 | return sum(args)
13 |
14 | t1 = TreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
15 | t2 = TreeValue({'a': 11, 'b': 22, 'x': {'c': 33, 'd': 44}})
16 | assert ssum(1, 2, 3) == 6
17 | assert ssum(t1, t2) == TreeValue({'a': 12, 'b': 24, 'x': {'c': 36, 'd': 48}})
18 | assert ssum(t1.x, t2.x) == TreeValue({'c': 36, 'd': 48})
19 | with pytest.raises(TypeError):
20 | _ = ssum(t1, t2, 3)
21 |
22 | t3 = TreeValue({'a': 11, 'b': 22, 'c': 33, 'x': {'c': 33, 'd': 44, 'e': 550}})
23 | assert ssum(t1, t2, t3) == TreeValue({'a': 23, 'b': 46, 'x': {'c': 69, 'd': 92}})
24 | assert ssum(t2, t3) == TreeValue({'a': 22, 'b': 44, 'x': {'c': 66, 'd': 88}})
25 | assert ssum(t3, t1, t2) == TreeValue({'a': 23, 'b': 46, 'x': {'c': 69, 'd': 92}})
26 |
27 | def test_inner_inherit(self):
28 | @func_treelize('inner', )
29 | def ssum(*args):
30 | return sum(args)
31 |
32 | t1 = TreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
33 | t2 = TreeValue({'a': 11, 'b': 22, 'x': {'c': 33, 'd': 44}})
34 | assert ssum(1, 2, 3) == 6
35 | assert ssum(t1, t2) == TreeValue({'a': 12, 'b': 24, 'x': {'c': 36, 'd': 48}})
36 | assert ssum(t1, t2, 3) == TreeValue({'a': 15, 'b': 27, 'x': {'c': 39, 'd': 51}})
37 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools>=42",
4 | "wheel",
5 | # "cython>=0.29; platform_system != 'Windows'",
6 | # "cython>=0.29,<3; platform_system == 'Windows'",
7 | "cython>=3",
8 | ]
9 |
10 | [tool.cibuildwheel]
11 | skip = ["pp*"] # Do not build for PyPy
12 |
13 | ## Windows build configuration
14 | [tool.cibuildwheel.windows]
15 | archs = ["x86", 'AMD64']
16 | before-test = [# Unittest for windows
17 | "pip install -r \"{project}\\requirements-test.txt\"",
18 | ]
19 | test-command = [
20 | "xcopy /e /i \"{project}\\test\" test",
21 | "copy \"{project}\\pytest.ini\" pytest.ini",
22 | "pytest test -sv -m unittest --log-level=DEBUG",
23 | "rmdir /s /q test",
24 | ]
25 |
26 | ## macOS build configuration
27 | [tool.cibuildwheel.macos]
28 | archs = ["x86_64", "arm64"] # Build for x86_64 and arm64
29 | before-test = [# Unittest for macos
30 | "pip install -r {project}/requirements-test.txt",
31 | ]
32 | test-command = [
33 | "cp -rf {project}/test test",
34 | "cp {project}/pytest.ini pytest.ini",
35 | "pytest test -sv -m unittest --log-level=DEBUG",
36 | "rm -rf test",
37 | ]
38 |
39 |
40 | ## Linux build configuration
41 | [tool.cibuildwheel.linux]
42 | archs = ["x86_64", "aarch64"] # Build for x86_64 and arm64
43 | skip = ["pp* *musllinux*"] # dependencies do not build for musl
44 | before-test = [# Unittest for linux
45 | "pip install -r {project}/requirements-test.txt",
46 | ]
47 | test-command = [
48 | "cp -rf {project}/test test",
49 | "cp {project}/pytest.ini pytest.ini",
50 | "pytest test -sv -m unittest --log-level=DEBUG",
51 | "rm -rf test",
52 | ]
53 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/inner_demo_1.gv:
--------------------------------------------------------------------------------
1 | digraph {
2 | label = "inner_demo_1 \n(Runnable, t3 is sum of t1 and t2, t2.a and t1.x.f is ignored.)";
3 | graph [bgcolor = "#ffffff00"];
4 |
5 | subgraph cluster_t1 {
6 | label = "Tree t1";
7 | t1 [label = t1];
8 | t1_b [label = 2];
9 | t1_x [label = "t1.x"];
10 | t1_x_c [label = 3];
11 | t1_x_d [label = 4];
12 | t1_x_e [label = 5];
13 | t1_x_f [label = 6];
14 |
15 | t1 -> t1_b [label = "b"];
16 | t1 -> t1_x [label = "x"];
17 | t1_x -> t1_x_c [label = "c"];
18 | t1_x -> t1_x_d [label = "d"];
19 | t1_x -> t1_x_e [label = "e"];
20 | t1_x -> t1_x_f [label = "f"];
21 | }
22 |
23 | subgraph cluster_t2 {
24 | label = "Tree t2";
25 | t2 [label = t2];
26 | t2_a [label = 11];
27 | t2_b [label = 22];
28 | t2_x [label = "t2.x"];
29 | t2_x_c [label = 30];
30 | t2_x_d [label = 48];
31 | t2_x_e [label = 54];
32 |
33 | t2 -> t2_a [label = "a"];
34 | t2 -> t2_b [label = "b"];
35 | t2 -> t2_x [label = "x"];
36 | t2_x -> t2_x_c [label = "c"];
37 | t2_x -> t2_x_d [label = "d"];
38 | t2_x -> t2_x_e [label = "e"];
39 | }
40 |
41 | subgraph cluster_t3 {
42 | label = "Tree t3";
43 | t3 [label = "t3 = t1 + t2"];
44 | t3_b [label = "24 = 2 + 22"];
45 | t3_x [label = "t3.x = t1.x + t2.x"];
46 | t3_x_c [label = "33 = 3 + 30"];
47 | t3_x_d [label = "52 = 4 + 48"];
48 | t3_x_e [label = "59 = 5 + 54"];
49 |
50 | t3 -> t3_b [label = "b"];
51 | t3 -> t3_x [label = "x"];
52 | t3_x -> t3_x_c [label = "c"];
53 | t3_x -> t3_x_d [label = "d"];
54 | t3_x -> t3_x_e [label = "e"];
55 | }
56 | }
--------------------------------------------------------------------------------
/docs/source/best_practice/sklearn/index.rst:
--------------------------------------------------------------------------------
1 | Apply into Scikit-Learn
2 | ===========================
3 |
4 | Actually, ``TreeValue`` can be used in practice with not only ``numpy`` or ``torch`` library, such as ``scikit-learn``.
5 | In the following part, a demo of PCA to tree-structured arrays will be shown.
6 |
7 | In the field of traditional machine learning, PCA (Principal Component Analysis) is often used to preprocess data,
8 | by normalizing the data range, and trying to reduce the dimensionality of the data, so as to reduce the complexity
9 | of the input data and improve machine learning's efficiency and quality. Just as the following image
10 |
11 | .. figure:: heading_of_pca.jpg
12 | :alt: PCA Principle
13 |
14 | PCA in a nutshell. Source: Lavrenko and Sutton 2011, slide 13.
15 |
16 | In the scikit-learn library, the PCA class is provided to support this function, and the function ``fit_transform``
17 | can be used to simplify the data. For a set of ``np.array`` format data that presents a tree structure,
18 | we can implement the operation support for the tree structure by quickly wrapping the function ``fit_transform``.
19 | The specific code is as follows
20 |
21 | .. literalinclude:: sklearn.demo.py
22 | :language: python
23 | :linenos:
24 |
25 | The output should be
26 |
27 | .. literalinclude:: sklearn.demo.py.txt
28 | :language: text
29 | :linenos:
30 |
31 | For further information, see the links below:
32 |
33 | * `Official documentation of PCA in scikit-learn `_
34 | * `Details of PCA `_
35 |
36 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/missing_demo_2.gv:
--------------------------------------------------------------------------------
1 | digraph {
2 | label = "missing_demo_2 \n(Runnable, t3 is sum of t1 and t2, t1.a and t2.x.e's values are actually treated as [].)";
3 | graph [bgcolor = "#ffffff00"];
4 |
5 | subgraph cluster_t1 {
6 | label = "Tree t1";
7 | t1 [label = t1];
8 | t1_b [label = "[2, 3]"];
9 | t1_x [label = "t1.x"];
10 | t1_x_c [label = "[5]"];
11 | t1_x_d [label = "[7, 11, 13]"];
12 | t1_x_e [label = "[17, 19]"];
13 |
14 | t1 -> t1_b [label = "b"];
15 | t1 -> t1_x [label = "x"];
16 | t1_x -> t1_x_c [label = "c"];
17 | t1_x -> t1_x_d [label = "d"];
18 | t1_x -> t1_x_e [label = "e"];
19 | }
20 |
21 | subgraph cluster_t2 {
22 | label = "Tree t2";
23 | t2 [label = t2];
24 | t2_a [label = "[23]"];
25 | t2_b [label = "[29, 31]"];
26 | t2_x [label = "t2.x"];
27 | t2_x_c [label = "[37]"];
28 | t2_x_d [label = "[41, 43]"];
29 |
30 | t2 -> t2_a [label = "a"];
31 | t2 -> t2_b [label = "b"];
32 | t2 -> t2_x [label = "x"];
33 | t2_x -> t2_x_c [label = "c"];
34 | t2_x -> t2_x_d [label = "d"];
35 | }
36 |
37 | subgraph cluster_t3 {
38 | label = "Tree t3";
39 | t3 [label = "t3 = t1 + t2"];
40 | t3_a [label = "[23]"];
41 | t3_b [label = "[2, 3, 29, 31]"];
42 | t3_x [label = "t3.x = t1.x + t2.x"];
43 | t3_x_c [label = "[5, 37]"];
44 | t3_x_d [label = "[7, 11, 13, 41, 43]"];
45 | t3_x_e [label = "[17, 19]"];
46 |
47 | t3 -> t3_a [label = "a"];
48 | t3 -> t3_b [label = "b"];
49 | t3 -> t3_x [label = "x"];
50 | t3_x -> t3_x_c [label = "c"];
51 | t3_x -> t3_x_d [label = "d"];
52 | t3_x -> t3_x_e [label = "e"];
53 | }
54 | }
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/strict_demo_1.gv:
--------------------------------------------------------------------------------
1 | digraph {
2 | label = "strict_demo_1 \n(Runnable, t3 is sum of t1 and t2)";
3 | graph [bgcolor = "#ffffff00"];
4 |
5 | subgraph cluster_t1 {
6 | label = "Tree t1";
7 | t1 [label = t1];
8 | t1_a [label = 1];
9 | t1_b [label = 2];
10 | t1_x [label = "t1.x"];
11 | t1_x_c [label = 3];
12 | t1_x_d [label = 4];
13 | t1_x_e [label = 5];
14 |
15 | t1 -> t1_a [label = "a"];
16 | t1 -> t1_b [label = "b"];
17 | t1 -> t1_x [label = "x"];
18 | t1_x -> t1_x_c [label = "c"];
19 | t1_x -> t1_x_d [label = "d"];
20 | t1_x -> t1_x_e [label = "e"];
21 | }
22 |
23 | subgraph cluster_t2 {
24 | label = "Tree t2";
25 | t2 [label = t2];
26 | t2_a [label = 11];
27 | t2_b [label = 22];
28 | t2_x [label = "t2.x"];
29 | t2_x_c [label = 30];
30 | t2_x_d [label = 48];
31 | t2_x_e [label = 54];
32 |
33 | t2 -> t2_a [label = "a"];
34 | t2 -> t2_b [label = "b"];
35 | t2 -> t2_x [label = "x"];
36 | t2_x -> t2_x_c [label = "c"];
37 | t2_x -> t2_x_d [label = "d"];
38 | t2_x -> t2_x_e [label = "e"];
39 | }
40 |
41 | subgraph cluster_t3 {
42 | label = "Tree t3";
43 | t3 [label = "t3 = t1 + t2"];
44 | t3_a [label = "12 = 1 + 11"];
45 | t3_b [label = "24 = 2 + 22"];
46 | t3_x [label = "t3.x = t1.x + t2.x"];
47 | t3_x_c [label = "33 = 3 + 30"];
48 | t3_x_d [label = "52 = 4 + 48"];
49 | t3_x_e [label = "59 = 5 + 54"];
50 |
51 | t3 -> t3_a [label = "a"];
52 | t3 -> t3_b [label = "b"];
53 | t3 -> t3_x [label = "x"];
54 | t3_x -> t3_x_c [label = "c"];
55 | t3_x -> t3_x_d [label = "d"];
56 | t3_x -> t3_x_e [label = "e"];
57 | }
58 | }
--------------------------------------------------------------------------------
/test/tree/common/test_delay.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from treevalue.tree.common import DelayedProxy, delayed_partial
4 |
5 |
6 | @pytest.mark.unittest
7 | class TestTreeDelay:
8 | def test_delayed_partial_simple(self):
9 | cnt = 0
10 |
11 | def f():
12 | nonlocal cnt
13 | cnt += 1
14 | return 1
15 |
16 | pv = delayed_partial(f)
17 | assert cnt == 0
18 | assert isinstance(pv, DelayedProxy)
19 |
20 | assert pv.value() == 1
21 | assert cnt == 1
22 |
23 | assert pv.value() == 1
24 | assert cnt == 1
25 |
26 | def test_delayed_partial_func(self):
27 | cnt = 0
28 |
29 | def f(x, y):
30 | nonlocal cnt
31 | cnt += 1
32 | return x + y * 2 + 1
33 |
34 | pv = delayed_partial(f, 2, y=3)
35 | assert cnt == 0
36 | assert isinstance(pv, DelayedProxy)
37 |
38 | assert pv.value() == 9
39 | assert cnt == 1
40 |
41 | assert pv.value() == 9
42 | assert cnt == 1
43 |
44 | def test_delayed_partial_complex(self):
45 | cnt1, cnt2 = 0, 0
46 |
47 | def f1():
48 | nonlocal cnt1
49 | cnt1 += 1
50 | return 1
51 |
52 | def f2(x, y):
53 | nonlocal cnt2
54 | cnt2 += 1
55 | return (x + 1) ** 2 + y + 2
56 |
57 | pv = delayed_partial(f2, delayed_partial(f1), delayed_partial(f1))
58 | assert cnt1 == 0
59 | assert cnt2 == 0
60 | assert isinstance(pv, DelayedProxy)
61 |
62 | assert pv.value() == 7
63 | assert cnt1 == 2
64 | assert cnt2 == 1
65 |
66 | assert pv.value() == 7
67 | assert cnt1 == 2
68 | assert cnt2 == 1
69 |
--------------------------------------------------------------------------------
/test/utils/test_tree.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from shutil import which
3 | from unittest.mock import patch
4 |
5 | import pytest
6 | from hbutils.testing import cmdv
7 |
8 | from treevalue.utils import build_graph
9 |
10 |
11 | @pytest.fixture()
12 | def no_dot():
13 | def _no_dot_which(x):
14 | if x == 'dot':
15 | return None
16 | else:
17 | return which(x)
18 |
19 | with patch('shutil.which', _no_dot_which):
20 | yield
21 |
22 |
23 | @pytest.mark.unittest
24 | class TestUtilsTree:
25 | @unittest.skipUnless(cmdv('dot'), 'Dot installed only')
26 | def test_build_graph(self):
27 | t = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}
28 | g = build_graph((t, 't'), graph_title="Demo of build_graph.")
29 | assert "Demo of build_graph." in g.source
30 | assert "t.x" in g.source
31 | assert len(g.source) <= 560
32 |
33 | g2 = build_graph(t, graph_title="Demo 2 of build_graph.")
34 | assert "Demo 2 of build_graph." in g2.source
35 | assert "root_0" in g2.source
36 | assert len(g2.source) <= 580
37 |
38 | g3 = build_graph((t,), graph_title="Demo 3 of build_graph.")
39 | assert "Demo 3 of build_graph." in g3.source
40 | assert "root_0" in g3.source
41 | assert len(g3.source) <= 580
42 |
43 | g4 = build_graph((), graph_title="Demo 4 of build_graph.")
44 | assert "Demo 4 of build_graph." in g4.source
45 | assert "node" not in g4.source
46 | assert len(g4.source) <= 110
47 |
48 | def test_build_graph_without_dot(self, no_dot):
49 | t = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}
50 | with pytest.raises(EnvironmentError):
51 | _ = build_graph((t, 't'), graph_title="Demo of build_graph.")
52 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/missing_demo_1.gv:
--------------------------------------------------------------------------------
1 | digraph {
2 | label = "missing_demo_1 \n(Runnable, t3 is sum of t1 and t2, t2.a is ignored, t2.x.f's values are actually treated as 0.)";
3 | graph [bgcolor = "#ffffff00"];
4 |
5 | subgraph cluster_t1 {
6 | label = "Tree t1";
7 | t1 [label = t1];
8 | t1_b [label = 2];
9 | t1_x [label = "t1.x"];
10 | t1_x_c [label = 3];
11 | t1_x_d [label = 4];
12 | t1_x_e [label = 5];
13 | t1_x_f [label = 6];
14 |
15 | t1 -> t1_b [label = "b"];
16 | t1 -> t1_x [label = "x"];
17 | t1_x -> t1_x_c [label = "c"];
18 | t1_x -> t1_x_d [label = "d"];
19 | t1_x -> t1_x_e [label = "e"];
20 | t1_x -> t1_x_f [label = "f"];
21 | }
22 |
23 | subgraph cluster_t2 {
24 | label = "Tree t2";
25 | t2 [label = t2];
26 | t2_a [label = 11];
27 | t2_b [label = 22];
28 | t2_x [label = "t2.x"];
29 | t2_x_c [label = 30];
30 | t2_x_d [label = 48];
31 | t2_x_e [label = 54];
32 |
33 | t2 -> t2_a [label = "a"];
34 | t2 -> t2_b [label = "b"];
35 | t2 -> t2_x [label = "x"];
36 | t2_x -> t2_x_c [label = "c"];
37 | t2_x -> t2_x_d [label = "d"];
38 | t2_x -> t2_x_e [label = "e"];
39 | }
40 |
41 | subgraph cluster_t3 {
42 | label = "Tree t3";
43 | t3 [label = "t3 = t1 + t2"];
44 | t3_b [label = "24 = 2 + 22"];
45 | t3_x [label = "t3.x = t1.x + t2.x"];
46 | t3_x_c [label = "33 = 3 + 30"];
47 | t3_x_d [label = "52 = 4 + 48"];
48 | t3_x_e [label = "59 = 5 + 54"];
49 | t3_x_f [label = "6 = 6 + 0"];
50 |
51 | t3 -> t3_b [label = "b"];
52 | t3 -> t3_x [label = "x"];
53 | t3_x -> t3_x_c [label = "c"];
54 | t3_x -> t3_x_d [label = "d"];
55 | t3_x -> t3_x_e [label = "e"];
56 | t3_x -> t3_x_f [label = "f"];
57 | }
58 | }
--------------------------------------------------------------------------------
/treevalue/tree/common/storage.pxd:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | from libcpp cimport bool
5 | cimport cython
6 |
7 | ctypedef unsigned char boolean
8 | ctypedef unsigned int uint
9 |
10 | @cython.final
11 | cdef class TreeStorage:
12 | cdef readonly dict map
13 |
14 | cpdef public void set(self, str key, object value) except *
15 | cpdef public object setdefault(self, str key, object default)
16 | cpdef public object get(self, str key)
17 | cpdef public object get_or_default(self, str key, object default)
18 | cpdef public object pop(self, str key)
19 | cpdef public object pop_or_default(self, str key, object default)
20 | cpdef public tuple popitem(self)
21 | cpdef public void del_(self, str key) except *
22 | cpdef public void clear(self)
23 | cpdef public boolean contains(self, str key)
24 | cpdef public uint size(self)
25 | cpdef public boolean empty(self)
26 | cpdef public dict dump(self)
27 | cpdef public dict deepdump(self)
28 | cpdef public dict deepdumpx(self, copy_func)
29 | cpdef public dict jsondumpx(self, copy_func, bool need_raw, bool allow_delayed)
30 | cpdef public TreeStorage copy(self)
31 | cpdef public TreeStorage deepcopy(self)
32 | cpdef public TreeStorage deepcopyx(self, copy_func, bool allow_delayed)
33 | cpdef public dict detach(self)
34 | cpdef public void copy_from(self, TreeStorage ts)
35 | cpdef public void deepcopy_from(self, TreeStorage ts)
36 | cpdef public void deepcopyx_from(self, TreeStorage ts, copy_func, bool allow_delayed)
37 |
38 | cpdef public object create_storage(dict value)
39 | cdef object _c_undelay_data(dict data, object k, object v)
40 | cdef object _c_undelay_not_none_data(dict data, object k, object v)
41 | cdef object _c_undelay_check_data(dict data, object k, object v)
42 |
--------------------------------------------------------------------------------
/docs/source/tutorials/advanced_usage/outer_demo_1.gv:
--------------------------------------------------------------------------------
1 | digraph {
2 | label = "outer_demo_1 \n(Runnable, t3 is sum of t1 and t2, t1.a and t2.x.f's values are actually treated as 0.)";
3 | graph [bgcolor = "#ffffff00"];
4 |
5 | subgraph cluster_t1 {
6 | label = "Tree t1";
7 | t1 [label = t1];
8 | t1_b [label = 2];
9 | t1_x [label = "t1.x"];
10 | t1_x_c [label = 3];
11 | t1_x_d [label = 4];
12 | t1_x_e [label = 5];
13 | t1_x_f [label = 6];
14 |
15 | t1 -> t1_b [label = "b"];
16 | t1 -> t1_x [label = "x"];
17 | t1_x -> t1_x_c [label = "c"];
18 | t1_x -> t1_x_d [label = "d"];
19 | t1_x -> t1_x_e [label = "e"];
20 | t1_x -> t1_x_f [label = "f"];
21 | }
22 |
23 | subgraph cluster_t2 {
24 | label = "Tree t2";
25 | t2 [label = t2];
26 | t2_a [label = 11];
27 | t2_b [label = 22];
28 | t2_x [label = "t2.x"];
29 | t2_x_c [label = 30];
30 | t2_x_d [label = 48];
31 | t2_x_e [label = 54];
32 |
33 | t2 -> t2_a [label = "a"];
34 | t2 -> t2_b [label = "b"];
35 | t2 -> t2_x [label = "x"];
36 | t2_x -> t2_x_c [label = "c"];
37 | t2_x -> t2_x_d [label = "d"];
38 | t2_x -> t2_x_e [label = "e"];
39 | }
40 |
41 | subgraph cluster_t3 {
42 | label = "Tree t3";
43 | t3 [label = "t3 = t1 + t2"];
44 | t3_a [label = "11 = 0 + 11"];
45 | t3_b [label = "24 = 2 + 22"];
46 | t3_x [label = "t3.x = t1.x + t2.x"];
47 | t3_x_c [label = "33 = 3 + 30"];
48 | t3_x_d [label = "52 = 4 + 48"];
49 | t3_x_e [label = "59 = 5 + 54"];
50 | t3_x_f [label = "6 = 6 + 0"];
51 |
52 | t3 -> t3_a [label = "a"];
53 | t3 -> t3_b [label = "b"];
54 | t3 -> t3_x [label = "x"];
55 | t3_x -> t3_x_c [label = "c"];
56 | t3_x -> t3_x_d [label = "d"];
57 | t3_x -> t3_x_e [label = "e"];
58 | t3_x -> t3_x_f [label = "f"];
59 | }
60 | }
--------------------------------------------------------------------------------
/docs/source/contribute/architecture/architecture.puml:
--------------------------------------------------------------------------------
1 | @startuml
2 | skinparam backgroundcolor transparent
3 | skinparam rectangle<> {
4 | roundCorner 25
5 | }
6 | sprite $primitive jar:archimate/application-component
7 | sprite $data jar:archimate/application-service
8 | sprite $logic jar:archimate/business-process
9 |
10 | rectangle "Logic Layer" as logiclayer {
11 | rectangle "TreeValue\n(Basic logic framework)" as treevalue <<$logic>> #Business
12 | rectangle "**func_treelize**\n(Function wrapper)" as func_treelize <<$logic>> #Application
13 |
14 | rectangle "**FastTreeValue**\n(Common treevalue)" as fasttreevalue <<$logic>> #Business
15 | rectangle "method_treelize" as method_treelize <<$logic>> #Application
16 | rectangle "classmethod_treelize" as classmethod_treelize <<$logic>> #Application
17 |
18 | method_treelize <-- func_treelize: Special wrapper \nfor instance method
19 | classmethod_treelize <-- func_treelize: Special wrapper \nfor class method
20 | fasttreevalue <-down- treevalue: Logic extension
21 | method_treelize --> fasttreevalue: Wrapper support
22 | classmethod_treelize --> fasttreevalue: Wrapper support
23 | }
24 |
25 | rectangle "Data Layer" as datalayer {
26 | rectangle "TreeStorage" as tree_storage <<$data>><> #Application
27 |
28 | tree_storage -up-> treevalue: "TreeValue can \nbe based on Tree."
29 | }
30 |
31 | rectangle "Primitive Python" as python {
32 | rectangle "Python Dict" as dict <<$primitive>> #Application
33 | rectangle "Python Operators" as operators <<$primitive>> #Application
34 | python -up-> tree_storage : "Tree is based on \nprimitive python dict."
35 | }
36 |
37 |
38 |
39 | legend left
40 | Examples of Architecture
41 | ====
42 | <$logic> : Logic Layer
43 | ====
44 | <$data> : Data Layer
45 | ====
46 | <$primitive> : Primitive Python
47 | endlegend
48 | @enduml
49 |
--------------------------------------------------------------------------------
/docs/source/tutorials/main_idea/treelize_demo.gv:
--------------------------------------------------------------------------------
1 | digraph {
2 | label = "An example of tree-based function __add__ (f_T)";
3 | compound = true;
4 | graph [bgcolor = "#ffffff00"]
5 |
6 | subgraph cluster_t1 {
7 | label = "TreeValue t1";
8 | root1 [label = "t1" shape = diamond];
9 | root1_n1 [label = "1" shape = box];
10 | root1_n2 [label = "2" shape = box];
11 | root1_n3 [label = "t1.x"];
12 | root1_n4 [label = "3" shape = box];
13 | root1_n5 [label = "4" shape = box];
14 |
15 | root1 -> root1_n1 [label = "a"];
16 | root1 -> root1_n2 [label = "b"];
17 | root1 -> root1_n3 [label = "x"];
18 | root1_n3 -> root1_n4 [label = "c"];
19 | root1_n3 -> root1_n5 [label = "d"];
20 | }
21 |
22 | subgraph cluster_t2 {
23 | label = "TreeValue t2";
24 | root2 [label = "t2" shape = diamond];
25 | root2_n1 [label = "11" shape = box];
26 | root2_n2 [label = "22" shape = box];
27 | root2_n3 [label = "t2.x"];
28 | root2_n4 [label = "30" shape = box];
29 | root2_n5 [label = "48" shape = box];
30 |
31 | root2 -> root2_n1 [label = "a"];
32 | root2 -> root2_n2 [label = "b"];
33 | root2 -> root2_n3 [label = "x"];
34 | root2_n3 -> root2_n4 [label = "c"];
35 | root2_n3 -> root2_n5 [label = "d"];
36 | }
37 |
38 | subgraph cluster_sum {
39 | label = "Result of __add__(t1, t2)";
40 | root3 [label = "__add__(t1, t2)" shape = diamond];
41 | root3_n1 [label = "12 = __add__(1, 11)" shape = box];
42 | root3_n2 [label = "24 = __add__(2, 22)" shape = box];
43 | root3_n3 [label = "__add__(t1.x, t2.x)"];
44 | root3_n4 [label = "33 = __add__(3, 30)" shape = box];
45 | root3_n5 [label = "52 = __add__(4, 48)" shape = box];
46 |
47 | root3 -> root3_n1 [label = "a"];
48 | root3 -> root3_n2 [label = "b"];
49 | root3 -> root3_n3 [label = "x"];
50 | root3_n3 -> root3_n4 [label = "c"];
51 | root3_n3 -> root3_n5 [label = "d"];
52 | }
53 | }
--------------------------------------------------------------------------------
/docs/source/best_practice/numpy/index.rst:
--------------------------------------------------------------------------------
1 | Apply into Numpy
2 | ======================
3 |
4 | In following parts, we will show some demos about how to use ``TreeValue`` in practice.
5 |
6 | For example, now we have a group of structed data in python-dict type, we want to do different operations on differnent key-value pairs inplace,
7 | get some statistics such as mean value and task some slices.
8 |
9 |
10 | In normal cases, we need to unroll multiple ``for-loop`` and ``if-else`` to implement cooresponding operations on each values, and declare additional
11 | temporal variables to save result. All the mentioned contents are executed serially, like the next code examples:
12 |
13 | .. literalinclude:: without_treevalue.py
14 | :language: python
15 | :linenos:
16 |
17 | However, with the help of ``TreeValue``, all the contents mentioned above can be implemented gracefully and efficiently. Users only need to ``func_treelize`` the primitive
18 | numpy functions and pack data with ``FastTreeValue``, then execute desired operations just like using standard numpy array.
19 |
20 | .. literalinclude:: with_treevalue.py
21 | :language: python
22 | :linenos:
23 |
24 | And we can run these two demos for comparison:
25 |
26 | .. literalinclude:: numpy.demo.py
27 | :language: python
28 | :linenos:
29 |
30 | The final output should be the text below, and all the assertions can be passed.
31 |
32 | .. literalinclude:: numpy.demo.py.txt
33 | :language: text
34 | :linenos:
35 |
36 | In this case, we can see that the ``TreeValue`` can be properly applied into the ``numpy`` library.
37 | The tree-structured matrix calculation can be easily built with ``TreeValue`` like using standard numpy array.
38 |
39 | Both the simplicity of logic structure and execution efficiency can be improve a lot.
40 |
41 | **And Last but not least, the only thing you need to do is to wrap the functions in Numpy library, and then use it painlessly like the primitive numpy.**
42 |
--------------------------------------------------------------------------------
/treevalue/entry/cli/base.py:
--------------------------------------------------------------------------------
1 | import click
2 | from click.core import Context, Option
3 |
4 | from ...config.meta import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__
5 |
6 | _raw_authors = [item.strip() for item in __AUTHOR__.split(',') if item.strip()]
7 | _raw_emails = [item.strip() for item in __AUTHOR_EMAIL__.split(',')]
8 | if len(_raw_emails) < len(_raw_authors):
9 | _raw_emails += [None] * (len(_raw_authors) - len(_raw_emails))
10 | elif len(_raw_emails) > len(_raw_authors):
11 | _raw_emails[len(_raw_authors) - 1] = tuple(_raw_emails[len(_raw_authors) - 1:])
12 | del _raw_emails[len(_raw_authors):]
13 |
14 | _author_tuples = [
15 | (author, tuple([item for item in (email if isinstance(email, tuple) else ((email,) if email else ())) if item]))
16 | for author, email in zip(_raw_authors, _raw_emails)
17 | ]
18 | _authors = [
19 | author if not emails else '{author} ({emails})'.format(author=author, emails=', '.join(emails))
20 | for author, emails in _author_tuples
21 | ]
22 |
23 |
24 | # noinspection PyUnusedLocal
25 | def print_version(ctx: Context, param: Option, value: bool) -> None:
26 | """
27 | Print version information of cli
28 | :param ctx: click context
29 | :param param: current parameter's metadata
30 | :param value: value of current parameter
31 | """
32 | if not value or ctx.resilient_parsing:
33 | return
34 | click.echo('{title}, version {version}.'.format(title=__TITLE__.capitalize(), version=__VERSION__))
35 | if _authors:
36 | click.echo('Developed by {authors}.'.format(authors=', '.join(_authors)))
37 | ctx.exit()
38 |
39 |
40 | CONTEXT_SETTINGS = dict(
41 | help_option_names=['-h', '--help']
42 | )
43 |
44 |
45 | @click.group(context_settings=CONTEXT_SETTINGS)
46 | @click.option('-v', '--version', is_flag=True,
47 | callback=print_version, expose_value=False, is_eager=True,
48 | help="Show package's version information.")
49 | def _base_treevalue_cli():
50 | pass
51 |
--------------------------------------------------------------------------------
/treevalue/entry/cli/io.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import pickle
4 | from string import Template
5 | from typing import Tuple, Iterator
6 |
7 | import dill
8 | from hbutils.reflection import dynamic_call, iter_import_objects
9 |
10 | from ...tree import TreeValue, load
11 |
12 |
13 | @dynamic_call
14 | def _import_trees_from_package(obj_pattern, title=None, *args,
15 | default_template: str = '$name') -> Iterator[Tuple[TreeValue, str]]:
16 | _title_template = Template(title or default_template)
17 | for _object, _module, _name in iter_import_objects(obj_pattern, lambda o: isinstance(o, TreeValue)):
18 | _title = _title_template.safe_substitute(dict(module=_module, name=_name))
19 | yield _object, _title
20 |
21 |
22 | @dynamic_call
23 | def _import_trees_from_binary(filename_pattern, title='', *args,
24 | default_template: str = '$bodyname') -> Iterator[Tuple[TreeValue, str]]:
25 | _title_template = Template(title or default_template)
26 | for filename in glob.glob(filename_pattern):
27 | if not os.path.exists(filename) or not os.path.isfile(filename) or not os.access(filename, os.R_OK):
28 | continue
29 |
30 | filename = os.path.abspath(filename)
31 | _name_body, _name_ext = os.path.splitext(os.path.basename(filename))
32 | _name_ext = _name_ext[1:] if _name_ext.startswith('.') else _name_ext
33 | with open(filename, 'rb') as file:
34 | try:
35 | _tree = load(file)
36 | except (pickle.UnpicklingError, dill.UnpicklingError, EOFError, IOError):
37 | continue
38 | else:
39 | yield _tree, _title_template.safe_substitute(dict(
40 | fullname=filename,
41 | dirname=os.path.dirname(filename),
42 | basename=os.path.basename(filename),
43 | extname=_name_ext,
44 | bodyname=_name_body,
45 | ))
46 |
--------------------------------------------------------------------------------
/.github/workflows/doc.yml:
--------------------------------------------------------------------------------
1 | # This workflow will check flake style
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Docs Deploy
5 |
6 | on:
7 | push:
8 | branches: [ main, 'doc/*', 'dev/*' ]
9 | release:
10 | types: [ published ]
11 |
12 | jobs:
13 | doc:
14 | runs-on: ubuntu-latest
15 | if: "!contains(github.event.head_commit.message, 'ci skip')"
16 | strategy:
17 | matrix:
18 | python-version:
19 | - '3.8'
20 |
21 | services:
22 | plantuml:
23 | image: plantuml/plantuml-server
24 | ports:
25 | - 18080:8080
26 |
27 | steps:
28 | - uses: actions/checkout@v3
29 | - name: Set up Python ${{ matrix.python-version }}
30 | uses: actions/setup-python@v4
31 | with:
32 | python-version: ${{ matrix.python-version }}
33 | - name: Install dependencies
34 | run: |
35 | sudo apt-get update -y
36 | sudo apt-get install -y make wget curl cloc graphviz pandoc
37 | dot -V
38 | python -m pip install -r requirements.txt
39 | python -m pip install -r requirements-doc.txt
40 | python -m pip install -r requirements-benchmark.txt
41 | - name: Generate
42 | env:
43 | ENV_PROD: 'true'
44 | PLANTUML_HOST: http://localhost:18080
45 | run: |
46 | git fetch --all --tags
47 | git branch -av
48 | git remote -v
49 | git tag
50 | plantumlcli -c
51 | make pdocs
52 | mv ./docs/build/html ./public
53 | - name: Deploy to Github Page
54 | uses: JamesIves/github-pages-deploy-action@3.7.1
55 | with:
56 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
57 | BRANCH: gh-pages # The branch the action should deploy to.
58 | FOLDER: public # The folder the action should deploy.
59 | CLEAN: true # Automatically remove deleted files from the deploy branch
60 |
--------------------------------------------------------------------------------
/treevalue/tree/common/base.pyx:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | import cython
5 |
6 | cdef class RawWrapper:
7 | """
8 | Wrapper class of the raw value.
9 | """
10 | @cython.binding(True)
11 | def __cinit__(self, object v):
12 | """
13 | Overview:
14 | C-leveled constructor of :class:`RawWrapper`.
15 |
16 | Arguments:
17 | - v (:obj:`object`): Value to be wrapped.
18 | """
19 | self.val = v
20 |
21 | def __getnewargs_ex__(self): # for __cinit__, when pickle.loads
22 | return (None,), {}
23 |
24 | @cython.binding(True)
25 | cpdef object value(self):
26 | """
27 | Overview:
28 | Get wrapped original value.
29 |
30 | Returns:
31 | - obj: Original value.
32 | """
33 | return self.val
34 |
35 | def __getstate__(self):
36 | return self.val
37 |
38 | def __setstate__(self, state):
39 | self.val = state
40 |
41 | @cython.binding(True)
42 | cpdef inline object raw(object obj):
43 | """
44 | Overview:
45 | Try wrap the given ``obj`` to raw wrapper.
46 |
47 | Arguments:
48 | - obj (:obj:`object`): The original object.
49 |
50 | Returns:
51 | - wrapped (:obj:`object`): Wrapped object, if the type is not \
52 | necessary to be wrapped, the original object will be returned here.
53 | """
54 | if not isinstance(obj, RawWrapper) and isinstance(obj, dict):
55 | return RawWrapper(obj)
56 | else:
57 | return obj
58 |
59 | @cython.binding(True)
60 | cpdef inline object unraw(object wrapped):
61 | """
62 | Overview:
63 | Try unwrap the given ``wrapped`` to original object.
64 |
65 | Arguments:
66 | - wrapped (:obj:`object`): Wrapped object.
67 |
68 | Returns:
69 | - obj (:obj:`object`): The original object.
70 | """
71 | if isinstance(wrapped, RawWrapper):
72 | return wrapped.value()
73 | else:
74 | return wrapped
75 |
--------------------------------------------------------------------------------
/test/tree/integration/test_jax.py:
--------------------------------------------------------------------------------
1 | from unittest import skipUnless
2 |
3 | import numpy as np
4 | import pytest
5 |
6 | from treevalue import FastTreeValue, register_for_jax
7 |
8 | try:
9 | import jax
10 | from jax.tree_util import register_pytree_node
11 | except (ModuleNotFoundError, ImportError):
12 | jax = None
13 |
14 |
15 | @pytest.mark.unittest
16 | class TestTreeTreeIntegration:
17 | @skipUnless(jax, 'Jax required.')
18 | def test_jax_double(self):
19 | @jax.jit
20 | def double(x):
21 | return x * 2 + 1.5
22 |
23 | t1 = FastTreeValue({
24 | 'a': np.random.randint(0, 10, (2, 3)) + 1,
25 | 'b': {
26 | 'x': np.asarray(233.0),
27 | 'y': np.random.randn(2, 3) + 100,
28 | }
29 | })
30 | r1 = double(t1)
31 | assert type(r1) is FastTreeValue
32 | assert FastTreeValue.func()(np.isclose)(r1, t1 * 2 + 1.5).all() == \
33 | FastTreeValue({'a': True, 'b': {'x': True, 'y': True}})
34 |
35 | class MyTreeValue(FastTreeValue):
36 | pass
37 |
38 | register_for_jax(MyTreeValue)
39 |
40 | t2 = MyTreeValue({
41 | 'a': np.random.randint(0, 10, (2, 3)) + 1,
42 | 'b': {
43 | 'x': np.asarray(233.0),
44 | 'y': np.random.randn(2, 3) + 100,
45 | }
46 | })
47 | r2 = double(t2)
48 | assert type(r2) is MyTreeValue
49 | assert MyTreeValue.func()(np.isclose)(r2, t2 * 2 + 1.5).all() == \
50 | MyTreeValue({'a': True, 'b': {'x': True, 'y': True}})
51 |
52 | @skipUnless(jax, 'Jax required.')
53 | def test_error_register(self):
54 | with pytest.raises(TypeError):
55 | register_for_jax(None)
56 | with pytest.raises(TypeError):
57 | register_for_jax(list)
58 |
59 | @skipUnless(not jax, 'No jax required')
60 | def test_ignored_register(self):
61 | class MyTreeValueX(FastTreeValue):
62 | pass
63 |
64 | with pytest.warns(UserWarning):
65 | register_for_jax(MyTreeValueX)
66 |
--------------------------------------------------------------------------------
/docs/source/api_doc/tree/common.rst:
--------------------------------------------------------------------------------
1 | treevalue.tree.common
2 | ======================
3 |
4 | .. py:currentmodule:: treevalue.tree.common
5 |
6 | .. _apidoc_tree_common_treestorage:
7 |
8 | TreeStorage
9 | -------------
10 |
11 | .. autoclass:: TreeStorage
12 | :members: get, get_or_default, pop, pop_or_default, popitem, set, setdefault, del_, contains, size, empty, copy, deepcopy, deepcopyx, dump, deepdump, deepdumpx, jsondumpx, copy_from, deepcopy_from, deepcopyx_from, detach, clear, iter_keys, iter_rev_keys, iter_values, iter_rev_values, iter_items, iter_rev_items
13 |
14 | .. note::
15 | Please refer to the source code for method details in this section of the documentation \
16 | because adding method signatures will significantly decrease running speed.
17 |
18 |
19 | .. _apidoc_tree_common_create_storage:
20 |
21 | create_storage
22 | -------------------
23 |
24 | .. autofunction:: create_storage
25 |
26 |
27 | .. _apidoc_tree_common_raw:
28 |
29 | raw
30 | ----------
31 |
32 | .. autofunction:: raw
33 |
34 |
35 | .. _apidoc_tree_common_unraw:
36 |
37 | unraw
38 | ----------
39 |
40 | .. autofunction:: unraw
41 |
42 |
43 | .. _apidoc_tree_common_rawwrapper:
44 |
45 | RawWrapper
46 | ---------------
47 |
48 | .. autoclass:: RawWrapper
49 | :members: __init__, value
50 |
51 |
52 | .. _apidoc_tree_common_delayed_partial:
53 |
54 | delayed_partial
55 | ----------------------
56 |
57 | .. autofunction:: delayed_partial
58 |
59 |
60 | .. _apidoc_tree_common_undelay:
61 |
62 | undelay
63 | ---------------
64 |
65 | .. autofunction:: undelay
66 |
67 |
68 | .. _apidoc_tree_common_delayedproxy:
69 |
70 | DelayedProxy
71 | -------------------
72 |
73 | .. autoclass:: DelayedProxy
74 | :members: value, fvalue
75 |
76 |
77 | .. _apidoc_tree_common_delayedvalueproxy:
78 |
79 | DelayedValueProxy
80 | -------------------
81 |
82 | .. autoclass:: DelayedValueProxy
83 | :members: __cinit__, value, fvalue
84 |
85 |
86 | .. _apidoc_tree_common_delayedfuncproxy:
87 |
88 | DelayedFuncProxy
89 | -------------------
90 |
91 | .. autoclass:: DelayedFuncProxy
92 | :members: __cinit__, value, fvalue
93 |
94 |
--------------------------------------------------------------------------------
/.github/workflows/run.yml:
--------------------------------------------------------------------------------
1 | name: Code Script Run
2 |
3 | on:
4 | push:
5 | branches: [ 'run/*' ]
6 |
7 | jobs:
8 | unittest:
9 | name: Code Script Run
10 | runs-on: ${{ matrix.os }}
11 | if: "!contains(github.event.head_commit.message, 'ci skip')"
12 | strategy:
13 | fail-fast: false
14 | matrix:
15 | os:
16 | - 'ubuntu-20.04'
17 | python-version:
18 | - '3.8'
19 | - '3.9'
20 | - '3.10'
21 | - '3.11'
22 | - '3.12'
23 |
24 | steps:
25 | - name: Checkout code
26 | uses: actions/checkout@v3
27 | with:
28 | fetch-depth: 20
29 | - name: Set up system dependences on linux
30 | if: ${{ runner.os == 'Linux' }}
31 | run: |
32 | sudo apt-get update
33 | sudo apt-get install -y tree cloc wget curl make graphviz
34 | sudo apt-get install -y libxml2-dev libxslt-dev python-dev # need by pypy3
35 | dot -V
36 | - name: Set up python ${{ matrix.python-version }}
37 | uses: actions/setup-python@v4
38 | with:
39 | python-version: ${{ matrix.python-version }}
40 | - name: Install dependencies
41 | run: |
42 | python -m pip install --upgrade pip
43 | pip install --upgrade flake8 setuptools wheel twine
44 | pip install -r requirements.txt
45 | pip install -r requirements-build.txt
46 | pip install -r requirements-test.txt
47 | pip install .
48 | - name: Test the basic environment
49 | run: |
50 | python -V
51 | pip --version
52 | pip list
53 | tree .
54 | cloc treevalue
55 | cloc test
56 | - name: Run Script
57 | env:
58 | CI: 'true'
59 | LINETRACE: 1
60 | run: |
61 | make clean build run
62 | - name: Show the artifacts
63 | run: |
64 | tree runs/artifacts
65 | - uses: actions/upload-artifact@v3
66 | with:
67 | name: run-artifacts-${{ matrix.os }}-${{ matrix.python-version }}
68 | path: runs/artifacts
69 |
--------------------------------------------------------------------------------
/docs/source/demos.mk:
--------------------------------------------------------------------------------
1 | PYTHON := $(shell which python)
2 |
3 | SOURCE ?= .
4 | PYTHON_DEMOS := $(shell find ${SOURCE} -name *.demo.py)
5 | PYTHON_DEMOXS := $(shell find ${SOURCE} -name *.demox.py)
6 | PYTHON_RESULTS := $(addsuffix .py.txt, $(basename ${PYTHON_DEMOS} ${PYTHON_DEMOXS}))
7 |
8 | SHELL_DEMOS := $(shell find ${SOURCE} -name *.demo.sh)
9 | SHELL_DEMOXS := $(shell find ${SOURCE} -name *.demox.sh)
10 | SHELL_RESULTS := $(addsuffix .sh.txt, $(basename ${SHELL_DEMOS} ${SHELL_DEMOXS}))
11 |
12 | %.demo.py.txt: %.demo.py
13 | cd "$(shell dirname $(shell readlink -f $<))" && \
14 | PYTHONPATH="$(shell dirname $(shell readlink -f $<)):${PYTHONPATH}" \
15 | $(PYTHON) "$(shell readlink -f $<)" > "$(shell readlink -f $@)"
16 |
17 | %.demox.py.txt: %.demox.py
18 | cd "$(shell dirname $(shell readlink -f $<))" && \
19 | PYTHONPATH="$(shell dirname $(shell readlink -f $<)):${PYTHONPATH}" \
20 | $(PYTHON) "$(shell readlink -f $<)" 1> "$(shell readlink -f $@)" \
21 | 2> "$(shell readlink -f $(addsuffix .err, $(basename $@)))"; \
22 | echo $$? > "$(shell readlink -f $(addsuffix .exitcode, $(basename $@)))"
23 |
24 | %.demo.sh.txt: %.demo.sh
25 | cd "$(shell dirname $(shell readlink -f $<))" && \
26 | PYTHONPATH="$(shell dirname $(shell readlink -f $<)):${PYTHONPATH}" \
27 | $(SHELL) "$(shell readlink -f $<)" > "$(shell readlink -f $@)"
28 |
29 | %.demox.sh.txt: %.demox.sh
30 | cd "$(shell dirname $(shell readlink -f $<))" && \
31 | PYTHONPATH="$(shell dirname $(shell readlink -f $<)):${PYTHONPATH}" \
32 | $(SHELL) "$(shell readlink -f $<)" 1> "$(shell readlink -f $@)" \
33 | 2> "$(shell readlink -f $(addsuffix .err, $(basename $@)))"; \
34 | echo $$? > "$(shell readlink -f $(addsuffix .exitcode, $(basename $@)))"
35 |
36 | build: ${PYTHON_RESULTS} ${SHELL_RESULTS}
37 |
38 | all: build
39 |
40 | clean:
41 | rm -rf \
42 | $(shell find ${SOURCE} -name *.py.txt) \
43 | $(shell find ${SOURCE} -name *.py.err) \
44 | $(shell find ${SOURCE} -name *.py.exitcode) \
45 | $(shell find ${SOURCE} -name *.sh.txt) \
46 | $(shell find ${SOURCE} -name *.sh.err) \
47 | $(shell find ${SOURCE} -name *.sh.exitcode) \
48 | $(shell find ${SOURCE} -name *.dat.*)
49 |
--------------------------------------------------------------------------------
/test/tree/func/test_left.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from treevalue.tree import func_treelize, TreeValue
4 |
5 |
6 | # noinspection DuplicatedCode
7 | @pytest.mark.unittest
8 | class TestTreeFuncLeft:
9 | def test_left_raw(self):
10 | @func_treelize('left')
11 | def ssum(*args):
12 | return sum(args)
13 |
14 | t1 = TreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
15 | t2 = TreeValue({'a': 11, 'b': 22, 'x': {'c': 33, 'd': 44}})
16 | assert ssum(1, 2, 3) == 6
17 | assert ssum(t1, t2) == TreeValue({'a': 12, 'b': 24, 'x': {'c': 36, 'd': 48}})
18 | assert ssum(t1.x, t2.x) == TreeValue({'c': 36, 'd': 48})
19 |
20 | t3 = TreeValue({'a': 11, 'b': 22, 'c': 33, 'x': {'c': 33, 'd': 44, 'e': 550}})
21 | assert ssum(t1, t2, t3) == TreeValue({'a': 23, 'b': 46, 'x': {'c': 69, 'd': 92}})
22 | assert ssum(t2, t3) == TreeValue({'a': 22, 'b': 44, 'x': {'c': 66, 'd': 88}})
23 | with pytest.raises(KeyError):
24 | _ = ssum(t3, t1, t2)
25 |
26 | def test_left_missing(self):
27 | @func_treelize('left', missing=0)
28 | def ssum(*args):
29 | return sum(args)
30 |
31 | t1 = TreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
32 | t2 = TreeValue({'a': 11, 'b': 22, 'x': {'c': 33, 'd': 44}})
33 | t3 = TreeValue({'a': 11, 'b': 22, 'c': 33, 'x': {'c': 33, 'd': 44, 'e': 550}})
34 | assert ssum(t1, t2, t3) == TreeValue({'a': 23, 'b': 46, 'x': {'c': 69, 'd': 92}})
35 | assert ssum(t2, t3) == TreeValue({'a': 22, 'b': 44, 'x': {'c': 66, 'd': 88}})
36 | assert ssum(t3, t1, t2) == TreeValue({'a': 23, 'b': 46, 'c': 33, 'x': {'c': 69, 'd': 92, 'e': 550}})
37 |
38 | def test_left_inherit(self):
39 | @func_treelize('left', missing=0, )
40 | def ssum(*args):
41 | return sum(args)
42 |
43 | t1 = TreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
44 | t2 = TreeValue({'a': 11, 'b': 22, 'x': {'c': 33, 'd': 44}})
45 | t3 = TreeValue({'a': 11, 'b': 22, 'c': 33, 'x': {'c': 33, 'd': 44, 'e': 550}})
46 | assert ssum(t3, t1, t2, -10) == TreeValue({'a': 13, 'b': 36, 'c': 23, 'x': {'c': 59, 'd': 82, 'e': 540}})
47 |
--------------------------------------------------------------------------------
/test/tree/func/test_strict.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from treevalue.tree import func_treelize, TreeValue
4 |
5 |
6 | # noinspection DuplicatedCode
7 | @pytest.mark.unittest
8 | class TestTreeFuncStrict:
9 | def test_strict_raw(self):
10 | @func_treelize(inherit=False)
11 | def ssum(*args):
12 | return sum(args)
13 |
14 | t1 = TreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
15 | t2 = TreeValue({'a': 11, 'b': 22, 'x': {'c': 33, 'd': 44}})
16 | t3 = TreeValue({'a': 11, 'b': 22, 'x': {'c': 33, 'dd': 44}})
17 |
18 | assert ssum(1, 2, 3) == 6
19 | assert ssum(t1, t2) == TreeValue({'a': 12, 'b': 24, 'x': {'c': 36, 'd': 48}})
20 | assert ssum(t1.x, t2.x) == TreeValue({'c': 36, 'd': 48})
21 | with pytest.raises(KeyError):
22 | _ = ssum(t1, t3)
23 | with pytest.raises(TypeError):
24 | _ = ssum(t1, 1)
25 |
26 | def test_strict_inherit(self):
27 | @func_treelize()
28 | def ssum(*args):
29 | return sum(args)
30 |
31 | t1 = TreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
32 | t2 = TreeValue({'a': 11, 'b': 22, 'x': {'c': 33, 'd': 44}})
33 | assert ssum(1, 2, 3) == 6
34 | assert ssum(t1, t2) == TreeValue({'a': 12, 'b': 24, 'x': {'c': 36, 'd': 48}})
35 | assert ssum(t1.x, t2.x) == TreeValue({'c': 36, 'd': 48})
36 | assert ssum(t1, 1) == TreeValue({'a': 2, 'b': 3, 'x': {'c': 4, 'd': 5}})
37 | assert ssum(t1, TreeValue({'a': 2, 'b': 3, 'x': 80})) == TreeValue({'a': 3, 'b': 5, 'x': {'c': 83, 'd': 84}})
38 |
39 | def test_strict_missing(self):
40 | def ssum(*args):
41 | return sum(args)
42 |
43 | with pytest.warns(RuntimeWarning):
44 | ssum = func_treelize(missing=0)(ssum)
45 |
46 | t1 = TreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
47 | t2 = TreeValue({'a': 11, 'b': 22, 'x': {'c': 33, 'd': 44}})
48 | t3 = TreeValue({'a': 11, 'b': 22, 'x': {'c': 33, 'dd': 44}})
49 |
50 | assert ssum(1, 2, 3) == 6
51 | assert ssum(t1, t2) == TreeValue({'a': 12, 'b': 24, 'x': {'c': 36, 'd': 48}})
52 | with pytest.raises(KeyError):
53 | _ = ssum(t1, t3)
54 |
--------------------------------------------------------------------------------
/bmtrans.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os.path
3 | import re
4 |
5 | import click
6 | import pandas as pd
7 |
8 | TEST_FUNC_PATTERN = re.compile(r'^(?P[a-zA-Z0-9_]+)(\[[^\]]*\])?$')
9 |
10 |
11 | def get_testfunc_name(name: str):
12 | matching = TEST_FUNC_PATTERN.fullmatch(name)
13 | return matching.group('func')
14 |
15 |
16 | def open_benchmark_file(filename: str) -> pd.DataFrame:
17 | with open(filename, 'r') as f:
18 | json_ = json.load(f)
19 |
20 | bms = json_['benchmarks']
21 | params, stats = set(), set()
22 | for item in bms:
23 | if item.get('params', None):
24 | params |= set(item['params'].keys())
25 | if item.get('stats', None):
26 | stats |= set(item['stats'].keys())
27 |
28 | params, stats = list(params), list(stats)
29 | columns = [
30 | 'group_name',
31 | 'func_name',
32 | *(f'param:{p}' for p in params),
33 | *stats,
34 | ]
35 | data = []
36 | for item in bms:
37 | item_params = dict(item.get('params', None) or {})
38 | item_stats = dict(item.get('stats', None) or {})
39 | data.append((
40 | item.get('group', None),
41 | get_testfunc_name(item['name']),
42 | *(item_params.get(p, None) for p in params),
43 | *(item_stats.get(s, None) for s in stats),
44 | ))
45 |
46 | return pd.DataFrame(data=data, columns=columns)
47 |
48 |
49 | @click.command(context_settings=dict(help_option_names=['-h', '--help']),
50 | help='Transform json format file created by pytest-benchmark to simple csv format.')
51 | @click.option('-i', '--input', 'input_filename', type=click.Path(exists=True, dir_okay=False, readable=True),
52 | help='Input json file.')
53 | @click.option('-o', '--output', 'output_filename', type=click.Path(dir_okay=False),
54 | help='Output csv file.')
55 | def trans(input_filename: str, output_filename: str):
56 | df = open_benchmark_file(input_filename)
57 | output_dir, _ = os.path.split(output_filename)
58 | if output_dir:
59 | os.makedirs(output_dir, exist_ok=True)
60 | df.to_csv(output_filename)
61 |
62 |
63 | if __name__ == '__main__':
64 | trans()
65 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # You can set these variables from the command line, and also
2 | # from the environment for the first two.
3 | SPHINXOPTS ?=
4 | SPHINXBUILD ?= $(shell which sphinx-build)
5 | SPHINXMULTIVERSION ?= $(shell which sphinx-multiversion)
6 | SOURCEDIR ?= source
7 | BUILDDIR ?= build
8 |
9 | # Minimal makefile for Sphinx documentation
10 | DIAGRAMS_MK := ${SOURCEDIR}/diagrams.mk
11 | DIAGRAMS := $(MAKE) -f "${DIAGRAMS_MK}" SOURCE=${SOURCEDIR}
12 | GRAPHVIZ_MK := ${SOURCEDIR}/graphviz.mk
13 | GRAPHVIZ := $(MAKE) -f "${GRAPHVIZ_MK}" SOURCE=${SOURCEDIR}
14 | DEMOS_MK := ${SOURCEDIR}/demos.mk
15 | DEMOS := $(MAKE) -f "${DEMOS_MK}" SOURCE=${SOURCEDIR}
16 | NOTEBOOK_MK := ${SOURCEDIR}/notebook.mk
17 | NOTEBOOK := $(MAKE) -f "${NOTEBOOK_MK}" SOURCE=${SOURCEDIR}
18 |
19 | _CURRENT_PATH := ${PATH}
20 | _PROJ_DIR := $(shell readlink -f ${CURDIR}/..)
21 | _LIBS_DIR := $(shell readlink -f ${SOURCEDIR}/_libs)
22 | _SHIMS_DIR := $(shell readlink -f ${SOURCEDIR}/_shims)
23 |
24 | .EXPORT_ALL_VARIABLES:
25 |
26 | PYTHONPATH = ${_PROJ_DIR}:${_LIBS_DIR}
27 | PATH = ${_SHIMS_DIR}:${_CURRENT_PATH}
28 | NO_CONTENTS_BUILD = true
29 |
30 | # Catch-all target: route all unknown targets to Sphinx using the new
31 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
32 | # Put it first so that "make" without argument is like "make help".
33 | .PHONY: help contents build html prod clean sourcedir builddir Makefile
34 |
35 | help:
36 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
37 |
38 | contents:
39 | @$(DIAGRAMS) build
40 | @$(GRAPHVIZ) build
41 | @$(DEMOS) build
42 | @$(NOTEBOOK) build
43 | build: html
44 | html: contents
45 | @$(SPHINXBUILD) -M html "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
46 | @touch "$(BUILDDIR)/html/.nojekyll"
47 | prod:
48 | @NO_CONTENTS_BUILD='' $(SPHINXMULTIVERSION) "$(SOURCEDIR)" "$(BUILDDIR)/html" $(SPHINXOPTS) $(O)
49 | @cp main_page.html "$(BUILDDIR)/html/index.html"
50 | @touch "$(BUILDDIR)/html/.nojekyll"
51 |
52 | clean:
53 | @$(DIAGRAMS) clean
54 | @$(GRAPHVIZ) clean
55 | @$(DEMOS) clean
56 | @$(NOTEBOOK) clean
57 | @$(SPHINXBUILD) -M clean "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
58 |
59 | sourcedir:
60 | @echo $(shell readlink -f ${SOURCEDIR})
61 | builddir:
62 | @echo $(shell readlink -f ${BUILDDIR}/html)
--------------------------------------------------------------------------------
/docs/source/comparison/environment.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "collapsed": false,
7 | "pycharm": {
8 | "name": "#%% md\n"
9 | }
10 | },
11 | "source": [
12 | "# Run Environment Information"
13 | ]
14 | },
15 | {
16 | "cell_type": "markdown",
17 | "metadata": {
18 | "collapsed": false,
19 | "pycharm": {
20 | "name": "#%% md\n"
21 | }
22 | },
23 | "source": [
24 | "Here is the information from the running environment."
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "metadata": {
31 | "pycharm": {
32 | "name": "#%%\n"
33 | }
34 | },
35 | "outputs": [],
36 | "source": [
37 | "import os\n",
38 | "import platform\n",
39 | "import shutil\n",
40 | "\n",
41 | "import cpuinfo\n",
42 | "import psutil\n",
43 | "from hbutils.scale import size_to_bytes_str\n",
44 | "\n",
45 | "print('OS:', platform.platform())\n",
46 | "print('Python:', platform.python_implementation(), platform.python_version())\n",
47 | "print('CPU Brand:', cpuinfo.get_cpu_info()[\"brand_raw\"])\n",
48 | "print('CPU Count:', os.cpu_count())\n",
49 | "print('CPU Freq:', psutil.cpu_freq().current, 'MHz')\n",
50 | "print('Memory Size:', size_to_bytes_str(psutil.virtual_memory().total, precision=3))\n",
51 | "print('Has CUDA:', 'Yes' if shutil.which('nvidia-smi') else 'No')"
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "metadata": {
57 | "collapsed": false,
58 | "pycharm": {
59 | "name": "#%% md\n"
60 | }
61 | },
62 | "source": [
63 | "Please note that, these information in deployed documentation is automatically executed on Github Action. Therefore, some performance data may be different. In the `README.md`, it is the result of our local test."
64 | ]
65 | }
66 | ],
67 | "metadata": {
68 | "kernelspec": {
69 | "display_name": "Python 3",
70 | "language": "python",
71 | "name": "python3"
72 | },
73 | "language_info": {
74 | "codemirror_mode": {
75 | "name": "ipython",
76 | "version": 2
77 | },
78 | "file_extension": ".py",
79 | "mimetype": "text/x-python",
80 | "name": "python",
81 | "nbconvert_exporter": "python",
82 | "pygments_lexer": "ipython2",
83 | "version": "2.7.6"
84 | }
85 | },
86 | "nbformat": 4,
87 | "nbformat_minor": 0
88 | }
89 |
--------------------------------------------------------------------------------
/benchmark/facebook/test_nest.py:
--------------------------------------------------------------------------------
1 | from ..base import CMP_N
2 |
3 | try:
4 | import nest
5 | except ImportError:
6 | nest = None
7 | import pytest
8 |
9 | from treevalue import FastTreeValue, flatten, mapping, func_treelize, unflatten, union
10 |
11 | _TREE_DATA_1 = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}
12 | _TREE_1 = FastTreeValue(_TREE_DATA_1)
13 |
14 | _UMARK = pytest.mark.benchmark(group='facebook-nest') if nest is not None else pytest.mark.ignore
15 |
16 |
17 | @_UMARK
18 | class TestCompareFacebookNest:
19 | N = CMP_N
20 |
21 | def __create_nested_tree_data(self, n):
22 | return {
23 | ('no_%04d' % (i + 1,)): _TREE_DATA_1 for i in range(n)
24 | }
25 |
26 | def __create_nested_tree(self, n):
27 | return FastTreeValue(self.__create_nested_tree_data(n))
28 |
29 | @pytest.mark.parametrize('n', [2 ** i for i in range(N)])
30 | def test_nest_flatten(self, benchmark, n):
31 | benchmark(nest.flatten, self.__create_nested_tree_data(n))
32 |
33 | @pytest.mark.parametrize('n', [2 ** i for i in range(N)])
34 | def test_tv_flatten(self, benchmark, n):
35 | benchmark(flatten, self.__create_nested_tree(n))
36 |
37 | def test_nest_pack_as(self, benchmark):
38 | benchmark(nest.pack_as, _TREE_DATA_1, nest.flatten(_TREE_DATA_1))
39 |
40 | def test_tv_unflatten(self, benchmark):
41 | benchmark(unflatten, flatten(_TREE_1))
42 |
43 | @pytest.mark.parametrize('n', [2 ** i for i in range(N)])
44 | def test_nest_map(self, benchmark, n):
45 | benchmark(nest.map, lambda x: x ** 2, self.__create_nested_tree_data(n))
46 |
47 | @pytest.mark.parametrize('n', [2 ** i for i in range(N)])
48 | def test_tv_map(self, benchmark, n):
49 | benchmark(mapping, self.__create_nested_tree(n), lambda x: x ** 2)
50 |
51 | def test_nest_map_many2(self, benchmark):
52 | def f(a, b):
53 | return a ** b + a * b
54 |
55 | benchmark(nest.map_many2, f, _TREE_DATA_1, _TREE_DATA_1)
56 |
57 | def test_nest_map_many(self, benchmark):
58 | def f(a):
59 | return a[0] ** a[1] + a[0] * a[1]
60 |
61 | benchmark(nest.map_many, f, _TREE_DATA_1, _TREE_DATA_1)
62 |
63 | def test_tv_treelize_call(self, benchmark):
64 | @func_treelize()
65 | def f(a, b):
66 | return a ** b + a * b
67 |
68 | benchmark(f, _TREE_1, _TREE_1)
69 |
70 | def test_tv_mapping_union(self, benchmark):
71 | def f(a):
72 | return a[0] ** a[1]
73 |
74 | def _my_func(fx, *v):
75 | return mapping(union(*v), fx)
76 |
77 | benchmark(_my_func, f, _TREE_1, _TREE_1)
78 |
--------------------------------------------------------------------------------
/treevalue/entry/cli/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from contextlib import contextmanager
3 | from functools import wraps
4 | from typing import Callable, Union, Tuple
5 |
6 | import click
7 | from hbutils.reflection import dynamic_call, str_traceback
8 |
9 |
10 | def validator(func):
11 | func = dynamic_call(func)
12 |
13 | @wraps(func)
14 | def _new_func(ctx, param, value):
15 | return func(ctx=ctx, param=param, value=value)
16 |
17 | return _new_func
18 |
19 |
20 | def multiple_validator(func):
21 | func = validator(func)
22 |
23 | @wraps(func)
24 | def _new_func(ctx, param, value):
25 | return [func(ctx, param, item) for item in value]
26 |
27 | return _new_func
28 |
29 |
30 | _EXPECTED_TREE_ERRORS = (
31 | ValueError, TypeError, ImportError, AttributeError, ModuleNotFoundError,
32 | FileNotFoundError, IsADirectoryError, PermissionError, FileExistsError,
33 | )
34 |
35 | _EXCEPTION_WRAPPED = '__exception_wrapped__'
36 |
37 |
38 | def err_validator(types: Union[type, Tuple[type]]):
39 | def _decorator(func):
40 | func = validator(func)
41 |
42 | @wraps(func)
43 | def _new_func(ctx, param, value):
44 | try:
45 | return func(ctx, param, value)
46 | except click.BadParameter as err:
47 | raise err
48 | except types as err:
49 | _messages = [item for item in err.args if isinstance(item, str)]
50 | _final_message = _messages[0] if _messages else str(_messages)
51 | raise click.BadParameter(_final_message)
52 |
53 | return _new_func
54 |
55 | return _decorator
56 |
57 |
58 | def _cli_builder(base_cli, *wrappers):
59 | _cli = None
60 | for wrapper in wrappers:
61 | _cli = wrapper(_cli or base_cli)
62 | return _cli
63 |
64 |
65 | @contextmanager
66 | def _click_pending(text: str, ok: Union[Callable, str] = 'OK', error: Union[Callable, str] = 'ERROR'):
67 | if not hasattr(ok, '__call__'):
68 | _okay_text = str(ok)
69 | ok = lambda: _okay_text
70 | ok = dynamic_call(ok)
71 |
72 | if not hasattr(error, '__call__'):
73 | _error_text = str(error)
74 | error = lambda: _error_text
75 | error = dynamic_call(error)
76 |
77 | click.echo(text, nl=False)
78 |
79 | try:
80 | yield
81 | except BaseException as err:
82 | click.secho(click.style(error(err), fg='red'), nl=False)
83 | click.secho(str_traceback(err), file=sys.stderr)
84 | raise err
85 | else:
86 | click.secho(click.style(ok(), fg='green'), nl=False)
87 | finally:
88 | click.echo('.', nl=True)
89 |
--------------------------------------------------------------------------------
/cloc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # This scripts counts the lines of code and comments in all source files
4 | # and prints the results to the command line. It uses the commandline tool
5 | # "cloc". You can either pass --loc, --comments or --percentage to show the
6 | # respective values only.
7 | # Some parts below need to be adapted to your project!
8 |
9 | # Get the location of this script.
10 | SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
11 |
12 | # Run cloc - this counts code lines, blank lines and comment lines
13 | # for the specified languages. You will need to change this accordingly.
14 | # For C++, you could use "C++,C/C++ Header" for example.
15 | # We are only interested in the summary, therefore the tail -1
16 | SUMMARY="$(cloc "${SCRIPT_DIR}" --include-lang="Python" --md | tail -1)"
17 |
18 | # The $SUMMARY is one line of a markdown table and looks like this:
19 | # SUM:|101|3123|2238|10783
20 | # We use the following command to split it into an array.
21 | IFS='|' read -r -a TOKENS <<<"$SUMMARY"
22 |
23 | # Store the individual tokens for better readability.
24 | NUMBER_OF_FILES=${TOKENS[1]}
25 | COMMENT_LINES=${TOKENS[3]}
26 | LINES_OF_CODE=${TOKENS[4]}
27 |
28 | # To make the estimate of commented lines more accurate, we have to
29 | # subtract any copyright header which is included in each file.
30 | # For Fly-Pie, this header has the length of five lines.
31 | # All dumb comments like those /////////// or those // ------------
32 | # are also subtracted. As cloc does not count inline comments,
33 | # the overall estimate should be rather conservative.
34 | # Change the lines below according to your project.
35 | # DUMB_COMMENTS="$(grep -r -E '//////|// -----' "${SCRIPT_DIR}" | wc -l)"
36 | # COMMENT_LINES=$(($COMMENT_LINES - 5 * $NUMBER_OF_FILES - $DUMB_COMMENTS))
37 |
38 | # Print all results if no arguments are given.
39 | if [[ $# -eq 0 ]]; then
40 | awk -v a=$LINES_OF_CODE \
41 | 'BEGIN {printf "Lines of source code: %6.1fk\n", a/1000}'
42 | awk -v a=$COMMENT_LINES \
43 | 'BEGIN {printf "Lines of comments: %6.1fk\n", a/1000}'
44 | awk -v a=$COMMENT_LINES -v b=$LINES_OF_CODE \
45 | 'BEGIN {printf "Comment Percentage: %6.1f%\n", 100*a/b}'
46 | exit 0
47 | fi
48 |
49 | # Show lines of code if --loc is given.
50 | if [[ $* == *--loc* ]]; then
51 | awk -v a=$LINES_OF_CODE \
52 | 'BEGIN {printf "%.1fk\n", a/1000}'
53 | fi
54 |
55 | # Show lines of comments if --comments is given.
56 | if [[ $* == *--comments* ]]; then
57 | awk -v a=$COMMENT_LINES \
58 | 'BEGIN {printf "%.1fk\n", a/1000}'
59 | fi
60 |
61 | # Show precentage of comments if --percentage is given.
62 | if [[ $* == *--percentage* ]]; then
63 | awk -v a=$COMMENT_LINES -v b=$LINES_OF_CODE \
64 | 'BEGIN {printf "%.1f\n", 100*a/b}'
65 | fi
66 |
--------------------------------------------------------------------------------
/test/tree/tree/test_io.py:
--------------------------------------------------------------------------------
1 | import zlib
2 |
3 | import pytest
4 |
5 | from treevalue.tree import loads, dumps, FastTreeValue
6 |
7 |
8 | class MyTreeValue(FastTreeValue):
9 | pass
10 |
11 |
12 | # noinspection DuplicatedCode
13 | @pytest.mark.unittest
14 | class TestTreeTreeIo:
15 | def test_dumps_and_loads(self):
16 | t = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': [3] * 1000, 'd': 4}})
17 |
18 | _dumped_data = dumps(t)
19 | t1 = loads(_dumped_data, type_=FastTreeValue)
20 | assert len(_dumped_data) < 2170
21 | assert isinstance(t1, FastTreeValue)
22 | assert t1 == t
23 |
24 | t2 = loads(_dumped_data, type_=MyTreeValue)
25 | assert isinstance(t2, MyTreeValue)
26 | assert t2 != t
27 | assert FastTreeValue(t2) == t
28 |
29 | with pytest.warns(UserWarning):
30 | t1 = loads(_dumped_data, decompress=zlib.decompress, type_=FastTreeValue)
31 | assert isinstance(t1, FastTreeValue)
32 | assert t1 == t
33 |
34 | def test_dumps_and_loads_with_zip(self):
35 | t = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': [3] * 1000, 'd': 4}})
36 |
37 | _dumped_data = dumps(t, compress=zlib)
38 | t1 = loads(_dumped_data, type_=FastTreeValue)
39 | assert len(dumps(t, compress=zlib)) < 240
40 | assert isinstance(t1, FastTreeValue)
41 | assert t1 == t
42 |
43 | t2 = loads(_dumped_data, type_=MyTreeValue)
44 | assert isinstance(t2, MyTreeValue)
45 | assert t2 != t
46 | assert FastTreeValue(t2) == t
47 |
48 | with pytest.warns(UserWarning):
49 | t1 = loads(_dumped_data, decompress=zlib.decompress, type_=FastTreeValue)
50 | assert isinstance(t1, FastTreeValue)
51 | assert t1 == t
52 |
53 | def test_dumps_and_loads_with_zip_tuple(self):
54 | t = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': [3] * 1000, 'd': 4}})
55 |
56 | _dumped_data = dumps(t, compress=(zlib.compress, zlib.decompress))
57 | t1 = loads(_dumped_data, type_=FastTreeValue)
58 | assert len(_dumped_data) < 240
59 | assert isinstance(t1, FastTreeValue)
60 | assert t1 == t
61 |
62 | t2 = loads(_dumped_data, type_=MyTreeValue)
63 | assert isinstance(t2, MyTreeValue)
64 | assert t2 != t
65 | assert FastTreeValue(t2) == t
66 |
67 | def test_dumps_and_loads_with_compress_only(self):
68 | t = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': [3] * 1000, 'd': 4}})
69 |
70 | _dumped_data = dumps(t, compress=zlib.compress)
71 | t1 = loads(_dumped_data, decompress=zlib.decompress, type_=FastTreeValue)
72 | assert len(_dumped_data) < 170
73 | assert isinstance(t1, FastTreeValue)
74 | assert t1 == t
75 |
76 | with pytest.raises(RuntimeError):
77 | loads(_dumped_data, type_=FastTreeValue)
78 |
--------------------------------------------------------------------------------
/test/tree/tree/test_graph.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 | import pytest
5 | from hbutils.testing import cmdv
6 |
7 | from treevalue import FastTreeValue, graphics
8 |
9 |
10 | class MyFastTreeValue(FastTreeValue):
11 | pass
12 |
13 |
14 | @pytest.mark.unittest
15 | class TestTreeTreeGraph:
16 | @unittest.skipUnless(cmdv('dot'), 'Dot installed only')
17 | def test_graphics(self):
18 | t = MyFastTreeValue({
19 | 'a': [4, 3, 2, 1],
20 | 'b': np.array([[5, 6], [7, 8]]),
21 | 'x': {
22 | 'c': np.array([[5, 7], [8, 6]]),
23 | 'd': {'a', 'b', 'c'},
24 | 'e': np.array([[1, 2], [3, 4]])
25 | },
26 | })
27 | t1 = MyFastTreeValue({
28 | 'aa': t.a,
29 | 'bb': np.array([[5, 6], [7, 8]]),
30 | 'xx': {
31 | 'cc': t.x.c,
32 | 'dd': t.x.d,
33 | 'ee': np.array([[1, 2], [3, 4]])
34 | },
35 | })
36 |
37 | graph_1 = graphics(
38 | (t, 't'), (t1, 't1'),
39 | (MyFastTreeValue({'a': t, 'b': t1, 'c': [1, 2], 'd': t1.xx}), 't2'),
40 | dup_value=(np.ndarray, list),
41 | title="This is a demo of 2 trees with dup value.",
42 | cfg={'bgcolor': '#ffffffff'},
43 | )
44 | assert len(graph_1.source) <= 5000
45 |
46 | graph_2 = graphics(
47 | (t, 't'), (t1, 't1'),
48 | (MyFastTreeValue({'a': t, 'b': t1, 'c': [1, 2], 'd': t1.xx}), 't2'),
49 | dup_value=False,
50 | title="This is a demo of 2 trees with dup value.",
51 | cfg={'bgcolor': '#ffffffff'},
52 | )
53 | assert len(graph_2.source) <= 5600
54 |
55 | graph_3 = graphics(
56 | (t, 't'), (t1, 't1'),
57 | (MyFastTreeValue({'a': t, 'b': t1, 'c': [1, 2], 'd': t1.xx}), 't2'),
58 | dup_value=lambda x: id(x),
59 | title="This is a demo of 2 trees with dup value.",
60 | cfg={'bgcolor': '#ffffffff'},
61 | )
62 | assert len(graph_3.source) <= 4760
63 |
64 | graph_4 = graphics(
65 | (t, 't'), (t1, 't1'),
66 | (MyFastTreeValue({'a': t, 'b': t1, 'c': [1, 2], 'd': t1.xx}), 't2'),
67 | dup_value=lambda x: type(x).__name__,
68 | title="This is a demo of 2 trees with dup value.",
69 | cfg={'bgcolor': '#ffffffff'},
70 | )
71 | assert len(graph_4.source) <= 4000
72 |
73 | graph_6 = graphics(
74 | (t, 't'), (t1, 't1'),
75 | (MyFastTreeValue({'a': t, 'b': t1, 'c': [1, 2], 'd': t1.xx}), 't2'),
76 | dup_value=True,
77 | title="This is a demo of 2 trees with dup value.",
78 | cfg={'bgcolor': '#ffffffff'},
79 | )
80 | assert len(graph_6.source) <= 4760
81 |
--------------------------------------------------------------------------------
/test/tree/general/test_general.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from treevalue.tree import general_tree_value, method_treelize
4 | from .base import get_fasttreevalue_test
5 |
6 |
7 | class TreeNumber(general_tree_value()):
8 | @method_treelize()
9 | def append(self, *args):
10 | return sum([self, *args])
11 |
12 |
13 | class NonDefaultTreeNumber(general_tree_value(base=dict(), methods=dict(
14 | __add__=dict(missing=0, mode='outer'),
15 | __radd__=dict(missing=0, mode='outer'),
16 | ))):
17 | pass
18 |
19 |
20 | class BanAndOverrideTreeNumber(general_tree_value(methods=dict(
21 | __add__=NotImplemented,
22 | __radd__=NotImplemented,
23 | __iadd__=NotImplemented,
24 | __mul__=KeyError("lksdjfkl"),
25 | __rmul__=KeyError("lsdfjkldks"),
26 | __imul__=KeyError("dklfgjsl"),
27 | __matmul__=KeyError,
28 | __rmatmul__=KeyError,
29 | __imatmul__=KeyError,
30 | __pos__=(lambda sp: sp // 2),
31 | __truediv__=(lambda sp: sp - 2),
32 | ))):
33 | pass
34 |
35 |
36 | @pytest.mark.unittest
37 | class TestTreeGeneralGeneral(get_fasttreevalue_test(TreeNumber)):
38 | def test_numeric_append(self):
39 | t1 = TreeNumber({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
40 | t2 = TreeNumber({'a': 11, 'b': 22, 'x': {'c': 33, 'd': 5}})
41 | t3 = TreeNumber({'a': 11, 'b': 22, 'x': 7})
42 |
43 | assert t1.append(t2, 3) == TreeNumber({'a': 15, 'b': 27, 'x': {'c': 39, 'd': 12}})
44 | assert t1.append(t3) == TreeNumber({'a': 12, 'b': 24, 'x': {'c': 10, 'd': 11}})
45 |
46 | def test_default_tree_number(self):
47 | t1 = TreeNumber({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4, 'e': 7}})
48 | t2 = TreeNumber({'a': 11, 'b': 22, 'c': 4, 'x': {'c': 33, 'd': 5}})
49 |
50 | with pytest.raises(KeyError):
51 | _ = t1 + t2
52 | with pytest.raises(KeyError):
53 | _ = t1 - t2
54 |
55 | def test_non_default_tree_number(self):
56 | t1 = NonDefaultTreeNumber({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4, 'e': 7}})
57 | t2 = NonDefaultTreeNumber({'a': 11, 'b': 22, 'c': 4, 'x': {'c': 33, 'd': 5}})
58 |
59 | assert (t1 + t2) == NonDefaultTreeNumber({'a': 12, 'b': 24, 'c': 4, 'x': {'c': 36, 'd': 9, 'e': 7}})
60 | with pytest.raises(KeyError):
61 | _ = t1 - t2
62 |
63 | def test_ban_add(self):
64 | t1 = BanAndOverrideTreeNumber({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4, 'e': 7}})
65 | t2 = BanAndOverrideTreeNumber({'a': 11, 'b': 22, 'c': 4, 'x': {'c': 33, 'd': 5}})
66 |
67 | with pytest.raises(TypeError):
68 | _ = t1 + t2
69 | with pytest.raises(KeyError):
70 | _ = t1 * t2
71 | with pytest.raises(KeyError):
72 | _ = t1 @ t2
73 | assert +t1 == BanAndOverrideTreeNumber({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4, 'e': 7}}) // 2
74 | assert t1 / 3 == BanAndOverrideTreeNumber({'a': 1 - 2, 'b': 2 - 2, 'x': {'c': 3 - 2, 'd': 4 - 2, 'e': 7 - 2}})
75 |
--------------------------------------------------------------------------------
/test/tree/func/test_outer.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from treevalue.tree import func_treelize, TreeValue
4 |
5 |
6 | # noinspection DuplicatedCode
7 | @pytest.mark.unittest
8 | class TestTreeFuncOuter:
9 | def test_outer_inherit(self):
10 | @func_treelize('outer', missing=lambda: 0)
11 | def ssum(*args):
12 | return sum(args)
13 |
14 | t1 = TreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
15 | t2 = TreeValue({'a': 11, 'b': 22, 'x': {'c': 33, 'd': 44, 'p': 76, 'e': 45}, 'f': 344})
16 | assert ssum(1, 2, 3) == 6
17 | assert ssum(t1, t2) == TreeValue({'a': 12, 'b': 24, 'f': 344, 'x': {'c': 36, 'd': 48, 'p': 76, 'e': 45}})
18 | assert ssum(t1.x, t2.x) == TreeValue({'c': 36, 'd': 48, 'p': 76, 'e': 45})
19 |
20 | t3 = TreeValue({'a': 11, 'b': 22, 'c': 33, 'x': {'c': 33, 'd': 44, 'e': 550, 'v': -100}})
21 | assert ssum(t1, t2, t3) == TreeValue({
22 | 'a': 23, 'b': 46, 'c': 33, 'f': 344,
23 | 'x': {'c': 69, 'd': 92, 'p': 76, 'e': 595, 'v': -100}
24 | })
25 |
26 | def test_outer_inherit_without_missing(self):
27 | with pytest.warns(RuntimeWarning):
28 | @func_treelize('outer')
29 | def ssum(*args):
30 | return sum(args)
31 |
32 | t1 = TreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
33 | t2 = TreeValue({'a': 11, 'b': 22, 'x': {'c': 33, 'd': 44}})
34 | t3 = TreeValue({'a': 11, 'b': 22, 'x': {'c': 33, 'd': 44, 'p': 76, 'e': 45}, 'f': 344})
35 | assert ssum(1, 2, 3) == 6
36 | assert ssum(t1, t2) == TreeValue({'a': 12, 'b': 24, 'x': {'c': 36, 'd': 48}})
37 | assert ssum(t1.x, t2.x) == TreeValue({'c': 36, 'd': 48})
38 |
39 | with pytest.raises(KeyError):
40 | _ = ssum(t1, t3)
41 |
42 | def test_delayed_treelize(self):
43 | t1 = TreeValue({
44 | 'a': 1, 'x': {'c': 3, 'd': 4},
45 | })
46 | t2 = TreeValue({
47 | 'a': 11, 'b': 23, 'x': {'c': 35, },
48 | })
49 |
50 | cnt_1 = 0
51 |
52 | @func_treelize(delayed=True, mode='outer', missing=0)
53 | def total(a, b):
54 | nonlocal cnt_1
55 | cnt_1 += 1
56 | return a + b
57 |
58 | # positional
59 | t3 = total(t1, t2)
60 | assert cnt_1 == 0
61 |
62 | assert t3.a == 12
63 | assert cnt_1 == 1
64 | assert t3.x == TreeValue({'c': 38, 'd': 4})
65 | assert cnt_1 == 3
66 | assert t3 == TreeValue({
67 | 'a': 12, 'b': 23, 'x': {'c': 38, 'd': 4}
68 | })
69 | assert cnt_1 == 4
70 |
71 | # keyword
72 | cnt_1 = 0
73 | t3 = total(a=t1, b=t2)
74 | assert cnt_1 == 0
75 |
76 | assert t3.a == 12
77 | assert cnt_1 == 1
78 | assert t3.x == TreeValue({'c': 38, 'd': 4})
79 | assert cnt_1 == 3
80 | assert t3 == TreeValue({
81 | 'a': 12, 'b': 23, 'x': {'c': 38, 'd': 4}
82 | })
83 | assert cnt_1 == 4
84 |
--------------------------------------------------------------------------------
/treevalue/tree/tree/tree.pxd:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | from libcpp cimport bool
5 |
6 | from .constraint cimport Constraint
7 | from ..common.delay cimport DelayedProxy
8 | from ..common.storage cimport TreeStorage
9 |
10 | cdef class _CObject:
11 | pass
12 |
13 | cdef class _SimplifiedConstraintProxy:
14 | cdef readonly Constraint cons
15 |
16 | cdef Constraint _c_get_constraint(object cons)
17 | cpdef register_dict_type(object type_, object f_items)
18 |
19 | cdef class ValidationError(Exception):
20 | cdef readonly TreeValue _object
21 | cdef readonly Exception _error
22 | cdef readonly tuple _path
23 | cdef readonly Constraint _cons
24 |
25 | cdef class TreeValue:
26 | cdef readonly TreeStorage _st
27 | cdef readonly Constraint constraint
28 | cdef readonly type _type
29 | cdef readonly dict _child_constraints
30 |
31 | cpdef TreeStorage _detach(self)
32 | cdef object _unraw(self, object obj, str key)
33 | cdef object _raw(self, object obj)
34 | cpdef _attr_extern(self, str key)
35 | cpdef _getitem_extern(self, object key)
36 | cpdef _setitem_extern(self, object key, object value)
37 | cpdef _delitem_extern(self, object key)
38 | cdef void _update(self, object d, dict kwargs) except*
39 | cpdef public get(self, str key, object default= *)
40 | cpdef public pop(self, str key, object default= *)
41 | cpdef public popitem(self)
42 | cpdef public void clear(self)
43 | cpdef public object setdefault(self, str key, object default= *)
44 |
45 | cpdef public treevalue_keys keys(self)
46 | cpdef public treevalue_values values(self)
47 | cpdef public treevalue_items items(self)
48 |
49 | cdef tuple _unpack(self, tuple keys, object default=*)
50 |
51 | cpdef void validate(self) except*
52 |
53 | cdef object _get_tree_graph(self)
54 |
55 | cdef str _prefix_fix(object text, object prefix)
56 | cdef str _title_repr(TreeStorage st, object type_)
57 | cdef object _build_tree(TreeStorage st, object type_, str prefix, dict id_pool, tuple path)
58 |
59 | # noinspection PyPep8Naming
60 | cdef class treevalue_keys(_CObject):
61 | cdef readonly TreeStorage _st
62 | cdef readonly type _type
63 |
64 | # noinspection PyPep8Naming
65 | cdef class treevalue_values(_CObject):
66 | cdef readonly TreeStorage _st
67 | cdef readonly type _type
68 | cdef readonly Constraint _constraint
69 | cdef readonly dict _child_constraints
70 |
71 | cdef _SimplifiedConstraintProxy _transact(self, str key)
72 |
73 | # noinspection PyPep8Naming
74 | cdef class treevalue_items(_CObject):
75 | cdef readonly TreeStorage _st
76 | cdef readonly type _type
77 | cdef readonly Constraint _constraint
78 | cdef readonly dict _child_constraints
79 |
80 | cdef _SimplifiedConstraintProxy _transact(self, str key)
81 |
82 | cdef class DetachedDelayedProxy(DelayedProxy):
83 | cdef DelayedProxy proxy
84 | cdef readonly bool calculated
85 | cdef object val
86 |
87 | cpdef object value(self)
88 | cpdef object fvalue(self)
89 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | from codecs import open
4 |
5 | from setuptools import find_packages, setup
6 | from Cython.Build import cythonize # this line should be after 'from setuptools import find_packages'
7 |
8 | _package_name = "treevalue"
9 |
10 | here = os.path.abspath(os.path.dirname(__file__))
11 | meta = {}
12 | with open(os.path.join(here, _package_name, 'config', 'meta.py'), 'r', 'utf-8') as f:
13 | exec(f.read(), meta)
14 |
15 |
16 | def _load_req(file: str):
17 | with open(file, 'r', 'utf-8') as f:
18 | return [line.strip() for line in f.readlines() if line.strip()]
19 |
20 |
21 | requirements = _load_req('requirements.txt')
22 |
23 | _REQ_PATTERN = re.compile('^requirements-([a-zA-Z0-9_]+)\\.txt$')
24 | group_requirements = {
25 | item.group(1): _load_req(item.group(0))
26 | for item in [_REQ_PATTERN.fullmatch(reqpath) for reqpath in os.listdir()] if item
27 | }
28 |
29 | with open('README.md', 'r', 'utf-8') as f:
30 | readme = f.read()
31 |
32 |
33 | def find_pyx(path=None):
34 | path = path or os.path.join(here, _package_name)
35 | pyx_files = []
36 | for root, dirs, filenames in os.walk(path):
37 | for fname in filenames:
38 | if fname.endswith('.pyx'):
39 | pyx_files.append(os.path.join(root, fname))
40 | return pyx_files
41 |
42 |
43 | _LINETRACE = not not os.environ.get('LINETRACE', None)
44 |
45 | setup(
46 | # information
47 | name=meta['__TITLE__'],
48 | version=meta['__VERSION__'],
49 | packages=find_packages(
50 | include=(_package_name, "%s.*" % _package_name)
51 | ),
52 | description=meta['__DESCRIPTION__'],
53 | long_description=readme,
54 | long_description_content_type='text/markdown',
55 | author=meta['__AUTHOR__'],
56 | author_email=meta['__AUTHOR_EMAIL__'],
57 | license='Apache License, Version 2.0',
58 | keywords='Tree-structured Value Management',
59 | url='https://github.com/HansBug/treevalue',
60 |
61 | # environment
62 | python_requires=">=3.8",
63 | ext_modules=cythonize(
64 | find_pyx(),
65 | language_level=3,
66 | compiler_directives=dict(
67 | linetrace=_LINETRACE,
68 | )
69 | ),
70 | install_requires=requirements,
71 | tests_require=group_requirements['test'],
72 | extras_require=group_requirements,
73 | classifiers=[
74 | 'Development Status :: 5 - Production/Stable',
75 | 'Intended Audience :: Developers',
76 | 'License :: OSI Approved :: Apache Software License',
77 | 'Programming Language :: Python',
78 | 'Programming Language :: Python :: 3',
79 | 'Programming Language :: Python :: 3.8',
80 | 'Programming Language :: Python :: 3.9',
81 | 'Programming Language :: Python :: 3.10',
82 | 'Programming Language :: Python :: 3.11',
83 | 'Programming Language :: Python :: 3.12',
84 | ],
85 | entry_points={
86 | 'console_scripts': [
87 | 'treevalue=treevalue.entry.cli:treevalue_cli'
88 | ]
89 | },
90 | )
91 |
--------------------------------------------------------------------------------
/treevalue/tree/tree/constraint.pxd:
--------------------------------------------------------------------------------
1 | # distutils:language=c++
2 | # cython:language_level=3
3 |
4 | import cython
5 | from libcpp cimport bool
6 |
7 | cdef class _WrappedConstraintException(Exception):
8 | pass
9 |
10 | cdef class Constraint:
11 | cpdef void _validate_node(self, object instance) except*
12 | cpdef void _validate_value(self, object instance) except*
13 | cpdef object _features(self)
14 | cpdef bool _contains(self, Constraint other)
15 | cpdef Constraint _transaction(self, str key)
16 |
17 | cdef bool _feature_match(self, Constraint other)
18 | cdef bool _contains_check(self, Constraint other)
19 | cdef tuple _native_validate(self, object instance, type type_, list path)
20 | cpdef tuple check(self, object instance)
21 | cpdef bool equiv(self, object other)
22 |
23 | cdef bool _c_default_accessible(Constraint cons, object type_, tuple items, dict params) except*
24 |
25 | @cython.final
26 | cdef class EmptyConstraint(Constraint):
27 | pass
28 |
29 | cdef EmptyConstraint _EMPTY_CONSTRAINT
30 |
31 | cdef Constraint _r_parse_cons(object obj)
32 | cpdef Constraint to_constraint(object obj)
33 |
34 | cdef class ValueConstraint(Constraint):
35 | pass
36 |
37 | cdef class NodeConstraint(Constraint):
38 | pass
39 |
40 | @cython.final
41 | cdef class TypeConstraint(ValueConstraint):
42 | cdef readonly object type_
43 |
44 | cdef str _c_func_fullname(object f)
45 |
46 | cdef class ValueFuncConstraint(ValueConstraint):
47 | cdef readonly object func
48 | cdef readonly str name
49 |
50 | @cython.final
51 | cdef class ValueValidateConstraint(ValueFuncConstraint):
52 | pass
53 |
54 | @cython.final
55 | cdef class ValueCheckConstraint(ValueFuncConstraint):
56 | pass
57 |
58 | cpdef ValueValidateConstraint vval(object func, object name= *)
59 | cpdef ValueCheckConstraint vcheck(object func, object name= *)
60 |
61 | @cython.final
62 | cdef class LeafConstraint(Constraint):
63 | pass
64 |
65 | cpdef LeafConstraint cleaf()
66 |
67 | cdef class NodeFuncConstraint(NodeConstraint):
68 | cdef readonly object func
69 | cdef readonly str name
70 |
71 | @cython.final
72 | cdef class NodeValidateConstraint(NodeFuncConstraint):
73 | pass
74 |
75 | @cython.final
76 | cdef class NodeCheckConstraint(NodeFuncConstraint):
77 | pass
78 |
79 | cpdef NodeValidateConstraint nval(object func, object name= *)
80 | cpdef NodeCheckConstraint ncheck(object func, object name= *)
81 |
82 | cdef class TreeConstraint(Constraint):
83 | cdef readonly dict _constraints
84 |
85 | cdef Constraint _s_tree_merge(list constraints)
86 | cdef Constraint _s_tree(TreeConstraint constraint)
87 |
88 | cdef class CompositeConstraint(Constraint):
89 | cdef readonly tuple _constraints
90 |
91 | cdef void _rec_composite_iter(Constraint constraint, list lst)
92 | cdef list _r_composite_iter(Constraint constraint)
93 |
94 | cdef Constraint _s_generic_merge(list constraints)
95 | cdef Constraint _s_composite(CompositeConstraint constraint)
96 |
97 | cdef Constraint _s_simplify(Constraint constraint)
98 |
99 | cpdef Constraint transact(object cons, str key)
100 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: docs test unittest build clean benchmark zip
2 |
3 | NO_DEBUG ?=
4 | NO_DOCSTRING ?=
5 | NO_DEBUG_CMD := $(if ${NO_DOCSTRING},-OO,$(if ${NO_DEBUG},-O,))
6 | PYTHON := $(shell which python) ${NO_DEBUG_CMD}
7 |
8 | DOC_DIR := ./docs
9 | DIST_DIR := ./dist
10 | WHEELHOUSE_DIR := ./wheelhouse
11 | TEST_DIR := ./test
12 | BENCHMARK_DIR := ./benchmark
13 | SRC_DIR := ./treevalue
14 | RUNS_DIR := ./runs
15 |
16 | RANGE_DIR ?= .
17 | RANGE_TEST_DIR := ${TEST_DIR}/${RANGE_DIR}
18 | RANGE_BENCH_DIR := ${BENCHMARK_DIR}/${RANGE_DIR}
19 | RANGE_SRC_DIR := ${SRC_DIR}/${RANGE_DIR}
20 |
21 | CYTHON_FILES := $(shell find ${SRC_DIR} -name '*.pyx')
22 |
23 | COV_TYPES ?= xml term-missing
24 | COMPILE_PLATFORM ?= manylinux_2_24_x86_64
25 |
26 | BENCHMARK_FILE ?=
27 | BENCHMARK_OUTPUT_DIR ?= .benchmarks
28 | BM_FILES := $(shell find ${BENCHMARK_OUTPUT_DIR} -name '*.json' -type f)
29 | BM_CSV_FILES := $(addsuffix .csv,$(basename ${BM_FILES}))
30 |
31 | build:
32 | $(PYTHON) setup.py build_ext --inplace \
33 | $(if ${LINETRACE},--define CYTHON_TRACE,)
34 |
35 | zip:
36 | $(PYTHON) -m build --sdist --outdir ${DIST_DIR}
37 |
38 | package:
39 | $(PYTHON) -m build --sdist --wheel --outdir ${DIST_DIR}
40 | for whl in `ls ${DIST_DIR}/*.whl`; do \
41 | auditwheel repair $$whl -w ${WHEELHOUSE_DIR} --plat ${COMPILE_PLATFORM} && \
42 | cp `ls ${WHEELHOUSE_DIR}/*.whl` ${DIST_DIR} && \
43 | rm -rf $$whl ${WHEELHOUSE_DIR}/* \
44 | ; done
45 |
46 | clean:
47 | rm -rf $(shell find ${SRC_DIR} -name '*.so') \
48 | $(shell ls $(addsuffix .c, $(basename ${CYTHON_FILES})) \
49 | $(addsuffix .cpp, $(basename ${CYTHON_FILES})) \
50 | $(addsuffix .h, $(basename ${CYTHON_FILES})) \
51 | 2> /dev/null)
52 | rm -rf ${DIST_DIR} ${WHEELHOUSE_DIR}
53 |
54 | test: unittest benchmark
55 |
56 | unittest:
57 | $(PYTHON) -m pytest "${RANGE_TEST_DIR}" \
58 | -sv -m unittest \
59 | $(shell for type in ${COV_TYPES}; do echo "--cov-report=$$type"; done) \
60 | --cov="${RANGE_SRC_DIR}" \
61 | $(if ${MIN_COVERAGE},--cov-fail-under=${MIN_COVERAGE},) \
62 | $(if ${WORKERS},-n ${WORKERS},)
63 |
64 | benchmark:
65 | $(PYTHON) -m pytest "${RANGE_TEST_DIR}" \
66 | -sv -m benchmark \
67 | --benchmark-columns=min,max,mean,median,IQR,ops,rounds,iterations \
68 | --benchmark-disable-gc \
69 | --benchmark-sort=mean \
70 | $(if ${WORKERS},-n ${WORKERS},) \
71 | --benchmark-autosave \
72 | $(if ${BENCHMARK_FILE},--benchmark-save=${BENCHMARK_FILE},)
73 |
74 | compare:
75 | $(PYTHON) -m pytest "${RANGE_BENCH_DIR}" \
76 | -sv -m benchmark \
77 | --benchmark-columns=min,max,mean,median,IQR,ops,rounds,iterations \
78 | --benchmark-disable-gc \
79 | --benchmark-sort=mean \
80 | $(if ${WORKERS},-n ${WORKERS},) \
81 | --benchmark-autosave \
82 | $(if ${BENCHMARK_FILE},--benchmark-save=${BENCHMARK_FILE},)
83 |
84 | %.csv: %.json
85 | $(PYTHON) bmtrans.py -i "$(shell readlink -f $<)" -o "$(shell readlink -f $@)"
86 | bmtrans: ${BM_CSV_FILES}
87 |
88 | docs:
89 | $(MAKE) -C "${DOC_DIR}" build
90 | pdocs:
91 | $(MAKE) -C "${DOC_DIR}" prod
92 |
93 | run:
94 | PYTHONPATH=$(shell readlink -f .):$(shell readlink -f ${RUNS_DIR}):${PYTHONPATH} $(MAKE) -C "${RUNS_DIR}" run
95 |
--------------------------------------------------------------------------------