├── .gitignore ├── .gitlab-ci.yml ├── .readthedocs.yml ├── Dockerfile ├── LICENSE ├── README.md ├── apprentice ├── __init__.py ├── agents │ ├── Memo.py │ ├── ModularAgent.py │ ├── RHS_LHS_Agent.py │ ├── RLAgent.py │ ├── Stub.py │ ├── WhereWhenHow.py │ ├── WhereWhenHowNoFoa.py │ ├── __init__.py │ ├── base.py │ ├── cre_agents │ │ ├── __init__.py │ │ ├── constraints.py │ │ ├── conv_funcs.py │ │ ├── cre_agent.py │ │ ├── debug_utils.py │ │ ├── dipl_base.py │ │ ├── environment.py │ │ ├── extending.py │ │ ├── feature_factory.py │ │ ├── funcs.py │ │ ├── learning_mechs │ │ │ ├── __init__.py │ │ │ ├── how │ │ │ │ ├── __init__.py │ │ │ │ ├── how.py │ │ │ │ └── nlp │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── analysis │ │ │ │ │ ├── .~lock.NLP_how_turk_data_cleaned.csv# │ │ │ │ │ ├── Analysis.ipynb │ │ │ │ │ ├── NLP_how_turk_data_cleaned.csv │ │ │ │ │ ├── NR_NLP_how_turk_data_cleaned.csv │ │ │ │ │ ├── Sandbox.ipynb │ │ │ │ │ ├── codex_madeup_formulae.py │ │ │ │ │ ├── codex_participant_probs.py │ │ │ │ │ ├── copilot │ │ │ │ │ │ ├── p1.py │ │ │ │ │ │ ├── p10.py │ │ │ │ │ │ ├── p11.py │ │ │ │ │ │ ├── p12.py │ │ │ │ │ │ ├── p13.py │ │ │ │ │ │ ├── p14.py │ │ │ │ │ │ ├── p2.py │ │ │ │ │ │ ├── p3.py │ │ │ │ │ │ ├── p4.py │ │ │ │ │ │ ├── p5.py │ │ │ │ │ │ ├── p6.py │ │ │ │ │ │ ├── p7.py │ │ │ │ │ │ ├── p8.py │ │ │ │ │ │ └── p9.py │ │ │ │ │ ├── garbage1.py │ │ │ │ │ └── output.pstats │ │ │ │ │ ├── examples.py │ │ │ │ │ ├── garbage1.py │ │ │ │ │ ├── nlp_sc_planner.py │ │ │ │ │ ├── parser.py │ │ │ │ │ ├── resources.py │ │ │ │ │ └── tests │ │ │ │ │ └── test_nlp_sc_planner.py │ │ │ ├── process │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── parse.py │ │ │ │ ├── process.py │ │ │ │ └── tests │ │ │ │ │ └── test.py │ │ │ ├── registers.py │ │ │ ├── when.py │ │ │ ├── where.py │ │ │ └── which.py │ │ ├── state.py │ │ ├── test.py │ │ ├── tests │ │ │ ├── test_agent.py │ │ │ ├── test_extending.py │ │ │ └── test_state.py │ │ └── two_mech_agent.py │ ├── diff_base.py │ ├── experta_agent.py │ ├── soartech_agent.py │ └── utils.py ├── custom_operators.py ├── explain │ ├── __init__.py │ ├── explanation.py │ ├── inspect_patch.py │ ├── kill_engine.py │ ├── lambda_test.py │ └── util.py ├── learners │ ├── Grammar.py │ ├── HowLearner.py │ ├── HowLearnerOld.py │ ├── IncrementalHeuristic.py │ ├── WhatLearner.py │ ├── WhenLearner.py │ ├── WhereLearner.py │ ├── WhichLearner.py │ ├── __init__.py │ ├── grammar.pickle │ ├── pyibl.py │ ├── utils.py │ └── when_learners │ │ ├── __init__.py │ │ ├── actor_critic.py │ │ ├── actor_critic_learner.py │ │ ├── dqn.py │ │ ├── dqn_learner.py │ │ ├── fractions_hasher.py │ │ ├── q_learner.py │ │ └── replay_memory.py ├── logging.yaml ├── planners │ ├── NumbaPlanner.py │ ├── VectorizedPlanner.py │ ├── __init__.py │ ├── action_planner.py │ ├── base_planner.py │ └── fo_planner.py ├── shared.py └── working_memory │ ├── __init__.py │ ├── adapters │ ├── __init__.py │ └── experta_ │ │ ├── __init__.py │ │ ├── factory.py │ │ └── workingmemory.py │ ├── base.py │ ├── experta_skills.py │ ├── fo_planner_operators.py │ ├── numba_operators.py │ ├── representation │ ├── __init__.py │ └── representation.py │ └── skills_test.py ├── django ├── agent_api │ ├── __init__.py │ ├── settings.py │ ├── urls.py │ └── wsgi.py ├── apprentice_learner │ ├── __init__.py │ ├── admin.py │ ├── apps.py │ ├── migrations │ │ └── __init__.py │ ├── models.py │ ├── templates │ │ └── apprentice_learner │ │ │ └── tester.html │ ├── tests.py │ ├── urls.py │ └── views.py └── manage.py ├── docker-compose.yml ├── docs ├── Makefile ├── agents.rst ├── conf.py ├── doc-requirements.txt ├── images │ └── batch_train_example.png ├── index.rst ├── learners.rst ├── make.bat ├── planners.rst └── working_memory.rst ├── examples ├── test_retract.py ├── tic-tac-toe.py ├── ttt_simple.py ├── ttt_wm.py └── ttt_wm_vs_human.py ├── requirements.txt ├── setup.cfg ├── setup.py ├── setup2.py ├── test-requirements.txt └── tox.ini /.gitignore: -------------------------------------------------------------------------------- 1 | builds/ 2 | codeclimate.json 3 | .tox 4 | output.xml 5 | .coverage 6 | .pytest_cache/ 7 | htmlcov/ 8 | *.pyc 9 | *.swp 10 | __pycache__/ 11 | 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | *.o 21 | 22 | # Distribution / packaging 23 | .Python 24 | env/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | AUTHORS 40 | ChangeLog 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *,cover 61 | .hypothesis/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # IPython Notebook 86 | .ipynb_checkpoints 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # dotenv 95 | .env 96 | 97 | # virtualenv 98 | venv/ 99 | ENV/ 100 | 101 | # Spyder project settings 102 | .spyderproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # Temporary vim files 108 | *.swp 109 | 110 | # PyCharm config 111 | .idea/ 112 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | image: "python:3.7" 2 | 3 | variables: 4 | PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip" 5 | 6 | cache: 7 | paths: 8 | - .cache/pip 9 | 10 | stages: 11 | - lint 12 | - test 13 | - document 14 | - deploy 15 | 16 | flake8: 17 | stage: lint 18 | allow_failure: true 19 | script: 20 | - pip install flake8 flake8-junit-report 21 | - retval=0 22 | - flake8 --output-file flake8.txt apprentice/ || retval=$? 23 | - flake8_junit flake8.txt flake8_junit.xml 24 | - cat flake8.txt 25 | - exit "$retval" 26 | artifacts: 27 | when: always 28 | reports: 29 | junit: flake8_junit.xml 30 | tags: 31 | - base 32 | 33 | coverage: 34 | stage: test 35 | allow_failure: true 36 | script: 37 | - pip install -r requirements.txt --exists-action w 38 | - pip install -r test-requirements.txt 39 | - retval=0 40 | - coverage run --source apprentice -m pytest || retval=$? 41 | - coverage html -d coverage 42 | - coverage report 43 | - exit "$retval" 44 | coverage: '/\d+\%\s*$/' 45 | artifacts: 46 | paths: 47 | - coverage 48 | tags: 49 | - base 50 | 51 | pytest: 52 | stage: test 53 | allow_failure: false 54 | script: 55 | - pip install -r requirements.txt --exists-action w 56 | - pip install -r test-requirements.txt 57 | - python -m pytest 58 | artifacts: 59 | when: always 60 | reports: 61 | junit: output.xml 62 | tags: 63 | - base 64 | 65 | sphinx: 66 | stage: document 67 | dependencies: 68 | - pytest 69 | script: 70 | - pip install -r requirements.txt --exists-action w 71 | - pip install -r docs/doc-requirements.txt 72 | - apt-get update 73 | - apt-get install make 74 | - cd docs 75 | - make html 76 | - mv _build/html/ ../sphinx 77 | artifacts: 78 | paths: 79 | - sphinx 80 | tags: 81 | - base 82 | only: 83 | - master 84 | 85 | publish: 86 | stage: deploy 87 | dependencies: 88 | - sphinx 89 | script: 90 | - 'which ssh-agent || ( apt-get update -y && apt-get install openssh-client -y )' 91 | - eval $(ssh-agent -s) 92 | - echo "$SSH_PRIVATE_KEY" | tr -d '\r' | ssh-add - 93 | - mkdir -p ~/.ssh 94 | - chmod 700 ~/.ssh 95 | - git config user.email "chris.maclellan@soartech.com" 96 | - git config user.name "Chris MacLellan (automated triskele)" 97 | - git remote rm public 98 | - git remote add public git@github.com:apprenticelearner/AL_Core.git 99 | - ssh-keyscan -t rsa github.com >> ~/.ssh/known_hosts 100 | - git push public HEAD:soartech-dev 101 | tags: 102 | - base 103 | only: 104 | - master 105 | 106 | pages: 107 | stage: deploy 108 | dependencies: 109 | - sphinx 110 | - coverage 111 | script: 112 | - mv sphinx public/ 113 | - mv coverage public/coverage 114 | environment: 115 | name: pages 116 | url: https://hq-git.soartech.com/apprentice/apprentice 117 | artifacts: 118 | paths: 119 | - public 120 | tags: 121 | - base 122 | only: 123 | - master 124 | 125 | dockerize: 126 | stage: deploy 127 | script: 128 | - docker build -t hq-git.soartech.com:4567/apprentice/apprentice . 129 | - docker push hq-git.soartech.com:4567/apprentice/apprentice 130 | tags: 131 | - shell 132 | only: 133 | - master 134 | 135 | 136 | # pypi: 137 | # stage: deploy 138 | # dependencies: 139 | # - pytest 140 | # script: 141 | # - pip install twine 142 | # - python setup.py sdist bdist_wheel 143 | # - twine upload --repository-url https://nexus.soartech.com:8443/nexus/repository/pypi-internal/ dist/* 144 | # tags: 145 | # - base 146 | # only: 147 | # - master 148 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | #version: 0.0.1 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | sphinx: 10 | configuration: docs/conf.py 11 | 12 | # Build documentation with MkDocs 13 | #mkdocs: 14 | # configuration: mkdocs.yml 15 | 16 | # Optionally build your docs in additional formats such as PDF 17 | formats: 18 | - pdf 19 | 20 | # Optionally set the version of Python and requirements required to build your docs 21 | python: 22 | version: 3.7 23 | install: 24 | - requirements: docs/doc-requirements.txt -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04 2 | 3 | RUN apt-get update && apt install -y software-properties-common && apt-get clean && rm -rf /var/lib/apt/lists/* 4 | RUN add-apt-repository ppa:deadsnakes/ppa 5 | 6 | RUN apt-get update && apt-get install -y --no-install-recommends \ 7 | python3.8 \ 8 | python3-pip \ 9 | git \ 10 | nano \ 11 | && \ 12 | apt-get clean && \ 13 | rm -rf /var/lib/apt/lists/* 14 | 15 | RUN python3.8 -m pip install --upgrade pip setuptools wheel 16 | RUN python3.8 -m pip install torch 17 | 18 | RUN mkdir -p /usr/local/apprentice 19 | WORKDIR /usr/local/apprentice 20 | COPY ./ /usr/local/apprentice 21 | 22 | RUN python3.8 -m pip install -r requirements.txt --exists-action=w 23 | RUN python3.8 -m pip install . 24 | 25 | RUN pytest tests 26 | 27 | #CMD ["/usr/bin/python3.8", "/usr/local/apprentice/django/manage.py", "runserver"] 28 | #CMD ["/bin/sh", "-ec", "sleep 1000"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Christopher J MacLellan, Erik Harpstead, and Daniel Weitekamp 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /apprentice/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | import os 4 | 5 | # import so "disable_loggers" can have effect 6 | #from experta import unwatch 7 | 8 | #unwatch() 9 | 10 | import coloredlogs 11 | import yaml 12 | 13 | from . import working_memory, agents, learners, planners 14 | 15 | 16 | def setup_logging(default_path='logging.yaml', default_level=logging.INFO, 17 | env_key='LOG_CFG'): 18 | # https://gist.github.com/kingspp/9451566a5555fb022215ca2b7b802f19 19 | path = default_path 20 | value = os.getenv(env_key, None) 21 | if value: 22 | path = value 23 | 24 | if os.path.exists(path): 25 | with open(path, 'rt') as f: 26 | try: 27 | config = yaml.safe_load(f.read()) 28 | logging.config.dictConfig(config) 29 | except Exception as e: 30 | print(e) 31 | print('Error in Logging Configuration. Using default configs') 32 | logging.basicConfig(level=default_level) 33 | coloredlogs.install(level=default_level) 34 | else: 35 | logging.basicConfig(level=default_level) 36 | coloredlogs.install(level=default_level) 37 | print('Failed to load configuration file. Using default configs') 38 | 39 | 40 | log_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 41 | 'logging.yaml') 42 | 43 | setup_logging(default_path=log_config_path) 44 | # '%(name)s:%(lineno)s | %(message)s' 45 | -------------------------------------------------------------------------------- /apprentice/agents/Memo.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from apprentice.agents.base import BaseAgent 4 | from apprentice.working_memory.representation import Sai 5 | 6 | 7 | def freeze(obj): 8 | """Freeze a state (dict), for memoizing.""" 9 | if isinstance(obj, dict): 10 | return frozenset({k: freeze(v) for k, v in obj.items()}.items()) 11 | if isinstance(obj, list): 12 | return tuple([freeze(v) for v in obj]) 13 | return obj 14 | 15 | 16 | class Memo(BaseAgent): 17 | """ 18 | Memorizes the state actions pairs and responds with the highest reward, 19 | demonstrated action for a given request. 20 | 21 | Made for testing the API. 22 | """ 23 | def __init__(self, **kwargs): 24 | self.lookup = {} 25 | 26 | def request(self, state: Dict, **kwargs) -> Dict: 27 | # print(state) 28 | state = freeze(state) 29 | resp = self.lookup.get(state, None) 30 | 31 | if resp is None: 32 | return {} 33 | 34 | return {'skill_label': resp[0], 35 | 'selection': resp[1].selection, 36 | 'action': resp[1].action, 37 | 'inputs': resp[1].inputs} 38 | 39 | def train(self, state: Dict, sai: Sai, reward: float, **kwargs): 40 | state = freeze(state) 41 | resp = self.lookup.get(state, None) 42 | if ((resp is None and reward > 0) or 43 | (resp is not None and reward >= resp[2])): 44 | self.lookup[state] = (skill_label, sai, reward) 45 | 46 | if (resp is not None and reward < 0 and skill_label == resp[0] and 47 | sai == resp[1]): 48 | del self.lookup[state] 49 | 50 | def check(self, state: Dict, sai: Sai, **kwargs) -> float: 51 | state = freeze(state) 52 | resp = self.lookip.get(state, None) 53 | 54 | if resp is None: 55 | return 0.0 56 | else: 57 | return resp[2] 58 | -------------------------------------------------------------------------------- /apprentice/agents/Stub.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from apprentice.agents.base import BaseAgent 4 | from apprentice.working_memory.representation import Sai 5 | 6 | 7 | class Stub(BaseAgent): 8 | """ 9 | Just a dummy agent that requests no actions, doesn't learn, and returns 10 | false for all checks. Made for testing the API. 11 | """ 12 | def request(self, state: Dict, **kwargs) -> Dict: 13 | return {} 14 | 15 | def train(self, state: Dict, sai: Sai, reward: float, **kwargs): 16 | pass 17 | 18 | def check(self, state: Dict, sai: Sai, **kwargs) -> float: 19 | return 0 20 | -------------------------------------------------------------------------------- /apprentice/agents/__init__.py: -------------------------------------------------------------------------------- 1 | # from .experta_agent import ExpertaAgent 2 | -------------------------------------------------------------------------------- /apprentice/agents/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | from abc import abstractmethod 3 | from typing import Dict 4 | 5 | from apprentice.working_memory.representation import Sai 6 | 7 | 8 | class BaseAgent(metaclass=ABCMeta): 9 | 10 | def __init__(self, **kwargs): 11 | """ 12 | Creates an agent with the provided skills. 13 | """ 14 | pass 15 | 16 | @abstractmethod 17 | def request(self, state: Dict, **kwargs) -> Dict: 18 | """ 19 | Returns a dict containing a Selection, Action, Input. 20 | 21 | :param state: a state represented as a dict (parsed from JSON) 22 | """ 23 | pass 24 | 25 | @abstractmethod 26 | def train(self, state: Dict, sai: Sai, reward: float, **kwargs): 27 | """ 28 | Accepts a JSON/Dict object representing the state, 29 | a JSON/Dict object representing the state after the SAI is invoked, 30 | a string representing the skill label, 31 | a list of strings representing the foas, 32 | a string representation the selection action and inputs, 33 | a reward 34 | """ 35 | pass 36 | 37 | @abstractmethod 38 | def check(self, state: Dict, sai: Sai, **kwargs) -> float: 39 | """ 40 | Checks the correctness (reward) of an SAI action in a given state. 41 | """ 42 | pass 43 | 44 | 45 | if __name__ == "__main__": 46 | pass 47 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .extending import new_register_all, new_register_decorator, registries 2 | from .environment import register_fact, register_all_facts 3 | from .funcs import register_func, register_all_funcs 4 | 5 | from .learning_mechs.registers import ( 6 | register_how, 7 | register_where, 8 | register_when, 9 | register_which, 10 | register_process 11 | ) 12 | 13 | from .learning_mechs.how import how 14 | from .learning_mechs import where 15 | from .learning_mechs import when 16 | from .learning_mechs import which 17 | from .learning_mechs.process import process 18 | 19 | from .cre_agent import CREAgent 20 | from .feature_factory import register_feature_factory 21 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/constraints.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apprenticelearner/AL_Core/9534e3e1d5f5439740e0c44196a4a4f91602aed3/apprentice/agents/cre_agents/constraints.py -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/conv_funcs.py: -------------------------------------------------------------------------------- 1 | from numba.types import f8, string, boolean 2 | from apprentice.agents.cre_agents.extending import registries, new_register_decorator, new_register_all 3 | from cre import CREFunc 4 | import numpy as np 5 | 6 | # -------------- 7 | # : Conversion Functions float/str 8 | 9 | register_conversion = new_register_decorator("conversion", full_descr="Conversions between types") 10 | 11 | @register_conversion(name="CastFloat") 12 | @CREFunc(shorthand = 'f8({0})') 13 | def CastFloat(a): 14 | return float(a) 15 | 16 | @register_conversion(name="CastStr") 17 | @CREFunc(shorthand = 'str({0})') 18 | def CastStr(a): 19 | return str(a) 20 | 21 | ##### Define all CREFuncs above this line ##### 22 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/debug_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import re 4 | from colorama import Fore, Back, Style 5 | # def get_value(x): 6 | # try: 7 | # return x.value 8 | # except: 9 | # return None 10 | 11 | # def get_locked(x): 12 | # try: 13 | # return x.value 14 | # except: 15 | # return None 16 | 17 | def style_locked(val,locked): 18 | if(locked): 19 | return f'{Back.WHITE}{Fore.BLACK}{val}{Style.RESET_ALL}' 20 | else: 21 | return val 22 | 23 | def shorten_id(_id): 24 | num = re.search(r'(\d)',_id) 25 | num = "" if num is None else num.group(0) 26 | char = re.search(r'([a-zA-Z])', _id) 27 | char = "" if char is None else char.group(0) 28 | short_id = f"{char}{num}" 29 | return short_id 30 | 31 | 32 | def shorthand_state_dict(state_dict, do_shorten_id=True): 33 | vals = [] 34 | for _id, x in state_dict.items(): 35 | val = x.get('value', None) 36 | val = "" if val is None else val 37 | 38 | short_id = shorten_id(_id) if do_shorten_id else _id 39 | val = f"{short_id}:{val if val else ''}" 40 | vals.append((val,x.get('locked', False))) 41 | return " ".join(style_locked(x,l) for x,l in sorted(vals)) 42 | 43 | def shorthand_state_wm(state): 44 | d = {} 45 | for x in state.get("working_memory").get_facts(): 46 | _id = x.id 47 | d[_id] = { 48 | 'id' : _id, 49 | 'value' : getattr(x,'value',None), 50 | 'locked' : getattr(x,'locked',None) 51 | } 52 | return shorthand_state_dict(d) 53 | 54 | 55 | def shorthand_state_flat(state): 56 | d = {} 57 | for x in state.get("flat").get_facts(): 58 | s = str(x) 59 | _id,rest = s.split(".") 60 | attr, val = rest.split(" == ") 61 | dct = d.get(_id,{"id" : _id}) 62 | if(attr == 'value'): 63 | val = val.strip("'") 64 | 65 | if(attr == 'locked'): 66 | val = True if val == "True" else False 67 | 68 | dct[attr] = val 69 | d[_id] = dct 70 | return shorthand_state_dict(d) 71 | 72 | 73 | 74 | def shorthand_state_rel(rel_state): 75 | d = {} 76 | for x in rel_state.get_facts(): 77 | s = str(x) 78 | if(s[0] != "("): 79 | splt = s.split(".") 80 | 81 | chain, rest = splt[:-1], splt[-1] 82 | attr, val = rest.split(" == ") 83 | _id = "".join([shorten_id(y) for y in chain]) 84 | # print("::", _id, s, chain, attr, val) 85 | dct = d.get(_id, {"id" : _id}) 86 | 87 | elif(s[1:10] == "SkillCand"): 88 | 89 | match = re.match(r"\(SkillCand:, .*, (?P\d+), (?P.*)\) == (?P.*)", s) 90 | _id = match.group('id') 91 | rest = match.group('rest') 92 | val = match.group('val') 93 | 94 | _id += "?" 95 | splt = rest.split(", ") 96 | for i, chain in enumerate(splt): 97 | _id += "".join([shorten_id(y) for y in chain.split(".")]) 98 | _id += ("," if i < len(splt) -1 else "") 99 | attr = 'value' 100 | 101 | if(attr == 'value'): 102 | val = val.strip("'") 103 | 104 | if(attr == 'locked'): 105 | val = True if val == "True" else False 106 | 107 | dct[attr] = val 108 | d[_id] = dct 109 | # print(d) 110 | return shorthand_state_dict(d, do_shorten_id=False) 111 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/environment.py: -------------------------------------------------------------------------------- 1 | from cre import define_fact, Fact, Conditions 2 | # from cre.default_funcs import CastFloat 3 | from apprentice.agents.cre_agents.extending import new_register_decorator, new_register_all 4 | from .conv_funcs import CastFloat, CastStr 5 | from copy import copy 6 | 7 | 8 | # NOTE : env_config might be unecessary 9 | # Env Config 10 | register_env_config = new_register_decorator("env_config", full_descr="environment configuration") 11 | 12 | 13 | 14 | # Fact 15 | register_fact = new_register_decorator("fact", full_descr="fact type") 16 | register_all_facts = new_register_all("fact", types=[Fact], full_descr="fact type") 17 | 18 | # Fact set 19 | register_fact_set = new_register_decorator("fact_set", full_descr='fact set') 20 | 21 | # Base Constraints 22 | register_constraints = new_register_decorator("constraint", full_descr="base constraint") 23 | 24 | 25 | 26 | with register_all_facts as HTML_fact_types: 27 | Component = define_fact("Component", { 28 | "id" : str, 29 | # "x" : {"type" : float, "visible" : False}, 30 | # "y" : {"type" : float, "visible" : False}, 31 | # "width" : {"type" : float, "visible" : False,}, 32 | # "height" : {"type" : float, "visible" : False}, 33 | "above" : "Component", 34 | "below" : "Component", 35 | "left": "Component", 36 | "right" : "Component", 37 | "parents" : "List(Component)" 38 | }) 39 | 40 | TextField = define_fact("TextField", { 41 | "inherit_from" : "Component", 42 | "value" : {"type" : str, "visible" : True, "semantic" : True, 43 | 'conversions' : {float : CastFloat}}, 44 | "locked" : {"type" : bool, "visible" : True}, 45 | }) 46 | 47 | Button = define_fact("Button", { 48 | "inherit_from" : "Component", 49 | }) 50 | 51 | Container = define_fact("Container", { 52 | "inherit_from" : "Component", 53 | "children" : "List(Component)" 54 | }) 55 | 56 | def str_as_id(x): 57 | return x.id 58 | 59 | # Redefine __repr__ to be more concise 60 | def text_field_repr(x): 61 | return f"TextField(id={x.id!r}, value={x.value!r}, locked={x.locked!r})" 62 | 63 | TextField._fact_proxy.__str__ = str_as_id 64 | TextField._fact_proxy.__repr__ = text_field_repr 65 | 66 | def button_repr(x): 67 | return f"Button(id={x.id!r})" 68 | 69 | Button._fact_proxy.__str__ = str_as_id 70 | Button._fact_proxy.__repr__ = button_repr 71 | 72 | def component_repr(x): 73 | return f"Component(id={x.id!r})" 74 | 75 | Component._fact_proxy.__str__ = str_as_id 76 | Component._fact_proxy.__repr__ = component_repr 77 | 78 | # TextField._fact_proxy.__repr__ = text_field_str 79 | 80 | register_fact_set(name='html')(HTML_fact_types) 81 | # with register_all_actions as HTML_action_types: 82 | 83 | 84 | @register_constraints(name='none') 85 | def default_constraints(_vars): 86 | sel, args = _vars[0], _vars[1:] 87 | 88 | conds = Conditions(sel) 89 | for arg in args: 90 | conds &= arg 91 | return conds 92 | 93 | @register_constraints(name='html') 94 | def html_constraints(_vars): 95 | sel, args = _vars[0], _vars[1:] 96 | conds = default_constraints(_vars) 97 | 98 | if(sel.base_type._fact_name == "TextField"): 99 | conds &= (sel.locked == False) 100 | 101 | for arg in args: 102 | if(arg.base_type._fact_name == "TextField"): 103 | conds &= (arg.value != '') 104 | 105 | return conds 106 | 107 | 108 | # ------------------------- 109 | # : ActionType 110 | # NOTE: work in progress 111 | 112 | class ActionType(object): 113 | def __init__(self, name, input_spec, apply_expected_change): 114 | self.name = name 115 | self.input_spec = input_spec 116 | self.apply_expected_change = apply_expected_change 117 | 118 | def predict_state_change(self, state, sai): 119 | cpy = copy(state) 120 | selection = cpy.get_fact(id=sai.selection) 121 | self.apply_expected_change(cpy, selection, sai.input) 122 | return copy(cpy) 123 | 124 | def __getitem__(self, attr): 125 | return self.input_spec[attr] 126 | 127 | def get(self,attr,default): 128 | return self.input_spec.get(attr,default) 129 | 130 | def __str__(self): 131 | return self.name 132 | 133 | def __repr__(self): 134 | return f"ActionType(name={self.name}, spec={self.input_spec})" 135 | 136 | def define_action_type(name, input_spec, *args): 137 | def wrapper(apply_expected_change): 138 | return ActionType(name, input_spec, apply_expected_change) 139 | if(len(args) > 0): 140 | return wrapper(*args) 141 | else: 142 | return wrapper 143 | 144 | 145 | # NOTE : action might be unecessary 146 | # Action 147 | register_action_type = new_register_decorator("action_type", full_descr="action type") 148 | register_all_action_types = new_register_all("action_type", types=[ActionType], full_descr="action type") 149 | 150 | # Action Set 151 | register_action_type_set = new_register_decorator("action_type_set", full_descr='action type set') 152 | 153 | 154 | 155 | with register_all_action_types as HTML_action_type_set: 156 | # NOTE need to 157 | 158 | @define_action_type("PressButton", 159 | {'type' : int, "semantic" : False} 160 | ) 161 | def PressButton(wm, selection, inp): 162 | pass 163 | 164 | @define_action_type("ButtonPressed", 165 | {'type' : int, "semantic" : False} 166 | ) 167 | def ButtonPressed(wm, selection, inp): 168 | pass 169 | 170 | @define_action_type("UpdateTextArea", 171 | {'type' : str, "semantic" : True} 172 | ) 173 | def UpdateTextArea(wm, selection, inp): 174 | wm.modify(selection, 'value', inp) 175 | wm.modify(selection, 'locked', True) 176 | 177 | @define_action_type("UpdateTextField", 178 | {'type' : str, "semantic" : True} 179 | ) 180 | def UpdateTextField(wm, selection, inp): 181 | wm.modify(selection, 'value', inp) 182 | wm.modify(selection, 'locked', True) 183 | 184 | @define_action_type("UpdateField", 185 | {'type' : str, "semantic" : True} 186 | ) 187 | def UpdateField(wm, selection, inp): 188 | wm.modify(selection, 'value', inp) 189 | wm.modify(selection, 'locked', True) 190 | 191 | HTML_action_type_set = {x.name: x for x in HTML_action_type_set} 192 | # HTML_action_type_set = { 193 | # "UpdateTextField" : UpdateTextField, 194 | # "UpdateField" : UpdateTextField, 195 | 196 | # "PressButton" : PressButton, 197 | # "ButtonPressed" : PressButton, 198 | # } 199 | register_action_type_set(name='html')(HTML_action_type_set) 200 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/extending.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | 4 | # ----------------------------------------------------------------------- 5 | # : Type registration 6 | 7 | registries = {} 8 | 9 | class Registery(): 10 | def __init__(self, name, full_descr=None): 11 | self.name = name 12 | self.full_descr = name if full_descr is None else full_descr 13 | self.reg_dict = {} 14 | 15 | def __contains__(self, name): 16 | return name in self.reg_dict 17 | 18 | def __getitem__(self, name): 19 | if(name not in self.reg_dict): 20 | raise ValueError(f"No {self.full_descr} registered with name {name!r}.") 21 | return self.reg_dict[name] 22 | 23 | def __iter__(self): 24 | return iter(self.reg_dict.values()) 25 | 26 | def __setitem__(self, name, val): 27 | self.reg_dict[name] = val 28 | 29 | def __str__(self): 30 | return str(self.reg_dict) 31 | 32 | def __len__(self): 33 | return len(self.reg_dict) 34 | 35 | def __str__(self): 36 | return str(self.reg_dict) 37 | 38 | __repr__ = __str__ 39 | 40 | def _resolve_name(obj, name_resolver=None): 41 | if(name_resolver is not None): 42 | return name_resolver(obj) 43 | elif(hasattr(obj,'__name__')): 44 | return obj.__name__ 45 | else: 46 | raise ValueError(f"Cannot resolve name during registration: {obj}") 47 | 48 | def _register(obj, registry, name=None, name_resolver=None, 49 | insert_func=None, full_descr=None, args=[], kwargs={}, stack_extra=0): 50 | name = _resolve_name(obj) if name is None else name 51 | regular_name = name.lower().replace("_", "") 52 | if(regular_name in registry): 53 | warnings.warn(f"Redefinition of {full_descr} '{name}'.", stacklevel=2+stack_extra) 54 | 55 | if(insert_func is not None): 56 | insert_func(registry, obj, name, *args, **kwargs) 57 | else: 58 | registry[regular_name] = obj 59 | return obj 60 | 61 | 62 | def new_register_decorator(type_name, name_resolver=None, insert_func=None, full_descr=None): 63 | full_descr = type_name if full_descr is None else full_descr 64 | registry = registries[type_name] = registries.get(type_name, Registery(type_name, full_descr)) 65 | def register_whatever(*args,name=None,**kwargs): 66 | if(len(args) >= 1): 67 | return _register(args[0], registry, name, name_resolver, 68 | insert_func, full_descr, kwargs=kwargs, stack_extra=1) 69 | else: 70 | return lambda obj: _register(obj, registry, name, name_resolver, 71 | insert_func, full_descr, kwargs=kwargs, stack_extra=2) 72 | return register_whatever 73 | 74 | class RegisterAll(): 75 | def __init__(self, type_name, types=[], acceptor_funcs=[], 76 | name_resolver=None, insert_func=None, full_descr=None): 77 | self.type_name = type_name 78 | self.registry = registries[type_name] = registries.get(type_name, Registery(type_name, full_descr)) 79 | self.types = types 80 | self.acceptor_funcs = acceptor_funcs 81 | self.insert_func = insert_func 82 | self.name_resolver = name_resolver 83 | self.full_descr = self.type_name if full_descr is None else full_descr 84 | 85 | def __call__(self, n_back=1, *args, **kwargs): 86 | frame = b_frame = inspect.currentframe() 87 | try: 88 | for n in range(n_back): 89 | b_frame = b_frame.f_back 90 | locs = {**b_frame.f_locals} 91 | finally: 92 | del frame 93 | 94 | for name, obj in locs.items(): 95 | # Skip any builtins 96 | if(name[:2] == "__" and name[-2:] == "__"): 97 | continue 98 | 99 | if(isinstance(obj, self.types) or 100 | any([f(obj) for f in self.acceptor_funcs])): 101 | 102 | # The registered thing needs to be 103 | if(hasattr(self, 'enter_locs') and name in self.enter_locs): 104 | # print(name, id(self.enter_locs[name]), id(obj)) 105 | if(id(self.enter_locs[name]) == id(obj)): 106 | continue 107 | _register(obj, self.registry, name, self.name_resolver, 108 | self.insert_func, self.full_descr, args, kwargs, stack_extra=1) 109 | 110 | if(hasattr(self, 'collected')): 111 | self.collected.append(obj) 112 | return self 113 | 114 | def __enter__(self): 115 | frame = inspect.currentframe() 116 | try: 117 | self.enter_locs = {**frame.f_back.f_locals} 118 | finally: 119 | del frame 120 | # print("enter_locs:") 121 | # print(list(self.enter_locs)) 122 | # print() 123 | 124 | self.collected = [] 125 | return self.collected 126 | 127 | def __exit__(self,*args): 128 | self.__call__(n_back=2) 129 | del self.collected 130 | del self.enter_locs 131 | 132 | 133 | def new_register_all(type_name, types=[], acceptor_funcs=[], 134 | name_resolver=None, insert_func=None, full_descr=None): 135 | if(not isinstance(types,(list,tuple))): types = [types] 136 | if(not isinstance(acceptor_funcs,(list,tuple))): acceptor_funcs = [acceptor_funcs] 137 | 138 | types = tuple(types) 139 | acceptor_funcs = tuple(acceptor_funcs) 140 | 141 | return RegisterAll(type_name, types, acceptor_funcs, name_resolver, insert_func, full_descr) 142 | 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/funcs.py: -------------------------------------------------------------------------------- 1 | from numba.types import f8, string, boolean 2 | from apprentice.agents.cre_agents.extending import registries, new_register_decorator, new_register_all 3 | from apprentice.agents.cre_agents.environment import TextField 4 | from cre import CREFunc 5 | import numpy as np 6 | 7 | register_func = new_register_decorator("func", full_descr="CREFunc") 8 | register_all_funcs = new_register_all("func", types=[CREFunc], full_descr="CREFunc") 9 | 10 | @CREFunc(signature=boolean(string,string), 11 | shorthand = '{0} == {1}', 12 | commutes=True) 13 | def Equals(a, b): 14 | return a == b 15 | 16 | @CREFunc(signature=f8(f8,f8), 17 | shorthand = '{0} + {1}', 18 | commutes=True) 19 | def Add(a, b): 20 | return a + b 21 | 22 | @CREFunc(signature=f8(f8,f8)) 23 | def AddPositive(a, b): 24 | if(not (a >= 0 and b >= 0)): 25 | raise Exception 26 | return a + b 27 | 28 | @CREFunc(signature=f8(f8,f8,f8), 29 | shorthand = '{0} + {1} + {2}', 30 | commutes=True) 31 | def Add3(a, b, c): 32 | return a + b + c 33 | 34 | @CREFunc(signature=f8(f8,f8), 35 | shorthand = '{0} - {1}') 36 | def Subtract(a, b): 37 | return a - b 38 | 39 | @CREFunc(signature=f8(f8,f8), 40 | shorthand = '{0} * {1}', 41 | commutes=True) 42 | def Multiply(a, b): 43 | return a * b 44 | 45 | @CREFunc(signature=f8(f8,f8), 46 | shorthand = '{0} / {1}' 47 | ) 48 | def Divide(a, b): 49 | return a / b 50 | 51 | @CREFunc(signature=f8(f8,f8), 52 | shorthand = '{0} // {1}') 53 | def FloorDivide(a, b): 54 | return a // b 55 | 56 | # @CREFunc(signature=f8(f8,f8), 57 | # shorthand = '{0} ** {1}') 58 | # def Power(a, b): 59 | # return a ** b 60 | 61 | @CREFunc(signature=f8(f8,f8), 62 | shorthand = '{0} % {1}') 63 | def Modulus(a, b): 64 | return a % b 65 | 66 | 67 | @CREFunc(signature=f8(f8), shorthand = '{0}^2') 68 | def Square(a): 69 | return a * a 70 | 71 | @CREFunc(signature=f8(f8, f8), shorthand = '{0}^{1}') 72 | def Power(a, b): 73 | return a ** b 74 | 75 | @CREFunc(signature=f8(f8), shorthand = '{0}+1') 76 | def Increment(a): 77 | return a + 1 78 | 79 | @CREFunc(signature=f8(f8), shorthand = '{0}-1') 80 | def Decrement(a): 81 | return a - 1 82 | 83 | @CREFunc(signature=f8(f8), shorthand = 'log2({0})') 84 | def Log2(a): 85 | return np.log2(a) 86 | 87 | @CREFunc(signature=f8(f8), shorthand = 'cos({0})') 88 | def Cos(a): 89 | return np.cos(a) 90 | 91 | @CREFunc(signature=f8(f8), shorthand = 'sin({0})') 92 | def Sin(a): 93 | return np.sin(a) 94 | 95 | @CREFunc(signature=f8(f8), 96 | shorthand = '{0} % 10') 97 | def Mod10(a): 98 | return a % 10 99 | 100 | @CREFunc(signature=f8(f8), 101 | shorthand = '{0} // 10') 102 | def Div10(a): 103 | return a // 10 104 | 105 | @CREFunc(signature=string(string), 106 | shorthand = '{0}') 107 | def Copy(a): 108 | return a 109 | 110 | @CREFunc(signature = string(string,string), 111 | shorthand = '{0} + {1}', 112 | commutes=False) 113 | def Concatenate(a, b): 114 | return a + b 115 | 116 | 117 | @CREFunc(signature=f8(f8), shorthand = '{0}/2') 118 | def Half(a): 119 | return a / 2 120 | 121 | @CREFunc(signature=f8(f8), shorthand = '{0}*2') 122 | def Double(a): 123 | return a * 2 124 | 125 | @CREFunc(signature=f8(f8), shorthand = 'OnesDigit({0})') 126 | def OnesDigit(a): 127 | return a % 10 128 | 129 | @CREFunc(signature=f8(f8), shorthand = 'TensDigit({0})') 130 | def TensDigit(a): 131 | return (a // 10) % 10 132 | 133 | 134 | ### Special Functions for Fractions 135 | ### --typically can be replaced with Multiply 136 | 137 | @CREFunc(signature=f8(f8,f8,f8), 138 | shorthand='({0} / {1}) * {2}') 139 | def ConvertNumerator(a, b, c): 140 | return (a / b) * c 141 | 142 | @CREFunc(signature=f8(TextField, TextField), 143 | shorthand='Cross({0} * {1})') 144 | def CrossMultiply(a, b): 145 | if('den' in a.id and 'den' in b.id): 146 | raise ValueError() 147 | if('num' in a.id and 'num' in b.id): 148 | raise ValueError() 149 | return (float(a.value) * float(b.value)) 150 | 151 | @CREFunc(signature=f8(TextField, TextField), 152 | shorthand='Across({0} * {1})') 153 | def AcrossMultiply(a, b): 154 | if('den' in a.id and 'den' not in b.id): 155 | raise ValueError() 156 | if('num' in a.id and 'num' not in b.id): 157 | raise ValueError() 158 | 159 | return (float(a.value) * float(b.value)) 160 | 161 | 162 | ##### Define all CREFuncs above this line ##### 163 | 164 | register_all_funcs() 165 | 166 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/__init__.py: -------------------------------------------------------------------------------- 1 | from .how import BaseHow, SetChaining, ExplanationSet, register_how 2 | from .nlp.nlp_sc_planner import NLPSetChaining 3 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/how.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from abc import ABCMeta 3 | from abc import abstractmethod 4 | # from ...extending import new_register_decorator, registries 5 | from cre.utils import PrintElapse 6 | from ..registers import register_how 7 | 8 | 9 | # ------------------------------------------------------------------------ 10 | # : How Base 11 | 12 | 13 | # TODO: COMMENTS 14 | class BaseHow(metaclass=ABCMeta): 15 | @abstractmethod 16 | def get_explanations(self, state, goal): 17 | """ 18 | 19 | :param state: 20 | """ 21 | pass 22 | 23 | 24 | # ------------------------------------------------------------------------ 25 | # : How Learning Mechanisms 26 | 27 | 28 | # -------------------------------------------------- 29 | # : SetChaining 30 | 31 | from numba.types import string, f8 32 | from cre.sc_planner import SetChainingPlanner 33 | from cre.func import CREFunc 34 | from ...conv_funcs import register_conversion 35 | 36 | # func_registry = registries['func'] 37 | 38 | class ExplanationSet(): 39 | def __init__(self, explanation_tree, arg_foci=None, 40 | post_func=None, choice_func=None, max_expls=1000): 41 | self.explanation_tree = explanation_tree 42 | self.choice_func = choice_func 43 | self.post_func = post_func 44 | 45 | if(explanation_tree is not None): 46 | if(isinstance(explanation_tree, list)): 47 | self.explanations = explanation_tree 48 | else: 49 | self.explanations = [] 50 | for i, (func_comp, match) in enumerate(explanation_tree): 51 | #print(func_comp, match) 52 | if(max_expls != -1 and i >= max_expls-1): break 53 | 54 | # Skip 55 | if(func_comp.n_args != len(match) or 56 | (arg_foci is not None and func_comp.n_args != len(arg_foci))): 57 | continue 58 | 59 | if(self.post_func is not None): 60 | func_comp = self.post_func(func_comp) 61 | 62 | self.explanations.append((func_comp, match)) 63 | 64 | # Sort by min depth, degree to which variables match unique goals, 65 | # and total funcs in the composition. 66 | import numpy as np 67 | has_foci_match = np.array([0]) 68 | def expl_key(tup): 69 | # global has_foci_match 70 | func_comp, match = tup 71 | 72 | # Prefer exact matches 73 | foci_match = False 74 | if(arg_foci is not None): 75 | foci_match = all([a.id == m.id for a,m in zip(arg_foci, match)]) 76 | has_foci_match[0] = has_foci_match[0] | foci_match 77 | # print("foci_match", foci_match, [m.id for m in match], [a.id for a in arg_foci], func_comp) 78 | # tup = (not foci_match, func_comp.depth, abs(func_comp.n_args-len(match)), func_comp.n_funcs) 79 | 80 | tup = (func_comp.depth, abs(func_comp.n_args-len(match)), func_comp.n_funcs) 81 | return tup 82 | self.explanations = sorted(self.explanations, key=expl_key) 83 | # if(len(self.explanations) > 0): 84 | # print("Any FOCI MATCH", not not has_foci_match[0]) 85 | else: 86 | self.explanations = [] 87 | 88 | 89 | def __len__(self): 90 | return len(self.explanations) 91 | 92 | def choose(self): 93 | if(len(self.explanations) > 0): 94 | return self.explanations[0] 95 | return None, [] 96 | 97 | 98 | def __iter__(self): 99 | for func, match in self.explanations: 100 | yield (func, match) 101 | 102 | @register_conversion(name="NumericalToStr") 103 | @CREFunc(signature=string(f8), 104 | shorthand="s({0})") 105 | def NumericalToStr(x): 106 | if(int(x) == x): 107 | return str(int(x)) 108 | else: 109 | return str(x) 110 | 111 | 112 | @register_how 113 | class SetChaining(BaseHow): 114 | def __init__(self, 115 | agent=None, 116 | search_depth=2, 117 | function_set=[], 118 | float_to_str=True, 119 | **kwargs): 120 | # print("SC", kwargs) 121 | self.agent = agent 122 | self.function_set = function_set 123 | self.search_depth = search_depth 124 | self.float_to_str = float_to_str 125 | self.fact_types = kwargs.get('fact_types', self.agent.fact_types if (self.agent) else []) 126 | for fn in function_set: 127 | assert isinstance(fn, CREFunc), \ 128 | "function_set must consist of CREFunc intances for SetChaining how-learning mechanism." 129 | 130 | 131 | def _search_for_explanations(self, goal, values, extra_consts=[], **kwargs): 132 | # Fallback on any parameters set in __init__() 133 | kwargs['funcs'] = kwargs.get('function_set', self.function_set) 134 | if('function_set' in kwargs): del kwargs['function_set'] 135 | kwargs['search_depth'] = kwargs.get('search_depth', self.search_depth) 136 | 137 | # Make a new planner instance and fill it with values 138 | planner = SetChainingPlanner(self.fact_types) 139 | for i, v in enumerate(values): 140 | # print(":", i, v, v.value if hasattr(v,'value') else None) 141 | planner.declare(v) 142 | 143 | for v in extra_consts: 144 | planner.declare(v,is_const=True) 145 | 146 | # Search for explanations 147 | explanation_tree = planner.search_for_explanations(goal, **kwargs) 148 | self.num_forward_inferences = planner.num_forward_inferences 149 | return explanation_tree 150 | 151 | def get_explanations(self, state, goal, arg_foci=None, float_to_str=None, 152 | extra_consts=[], **kwargs): 153 | # print("arg_foci", arg_foci) 154 | # Prevent from learning too shallow when multiple foci 155 | if(arg_foci is not None and len(arg_foci) > 1): 156 | # arg_foci = list(reversed(arg_foci)) 157 | # Don't allow fallback to constant 158 | kwargs['min_stop_depth'] = kwargs.get('min_stop_depth', kwargs.get('search_depth',getattr(self, 'search_depth', 2))) 159 | kwargs['min_solution_depth'] = 1 160 | 161 | 162 | if(isinstance(state, list)): 163 | values = state 164 | else: 165 | wm = state.get("working_memory") 166 | values = list(wm.get_facts()) if arg_foci is None else arg_foci 167 | 168 | float_to_str = float_to_str if float_to_str is not None else self.float_to_str 169 | 170 | try: 171 | flt_goal = float(goal) 172 | except ValueError: 173 | explanation_tree = None 174 | else: 175 | explanation_tree = self._search_for_explanations(flt_goal, values, extra_consts, **kwargs) 176 | post_func = NumericalToStr if (float_to_str) else None 177 | 178 | 179 | # Try to find the goal as a string 180 | if(explanation_tree is None): 181 | # TODO: Shouldn't full reset and run a second time here, should just query. 182 | # print("GOAL", goal) 183 | explanation_tree = self._search_for_explanations(goal, values, extra_consts, **kwargs) 184 | post_func = None 185 | 186 | expl_set = ExplanationSet(explanation_tree, arg_foci, post_func=post_func) 187 | 188 | # if(expl_set is not None): 189 | # for op_comp, match in expl_set: 190 | # print("<<", op_comp, [m.id for m in match]) 191 | 192 | 193 | return expl_set 194 | 195 | def new_explanation_set(self, explanations, *args, **kwargs): 196 | '''Takes a list of explanations i.e. (func, match) and yields an ExplanationSet object''' 197 | return ExplanationSet(explanations,*args, **kwargs) 198 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apprenticelearner/AL_Core/9534e3e1d5f5439740e0c44196a4a4f91602aed3/apprentice/agents/cre_agents/learning_mechs/how/nlp/__init__.py -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/analysis/.~lock.NLP_how_turk_data_cleaned.csv#: -------------------------------------------------------------------------------- 1 | ,danny,danny-Yoga-7-15ITL5,28.06.2023 00:07,file:///home/danny/.config/libreoffice/4; -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/analysis/copilot/p1.py: -------------------------------------------------------------------------------- 1 | # 3 divided by 12 is .25. .25 multiplied by 100 is 25% 2 | def calc1(): 3 | print("This program will calculate the percentage of 3 out of 12") 4 | 5 | # Form a ratio with 3 as the numerator, 12 as the denominator, and convert the decimal into a percent by multiplying 100 6 | def calc2(): 7 | print(3/12*100) 8 | 9 | # Find the quotient of 12 divided by three and divide 100 by the quotient to find the percent 10 | def calc3(): 11 | print(12/3) 12 | 13 | # divide 3 by 12 to get .25, then multiply by 100 which is 25 percent. 14 | def calc4(): 15 | print(3/12*100) 16 | 17 | # Multiply 3 times 100. Then, divide the product by 12 to get the percentage. 18 | def calc5(): 19 | return (3 * 100) / 12 20 | 21 | # Multiply the numerator 3 by 100 and divide that result by the numerator 12 to find the percentage that 3 is of 12. 22 | def calc6(): 23 | return (3 * 100) / 12 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/analysis/copilot/p10.py: -------------------------------------------------------------------------------- 1 | # Divide 12 by 3 and multiply that value by 2 2 | def calc1(): 3 | return 2 * (12 / 3) 4 | 5 | # The larger triangle has a ratio of 3 over 12. The smaller triangle has a ratio of 2 over x. If you set the ratios equal to each other (because they are similar triangles), you can use cross-multiplication to solve for x. Multiply 2 and 12. Divide that value by 3. 6 | def calc2(): 7 | return (2*12)/3 8 | 9 | # Divide 3 by 2 and use the quotient to divide 12 by. 10 | def calc3(): 11 | # Divide 3 by 2 12 | quotient = 3 / 2 13 | # Use the quotient to divide 12 by 14 | quotient2 = 12 / quotient 15 | # Print the result 16 | print(quotient2) 17 | 18 | # 12 divided by 3 is 4 so the green triangle's longer side is 4 times larger than its smaller one. Now you take the 2 on the smaller side of the pink triangle and multiply it by 4 to get 8 yd. 19 | def calc4(): 20 | print("12 divided by 3 is 4 so the green triangle's longer side is 4 times larger than its smaller one. Now you take the 2 on the smaller side of the pink triangle and multiply it by 4 to get 8 yd.") 21 | print("The answer is 8 yd.") 22 | print("") 23 | 24 | # Use the length of the two similar sides to make a ratio. In this case, use 2 and 3 yards to make the ratio, 2/3. Make the ratio with the other two similar sides, in this case x and 12, to make the ratio x/12. Cross-multiply the two ratios and solve for x. 25 | def calc5(): 26 | print("2/3 = x/12") 27 | print("2x = 36") 28 | print("x = 18") 29 | print("The length of the side is 18 yards.") 30 | 31 | # To find the length of side x, divide 3 by 2 to get 1.5. Take 1.5 and multiply it with 12 32 | def calc6(): 33 | x = (3/2)*12 34 | print(x) 35 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/analysis/copilot/p11.py: -------------------------------------------------------------------------------- 1 | # Divide 144 by the sum of 1 and .2 to get the number of pears Type B produced 2 | def calc1(): 3 | return 144/(1+.2) 4 | 5 | # The total percentage value is 100 plus 20. Divide 144, the number of Type A pears, by 1.2. 6 | def calc2(): 7 | return 100 + 20 / 144 / 1.2 8 | 9 | # a equals b plus 20 percent so 120 percent. b equals 100 percent. 10 divided by 120 equals 83.33 percent. So b produces 83.33 percent (.8333) of a (144) which equals 120 10 | def calc3(): 11 | a = 144 12 | b = a * .8333 13 | return b -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/analysis/copilot/p12.py: -------------------------------------------------------------------------------- 1 | # Divide 10 by 2 first and then add 3. 2 | def calc1(): 3 | return (10/2)+3 # return the result of the calculation 4 | 5 | # Divide 10/2 and then add 3 6 | def calc2(): 7 | return (10/2) + 3 8 | 9 | # PEMDAS dictates that division takes place before addition. So, 10 should be divided by 2 first. Then, the quotient should be added to 3. 10 | def calc3(): 11 | return 10 / 2 + 3 # 10 / 2 = 5, 5 + 3 = 8 12 | 13 | # Divide 10 by 2. Then add that value to 3. 14 | def calc4(): 15 | return (10/2)+3 16 | 17 | # Divide 10 by 2 then add 3. 18 | def calc5(): 19 | print(10/2+3) 20 | 21 | # divide 10 by 2 then add 3 22 | def calc6(): 23 | return 10 / 2 + 3 24 | 25 | # Divide 10 by 2. Then add the answer to 3. 26 | def calc7(): 27 | return 10/2 + 3 28 | 29 | # First divide the numerator 10 by the numerator 2, because division comes before addition. Then sum that result (5) and the numerator 3 to find the final answer. 30 | def calc8(): 31 | print(10/2+3) 32 | 33 | # Take 10 divided by 2 to get 5, and then find the sum of 5 and 3 34 | def calc9(): 35 | return 10 / 2 + 3 -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/analysis/copilot/p13.py: -------------------------------------------------------------------------------- 1 | # Add 8 and 4. Since it is more than 10, take just the ones digit. 2 | def calc1(): 3 | return (8+4)%10 4 | 5 | # Add the digits in the tens place, 8 and 4. Since the sum exceeds 10, write down the ones of the sum (2), then carry the tens (1) to the hundreds place. 6 | def calc2(): 7 | x = 8 8 | y = 4 9 | z = x + y 10 | print(z) # 12 11 | 12 | # The sum of 8 and 4 without the 1. 13 | def calc3(): 14 | print(8 + 4 - 1) 15 | 16 | # add 8 to 4 which gives you 12 and you put down the ones column, which is a two and carry the one over to the next column to the left. 17 | def calc4(): 18 | print(8+4) 19 | 20 | # Add 8 plus 4. Bring down the answer's ones place digit, next to the 9. Carry the tens place digit up above the 3 to be added in the hundreds place. 21 | def calc5(): 22 | print(8 + 4) 23 | print(" Bring down the answer's ones place digit, next to the 9. Carry the tens place digit up above the 3 to be added in the hundreds place.") 24 | 25 | # Add the numeral 8 plus the numeral 4, and then use the last digit of the resulting sum (12) to find the missing value (2). 26 | def calc6(): 27 | return 8 + 4 - 12 # 2 28 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/analysis/copilot/p14.py: -------------------------------------------------------------------------------- 1 | # Subtract 8 by 4 and Subtract 7 by 5. Divide the first sum by the second sum. 2 | def calc1(): 3 | return (8-4)/(7-5) # 1.0 4 | 5 | # The slope is the quotient of the difference between the y-values 8 and 4, and the difference between the x-values 7 and 5. 6 | def calc2(): 7 | print((8-4)/(7-5)) # 1.0 8 | 9 | # subtract 7 and 5, then divide the product by 8 minus 4 10 | def calc3(): 11 | return (7-5)/(8-4) 12 | 13 | # You take y2 which is 8 and subtract y1 which is 4 and get 4. Put that aside for a minute and take x2 (7) and subtract x1 (5) which is 2. Take the four from the differences in the y values and put that over the two from the difference of x values and you get 4 over 2 which is 2. 14 | def calc4(): 15 | y2 = 8 16 | y1 = 4 17 | x2 = 7 18 | x1 = 5 19 | return (y2 - y1) / (x2 - x1) 20 | 21 | # The change in our x points if found by subtracting 5 from 7. The change in our y points is found by subtracting 4 from 8. Divide the change in y by the change x to find the slope. 22 | def calc5(): 23 | x1 = 5 24 | x2 = 7 25 | y1 = 4 26 | y2 = 8 27 | slope = (y2-y1)/(x2-x1) 28 | print(slope) 29 | 30 | # Subtract the first y-coordinate 4 from the second y-coordinate 8, then subtract the first x-coordinate 5 from the second x-coordinate 7, and then divide the difference of the y-coordinates (4) by the difference of the x-coordinates (2). 31 | def calc6(): 32 | y = 8 - 4 33 | x = 7 - 5 34 | return y / x 35 | 36 | # Take 4 and subtract it from 8 to get 4. Take 5 and subtract it from 7. Take 4 divided by 2 to get 2 37 | def calc7(): 38 | a = 8 - 4 39 | b = 7 - 5 40 | c = 4 / 2 41 | return a, b, c # return a tuple of values 42 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/analysis/copilot/p2.py: -------------------------------------------------------------------------------- 1 | # Multiply 3x and 2x, the first two terms of each binomial and take the constant 2 | def calc1(): 3 | return 6 4 | 5 | # a is the product of integer 3 and 2, which are both in front of variable x. 6 | def calc2(): 7 | x = 3 * 2 8 | return x 9 | 10 | # Multiply 3 and 2 11 | def calc3(): 12 | return 3 * 2 13 | 14 | # The value of a is the product of 3 and 2. 15 | def calc4(): 16 | a = 3 * 2 17 | return a 18 | 19 | # To find a you would use the foil method and multiply the first numbers in each set, which are 3x and 2x which would give you 6x² so a is 6. 20 | def calc5(): 21 | return 6 22 | 23 | # Multiply 3 and 2 together because they are both combined with x. Write the product as the value of a. 24 | def calc6(): 25 | a = 3 * 2 26 | return a 27 | 28 | # Multiply the numerator 3 by the numerator 2 to find the value of a. 29 | def calc7(): 30 | a = 3 * 2 31 | return a 32 | 33 | # In order to find the value of a in the polynomial, the variables 3x and 2x will be multiplied together 34 | def calc8(): 35 | a = (3*x) * (2*x) 36 | return a -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/analysis/copilot/p3.py: -------------------------------------------------------------------------------- 1 | # Multiply 3 and 5. Multiply 2 and 4. Add those two values together 2 | def calc1(): 3 | return 3 * 5 + 2 * 4 4 | 5 | # Using foil and multiply the inside and outside values (4 times 2x, which is 8x) and (3x times 5 which is 15x) and then adding these together we get 23x so the value for b is 23. 6 | def calc2(): 7 | return 23 8 | 9 | # Multiply the 3 of the 3x by the 5. Multiply the 2 of the 2x by the 4. Add each of those products together to fine the value of b. 10 | def calc3(): 11 | return 3 * 5 + 2 * 4 12 | 13 | # Multiply the numerator 3 by the numerator 5, then multiply the numerator 2 by the numerator 4, and then add the results to find the value of b. 14 | def calc4(): 15 | return 3 * 5 + 2 * 4 16 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/analysis/copilot/p4.py: -------------------------------------------------------------------------------- 1 | # Multiply 4 and 5 to get c 2 | def calc1(): 3 | return 4 * 5 4 | 5 | # c is calculated by multiplying 4 and 5, the constant terms 6 | def calc2(): 7 | return 4 * 5 8 | 9 | # Multiply 4 and 5. 10 | def calc3(): 11 | return 4 * 5 12 | 13 | # Multipy 4 and 5 to find the product that is the value for c. 14 | def calc4(): 15 | return 4 * 5 16 | 17 | # You multiply 4 times 5 and get 20 and that's the answer because they have no variables they can't be added to any other numbers. 18 | def calc5(): 19 | return 4 * 5 20 | 21 | # Multiply 4 times 5 to find the value of c. 22 | def calc6(): 23 | return 4 * 5 24 | 25 | # Multiply the numerator 4 by the numerator 5 to find the value of c. 26 | def calc7(): 27 | return 4 * 5 28 | 29 | # To get the value of c, 4 and 5 will be multiplied together 30 | def calc8(): 31 | return 4 * 5 32 | 33 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/analysis/copilot/p5.py: -------------------------------------------------------------------------------- 1 | # Multiply 4 and 3 to find the denominator of the new fractions 2 | def calc1(): 3 | return 4 * 3 4 | 5 | # Multiply 4 x 3 and use the product 12 as your denominator. 6 | def calc2(): 7 | return 4 * 3 8 | 9 | # Multiply the denominator 4 by the denominator 3 to find the common denominator to be used for the converted fractions. 10 | def calc3(): 11 | return 4 * 3 12 | 13 | def calc4(): 14 | return 9/12 + 8/12 -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/analysis/copilot/p6.py: -------------------------------------------------------------------------------- 1 | # Multiply the 3 of the first fraction by the 3 needed to turn the "4" denominator into 12 2 | def calc1(): 3 | return 3 * 3 4 | 5 | # Since the denominator 4 is converted to denominator 12 by multiplying 3, also multiply the numerator 3 by 3 6 | def calc2(): 7 | num = 3 8 | den = 4 9 | return num * (den/12) 10 | 11 | # Multiply 3 and 3 12 | def calc3(): 13 | return 3 * 3 14 | 15 | # Multiply 3 by the multiple used to convert the denominator 3. 16 | def calc4(): 17 | return 3 * 3 18 | 19 | # To get a denominator of 12 from 4, you multiply by 3 on the bottom, so you have to do this on the top as well. 3 times 3 equals 9. three-fourths equals nine-twelfths. 20 | def calc5(): 21 | return 9/12 + 8/12 22 | 23 | # Your denominator 4 goes into 12 3 times, so multiply your numerator, 3, by 3 as well. The product is your new numerator. 24 | def calc6(): 25 | return 3 * 3 26 | 27 | # Multiply the numerator 3 by the denominator 3 to find the numerator of the left converted fraction. 28 | def calc7(): 29 | return 3 * 3 30 | 31 | # To find the numerator of the left converted fraction, multiply the left unconverted fraction, 3/4, by the variable 3. This will multiply with the unconverted numerator of 3 to result in the converted numerator of 9. 32 | def calc8(): 33 | return 3 * 3 34 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/analysis/copilot/p7.py: -------------------------------------------------------------------------------- 1 | # Add 9 and 8 to get the new numerator 2 | def calc1(): 3 | return 9 + 8 4 | 5 | # Add the values of the numerators in the fractions 9/12 and 8/12. 6 | def calc2(): 7 | x = 9/12 + 8/12 8 | return x 9 | 10 | # Since the two fractions share the same denominator 12, simply add the numerators 9 and 8 to calculate the sum of the fraction 11 | def calc3(): 12 | print(9 + 8) 13 | 14 | # Add 9 and 8 15 | def calc4(): 16 | return 9 + 8 17 | 18 | # Add 9 plus 8. 19 | def calc5(): 20 | return 9 + 8 21 | 22 | # 9 plus 8 equals 17 23 | def calc6(): 24 | return 17 25 | 26 | # Add the numerator 9, to the numerator 8. The sum is the numerator of the new fraction. 27 | def calc7(): 28 | return 9 + 8 29 | 30 | # Sum the numerator 9 and the numerator 8 to find the numerator of the simplified fraction. 31 | def calc8(): 32 | return 9 + 8 33 | 34 | # To find the numerator of the simplified fraction, add the numerators from the converted fractions, 9 and 8, together to get 17. 35 | def calc9(): 36 | return 17 # 9 + 8 37 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/analysis/copilot/p8.py: -------------------------------------------------------------------------------- 1 | # Find the area of the shaded region, (1/2)*(12km*20km), and then subtract the area of the unshaded region (1/2)*(6km*10km), from the total 2 | def calc1(): 3 | return (1/2)*(12*20) - (1/2)*(6*10) 4 | 5 | # Find the area of the large triangle by multiplying the base 20km by height 12km and 1/2. Then find the area of the small triangle by multiplying base 10km by height 6km and 1/2. Subtract the area of the small triangle from the area of the big triangle. 6 | def calc2(): 7 | large_triangle = 20 * 12 * 1/2 8 | small_triangle = 10 * 6 * 1/2 9 | return large_triangle - small_triangle 10 | 11 | # The area of the larger is half of 20 times 12. The area of the smaller triangle is half of 6 times 10. Subtract the area of the smaller triangle from the area of the larger triangle. 12 | def calc3(): 13 | return (20 * 12) / 2 - (6 * 10) / 2 14 | 15 | # Multiply 12 and 20 and divide by two. Then multiply 10 and 6 and divide that product by two. Then subtract the smaller triangle area from the larger triangle area. 16 | def calc4(): 17 | return (12*20/2) - (10*6/2) 18 | 19 | # The area for the larger triangle is (one half times 20 km times 12 km which equals 120km squared) the smaller triangle's area is (one half times 6 km times 10 km which equals 30 km squared.) Then you subtract the small white area to leave the shade region which is 120 km squared minus 30 km squared. 20 | def calc5(): 21 | print(120-30) 22 | 23 | # Find the area of the larger triangle by multiplying the base (20) by the width (12) and dividing that in half. Then do the same for the smaller triangle by multiplying the base (10) by the width (6) and dividing in half. Here, the rectangles have areas of 120 and 30. To find the area of the shaded region, subtract the area 30 from the area 120. 24 | def calc6(): 25 | area1 = (20 * 12) / 2 26 | area2 = (10 * 6) / 2 27 | print(area1 - area2) -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/analysis/copilot/p9.py: -------------------------------------------------------------------------------- 1 | # Calculate the ratio of the sector by dividing the 135 degree arc of the sector by 360, the total degree of a circle. Then, multiply this product by the square of the radius 6 to calculate the value multiplied by pi for the area of the sector 2 | def calc1(): 3 | ratio = (135/360) 4 | area = ratio * (6**2) 5 | return area * 3.14 # 3.14 is the value of pi 6 | 7 | # Using the formula and plugging in our info we get 135 over 360 times pi 6 squared which equals 13.5 pi cm squared 8 | def calc2(): 9 | return 135/360*3.14*6**2 -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/analysis/output.pstats: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apprenticelearner/AL_Core/9534e3e1d5f5439740e0c44196a4a4f91602aed3/apprentice/agents/cre_agents/learning_mechs/how/nlp/analysis/output.pstats -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/garbage1.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | def foo(p_args, e_args): 4 | p_cnts, e_cnts = Counter(p_args), Counter(e_args) 5 | arg_scr = 0 6 | for arg, p_cnt in p_cnts.items(): 7 | arg_scr += p_cnt-abs(p_cnt-e_cnts.get(arg, 0)) 8 | return arg_scr / max(len(p_args), len(e_args)) 9 | 10 | # print(f"A_SCR: {arg_scr / max(len(p_args), len(e_args)):.2}", e_args, p_args ) 11 | 12 | print(foo(['a','a'], ['a'])) 13 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/resources.py: -------------------------------------------------------------------------------- 1 | 2 | dictionary = { 3 | "add" : "+", 4 | "sum" : "+", 5 | "total" : "+", 6 | "count" : "+", 7 | "plus" : "+", 8 | "half" : "/2", 9 | 10 | "subtract" : "-", 11 | "difference": "-", 12 | 13 | "reduction": "-", 14 | "minus": "-", 15 | 16 | "multiply" : "x", 17 | "product" : "x", 18 | "times" : "x", 19 | "double" : "x", 20 | 21 | "divide" : "/", 22 | "split": "/", 23 | "equals" : "=", 24 | 25 | "half" : "/2", 26 | 27 | "ones" : "[0]", 28 | "ones-digit" : "[0]", 29 | "tens" : "[1]", 30 | "tens-digit" : "[1]", 31 | 32 | "square" : "**2", 33 | "squared" : "**2", 34 | } 35 | 36 | 37 | special_patterns = { 38 | r"(\S+)\sdivided\sby\s(\S+)" : "/", 39 | r"(\S+)\stimes\s(\S+)" : "x", 40 | r"(\S+)\sminus\s(\S+)" : "-", 41 | r"(\S+)\splus\s(\S+)" : "+", 42 | 43 | 44 | r"ones\sdigit" : "[0]", 45 | r"ones'\sdigit" : "[0]", 46 | r"one's\sdigit" : "[0]", 47 | 48 | r"tens\sdigit" : "[1]", 49 | r"tens'\sdigit" : "[1]", 50 | r"ten's\sdigit" : "[1]", 51 | } 52 | 53 | not_main = { 54 | "take", 55 | "set", 56 | "find", 57 | "calculate", 58 | "apply" 59 | } 60 | 61 | #substitute numbers in the sentence with nouns for more accurate parsing 62 | noun = "dog" 63 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/how/nlp/tests/test_nlp_sc_planner.py: -------------------------------------------------------------------------------- 1 | from apprentice.agents.cre_agents.how.nlp.nlp_sc_planner import NLPSetChaining, func_to_policy 2 | from apprentice.agents.cre_agents.funcs import ( 3 | Add, Add3, Subtract, Multiply, OnesDigit, TensDigit, CastFloat, 4 | Divide, Multiply) 5 | from cre import CREFunc, define_fact, MemSet, Var 6 | from cre.utils import PrintElapse 7 | from numba.types import f8, unicode_type 8 | import numpy as np 9 | 10 | IE = define_fact("IE", { 11 | "id" : str, 12 | "value" : {"type" : str, "visible" : True, "semantic" : True, 13 | 'conversions' : {float : CastFloat}}, 14 | }) 15 | IE._fact_proxy.__str__ = lambda x: f"{x.value}@{x.id})" 16 | 17 | 18 | def state_w_values(vals): 19 | state = MemSet() 20 | for i, val in enumerate(vals): 21 | state.declare(IE(str(i), str(val))) 22 | return {"working_memory" : state} 23 | 24 | def test_func_to_policy(): 25 | # 1 26 | a, b = Var(f8,'a'), Var(f8,'b') 27 | cf = Add(Subtract(a,7.0),b) 28 | policy = func_to_policy(cf, [1,3]) 29 | assert str(policy) == "[[(Subtract(a, b), [1, 7.0])], [(Add(a, b), [3])]]" 30 | 31 | # 2 32 | cf = Add(a,a) 33 | policy = func_to_policy(cf, [1]) 34 | assert str(policy) == "[[(Add(a, a), [1])]]" 35 | 36 | 37 | # 3 38 | BOOP = define_fact("BOOP", {"A" :unicode_type, "B" :f8}) 39 | v = Var(BOOP, 'v') 40 | cf = Add(v.B,v.B) 41 | policy = func_to_policy(cf, [BOOP("1", 1)]) 42 | assert str(policy) == "[[(Add(a, a), [1.0])]]" 43 | 44 | Float = CastFloat(unicode_type) 45 | 46 | # 4 47 | a, b = Var(IE,'a'), Var(IE,'b') 48 | cf = TensDigit((Float(a.value) + Float(b.value))) 49 | policy = func_to_policy(cf, [IE("7", "7"), IE("6", "6")], conv_funcs=[Float]) 50 | assert str(policy) == "[[(Add(a, b), [7.0, 6.0])], [(TensDigit(a), [])]]" 51 | 52 | # 5 53 | cf = TensDigit((Float(a.value) + Float(a.value))) 54 | policy = func_to_policy(cf, [IE("7", "7")], conv_funcs=[Float]) 55 | assert str(policy) == "[[(Add(a, a), [7.0])], [(TensDigit(a), [])]]" 56 | 57 | def test_basic_searches(): 58 | planner = NLPSetChaining( 59 | fact_types=(IE,), float_to_str=False, 60 | function_set=[Add, Add3, OnesDigit, TensDigit] 61 | ) 62 | # If multiple args are stated then expl1 should not include 63 | # any explanations like Add(a,a), only Add(a,b) 64 | 65 | expls1 = planner.get_explanations( 66 | state_w_values([7,7,7,7,7,7]), 67 | 1, 68 | "Add 7 and 7" 69 | ) 70 | 71 | expls2 = planner.get_explanations( 72 | state_w_values([7,7,7,7,7,7]), 73 | 1, 74 | "Add" 75 | ) 76 | 77 | assert len(expls1) < len(expls2) 78 | 79 | # If arg_foci are given either of these cases then there should only 80 | # be one explanation 81 | 82 | state = state_w_values([7,7,7,7,7,7]) 83 | wm = state.get('working_memory') 84 | arg_foci = [wm.get_fact(id="0"), wm.get_fact(id="1")] 85 | expls1 = planner.get_explanations( 86 | state, 87 | 1, 88 | "Add 7 and 7", 89 | arg_foci=arg_foci 90 | ) 91 | assert len(expls1) == 1 92 | 93 | expls2 = planner.get_explanations( 94 | state, 95 | 1, 96 | "Add", 97 | arg_foci=arg_foci 98 | ) 99 | assert len(expls2) == 1 100 | 101 | def test_const_searches(): 102 | planner = NLPSetChaining( 103 | fact_types=(IE,), float_to_str=False 104 | ) 105 | # If multiple args are stated then expl1 should not include 106 | # any explanations like Add(a,a), only Add(a,b) 107 | 108 | expls1 = planner.get_explanations( 109 | state_w_values([3,12]), 110 | 25.0, 111 | "Divide 3 and 12 and multiply by 100." 112 | ) 113 | 114 | expls1 = planner.get_explanations( 115 | state_w_values([3.0,12.0]), 116 | 25.0, 117 | "Form a ratio with 3 as the numerator, 12 as the denominator, and convert the decimal into a percent by multiplying 100", 118 | ) 119 | 120 | # A useless hint... need to ensure ares are not doubled 121 | with PrintElapse("elapse"): 122 | expls1 = planner.get_explanations( 123 | state_w_values([3,4,2,3]), 124 | 12, 125 | "List out the multiples of denominator 4 and denominator 3 to identify the least common denominator 12" 126 | ) 127 | 128 | def test_code_sections(): 129 | planner = NLPSetChaining( 130 | fact_types=(IE,), float_to_str=False 131 | ) 132 | 133 | expls1 = planner.get_explanations( 134 | state_w_values([3,12]), 135 | 12, 136 | "(3/12)*100" 137 | ) 138 | 139 | for expl in expls1: 140 | print(expl) 141 | 142 | if __name__ == "__main__": 143 | # test_func_to_policy() 144 | # test_basic_searches() 145 | # test_const_searches() 146 | test_code_sections() 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/process/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apprenticelearner/AL_Core/9534e3e1d5f5439740e0c44196a4a4f91602aed3/apprentice/agents/cre_agents/learning_mechs/process/__init__.py -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/registers.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from ..extending import new_register_decorator, registries 4 | register_how = new_register_decorator("how", full_descr="how-learning mechanism") 5 | register_when = new_register_decorator("when", full_descr="when-learning mechanism") 6 | register_where = new_register_decorator("where", full_descr="where-learning mechanism") 7 | register_which = new_register_decorator("which", full_descr="which-learning mechanism") 8 | register_process = new_register_decorator("process", full_descr="process-learning mechanism") 9 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/learning_mechs/which.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from abc import ABCMeta 3 | from abc import abstractmethod 4 | from .registers import register_which 5 | 6 | # ------------------------------------------------------------------------ 7 | # : BaseWhich 8 | 9 | # TODO: COMMENTS 10 | class BaseWhich(metaclass=ABCMeta): 11 | def __init__(self, skill,**kwargs): 12 | self.skill = skill 13 | self.agent = skill.agent 14 | 15 | @staticmethod 16 | def sort(state, skill_applications): 17 | # Sort in descending order of utility (i.e. best first) 18 | def key_func(skill_app): 19 | which_lrn_mech = skill_app.skill.which_lrn_mech 20 | return which_lrn_mech.get_utility(state, skill_app) 21 | 22 | return sorted(skill_applications, key=key_func, reverse=True) 23 | 24 | def ifit(self, state, skill_app, reward): 25 | """ 26 | 27 | :param state: 28 | """ 29 | raise NotImplemented() 30 | 31 | 32 | def remove(self, state, skill_app): 33 | # TODO: Haven't actually implemented anywhere below 34 | return 35 | 36 | def get_utility(self, state, skill_app): 37 | """ 38 | 39 | """ 40 | raise NotImplemented() 41 | 42 | def get_info(self, **kwargs): 43 | return {} 44 | 45 | 46 | @register_which 47 | class TotalCorrect(BaseWhich): 48 | def __init__(self, skill, **kwargs): 49 | super().__init__(skill, **kwargs) 50 | self.num_correct = 0 51 | self.num_incorrect = 0 52 | def ifit(self, state, skill_app, reward): 53 | if(reward > 0): 54 | self.num_correct += 1 55 | else: 56 | self.num_incorrect += 1 57 | 58 | def get_utility(self, state, skill_app): 59 | return self.num_correct 60 | 61 | 62 | 63 | @register_which 64 | class WhenPrediction(BaseWhich): 65 | def __init__(self, skill, **kwargs): 66 | super().__init__(skill, **kwargs) 67 | def ifit(self, state, skill_app, reward): 68 | pass 69 | 70 | def get_utility(self, state, skill_app): 71 | when_pred = getattr(skill_app,'when_pred', 0) 72 | return when_pred if when_pred is not None else 0 73 | 74 | # @staticmethod 75 | # def sort(state, skill_applications): 76 | # return sorted(skill_applications, key=lambda s: getattr(s,'when_prob', 0)) 77 | 78 | 79 | @register_which 80 | class ProportionCorrect(TotalCorrect): 81 | def get_utility(self, state, skill_app): 82 | p,n = self.num_correct, self.num_incorrect 83 | s = p + n 84 | return (p / s if s > 0 else 0, s) 85 | 86 | @register_which 87 | class WeightedProportionCorrect(TotalCorrect): 88 | def get_utility(self,state, skill_app, w=2.0): 89 | p,n = self.num_correct, w*self.num_incorrect 90 | s = p + n 91 | return (p / s if s > 0 else 0, s) 92 | 93 | @register_which 94 | class NonLinearProportionCorrect(TotalCorrect): 95 | def get_utility(self,state, skill_app, a=1.0,b=.25): 96 | p,n = self.num_correct, self.num_incorrect 97 | n = a*n + b*(n*n) 98 | s = p + n 99 | return (p / s if s > 0 else 0, s) 100 | 101 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/test.py: -------------------------------------------------------------------------------- 1 | from apprentice.agents.CRE_Agent.extending import new_register_all, new_register_decorator, registries 2 | 3 | if __name__ == "__main__": 4 | register_poop = new_register_decorator("poop") 5 | register_all_things = new_register_all("things", types=[int,str], acceptor_funcs=[lambda x : isinstance(x,tuple)]) 6 | 7 | @register_poop 8 | def foo(): 9 | return "foo" 10 | 11 | print(registries) 12 | 13 | with register_all_things() as things: 14 | a = 1 15 | b = 'q' 16 | c = ('eggo',) 17 | d = ['noodles'] 18 | e = ['noodlez'] 19 | 20 | with register_all_things() as things: 21 | x = 1 22 | b = 'q' 23 | z = ('eggo',) 24 | q = ['noodles'] 25 | v = ['noodlez'] 26 | 27 | print(things) 28 | print(registries) 29 | 30 | 31 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/tests/test_agent.py: -------------------------------------------------------------------------------- 1 | from apprentice.agents.cre_agents.environment import Button, TextField, Container 2 | from apprentice.agents.cre_agents.state import encode_neighbors, State 3 | from apprentice.agents.cre_agents.dipl_base import BaseDIPLAgent 4 | from apprentice.agents.cre_agents.cre_agent import CREAgent, SAI 5 | from apprentice.agents.cre_agents.tests.test_state import new_mc_addition_state 6 | 7 | from numba.types import unicode_type 8 | from cre import MemSet 9 | from cre.transform import MemSetBuilder, Flattener, FeatureApplier, RelativeEncoder, Vectorizer 10 | from cre.default_funcs import Equals 11 | 12 | 13 | 14 | def test_init_base_dipl(): 15 | agent = BaseDIPLAgent() 16 | 17 | def test_cre_agent(): 18 | function_set = ["Add3", "Mod10", "Add", "Div10", "Copy"] 19 | # Test no FOCI 20 | # agent = CREAgent(feature_set=[], function_set=function_set, 21 | # where="antiunify") 22 | 23 | # py_dict = new_mc_addition_state(567,491) 24 | # print(agent.act(py_dict)) 25 | 26 | # agent.train(py_dict, ( "0_answer", "UpdateTextField", {"value": 8} ), 1) 27 | # py_dict['0_answer'].update({"value" : '8', "locked" : True}) 28 | 29 | # agent.train(py_dict, ( "1_answer", "UpdateTextField", {"value": 5} ), 1) 30 | # py_dict['1_answer'].update({"value" : '5', "locked" : True}) 31 | 32 | # agent.train(py_dict, ( "2_carry", "UpdateTextField", {"value": 1} ), 1) 33 | # py_dict['2_carry'].update({"value" : '1', "locked" : True}) 34 | 35 | # agent.train(py_dict, ( "2_answer", "UpdateTextField", {"value": 0} ), 1) 36 | # py_dict['2_answer'].update({"value" : '0', "locked" : True}) 37 | 38 | # agent.train(py_dict, ( "3_carry", "UpdateTextField", {"value": 1} ), 1) 39 | # py_dict['3_carry'].update({"value" : '1', "locked" : True}) 40 | 41 | # agent.train(py_dict, ( "3_answer", "UpdateTextField", {"value": 1} ), 1) 42 | # py_dict['3_answer'].update({"value" : '1', "locked" : True}) 43 | 44 | # print(agent.act(py_dict)) 45 | 46 | # py_dict = new_mc_addition_state(456,582) 47 | 48 | # print(agent.act(py_dict)) 49 | # act_all = agent.act_all(py_dict) 50 | # print("L", act_all) 51 | # # print(agent.act_all(py_dict)) 52 | # # print(agent.act_rollout(py_dict)) 53 | 54 | # print("---------------------------") 55 | 56 | # # Test w/ foci 57 | # agent = CREAgent(feature_set=[], function_set=function_set, 58 | # where="antiunify") 59 | 60 | # py_dict = new_mc_addition_state(567,491) 61 | # print(agent.act(py_dict)) 62 | 63 | # agent.train(py_dict, ( "0_answer", "UpdateTextField", {"value": 8} ), 1, 64 | # arg_foci=["0_upper", '0_lower']) 65 | # py_dict['0_answer'].update({"value" : '8', "locked" : True}) 66 | 67 | # agent.train(py_dict, ( "1_answer", "UpdateTextField", {"value": 5} ), 1, 68 | # arg_foci=["1_upper", '1_lower']) 69 | # py_dict['1_answer'].update({"value" : '5', "locked" : True},) 70 | 71 | # agent.train(py_dict, ( "2_carry", "UpdateTextField", {"value": 1} ), 1, 72 | # arg_foci=["1_upper", '1_lower']) 73 | # py_dict['2_carry'].update({"value" : '1', "locked" : True}) 74 | 75 | # agent.train(py_dict, ( "2_answer", "UpdateTextField", {"value": 0} ), 1, 76 | # arg_foci=['2_carry', "2_upper", '2_lower']) 77 | # py_dict['2_answer'].update({"value" : '0', "locked" : True}) 78 | 79 | # agent.train(py_dict, ( "3_carry", "UpdateTextField", {"value": 1} ), 1, 80 | # arg_foci=['2_carry', "2_upper", '2_lower']) 81 | # py_dict['3_carry'].update({"value" : '1', "locked" : True}) 82 | 83 | # agent.train(py_dict, ( "3_answer", "UpdateTextField", {"value": 1} ), 1, 84 | # arg_foci=['3_carry']) 85 | # py_dict['3_answer'].update({"value" : '1', "locked" : True}) 86 | 87 | 88 | # py_dict = new_mc_addition_state(456,582) 89 | 90 | # print(agent.act(py_dict)) 91 | 92 | agent = CREAgent(feature_set=[], function_set=function_set, 93 | where="antiunify") 94 | 95 | py_dict = new_mc_addition_state(333,333) 96 | print(agent.act_all(py_dict)) 97 | 98 | agent.train(py_dict, ( "0_answer", "UpdateTextField", {"value": 6} ), 1, 99 | arg_foci=["0_upper", '0_lower']) 100 | py_dict['0_answer'].update({"value" : '6', "locked" : True}) 101 | 102 | agent.train(py_dict, ( "1_answer", "UpdateTextField", {"value": 6} ), 1, 103 | arg_foci=["1_upper", '1_lower']) 104 | py_dict['1_answer'].update({"value" : '6', "locked" : True}) 105 | 106 | agent.train(py_dict, ( "2_answer", "UpdateTextField", {"value": 6} ), 1, 107 | arg_foci=["2_upper", '2_lower']) 108 | py_dict['2_answer'].update({"value" : '6', "locked" : True}) 109 | 110 | py_dict = new_mc_addition_state(333,333) 111 | print(agent.act_all(py_dict)) 112 | 113 | # py_dict = new_mc_addition_state(456,582) 114 | 115 | # print(agent.act(py_dict)) 116 | # act_all = agent.act_all(py_dict) 117 | # print("L", act_all) 118 | 119 | def test_feedback_updating(): 120 | function_set = ["Add3", "Mod10", "Add", "Div10", "Copy"] 121 | # Test no FOCI 122 | agent = CREAgent(feature_set=[], function_set=function_set, 123 | where="antiunify") 124 | 125 | py_dict = new_mc_addition_state(567,491) 126 | print(agent.act(py_dict)) 127 | 128 | # Same as example above, but add preceeding negative examples that 129 | # should be overridden. 130 | agent.train(py_dict, ( "0_answer", "UpdateTextField", {"value": 8} ), 1, arg_foci=["0_upper", '0_lower']) 131 | agent.train(py_dict, ( "0_answer", "UpdateTextField", {"value": 8} ), -1, arg_foci=["0_upper", '0_lower']) 132 | agent.train(py_dict, ( "0_answer", "UpdateTextField", {"value": 8} ), -1, arg_foci=["0_upper", '0_lower']) 133 | agent.train(py_dict, ( "0_answer", "UpdateTextField", {"value": 8} ), 1, arg_foci=["0_upper", '0_lower']) 134 | # py_dict['0_answer'].update({"value" : '8', "locked" : True}) 135 | 136 | # agent.train(py_dict, ( "1_answer", "UpdateTextField", {"value": 5} ), 1) 137 | # py_dict['1_answer'].update({"value" : '5', "locked" : True}) 138 | 139 | # agent.train(py_dict, ( "2_carry", "UpdateTextField", {"value": 1} ), 1) 140 | # py_dict['2_carry'].update({"value" : '1', "locked" : True}) 141 | 142 | # agent.train(py_dict, ( "2_answer", "UpdateTextField", {"value": 0} ), 1) 143 | # py_dict['2_answer'].update({"value" : '0', "locked" : True}) 144 | 145 | # agent.train(py_dict, ( "3_carry", "UpdateTextField", {"value": 1} ), 1) 146 | # py_dict['3_carry'].update({"value" : '1', "locked" : True}) 147 | 148 | # agent.train(py_dict, ( "3_answer", "UpdateTextField", {"value": 1} ), 1) 149 | # py_dict['3_answer'].update({"value" : '1', "locked" : True}) 150 | 151 | 152 | if __name__ == "__main__": 153 | import faulthandler; faulthandler.enable() 154 | # test_init_base_dipl() 155 | test_cre_agent() 156 | # test_feedback_updating() 157 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/tests/test_extending.py: -------------------------------------------------------------------------------- 1 | from apprentice.agents.cre_agents.extending import new_register_all, new_register_decorator, registries 2 | 3 | def test_register(): 4 | register_funcs = new_register_decorator("my_funcs") 5 | register_all_things = new_register_all("things", types=[dict], acceptor_funcs=[lambda x : isinstance(x,tuple)]) 6 | 7 | @register_funcs 8 | def foo(): 9 | return "foo" 10 | 11 | @register_funcs 12 | def foo(): 13 | return "foo" 14 | 15 | assert registries['my_funcs']['foo'] == foo 16 | 17 | with register_all_things() as things: 18 | a = {1:0} 19 | b = {'q':0} 20 | c = ('eggo',) 21 | d = ['noodles'] 22 | e = ['noodlez'] 23 | 24 | assert len(things) == 3 25 | assert {1:0} in things 26 | assert {'q':0} in things 27 | assert ('eggo',) in things 28 | 29 | print("things:", things) 30 | 31 | print(registries) 32 | 33 | with register_all_things as things: 34 | x = {7:0} 35 | b = {'q' : 0} 36 | z = ('eggo',) 37 | q = ['noodles'] 38 | v = ['noodlez'] 39 | 40 | print("things:", things) 41 | assert len(things) == 3 42 | assert {7:0} in things 43 | assert {'q':0} in things 44 | assert ('eggo',) in things 45 | 46 | register_all_stuff = new_register_all("stuff", types=list) 47 | 48 | def reg(): 49 | r = ["r"] 50 | q = ["Q"] 51 | register_all_stuff() 52 | 53 | reg() 54 | print(registries) 55 | stuff_registry = registries['stuff'] 56 | assert len(stuff_registry) == 2 57 | assert stuff_registry['r'] == ['r'] 58 | assert stuff_registry['q'] == ["Q"] 59 | 60 | print(things) 61 | print(registries) 62 | 63 | if __name__ == "__main__": 64 | test_register() 65 | 66 | 67 | -------------------------------------------------------------------------------- /apprentice/agents/cre_agents/tests/test_state.py: -------------------------------------------------------------------------------- 1 | from apprentice.agents.cre_agents.environment import Component, Button, TextField, Container 2 | from apprentice.agents.cre_agents.state import encode_neighbors, State 3 | 4 | from numba.types import unicode_type 5 | from cre import MemSet 6 | from cre.transform import MemSetBuilder, Flattener, FeatureApplier, RelativeEncoder, Vectorizer 7 | from cre.default_funcs import Equals 8 | 9 | def new_mc_addition_state(upper, lower): 10 | upper, lower = str(upper), str(lower) 11 | n = max(len(upper),len(lower)) 12 | 13 | tf_config = {"type": "TextField", "width" : 100, "height" : 100, "value" : ""} 14 | # comp_config = tf_config 15 | # comp_config = {"type": "Component", "width" : 100, "height" : 100} 16 | hidden_config = {**tf_config, 'locked' : True} 17 | comp_config = hidden_config 18 | 19 | d_state = { 20 | "operator" : {"id" : "operator", "x" :-110,"y" : 220 , **comp_config}, 21 | # "line" : {"id" : "line", "x" :0, "y" : 325 , **comp_config, "height" : 5}, 22 | "done" : {"id" : "done", "x" :0, "y" : 440 , **comp_config, "type": "Button"}, 23 | "hidey1" : {"id" : "hidey1", "x" :n * 110, "y" : 0 , **hidden_config}, 24 | "hidey2" : {"id" : "hidey2", "x" :0, "y" : 110 , **hidden_config}, 25 | "hidey3" : {"id" : "hidey3", "x" :0, "y" : 220 , **hidden_config}, 26 | } 27 | 28 | for i in range(n): 29 | offset = (n - i) * 110 30 | d_state.update({ 31 | f"{i}_carry": {"id" : f"{i}_carry", "x" :offset, "y" : 0 , **tf_config}, 32 | f"{i}_upper": {"id" : f"{i}_upper", "x" :offset, "y" : 110 , "locked" : True, **tf_config}, 33 | f"{i}_lower": {"id" : f"{i}_lower", "x" :offset, "y" : 220 , "locked" : True, **tf_config}, 34 | f"{i}_answer": {"id" : f"{i}_answer", "x" :offset, "y" : 330 , **tf_config}, 35 | }) 36 | 37 | del d_state["0_carry"] 38 | 39 | d_state.update({ 40 | f"{n}_carry": {"id" : f"{n}_carry", "x" :0, "y" : 0 , **tf_config}, 41 | f"{n}_answer": {"id" : f"{n}_answer", "x" :0, "y" : 330 , **tf_config}, 42 | }) 43 | 44 | for i,c in enumerate(reversed(upper)): 45 | d_state[f'{i}_upper']['value'] = c 46 | 47 | for i,c in enumerate(reversed(lower)): 48 | d_state[f'{i}_lower']['value'] = c 49 | 50 | 51 | # d_state = encode_neighbors(d_state) 52 | 53 | # pprint(d_state) 54 | return d_state 55 | 56 | 57 | def test_encode_neighbors(): 58 | pass 59 | 60 | def test_flatten_featurize(): 61 | agent = object() 62 | state_cls = State(agent) 63 | 64 | 65 | fl = Flattener((Component, Button, TextField, Container)) 66 | fe = FeatureApplier([Equals(unicode_type, unicode_type)]) 67 | 68 | @state_cls.register_transform(is_incremental=True, prereqs=['working_memory']) 69 | def flat(state): 70 | wm = state.get('working_memory') 71 | return fl(wm) 72 | 73 | @state_cls.register_transform(is_incremental=True, prereqs=['flat']) 74 | def flat_featurized(state): 75 | flat = state.get('flat') 76 | return fe(flat) 77 | 78 | # print(transfrom_registry) 79 | 80 | a = TextField(id="a",value="a") 81 | b = TextField(id="b",value="b") 82 | c = TextField(id="c",value="c") 83 | d = TextField(id="d",value="a") 84 | wm = MemSet() 85 | wm.declare(a) 86 | wm.declare(b) 87 | wm.declare(c) 88 | wm.declare(d) 89 | 90 | state = state_cls() 91 | state.set("working_memory", wm) 92 | flat = state.get("flat") 93 | feat = state.get("flat_featurized") 94 | 95 | assert set([fact.val for fact in flat]) == {False, "a", "b", "c"} 96 | assert set([fact.val for fact in feat]) == {True, False, "a", "b", "c"} 97 | 98 | print(flat) 99 | print(feat) 100 | 101 | a = TextField(id="a",value="A") 102 | b = TextField(id="b",value="B") 103 | c = TextField(id="c",value="A") 104 | d = TextField(id="d",value="A") 105 | wm = MemSet() 106 | wm.declare(a) 107 | wm.declare(b) 108 | wm.declare(c) 109 | wm.declare(d) 110 | 111 | state.set("working_memory", wm) 112 | flat = state.get("flat") 113 | feat = state.get("flat_featurized") 114 | 115 | assert set([fact.val for fact in flat]) == {False, "A", "B"} 116 | assert set([fact.val for fact in feat]) == {True, False, "A", "B"} 117 | 118 | print(flat) 119 | print(feat) 120 | 121 | 122 | def test_full_when_pipeline(): 123 | from numba.types import f8, boolean, string 124 | from cre import Var 125 | py_dicts = new_mc_addition_state(567,891) 126 | print(py_dicts) 127 | 128 | agent = object() 129 | state_cls = State(agent) 130 | 131 | fact_types = (Component, Button, TextField, Container) 132 | val_types = [f8, string, boolean] 133 | 134 | msb = MemSetBuilder() 135 | fl = Flattener(fact_types) 136 | fe = FeatureApplier([Equals(string, string)]) 137 | re = RelativeEncoder(fact_types) 138 | ve = Vectorizer(val_types) 139 | 140 | @state_cls.register_transform(is_incremental=True, prereqs=['working_memory']) 141 | def flat(state): 142 | wm = state.get('working_memory') 143 | return fl(wm) 144 | 145 | @state_cls.register_transform(is_incremental=True, prereqs=['flat']) 146 | def flat_featurized(state): 147 | flat = state.get('flat') 148 | return fe(flat) 149 | 150 | state = state_cls() 151 | py_dicts = encode_neighbors(py_dicts) 152 | wm = msb(py_dicts) 153 | state.set("working_memory", wm) 154 | 155 | wm = state.get('working_memory') 156 | flat_featurized = state.get('flat_featurized') 157 | 158 | _vars = [Var(TextField,"sel")] 159 | facts = [wm.get_fact(id='0_answer')] 160 | 161 | re.set_in_memset(wm) 162 | rel = re.encode_relative_to(flat_featurized, facts, _vars) 163 | 164 | vec = ve(rel) 165 | 166 | print(vec) 167 | 168 | 169 | 170 | if __name__ == "__main__": 171 | test_flatten_featurize() 172 | test_full_when_pipeline() 173 | -------------------------------------------------------------------------------- /apprentice/agents/diff_base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Dict 3 | from jsondiff import diff 4 | 5 | from apprentice.agents.base import BaseAgent 6 | from apprentice.working_memory.representation import Sai 7 | 8 | 9 | class DiffBaseAgent(BaseAgent): 10 | prior_state = {} 11 | 12 | def request(self, state: Dict, **kwargs) -> Dict: 13 | """ 14 | Returns a dict containing a Selection, Action, Input. 15 | 16 | :param state: a state represented as a dict (parsed from JSON) 17 | """ 18 | d = diff(self.prior_state, state) 19 | self.prior_state = state 20 | return self.request_diff(d) 21 | 22 | @abstractmethod 23 | def request_diff(self, state_diff: Dict) -> Dict: 24 | """ 25 | :param diff: a diff object that is the output of JSON diff 26 | """ 27 | pass 28 | 29 | def train(self, state: Dict, sai: Sai, reward: float, next_state: Dict, 30 | **kwargs): 31 | """ 32 | Accepts a JSON/Dict object representing the state, 33 | a JSON/Dict object representing the state after the SAI is invoked, 34 | a string representing the skill label, 35 | a list of strings representing the foas, 36 | a string representation the selection action and inputs, 37 | a reward 38 | """ 39 | state_diff = diff(self.prior_state, state) 40 | next_state_diff = diff(state, next_state) 41 | self.prior_state = next_state 42 | return self.train_diff(state_diff, next_state_diff, sai, reward) 43 | 44 | @abstractmethod 45 | def train_diff(self, state_diff, next_state_diff, sai, reward): 46 | """ 47 | Updates the state by some provided diff, then trains on the provided 48 | demonstration in this state. 49 | """ 50 | pass 51 | 52 | 53 | if __name__ == "__main__": 54 | pass 55 | -------------------------------------------------------------------------------- /apprentice/agents/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from itertools import product 3 | from random import uniform 4 | 5 | def weighted_choice(choices): 6 | total = sum(w for c, w in choices) 7 | r = uniform(0, total) 8 | upto = 0 9 | for c, w in choices: 10 | if upto + w >= r: 11 | return c, w 12 | upto += w 13 | assert False, "Shouldn't get here" 14 | 15 | def gen_varnames(start=0, end=float('inf')): 16 | while start < end: 17 | var = "" 18 | val = start 19 | while val > 25: 20 | r = val % 26 21 | val = val // 26 22 | var = chr(r + ord('A')) + var 23 | if var == "": 24 | var = chr(val + ord('A')) + var 25 | else: 26 | var = chr(val-1 + ord('A')) + var 27 | yield var 28 | start += 1 29 | 30 | def tup_sai(selection,action,inputs): 31 | sai = ['sai'] 32 | sai.append(action) 33 | sai.append(selection) 34 | 35 | if inputs is None: 36 | pass 37 | elif isinstance(inputs, list): 38 | sai.extend(inputs) 39 | else: 40 | sai.append(inputs) 41 | 42 | return tuple(sai) 43 | 44 | def compute_features(state, features): 45 | original_state = {a: state[a] for a in state} 46 | for feature in features: 47 | num_args = len(inspect.getargspec(features[feature]).args) 48 | if num_args < 1: 49 | raise Exception("Features must accept at least 1 argument") 50 | 51 | possible_args = [attr for attr in original_state] 52 | 53 | for tupled_args in product(possible_args, repeat=num_args): 54 | new_feature = (feature,) + tupled_args 55 | values = [state[attr] for attr in tupled_args] 56 | try: 57 | yield new_feature, features[feature](*values) 58 | except Exception as e: 59 | pass 60 | 61 | def parse_foas(foas): 62 | return [{'name':foa.split('|')[1], 'value':foa.split('|')[2]} for foa in foas] 63 | -------------------------------------------------------------------------------- /apprentice/custom_operators.py: -------------------------------------------------------------------------------- 1 | from apprentice.planners.fo_planner import Operator 2 | # from planners.VectorizedPlanner import BaseOperator 3 | 4 | ''' USAGE INSTRUCTIONS 5 | FO Operator Structure: Operator(
, [...], [...]) 6 | 7 |
: ('', '?', ... , '?') 8 | example : ('Add', '?x', '?y') 9 | : [(('', '?'),'?'), ... , 10 | (, '?', ...), ... 11 | ] 12 | example : [ (('value', '?x'), '?xv'), 13 | (('value', '?y'), '?yv'), 14 | (lambda x, y: x <= y, '?x', '?y') 15 | ] 16 | : [(, 17 | ('', ('', '?'), ...), 18 | (, '?', ...) 19 | ), ...] 20 | example :[(('value', ('Add', ('value', '?x'), ('value', '?y'))), 21 | (int_float_add, '?xv', '?yv'))]) 22 | Full Example: 23 | def int_float_add(x, y): 24 | z = float(x) + float(y) 25 | if z.is_integer(): 26 | z = int(z) 27 | return str(z) 28 | 29 | add_rule = Operator(('Add', '?x', '?y'), 30 | [(('value', '?x'), '?xv'), 31 | (('value', '?y'), '?yv'), 32 | (lambda x, y: x <= y, '?x', '?y') 33 | ], 34 | [(('value', ('Add', ('value', '?x'), ('value', '?y'))), 35 | (int_float_add, '?xv', '?yv'))]) 36 | 37 | Note: You should explicitly register your operators so you can 38 | refer to them in your training.json, otherwise the name will 39 | be the same as the local variable 40 | example: Operator.register("Add") 41 | 42 | vvvvvvvvvvvvvvvvvvvv WRITE YOUR OPERATORS BELOW vvvvvvvvvvvvvvvvvvvvvvv ''' 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | # ^^^^^^^^^^^^^^ DEFINE ALL YOUR OPERATORS ABOVE THIS LINE ^^^^^^^^^^^^^^^^ 65 | for name,op in locals().copy().items(): 66 | if(isinstance(op, Operator)): 67 | Operator.register(name,op) 68 | -------------------------------------------------------------------------------- /apprentice/explain/__init__.py: -------------------------------------------------------------------------------- 1 | from .explanation import Explanation -------------------------------------------------------------------------------- /apprentice/explain/kill_engine.py: -------------------------------------------------------------------------------- 1 | from experta import Fact, KnowledgeEngine, Rule, MATCH, TEST, DefFacts 2 | 3 | 4 | class Depressed(Fact): 5 | pass 6 | 7 | 8 | class Hate(Fact): 9 | pass 10 | 11 | 12 | class Buy(Fact): 13 | pass 14 | 15 | 16 | class Possess(Fact): 17 | pass 18 | 19 | 20 | class Gun(Fact): 21 | pass 22 | 23 | 24 | class Weapon(Fact): 25 | pass 26 | 27 | 28 | class Kill(Fact): 29 | pass 30 | 31 | 32 | class KillEngine(KnowledgeEngine): 33 | @DefFacts() 34 | def first(self): 35 | yield Depressed("JOHN") 36 | yield Buy("JOHN", "OBJ1") 37 | yield Gun("OBJ1") 38 | 39 | @Rule( 40 | Hate(MATCH.a, MATCH.b), Possess(MATCH.a, MATCH.c), Weapon(MATCH.c)) 41 | def kill_rule(self, a, b): 42 | print(a + ' kills ' + b) 43 | self.declare(Kill(a, b)) 44 | 45 | @Rule( 46 | Depressed(MATCH.w)) 47 | def hate_rule(self, w): 48 | print(w + ' hates ' + w) 49 | self.declare(Hate(w, w)) 50 | 51 | @Rule( 52 | Buy(MATCH.u, MATCH.v), 53 | TEST(lambda u: True)) 54 | def possess_rule(self, u, v): 55 | print(u + ' possesses ' + v) 56 | self.declare(Possess(u, v)) 57 | 58 | @Rule( 59 | Gun(MATCH.z), 60 | TEST(lambda z: True)) 61 | def weapon_rule(self, z): 62 | print(z + " is a weapon") 63 | self.declare(Weapon(z)) 64 | 65 | class KillEngineEmpty(KnowledgeEngine): 66 | @DefFacts() 67 | def first(self): 68 | yield Depressed("JOHN") 69 | yield Buy("JOHN", "OBJ1") 70 | yield Gun("OBJ1") 71 | 72 | from apprentice.explain.util import rename_rule_unique 73 | 74 | if __name__=="__main__": 75 | from apprentice.explain.explanation import Explanation 76 | from apprentice.working_memory import ExpertaWorkingMemory 77 | 78 | new_wm = ExpertaWorkingMemory(KillEngineEmpty()) 79 | cf = KillEngine() 80 | cf.reset() 81 | cf.run(10) 82 | facts = cf.facts 83 | kill_fact = cf.facts[7] 84 | x = Explanation(kill_fact) 85 | #print(x.general) 86 | #print(x.conditions) 87 | c = x.conditions[-1] 88 | print("===") 89 | 90 | r = x.new_rule 91 | new_wm.add_rule(r) 92 | 93 | new_wm.ke.reset() 94 | new_wm.ke.run(10) 95 | print(new_wm.ke.get_rules()) 96 | -------------------------------------------------------------------------------- /apprentice/explain/lambda_test.py: -------------------------------------------------------------------------------- 1 | from experta import * 2 | import inspect 3 | 4 | class Maximum(KnowledgeEngine): 5 | @Rule(NOT(Fact(max=W()))) 6 | def init(self): 7 | self.declare(Fact(max=0)) 8 | 9 | @Rule(Fact(val=MATCH.val), 10 | AS.m << Fact(max=MATCH.max), 11 | TEST(lambda max, val: val > max)) 12 | def compute_max(self, m, val): 13 | self.modify(m, max=val) 14 | 15 | @Rule(AS.v << Fact(val=MATCH.val), 16 | Fact(max=MATCH.max), 17 | TEST(lambda max, val: val <= max)) 18 | def remove_val(self, v): 19 | self.retract(v) 20 | 21 | @Rule(AS.v << Fact(max=W()), 22 | NOT(Fact(val=W()))) 23 | def print_max(self, v): 24 | print("Max:", v['max']) 25 | 26 | 27 | from apprentice.explain.util import rename_rule_unique 28 | import random 29 | import inspect 30 | 31 | if __name__ == "__main__": 32 | ke = Maximum() 33 | a = ke.remove_val._wrapped 34 | b = ke.print_max._wrapped 35 | a2 = inspect.getsource(a) 36 | b2 = inspect.getsource(b) 37 | print(a2) 38 | print("___") 39 | print(b2) 40 | # x = get_rule_binding(ke.kill_rule) 41 | 42 | -------------------------------------------------------------------------------- /apprentice/learners/HowLearner.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | 3 | from apprentice.agents.utils import tup_sai 4 | 5 | 6 | class IncrementalMany(object): 7 | 8 | def __init__(self, planner): 9 | self.planner = planner 10 | self.explanations = {} 11 | self.examples = [] 12 | 13 | def ifit(self, example): 14 | found = False 15 | for exp in self.explanations: 16 | if self.explains(exp, example): 17 | self.explanations[exp].append(example) 18 | found = True 19 | 20 | if not found and example['correct'] is True: 21 | sai = tup_sai(example['selection'], example['action'], 22 | example['inputs']) 23 | exp = tuple(self.planner.explain_sai(example['limited_state'], 24 | sai))[0] 25 | self.explanations[exp] = [example] 26 | for e in self.examples: 27 | if self.explains(exp, e): 28 | self.explanations[exp].append(e) 29 | 30 | self.examples.append(example) 31 | self.remove_subsumed() 32 | return self.explanations 33 | 34 | def remove_subsumed(self): 35 | unnecessary = set() 36 | 37 | explanations = list(self.explanations) 38 | for i, exp1 in enumerate(explanations): 39 | for exp2 in explanations[i+1:]: 40 | if self.subsumes(exp1, exp2): 41 | unnecessary.add(exp2) 42 | elif self.subsumes(exp2, exp1): 43 | unnecessary.add(exp1) 44 | 45 | for exp in unnecessary: 46 | del self.explanations[exp] 47 | 48 | def subsumes(self, exp1, exp2): 49 | """ 50 | Checks if one explanation subsumes another. However, this only applies 51 | to the positive examples. The negative examples are kept around in 52 | self.examples, but we don't need explainations that only cover negative 53 | examples. 54 | """ 55 | for e in self.explanations[exp2]: 56 | if e['correct'] is False: 57 | continue 58 | if e not in self.explanations[exp1]: 59 | return False 60 | return True 61 | 62 | def fit(self, examples): 63 | self.explanations = {} 64 | for e in examples: 65 | self.ifit(e) 66 | return self.explanations 67 | 68 | def explains(self, explanation, example): 69 | """ 70 | Checks if an explanation successfully explains an example 71 | """ 72 | try: 73 | sai = tup_sai(example['selection'], example['action'], 74 | [example['inputs'][a] for a in example['inputs']]) 75 | grounded_plan = tuple([self.planner.execute_plan(ele, 76 | example['limited_state']) 77 | for ele in explanation]) 78 | print() 79 | print(sai, 'VS', grounded_plan) 80 | print() 81 | return self.planner.is_sais_equal(grounded_plan, sai) 82 | except Exception as e: 83 | print(e) 84 | print('plan could not execute') 85 | pprint(explanation) 86 | pprint(example['limited_state']) 87 | return False 88 | 89 | 90 | class SimStudentHow(IncrementalMany): 91 | 92 | def ifit(self, example): 93 | found = False 94 | 95 | for exp in list(self.explanations): 96 | if self.explains(exp, example): 97 | self.explanations[exp].append(example) 98 | found = True 99 | elif not found and example['correct'] is True: 100 | seed = self.explanations[exp][0] 101 | sai = tup_sai(seed['selection'], seed['action'], 102 | seed['inputs']) 103 | 104 | print("LIMITED STATE FOR HOW") 105 | pprint(seed['limited_state']) 106 | 107 | for new_exp in self.planner.explain_sai_iter(seed['limited_state'], 108 | sai): 109 | if not self.explains(new_exp, example): 110 | continue 111 | covers = True 112 | for e in self.explanations[exp][1:]: 113 | if e['correct'] and not self.explains(new_exp, e): 114 | covers = False 115 | break 116 | if covers: 117 | self.explanations[new_exp] = [example] 118 | for e in self.examples: 119 | if self.explains(new_exp, e): 120 | self.explanations[new_exp].append(e) 121 | found = True 122 | break 123 | if not found and example['correct']: 124 | sai = tup_sai(example['selection'], example['action'], 125 | example['inputs']) 126 | exp = tuple(self.planner.explain_sai(example['limited_state'], 127 | sai))[0] 128 | self.explanations[exp] = [example] 129 | for e in self.examples: 130 | if self.explains(exp, e): 131 | self.explanations[exp].append(e) 132 | 133 | self.examples.append(example) 134 | self.remove_subsumed() 135 | return self.explanations 136 | 137 | def get_how_learner(name): 138 | return HOW_LEARNERS[name.lower().replace(' ', '').replace('_', '')] 139 | 140 | HOW_LEARNERS = { 141 | 'incremental':IncrementalMany, 142 | 'simstudent':SimStudentHow 143 | } 144 | -------------------------------------------------------------------------------- /apprentice/learners/HowLearnerOld.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | 3 | from apprentice.agents.utils import tup_sai 4 | 5 | 6 | class IncrementalMany(object): 7 | 8 | def __init__(self, planner): 9 | self.planner = planner 10 | self.explanations = {} 11 | self.examples = [] 12 | 13 | def ifit(self, example): 14 | found = False 15 | for exp in self.explanations: 16 | if self.explains(exp, example): 17 | self.explanations[exp].append(example) 18 | found = True 19 | 20 | if not found and example['correct'] is True: 21 | sai = tup_sai(example['selection'], example['action'], 22 | example['inputs']) 23 | exp = tuple(self.planner.explain_sai(example['limited_state'], 24 | sai))[0] 25 | self.explanations[exp] = [example] 26 | for e in self.examples: 27 | if self.explains(exp, e): 28 | self.explanations[exp].append(e) 29 | 30 | self.examples.append(example) 31 | self.remove_subsumed() 32 | return self.explanations 33 | 34 | def remove_subsumed(self): 35 | unnecessary = set() 36 | 37 | explanations = list(self.explanations) 38 | for i, exp1 in enumerate(explanations): 39 | for exp2 in explanations[i+1:]: 40 | if self.subsumes(exp1, exp2): 41 | unnecessary.add(exp2) 42 | elif self.subsumes(exp2, exp1): 43 | unnecessary.add(exp1) 44 | 45 | for exp in unnecessary: 46 | del self.explanations[exp] 47 | 48 | def subsumes(self, exp1, exp2): 49 | """ 50 | Checks if one explanation subsumes another. However, this only applies 51 | to the positive examples. The negative examples are kept around in 52 | self.examples, but we don't need explainations that only cover negative 53 | examples. 54 | """ 55 | for e in self.explanations[exp2]: 56 | if e['correct'] is False: 57 | continue 58 | if e not in self.explanations[exp1]: 59 | return False 60 | return True 61 | 62 | def fit(self, examples): 63 | self.explanations = {} 64 | for e in examples: 65 | self.ifit(e) 66 | return self.explanations 67 | 68 | def explains(self, explanation, example): 69 | """ 70 | Checks if an explanation successfully explains an example 71 | """ 72 | try: 73 | sai = tup_sai(example['selection'], example['action'], 74 | [example['inputs'][a] for a in example['inputs']]) 75 | grounded_plan = tuple([self.planner.execute_plan(ele, 76 | example['limited_state']) 77 | for ele in explanation]) 78 | print() 79 | print(sai, 'VS', grounded_plan) 80 | print() 81 | return self.planner.is_sais_equal(grounded_plan, sai) 82 | except Exception as e: 83 | print(e) 84 | print('plan could not execute') 85 | pprint(explanation) 86 | pprint(example['limited_state']) 87 | return False 88 | 89 | 90 | class SimStudentHow(IncrementalMany): 91 | 92 | def ifit(self, example): 93 | found = False 94 | 95 | for exp in list(self.explanations): 96 | if self.explains(exp, example): 97 | self.explanations[exp].append(example) 98 | found = True 99 | elif not found and example['correct'] is True: 100 | seed = self.explanations[exp][0] 101 | sai = tup_sai(seed['selection'], seed['action'], 102 | seed['inputs']) 103 | 104 | print("LIMITED STATE FOR HOW") 105 | pprint(seed['limited_state']) 106 | 107 | for new_exp in self.planner.explain_sai_iter(seed['limited_state'], 108 | sai): 109 | if not self.explains(new_exp, example): 110 | continue 111 | covers = True 112 | for e in self.explanations[exp][1:]: 113 | if e['correct'] and not self.explains(new_exp, e): 114 | covers = False 115 | break 116 | if covers: 117 | self.explanations[new_exp] = [example] 118 | for e in self.examples: 119 | if self.explains(new_exp, e): 120 | self.explanations[new_exp].append(e) 121 | found = True 122 | break 123 | if not found and example['correct']: 124 | sai = tup_sai(example['selection'], example['action'], 125 | example['inputs']) 126 | exp = tuple(self.planner.explain_sai(example['limited_state'], 127 | sai))[0] 128 | self.explanations[exp] = [example] 129 | for e in self.examples: 130 | if self.explains(exp, e): 131 | self.explanations[exp].append(e) 132 | 133 | self.examples.append(example) 134 | self.remove_subsumed() 135 | return self.explanations 136 | 137 | def get_how_learner(name): 138 | return HOW_LEARNERS[name.lower().replace(' ', '').replace('_', '')] 139 | 140 | HOW_LEARNERS = { 141 | 'incremental':IncrementalMany, 142 | 'simstudent':SimStudentHow 143 | } 144 | -------------------------------------------------------------------------------- /apprentice/learners/WhichLearner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | from random import shuffle 4 | 5 | 6 | class WhichLearner(object): 7 | 8 | def __init__(self, agent, utility_type, explanation_choice, remove_utility_type=None, **learner_kwargs): 9 | 10 | # self.learner_name = learner_name 11 | self.agent = agent 12 | rem_util_same = (utility_type == remove_utility_type) or remove_utility_type is None 13 | self.utility_type = utility_type 14 | self.remove_utility_type = remove_utility_type if not rem_util_same else utility_type 15 | self.explanation_choice = explanation_choice 16 | self.learner_kwargs = learner_kwargs 17 | self.rhs_by_label = {} 18 | self.learners = {} 19 | self.removal_learners = {} if not rem_util_same else self.learners 20 | self.explanation_choice = get_explanation_choice(explanation_choice) 21 | 22 | 23 | 24 | def add_rhs(self,rhs): 25 | self.learners[rhs] = get_utility_sublearner(self, rhs,self.utility_type,**self.learner_kwargs) 26 | rhs_list = self.rhs_by_label.get(rhs.label,[]) 27 | rhs_list.append(rhs) 28 | self.rhs_by_label[rhs.label] = rhs_list 29 | if(self.utility_type != self.remove_utility_type): 30 | self.removal_learners[rhs] = get_utility_sublearner(self, rhs,self.remove_utility_type,**self.learner_kwargs) 31 | 32 | 33 | def ifit(self,rhs, state, mapping, reward): 34 | self.learners[rhs].ifit(state, mapping, reward) 35 | if(self.utility_type != self.remove_utility_type): 36 | self.removal_learners[rhs].ifit(state, mapping, reward) 37 | 38 | def sort_by_utility(self,rhs_list,state): 39 | # print([(x._id_num,self.learners[x].utility(state)) for x in skills]) 40 | # out = sorted(skills,reverse=True, key=lambda x:self.learners[x].utility(state)) 41 | # print([(x._id_num,self.learners[x].utility(state)) for x in out]) 42 | return sorted(rhs_list,reverse=True, key=lambda x:self.get_utility(x,state)) 43 | 44 | def select_how(self,expl_iter): 45 | return self.explanation_choice(expl_iter) 46 | 47 | def get_utility(self, rhs, state): 48 | return self.learners[rhs].utility(state) 49 | 50 | def get_removal_utility(self, rhs, state): 51 | return self.removal_learners[rhs].utility(state) 52 | 53 | 54 | 55 | ####---------------utility------------######## 56 | 57 | class BaseutilityAgent(object): 58 | def __init__(self): 59 | pass 60 | def ifit(self,state,reward): 61 | pass 62 | def utility(self,state): 63 | pass 64 | 65 | 66 | 67 | 68 | 69 | class TotalCorrect(BaseutilityAgent): 70 | def __init__(self): 71 | self.num_correct = 0 72 | self.num_incorrect = 0 73 | def ifit(self,state, mapping, reward): 74 | if(reward > 0): 75 | self.num_correct += 1 76 | else: 77 | self.num_incorrect += 1 78 | def utility(self,state): 79 | return self.num_correct 80 | 81 | class IncrementalWhenAccuracy(TotalCorrect): 82 | def ifit(self, state, mapping, reward): 83 | v_state = state.get_view(("variablize",self.rhs,tuple(mapping))) 84 | pred = self.agent.when_learner.predict(self.rhs, v_state) 85 | if(pred == reward): 86 | self.num_correct += 1 87 | else: 88 | self.num_incorrect += 1 89 | 90 | def utility(self, state): 91 | p,n = self.num_correct, self.num_incorrect 92 | s = p + n 93 | return (p / s if s > 0 else 0, s) 94 | 95 | 96 | class ProportionCorrect(TotalCorrect): 97 | def utility(self,state): 98 | p,n = self.num_correct, self.num_incorrect 99 | s = p + n 100 | return (p / s if s > 0 else 0, s) 101 | 102 | class WeightedProportionCorrect(TotalCorrect): 103 | def utility(self,state,w=2.0): 104 | p,n = self.num_correct, w*self.num_incorrect 105 | s = p + n 106 | return (p / s if s > 0 else 0, s) 107 | 108 | class NonLinearProportionCorrect(TotalCorrect): 109 | def utility(self,state,a=1.0,b=.25): 110 | p,n = self.num_correct, self.num_incorrect 111 | n = a*n + b*(n*n) 112 | s = p + n 113 | return (p / s if s > 0 else 0, s) 114 | 115 | ####---------------HOW CULL RULE------------######## 116 | 117 | def first(expl_iter): 118 | return [next(iter(expl_iter))] 119 | 120 | def most_parsimonious(expl_iter): 121 | l = sorted(expl_iter,key=lambda x:x.get_how_depth()) 122 | return l[:1] 123 | 124 | def least_depth(expl_iter): 125 | expl_iter = list(expl_iter) 126 | shuffle(expl_iter) 127 | l = sorted(expl_iter,key=lambda x: getattr(x.rhs.input_rule,'depth',0)) 128 | return l[:1] 129 | 130 | def least_operations(expl_iter): 131 | expl_iter = list(expl_iter) 132 | shuffle(expl_iter) 133 | l = sorted(expl_iter,key=lambda x: getattr(x.rhs.input_rule,'num_ops',0)) 134 | return l[:1] 135 | 136 | def return_all(expl_iter): 137 | return [x for x in expl_iter] 138 | 139 | def random(expl_iter): 140 | arr = [x for x in expl_iter] 141 | shuffle(arr) 142 | # print("RANDOM",str(arr[:1][0])) 143 | return arr[:1] 144 | 145 | # import itertools 146 | # def closest(expl_iter,knowledge_base): 147 | 148 | # expl_iter = sorted(expl_iter,key=lambda x:x.get_how_depth()) 149 | # closest = None 150 | # min_dist = float("inf") 151 | # for exp in expl_iter: 152 | # coords = [] 153 | # for v in exp.mapping.values(): 154 | # b = [x for x in knowledge_base.fc_query([(("offsetTop",v),"?top"),(("offsetLeft",v),"?left")],max_depth=0)] 155 | # coords.append((b[0]["?left"],b[0]["?top"])) 156 | 157 | # l1_dist = 0 158 | # pairs = list(itertools.combinations(coords, 2)) 159 | # for pair in pairs: 160 | # l1_dist += np.abs(pair[0][0] - pair[1][0]) + np.abs(pair[0][1] - pair[1][1]) 161 | # l1_dist = float(l1_dist) / max(float(len(pairs)), 1) 162 | # if(l1_dist < min_dist): 163 | # closest = exp 164 | # min_dist = l1_dist 165 | # # print(l1_dist) 166 | # if(closest != None): 167 | # print("HERE",list(closest.mapping.values()),min_dist) 168 | # return [closest] if closest != None else [] 169 | 170 | 171 | 172 | #####---------------UTILITIES------------------##### 173 | 174 | def get_explanation_choice(name): 175 | return CULL_HOW_RULES[name.lower().replace(' ', '').replace('_', '')] 176 | 177 | def get_utility_sublearner(parent, rhs, name,**learner_kwargs): 178 | sl = WHICH_utility_AGENTS[name.lower().replace(' ', '').replace('_', '')](**learner_kwargs) 179 | sl.rhs = rhs 180 | sl.agent = parent.agent 181 | return sl 182 | 183 | def get_which_learner(agent, utility_learner, explanation_choice,**kwargs): 184 | return WhichLearner(agent, utility_learner, explanation_choice,**kwargs) 185 | 186 | 187 | 188 | WHICH_utility_AGENTS = { 189 | 'incrementalwhenaccuracy': IncrementalWhenAccuracy, 190 | 'proportioncorrect': ProportionCorrect, 191 | 'totalcorrect': TotalCorrect, 192 | 'weightedproportioncorrect': WeightedProportionCorrect, 193 | 'nonlinearproportioncorrect': NonLinearProportionCorrect 194 | } 195 | 196 | CULL_HOW_RULES = { 197 | 'first': first, 198 | 'mostparsimonious': most_parsimonious, #probably need to depricate 199 | 'leastdepth': least_depth, 200 | 'leastoperations': least_operations, 201 | 'all': return_all, 202 | 'random' : random, 203 | # 'closest': closest, 204 | } 205 | -------------------------------------------------------------------------------- /apprentice/learners/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apprenticelearner/AL_Core/9534e3e1d5f5439740e0c44196a4a4f91602aed3/apprentice/learners/__init__.py -------------------------------------------------------------------------------- /apprentice/learners/grammar.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apprenticelearner/AL_Core/9534e3e1d5f5439740e0c44196a4a4f91602aed3/apprentice/learners/grammar.pickle -------------------------------------------------------------------------------- /apprentice/learners/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Standard functions used to support relational learning 3 | """ 4 | from random import uniform 5 | from itertools import product 6 | from multiprocessing import Pool 7 | from multiprocessing import cpu_count 8 | 9 | from apprentice.planners.fo_planner import Operator 10 | from apprentice.planners.fo_planner import build_index 11 | # from planners.fo_planner import subst 12 | from apprentice.planners.fo_planner import is_variable 13 | from apprentice.planners.fo_planner import extract_strings 14 | 15 | pool = None 16 | 17 | def weighted_choice(choices): 18 | """ 19 | A weighted version of choice. 20 | """ 21 | total = sum(w for w, c in choices) 22 | r = uniform(0, total) 23 | upto = 0 24 | for w, c in choices: 25 | if upto + w >= r: 26 | return c 27 | upto += w 28 | assert False, "Shouldn't get here" 29 | 30 | 31 | def get_variablizations(literal): 32 | """ 33 | Takes a literal and returns all possible variablizations of it. Currently, 34 | this replaces constants only. Also, it replaces them with a variable that 35 | is generated based on the hash of the constant, so that similar constants 36 | map to the same variable. 37 | """ 38 | if isinstance(literal, tuple): 39 | head = literal[0] 40 | possible_bodies = [[e] + list(get_variablizations(e)) for e in 41 | literal[1:]] 42 | for body in product(*possible_bodies): 43 | new = (head,) + tuple(body) 44 | if new != literal: 45 | yield new 46 | 47 | elif not is_variable(literal): 48 | yield '?gen%s' % repr(literal) 49 | 50 | 51 | def count_occurances(var, h): 52 | return len([s for x in h for s in extract_strings(x) if s == var]) 53 | 54 | 55 | def parallel_covers(x): 56 | h, constraints, x, xm = x 57 | return covers(h.union(constraints), x, xm) 58 | 59 | 60 | def test_coverage(h, constraints, pset, nset): 61 | global pool 62 | if pool is None: 63 | pool = Pool(cpu_count()) 64 | 65 | xset = [(h, constraints, p, pm) for p, pm in pset] 66 | pset_covers = pool.map(parallel_covers, xset) 67 | new_pset = [pset[i] for i, v in enumerate(pset_covers) if v is True] 68 | 69 | xset = [(h, constraints, n, nm) for n, nm in nset] 70 | nset_covers = pool.map(parallel_covers, xset) 71 | new_nset = [nset[i] for i, v in enumerate(nset_covers) if v is True] 72 | 73 | # print("TESTING MULTICORE!") 74 | # print(covers) 75 | # new_pset = [(p, pm) for p, pm in pset if 76 | # covers(h.union(constraints), p, pm)] 77 | # new_nset = [(n, nm) for n, nm in nset if 78 | # covers(h.union(constraints), n, nm)] 79 | return new_pset, new_nset 80 | 81 | 82 | def covers(h, x, initial_mapping): 83 | """ 84 | Returns true if h covers x 85 | """ 86 | index = build_index(x) 87 | operator = Operator(tuple(['Rule']), h, []) 88 | for m in operator.match(index, initial_mapping=initial_mapping): 89 | return True 90 | return False 91 | 92 | 93 | def rename(mapping, literal): 94 | """ 95 | Given a mapping, renames the literal. Unlike subst, this works with 96 | constants as well as variables. 97 | """ 98 | return tuple(mapping[ele] if ele in mapping else rename(mapping, ele) if 99 | isinstance(ele, tuple) else ele for ele in literal) 100 | 101 | 102 | def generate_literal(relation, arity, gensym): 103 | """ 104 | Returns a new literal with novel variables. 105 | """ 106 | return (relation,) + tuple(gensym() for i in range(arity)) 107 | 108 | 109 | def generalize_literal(literal, gensym): 110 | """ 111 | This takes a literal and returns the most general version of it possible. 112 | i.e., a version that has all the values replaced with new veriables. 113 | """ 114 | return (literal[0],) + tuple(ele if is_variable(ele) else 115 | # '?gen%s' % hash(ele) 116 | gensym() 117 | for ele in literal[1:]) 118 | 119 | 120 | def remove_vars(literal): 121 | """ 122 | This removes all variables by putting XXX at the front of the string, so it 123 | cannot be unified anymore. 124 | """ 125 | return tuple('XXX' + ele if is_variable(ele) else remove_vars(ele) if 126 | isinstance(ele, tuple) else ele for ele in literal) 127 | 128 | 129 | def clause_length(clause): 130 | """ 131 | Counts the length of a clause. In particular, it counts number of 132 | relations, constants, and variable equality relations. 133 | """ 134 | var_counts = {} 135 | count = 0 136 | 137 | for l in clause: 138 | count += count_elements(l, var_counts) 139 | 140 | for v in var_counts: 141 | count += var_counts[v] - 1 142 | 143 | return count 144 | 145 | 146 | def count_elements(x, var_counts): 147 | """ 148 | Counts the number of constants and keeps track of variable occurnaces. 149 | """ 150 | if x is None: 151 | return 0 152 | 153 | c = 0 154 | if isinstance(x, tuple): 155 | c = sum([count_elements(ele, var_counts) for ele in x]) 156 | elif is_variable(x): 157 | if x not in var_counts: 158 | var_counts[x] = 0 159 | var_counts[x] += 1 160 | else: 161 | c = 1 162 | return c 163 | 164 | 165 | 166 | # class EmptySensitiveDictVectorizer(TransformerMixin, BaseEstimator): 167 | # ''' Like DictVectorizer but adds a field for whether or 168 | -------------------------------------------------------------------------------- /apprentice/learners/when_learners/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apprenticelearner/AL_Core/9534e3e1d5f5439740e0c44196a4a4f91602aed3/apprentice/learners/when_learners/__init__.py -------------------------------------------------------------------------------- /apprentice/learners/when_learners/actor_critic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ValueNet(nn.Module): 6 | """ 7 | The part of the actor critic network that computes the state value. Also, 8 | returns the hidden layer before state valuation, for use in action network. 9 | """ 10 | 11 | def __init__(self, n_inputs: int, n_hidden: int = None): 12 | """ 13 | Specify the number of inputs. Also, specify the number of nodes in each 14 | hidden layer. If no value is provided for the number of hidden, then 15 | it is set to half the number of inputs. 16 | """ 17 | super(ValueNet, self).__init__() 18 | 19 | if n_hidden is None: 20 | n_hidden = (n_inputs + 2) // 2 21 | 22 | self.n_hidden = n_hidden 23 | 24 | self.hidden = nn.Sequential( 25 | nn.Linear(n_inputs, n_hidden), 26 | nn.ReLU() 27 | ) 28 | 29 | self.value = nn.Linear(n_hidden, 1) 30 | 31 | def forward(self, x): 32 | """ 33 | Returns the value of the state and the hidden layer values. 34 | """ 35 | x = self.hidden(x) 36 | return self.value(x), x 37 | 38 | 39 | class ActionNet(nn.Module): 40 | """ 41 | The part of the actor critic network that computes the action value. 42 | """ 43 | 44 | def __init__(self, n_action_inputs: int, n_value_hidden: int, 45 | n_action_hidden: int = None): 46 | """ 47 | Takes as input the action features and the hidden values from the value 48 | net. Returns a value for the action. 49 | """ 50 | super(ActionNet, self).__init__() 51 | 52 | if n_action_hidden is None: 53 | n_action_hidden = (n_action_inputs + n_value_hidden + 2) // 2 54 | 55 | self.hidden = nn.Sequential( 56 | nn.Linear(n_action_inputs + n_value_hidden, n_action_hidden), 57 | nn.ReLU() 58 | ) 59 | 60 | self.action_value = nn.Linear(n_action_hidden, 1) 61 | 62 | def forward(self, action_x, value_hidden): 63 | """ 64 | Returns the value of the state and the hidden layer values. 65 | """ 66 | x = self.hidden(torch.cat((action_x, value_hidden), 1)) 67 | return self.action_value(x) 68 | -------------------------------------------------------------------------------- /apprentice/learners/when_learners/actor_critic_learner.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Collection 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from sklearn.feature_extraction import FeatureHasher 8 | 9 | from apprentice.learners.WhenLearner import WhenLearner 10 | from apprentice.working_memory.representation import Activation 11 | from apprentice.learners.when_learners.actor_critic import ACValueNet 12 | from apprentice.learners.when_learners.actor_critic import ACActionNet 13 | from apprentice.learners.when_learners.replay_memory import Transition 14 | 15 | 16 | class ActorCriticLearner(WhenLearner): 17 | def __init__(self, gamma=0.9, lr=1e-3, state_size=1000, action_size=1000, 18 | hidden_size=200): 19 | self.device = torch.device("cuda" if torch.cuda.is_available() else 20 | "cpu") 21 | self.gamma = gamma 22 | self.lr = lr 23 | 24 | self.state_size = state_size 25 | self.action_size = action_size 26 | self.hidden_size = hidden_size 27 | 28 | self.state_hasher = FeatureHasher(n_features=self.state_size) 29 | self.action_hasher = FeatureHasher(n_features=self.action_size) 30 | self.value_net = ACValueNet(self.state_size, self.hidden_size) 31 | self.action_net = ACActionNet(self.action_size, self.hidden_size, 32 | self.hidden_size) 33 | 34 | params = (list(self.value_net.parameters()) + 35 | list(self.action_net.parameters())) 36 | self.optimizer = torch.optim.Adam(params, lr=self.lr) 37 | 38 | def gen_state_vector(self, state: dict) -> np.ndarray: 39 | state = {str(a): state[a] for a in state} 40 | return self.state_hasher.transform([state]).toarray() 41 | 42 | def gen_action_vectors( 43 | self, actions: Collection[Activation]) -> np.ndarray: 44 | 45 | action_dicts = [] 46 | for action in actions: 47 | act_d = {} 48 | name = action.get_rule_name() 49 | act_d['rulename'] = name 50 | bindings = action.get_rule_bindings() 51 | for a, v in bindings.items(): 52 | if isinstance(v, bool): 53 | act_d[str(a)] = str(v) 54 | else: 55 | act_d[str(a)] = v 56 | action_dicts.append(act_d) 57 | 58 | return self.action_hasher.transform(action_dicts).toarray() 59 | 60 | def eval_all(self, state: dict, 61 | actions: Collection[Activation]) -> Collection[float]: 62 | pass 63 | 64 | def eval(self, state: dict, action: Activation) -> float: 65 | if state is None: 66 | return 0 67 | 68 | state_x = torch.from_numpy( 69 | self.gen_state_vector(state)).float().to(self.device) 70 | action_x = torch.from_numpy( 71 | self.gen_action_vectors([action])).float().to(self.device) 72 | 73 | with torch.no_grad(): 74 | state_val, state_hidden = self.value_net(state_x) 75 | action_val = self.action_net(action_x, state_hidden) 76 | return action_val[0].cpu().item() 77 | 78 | def update( 79 | self, 80 | state: dict, 81 | action: Activation, 82 | reward: float, 83 | next_state: dict, 84 | next_actions: Collection[Activation], 85 | ) -> None: 86 | return 87 | 88 | sa = self.generate_vector(state, action) 89 | if len(next_actions) == 0: 90 | next_sa = None 91 | else: 92 | next_sa = np.stack((self.generate_vector(next_state, 93 | next_actions[i]) 94 | for i in range(len(next_actions)))) 95 | 96 | # print("REWARD") 97 | # print(reward) 98 | # print("NEXT SAs") 99 | # print(next_sa.shape) 100 | # print() 101 | 102 | self.replay_memory.push( 103 | torch.from_numpy(sa).float().to(self.device), 104 | torch.tensor([reward]).to(self.device), 105 | torch.from_numpy(next_sa).float().to(self.device)) 106 | 107 | self.train() 108 | 109 | 110 | -------------------------------------------------------------------------------- /apprentice/learners/when_learners/dqn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class DQN(nn.Module): 5 | """ 6 | A DQN Architecture that takes separate inputs representing a state-action 7 | pair returns a single value estimate. 8 | """ 9 | 10 | def __init__(self, n_inputs: int, n_hidden: int = None): 11 | """ 12 | Specify the number of inputs. Also, specify the number of nodes in each 13 | hidden layer. If no value is provided for the number of hidden, then 14 | it is set to half the number of inputs. 15 | """ 16 | super(DQN, self).__init__() 17 | 18 | if n_hidden is None: 19 | n_hidden = (n_inputs + 2) // 2 20 | 21 | self.value = nn.Sequential( 22 | nn.Linear(n_inputs, n_hidden), 23 | nn.ReLU(), 24 | nn.Linear(n_hidden, 1) 25 | ) 26 | 27 | def forward(self, x): 28 | return self.value(x) 29 | -------------------------------------------------------------------------------- /apprentice/learners/when_learners/fractions_hasher.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from sklearn.feature_extraction import DictVectorizer 4 | 5 | 6 | class FractionsStateHasher(object): 7 | 8 | def __init__(self): 9 | state_extremes = [{"('contentEditable', 'JCommTable.R0C0')": 'False', 10 | "('contentEditable', 'JCommTable.R1C0')": 'False', 11 | "('contentEditable', 'JCommTable2.R0C0')": 'False', 12 | "('contentEditable', 'JCommTable3.R0C0')": 'False', 13 | "('contentEditable', 'JCommTable3.R1C0')": 'False', 14 | "('contentEditable', 'JCommTable4.R0C0')": 'False', 15 | "('contentEditable', 'JCommTable4.R1C0')": 'False', 16 | "('contentEditable', 'JCommTable5.R0C0')": 'False', 17 | "('contentEditable', 'JCommTable5.R1C0')": 'False', 18 | "('contentEditable', 'JCommTable6.R0C0')": 'False', 19 | "('contentEditable', 'JCommTable6.R1C0')": 'False', 20 | "('contentEditable', 'JCommTable7.R0C0')": 'False', 21 | "('contentEditable', 'JCommTable8.R0C0')": 'False', 22 | "('value', 'JCommTable2.R0C0')": '*', 23 | "('value', 'JCommTable7.R0C0')": '*', 24 | "('r_val', 'equal(JCommTable.R0C0, JCommTable.R0C0)')": 'True', 25 | }, 26 | {"('contentEditable', 'JCommTable.R0C0')": 'True', 27 | "('contentEditable', 'JCommTable.R1C0')": 'True', 28 | "('contentEditable', 'JCommTable2.R0C0')": 'True', 29 | "('contentEditable', 'JCommTable3.R0C0')": 'True', 30 | "('contentEditable', 'JCommTable3.R1C0')": 'True', 31 | "('contentEditable', 'JCommTable4.R0C0')": 'True', 32 | "('contentEditable', 'JCommTable4.R1C0')": 'True', 33 | "('contentEditable', 'JCommTable5.R0C0')": 'True', 34 | "('contentEditable', 'JCommTable5.R1C0')": 'True', 35 | "('contentEditable', 'JCommTable6.R0C0')": 'True', 36 | "('contentEditable', 'JCommTable6.R1C0')": 'True', 37 | "('contentEditable', 'JCommTable7.R0C0')": 'True', 38 | "('contentEditable', 'JCommTable8.R0C0')": 'True', 39 | "('value', 'JCommTable2.R0C0')": '+', 40 | "('value', 'JCommTable7.R0C0')": '+', 41 | }] 42 | 43 | fields = ["JCommTable.R0C0", "JCommTable.R1C0", "JCommTable2.R0C0", 44 | "JCommTable3.R0C0", "JCommTable3.R1C0", "JCommTable4.R0C0", 45 | "JCommTable4.R1C0", "JCommTable5.R0C0", "JCommTable5.R1C0", 46 | "JCommTable6.R0C0", "JCommTable6.R1C0", "JCommTable7.R0C0", 47 | "JCommTable8.R0C0"] 48 | 49 | for field1 in fields: 50 | for field2 in fields: 51 | if field1 > field2: 52 | continue 53 | state_extremes[0][ 54 | "('r_val', 'equal(" + field1 + ", " + field2 + ")')"] = 'False' 55 | state_extremes[0][ 56 | "('contentEditable', 'add(" + field1 + ", " + field2 + ")')"] = 'False' 57 | state_extremes[0][ 58 | "('contentEditable', 'multiply(" + field1 + ", " + field2 + ")')"] = 'False' 59 | 60 | state_extremes[1][ 61 | "('r_val', 'equal(" + field1 + ", " + field2 + ")')"] = 'True' 62 | 63 | # from pprint import pprint 64 | # pprint(state_extremes) 65 | # raise Exception("BEEP") 66 | 67 | self.dv = DictVectorizer() 68 | self.dv.fit(state_extremes) 69 | 70 | def transform(self, states: List[dict]): 71 | # t = self.dv.transform(states) 72 | # from pprint import pprint 73 | # print("BEFORE") 74 | # pprint(states) 75 | # print('AFTER') 76 | # pprint(self.dv.inverse_transform(t)) 77 | # print() 78 | 79 | return self.dv.transform(states) 80 | 81 | 82 | class FractionsActionHasher(object): 83 | 84 | def __init__(self): 85 | action_extremes = [ 86 | # { 87 | # 'fact-0: id': 'JCommTable6.R1C0', 88 | # 'fact-1: id': 'JCommTable.R1C0', 89 | # 'rulename': 'update_answer_field' 90 | # } 91 | ] 92 | 93 | operators = ['multiply', 'add'] 94 | for o in operators: 95 | action_extremes.append({'fact-0: operator': o, 96 | 'fact-1: operator': o}) 97 | 98 | fields = ["JCommTable.R0C0", "JCommTable.R1C0", "JCommTable2.R0C0", 99 | "JCommTable3.R0C0", "JCommTable3.R1C0", "JCommTable4.R0C0", 100 | "JCommTable4.R1C0", "JCommTable5.R0C0", "JCommTable5.R1C0", 101 | "JCommTable6.R0C0", "JCommTable6.R1C0", "JCommTable7.R0C0", 102 | "JCommTable8.R0C0"] 103 | 104 | for f in fields: 105 | action_extremes.append({'fact-0: id': f, 106 | 'fact-1: id': f, 107 | 'fact-0: ele1': f, 108 | 'fact-1: ele2': f, 109 | }) 110 | 111 | prods = [] 112 | for f1 in fields: 113 | for f2 in fields: 114 | if f1 > f2: 115 | continue 116 | prods.append('add(' + f1 + ', ' + f2 + ')') 117 | prods.append('multiply(' + f1 + ', ' + f2 + ')') 118 | 119 | for p in prods: 120 | action_extremes.append({'fact-0: id': p}) 121 | 122 | rule_names = [ 123 | "click_done", 124 | "check", 125 | "equal", 126 | "update_answer", 127 | "update_convert", 128 | "add", 129 | "multiply", 130 | "correct_multiply_num", 131 | "correct_multiply_denom", 132 | "correct_done", 133 | "correct_add_same_num", 134 | "correct_copy_same_denom", 135 | "correct_check", 136 | "correct_convert_num1", 137 | "correct_convert_num2", 138 | "correct_convert_denom1", 139 | "correct_convert_denom2", 140 | "correct_add_convert_num", 141 | "correct_copy_convert_denom", 142 | ] 143 | for name in rule_names: 144 | action_extremes.append({'rulename': name}) 145 | 146 | self.dv = DictVectorizer() 147 | self.dv.fit(action_extremes) 148 | 149 | def transform(self, actions: List[dict]): 150 | # t = self.dv.transform(actions) 151 | # from pprint import pprint 152 | # print("BEFORE") 153 | # pprint(actions) 154 | # print('AFTER') 155 | # pprint(self.dv.inverse_transform(t)) 156 | # print(t.shape) 157 | # print() 158 | 159 | return self.dv.transform(actions) 160 | -------------------------------------------------------------------------------- /apprentice/learners/when_learners/q_learner.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Collection 3 | 4 | from apprentice.learners.WhenLearner import WhenLearner 5 | from apprentice.working_memory.representation import Activation 6 | 7 | # from concept_formation.trestle import TrestleTree 8 | from sklearn.feature_extraction import DictVectorizer 9 | from sklearn.linear_model import LinearRegression 10 | from sklearn.linear_model import SGDRegressor 11 | from concept_formation.cobweb3 import Cobweb3Tree 12 | 13 | 14 | class QLearner(WhenLearner): 15 | def __init__(self, q_init=0.0, discount=0.99, learning_rate=0.9, 16 | func=None): 17 | self.func = func 18 | if self.func is None: 19 | self.func = Tabular 20 | 21 | self.Q = {} 22 | self.q_init = q_init 23 | self.discount = discount 24 | self.learning_rate = learning_rate 25 | 26 | def evaluate(self, state: dict, action: Activation) -> float: 27 | if state is None: 28 | return 0 29 | name = action.get_rule_name() 30 | bindings = action.get_rule_bindings() 31 | if name not in self.Q: 32 | return self.q_init 33 | state = deepcopy(state) 34 | for a, v in bindings.items(): 35 | if isinstance(v, bool): 36 | state[('ACTION_FEATURE', a)] = str(v) 37 | else: 38 | state[('ACTION_FEATURE', a)] = v 39 | return self.Q[name].get_q(state) 40 | 41 | def update( 42 | self, 43 | state: dict, 44 | action: Activation, 45 | reward: float, 46 | next_state: dict, 47 | next_actions: Collection[Activation], 48 | ) -> None: 49 | 50 | q_next_est = 0 51 | if len(next_actions) != 0: 52 | q_next_est = max((self.evaluate(next_state, a) 53 | for a in next_actions)) 54 | 55 | # print('q_next_est', q_next_est) 56 | # print('updating %s' % action.get_rule_name(), 57 | # action.get_rule_bindings()) 58 | 59 | learned_reward = reward + self.discount * q_next_est 60 | name = action.get_rule_name() 61 | bindings = action.get_rule_bindings() 62 | state = deepcopy(state) 63 | for a, v in bindings.items(): 64 | if isinstance(v, bool): 65 | state[('ACTION_FEATURE', a)] = str(v) 66 | else: 67 | state[('ACTION_FEATURE', a)] = v 68 | if name not in self.Q: 69 | self.Q[name] = self.func(q_init=self.q_init, 70 | learning_rate=self.learning_rate) 71 | 72 | #from pprint import pprint 73 | #pprint(state) 74 | # if name == "update_field": 75 | # pprint(state) 76 | # print(reward) 77 | # print(learned_reward) 78 | 79 | self.Q[name].update(state, learned_reward) 80 | 81 | 82 | class Tabular: 83 | def __init__(self, q_init=0.6, learning_rate=0.9): 84 | self.row = {} 85 | self.q_init = q_init 86 | self.alpha = learning_rate 87 | 88 | def update(self, state, learned_reward): 89 | s = frozenset(state) 90 | if s not in self.row: 91 | self.row[s] = self.q_init 92 | 93 | self.row[s] = (1 - self.alpha) * self.row[s] + \ 94 | self.alpha * learned_reward 95 | 96 | def get_q(self, state): 97 | s = frozenset(state) 98 | if s not in self.row: 99 | self.row[s] = self.q_init 100 | return self.row[s] 101 | 102 | def __str__(self): 103 | return str(self.row) 104 | 105 | 106 | class LinearFunc: 107 | def __init__(self, q_init=0, learning_rate=0.9): 108 | # self.clf = LinearRegression() 109 | self.clf = SGDRegressor(shuffle=False, max_iter=1, 110 | learning_rate="constant", eta0=learning_rate) 111 | self.dv = DictVectorizer(sort=False) 112 | self.X = [] 113 | self.Y = [] 114 | self.q_init = q_init 115 | 116 | def update(self, state, learned_reward): 117 | # from pprint import pprint 118 | # pprint(state) 119 | # print() 120 | self.X.append(state) 121 | self.Y.append(learned_reward) 122 | # print('training on', self.Y) 123 | self.clf.fit(self.dv.fit_transform(self.X), self.Y) 124 | 125 | def get_q(self, state): 126 | if len(self.X) == 0: 127 | return self.q_init 128 | x = self.dv.transform([state]) 129 | return self.clf.predict(x)[0] 130 | 131 | 132 | class Cobweb: 133 | def __init__(self, q_init=0, learning_rate=0): 134 | self.tree = Cobweb3Tree() 135 | 136 | def update(self, state, learned_reward): 137 | x = deepcopy(state) 138 | x['_q'] = float(learned_reward) 139 | self.tree.ifit(x) 140 | 141 | def get_q(self, state): 142 | return self.tree.categorize(state).predict('_q') 143 | -------------------------------------------------------------------------------- /apprentice/learners/when_learners/replay_memory.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import namedtuple 3 | 4 | # Transition = namedtuple( 5 | # 'Transition', ('state_action', 'reward', 'next_state_actions')) 6 | 7 | Transition = namedtuple( 8 | 'Transition', ('state', 'action', 'reward', 'next_state', 'next_actions')) 9 | 10 | 11 | class ReplayMemory(object): 12 | 13 | def __init__(self, capacity): 14 | """ 15 | Constructor. 16 | """ 17 | self.capacity = capacity 18 | self.memory = [] 19 | self.position = 0 20 | 21 | def push(self, *args): 22 | """ 23 | Push an experience into the replay memory. 24 | """ 25 | if len(self.memory) < self.capacity: 26 | self.memory.append(None) 27 | self.memory[self.position] = Transition(*args) 28 | self.position = (self.position + 1) % self.capacity 29 | 30 | def set_capacity(self, capacity): 31 | """ 32 | Update the memory (either increase or decrease) the memory to the 33 | specified capacity. 34 | """ 35 | self.capacity = capacity 36 | self.memory = self.memory[:capacity] 37 | if len(self.memory) < self.capacity: 38 | self.memory.append(None) 39 | if self.position >= capacity: 40 | self.position = 0 41 | 42 | def sample(self, batch_size): 43 | """ 44 | Retreive a set of samples for training. 45 | """ 46 | return random.sample(self.memory, batch_size) 47 | 48 | def __len__(self): 49 | """ 50 | Size of memory 51 | """ 52 | return len(self.memory) 53 | -------------------------------------------------------------------------------- /apprentice/logging.yaml: -------------------------------------------------------------------------------- 1 | # see: https://docs.python.org/2/library/logging.config.html#logging-config-dictschema 2 | # and https://stackoverflow.com/questions/45465510/using-logging-with-coloredlogs 3 | 4 | version: 1 5 | disable_existing_loggers: True 6 | 7 | formatters: 8 | standard: 9 | '()': 'coloredlogs.ColoredFormatter' 10 | format: '%(name)s:%(lineno)s | %(message)s' 11 | standard_colored: 12 | '()': 'coloredlogs.ColoredFormatter' 13 | format: '%(name)s:%(lineno)s | %(message)s' 14 | error: 15 | format: "%(levelname)s %(name)s.%(funcName)s(): %(message)s" 16 | 17 | handlers: 18 | console: 19 | class: logging.StreamHandler 20 | level: ERROR #This seems to change the level globally, which is very annoying esp. w/ numba 21 | formatter: standard 22 | stream: ext://sys.stdout 23 | 24 | # info_file_handler: 25 | # class: logging.handlers.RotatingFileHandler 26 | # level: INFO 27 | # formatter: standard 28 | # filename: /tmp/info.log 29 | # maxBytes: 10485760 # 10MB 30 | # backupCount: 20 31 | # encoding: utf8 32 | # 33 | # error_file_handler: 34 | # class: logging.handlers.RotatingFileHandler 35 | # level: ERROR 36 | # formatter: error 37 | # filename: /tmp/errors.log 38 | # maxBytes: 10485760 # 10MB 39 | # backupCount: 20 40 | # encoding: utf8 41 | # 42 | # debug_file_handler: 43 | # class: logging.handlers.RotatingFileHandler 44 | # level: DEBUG 45 | # formatter: standard 46 | # filename: /tmp/debug.log 47 | # maxBytes: 10485760 # 10MB 48 | # backupCount: 20 49 | # encoding: utf8 50 | # 51 | # critical_file_handler: 52 | # class: logging.handlers.RotatingFileHandler 53 | # level: CRITICAL 54 | # formatter: standard 55 | # filename: /tmp/critical.log 56 | # maxBytes: 10485760 # 10MB 57 | # backupCount: 20 58 | # encoding: utf8 59 | # 60 | # warn_file_handler: 61 | # class: logging.handlers.RotatingFileHandler 62 | # level: WARN 63 | # formatter: standard 64 | # filename: /tmp/warn.log 65 | # maxBytes: 10485760 # 10MB 66 | # backupCount: 20 67 | # encoding: utf8 68 | 69 | root: 70 | level: NOTSET 71 | handlers: [console] 72 | propogate: yes 73 | 74 | loggers: 75 | : 76 | level: INFO 77 | handlers: [console] 78 | #, info_file_handler, error_file_handler, critical_file_handler, debug_file_handler, warn_file_handler] 79 | propogate: no 80 | 81 | # : 82 | # level: DEBUG 83 | # handlers: [info_file_handler, error_file_handler, critical_file_handler, debug_file_handler, warn_file_handler] 84 | # propogate: yes 85 | -------------------------------------------------------------------------------- /apprentice/planners/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apprenticelearner/AL_Core/9534e3e1d5f5439740e0c44196a4a4f91602aed3/apprentice/planners/__init__.py -------------------------------------------------------------------------------- /apprentice/planners/base_planner.py: -------------------------------------------------------------------------------- 1 | PLANNERS = {} 2 | 3 | 4 | class BasePlanner(object): 5 | def how_search(self,state,sai): 6 | raise NotImplementedError() 7 | def apply_featureset(self,state): 8 | raise NotImplementedError() 9 | def eval_expression(self,x,mapping,state): 10 | raise NotImplementedError() 11 | def resolve_operators(operators): 12 | raise NotImplementedError() 13 | def unify_op(self,state,op,sai,foci_of_attention=None): 14 | raise NotImplementedError() 15 | 16 | 17 | def get_planner_class(name): 18 | if(name == "vectorized"): 19 | from apprentice.planners.VectorizedPlanner import VectorizedPlanner 20 | name = name.lower().replace(' ', '').replace('_', '') 21 | return PLANNERS[name] 22 | 23 | 24 | # from planners.VectorizedPlanner import VectorizedPlanner 25 | # from planners.fo_planner import FoPlanner 26 | -------------------------------------------------------------------------------- /apprentice/working_memory/__init__.py: -------------------------------------------------------------------------------- 1 | #from .adapters.experta_.workingmemory import ExpertaWorkingMemory -------------------------------------------------------------------------------- /apprentice/working_memory/adapters/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apprenticelearner/AL_Core/9534e3e1d5f5439740e0c44196a4a4f91602aed3/apprentice/working_memory/adapters/__init__.py -------------------------------------------------------------------------------- /apprentice/working_memory/adapters/experta_/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apprenticelearner/AL_Core/9534e3e1d5f5439740e0c44196a4a4f91602aed3/apprentice/working_memory/adapters/experta_/__init__.py -------------------------------------------------------------------------------- /apprentice/working_memory/adapters/experta_/factory.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from typing import Any 3 | from typing import Callable 4 | 5 | import experta as ex 6 | 7 | from apprentice.working_memory.representation import Skill, Activation 8 | 9 | class Factory: 10 | pass 11 | 12 | 13 | class ExpertaFactFactory(Factory): 14 | def build(self, _dict: dict) -> ex.Fact: 15 | return ex.Fact(**dict) 16 | 17 | def from_ex_fact(self, _fact: ex.Fact) -> dict: 18 | return _fact.as_dict() 19 | 20 | def to_ex_fact(self, _dict: dict) -> ex.Fact: 21 | return self.build(_dict) 22 | 23 | 24 | class ExpertaConditionFactory(Factory): 25 | def build(self, _tuple: tuple) -> Any: 26 | return _tuple 27 | 28 | def validate(self): 29 | raise NotImplementedError 30 | 31 | def from_ex_condition(self, 32 | ex_condition: 33 | tuple) -> \ 34 | Any: 35 | 36 | def c2r(c): 37 | if type(c) is tuple: 38 | return tuple(c2r(_) for _ in c) 39 | if isinstance(c, tuple): 40 | r = tuple(c2r(_) for _ in c) 41 | # print('==>', c.__class__, ' with args ', r, flush=True) 42 | return c.__class__(*r) 43 | if isinstance(c, ex.Fact): 44 | return c.as_dict() 45 | if callable(c): 46 | return c 47 | assert False 48 | 49 | return c2r(ex_condition) 50 | 51 | def to_ex_condition(self, 52 | ex_rule: 53 | ex.Rule) -> \ 54 | Any: 55 | def r2c(c): 56 | if type(c) is tuple: 57 | return tuple(r2c(_) for _ in c) 58 | if isinstance(c, tuple): 59 | r = tuple(r2c(_) for _ in c) 60 | # print('==>', c.__class__, ' with args ', r, flush=True) 61 | return c.__class__(*r) 62 | if isinstance(c, dict): 63 | return ex.Fact(**c) 64 | if callable(c): 65 | return c 66 | assert False 67 | 68 | return r2c(ex_rule) 69 | 70 | 71 | class ExpertaSkillFactory(Factory): 72 | def __init__(self, _ke: ex.KnowledgeEngine): 73 | self._ke = _ke 74 | self.condition_factory = ExpertaConditionFactory() 75 | 76 | def build(self, _condition: Any, 77 | _function: Callable, 78 | _name: str = None) -> Skill: 79 | 80 | s = Skill(_condition, _function)#, _name) 81 | 82 | return s 83 | 84 | def from_ex_rule(self, _rule: ex.Rule) -> Skill: 85 | return self.build(_rule._args, 86 | _rule._wrapped, 87 | _name=_rule._wrapped.__name__) 88 | 89 | def to_ex_rule(self, _skill: Skill) -> ex.Rule(): 90 | cond = self.condition_factory.to_ex_condition(_skill.conditions) 91 | rule = ex.Rule.__new__(ex.Rule, *cond)(_skill.function_) 92 | rule.ke = self._ke 93 | rule._wrapped_self = self._ke 94 | return rule 95 | 96 | 97 | class ExpertaActivationFactory(Factory): 98 | def __init__(self, _ke: ex.KnowledgeEngine): 99 | self._ke = _ke 100 | self.skill_factory = ExpertaSkillFactory(_ke) 101 | 102 | def build(self, _skill: Skill, _context: dict): 103 | return Activation(_skill, _context) 104 | 105 | def from_ex_activation(self, 106 | ex_activation: ex.activation.Activation) -> \ 107 | Activation: 108 | return self.build( 109 | _skill=self.skill_factory.from_ex_rule( 110 | ex_activation.rule), 111 | _context=ex_activation.context) 112 | 113 | def to_ex_activation(self, 114 | _activation: Activation) -> ex.activation.Activation: 115 | return ex.activation.Activation( 116 | self.skill_factory.to_ex_rule(_activation.skill), 117 | set(_activation.context.values()), _activation.context) 118 | 119 | def fire(self): 120 | self.base.fire() 121 | 122 | -------------------------------------------------------------------------------- /apprentice/working_memory/adapters/experta_/workingmemory.py: -------------------------------------------------------------------------------- 1 | import jsondiff 2 | import experta as ex 3 | 4 | from apprentice.working_memory.base import WorkingMemory 5 | from apprentice.working_memory.representation import Skill 6 | 7 | from .factory import ( 8 | ExpertaSkillFactory, 9 | ExpertaConditionFactory, 10 | ExpertaActivationFactory, 11 | ) 12 | 13 | 14 | class ExpertaWorkingMemory(WorkingMemory): 15 | def __init__(self, ke=None, reset=True): 16 | if ke is None: 17 | ke = ex.engine.KnowledgeEngine() 18 | self.ke = ke 19 | if reset: 20 | self.ke.reset() 21 | self.skill_factory = ExpertaSkillFactory(ke) 22 | self.activation_factory = ExpertaActivationFactory(ke) 23 | self.condition_factory = ExpertaConditionFactory() 24 | super().__init__() 25 | 26 | def step(self): 27 | self.ke.step() 28 | 29 | def output(self): 30 | raise NotImplementedError 31 | 32 | @property 33 | def facts(self): 34 | labeled_facts = {k: v.as_dict() for k, v in self.ke.facts.items()} 35 | return labeled_facts 36 | 37 | # f in self.ke.facts.values()] 38 | 39 | @property 40 | def state(self): 41 | factlist = [] 42 | for fact in self.ke.facts.values(): 43 | f = {} 44 | for k, v in fact.as_dict().items(): 45 | if ex.Fact.is_special(k): 46 | continue 47 | if isinstance(v, bool): 48 | f[k] = str(v) 49 | else: 50 | f[k] = v 51 | factlist.append(f) 52 | 53 | # from pprint import pprint 54 | # pprint(factlist) 55 | 56 | state = {} 57 | # for fact in factlist: 58 | # state[tuple(sorted("%s=%s" % (k, v) 59 | # for k, v in fact.items()))] = True 60 | 61 | for fact in factlist: 62 | if 'id' not in fact: 63 | continue 64 | for k, v in fact.items(): 65 | if k == 'id': 66 | continue 67 | state[(k, fact['id'])] = v 68 | 69 | # for i, fact in enumerate(sorted(factlist, key=lambda d: 70 | # sorted(d.items()))): 71 | # for k, v in fact.items(): 72 | # state["{0}_{1}".format(str(k), str(i))] = v 73 | 74 | return state 75 | 76 | def add_fact(self, key: object, fact: dict) -> None: 77 | f = ex.Fact(**fact) 78 | self.ke.declare(f) 79 | self.lookup[key] = f 80 | 81 | def remove_fact(self, key: object) -> bool: 82 | if key not in self.lookup: 83 | return False 84 | f = self.lookup[key] 85 | self.ke.retract(f) 86 | del self.lookup[key] 87 | 88 | def update_fact(self, key: object, diff: dict) -> None: 89 | old_fact = self.lookup[key] 90 | new_fact = apply_diff_to_fact(old_fact, diff) 91 | self.ke.retract(old_fact) 92 | self.ke.declare(new_fact) 93 | self.lookup[key] = new_fact 94 | 95 | @property 96 | def skills(self): 97 | for rule in self.ke.get_rules(): 98 | yield self.skill_factory.from_ex_rule(rule) 99 | 100 | def add_skill(self, skill: Skill): 101 | rule = self.skill_factory.to_ex_rule(skill) 102 | self.add_rule(rule) 103 | 104 | def add_rule(self, rule: ex.Rule): 105 | setattr(self.ke, rule._wrapped.__name__, rule) 106 | rule.ke = self.ke 107 | rule._wrapped_self = self.ke 108 | self.ke.matcher.__init__(self.ke) 109 | self.ke.reset() #todo: not sure if this is necessary 110 | 111 | def update_skill(self, skill: Skill): 112 | self.add_skill(skill) 113 | 114 | @property 115 | def activations(self): 116 | for a in self.ke.agenda.activations: 117 | # for a in self.ke.get_activations()[0]: 118 | yield self.activation_factory.from_ex_activation(a) 119 | 120 | def run(self): 121 | self.ke.run() 122 | 123 | 124 | def apply_diff_to_fact(fact: ex.Fact, diff: dict) -> ex.Fact: 125 | if jsondiff.symbols.replace in diff: 126 | return ex.Fact(**diff[jsondiff.symbols.replace]) 127 | 128 | new_fact = {} 129 | for k in fact: 130 | if (jsondiff.symbols.delete in diff and 131 | k in diff[jsondiff.symbols.delete]): 132 | continue 133 | new_fact[k] = fact[k] 134 | 135 | for k in diff: 136 | if k is not jsondiff.symbols.delete: 137 | new_fact[k] = diff[k] 138 | 139 | return ex.Fact(**new_fact) 140 | -------------------------------------------------------------------------------- /apprentice/working_memory/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | from abc import abstractmethod 3 | from typing import Any 4 | from typing import Callable 5 | from typing import Collection 6 | from typing import Dict 7 | 8 | 9 | import jsondiff 10 | from apprentice.working_memory.representation import Activation 11 | from apprentice.working_memory.representation import Skill 12 | 13 | 14 | class WorkingMemory(metaclass=ABCMeta): 15 | """ 16 | Abstract base class for working memory 17 | """ 18 | def __init__(self, ke=None, reset=True): 19 | self.lookup = {} 20 | 21 | def build_skill(self, _condition: Any, 22 | _function: Callable) -> Skill: 23 | return self.skill_factory.build(_condition, _function) 24 | 25 | def update(self, diff: Dict) -> None: 26 | """ 27 | Updates the working memory based on the provided diff. 28 | 29 | Currently, making a simplifying assumption that: 30 | - the state has the format {: {fact values}, ...} 31 | - facts are not nested 32 | 33 | :param diff: a diff object generated by JSON diff 34 | """ 35 | for k in diff: 36 | if k is jsondiff.symbols.replace: 37 | keys = [k2 for k2 in self.lookup] 38 | for k2 in keys: 39 | self.remove_fact(k2) 40 | for k2, v in diff[k].items(): 41 | self.add_fact(k2, v) 42 | elif k is jsondiff.symbols.delete: 43 | for k2 in diff[k]: 44 | self.remove_fact(k2) 45 | elif k in self.lookup: 46 | self.update_fact(k, diff[k]) 47 | else: 48 | self.add_fact(k, diff[k]) 49 | 50 | @property 51 | @abstractmethod 52 | def facts(self) -> Collection[dict]: 53 | """ 54 | Get the facts currently in working memory. 55 | """ 56 | pass 57 | 58 | @property 59 | @abstractmethod 60 | def skills(self) -> Collection[Skill]: 61 | """ 62 | Get the skills currently in working memory. 63 | """ 64 | pass 65 | 66 | def add_facts(self, facts: dict) -> None: 67 | for key, fact in facts.items(): 68 | self.add_fact(key, fact) 69 | 70 | @abstractmethod 71 | def add_fact(self, key: object, fact: dict) -> None: 72 | """ 73 | Add a fact to working memory 74 | 75 | :param key: a hashable key for storing the fact 76 | :param fact: the fact to be added 77 | """ 78 | pass 79 | 80 | @abstractmethod 81 | def remove_fact(self, key: object) -> bool: 82 | """ 83 | Remove a fact from working memory 84 | 85 | :param fact: the fact to be added 86 | :returns false if fact does not exist, true if successfully removed 87 | """ 88 | pass 89 | 90 | @abstractmethod 91 | def update_fact(self, key: object, diff: dict) -> None: 92 | """ 93 | Update a fact in working memory 94 | 95 | :param key: a hashable key for the target fact 96 | :param diff: a jsondiff object to apply 97 | """ 98 | pass 99 | 100 | def add_skills(self, skills: Collection[Skill]) -> None: 101 | """ 102 | Adds a collection of skills to working memory 103 | :param skills: skills to be added 104 | """ 105 | for skill in skills: 106 | self.add_skill(skill) 107 | 108 | @abstractmethod 109 | def add_skill(self, skill: Skill) -> None: 110 | """ 111 | Add a skill to working memory 112 | 113 | :param skill: skill to be added 114 | """ 115 | pass 116 | 117 | @abstractmethod 118 | def update_skill(self, skill: Skill) -> None: 119 | """ 120 | Update a skill in working memory 121 | 122 | :param skill: the updated skill 123 | """ 124 | pass 125 | 126 | @property 127 | @abstractmethod 128 | def activations(self) -> Collection[Activation]: 129 | """ 130 | Returns the matching rule activations that are being considered. 131 | """ 132 | pass 133 | 134 | @property 135 | @abstractmethod 136 | def output(self) -> object: 137 | """ 138 | Returns an object; what will ultimately get sent over back as an 139 | action. 140 | 141 | .. todo:: 142 | 143 | Write a setter to set object. 144 | """ 145 | pass 146 | 147 | @abstractmethod 148 | def run(self): 149 | """ 150 | update KE/perform inference (under what conditions?) 151 | """ 152 | pass 153 | 154 | @abstractmethod 155 | def step(self): 156 | """ 157 | update KE/perform inference for one step (what is one step?) 158 | """ 159 | pass 160 | -------------------------------------------------------------------------------- /apprentice/working_memory/numba_operators.py: -------------------------------------------------------------------------------- 1 | from numbert.operator import BaseOperator 2 | import math 3 | from numba import njit 4 | # from .representation import numbalizer 5 | 6 | textfield = { 7 | "id" : "string", 8 | "dom_class" : "string", 9 | # "offsetParent" : "string", 10 | "value" : "string", 11 | "contentEditable" : "number", 12 | "above" : "string", 13 | "below" : "string", 14 | "to_right" : "string", 15 | "to_left" : "string", 16 | } 17 | 18 | button = { 19 | "id": "string", 20 | "dom_class":"string", 21 | "label":"string", 22 | "above" : "string", 23 | "below" : "string", 24 | "to_right" : "string", 25 | "to_left" : "string", 26 | } 27 | 28 | checkbox = { 29 | "id": "string", 30 | "dom_class":"string", 31 | "label":"string", 32 | "above" : "string", 33 | "below" : "string", 34 | "to_right" : "string", 35 | "to_left" : "string", 36 | "groupName":"string", 37 | } 38 | 39 | 40 | component = { 41 | "id" : "string", 42 | "dom_class" : "string", 43 | # "offsetParent" : "string", 44 | "above" : "string", 45 | "below" : "string", 46 | "to_right" : "string", 47 | "to_left" : "string", 48 | } 49 | 50 | symbol = { 51 | "id" : "string", 52 | "value" : "string", 53 | "filled" : "number", 54 | "above" : "string", 55 | "below" : "string", 56 | "to_right" : "string", 57 | "to_left" : "string", 58 | } 59 | 60 | overlay_button = { 61 | "id" : "string", 62 | } 63 | 64 | 65 | # numbalizer.register_specification("TextField",textfield) 66 | # numbalizer.register_specification("TextArea",textfield) 67 | # numbalizer.register_specification("Button", button) 68 | # numbalizer.register_specification("Checkbox", checkbox) 69 | # numbalizer.register_specification("RadioButton", checkbox) 70 | # numbalizer.register_specification("Component",component) 71 | # numbalizer.register_specification("Symbol",symbol) 72 | # numbalizer.register_specification("OverlayButton",overlay_button) 73 | 74 | 75 | @njit(cache=True) 76 | def is_prime(n): 77 | if n % 2 == 0 and n > 2: 78 | return False 79 | for i in range(3, int(math.sqrt(n)) + 1, 2): 80 | if n % i == 0: 81 | return False 82 | return True 83 | 84 | 85 | class SquaresOfPrimes(BaseOperator): 86 | signature = 'float(float)' 87 | 88 | def condition(x): 89 | out = is_prime(x) 90 | return out 91 | 92 | def forward(x): 93 | return x**2 94 | 95 | 96 | class EvenPowersOfPrimes(BaseOperator): 97 | signature = 'float(float,float)' 98 | 99 | def condition(x, y): 100 | b = is_prime(x) 101 | a = (y % 2 == 0) and (y > 0) and (y == int(y)) 102 | return a and b 103 | 104 | def forward(x, y): 105 | return x**y 106 | 107 | 108 | class Add(BaseOperator): 109 | commutes = True 110 | signature = 'float(float,float)' 111 | 112 | def forward(x, y): 113 | return x + y 114 | 115 | 116 | class AddOne(BaseOperator): 117 | commutes = True 118 | signature = 'float(float)' 119 | 120 | def forward(x): 121 | return x + 1 122 | 123 | 124 | class Subtract(BaseOperator): 125 | commutes = False 126 | signature = 'float(float,float)' 127 | 128 | def forward(x, y): 129 | return x - y 130 | 131 | 132 | class Multiply(BaseOperator): 133 | commutes = True 134 | signature = 'float(float,float)' 135 | 136 | def forward(x, y): 137 | return x * y 138 | 139 | 140 | class Divide(BaseOperator): 141 | commutes = False 142 | signature = 'float(float,float)' 143 | 144 | def condition(x, y): 145 | return y != 0 146 | 147 | def forward(x, y): 148 | return x / y 149 | 150 | 151 | class Equals(BaseOperator): 152 | commutes = False 153 | signature = 'float(float,float)' 154 | 155 | def forward(x, y): 156 | return x == y 157 | 158 | 159 | class Add3(BaseOperator): 160 | commutes = True 161 | signature = 'float(float,float,float)' 162 | 163 | def forward(x, y, z): 164 | return x + y + z 165 | 166 | 167 | class Mod10(BaseOperator): 168 | commutes = True 169 | signature = 'float(float)' 170 | 171 | def forward(x): 172 | return x % 10 173 | 174 | 175 | class Div10(BaseOperator): 176 | commutes = True 177 | signature = 'float(float)' 178 | 179 | def forward(x): 180 | return x // 10 181 | 182 | 183 | class Concatenate(BaseOperator): 184 | signature = 'string(string,string)' 185 | 186 | def forward(x, y): 187 | return x + y 188 | 189 | class Append25(BaseOperator): 190 | signature = 'string(string)' 191 | 192 | def forward(x): 193 | return x + "25" 194 | 195 | 196 | class StrToFloat(BaseOperator): 197 | signature = 'float(string)' 198 | muted_exceptions = [ValueError] 199 | nopython = False 200 | 201 | def forward(x): 202 | return float(x) 203 | 204 | 205 | class FloatToStr(BaseOperator): 206 | signature = 'string(float)' 207 | muted_exceptions = [ValueError] 208 | nopython = False 209 | 210 | def forward(x): 211 | # if(int(x) == x): 212 | # return str(int(x)) 213 | return str(x) 214 | 215 | class RipStrValue(BaseOperator): 216 | signature = 'string(TextField)' 217 | template = "{}.v" 218 | nopython=False 219 | muted_exceptions = [ValueError] 220 | def forward(x): 221 | return str(x.value) 222 | 223 | class RipFloatValue(BaseOperator): 224 | signature = 'float(TextField)' 225 | template = "{}.v" 226 | nopython=False 227 | muted_exceptions = [ValueError] 228 | def forward(x): 229 | return float(x.value) 230 | 231 | class RipFloatValueSymbol(BaseOperator): 232 | signature = 'float(Symbol)' 233 | template = "{}.v" 234 | nopython=False 235 | muted_exceptions = [ValueError] 236 | def forward(x): 237 | return float(x.value) 238 | 239 | 240 | class Numerator_Multiply(BaseOperator): 241 | signature = 'float(TextField,TextField)' 242 | template = "Numerator_Multiply({}.v,{}.v)" 243 | nopython=False 244 | muted_exceptions = [ValueError] 245 | def condition(x,y): 246 | return x.id.split(".R")[1] == y.id.split(".R")[1] 247 | def forward(x,y): 248 | return float(x.value) * float(y.value) 249 | 250 | class Cross_Multiply(BaseOperator): 251 | signature = 'float(TextField,TextField)' 252 | template = "Cross_Multiply({}.v,{}.v)" 253 | nopython=False 254 | muted_exceptions = [ValueError] 255 | def condition(x,y): 256 | return x.id.split(".R")[1] != y.id.split(".R")[1] 257 | def forward(x,y): 258 | return float(x.value) * float(y.value) 259 | 260 | class Numerator_Multiply_symb(BaseOperator): 261 | signature = 'float(TextField,TextField)' 262 | template = "Numerator_Multiply({}.v,{}.v)" 263 | nopython=False 264 | muted_exceptions = [ValueError] 265 | def condition(x,y): 266 | return x.id.split("_")[1] == y.id.split("_")[1] 267 | def forward(x,y): 268 | return float(x.value) * float(y.value) 269 | 270 | class Cross_Multiply_symb(BaseOperator): 271 | signature = 'float(TextField,TextField)' 272 | template = "Cross_Multiply({}.v,{}.v)" 273 | nopython=False 274 | muted_exceptions = [ValueError] 275 | def condition(x,y): 276 | return x.id.split("_")[1] != y.id.split("_")[1] 277 | def forward(x,y): 278 | return float(x.value) * float(y.value) 279 | 280 | 281 | class ConvertNumerator(BaseOperator): 282 | commutes = False 283 | signature = 'float(float, float, float)' 284 | # template = "ConvertNumerator({0},{1},{2})" 285 | # nopython=False 286 | # muted_exceptions = [ValueError] 287 | def condition(cden, iden, inum): 288 | return iden != 0 and iden <= cden 289 | 290 | def forward(cden, iden, inum): 291 | return (cden / iden) * inum 292 | 293 | 294 | -------------------------------------------------------------------------------- /apprentice/working_memory/representation/__init__.py: -------------------------------------------------------------------------------- 1 | from .representation import Sai, RHS, StateMultiView, Explanation 2 | -------------------------------------------------------------------------------- /django/agent_api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apprenticelearner/AL_Core/9534e3e1d5f5439740e0c44196a4a4f91602aed3/django/agent_api/__init__.py -------------------------------------------------------------------------------- /django/agent_api/settings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Django settings for agent_api project. 3 | 4 | Generated by 'django-admin startproject' using Django 2.1.1. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/2.1/topics/settings/ 8 | 9 | For the full list of settings and their values, see 10 | https://docs.djangoproject.com/en/2.1/ref/settings/ 11 | """ 12 | 13 | import os 14 | 15 | # Build paths inside the project like this: os.path.join(BASE_DIR, ...) 16 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 17 | 18 | PROFILE_LOG_BASE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'profile') 19 | # Quick-start development settings - unsuitable for production 20 | # See https://docs.djangoproject.com/en/2.1/howto/deployment/checklist/ 21 | 22 | # SECURITY WARNING: keep the secret key used in production secret! 23 | SECRET_KEY = 'g1s7^%a)(u(sy2l_flc@#5tm(w-@qu9)v^s(1dwyz9(-si_ca@' 24 | 25 | # SECURITY WARNING: don't run with debug turned on in production! 26 | DEBUG = True 27 | 28 | ALLOWED_HOSTS = ["ai2t.site", "www.ai2t.site", 29 | "ai2t.online", "www.ai2t.online", 30 | "127.0.0.1"] 31 | 32 | 33 | # Application definition 34 | 35 | INSTALLED_APPS = [ 36 | 'django.contrib.admin', 37 | 'django.contrib.auth', 38 | 'django.contrib.contenttypes', 39 | 'django.contrib.sessions', 40 | 'django.contrib.messages', 41 | 'django.contrib.staticfiles', 42 | 'corsheaders', 43 | 'apprentice_learner', 44 | 'django_extensions' 45 | ] 46 | 47 | MIDDLEWARE = [ 48 | 'corsheaders.middleware.CorsMiddleware', 49 | 'django.middleware.security.SecurityMiddleware', 50 | 'django.contrib.sessions.middleware.SessionMiddleware', 51 | 'django.middleware.common.CommonMiddleware', 52 | 'django.middleware.csrf.CsrfViewMiddleware', 53 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 54 | 'django.contrib.messages.middleware.MessageMiddleware', 55 | 'django.middleware.clickjacking.XFrameOptionsMiddleware', 56 | ] 57 | 58 | ROOT_URLCONF = 'agent_api.urls' 59 | 60 | TEMPLATES = [ 61 | { 62 | 'BACKEND': 'django.template.backends.django.DjangoTemplates', 63 | 'DIRS': [], 64 | 'APP_DIRS': True, 65 | 'OPTIONS': { 66 | 'context_processors': [ 67 | 'django.template.context_processors.debug', 68 | 'django.template.context_processors.request', 69 | 'django.contrib.auth.context_processors.auth', 70 | 'django.contrib.messages.context_processors.messages', 71 | ], 72 | }, 73 | }, 74 | ] 75 | 76 | WSGI_APPLICATION = 'agent_api.wsgi.application' 77 | 78 | 79 | # Database 80 | # https://docs.djangoproject.com/en/2.1/ref/settings/#databases 81 | 82 | DATABASES = { 83 | 'default': { 84 | 'ENGINE': 'django.db.backends.sqlite3', 85 | 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), 86 | 'OPTIONS': {'timeout': 100000} 87 | } 88 | } 89 | 90 | 91 | # Password validation 92 | # https://docs.djangoproject.com/en/2.1/ref/settings/#auth-password-validators 93 | 94 | AUTH_PASSWORD_VALIDATORS = [ 95 | { 96 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', 97 | }, 98 | { 99 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', 100 | }, 101 | { 102 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', 103 | }, 104 | { 105 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', 106 | }, 107 | ] 108 | 109 | 110 | # Internationalization 111 | # https://docs.djangoproject.com/en/2.1/topics/i18n/ 112 | 113 | LANGUAGE_CODE = 'en-us' 114 | 115 | TIME_ZONE = 'UTC' 116 | 117 | USE_I18N = True 118 | 119 | USE_L10N = True 120 | 121 | USE_TZ = True 122 | 123 | 124 | # Static files (CSS, JavaScript, Images) 125 | # https://docs.djangoproject.com/en/2.1/howto/static-files/ 126 | 127 | STATIC_URL = '/static/' 128 | STATIC_ROOT = './static/' 129 | 130 | # A flag used in the apprentice API for whether to load custom operators or not. 131 | USE_CUSTOM_OPERATORS = True 132 | 133 | CORS_ORIGIN_ALLOW_ALL = True 134 | # CORS_ORIGIN_REGEX_WHITELIST = ( 135 | # r'http://localhost:\d+$', 136 | # r'http://127.0.0.1:\d+$', 137 | # r'http://0.0.0.0:\d+$', 138 | # r'' 139 | # ) 140 | -------------------------------------------------------------------------------- /django/agent_api/urls.py: -------------------------------------------------------------------------------- 1 | """agent_api URL Configuration 2 | 3 | The `urlpatterns` list routes URLs to views. For more information please see: 4 | https://docs.djangoproject.com/en/1.9/topics/http/urls/ 5 | Examples: 6 | Function views 7 | 1. Add an import: from my_app import views 8 | 2. Add a URL to urlpatterns: url(r'^$', views.home, name='home') 9 | Class-based views 10 | 1. Add an import: from other_app.views import Home 11 | 2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home') 12 | Including another URLconf 13 | 1. Import the include() function: from django.conf.urls import url, include 14 | 2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls')) 15 | """ 16 | from django.urls import include, re_path 17 | # from django.conf.urls import url 18 | # from django.conf.urls import include 19 | # from django.contrib import admin 20 | # from flashpolicies.views import simple 21 | 22 | urlpatterns = [ 23 | re_path(r'^',include('apprentice_learner.urls',namespace='apprentice_learner')), 24 | # url(r'^admin/', admin.site.urls), 25 | # url(r'^crossdomain.xml$', simple, 26 | # {'domains': ['*']}), 27 | ] 28 | 29 | -------------------------------------------------------------------------------- /django/agent_api/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for agent_api project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/1.9/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "agent_api.settings") 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /django/apprentice_learner/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apprenticelearner/AL_Core/9534e3e1d5f5439740e0c44196a4a4f91602aed3/django/apprentice_learner/__init__.py -------------------------------------------------------------------------------- /django/apprentice_learner/admin.py: -------------------------------------------------------------------------------- 1 | from django import forms 2 | from django.contrib import admin 3 | from apprentice_learner.models import Agent 4 | 5 | 6 | class AgentAdmin(admin.ModelAdmin): 7 | pass 8 | 9 | 10 | # Register your models here. 11 | admin.site.register(Agent, AgentAdmin) 12 | -------------------------------------------------------------------------------- /django/apprentice_learner/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class ApprenticeLearnerConfig(AppConfig): 5 | name = 'apprentice_learner' 6 | -------------------------------------------------------------------------------- /django/apprentice_learner/migrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apprenticelearner/AL_Core/9534e3e1d5f5439740e0c44196a4a4f91602aed3/django/apprentice_learner/migrations/__init__.py -------------------------------------------------------------------------------- /django/apprentice_learner/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | # from picklefield.fields import PickledObjectField 3 | # from django.core.exceptions import ValidationError 4 | # from django.utils.translation import ugettext_lazy as _ 5 | 6 | # from apprentice.planners.fo_planner import Operator as Opp 7 | 8 | 9 | class Agent(models.Model): 10 | """ 11 | Agents are the meat of the Apprentice Learner API that instantiate the 12 | various learning mechanisms. 13 | """ 14 | # instance = PickledObjectField() 15 | uid = models.CharField(max_length=50, primary_key=True) 16 | name = models.CharField(max_length=200, blank=True) 17 | num_act = models.IntegerField(default=0) 18 | num_train = models.IntegerField(default=0) 19 | num_check = models.IntegerField(default=0) 20 | created = models.DateTimeField(auto_now_add=True) 21 | updated = models.DateTimeField(auto_now=True) 22 | 23 | def inc_act(self): 24 | self.num_act = self.num_act + 1 25 | 26 | def inc_train(self): 27 | self.num_train = self.num_train + 1 28 | 29 | def inc_check(self): 30 | self.num_check = self.num_check + 1 31 | 32 | # def __str__(self): 33 | # return str(self.instance) 34 | # skills = {} 35 | 36 | # try: 37 | # skill_dict = self.instance.skills 38 | # for label in skill_dict: 39 | # for i, how in enumerate(skill_dict[label]): 40 | # name = label 41 | # if i > 0: 42 | # name = "%s-%i" % (label, i+1) 43 | # skills[name] = {} 44 | # skills[name]['where'] = skill_dict[label][how]['where_classifier'] 45 | # skills[name]['when'] = skill_dict[label][how]['when_classifier'] 46 | # skills[name]['how'] = how 47 | 48 | # except: 49 | # pass 50 | 51 | # return "Agent {0} - {1} : {2}".format(self.pk, self.name, len(skills)) 52 | 53 | # def generate_trees(self): 54 | # import pydotplus 55 | # from sklearn import tree 56 | # from sklearn.externals.six import StringIO 57 | 58 | # for label in self.instance.skills: 59 | # for n, how in enumerate(self.instance.skills[label]): 60 | # pipeline = self.instance.skills[label][how]['when_classifier'] 61 | 62 | # dv = pipeline.steps[0][1] 63 | # dt = pipeline.steps[1][1] 64 | 65 | # dot_data = StringIO() 66 | # tree.export_graphviz(dt, out_file=dot_data, 67 | # feature_names=dv.feature_names_, 68 | # class_names=["Don't Fire Rule", 69 | # "Fire Rule"], 70 | # filled=True, rounded=True, 71 | # special_characters=True) 72 | # graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) 73 | # graph.write_png("decisiontrees/%s-%i.png" % (label, n)) 74 | 75 | class Meta: 76 | ordering = ('-updated',) 77 | 78 | # user specified domain 79 | # owner 80 | 81 | # Users and User permissions 82 | -------------------------------------------------------------------------------- /django/apprentice_learner/templates/apprentice_learner/tester.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Apprentice Learner Tester 7 | 8 | 9 | 10 |
11 |

