├── tests
├── examples
│ ├── __init__.py
│ ├── behaviors
│ │ ├── __init__.py
│ │ ├── binary_sum.py
│ │ ├── loop_with_alternative.py
│ │ ├── report_example.py
│ │ ├── basic_empty.py
│ │ ├── loop_without_alternative.py
│ │ └── basic.py
│ └── feature_conditions.py
├── __init__.py
├── test_metaheuristics.py
├── test_behavior_empty.py
├── test_node.py
├── test_loop.py
├── test_behavior.py
├── test_pet_a_cat.py
├── test_call_graph.py
├── test_paper_basic_example.py
└── test_code_generation.py
├── setup.py
├── commands
└── coverage.ps1
├── requirements.txt
├── .coveragerc
├── docs
└── images
│ ├── PetACatGraph.png
│ └── PetACatGraphUnrolled.png
├── src
└── hebg
│ ├── metrics
│ ├── __init__.py
│ ├── utility
│ │ ├── __init__.py
│ │ └── binary_utility.py
│ ├── complexity
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ └── complexities.py
│ └── histograms.py
│ ├── layouts
│ ├── __init__.py
│ ├── deterministic.py
│ ├── metaheuristics.py
│ └── metabased.py
│ ├── __init__.py
│ ├── behavior.py
│ ├── node.py
│ ├── requirements_graph.py
│ ├── heb_graph.py
│ ├── draw.py
│ ├── graph.py
│ ├── unrolling.py
│ ├── call_graph.py
│ └── codegen.py
├── .github
└── workflows
│ ├── python-tests.yml
│ ├── python-pypi.yml
│ └── python-coverage.yml
├── .pre-commit-config.yaml
├── shell.nix
├── pyproject.toml
├── CONTRIBUTING.rst
├── .gitignore
└── README.rst
/tests/examples/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 |
4 | setup()
5 |
--------------------------------------------------------------------------------
/commands/coverage.ps1:
--------------------------------------------------------------------------------
1 | pytest --cov=src --cov-report=html --cov-report=term
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | networkx >= 2.5.1
2 | matplotlib
3 | numpy
4 | tqdm
5 | scipy
--------------------------------------------------------------------------------
/.coveragerc:
--------------------------------------------------------------------------------
1 | [report]
2 | show_missing = True
3 | omit =
4 | heb_graph/layouts/*
5 |
--------------------------------------------------------------------------------
/docs/images/PetACatGraph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IRLL/HEB_graphs/HEAD/docs/images/PetACatGraph.png
--------------------------------------------------------------------------------
/docs/images/PetACatGraphUnrolled.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IRLL/HEB_graphs/HEAD/docs/images/PetACatGraphUnrolled.png
--------------------------------------------------------------------------------
/src/hebg/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """Module containing HEBGraph metrics."""
5 |
--------------------------------------------------------------------------------
/src/hebg/metrics/utility/__init__.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """Module for utility computation methods."""
5 |
6 | from hebg.metrics.utility.binary_utility import binary_graphbased_utility
7 |
8 | __all__ = ["binary_graphbased_utility"]
9 |
--------------------------------------------------------------------------------
/src/hebg/metrics/complexity/__init__.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """Module for complexity computation methods."""
5 |
6 | from hebg.metrics.complexity.complexities import general_complexity, learning_complexity
7 |
8 | __all__ = ["general_complexity", "learning_complexity"]
9 |
--------------------------------------------------------------------------------
/src/hebg/layouts/__init__.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """Module containing layouts to draw graphs."""
5 |
6 | from hebg.layouts.deterministic import staircase_layout
7 | from hebg.layouts.metabased import leveled_layout_energy
8 |
9 | __all__ = ["staircase_layout", "leveled_layout_energy"]
10 |
--------------------------------------------------------------------------------
/src/hebg/metrics/complexity/utils.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """Utility functions for complexity computation."""
5 |
6 | from copy import deepcopy
7 |
8 |
9 | def update_sum_dict(dict1: dict, dict2: dict):
10 | """Give the sum of two dictionaries."""
11 | dict1, dict2 = deepcopy(dict1), deepcopy(dict2)
12 | for key, val in dict2.items():
13 | if not isinstance(val, dict):
14 | try:
15 | dict1[key] += val
16 | except KeyError:
17 | dict1[key] = val
18 | return dict1
19 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """Tests for the heb_graph package."""
5 |
6 | from typing import Protocol
7 | from matplotlib import pyplot as plt
8 |
9 |
10 | class Graph(Protocol):
11 | def draw(self, ax, pos):
12 | """Draw the graph on a matplotlib axes."""
13 |
14 | def nodes(self) -> list:
15 | """Return a list of nodes"""
16 |
17 |
18 | def plot_graph(graph: Graph, **kwargs):
19 | _, ax = plt.subplots()
20 | graph.draw(ax, **kwargs)
21 | plt.axis("off") # turn off axis
22 | plt.show()
23 |
--------------------------------------------------------------------------------
/src/hebg/__init__.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """A structure for explainable hierarchical reinforcement learning"""
5 |
6 | from hebg.behavior import Behavior
7 | from hebg.heb_graph import HEBGraph
8 | from hebg.node import Action, EmptyNode, FeatureCondition, Node, StochasticAction
9 | from hebg.requirements_graph import build_requirement_graph
10 |
11 | __all__ = [
12 | "HEBGraph",
13 | "Behavior",
14 | "Action",
15 | "EmptyNode",
16 | "FeatureCondition",
17 | "Node",
18 | "StochasticAction",
19 | "build_requirement_graph",
20 | ]
21 |
--------------------------------------------------------------------------------
/.github/workflows/python-tests.yml:
--------------------------------------------------------------------------------
1 | name: Python tests
2 |
3 | on: ["push"]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: windows-latest
9 | strategy:
10 | matrix:
11 | python-version: ['3.10', '3.11', '3.12']
12 | steps:
13 | - uses: actions/checkout@v2
14 | - name: Set up Python ${{ matrix.python-version }}
15 | uses: actions/setup-python@v1
16 | with:
17 | python-version: ${{ matrix.python-version }}
18 | - name: Install dependencies
19 | run: |
20 | git submodule update --init --recursive
21 | python -m pip install --upgrade pip
22 | pip install .[dev]
23 | - name: Test with pytest
24 | run: |
25 | pytest tests
26 |
--------------------------------------------------------------------------------
/.github/workflows/python-pypi.yml:
--------------------------------------------------------------------------------
1 | name: PyPi
2 |
3 | on:
4 | push:
5 | tags:
6 | - 'v*'
7 | release:
8 | types: [published]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v1
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install build twine
25 | - name: Build
26 | run: |
27 | python -m build
28 | - name: Publish
29 | env:
30 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
31 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
32 | run: |
33 | twine upload dist/*
--------------------------------------------------------------------------------
/.github/workflows/python-coverage.yml:
--------------------------------------------------------------------------------
1 | name: Python coverage
2 |
3 | on: ["push"]
4 |
5 | jobs:
6 | build:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - uses: actions/checkout@v2
10 | - name: Set up Python 3.10
11 | uses: actions/setup-python@v1
12 | with:
13 | python-version: '3.10'
14 | - name: Install dependencies
15 | run: |
16 | git submodule update --init --recursive
17 | python -m pip install --upgrade pip
18 | pip install -e .[dev]
19 | - name: Build coverage using pytest-cov
20 | run: |
21 | pytest --cov=src --cov-report=xml tests
22 | - name: Codacy Coverage Reporter
23 | uses: codacy/codacy-coverage-reporter-action@v1.3.0
24 | with:
25 | project-token: ${{ secrets.CODACY_PROJECT_TOKEN }}
26 | coverage-reports: coverage.xml
27 |
--------------------------------------------------------------------------------
/tests/test_metaheuristics.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """Unit tests for the hebg.metaheuristics module."""
5 |
6 | import numpy as np
7 |
8 | import pytest_check as check
9 |
10 | from hebg.layouts.metaheuristics import simulated_annealing
11 |
12 |
13 | def test_simulated_annealing():
14 | """Simulated annealing must work on the simple x**2 case."""
15 |
16 | def energy(x):
17 | return x**2
18 |
19 | step_size = 0.05
20 |
21 | def neighbor(x):
22 | return x + np.random.choice([-1, 1]) * step_size
23 |
24 | optimal_x = simulated_annealing(
25 | -1, energy, neighbor, max_iterations=1000, initial_temperature=5, verbose=1
26 | )
27 |
28 | check.less_equal(abs(optimal_x), 3.1 * step_size)
29 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v2.3.0
4 | hooks:
5 | - id: check-yaml
6 | - id: end-of-file-fixer
7 | - id: trailing-whitespace
8 | - repo: https://github.com/astral-sh/ruff-pre-commit
9 | rev: v0.1.11
10 | hooks:
11 | - id: ruff
12 | args: [ --fix ]
13 | - id: ruff-format
14 | - repo: local
15 | hooks:
16 | - id: pytest-fast-check
17 | name: pytest-fast-check
18 | entry: pytest -m "not slow"
19 | stages: ["pre-commit"]
20 | language: system
21 | pass_filenames: false
22 | always_run: true
23 | - repo: local
24 | hooks:
25 | - id: pytest-check
26 | name: pytest-check
27 | entry: pytest
28 | stages: ["pre-push"]
29 | language: system
30 | pass_filenames: false
31 | always_run: true
32 |
--------------------------------------------------------------------------------
/tests/examples/behaviors/__init__.py:
--------------------------------------------------------------------------------
1 | from tests.examples.behaviors.basic import (
2 | FundamentalBehavior,
3 | F_A_Behavior,
4 | F_AA_Behavior,
5 | F_F_A_Behavior,
6 | )
7 | from tests.examples.behaviors.basic_empty import (
8 | E_A_Behavior,
9 | E_F_A_Behavior,
10 | F_E_A_Behavior,
11 | E_E_A_Behavior,
12 | )
13 | from tests.examples.behaviors.binary_sum import build_binary_sum_behavior
14 | from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors
15 | from tests.examples.behaviors.loop_without_alternative import (
16 | build_looping_behaviors_without_direct_alternatives,
17 | )
18 |
19 |
20 | __all__ = [
21 | "FundamentalBehavior",
22 | "F_A_Behavior",
23 | "F_AA_Behavior",
24 | "F_F_A_Behavior",
25 | "E_A_Behavior",
26 | "E_F_A_Behavior",
27 | "F_E_A_Behavior",
28 | "E_E_A_Behavior",
29 | "build_binary_sum_behavior",
30 | "build_looping_behaviors",
31 | "build_looping_behaviors_without_direct_alternatives",
32 | ]
33 |
--------------------------------------------------------------------------------
/shell.nix:
--------------------------------------------------------------------------------
1 | with import { };
2 |
3 | let
4 | pythonPackages = python3Packages;
5 | in pkgs.mkShell rec {
6 | name = "localDevPythonEnv";
7 | venvDir = "./.venv";
8 | buildInputs = [
9 | # A Python interpreter including the 'venv' module is required to bootstrap
10 | # the environment.
11 | pythonPackages.python
12 |
13 | # This executes some shell code to initialize a venv in $venvDir before
14 | # dropping into the shell
15 | pythonPackages.venvShellHook
16 |
17 | # Those are dependencies that we would like to use from nixpkgs, which will
18 | # add them to PYTHONPATH and thus make them accessible from within the venv.
19 | pythonPackages.numpy
20 | pythonPackages.networkx
21 | pythonPackages.matplotlib
22 | pythonPackages.numpy
23 | pythonPackages.tqdm
24 | pythonPackages.scipy
25 | ];
26 |
27 | # Run this command, only after creating the virtual environment
28 | postVenvCreation = ''
29 | pip install -e '.[dev]'
30 | '';
31 |
32 | # Now we can execute any commands within the virtual environment.
33 | # This is optional and can be left out to run pip manually.
34 | postShellHook = ''
35 | '';
36 |
37 | }
38 |
--------------------------------------------------------------------------------
/tests/examples/behaviors/binary_sum.py:
--------------------------------------------------------------------------------
1 | from hebg import Action
2 | from tests.examples.behaviors import F_A_Behavior
3 | from tests.examples.feature_conditions import IsDivisibleFeatureCondition
4 |
5 |
6 | def build_binary_sum_behavior() -> F_A_Behavior:
7 | feature_condition = IsDivisibleFeatureCondition(2)
8 | actions = {0: Action(0), 1: Action(1)}
9 | binary_1 = F_A_Behavior("Is x1 in binary ?", feature_condition, actions)
10 |
11 | feature_condition = IsDivisibleFeatureCondition(2)
12 | actions = {0: Action(1), 1: Action(0)}
13 | binary_0 = F_A_Behavior("Is x0 in binary ?", feature_condition, actions)
14 |
15 | feature_condition = IsDivisibleFeatureCondition(4)
16 | actions = {0: Action(0), 1: binary_1}
17 | binary_11 = F_A_Behavior("Is x11 in binary ?", feature_condition, actions)
18 |
19 | feature_condition = IsDivisibleFeatureCondition(4)
20 | actions = {0: binary_0, 1: binary_1}
21 | binary_10_01 = F_A_Behavior("Is x01 or x10 in binary ?", feature_condition, actions)
22 |
23 | feature_condition = IsDivisibleFeatureCondition(8)
24 | actions = {0: binary_11, 1: binary_10_01}
25 |
26 | return F_A_Behavior("Is sum (of last 3 binary) 2 ?", feature_condition, actions)
27 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"]
3 |
4 | [project]
5 | name = "hebg"
6 | authors = [
7 | { name = "Mathïs Fédérico" },
8 | { name = "Mathïs Fédérico", email = "mathfederico@gmail.com" },
9 | ]
10 | description = "HEBG: Hierarchial Explainations of Behavior as Graph"
11 | dynamic = ["version", "readme", "dependencies"]
12 | license = { text = "GPLv3 license" }
13 | requires-python = ">=3.7"
14 |
15 | [project.urls]
16 | repository = "https://github.com/IRLL/HEB_graphs"
17 |
18 |
19 | [tool.setuptools]
20 | license-files = ['LICEN[CS]E*', 'COPYING*', 'NOTICE*', 'AUTHORS*']
21 |
22 | [project.optional-dependencies]
23 | dev = ["ruff", "pytest", "pytest-check", "pytest-mock", "pytest-cov", "mypy", "pre-commit"]
24 |
25 | [project.scripts]
26 |
27 | [tool.setuptools.dynamic]
28 | readme = { file = ["README.rst"] }
29 | dependencies = { file = ["requirements.txt"] }
30 |
31 | [tool.setuptools_scm]
32 |
33 | [tool.mypy]
34 | files = "hebg"
35 | check_untyped_defs = true
36 | disallow_any_generics = false
37 | disallow_incomplete_defs = true
38 | no_implicit_optional = true
39 | no_implicit_reexport = false
40 | strict_equality = true
41 | warn_redundant_casts = true
42 | warn_unused_ignores = true
43 | ignore_missing_imports = true
44 |
--------------------------------------------------------------------------------
/tests/test_behavior_empty.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """Behavior of HEBGraphs with empty nodes."""
5 |
6 | import pytest_check as check
7 |
8 | from hebg.node import Action
9 | from tests.examples.behaviors import (
10 | E_F_A_Behavior,
11 | E_A_Behavior,
12 | F_E_A_Behavior,
13 | E_E_A_Behavior,
14 | )
15 |
16 |
17 | def test_e_a_graph():
18 | """(E-A) Empty nodes should skip to successor."""
19 | action_id = 42
20 | behavior = E_A_Behavior("E_A", Action(action_id))
21 | check.equal(behavior(None), action_id)
22 |
23 |
24 | def test_e_f_a_graph():
25 | """(E-F-A) Empty should orient path properly in chain with Feature condition."""
26 | behavior = E_F_A_Behavior("E_F_A")
27 | check.equal(behavior(-1), 0)
28 | check.equal(behavior(1), 1)
29 |
30 |
31 | def test_f_e_a_graph():
32 | """(F-E-A) Feature condition should orient path properly in chain with Empty."""
33 | behavior = F_E_A_Behavior("F_E_A")
34 | check.equal(behavior(1), 0)
35 | check.equal(behavior(-1), 1)
36 |
37 |
38 | def test_e_e_a_graph():
39 | """(E-E-A) Empty should orient path properly in double chain."""
40 | behavior = E_E_A_Behavior("E_E_A")
41 | check.equal(behavior(None), 0)
42 |
--------------------------------------------------------------------------------
/src/hebg/metrics/utility/binary_utility.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """Simplest binary utility for HEBGraph."""
5 |
6 | from typing import Dict, List
7 |
8 | from hebg import Behavior
9 |
10 |
11 | def binary_graphbased_utility(
12 | behavior: Behavior,
13 | solving_behaviors: List[Behavior],
14 | used_nodes: Dict[str, Dict[str, int]],
15 | ) -> bool:
16 | """Returns if the behavior is in the HEBGraph of any solving_behavior.
17 |
18 | Args:
19 | behavior: Behavior of which we want to compute the utility.
20 | solving_behaviors: list of behaviors that solves the task of interest.
21 | used_nodes: dictionary mapping behavior_id to nodes used in the behavior.
22 |
23 | Returns:
24 | True if the behavior in the HEBGraph of any solving_behavior. False otherwise.
25 |
26 | """
27 |
28 | for solving_behavior in solving_behaviors:
29 | if behavior == solving_behavior:
30 | return True
31 | if behavior in solving_behavior.graph.nodes():
32 | return True
33 | if (
34 | behavior in used_nodes[solving_behavior]
35 | and used_nodes[solving_behavior][behavior] > 0
36 | ):
37 | return True
38 | return False
39 |
--------------------------------------------------------------------------------
/tests/examples/feature_conditions.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | from enum import Enum
3 |
4 | from hebg.node import FeatureCondition
5 |
6 |
7 | class ThresholdFeatureCondition(FeatureCondition):
8 | class Relation(Enum):
9 | GREATER_OR_EQUAL_TO = ">="
10 | LESSER_OR_EQUAL_TO = "<="
11 | GREATER_THAN = ">"
12 | LESSER_THAN = "<"
13 |
14 | def __init__(
15 | self, relation: Union[Relation, str] = ">=", threshold: float = 0, **kwargs
16 | ) -> None:
17 | """Threshold-based feature condition for scalar feature."""
18 | self.relation = relation
19 | self.threshold = threshold
20 | self._relation = self.Relation(relation)
21 | display_name = self._relation.name.capitalize().replace("_", " ")
22 | name = f"{display_name} {threshold} ?"
23 | super().__init__(name=name, **kwargs)
24 |
25 | def __call__(self, observation: float) -> int:
26 | conditions = {
27 | self.Relation.GREATER_OR_EQUAL_TO: int(observation >= self.threshold),
28 | self.Relation.LESSER_OR_EQUAL_TO: int(observation <= self.threshold),
29 | self.Relation.GREATER_THAN: int(observation > self.threshold),
30 | self.Relation.LESSER_THAN: int(observation < self.threshold),
31 | }
32 | if self._relation in conditions:
33 | return conditions[self._relation]
34 |
35 |
36 | class IsDivisibleFeatureCondition(FeatureCondition):
37 | def __init__(self, number: int = 0) -> None:
38 | """Is divisible feature condition for scalar feature."""
39 | self.number = number
40 | name = f"Is divisible by {number} ?"
41 | super().__init__(name=name, image=None)
42 |
43 | def __call__(self, observation: float) -> int:
44 | return int(observation // self.number == 1)
45 |
--------------------------------------------------------------------------------
/src/hebg/behavior.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """Module for base Behavior."""
5 |
6 | from __future__ import annotations
7 |
8 | from typing import TYPE_CHECKING
9 |
10 | from hebg.graph import compute_levels
11 | from hebg.node import Node
12 |
13 | if TYPE_CHECKING:
14 | from hebg.heb_graph import HEBGraph
15 |
16 |
17 | class Behavior(Node):
18 | """Abstract class for a Behavior as Node"""
19 |
20 | def __init__(self, name: str, image=None, **kwargs) -> None:
21 | super().__init__(name, "behavior", image=image, **kwargs)
22 | self._graph = None
23 |
24 | def __call__(self, observation, *args, **kwargs) -> None:
25 | """Use the behavior to get next actions.
26 |
27 | By default, uses the HEBGraph if it can be built.
28 |
29 | Args:
30 | observation: Observations of the environment.
31 | greedy: If true, the agent should act greedily.
32 |
33 | Returns:
34 | action: Action given by the behavior with current observation.
35 |
36 | """
37 | return self.graph.__call__(observation, *args, **kwargs)
38 |
39 | def build_graph(self) -> HEBGraph:
40 | """Build the HEBGraph of this Behavior.
41 |
42 | Returns:
43 | The built HEBGraph.
44 |
45 | """
46 | raise NotImplementedError
47 |
48 | @property
49 | def graph(self) -> HEBGraph:
50 | """Access to the Behavior's graph.
51 |
52 | Only build's the graph the first time called for efficiency.
53 |
54 | Returns:
55 | This Behavior's HEBGraph.
56 |
57 | """
58 | if self._graph is None:
59 | self._graph = self.build_graph()
60 | compute_levels(self._graph)
61 | return self._graph
62 |
--------------------------------------------------------------------------------
/tests/examples/behaviors/loop_with_alternative.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List
2 |
3 | from hebg import HEBGraph, Action, FeatureCondition, Behavior
4 |
5 |
6 | class HasItem(FeatureCondition):
7 | def __init__(self, item_name: str) -> None:
8 | self.item_name = item_name
9 | super().__init__(name=f"Has {item_name} ?", complexity=1.0)
10 |
11 | def __call__(self, observation: Any) -> int:
12 | return self.item_name in observation
13 |
14 |
15 | class GatherWood(Behavior):
16 | """Gather wood"""
17 |
18 | def __init__(self) -> None:
19 | """Gather wood"""
20 | super().__init__("Gather wood")
21 |
22 | def build_graph(self) -> HEBGraph:
23 | graph = HEBGraph(self)
24 | has_axe = HasItem("axe")
25 | graph.add_edge(has_axe, Action("Punch tree", complexity=2.0), index=False)
26 | graph.add_edge(has_axe, Behavior("Get new axe", complexity=1.0), index=False)
27 | graph.add_edge(has_axe, Action("Use axe on tree", complexity=1.0), index=True)
28 | return graph
29 |
30 |
31 | class GetNewAxe(Behavior):
32 | """Get new axe with wood"""
33 |
34 | def __init__(self) -> None:
35 | """Get new axe with wood"""
36 | super().__init__("Get new axe")
37 |
38 | def build_graph(self) -> HEBGraph:
39 | graph = HEBGraph(self)
40 | has_wood = HasItem("wood")
41 | graph.add_edge(has_wood, Behavior("Gather wood", complexity=1.0), index=False)
42 | graph.add_edge(
43 | has_wood, Action("Summon axe out of thin air", complexity=10.0), index=False
44 | )
45 | graph.add_edge(has_wood, Action("Craft axe", complexity=1.0), index=True)
46 | return graph
47 |
48 |
49 | def build_looping_behaviors() -> List[Behavior]:
50 | behaviors: List[Behavior] = [GatherWood(), GetNewAxe()]
51 | all_behaviors = {behavior.name: behavior for behavior in behaviors}
52 | for behavior in behaviors:
53 | behavior.graph.all_behaviors = all_behaviors
54 | behavior.complexity = 5
55 | return behaviors
56 |
--------------------------------------------------------------------------------
/tests/examples/behaviors/report_example.py:
--------------------------------------------------------------------------------
1 | from hebg import HEBGraph, Action, FeatureCondition, Behavior
2 |
3 |
4 | class Behavior0(Behavior):
5 | """Behavior 0"""
6 |
7 | def __init__(self) -> None:
8 | super().__init__("behavior 0")
9 |
10 | def build_graph(self) -> HEBGraph:
11 | graph = HEBGraph(self)
12 | feature = FeatureCondition("feature 0", complexity=1)
13 | graph.add_edge(feature, Action(0, complexity=1), index=False)
14 | graph.add_edge(feature, Action(1, complexity=1), index=True)
15 | return graph
16 |
17 |
18 | class Behavior1(Behavior):
19 | """Behavior 1"""
20 |
21 | def __init__(self) -> None:
22 | super().__init__("behavior 1")
23 |
24 | def build_graph(self) -> HEBGraph:
25 | graph = HEBGraph(self)
26 | feature_1 = FeatureCondition("feature 1", complexity=1)
27 | feature_2 = FeatureCondition("feature 2", complexity=1)
28 | graph.add_edge(feature_1, Behavior0(), index=False)
29 | graph.add_edge(feature_1, feature_2, index=True)
30 | graph.add_edge(feature_2, Action(0, complexity=1), index=False)
31 | graph.add_edge(feature_2, Action(2, complexity=1), index=True)
32 | return graph
33 |
34 |
35 | class Behavior2(Behavior):
36 | """Behavior 2"""
37 |
38 | def __init__(self) -> None:
39 | super().__init__("behavior 2")
40 |
41 | def build_graph(self) -> HEBGraph:
42 | graph = HEBGraph(self)
43 | feature_3 = FeatureCondition("feature 3", complexity=1)
44 | feature_4 = FeatureCondition("feature 4", complexity=1)
45 | feature_5 = FeatureCondition("feature 5", complexity=1)
46 | graph.add_edge(feature_3, feature_4, index=False)
47 | graph.add_edge(feature_3, feature_5, index=True)
48 | graph.add_edge(feature_4, Action(0, complexity=1), index=False)
49 | graph.add_edge(feature_4, Behavior1(), index=True)
50 | graph.add_edge(feature_5, Behavior1(), index=False)
51 | graph.add_edge(feature_5, Behavior0(), index=True)
52 | return graph
53 |
--------------------------------------------------------------------------------
/src/hebg/layouts/deterministic.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 | # pylint: disable=protected-access
4 |
5 | """Deterministic layouts"""
6 |
7 | import networkx as nx
8 | import numpy as np
9 |
10 | from hebg.graph import get_roots
11 |
12 |
13 | def staircase_layout(graph: nx.DiGraph, center=None):
14 | """Compute specific default positions for an DiGraph.
15 |
16 | Requires graph to have a 'nodes_by_level' attribute.
17 |
18 | Args:
19 | graph: A networkx DiGraph with 'nodes_by_level' attribute (see compute_levels).
20 | center (Optional): Center of the graph layout.
21 |
22 | Returns:
23 | pos: Positions of each node.
24 |
25 | """
26 |
27 | def place_successors(pos, pos_by_level, node, level) -> int:
28 | if level not in pos_by_level:
29 | pos_by_level[level] = pos_by_level[level - 1]
30 | pos_by_level[level] = max(pos[node][0], pos_by_level[level])
31 | succs = list(graph.successors(node))
32 | if len(succs) == 0:
33 | return 1
34 | succs_order = np.argsort([graph.edges[node, succ]["index"] for succ in succs])
35 | for index, succ_id in enumerate(succs_order):
36 | succ = succs[succ_id]
37 | if succ in pos:
38 | continue
39 | pos[succ] = [pos_by_level[level], -level]
40 | if index == 0:
41 | pos[node][0] = max(pos[node][0], pos[succ][0])
42 | pos_by_level[level - 1] = max(pos_by_level[level - 1], pos[node][0])
43 | n_succs = place_successors(pos, pos_by_level, succ, level + 1)
44 | pos_by_level[level] += n_succs
45 | return len(succs)
46 |
47 | graph, _ = nx.drawing.layout._process_params(graph, center, dim=2)
48 | pos = {}
49 | pos_by_level = {0: 0}
50 | for node in get_roots(graph):
51 | pos[node] = [pos_by_level[0], 0]
52 | pos_by_level[0] += place_successors(pos, pos_by_level, node, 1)
53 | return pos
54 |
--------------------------------------------------------------------------------
/src/hebg/layouts/metaheuristics.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """Metaheuristics used for building layouts"""
5 |
6 | import numpy as np
7 |
8 |
9 | def simulated_annealing(
10 | initial,
11 | energy,
12 | neighbor,
13 | max_iterations: int = 1000,
14 | initial_temperature: float = 5,
15 | max_iters_without_new: int = np.inf,
16 | verbose: bool = 1,
17 | ):
18 | """Perform simulated annealing metaheuristic on an energy using a neighboring.
19 |
20 | See https://en.wikipedia.org/wiki/Simulated_annealing for more details.
21 |
22 | Args:
23 | initial: Initial variable position.
24 | energy: Function giving the energy (aka cost) of a given variable position.
25 | neighbor: Function giving a neighbor of a given variable position.
26 | max_iterations: Maximum number of iterations.
27 | initial_temperature: Initial temperature parameter, more is more random search.
28 | max_iters_without_new: Maximum number of iterations without a new best position.
29 |
30 | """
31 |
32 | def prob_keep(temperature, delta_e):
33 | return min(1, np.exp(delta_e / temperature))
34 |
35 | state = initial
36 | energy_pos = energy(state)
37 | iters_without_new = 0
38 | for k in range(max_iterations):
39 | new_state = neighbor(state)
40 | new_energy = energy(new_state)
41 | temperature = initial_temperature / (k + 1)
42 | iters_without_new += 1
43 | prob = prob_keep(temperature, energy_pos - new_energy)
44 | if np.random.random() < prob:
45 | if verbose == 1:
46 | print(
47 | f"{k}\t({prob:.0%})\t{energy_pos:.2f}->{new_energy:.2f}", end="\r"
48 | )
49 | state, energy_pos = new_state, new_energy
50 | iters_without_new = 0
51 |
52 | if iters_without_new >= max_iters_without_new:
53 | break
54 | if verbose == 1:
55 | print()
56 | return state
57 |
--------------------------------------------------------------------------------
/tests/examples/behaviors/basic_empty.py:
--------------------------------------------------------------------------------
1 | from hebg.node import Action, EmptyNode
2 | from hebg.behavior import Behavior
3 | from hebg.heb_graph import HEBGraph
4 |
5 | from tests.examples.feature_conditions import ThresholdFeatureCondition
6 |
7 |
8 | class E_A_Behavior(Behavior):
9 | """Empty behavior"""
10 |
11 | def __init__(self, name: str, action: Action) -> None:
12 | super().__init__(name)
13 | self.action = action
14 |
15 | def build_graph(self) -> HEBGraph:
16 | graph = HEBGraph(self)
17 | graph.add_edge(EmptyNode("empty"), self.action)
18 | return graph
19 |
20 |
21 | class E_F_A_Behavior(Behavior):
22 | """Double layer empty then feature conditions behavior"""
23 |
24 | def build_graph(self) -> HEBGraph:
25 | graph = HEBGraph(self)
26 | empty = EmptyNode("empty")
27 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0)
28 |
29 | graph.add_edge(empty, feature_condition)
30 | for i, edge_index in zip(range(2), (0, 1)):
31 | action = Action(i)
32 | graph.add_edge(feature_condition, action, index=edge_index)
33 |
34 | return graph
35 |
36 |
37 | class F_E_A_Behavior(Behavior):
38 | """Double layer feature conditions then empty behavior"""
39 |
40 | def build_graph(self) -> HEBGraph:
41 | graph = HEBGraph(self)
42 |
43 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0)
44 | empty_0 = EmptyNode("empty_0")
45 | empty_1 = EmptyNode("empty_1")
46 |
47 | graph.add_edge(feature_condition, empty_0, index=int(True))
48 | graph.add_edge(feature_condition, empty_1, index=int(False))
49 |
50 | graph.add_edge(empty_0, Action(0))
51 | graph.add_edge(empty_1, Action(1))
52 |
53 | return graph
54 |
55 |
56 | class E_E_A_Behavior(Behavior):
57 | """Double layer empty behavior"""
58 |
59 | def build_graph(self) -> HEBGraph:
60 | graph = HEBGraph(self)
61 |
62 | empty_0 = EmptyNode("empty_0")
63 | empty_1 = EmptyNode("empty_1")
64 |
65 | graph.add_edge(empty_0, empty_1)
66 | graph.add_edge(empty_1, Action(0))
67 |
68 | return graph
69 |
--------------------------------------------------------------------------------
/CONTRIBUTING.rst:
--------------------------------------------------------------------------------
1 | Contributing to HEBG
2 | ====================
3 |
4 | We are happy to receive contributions in the form of **pull requests** via Github.
5 | Feel free to fork the repository, implement your changes and create a merge request to the **main** branch.
6 |
7 | Git Commit Messages
8 | ~~~~~~~~~~~~~~~~~~~
9 |
10 | Commits should start with a Capital letter and should be written in present tense (e.g. ``:tada: Add cool new feature`` instead of ``:tada: Added cool new feature``).
11 | You should also start your commit message with **one** applicable emoji. This does not only look great but also makes you rethink what to add to a commit. Make many but small commits!
12 |
13 |
14 | .. list-table:: Title
15 | :header-rows: 1
16 |
17 | * - Emoji
18 | - Description
19 | * - `:tada:` ``:tada:``
20 | - When you added a cool new feature
21 | * - `:bug:` ``:bug:``
22 | - When you refactored / When you fixed a bug
23 | * - `:fire:` ``:fire:``
24 | - When you removed something.
25 | * - `:truck:` ``:truck:``
26 | - When you moved / renamed something.
27 | * - `:wrench:` ``:wrench:``
28 | - When you refactored / improved a small piece of code.
29 | * - `:hammer:` ``:hammer:``
30 | - When you refactored / improved large parts of the code.
31 | * - `:sparkles:` ``:sparkles:``
32 | - When you improved code quality (pylint, PEP, ...).
33 | * - `:art:` ``:art:``
34 | - When you improved / added design assets.
35 | * - `:rocket:` ``:rocket:``
36 | - When you improved performance.
37 | * - `:memo:` ``:memo:``
38 | - When you wrote documentation.
39 | * - `:umbrella:` ``:umbrella:``
40 | - When you improve coverage.
41 | * - `:twisted_rightwards_arrows:` ``:twisted_rightwards_arrows:``
42 | - When you merged a branch.
43 |
44 | This section was inspired by `This repository `_.
45 |
46 | Version Numbers
47 | ---------------
48 |
49 | Version numbers will be assigned according to the `Semantic Versioning `_. scheme.
50 | This means, given a version number MAJOR.MINOR.PATCH, we will increment the:
51 |
52 | 1. MAJOR version when we make incompatible API changes,
53 | 2. MINOR version when we add functionality in a backwards compatible manner, and
54 | 3. PATCH version when we make backwards compatible bug fixes.
55 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # VsCode
132 | .vscode
133 |
--------------------------------------------------------------------------------
/tests/examples/behaviors/loop_without_alternative.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from hebg import HEBGraph, Action, FeatureCondition, Behavior
4 |
5 |
6 | class ReachForest(Behavior):
7 | """Reach forest"""
8 |
9 | def __init__(self) -> None:
10 | """Reach forest"""
11 | super().__init__("Reach forest")
12 |
13 | def build_graph(self) -> HEBGraph:
14 | graph = HEBGraph(self)
15 | is_in_other_zone = FeatureCondition("Is in other zone ?")
16 | graph.add_edge(is_in_other_zone, Behavior("Reach other zone"), index=False)
17 | graph.add_edge(is_in_other_zone, Action("> forest"), index=True)
18 | is_in_other_zone = FeatureCondition("Is in meadow ?")
19 | graph.add_edge(is_in_other_zone, Behavior("Reach meadow"), index=False)
20 | graph.add_edge(is_in_other_zone, Action("> forest"), index=True)
21 | return graph
22 |
23 |
24 | class ReachOtherZone(Behavior):
25 | """Reach other zone"""
26 |
27 | def __init__(self) -> None:
28 | """Reach other zone"""
29 | super().__init__("Reach other zone")
30 |
31 | def build_graph(self) -> HEBGraph:
32 | graph = HEBGraph(self)
33 | is_in_forest = FeatureCondition("Is in forest ?")
34 | graph.add_edge(is_in_forest, Behavior("Reach forest"), index=False)
35 | graph.add_edge(is_in_forest, Action("> other zone"), index=True)
36 | is_in_other_zone = FeatureCondition("Is in meadow ?")
37 | graph.add_edge(is_in_other_zone, Behavior("Reach meadow"), index=False)
38 | graph.add_edge(is_in_other_zone, Action("> other zone"), index=True)
39 | return graph
40 |
41 |
42 | class ReachMeadow(Behavior):
43 | """Reach meadow"""
44 |
45 | def __init__(self) -> None:
46 | """Reach meadow"""
47 | super().__init__("Reach meadow")
48 |
49 | def build_graph(self) -> HEBGraph:
50 | graph = HEBGraph(self)
51 | is_in_forest = FeatureCondition("Is in forest ?")
52 | graph.add_edge(is_in_forest, Behavior("Reach forest"), index=False)
53 | graph.add_edge(is_in_forest, Action("> meadow"), index=True)
54 | is_in_other_zone = FeatureCondition("Is in other zone ?")
55 | graph.add_edge(is_in_other_zone, Behavior("Reach other zone"), index=False)
56 | graph.add_edge(is_in_other_zone, Action("> meadow"), index=True)
57 | return graph
58 |
59 |
60 | def build_looping_behaviors_without_direct_alternatives() -> List[Behavior]:
61 | behaviors: List[Behavior] = [
62 | ReachForest(),
63 | ReachOtherZone(),
64 | ReachMeadow(),
65 | ]
66 | all_behaviors = {behavior.name: behavior for behavior in behaviors}
67 | for behavior in behaviors:
68 | behavior.graph.all_behaviors = all_behaviors
69 | return behaviors
70 |
--------------------------------------------------------------------------------
/tests/test_node.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """Unit tests for the hebg.node module."""
5 |
6 | import pytest
7 | import pytest_check as check
8 |
9 | from hebg.node import Node, Action, FeatureCondition, EmptyNode
10 |
11 |
12 | class TestNode:
13 | """Node"""
14 |
15 | @pytest.fixture(autouse=True)
16 | def setup(self):
17 | """Initialize variables."""
18 |
19 | def test_node_type(self):
20 | """should have correct node_type and raise ValueError otherwise."""
21 | with pytest.raises(ValueError):
22 | Node("", "")
23 | for node_type in ("action", "feature_condition", "behavior", "empty"):
24 | node = Node("", node_type)
25 | check.equal(node.type, node_type)
26 |
27 | def test_node_name(self):
28 | """should have name as attribute."""
29 | name = "node_name"
30 | node = Node(name, "empty")
31 | check.equal(node.name, name)
32 |
33 | def test_node_call(self):
34 | """should raise NotImplementedError on call."""
35 | node = Node("", "empty")
36 | with pytest.raises(NotImplementedError):
37 | node(None)
38 |
39 |
40 | class TestAction:
41 | """Action"""
42 |
43 | @pytest.fixture(autouse=True)
44 | def setup(self):
45 | """Initialize variables."""
46 |
47 | def test_node_type(self):
48 | """should have 'action' as node_type."""
49 | node = Action("", "")
50 | check.equal(node.type, "action")
51 |
52 | def test_node_call(self):
53 | """should return Action.action when called."""
54 | action = "action_action"
55 | node = Action(action, "action_name")
56 | check.equal(node(None), action)
57 | check.equal(node.action, action)
58 |
59 |
60 | class TestFeatureCondition:
61 | """FeatureCondition"""
62 |
63 | @pytest.fixture(autouse=True)
64 | def setup(self):
65 | """Initialize variables."""
66 |
67 | def test_node_type(self):
68 | """should have 'feature_condition' as node_type."""
69 | node = FeatureCondition("")
70 | check.equal(node.type, "feature_condition")
71 |
72 | def test_node_call(self):
73 | """should raise NotImplementedError on call."""
74 | node = FeatureCondition("")
75 | with pytest.raises(NotImplementedError):
76 | node(None)
77 |
78 |
79 | class TestEmptyNode:
80 | """EmptyNode"""
81 |
82 | @pytest.fixture(autouse=True)
83 | def setup(self):
84 | """Initialize variables."""
85 |
86 | def test_node_type(self):
87 | """should have 'empty' as node_type."""
88 | node = EmptyNode("")
89 | check.equal(node.type, "empty")
90 |
91 | def test_node_call(self):
92 | """should return 1 when called."""
93 | node = EmptyNode("")
94 | check.equal(node(None), 1)
95 |
--------------------------------------------------------------------------------
/src/hebg/node.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """Module for base Node classes."""
5 |
6 | import dis
7 | from typing import Any, List
8 |
9 | import numpy as np
10 |
11 |
12 | def bytecode_complexity(obj):
13 | """Compute the number of instructions in the bytecode of a given obj."""
14 | return len(list(dis.get_instructions(obj)))
15 |
16 |
17 | class Node:
18 | NODE_TYPES = ("action", "feature_condition", "behavior", "empty")
19 |
20 | def __init__(
21 | self,
22 | name: str,
23 | node_type: str,
24 | complexity: int = None,
25 | image=None,
26 | ) -> None:
27 | """Base Node class for any HEBGraph.
28 |
29 | Args:
30 | name (str): A UNIQUE name representing this node.
31 | node_type (str): One of {action, feature_condition, behavior, empty}.
32 | complexity (int, optional): Given individual complexity of the node.
33 | If None, uses the number of bytecode instructions of init and call.
34 | Defaults to None.
35 | image (2d array, optional): Image to represent the node. Defaults to None.
36 |
37 | Raises:
38 | ValueError: If node_type has an unexpected value.
39 | """
40 | self.name = name
41 | self.image = image
42 | if node_type not in self.NODE_TYPES:
43 | raise ValueError(
44 | f"node_type ({node_type})"
45 | f"not in authorised node_types ({self.NODE_TYPES})."
46 | )
47 | self.type = node_type
48 | self.complexity = complexity
49 |
50 | def __call__(self, observation: Any) -> Any:
51 | raise NotImplementedError
52 |
53 | def __str__(self) -> str:
54 | return self.name
55 |
56 | def __eq__(self, o: object) -> bool:
57 | return self.name == str(o)
58 |
59 | def __hash__(self) -> int:
60 | return self.name.__hash__()
61 |
62 | def __repr__(self) -> str:
63 | return self.name
64 |
65 |
66 | class Action(Node):
67 | """Node representing an action in an HEBGraph."""
68 |
69 | def __init__(self, action: Any, name: str = None, **kwargs) -> None:
70 | self.action = action
71 | super().__init__(self._get_name(name), "action", **kwargs)
72 |
73 | def _get_name(self, name):
74 | """Get the default name of the action if None is given."""
75 | return f"Action({self.action})" if name is None else name
76 |
77 | def __call__(self, observation: Any) -> Any:
78 | return self.action
79 |
80 |
81 | class StochasticAction(Action):
82 | """Node representing a stochastic choice between actions in an HEBGraph."""
83 |
84 | def __init__(
85 | self, actions: List[Action], probs: list, name: str, image=None
86 | ) -> None:
87 | super().__init__(actions, name=name, image=image)
88 | self.probs = probs
89 |
90 | def __call__(self, observation):
91 | selected_action = np.random.choice(self.action, p=self.probs)
92 | return selected_action(observation)
93 |
94 |
95 | class FeatureCondition(Node):
96 | """Node representing a feature condition in an HEBGraph."""
97 |
98 | def __init__(self, name: str = None, **kwargs) -> None:
99 | super().__init__(name, "feature_condition", **kwargs)
100 |
101 | def __call__(self, observation: Any) -> int:
102 | raise NotImplementedError
103 |
104 |
105 | class EmptyNode(Node):
106 | """Node representing an empty node in an HEBGraph."""
107 |
108 | def __init__(self, name: str) -> None:
109 | super().__init__(name, "empty")
110 |
111 | def __call__(self, observation: Any) -> int:
112 | return int(True)
113 |
--------------------------------------------------------------------------------
/src/hebg/layouts/metabased.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 | # pylint: disable=protected-access
4 |
5 | """Metaheuristics based layouts"""
6 |
7 | from copy import deepcopy
8 |
9 | import networkx as nx
10 | import numpy as np
11 |
12 | from hebg.layouts.metaheuristics import simulated_annealing
13 |
14 |
15 | def leveled_layout_energy(
16 | graph: nx.DiGraph, center=None, metaheuristic=simulated_annealing
17 | ):
18 | """Compute positions for a leveled DiGraph using a metaheuristic to minimize energy.
19 |
20 | Requires each node to have a 'level' attribute.
21 |
22 | Args:
23 | graph: A networkx DiGraph.
24 | center (Optional): Center of the graph layout.
25 |
26 | Returns:
27 | pos: Positions of each node.
28 | nodes_by_level: List of nodes by levels.
29 |
30 | """
31 | graph, center = nx.drawing.layout._process_params(graph, center, dim=2)
32 |
33 | nodes_by_level = graph.graph["nodes_by_level"]
34 | pos = {}
35 | step_size = 1 / max(len(nodes_by_level[level]) for level in nodes_by_level)
36 | spacing = np.arange(0, 1, step=step_size)
37 | for level in nodes_by_level:
38 | n_nodes_in_level = len(nodes_by_level[level])
39 | if n_nodes_in_level > 1:
40 | positions = np.linspace(
41 | 0, len(spacing) - 1, n_nodes_in_level, endpoint=True, dtype=np.int32
42 | )
43 | positions = spacing[positions]
44 | else:
45 | positions = [spacing[(len(spacing) - 1) // 2]]
46 |
47 | for i, node in enumerate(nodes_by_level[level]):
48 | pos[node] = [level, positions[i]]
49 |
50 | def energy(pos, nodes_strenght=1, edges_strenght=2):
51 | def dist(start, stop):
52 | x_arr, y_arr = np.array(start), np.array(stop)
53 | return np.linalg.norm(x_arr - y_arr)
54 |
55 | energy = 0
56 | for level in nodes_by_level:
57 | for node in nodes_by_level[level]:
58 | energy += nodes_strenght * sum(
59 | np.square(dist(pos[node], pos[n]))
60 | for n in nodes_by_level[level]
61 | if n != node
62 | )
63 | energy -= sum(
64 | edges_strenght
65 | / abs(
66 | max(1, graph.nodes[node]["level"] - graph.nodes[pred]["level"])
67 | )
68 | / max(1e-6, dist(pos[node], pos[pred]))
69 | for pred in graph.predecessors(node)
70 | )
71 | energy -= sum(
72 | edges_strenght
73 | / abs(
74 | max(1, graph.nodes[node]["level"] - graph.nodes[succ]["level"])
75 | )
76 | / max(1e-6, dist(pos[node], pos[succ]))
77 | for succ in graph.successors(node)
78 | )
79 |
80 | return energy
81 |
82 | def neighbor(pos: dict):
83 | pos_copy = deepcopy(pos)
84 | nodes_list = list(pos.keys())
85 | choosen_node_id = int(np.random.randint(len(nodes_list)))
86 | choosen_node = nodes_list[choosen_node_id]
87 | choosen_level = graph.nodes(data="level")[choosen_node]
88 | new_pos = [pos_copy[choosen_node][0], np.random.choice(spacing)]
89 | for node in nodes_by_level[choosen_level]:
90 | if node != choosen_node and np.all(np.isclose(new_pos, pos_copy[node])):
91 | pos_copy[choosen_node], pos_copy[node] = (
92 | pos_copy[node],
93 | pos_copy[choosen_node],
94 | )
95 | return pos_copy
96 | pos_copy[choosen_node] = new_pos
97 | return pos_copy
98 |
99 | return metaheuristic(pos, energy, neighbor)
100 |
--------------------------------------------------------------------------------
/src/hebg/requirements_graph.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 | # pylint: disable=arguments-differ
4 |
5 | """Module for building underlying requirement graphs based on a set of behaviors."""
6 |
7 | from __future__ import annotations
8 |
9 | from copy import deepcopy
10 | from typing import Dict, List
11 |
12 | from networkx import DiGraph, descendants
13 |
14 | from hebg.behavior import Behavior
15 | from hebg.graph import compute_levels
16 | from hebg.heb_graph import HEBGraph
17 | from hebg.node import EmptyNode
18 |
19 |
20 | def build_requirement_graph(behaviors: List[Behavior]) -> DiGraph:
21 | """Builds a DiGraph of the requirements induced by a list of behaviors.
22 |
23 | Args:
24 | behaviors: List of Behaviors to build the requirement graph from.
25 |
26 | Returns:
27 | The requirement graph induced by the given list of behaviors.
28 |
29 | """
30 |
31 | try:
32 | heb_graphs = [behavior.graph for behavior in behaviors]
33 | except NotImplementedError as error:
34 | user_msg = "All behaviors given must be able to build an HEBGraph"
35 | raise NotImplementedError(user_msg) from error
36 |
37 | requirements_graph = DiGraph()
38 | for behavior in behaviors:
39 | requirements_graph.add_node(behavior)
40 |
41 | requirement_degree = {}
42 |
43 | for graph in heb_graphs:
44 | requirement_degree[graph.behavior] = {}
45 | for node in graph.nodes():
46 | if not isinstance(node, EmptyNode):
47 | continue
48 | requirement_degree = _cut_alternatives_to_empty_node(
49 | graph, node, requirement_degree
50 | )
51 |
52 | for graph in heb_graphs:
53 | for node in graph.nodes():
54 | if not isinstance(node, Behavior):
55 | continue
56 | if node not in requirement_degree[graph.behavior]:
57 | requirement_degree[graph.behavior][node] = 0
58 | requirement_degree[graph.behavior][node] += 1
59 |
60 | index = 0
61 | for graph in heb_graphs:
62 | for node in graph.nodes():
63 | if (
64 | not isinstance(node, Behavior)
65 | or requirement_degree[graph.behavior][node] == 0
66 | ):
67 | continue
68 | if node not in requirements_graph.nodes():
69 | requirements_graph.add_node(node)
70 | index = len(list(requirements_graph.successors(node))) + 1
71 | requirements_graph.add_edge(node, graph.behavior, index=index)
72 |
73 | compute_levels(requirements_graph)
74 | return requirements_graph
75 |
76 |
77 | def _cut_alternatives_to_empty_node(
78 | graph: HEBGraph,
79 | node: EmptyNode,
80 | requirement_degree: Dict[Behavior, Dict[Behavior, int]],
81 | ) -> Dict[Behavior, Dict[Behavior, int]]:
82 | successor = list(graph.successors(node))[0]
83 | empty_index = graph.edges[node, successor]["index"]
84 | alternatives = graph.predecessors(successor)
85 | alternatives = [
86 | alt_node
87 | for alt_node in alternatives
88 | if graph.edges[alt_node, successor]["index"] == empty_index
89 | ]
90 | cut_graph = deepcopy(graph)
91 | for alternative in alternatives:
92 | cut_graph.remove_edge(alternative, successor)
93 | for alternative in alternatives:
94 | following_behaviors = [
95 | following_node
96 | for following_node in descendants(cut_graph, alternative)
97 | if isinstance(following_node, Behavior)
98 | ]
99 | for following_behavior in following_behaviors:
100 | if following_behavior not in requirement_degree[graph.behavior]:
101 | requirement_degree[graph.behavior][following_behavior] = 0
102 | requirement_degree[graph.behavior][following_behavior] -= 1
103 |
104 | return requirement_degree
105 |
--------------------------------------------------------------------------------
/tests/examples/behaviors/basic.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | from hebg.node import Action, FeatureCondition
4 | from hebg.behavior import Behavior
5 | from hebg.heb_graph import HEBGraph
6 |
7 | from tests.examples.feature_conditions import ThresholdFeatureCondition
8 |
9 |
10 | class FundamentalBehavior(Behavior):
11 | """Fundamental behavior based on an Action."""
12 |
13 | def __init__(self, action: Action) -> None:
14 | self.action = action
15 | name = action.name + "_behavior"
16 | super().__init__(name, image=action.image)
17 |
18 | def build_graph(self) -> HEBGraph:
19 | graph = HEBGraph(self)
20 | graph.add_node(self.action)
21 | return graph
22 |
23 |
24 | class AA_Behavior(Behavior):
25 | """Double root fundamental behavior"""
26 |
27 | def __init__(self, name: str, any_mode: str) -> None:
28 | super().__init__(name, image=None)
29 | self.any_mode = any_mode
30 |
31 | def build_graph(self) -> HEBGraph:
32 | graph = HEBGraph(self, any_mode=self.any_mode)
33 |
34 | graph.add_node(Action(0))
35 | graph.add_node(Action(1))
36 |
37 | return graph
38 |
39 |
40 | class F_A_Behavior(Behavior):
41 | """Single feature condition behavior"""
42 |
43 | def __init__(
44 | self,
45 | name: str,
46 | feature_condition: FeatureCondition,
47 | actions: Dict[int, Action],
48 | ) -> None:
49 | """Single feature condition behavior
50 |
51 | Args:
52 | name (str): Name of the behavior.
53 | feature_condition (FeatureCondition): Feature_condition used in behavior.
54 | actions (Dict[int, Action]): Mapping from feature_condition output to actions.
55 | """
56 | super().__init__(name)
57 | self.actions = actions
58 | self.feature_condition = feature_condition
59 |
60 | def build_graph(self) -> HEBGraph:
61 | graph = HEBGraph(self)
62 | for fc_output, action in self.actions.items():
63 | graph.add_edge(self.feature_condition, action, index=fc_output)
64 | return graph
65 |
66 |
67 | class F_F_A_Behavior(Behavior):
68 | """Double layer feature conditions behavior"""
69 |
70 | def __init__(self, name: str = "F_F_A", *args, **kwargs) -> None:
71 | super().__init__(name=name, *args, **kwargs)
72 |
73 | def build_graph(self) -> HEBGraph:
74 | graph = HEBGraph(self)
75 |
76 | feature_condition_1 = ThresholdFeatureCondition(relation=">=", threshold=0)
77 | feature_condition_2 = ThresholdFeatureCondition(relation="<=", threshold=1)
78 | feature_condition_3 = ThresholdFeatureCondition(relation=">=", threshold=-1)
79 |
80 | graph.add_edge(feature_condition_1, feature_condition_2, index=True)
81 | graph.add_edge(feature_condition_1, feature_condition_3, index=False)
82 |
83 | for action, edge_index in zip(range(2, 4), (1, 0)):
84 | graph.add_edge(feature_condition_2, Action(action), index=edge_index)
85 |
86 | for action, edge_index in zip(range(2), (0, 1)):
87 | graph.add_edge(feature_condition_3, Action(action), index=edge_index)
88 |
89 | return graph
90 |
91 |
92 | class F_AA_Behavior(Behavior):
93 | """Feature condition with mutliple actions on same index."""
94 |
95 | def __init__(self, name: str = "F_AA") -> None:
96 | super().__init__(name, image=None)
97 |
98 | def build_graph(self) -> HEBGraph:
99 | graph = HEBGraph(self)
100 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0)
101 |
102 | graph.add_edge(feature_condition, Action(0), index=int(True))
103 | graph.add_edge(feature_condition, Action(1), index=int(False))
104 | graph.add_edge(feature_condition, Action(2), index=int(False))
105 |
106 | return graph
107 |
108 |
109 | class AF_A_Behavior(Behavior):
110 | """Double root with feature condition and action"""
111 |
112 | def __init__(self, name: str, any_mode: str) -> None:
113 | super().__init__(name, image=None)
114 | self.any_mode = any_mode
115 |
116 | def build_graph(self) -> HEBGraph:
117 | graph = HEBGraph(self, any_mode=self.any_mode)
118 |
119 | graph.add_node(Action(0))
120 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0)
121 |
122 | graph.add_edge(feature_condition, Action(1), index=int(True))
123 | graph.add_edge(feature_condition, Action(2), index=int(False))
124 |
125 | return graph
126 |
--------------------------------------------------------------------------------
/src/hebg/metrics/complexity/complexities.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """General complexity."""
5 |
6 | from typing import TYPE_CHECKING, Dict, Tuple
7 |
8 | from hebg.behavior import Behavior
9 | from hebg.metrics.complexity.utils import update_sum_dict
10 | from hebg.node import Action
11 |
12 | if TYPE_CHECKING:
13 | from hebg.node import Node
14 |
15 |
16 | def general_complexity(
17 | behavior: Behavior,
18 | used_nodes_all: Dict["Node", Dict["Node", int]],
19 | scomplexity,
20 | kcomplexity,
21 | previous_used_nodes: Dict["Node", int] = None,
22 | ) -> Tuple[float]:
23 | """Compute the general complexity of a Behavior with used nodes.
24 |
25 | Using the number of time each node is used in its HEBGraph and based on the increase of
26 | complexity given by 'kcomplexity', and the saved complexity using behaviors given by
27 | 'saved_complexity', we sum the general complexity of an behavior and the total saved complexity.
28 |
29 | Args:
30 | behavior: The Behavior for which we compute the general complexity.
31 | used_nodes_all: Dictionary of dictionary of the number of times each nodes was used in the
32 | past, and thus for each node. Not being in a dictionary is counted as not being used.
33 | scomplexity: Callable taking as input (node, n_used, n_previous_used), where node is
34 | the node concerned, n_used is the number of time this node is used and n_previous_used
35 | is the number of time this node was used before that, returning the saved complexity
36 | thanks to using a behavior.
37 | kcomplexity: Callable taking as input (node, n_used), where node is the node concerned and
38 | n_used is the number of time this node is used, returning the accumulated complexity.
39 | previous_used_nodes: Dictionary of the number of times each nodes was used in the past,
40 | not being in the dictionary is counted as 0.
41 |
42 | Returns:
43 | Tuple composed of the general complexity and the total saved complexity.
44 |
45 | """
46 |
47 | previous_used_nodes = previous_used_nodes if previous_used_nodes else {}
48 |
49 | total_complexity = 0
50 | saved_complexity = 0
51 |
52 | for node, n_used in used_nodes_all[behavior].items():
53 | n_previous_used = (
54 | previous_used_nodes[node] if node in previous_used_nodes else 0
55 | )
56 |
57 | if isinstance(node, Behavior) and node in used_nodes_all:
58 | node_complexity, saved_node_complexity = general_complexity(
59 | node,
60 | used_nodes_all,
61 | scomplexity=scomplexity,
62 | kcomplexity=kcomplexity,
63 | previous_used_nodes=previous_used_nodes.copy(),
64 | )
65 | previous_used_nodes = update_sum_dict(
66 | previous_used_nodes, used_nodes_all[node]
67 | )
68 | total_complexity += saved_node_complexity * kcomplexity(node, n_used)
69 | saved_complexity += saved_node_complexity * kcomplexity(node, n_used)
70 | else:
71 | node_complexity = node.complexity
72 |
73 | total_complexity += node_complexity * kcomplexity(node, n_used)
74 |
75 | if isinstance(node, (Behavior, Action)):
76 | saved_complexity += node_complexity * scomplexity(
77 | node, n_used, n_previous_used
78 | )
79 |
80 | previous_used_nodes = update_sum_dict(previous_used_nodes, {node: n_used})
81 |
82 | return total_complexity - saved_complexity, saved_complexity
83 |
84 |
85 | def learning_complexity(
86 | behavior: Behavior,
87 | used_nodes_all: Dict["Node", Dict["Node", int]],
88 | previous_used_nodes=None,
89 | ):
90 | """Compute the learning complexity of a Behavior with used nodes.
91 |
92 | Using the number of time each node is used in its HEBGraph we compute the learning
93 | complexity of a behavior and the total saved complexity.
94 |
95 | Args:
96 | behavior: The Behavior for which we compute the learning complexity.
97 | used_nodes_all: Dictionary of dictionary of the number of times each nodes was used in the
98 | past, and thus for each node. Not being in a dictionary is counted as not being used.
99 | previous_used_nodes: Dictionary of the number of times each nodes was used in the past,
100 | not being in the dictionary is counted as not being used.
101 |
102 | Returns:
103 | Tuple composed of the learning complexity and the total saved complexity.
104 |
105 | """
106 | return general_complexity(
107 | behavior=behavior,
108 | used_nodes_all=used_nodes_all,
109 | previous_used_nodes=previous_used_nodes,
110 | scomplexity=lambda node, k, p: max(0, min(k, p + k - 1)),
111 | kcomplexity=lambda node, k: k,
112 | )
113 |
--------------------------------------------------------------------------------
/tests/test_loop.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import pytest_check as check
3 |
4 | import networkx as nx
5 | from hebg.unrolling import unroll_graph
6 | from tests import plot_graph
7 |
8 | from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors
9 | from tests.examples.behaviors.loop_without_alternative import (
10 | build_looping_behaviors_without_direct_alternatives,
11 | )
12 |
13 |
14 | class TestLoopAlternative:
15 | """Tests for the loop with alternative example"""
16 |
17 | @pytest.fixture(autouse=True)
18 | def setup_method(self):
19 | self.gather_wood, self.get_new_axe = build_looping_behaviors()
20 |
21 | def test_unroll_gather_wood(self):
22 | draw = False
23 | unrolled_graph = unroll_graph(self.gather_wood.graph, add_prefix=True)
24 | if draw:
25 | plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True)
26 |
27 | expected_graph = nx.DiGraph()
28 | expected_graph.add_edge("Has axe", "Punch tree")
29 | expected_graph.add_edge("Has axe", "Cut tree with axe")
30 | expected_graph.add_edge("Has axe", "Has wood")
31 |
32 | # Expected sub-behavior
33 | expected_graph.add_edge("Has wood", "Gather wood")
34 | expected_graph.add_edge("Has wood", "Craft axe")
35 | expected_graph.add_edge("Has wood", "Summon axe out of thin air")
36 | check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))
37 |
38 | def test_unroll_get_new_axe(self):
39 | draw = False
40 | unrolled_graph = unroll_graph(self.get_new_axe.graph, add_prefix=True)
41 | if draw:
42 | plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True)
43 |
44 | expected_graph = nx.DiGraph()
45 | expected_graph.add_edge("Has wood", "Has axe")
46 | expected_graph.add_edge("Has wood", "Craft new axe")
47 | expected_graph.add_edge("Has wood", "Summon axe out of thin air")
48 |
49 | # Expected sub-behavior
50 | expected_graph.add_edge("Has axe", "Punch tree")
51 | expected_graph.add_edge("Has axe", "Cut tree with axe")
52 | expected_graph.add_edge("Has axe", "Get new axe")
53 | check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))
54 |
55 | def test_unroll_gather_wood_cutting_alternatives(self):
56 | draw = False
57 | unrolled_graph = unroll_graph(
58 | self.gather_wood.graph, add_prefix=True, cut_looping_alternatives=True
59 | )
60 | if draw:
61 | plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True)
62 |
63 | expected_graph = nx.DiGraph()
64 | expected_graph.add_edge("Has axe", "Punch tree")
65 | expected_graph.add_edge("Has axe", "Has wood")
66 | expected_graph.add_edge("Has axe", "Use axe")
67 |
68 | # Expected sub-behavior
69 | expected_graph.add_edge("Has wood", "Summon axe of out thin air")
70 | expected_graph.add_edge("Has wood", "Craft axe")
71 |
72 | check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))
73 |
74 | def test_unroll_get_new_axe_cutting_alternatives(self):
75 | draw = False
76 | unrolled_graph = unroll_graph(
77 | self.get_new_axe.graph, add_prefix=True, cut_looping_alternatives=True
78 | )
79 | if draw:
80 | plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True)
81 |
82 | expected_graph = nx.DiGraph(
83 | [
84 | ("Has wood", "Has axe"),
85 | ("Has wood", "Craft new axe"),
86 | ("Has wood", "Summon axe out of thin air"),
87 | # Expected sub-behavior
88 | ("Has axe", "Punch tree"),
89 | ("Has axe", "Cut tree with axe"),
90 | ]
91 | )
92 | check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))
93 |
94 | @pytest.mark.xfail
95 | def test_unroll_root_alternative_reach_forest(self):
96 | (
97 | reach_forest,
98 | _reach_other_zone,
99 | _reach_meadow,
100 | ) = build_looping_behaviors_without_direct_alternatives()
101 | draw = False
102 | unrolled_graph = unroll_graph(
103 | reach_forest.graph,
104 | add_prefix=True,
105 | cut_looping_alternatives=True,
106 | )
107 | if draw:
108 | plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True)
109 |
110 | expected_graph = nx.DiGraph(
111 | [
112 | # ("Root", "Is in other zone ?"),
113 | # ("Root", "Is in meadow ?"),
114 | ("Is in other zone ?", "Reach other zone"),
115 | ("Is in other zone ?", "Go to forest"),
116 | ("Is in meadow ?", "Go to forest"),
117 | ("Is in meadow ?", "Reach meadow>Is in other zones ?"),
118 | ("Reach meadow>Is in other zone ?", "Reach meadow>Reach other zone"),
119 | ("Reach meadow>Is in other zone ?", "Reach meadow>Go to forest"),
120 | ]
121 | )
122 | check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))
123 |
--------------------------------------------------------------------------------
/tests/test_behavior.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """Behavior of HEBGraphs when called."""
5 |
6 | import pytest
7 | import pytest_check as check
8 | from pytest_mock import MockerFixture
9 |
10 | from hebg.behavior import Behavior
11 | from hebg.heb_graph import HEBGraph
12 | from hebg.node import Action
13 |
14 | from tests.examples.behaviors import FundamentalBehavior, F_A_Behavior, F_F_A_Behavior
15 | from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors
16 | from tests.examples.feature_conditions import ThresholdFeatureCondition
17 |
18 |
19 | class TestBehavior:
20 |
21 | """Behavior"""
22 |
23 | @pytest.fixture(autouse=True)
24 | def setup(self):
25 | """Initialize variables."""
26 | self.node = Behavior("behavior_name")
27 |
28 | def test_node_type(self):
29 | """should have 'behavior' as node_type."""
30 | check.equal(self.node.type, "behavior")
31 |
32 | def test_node_call(self, mocker: MockerFixture):
33 | """should use graph on call."""
34 | mocker.patch("hebg.behavior.Behavior.graph")
35 | self.node(None)
36 | check.is_true(self.node.graph.called)
37 |
38 | def test_build_graph(self):
39 | """should raise NotImplementedError when build_graph is called."""
40 | with pytest.raises(NotImplementedError):
41 | self.node.build_graph()
42 |
43 | def test_graph(self, mocker: MockerFixture):
44 | """should build graph and compute its levels if, and only if,
45 | the graph is not yet built.
46 | """
47 | mocker.patch("hebg.behavior.Behavior.build_graph")
48 | mocker.patch("hebg.behavior.compute_levels")
49 | self.node.graph
50 | check.is_true(self.node.build_graph.called)
51 | check.is_true(self.node.build_graph.called)
52 |
53 | mocker.patch("hebg.behavior.Behavior.build_graph")
54 | mocker.patch("hebg.behavior.compute_levels")
55 | self.node.graph
56 | check.is_false(self.node.build_graph.called)
57 | check.is_false(self.node.build_graph.called)
58 |
59 |
60 | class TestPathfinding:
61 | def test_fundamental_behavior(self):
62 | """Fundamental behavior (single action) should return its action."""
63 | action_id = 42
64 | behavior = FundamentalBehavior(Action(action_id))
65 | check.equal(behavior(None), action_id)
66 |
67 | def test_feature_condition_single(self):
68 | """Feature condition should orient path properly."""
69 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0)
70 | actions = {0: Action(0), 1: Action(1)}
71 | behavior = F_A_Behavior("F_A", feature_condition, actions)
72 | check.equal(behavior(1), 1)
73 | check.equal(behavior(-1), 0)
74 |
75 | def test_feature_conditions_chained(self):
76 | """Feature condition should orient path properly in double chain."""
77 | behavior = F_F_A_Behavior("F_F_A")
78 | check.equal(behavior(-2), 0)
79 | check.equal(behavior(-1), 1)
80 | check.equal(behavior(1), 2)
81 | check.equal(behavior(2), 3)
82 |
83 | def test_looping_resolve(self):
84 | """Loops with alternatives should be ignored."""
85 | _gather_wood, get_axe = build_looping_behaviors()
86 | check.equal(get_axe({}), "Punch tree")
87 |
88 |
89 | class TestCostBehavior:
90 | def test_choose_root_of_lesser_cost(self):
91 | """Should choose root of lesser cost."""
92 |
93 | expected_action = "EXPECTED"
94 |
95 | class AAA_Behavior(Behavior):
96 | def __init__(self) -> None:
97 | super().__init__("AAA")
98 |
99 | def build_graph(self) -> HEBGraph:
100 | graph = HEBGraph(self)
101 | graph.add_node(Action(0, complexity=2))
102 | graph.add_node(Action(expected_action, complexity=1))
103 | graph.add_node(Action(2, complexity=3))
104 | return graph
105 |
106 | behavior = AAA_Behavior()
107 | check.equal(behavior(None), expected_action)
108 |
109 | def test_not_path_of_least_cost(self):
110 | """Should choose path of larger complexity if individual costs lead to it."""
111 |
112 | class AF_A_Behavior(Behavior):
113 |
114 | """Double root with feature condition and action"""
115 |
116 | def __init__(self) -> None:
117 | super().__init__("AF_A")
118 |
119 | def build_graph(self) -> HEBGraph:
120 | graph = HEBGraph(self)
121 |
122 | graph.add_node(Action(0, complexity=1.5))
123 | feature_condition = ThresholdFeatureCondition(
124 | relation=">=", threshold=0, complexity=1.0
125 | )
126 |
127 | graph.add_edge(
128 | feature_condition, Action(1, complexity=1.0), index=int(True)
129 | )
130 | graph.add_edge(
131 | feature_condition, Action(2, complexity=1.0), index=int(False)
132 | )
133 |
134 | return graph
135 |
136 | behavior = AF_A_Behavior()
137 | check.equal(behavior(1), 1)
138 | check.equal(behavior(-1), 2)
139 |
--------------------------------------------------------------------------------
/src/hebg/heb_graph.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 | # pylint: disable=arguments-differ
4 |
5 | """Module containing the HEBGraph base class."""
6 |
7 | from __future__ import annotations
8 |
9 | from typing import Any, Dict, List, Optional, Tuple
10 |
11 | from matplotlib.axes import Axes
12 | from networkx import DiGraph
13 |
14 | from hebg.behavior import Behavior
15 | from hebg.call_graph import CallGraph
16 | from hebg.codegen import get_hebg_source
17 | from hebg.draw import draw_hebgraph
18 | from hebg.graph import get_roots
19 | from hebg.node import Node
20 | from hebg.unrolling import unroll_graph
21 |
22 |
23 | class HEBGraph(DiGraph):
24 | """Base class for Hierchical Explanation of Behavior as Graphs.
25 |
26 | An HEBGraph is a DiGraph, and as such stores nodes and directed edges with
27 | optional data, or attributes.
28 |
29 | But nodes of an HEBGraph are not arbitrary.
30 | Leaf nodes can either be an Action or a Behavior.
31 | Other nodes can either be a FeatureCondition or an EmptyNode.
32 |
33 | An HEBGraph determines a behavior, it can be called with an observation
34 | to return the action given by this behavior.
35 |
36 | An HEBGraph edges are directed and indexed,
37 | this indexing for path making when calling the graph.
38 |
39 | As in a DiGraph loops are allowed but multiple (parallel) edges are not.
40 |
41 | Args:
42 | behavior: The Behavior object from which this graph is built.
43 | all_behaviors: A dictionary of behavior, this can be used to avoid cirular definitions using
44 | the behavior names as anchor instead of the behavior object itself.
45 | incoming_graph_data: Additional data to include in the graph.
46 |
47 | """
48 |
49 | NODES_COLORS = {"feature_condition": "blue", "action": "red", "behavior": "orange"}
50 | EDGES_COLORS = {
51 | 0: "red",
52 | 1: "green",
53 | 2: "blue",
54 | 3: "yellow",
55 | 4: "purple",
56 | 5: "cyan",
57 | 6: "gray",
58 | }
59 |
60 | def __init__(
61 | self,
62 | behavior: Behavior,
63 | all_behaviors: Dict[str, Behavior] = None,
64 | incoming_graph_data=None,
65 | **attr,
66 | ):
67 | self.behavior = behavior
68 | self.all_behaviors = all_behaviors if all_behaviors is not None else {}
69 |
70 | self._unrolled_graph = None
71 | self.call_graph: Optional[CallGraph] = None
72 |
73 | super().__init__(incoming_graph_data=incoming_graph_data, **attr)
74 |
75 | def add_node(self, node_for_adding: Node, **attr):
76 | node = node_for_adding
77 | color = attr.pop("color", None)
78 | attr.pop("type", None)
79 | attr.pop("image", None)
80 | if color is None:
81 | try:
82 | color = self.NODES_COLORS[node.type]
83 | except KeyError:
84 | color = None
85 | super().add_node(node, type=node.type, color=color, image=node.image, **attr)
86 |
87 | def add_edge(self, u_of_edge: Node, v_of_edge: Node, index: int = 1, **attr):
88 | for node in (u_of_edge, v_of_edge):
89 | if node not in self.nodes():
90 | self.add_node(node)
91 |
92 | color = attr.pop("color", None)
93 | if color is None:
94 | try:
95 | color = self.EDGES_COLORS[index]
96 | except KeyError:
97 | color = "black"
98 | super().add_edge(u_of_edge, v_of_edge, index=index, color=color, **attr)
99 |
100 | @property
101 | def unrolled_graph(self) -> HEBGraph:
102 | """Access to the unrolled behavior graph.
103 |
104 | The unrolled behavior graph as the same behavior but every behavior node is recursively replaced
105 | by it's behavior graph if it can be computed.
106 |
107 | Only build's the graph the first time called for efficiency.
108 |
109 | Returns:
110 | This HEBGraph's unrolled HEBGraph.
111 |
112 | """
113 | if self._unrolled_graph is None:
114 | self._unrolled_graph = unroll_graph(self)
115 | return self._unrolled_graph
116 |
117 | def __call__(
118 | self,
119 | observation,
120 | call_graph: Optional[CallGraph] = None,
121 | ) -> Any:
122 | if call_graph is None:
123 | call_graph = CallGraph()
124 | call_graph.add_root(heb_node=self.behavior, heb_graph=self)
125 | self.call_graph = call_graph
126 | return self.call_graph.call_nodes(self.roots, observation, heb_graph=self)
127 |
128 | @property
129 | def roots(self) -> List[Node]:
130 | """Roots of the behavior graph (nodes without predecessors)."""
131 | return get_roots(self)
132 |
133 | def generate_source_code(self) -> str:
134 | """Generated source code of the behavior from graph."""
135 | return get_hebg_source(self)
136 |
137 | def draw(
138 | self, ax: "Axes", **kwargs
139 | ) -> Tuple["Axes", Dict[Node, Tuple[float, float]]]:
140 | """Draw the HEBGraph on the given Axis.
141 |
142 | Args:
143 | ax: The matplotlib ax to draw on.
144 |
145 | Kwargs:
146 | fontcolor: Font color to use for all texts.
147 |
148 | Returns:
149 | The resulting matplotlib Axis drawn on and a dictionary of each node position.
150 |
151 | """
152 | return draw_hebgraph(self, ax, **kwargs)
153 |
--------------------------------------------------------------------------------
/src/hebg/draw.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | import math
5 | from typing import TYPE_CHECKING, Dict, Optional, Tuple
6 |
7 | import matplotlib.patches as mpatches
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | from matplotlib.axes import Axes
11 | from matplotlib.legend import Legend
12 | from matplotlib.legend_handler import HandlerPatch
13 | from networkx import draw_networkx_edges, spring_layout
14 | from scipy.spatial import ConvexHull # pylint: disable=no-name-in-module
15 |
16 | from hebg.graph import draw_networkx_nodes_images
17 | from hebg.layouts import staircase_layout
18 | from hebg.unrolling import group_behaviors_points
19 |
20 | if TYPE_CHECKING:
21 | from hebg.heb_graph import HEBGraph
22 | from hebg.node import Node
23 |
24 |
25 | def draw_hebgraph(
26 | graph: "HEBGraph",
27 | ax: Axes,
28 | pos: Optional[Dict["Node", Tuple[float, float]]] = None,
29 | fontcolor: str = "black",
30 | draw_hulls: bool = False,
31 | show_all_hulls: bool = False,
32 | ) -> Tuple["Axes", Dict["Node", Tuple[float, float]]]:
33 | if len(list(graph.nodes())) == 0:
34 | return
35 |
36 | ax.set_title(graph.behavior.name, fontdict={"color": "orange"})
37 | plt.setp(ax.spines.values(), color="orange")
38 |
39 | if pos is None:
40 | if len(graph.roots) > 0:
41 | pos = staircase_layout(graph)
42 | else:
43 | pos = spring_layout(graph)
44 | draw_networkx_nodes_images(graph, pos, ax=ax, img_zoom=0.5)
45 |
46 | draw_networkx_edges(
47 | graph,
48 | pos,
49 | ax=ax,
50 | arrowsize=20,
51 | arrowstyle="-|>",
52 | min_source_margin=0,
53 | min_target_margin=10,
54 | node_shape="s",
55 | node_size=1500,
56 | edge_color=[color for _, _, color in graph.edges(data="color")],
57 | )
58 |
59 | legend = draw_graph_legend(graph, ax)
60 | plt.setp(legend.get_texts(), color=fontcolor)
61 |
62 | if draw_hulls:
63 | group_and_draw_hulls(graph, pos, ax, show_all_hulls=show_all_hulls)
64 |
65 |
66 | def draw_graph_legend(graph: "HEBGraph", ax: Axes) -> Legend:
67 | used_node_types = [node_type for _, node_type in graph.nodes(data="type")]
68 | legend_patches = [
69 | mpatches.Patch(facecolor="none", edgecolor=color, label=node_type.capitalize())
70 | for node_type, color in graph.NODES_COLORS.items()
71 | if node_type in used_node_types and node_type in graph.NODES_COLORS
72 | ]
73 | used_edge_indexes = [index for _, _, index in graph.edges(data="index")]
74 | legend_arrows = [
75 | mpatches.FancyArrow(
76 | *(0, 0, 1, 0),
77 | facecolor=color,
78 | edgecolor="none",
79 | label=str(index) if index > 1 else f"{str(bool(index))} ({index})",
80 | )
81 | for index, color in graph.EDGES_COLORS.items()
82 | if index in used_edge_indexes and index in graph.EDGES_COLORS
83 | ]
84 |
85 | # Draw the legend
86 | legend = ax.legend(
87 | fancybox=True,
88 | framealpha=0,
89 | fontsize="x-large",
90 | loc="upper right",
91 | handles=legend_patches + legend_arrows,
92 | handler_map={
93 | # Patch arrows with fancy arrows in legend
94 | mpatches.FancyArrow: HandlerPatch(
95 | patch_func=lambda width, height, **kwargs: mpatches.FancyArrow(
96 | *(0, 0.5 * height, width, 0),
97 | width=0.2 * height,
98 | length_includes_head=True,
99 | head_width=height,
100 | overhang=0.5,
101 | )
102 | ),
103 | },
104 | )
105 |
106 | return legend
107 |
108 |
109 | def group_and_draw_hulls(graph: "HEBGraph", pos, ax: Axes, show_all_hulls: bool):
110 | grouped_points = group_behaviors_points(pos, graph)
111 | if not show_all_hulls:
112 | key_count = {key[-1]: 0 for key in grouped_points}
113 | for key in grouped_points:
114 | key_count[key[-1]] += 1
115 | grouped_points = {
116 | key: points
117 | for key, points in grouped_points.items()
118 | if key_count[key[-1]] > 1 and (len(key) == 1 or key[-1] != key[-2])
119 | }
120 |
121 | for group_key, points in grouped_points.items():
122 | stretch = 0.5 - 0.05 * (len(group_key) - 1)
123 | if len(points) >= 3:
124 | draw_convex_hull(points, ax, stretch=stretch, lw=3, color="orange")
125 |
126 |
127 | def draw_convex_hull(points, ax: "Axes", stretch=0.3, n_points=30, **kwargs):
128 | points = np.array(points)
129 | convh = ConvexHull(points) # Get the first convexHull (speeds up the next process)
130 | points = buffer_points(points[convh.vertices], stretch=stretch, samples=n_points)
131 |
132 | hull = ConvexHull(points)
133 | hull_cycle = np.concatenate((hull.vertices, hull.vertices[:1]))
134 | ax.plot(points[hull_cycle, 0], points[hull_cycle, 1], **kwargs)
135 |
136 |
137 | def buffer_points(inside_points, stretch, samples):
138 | new_points = []
139 | for point in inside_points:
140 | new_points += points_in_circum(point, stretch, samples)
141 | new_points = np.array(new_points)
142 | hull = ConvexHull(new_points)
143 | return new_points[hull.vertices]
144 |
145 |
146 | def points_in_circum(points, radius, samples=100):
147 | return [
148 | (
149 | points[0] + math.cos(2 * math.pi / samples * x) * radius,
150 | points[1] + math.sin(2 * math.pi / samples * x) * radius,
151 | )
152 | for x in range(0, samples + 1)
153 | ]
154 |
--------------------------------------------------------------------------------
/tests/test_pet_a_cat.py:
--------------------------------------------------------------------------------
1 | """This examples shows how could we hierarchicaly build a behavior to pet a cat.
2 |
3 | Here is the hierarchical structure that we would want:
4 |
5 | ```
6 | PetACat:
7 | IsThereACatAround ?
8 | -> Yes:
9 | PetNearbyCat
10 | -> No:
11 | LookForACat
12 |
13 | PetNearbyCat:
14 | IsYourHandNearTheCat ?
15 | -> Yes:
16 | Pet
17 | -> No:
18 | MoveYourHandNearTheCat
19 | ```
20 |
21 | """
22 |
23 | import pytest
24 | import pytest_check as check
25 |
26 | import matplotlib.pyplot as plt
27 |
28 | from hebg import HEBGraph, Action, FeatureCondition, Behavior
29 | from hebg.unrolling import unroll_graph
30 | from tests.test_code_generation import _unidiff_output
31 |
32 |
33 | class Pet(Action):
34 | def __init__(self) -> None:
35 | super().__init__(action="Pet")
36 |
37 |
38 | class IsYourHandNearTheCat(FeatureCondition):
39 | def __init__(self, hand) -> None:
40 | super().__init__(name="Is hand near the cat ?")
41 | self.hand = hand
42 |
43 | def __call__(self, observation):
44 | # Could be a very complex function that returns 1 is the hand is near the cat else 0.
45 | if observation["cat"] == observation[self.hand]:
46 | return int(True) # 1
47 | return int(False) # 0
48 |
49 |
50 | class MoveYourHandNearTheCat(Behavior):
51 | def __init__(self) -> None:
52 | super().__init__(name="Move slowly your hand near the cat")
53 |
54 | def __call__(self, observation, *args, **kwargs) -> Action:
55 | # Could be a very complex function that returns actions from any given observation
56 | return Action("Move hand to cat")
57 |
58 |
59 | class PetNearbyCat(Behavior):
60 | def __init__(self) -> None:
61 | super().__init__(name="Pet nearby cat")
62 |
63 | def build_graph(self) -> HEBGraph:
64 | graph = HEBGraph(self)
65 | is_hand_near_cat = IsYourHandNearTheCat(hand="hand")
66 | graph.add_edge(is_hand_near_cat, MoveYourHandNearTheCat(), index=int(False))
67 | graph.add_edge(is_hand_near_cat, Pet(), index=int(True))
68 |
69 | return graph
70 |
71 |
72 | class IsThereACatAround(FeatureCondition):
73 | def __init__(self) -> None:
74 | super().__init__(name="Is there a cat around ?")
75 |
76 | def __call__(self, observation):
77 | # Could be a very complex function that returns 1 is there is a cat around else 0.
78 | if "cat" in observation:
79 | return int(True) # 1
80 | return int(False) # 0
81 |
82 |
83 | class LookForACat(Behavior):
84 | def __init__(self) -> None:
85 | super().__init__(name="Look for a nearby cat")
86 |
87 | def __call__(self, observation, *args, **kwargs) -> Action:
88 | # Could be a very complex function that returns actions from any given observation
89 | return Action("Move to a cat")
90 |
91 |
92 | class PetACat(Behavior):
93 | def __init__(self) -> None:
94 | super().__init__(name="Pet a cat")
95 |
96 | def build_graph(self) -> HEBGraph:
97 | graph = HEBGraph(self)
98 | is_a_cat_around = IsThereACatAround()
99 | graph.add_edge(is_a_cat_around, LookForACat(), index=int(False))
100 | graph.add_edge(is_a_cat_around, PetNearbyCat(), index=int(True))
101 | return graph
102 |
103 |
104 | class TestPetACat:
105 | """PetACat example"""
106 |
107 | @pytest.fixture(autouse=True)
108 | def setup_method(self):
109 | self.pet_nearby_cat_behavior = PetNearbyCat()
110 | self.pet_a_cat_behavior = PetACat()
111 |
112 | def test_call(self):
113 | """should give expected call"""
114 | observation = {
115 | "cat": "sofa",
116 | "hand": "computer",
117 | }
118 |
119 | # Call on observation
120 | action = self.pet_a_cat_behavior(observation)
121 | check.equal(action, Action("Move hand to cat"))
122 |
123 | def test_pet_a_cat_graph_edges(self):
124 | """should give expected edges for PetACat"""
125 | # Obtain networkx graph
126 | graph = self.pet_a_cat_behavior.graph
127 | check.equal(
128 | set(graph.edges(data="index")),
129 | {
130 | ("Is there a cat around ?", "Look for a nearby cat", 0),
131 | ("Is there a cat around ?", "Pet nearby cat", 1),
132 | },
133 | )
134 |
135 | def test_pet_nearby_cat_graph_edges(self):
136 | """should give expected edges for PetNearbyCat"""
137 | # Obtain networkx graph
138 | graph = self.pet_nearby_cat_behavior.graph
139 | check.equal(
140 | set(graph.edges(data="index")),
141 | {
142 | ("Is hand near the cat ?", "Move slowly your hand near the cat", 0),
143 | ("Is hand near the cat ?", "Action(Pet)", 1),
144 | },
145 | )
146 |
147 | def test_draw(self):
148 | """should be able to draw without error"""
149 | fig, ax = plt.subplots()
150 | self.pet_a_cat_behavior.graph.draw(ax)
151 | plt.close(fig)
152 |
153 | def test_draw_unrolled(self):
154 | """should be able to draw without error"""
155 | fig, ax = plt.subplots()
156 | unrolled_graph = unroll_graph(self.pet_a_cat_behavior.graph)
157 | unrolled_graph.draw(ax)
158 | plt.close(fig)
159 |
160 | @pytest.mark.filterwarnings("ignore:Could not load graph for behavior")
161 | def test_codegen(self):
162 | """should generate expected source code"""
163 | code = self.pet_a_cat_behavior.graph.generate_source_code()
164 | expected_code = "\n".join(
165 | (
166 | "from hebg.codegen import GeneratedBehavior",
167 | "",
168 | "# Require 'Look for a nearby cat' behavior to be given.",
169 | "# Require 'Move slowly your hand near the cat' behavior to be given.",
170 | "class PetACat(GeneratedBehavior):",
171 | " def __call__(self, observation):",
172 | " edge_index = self.feature_conditions['Is there a cat around ?'](observation)",
173 | " if edge_index == 0:",
174 | " return self.known_behaviors['Look for a nearby cat'](observation)",
175 | " if edge_index == 1:",
176 | " edge_index_1 = self.feature_conditions['Is hand near the cat ?'](observation)",
177 | " if edge_index_1 == 0:",
178 | " return self.known_behaviors['Move slowly your hand near the cat'](observation)",
179 | " if edge_index_1 == 1:",
180 | " return self.actions['Action(Pet)'](observation)",
181 | )
182 | )
183 | check.equal(code, expected_code, _unidiff_output(code, expected_code))
184 |
--------------------------------------------------------------------------------
/src/hebg/graph.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 | # pylint: disable=protected-access
4 |
5 | """Additional utility functions for networkx graphs."""
6 |
7 | from typing import TYPE_CHECKING, Any, Dict, List
8 |
9 | from matplotlib.axes import Axes
10 | from matplotlib.offsetbox import AnnotationBbox, OffsetImage
11 | from networkx import DiGraph
12 |
13 | if TYPE_CHECKING:
14 | from hebg.heb_graph import HEBGraph
15 | from hebg.node import Node
16 |
17 |
18 | def get_roots(graph: DiGraph):
19 | """Finds roots in a DiGraph.
20 |
21 | Args:
22 | graph: A networkx DiGraph.
23 |
24 | Returns:
25 | List of root nodes.
26 |
27 | """
28 | roots = []
29 | for node in graph.nodes():
30 | if len(list(graph.predecessors(node))) == 0:
31 | roots.append(node)
32 | return roots
33 |
34 |
35 | def get_nodes_by_level(graph: DiGraph) -> Dict[int, Any]:
36 | """Get the dictionary of nodes by level.
37 |
38 | Requires nodes to have a 'level' attribute.
39 |
40 | Args:
41 | graph: A networkx DiGraph.
42 |
43 | Returns:
44 | Dictionary of nodes by level.
45 |
46 | """
47 | nodes_by_level = {}
48 | for node in graph.nodes():
49 | level = graph.nodes[node]["level"]
50 | try:
51 | nodes_by_level[level].append(node)
52 | except KeyError:
53 | nodes_by_level[level] = [node]
54 |
55 | graph.graph["nodes_by_level"] = nodes_by_level
56 | graph.graph["depth"] = max(level for level in nodes_by_level)
57 | return nodes_by_level
58 |
59 |
60 | def compute_levels(graph: DiGraph):
61 | """Compute the hierachical levels of all DiGraph nodes.
62 |
63 | Adds the attribute 'level' to each node in the given graph.
64 | Adds the attribute 'nodes_by_level' to the given graph.
65 | Adds the attribute 'depth' to the given graph.
66 |
67 | Args:
68 | graph: A networkx DiGraph.
69 |
70 | Returns:
71 | Dictionary of nodes by level.
72 |
73 | """
74 |
75 | def _compute_level_dependencies(graph: DiGraph, node):
76 | predecessors = list(graph.predecessors(node))
77 | if len(predecessors) == 0:
78 | graph.nodes[node]["level"] = 0
79 | return True
80 |
81 | pred_level_by_index = {}
82 | for pred in predecessors:
83 | index = graph.edges[pred, node]["index"]
84 | try:
85 | pred_level = graph.nodes[pred]["level"]
86 | except KeyError:
87 | pred_level = None
88 |
89 | if index in pred_level_by_index:
90 | pred_level_by_index[index].append(pred_level)
91 | else:
92 | pred_level_by_index[index] = [pred_level]
93 |
94 | min_level_by_index = []
95 | for index, level_list in pred_level_by_index.items():
96 | level_list_wo_none = [lvl for lvl in level_list if lvl is not None]
97 | if len(level_list_wo_none) == 0:
98 | return False
99 | min_level_by_index.append(min(level_list_wo_none))
100 | level = 1 + max(min_level_by_index)
101 | graph.nodes[node]["level"] = level
102 | return True
103 |
104 | for _ in range(len(graph.nodes())):
105 | all_nodes_have_level = True
106 | incomplete_nodes = []
107 | for node in graph.nodes():
108 | incomplete = not _compute_level_dependencies(graph, node)
109 | if incomplete:
110 | incomplete_nodes.append(node)
111 | all_nodes_have_level = False
112 | if all_nodes_have_level:
113 | break
114 |
115 | if not all_nodes_have_level:
116 | raise ValueError(
117 | "Could not attribute levels to all nodes. "
118 | f"Incomplete nodes: {incomplete_nodes}"
119 | )
120 |
121 | return get_nodes_by_level(graph)
122 |
123 |
124 | def compute_edges_color(graph: DiGraph):
125 | """Compute the edges colors of a leveled graph for readability.
126 |
127 | Requires nodes to have a 'level' attribute.
128 | Adds the attribute 'color' and 'linestyle' to each edge in the given graph.
129 | Nodes with a lot of successors will have more transparent edges.
130 | Edges going from high to low level will be dashed.
131 |
132 | Args:
133 | graph: A networkx DiGraph.
134 |
135 | """
136 | alphas = [1, 1, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3]
137 | for node in graph.nodes():
138 | successors = list(graph.successors(node))
139 | for succ in successors:
140 | alpha = 0.2
141 | if graph.nodes[node]["level"] < graph.nodes[succ]["level"]:
142 | if len(successors) < len(alphas):
143 | alpha = alphas[len(successors) - 1]
144 | else:
145 | graph.edges[node, succ]["linestyle"] = "dashed"
146 | if isinstance(graph.edges[node, succ]["color"], list):
147 | graph.edges[node, succ]["color"][3] = alpha
148 |
149 |
150 | def draw_networkx_nodes_images(
151 | graph: DiGraph, pos, ax: Axes, img_zoom: float = 1, **kwargs
152 | ):
153 | """Draw nodes images of a networkx DiGraph on a given matplotlib ax.
154 |
155 | Requires nodes to have attributes 'image' for the node image and 'color' for the border color.
156 |
157 | Args:
158 | graph: A networkx DiGraph.
159 | pos: Layout positions of the graph.
160 | ax: A matplotlib Axes.
161 | img_zoom (Optional): Zoom to apply to images.
162 |
163 | """
164 | for n in graph:
165 | img = graph.nodes(data="image", default=None)[n]
166 | color = graph.nodes(data="color", default="black")[n]
167 | if img is not None:
168 | min_dim = min(img.shape[:2])
169 | min_ax_shape = min(ax._position.width, ax._position.height)
170 | zoom = 100 * img_zoom * min_ax_shape / min_dim
171 | imagebox = OffsetImage(img, zoom=zoom)
172 | imagebox = AnnotationBbox(
173 | imagebox, pos[n], frameon=True, box_alignment=(0.5, 0.5)
174 | )
175 |
176 | imagebox.patch.set_facecolor("None")
177 | imagebox.patch.set_edgecolor(color)
178 | imagebox.patch.set_linewidth(3)
179 | imagebox.patch.set_boxstyle("round", pad=0.15)
180 | ax.add_artist(imagebox)
181 | else:
182 | # If no image is found, draw label instead
183 | (x, y) = pos[n]
184 | label = str(n)
185 | ax.text(
186 | x,
187 | y,
188 | label,
189 | size=kwargs.get("font_size", 12),
190 | color=kwargs.get("font_color", "k"),
191 | family=kwargs.get("font_family", "sans-serif"),
192 | weight=kwargs.get("font_weight", "normal"),
193 | alpha=kwargs.get("alpha"),
194 | horizontalalignment=kwargs.get("horizontalalignment", "center"),
195 | verticalalignment=kwargs.get("verticalalignment", "center"),
196 | transform=ax.transData,
197 | bbox=dict(facecolor="none", edgecolor=color),
198 | clip_on=True,
199 | )
200 |
201 |
202 | def get_successors_with_index(
203 | graph: "HEBGraph", node: "Node", next_edge_index: int
204 | ) -> List["Node"]:
205 | succs = graph.successors(node)
206 | next_nodes = []
207 | for next_node in succs:
208 | if int(graph.edges[node, next_node]["index"]) == next_edge_index:
209 | next_nodes.append(next_node)
210 | if len(next_nodes) == 0:
211 | raise ValueError(
212 | f"FeatureCondition {node} returned index {next_edge_index}"
213 | f" but {next_edge_index} was not found as an edge index"
214 | )
215 | return next_nodes
216 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | HEBG - Hierachical Explainable Behaviors using Graphs
2 | =====================================================
3 |
4 | .. image:: https://badge.fury.io/py/hebg.svg
5 | :alt: [Fury - PyPi stable version]
6 | :target: https://badge.fury.io/py/hebg
7 |
8 | .. image:: https://static.pepy.tech/badge/hebg
9 | :alt: [PePy - Downloads]
10 | :target: https://pepy.tech/project/hebg
11 |
12 | .. image:: https://static.pepy.tech/badge/hebg/week
13 | :alt: [PePy - Downloads per week]
14 | :target: https://pepy.tech/project/hebg
15 |
16 | .. image:: https://app.codacy.com/project/badge/Grade/ec4b296d18f4412398d64a66224c66dd
17 | :alt: [Codacy - Grade]
18 | :target: https://www.codacy.com/gh/IRLL/HEB_graphs/dashboard?utm_source=github.com&utm_medium=referral&utm_content=IRLL/HEB_graphs&utm_campaign=Badge_Grade
19 |
20 | .. image:: https://app.codacy.com/project/badge/Coverage/ec4b296d18f4412398d64a66224c66dd
21 | :alt: [Codacy - Coverage]
22 | :target: https://www.codacy.com/gh/IRLL/HEB_graphs/dashboard?utm_source=github.com&utm_medium=referral&utm_content=IRLL/HEB_graphs&utm_campaign=Badge_Coverage
23 |
24 | .. image:: https://img.shields.io/badge/code%20style-black-000000.svg
25 | :alt: [CodeStyle - Black]
26 | :target: https://github.com/psf/black
27 |
28 | .. image:: https://img.shields.io/github/license/MathisFederico/Crafting?style=plastic
29 | :alt: [Licence - GPLv3]
30 | :target: https://www.gnu.org/licenses/
31 |
32 |
33 | This package is meant to build programatic hierarchical behaviors as graphs
34 | to compare them to human explanations of behavior.
35 |
36 | We take the definition of "behavior" as a function from observation to action.
37 |
38 |
39 | Installation
40 | ------------
41 |
42 |
43 | .. code-block:: sh
44 |
45 | pip install hebg
46 |
47 |
48 | Usage
49 | -----
50 |
51 | Build a HEBGraph
52 | ~~~~~~~~~~~~~~~~
53 |
54 | Here is an example to show how could we hierarchicaly build an explanable behavior to pet a cat.
55 |
56 | .. code-block:: py3
57 |
58 | """
59 |
60 | Here is the hierarchical structure that we would want:
61 |
62 | ```
63 | PetACat:
64 | IsThereACatAround ?
65 | -> Yes:
66 | PetNearbyCat
67 | -> No:
68 | LookForACat
69 |
70 | PetNearbyCat:
71 | IsYourHandNearTheCat ?
72 | -> Yes:
73 | Pet
74 | -> No:
75 | MoveYourHandNearTheCat
76 | ```
77 |
78 | """
79 |
80 | from hebg import HEBGraph, Action, FeatureCondition, Behavior
81 | from hebg.unrolling import unroll_graph
82 |
83 | # Add a fundamental action
84 | class Pet(Action):
85 | def __init__(self) -> None:
86 | super().__init__(action="Pet")
87 |
88 | # Add a condition on the observation
89 | class IsYourHandNearTheCat(FeatureCondition):
90 | def __init__(self, hand) -> None:
91 | super().__init__(name="Is hand near the cat ?")
92 | self.hand = hand
93 | def __call__(self, observation) -> int:
94 | # Could be a very complex function that returns 1 is the hand is near the cat else 0.
95 | if observation["cat"] == observation[self.hand]:
96 | return int(True) # 1
97 | return int(False) # 0
98 |
99 | # Add an unexplainable Behavior (without a graph, but a function that can be called).
100 | class MoveYourHandNearTheCat(Behavior):
101 | def __init__(self) -> None:
102 | super().__init__(name="Move slowly your hand near the cat")
103 | def __call__(self, observation, *args, **kwargs) -> Action:
104 | # Could be a very complex function that returns actions from any given observation
105 | return Action("Move hand to cat")
106 |
107 | # Add a sub-behavior
108 | class PetNearbyCat(Behavior):
109 | def __init__(self) -> None:
110 | super().__init__(name="Pet nearby cat")
111 | def build_graph(self) -> HEBGraph:
112 | graph = HEBGraph(self)
113 | is_hand_near_cat = IsYourHandNearTheCat(hand="hand")
114 | graph.add_edge(is_hand_near_cat, MoveYourHandNearTheCat(), index=int(False))
115 | graph.add_edge(is_hand_near_cat, Pet(), index=int(True))
116 | return graph
117 |
118 | # Add an other condition on observation
119 | class IsThereACatAround(FeatureCondition):
120 | def __init__(self) -> None:
121 | super().__init__(name="Is there a cat around ?")
122 | def __call__(self, observation) -> int:
123 | # Could be a very complex function that returns 1 is there is a cat around else 0.
124 | if "cat" in observation:
125 | return int(True) # 1
126 | return int(False) # 0
127 |
128 | # Add an other unexplainable Behavior (without a graph, but a function that can be called).
129 | class LookForACat(Behavior):
130 | def __init__(self) -> None:
131 | super().__init__(name="Look for a nearby cat")
132 | def __call__(self, observation, *args, **kwargs) -> Action:
133 | # Could be a very complex function that returns actions from any given observation
134 | return Action("Move to a cat")
135 |
136 | # Finally, add the main Behavior
137 | class PetACat(Behavior):
138 | def __init__(self) -> None:
139 | super().__init__(name="Pet a cat")
140 | def build_graph(self) -> HEBGraph:
141 | graph = HEBGraph(self)
142 | is_a_cat_around = IsThereACatAround()
143 | graph.add_edge(is_a_cat_around, LookForACat(), index=int(False))
144 | graph.add_edge(is_a_cat_around, PetNearbyCat(), index=int(True))
145 | return graph
146 |
147 | if __name__ == "__main__":
148 | pet_a_cat_behavior = PetACat()
149 | observation = {
150 | "cat": "sofa",
151 | "hand": "computer",
152 | }
153 |
154 | # Call on observation
155 | action = pet_a_cat_behavior(observation)
156 | print(action) # Action("Move hand to cat")
157 |
158 | # Obtain networkx graph
159 | graph = pet_a_cat_behavior.graph
160 | print(list(graph.edges(data="index")))
161 |
162 | # Draw graph using matplotlib
163 | import matplotlib.pyplot as plt
164 | fig, ax = plt.subplots()
165 | graph.draw(ax)
166 | plt.show()
167 |
168 |
169 | .. image:: docs/images/PetACatGraph.png
170 | :align: center
171 |
172 | Unrolling HEBGraph
173 | ~~~~~~~~~~~~~~~~~~
174 |
175 | When ploting an HEBGraph of a behavior, only the graph of the behavior itself is shown.
176 | To see the full hierarchical graph (including sub-behaviors), we need to unroll the graph as such:
177 |
178 | .. code-block:: py3
179 |
180 | from hebg.unrolling import unroll_graph
181 |
182 | unrolled_graph = unroll_graph(pet_a_cat_behavior.graph, add_prefix=False)
183 |
184 | # Is also a networkx graph
185 | print(list(unrolled_graph.edges(data="index")))
186 |
187 | # Draw graph using matplotlib
188 | import matplotlib.pyplot as plt
189 | fig, ax = plt.subplots()
190 | unrolled_graph.draw(ax)
191 | plt.show()
192 |
193 |
194 | .. image:: docs/images/PetACatGraphUnrolled.png
195 | :align: center
196 |
197 | Note that unexplainable behaviors (the one without graphs) are kept as is.
198 |
199 | Python code generation from graph
200 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
201 |
202 | Once you have a HEBGraph, you can use it to generate a working python code that
203 | replicates the HEBGraph's behavior:
204 |
205 | .. code-block:: py3
206 |
207 | code = pet_a_cat_behavior.graph.generate_source_code()
208 | with open("pet_a_cat.py", "w") as pyfile:
209 | pyfile.write(code)
210 |
211 | Will generate the code bellow:
212 |
213 | .. code-block:: py3
214 |
215 | from hebg.codegen import GeneratedBehavior
216 |
217 | # Require 'Look for a nearby cat' behavior to be given.
218 | # Require 'Move slowly your hand near the cat' behavior to be given.
219 | class PetTheCat(GeneratedBehavior):
220 | def __call__(self, observation) -> Any:
221 | edge_index = self.feature_conditions['Is there a cat around ?'](observation)
222 | if edge_index == 0:
223 | return self.known_behaviors['Look for a nearby cat'](observation)
224 | if edge_index == 1:
225 | edge_index_1 = self.feature_conditions['Is hand near the cat ?'](observation)
226 | if edge_index_1 == 0:
227 | return self.known_behaviors['Move slowly your hand near the cat'](observation)
228 | if edge_index_1 == 1:
229 | return self.actions['Action(Pet)'](observation)
230 |
231 |
232 | Contributing to HEBG
233 | --------------------
234 |
235 | Whenever you encounter a :bug: **bug** or have :tada: **feature request**,
236 | report this via `Github issues `_.
237 |
238 | If you wish to contribute directly, see `CONTRIBUTING `_
239 |
--------------------------------------------------------------------------------
/src/hebg/unrolling.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """Module to unroll HEBGraph.
5 |
6 | Unrolling means expanding each sub-behavior node as it's own graph in the global HEBGraph.
7 | Behaviors that do not have a graph (Unexplainable behaviors) should stay as is in the graph.
8 |
9 | """
10 |
11 | from copy import copy
12 | from typing import TYPE_CHECKING, Dict, List, Tuple, Optional, Union
13 |
14 | from networkx import relabel_nodes
15 |
16 | from hebg.behavior import Behavior
17 |
18 | BEHAVIOR_SEPARATOR = ">"
19 |
20 | if TYPE_CHECKING:
21 | from hebg import HEBGraph, Node, Action
22 |
23 |
24 | def unroll_graph(
25 | graph: "HEBGraph",
26 | add_prefix: bool = False,
27 | cut_looping_alternatives: bool = False,
28 | ) -> "HEBGraph":
29 | """Build the the unrolled HEBGraph.
30 |
31 | The HEBGraph as the same behavior but every behavior node is recursively replaced
32 | by it's own HEBGraph if it can be computed.
33 |
34 | Args:
35 | graph (HEBGraph): HEBGraph to unroll the behavior in.
36 | add_prefix (bool, optional): If True, adds a name prefix to keep nodes different.
37 | Defaults to False.
38 | cut_looping_alternatives (bool, optional): If True, cut the looping alternatives.
39 | Defaults to False.
40 |
41 | Returns:
42 | HEBGraph: This HEBGraph's unrolled HEBGraph.
43 | """
44 | unrolled_graph, _is_looping = _unroll_graph(
45 | graph,
46 | add_prefix=add_prefix,
47 | cut_looping_alternatives=cut_looping_alternatives,
48 | )
49 | return unrolled_graph
50 |
51 |
52 | def _unroll_graph(
53 | graph: "HEBGraph",
54 | add_prefix: bool = False,
55 | cut_looping_alternatives: bool = False,
56 | _current_alternatives: Optional[List[Union["Action", "Behavior"]]] = None,
57 | _unrolled_behaviors: Optional[Dict[str, Optional["HEBGraph"]]] = None,
58 | ) -> Tuple["HEBGraph", bool]:
59 | if _unrolled_behaviors is None:
60 | _unrolled_behaviors = {}
61 | if _current_alternatives is None:
62 | _current_alternatives = {0: []}
63 |
64 | is_looping = False
65 | _unrolled_behaviors[graph.behavior.name] = None
66 |
67 | unrolled_graph: "HEBGraph" = copy(graph)
68 | for node in list(unrolled_graph.nodes()):
69 | if not isinstance(node, Behavior):
70 | continue
71 |
72 | _current_alternatives[0] = _direct_alternatives(node, graph)
73 | _current_alternatives[1] = _roots_alternatives(node, graph)
74 | unrolled_graph, behavior_is_looping = _unroll_behavior(
75 | unrolled_graph,
76 | node,
77 | add_prefix,
78 | cut_looping_alternatives,
79 | _current_alternatives,
80 | _unrolled_behaviors,
81 | )
82 | if behavior_is_looping:
83 | is_looping = True
84 |
85 | return unrolled_graph, is_looping
86 |
87 |
88 | def _direct_alternatives(node: "Node", graph: "HEBGraph") -> list["Node"]:
89 | alternatives = []
90 | for pred, _node, data in graph.in_edges(node, data=True):
91 | index = data["index"]
92 | for _pred, alternative, alt_index in graph.out_edges(pred, data="index"):
93 | if index != alt_index or alternative == node:
94 | continue
95 | alternatives.append(alternative)
96 | return alternatives
97 |
98 |
99 | def _roots_alternatives(node: "Node", graph: "HEBGraph") -> list["Node"]:
100 | alternatives = []
101 | for pred, _node, _data in graph.in_edges(node, data=True):
102 | if pred in graph.roots:
103 | alternatives.extend([r for r in graph.roots if r != pred])
104 | return alternatives
105 |
106 |
107 | def _unroll_behavior(
108 | graph: "HEBGraph",
109 | behavior: "Behavior",
110 | add_prefix: bool,
111 | cut_looping_alternatives: bool,
112 | _current_alternatives: List[Union["Action", "Behavior"]],
113 | _unrolled_behaviors: Dict[str, Optional["HEBGraph"]],
114 | ) -> Tuple["HEBGraph", bool]:
115 | """Unroll a behavior node in a given HEBGraph
116 |
117 | Args:
118 | graph (HEBGraph): HEBGraph to unroll the behavior in.
119 | behavior (Behavior): Behavior node to unroll, must be in the given graph.
120 | add_prefix (bool): If True, adds a name prefix to keep nodes different.
121 | cut_looping_alternatives (bool): If True, cut the looping alternatives.
122 |
123 | Returns:
124 | HEBGraph: Initial graph with unrolled behavior.
125 | """
126 | # Look for name reference.
127 | if behavior.name in graph.all_behaviors:
128 | behavior = graph.all_behaviors[behavior.name]
129 |
130 | node_graph, is_looping = _unrolled_behavior_graph(
131 | behavior,
132 | add_prefix,
133 | cut_looping_alternatives,
134 | _current_alternatives,
135 | _unrolled_behaviors,
136 | )
137 |
138 | if is_looping and cut_looping_alternatives:
139 | for alternative in _current_alternatives[0]:
140 | for last_condition, _, data in graph.in_edges(behavior, data=True):
141 | graph.add_edge(last_condition, alternative, **data)
142 | if _current_alternatives[0]:
143 | graph.remove_node(behavior)
144 | return graph, False
145 | if _current_alternatives[1]:
146 | predecessors = list(graph.predecessors(behavior))
147 | for last_condition in predecessors:
148 | successors = list(graph.successors(last_condition))
149 | for descendant in successors:
150 | graph.remove_edge(last_condition, descendant)
151 | if graph.neighbors(descendant) == 0:
152 | graph.remove_node(descendant)
153 | graph.remove_node(last_condition)
154 | graph.remove_node(behavior)
155 | return graph, False
156 | raise NotImplementedError()
157 |
158 | if node_graph is None:
159 | # If we cannot get the node's graph, we keep it as is.
160 | return graph, is_looping
161 |
162 | # Relabel graph nodes to obtain disjoint node labels (if more that one node).
163 | if add_prefix and len(node_graph.nodes()) > 1:
164 | _add_prefix_to_graph(node_graph, behavior.name + BEHAVIOR_SEPARATOR)
165 |
166 | # Replace the behavior node by the unrolled behavior's graph
167 | graph = compose_heb_graphs(graph, node_graph)
168 | for edge_u, _, data in graph.in_edges(behavior, data=True):
169 | for root in node_graph.roots:
170 | graph.add_edge(edge_u, root, **data)
171 |
172 | graph.remove_node(behavior)
173 | return graph, is_looping
174 |
175 |
176 | def _unrolled_behavior_graph(
177 | behavior: "Behavior",
178 | add_prefix: bool,
179 | cut_looping_alternatives: bool,
180 | _current_alternatives: List[Union["Action", "Behavior"]],
181 | _unrolled_behaviors: Dict[str, Optional["HEBGraph"]],
182 | ) -> Tuple[Optional["HEBGraph"], bool]:
183 | """Get the unrolled sub-graph of a behavior.
184 |
185 | Args:
186 | behavior (Behavior): Behavior to get the unrolled graph of.
187 | add_prefix (bool): If True, adds a prefix in sub-hierarchies to have distinct nodes.
188 | cut_looping_alternatives (bool): If True, cut the looping alternatives.
189 | _unrolled_behaviors (Dict[str, Optional[HEBGraph]]): Dictionary of already computed
190 | unrolled graphs, both to save compute and prevent recursion loops.
191 |
192 | Returns:
193 | Optional[HEBGraph]: Unrolled graph of a behavior, None if it cannot be computed.
194 | """
195 | if behavior.name in _unrolled_behaviors:
196 | # If we have aleardy unrolled this behavior, we reuse it's graph
197 | is_looping = _unrolled_behaviors[behavior.name] is None
198 | return _unrolled_behaviors[behavior.name], is_looping
199 |
200 | try:
201 | node_graph, is_looping = _unroll_graph(
202 | behavior.graph,
203 | add_prefix=add_prefix,
204 | cut_looping_alternatives=cut_looping_alternatives,
205 | _current_alternatives=_current_alternatives,
206 | _unrolled_behaviors=_unrolled_behaviors,
207 | )
208 | _unrolled_behaviors[behavior.name] = node_graph
209 | return node_graph, is_looping
210 | except NotImplementedError:
211 | return None, False
212 |
213 |
214 | def _add_prefix_to_graph(graph: "HEBGraph", prefix: str) -> None:
215 | """Rename graph to obtain disjoint node labels."""
216 | if prefix is None:
217 | return graph
218 |
219 | def rename(node: "Node") -> None:
220 | new_node = copy(node)
221 | new_node.name = prefix + node.name
222 | return new_node
223 |
224 | return relabel_nodes(graph, rename, copy=False)
225 |
226 |
227 | def group_behaviors_points(
228 | pos: Dict["Node", tuple],
229 | graph: "HEBGraph",
230 | ) -> Dict[tuple, list]:
231 | """Group nodes positions of an HEBGraph by sub-behavior.
232 |
233 | Args:
234 | pos (Dict[Node, tuple]): Positions of nodes.
235 | graph (HEBGraph): Graph.
236 |
237 | Returns:
238 | Dict[tuple, list]: A dictionary of nodes grouped by their behavior's hierarchy.
239 | """
240 | points_grouped_by_behavior: Dict[tuple, list] = {}
241 | for node in graph.nodes():
242 | groups = str(node).split(BEHAVIOR_SEPARATOR)
243 | if len(groups) > 1:
244 | for i in range(len(groups[:-1])):
245 | key = tuple(groups[: -1 - i])
246 | point = pos[node]
247 | if key in points_grouped_by_behavior:
248 | points_grouped_by_behavior[key].append(point)
249 | else:
250 | points_grouped_by_behavior[key] = [point]
251 | return points_grouped_by_behavior
252 |
253 |
254 | def compose_heb_graphs(graph_of_reference: "HEBGraph", other_graph: "HEBGraph") -> None:
255 | """Returns a new_graph of graph_of_reference composed with other_graph.
256 |
257 | Composition is the simple union of the node sets and edge sets.
258 | The node sets of the graph_of_reference and other_graph do not need to be disjoint.
259 |
260 | Args:
261 | graph_of_reference, other_graph : HEBGraphs to compose.
262 |
263 | Returns:
264 | A new HEBGraph with the same type as graph_of_reference.
265 |
266 | """
267 | new_graph = graph_of_reference.__class__(
268 | graph_of_reference.behavior, all_behaviors=graph_of_reference.all_behaviors
269 | )
270 | # add graph attributes, H attributes take precedent over G attributes
271 | new_graph.graph.update(graph_of_reference.graph)
272 | new_graph.graph.update(other_graph.graph)
273 |
274 | new_graph.add_nodes_from(graph_of_reference.nodes(data=True))
275 | new_graph.add_nodes_from(other_graph.nodes(data=True))
276 |
277 | if graph_of_reference.is_multigraph():
278 | new_graph.add_edges_from(graph_of_reference.edges(keys=True, data=True))
279 | else:
280 | new_graph.add_edges_from(graph_of_reference.edges(data=True))
281 | if other_graph.is_multigraph():
282 | new_graph.add_edges_from(other_graph.edges(keys=True, data=True))
283 | else:
284 | new_graph.add_edges_from(other_graph.edges(data=True))
285 | return new_graph
286 |
--------------------------------------------------------------------------------
/src/hebg/call_graph.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import (
3 | TYPE_CHECKING,
4 | Any,
5 | Dict,
6 | List,
7 | NamedTuple,
8 | Optional,
9 | Tuple,
10 | TypeVar,
11 | Union,
12 | )
13 | from matplotlib import pyplot as plt
14 | from matplotlib.axes import Axes
15 |
16 | from networkx import (
17 | DiGraph,
18 | all_simple_paths,
19 | draw_networkx_edges,
20 | draw_networkx_labels,
21 | draw_networkx_nodes,
22 | ancestors,
23 | )
24 | import numpy as np
25 | from hebg.behavior import Behavior
26 | from hebg.graph import get_successors_with_index
27 | from hebg.node import Action, FeatureCondition, Node
28 |
29 | if TYPE_CHECKING:
30 | from hebg.heb_graph import HEBGraph
31 |
32 | EnvAction = TypeVar("EnvAction")
33 |
34 |
35 | class CallGraph(DiGraph):
36 | def __init__(self, **attr) -> None:
37 | super().__init__(incoming_graph_data=None, **attr)
38 | self.graph["n_branches"] = 0
39 | self.graph["n_calls"] = 0
40 | self.graph["frontiere"] = []
41 | self._known_fc: Dict[FeatureCondition, Any] = {}
42 | self._current_node = CallNode(0, 0)
43 |
44 | def add_root(self, heb_node: "Node", heb_graph: "HEBGraph", **kwargs) -> None:
45 | self.add_node(
46 | self._current_node, heb_node=heb_node, heb_graph=heb_graph, **kwargs
47 | )
48 |
49 | def call_nodes(
50 | self, nodes: List["Node"], observation, heb_graph: "HEBGraph"
51 | ) -> EnvAction:
52 | self._extend_frontiere(nodes, heb_graph)
53 | action = None
54 |
55 | while len(self.graph["frontiere"]) > 0 and action is None:
56 | next_call_node = self._pop_from_frontiere()
57 | if next_call_node is None:
58 | break
59 |
60 | node: "Node" = self.nodes[next_call_node]["heb_node"]
61 | heb_graph: "HEBGraph" = self.nodes[next_call_node]["heb_graph"]
62 |
63 | if node.type == "behavior":
64 | # Search for name reference in all_behaviors
65 | if node.name in heb_graph.all_behaviors:
66 | node = heb_graph.all_behaviors[node.name]
67 | if not hasattr(node, "_graph") or node._graph is None:
68 | action = node(observation)
69 | break
70 | self._extend_frontiere(node.graph.roots, heb_graph=node.graph)
71 | elif node.type == "action":
72 | action = node(observation)
73 | break
74 | elif node.type == "feature_condition":
75 | if node in self._known_fc:
76 | next_edge_index = self._known_fc[node]
77 | else:
78 | next_edge_index = int(node(observation))
79 | self._known_fc[node] = next_edge_index
80 | next_nodes = get_successors_with_index(heb_graph, node, next_edge_index)
81 | self._extend_frontiere(next_nodes, heb_graph)
82 | elif node.type == "empty":
83 | self._extend_frontiere(list(heb_graph.successors(node)), heb_graph)
84 | else:
85 | raise ValueError(
86 | f"Unknowed value {node.type} for node.type with node: {node}."
87 | )
88 |
89 | if action is None:
90 | raise ValueError("No valid frontiere left in call_graph")
91 |
92 | return action
93 |
94 | def call_edge_labels(self) -> list[tuple]:
95 | return [
96 | (self.nodes[u]["label"], self.nodes[v]["label"]) for u, v in self.edges()
97 | ]
98 |
99 | def add_node(
100 | self, node_for_adding, heb_node: "Node", heb_graph: "HEBGraph", **attr
101 | ):
102 | super().add_node(
103 | node_for_adding,
104 | heb_graph=heb_graph,
105 | heb_node=heb_node,
106 | label=heb_node.name,
107 | **attr,
108 | )
109 |
110 | def add_edge(
111 | self,
112 | u_of_edge,
113 | v_of_edge,
114 | status: "CallEdgeStatus",
115 | **attr,
116 | ):
117 | return super().add_edge(u_of_edge, v_of_edge, status=status.value, **attr)
118 |
119 | def _make_new_branch(self) -> int:
120 | self.graph["n_branches"] += 1
121 | return self.graph["n_branches"]
122 |
123 | def _extend_frontiere(self, nodes: List["Node"], heb_graph: "HEBGraph") -> None:
124 | frontiere: List[CallNode] = self.graph["frontiere"]
125 |
126 | parent = self._current_node
127 | call_nodes = []
128 |
129 | for i, node in enumerate(nodes):
130 | if i > 0:
131 | branch_id = self._make_new_branch()
132 | else:
133 | branch_id = parent.branch
134 | call_node = CallNode(branch_id, parent.rank + 1)
135 |
136 | if node.name in heb_graph.all_behaviors:
137 | node = heb_graph.all_behaviors[node.name]
138 | self.add_node(call_node, heb_node=node, heb_graph=heb_graph)
139 | self.add_edge(parent, call_node, CallEdgeStatus.UNEXPLORED)
140 | call_nodes.append(call_node)
141 |
142 | frontiere.extend(call_nodes)
143 |
144 | def _heb_node_from_call_node(self, node: "CallNode") -> "Node":
145 | return self.nodes[node]["heb_node"]
146 |
147 | def _pop_from_frontiere(self) -> Optional["CallNode"]:
148 | frontiere: List["CallNode"] = self.graph["frontiere"]
149 |
150 | next_node = None
151 |
152 | while next_node is None:
153 | if not frontiere:
154 | return None
155 |
156 | next_call_node = frontiere.pop(
157 | np.argmin(
158 | [
159 | self._heb_node_from_call_node(node).complexity
160 | for node in frontiere
161 | ]
162 | )
163 | )
164 | maybe_next_node = self._heb_node_from_call_node(next_call_node)
165 | # Nodes should only have one parent
166 | parent = list(self.predecessors(next_call_node))[0]
167 |
168 | if isinstance(maybe_next_node, Behavior) and maybe_next_node in [
169 | self._heb_node_from_call_node(node)
170 | for node in ancestors(self, next_call_node)
171 | ]:
172 | self._update_edge_status(parent, next_call_node, CallEdgeStatus.FAILURE)
173 | continue
174 |
175 | next_node = maybe_next_node
176 |
177 | self.graph["n_calls"] += 1
178 | self.nodes[next_call_node]["call_rank"] = 1
179 | self._update_edge_status(parent, next_call_node, CallEdgeStatus.CALLED)
180 | self._current_node = next_call_node
181 | return next_call_node
182 |
183 | def _update_edge_status(
184 | self, start: "Node", end: "Node", status: Union["CallEdgeStatus", str]
185 | ):
186 | status = CallEdgeStatus(status)
187 | self.edges[start, end]["status"] = status.value
188 |
189 | def draw(
190 | self,
191 | ax: Optional[Axes] = None,
192 | pos: Optional[Dict[str, Tuple[float, float]]] = None,
193 | nodes_kwargs: Optional[dict] = None,
194 | label_kwargs: Optional[dict] = None,
195 | edges_kwargs: Optional[dict] = None,
196 | ):
197 | if pos is None:
198 | pos = _call_graph_pos(self)
199 | if nodes_kwargs is None:
200 | nodes_kwargs = {}
201 |
202 | if ax is None:
203 | ax = plt.gca()
204 |
205 | pos_arr = np.array(list(pos.values()))
206 | max_x, max_y = pos_arr.max(axis=0)
207 | min_x, min_y = pos_arr.min(axis=0)
208 | y_range = max_y - min_y
209 | x_range = max_x - min_x
210 | ax.set_ylim([min_y - 0.1 * y_range, max_y + 0.1 * y_range])
211 | ax.set_xlim([min_x - 0.1 * x_range, max_x + 0.1 * x_range])
212 |
213 | nodes_complexity = np.array(
214 | [node_data["heb_node"].complexity for _, node_data in self.nodes(data=True)]
215 | )
216 | complexity_range = nodes_complexity.max() - nodes_complexity.min()
217 |
218 | nodes_complexity_scaled = (
219 | 50 + 600 * (nodes_complexity - nodes_complexity.min()) / complexity_range
220 | )
221 |
222 | draw_networkx_nodes(
223 | self,
224 | node_color=[
225 | _node_color(node_data["heb_node"])
226 | for _, node_data in self.nodes(data=True)
227 | ],
228 | node_size=nodes_complexity_scaled,
229 | ax=ax,
230 | pos=pos,
231 | **nodes_kwargs,
232 | )
233 | if label_kwargs is None:
234 | label_kwargs = {}
235 | draw_networkx_labels(
236 | self,
237 | labels={
238 | node: f"{node_data['label']}"
239 | for node, node_data in self.nodes(data=True)
240 | },
241 | ax=ax,
242 | horizontalalignment="center",
243 | verticalalignment="center",
244 | font_size=8,
245 | pos=pos,
246 | **nodes_kwargs,
247 | )
248 | if edges_kwargs is None:
249 | edges_kwargs = {}
250 | if "connectionstyle" not in edges_kwargs:
251 | edges_kwargs.update(connectionstyle="angle,angleA=0,angleB=90,rad=5")
252 | draw_networkx_edges(
253 | self,
254 | ax=ax,
255 | pos=pos,
256 | arrowstyle="-",
257 | alpha=0.5,
258 | width=3,
259 | node_size=1,
260 | edge_color=[
261 | _call_status_to_color(status)
262 | for _, _, status in self.edges(data="status")
263 | ],
264 | **edges_kwargs,
265 | )
266 |
267 |
268 | class CallNode(NamedTuple):
269 | branch: int
270 | rank: int
271 |
272 |
273 | class CallEdgeStatus(Enum):
274 | UNEXPLORED = "unexplored"
275 | CALLED = "called"
276 | FAILURE = "failure"
277 |
278 |
279 | def _node_color(node: Union[Action, FeatureCondition, Behavior]) -> str:
280 | if isinstance(node, Action):
281 | return "red"
282 | if isinstance(node, FeatureCondition):
283 | return "blue"
284 | if isinstance(node, Behavior):
285 | return "orange"
286 | raise NotImplementedError
287 |
288 |
289 | def _call_status_to_color(status: Union[str, "CallEdgeStatus"]) -> str:
290 | status = CallEdgeStatus(status)
291 | if status is CallEdgeStatus.UNEXPLORED:
292 | return "black"
293 | if status is CallEdgeStatus.CALLED:
294 | return "green"
295 | if status is CallEdgeStatus.FAILURE:
296 | return "red"
297 | raise NotImplementedError
298 |
299 |
300 | def _call_graph_pos(call_graph: "CallGraph") -> Dict[str, Tuple[float, float]]:
301 | pos = {}
302 |
303 | roots = [n for (n, d) in call_graph.in_degree if d == 0]
304 | leafs = [n for (n, d) in call_graph.out_degree if d == 0]
305 |
306 | branches = all_simple_paths(call_graph, roots[0], leafs)
307 | branches = sorted(branches, key=lambda x: -len(x))
308 |
309 | for branch_id, nodes_in_branch in enumerate(branches):
310 | for node in nodes_in_branch:
311 | if node in pos:
312 | continue
313 | rank = node.rank
314 | pos[node] = [branch_id, -rank]
315 | return pos
316 |
--------------------------------------------------------------------------------
/tests/test_call_graph.py:
--------------------------------------------------------------------------------
1 | from networkx import DiGraph
2 | import pytest
3 |
4 | from hebg.behavior import Behavior
5 | from hebg.call_graph import CallEdgeStatus, CallGraph, CallNode, _call_graph_pos
6 | from hebg.heb_graph import HEBGraph
7 | from hebg.node import Action, FeatureCondition
8 |
9 | from pytest_mock import MockerFixture
10 | import pytest_check as check
11 |
12 | from tests import plot_graph
13 |
14 | from tests.examples.behaviors import F_F_A_Behavior
15 | from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors
16 | from tests.examples.feature_conditions import ThresholdFeatureCondition
17 |
18 |
19 | class TestCall:
20 | """Ensure that the call graph is faithful for debugging and efficient breadth first search."""
21 |
22 | def test_call_stack_without_branches(self) -> None:
23 | """When there is no branches, the graph should be a simple sequence of the call stack."""
24 | f_f_a_behavior = F_F_A_Behavior()
25 |
26 | draw = False
27 | if draw:
28 | plot_graph(f_f_a_behavior.graph.unrolled_graph)
29 | f_f_a_behavior(observation=-2)
30 |
31 | expected_graph = DiGraph(
32 | [
33 | ("F_F_A", "Greater or equal to 0 ?"),
34 | ("Greater or equal to 0 ?", "Greater or equal to -1 ?"),
35 | ("Greater or equal to -1 ?", "Action(0)"),
36 | ]
37 | )
38 |
39 | call_graph = f_f_a_behavior.graph.call_graph
40 | assert set(call_graph.call_edge_labels()) == set(expected_graph.edges())
41 |
42 | def test_split_on_same_fc_index(self, mocker: MockerFixture) -> None:
43 | """When there are multiple indexes on the same feature condition,
44 | a branch should be created."""
45 |
46 | expected_action = Action("EXPECTED", complexity=1)
47 |
48 | forbidden_value = "FORBIDDEN"
49 | forbidden_action = Action(forbidden_value, complexity=2)
50 | forbidden_action.__call__ = mocker.MagicMock(return_value=forbidden_value)
51 |
52 | class F_AA_Behavior(Behavior):
53 | """Feature condition with mutliple actions on same index."""
54 |
55 | def __init__(self) -> None:
56 | super().__init__("F_AA")
57 |
58 | def build_graph(self) -> HEBGraph:
59 | graph = HEBGraph(self)
60 | feature_condition = ThresholdFeatureCondition(
61 | relation=">=", threshold=0
62 | )
63 | graph.add_edge(
64 | feature_condition, Action(0, complexity=0), index=int(True)
65 | )
66 | graph.add_edge(feature_condition, forbidden_action, index=int(False))
67 | graph.add_edge(feature_condition, expected_action, index=int(False))
68 |
69 | return graph
70 |
71 | f_aa_behavior = F_AA_Behavior()
72 | draw = False
73 | if draw:
74 | plot_graph(f_aa_behavior.graph.unrolled_graph)
75 |
76 | # Sanity check that the right action should be called and not the forbidden one.
77 | assert f_aa_behavior(observation=-1) == expected_action.action
78 | forbidden_action.__call__.assert_not_called()
79 |
80 | # Graph should have the good split
81 | call_graph = f_aa_behavior.graph.call_graph
82 | expected_graph = DiGraph(
83 | [
84 | ("F_AA", "Greater or equal to 0 ?"),
85 | ("Greater or equal to 0 ?", "Action(EXPECTED)"),
86 | ("Greater or equal to 0 ?", "Action(FORBIDDEN)"),
87 | ]
88 | )
89 | assert set(call_graph.call_edge_labels()) == set(expected_graph.edges())
90 |
91 | @pytest.mark.xfail
92 | def test_multiple_call_to_same_fc(self, mocker: MockerFixture) -> None:
93 | """Call graph should allow for the same feature condition
94 | to be called multiple times in the same branch (in different behaviors)."""
95 | expected_action = Action("EXPECTED")
96 | unexpected_action = Action("UNEXPECTED")
97 |
98 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0)
99 |
100 | class SubBehavior(Behavior):
101 | def __init__(self) -> None:
102 | super().__init__("SubBehavior")
103 |
104 | def build_graph(self) -> HEBGraph:
105 | graph = HEBGraph(self)
106 | graph.add_edge(feature_condition, expected_action, index=int(True))
107 | graph.add_edge(feature_condition, unexpected_action, index=int(False))
108 | return graph
109 |
110 | class RootBehavior(Behavior):
111 | """Feature condition with mutliple actions on same index."""
112 |
113 | def __init__(self) -> None:
114 | super().__init__("RootBehavior")
115 |
116 | def build_graph(self) -> HEBGraph:
117 | graph = HEBGraph(self)
118 | graph.add_edge(feature_condition, SubBehavior(), index=int(True))
119 | graph.add_edge(feature_condition, unexpected_action, index=int(False))
120 |
121 | return graph
122 |
123 | root_behavior = RootBehavior()
124 | draw = False
125 | if draw:
126 | plot_graph(root_behavior.graph.unrolled_graph)
127 |
128 | # Sanity check that the right action should be called and not the forbidden one.
129 | assert root_behavior(observation=2) == expected_action.action
130 |
131 | # Graph should have the good split
132 | call_graph = root_behavior.graph.call_graph
133 | expected_graph = DiGraph(
134 | [
135 | ("RootBehavior", "Greater or equal to 0 ?"),
136 | ("Greater or equal to 0 ?", "SubBehavior"),
137 | ("SubBehavior", "Greater or equal to 0 ?"),
138 | ("Greater or equal to 0 ?", "Action(EXPECTED)"),
139 | ]
140 | )
141 | assert set(call_graph.call_edge_labels()) == set(expected_graph.edges())
142 |
143 | expected_labels = {
144 | CallNode(0, 0): "RootBehavior",
145 | CallNode(0, 1): "Greater or equal to 0 ?",
146 | CallNode(0, 2): "SubBehavior",
147 | CallNode(0, 3): "Greater or equal to 0 ?",
148 | CallNode(0, 4): "Action(EXPECTED)",
149 | }
150 | for node, label in call_graph.nodes(data="label"):
151 | check.equal(label, expected_labels[node])
152 |
153 | def test_chain_behaviors(self, mocker: MockerFixture) -> None:
154 | """When sub-behaviors with a graph are called recursively,
155 | the call graph should still find their nodes."""
156 |
157 | expected_action = "EXPECTED"
158 |
159 | class DummyBehavior(Behavior):
160 | __call__ = mocker.MagicMock(return_value=expected_action)
161 |
162 | class SubBehavior(Behavior):
163 | def __init__(self) -> None:
164 | super().__init__("SubBehavior")
165 |
166 | def build_graph(self) -> HEBGraph:
167 | graph = HEBGraph(self)
168 | graph.add_node(DummyBehavior("Dummy"))
169 | return graph
170 |
171 | class RootBehavior(Behavior):
172 | def __init__(self) -> None:
173 | super().__init__("RootBehavior")
174 |
175 | def build_graph(self) -> HEBGraph:
176 | graph = HEBGraph(self)
177 | graph.add_node(Behavior("SubBehavior"))
178 | return graph
179 |
180 | sub_behavior = SubBehavior()
181 | sub_behavior.graph
182 |
183 | root_behavior = RootBehavior()
184 | root_behavior.graph.all_behaviors["SubBehavior"] = sub_behavior
185 |
186 | # Sanity check that the right action should be called.
187 | assert root_behavior(observation=-1) == expected_action
188 |
189 | call_graph = root_behavior.graph.call_graph
190 | expected_graph = DiGraph(
191 | [
192 | ("RootBehavior", "SubBehavior"),
193 | ("SubBehavior", "Dummy"),
194 | ]
195 | )
196 | assert set(call_graph.call_edge_labels()) == set(expected_graph.edges())
197 |
198 | def test_looping_goback(self) -> None:
199 | """Loops with alternatives should be ignored."""
200 | draw = False
201 | _gather_wood, get_axe = build_looping_behaviors()
202 | assert get_axe({}) == "Punch tree"
203 |
204 | call_graph = get_axe.graph.call_graph
205 |
206 | if draw:
207 | plot_graph(call_graph)
208 |
209 | expected_labels = {
210 | CallNode(0, 0): "Get new axe",
211 | CallNode(0, 1): "Has wood ?",
212 | CallNode(1, 2): "Action(Summon axe out of thin air)",
213 | CallNode(0, 2): "Gather wood",
214 | CallNode(0, 3): "Has axe ?",
215 | CallNode(2, 4): "Get new axe",
216 | CallNode(0, 4): "Action(Punch tree)",
217 | }
218 | for node, label in call_graph.nodes(data="label"):
219 | check.equal(label, expected_labels[node])
220 |
221 | expected_graph = DiGraph(
222 | [
223 | ("Get new axe", "Has wood ?"),
224 | ("Has wood ?", "Action(Summon axe out of thin air)"),
225 | ("Has wood ?", "Gather wood"),
226 | ("Gather wood", "Has axe ?"),
227 | ("Has axe ?", "Get new axe"),
228 | ("Has axe ?", "Action(Punch tree)"),
229 | ]
230 | )
231 |
232 | assert set(call_graph.call_edge_labels()) == set(expected_graph.edges())
233 |
234 |
235 | class TestDraw:
236 | """Ensures that the graph is readable even in complex situations."""
237 |
238 | def test_result_on_first_branch(self) -> None:
239 | """Resulting action should always be on the first branch."""
240 | draw = False
241 | root_behavior = Behavior("Root", complexity=20)
242 | call_graph = CallGraph()
243 | call_graph.add_root(root_behavior, None)
244 |
245 | nodes = [
246 | (CallNode(0, 1), FeatureCondition("FC1", complexity=1)),
247 | (CallNode(0, 2), FeatureCondition("FC2", complexity=1)),
248 | (CallNode(0, 3), root_behavior),
249 | (CallNode(1, 1), FeatureCondition("FC3", complexity=1)),
250 | (CallNode(1, 2), FeatureCondition("FC4", complexity=1)),
251 | (CallNode(1, 3), FeatureCondition("FC5", complexity=1)),
252 | (CallNode(1, 4), Action("A", complexity=1)),
253 | ]
254 |
255 | for node, heb_node in nodes:
256 | call_graph.add_node(node, heb_node, None)
257 |
258 | edges = [
259 | (CallNode(0, 0), CallNode(0, 1), CallEdgeStatus.CALLED),
260 | (CallNode(0, 1), CallNode(0, 2), CallEdgeStatus.CALLED),
261 | (CallNode(0, 2), CallNode(0, 3), CallEdgeStatus.FAILURE),
262 | (CallNode(0, 0), CallNode(1, 1), CallEdgeStatus.CALLED),
263 | (CallNode(1, 1), CallNode(1, 2), CallEdgeStatus.CALLED),
264 | (CallNode(1, 2), CallNode(1, 3), CallEdgeStatus.CALLED),
265 | (CallNode(1, 3), CallNode(1, 4), CallEdgeStatus.CALLED),
266 | ]
267 |
268 | for start, end, status in edges:
269 | call_graph.add_edge(start, end, status)
270 |
271 | expected_poses = {
272 | CallNode(0, 0): [0, 0],
273 | CallNode(1, 1): [0, -1],
274 | CallNode(1, 2): [0, -2],
275 | CallNode(1, 3): [0, -3],
276 | CallNode(1, 4): [0, -4],
277 | CallNode(0, 1): [1, -1],
278 | CallNode(0, 2): [1, -2],
279 | CallNode(0, 3): [1, -3],
280 | }
281 | if draw:
282 | plot_graph(call_graph)
283 |
284 | assert _call_graph_pos(call_graph) == expected_poses
285 |
--------------------------------------------------------------------------------
/src/hebg/codegen.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """Module for code generation from HEBGraph."""
5 |
6 | from re import sub
7 | from typing import TYPE_CHECKING, Dict, List, Set, Tuple
8 |
9 | from hebg.behavior import Behavior
10 | from hebg.graph import get_roots, get_successors_with_index
11 | from hebg.metrics.histograms import cumulated_hebgraph_histogram
12 | from hebg.node import Action, FeatureCondition, Node
13 | from hebg.unrolling import BEHAVIOR_SEPARATOR
14 |
15 | if TYPE_CHECKING:
16 | from hebg.heb_graph import HEBGraph
17 |
18 |
19 | class GeneratedBehavior:
20 | """Base class for generated behaviors.
21 |
22 | Used to reduce the overhead of abstracting a behavior."""
23 |
24 | def __init__(
25 | self,
26 | actions: Dict[str, "Action"] = None,
27 | feature_conditions: Dict[str, "FeatureCondition"] = None,
28 | behaviors: Dict[str, "Behavior"] = None,
29 | ):
30 | self.actions = actions if actions is not None else {}
31 | self.feature_conditions = feature_conditions if actions is not None else {}
32 | self.known_behaviors = behaviors if behaviors is not None else {}
33 |
34 |
35 | def get_hebg_source(graph: "HEBGraph") -> str:
36 | """Builds the generated source code corresponding to the HEBGraph behavior.
37 |
38 | Args:
39 | graph (HEBGraph): HEBGraph to generate the source code from.
40 |
41 | Returns:
42 | str: Python source code corresponding the behavior of the HEBGraph.
43 | """
44 | class_codelines = get_behavior_class_codelines(graph)
45 | source = "\n".join(class_codelines)
46 | return source
47 |
48 |
49 | def get_behavior_class_codelines(
50 | graph: "HEBGraph",
51 | behaviors_histogram: Dict["Behavior", Dict["Node", int]] = None,
52 | add_dependencies: bool = True,
53 | ) -> List[str]:
54 | """Generate codelines of the whole GeneratedBehavior from a HEBGraph.
55 |
56 | Args:
57 | graph (HEBGraph): HEBGraph to generate the behavior from.
58 | behaviors_histogram (Dict[Behavior, Dict[Node, int]], optional): Histogram of uses
59 | of all behaviors needed for the given graph. Defaults to None.
60 | add_dependencies (bool, optional): If True, codelines will include other GeneratedBehavior
61 | from sub-behaviors of the given HEBGraph and a hashmap to map behavior name
62 | to coresponding GeneratedBehavior. Defaults to True.
63 |
64 | Returns:
65 | List[str]: Codelines generated of the HEBGraph GeneratedBehavior.
66 | """
67 | if behaviors_histogram is None:
68 | behaviors_histogram = cumulated_hebgraph_histogram(graph)
69 |
70 | (
71 | dependencies_codelines,
72 | behaviors_incall_codelines,
73 | hashmap_codelines,
74 | ) = get_dependencies_codelines(
75 | graph,
76 | behaviors_histogram,
77 | add_dependencies_codelines=add_dependencies,
78 | )
79 |
80 | class_codelines = []
81 | # Other behaviors dependencies
82 | if add_dependencies:
83 | class_codelines += ["from hebg.codegen import GeneratedBehavior", ""]
84 | for _behavior, dependency_codelines in dependencies_codelines.items():
85 | class_codelines += dependency_codelines
86 |
87 | # Class overhead
88 | behavior_class_name = _to_camel_case(graph.behavior.name.capitalize())
89 | class_codelines.append(f"class {behavior_class_name}(GeneratedBehavior):")
90 | # Call function
91 | class_codelines += get_behavior_call_codelines(
92 | graph,
93 | behaviors_incall_codelines=behaviors_incall_codelines,
94 | )
95 | # Dependencies hashmap
96 | if add_dependencies:
97 | class_codelines += hashmap_codelines
98 | return class_codelines
99 |
100 |
101 | def get_dependencies_codelines(
102 | graph: "HEBGraph",
103 | behaviors_histogram: Dict["Behavior", Dict["Node", int]],
104 | add_dependencies_codelines: bool = True,
105 | ) -> Tuple[Dict["Behavior", List[str]], Set["Behavior"], Dict[str, str]]:
106 | """Parse dependencies of the given HEBGraph behavior's.
107 |
108 | Args:
109 | graph (HEBGraph): HEBGraph to parse dependecies from.
110 | behaviors_histogram (Dict[Behavior, Dict[Node, int]]): _description_
111 | add_dependencies_codelines (bool, optional): _description_. Defaults to True.
112 |
113 | Returns:
114 | Tuple of three elements:
115 | - dependencies_codelines (Dict[Behavior, List[str]]): Codelines of the GeneratedBehavior for
116 | each of the behavior used in the HEBGraph if they can be computed, else a comment.
117 | - behaviors_incall_codelines (Set[Behavior]): Set of behaviors that should be directly
118 | unrolled in the call function (because the abstraction is not worth it).
119 | - hashmap_codelines (List[str]): Codelines of the map between behavior names
120 | and coresponding GeneratedBehavior.
121 | """
122 | dependencies_codelines: Dict["Behavior", List[str]] = {}
123 | behaviors_incall_codelines: Set["Behavior"] = set()
124 | dependencies_hashmap: Dict[str, str] = {}
125 | for behavior, n_used in sorted(
126 | behaviors_histogram.items(), key=lambda x: x[1], reverse=True
127 | ):
128 | if not isinstance(behavior, Behavior):
129 | continue
130 | if behavior.name in graph.all_behaviors:
131 | behavior = graph.all_behaviors[behavior.name]
132 |
133 | try:
134 | sub_graph = behavior.graph
135 | except NotImplementedError:
136 | # If subgraph cannot be computed, we simply have a ref
137 | if add_dependencies_codelines:
138 | dependencies_codelines[behavior] = [
139 | f"# Require '{behavior.name}' behavior to be given."
140 | ]
141 | continue
142 |
143 | if n_used == 1 or len(sub_graph.nodes()) == 1:
144 | dependencies_codelines[behavior] = []
145 | behaviors_incall_codelines.add(behavior)
146 | continue
147 |
148 | if add_dependencies_codelines:
149 | dependencies_codelines[behavior] = get_behavior_class_codelines(
150 | sub_graph, behaviors_histogram, add_dependencies=False
151 | )
152 | dependencies_hashmap[behavior.name] = _to_camel_case(
153 | behavior.name.capitalize()
154 | )
155 |
156 | hashmap_codelines = []
157 | if dependencies_hashmap:
158 | hashmap_codelines = ["BEHAVIOR_TO_NAME = {"]
159 | hashmap_codelines += [
160 | f" '{name}': {class_name},"
161 | for name, class_name in dependencies_hashmap.items()
162 | ]
163 | hashmap_codelines += ["}"]
164 |
165 | return dependencies_codelines, behaviors_incall_codelines, hashmap_codelines
166 |
167 |
168 | def get_behavior_call_codelines(
169 | graph: "HEBGraph",
170 | behaviors_incall_codelines: Set["Behavior"],
171 | indent: int = 1,
172 | with_overhead=True,
173 | ) -> List[str]:
174 | """Generate the codelines of a GeneratedBehavior call function.
175 |
176 | Args:
177 | graph (HEBGraph): HEBGraph from which to generate the behavior of.
178 | behaviors_incall_codelines (Set[Behavior]): Set of behavior to unroll directly
179 | instead of refering to.
180 | indent (int, optional): Indentation level. Defaults to 1.
181 | with_overhead (bool, optional): If True, adds the call function definition.
182 | Defaults to True.
183 |
184 | Returns:
185 | List[str]: Codelines of the GeneratedBehavior call function.
186 | """
187 | call_codelines = []
188 | if with_overhead:
189 | call_codelines.append(indent_str(indent) + "def __call__(self, observation):")
190 | indent += 1
191 | roots = get_roots(graph)
192 | return call_codelines + get_node_call_codelines(
193 | graph,
194 | roots[0],
195 | indent,
196 | behaviors_incall_codelines=behaviors_incall_codelines,
197 | )
198 |
199 |
200 | def get_node_call_codelines(
201 | graph: "HEBGraph",
202 | node: Node,
203 | indent: int,
204 | behaviors_incall_codelines: Set["Behavior"],
205 | ) -> List[str]:
206 | """Generate codelines for an HEBGraph node recursively using the succesors.
207 |
208 | Args:
209 | graph (HEBGraph): HEBGraph containing the node.
210 | node (Node): Node to generate the call of.
211 | indent (int): Indentation level.
212 | behaviors_incall_codelines (Set[Behavior]): Set of behavior to unroll directly
213 | instead of refering to.
214 | Raises:
215 | NotImplementedError: Node is not an Action, FeatureCondition or Behavior.
216 |
217 | Returns:
218 | List[str]: Codelines of the node call.
219 | """
220 | node_codelines = []
221 | if isinstance(node, Action):
222 | action_name = node.name.split(BEHAVIOR_SEPARATOR)[-1]
223 | node_codelines.append(
224 | indent_str(indent) + f"return self.actions['{action_name}'](observation)"
225 | )
226 | return node_codelines
227 | if isinstance(node, FeatureCondition):
228 | var_name = f"edge_index_{indent-2}" if indent > 2 else "edge_index"
229 | fc_name = node.name.split(BEHAVIOR_SEPARATOR)[-1]
230 | node_codelines.append(
231 | indent_str(indent)
232 | + f"{var_name} = self.feature_conditions['{fc_name}'](observation)"
233 | )
234 | for i in [0, 1]:
235 | node_codelines.append(indent_str(indent) + f"if {var_name} == {i}:")
236 | successors = get_successors_with_index(graph, node, i)
237 | for succ_node in successors:
238 | node_codelines += get_node_call_codelines(
239 | graph,
240 | node=succ_node,
241 | indent=indent + 1,
242 | behaviors_incall_codelines=behaviors_incall_codelines,
243 | )
244 | return node_codelines
245 | if isinstance(node, Behavior):
246 | if node in behaviors_incall_codelines:
247 | if node.name in graph.all_behaviors:
248 | node = graph.all_behaviors[node.name]
249 | return get_behavior_call_codelines(
250 | node.graph,
251 | behaviors_incall_codelines=behaviors_incall_codelines,
252 | indent=indent,
253 | with_overhead=False,
254 | )
255 | default_line = f"return self.known_behaviors['{node.name}'](observation)"
256 | return [indent_str(indent) + default_line]
257 | raise NotImplementedError
258 |
259 |
260 | def indent_str(indent_level: int, indent_amount: int = 4) -> str:
261 | """Gives a string indentation from a given indent level.
262 |
263 | Args:
264 | indent_level (int): Level of indentation.
265 | indent_amount (int, optional): Number of spaces per indent. Defaults to 4.
266 |
267 | Returns:
268 | str: Indentation string.
269 | """
270 | return " " * indent_level * indent_amount
271 |
272 |
273 | def _to_camel_case(text: str) -> str:
274 | s = (
275 | text.replace("-", " ")
276 | .replace("_", " ")
277 | .replace("?", "")
278 | .replace("[", "")
279 | .replace("]", "")
280 | .replace("(", "")
281 | .replace(")", "")
282 | .replace(",", "")
283 | )
284 | s = s.split()
285 | if len(text) == 0:
286 | return text
287 | return s[0] + "".join(i.capitalize() for i in s[1:])
288 |
289 |
290 | def _to_snake_case(text: str) -> str:
291 | text = text.replace("-", " ").replace("?", "")
292 | return "_".join(
293 | sub("([A-Z][a-z]+)", r" \1", sub("([A-Z]+)", r" \1", text)).split()
294 | ).lower()
295 |
--------------------------------------------------------------------------------
/src/hebg/metrics/histograms.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """HEBGraph used nodes histograms computation."""
5 |
6 | from typing import TYPE_CHECKING, Dict, List, Tuple
7 | from warnings import warn
8 |
9 | import numpy as np
10 |
11 | from hebg.behavior import Behavior
12 | from hebg.metrics.complexity.utils import update_sum_dict
13 | from hebg.node import Action, FeatureCondition
14 |
15 | if TYPE_CHECKING:
16 | from hebg import HEBGraph, Node
17 |
18 |
19 | def behaviors_histograms(
20 | behaviors: List["Behavior"],
21 | default_node_complexity: float = 1.0,
22 | ) -> Dict["Behavior", Dict["Node", int]]:
23 | """Compute the used nodes histograms for a list of Behavior.
24 |
25 | Args:
26 | behaviors: List of Behavior to compute histograms of.
27 | default_node_complexity: Default node complexity if Node has no attribute complexity.
28 |
29 | Return:
30 | Dictionary of dictionaries of the number of use for each used node, for each behavior.
31 | """
32 | return behaviors_histograms_and_complexites(behaviors, default_node_complexity)[0]
33 |
34 |
35 | def behaviors_histograms_and_complexites(
36 | behaviors: List["Behavior"],
37 | default_node_complexity: float = 1.0,
38 | ) -> Tuple[Dict["Behavior", Dict["Node", int]], Dict["Behavior", float]]:
39 | """Compute the used nodes histograms for a list of Behavior.
40 |
41 | Args:
42 | behaviors: List of Behavior to compute histograms of.
43 | default_node_complexity: Default node complexity if Node has no attribute complexity.
44 |
45 | Return:
46 | Tuple of two elements:
47 | - Dictionary of dictionaries of the number of use for each used node, for each behavior.
48 | - Dictionary computed complexity for each behavior.
49 | """
50 | histograms: Dict["Behavior", Dict["Node", int]] = {}
51 | complexities: Dict["Behavior", float] = {}
52 | for behavior in behaviors:
53 | try:
54 | graph = behavior.graph
55 | except NotImplementedError:
56 | warn(
57 | f"Could not load graph for behavior: {behavior}."
58 | "Skipping histogram computation."
59 | )
60 | continue
61 | histogram, complexity = hebgraph_histogram_and_complexity(
62 | graph,
63 | default_node_complexity=default_node_complexity,
64 | )
65 | histograms[behavior] = histogram
66 | complexities[behavior] = complexity
67 | return histograms, complexities
68 |
69 |
70 | def hebgraph_histogram_and_complexity(
71 | graph: "HEBGraph", default_node_complexity: float = 1.0
72 | ) -> Tuple[Dict["Node", int], float]:
73 | """Compute the used nodes histogram for a Behavior.
74 |
75 | Args:
76 | behavior: Behavior to compute histogram of.
77 | default_node_complexity: Default node complexity if Node has no attribute complexity.
78 |
79 | Return:
80 | Tuple composed of two element:
81 | - Dictionary of the number of use for each used node in the graph.
82 | - Computed total complexity of the graph.
83 |
84 | """
85 | nodes_histograms, nodes_complexities = nodes_histograms_and_complexities(
86 | graph, default_node_complexity
87 | )
88 | root = graph.graph["nodes_by_level"][0][0] # Assumes a single root
89 | return nodes_histograms[root], nodes_complexities[root]
90 |
91 |
92 | def cumulated_hebgraph_histogram(
93 | graph: "HEBGraph", default_node_complexity: float = 1.0
94 | ) -> Dict["Node", int]:
95 | """Unroll the hebgraph histogram by accumulating sub-behaviors histograms.
96 |
97 | Args:
98 | graph (HEBGraph): The HEBgraph to compute the cumulated histogram of.
99 | default_node_complexity (float, optional): Default node complexity. Defaults to 1.0.
100 |
101 | Returns:
102 | Dict[Node, int]: Cumulated histogram of used nodes with unrolled behaviors.
103 | """
104 | histogram, _ = hebgraph_histogram_and_complexity(graph, default_node_complexity)
105 | histograms = {graph.behavior: histogram}
106 | sub_behaviors = [
107 | node
108 | for node in histogram
109 | if isinstance(node, Behavior) and node != graph.behavior
110 | ]
111 | for i, behavior in enumerate(sub_behaviors):
112 | if behavior.name in graph.all_behaviors:
113 | sub_behaviors[i] = graph.all_behaviors[behavior.name]
114 |
115 | histograms.update(behaviors_histograms(sub_behaviors, default_node_complexity))
116 |
117 | done = False
118 | behavior_iteration = {} # Stoping condition
119 | while not done:
120 | behaviors = [node for node in histogram if isinstance(node, Behavior)]
121 | behavior_iteration, histogram, histograms = _iterate_cumulation(
122 | graph=graph,
123 | behaviors=behaviors,
124 | histogram=histogram,
125 | histograms=histograms,
126 | behavior_iteration=behavior_iteration,
127 | default_node_complexity=default_node_complexity,
128 | )
129 |
130 | # Recompute behaviors because some might have been added
131 | behaviors = [node for node in histogram if isinstance(node, Behavior)]
132 |
133 | # We stop when all behaviors have been iterated on enough.
134 | if all(behavior in behavior_iteration for behavior in behaviors):
135 | done = all(
136 | behavior_iteration[behavior] == histogram[behavior]
137 | for behavior in behaviors
138 | )
139 |
140 | return histogram
141 |
142 |
143 | def _iterate_cumulation(
144 | graph: "HEBGraph",
145 | behaviors: List["Behavior"],
146 | histogram: Dict["Node", int],
147 | histograms: Dict["Behavior", Dict["Node", int]],
148 | behavior_iteration: Dict["Behavior", int],
149 | default_node_complexity: float = 1.0,
150 | ):
151 | """Iterate once on every behavior accumulating histograms.
152 |
153 | Args:
154 | graph (HEBGraph): Current graph being iterated on.
155 | behaviors (List[Behavior]): List of behaviors to iterate.
156 | histogram (Dict[Node, int]): Histogram being accumulated.
157 | histograms (Dict[Behavior, Dict[Node, int]]): Histograms of knowed behaviors.
158 | behavior_iteration (Dict[Behavior, int]): Number of iterations already done
159 | for each behavior.
160 |
161 | Returns:
162 | Tuple of three updated dictionaries. (behavior_iteration, histogram, histograms).
163 | """
164 | for behavior in behaviors:
165 | already_iterated = (
166 | behavior_iteration[behavior] if behavior in behavior_iteration else 0
167 | )
168 | n_used = max(0, histogram[behavior] - already_iterated)
169 | if n_used == 0:
170 | continue
171 | if behavior not in histograms:
172 | if behavior.name in graph.all_behaviors:
173 | behavior = graph.all_behaviors[behavior.name]
174 | try:
175 | behavior_graph = behavior.graph
176 | except NotImplementedError:
177 | if behavior not in behavior_iteration:
178 | behavior_iteration[behavior] = 0
179 | behavior_iteration[behavior] += 1
180 | continue
181 | sub_histogram, _ = hebgraph_histogram_and_complexity(
182 | behavior_graph, default_node_complexity
183 | )
184 | histograms[behavior] = sub_histogram
185 | for _ in range(n_used):
186 | if behavior in histograms[behavior]:
187 | histograms[behavior].pop(behavior)
188 | histogram = update_sum_dict(histogram, histograms[behavior])
189 | if behavior not in behavior_iteration:
190 | behavior_iteration[behavior] = 0
191 | behavior_iteration[behavior] += 1
192 | return behavior_iteration, histogram, histograms
193 |
194 |
195 | def nodes_histograms_and_complexities(
196 | graph: "HEBGraph",
197 | default_node_complexity: float = 1.0,
198 | _behaviors_in_search=None,
199 | ):
200 | """Compute the number of times each node of the graph is present
201 | while computing complexities to find the least complex path.
202 |
203 | Args:
204 | graph (HEBGraph): HEBGraph to compute the histogram of.
205 | default_node_complexity (float, optional): Default node complexity if undefined.
206 | Defaults to 1.0.
207 | _behaviors_in_search (List[Behavior], optional): List of behaviors already in reccursive
208 | pile to avoid infinite cycle. Defaults to None.
209 |
210 | Returns:
211 | Tuple[Dict[Node, Dict[Node, int]], Dict[Node, float]]: Tuple of two values:
212 | - Dictionary of subnode used for each node. (nodes_used_nodes)
213 | - Dictionary of computed smallest path complexities for each node. (complexities)
214 | """
215 | nodes_by_level = graph.graph["nodes_by_level"]
216 | depth = graph.graph["depth"]
217 |
218 | _behaviors_in_search = [] if _behaviors_in_search is None else _behaviors_in_search
219 | _behaviors_in_search.append(str(graph.behavior))
220 |
221 | complexities = {}
222 | nodes_used_nodes = {}
223 | for node in graph.nodes:
224 | if isinstance(node, (Action, FeatureCondition)):
225 | complexities[node] = node.complexity
226 | nodes_used_nodes[node] = {node: 1}
227 |
228 | for level in range(depth + 1)[::-1]:
229 | for node in nodes_by_level[level]:
230 | node_complexity = 0
231 | node_used_nodes = {}
232 |
233 | # Best successors accumulated histograms and complexity
234 | succ_by_index, complexities_by_index = _successors_by_index(
235 | graph, node, complexities
236 | )
237 | for index, values in complexities_by_index.items():
238 | min_index = np.argmin(values)
239 | choosen_succ = succ_by_index[index][min_index]
240 | node_used_nodes = update_sum_dict(
241 | node_used_nodes, nodes_used_nodes[choosen_succ]
242 | )
243 | node_complexity += values[min_index]
244 |
245 | # Node only histogram and complexity
246 | (
247 | node_only_used_behaviors,
248 | node_only_complexity,
249 | ) = _get_node_histogram_complexity(
250 | node,
251 | default_node_complexity=default_node_complexity,
252 | behaviors_in_search=_behaviors_in_search,
253 | )
254 | node_used_nodes = update_sum_dict(node_used_nodes, node_only_used_behaviors)
255 | node_complexity += node_only_complexity
256 |
257 | complexities[node] = node_complexity
258 | nodes_used_nodes[node] = node_used_nodes
259 | return nodes_used_nodes, complexities
260 |
261 |
262 | def _successors_by_index(
263 | graph: "HEBGraph", node: "Node", complexities: Dict["Node", float]
264 | ) -> Tuple[Dict[int, List["Node"]], Dict[int, List[float]]]:
265 | """Group successors and their complexities by index.
266 |
267 | Args:
268 | graph: The HEBGraph to use.
269 | node: The Node from which we want to group successors.
270 | complexities: Dictionary of complexities for each potential successor node.
271 |
272 | Return:
273 | Tuple composed of a dictionary of successors for each index
274 | and a dictionary of complexities for each index.
275 |
276 | """
277 | complexities_by_index = {}
278 | succ_by_index = {}
279 | for succ in graph.successors(node):
280 | succ_complexity = complexities[succ]
281 | index = int(graph.edges[node, succ]["index"])
282 | if index not in complexities_by_index:
283 | complexities_by_index[index] = []
284 | if index not in succ_by_index:
285 | succ_by_index[index] = []
286 | complexities_by_index[index].append(succ_complexity)
287 | succ_by_index[index].append(succ)
288 | return succ_by_index, complexities_by_index
289 |
290 |
291 | def _get_node_histogram_complexity(
292 | node: "Node", behaviors_in_search=None, default_node_complexity: float = 1.0
293 | ) -> Tuple[Dict["Node", int], float]:
294 | """Compute the used nodes histogram and complexity of a single node.
295 |
296 | Args:
297 | node: The Node from which we want to compute the complexity.
298 | behaviors_in_search: List of Behavior already in search to avoid circular search.
299 | default_node_complexity: Default node complexity if Node has no attribute complexity.
300 |
301 | Return:
302 | Tuple composed of a dictionary of the number of use for each used Node by the given node
303 | and the given node complexity.
304 |
305 | """
306 |
307 | if node.type == "behavior":
308 | if behaviors_in_search is not None and str(node) in behaviors_in_search:
309 | return {}, np.inf
310 | if node.type in ("action", "feature_condition", "behavior"):
311 | if node.complexity is not None:
312 | node_complexity = node.complexity
313 | else:
314 | node_complexity = default_node_complexity
315 | return {node: 1}, node_complexity
316 | if node.type == "empty":
317 | return {}, 0
318 | raise ValueError(f"Unkowned node type {node.type}")
319 |
--------------------------------------------------------------------------------
/tests/test_paper_basic_example.py:
--------------------------------------------------------------------------------
1 | # HEBGraph for explainable hierarchical reinforcement learning
2 | # Copyright (C) 2021-2024 Mathïs FEDERICO
3 |
4 | """Integration tests for the initial paper examples."""
5 |
6 | from typing import Dict, List
7 | from copy import deepcopy
8 |
9 | # import matplotlib.pyplot as plt
10 |
11 |
12 | import pytest
13 | import pytest_check as check
14 |
15 | from itertools import permutations
16 | from networkx.classes.digraph import DiGraph
17 | from networkx import is_isomorphic
18 |
19 | from hebg import Action, Behavior, FeatureCondition, HEBGraph
20 | from hebg.metrics.histograms import behaviors_histograms, cumulated_hebgraph_histogram
21 | from hebg.metrics.complexity.complexities import learning_complexity
22 | from hebg.requirements_graph import build_requirement_graph
23 | from hebg.unrolling import BEHAVIOR_SEPARATOR, unroll_graph
24 |
25 | from tests.examples.behaviors.report_example import Behavior0, Behavior1, Behavior2
26 |
27 |
28 | class TestPaperBasicExamples:
29 | """Basic examples from the initial paper"""
30 |
31 | @pytest.fixture(autouse=True)
32 | def setup(self):
33 | """Initialize variables."""
34 |
35 | self.actions: List[Action] = [Action(i, complexity=1) for i in range(3)]
36 | self.feature_conditions: List[FeatureCondition] = [
37 | FeatureCondition(f"feature {i}", complexity=1) for i in range(6)
38 | ]
39 | self.behaviors: List[Behavior] = [Behavior0(), Behavior1(), Behavior2()]
40 |
41 | self.expected_behavior_histograms: Dict[Behavior, Dict[Action, int]] = {
42 | self.behaviors[0]: {
43 | self.actions[0]: 1,
44 | self.actions[1]: 1,
45 | self.feature_conditions[0]: 1,
46 | },
47 | self.behaviors[1]: {
48 | self.actions[0]: 1,
49 | self.actions[2]: 1,
50 | self.behaviors[0]: 1,
51 | self.feature_conditions[1]: 1,
52 | self.feature_conditions[2]: 1,
53 | },
54 | self.behaviors[2]: {
55 | self.actions[0]: 1,
56 | self.behaviors[0]: 1,
57 | self.behaviors[1]: 2,
58 | self.feature_conditions[3]: 1,
59 | self.feature_conditions[4]: 1,
60 | self.feature_conditions[5]: 1,
61 | },
62 | }
63 |
64 | def test_histograms(self):
65 | """should give expected histograms."""
66 | check.equal(
67 | behaviors_histograms(self.behaviors), self.expected_behavior_histograms
68 | )
69 |
70 | def test_cumulated_histograms(self):
71 | """should give expected cumulated histograms."""
72 | expected_cumulated_histograms = {
73 | self.behaviors[0]: {
74 | self.actions[0]: 1,
75 | self.actions[1]: 1,
76 | self.feature_conditions[0]: 1,
77 | },
78 | self.behaviors[1]: {
79 | self.actions[0]: 2,
80 | self.actions[2]: 1,
81 | self.actions[1]: 1,
82 | self.feature_conditions[0]: 1,
83 | self.feature_conditions[1]: 1,
84 | self.feature_conditions[2]: 1,
85 | self.behaviors[0]: 1,
86 | },
87 | self.behaviors[2]: {
88 | self.actions[0]: 6,
89 | self.actions[1]: 3,
90 | self.actions[2]: 2,
91 | self.feature_conditions[0]: 3,
92 | self.feature_conditions[1]: 2,
93 | self.feature_conditions[2]: 2,
94 | self.feature_conditions[3]: 1,
95 | self.feature_conditions[4]: 1,
96 | self.feature_conditions[5]: 1,
97 | self.behaviors[0]: 3,
98 | self.behaviors[1]: 2,
99 | },
100 | }
101 | for behavior in self.behaviors:
102 | check.equal(
103 | cumulated_hebgraph_histogram(behavior.graph),
104 | expected_cumulated_histograms[behavior],
105 | )
106 |
107 | def test_learning_complexity(self):
108 | """should give expected learning_complexity."""
109 | expected_learning_complexities = {
110 | self.behaviors[0]: 3,
111 | self.behaviors[1]: 6,
112 | self.behaviors[2]: 9,
113 | }
114 | expected_saved_complexities = {
115 | self.behaviors[0]: 0,
116 | self.behaviors[1]: 1,
117 | self.behaviors[2]: 12,
118 | }
119 |
120 | for behavior in self.behaviors:
121 | c_learning, saved_complexity = learning_complexity(
122 | behavior, used_nodes_all=self.expected_behavior_histograms
123 | )
124 |
125 | print(
126 | f"{behavior}: {c_learning}|{expected_learning_complexities[behavior]}"
127 | f" {saved_complexity}|{expected_saved_complexities[behavior]}"
128 | )
129 |
130 | check.almost_equal(c_learning, expected_learning_complexities[behavior])
131 | check.almost_equal(saved_complexity, expected_saved_complexities[behavior])
132 |
133 | def test_codegen(self):
134 | expected_code = "\n".join(
135 | (
136 | "from hebg.codegen import GeneratedBehavior",
137 | "",
138 | "class Behavior0(GeneratedBehavior):",
139 | " def __call__(self, observation):",
140 | " edge_index = self.feature_conditions['feature 0'](observation)",
141 | " if edge_index == 0:",
142 | " return self.actions['Action(0)'](observation)",
143 | " if edge_index == 1:",
144 | " return self.actions['Action(1)'](observation)",
145 | "class Behavior1(GeneratedBehavior):",
146 | " def __call__(self, observation):",
147 | " edge_index = self.feature_conditions['feature 1'](observation)",
148 | " if edge_index == 0:",
149 | " return self.known_behaviors['behavior 0'](observation)",
150 | " if edge_index == 1:",
151 | " edge_index_1 = self.feature_conditions['feature 2'](observation)",
152 | " if edge_index_1 == 0:",
153 | " return self.actions['Action(0)'](observation)",
154 | " if edge_index_1 == 1:",
155 | " return self.actions['Action(2)'](observation)",
156 | "class Behavior2(GeneratedBehavior):",
157 | " def __call__(self, observation):",
158 | " edge_index = self.feature_conditions['feature 3'](observation)",
159 | " if edge_index == 0:",
160 | " edge_index_1 = self.feature_conditions['feature 4'](observation)",
161 | " if edge_index_1 == 0:",
162 | " return self.actions['Action(0)'](observation)",
163 | " if edge_index_1 == 1:",
164 | " return self.known_behaviors['behavior 1'](observation)",
165 | " if edge_index == 1:",
166 | " edge_index_1 = self.feature_conditions['feature 5'](observation)",
167 | " if edge_index_1 == 0:",
168 | " return self.known_behaviors['behavior 1'](observation)",
169 | " if edge_index_1 == 1:",
170 | " return self.known_behaviors['behavior 0'](observation)",
171 | "BEHAVIOR_TO_NAME = {",
172 | " 'behavior 0': Behavior0,",
173 | " 'behavior 1': Behavior1,",
174 | "}",
175 | )
176 | )
177 | generated_code = self.behaviors[2].graph.generate_source_code()
178 | check.equal(generated_code, expected_code)
179 |
180 | def test_requirement_graph_edges(self):
181 | """should give expected requirement_graph edges."""
182 | expected_requirement_graph = DiGraph()
183 | for behavior in self.behaviors:
184 | expected_requirement_graph.add_node(behavior)
185 | expected_requirement_graph.add_edge(self.behaviors[0], self.behaviors[1])
186 | expected_requirement_graph.add_edge(self.behaviors[0], self.behaviors[2])
187 | expected_requirement_graph.add_edge(self.behaviors[1], self.behaviors[2])
188 |
189 | requirements_graph = build_requirement_graph(self.behaviors)
190 | for behavior, other_behavior in permutations(self.behaviors, 2):
191 | print(behavior, other_behavior)
192 | req_has_edge = requirements_graph.has_edge(behavior, other_behavior)
193 | expected_req_has_edge = expected_requirement_graph.has_edge(
194 | behavior, other_behavior
195 | )
196 | check.equal(req_has_edge, expected_req_has_edge)
197 |
198 | def test_requirement_graph_levels(self):
199 | """should give expected requirement_graph node levels (requirement depth)."""
200 | expected_levels = {
201 | self.behaviors[0]: 0,
202 | self.behaviors[1]: 1,
203 | self.behaviors[2]: 2,
204 | }
205 | requirements_graph = build_requirement_graph(self.behaviors)
206 | for behavior, level in requirements_graph.nodes(data="level"):
207 | check.equal(level, expected_levels[behavior])
208 |
209 | def test_unrolled_behaviors_graphs(self):
210 | """should give expected unrolled_behaviors_graphs for each example behaviors."""
211 |
212 | def lname(*args):
213 | return BEHAVIOR_SEPARATOR.join([str(arg) for arg in args])
214 |
215 | expected_graph_0 = deepcopy(self.behaviors[0].graph)
216 |
217 | expected_graph_1 = HEBGraph(self.behaviors[1])
218 | feature_0 = FeatureCondition(lname(self.behaviors[0], "feature 0"))
219 | expected_graph_1.add_edge(
220 | feature_0, Action(0, lname(self.behaviors[0], "Action(0)")), index=False
221 | )
222 | expected_graph_1.add_edge(
223 | feature_0, Action(1, lname(self.behaviors[0], "Action(1)")), index=True
224 | )
225 | feature_1 = FeatureCondition("feature 1")
226 | feature_2 = FeatureCondition("feature 2")
227 | expected_graph_1.add_edge(feature_1, feature_0, index=False)
228 | expected_graph_1.add_edge(feature_1, feature_2, index=True)
229 | expected_graph_1.add_edge(feature_2, Action(0), index=False)
230 | expected_graph_1.add_edge(feature_2, Action(2), index=True)
231 |
232 | expected_graph_2 = HEBGraph(self.behaviors[2])
233 | feature_3 = FeatureCondition("feature 3")
234 | feature_4 = FeatureCondition("feature 4")
235 | feature_5 = FeatureCondition("feature 5")
236 | expected_graph_2.add_edge(feature_3, feature_4, index=False)
237 | expected_graph_2.add_edge(feature_3, feature_5, index=True)
238 | expected_graph_2.add_edge(feature_4, Action(0), index=False)
239 |
240 | feature_0 = FeatureCondition(
241 | lname(self.behaviors[1], self.behaviors[0], "feature 0")
242 | )
243 | expected_graph_2.add_edge(
244 | feature_0,
245 | Action(0, lname(self.behaviors[1], self.behaviors[0], "Action(0)")),
246 | index=False,
247 | )
248 | expected_graph_2.add_edge(
249 | feature_0,
250 | Action(1, lname(self.behaviors[1], self.behaviors[0], "Action(1)")),
251 | index=True,
252 | )
253 | feature_1 = FeatureCondition(lname(self.behaviors[1], "feature 1"))
254 | feature_2 = FeatureCondition(lname(self.behaviors[1], "feature 2"))
255 | expected_graph_2.add_edge(feature_1, feature_0, index=False)
256 | expected_graph_2.add_edge(feature_1, feature_2, index=True)
257 | expected_graph_2.add_edge(
258 | feature_2, Action(0, lname(self.behaviors[1], "Action(0)")), index=False
259 | )
260 | expected_graph_2.add_edge(
261 | feature_2, Action(2, lname(self.behaviors[1], "Action(2)")), index=True
262 | )
263 |
264 | expected_graph_2.add_edge(feature_4, feature_1, index=True)
265 |
266 | feature_0_0 = FeatureCondition(lname(self.behaviors[0], "feature 0"))
267 | expected_graph_2.add_edge(
268 | feature_0_0,
269 | Action(0, lname(self.behaviors[0], "Action(0)")),
270 | index=False,
271 | )
272 | expected_graph_2.add_edge(
273 | feature_0_0,
274 | Action(1, lname(self.behaviors[0], "Action(1)")),
275 | index=True,
276 | )
277 |
278 | expected_graph_2.add_edge(feature_5, feature_1, index=False)
279 | expected_graph_2.add_edge(feature_5, feature_0_0, index=True)
280 |
281 | expected_graph = {
282 | self.behaviors[0]: expected_graph_0,
283 | self.behaviors[1]: expected_graph_1,
284 | self.behaviors[2]: expected_graph_2,
285 | }
286 | for behavior in self.behaviors:
287 | unrolled_graph = unroll_graph(behavior.graph, add_prefix=True)
288 | check.is_true(is_isomorphic(unrolled_graph, expected_graph[behavior]))
289 |
290 | # fig, axes = plt.subplots(1, 2)
291 | # unrolled_graph = behavior.graph.unrolled_graph
292 | # unrolled_graph.draw(axes[0], draw_behaviors_hulls=True)
293 | # expected_graph[behavior].draw(axes[1], draw_behaviors_hulls=True)
294 | # plt.show()
295 |
--------------------------------------------------------------------------------
/tests/test_code_generation.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Any
2 |
3 | import pytest
4 | import pytest_check as check
5 |
6 | from hebg import Action, FeatureCondition, Behavior
7 |
8 | from tests.examples.behaviors import (
9 | FundamentalBehavior,
10 | F_A_Behavior,
11 | F_F_A_Behavior,
12 | build_binary_sum_behavior,
13 | )
14 | from tests.examples.feature_conditions import ThresholdFeatureCondition
15 |
16 |
17 | class TestABehavior:
18 | """(A) Fundamental behaviors (single Action node) should return action call."""
19 |
20 | @pytest.fixture(autouse=True)
21 | def setup(self):
22 | self.behavior = FundamentalBehavior(Action(42))
23 |
24 | def test_source_codegen(self):
25 | source_code = self.behavior.graph.generate_source_code()
26 | expected_source_code = "\n".join(
27 | (
28 | "from hebg.codegen import GeneratedBehavior",
29 | "",
30 | "class Action42Behavior(GeneratedBehavior):",
31 | " def __call__(self, observation):",
32 | " return self.actions['Action(42)'](observation)",
33 | )
34 | )
35 | check.equal(
36 | source_code,
37 | expected_source_code,
38 | msg=_unidiff_output(source_code, expected_source_code),
39 | )
40 |
41 | def test_exec_codegen(self):
42 | check_execution_for_values(self.behavior, "Action42Behavior", (1, -1))
43 |
44 |
45 | class TestFABehavior:
46 | """(F-A) Feature condition should generate if/else condition."""
47 |
48 | @pytest.fixture(autouse=True)
49 | def setup(self):
50 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0)
51 | actions = {0: Action(0), 1: Action(1)}
52 | self.behavior = F_A_Behavior("Is above_zero", feature_condition, actions)
53 |
54 | def test_source_codegen(self):
55 | source_code = self.behavior.graph.generate_source_code()
56 | expected_source_code = "\n".join(
57 | (
58 | "from hebg.codegen import GeneratedBehavior",
59 | "",
60 | "class IsAboveZero(GeneratedBehavior):",
61 | " def __call__(self, observation):",
62 | " edge_index = self.feature_conditions['Greater or equal to 0 ?'](observation)",
63 | " if edge_index == 0:",
64 | " return self.actions['Action(0)'](observation)",
65 | " if edge_index == 1:",
66 | " return self.actions['Action(1)'](observation)",
67 | )
68 | )
69 |
70 | check.equal(
71 | source_code,
72 | expected_source_code,
73 | msg=_unidiff_output(source_code, expected_source_code),
74 | )
75 |
76 | def test_exec_codegen(self):
77 | check_execution_for_values(self.behavior, "IsAboveZero", (1, -1))
78 |
79 |
80 | class TestFFABehavior:
81 | """(F-F-A) Chained FeatureConditions should condition should generate nested if/else."""
82 |
83 | @pytest.fixture(autouse=True)
84 | def setup(self):
85 | self.behavior = F_F_A_Behavior("scalar classification ]-1,0,1[ ?")
86 |
87 | def test_source_codegen(self):
88 | source_code = self.behavior.graph.generate_source_code()
89 | expected_source_code = "\n".join(
90 | (
91 | "from hebg.codegen import GeneratedBehavior",
92 | "",
93 | "class ScalarClassification101(GeneratedBehavior):",
94 | " def __call__(self, observation):",
95 | " edge_index = self.feature_conditions['Greater or equal to 0 ?'](observation)",
96 | " if edge_index == 0:",
97 | " edge_index_1 = self.feature_conditions['Greater or equal to -1 ?'](observation)",
98 | " if edge_index_1 == 0:",
99 | " return self.actions['Action(0)'](observation)",
100 | " if edge_index_1 == 1:",
101 | " return self.actions['Action(1)'](observation)",
102 | " if edge_index == 1:",
103 | " edge_index_1 = self.feature_conditions['Lesser or equal to 1 ?'](observation)",
104 | " if edge_index_1 == 0:",
105 | " return self.actions['Action(3)'](observation)",
106 | " if edge_index_1 == 1:",
107 | " return self.actions['Action(2)'](observation)",
108 | )
109 | )
110 |
111 | check.equal(
112 | source_code,
113 | expected_source_code,
114 | msg=_unidiff_output(source_code, expected_source_code),
115 | )
116 |
117 | def test_exec_codegen(self):
118 | check_execution_for_values(
119 | self.behavior, "ScalarClassification101", (2, 1, -1, -2)
120 | )
121 |
122 |
123 | class TestFBBehavior:
124 | """(F-BA) Behaviors should be unrolled by default if they appear only once and have a graph."""
125 |
126 | @pytest.fixture(autouse=True)
127 | def setup(self):
128 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0)
129 | actions = {0: Action(0), 1: Action(1)}
130 | sub_behavior = F_A_Behavior("Is above_zero", feature_condition, actions)
131 |
132 | feature_condition = ThresholdFeatureCondition(relation="<=", threshold=1)
133 | actions = {0: Action(0), 1: sub_behavior}
134 | self.behavior = F_A_Behavior("Is between 0 and 1 ?", feature_condition, actions)
135 |
136 | def test_source_codegen(self):
137 | source_code = self.behavior.graph.generate_source_code()
138 | expected_source_code = "\n".join(
139 | (
140 | "from hebg.codegen import GeneratedBehavior",
141 | "",
142 | "class IsBetween0And1(GeneratedBehavior):",
143 | " def __call__(self, observation):",
144 | " edge_index = self.feature_conditions['Lesser or equal to 1 ?'](observation)",
145 | " if edge_index == 0:",
146 | " return self.actions['Action(0)'](observation)",
147 | " if edge_index == 1:",
148 | " edge_index_1 = self.feature_conditions['Greater or equal to 0 ?'](observation)",
149 | " if edge_index_1 == 0:",
150 | " return self.actions['Action(0)'](observation)",
151 | " if edge_index_1 == 1:",
152 | " return self.actions['Action(1)'](observation)",
153 | )
154 | )
155 |
156 | check.equal(
157 | source_code,
158 | expected_source_code,
159 | msg=_unidiff_output(source_code, expected_source_code),
160 | )
161 |
162 | def test_exec_codegen(self):
163 | check_execution_for_values(self.behavior, "IsBetween0And1", (-1, 0, 1, 2))
164 |
165 |
166 | class TestFBBehaviorNameRef:
167 | """(F-BA) Behaviors should work with only name reference to behavior,
168 | but will expect behavior to be given, even when unrolled."""
169 |
170 | @pytest.fixture(autouse=True)
171 | def setup(self):
172 | feature_condition = ThresholdFeatureCondition(relation="<=", threshold=1)
173 | actions = {0: Action(0), 1: Behavior("Is above_zero")}
174 | self.behavior = F_A_Behavior("Is between 0 and 1 ?", feature_condition, actions)
175 |
176 | @pytest.mark.filterwarnings("ignore:Could not load graph for behavior")
177 | def test_source_codegen_by_ref(self):
178 | unrolled_source_code = self.behavior.graph.generate_source_code()
179 | source_code = self.behavior.graph.unrolled_graph.generate_source_code()
180 | expected_source_code = "\n".join(
181 | (
182 | "from hebg.codegen import GeneratedBehavior",
183 | "",
184 | "# Require 'Is above_zero' behavior to be given.",
185 | "class IsBetween0And1(GeneratedBehavior):",
186 | " def __call__(self, observation):",
187 | " edge_index = self.feature_conditions['Lesser or equal to 1 ?'](observation)",
188 | " if edge_index == 0:",
189 | " return self.actions['Action(0)'](observation)",
190 | " if edge_index == 1:",
191 | " return self.known_behaviors['Is above_zero'](observation)",
192 | )
193 | )
194 |
195 | check.equal(
196 | source_code,
197 | expected_source_code,
198 | msg=_unidiff_output(source_code, expected_source_code),
199 | )
200 | check.equal(
201 | unrolled_source_code,
202 | expected_source_code,
203 | msg=_unidiff_output(unrolled_source_code, expected_source_code),
204 | )
205 |
206 | def test_source_codegen_in_all_behavior(self):
207 | """When the behavior is found in 'all_behaviors'
208 | it should used the found behavior for codegen."""
209 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0)
210 | actions = {0: Action(0), 1: Action(1)}
211 | sub_behavior = F_A_Behavior("Is above_zero", feature_condition, actions)
212 | self.behavior.graph.all_behaviors["Is above_zero"] = sub_behavior
213 | source_code = self.behavior.graph.generate_source_code()
214 | expected_source_code = "\n".join(
215 | (
216 | "from hebg.codegen import GeneratedBehavior",
217 | "",
218 | "class IsBetween0And1(GeneratedBehavior):",
219 | " def __call__(self, observation):",
220 | " edge_index = self.feature_conditions['Lesser or equal to 1 ?'](observation)",
221 | " if edge_index == 0:",
222 | " return self.actions['Action(0)'](observation)",
223 | " if edge_index == 1:",
224 | " edge_index_1 = self.feature_conditions['Greater or equal to 0 ?'](observation)",
225 | " if edge_index_1 == 0:",
226 | " return self.actions['Action(0)'](observation)",
227 | " if edge_index_1 == 1:",
228 | " return self.actions['Action(1)'](observation)",
229 | )
230 | )
231 |
232 | check.equal(
233 | source_code,
234 | expected_source_code,
235 | msg=_unidiff_output(source_code, expected_source_code),
236 | )
237 |
238 | def test_exec_codegen(self):
239 | """When the behavior is found in 'all_behaviors'
240 | it should used the found behavior for graph call."""
241 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0)
242 | actions = {0: Action(0), 1: Action(1)}
243 | sub_behavior = F_A_Behavior("Is above_zero", feature_condition, actions)
244 | self.behavior.graph.all_behaviors["Is above_zero"] = sub_behavior
245 | check_execution_for_values(
246 | self.behavior,
247 | "IsBetween0And1",
248 | (-1, 0, 1, 2),
249 | known_behaviors={"Is above_zero": sub_behavior},
250 | )
251 |
252 |
253 | class TestNestedBehaviorReuse:
254 | """Behaviors should be rolled when they are used mutliple times in nested subgraphs."""
255 |
256 | @pytest.fixture(autouse=True)
257 | def setup(self):
258 | feature_condition = FeatureCondition(name="fc1")
259 | actions = {0: Action(0), 1: Action(1)}
260 | behavior_0 = F_A_Behavior("behavior 0", feature_condition, actions)
261 |
262 | feature_condition = FeatureCondition(name="fc2")
263 | actions = {0: Action(0), 1: behavior_0}
264 | behavior_1 = F_A_Behavior("behavior 1", feature_condition, actions)
265 |
266 | feature_condition = FeatureCondition(name="fc3")
267 | actions = {0: behavior_0, 1: behavior_1}
268 | self.behavior = F_A_Behavior("behavior 2", feature_condition, actions)
269 |
270 | def test_nested_reuse_codegen(self):
271 | source_code = self.behavior.graph.generate_source_code()
272 | expected_source_code = "\n".join(
273 | (
274 | "from hebg.codegen import GeneratedBehavior",
275 | "",
276 | "class Behavior0(GeneratedBehavior):",
277 | " def __call__(self, observation):",
278 | " edge_index = self.feature_conditions['fc1'](observation)",
279 | " if edge_index == 0:",
280 | " return self.actions['Action(0)'](observation)",
281 | " if edge_index == 1:",
282 | " return self.actions['Action(1)'](observation)",
283 | "class Behavior2(GeneratedBehavior):",
284 | " def __call__(self, observation):",
285 | " edge_index = self.feature_conditions['fc3'](observation)",
286 | " if edge_index == 0:",
287 | " return self.known_behaviors['behavior 0'](observation)",
288 | " if edge_index == 1:",
289 | " edge_index_1 = self.feature_conditions['fc2'](observation)",
290 | " if edge_index_1 == 0:",
291 | " return self.actions['Action(0)'](observation)",
292 | " if edge_index_1 == 1:",
293 | " return self.known_behaviors['behavior 0'](observation)",
294 | "BEHAVIOR_TO_NAME = {",
295 | " 'behavior 0': Behavior0,",
296 | "}",
297 | )
298 | )
299 |
300 | check.equal(
301 | source_code,
302 | expected_source_code,
303 | msg=_unidiff_output(source_code, expected_source_code),
304 | )
305 |
306 |
307 | class TestFundamentalBehaviorReuse:
308 | """Fundamental Behaviors should never be abstracted."""
309 |
310 | @pytest.fixture(autouse=True)
311 | def setup(self):
312 | behavior_0 = FundamentalBehavior(Action(1))
313 |
314 | feature_condition = FeatureCondition(name="fc2")
315 | actions = {0: Action(0), 1: behavior_0}
316 | behavior_1 = F_A_Behavior("behavior 1", feature_condition, actions)
317 |
318 | feature_condition = FeatureCondition(name="fc3")
319 | actions = {0: behavior_0, 1: behavior_1}
320 | self.behavior = F_A_Behavior("behavior 2", feature_condition, actions)
321 |
322 | def test_nested_reuse_codegen(self):
323 | source_code = self.behavior.graph.generate_source_code()
324 | expected_source_code = "\n".join(
325 | (
326 | "from hebg.codegen import GeneratedBehavior",
327 | "",
328 | "class Behavior2(GeneratedBehavior):",
329 | " def __call__(self, observation):",
330 | " edge_index = self.feature_conditions['fc3'](observation)",
331 | " if edge_index == 0:",
332 | " return self.actions['Action(1)'](observation)",
333 | " if edge_index == 1:",
334 | " edge_index_1 = self.feature_conditions['fc2'](observation)",
335 | " if edge_index_1 == 0:",
336 | " return self.actions['Action(0)'](observation)",
337 | " if edge_index_1 == 1:",
338 | " return self.actions['Action(1)'](observation)",
339 | )
340 | )
341 |
342 | check.equal(
343 | source_code,
344 | expected_source_code,
345 | msg=_unidiff_output(source_code, expected_source_code),
346 | )
347 |
348 |
349 | class TestFBBBehavior:
350 | """(F-B-B) Behaviors should only be added once as a class."""
351 |
352 | @pytest.fixture(autouse=True)
353 | def setup(self):
354 | self.behavior = build_binary_sum_behavior()
355 |
356 | def test_classes_in_codegen(self):
357 | source_code = self.behavior.graph.generate_source_code()
358 | expected_classes = [
359 | "IsX1InBinary", # Only this one is used twice
360 | "IsSumOfLast3Binary2",
361 | ]
362 |
363 | for expected_class in expected_classes:
364 | check.equal(
365 | source_code.count(f"class {expected_class}"),
366 | 1,
367 | msg=f"Missing or duplicated class: {expected_class}\n{source_code}",
368 | )
369 |
370 | def test_exec_codegen(self):
371 | check_execution_for_values(
372 | self.behavior, "IsSumOfLast3Binary2", (0, 1, 3, 5, 15)
373 | )
374 |
375 |
376 | def check_execution_for_values(
377 | behavior: Behavior,
378 | class_name: str,
379 | values: Tuple[Any],
380 | known_behaviors: Optional[dict] = None,
381 | ):
382 | generated_source_code = behavior.graph.generate_source_code()
383 | exec(generated_source_code)
384 | CodeGenPolicy = locals()[class_name]
385 |
386 | actions, feature_conditions, behaviors = separate_nodes_by_type(behavior)
387 |
388 | _behaviors = behaviors.copy()
389 | while len(_behaviors) > 0:
390 | _, sub_behavior = _behaviors.popitem()
391 | if sub_behavior in behavior.graph.all_behaviors:
392 | sub_behavior = behavior.graph.all_behaviors[sub_behavior]
393 | sub_actions, sub_feature_conditions, sub_behaviors = separate_nodes_by_type(
394 | sub_behavior
395 | )
396 | actions.update(sub_actions)
397 | feature_conditions.update(sub_feature_conditions)
398 | _behaviors.update(sub_behaviors)
399 | behaviors.update(sub_behaviors)
400 |
401 | known_behaviors = known_behaviors if known_behaviors is not None else {}
402 | behaviors.update(known_behaviors)
403 |
404 | behavior_rebuilt = CodeGenPolicy(
405 | actions=actions,
406 | feature_conditions=feature_conditions,
407 | behaviors=behaviors,
408 | )
409 |
410 | for val in values:
411 | check.equal(behavior(val), behavior_rebuilt(val))
412 |
413 |
414 | def separate_nodes_by_type(behavior: Behavior):
415 | actions = {
416 | node.name: node for node in behavior.graph.nodes if isinstance(node, Action)
417 | }
418 | feature_conditions = {
419 | node.name: node
420 | for node in behavior.graph.nodes
421 | if isinstance(node, FeatureCondition)
422 | }
423 | behaviors = {
424 | node.name: node for node in behavior.graph.nodes if isinstance(node, Behavior)
425 | }
426 | return actions, feature_conditions, behaviors
427 |
428 |
429 | def _unidiff_output(expected: str, actual: str):
430 | """
431 | Helper function. Returns a string containing the unified diff of two multiline strings.
432 | """
433 |
434 | import difflib
435 |
436 | expected = [line for line in expected.split("\n")]
437 | actual = [line for line in actual.split("\n")]
438 |
439 | diff = difflib.unified_diff(expected, actual)
440 |
441 | return "\n".join(diff)
442 |
--------------------------------------------------------------------------------