├── tests ├── examples │ ├── __init__.py │ ├── behaviors │ │ ├── __init__.py │ │ ├── binary_sum.py │ │ ├── loop_with_alternative.py │ │ ├── report_example.py │ │ ├── basic_empty.py │ │ ├── loop_without_alternative.py │ │ └── basic.py │ └── feature_conditions.py ├── __init__.py ├── test_metaheuristics.py ├── test_behavior_empty.py ├── test_node.py ├── test_loop.py ├── test_behavior.py ├── test_pet_a_cat.py ├── test_call_graph.py ├── test_paper_basic_example.py └── test_code_generation.py ├── setup.py ├── commands └── coverage.ps1 ├── requirements.txt ├── .coveragerc ├── docs └── images │ ├── PetACatGraph.png │ └── PetACatGraphUnrolled.png ├── src └── hebg │ ├── metrics │ ├── __init__.py │ ├── utility │ │ ├── __init__.py │ │ └── binary_utility.py │ ├── complexity │ │ ├── __init__.py │ │ ├── utils.py │ │ └── complexities.py │ └── histograms.py │ ├── layouts │ ├── __init__.py │ ├── deterministic.py │ ├── metaheuristics.py │ └── metabased.py │ ├── __init__.py │ ├── behavior.py │ ├── node.py │ ├── requirements_graph.py │ ├── heb_graph.py │ ├── draw.py │ ├── graph.py │ ├── unrolling.py │ ├── call_graph.py │ └── codegen.py ├── .github └── workflows │ ├── python-tests.yml │ ├── python-pypi.yml │ └── python-coverage.yml ├── .pre-commit-config.yaml ├── shell.nix ├── pyproject.toml ├── CONTRIBUTING.rst ├── .gitignore └── README.rst /tests/examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | setup() 5 | -------------------------------------------------------------------------------- /commands/coverage.ps1: -------------------------------------------------------------------------------- 1 | pytest --cov=src --cov-report=html --cov-report=term 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | networkx >= 2.5.1 2 | matplotlib 3 | numpy 4 | tqdm 5 | scipy -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | show_missing = True 3 | omit = 4 | heb_graph/layouts/* 5 | -------------------------------------------------------------------------------- /docs/images/PetACatGraph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IRLL/HEB_graphs/HEAD/docs/images/PetACatGraph.png -------------------------------------------------------------------------------- /docs/images/PetACatGraphUnrolled.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IRLL/HEB_graphs/HEAD/docs/images/PetACatGraphUnrolled.png -------------------------------------------------------------------------------- /src/hebg/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """Module containing HEBGraph metrics.""" 5 | -------------------------------------------------------------------------------- /src/hebg/metrics/utility/__init__.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """Module for utility computation methods.""" 5 | 6 | from hebg.metrics.utility.binary_utility import binary_graphbased_utility 7 | 8 | __all__ = ["binary_graphbased_utility"] 9 | -------------------------------------------------------------------------------- /src/hebg/metrics/complexity/__init__.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """Module for complexity computation methods.""" 5 | 6 | from hebg.metrics.complexity.complexities import general_complexity, learning_complexity 7 | 8 | __all__ = ["general_complexity", "learning_complexity"] 9 | -------------------------------------------------------------------------------- /src/hebg/layouts/__init__.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """Module containing layouts to draw graphs.""" 5 | 6 | from hebg.layouts.deterministic import staircase_layout 7 | from hebg.layouts.metabased import leveled_layout_energy 8 | 9 | __all__ = ["staircase_layout", "leveled_layout_energy"] 10 | -------------------------------------------------------------------------------- /src/hebg/metrics/complexity/utils.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """Utility functions for complexity computation.""" 5 | 6 | from copy import deepcopy 7 | 8 | 9 | def update_sum_dict(dict1: dict, dict2: dict): 10 | """Give the sum of two dictionaries.""" 11 | dict1, dict2 = deepcopy(dict1), deepcopy(dict2) 12 | for key, val in dict2.items(): 13 | if not isinstance(val, dict): 14 | try: 15 | dict1[key] += val 16 | except KeyError: 17 | dict1[key] = val 18 | return dict1 19 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """Tests for the heb_graph package.""" 5 | 6 | from typing import Protocol 7 | from matplotlib import pyplot as plt 8 | 9 | 10 | class Graph(Protocol): 11 | def draw(self, ax, pos): 12 | """Draw the graph on a matplotlib axes.""" 13 | 14 | def nodes(self) -> list: 15 | """Return a list of nodes""" 16 | 17 | 18 | def plot_graph(graph: Graph, **kwargs): 19 | _, ax = plt.subplots() 20 | graph.draw(ax, **kwargs) 21 | plt.axis("off") # turn off axis 22 | plt.show() 23 | -------------------------------------------------------------------------------- /src/hebg/__init__.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """A structure for explainable hierarchical reinforcement learning""" 5 | 6 | from hebg.behavior import Behavior 7 | from hebg.heb_graph import HEBGraph 8 | from hebg.node import Action, EmptyNode, FeatureCondition, Node, StochasticAction 9 | from hebg.requirements_graph import build_requirement_graph 10 | 11 | __all__ = [ 12 | "HEBGraph", 13 | "Behavior", 14 | "Action", 15 | "EmptyNode", 16 | "FeatureCondition", 17 | "Node", 18 | "StochasticAction", 19 | "build_requirement_graph", 20 | ] 21 | -------------------------------------------------------------------------------- /.github/workflows/python-tests.yml: -------------------------------------------------------------------------------- 1 | name: Python tests 2 | 3 | on: ["push"] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: windows-latest 9 | strategy: 10 | matrix: 11 | python-version: ['3.10', '3.11', '3.12'] 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python ${{ matrix.python-version }} 15 | uses: actions/setup-python@v1 16 | with: 17 | python-version: ${{ matrix.python-version }} 18 | - name: Install dependencies 19 | run: | 20 | git submodule update --init --recursive 21 | python -m pip install --upgrade pip 22 | pip install .[dev] 23 | - name: Test with pytest 24 | run: | 25 | pytest tests 26 | -------------------------------------------------------------------------------- /.github/workflows/python-pypi.yml: -------------------------------------------------------------------------------- 1 | name: PyPi 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | release: 8 | types: [published] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v1 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install build twine 25 | - name: Build 26 | run: | 27 | python -m build 28 | - name: Publish 29 | env: 30 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 31 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 32 | run: | 33 | twine upload dist/* -------------------------------------------------------------------------------- /.github/workflows/python-coverage.yml: -------------------------------------------------------------------------------- 1 | name: Python coverage 2 | 3 | on: ["push"] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - name: Set up Python 3.10 11 | uses: actions/setup-python@v1 12 | with: 13 | python-version: '3.10' 14 | - name: Install dependencies 15 | run: | 16 | git submodule update --init --recursive 17 | python -m pip install --upgrade pip 18 | pip install -e .[dev] 19 | - name: Build coverage using pytest-cov 20 | run: | 21 | pytest --cov=src --cov-report=xml tests 22 | - name: Codacy Coverage Reporter 23 | uses: codacy/codacy-coverage-reporter-action@v1.3.0 24 | with: 25 | project-token: ${{ secrets.CODACY_PROJECT_TOKEN }} 26 | coverage-reports: coverage.xml 27 | -------------------------------------------------------------------------------- /tests/test_metaheuristics.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """Unit tests for the hebg.metaheuristics module.""" 5 | 6 | import numpy as np 7 | 8 | import pytest_check as check 9 | 10 | from hebg.layouts.metaheuristics import simulated_annealing 11 | 12 | 13 | def test_simulated_annealing(): 14 | """Simulated annealing must work on the simple x**2 case.""" 15 | 16 | def energy(x): 17 | return x**2 18 | 19 | step_size = 0.05 20 | 21 | def neighbor(x): 22 | return x + np.random.choice([-1, 1]) * step_size 23 | 24 | optimal_x = simulated_annealing( 25 | -1, energy, neighbor, max_iterations=1000, initial_temperature=5, verbose=1 26 | ) 27 | 28 | check.less_equal(abs(optimal_x), 3.1 * step_size) 29 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/astral-sh/ruff-pre-commit 9 | rev: v0.1.11 10 | hooks: 11 | - id: ruff 12 | args: [ --fix ] 13 | - id: ruff-format 14 | - repo: local 15 | hooks: 16 | - id: pytest-fast-check 17 | name: pytest-fast-check 18 | entry: pytest -m "not slow" 19 | stages: ["pre-commit"] 20 | language: system 21 | pass_filenames: false 22 | always_run: true 23 | - repo: local 24 | hooks: 25 | - id: pytest-check 26 | name: pytest-check 27 | entry: pytest 28 | stages: ["pre-push"] 29 | language: system 30 | pass_filenames: false 31 | always_run: true 32 | -------------------------------------------------------------------------------- /tests/examples/behaviors/__init__.py: -------------------------------------------------------------------------------- 1 | from tests.examples.behaviors.basic import ( 2 | FundamentalBehavior, 3 | F_A_Behavior, 4 | F_AA_Behavior, 5 | F_F_A_Behavior, 6 | ) 7 | from tests.examples.behaviors.basic_empty import ( 8 | E_A_Behavior, 9 | E_F_A_Behavior, 10 | F_E_A_Behavior, 11 | E_E_A_Behavior, 12 | ) 13 | from tests.examples.behaviors.binary_sum import build_binary_sum_behavior 14 | from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors 15 | from tests.examples.behaviors.loop_without_alternative import ( 16 | build_looping_behaviors_without_direct_alternatives, 17 | ) 18 | 19 | 20 | __all__ = [ 21 | "FundamentalBehavior", 22 | "F_A_Behavior", 23 | "F_AA_Behavior", 24 | "F_F_A_Behavior", 25 | "E_A_Behavior", 26 | "E_F_A_Behavior", 27 | "F_E_A_Behavior", 28 | "E_E_A_Behavior", 29 | "build_binary_sum_behavior", 30 | "build_looping_behaviors", 31 | "build_looping_behaviors_without_direct_alternatives", 32 | ] 33 | -------------------------------------------------------------------------------- /shell.nix: -------------------------------------------------------------------------------- 1 | with import { }; 2 | 3 | let 4 | pythonPackages = python3Packages; 5 | in pkgs.mkShell rec { 6 | name = "localDevPythonEnv"; 7 | venvDir = "./.venv"; 8 | buildInputs = [ 9 | # A Python interpreter including the 'venv' module is required to bootstrap 10 | # the environment. 11 | pythonPackages.python 12 | 13 | # This executes some shell code to initialize a venv in $venvDir before 14 | # dropping into the shell 15 | pythonPackages.venvShellHook 16 | 17 | # Those are dependencies that we would like to use from nixpkgs, which will 18 | # add them to PYTHONPATH and thus make them accessible from within the venv. 19 | pythonPackages.numpy 20 | pythonPackages.networkx 21 | pythonPackages.matplotlib 22 | pythonPackages.numpy 23 | pythonPackages.tqdm 24 | pythonPackages.scipy 25 | ]; 26 | 27 | # Run this command, only after creating the virtual environment 28 | postVenvCreation = '' 29 | pip install -e '.[dev]' 30 | ''; 31 | 32 | # Now we can execute any commands within the virtual environment. 33 | # This is optional and can be left out to run pip manually. 34 | postShellHook = '' 35 | ''; 36 | 37 | } 38 | -------------------------------------------------------------------------------- /tests/examples/behaviors/binary_sum.py: -------------------------------------------------------------------------------- 1 | from hebg import Action 2 | from tests.examples.behaviors import F_A_Behavior 3 | from tests.examples.feature_conditions import IsDivisibleFeatureCondition 4 | 5 | 6 | def build_binary_sum_behavior() -> F_A_Behavior: 7 | feature_condition = IsDivisibleFeatureCondition(2) 8 | actions = {0: Action(0), 1: Action(1)} 9 | binary_1 = F_A_Behavior("Is x1 in binary ?", feature_condition, actions) 10 | 11 | feature_condition = IsDivisibleFeatureCondition(2) 12 | actions = {0: Action(1), 1: Action(0)} 13 | binary_0 = F_A_Behavior("Is x0 in binary ?", feature_condition, actions) 14 | 15 | feature_condition = IsDivisibleFeatureCondition(4) 16 | actions = {0: Action(0), 1: binary_1} 17 | binary_11 = F_A_Behavior("Is x11 in binary ?", feature_condition, actions) 18 | 19 | feature_condition = IsDivisibleFeatureCondition(4) 20 | actions = {0: binary_0, 1: binary_1} 21 | binary_10_01 = F_A_Behavior("Is x01 or x10 in binary ?", feature_condition, actions) 22 | 23 | feature_condition = IsDivisibleFeatureCondition(8) 24 | actions = {0: binary_11, 1: binary_10_01} 25 | 26 | return F_A_Behavior("Is sum (of last 3 binary) 2 ?", feature_condition, actions) 27 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"] 3 | 4 | [project] 5 | name = "hebg" 6 | authors = [ 7 | { name = "Mathïs Fédérico" }, 8 | { name = "Mathïs Fédérico", email = "mathfederico@gmail.com" }, 9 | ] 10 | description = "HEBG: Hierarchial Explainations of Behavior as Graph" 11 | dynamic = ["version", "readme", "dependencies"] 12 | license = { text = "GPLv3 license" } 13 | requires-python = ">=3.7" 14 | 15 | [project.urls] 16 | repository = "https://github.com/IRLL/HEB_graphs" 17 | 18 | 19 | [tool.setuptools] 20 | license-files = ['LICEN[CS]E*', 'COPYING*', 'NOTICE*', 'AUTHORS*'] 21 | 22 | [project.optional-dependencies] 23 | dev = ["ruff", "pytest", "pytest-check", "pytest-mock", "pytest-cov", "mypy", "pre-commit"] 24 | 25 | [project.scripts] 26 | 27 | [tool.setuptools.dynamic] 28 | readme = { file = ["README.rst"] } 29 | dependencies = { file = ["requirements.txt"] } 30 | 31 | [tool.setuptools_scm] 32 | 33 | [tool.mypy] 34 | files = "hebg" 35 | check_untyped_defs = true 36 | disallow_any_generics = false 37 | disallow_incomplete_defs = true 38 | no_implicit_optional = true 39 | no_implicit_reexport = false 40 | strict_equality = true 41 | warn_redundant_casts = true 42 | warn_unused_ignores = true 43 | ignore_missing_imports = true 44 | -------------------------------------------------------------------------------- /tests/test_behavior_empty.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """Behavior of HEBGraphs with empty nodes.""" 5 | 6 | import pytest_check as check 7 | 8 | from hebg.node import Action 9 | from tests.examples.behaviors import ( 10 | E_F_A_Behavior, 11 | E_A_Behavior, 12 | F_E_A_Behavior, 13 | E_E_A_Behavior, 14 | ) 15 | 16 | 17 | def test_e_a_graph(): 18 | """(E-A) Empty nodes should skip to successor.""" 19 | action_id = 42 20 | behavior = E_A_Behavior("E_A", Action(action_id)) 21 | check.equal(behavior(None), action_id) 22 | 23 | 24 | def test_e_f_a_graph(): 25 | """(E-F-A) Empty should orient path properly in chain with Feature condition.""" 26 | behavior = E_F_A_Behavior("E_F_A") 27 | check.equal(behavior(-1), 0) 28 | check.equal(behavior(1), 1) 29 | 30 | 31 | def test_f_e_a_graph(): 32 | """(F-E-A) Feature condition should orient path properly in chain with Empty.""" 33 | behavior = F_E_A_Behavior("F_E_A") 34 | check.equal(behavior(1), 0) 35 | check.equal(behavior(-1), 1) 36 | 37 | 38 | def test_e_e_a_graph(): 39 | """(E-E-A) Empty should orient path properly in double chain.""" 40 | behavior = E_E_A_Behavior("E_E_A") 41 | check.equal(behavior(None), 0) 42 | -------------------------------------------------------------------------------- /src/hebg/metrics/utility/binary_utility.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """Simplest binary utility for HEBGraph.""" 5 | 6 | from typing import Dict, List 7 | 8 | from hebg import Behavior 9 | 10 | 11 | def binary_graphbased_utility( 12 | behavior: Behavior, 13 | solving_behaviors: List[Behavior], 14 | used_nodes: Dict[str, Dict[str, int]], 15 | ) -> bool: 16 | """Returns if the behavior is in the HEBGraph of any solving_behavior. 17 | 18 | Args: 19 | behavior: Behavior of which we want to compute the utility. 20 | solving_behaviors: list of behaviors that solves the task of interest. 21 | used_nodes: dictionary mapping behavior_id to nodes used in the behavior. 22 | 23 | Returns: 24 | True if the behavior in the HEBGraph of any solving_behavior. False otherwise. 25 | 26 | """ 27 | 28 | for solving_behavior in solving_behaviors: 29 | if behavior == solving_behavior: 30 | return True 31 | if behavior in solving_behavior.graph.nodes(): 32 | return True 33 | if ( 34 | behavior in used_nodes[solving_behavior] 35 | and used_nodes[solving_behavior][behavior] > 0 36 | ): 37 | return True 38 | return False 39 | -------------------------------------------------------------------------------- /tests/examples/feature_conditions.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from enum import Enum 3 | 4 | from hebg.node import FeatureCondition 5 | 6 | 7 | class ThresholdFeatureCondition(FeatureCondition): 8 | class Relation(Enum): 9 | GREATER_OR_EQUAL_TO = ">=" 10 | LESSER_OR_EQUAL_TO = "<=" 11 | GREATER_THAN = ">" 12 | LESSER_THAN = "<" 13 | 14 | def __init__( 15 | self, relation: Union[Relation, str] = ">=", threshold: float = 0, **kwargs 16 | ) -> None: 17 | """Threshold-based feature condition for scalar feature.""" 18 | self.relation = relation 19 | self.threshold = threshold 20 | self._relation = self.Relation(relation) 21 | display_name = self._relation.name.capitalize().replace("_", " ") 22 | name = f"{display_name} {threshold} ?" 23 | super().__init__(name=name, **kwargs) 24 | 25 | def __call__(self, observation: float) -> int: 26 | conditions = { 27 | self.Relation.GREATER_OR_EQUAL_TO: int(observation >= self.threshold), 28 | self.Relation.LESSER_OR_EQUAL_TO: int(observation <= self.threshold), 29 | self.Relation.GREATER_THAN: int(observation > self.threshold), 30 | self.Relation.LESSER_THAN: int(observation < self.threshold), 31 | } 32 | if self._relation in conditions: 33 | return conditions[self._relation] 34 | 35 | 36 | class IsDivisibleFeatureCondition(FeatureCondition): 37 | def __init__(self, number: int = 0) -> None: 38 | """Is divisible feature condition for scalar feature.""" 39 | self.number = number 40 | name = f"Is divisible by {number} ?" 41 | super().__init__(name=name, image=None) 42 | 43 | def __call__(self, observation: float) -> int: 44 | return int(observation // self.number == 1) 45 | -------------------------------------------------------------------------------- /src/hebg/behavior.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """Module for base Behavior.""" 5 | 6 | from __future__ import annotations 7 | 8 | from typing import TYPE_CHECKING 9 | 10 | from hebg.graph import compute_levels 11 | from hebg.node import Node 12 | 13 | if TYPE_CHECKING: 14 | from hebg.heb_graph import HEBGraph 15 | 16 | 17 | class Behavior(Node): 18 | """Abstract class for a Behavior as Node""" 19 | 20 | def __init__(self, name: str, image=None, **kwargs) -> None: 21 | super().__init__(name, "behavior", image=image, **kwargs) 22 | self._graph = None 23 | 24 | def __call__(self, observation, *args, **kwargs) -> None: 25 | """Use the behavior to get next actions. 26 | 27 | By default, uses the HEBGraph if it can be built. 28 | 29 | Args: 30 | observation: Observations of the environment. 31 | greedy: If true, the agent should act greedily. 32 | 33 | Returns: 34 | action: Action given by the behavior with current observation. 35 | 36 | """ 37 | return self.graph.__call__(observation, *args, **kwargs) 38 | 39 | def build_graph(self) -> HEBGraph: 40 | """Build the HEBGraph of this Behavior. 41 | 42 | Returns: 43 | The built HEBGraph. 44 | 45 | """ 46 | raise NotImplementedError 47 | 48 | @property 49 | def graph(self) -> HEBGraph: 50 | """Access to the Behavior's graph. 51 | 52 | Only build's the graph the first time called for efficiency. 53 | 54 | Returns: 55 | This Behavior's HEBGraph. 56 | 57 | """ 58 | if self._graph is None: 59 | self._graph = self.build_graph() 60 | compute_levels(self._graph) 61 | return self._graph 62 | -------------------------------------------------------------------------------- /tests/examples/behaviors/loop_with_alternative.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | from hebg import HEBGraph, Action, FeatureCondition, Behavior 4 | 5 | 6 | class HasItem(FeatureCondition): 7 | def __init__(self, item_name: str) -> None: 8 | self.item_name = item_name 9 | super().__init__(name=f"Has {item_name} ?", complexity=1.0) 10 | 11 | def __call__(self, observation: Any) -> int: 12 | return self.item_name in observation 13 | 14 | 15 | class GatherWood(Behavior): 16 | """Gather wood""" 17 | 18 | def __init__(self) -> None: 19 | """Gather wood""" 20 | super().__init__("Gather wood") 21 | 22 | def build_graph(self) -> HEBGraph: 23 | graph = HEBGraph(self) 24 | has_axe = HasItem("axe") 25 | graph.add_edge(has_axe, Action("Punch tree", complexity=2.0), index=False) 26 | graph.add_edge(has_axe, Behavior("Get new axe", complexity=1.0), index=False) 27 | graph.add_edge(has_axe, Action("Use axe on tree", complexity=1.0), index=True) 28 | return graph 29 | 30 | 31 | class GetNewAxe(Behavior): 32 | """Get new axe with wood""" 33 | 34 | def __init__(self) -> None: 35 | """Get new axe with wood""" 36 | super().__init__("Get new axe") 37 | 38 | def build_graph(self) -> HEBGraph: 39 | graph = HEBGraph(self) 40 | has_wood = HasItem("wood") 41 | graph.add_edge(has_wood, Behavior("Gather wood", complexity=1.0), index=False) 42 | graph.add_edge( 43 | has_wood, Action("Summon axe out of thin air", complexity=10.0), index=False 44 | ) 45 | graph.add_edge(has_wood, Action("Craft axe", complexity=1.0), index=True) 46 | return graph 47 | 48 | 49 | def build_looping_behaviors() -> List[Behavior]: 50 | behaviors: List[Behavior] = [GatherWood(), GetNewAxe()] 51 | all_behaviors = {behavior.name: behavior for behavior in behaviors} 52 | for behavior in behaviors: 53 | behavior.graph.all_behaviors = all_behaviors 54 | behavior.complexity = 5 55 | return behaviors 56 | -------------------------------------------------------------------------------- /tests/examples/behaviors/report_example.py: -------------------------------------------------------------------------------- 1 | from hebg import HEBGraph, Action, FeatureCondition, Behavior 2 | 3 | 4 | class Behavior0(Behavior): 5 | """Behavior 0""" 6 | 7 | def __init__(self) -> None: 8 | super().__init__("behavior 0") 9 | 10 | def build_graph(self) -> HEBGraph: 11 | graph = HEBGraph(self) 12 | feature = FeatureCondition("feature 0", complexity=1) 13 | graph.add_edge(feature, Action(0, complexity=1), index=False) 14 | graph.add_edge(feature, Action(1, complexity=1), index=True) 15 | return graph 16 | 17 | 18 | class Behavior1(Behavior): 19 | """Behavior 1""" 20 | 21 | def __init__(self) -> None: 22 | super().__init__("behavior 1") 23 | 24 | def build_graph(self) -> HEBGraph: 25 | graph = HEBGraph(self) 26 | feature_1 = FeatureCondition("feature 1", complexity=1) 27 | feature_2 = FeatureCondition("feature 2", complexity=1) 28 | graph.add_edge(feature_1, Behavior0(), index=False) 29 | graph.add_edge(feature_1, feature_2, index=True) 30 | graph.add_edge(feature_2, Action(0, complexity=1), index=False) 31 | graph.add_edge(feature_2, Action(2, complexity=1), index=True) 32 | return graph 33 | 34 | 35 | class Behavior2(Behavior): 36 | """Behavior 2""" 37 | 38 | def __init__(self) -> None: 39 | super().__init__("behavior 2") 40 | 41 | def build_graph(self) -> HEBGraph: 42 | graph = HEBGraph(self) 43 | feature_3 = FeatureCondition("feature 3", complexity=1) 44 | feature_4 = FeatureCondition("feature 4", complexity=1) 45 | feature_5 = FeatureCondition("feature 5", complexity=1) 46 | graph.add_edge(feature_3, feature_4, index=False) 47 | graph.add_edge(feature_3, feature_5, index=True) 48 | graph.add_edge(feature_4, Action(0, complexity=1), index=False) 49 | graph.add_edge(feature_4, Behavior1(), index=True) 50 | graph.add_edge(feature_5, Behavior1(), index=False) 51 | graph.add_edge(feature_5, Behavior0(), index=True) 52 | return graph 53 | -------------------------------------------------------------------------------- /src/hebg/layouts/deterministic.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | # pylint: disable=protected-access 4 | 5 | """Deterministic layouts""" 6 | 7 | import networkx as nx 8 | import numpy as np 9 | 10 | from hebg.graph import get_roots 11 | 12 | 13 | def staircase_layout(graph: nx.DiGraph, center=None): 14 | """Compute specific default positions for an DiGraph. 15 | 16 | Requires graph to have a 'nodes_by_level' attribute. 17 | 18 | Args: 19 | graph: A networkx DiGraph with 'nodes_by_level' attribute (see compute_levels). 20 | center (Optional): Center of the graph layout. 21 | 22 | Returns: 23 | pos: Positions of each node. 24 | 25 | """ 26 | 27 | def place_successors(pos, pos_by_level, node, level) -> int: 28 | if level not in pos_by_level: 29 | pos_by_level[level] = pos_by_level[level - 1] 30 | pos_by_level[level] = max(pos[node][0], pos_by_level[level]) 31 | succs = list(graph.successors(node)) 32 | if len(succs) == 0: 33 | return 1 34 | succs_order = np.argsort([graph.edges[node, succ]["index"] for succ in succs]) 35 | for index, succ_id in enumerate(succs_order): 36 | succ = succs[succ_id] 37 | if succ in pos: 38 | continue 39 | pos[succ] = [pos_by_level[level], -level] 40 | if index == 0: 41 | pos[node][0] = max(pos[node][0], pos[succ][0]) 42 | pos_by_level[level - 1] = max(pos_by_level[level - 1], pos[node][0]) 43 | n_succs = place_successors(pos, pos_by_level, succ, level + 1) 44 | pos_by_level[level] += n_succs 45 | return len(succs) 46 | 47 | graph, _ = nx.drawing.layout._process_params(graph, center, dim=2) 48 | pos = {} 49 | pos_by_level = {0: 0} 50 | for node in get_roots(graph): 51 | pos[node] = [pos_by_level[0], 0] 52 | pos_by_level[0] += place_successors(pos, pos_by_level, node, 1) 53 | return pos 54 | -------------------------------------------------------------------------------- /src/hebg/layouts/metaheuristics.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """Metaheuristics used for building layouts""" 5 | 6 | import numpy as np 7 | 8 | 9 | def simulated_annealing( 10 | initial, 11 | energy, 12 | neighbor, 13 | max_iterations: int = 1000, 14 | initial_temperature: float = 5, 15 | max_iters_without_new: int = np.inf, 16 | verbose: bool = 1, 17 | ): 18 | """Perform simulated annealing metaheuristic on an energy using a neighboring. 19 | 20 | See https://en.wikipedia.org/wiki/Simulated_annealing for more details. 21 | 22 | Args: 23 | initial: Initial variable position. 24 | energy: Function giving the energy (aka cost) of a given variable position. 25 | neighbor: Function giving a neighbor of a given variable position. 26 | max_iterations: Maximum number of iterations. 27 | initial_temperature: Initial temperature parameter, more is more random search. 28 | max_iters_without_new: Maximum number of iterations without a new best position. 29 | 30 | """ 31 | 32 | def prob_keep(temperature, delta_e): 33 | return min(1, np.exp(delta_e / temperature)) 34 | 35 | state = initial 36 | energy_pos = energy(state) 37 | iters_without_new = 0 38 | for k in range(max_iterations): 39 | new_state = neighbor(state) 40 | new_energy = energy(new_state) 41 | temperature = initial_temperature / (k + 1) 42 | iters_without_new += 1 43 | prob = prob_keep(temperature, energy_pos - new_energy) 44 | if np.random.random() < prob: 45 | if verbose == 1: 46 | print( 47 | f"{k}\t({prob:.0%})\t{energy_pos:.2f}->{new_energy:.2f}", end="\r" 48 | ) 49 | state, energy_pos = new_state, new_energy 50 | iters_without_new = 0 51 | 52 | if iters_without_new >= max_iters_without_new: 53 | break 54 | if verbose == 1: 55 | print() 56 | return state 57 | -------------------------------------------------------------------------------- /tests/examples/behaviors/basic_empty.py: -------------------------------------------------------------------------------- 1 | from hebg.node import Action, EmptyNode 2 | from hebg.behavior import Behavior 3 | from hebg.heb_graph import HEBGraph 4 | 5 | from tests.examples.feature_conditions import ThresholdFeatureCondition 6 | 7 | 8 | class E_A_Behavior(Behavior): 9 | """Empty behavior""" 10 | 11 | def __init__(self, name: str, action: Action) -> None: 12 | super().__init__(name) 13 | self.action = action 14 | 15 | def build_graph(self) -> HEBGraph: 16 | graph = HEBGraph(self) 17 | graph.add_edge(EmptyNode("empty"), self.action) 18 | return graph 19 | 20 | 21 | class E_F_A_Behavior(Behavior): 22 | """Double layer empty then feature conditions behavior""" 23 | 24 | def build_graph(self) -> HEBGraph: 25 | graph = HEBGraph(self) 26 | empty = EmptyNode("empty") 27 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0) 28 | 29 | graph.add_edge(empty, feature_condition) 30 | for i, edge_index in zip(range(2), (0, 1)): 31 | action = Action(i) 32 | graph.add_edge(feature_condition, action, index=edge_index) 33 | 34 | return graph 35 | 36 | 37 | class F_E_A_Behavior(Behavior): 38 | """Double layer feature conditions then empty behavior""" 39 | 40 | def build_graph(self) -> HEBGraph: 41 | graph = HEBGraph(self) 42 | 43 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0) 44 | empty_0 = EmptyNode("empty_0") 45 | empty_1 = EmptyNode("empty_1") 46 | 47 | graph.add_edge(feature_condition, empty_0, index=int(True)) 48 | graph.add_edge(feature_condition, empty_1, index=int(False)) 49 | 50 | graph.add_edge(empty_0, Action(0)) 51 | graph.add_edge(empty_1, Action(1)) 52 | 53 | return graph 54 | 55 | 56 | class E_E_A_Behavior(Behavior): 57 | """Double layer empty behavior""" 58 | 59 | def build_graph(self) -> HEBGraph: 60 | graph = HEBGraph(self) 61 | 62 | empty_0 = EmptyNode("empty_0") 63 | empty_1 = EmptyNode("empty_1") 64 | 65 | graph.add_edge(empty_0, empty_1) 66 | graph.add_edge(empty_1, Action(0)) 67 | 68 | return graph 69 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | Contributing to HEBG 2 | ==================== 3 | 4 | We are happy to receive contributions in the form of **pull requests** via Github. 5 | Feel free to fork the repository, implement your changes and create a merge request to the **main** branch. 6 | 7 | Git Commit Messages 8 | ~~~~~~~~~~~~~~~~~~~ 9 | 10 | Commits should start with a Capital letter and should be written in present tense (e.g. ``:tada: Add cool new feature`` instead of ``:tada: Added cool new feature``). 11 | You should also start your commit message with **one** applicable emoji. This does not only look great but also makes you rethink what to add to a commit. Make many but small commits! 12 | 13 | 14 | .. list-table:: Title 15 | :header-rows: 1 16 | 17 | * - Emoji 18 | - Description 19 | * - `:tada:` ``:tada:`` 20 | - When you added a cool new feature 21 | * - `:bug:` ``:bug:`` 22 | - When you refactored / When you fixed a bug 23 | * - `:fire:` ``:fire:`` 24 | - When you removed something. 25 | * - `:truck:` ``:truck:`` 26 | - When you moved / renamed something. 27 | * - `:wrench:` ``:wrench:`` 28 | - When you refactored / improved a small piece of code. 29 | * - `:hammer:` ``:hammer:`` 30 | - When you refactored / improved large parts of the code. 31 | * - `:sparkles:` ``:sparkles:`` 32 | - When you improved code quality (pylint, PEP, ...). 33 | * - `:art:` ``:art:`` 34 | - When you improved / added design assets. 35 | * - `:rocket:` ``:rocket:`` 36 | - When you improved performance. 37 | * - `:memo:` ``:memo:`` 38 | - When you wrote documentation. 39 | * - `:umbrella:` ``:umbrella:`` 40 | - When you improve coverage. 41 | * - `:twisted_rightwards_arrows:` ``:twisted_rightwards_arrows:`` 42 | - When you merged a branch. 43 | 44 | This section was inspired by `This repository `_. 45 | 46 | Version Numbers 47 | --------------- 48 | 49 | Version numbers will be assigned according to the `Semantic Versioning `_. scheme. 50 | This means, given a version number MAJOR.MINOR.PATCH, we will increment the: 51 | 52 | 1. MAJOR version when we make incompatible API changes, 53 | 2. MINOR version when we add functionality in a backwards compatible manner, and 54 | 3. PATCH version when we make backwards compatible bug fixes. 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # VsCode 132 | .vscode 133 | -------------------------------------------------------------------------------- /tests/examples/behaviors/loop_without_alternative.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from hebg import HEBGraph, Action, FeatureCondition, Behavior 4 | 5 | 6 | class ReachForest(Behavior): 7 | """Reach forest""" 8 | 9 | def __init__(self) -> None: 10 | """Reach forest""" 11 | super().__init__("Reach forest") 12 | 13 | def build_graph(self) -> HEBGraph: 14 | graph = HEBGraph(self) 15 | is_in_other_zone = FeatureCondition("Is in other zone ?") 16 | graph.add_edge(is_in_other_zone, Behavior("Reach other zone"), index=False) 17 | graph.add_edge(is_in_other_zone, Action("> forest"), index=True) 18 | is_in_other_zone = FeatureCondition("Is in meadow ?") 19 | graph.add_edge(is_in_other_zone, Behavior("Reach meadow"), index=False) 20 | graph.add_edge(is_in_other_zone, Action("> forest"), index=True) 21 | return graph 22 | 23 | 24 | class ReachOtherZone(Behavior): 25 | """Reach other zone""" 26 | 27 | def __init__(self) -> None: 28 | """Reach other zone""" 29 | super().__init__("Reach other zone") 30 | 31 | def build_graph(self) -> HEBGraph: 32 | graph = HEBGraph(self) 33 | is_in_forest = FeatureCondition("Is in forest ?") 34 | graph.add_edge(is_in_forest, Behavior("Reach forest"), index=False) 35 | graph.add_edge(is_in_forest, Action("> other zone"), index=True) 36 | is_in_other_zone = FeatureCondition("Is in meadow ?") 37 | graph.add_edge(is_in_other_zone, Behavior("Reach meadow"), index=False) 38 | graph.add_edge(is_in_other_zone, Action("> other zone"), index=True) 39 | return graph 40 | 41 | 42 | class ReachMeadow(Behavior): 43 | """Reach meadow""" 44 | 45 | def __init__(self) -> None: 46 | """Reach meadow""" 47 | super().__init__("Reach meadow") 48 | 49 | def build_graph(self) -> HEBGraph: 50 | graph = HEBGraph(self) 51 | is_in_forest = FeatureCondition("Is in forest ?") 52 | graph.add_edge(is_in_forest, Behavior("Reach forest"), index=False) 53 | graph.add_edge(is_in_forest, Action("> meadow"), index=True) 54 | is_in_other_zone = FeatureCondition("Is in other zone ?") 55 | graph.add_edge(is_in_other_zone, Behavior("Reach other zone"), index=False) 56 | graph.add_edge(is_in_other_zone, Action("> meadow"), index=True) 57 | return graph 58 | 59 | 60 | def build_looping_behaviors_without_direct_alternatives() -> List[Behavior]: 61 | behaviors: List[Behavior] = [ 62 | ReachForest(), 63 | ReachOtherZone(), 64 | ReachMeadow(), 65 | ] 66 | all_behaviors = {behavior.name: behavior for behavior in behaviors} 67 | for behavior in behaviors: 68 | behavior.graph.all_behaviors = all_behaviors 69 | return behaviors 70 | -------------------------------------------------------------------------------- /tests/test_node.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """Unit tests for the hebg.node module.""" 5 | 6 | import pytest 7 | import pytest_check as check 8 | 9 | from hebg.node import Node, Action, FeatureCondition, EmptyNode 10 | 11 | 12 | class TestNode: 13 | """Node""" 14 | 15 | @pytest.fixture(autouse=True) 16 | def setup(self): 17 | """Initialize variables.""" 18 | 19 | def test_node_type(self): 20 | """should have correct node_type and raise ValueError otherwise.""" 21 | with pytest.raises(ValueError): 22 | Node("", "") 23 | for node_type in ("action", "feature_condition", "behavior", "empty"): 24 | node = Node("", node_type) 25 | check.equal(node.type, node_type) 26 | 27 | def test_node_name(self): 28 | """should have name as attribute.""" 29 | name = "node_name" 30 | node = Node(name, "empty") 31 | check.equal(node.name, name) 32 | 33 | def test_node_call(self): 34 | """should raise NotImplementedError on call.""" 35 | node = Node("", "empty") 36 | with pytest.raises(NotImplementedError): 37 | node(None) 38 | 39 | 40 | class TestAction: 41 | """Action""" 42 | 43 | @pytest.fixture(autouse=True) 44 | def setup(self): 45 | """Initialize variables.""" 46 | 47 | def test_node_type(self): 48 | """should have 'action' as node_type.""" 49 | node = Action("", "") 50 | check.equal(node.type, "action") 51 | 52 | def test_node_call(self): 53 | """should return Action.action when called.""" 54 | action = "action_action" 55 | node = Action(action, "action_name") 56 | check.equal(node(None), action) 57 | check.equal(node.action, action) 58 | 59 | 60 | class TestFeatureCondition: 61 | """FeatureCondition""" 62 | 63 | @pytest.fixture(autouse=True) 64 | def setup(self): 65 | """Initialize variables.""" 66 | 67 | def test_node_type(self): 68 | """should have 'feature_condition' as node_type.""" 69 | node = FeatureCondition("") 70 | check.equal(node.type, "feature_condition") 71 | 72 | def test_node_call(self): 73 | """should raise NotImplementedError on call.""" 74 | node = FeatureCondition("") 75 | with pytest.raises(NotImplementedError): 76 | node(None) 77 | 78 | 79 | class TestEmptyNode: 80 | """EmptyNode""" 81 | 82 | @pytest.fixture(autouse=True) 83 | def setup(self): 84 | """Initialize variables.""" 85 | 86 | def test_node_type(self): 87 | """should have 'empty' as node_type.""" 88 | node = EmptyNode("") 89 | check.equal(node.type, "empty") 90 | 91 | def test_node_call(self): 92 | """should return 1 when called.""" 93 | node = EmptyNode("") 94 | check.equal(node(None), 1) 95 | -------------------------------------------------------------------------------- /src/hebg/node.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """Module for base Node classes.""" 5 | 6 | import dis 7 | from typing import Any, List 8 | 9 | import numpy as np 10 | 11 | 12 | def bytecode_complexity(obj): 13 | """Compute the number of instructions in the bytecode of a given obj.""" 14 | return len(list(dis.get_instructions(obj))) 15 | 16 | 17 | class Node: 18 | NODE_TYPES = ("action", "feature_condition", "behavior", "empty") 19 | 20 | def __init__( 21 | self, 22 | name: str, 23 | node_type: str, 24 | complexity: int = None, 25 | image=None, 26 | ) -> None: 27 | """Base Node class for any HEBGraph. 28 | 29 | Args: 30 | name (str): A UNIQUE name representing this node. 31 | node_type (str): One of {action, feature_condition, behavior, empty}. 32 | complexity (int, optional): Given individual complexity of the node. 33 | If None, uses the number of bytecode instructions of init and call. 34 | Defaults to None. 35 | image (2d array, optional): Image to represent the node. Defaults to None. 36 | 37 | Raises: 38 | ValueError: If node_type has an unexpected value. 39 | """ 40 | self.name = name 41 | self.image = image 42 | if node_type not in self.NODE_TYPES: 43 | raise ValueError( 44 | f"node_type ({node_type})" 45 | f"not in authorised node_types ({self.NODE_TYPES})." 46 | ) 47 | self.type = node_type 48 | self.complexity = complexity 49 | 50 | def __call__(self, observation: Any) -> Any: 51 | raise NotImplementedError 52 | 53 | def __str__(self) -> str: 54 | return self.name 55 | 56 | def __eq__(self, o: object) -> bool: 57 | return self.name == str(o) 58 | 59 | def __hash__(self) -> int: 60 | return self.name.__hash__() 61 | 62 | def __repr__(self) -> str: 63 | return self.name 64 | 65 | 66 | class Action(Node): 67 | """Node representing an action in an HEBGraph.""" 68 | 69 | def __init__(self, action: Any, name: str = None, **kwargs) -> None: 70 | self.action = action 71 | super().__init__(self._get_name(name), "action", **kwargs) 72 | 73 | def _get_name(self, name): 74 | """Get the default name of the action if None is given.""" 75 | return f"Action({self.action})" if name is None else name 76 | 77 | def __call__(self, observation: Any) -> Any: 78 | return self.action 79 | 80 | 81 | class StochasticAction(Action): 82 | """Node representing a stochastic choice between actions in an HEBGraph.""" 83 | 84 | def __init__( 85 | self, actions: List[Action], probs: list, name: str, image=None 86 | ) -> None: 87 | super().__init__(actions, name=name, image=image) 88 | self.probs = probs 89 | 90 | def __call__(self, observation): 91 | selected_action = np.random.choice(self.action, p=self.probs) 92 | return selected_action(observation) 93 | 94 | 95 | class FeatureCondition(Node): 96 | """Node representing a feature condition in an HEBGraph.""" 97 | 98 | def __init__(self, name: str = None, **kwargs) -> None: 99 | super().__init__(name, "feature_condition", **kwargs) 100 | 101 | def __call__(self, observation: Any) -> int: 102 | raise NotImplementedError 103 | 104 | 105 | class EmptyNode(Node): 106 | """Node representing an empty node in an HEBGraph.""" 107 | 108 | def __init__(self, name: str) -> None: 109 | super().__init__(name, "empty") 110 | 111 | def __call__(self, observation: Any) -> int: 112 | return int(True) 113 | -------------------------------------------------------------------------------- /src/hebg/layouts/metabased.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | # pylint: disable=protected-access 4 | 5 | """Metaheuristics based layouts""" 6 | 7 | from copy import deepcopy 8 | 9 | import networkx as nx 10 | import numpy as np 11 | 12 | from hebg.layouts.metaheuristics import simulated_annealing 13 | 14 | 15 | def leveled_layout_energy( 16 | graph: nx.DiGraph, center=None, metaheuristic=simulated_annealing 17 | ): 18 | """Compute positions for a leveled DiGraph using a metaheuristic to minimize energy. 19 | 20 | Requires each node to have a 'level' attribute. 21 | 22 | Args: 23 | graph: A networkx DiGraph. 24 | center (Optional): Center of the graph layout. 25 | 26 | Returns: 27 | pos: Positions of each node. 28 | nodes_by_level: List of nodes by levels. 29 | 30 | """ 31 | graph, center = nx.drawing.layout._process_params(graph, center, dim=2) 32 | 33 | nodes_by_level = graph.graph["nodes_by_level"] 34 | pos = {} 35 | step_size = 1 / max(len(nodes_by_level[level]) for level in nodes_by_level) 36 | spacing = np.arange(0, 1, step=step_size) 37 | for level in nodes_by_level: 38 | n_nodes_in_level = len(nodes_by_level[level]) 39 | if n_nodes_in_level > 1: 40 | positions = np.linspace( 41 | 0, len(spacing) - 1, n_nodes_in_level, endpoint=True, dtype=np.int32 42 | ) 43 | positions = spacing[positions] 44 | else: 45 | positions = [spacing[(len(spacing) - 1) // 2]] 46 | 47 | for i, node in enumerate(nodes_by_level[level]): 48 | pos[node] = [level, positions[i]] 49 | 50 | def energy(pos, nodes_strenght=1, edges_strenght=2): 51 | def dist(start, stop): 52 | x_arr, y_arr = np.array(start), np.array(stop) 53 | return np.linalg.norm(x_arr - y_arr) 54 | 55 | energy = 0 56 | for level in nodes_by_level: 57 | for node in nodes_by_level[level]: 58 | energy += nodes_strenght * sum( 59 | np.square(dist(pos[node], pos[n])) 60 | for n in nodes_by_level[level] 61 | if n != node 62 | ) 63 | energy -= sum( 64 | edges_strenght 65 | / abs( 66 | max(1, graph.nodes[node]["level"] - graph.nodes[pred]["level"]) 67 | ) 68 | / max(1e-6, dist(pos[node], pos[pred])) 69 | for pred in graph.predecessors(node) 70 | ) 71 | energy -= sum( 72 | edges_strenght 73 | / abs( 74 | max(1, graph.nodes[node]["level"] - graph.nodes[succ]["level"]) 75 | ) 76 | / max(1e-6, dist(pos[node], pos[succ])) 77 | for succ in graph.successors(node) 78 | ) 79 | 80 | return energy 81 | 82 | def neighbor(pos: dict): 83 | pos_copy = deepcopy(pos) 84 | nodes_list = list(pos.keys()) 85 | choosen_node_id = int(np.random.randint(len(nodes_list))) 86 | choosen_node = nodes_list[choosen_node_id] 87 | choosen_level = graph.nodes(data="level")[choosen_node] 88 | new_pos = [pos_copy[choosen_node][0], np.random.choice(spacing)] 89 | for node in nodes_by_level[choosen_level]: 90 | if node != choosen_node and np.all(np.isclose(new_pos, pos_copy[node])): 91 | pos_copy[choosen_node], pos_copy[node] = ( 92 | pos_copy[node], 93 | pos_copy[choosen_node], 94 | ) 95 | return pos_copy 96 | pos_copy[choosen_node] = new_pos 97 | return pos_copy 98 | 99 | return metaheuristic(pos, energy, neighbor) 100 | -------------------------------------------------------------------------------- /src/hebg/requirements_graph.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | # pylint: disable=arguments-differ 4 | 5 | """Module for building underlying requirement graphs based on a set of behaviors.""" 6 | 7 | from __future__ import annotations 8 | 9 | from copy import deepcopy 10 | from typing import Dict, List 11 | 12 | from networkx import DiGraph, descendants 13 | 14 | from hebg.behavior import Behavior 15 | from hebg.graph import compute_levels 16 | from hebg.heb_graph import HEBGraph 17 | from hebg.node import EmptyNode 18 | 19 | 20 | def build_requirement_graph(behaviors: List[Behavior]) -> DiGraph: 21 | """Builds a DiGraph of the requirements induced by a list of behaviors. 22 | 23 | Args: 24 | behaviors: List of Behaviors to build the requirement graph from. 25 | 26 | Returns: 27 | The requirement graph induced by the given list of behaviors. 28 | 29 | """ 30 | 31 | try: 32 | heb_graphs = [behavior.graph for behavior in behaviors] 33 | except NotImplementedError as error: 34 | user_msg = "All behaviors given must be able to build an HEBGraph" 35 | raise NotImplementedError(user_msg) from error 36 | 37 | requirements_graph = DiGraph() 38 | for behavior in behaviors: 39 | requirements_graph.add_node(behavior) 40 | 41 | requirement_degree = {} 42 | 43 | for graph in heb_graphs: 44 | requirement_degree[graph.behavior] = {} 45 | for node in graph.nodes(): 46 | if not isinstance(node, EmptyNode): 47 | continue 48 | requirement_degree = _cut_alternatives_to_empty_node( 49 | graph, node, requirement_degree 50 | ) 51 | 52 | for graph in heb_graphs: 53 | for node in graph.nodes(): 54 | if not isinstance(node, Behavior): 55 | continue 56 | if node not in requirement_degree[graph.behavior]: 57 | requirement_degree[graph.behavior][node] = 0 58 | requirement_degree[graph.behavior][node] += 1 59 | 60 | index = 0 61 | for graph in heb_graphs: 62 | for node in graph.nodes(): 63 | if ( 64 | not isinstance(node, Behavior) 65 | or requirement_degree[graph.behavior][node] == 0 66 | ): 67 | continue 68 | if node not in requirements_graph.nodes(): 69 | requirements_graph.add_node(node) 70 | index = len(list(requirements_graph.successors(node))) + 1 71 | requirements_graph.add_edge(node, graph.behavior, index=index) 72 | 73 | compute_levels(requirements_graph) 74 | return requirements_graph 75 | 76 | 77 | def _cut_alternatives_to_empty_node( 78 | graph: HEBGraph, 79 | node: EmptyNode, 80 | requirement_degree: Dict[Behavior, Dict[Behavior, int]], 81 | ) -> Dict[Behavior, Dict[Behavior, int]]: 82 | successor = list(graph.successors(node))[0] 83 | empty_index = graph.edges[node, successor]["index"] 84 | alternatives = graph.predecessors(successor) 85 | alternatives = [ 86 | alt_node 87 | for alt_node in alternatives 88 | if graph.edges[alt_node, successor]["index"] == empty_index 89 | ] 90 | cut_graph = deepcopy(graph) 91 | for alternative in alternatives: 92 | cut_graph.remove_edge(alternative, successor) 93 | for alternative in alternatives: 94 | following_behaviors = [ 95 | following_node 96 | for following_node in descendants(cut_graph, alternative) 97 | if isinstance(following_node, Behavior) 98 | ] 99 | for following_behavior in following_behaviors: 100 | if following_behavior not in requirement_degree[graph.behavior]: 101 | requirement_degree[graph.behavior][following_behavior] = 0 102 | requirement_degree[graph.behavior][following_behavior] -= 1 103 | 104 | return requirement_degree 105 | -------------------------------------------------------------------------------- /tests/examples/behaviors/basic.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from hebg.node import Action, FeatureCondition 4 | from hebg.behavior import Behavior 5 | from hebg.heb_graph import HEBGraph 6 | 7 | from tests.examples.feature_conditions import ThresholdFeatureCondition 8 | 9 | 10 | class FundamentalBehavior(Behavior): 11 | """Fundamental behavior based on an Action.""" 12 | 13 | def __init__(self, action: Action) -> None: 14 | self.action = action 15 | name = action.name + "_behavior" 16 | super().__init__(name, image=action.image) 17 | 18 | def build_graph(self) -> HEBGraph: 19 | graph = HEBGraph(self) 20 | graph.add_node(self.action) 21 | return graph 22 | 23 | 24 | class AA_Behavior(Behavior): 25 | """Double root fundamental behavior""" 26 | 27 | def __init__(self, name: str, any_mode: str) -> None: 28 | super().__init__(name, image=None) 29 | self.any_mode = any_mode 30 | 31 | def build_graph(self) -> HEBGraph: 32 | graph = HEBGraph(self, any_mode=self.any_mode) 33 | 34 | graph.add_node(Action(0)) 35 | graph.add_node(Action(1)) 36 | 37 | return graph 38 | 39 | 40 | class F_A_Behavior(Behavior): 41 | """Single feature condition behavior""" 42 | 43 | def __init__( 44 | self, 45 | name: str, 46 | feature_condition: FeatureCondition, 47 | actions: Dict[int, Action], 48 | ) -> None: 49 | """Single feature condition behavior 50 | 51 | Args: 52 | name (str): Name of the behavior. 53 | feature_condition (FeatureCondition): Feature_condition used in behavior. 54 | actions (Dict[int, Action]): Mapping from feature_condition output to actions. 55 | """ 56 | super().__init__(name) 57 | self.actions = actions 58 | self.feature_condition = feature_condition 59 | 60 | def build_graph(self) -> HEBGraph: 61 | graph = HEBGraph(self) 62 | for fc_output, action in self.actions.items(): 63 | graph.add_edge(self.feature_condition, action, index=fc_output) 64 | return graph 65 | 66 | 67 | class F_F_A_Behavior(Behavior): 68 | """Double layer feature conditions behavior""" 69 | 70 | def __init__(self, name: str = "F_F_A", *args, **kwargs) -> None: 71 | super().__init__(name=name, *args, **kwargs) 72 | 73 | def build_graph(self) -> HEBGraph: 74 | graph = HEBGraph(self) 75 | 76 | feature_condition_1 = ThresholdFeatureCondition(relation=">=", threshold=0) 77 | feature_condition_2 = ThresholdFeatureCondition(relation="<=", threshold=1) 78 | feature_condition_3 = ThresholdFeatureCondition(relation=">=", threshold=-1) 79 | 80 | graph.add_edge(feature_condition_1, feature_condition_2, index=True) 81 | graph.add_edge(feature_condition_1, feature_condition_3, index=False) 82 | 83 | for action, edge_index in zip(range(2, 4), (1, 0)): 84 | graph.add_edge(feature_condition_2, Action(action), index=edge_index) 85 | 86 | for action, edge_index in zip(range(2), (0, 1)): 87 | graph.add_edge(feature_condition_3, Action(action), index=edge_index) 88 | 89 | return graph 90 | 91 | 92 | class F_AA_Behavior(Behavior): 93 | """Feature condition with mutliple actions on same index.""" 94 | 95 | def __init__(self, name: str = "F_AA") -> None: 96 | super().__init__(name, image=None) 97 | 98 | def build_graph(self) -> HEBGraph: 99 | graph = HEBGraph(self) 100 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0) 101 | 102 | graph.add_edge(feature_condition, Action(0), index=int(True)) 103 | graph.add_edge(feature_condition, Action(1), index=int(False)) 104 | graph.add_edge(feature_condition, Action(2), index=int(False)) 105 | 106 | return graph 107 | 108 | 109 | class AF_A_Behavior(Behavior): 110 | """Double root with feature condition and action""" 111 | 112 | def __init__(self, name: str, any_mode: str) -> None: 113 | super().__init__(name, image=None) 114 | self.any_mode = any_mode 115 | 116 | def build_graph(self) -> HEBGraph: 117 | graph = HEBGraph(self, any_mode=self.any_mode) 118 | 119 | graph.add_node(Action(0)) 120 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0) 121 | 122 | graph.add_edge(feature_condition, Action(1), index=int(True)) 123 | graph.add_edge(feature_condition, Action(2), index=int(False)) 124 | 125 | return graph 126 | -------------------------------------------------------------------------------- /src/hebg/metrics/complexity/complexities.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """General complexity.""" 5 | 6 | from typing import TYPE_CHECKING, Dict, Tuple 7 | 8 | from hebg.behavior import Behavior 9 | from hebg.metrics.complexity.utils import update_sum_dict 10 | from hebg.node import Action 11 | 12 | if TYPE_CHECKING: 13 | from hebg.node import Node 14 | 15 | 16 | def general_complexity( 17 | behavior: Behavior, 18 | used_nodes_all: Dict["Node", Dict["Node", int]], 19 | scomplexity, 20 | kcomplexity, 21 | previous_used_nodes: Dict["Node", int] = None, 22 | ) -> Tuple[float]: 23 | """Compute the general complexity of a Behavior with used nodes. 24 | 25 | Using the number of time each node is used in its HEBGraph and based on the increase of 26 | complexity given by 'kcomplexity', and the saved complexity using behaviors given by 27 | 'saved_complexity', we sum the general complexity of an behavior and the total saved complexity. 28 | 29 | Args: 30 | behavior: The Behavior for which we compute the general complexity. 31 | used_nodes_all: Dictionary of dictionary of the number of times each nodes was used in the 32 | past, and thus for each node. Not being in a dictionary is counted as not being used. 33 | scomplexity: Callable taking as input (node, n_used, n_previous_used), where node is 34 | the node concerned, n_used is the number of time this node is used and n_previous_used 35 | is the number of time this node was used before that, returning the saved complexity 36 | thanks to using a behavior. 37 | kcomplexity: Callable taking as input (node, n_used), where node is the node concerned and 38 | n_used is the number of time this node is used, returning the accumulated complexity. 39 | previous_used_nodes: Dictionary of the number of times each nodes was used in the past, 40 | not being in the dictionary is counted as 0. 41 | 42 | Returns: 43 | Tuple composed of the general complexity and the total saved complexity. 44 | 45 | """ 46 | 47 | previous_used_nodes = previous_used_nodes if previous_used_nodes else {} 48 | 49 | total_complexity = 0 50 | saved_complexity = 0 51 | 52 | for node, n_used in used_nodes_all[behavior].items(): 53 | n_previous_used = ( 54 | previous_used_nodes[node] if node in previous_used_nodes else 0 55 | ) 56 | 57 | if isinstance(node, Behavior) and node in used_nodes_all: 58 | node_complexity, saved_node_complexity = general_complexity( 59 | node, 60 | used_nodes_all, 61 | scomplexity=scomplexity, 62 | kcomplexity=kcomplexity, 63 | previous_used_nodes=previous_used_nodes.copy(), 64 | ) 65 | previous_used_nodes = update_sum_dict( 66 | previous_used_nodes, used_nodes_all[node] 67 | ) 68 | total_complexity += saved_node_complexity * kcomplexity(node, n_used) 69 | saved_complexity += saved_node_complexity * kcomplexity(node, n_used) 70 | else: 71 | node_complexity = node.complexity 72 | 73 | total_complexity += node_complexity * kcomplexity(node, n_used) 74 | 75 | if isinstance(node, (Behavior, Action)): 76 | saved_complexity += node_complexity * scomplexity( 77 | node, n_used, n_previous_used 78 | ) 79 | 80 | previous_used_nodes = update_sum_dict(previous_used_nodes, {node: n_used}) 81 | 82 | return total_complexity - saved_complexity, saved_complexity 83 | 84 | 85 | def learning_complexity( 86 | behavior: Behavior, 87 | used_nodes_all: Dict["Node", Dict["Node", int]], 88 | previous_used_nodes=None, 89 | ): 90 | """Compute the learning complexity of a Behavior with used nodes. 91 | 92 | Using the number of time each node is used in its HEBGraph we compute the learning 93 | complexity of a behavior and the total saved complexity. 94 | 95 | Args: 96 | behavior: The Behavior for which we compute the learning complexity. 97 | used_nodes_all: Dictionary of dictionary of the number of times each nodes was used in the 98 | past, and thus for each node. Not being in a dictionary is counted as not being used. 99 | previous_used_nodes: Dictionary of the number of times each nodes was used in the past, 100 | not being in the dictionary is counted as not being used. 101 | 102 | Returns: 103 | Tuple composed of the learning complexity and the total saved complexity. 104 | 105 | """ 106 | return general_complexity( 107 | behavior=behavior, 108 | used_nodes_all=used_nodes_all, 109 | previous_used_nodes=previous_used_nodes, 110 | scomplexity=lambda node, k, p: max(0, min(k, p + k - 1)), 111 | kcomplexity=lambda node, k: k, 112 | ) 113 | -------------------------------------------------------------------------------- /tests/test_loop.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pytest_check as check 3 | 4 | import networkx as nx 5 | from hebg.unrolling import unroll_graph 6 | from tests import plot_graph 7 | 8 | from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors 9 | from tests.examples.behaviors.loop_without_alternative import ( 10 | build_looping_behaviors_without_direct_alternatives, 11 | ) 12 | 13 | 14 | class TestLoopAlternative: 15 | """Tests for the loop with alternative example""" 16 | 17 | @pytest.fixture(autouse=True) 18 | def setup_method(self): 19 | self.gather_wood, self.get_new_axe = build_looping_behaviors() 20 | 21 | def test_unroll_gather_wood(self): 22 | draw = False 23 | unrolled_graph = unroll_graph(self.gather_wood.graph, add_prefix=True) 24 | if draw: 25 | plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True) 26 | 27 | expected_graph = nx.DiGraph() 28 | expected_graph.add_edge("Has axe", "Punch tree") 29 | expected_graph.add_edge("Has axe", "Cut tree with axe") 30 | expected_graph.add_edge("Has axe", "Has wood") 31 | 32 | # Expected sub-behavior 33 | expected_graph.add_edge("Has wood", "Gather wood") 34 | expected_graph.add_edge("Has wood", "Craft axe") 35 | expected_graph.add_edge("Has wood", "Summon axe out of thin air") 36 | check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph)) 37 | 38 | def test_unroll_get_new_axe(self): 39 | draw = False 40 | unrolled_graph = unroll_graph(self.get_new_axe.graph, add_prefix=True) 41 | if draw: 42 | plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True) 43 | 44 | expected_graph = nx.DiGraph() 45 | expected_graph.add_edge("Has wood", "Has axe") 46 | expected_graph.add_edge("Has wood", "Craft new axe") 47 | expected_graph.add_edge("Has wood", "Summon axe out of thin air") 48 | 49 | # Expected sub-behavior 50 | expected_graph.add_edge("Has axe", "Punch tree") 51 | expected_graph.add_edge("Has axe", "Cut tree with axe") 52 | expected_graph.add_edge("Has axe", "Get new axe") 53 | check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph)) 54 | 55 | def test_unroll_gather_wood_cutting_alternatives(self): 56 | draw = False 57 | unrolled_graph = unroll_graph( 58 | self.gather_wood.graph, add_prefix=True, cut_looping_alternatives=True 59 | ) 60 | if draw: 61 | plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True) 62 | 63 | expected_graph = nx.DiGraph() 64 | expected_graph.add_edge("Has axe", "Punch tree") 65 | expected_graph.add_edge("Has axe", "Has wood") 66 | expected_graph.add_edge("Has axe", "Use axe") 67 | 68 | # Expected sub-behavior 69 | expected_graph.add_edge("Has wood", "Summon axe of out thin air") 70 | expected_graph.add_edge("Has wood", "Craft axe") 71 | 72 | check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph)) 73 | 74 | def test_unroll_get_new_axe_cutting_alternatives(self): 75 | draw = False 76 | unrolled_graph = unroll_graph( 77 | self.get_new_axe.graph, add_prefix=True, cut_looping_alternatives=True 78 | ) 79 | if draw: 80 | plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True) 81 | 82 | expected_graph = nx.DiGraph( 83 | [ 84 | ("Has wood", "Has axe"), 85 | ("Has wood", "Craft new axe"), 86 | ("Has wood", "Summon axe out of thin air"), 87 | # Expected sub-behavior 88 | ("Has axe", "Punch tree"), 89 | ("Has axe", "Cut tree with axe"), 90 | ] 91 | ) 92 | check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph)) 93 | 94 | @pytest.mark.xfail 95 | def test_unroll_root_alternative_reach_forest(self): 96 | ( 97 | reach_forest, 98 | _reach_other_zone, 99 | _reach_meadow, 100 | ) = build_looping_behaviors_without_direct_alternatives() 101 | draw = False 102 | unrolled_graph = unroll_graph( 103 | reach_forest.graph, 104 | add_prefix=True, 105 | cut_looping_alternatives=True, 106 | ) 107 | if draw: 108 | plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True) 109 | 110 | expected_graph = nx.DiGraph( 111 | [ 112 | # ("Root", "Is in other zone ?"), 113 | # ("Root", "Is in meadow ?"), 114 | ("Is in other zone ?", "Reach other zone"), 115 | ("Is in other zone ?", "Go to forest"), 116 | ("Is in meadow ?", "Go to forest"), 117 | ("Is in meadow ?", "Reach meadow>Is in other zones ?"), 118 | ("Reach meadow>Is in other zone ?", "Reach meadow>Reach other zone"), 119 | ("Reach meadow>Is in other zone ?", "Reach meadow>Go to forest"), 120 | ] 121 | ) 122 | check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph)) 123 | -------------------------------------------------------------------------------- /tests/test_behavior.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """Behavior of HEBGraphs when called.""" 5 | 6 | import pytest 7 | import pytest_check as check 8 | from pytest_mock import MockerFixture 9 | 10 | from hebg.behavior import Behavior 11 | from hebg.heb_graph import HEBGraph 12 | from hebg.node import Action 13 | 14 | from tests.examples.behaviors import FundamentalBehavior, F_A_Behavior, F_F_A_Behavior 15 | from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors 16 | from tests.examples.feature_conditions import ThresholdFeatureCondition 17 | 18 | 19 | class TestBehavior: 20 | 21 | """Behavior""" 22 | 23 | @pytest.fixture(autouse=True) 24 | def setup(self): 25 | """Initialize variables.""" 26 | self.node = Behavior("behavior_name") 27 | 28 | def test_node_type(self): 29 | """should have 'behavior' as node_type.""" 30 | check.equal(self.node.type, "behavior") 31 | 32 | def test_node_call(self, mocker: MockerFixture): 33 | """should use graph on call.""" 34 | mocker.patch("hebg.behavior.Behavior.graph") 35 | self.node(None) 36 | check.is_true(self.node.graph.called) 37 | 38 | def test_build_graph(self): 39 | """should raise NotImplementedError when build_graph is called.""" 40 | with pytest.raises(NotImplementedError): 41 | self.node.build_graph() 42 | 43 | def test_graph(self, mocker: MockerFixture): 44 | """should build graph and compute its levels if, and only if, 45 | the graph is not yet built. 46 | """ 47 | mocker.patch("hebg.behavior.Behavior.build_graph") 48 | mocker.patch("hebg.behavior.compute_levels") 49 | self.node.graph 50 | check.is_true(self.node.build_graph.called) 51 | check.is_true(self.node.build_graph.called) 52 | 53 | mocker.patch("hebg.behavior.Behavior.build_graph") 54 | mocker.patch("hebg.behavior.compute_levels") 55 | self.node.graph 56 | check.is_false(self.node.build_graph.called) 57 | check.is_false(self.node.build_graph.called) 58 | 59 | 60 | class TestPathfinding: 61 | def test_fundamental_behavior(self): 62 | """Fundamental behavior (single action) should return its action.""" 63 | action_id = 42 64 | behavior = FundamentalBehavior(Action(action_id)) 65 | check.equal(behavior(None), action_id) 66 | 67 | def test_feature_condition_single(self): 68 | """Feature condition should orient path properly.""" 69 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0) 70 | actions = {0: Action(0), 1: Action(1)} 71 | behavior = F_A_Behavior("F_A", feature_condition, actions) 72 | check.equal(behavior(1), 1) 73 | check.equal(behavior(-1), 0) 74 | 75 | def test_feature_conditions_chained(self): 76 | """Feature condition should orient path properly in double chain.""" 77 | behavior = F_F_A_Behavior("F_F_A") 78 | check.equal(behavior(-2), 0) 79 | check.equal(behavior(-1), 1) 80 | check.equal(behavior(1), 2) 81 | check.equal(behavior(2), 3) 82 | 83 | def test_looping_resolve(self): 84 | """Loops with alternatives should be ignored.""" 85 | _gather_wood, get_axe = build_looping_behaviors() 86 | check.equal(get_axe({}), "Punch tree") 87 | 88 | 89 | class TestCostBehavior: 90 | def test_choose_root_of_lesser_cost(self): 91 | """Should choose root of lesser cost.""" 92 | 93 | expected_action = "EXPECTED" 94 | 95 | class AAA_Behavior(Behavior): 96 | def __init__(self) -> None: 97 | super().__init__("AAA") 98 | 99 | def build_graph(self) -> HEBGraph: 100 | graph = HEBGraph(self) 101 | graph.add_node(Action(0, complexity=2)) 102 | graph.add_node(Action(expected_action, complexity=1)) 103 | graph.add_node(Action(2, complexity=3)) 104 | return graph 105 | 106 | behavior = AAA_Behavior() 107 | check.equal(behavior(None), expected_action) 108 | 109 | def test_not_path_of_least_cost(self): 110 | """Should choose path of larger complexity if individual costs lead to it.""" 111 | 112 | class AF_A_Behavior(Behavior): 113 | 114 | """Double root with feature condition and action""" 115 | 116 | def __init__(self) -> None: 117 | super().__init__("AF_A") 118 | 119 | def build_graph(self) -> HEBGraph: 120 | graph = HEBGraph(self) 121 | 122 | graph.add_node(Action(0, complexity=1.5)) 123 | feature_condition = ThresholdFeatureCondition( 124 | relation=">=", threshold=0, complexity=1.0 125 | ) 126 | 127 | graph.add_edge( 128 | feature_condition, Action(1, complexity=1.0), index=int(True) 129 | ) 130 | graph.add_edge( 131 | feature_condition, Action(2, complexity=1.0), index=int(False) 132 | ) 133 | 134 | return graph 135 | 136 | behavior = AF_A_Behavior() 137 | check.equal(behavior(1), 1) 138 | check.equal(behavior(-1), 2) 139 | -------------------------------------------------------------------------------- /src/hebg/heb_graph.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | # pylint: disable=arguments-differ 4 | 5 | """Module containing the HEBGraph base class.""" 6 | 7 | from __future__ import annotations 8 | 9 | from typing import Any, Dict, List, Optional, Tuple 10 | 11 | from matplotlib.axes import Axes 12 | from networkx import DiGraph 13 | 14 | from hebg.behavior import Behavior 15 | from hebg.call_graph import CallGraph 16 | from hebg.codegen import get_hebg_source 17 | from hebg.draw import draw_hebgraph 18 | from hebg.graph import get_roots 19 | from hebg.node import Node 20 | from hebg.unrolling import unroll_graph 21 | 22 | 23 | class HEBGraph(DiGraph): 24 | """Base class for Hierchical Explanation of Behavior as Graphs. 25 | 26 | An HEBGraph is a DiGraph, and as such stores nodes and directed edges with 27 | optional data, or attributes. 28 | 29 | But nodes of an HEBGraph are not arbitrary. 30 | Leaf nodes can either be an Action or a Behavior. 31 | Other nodes can either be a FeatureCondition or an EmptyNode. 32 | 33 | An HEBGraph determines a behavior, it can be called with an observation 34 | to return the action given by this behavior. 35 | 36 | An HEBGraph edges are directed and indexed, 37 | this indexing for path making when calling the graph. 38 | 39 | As in a DiGraph loops are allowed but multiple (parallel) edges are not. 40 | 41 | Args: 42 | behavior: The Behavior object from which this graph is built. 43 | all_behaviors: A dictionary of behavior, this can be used to avoid cirular definitions using 44 | the behavior names as anchor instead of the behavior object itself. 45 | incoming_graph_data: Additional data to include in the graph. 46 | 47 | """ 48 | 49 | NODES_COLORS = {"feature_condition": "blue", "action": "red", "behavior": "orange"} 50 | EDGES_COLORS = { 51 | 0: "red", 52 | 1: "green", 53 | 2: "blue", 54 | 3: "yellow", 55 | 4: "purple", 56 | 5: "cyan", 57 | 6: "gray", 58 | } 59 | 60 | def __init__( 61 | self, 62 | behavior: Behavior, 63 | all_behaviors: Dict[str, Behavior] = None, 64 | incoming_graph_data=None, 65 | **attr, 66 | ): 67 | self.behavior = behavior 68 | self.all_behaviors = all_behaviors if all_behaviors is not None else {} 69 | 70 | self._unrolled_graph = None 71 | self.call_graph: Optional[CallGraph] = None 72 | 73 | super().__init__(incoming_graph_data=incoming_graph_data, **attr) 74 | 75 | def add_node(self, node_for_adding: Node, **attr): 76 | node = node_for_adding 77 | color = attr.pop("color", None) 78 | attr.pop("type", None) 79 | attr.pop("image", None) 80 | if color is None: 81 | try: 82 | color = self.NODES_COLORS[node.type] 83 | except KeyError: 84 | color = None 85 | super().add_node(node, type=node.type, color=color, image=node.image, **attr) 86 | 87 | def add_edge(self, u_of_edge: Node, v_of_edge: Node, index: int = 1, **attr): 88 | for node in (u_of_edge, v_of_edge): 89 | if node not in self.nodes(): 90 | self.add_node(node) 91 | 92 | color = attr.pop("color", None) 93 | if color is None: 94 | try: 95 | color = self.EDGES_COLORS[index] 96 | except KeyError: 97 | color = "black" 98 | super().add_edge(u_of_edge, v_of_edge, index=index, color=color, **attr) 99 | 100 | @property 101 | def unrolled_graph(self) -> HEBGraph: 102 | """Access to the unrolled behavior graph. 103 | 104 | The unrolled behavior graph as the same behavior but every behavior node is recursively replaced 105 | by it's behavior graph if it can be computed. 106 | 107 | Only build's the graph the first time called for efficiency. 108 | 109 | Returns: 110 | This HEBGraph's unrolled HEBGraph. 111 | 112 | """ 113 | if self._unrolled_graph is None: 114 | self._unrolled_graph = unroll_graph(self) 115 | return self._unrolled_graph 116 | 117 | def __call__( 118 | self, 119 | observation, 120 | call_graph: Optional[CallGraph] = None, 121 | ) -> Any: 122 | if call_graph is None: 123 | call_graph = CallGraph() 124 | call_graph.add_root(heb_node=self.behavior, heb_graph=self) 125 | self.call_graph = call_graph 126 | return self.call_graph.call_nodes(self.roots, observation, heb_graph=self) 127 | 128 | @property 129 | def roots(self) -> List[Node]: 130 | """Roots of the behavior graph (nodes without predecessors).""" 131 | return get_roots(self) 132 | 133 | def generate_source_code(self) -> str: 134 | """Generated source code of the behavior from graph.""" 135 | return get_hebg_source(self) 136 | 137 | def draw( 138 | self, ax: "Axes", **kwargs 139 | ) -> Tuple["Axes", Dict[Node, Tuple[float, float]]]: 140 | """Draw the HEBGraph on the given Axis. 141 | 142 | Args: 143 | ax: The matplotlib ax to draw on. 144 | 145 | Kwargs: 146 | fontcolor: Font color to use for all texts. 147 | 148 | Returns: 149 | The resulting matplotlib Axis drawn on and a dictionary of each node position. 150 | 151 | """ 152 | return draw_hebgraph(self, ax, **kwargs) 153 | -------------------------------------------------------------------------------- /src/hebg/draw.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | import math 5 | from typing import TYPE_CHECKING, Dict, Optional, Tuple 6 | 7 | import matplotlib.patches as mpatches 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | from matplotlib.axes import Axes 11 | from matplotlib.legend import Legend 12 | from matplotlib.legend_handler import HandlerPatch 13 | from networkx import draw_networkx_edges, spring_layout 14 | from scipy.spatial import ConvexHull # pylint: disable=no-name-in-module 15 | 16 | from hebg.graph import draw_networkx_nodes_images 17 | from hebg.layouts import staircase_layout 18 | from hebg.unrolling import group_behaviors_points 19 | 20 | if TYPE_CHECKING: 21 | from hebg.heb_graph import HEBGraph 22 | from hebg.node import Node 23 | 24 | 25 | def draw_hebgraph( 26 | graph: "HEBGraph", 27 | ax: Axes, 28 | pos: Optional[Dict["Node", Tuple[float, float]]] = None, 29 | fontcolor: str = "black", 30 | draw_hulls: bool = False, 31 | show_all_hulls: bool = False, 32 | ) -> Tuple["Axes", Dict["Node", Tuple[float, float]]]: 33 | if len(list(graph.nodes())) == 0: 34 | return 35 | 36 | ax.set_title(graph.behavior.name, fontdict={"color": "orange"}) 37 | plt.setp(ax.spines.values(), color="orange") 38 | 39 | if pos is None: 40 | if len(graph.roots) > 0: 41 | pos = staircase_layout(graph) 42 | else: 43 | pos = spring_layout(graph) 44 | draw_networkx_nodes_images(graph, pos, ax=ax, img_zoom=0.5) 45 | 46 | draw_networkx_edges( 47 | graph, 48 | pos, 49 | ax=ax, 50 | arrowsize=20, 51 | arrowstyle="-|>", 52 | min_source_margin=0, 53 | min_target_margin=10, 54 | node_shape="s", 55 | node_size=1500, 56 | edge_color=[color for _, _, color in graph.edges(data="color")], 57 | ) 58 | 59 | legend = draw_graph_legend(graph, ax) 60 | plt.setp(legend.get_texts(), color=fontcolor) 61 | 62 | if draw_hulls: 63 | group_and_draw_hulls(graph, pos, ax, show_all_hulls=show_all_hulls) 64 | 65 | 66 | def draw_graph_legend(graph: "HEBGraph", ax: Axes) -> Legend: 67 | used_node_types = [node_type for _, node_type in graph.nodes(data="type")] 68 | legend_patches = [ 69 | mpatches.Patch(facecolor="none", edgecolor=color, label=node_type.capitalize()) 70 | for node_type, color in graph.NODES_COLORS.items() 71 | if node_type in used_node_types and node_type in graph.NODES_COLORS 72 | ] 73 | used_edge_indexes = [index for _, _, index in graph.edges(data="index")] 74 | legend_arrows = [ 75 | mpatches.FancyArrow( 76 | *(0, 0, 1, 0), 77 | facecolor=color, 78 | edgecolor="none", 79 | label=str(index) if index > 1 else f"{str(bool(index))} ({index})", 80 | ) 81 | for index, color in graph.EDGES_COLORS.items() 82 | if index in used_edge_indexes and index in graph.EDGES_COLORS 83 | ] 84 | 85 | # Draw the legend 86 | legend = ax.legend( 87 | fancybox=True, 88 | framealpha=0, 89 | fontsize="x-large", 90 | loc="upper right", 91 | handles=legend_patches + legend_arrows, 92 | handler_map={ 93 | # Patch arrows with fancy arrows in legend 94 | mpatches.FancyArrow: HandlerPatch( 95 | patch_func=lambda width, height, **kwargs: mpatches.FancyArrow( 96 | *(0, 0.5 * height, width, 0), 97 | width=0.2 * height, 98 | length_includes_head=True, 99 | head_width=height, 100 | overhang=0.5, 101 | ) 102 | ), 103 | }, 104 | ) 105 | 106 | return legend 107 | 108 | 109 | def group_and_draw_hulls(graph: "HEBGraph", pos, ax: Axes, show_all_hulls: bool): 110 | grouped_points = group_behaviors_points(pos, graph) 111 | if not show_all_hulls: 112 | key_count = {key[-1]: 0 for key in grouped_points} 113 | for key in grouped_points: 114 | key_count[key[-1]] += 1 115 | grouped_points = { 116 | key: points 117 | for key, points in grouped_points.items() 118 | if key_count[key[-1]] > 1 and (len(key) == 1 or key[-1] != key[-2]) 119 | } 120 | 121 | for group_key, points in grouped_points.items(): 122 | stretch = 0.5 - 0.05 * (len(group_key) - 1) 123 | if len(points) >= 3: 124 | draw_convex_hull(points, ax, stretch=stretch, lw=3, color="orange") 125 | 126 | 127 | def draw_convex_hull(points, ax: "Axes", stretch=0.3, n_points=30, **kwargs): 128 | points = np.array(points) 129 | convh = ConvexHull(points) # Get the first convexHull (speeds up the next process) 130 | points = buffer_points(points[convh.vertices], stretch=stretch, samples=n_points) 131 | 132 | hull = ConvexHull(points) 133 | hull_cycle = np.concatenate((hull.vertices, hull.vertices[:1])) 134 | ax.plot(points[hull_cycle, 0], points[hull_cycle, 1], **kwargs) 135 | 136 | 137 | def buffer_points(inside_points, stretch, samples): 138 | new_points = [] 139 | for point in inside_points: 140 | new_points += points_in_circum(point, stretch, samples) 141 | new_points = np.array(new_points) 142 | hull = ConvexHull(new_points) 143 | return new_points[hull.vertices] 144 | 145 | 146 | def points_in_circum(points, radius, samples=100): 147 | return [ 148 | ( 149 | points[0] + math.cos(2 * math.pi / samples * x) * radius, 150 | points[1] + math.sin(2 * math.pi / samples * x) * radius, 151 | ) 152 | for x in range(0, samples + 1) 153 | ] 154 | -------------------------------------------------------------------------------- /tests/test_pet_a_cat.py: -------------------------------------------------------------------------------- 1 | """This examples shows how could we hierarchicaly build a behavior to pet a cat. 2 | 3 | Here is the hierarchical structure that we would want: 4 | 5 | ``` 6 | PetACat: 7 | IsThereACatAround ? 8 | -> Yes: 9 | PetNearbyCat 10 | -> No: 11 | LookForACat 12 | 13 | PetNearbyCat: 14 | IsYourHandNearTheCat ? 15 | -> Yes: 16 | Pet 17 | -> No: 18 | MoveYourHandNearTheCat 19 | ``` 20 | 21 | """ 22 | 23 | import pytest 24 | import pytest_check as check 25 | 26 | import matplotlib.pyplot as plt 27 | 28 | from hebg import HEBGraph, Action, FeatureCondition, Behavior 29 | from hebg.unrolling import unroll_graph 30 | from tests.test_code_generation import _unidiff_output 31 | 32 | 33 | class Pet(Action): 34 | def __init__(self) -> None: 35 | super().__init__(action="Pet") 36 | 37 | 38 | class IsYourHandNearTheCat(FeatureCondition): 39 | def __init__(self, hand) -> None: 40 | super().__init__(name="Is hand near the cat ?") 41 | self.hand = hand 42 | 43 | def __call__(self, observation): 44 | # Could be a very complex function that returns 1 is the hand is near the cat else 0. 45 | if observation["cat"] == observation[self.hand]: 46 | return int(True) # 1 47 | return int(False) # 0 48 | 49 | 50 | class MoveYourHandNearTheCat(Behavior): 51 | def __init__(self) -> None: 52 | super().__init__(name="Move slowly your hand near the cat") 53 | 54 | def __call__(self, observation, *args, **kwargs) -> Action: 55 | # Could be a very complex function that returns actions from any given observation 56 | return Action("Move hand to cat") 57 | 58 | 59 | class PetNearbyCat(Behavior): 60 | def __init__(self) -> None: 61 | super().__init__(name="Pet nearby cat") 62 | 63 | def build_graph(self) -> HEBGraph: 64 | graph = HEBGraph(self) 65 | is_hand_near_cat = IsYourHandNearTheCat(hand="hand") 66 | graph.add_edge(is_hand_near_cat, MoveYourHandNearTheCat(), index=int(False)) 67 | graph.add_edge(is_hand_near_cat, Pet(), index=int(True)) 68 | 69 | return graph 70 | 71 | 72 | class IsThereACatAround(FeatureCondition): 73 | def __init__(self) -> None: 74 | super().__init__(name="Is there a cat around ?") 75 | 76 | def __call__(self, observation): 77 | # Could be a very complex function that returns 1 is there is a cat around else 0. 78 | if "cat" in observation: 79 | return int(True) # 1 80 | return int(False) # 0 81 | 82 | 83 | class LookForACat(Behavior): 84 | def __init__(self) -> None: 85 | super().__init__(name="Look for a nearby cat") 86 | 87 | def __call__(self, observation, *args, **kwargs) -> Action: 88 | # Could be a very complex function that returns actions from any given observation 89 | return Action("Move to a cat") 90 | 91 | 92 | class PetACat(Behavior): 93 | def __init__(self) -> None: 94 | super().__init__(name="Pet a cat") 95 | 96 | def build_graph(self) -> HEBGraph: 97 | graph = HEBGraph(self) 98 | is_a_cat_around = IsThereACatAround() 99 | graph.add_edge(is_a_cat_around, LookForACat(), index=int(False)) 100 | graph.add_edge(is_a_cat_around, PetNearbyCat(), index=int(True)) 101 | return graph 102 | 103 | 104 | class TestPetACat: 105 | """PetACat example""" 106 | 107 | @pytest.fixture(autouse=True) 108 | def setup_method(self): 109 | self.pet_nearby_cat_behavior = PetNearbyCat() 110 | self.pet_a_cat_behavior = PetACat() 111 | 112 | def test_call(self): 113 | """should give expected call""" 114 | observation = { 115 | "cat": "sofa", 116 | "hand": "computer", 117 | } 118 | 119 | # Call on observation 120 | action = self.pet_a_cat_behavior(observation) 121 | check.equal(action, Action("Move hand to cat")) 122 | 123 | def test_pet_a_cat_graph_edges(self): 124 | """should give expected edges for PetACat""" 125 | # Obtain networkx graph 126 | graph = self.pet_a_cat_behavior.graph 127 | check.equal( 128 | set(graph.edges(data="index")), 129 | { 130 | ("Is there a cat around ?", "Look for a nearby cat", 0), 131 | ("Is there a cat around ?", "Pet nearby cat", 1), 132 | }, 133 | ) 134 | 135 | def test_pet_nearby_cat_graph_edges(self): 136 | """should give expected edges for PetNearbyCat""" 137 | # Obtain networkx graph 138 | graph = self.pet_nearby_cat_behavior.graph 139 | check.equal( 140 | set(graph.edges(data="index")), 141 | { 142 | ("Is hand near the cat ?", "Move slowly your hand near the cat", 0), 143 | ("Is hand near the cat ?", "Action(Pet)", 1), 144 | }, 145 | ) 146 | 147 | def test_draw(self): 148 | """should be able to draw without error""" 149 | fig, ax = plt.subplots() 150 | self.pet_a_cat_behavior.graph.draw(ax) 151 | plt.close(fig) 152 | 153 | def test_draw_unrolled(self): 154 | """should be able to draw without error""" 155 | fig, ax = plt.subplots() 156 | unrolled_graph = unroll_graph(self.pet_a_cat_behavior.graph) 157 | unrolled_graph.draw(ax) 158 | plt.close(fig) 159 | 160 | @pytest.mark.filterwarnings("ignore:Could not load graph for behavior") 161 | def test_codegen(self): 162 | """should generate expected source code""" 163 | code = self.pet_a_cat_behavior.graph.generate_source_code() 164 | expected_code = "\n".join( 165 | ( 166 | "from hebg.codegen import GeneratedBehavior", 167 | "", 168 | "# Require 'Look for a nearby cat' behavior to be given.", 169 | "# Require 'Move slowly your hand near the cat' behavior to be given.", 170 | "class PetACat(GeneratedBehavior):", 171 | " def __call__(self, observation):", 172 | " edge_index = self.feature_conditions['Is there a cat around ?'](observation)", 173 | " if edge_index == 0:", 174 | " return self.known_behaviors['Look for a nearby cat'](observation)", 175 | " if edge_index == 1:", 176 | " edge_index_1 = self.feature_conditions['Is hand near the cat ?'](observation)", 177 | " if edge_index_1 == 0:", 178 | " return self.known_behaviors['Move slowly your hand near the cat'](observation)", 179 | " if edge_index_1 == 1:", 180 | " return self.actions['Action(Pet)'](observation)", 181 | ) 182 | ) 183 | check.equal(code, expected_code, _unidiff_output(code, expected_code)) 184 | -------------------------------------------------------------------------------- /src/hebg/graph.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | # pylint: disable=protected-access 4 | 5 | """Additional utility functions for networkx graphs.""" 6 | 7 | from typing import TYPE_CHECKING, Any, Dict, List 8 | 9 | from matplotlib.axes import Axes 10 | from matplotlib.offsetbox import AnnotationBbox, OffsetImage 11 | from networkx import DiGraph 12 | 13 | if TYPE_CHECKING: 14 | from hebg.heb_graph import HEBGraph 15 | from hebg.node import Node 16 | 17 | 18 | def get_roots(graph: DiGraph): 19 | """Finds roots in a DiGraph. 20 | 21 | Args: 22 | graph: A networkx DiGraph. 23 | 24 | Returns: 25 | List of root nodes. 26 | 27 | """ 28 | roots = [] 29 | for node in graph.nodes(): 30 | if len(list(graph.predecessors(node))) == 0: 31 | roots.append(node) 32 | return roots 33 | 34 | 35 | def get_nodes_by_level(graph: DiGraph) -> Dict[int, Any]: 36 | """Get the dictionary of nodes by level. 37 | 38 | Requires nodes to have a 'level' attribute. 39 | 40 | Args: 41 | graph: A networkx DiGraph. 42 | 43 | Returns: 44 | Dictionary of nodes by level. 45 | 46 | """ 47 | nodes_by_level = {} 48 | for node in graph.nodes(): 49 | level = graph.nodes[node]["level"] 50 | try: 51 | nodes_by_level[level].append(node) 52 | except KeyError: 53 | nodes_by_level[level] = [node] 54 | 55 | graph.graph["nodes_by_level"] = nodes_by_level 56 | graph.graph["depth"] = max(level for level in nodes_by_level) 57 | return nodes_by_level 58 | 59 | 60 | def compute_levels(graph: DiGraph): 61 | """Compute the hierachical levels of all DiGraph nodes. 62 | 63 | Adds the attribute 'level' to each node in the given graph. 64 | Adds the attribute 'nodes_by_level' to the given graph. 65 | Adds the attribute 'depth' to the given graph. 66 | 67 | Args: 68 | graph: A networkx DiGraph. 69 | 70 | Returns: 71 | Dictionary of nodes by level. 72 | 73 | """ 74 | 75 | def _compute_level_dependencies(graph: DiGraph, node): 76 | predecessors = list(graph.predecessors(node)) 77 | if len(predecessors) == 0: 78 | graph.nodes[node]["level"] = 0 79 | return True 80 | 81 | pred_level_by_index = {} 82 | for pred in predecessors: 83 | index = graph.edges[pred, node]["index"] 84 | try: 85 | pred_level = graph.nodes[pred]["level"] 86 | except KeyError: 87 | pred_level = None 88 | 89 | if index in pred_level_by_index: 90 | pred_level_by_index[index].append(pred_level) 91 | else: 92 | pred_level_by_index[index] = [pred_level] 93 | 94 | min_level_by_index = [] 95 | for index, level_list in pred_level_by_index.items(): 96 | level_list_wo_none = [lvl for lvl in level_list if lvl is not None] 97 | if len(level_list_wo_none) == 0: 98 | return False 99 | min_level_by_index.append(min(level_list_wo_none)) 100 | level = 1 + max(min_level_by_index) 101 | graph.nodes[node]["level"] = level 102 | return True 103 | 104 | for _ in range(len(graph.nodes())): 105 | all_nodes_have_level = True 106 | incomplete_nodes = [] 107 | for node in graph.nodes(): 108 | incomplete = not _compute_level_dependencies(graph, node) 109 | if incomplete: 110 | incomplete_nodes.append(node) 111 | all_nodes_have_level = False 112 | if all_nodes_have_level: 113 | break 114 | 115 | if not all_nodes_have_level: 116 | raise ValueError( 117 | "Could not attribute levels to all nodes. " 118 | f"Incomplete nodes: {incomplete_nodes}" 119 | ) 120 | 121 | return get_nodes_by_level(graph) 122 | 123 | 124 | def compute_edges_color(graph: DiGraph): 125 | """Compute the edges colors of a leveled graph for readability. 126 | 127 | Requires nodes to have a 'level' attribute. 128 | Adds the attribute 'color' and 'linestyle' to each edge in the given graph. 129 | Nodes with a lot of successors will have more transparent edges. 130 | Edges going from high to low level will be dashed. 131 | 132 | Args: 133 | graph: A networkx DiGraph. 134 | 135 | """ 136 | alphas = [1, 1, 0.9, 0.8, 0.7, 0.5, 0.4, 0.3] 137 | for node in graph.nodes(): 138 | successors = list(graph.successors(node)) 139 | for succ in successors: 140 | alpha = 0.2 141 | if graph.nodes[node]["level"] < graph.nodes[succ]["level"]: 142 | if len(successors) < len(alphas): 143 | alpha = alphas[len(successors) - 1] 144 | else: 145 | graph.edges[node, succ]["linestyle"] = "dashed" 146 | if isinstance(graph.edges[node, succ]["color"], list): 147 | graph.edges[node, succ]["color"][3] = alpha 148 | 149 | 150 | def draw_networkx_nodes_images( 151 | graph: DiGraph, pos, ax: Axes, img_zoom: float = 1, **kwargs 152 | ): 153 | """Draw nodes images of a networkx DiGraph on a given matplotlib ax. 154 | 155 | Requires nodes to have attributes 'image' for the node image and 'color' for the border color. 156 | 157 | Args: 158 | graph: A networkx DiGraph. 159 | pos: Layout positions of the graph. 160 | ax: A matplotlib Axes. 161 | img_zoom (Optional): Zoom to apply to images. 162 | 163 | """ 164 | for n in graph: 165 | img = graph.nodes(data="image", default=None)[n] 166 | color = graph.nodes(data="color", default="black")[n] 167 | if img is not None: 168 | min_dim = min(img.shape[:2]) 169 | min_ax_shape = min(ax._position.width, ax._position.height) 170 | zoom = 100 * img_zoom * min_ax_shape / min_dim 171 | imagebox = OffsetImage(img, zoom=zoom) 172 | imagebox = AnnotationBbox( 173 | imagebox, pos[n], frameon=True, box_alignment=(0.5, 0.5) 174 | ) 175 | 176 | imagebox.patch.set_facecolor("None") 177 | imagebox.patch.set_edgecolor(color) 178 | imagebox.patch.set_linewidth(3) 179 | imagebox.patch.set_boxstyle("round", pad=0.15) 180 | ax.add_artist(imagebox) 181 | else: 182 | # If no image is found, draw label instead 183 | (x, y) = pos[n] 184 | label = str(n) 185 | ax.text( 186 | x, 187 | y, 188 | label, 189 | size=kwargs.get("font_size", 12), 190 | color=kwargs.get("font_color", "k"), 191 | family=kwargs.get("font_family", "sans-serif"), 192 | weight=kwargs.get("font_weight", "normal"), 193 | alpha=kwargs.get("alpha"), 194 | horizontalalignment=kwargs.get("horizontalalignment", "center"), 195 | verticalalignment=kwargs.get("verticalalignment", "center"), 196 | transform=ax.transData, 197 | bbox=dict(facecolor="none", edgecolor=color), 198 | clip_on=True, 199 | ) 200 | 201 | 202 | def get_successors_with_index( 203 | graph: "HEBGraph", node: "Node", next_edge_index: int 204 | ) -> List["Node"]: 205 | succs = graph.successors(node) 206 | next_nodes = [] 207 | for next_node in succs: 208 | if int(graph.edges[node, next_node]["index"]) == next_edge_index: 209 | next_nodes.append(next_node) 210 | if len(next_nodes) == 0: 211 | raise ValueError( 212 | f"FeatureCondition {node} returned index {next_edge_index}" 213 | f" but {next_edge_index} was not found as an edge index" 214 | ) 215 | return next_nodes 216 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | HEBG - Hierachical Explainable Behaviors using Graphs 2 | ===================================================== 3 | 4 | .. image:: https://badge.fury.io/py/hebg.svg 5 | :alt: [Fury - PyPi stable version] 6 | :target: https://badge.fury.io/py/hebg 7 | 8 | .. image:: https://static.pepy.tech/badge/hebg 9 | :alt: [PePy - Downloads] 10 | :target: https://pepy.tech/project/hebg 11 | 12 | .. image:: https://static.pepy.tech/badge/hebg/week 13 | :alt: [PePy - Downloads per week] 14 | :target: https://pepy.tech/project/hebg 15 | 16 | .. image:: https://app.codacy.com/project/badge/Grade/ec4b296d18f4412398d64a66224c66dd 17 | :alt: [Codacy - Grade] 18 | :target: https://www.codacy.com/gh/IRLL/HEB_graphs/dashboard?utm_source=github.com&utm_medium=referral&utm_content=IRLL/HEB_graphs&utm_campaign=Badge_Grade 19 | 20 | .. image:: https://app.codacy.com/project/badge/Coverage/ec4b296d18f4412398d64a66224c66dd 21 | :alt: [Codacy - Coverage] 22 | :target: https://www.codacy.com/gh/IRLL/HEB_graphs/dashboard?utm_source=github.com&utm_medium=referral&utm_content=IRLL/HEB_graphs&utm_campaign=Badge_Coverage 23 | 24 | .. image:: https://img.shields.io/badge/code%20style-black-000000.svg 25 | :alt: [CodeStyle - Black] 26 | :target: https://github.com/psf/black 27 | 28 | .. image:: https://img.shields.io/github/license/MathisFederico/Crafting?style=plastic 29 | :alt: [Licence - GPLv3] 30 | :target: https://www.gnu.org/licenses/ 31 | 32 | 33 | This package is meant to build programatic hierarchical behaviors as graphs 34 | to compare them to human explanations of behavior. 35 | 36 | We take the definition of "behavior" as a function from observation to action. 37 | 38 | 39 | Installation 40 | ------------ 41 | 42 | 43 | .. code-block:: sh 44 | 45 | pip install hebg 46 | 47 | 48 | Usage 49 | ----- 50 | 51 | Build a HEBGraph 52 | ~~~~~~~~~~~~~~~~ 53 | 54 | Here is an example to show how could we hierarchicaly build an explanable behavior to pet a cat. 55 | 56 | .. code-block:: py3 57 | 58 | """ 59 | 60 | Here is the hierarchical structure that we would want: 61 | 62 | ``` 63 | PetACat: 64 | IsThereACatAround ? 65 | -> Yes: 66 | PetNearbyCat 67 | -> No: 68 | LookForACat 69 | 70 | PetNearbyCat: 71 | IsYourHandNearTheCat ? 72 | -> Yes: 73 | Pet 74 | -> No: 75 | MoveYourHandNearTheCat 76 | ``` 77 | 78 | """ 79 | 80 | from hebg import HEBGraph, Action, FeatureCondition, Behavior 81 | from hebg.unrolling import unroll_graph 82 | 83 | # Add a fundamental action 84 | class Pet(Action): 85 | def __init__(self) -> None: 86 | super().__init__(action="Pet") 87 | 88 | # Add a condition on the observation 89 | class IsYourHandNearTheCat(FeatureCondition): 90 | def __init__(self, hand) -> None: 91 | super().__init__(name="Is hand near the cat ?") 92 | self.hand = hand 93 | def __call__(self, observation) -> int: 94 | # Could be a very complex function that returns 1 is the hand is near the cat else 0. 95 | if observation["cat"] == observation[self.hand]: 96 | return int(True) # 1 97 | return int(False) # 0 98 | 99 | # Add an unexplainable Behavior (without a graph, but a function that can be called). 100 | class MoveYourHandNearTheCat(Behavior): 101 | def __init__(self) -> None: 102 | super().__init__(name="Move slowly your hand near the cat") 103 | def __call__(self, observation, *args, **kwargs) -> Action: 104 | # Could be a very complex function that returns actions from any given observation 105 | return Action("Move hand to cat") 106 | 107 | # Add a sub-behavior 108 | class PetNearbyCat(Behavior): 109 | def __init__(self) -> None: 110 | super().__init__(name="Pet nearby cat") 111 | def build_graph(self) -> HEBGraph: 112 | graph = HEBGraph(self) 113 | is_hand_near_cat = IsYourHandNearTheCat(hand="hand") 114 | graph.add_edge(is_hand_near_cat, MoveYourHandNearTheCat(), index=int(False)) 115 | graph.add_edge(is_hand_near_cat, Pet(), index=int(True)) 116 | return graph 117 | 118 | # Add an other condition on observation 119 | class IsThereACatAround(FeatureCondition): 120 | def __init__(self) -> None: 121 | super().__init__(name="Is there a cat around ?") 122 | def __call__(self, observation) -> int: 123 | # Could be a very complex function that returns 1 is there is a cat around else 0. 124 | if "cat" in observation: 125 | return int(True) # 1 126 | return int(False) # 0 127 | 128 | # Add an other unexplainable Behavior (without a graph, but a function that can be called). 129 | class LookForACat(Behavior): 130 | def __init__(self) -> None: 131 | super().__init__(name="Look for a nearby cat") 132 | def __call__(self, observation, *args, **kwargs) -> Action: 133 | # Could be a very complex function that returns actions from any given observation 134 | return Action("Move to a cat") 135 | 136 | # Finally, add the main Behavior 137 | class PetACat(Behavior): 138 | def __init__(self) -> None: 139 | super().__init__(name="Pet a cat") 140 | def build_graph(self) -> HEBGraph: 141 | graph = HEBGraph(self) 142 | is_a_cat_around = IsThereACatAround() 143 | graph.add_edge(is_a_cat_around, LookForACat(), index=int(False)) 144 | graph.add_edge(is_a_cat_around, PetNearbyCat(), index=int(True)) 145 | return graph 146 | 147 | if __name__ == "__main__": 148 | pet_a_cat_behavior = PetACat() 149 | observation = { 150 | "cat": "sofa", 151 | "hand": "computer", 152 | } 153 | 154 | # Call on observation 155 | action = pet_a_cat_behavior(observation) 156 | print(action) # Action("Move hand to cat") 157 | 158 | # Obtain networkx graph 159 | graph = pet_a_cat_behavior.graph 160 | print(list(graph.edges(data="index"))) 161 | 162 | # Draw graph using matplotlib 163 | import matplotlib.pyplot as plt 164 | fig, ax = plt.subplots() 165 | graph.draw(ax) 166 | plt.show() 167 | 168 | 169 | .. image:: docs/images/PetACatGraph.png 170 | :align: center 171 | 172 | Unrolling HEBGraph 173 | ~~~~~~~~~~~~~~~~~~ 174 | 175 | When ploting an HEBGraph of a behavior, only the graph of the behavior itself is shown. 176 | To see the full hierarchical graph (including sub-behaviors), we need to unroll the graph as such: 177 | 178 | .. code-block:: py3 179 | 180 | from hebg.unrolling import unroll_graph 181 | 182 | unrolled_graph = unroll_graph(pet_a_cat_behavior.graph, add_prefix=False) 183 | 184 | # Is also a networkx graph 185 | print(list(unrolled_graph.edges(data="index"))) 186 | 187 | # Draw graph using matplotlib 188 | import matplotlib.pyplot as plt 189 | fig, ax = plt.subplots() 190 | unrolled_graph.draw(ax) 191 | plt.show() 192 | 193 | 194 | .. image:: docs/images/PetACatGraphUnrolled.png 195 | :align: center 196 | 197 | Note that unexplainable behaviors (the one without graphs) are kept as is. 198 | 199 | Python code generation from graph 200 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 201 | 202 | Once you have a HEBGraph, you can use it to generate a working python code that 203 | replicates the HEBGraph's behavior: 204 | 205 | .. code-block:: py3 206 | 207 | code = pet_a_cat_behavior.graph.generate_source_code() 208 | with open("pet_a_cat.py", "w") as pyfile: 209 | pyfile.write(code) 210 | 211 | Will generate the code bellow: 212 | 213 | .. code-block:: py3 214 | 215 | from hebg.codegen import GeneratedBehavior 216 | 217 | # Require 'Look for a nearby cat' behavior to be given. 218 | # Require 'Move slowly your hand near the cat' behavior to be given. 219 | class PetTheCat(GeneratedBehavior): 220 | def __call__(self, observation) -> Any: 221 | edge_index = self.feature_conditions['Is there a cat around ?'](observation) 222 | if edge_index == 0: 223 | return self.known_behaviors['Look for a nearby cat'](observation) 224 | if edge_index == 1: 225 | edge_index_1 = self.feature_conditions['Is hand near the cat ?'](observation) 226 | if edge_index_1 == 0: 227 | return self.known_behaviors['Move slowly your hand near the cat'](observation) 228 | if edge_index_1 == 1: 229 | return self.actions['Action(Pet)'](observation) 230 | 231 | 232 | Contributing to HEBG 233 | -------------------- 234 | 235 | Whenever you encounter a :bug: **bug** or have :tada: **feature request**, 236 | report this via `Github issues `_. 237 | 238 | If you wish to contribute directly, see `CONTRIBUTING `_ 239 | -------------------------------------------------------------------------------- /src/hebg/unrolling.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """Module to unroll HEBGraph. 5 | 6 | Unrolling means expanding each sub-behavior node as it's own graph in the global HEBGraph. 7 | Behaviors that do not have a graph (Unexplainable behaviors) should stay as is in the graph. 8 | 9 | """ 10 | 11 | from copy import copy 12 | from typing import TYPE_CHECKING, Dict, List, Tuple, Optional, Union 13 | 14 | from networkx import relabel_nodes 15 | 16 | from hebg.behavior import Behavior 17 | 18 | BEHAVIOR_SEPARATOR = ">" 19 | 20 | if TYPE_CHECKING: 21 | from hebg import HEBGraph, Node, Action 22 | 23 | 24 | def unroll_graph( 25 | graph: "HEBGraph", 26 | add_prefix: bool = False, 27 | cut_looping_alternatives: bool = False, 28 | ) -> "HEBGraph": 29 | """Build the the unrolled HEBGraph. 30 | 31 | The HEBGraph as the same behavior but every behavior node is recursively replaced 32 | by it's own HEBGraph if it can be computed. 33 | 34 | Args: 35 | graph (HEBGraph): HEBGraph to unroll the behavior in. 36 | add_prefix (bool, optional): If True, adds a name prefix to keep nodes different. 37 | Defaults to False. 38 | cut_looping_alternatives (bool, optional): If True, cut the looping alternatives. 39 | Defaults to False. 40 | 41 | Returns: 42 | HEBGraph: This HEBGraph's unrolled HEBGraph. 43 | """ 44 | unrolled_graph, _is_looping = _unroll_graph( 45 | graph, 46 | add_prefix=add_prefix, 47 | cut_looping_alternatives=cut_looping_alternatives, 48 | ) 49 | return unrolled_graph 50 | 51 | 52 | def _unroll_graph( 53 | graph: "HEBGraph", 54 | add_prefix: bool = False, 55 | cut_looping_alternatives: bool = False, 56 | _current_alternatives: Optional[List[Union["Action", "Behavior"]]] = None, 57 | _unrolled_behaviors: Optional[Dict[str, Optional["HEBGraph"]]] = None, 58 | ) -> Tuple["HEBGraph", bool]: 59 | if _unrolled_behaviors is None: 60 | _unrolled_behaviors = {} 61 | if _current_alternatives is None: 62 | _current_alternatives = {0: []} 63 | 64 | is_looping = False 65 | _unrolled_behaviors[graph.behavior.name] = None 66 | 67 | unrolled_graph: "HEBGraph" = copy(graph) 68 | for node in list(unrolled_graph.nodes()): 69 | if not isinstance(node, Behavior): 70 | continue 71 | 72 | _current_alternatives[0] = _direct_alternatives(node, graph) 73 | _current_alternatives[1] = _roots_alternatives(node, graph) 74 | unrolled_graph, behavior_is_looping = _unroll_behavior( 75 | unrolled_graph, 76 | node, 77 | add_prefix, 78 | cut_looping_alternatives, 79 | _current_alternatives, 80 | _unrolled_behaviors, 81 | ) 82 | if behavior_is_looping: 83 | is_looping = True 84 | 85 | return unrolled_graph, is_looping 86 | 87 | 88 | def _direct_alternatives(node: "Node", graph: "HEBGraph") -> list["Node"]: 89 | alternatives = [] 90 | for pred, _node, data in graph.in_edges(node, data=True): 91 | index = data["index"] 92 | for _pred, alternative, alt_index in graph.out_edges(pred, data="index"): 93 | if index != alt_index or alternative == node: 94 | continue 95 | alternatives.append(alternative) 96 | return alternatives 97 | 98 | 99 | def _roots_alternatives(node: "Node", graph: "HEBGraph") -> list["Node"]: 100 | alternatives = [] 101 | for pred, _node, _data in graph.in_edges(node, data=True): 102 | if pred in graph.roots: 103 | alternatives.extend([r for r in graph.roots if r != pred]) 104 | return alternatives 105 | 106 | 107 | def _unroll_behavior( 108 | graph: "HEBGraph", 109 | behavior: "Behavior", 110 | add_prefix: bool, 111 | cut_looping_alternatives: bool, 112 | _current_alternatives: List[Union["Action", "Behavior"]], 113 | _unrolled_behaviors: Dict[str, Optional["HEBGraph"]], 114 | ) -> Tuple["HEBGraph", bool]: 115 | """Unroll a behavior node in a given HEBGraph 116 | 117 | Args: 118 | graph (HEBGraph): HEBGraph to unroll the behavior in. 119 | behavior (Behavior): Behavior node to unroll, must be in the given graph. 120 | add_prefix (bool): If True, adds a name prefix to keep nodes different. 121 | cut_looping_alternatives (bool): If True, cut the looping alternatives. 122 | 123 | Returns: 124 | HEBGraph: Initial graph with unrolled behavior. 125 | """ 126 | # Look for name reference. 127 | if behavior.name in graph.all_behaviors: 128 | behavior = graph.all_behaviors[behavior.name] 129 | 130 | node_graph, is_looping = _unrolled_behavior_graph( 131 | behavior, 132 | add_prefix, 133 | cut_looping_alternatives, 134 | _current_alternatives, 135 | _unrolled_behaviors, 136 | ) 137 | 138 | if is_looping and cut_looping_alternatives: 139 | for alternative in _current_alternatives[0]: 140 | for last_condition, _, data in graph.in_edges(behavior, data=True): 141 | graph.add_edge(last_condition, alternative, **data) 142 | if _current_alternatives[0]: 143 | graph.remove_node(behavior) 144 | return graph, False 145 | if _current_alternatives[1]: 146 | predecessors = list(graph.predecessors(behavior)) 147 | for last_condition in predecessors: 148 | successors = list(graph.successors(last_condition)) 149 | for descendant in successors: 150 | graph.remove_edge(last_condition, descendant) 151 | if graph.neighbors(descendant) == 0: 152 | graph.remove_node(descendant) 153 | graph.remove_node(last_condition) 154 | graph.remove_node(behavior) 155 | return graph, False 156 | raise NotImplementedError() 157 | 158 | if node_graph is None: 159 | # If we cannot get the node's graph, we keep it as is. 160 | return graph, is_looping 161 | 162 | # Relabel graph nodes to obtain disjoint node labels (if more that one node). 163 | if add_prefix and len(node_graph.nodes()) > 1: 164 | _add_prefix_to_graph(node_graph, behavior.name + BEHAVIOR_SEPARATOR) 165 | 166 | # Replace the behavior node by the unrolled behavior's graph 167 | graph = compose_heb_graphs(graph, node_graph) 168 | for edge_u, _, data in graph.in_edges(behavior, data=True): 169 | for root in node_graph.roots: 170 | graph.add_edge(edge_u, root, **data) 171 | 172 | graph.remove_node(behavior) 173 | return graph, is_looping 174 | 175 | 176 | def _unrolled_behavior_graph( 177 | behavior: "Behavior", 178 | add_prefix: bool, 179 | cut_looping_alternatives: bool, 180 | _current_alternatives: List[Union["Action", "Behavior"]], 181 | _unrolled_behaviors: Dict[str, Optional["HEBGraph"]], 182 | ) -> Tuple[Optional["HEBGraph"], bool]: 183 | """Get the unrolled sub-graph of a behavior. 184 | 185 | Args: 186 | behavior (Behavior): Behavior to get the unrolled graph of. 187 | add_prefix (bool): If True, adds a prefix in sub-hierarchies to have distinct nodes. 188 | cut_looping_alternatives (bool): If True, cut the looping alternatives. 189 | _unrolled_behaviors (Dict[str, Optional[HEBGraph]]): Dictionary of already computed 190 | unrolled graphs, both to save compute and prevent recursion loops. 191 | 192 | Returns: 193 | Optional[HEBGraph]: Unrolled graph of a behavior, None if it cannot be computed. 194 | """ 195 | if behavior.name in _unrolled_behaviors: 196 | # If we have aleardy unrolled this behavior, we reuse it's graph 197 | is_looping = _unrolled_behaviors[behavior.name] is None 198 | return _unrolled_behaviors[behavior.name], is_looping 199 | 200 | try: 201 | node_graph, is_looping = _unroll_graph( 202 | behavior.graph, 203 | add_prefix=add_prefix, 204 | cut_looping_alternatives=cut_looping_alternatives, 205 | _current_alternatives=_current_alternatives, 206 | _unrolled_behaviors=_unrolled_behaviors, 207 | ) 208 | _unrolled_behaviors[behavior.name] = node_graph 209 | return node_graph, is_looping 210 | except NotImplementedError: 211 | return None, False 212 | 213 | 214 | def _add_prefix_to_graph(graph: "HEBGraph", prefix: str) -> None: 215 | """Rename graph to obtain disjoint node labels.""" 216 | if prefix is None: 217 | return graph 218 | 219 | def rename(node: "Node") -> None: 220 | new_node = copy(node) 221 | new_node.name = prefix + node.name 222 | return new_node 223 | 224 | return relabel_nodes(graph, rename, copy=False) 225 | 226 | 227 | def group_behaviors_points( 228 | pos: Dict["Node", tuple], 229 | graph: "HEBGraph", 230 | ) -> Dict[tuple, list]: 231 | """Group nodes positions of an HEBGraph by sub-behavior. 232 | 233 | Args: 234 | pos (Dict[Node, tuple]): Positions of nodes. 235 | graph (HEBGraph): Graph. 236 | 237 | Returns: 238 | Dict[tuple, list]: A dictionary of nodes grouped by their behavior's hierarchy. 239 | """ 240 | points_grouped_by_behavior: Dict[tuple, list] = {} 241 | for node in graph.nodes(): 242 | groups = str(node).split(BEHAVIOR_SEPARATOR) 243 | if len(groups) > 1: 244 | for i in range(len(groups[:-1])): 245 | key = tuple(groups[: -1 - i]) 246 | point = pos[node] 247 | if key in points_grouped_by_behavior: 248 | points_grouped_by_behavior[key].append(point) 249 | else: 250 | points_grouped_by_behavior[key] = [point] 251 | return points_grouped_by_behavior 252 | 253 | 254 | def compose_heb_graphs(graph_of_reference: "HEBGraph", other_graph: "HEBGraph") -> None: 255 | """Returns a new_graph of graph_of_reference composed with other_graph. 256 | 257 | Composition is the simple union of the node sets and edge sets. 258 | The node sets of the graph_of_reference and other_graph do not need to be disjoint. 259 | 260 | Args: 261 | graph_of_reference, other_graph : HEBGraphs to compose. 262 | 263 | Returns: 264 | A new HEBGraph with the same type as graph_of_reference. 265 | 266 | """ 267 | new_graph = graph_of_reference.__class__( 268 | graph_of_reference.behavior, all_behaviors=graph_of_reference.all_behaviors 269 | ) 270 | # add graph attributes, H attributes take precedent over G attributes 271 | new_graph.graph.update(graph_of_reference.graph) 272 | new_graph.graph.update(other_graph.graph) 273 | 274 | new_graph.add_nodes_from(graph_of_reference.nodes(data=True)) 275 | new_graph.add_nodes_from(other_graph.nodes(data=True)) 276 | 277 | if graph_of_reference.is_multigraph(): 278 | new_graph.add_edges_from(graph_of_reference.edges(keys=True, data=True)) 279 | else: 280 | new_graph.add_edges_from(graph_of_reference.edges(data=True)) 281 | if other_graph.is_multigraph(): 282 | new_graph.add_edges_from(other_graph.edges(keys=True, data=True)) 283 | else: 284 | new_graph.add_edges_from(other_graph.edges(data=True)) 285 | return new_graph 286 | -------------------------------------------------------------------------------- /src/hebg/call_graph.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import ( 3 | TYPE_CHECKING, 4 | Any, 5 | Dict, 6 | List, 7 | NamedTuple, 8 | Optional, 9 | Tuple, 10 | TypeVar, 11 | Union, 12 | ) 13 | from matplotlib import pyplot as plt 14 | from matplotlib.axes import Axes 15 | 16 | from networkx import ( 17 | DiGraph, 18 | all_simple_paths, 19 | draw_networkx_edges, 20 | draw_networkx_labels, 21 | draw_networkx_nodes, 22 | ancestors, 23 | ) 24 | import numpy as np 25 | from hebg.behavior import Behavior 26 | from hebg.graph import get_successors_with_index 27 | from hebg.node import Action, FeatureCondition, Node 28 | 29 | if TYPE_CHECKING: 30 | from hebg.heb_graph import HEBGraph 31 | 32 | EnvAction = TypeVar("EnvAction") 33 | 34 | 35 | class CallGraph(DiGraph): 36 | def __init__(self, **attr) -> None: 37 | super().__init__(incoming_graph_data=None, **attr) 38 | self.graph["n_branches"] = 0 39 | self.graph["n_calls"] = 0 40 | self.graph["frontiere"] = [] 41 | self._known_fc: Dict[FeatureCondition, Any] = {} 42 | self._current_node = CallNode(0, 0) 43 | 44 | def add_root(self, heb_node: "Node", heb_graph: "HEBGraph", **kwargs) -> None: 45 | self.add_node( 46 | self._current_node, heb_node=heb_node, heb_graph=heb_graph, **kwargs 47 | ) 48 | 49 | def call_nodes( 50 | self, nodes: List["Node"], observation, heb_graph: "HEBGraph" 51 | ) -> EnvAction: 52 | self._extend_frontiere(nodes, heb_graph) 53 | action = None 54 | 55 | while len(self.graph["frontiere"]) > 0 and action is None: 56 | next_call_node = self._pop_from_frontiere() 57 | if next_call_node is None: 58 | break 59 | 60 | node: "Node" = self.nodes[next_call_node]["heb_node"] 61 | heb_graph: "HEBGraph" = self.nodes[next_call_node]["heb_graph"] 62 | 63 | if node.type == "behavior": 64 | # Search for name reference in all_behaviors 65 | if node.name in heb_graph.all_behaviors: 66 | node = heb_graph.all_behaviors[node.name] 67 | if not hasattr(node, "_graph") or node._graph is None: 68 | action = node(observation) 69 | break 70 | self._extend_frontiere(node.graph.roots, heb_graph=node.graph) 71 | elif node.type == "action": 72 | action = node(observation) 73 | break 74 | elif node.type == "feature_condition": 75 | if node in self._known_fc: 76 | next_edge_index = self._known_fc[node] 77 | else: 78 | next_edge_index = int(node(observation)) 79 | self._known_fc[node] = next_edge_index 80 | next_nodes = get_successors_with_index(heb_graph, node, next_edge_index) 81 | self._extend_frontiere(next_nodes, heb_graph) 82 | elif node.type == "empty": 83 | self._extend_frontiere(list(heb_graph.successors(node)), heb_graph) 84 | else: 85 | raise ValueError( 86 | f"Unknowed value {node.type} for node.type with node: {node}." 87 | ) 88 | 89 | if action is None: 90 | raise ValueError("No valid frontiere left in call_graph") 91 | 92 | return action 93 | 94 | def call_edge_labels(self) -> list[tuple]: 95 | return [ 96 | (self.nodes[u]["label"], self.nodes[v]["label"]) for u, v in self.edges() 97 | ] 98 | 99 | def add_node( 100 | self, node_for_adding, heb_node: "Node", heb_graph: "HEBGraph", **attr 101 | ): 102 | super().add_node( 103 | node_for_adding, 104 | heb_graph=heb_graph, 105 | heb_node=heb_node, 106 | label=heb_node.name, 107 | **attr, 108 | ) 109 | 110 | def add_edge( 111 | self, 112 | u_of_edge, 113 | v_of_edge, 114 | status: "CallEdgeStatus", 115 | **attr, 116 | ): 117 | return super().add_edge(u_of_edge, v_of_edge, status=status.value, **attr) 118 | 119 | def _make_new_branch(self) -> int: 120 | self.graph["n_branches"] += 1 121 | return self.graph["n_branches"] 122 | 123 | def _extend_frontiere(self, nodes: List["Node"], heb_graph: "HEBGraph") -> None: 124 | frontiere: List[CallNode] = self.graph["frontiere"] 125 | 126 | parent = self._current_node 127 | call_nodes = [] 128 | 129 | for i, node in enumerate(nodes): 130 | if i > 0: 131 | branch_id = self._make_new_branch() 132 | else: 133 | branch_id = parent.branch 134 | call_node = CallNode(branch_id, parent.rank + 1) 135 | 136 | if node.name in heb_graph.all_behaviors: 137 | node = heb_graph.all_behaviors[node.name] 138 | self.add_node(call_node, heb_node=node, heb_graph=heb_graph) 139 | self.add_edge(parent, call_node, CallEdgeStatus.UNEXPLORED) 140 | call_nodes.append(call_node) 141 | 142 | frontiere.extend(call_nodes) 143 | 144 | def _heb_node_from_call_node(self, node: "CallNode") -> "Node": 145 | return self.nodes[node]["heb_node"] 146 | 147 | def _pop_from_frontiere(self) -> Optional["CallNode"]: 148 | frontiere: List["CallNode"] = self.graph["frontiere"] 149 | 150 | next_node = None 151 | 152 | while next_node is None: 153 | if not frontiere: 154 | return None 155 | 156 | next_call_node = frontiere.pop( 157 | np.argmin( 158 | [ 159 | self._heb_node_from_call_node(node).complexity 160 | for node in frontiere 161 | ] 162 | ) 163 | ) 164 | maybe_next_node = self._heb_node_from_call_node(next_call_node) 165 | # Nodes should only have one parent 166 | parent = list(self.predecessors(next_call_node))[0] 167 | 168 | if isinstance(maybe_next_node, Behavior) and maybe_next_node in [ 169 | self._heb_node_from_call_node(node) 170 | for node in ancestors(self, next_call_node) 171 | ]: 172 | self._update_edge_status(parent, next_call_node, CallEdgeStatus.FAILURE) 173 | continue 174 | 175 | next_node = maybe_next_node 176 | 177 | self.graph["n_calls"] += 1 178 | self.nodes[next_call_node]["call_rank"] = 1 179 | self._update_edge_status(parent, next_call_node, CallEdgeStatus.CALLED) 180 | self._current_node = next_call_node 181 | return next_call_node 182 | 183 | def _update_edge_status( 184 | self, start: "Node", end: "Node", status: Union["CallEdgeStatus", str] 185 | ): 186 | status = CallEdgeStatus(status) 187 | self.edges[start, end]["status"] = status.value 188 | 189 | def draw( 190 | self, 191 | ax: Optional[Axes] = None, 192 | pos: Optional[Dict[str, Tuple[float, float]]] = None, 193 | nodes_kwargs: Optional[dict] = None, 194 | label_kwargs: Optional[dict] = None, 195 | edges_kwargs: Optional[dict] = None, 196 | ): 197 | if pos is None: 198 | pos = _call_graph_pos(self) 199 | if nodes_kwargs is None: 200 | nodes_kwargs = {} 201 | 202 | if ax is None: 203 | ax = plt.gca() 204 | 205 | pos_arr = np.array(list(pos.values())) 206 | max_x, max_y = pos_arr.max(axis=0) 207 | min_x, min_y = pos_arr.min(axis=0) 208 | y_range = max_y - min_y 209 | x_range = max_x - min_x 210 | ax.set_ylim([min_y - 0.1 * y_range, max_y + 0.1 * y_range]) 211 | ax.set_xlim([min_x - 0.1 * x_range, max_x + 0.1 * x_range]) 212 | 213 | nodes_complexity = np.array( 214 | [node_data["heb_node"].complexity for _, node_data in self.nodes(data=True)] 215 | ) 216 | complexity_range = nodes_complexity.max() - nodes_complexity.min() 217 | 218 | nodes_complexity_scaled = ( 219 | 50 + 600 * (nodes_complexity - nodes_complexity.min()) / complexity_range 220 | ) 221 | 222 | draw_networkx_nodes( 223 | self, 224 | node_color=[ 225 | _node_color(node_data["heb_node"]) 226 | for _, node_data in self.nodes(data=True) 227 | ], 228 | node_size=nodes_complexity_scaled, 229 | ax=ax, 230 | pos=pos, 231 | **nodes_kwargs, 232 | ) 233 | if label_kwargs is None: 234 | label_kwargs = {} 235 | draw_networkx_labels( 236 | self, 237 | labels={ 238 | node: f"{node_data['label']}" 239 | for node, node_data in self.nodes(data=True) 240 | }, 241 | ax=ax, 242 | horizontalalignment="center", 243 | verticalalignment="center", 244 | font_size=8, 245 | pos=pos, 246 | **nodes_kwargs, 247 | ) 248 | if edges_kwargs is None: 249 | edges_kwargs = {} 250 | if "connectionstyle" not in edges_kwargs: 251 | edges_kwargs.update(connectionstyle="angle,angleA=0,angleB=90,rad=5") 252 | draw_networkx_edges( 253 | self, 254 | ax=ax, 255 | pos=pos, 256 | arrowstyle="-", 257 | alpha=0.5, 258 | width=3, 259 | node_size=1, 260 | edge_color=[ 261 | _call_status_to_color(status) 262 | for _, _, status in self.edges(data="status") 263 | ], 264 | **edges_kwargs, 265 | ) 266 | 267 | 268 | class CallNode(NamedTuple): 269 | branch: int 270 | rank: int 271 | 272 | 273 | class CallEdgeStatus(Enum): 274 | UNEXPLORED = "unexplored" 275 | CALLED = "called" 276 | FAILURE = "failure" 277 | 278 | 279 | def _node_color(node: Union[Action, FeatureCondition, Behavior]) -> str: 280 | if isinstance(node, Action): 281 | return "red" 282 | if isinstance(node, FeatureCondition): 283 | return "blue" 284 | if isinstance(node, Behavior): 285 | return "orange" 286 | raise NotImplementedError 287 | 288 | 289 | def _call_status_to_color(status: Union[str, "CallEdgeStatus"]) -> str: 290 | status = CallEdgeStatus(status) 291 | if status is CallEdgeStatus.UNEXPLORED: 292 | return "black" 293 | if status is CallEdgeStatus.CALLED: 294 | return "green" 295 | if status is CallEdgeStatus.FAILURE: 296 | return "red" 297 | raise NotImplementedError 298 | 299 | 300 | def _call_graph_pos(call_graph: "CallGraph") -> Dict[str, Tuple[float, float]]: 301 | pos = {} 302 | 303 | roots = [n for (n, d) in call_graph.in_degree if d == 0] 304 | leafs = [n for (n, d) in call_graph.out_degree if d == 0] 305 | 306 | branches = all_simple_paths(call_graph, roots[0], leafs) 307 | branches = sorted(branches, key=lambda x: -len(x)) 308 | 309 | for branch_id, nodes_in_branch in enumerate(branches): 310 | for node in nodes_in_branch: 311 | if node in pos: 312 | continue 313 | rank = node.rank 314 | pos[node] = [branch_id, -rank] 315 | return pos 316 | -------------------------------------------------------------------------------- /tests/test_call_graph.py: -------------------------------------------------------------------------------- 1 | from networkx import DiGraph 2 | import pytest 3 | 4 | from hebg.behavior import Behavior 5 | from hebg.call_graph import CallEdgeStatus, CallGraph, CallNode, _call_graph_pos 6 | from hebg.heb_graph import HEBGraph 7 | from hebg.node import Action, FeatureCondition 8 | 9 | from pytest_mock import MockerFixture 10 | import pytest_check as check 11 | 12 | from tests import plot_graph 13 | 14 | from tests.examples.behaviors import F_F_A_Behavior 15 | from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors 16 | from tests.examples.feature_conditions import ThresholdFeatureCondition 17 | 18 | 19 | class TestCall: 20 | """Ensure that the call graph is faithful for debugging and efficient breadth first search.""" 21 | 22 | def test_call_stack_without_branches(self) -> None: 23 | """When there is no branches, the graph should be a simple sequence of the call stack.""" 24 | f_f_a_behavior = F_F_A_Behavior() 25 | 26 | draw = False 27 | if draw: 28 | plot_graph(f_f_a_behavior.graph.unrolled_graph) 29 | f_f_a_behavior(observation=-2) 30 | 31 | expected_graph = DiGraph( 32 | [ 33 | ("F_F_A", "Greater or equal to 0 ?"), 34 | ("Greater or equal to 0 ?", "Greater or equal to -1 ?"), 35 | ("Greater or equal to -1 ?", "Action(0)"), 36 | ] 37 | ) 38 | 39 | call_graph = f_f_a_behavior.graph.call_graph 40 | assert set(call_graph.call_edge_labels()) == set(expected_graph.edges()) 41 | 42 | def test_split_on_same_fc_index(self, mocker: MockerFixture) -> None: 43 | """When there are multiple indexes on the same feature condition, 44 | a branch should be created.""" 45 | 46 | expected_action = Action("EXPECTED", complexity=1) 47 | 48 | forbidden_value = "FORBIDDEN" 49 | forbidden_action = Action(forbidden_value, complexity=2) 50 | forbidden_action.__call__ = mocker.MagicMock(return_value=forbidden_value) 51 | 52 | class F_AA_Behavior(Behavior): 53 | """Feature condition with mutliple actions on same index.""" 54 | 55 | def __init__(self) -> None: 56 | super().__init__("F_AA") 57 | 58 | def build_graph(self) -> HEBGraph: 59 | graph = HEBGraph(self) 60 | feature_condition = ThresholdFeatureCondition( 61 | relation=">=", threshold=0 62 | ) 63 | graph.add_edge( 64 | feature_condition, Action(0, complexity=0), index=int(True) 65 | ) 66 | graph.add_edge(feature_condition, forbidden_action, index=int(False)) 67 | graph.add_edge(feature_condition, expected_action, index=int(False)) 68 | 69 | return graph 70 | 71 | f_aa_behavior = F_AA_Behavior() 72 | draw = False 73 | if draw: 74 | plot_graph(f_aa_behavior.graph.unrolled_graph) 75 | 76 | # Sanity check that the right action should be called and not the forbidden one. 77 | assert f_aa_behavior(observation=-1) == expected_action.action 78 | forbidden_action.__call__.assert_not_called() 79 | 80 | # Graph should have the good split 81 | call_graph = f_aa_behavior.graph.call_graph 82 | expected_graph = DiGraph( 83 | [ 84 | ("F_AA", "Greater or equal to 0 ?"), 85 | ("Greater or equal to 0 ?", "Action(EXPECTED)"), 86 | ("Greater or equal to 0 ?", "Action(FORBIDDEN)"), 87 | ] 88 | ) 89 | assert set(call_graph.call_edge_labels()) == set(expected_graph.edges()) 90 | 91 | @pytest.mark.xfail 92 | def test_multiple_call_to_same_fc(self, mocker: MockerFixture) -> None: 93 | """Call graph should allow for the same feature condition 94 | to be called multiple times in the same branch (in different behaviors).""" 95 | expected_action = Action("EXPECTED") 96 | unexpected_action = Action("UNEXPECTED") 97 | 98 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0) 99 | 100 | class SubBehavior(Behavior): 101 | def __init__(self) -> None: 102 | super().__init__("SubBehavior") 103 | 104 | def build_graph(self) -> HEBGraph: 105 | graph = HEBGraph(self) 106 | graph.add_edge(feature_condition, expected_action, index=int(True)) 107 | graph.add_edge(feature_condition, unexpected_action, index=int(False)) 108 | return graph 109 | 110 | class RootBehavior(Behavior): 111 | """Feature condition with mutliple actions on same index.""" 112 | 113 | def __init__(self) -> None: 114 | super().__init__("RootBehavior") 115 | 116 | def build_graph(self) -> HEBGraph: 117 | graph = HEBGraph(self) 118 | graph.add_edge(feature_condition, SubBehavior(), index=int(True)) 119 | graph.add_edge(feature_condition, unexpected_action, index=int(False)) 120 | 121 | return graph 122 | 123 | root_behavior = RootBehavior() 124 | draw = False 125 | if draw: 126 | plot_graph(root_behavior.graph.unrolled_graph) 127 | 128 | # Sanity check that the right action should be called and not the forbidden one. 129 | assert root_behavior(observation=2) == expected_action.action 130 | 131 | # Graph should have the good split 132 | call_graph = root_behavior.graph.call_graph 133 | expected_graph = DiGraph( 134 | [ 135 | ("RootBehavior", "Greater or equal to 0 ?"), 136 | ("Greater or equal to 0 ?", "SubBehavior"), 137 | ("SubBehavior", "Greater or equal to 0 ?"), 138 | ("Greater or equal to 0 ?", "Action(EXPECTED)"), 139 | ] 140 | ) 141 | assert set(call_graph.call_edge_labels()) == set(expected_graph.edges()) 142 | 143 | expected_labels = { 144 | CallNode(0, 0): "RootBehavior", 145 | CallNode(0, 1): "Greater or equal to 0 ?", 146 | CallNode(0, 2): "SubBehavior", 147 | CallNode(0, 3): "Greater or equal to 0 ?", 148 | CallNode(0, 4): "Action(EXPECTED)", 149 | } 150 | for node, label in call_graph.nodes(data="label"): 151 | check.equal(label, expected_labels[node]) 152 | 153 | def test_chain_behaviors(self, mocker: MockerFixture) -> None: 154 | """When sub-behaviors with a graph are called recursively, 155 | the call graph should still find their nodes.""" 156 | 157 | expected_action = "EXPECTED" 158 | 159 | class DummyBehavior(Behavior): 160 | __call__ = mocker.MagicMock(return_value=expected_action) 161 | 162 | class SubBehavior(Behavior): 163 | def __init__(self) -> None: 164 | super().__init__("SubBehavior") 165 | 166 | def build_graph(self) -> HEBGraph: 167 | graph = HEBGraph(self) 168 | graph.add_node(DummyBehavior("Dummy")) 169 | return graph 170 | 171 | class RootBehavior(Behavior): 172 | def __init__(self) -> None: 173 | super().__init__("RootBehavior") 174 | 175 | def build_graph(self) -> HEBGraph: 176 | graph = HEBGraph(self) 177 | graph.add_node(Behavior("SubBehavior")) 178 | return graph 179 | 180 | sub_behavior = SubBehavior() 181 | sub_behavior.graph 182 | 183 | root_behavior = RootBehavior() 184 | root_behavior.graph.all_behaviors["SubBehavior"] = sub_behavior 185 | 186 | # Sanity check that the right action should be called. 187 | assert root_behavior(observation=-1) == expected_action 188 | 189 | call_graph = root_behavior.graph.call_graph 190 | expected_graph = DiGraph( 191 | [ 192 | ("RootBehavior", "SubBehavior"), 193 | ("SubBehavior", "Dummy"), 194 | ] 195 | ) 196 | assert set(call_graph.call_edge_labels()) == set(expected_graph.edges()) 197 | 198 | def test_looping_goback(self) -> None: 199 | """Loops with alternatives should be ignored.""" 200 | draw = False 201 | _gather_wood, get_axe = build_looping_behaviors() 202 | assert get_axe({}) == "Punch tree" 203 | 204 | call_graph = get_axe.graph.call_graph 205 | 206 | if draw: 207 | plot_graph(call_graph) 208 | 209 | expected_labels = { 210 | CallNode(0, 0): "Get new axe", 211 | CallNode(0, 1): "Has wood ?", 212 | CallNode(1, 2): "Action(Summon axe out of thin air)", 213 | CallNode(0, 2): "Gather wood", 214 | CallNode(0, 3): "Has axe ?", 215 | CallNode(2, 4): "Get new axe", 216 | CallNode(0, 4): "Action(Punch tree)", 217 | } 218 | for node, label in call_graph.nodes(data="label"): 219 | check.equal(label, expected_labels[node]) 220 | 221 | expected_graph = DiGraph( 222 | [ 223 | ("Get new axe", "Has wood ?"), 224 | ("Has wood ?", "Action(Summon axe out of thin air)"), 225 | ("Has wood ?", "Gather wood"), 226 | ("Gather wood", "Has axe ?"), 227 | ("Has axe ?", "Get new axe"), 228 | ("Has axe ?", "Action(Punch tree)"), 229 | ] 230 | ) 231 | 232 | assert set(call_graph.call_edge_labels()) == set(expected_graph.edges()) 233 | 234 | 235 | class TestDraw: 236 | """Ensures that the graph is readable even in complex situations.""" 237 | 238 | def test_result_on_first_branch(self) -> None: 239 | """Resulting action should always be on the first branch.""" 240 | draw = False 241 | root_behavior = Behavior("Root", complexity=20) 242 | call_graph = CallGraph() 243 | call_graph.add_root(root_behavior, None) 244 | 245 | nodes = [ 246 | (CallNode(0, 1), FeatureCondition("FC1", complexity=1)), 247 | (CallNode(0, 2), FeatureCondition("FC2", complexity=1)), 248 | (CallNode(0, 3), root_behavior), 249 | (CallNode(1, 1), FeatureCondition("FC3", complexity=1)), 250 | (CallNode(1, 2), FeatureCondition("FC4", complexity=1)), 251 | (CallNode(1, 3), FeatureCondition("FC5", complexity=1)), 252 | (CallNode(1, 4), Action("A", complexity=1)), 253 | ] 254 | 255 | for node, heb_node in nodes: 256 | call_graph.add_node(node, heb_node, None) 257 | 258 | edges = [ 259 | (CallNode(0, 0), CallNode(0, 1), CallEdgeStatus.CALLED), 260 | (CallNode(0, 1), CallNode(0, 2), CallEdgeStatus.CALLED), 261 | (CallNode(0, 2), CallNode(0, 3), CallEdgeStatus.FAILURE), 262 | (CallNode(0, 0), CallNode(1, 1), CallEdgeStatus.CALLED), 263 | (CallNode(1, 1), CallNode(1, 2), CallEdgeStatus.CALLED), 264 | (CallNode(1, 2), CallNode(1, 3), CallEdgeStatus.CALLED), 265 | (CallNode(1, 3), CallNode(1, 4), CallEdgeStatus.CALLED), 266 | ] 267 | 268 | for start, end, status in edges: 269 | call_graph.add_edge(start, end, status) 270 | 271 | expected_poses = { 272 | CallNode(0, 0): [0, 0], 273 | CallNode(1, 1): [0, -1], 274 | CallNode(1, 2): [0, -2], 275 | CallNode(1, 3): [0, -3], 276 | CallNode(1, 4): [0, -4], 277 | CallNode(0, 1): [1, -1], 278 | CallNode(0, 2): [1, -2], 279 | CallNode(0, 3): [1, -3], 280 | } 281 | if draw: 282 | plot_graph(call_graph) 283 | 284 | assert _call_graph_pos(call_graph) == expected_poses 285 | -------------------------------------------------------------------------------- /src/hebg/codegen.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """Module for code generation from HEBGraph.""" 5 | 6 | from re import sub 7 | from typing import TYPE_CHECKING, Dict, List, Set, Tuple 8 | 9 | from hebg.behavior import Behavior 10 | from hebg.graph import get_roots, get_successors_with_index 11 | from hebg.metrics.histograms import cumulated_hebgraph_histogram 12 | from hebg.node import Action, FeatureCondition, Node 13 | from hebg.unrolling import BEHAVIOR_SEPARATOR 14 | 15 | if TYPE_CHECKING: 16 | from hebg.heb_graph import HEBGraph 17 | 18 | 19 | class GeneratedBehavior: 20 | """Base class for generated behaviors. 21 | 22 | Used to reduce the overhead of abstracting a behavior.""" 23 | 24 | def __init__( 25 | self, 26 | actions: Dict[str, "Action"] = None, 27 | feature_conditions: Dict[str, "FeatureCondition"] = None, 28 | behaviors: Dict[str, "Behavior"] = None, 29 | ): 30 | self.actions = actions if actions is not None else {} 31 | self.feature_conditions = feature_conditions if actions is not None else {} 32 | self.known_behaviors = behaviors if behaviors is not None else {} 33 | 34 | 35 | def get_hebg_source(graph: "HEBGraph") -> str: 36 | """Builds the generated source code corresponding to the HEBGraph behavior. 37 | 38 | Args: 39 | graph (HEBGraph): HEBGraph to generate the source code from. 40 | 41 | Returns: 42 | str: Python source code corresponding the behavior of the HEBGraph. 43 | """ 44 | class_codelines = get_behavior_class_codelines(graph) 45 | source = "\n".join(class_codelines) 46 | return source 47 | 48 | 49 | def get_behavior_class_codelines( 50 | graph: "HEBGraph", 51 | behaviors_histogram: Dict["Behavior", Dict["Node", int]] = None, 52 | add_dependencies: bool = True, 53 | ) -> List[str]: 54 | """Generate codelines of the whole GeneratedBehavior from a HEBGraph. 55 | 56 | Args: 57 | graph (HEBGraph): HEBGraph to generate the behavior from. 58 | behaviors_histogram (Dict[Behavior, Dict[Node, int]], optional): Histogram of uses 59 | of all behaviors needed for the given graph. Defaults to None. 60 | add_dependencies (bool, optional): If True, codelines will include other GeneratedBehavior 61 | from sub-behaviors of the given HEBGraph and a hashmap to map behavior name 62 | to coresponding GeneratedBehavior. Defaults to True. 63 | 64 | Returns: 65 | List[str]: Codelines generated of the HEBGraph GeneratedBehavior. 66 | """ 67 | if behaviors_histogram is None: 68 | behaviors_histogram = cumulated_hebgraph_histogram(graph) 69 | 70 | ( 71 | dependencies_codelines, 72 | behaviors_incall_codelines, 73 | hashmap_codelines, 74 | ) = get_dependencies_codelines( 75 | graph, 76 | behaviors_histogram, 77 | add_dependencies_codelines=add_dependencies, 78 | ) 79 | 80 | class_codelines = [] 81 | # Other behaviors dependencies 82 | if add_dependencies: 83 | class_codelines += ["from hebg.codegen import GeneratedBehavior", ""] 84 | for _behavior, dependency_codelines in dependencies_codelines.items(): 85 | class_codelines += dependency_codelines 86 | 87 | # Class overhead 88 | behavior_class_name = _to_camel_case(graph.behavior.name.capitalize()) 89 | class_codelines.append(f"class {behavior_class_name}(GeneratedBehavior):") 90 | # Call function 91 | class_codelines += get_behavior_call_codelines( 92 | graph, 93 | behaviors_incall_codelines=behaviors_incall_codelines, 94 | ) 95 | # Dependencies hashmap 96 | if add_dependencies: 97 | class_codelines += hashmap_codelines 98 | return class_codelines 99 | 100 | 101 | def get_dependencies_codelines( 102 | graph: "HEBGraph", 103 | behaviors_histogram: Dict["Behavior", Dict["Node", int]], 104 | add_dependencies_codelines: bool = True, 105 | ) -> Tuple[Dict["Behavior", List[str]], Set["Behavior"], Dict[str, str]]: 106 | """Parse dependencies of the given HEBGraph behavior's. 107 | 108 | Args: 109 | graph (HEBGraph): HEBGraph to parse dependecies from. 110 | behaviors_histogram (Dict[Behavior, Dict[Node, int]]): _description_ 111 | add_dependencies_codelines (bool, optional): _description_. Defaults to True. 112 | 113 | Returns: 114 | Tuple of three elements: 115 | - dependencies_codelines (Dict[Behavior, List[str]]): Codelines of the GeneratedBehavior for 116 | each of the behavior used in the HEBGraph if they can be computed, else a comment. 117 | - behaviors_incall_codelines (Set[Behavior]): Set of behaviors that should be directly 118 | unrolled in the call function (because the abstraction is not worth it). 119 | - hashmap_codelines (List[str]): Codelines of the map between behavior names 120 | and coresponding GeneratedBehavior. 121 | """ 122 | dependencies_codelines: Dict["Behavior", List[str]] = {} 123 | behaviors_incall_codelines: Set["Behavior"] = set() 124 | dependencies_hashmap: Dict[str, str] = {} 125 | for behavior, n_used in sorted( 126 | behaviors_histogram.items(), key=lambda x: x[1], reverse=True 127 | ): 128 | if not isinstance(behavior, Behavior): 129 | continue 130 | if behavior.name in graph.all_behaviors: 131 | behavior = graph.all_behaviors[behavior.name] 132 | 133 | try: 134 | sub_graph = behavior.graph 135 | except NotImplementedError: 136 | # If subgraph cannot be computed, we simply have a ref 137 | if add_dependencies_codelines: 138 | dependencies_codelines[behavior] = [ 139 | f"# Require '{behavior.name}' behavior to be given." 140 | ] 141 | continue 142 | 143 | if n_used == 1 or len(sub_graph.nodes()) == 1: 144 | dependencies_codelines[behavior] = [] 145 | behaviors_incall_codelines.add(behavior) 146 | continue 147 | 148 | if add_dependencies_codelines: 149 | dependencies_codelines[behavior] = get_behavior_class_codelines( 150 | sub_graph, behaviors_histogram, add_dependencies=False 151 | ) 152 | dependencies_hashmap[behavior.name] = _to_camel_case( 153 | behavior.name.capitalize() 154 | ) 155 | 156 | hashmap_codelines = [] 157 | if dependencies_hashmap: 158 | hashmap_codelines = ["BEHAVIOR_TO_NAME = {"] 159 | hashmap_codelines += [ 160 | f" '{name}': {class_name}," 161 | for name, class_name in dependencies_hashmap.items() 162 | ] 163 | hashmap_codelines += ["}"] 164 | 165 | return dependencies_codelines, behaviors_incall_codelines, hashmap_codelines 166 | 167 | 168 | def get_behavior_call_codelines( 169 | graph: "HEBGraph", 170 | behaviors_incall_codelines: Set["Behavior"], 171 | indent: int = 1, 172 | with_overhead=True, 173 | ) -> List[str]: 174 | """Generate the codelines of a GeneratedBehavior call function. 175 | 176 | Args: 177 | graph (HEBGraph): HEBGraph from which to generate the behavior of. 178 | behaviors_incall_codelines (Set[Behavior]): Set of behavior to unroll directly 179 | instead of refering to. 180 | indent (int, optional): Indentation level. Defaults to 1. 181 | with_overhead (bool, optional): If True, adds the call function definition. 182 | Defaults to True. 183 | 184 | Returns: 185 | List[str]: Codelines of the GeneratedBehavior call function. 186 | """ 187 | call_codelines = [] 188 | if with_overhead: 189 | call_codelines.append(indent_str(indent) + "def __call__(self, observation):") 190 | indent += 1 191 | roots = get_roots(graph) 192 | return call_codelines + get_node_call_codelines( 193 | graph, 194 | roots[0], 195 | indent, 196 | behaviors_incall_codelines=behaviors_incall_codelines, 197 | ) 198 | 199 | 200 | def get_node_call_codelines( 201 | graph: "HEBGraph", 202 | node: Node, 203 | indent: int, 204 | behaviors_incall_codelines: Set["Behavior"], 205 | ) -> List[str]: 206 | """Generate codelines for an HEBGraph node recursively using the succesors. 207 | 208 | Args: 209 | graph (HEBGraph): HEBGraph containing the node. 210 | node (Node): Node to generate the call of. 211 | indent (int): Indentation level. 212 | behaviors_incall_codelines (Set[Behavior]): Set of behavior to unroll directly 213 | instead of refering to. 214 | Raises: 215 | NotImplementedError: Node is not an Action, FeatureCondition or Behavior. 216 | 217 | Returns: 218 | List[str]: Codelines of the node call. 219 | """ 220 | node_codelines = [] 221 | if isinstance(node, Action): 222 | action_name = node.name.split(BEHAVIOR_SEPARATOR)[-1] 223 | node_codelines.append( 224 | indent_str(indent) + f"return self.actions['{action_name}'](observation)" 225 | ) 226 | return node_codelines 227 | if isinstance(node, FeatureCondition): 228 | var_name = f"edge_index_{indent-2}" if indent > 2 else "edge_index" 229 | fc_name = node.name.split(BEHAVIOR_SEPARATOR)[-1] 230 | node_codelines.append( 231 | indent_str(indent) 232 | + f"{var_name} = self.feature_conditions['{fc_name}'](observation)" 233 | ) 234 | for i in [0, 1]: 235 | node_codelines.append(indent_str(indent) + f"if {var_name} == {i}:") 236 | successors = get_successors_with_index(graph, node, i) 237 | for succ_node in successors: 238 | node_codelines += get_node_call_codelines( 239 | graph, 240 | node=succ_node, 241 | indent=indent + 1, 242 | behaviors_incall_codelines=behaviors_incall_codelines, 243 | ) 244 | return node_codelines 245 | if isinstance(node, Behavior): 246 | if node in behaviors_incall_codelines: 247 | if node.name in graph.all_behaviors: 248 | node = graph.all_behaviors[node.name] 249 | return get_behavior_call_codelines( 250 | node.graph, 251 | behaviors_incall_codelines=behaviors_incall_codelines, 252 | indent=indent, 253 | with_overhead=False, 254 | ) 255 | default_line = f"return self.known_behaviors['{node.name}'](observation)" 256 | return [indent_str(indent) + default_line] 257 | raise NotImplementedError 258 | 259 | 260 | def indent_str(indent_level: int, indent_amount: int = 4) -> str: 261 | """Gives a string indentation from a given indent level. 262 | 263 | Args: 264 | indent_level (int): Level of indentation. 265 | indent_amount (int, optional): Number of spaces per indent. Defaults to 4. 266 | 267 | Returns: 268 | str: Indentation string. 269 | """ 270 | return " " * indent_level * indent_amount 271 | 272 | 273 | def _to_camel_case(text: str) -> str: 274 | s = ( 275 | text.replace("-", " ") 276 | .replace("_", " ") 277 | .replace("?", "") 278 | .replace("[", "") 279 | .replace("]", "") 280 | .replace("(", "") 281 | .replace(")", "") 282 | .replace(",", "") 283 | ) 284 | s = s.split() 285 | if len(text) == 0: 286 | return text 287 | return s[0] + "".join(i.capitalize() for i in s[1:]) 288 | 289 | 290 | def _to_snake_case(text: str) -> str: 291 | text = text.replace("-", " ").replace("?", "") 292 | return "_".join( 293 | sub("([A-Z][a-z]+)", r" \1", sub("([A-Z]+)", r" \1", text)).split() 294 | ).lower() 295 | -------------------------------------------------------------------------------- /src/hebg/metrics/histograms.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """HEBGraph used nodes histograms computation.""" 5 | 6 | from typing import TYPE_CHECKING, Dict, List, Tuple 7 | from warnings import warn 8 | 9 | import numpy as np 10 | 11 | from hebg.behavior import Behavior 12 | from hebg.metrics.complexity.utils import update_sum_dict 13 | from hebg.node import Action, FeatureCondition 14 | 15 | if TYPE_CHECKING: 16 | from hebg import HEBGraph, Node 17 | 18 | 19 | def behaviors_histograms( 20 | behaviors: List["Behavior"], 21 | default_node_complexity: float = 1.0, 22 | ) -> Dict["Behavior", Dict["Node", int]]: 23 | """Compute the used nodes histograms for a list of Behavior. 24 | 25 | Args: 26 | behaviors: List of Behavior to compute histograms of. 27 | default_node_complexity: Default node complexity if Node has no attribute complexity. 28 | 29 | Return: 30 | Dictionary of dictionaries of the number of use for each used node, for each behavior. 31 | """ 32 | return behaviors_histograms_and_complexites(behaviors, default_node_complexity)[0] 33 | 34 | 35 | def behaviors_histograms_and_complexites( 36 | behaviors: List["Behavior"], 37 | default_node_complexity: float = 1.0, 38 | ) -> Tuple[Dict["Behavior", Dict["Node", int]], Dict["Behavior", float]]: 39 | """Compute the used nodes histograms for a list of Behavior. 40 | 41 | Args: 42 | behaviors: List of Behavior to compute histograms of. 43 | default_node_complexity: Default node complexity if Node has no attribute complexity. 44 | 45 | Return: 46 | Tuple of two elements: 47 | - Dictionary of dictionaries of the number of use for each used node, for each behavior. 48 | - Dictionary computed complexity for each behavior. 49 | """ 50 | histograms: Dict["Behavior", Dict["Node", int]] = {} 51 | complexities: Dict["Behavior", float] = {} 52 | for behavior in behaviors: 53 | try: 54 | graph = behavior.graph 55 | except NotImplementedError: 56 | warn( 57 | f"Could not load graph for behavior: {behavior}." 58 | "Skipping histogram computation." 59 | ) 60 | continue 61 | histogram, complexity = hebgraph_histogram_and_complexity( 62 | graph, 63 | default_node_complexity=default_node_complexity, 64 | ) 65 | histograms[behavior] = histogram 66 | complexities[behavior] = complexity 67 | return histograms, complexities 68 | 69 | 70 | def hebgraph_histogram_and_complexity( 71 | graph: "HEBGraph", default_node_complexity: float = 1.0 72 | ) -> Tuple[Dict["Node", int], float]: 73 | """Compute the used nodes histogram for a Behavior. 74 | 75 | Args: 76 | behavior: Behavior to compute histogram of. 77 | default_node_complexity: Default node complexity if Node has no attribute complexity. 78 | 79 | Return: 80 | Tuple composed of two element: 81 | - Dictionary of the number of use for each used node in the graph. 82 | - Computed total complexity of the graph. 83 | 84 | """ 85 | nodes_histograms, nodes_complexities = nodes_histograms_and_complexities( 86 | graph, default_node_complexity 87 | ) 88 | root = graph.graph["nodes_by_level"][0][0] # Assumes a single root 89 | return nodes_histograms[root], nodes_complexities[root] 90 | 91 | 92 | def cumulated_hebgraph_histogram( 93 | graph: "HEBGraph", default_node_complexity: float = 1.0 94 | ) -> Dict["Node", int]: 95 | """Unroll the hebgraph histogram by accumulating sub-behaviors histograms. 96 | 97 | Args: 98 | graph (HEBGraph): The HEBgraph to compute the cumulated histogram of. 99 | default_node_complexity (float, optional): Default node complexity. Defaults to 1.0. 100 | 101 | Returns: 102 | Dict[Node, int]: Cumulated histogram of used nodes with unrolled behaviors. 103 | """ 104 | histogram, _ = hebgraph_histogram_and_complexity(graph, default_node_complexity) 105 | histograms = {graph.behavior: histogram} 106 | sub_behaviors = [ 107 | node 108 | for node in histogram 109 | if isinstance(node, Behavior) and node != graph.behavior 110 | ] 111 | for i, behavior in enumerate(sub_behaviors): 112 | if behavior.name in graph.all_behaviors: 113 | sub_behaviors[i] = graph.all_behaviors[behavior.name] 114 | 115 | histograms.update(behaviors_histograms(sub_behaviors, default_node_complexity)) 116 | 117 | done = False 118 | behavior_iteration = {} # Stoping condition 119 | while not done: 120 | behaviors = [node for node in histogram if isinstance(node, Behavior)] 121 | behavior_iteration, histogram, histograms = _iterate_cumulation( 122 | graph=graph, 123 | behaviors=behaviors, 124 | histogram=histogram, 125 | histograms=histograms, 126 | behavior_iteration=behavior_iteration, 127 | default_node_complexity=default_node_complexity, 128 | ) 129 | 130 | # Recompute behaviors because some might have been added 131 | behaviors = [node for node in histogram if isinstance(node, Behavior)] 132 | 133 | # We stop when all behaviors have been iterated on enough. 134 | if all(behavior in behavior_iteration for behavior in behaviors): 135 | done = all( 136 | behavior_iteration[behavior] == histogram[behavior] 137 | for behavior in behaviors 138 | ) 139 | 140 | return histogram 141 | 142 | 143 | def _iterate_cumulation( 144 | graph: "HEBGraph", 145 | behaviors: List["Behavior"], 146 | histogram: Dict["Node", int], 147 | histograms: Dict["Behavior", Dict["Node", int]], 148 | behavior_iteration: Dict["Behavior", int], 149 | default_node_complexity: float = 1.0, 150 | ): 151 | """Iterate once on every behavior accumulating histograms. 152 | 153 | Args: 154 | graph (HEBGraph): Current graph being iterated on. 155 | behaviors (List[Behavior]): List of behaviors to iterate. 156 | histogram (Dict[Node, int]): Histogram being accumulated. 157 | histograms (Dict[Behavior, Dict[Node, int]]): Histograms of knowed behaviors. 158 | behavior_iteration (Dict[Behavior, int]): Number of iterations already done 159 | for each behavior. 160 | 161 | Returns: 162 | Tuple of three updated dictionaries. (behavior_iteration, histogram, histograms). 163 | """ 164 | for behavior in behaviors: 165 | already_iterated = ( 166 | behavior_iteration[behavior] if behavior in behavior_iteration else 0 167 | ) 168 | n_used = max(0, histogram[behavior] - already_iterated) 169 | if n_used == 0: 170 | continue 171 | if behavior not in histograms: 172 | if behavior.name in graph.all_behaviors: 173 | behavior = graph.all_behaviors[behavior.name] 174 | try: 175 | behavior_graph = behavior.graph 176 | except NotImplementedError: 177 | if behavior not in behavior_iteration: 178 | behavior_iteration[behavior] = 0 179 | behavior_iteration[behavior] += 1 180 | continue 181 | sub_histogram, _ = hebgraph_histogram_and_complexity( 182 | behavior_graph, default_node_complexity 183 | ) 184 | histograms[behavior] = sub_histogram 185 | for _ in range(n_used): 186 | if behavior in histograms[behavior]: 187 | histograms[behavior].pop(behavior) 188 | histogram = update_sum_dict(histogram, histograms[behavior]) 189 | if behavior not in behavior_iteration: 190 | behavior_iteration[behavior] = 0 191 | behavior_iteration[behavior] += 1 192 | return behavior_iteration, histogram, histograms 193 | 194 | 195 | def nodes_histograms_and_complexities( 196 | graph: "HEBGraph", 197 | default_node_complexity: float = 1.0, 198 | _behaviors_in_search=None, 199 | ): 200 | """Compute the number of times each node of the graph is present 201 | while computing complexities to find the least complex path. 202 | 203 | Args: 204 | graph (HEBGraph): HEBGraph to compute the histogram of. 205 | default_node_complexity (float, optional): Default node complexity if undefined. 206 | Defaults to 1.0. 207 | _behaviors_in_search (List[Behavior], optional): List of behaviors already in reccursive 208 | pile to avoid infinite cycle. Defaults to None. 209 | 210 | Returns: 211 | Tuple[Dict[Node, Dict[Node, int]], Dict[Node, float]]: Tuple of two values: 212 | - Dictionary of subnode used for each node. (nodes_used_nodes) 213 | - Dictionary of computed smallest path complexities for each node. (complexities) 214 | """ 215 | nodes_by_level = graph.graph["nodes_by_level"] 216 | depth = graph.graph["depth"] 217 | 218 | _behaviors_in_search = [] if _behaviors_in_search is None else _behaviors_in_search 219 | _behaviors_in_search.append(str(graph.behavior)) 220 | 221 | complexities = {} 222 | nodes_used_nodes = {} 223 | for node in graph.nodes: 224 | if isinstance(node, (Action, FeatureCondition)): 225 | complexities[node] = node.complexity 226 | nodes_used_nodes[node] = {node: 1} 227 | 228 | for level in range(depth + 1)[::-1]: 229 | for node in nodes_by_level[level]: 230 | node_complexity = 0 231 | node_used_nodes = {} 232 | 233 | # Best successors accumulated histograms and complexity 234 | succ_by_index, complexities_by_index = _successors_by_index( 235 | graph, node, complexities 236 | ) 237 | for index, values in complexities_by_index.items(): 238 | min_index = np.argmin(values) 239 | choosen_succ = succ_by_index[index][min_index] 240 | node_used_nodes = update_sum_dict( 241 | node_used_nodes, nodes_used_nodes[choosen_succ] 242 | ) 243 | node_complexity += values[min_index] 244 | 245 | # Node only histogram and complexity 246 | ( 247 | node_only_used_behaviors, 248 | node_only_complexity, 249 | ) = _get_node_histogram_complexity( 250 | node, 251 | default_node_complexity=default_node_complexity, 252 | behaviors_in_search=_behaviors_in_search, 253 | ) 254 | node_used_nodes = update_sum_dict(node_used_nodes, node_only_used_behaviors) 255 | node_complexity += node_only_complexity 256 | 257 | complexities[node] = node_complexity 258 | nodes_used_nodes[node] = node_used_nodes 259 | return nodes_used_nodes, complexities 260 | 261 | 262 | def _successors_by_index( 263 | graph: "HEBGraph", node: "Node", complexities: Dict["Node", float] 264 | ) -> Tuple[Dict[int, List["Node"]], Dict[int, List[float]]]: 265 | """Group successors and their complexities by index. 266 | 267 | Args: 268 | graph: The HEBGraph to use. 269 | node: The Node from which we want to group successors. 270 | complexities: Dictionary of complexities for each potential successor node. 271 | 272 | Return: 273 | Tuple composed of a dictionary of successors for each index 274 | and a dictionary of complexities for each index. 275 | 276 | """ 277 | complexities_by_index = {} 278 | succ_by_index = {} 279 | for succ in graph.successors(node): 280 | succ_complexity = complexities[succ] 281 | index = int(graph.edges[node, succ]["index"]) 282 | if index not in complexities_by_index: 283 | complexities_by_index[index] = [] 284 | if index not in succ_by_index: 285 | succ_by_index[index] = [] 286 | complexities_by_index[index].append(succ_complexity) 287 | succ_by_index[index].append(succ) 288 | return succ_by_index, complexities_by_index 289 | 290 | 291 | def _get_node_histogram_complexity( 292 | node: "Node", behaviors_in_search=None, default_node_complexity: float = 1.0 293 | ) -> Tuple[Dict["Node", int], float]: 294 | """Compute the used nodes histogram and complexity of a single node. 295 | 296 | Args: 297 | node: The Node from which we want to compute the complexity. 298 | behaviors_in_search: List of Behavior already in search to avoid circular search. 299 | default_node_complexity: Default node complexity if Node has no attribute complexity. 300 | 301 | Return: 302 | Tuple composed of a dictionary of the number of use for each used Node by the given node 303 | and the given node complexity. 304 | 305 | """ 306 | 307 | if node.type == "behavior": 308 | if behaviors_in_search is not None and str(node) in behaviors_in_search: 309 | return {}, np.inf 310 | if node.type in ("action", "feature_condition", "behavior"): 311 | if node.complexity is not None: 312 | node_complexity = node.complexity 313 | else: 314 | node_complexity = default_node_complexity 315 | return {node: 1}, node_complexity 316 | if node.type == "empty": 317 | return {}, 0 318 | raise ValueError(f"Unkowned node type {node.type}") 319 | -------------------------------------------------------------------------------- /tests/test_paper_basic_example.py: -------------------------------------------------------------------------------- 1 | # HEBGraph for explainable hierarchical reinforcement learning 2 | # Copyright (C) 2021-2024 Mathïs FEDERICO 3 | 4 | """Integration tests for the initial paper examples.""" 5 | 6 | from typing import Dict, List 7 | from copy import deepcopy 8 | 9 | # import matplotlib.pyplot as plt 10 | 11 | 12 | import pytest 13 | import pytest_check as check 14 | 15 | from itertools import permutations 16 | from networkx.classes.digraph import DiGraph 17 | from networkx import is_isomorphic 18 | 19 | from hebg import Action, Behavior, FeatureCondition, HEBGraph 20 | from hebg.metrics.histograms import behaviors_histograms, cumulated_hebgraph_histogram 21 | from hebg.metrics.complexity.complexities import learning_complexity 22 | from hebg.requirements_graph import build_requirement_graph 23 | from hebg.unrolling import BEHAVIOR_SEPARATOR, unroll_graph 24 | 25 | from tests.examples.behaviors.report_example import Behavior0, Behavior1, Behavior2 26 | 27 | 28 | class TestPaperBasicExamples: 29 | """Basic examples from the initial paper""" 30 | 31 | @pytest.fixture(autouse=True) 32 | def setup(self): 33 | """Initialize variables.""" 34 | 35 | self.actions: List[Action] = [Action(i, complexity=1) for i in range(3)] 36 | self.feature_conditions: List[FeatureCondition] = [ 37 | FeatureCondition(f"feature {i}", complexity=1) for i in range(6) 38 | ] 39 | self.behaviors: List[Behavior] = [Behavior0(), Behavior1(), Behavior2()] 40 | 41 | self.expected_behavior_histograms: Dict[Behavior, Dict[Action, int]] = { 42 | self.behaviors[0]: { 43 | self.actions[0]: 1, 44 | self.actions[1]: 1, 45 | self.feature_conditions[0]: 1, 46 | }, 47 | self.behaviors[1]: { 48 | self.actions[0]: 1, 49 | self.actions[2]: 1, 50 | self.behaviors[0]: 1, 51 | self.feature_conditions[1]: 1, 52 | self.feature_conditions[2]: 1, 53 | }, 54 | self.behaviors[2]: { 55 | self.actions[0]: 1, 56 | self.behaviors[0]: 1, 57 | self.behaviors[1]: 2, 58 | self.feature_conditions[3]: 1, 59 | self.feature_conditions[4]: 1, 60 | self.feature_conditions[5]: 1, 61 | }, 62 | } 63 | 64 | def test_histograms(self): 65 | """should give expected histograms.""" 66 | check.equal( 67 | behaviors_histograms(self.behaviors), self.expected_behavior_histograms 68 | ) 69 | 70 | def test_cumulated_histograms(self): 71 | """should give expected cumulated histograms.""" 72 | expected_cumulated_histograms = { 73 | self.behaviors[0]: { 74 | self.actions[0]: 1, 75 | self.actions[1]: 1, 76 | self.feature_conditions[0]: 1, 77 | }, 78 | self.behaviors[1]: { 79 | self.actions[0]: 2, 80 | self.actions[2]: 1, 81 | self.actions[1]: 1, 82 | self.feature_conditions[0]: 1, 83 | self.feature_conditions[1]: 1, 84 | self.feature_conditions[2]: 1, 85 | self.behaviors[0]: 1, 86 | }, 87 | self.behaviors[2]: { 88 | self.actions[0]: 6, 89 | self.actions[1]: 3, 90 | self.actions[2]: 2, 91 | self.feature_conditions[0]: 3, 92 | self.feature_conditions[1]: 2, 93 | self.feature_conditions[2]: 2, 94 | self.feature_conditions[3]: 1, 95 | self.feature_conditions[4]: 1, 96 | self.feature_conditions[5]: 1, 97 | self.behaviors[0]: 3, 98 | self.behaviors[1]: 2, 99 | }, 100 | } 101 | for behavior in self.behaviors: 102 | check.equal( 103 | cumulated_hebgraph_histogram(behavior.graph), 104 | expected_cumulated_histograms[behavior], 105 | ) 106 | 107 | def test_learning_complexity(self): 108 | """should give expected learning_complexity.""" 109 | expected_learning_complexities = { 110 | self.behaviors[0]: 3, 111 | self.behaviors[1]: 6, 112 | self.behaviors[2]: 9, 113 | } 114 | expected_saved_complexities = { 115 | self.behaviors[0]: 0, 116 | self.behaviors[1]: 1, 117 | self.behaviors[2]: 12, 118 | } 119 | 120 | for behavior in self.behaviors: 121 | c_learning, saved_complexity = learning_complexity( 122 | behavior, used_nodes_all=self.expected_behavior_histograms 123 | ) 124 | 125 | print( 126 | f"{behavior}: {c_learning}|{expected_learning_complexities[behavior]}" 127 | f" {saved_complexity}|{expected_saved_complexities[behavior]}" 128 | ) 129 | 130 | check.almost_equal(c_learning, expected_learning_complexities[behavior]) 131 | check.almost_equal(saved_complexity, expected_saved_complexities[behavior]) 132 | 133 | def test_codegen(self): 134 | expected_code = "\n".join( 135 | ( 136 | "from hebg.codegen import GeneratedBehavior", 137 | "", 138 | "class Behavior0(GeneratedBehavior):", 139 | " def __call__(self, observation):", 140 | " edge_index = self.feature_conditions['feature 0'](observation)", 141 | " if edge_index == 0:", 142 | " return self.actions['Action(0)'](observation)", 143 | " if edge_index == 1:", 144 | " return self.actions['Action(1)'](observation)", 145 | "class Behavior1(GeneratedBehavior):", 146 | " def __call__(self, observation):", 147 | " edge_index = self.feature_conditions['feature 1'](observation)", 148 | " if edge_index == 0:", 149 | " return self.known_behaviors['behavior 0'](observation)", 150 | " if edge_index == 1:", 151 | " edge_index_1 = self.feature_conditions['feature 2'](observation)", 152 | " if edge_index_1 == 0:", 153 | " return self.actions['Action(0)'](observation)", 154 | " if edge_index_1 == 1:", 155 | " return self.actions['Action(2)'](observation)", 156 | "class Behavior2(GeneratedBehavior):", 157 | " def __call__(self, observation):", 158 | " edge_index = self.feature_conditions['feature 3'](observation)", 159 | " if edge_index == 0:", 160 | " edge_index_1 = self.feature_conditions['feature 4'](observation)", 161 | " if edge_index_1 == 0:", 162 | " return self.actions['Action(0)'](observation)", 163 | " if edge_index_1 == 1:", 164 | " return self.known_behaviors['behavior 1'](observation)", 165 | " if edge_index == 1:", 166 | " edge_index_1 = self.feature_conditions['feature 5'](observation)", 167 | " if edge_index_1 == 0:", 168 | " return self.known_behaviors['behavior 1'](observation)", 169 | " if edge_index_1 == 1:", 170 | " return self.known_behaviors['behavior 0'](observation)", 171 | "BEHAVIOR_TO_NAME = {", 172 | " 'behavior 0': Behavior0,", 173 | " 'behavior 1': Behavior1,", 174 | "}", 175 | ) 176 | ) 177 | generated_code = self.behaviors[2].graph.generate_source_code() 178 | check.equal(generated_code, expected_code) 179 | 180 | def test_requirement_graph_edges(self): 181 | """should give expected requirement_graph edges.""" 182 | expected_requirement_graph = DiGraph() 183 | for behavior in self.behaviors: 184 | expected_requirement_graph.add_node(behavior) 185 | expected_requirement_graph.add_edge(self.behaviors[0], self.behaviors[1]) 186 | expected_requirement_graph.add_edge(self.behaviors[0], self.behaviors[2]) 187 | expected_requirement_graph.add_edge(self.behaviors[1], self.behaviors[2]) 188 | 189 | requirements_graph = build_requirement_graph(self.behaviors) 190 | for behavior, other_behavior in permutations(self.behaviors, 2): 191 | print(behavior, other_behavior) 192 | req_has_edge = requirements_graph.has_edge(behavior, other_behavior) 193 | expected_req_has_edge = expected_requirement_graph.has_edge( 194 | behavior, other_behavior 195 | ) 196 | check.equal(req_has_edge, expected_req_has_edge) 197 | 198 | def test_requirement_graph_levels(self): 199 | """should give expected requirement_graph node levels (requirement depth).""" 200 | expected_levels = { 201 | self.behaviors[0]: 0, 202 | self.behaviors[1]: 1, 203 | self.behaviors[2]: 2, 204 | } 205 | requirements_graph = build_requirement_graph(self.behaviors) 206 | for behavior, level in requirements_graph.nodes(data="level"): 207 | check.equal(level, expected_levels[behavior]) 208 | 209 | def test_unrolled_behaviors_graphs(self): 210 | """should give expected unrolled_behaviors_graphs for each example behaviors.""" 211 | 212 | def lname(*args): 213 | return BEHAVIOR_SEPARATOR.join([str(arg) for arg in args]) 214 | 215 | expected_graph_0 = deepcopy(self.behaviors[0].graph) 216 | 217 | expected_graph_1 = HEBGraph(self.behaviors[1]) 218 | feature_0 = FeatureCondition(lname(self.behaviors[0], "feature 0")) 219 | expected_graph_1.add_edge( 220 | feature_0, Action(0, lname(self.behaviors[0], "Action(0)")), index=False 221 | ) 222 | expected_graph_1.add_edge( 223 | feature_0, Action(1, lname(self.behaviors[0], "Action(1)")), index=True 224 | ) 225 | feature_1 = FeatureCondition("feature 1") 226 | feature_2 = FeatureCondition("feature 2") 227 | expected_graph_1.add_edge(feature_1, feature_0, index=False) 228 | expected_graph_1.add_edge(feature_1, feature_2, index=True) 229 | expected_graph_1.add_edge(feature_2, Action(0), index=False) 230 | expected_graph_1.add_edge(feature_2, Action(2), index=True) 231 | 232 | expected_graph_2 = HEBGraph(self.behaviors[2]) 233 | feature_3 = FeatureCondition("feature 3") 234 | feature_4 = FeatureCondition("feature 4") 235 | feature_5 = FeatureCondition("feature 5") 236 | expected_graph_2.add_edge(feature_3, feature_4, index=False) 237 | expected_graph_2.add_edge(feature_3, feature_5, index=True) 238 | expected_graph_2.add_edge(feature_4, Action(0), index=False) 239 | 240 | feature_0 = FeatureCondition( 241 | lname(self.behaviors[1], self.behaviors[0], "feature 0") 242 | ) 243 | expected_graph_2.add_edge( 244 | feature_0, 245 | Action(0, lname(self.behaviors[1], self.behaviors[0], "Action(0)")), 246 | index=False, 247 | ) 248 | expected_graph_2.add_edge( 249 | feature_0, 250 | Action(1, lname(self.behaviors[1], self.behaviors[0], "Action(1)")), 251 | index=True, 252 | ) 253 | feature_1 = FeatureCondition(lname(self.behaviors[1], "feature 1")) 254 | feature_2 = FeatureCondition(lname(self.behaviors[1], "feature 2")) 255 | expected_graph_2.add_edge(feature_1, feature_0, index=False) 256 | expected_graph_2.add_edge(feature_1, feature_2, index=True) 257 | expected_graph_2.add_edge( 258 | feature_2, Action(0, lname(self.behaviors[1], "Action(0)")), index=False 259 | ) 260 | expected_graph_2.add_edge( 261 | feature_2, Action(2, lname(self.behaviors[1], "Action(2)")), index=True 262 | ) 263 | 264 | expected_graph_2.add_edge(feature_4, feature_1, index=True) 265 | 266 | feature_0_0 = FeatureCondition(lname(self.behaviors[0], "feature 0")) 267 | expected_graph_2.add_edge( 268 | feature_0_0, 269 | Action(0, lname(self.behaviors[0], "Action(0)")), 270 | index=False, 271 | ) 272 | expected_graph_2.add_edge( 273 | feature_0_0, 274 | Action(1, lname(self.behaviors[0], "Action(1)")), 275 | index=True, 276 | ) 277 | 278 | expected_graph_2.add_edge(feature_5, feature_1, index=False) 279 | expected_graph_2.add_edge(feature_5, feature_0_0, index=True) 280 | 281 | expected_graph = { 282 | self.behaviors[0]: expected_graph_0, 283 | self.behaviors[1]: expected_graph_1, 284 | self.behaviors[2]: expected_graph_2, 285 | } 286 | for behavior in self.behaviors: 287 | unrolled_graph = unroll_graph(behavior.graph, add_prefix=True) 288 | check.is_true(is_isomorphic(unrolled_graph, expected_graph[behavior])) 289 | 290 | # fig, axes = plt.subplots(1, 2) 291 | # unrolled_graph = behavior.graph.unrolled_graph 292 | # unrolled_graph.draw(axes[0], draw_behaviors_hulls=True) 293 | # expected_graph[behavior].draw(axes[1], draw_behaviors_hulls=True) 294 | # plt.show() 295 | -------------------------------------------------------------------------------- /tests/test_code_generation.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Any 2 | 3 | import pytest 4 | import pytest_check as check 5 | 6 | from hebg import Action, FeatureCondition, Behavior 7 | 8 | from tests.examples.behaviors import ( 9 | FundamentalBehavior, 10 | F_A_Behavior, 11 | F_F_A_Behavior, 12 | build_binary_sum_behavior, 13 | ) 14 | from tests.examples.feature_conditions import ThresholdFeatureCondition 15 | 16 | 17 | class TestABehavior: 18 | """(A) Fundamental behaviors (single Action node) should return action call.""" 19 | 20 | @pytest.fixture(autouse=True) 21 | def setup(self): 22 | self.behavior = FundamentalBehavior(Action(42)) 23 | 24 | def test_source_codegen(self): 25 | source_code = self.behavior.graph.generate_source_code() 26 | expected_source_code = "\n".join( 27 | ( 28 | "from hebg.codegen import GeneratedBehavior", 29 | "", 30 | "class Action42Behavior(GeneratedBehavior):", 31 | " def __call__(self, observation):", 32 | " return self.actions['Action(42)'](observation)", 33 | ) 34 | ) 35 | check.equal( 36 | source_code, 37 | expected_source_code, 38 | msg=_unidiff_output(source_code, expected_source_code), 39 | ) 40 | 41 | def test_exec_codegen(self): 42 | check_execution_for_values(self.behavior, "Action42Behavior", (1, -1)) 43 | 44 | 45 | class TestFABehavior: 46 | """(F-A) Feature condition should generate if/else condition.""" 47 | 48 | @pytest.fixture(autouse=True) 49 | def setup(self): 50 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0) 51 | actions = {0: Action(0), 1: Action(1)} 52 | self.behavior = F_A_Behavior("Is above_zero", feature_condition, actions) 53 | 54 | def test_source_codegen(self): 55 | source_code = self.behavior.graph.generate_source_code() 56 | expected_source_code = "\n".join( 57 | ( 58 | "from hebg.codegen import GeneratedBehavior", 59 | "", 60 | "class IsAboveZero(GeneratedBehavior):", 61 | " def __call__(self, observation):", 62 | " edge_index = self.feature_conditions['Greater or equal to 0 ?'](observation)", 63 | " if edge_index == 0:", 64 | " return self.actions['Action(0)'](observation)", 65 | " if edge_index == 1:", 66 | " return self.actions['Action(1)'](observation)", 67 | ) 68 | ) 69 | 70 | check.equal( 71 | source_code, 72 | expected_source_code, 73 | msg=_unidiff_output(source_code, expected_source_code), 74 | ) 75 | 76 | def test_exec_codegen(self): 77 | check_execution_for_values(self.behavior, "IsAboveZero", (1, -1)) 78 | 79 | 80 | class TestFFABehavior: 81 | """(F-F-A) Chained FeatureConditions should condition should generate nested if/else.""" 82 | 83 | @pytest.fixture(autouse=True) 84 | def setup(self): 85 | self.behavior = F_F_A_Behavior("scalar classification ]-1,0,1[ ?") 86 | 87 | def test_source_codegen(self): 88 | source_code = self.behavior.graph.generate_source_code() 89 | expected_source_code = "\n".join( 90 | ( 91 | "from hebg.codegen import GeneratedBehavior", 92 | "", 93 | "class ScalarClassification101(GeneratedBehavior):", 94 | " def __call__(self, observation):", 95 | " edge_index = self.feature_conditions['Greater or equal to 0 ?'](observation)", 96 | " if edge_index == 0:", 97 | " edge_index_1 = self.feature_conditions['Greater or equal to -1 ?'](observation)", 98 | " if edge_index_1 == 0:", 99 | " return self.actions['Action(0)'](observation)", 100 | " if edge_index_1 == 1:", 101 | " return self.actions['Action(1)'](observation)", 102 | " if edge_index == 1:", 103 | " edge_index_1 = self.feature_conditions['Lesser or equal to 1 ?'](observation)", 104 | " if edge_index_1 == 0:", 105 | " return self.actions['Action(3)'](observation)", 106 | " if edge_index_1 == 1:", 107 | " return self.actions['Action(2)'](observation)", 108 | ) 109 | ) 110 | 111 | check.equal( 112 | source_code, 113 | expected_source_code, 114 | msg=_unidiff_output(source_code, expected_source_code), 115 | ) 116 | 117 | def test_exec_codegen(self): 118 | check_execution_for_values( 119 | self.behavior, "ScalarClassification101", (2, 1, -1, -2) 120 | ) 121 | 122 | 123 | class TestFBBehavior: 124 | """(F-BA) Behaviors should be unrolled by default if they appear only once and have a graph.""" 125 | 126 | @pytest.fixture(autouse=True) 127 | def setup(self): 128 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0) 129 | actions = {0: Action(0), 1: Action(1)} 130 | sub_behavior = F_A_Behavior("Is above_zero", feature_condition, actions) 131 | 132 | feature_condition = ThresholdFeatureCondition(relation="<=", threshold=1) 133 | actions = {0: Action(0), 1: sub_behavior} 134 | self.behavior = F_A_Behavior("Is between 0 and 1 ?", feature_condition, actions) 135 | 136 | def test_source_codegen(self): 137 | source_code = self.behavior.graph.generate_source_code() 138 | expected_source_code = "\n".join( 139 | ( 140 | "from hebg.codegen import GeneratedBehavior", 141 | "", 142 | "class IsBetween0And1(GeneratedBehavior):", 143 | " def __call__(self, observation):", 144 | " edge_index = self.feature_conditions['Lesser or equal to 1 ?'](observation)", 145 | " if edge_index == 0:", 146 | " return self.actions['Action(0)'](observation)", 147 | " if edge_index == 1:", 148 | " edge_index_1 = self.feature_conditions['Greater or equal to 0 ?'](observation)", 149 | " if edge_index_1 == 0:", 150 | " return self.actions['Action(0)'](observation)", 151 | " if edge_index_1 == 1:", 152 | " return self.actions['Action(1)'](observation)", 153 | ) 154 | ) 155 | 156 | check.equal( 157 | source_code, 158 | expected_source_code, 159 | msg=_unidiff_output(source_code, expected_source_code), 160 | ) 161 | 162 | def test_exec_codegen(self): 163 | check_execution_for_values(self.behavior, "IsBetween0And1", (-1, 0, 1, 2)) 164 | 165 | 166 | class TestFBBehaviorNameRef: 167 | """(F-BA) Behaviors should work with only name reference to behavior, 168 | but will expect behavior to be given, even when unrolled.""" 169 | 170 | @pytest.fixture(autouse=True) 171 | def setup(self): 172 | feature_condition = ThresholdFeatureCondition(relation="<=", threshold=1) 173 | actions = {0: Action(0), 1: Behavior("Is above_zero")} 174 | self.behavior = F_A_Behavior("Is between 0 and 1 ?", feature_condition, actions) 175 | 176 | @pytest.mark.filterwarnings("ignore:Could not load graph for behavior") 177 | def test_source_codegen_by_ref(self): 178 | unrolled_source_code = self.behavior.graph.generate_source_code() 179 | source_code = self.behavior.graph.unrolled_graph.generate_source_code() 180 | expected_source_code = "\n".join( 181 | ( 182 | "from hebg.codegen import GeneratedBehavior", 183 | "", 184 | "# Require 'Is above_zero' behavior to be given.", 185 | "class IsBetween0And1(GeneratedBehavior):", 186 | " def __call__(self, observation):", 187 | " edge_index = self.feature_conditions['Lesser or equal to 1 ?'](observation)", 188 | " if edge_index == 0:", 189 | " return self.actions['Action(0)'](observation)", 190 | " if edge_index == 1:", 191 | " return self.known_behaviors['Is above_zero'](observation)", 192 | ) 193 | ) 194 | 195 | check.equal( 196 | source_code, 197 | expected_source_code, 198 | msg=_unidiff_output(source_code, expected_source_code), 199 | ) 200 | check.equal( 201 | unrolled_source_code, 202 | expected_source_code, 203 | msg=_unidiff_output(unrolled_source_code, expected_source_code), 204 | ) 205 | 206 | def test_source_codegen_in_all_behavior(self): 207 | """When the behavior is found in 'all_behaviors' 208 | it should used the found behavior for codegen.""" 209 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0) 210 | actions = {0: Action(0), 1: Action(1)} 211 | sub_behavior = F_A_Behavior("Is above_zero", feature_condition, actions) 212 | self.behavior.graph.all_behaviors["Is above_zero"] = sub_behavior 213 | source_code = self.behavior.graph.generate_source_code() 214 | expected_source_code = "\n".join( 215 | ( 216 | "from hebg.codegen import GeneratedBehavior", 217 | "", 218 | "class IsBetween0And1(GeneratedBehavior):", 219 | " def __call__(self, observation):", 220 | " edge_index = self.feature_conditions['Lesser or equal to 1 ?'](observation)", 221 | " if edge_index == 0:", 222 | " return self.actions['Action(0)'](observation)", 223 | " if edge_index == 1:", 224 | " edge_index_1 = self.feature_conditions['Greater or equal to 0 ?'](observation)", 225 | " if edge_index_1 == 0:", 226 | " return self.actions['Action(0)'](observation)", 227 | " if edge_index_1 == 1:", 228 | " return self.actions['Action(1)'](observation)", 229 | ) 230 | ) 231 | 232 | check.equal( 233 | source_code, 234 | expected_source_code, 235 | msg=_unidiff_output(source_code, expected_source_code), 236 | ) 237 | 238 | def test_exec_codegen(self): 239 | """When the behavior is found in 'all_behaviors' 240 | it should used the found behavior for graph call.""" 241 | feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0) 242 | actions = {0: Action(0), 1: Action(1)} 243 | sub_behavior = F_A_Behavior("Is above_zero", feature_condition, actions) 244 | self.behavior.graph.all_behaviors["Is above_zero"] = sub_behavior 245 | check_execution_for_values( 246 | self.behavior, 247 | "IsBetween0And1", 248 | (-1, 0, 1, 2), 249 | known_behaviors={"Is above_zero": sub_behavior}, 250 | ) 251 | 252 | 253 | class TestNestedBehaviorReuse: 254 | """Behaviors should be rolled when they are used mutliple times in nested subgraphs.""" 255 | 256 | @pytest.fixture(autouse=True) 257 | def setup(self): 258 | feature_condition = FeatureCondition(name="fc1") 259 | actions = {0: Action(0), 1: Action(1)} 260 | behavior_0 = F_A_Behavior("behavior 0", feature_condition, actions) 261 | 262 | feature_condition = FeatureCondition(name="fc2") 263 | actions = {0: Action(0), 1: behavior_0} 264 | behavior_1 = F_A_Behavior("behavior 1", feature_condition, actions) 265 | 266 | feature_condition = FeatureCondition(name="fc3") 267 | actions = {0: behavior_0, 1: behavior_1} 268 | self.behavior = F_A_Behavior("behavior 2", feature_condition, actions) 269 | 270 | def test_nested_reuse_codegen(self): 271 | source_code = self.behavior.graph.generate_source_code() 272 | expected_source_code = "\n".join( 273 | ( 274 | "from hebg.codegen import GeneratedBehavior", 275 | "", 276 | "class Behavior0(GeneratedBehavior):", 277 | " def __call__(self, observation):", 278 | " edge_index = self.feature_conditions['fc1'](observation)", 279 | " if edge_index == 0:", 280 | " return self.actions['Action(0)'](observation)", 281 | " if edge_index == 1:", 282 | " return self.actions['Action(1)'](observation)", 283 | "class Behavior2(GeneratedBehavior):", 284 | " def __call__(self, observation):", 285 | " edge_index = self.feature_conditions['fc3'](observation)", 286 | " if edge_index == 0:", 287 | " return self.known_behaviors['behavior 0'](observation)", 288 | " if edge_index == 1:", 289 | " edge_index_1 = self.feature_conditions['fc2'](observation)", 290 | " if edge_index_1 == 0:", 291 | " return self.actions['Action(0)'](observation)", 292 | " if edge_index_1 == 1:", 293 | " return self.known_behaviors['behavior 0'](observation)", 294 | "BEHAVIOR_TO_NAME = {", 295 | " 'behavior 0': Behavior0,", 296 | "}", 297 | ) 298 | ) 299 | 300 | check.equal( 301 | source_code, 302 | expected_source_code, 303 | msg=_unidiff_output(source_code, expected_source_code), 304 | ) 305 | 306 | 307 | class TestFundamentalBehaviorReuse: 308 | """Fundamental Behaviors should never be abstracted.""" 309 | 310 | @pytest.fixture(autouse=True) 311 | def setup(self): 312 | behavior_0 = FundamentalBehavior(Action(1)) 313 | 314 | feature_condition = FeatureCondition(name="fc2") 315 | actions = {0: Action(0), 1: behavior_0} 316 | behavior_1 = F_A_Behavior("behavior 1", feature_condition, actions) 317 | 318 | feature_condition = FeatureCondition(name="fc3") 319 | actions = {0: behavior_0, 1: behavior_1} 320 | self.behavior = F_A_Behavior("behavior 2", feature_condition, actions) 321 | 322 | def test_nested_reuse_codegen(self): 323 | source_code = self.behavior.graph.generate_source_code() 324 | expected_source_code = "\n".join( 325 | ( 326 | "from hebg.codegen import GeneratedBehavior", 327 | "", 328 | "class Behavior2(GeneratedBehavior):", 329 | " def __call__(self, observation):", 330 | " edge_index = self.feature_conditions['fc3'](observation)", 331 | " if edge_index == 0:", 332 | " return self.actions['Action(1)'](observation)", 333 | " if edge_index == 1:", 334 | " edge_index_1 = self.feature_conditions['fc2'](observation)", 335 | " if edge_index_1 == 0:", 336 | " return self.actions['Action(0)'](observation)", 337 | " if edge_index_1 == 1:", 338 | " return self.actions['Action(1)'](observation)", 339 | ) 340 | ) 341 | 342 | check.equal( 343 | source_code, 344 | expected_source_code, 345 | msg=_unidiff_output(source_code, expected_source_code), 346 | ) 347 | 348 | 349 | class TestFBBBehavior: 350 | """(F-B-B) Behaviors should only be added once as a class.""" 351 | 352 | @pytest.fixture(autouse=True) 353 | def setup(self): 354 | self.behavior = build_binary_sum_behavior() 355 | 356 | def test_classes_in_codegen(self): 357 | source_code = self.behavior.graph.generate_source_code() 358 | expected_classes = [ 359 | "IsX1InBinary", # Only this one is used twice 360 | "IsSumOfLast3Binary2", 361 | ] 362 | 363 | for expected_class in expected_classes: 364 | check.equal( 365 | source_code.count(f"class {expected_class}"), 366 | 1, 367 | msg=f"Missing or duplicated class: {expected_class}\n{source_code}", 368 | ) 369 | 370 | def test_exec_codegen(self): 371 | check_execution_for_values( 372 | self.behavior, "IsSumOfLast3Binary2", (0, 1, 3, 5, 15) 373 | ) 374 | 375 | 376 | def check_execution_for_values( 377 | behavior: Behavior, 378 | class_name: str, 379 | values: Tuple[Any], 380 | known_behaviors: Optional[dict] = None, 381 | ): 382 | generated_source_code = behavior.graph.generate_source_code() 383 | exec(generated_source_code) 384 | CodeGenPolicy = locals()[class_name] 385 | 386 | actions, feature_conditions, behaviors = separate_nodes_by_type(behavior) 387 | 388 | _behaviors = behaviors.copy() 389 | while len(_behaviors) > 0: 390 | _, sub_behavior = _behaviors.popitem() 391 | if sub_behavior in behavior.graph.all_behaviors: 392 | sub_behavior = behavior.graph.all_behaviors[sub_behavior] 393 | sub_actions, sub_feature_conditions, sub_behaviors = separate_nodes_by_type( 394 | sub_behavior 395 | ) 396 | actions.update(sub_actions) 397 | feature_conditions.update(sub_feature_conditions) 398 | _behaviors.update(sub_behaviors) 399 | behaviors.update(sub_behaviors) 400 | 401 | known_behaviors = known_behaviors if known_behaviors is not None else {} 402 | behaviors.update(known_behaviors) 403 | 404 | behavior_rebuilt = CodeGenPolicy( 405 | actions=actions, 406 | feature_conditions=feature_conditions, 407 | behaviors=behaviors, 408 | ) 409 | 410 | for val in values: 411 | check.equal(behavior(val), behavior_rebuilt(val)) 412 | 413 | 414 | def separate_nodes_by_type(behavior: Behavior): 415 | actions = { 416 | node.name: node for node in behavior.graph.nodes if isinstance(node, Action) 417 | } 418 | feature_conditions = { 419 | node.name: node 420 | for node in behavior.graph.nodes 421 | if isinstance(node, FeatureCondition) 422 | } 423 | behaviors = { 424 | node.name: node for node in behavior.graph.nodes if isinstance(node, Behavior) 425 | } 426 | return actions, feature_conditions, behaviors 427 | 428 | 429 | def _unidiff_output(expected: str, actual: str): 430 | """ 431 | Helper function. Returns a string containing the unified diff of two multiline strings. 432 | """ 433 | 434 | import difflib 435 | 436 | expected = [line for line in expected.split("\n")] 437 | actual = [line for line in actual.split("\n")] 438 | 439 | diff = difflib.unified_diff(expected, actual) 440 | 441 | return "\n".join(diff) 442 | --------------------------------------------------------------------------------