Apprentice API Tester

12 |
13 |
14 | 23 | 30 | 31 | 43 | 44 |
45 |
46 | 47 |
48 |
49 |
50 | 51 | 52 | 53 | 54 | 55 | 56 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /django/apprentice_learner/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | 3 | # Create your tests here. 4 | -------------------------------------------------------------------------------- /django/apprentice_learner/urls.py: -------------------------------------------------------------------------------- 1 | # from django.urls import path 2 | from django.urls import include, re_path 3 | from django.conf import settings 4 | from django.conf.urls.static import static 5 | 6 | from apprentice_learner import views 7 | 8 | 9 | app_name = 'apprentice_api' 10 | urlpatterns = [ 11 | re_path('list_agents/', views.list_agents, name="list_agents"), 12 | re_path('create/', views.create, name="create"), 13 | re_path('verify/', views.verify, name="verify"), 14 | re_path('get_active_agent/', views.get_active_agent, name="get_active_agent"), 15 | re_path('act/', views.act, name="act"), 16 | re_path('act_all/', views.act_all, name="act_all"), 17 | re_path('act_rollout/', views.act_rollout, name="act_rollout"), 18 | re_path('train/', views.train, name="train"), 19 | re_path('train_all/', views.train_all, name="train_all"), 20 | re_path('explain_demo/', views.explain_demo, name="explain_demo"), 21 | re_path('get_state_uid/', views.get_state_uid, name="get_state_uid"), 22 | re_path('predict_next_state/', views.predict_next_state, name="predict_next_state"), 23 | re_path('check/', views.check, name="check"), 24 | re_path('get_skills/', views.get_skills, name="get_skills"), 25 | 26 | re_path('gen_completeness_profile/', views.gen_completeness_profile, name="gen_completeness_profile"), 27 | re_path('eval_completeness/', views.eval_completeness, name="eval_completeness"), 28 | 29 | 30 | # url(r'^report/(?P[0-9]+)/$', views.report, name="report"), 31 | 32 | # url(r'^request/(?P[a-zA-Z0-9_-]{1,200})/$', 33 | # views.request_by_name, name="request_by_name"), 34 | # url(r'^train/(?P[a-zA-Z0-9_-]{1,200})/$', 35 | # views.train_by_name, name="train_by_name"), 36 | # url(r'^check/(?P[a-zA-Z0-9_-]{1,200})/$', 37 | # views.check_by_name, name="check_by_name"), 38 | # url(r'^report/(?P[a-zA-Z0-9_-]{1,200})/$', 39 | # views.report_by_name, name="report_by_name"), 40 | # url(r'^tester/$', views.test_view, name='tester'), 41 | 42 | # url(r'^get_skills/(?P[0-9]+)/$', views.get_skills, name='get_skills') 43 | ] #+ static(settings.STATIC_URL, document_root=settings.STATIC_ROOT) 44 | 45 | 46 | # DUMP 47 | # Pattern for integers 48 | # r'^report/(?P[0-9]+)/$ 49 | -------------------------------------------------------------------------------- /django/manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os, glob 3 | import sys 4 | 5 | if __name__ == "__main__": 6 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "agent_api.settings") 7 | 8 | from django.core.management import execute_from_command_line 9 | 10 | execute_from_command_line(sys.argv) 11 | 12 | 13 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | 3 | services: 4 | #db: 5 | # image: postgres 6 | web: 7 | build: . 8 | command: python3.8 /usr/local/apprentice/django/manage.py runserver 0.0.0.0:8000 9 | ports: 10 | - "8000:8000" 11 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /docs/agents.rst: -------------------------------------------------------------------------------- 1 | Agents 2 | ====== 3 | 4 | Base Agents 5 | --------- 6 | 7 | .. automodule:: apprentice.agents.base 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | .. automodule:: apprentice.agents.diff_base 13 | :members: 14 | :undoc-members: 15 | :show-inheritance: 16 | 17 | 18 | ExpertaAgent 19 | ------------- 20 | 21 | .. automodule:: apprentice.agents.experta_agent 22 | :members: 23 | :undoc-members: 24 | :show-inheritance: 25 | 26 | ModularAgent 27 | ------------ 28 | 29 | .. automodule:: apprentice.agents.ModularAgent 30 | :members: 31 | :undoc-members: 32 | :show-inheritance: 33 | 34 | Other Agents 35 | ------- 36 | 37 | .. automodule:: apprentice.agents.RLAgent 38 | :members: 39 | :undoc-members: 40 | :show-inheritance: 41 | 42 | .. automodule:: apprentice.agents.WhereWhenHowNoFoa 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | .. automodule:: apprentice.agents.Memo 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: 51 | 52 | .. automodule:: apprentice.agents.Stub 53 | :members: 54 | :undoc-members: 55 | :show-inheritance: 56 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | sys.path.insert(0, os.path.abspath('../')) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'Apprentice' 23 | copyright = '2019, Chris MacLellan, Erik Harpstead, and Daniel Weitekamp' 24 | author = 'Chris MacLellan, Erik Harpstead, and Daniel Weitekamp' 25 | 26 | # The short X.Y version 27 | version = '' 28 | # The full version, including alpha/beta/rc tags 29 | release = '' 30 | 31 | 32 | # -- General configuration --------------------------------------------------- 33 | 34 | # If your documentation needs a minimal Sphinx version, state it here. 35 | # 36 | # needs_sphinx = '1.0' 37 | 38 | # Add any Sphinx extension module names here, as strings. They can be 39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 40 | # ones. 41 | extensions = [ 42 | 'sphinx.ext.autodoc', 43 | 'sphinx_autodoc_typehints' 44 | ] 45 | 46 | # Add any paths that contain templates here, relative to this directory. 47 | templates_path = ['_templates'] 48 | 49 | # The suffix(es) of source filenames. 50 | # You can specify multiple suffix as a list of string: 51 | # 52 | # source_suffix = ['.rst', '.md'] 53 | source_suffix = '.rst' 54 | 55 | # The master toctree document. 56 | master_doc = 'index' 57 | 58 | # The language for content autogenerated by Sphinx. Refer to documentation 59 | # for a list of supported languages. 60 | # 61 | # This is also used if you do content translation via gettext catalogs. 62 | # Usually you set "language" from the command line for these cases. 63 | language = None 64 | 65 | # List of patterns, relative to source directory, that match files and 66 | # directories to ignore when looking for source files. 67 | # This pattern also affects html_static_path and html_extra_path. 68 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 69 | 70 | # The name of the Pygments (syntax highlighting) style to use. 71 | pygments_style = None 72 | 73 | 74 | # -- Options for HTML output ------------------------------------------------- 75 | 76 | # The theme to use for HTML and HTML Help pages. See the documentation for 77 | # a list of builtin themes. 78 | # 79 | html_theme = 'sphinx_rtd_theme' 80 | 81 | # Theme options are theme-specific and customize the look and feel of a theme 82 | # further. For a list of options available for each theme, see the 83 | # documentation. 84 | # 85 | # html_theme_options = {} 86 | 87 | # Add any paths that contain custom static files (such as style sheets) here, 88 | # relative to this directory. They are copied after the builtin static files, 89 | # so a file named "default.css" will overwrite the builtin "default.css". 90 | html_static_path = ['_static'] 91 | 92 | # Custom sidebar templates, must be a dictionary that maps document names 93 | # to template names. 94 | # 95 | # The default sidebars (for documents that don't match any pattern) are 96 | # defined by theme itself. Builtin themes are using these templates by 97 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 98 | # 'searchbox.html']``. 99 | # 100 | # html_sidebars = {} 101 | 102 | 103 | # -- Options for HTMLHelp output --------------------------------------------- 104 | 105 | # Output file base name for HTML help builder. 106 | htmlhelp_basename = 'ApprenticePythondoc' 107 | 108 | 109 | # -- Options for LaTeX output ------------------------------------------------ 110 | 111 | latex_elements = { 112 | # The paper size ('letterpaper' or 'a4paper'). 113 | # 114 | # 'papersize': 'letterpaper', 115 | 116 | # The font size ('10pt', '11pt' or '12pt'). 117 | # 118 | # 'pointsize': '10pt', 119 | 120 | # Additional stuff for the LaTeX preamble. 121 | # 122 | # 'preamble': '', 123 | 124 | # Latex figure (float) alignment 125 | # 126 | # 'figure_align': 'htbp', 127 | } 128 | 129 | # Grouping the document tree into LaTeX files. List of tuples 130 | # (source start file, target name, title, 131 | # author, documentclass [howto, manual, or own class]). 132 | latex_documents = [ 133 | (master_doc, 'ApprenticePython.tex', 'Apprentice Python Documentation', 134 | 'Chris MacLellan \\and Erik Harpstead \\and Daniel Weitekamp', 'manual'), 135 | ] 136 | 137 | 138 | # -- Options for manual page output ------------------------------------------ 139 | 140 | # One entry per manual page. List of tuples 141 | # (source start file, name, description, authors, manual section). 142 | man_pages = [ 143 | (master_doc, 'testgitlabcipython', 'Apprentice Python Documentation', 144 | [author], 1) 145 | ] 146 | 147 | 148 | # -- Options for Texinfo output ---------------------------------------------- 149 | 150 | # Grouping the document tree into Texinfo files. List of tuples 151 | # (source start file, target name, title, author, 152 | # dir menu entry, description, category) 153 | texinfo_documents = [ 154 | (master_doc, 'ApprenticePython', 'Apprentice Python Documentation', 155 | author, 'ApprenticePython', 'An architecture for building agents that ' 156 | 'learn from demonstrations and feedback.', 157 | 'Miscellaneous'), 158 | ] 159 | 160 | 161 | # -- Options for Epub output ------------------------------------------------- 162 | 163 | # Bibliographic Dublin Core info. 164 | epub_title = project 165 | 166 | # The unique identifier of the text. This can be a ISBN number 167 | # or the project homepage. 168 | # 169 | # epub_identifier = '' 170 | 171 | # A unique identification for the text. 172 | # 173 | # epub_uid = '' 174 | 175 | # A list of files that should not be packed into the epub file. 176 | epub_exclude_files = ['search.html'] 177 | 178 | 179 | # -- Extension configuration ------------------------------------------------- 180 | -------------------------------------------------------------------------------- /docs/doc-requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx-rtd-theme 3 | sphinx-autodoc-typehints 4 | -------------------------------------------------------------------------------- /docs/images/batch_train_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apprenticelearner/AL_Core/9534e3e1d5f5439740e0c44196a4a4f91602aed3/docs/images/batch_train_example.png -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | 3 | .. toctree:: 4 | :maxdepth: 2 5 | :caption: Contents: 6 | 7 | agents.rst 8 | learners.rst 9 | planners.rst 10 | working_memory.rst 11 | 12 | 13 | 14 | Indices and tables 15 | ================== 16 | 17 | * :ref:`genindex` 18 | * :ref:`modindex` 19 | * :ref:`search` 20 | -------------------------------------------------------------------------------- /docs/learners.rst: -------------------------------------------------------------------------------- 1 | Learners 2 | ======== 3 | 4 | HowLearner 5 | ---------- 6 | 7 | .. automodule:: apprentice.learners.HowLearner 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | WhereLearner 13 | ------------ 14 | 15 | .. automodule:: apprentice.learners.WhereLearner 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | WhenLearner 21 | ----------- 22 | 23 | .. automodule:: apprentice.learners.WhenLearner 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | WhichLearner 29 | ------------ 30 | 31 | .. automodule:: apprentice.learners.WhichLearner 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | WhatLearner 37 | ----------- 38 | 39 | .. automodule:: apprentice.learners.WhatLearner 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | Utils 45 | ----- 46 | 47 | .. automodule:: apprentice.learners.utils 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: 51 | 52 | HowLearnerOld 53 | -------------- 54 | 55 | .. automodule:: apprentice.learners.HowLearnerOld 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | 60 | IncrementalHeuristic 61 | -------------------- 62 | 63 | .. automodule:: apprentice.learners.IncrementalHeuristic 64 | :members: 65 | :undoc-members: 66 | :show-inheritance: 67 | 68 | Grammar 69 | ------- 70 | 71 | .. automodule:: apprentice.learners.Grammar 72 | :members: 73 | :undoc-members: 74 | :show-inheritance: 75 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/planners.rst: -------------------------------------------------------------------------------- 1 | Planners 2 | ======== 3 | 4 | Base Planner 5 | ------------ 6 | 7 | .. automodule:: apprentice.planners.base_planner 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | Numbert Planner 13 | --------------- 14 | 15 | .. automodule:: apprentice.planners.NumbaPlanner 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | FO Planner 21 | ---------- 22 | 23 | .. automodule:: apprentice.planners.fo_planner 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | VectorizedPlanner 29 | ----------------- 30 | 31 | .. automodule:: apprentice.planners.VectorizedPlanner 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | Action Planner 37 | -------------- 38 | 39 | .. automodule:: apprentice.planners.action_planner 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | Rulesets 45 | ---------- 46 | 47 | .. automodule:: apprentice.planners.rulesets 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: 51 | -------------------------------------------------------------------------------- /docs/working_memory.rst: -------------------------------------------------------------------------------- 1 | Working Memory 2 | ============== 3 | 4 | Base 5 | ----- 6 | 7 | .. automodule:: apprentice.working_memory.base 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | Representation 13 | -------------- 14 | 15 | .. automodule:: apprentice.working_memory.representation 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | -------------------------------------------------------------------------------- /examples/test_retract.py: -------------------------------------------------------------------------------- 1 | from experta import KnowledgeEngine 2 | from experta import MATCH 3 | from experta import Field 4 | from experta import Rule 5 | from experta import Fact 6 | 7 | 8 | class Number(Fact): 9 | """ Holds a number """ 10 | value = Field(int, mandatory=True) 11 | 12 | 13 | class KB(KnowledgeEngine): 14 | 15 | @Rule(Number(value=MATCH.x)) 16 | def increment(self, x): 17 | if x >= 3: 18 | return 19 | y = x + 1 20 | self.declare(Number(value=y)) 21 | 22 | 23 | if __name__ == "__main__": 24 | 25 | engine = KB() 26 | engine.reset() 27 | initial_f = Number(value=0) 28 | engine.declare(initial_f) 29 | engine.run() 30 | print(engine.facts) 31 | 32 | engine.retract(initial_f) 33 | engine.step() 34 | print(engine.facts) 35 | -------------------------------------------------------------------------------- /examples/ttt_simple.py: -------------------------------------------------------------------------------- 1 | import schema 2 | 3 | from apprentice.working_memory.adapters.experta_.factory import ExpertaSkillFactory 4 | from apprentice.working_memory.representation.representation import Sai, Skill 5 | from experta import Fact, Field, KnowledgeEngine, AS, Rule, MATCH, NOT 6 | from tabulate import tabulate 7 | 8 | 9 | class Square(Fact): 10 | row = Field(str, mandatory=True) 11 | col = Field(str, mandatory=True) 12 | player = Field(str, mandatory=True) 13 | 14 | 15 | class PossibleMove(Square): 16 | row = Field(str, mandatory=True) 17 | col = Field(str, mandatory=True) 18 | player = Field(str, mandatory=True) 19 | 20 | 21 | class CurrentPlayer(Fact): 22 | name = Field(schema.Or("X", "O"), mandatory=True) 23 | 24 | 25 | 26 | class ttt_engine(KnowledgeEngine): 27 | # @DefFacts() 28 | # def init_board(self, x=3, players=['X', 'O']): 29 | # return 30 | # yield CurrentPlayer(name=players[0]) 31 | # for row in range(3): 32 | # for col in range(3): 33 | # yield Square(row=row, col=col, player='') 34 | 35 | # @Rule( 36 | # AS.square1 << Square(row=MATCH.row, col=MATCH.square1col), 37 | # AS.square2 << Square(row=MATCH.row, col=MATCH.square2col), 38 | # TEST(lambda square1col, square2col: square2col == square1col + 1), 39 | # ) 40 | # def horizontally_adj(self, row, square1col, square2col): 41 | # relation = Fact( 42 | # relation="horizontally_adjacent", 43 | # row=row, 44 | # square1col=square1col, 45 | # square2col=square2col, 46 | # ) 47 | # self.declare(relation) 48 | 49 | # ... other relations 50 | 51 | # @Rule( 52 | # Fact(type="CurrentPlayer", player=MATCH.player), 53 | # AS.square << Fact(type="Square", row=MATCH.row, col=MATCH.col, 54 | # player=""), 55 | # NOT(Fact(type="PossibleMove", row=MATCH.row, col=MATCH.col, 56 | # player=MATCH.player)) 57 | # ) 58 | # def suggest_move(self, row, col, player): 59 | # self.declare(Fact(type="PossibleMove", row=row, col=col, 60 | # player=player)) 61 | 62 | # @Rule( 63 | # Fact(type="CurrentPlayer", player=MATCH.player), 64 | # AS.square 65 | # << Fact(type="PossibleMove", row=MATCH.row, col=MATCH.col, 66 | # player=MATCH.player), 67 | # ) 68 | # def make_move(self, row, col, player): 69 | # return Sai(None, "move", {"row": row, "col": col, "player": player}) 70 | 71 | @Rule( 72 | Fact(type='CurrentPlayer', player=MATCH.player), 73 | Fact(type='Square', row=MATCH.row, col=MATCH.col, player="") 74 | ) 75 | def make_move(self, row, col, player): 76 | print("moving", row, col, player) 77 | return Sai(None, 'move', {'row': row, 'col': col, 'player': player}) 78 | 79 | 80 | ke = ttt_engine() 81 | make_move_skill = ExpertaSkillFactory(ke).from_ex_rule(ke.make_move) 82 | 83 | ttt_skill_map = {'make_move': make_move_skill} 84 | 85 | class ttt_oracle: 86 | """ 87 | Enviornment oracle for ttt: 88 | """ 89 | 90 | def __init__(self, players=["X", "O"]): 91 | self.players = players 92 | self.current_player = "X" 93 | self.board = [["" for _ in range(3)] for _ in range(3)] 94 | 95 | def move2(self, row, col, player): 96 | assert self.board[row][col] == "" 97 | self.board[row][col] = player 98 | return ( 99 | [{"__class__": Square, "row": row, "col": col, "val": player}], 100 | [{"__class__": Square, "row": row, "col": col, "val": ""}], 101 | ) 102 | 103 | def set_state(self, state): 104 | for fact in state: 105 | if fact["__class__"].__name__ == "Square": 106 | self.board[fact["row"]][fact["col"]] = fact["val"] 107 | 108 | def as_dict(self): 109 | def ids(): 110 | i = 0 111 | while True: 112 | i += 1 113 | yield i 114 | 115 | idg = ids() 116 | d = {next(idg): {"type": "CurrentPlayer", "player": 117 | self.current_player}} 118 | for row in range(3): 119 | for col in range(3): 120 | d[next(idg)] = { 121 | "type": "Square", 122 | "row": str(row), 123 | "col": str(col), 124 | "player": self.board[row][col], 125 | } 126 | return d 127 | 128 | def __str__(self): 129 | table = [] 130 | table.append(["", "Col 0", "Col 1", "Col 2"]) 131 | for i in range(3): 132 | table.append(["Row %i" % i] + [s for s in self.board[i]]) 133 | 134 | return tabulate(table, tablefmt="fancy_grid", stralign="center") 135 | 136 | def move(self, row, col, player): 137 | """ 138 | Row -> 0-2 range inclusive 139 | Col -> 0-2 range inclusive 140 | """ 141 | row = int(row) 142 | col = int(col) 143 | if row < 0 or row > 2: 144 | raise ValueError("Move not on board") 145 | if col < 0 or col > 2: 146 | raise ValueError("Move not on board") 147 | if self.board[row][col] != "": 148 | raise ValueError("Move already played") 149 | 150 | self.board[row][col] = player 151 | 152 | if self.current_player == "X": 153 | self.current_player = "O" 154 | else: 155 | self.current_player = "X" 156 | 157 | def check_winner(self): 158 | moves = 0 159 | for row in range(3): 160 | for col in range(3): 161 | if self.board[row][col] != "": 162 | moves += 1 163 | 164 | if self.board[row][0] != "" and ( 165 | self.board[row][0] == self.board[row][1] == self.board[row][2] 166 | ): 167 | return self.board[row][0] 168 | 169 | for col in range(3): 170 | if self.board[0][col] != "" and ( 171 | self.board[0][col] == self.board[1][col] == self.board[2][col] 172 | ): 173 | return self.board[0][col] 174 | 175 | if self.board[0][0] != "" and ( 176 | self.board[0][0] == self.board[1][1] == self.board[2][2] 177 | ): 178 | return self.board[0][0] 179 | 180 | if self.board[2][0] != "" and ( 181 | self.board[2][0] == self.board[1][1] == self.board[0][2] 182 | ): 183 | return self.board[2][0] 184 | 185 | if moves == 9: 186 | return "draw" 187 | 188 | return False 189 | 190 | 191 | if __name__ == "__main__": 192 | t = ttt_oracle() 193 | -------------------------------------------------------------------------------- /examples/ttt_wm.py: -------------------------------------------------------------------------------- 1 | from experta import KnowledgeEngine 2 | 3 | from apprentice.agents import SoarTechAgent 4 | from apprentice.working_memory import ExpertaWorkingMemory 5 | from ttt_simple import ttt_oracle 6 | 7 | from examples.ttt_simple import ttt_skill_map 8 | 9 | if __name__ == "__main__": 10 | # with experta knowledge engine 11 | wm2 = ExpertaWorkingMemory(ke=KnowledgeEngine()) 12 | a1 = SoarTechAgent( 13 | wm=wm2, 14 | # when=q_learner.QLearner(func=q_learner.Cobweb, q_init=0.0), 15 | feature_set=[], function_set=[], 16 | negative_actions=True, action_penalty=0.0, epsilon=0.3, 17 | skill_map=ttt_skill_map, prior_skills=['make_move'] 18 | ) 19 | 20 | max_training_games = 100 21 | consecutive_wins = 0 22 | prev_win_board = None 23 | i = 0 24 | while consecutive_wins < 5 and i < max_training_games: 25 | o = ttt_oracle() 26 | winner = False 27 | print("#############################") 28 | print("Training game {}".format(i)) 29 | print("#############################") 30 | print() 31 | 32 | i += 1 33 | while not winner: 34 | state = o.as_dict() 35 | sai = a1.request(state) 36 | 37 | getattr(o, sai.action)(**sai.input) 38 | next_state = o.as_dict() 39 | winner = o.check_winner() 40 | reward = 0 41 | 42 | if winner is not False: 43 | next_state = None 44 | 45 | if winner == "X": 46 | reward = -1 47 | consecutive_wins = 0 48 | if winner == "O": 49 | reward = 1 50 | if prev_win_board and o.as_dict() == prev_win_board: 51 | consecutive_wins += 1 52 | else: 53 | prev_win_board = o.as_dict() 54 | consecutive_wins = 0 55 | 56 | a1.train(state, next_state, sai, reward, "", [""]) 57 | print("#############################") 58 | print("Final reward = %0.1f" % reward) 59 | print("#############################") 60 | print() 61 | 62 | test_games = 20 63 | a1.epsilon = 0 64 | wins = 0 65 | for i in range(test_games): 66 | o = ttt_oracle() 67 | print("test game {}".format(i)) 68 | while not o.check_winner(): 69 | d = o.as_dict() 70 | sai = a1.request(d) 71 | getattr(o, sai.action)(**sai.input) 72 | print(o) 73 | if o.check_winner() == "O": 74 | wins += 1 75 | if o.check_winner() == "X": 76 | wins -= 1 77 | 78 | print("#############################") 79 | print("test win rate: {}".format(wins / test_games)) 80 | print("#############################") 81 | print() 82 | -------------------------------------------------------------------------------- /examples/ttt_wm_vs_human.py: -------------------------------------------------------------------------------- 1 | from apprentice.agents import SoarTechAgent 2 | from apprentice.working_memory import ExpertaWorkingMemory 3 | from apprentice.working_memory.representation import Sai 4 | # from apprentice.learners.when_learners import q_learner 5 | from ttt_simple import ttt_engine, ttt_oracle 6 | 7 | 8 | def get_user_demo(): 9 | print() 10 | print("Current Player: " + game.current_player) 11 | print(game) 12 | print("Don't know what to do.") 13 | print("Please provide example of correct behavior.") 14 | print() 15 | 16 | while True: 17 | try: 18 | loc = input("Enter move as row and column integers " 19 | "(e.g., 1,2):") 20 | loc = loc.split(',') 21 | 22 | row = int(loc[0]) 23 | col = int(loc[1]) 24 | player = game.current_player 25 | break 26 | except Exception: 27 | print("error with input, try again.") 28 | 29 | return Sai(None, "move", {"row": row, "col": col, "player": player}) 30 | 31 | 32 | if __name__ == "__main__": 33 | # with experta knowledge engine 34 | wm1 = ExpertaWorkingMemory(ke=ttt_engine()) 35 | a1 = SoarTechAgent( 36 | # wm=wm1, when=q_learner.QLearner(func=q_learner.Cobweb, q_init=0.0) 37 | feature_set=[], function_set=[], 38 | wm=wm1, 39 | epsilon=0.5, 40 | # when=q_learner.QLearner(func=q_learner.LinearFunc, q_init=0.0), 41 | negative_actions=True, 42 | action_penalty=0.0 43 | ) 44 | 45 | new_game = True 46 | while new_game: 47 | game = ttt_oracle() 48 | winner = False 49 | last_state = None 50 | last_sai = None 51 | user_demo = False 52 | 53 | while not winner: 54 | print() 55 | print("Current Player: " + game.current_player) 56 | print(game) 57 | state = game.as_dict() 58 | # pprint(state) 59 | 60 | if game.current_player == "X": 61 | if (last_state is not None and last_sai is not None and not 62 | user_demo): 63 | a1.train(last_state, state, last_sai, 0.0, "", [""]) 64 | elif (last_state is not None and last_sai is not None and 65 | user_demo): 66 | print('providing bonus reward for user demo!') 67 | a1.train(last_state, state, last_sai, 1.0, "", [""]) 68 | 69 | last_state = state 70 | sai = a1.request(state) 71 | 72 | if not isinstance(sai, Sai): 73 | sai = get_user_demo() 74 | user_demo = True 75 | else: 76 | user_demo = False 77 | 78 | last_sai = sai 79 | 80 | getattr(game, sai.action)(**sai.input) 81 | print("AI's move", sai) 82 | 83 | else: 84 | while True: 85 | try: 86 | loc = input("Enter move as row and column integers " 87 | "(e.g., 1,2):") 88 | loc = loc.split(',') 89 | 90 | row = int(loc[0]) 91 | col = int(loc[1]) 92 | player = game.current_player 93 | game.move(row, col, player) 94 | break 95 | except Exception: 96 | print("error with input, try again.") 97 | 98 | winner = game.check_winner() 99 | 100 | if winner == "X": 101 | a1.train(last_state, None, last_sai, 1.0, "", [""]) 102 | elif winner == "O": 103 | a1.train(last_state, None, last_sai, -1.0, "", [""]) 104 | else: 105 | a1.train(last_state, None, last_sai, 0, "", [""]) 106 | 107 | print("WINNER = ", winner) 108 | print(game) 109 | print() 110 | 111 | new_game = True 112 | # new = input("Play again? Press enter to continue or type 'no' to" 113 | # " stop.") 114 | # new_game = new == "" 115 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | colorama==0.4.3 2 | concept-formation==0.3.4 3 | Django>=4.2.9 4 | django-extensions>=3.2.3 5 | django-picklefield @ git+https://github.com/eharpste/django-picklefield.git 6 | munkres==1.0.12 7 | numpy>=1.16.0 8 | ordered-set==3.0.2 9 | prettytable==0.7.2 10 | py-search==2.0.1 11 | pytz==2018.5 12 | six==1.11.0 13 | tabulate==0.8.2 14 | verlib==0.1 15 | jsondiff 16 | nltk 17 | coloredlogs 18 | PyYAML 19 | pytest 20 | multiprocess 21 | scikit-learn 22 | dill 23 | django-cors-headers 24 | numba==0.58.1 25 | cre 26 | stand @ git+https://github.com/DannyWeitekamp/STAND.git 27 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = apprentice 3 | author = Christopher J. MacLellan, Erik Harpstead, Daniel Weitekamp 4 | author-email = maclellan.christopher@gmail.com, whitill29@gmail.com, dannyweitekamp@gmail.com 5 | summary = A framework for creating apprentice agents that learn from demonstrations and feedback. 6 | description-file = README.md 7 | description-content-type = text/markdown; charset=UTF-8 8 | home-page = https://github.com/apprenticelearner/ 9 | project_urls = 10 | Source Code = https://github.com/apprenticelearner/ 11 | license = MIT 12 | license_file = LICENSE 13 | classifier = 14 | Development Status :: 4 - Beta 15 | Intended Audience :: Science/Research 16 | Topic :: Scientific/Engineering :: Artificial Intelligence 17 | License :: OSI Approved :: MIT License 18 | Programming Language :: Python 19 | Programming Language :: Python :: 3 20 | Programming Language :: Python :: 3.7 21 | Programming Language :: Python :: Implementation :: PyPy 22 | 23 | [files] 24 | package = 25 | apprentice 26 | 27 | [bdist_wheel] 28 | universal=1 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | 5 | setup( 6 | python_requires='>=3.7.0', 7 | setup_requires=['pbr'], 8 | pbr=True, 9 | ) 10 | -------------------------------------------------------------------------------- /setup2.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="apprentice", 5 | version="0.0.1", 6 | packages=setuptools.find_packages(), 7 | python_requires='>=3.7', 8 | ) -------------------------------------------------------------------------------- /test-requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | pytest 3 | pytest-benchmark 4 | coverage 5 | flake8 6 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py37 3 | 4 | [testenv] 5 | commands = 6 | coverage run --source torchmm -m pytest 7 | coverage report 8 | flake8 torchmm 9 | deps = 10 | -r test-requirements.txt 11 | 12 | [pytest] 13 | doctest_optionflags=ALLOW_UNICODE 14 | addopts = --junitxml output.xml 15 | testpaths = 16 | tests 17 | #benchmarks 18 | 19 | [coverage:run] 20 | branch = true 21 | omit = 22 | torchmm/__init__.py 23 | tests/* 24 | benchmarks/* 25 | --------------------------------------------------------------------------------