├── test ├── __init__.py ├── auto │ ├── __init__.py │ ├── env │ │ └── __init__.py │ ├── test_dataset.py │ ├── _pl_plugin_runner.py │ ├── test_pl_logger.py │ └── test_pl_plugin.py ├── frame │ ├── __init__.py │ ├── noise │ │ ├── __init__.py │ │ └── test_generator.py │ ├── algorithms │ │ ├── __init__.py │ │ └── utils.py │ └── buffers │ │ └── __init__.py ├── utils │ ├── __init__.py │ ├── test_learning_rate.py │ ├── test_tensor_board.py │ ├── test_visualize.py │ ├── test_checker.py │ ├── test_conf.py │ ├── test_media.py │ ├── test_helper_classes.py │ ├── test_save_env.py │ └── test_prepare.py ├── parallel │ ├── __init__.py │ ├── server │ │ ├── __init__.py │ │ └── test_ordered_server.py │ ├── distributed │ │ └── __init__.py │ └── test_pickle.py ├── env │ └── wrappers │ │ └── __init__.py ├── data │ ├── generators │ │ └── __init__.py │ ├── __init__.py │ ├── __main__.py │ ├── archive.py │ └── all.py ├── conftest.py ├── util_platforms.py ├── util_create_ma_env.py └── util_fixtures.py ├── machin ├── frame │ ├── helpers │ │ └── __init__.py │ ├── __init__.py │ ├── noise │ │ └── __init__.py │ ├── buffers │ │ └── __init__.py │ └── algorithms │ │ └── __init__.py ├── model │ ├── algorithms │ │ └── __init__.py │ ├── __init__.py │ └── nets │ │ └── __init__.py ├── auto │ ├── envs │ │ └── __init__.py │ ├── __init__.py │ ├── __main__.py │ └── dataset.py ├── env │ ├── utils │ │ ├── __init__.py │ │ └── openai_gym.py │ ├── __init__.py │ └── wrappers │ │ ├── __init__.py │ │ └── base.py ├── __init__.py ├── utils │ ├── __init__.py │ ├── visualize.py │ ├── tensor_board.py │ ├── logging.py │ ├── learning_rate.py │ └── conf.py └── parallel │ ├── __init__.py │ ├── server │ └── __init__.py │ ├── distributed │ └── __init__.py │ ├── exception.py │ ├── thread.py │ ├── process.py │ ├── event.py │ └── util.py ├── pytest.ini ├── test_lib └── multiagent-particle-envs │ ├── bin │ ├── __init__.py │ └── interactive.py │ ├── multiagent.egg-info │ ├── not-zip-safe │ ├── dependency_links.txt │ ├── requires.txt │ ├── top_level.txt │ ├── PKG-INFO │ └── SOURCES.txt │ ├── multiagent │ ├── scenarios │ │ ├── __init__.py │ │ ├── simple.py │ │ ├── simple_reference.py │ │ └── simple_speaker_listener.py │ ├── scenario.py │ ├── __init__.py │ ├── policy.py │ └── multi_discrete.py │ ├── setup.py │ ├── LICENSE.txt │ └── make_env.py ├── .gitconfig ├── docs ├── theme │ ├── theme.conf │ └── static │ │ └── pygments.css ├── source │ ├── static │ │ ├── icon.png │ │ ├── favicon.png │ │ ├── icon_title.png │ │ ├── advance │ │ │ └── algorithm_apis │ │ │ │ └── category.png │ │ ├── tutorials │ │ │ ├── recurrent_networks │ │ │ │ ├── drqn.png │ │ │ │ ├── rppo.png │ │ │ │ ├── dqn_his=4.png │ │ │ │ ├── ppo_his=1.png │ │ │ │ └── ppo_his=4.png │ │ │ ├── your_first_program │ │ │ │ └── cartpole.gif │ │ │ └── unleash_distributed_power │ │ │ │ ├── a3c_pcode.png │ │ │ │ ├── impala_pcode.png │ │ │ │ └── dqn_apex_pcode.png │ │ └── icon_title.svg │ ├── api │ │ ├── index.rst │ │ ├── machin.model.rst │ │ ├── machin.env.rst │ │ ├── machin.parallel.rst │ │ ├── machin.auto.rst │ │ └── machin.utils.rst │ ├── advance │ │ └── index.rst │ ├── tutorials │ │ └── index.rst │ ├── index.rst │ ├── about.rst │ └── conf.py ├── requirements.txt ├── Makefile ├── make.bat └── misc │ └── contribute.md ├── examples ├── tutorials │ ├── as_fast_as_lightning │ │ ├── automatic │ │ │ ├── launch.cmd │ │ │ ├── qnet.py │ │ │ └── config.json │ │ └── programmatic │ │ │ ├── nni │ │ │ ├── launch.cmd │ │ │ ├── search_space.json │ │ │ ├── config.yml │ │ │ └── nni_main.py │ │ │ └── simple │ │ │ └── main.py │ ├── recurrent_networks │ │ ├── util.py │ │ ├── history.py │ │ ├── dqn.py │ │ ├── rppo.py │ │ ├── ppo.py │ │ └── drqn.py │ ├── parallel_distributed │ │ ├── mth_exception.py │ │ ├── mpr_exception.py │ │ ├── dist_coll.py │ │ ├── mth_event.py │ │ ├── assign.py │ │ ├── dist_rpc.py │ │ ├── dist_oserver.py │ │ └── mpr_pickle.py │ └── your_first_program │ │ └── main.py └── framework_examples │ ├── dqn.py │ ├── dqn_per.py │ ├── ars.py │ ├── a2c.py │ ├── ppo.py │ ├── rainbow.py │ ├── ddpg.py │ ├── hddpg.py │ └── ddpg_per.py ├── .pylintrc ├── .pre-commit-config.yaml ├── .git-ignore-revs ├── .github └── ISSUE_TEMPLATE │ ├── alternation.md │ ├── feature_request.md │ └── bug_report.md ├── run_linux_test.sh ├── run_macos_test.sh ├── run_win_test.bat ├── DEVELOPMENT.md ├── LICENSE ├── .circleci-archive └── config.yml └── setup.py /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/auto/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/frame/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/auto/env/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/frame/noise/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /machin/frame/helpers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/env/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/frame/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/frame/buffers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/parallel/server/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /machin/model/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/parallel/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | # log_cli = 1 3 | -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/bin/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitconfig: -------------------------------------------------------------------------------- 1 | [blame] 2 | ignoreRevsFile = .git-ignore-revs 3 | -------------------------------------------------------------------------------- /test/data/generators/__init__.py: -------------------------------------------------------------------------------- 1 | from . import generate_gail 2 | -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/multiagent.egg-info/not-zip-safe: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/multiagent.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/multiagent.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | gym 2 | numpy-stl 3 | -------------------------------------------------------------------------------- /docs/theme/theme.conf: -------------------------------------------------------------------------------- 1 | [theme] 2 | inherit = sphinx_rtd_theme 3 | stylesheet = css/machin.css -------------------------------------------------------------------------------- /machin/auto/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from . import openai_gym 2 | 3 | __all__ = ["openai_gym"] 4 | -------------------------------------------------------------------------------- /machin/env/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import openai_gym 2 | 3 | __all__ = ["openai_gym"] 4 | -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/multiagent.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | bin 2 | multiagent 3 | -------------------------------------------------------------------------------- /machin/env/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils, wrappers 2 | 3 | __all__ = ["utils", "wrappers"] 4 | -------------------------------------------------------------------------------- /test/data/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ROOT = os.path.dirname(os.path.abspath(__file__)) 4 | -------------------------------------------------------------------------------- /docs/source/static/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffiX/machin/HEAD/docs/source/static/icon.png -------------------------------------------------------------------------------- /docs/source/static/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffiX/machin/HEAD/docs/source/static/favicon.png -------------------------------------------------------------------------------- /machin/env/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base, openai_gym 2 | 3 | __all__ = ["base", "openai_gym"] 4 | -------------------------------------------------------------------------------- /examples/tutorials/as_fast_as_lightning/automatic/launch.cmd: -------------------------------------------------------------------------------- 1 | python -m machin.auto launch --config config.json -------------------------------------------------------------------------------- /docs/source/static/icon_title.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffiX/machin/HEAD/docs/source/static/icon_title.png -------------------------------------------------------------------------------- /machin/model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import nets 2 | from . import algorithms 3 | 4 | __all__ = ["nets", "algorithms"] 5 | -------------------------------------------------------------------------------- /test/data/__main__.py: -------------------------------------------------------------------------------- 1 | from .all import generate_all 2 | 3 | if __name__ == "__main__": 4 | generate_all() 5 | -------------------------------------------------------------------------------- /examples/tutorials/as_fast_as_lightning/programmatic/nni/launch.cmd: -------------------------------------------------------------------------------- 1 | nnictl create --config config.yml --port 8088 --debug -------------------------------------------------------------------------------- /docs/source/static/advance/algorithm_apis/category.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffiX/machin/HEAD/docs/source/static/advance/algorithm_apis/category.png -------------------------------------------------------------------------------- /docs/source/static/tutorials/recurrent_networks/drqn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffiX/machin/HEAD/docs/source/static/tutorials/recurrent_networks/drqn.png -------------------------------------------------------------------------------- /docs/source/static/tutorials/recurrent_networks/rppo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffiX/machin/HEAD/docs/source/static/tutorials/recurrent_networks/rppo.png -------------------------------------------------------------------------------- /machin/__init__.py: -------------------------------------------------------------------------------- 1 | from . import env, frame, model, parallel, utils 2 | 3 | __version__ = "0.4.2" 4 | __all__ = ["env", "frame", "model", "parallel", "utils"] 5 | -------------------------------------------------------------------------------- /docs/source/static/tutorials/your_first_program/cartpole.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffiX/machin/HEAD/docs/source/static/tutorials/your_first_program/cartpole.gif -------------------------------------------------------------------------------- /docs/source/static/tutorials/recurrent_networks/dqn_his=4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffiX/machin/HEAD/docs/source/static/tutorials/recurrent_networks/dqn_his=4.png -------------------------------------------------------------------------------- /docs/source/static/tutorials/recurrent_networks/ppo_his=1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffiX/machin/HEAD/docs/source/static/tutorials/recurrent_networks/ppo_his=1.png -------------------------------------------------------------------------------- /docs/source/static/tutorials/recurrent_networks/ppo_his=4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffiX/machin/HEAD/docs/source/static/tutorials/recurrent_networks/ppo_his=4.png -------------------------------------------------------------------------------- /machin/frame/__init__.py: -------------------------------------------------------------------------------- 1 | from . import algorithms, buffers, helpers, noise, transition 2 | 3 | __all__ = ["algorithms", "buffers", "helpers", "noise", "transition"] 4 | -------------------------------------------------------------------------------- /machin/frame/noise/__init__.py: -------------------------------------------------------------------------------- 1 | from . import action_space_noise, generator, param_space_noise 2 | 3 | __all__ = ["action_space_noise", "generator", "param_space_noise"] 4 | -------------------------------------------------------------------------------- /docs/source/static/tutorials/unleash_distributed_power/a3c_pcode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffiX/machin/HEAD/docs/source/static/tutorials/unleash_distributed_power/a3c_pcode.png -------------------------------------------------------------------------------- /docs/source/static/tutorials/unleash_distributed_power/impala_pcode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffiX/machin/HEAD/docs/source/static/tutorials/unleash_distributed_power/impala_pcode.png -------------------------------------------------------------------------------- /docs/source/static/tutorials/unleash_distributed_power/dqn_apex_pcode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iffiX/machin/HEAD/docs/source/static/tutorials/unleash_distributed_power/dqn_apex_pcode.png -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | extension-pkg-whitelist=numpy,torch 3 | 4 | [TYPECHECK] 5 | ignored-modules=numpy,torch 6 | ignored-classes=numpy,torch 7 | generated-members=numpy.*,torch.*,*.data,*.grad 8 | -------------------------------------------------------------------------------- /examples/tutorials/recurrent_networks/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as t 3 | 4 | 5 | def convert(mem: np.ndarray): 6 | return t.tensor(mem.reshape(1, 128).astype(np.float32) / 255) 7 | -------------------------------------------------------------------------------- /docs/source/api/index.rst: -------------------------------------------------------------------------------- 1 | API 2 | ================== 3 | .. toctree:: 4 | 5 | machin.env.rst 6 | machin.auto.rst 7 | machin.frame.rst 8 | machin.model.rst 9 | machin.utils.rst 10 | machin.parallel.rst -------------------------------------------------------------------------------- /examples/tutorials/as_fast_as_lightning/programmatic/nni/search_space.json: -------------------------------------------------------------------------------- 1 | { 2 | "lr":{"_type":"choice","_value":[0.0001, 0.001, 0.01, 0.1]}, 3 | "upd":{"_type": "choice", "_value": [0.005, 0.002, 0.001, 0.0001]} 4 | } -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/multiagent/scenarios/__init__.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import os.path as osp 3 | 4 | 5 | def load(name): 6 | pathname = osp.join(osp.dirname(__file__), name) 7 | return imp.load_source('', pathname) 8 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | gym 2 | psutil 3 | numpy 4 | torch>=1.6.0 5 | torchviz 6 | moviepy 7 | matplotlib 8 | colorlog 9 | dill 10 | GPUtil 11 | Pillow 12 | tensorboardX 13 | sphinx==3.0.3 14 | sphinx-autodoc-typehints==1.10.3 15 | sphinx-rtd-theme==0.4.3 16 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 20.8b1 # Replace by any tag/version: https://github.com/psf/black/tags 4 | hooks: 5 | - id: black 6 | language_version: python3 # Should be a command that runs python3.6+ 7 | -------------------------------------------------------------------------------- /docs/source/advance/index.rst: -------------------------------------------------------------------------------- 1 | Advance 2 | ================== 3 | .. toctree:: 4 | :maxdepth: 1 5 | 6 | Architecture overview 7 | Algorithm APIs 8 | Algorithm model requirements 9 | -------------------------------------------------------------------------------- /test/utils/test_learning_rate.py: -------------------------------------------------------------------------------- 1 | from machin.utils.logging import default_logger 2 | from machin.utils.learning_rate import gen_learning_rate_func 3 | 4 | 5 | def test_gen_learning_rate_func(): 6 | func = gen_learning_rate_func([(0, 1e-3), (20000, 1e-3)], default_logger) 7 | func(10000) 8 | func(20001) 9 | -------------------------------------------------------------------------------- /.git-ignore-revs: -------------------------------------------------------------------------------- 1 | # Black formatting 2 | e16d3b2094455bae3fe9b601c6249c6a2977a4e3 3 | 4 | # Pyupgrade for class and inheritance modernization 5 | 9107aaf2b2d64e6b9ff70b32fb3e9a5e23b5cd7f 6 | 7 | # F-strings 8 | 188ae7e3965fa0a7850b78a515b3823cb5a0a1ab 9 | 10 | # Format examples folder - black, f-strings, pyupgrade 11 | a68529f2bc021abf29a36eab5405f456d801751d 12 | -------------------------------------------------------------------------------- /test/utils/test_tensor_board.py: -------------------------------------------------------------------------------- 1 | from machin.utils.tensor_board import default_board 2 | 3 | import pytest 4 | 5 | 6 | class TestTensorBoard: 7 | def test_tensor_board(self): 8 | default_board.init() 9 | with pytest.raises(RuntimeError, match="has been initialized"): 10 | default_board.init() 11 | assert default_board.is_inited() 12 | -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/multiagent.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: multiagent 3 | Version: 0.0.1 4 | Summary: Multi-Agent Goal-Driven Communication Environment 5 | Home-page: https://github.com/openai/multiagent-public 6 | Author: Igor Mordatch 7 | Author-email: mordatch@openai.com 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /machin/model/nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import NeuralNetworkModule, dynamic_module_wrapper, static_module_wrapper 2 | 3 | from .resnet import ResNet 4 | 5 | from . import base 6 | from . import resnet 7 | 8 | __all__ = [ 9 | "NeuralNetworkModule", 10 | "dynamic_module_wrapper", 11 | "static_module_wrapper", 12 | "ResNet", 13 | "base", 14 | "resnet", 15 | ] 16 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | def pytest_addoption(parser): 2 | parser.addoption( 3 | "--gpu_device", 4 | action="store", 5 | default=None, 6 | help="GPU device descriptor in pytorch", 7 | ) 8 | parser.addoption( 9 | "--multiprocess_method", 10 | default="forkserver", 11 | help="spawn or forkserver, default is forkserver", 12 | ) 13 | -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/multiagent/scenario.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # defines scenario upon which the world is built 4 | class BaseScenario(object): 5 | # create elements of the world 6 | def make_world(self): 7 | raise NotImplementedError() 8 | # create initial conditions of the world 9 | def reset_world(self, world): 10 | raise NotImplementedError() 11 | -------------------------------------------------------------------------------- /machin/frame/buffers/__init__.py: -------------------------------------------------------------------------------- 1 | from .buffer import Buffer 2 | from .prioritized_buffer import WeightTree, PrioritizedBuffer 3 | from .buffer_d import DistributedBuffer 4 | from .prioritized_buffer_d import DistributedPrioritizedBuffer 5 | 6 | 7 | __all__ = [ 8 | "Buffer", 9 | "DistributedBuffer", 10 | "PrioritizedBuffer", 11 | "DistributedPrioritizedBuffer", 12 | "WeightTree", 13 | ] 14 | -------------------------------------------------------------------------------- /test/utils/test_visualize.py: -------------------------------------------------------------------------------- 1 | from machin.utils.visualize import visualize_graph 2 | from unittest import mock 3 | 4 | import torch as t 5 | 6 | 7 | def mock_exit(_exit_code): 8 | pass 9 | 10 | 11 | def test_visualize_graph(): 12 | tensor = t.ones([2, 2]) 13 | tensor = tensor * t.ones([2, 2]) 14 | with mock.patch("machin.utils.visualize.exit", mock_exit): 15 | visualize_graph(tensor) 16 | -------------------------------------------------------------------------------- /docs/source/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ================== 3 | .. toctree:: 4 | :maxdepth: 1 5 | 6 | Your first program 7 | As fast as lightning 8 | Data flow in Machin 9 | Parallel, distributed 10 | Unleash distributed power 11 | Recurrent networks 12 | -------------------------------------------------------------------------------- /machin/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | checker, 3 | conf, 4 | helper_classes, 5 | learning_rate, 6 | logging, 7 | media, 8 | prepare, 9 | save_env, 10 | tensor_board, 11 | visualize, 12 | ) 13 | 14 | __all__ = [ 15 | "checker", 16 | "conf", 17 | "helper_classes", 18 | "learning_rate", 19 | "logging", 20 | "media", 21 | "prepare", 22 | "save_env", 23 | "tensor_board", 24 | "visualize", 25 | ] 26 | -------------------------------------------------------------------------------- /machin/env/utils/openai_gym.py: -------------------------------------------------------------------------------- 1 | def disable_view_window(): 2 | # Disable pop up windows and render in background 3 | # by injecting custom viewer constructor. 4 | from gym.envs.classic_control import rendering 5 | 6 | org_constructor = rendering.Viewer.__init__ 7 | 8 | def constructor(self, *args, **kwargs): 9 | org_constructor(self, *args, **kwargs) 10 | self.window.set_visible(visible=False) 11 | 12 | rendering.Viewer.__init__ = constructor 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/alternation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Alternation 3 | about: Something you find inconvinient and needs to be alternated. 4 | title: "[ALTER]" 5 | labels: allter 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Where to alter** 11 | 12 | **Why to alter** 13 | illy designed? 14 | bad naming / strange acronyms? 15 | obscured function description? 16 | lack of inline comments? 17 | Unecessary document? 18 | 19 | **How to alter** 20 | Your suggested altenation, preferably a detailed example. 21 | -------------------------------------------------------------------------------- /machin/auto/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from . import envs 3 | from . import config 4 | from . import dataset 5 | from . import launcher 6 | from . import pl_logger 7 | 8 | try: 9 | from . import pl_plugin 10 | except Exception as _: 11 | warnings.warn( 12 | "Failed to import pytorch_lightning plugins relying on torch.distributed." 13 | " Set them to None." 14 | ) 15 | pl_plugin = None 16 | 17 | __all__ = ["envs", "config", "dataset", "launcher", "pl_logger", "pl_plugin"] 18 | -------------------------------------------------------------------------------- /examples/tutorials/as_fast_as_lightning/automatic/qnet.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | import torch.nn as nn 3 | 4 | 5 | class QNet(nn.Module): 6 | def __init__(self, state_dim, action_num): 7 | super().__init__() 8 | 9 | self.fc1 = nn.Linear(state_dim, 16) 10 | self.fc2 = nn.Linear(16, 16) 11 | self.fc3 = nn.Linear(16, action_num) 12 | 13 | def forward(self, state): 14 | a = t.relu(self.fc1(state)) 15 | a = t.relu(self.fc2(a)) 16 | return self.fc3(a) 17 | -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='multiagent', 4 | version='0.0.1', 5 | description='Multi-Agent Goal-Driven Communication Environment', 6 | url='https://github.com/openai/multiagent-public', 7 | author='Igor Mordatch', 8 | author_email='mordatch@openai.com', 9 | packages=find_packages(), 10 | include_package_data=True, 11 | zip_safe=False, 12 | install_requires=['gym', 'numpy-stl'] 13 | ) 14 | -------------------------------------------------------------------------------- /run_linux_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # if you are using any non-standard python 3 installation (>= 3.6) 3 | # please replace first three python with your absolute path 4 | python --version 5 | python -m pip install virtualenv 6 | python -m virtualenv venv 7 | source venv/bin/activate 8 | pip install . 9 | pip install ./test_lib/multiagent-particle-envs/ 10 | pip install mock pytest==6.0.1 pytest-html==1.22.1 pytest-repeat==0.8.0 11 | python -m pytest -s --assert=plain -k "not full_train" --html=test_api.html --self-contained-html ./test/ -------------------------------------------------------------------------------- /run_macos_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # if you are using any non-standard python 3 installation (>= 3.6) 3 | # please replace first three python with your absolute path 4 | python --version 5 | python -m pip install virtualenv 6 | python -m virtualenv venv 7 | source venv/bin/activate 8 | pip install . 9 | pip install ./test_lib/multiagent-particle-envs/ 10 | pip install mock pytest==6.0.1 pytest-html==1.22.1 pytest-repeat==0.8.0 11 | python -m pytest -s --assert=plain -k "not full_train" --html=test_api.html --self-contained-html ./test/ -------------------------------------------------------------------------------- /run_win_test.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | REM if you are using any non-standard python 3 installation (>= 3.6) 3 | REM please replace first three python with your absolute path 4 | 5 | python --version 6 | python -m pip install virtualenv 7 | python -m virtualenv venv 8 | call .\venv\Scripts\activate 9 | pip install . 10 | pip install .\test_lib\multiagent-particle-envs\ 11 | pip install mock pytest==6.0.1 pytest-html==1.22.1 pytest-repeat==0.8.0 12 | python -m pytest -s --assert=plain -k "not full_train" --html=test_api.html --self-contained-html .\test\ 13 | pause -------------------------------------------------------------------------------- /DEVELOPMENT.md: -------------------------------------------------------------------------------- 1 | # Code style 2 | The code follows code styling by [black](https://github.com/psf/black). 3 | 4 | To automate code formatting, [pre-commit](https://github.com/pre-commit/pre-commit) is used, to run code checks before commiting changes. 5 | If you have pre-commit installed from the requirements-dev.txt simple run ``pre-commit install`` to install the hooks for this repo. 6 | 7 | # Setup 8 | 9 | ### git setup for ignoring refactoring commits 10 | Run `git config --local include.path ../.gitconfig` inside your clone of the repo. 11 | Update git to `2.23` or higher to be able to use this feature. 12 | -------------------------------------------------------------------------------- /examples/tutorials/parallel_distributed/mth_exception.py: -------------------------------------------------------------------------------- 1 | from machin.parallel.thread import Thread, ThreadException 2 | import time 3 | 4 | 5 | def test1(): 6 | time.sleep(1) 7 | print(f"Exception occurred at {time.time()}") 8 | raise RuntimeError("Error") 9 | 10 | 11 | if __name__ == "__main__": 12 | t1 = Thread(target=test1) 13 | t1.start() 14 | while True: 15 | try: 16 | t1.watch() 17 | except ThreadException as e: 18 | print(f"Exception caught at {time.time()}") 19 | print(f"Exception is: {e}") 20 | break 21 | t1.join() 22 | -------------------------------------------------------------------------------- /examples/tutorials/parallel_distributed/mpr_exception.py: -------------------------------------------------------------------------------- 1 | from machin.parallel.process import Process, ProcessException 2 | import time 3 | 4 | 5 | def test1(): 6 | time.sleep(1) 7 | print(f"Exception occurred at {time.time()}") 8 | raise RuntimeError("Error") 9 | 10 | 11 | if __name__ == "__main__": 12 | t1 = Process(target=test1) 13 | t1.start() 14 | while True: 15 | try: 16 | t1.watch() 17 | except ProcessException as e: 18 | print(f"Exception caught at {time.time()}") 19 | print(f"Exception is: {e}") 20 | break 21 | t1.join() 22 | -------------------------------------------------------------------------------- /docs/source/api/machin.model.rst: -------------------------------------------------------------------------------- 1 | machin.model 2 | ============= 3 | 4 | algorithms 5 | +++++++++++++ 6 | ``machin.model.algorithms`` provides implementations for various network 7 | architectures used by algorithms. 8 | 9 | .. automodule:: machin.model.algorithms 10 | :members: 11 | :undoc-members: 12 | :show-inheritance: 13 | :member-order: bysource 14 | 15 | nets 16 | +++++++++++++ 17 | ``machin.model.nets`` provides implementations for various popular network 18 | architectures. 19 | 20 | .. automodule:: machin.model.nets 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | :member-order: bysource -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/multiagent/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | # Multiagent envs 4 | # ---------------------------------------- 5 | 6 | # obsolete 7 | 8 | # register( 9 | # id='MultiagentSimple-v0', 10 | # entry_point='multiagent.envs:SimpleEnv', 11 | # # FIXME(cathywu) currently has to be exactly max_path_length parameters in 12 | # # rllab run script 13 | # max_episode_steps=100, 14 | # ) 15 | # 16 | # register( 17 | # id='MultiagentSimpleSpeakerListener-v0', 18 | # entry_point='multiagent.envs:SimpleSpeakerListenerEnv', 19 | # max_episode_steps=100, 20 | # ) 21 | -------------------------------------------------------------------------------- /machin/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from multiprocessing import get_context, get_start_method 3 | from . import assigner, exception, pickle, thread, pool, queue 4 | 5 | try: 6 | from . import distributed, server 7 | except Exception as _: 8 | warnings.warn( 9 | "Failed to import distributed and server modules relying on torch.distributed." 10 | " Set them to None." 11 | ) 12 | distributed = None 13 | server = None 14 | 15 | __all__ = [ 16 | "get_context", 17 | "get_start_method", 18 | "distributed", 19 | "server", 20 | "assigner", 21 | "exception", 22 | "pickle", 23 | "pool", 24 | "queue", 25 | ] 26 | -------------------------------------------------------------------------------- /machin/parallel/server/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ordered_server 2 | from . import param_server 3 | from .ordered_server import ( 4 | OrderedServerBase, 5 | OrderedServerSimple, 6 | OrderedServerSimpleImpl, 7 | ) 8 | from .param_server import ( 9 | PushPullGradServer, 10 | PushPullGradServerImpl, 11 | PushPullModelServer, 12 | PushPullModelServerImpl, 13 | ) 14 | 15 | __all__ = [ 16 | "OrderedServerBase", 17 | "OrderedServerSimple", 18 | "OrderedServerSimpleImpl", 19 | "PushPullGradServer", 20 | "PushPullGradServerImpl", 21 | "PushPullModelServer", 22 | "PushPullModelServerImpl", 23 | "ordered_server", 24 | "param_server", 25 | ] 26 | -------------------------------------------------------------------------------- /examples/tutorials/recurrent_networks/history.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | 3 | 4 | class History: 5 | def __init__(self, history_depth, state_shape): 6 | self.history = [t.zeros(state_shape) for _ in range(history_depth)] 7 | self.state_shape = state_shape 8 | 9 | def append(self, state): 10 | assert ( 11 | t.is_tensor(state) 12 | and state.dtype == t.float32 13 | and tuple(state.shape) == self.state_shape 14 | ) 15 | self.history.append(state) 16 | self.history.pop(0) 17 | return self 18 | 19 | def get(self): 20 | # size: (1, history_depth, ...) 21 | return t.cat(self.history, dim=0).unsqueeze(0) 22 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[FEATURE]" 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /test/frame/algorithms/utils.py: -------------------------------------------------------------------------------- 1 | from gym.wrappers.time_limit import TimeLimit 2 | 3 | 4 | def unwrap_time_limit(env): 5 | # some environment comes with a time limit, we must remove it 6 | if isinstance(env, TimeLimit): 7 | return env.unwrapped 8 | else: 9 | return env 10 | 11 | 12 | class Smooth: 13 | def __init__(self): 14 | self._value = None 15 | 16 | def update(self, new_value, update_rate=0.2): 17 | if self._value is None: 18 | self._value = new_value 19 | else: 20 | self._value = self._value * (1 - update_rate) + new_value * update_rate 21 | return self._value 22 | 23 | @property 24 | def value(self): 25 | return self._value 26 | -------------------------------------------------------------------------------- /docs/source/api/machin.env.rst: -------------------------------------------------------------------------------- 1 | machin.env 2 | ============= 3 | 4 | utils 5 | +++++++++++++ 6 | ``machin.env.utils`` provides utilities to deal with various environments. 7 | 8 | .. automodule:: machin.env.utils.openai_gym 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | :member-order: bysource 13 | 14 | wrappers 15 | +++++++++++++ 16 | ``machin.env.wrappers`` provides parallel execution wrappers for various 17 | environments. 18 | 19 | .. automodule:: machin.env.wrappers.base 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | :member-order: bysource 24 | 25 | .. automodule:: machin.env.wrappers.openai_gym 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | :member-order: bysource -------------------------------------------------------------------------------- /examples/tutorials/as_fast_as_lightning/programmatic/nni/config.yml: -------------------------------------------------------------------------------- 1 | authorName: default 2 | experimentName: example_nni_auto 3 | trialConcurrency: 1 4 | maxExecDuration: 1h 5 | maxTrialNum: 10 6 | #choice: local, remote, pai 7 | trainingServicePlatform: local 8 | searchSpacePath: search_space.json 9 | #choice: true, false 10 | useAnnotation: false 11 | tuner: 12 | #choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner 13 | #SMAC (SMAC should be installed through nnictl) 14 | builtinTunerName: TPE 15 | classArgs: 16 | #choice: maximize, minimize 17 | optimize_mode: maximize 18 | trial: 19 | command: python nni_main.py 20 | codeDir: . 21 | gpuNum: 1 22 | localConfig: 23 | useActiveGpu: true 24 | maxTrialNumPerGpu: 2 25 | gpuIndices: "0" -------------------------------------------------------------------------------- /test/util_platforms.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pytest 3 | 4 | linux_only = pytest.mark.skipif("not sys.platform.startswith('linux')") 5 | windows_only = pytest.mark.skipif("not sys.platform.startswith('win')") 6 | macos_only = pytest.mark.skipif("not sys.platform.startswith('darwin')") 7 | 8 | 9 | def linux_only_forall(): 10 | if not sys.platform.startswith("linux"): 11 | pytest.skip("Requires Linux platform", allow_module_level=True) 12 | 13 | 14 | def windows_only_forall(): 15 | if not sys.platform.startswith("win"): 16 | pytest.skip("Requires Windows platform", allow_module_level=True) 17 | 18 | 19 | def macos_only_forall(): 20 | if not sys.platform.startswith("darwin"): 21 | pytest.skip("Requires MacOS platform", allow_module_level=True) 22 | -------------------------------------------------------------------------------- /examples/tutorials/parallel_distributed/dist_coll.py: -------------------------------------------------------------------------------- 1 | from machin.parallel.distributed import World 2 | from torch.multiprocessing import spawn 3 | import torch as t 4 | 5 | 6 | def main(rank): 7 | world = World(world_size=3, rank=rank, name=str(rank), rpc_timeout=20) 8 | # all sub processes must enter this function, including non-group members 9 | group = world.create_collective_group(ranks=[0, 1, 2]) 10 | # test broadcast 11 | # process 0 will broad cast a tensor filled with 1 12 | # to process 1 and 2 13 | if rank == 0: 14 | a = t.ones([5]) 15 | else: 16 | a = t.zeros([5]) 17 | group.broadcast(a, 0) 18 | print(a) 19 | group.destroy() 20 | return True 21 | 22 | 23 | if __name__ == "__main__": 24 | # spawn 3 sub processes 25 | spawn(main, nprocs=3) 26 | -------------------------------------------------------------------------------- /machin/utils/visualize.py: -------------------------------------------------------------------------------- 1 | from torchviz import make_dot 2 | 3 | """ 4 | The visualization module, currently it only contains the pytorch 5 | flow graph visualization, more visualizations for cnn, resnet, lstm & rnn, 6 | attention layers will be added in the future, if there is any feature request. 7 | """ 8 | 9 | 10 | def visualize_graph(final_tensor, visualize_dir="", exit_after_vis=True): 11 | """ 12 | Visualize a pytorch flow graph 13 | 14 | Args: 15 | final_tensor: The last output tensor of the flow graph 16 | visualize_dir: Directory to place the visualized files 17 | exit_after_vis: Whether to exit the whole program 18 | after visualization. 19 | """ 20 | g = make_dot(final_tensor) 21 | g.render(directory=visualize_dir, view=False, quiet=True) 22 | if exit_after_vis: 23 | exit(0) 24 | -------------------------------------------------------------------------------- /machin/utils/tensor_board.py: -------------------------------------------------------------------------------- 1 | """ 2 | Attributes: 3 | default_board: The default global board. 4 | """ 5 | import numpy as np 6 | from tensorboardX import SummaryWriter 7 | 8 | 9 | class TensorBoard: 10 | """ 11 | Create a tensor board object. 12 | 13 | Attributes: 14 | writer: ``SummaryWriter`` of package ``tensorboardX``. 15 | """ 16 | 17 | def __init__(self): 18 | self.writer = None 19 | 20 | def init(self, *writer_args): 21 | if self.writer is None: 22 | self.writer = SummaryWriter(*writer_args) 23 | else: 24 | raise RuntimeError("Writer has been initialized!") 25 | 26 | def is_inited(self) -> bool: 27 | """ 28 | Returns: whether the board has been initialized with a writer. 29 | """ 30 | return not self.writer is None 31 | 32 | 33 | default_board = TensorBoard() 34 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/multiagent.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | README.md 2 | setup.py 3 | bin/__init__.py 4 | bin/interactive.py 5 | multiagent/__init__.py 6 | multiagent/core.py 7 | multiagent/environment.py 8 | multiagent/multi_discrete.py 9 | multiagent/policy.py 10 | multiagent/rendering.py 11 | multiagent/scenario.py 12 | multiagent.egg-info/PKG-INFO 13 | multiagent.egg-info/SOURCES.txt 14 | multiagent.egg-info/dependency_links.txt 15 | multiagent.egg-info/not-zip-safe 16 | multiagent.egg-info/requires.txt 17 | multiagent.egg-info/top_level.txt 18 | multiagent/scenarios/__init__.py 19 | multiagent/scenarios/simple.py 20 | multiagent/scenarios/simple_adversary.py 21 | multiagent/scenarios/simple_crypto.py 22 | multiagent/scenarios/simple_push.py 23 | multiagent/scenarios/simple_reference.py 24 | multiagent/scenarios/simple_speaker_listener.py 25 | multiagent/scenarios/simple_spread.py 26 | multiagent/scenarios/simple_tag.py 27 | multiagent/scenarios/simple_world_comm.py -------------------------------------------------------------------------------- /docs/source/api/machin.parallel.rst: -------------------------------------------------------------------------------- 1 | machin.parallel 2 | =============== 3 | 4 | distributed 5 | +++++++++++++++++ 6 | .. automodule:: machin.parallel.distributed 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | server 12 | +++++++++++++++++ 13 | .. automodule:: machin.parallel.server 14 | :members: 15 | :undoc-members: 16 | :show-inheritance: 17 | 18 | assigner 19 | +++++++++++++++++ 20 | .. automodule:: machin.parallel.assigner 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | pickle 26 | +++++++++++++++++ 27 | .. automodule:: machin.parallel.pickle 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | 32 | pool 33 | +++++++++++++++++ 34 | .. automodule:: machin.parallel.pool 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | queue 40 | +++++++++++++++++ 41 | .. automodule:: machin.parallel.queue 42 | :members: 43 | :undoc-members: 44 | :show-inheritance: -------------------------------------------------------------------------------- /machin/frame/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from .base import TorchFramework 3 | 4 | from .dqn import DQN 5 | from .dqn_per import DQNPer 6 | from .rainbow import RAINBOW 7 | 8 | from .ddpg import DDPG 9 | from .hddpg import HDDPG 10 | from .td3 import TD3 11 | from .ddpg_per import DDPGPer 12 | 13 | from .a2c import A2C 14 | from .ppo import PPO 15 | from .trpo import TRPO 16 | from .sac import SAC 17 | 18 | from .maddpg import MADDPG 19 | 20 | from .gail import GAIL 21 | 22 | 23 | from .a3c import A3C 24 | from .apex import DQNApex, DDPGApex 25 | from .impala import IMPALA 26 | from .ars import ARS 27 | 28 | 29 | __all__ = [ 30 | "TorchFramework", 31 | "DQN", 32 | "DQNPer", 33 | "RAINBOW", 34 | "DDPG", 35 | "HDDPG", 36 | "TD3", 37 | "DDPGPer", 38 | "A2C", 39 | "A3C", 40 | "PPO", 41 | "TRPO", 42 | "SAC", 43 | "DQNApex", 44 | "DDPGApex", 45 | "IMPALA", 46 | "ARS", 47 | "MADDPG", 48 | "GAIL", 49 | ] 50 | -------------------------------------------------------------------------------- /test/util_create_ma_env.py: -------------------------------------------------------------------------------- 1 | from multiagent.environment import MultiAgentEnv 2 | import multiagent.scenarios as scenarios 3 | import os 4 | 5 | _root_dir = os.path.dirname(scenarios.__file__) 6 | _all_files = [ 7 | f.split(".")[0] 8 | for f in os.listdir(_root_dir) 9 | if (not f.startswith("__") and f.endswith(".py")) 10 | ] 11 | 12 | 13 | def all_envs(): 14 | return _all_files 15 | 16 | 17 | def create_env(env_name): 18 | if env_name not in all_envs(): 19 | raise RuntimeError("Invalid multi-agent environment: " + env_name) 20 | # load scenario from script 21 | scenario = scenarios.load(env_name + ".py").Scenario() 22 | # create world 23 | world = scenario.make_world() 24 | # create multiagent environment 25 | env = MultiAgentEnv( 26 | world, 27 | scenario.reset_world, 28 | scenario.reward, 29 | scenario.observation, 30 | info_callback=None, 31 | shared_viewer=False, 32 | ) 33 | return env 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /docs/source/api/machin.auto.rst: -------------------------------------------------------------------------------- 1 | machin.auto 2 | ============= 3 | 4 | ``machin.auto`` provides utilities to simplify your programming with PyTorch Lightning. 5 | 6 | envs 7 | +++++++++++++++++ 8 | .. automodule:: machin.auto.envs 9 | :members: 10 | :undoc-members: 11 | :member-order: bysource 12 | 13 | config 14 | +++++++++++++++++ 15 | .. automodule:: machin.auto.config 16 | :members: 17 | :undoc-members: 18 | :member-order: bysource 19 | 20 | dataset 21 | +++++++++++++++++ 22 | .. automodule:: machin.auto.dataset 23 | :members: 24 | :undoc-members: 25 | :member-order: bysource 26 | 27 | launcher 28 | +++++++++++++++++ 29 | .. automodule:: machin.auto.launcher 30 | :members: 31 | :undoc-members: 32 | :member-order: bysource 33 | 34 | pl_logger 35 | +++++++++++++++++ 36 | .. automodule:: machin.auto.pl_logger 37 | :members: 38 | :undoc-members: 39 | :member-order: bysource 40 | 41 | pl_pluggin 42 | +++++++++++++++++ 43 | .. automodule:: machin.auto.pl_plugin 44 | :members: 45 | :undoc-members: 46 | :member-order: bysource -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Iffi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 18 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 19 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 20 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 21 | OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /examples/tutorials/as_fast_as_lightning/programmatic/simple/main.py: -------------------------------------------------------------------------------- 1 | from machin.auto.config import ( 2 | generate_algorithm_config, 3 | generate_env_config, 4 | generate_training_config, 5 | launch, 6 | ) 7 | 8 | import torch as t 9 | import torch.nn as nn 10 | 11 | 12 | class SomeQNet(nn.Module): 13 | def __init__(self, state_dim, action_num): 14 | super().__init__() 15 | 16 | self.fc1 = nn.Linear(state_dim, 16) 17 | self.fc2 = nn.Linear(16, 16) 18 | self.fc3 = nn.Linear(16, action_num) 19 | 20 | def forward(self, state): 21 | a = t.relu(self.fc1(state)) 22 | a = t.relu(self.fc2(a)) 23 | return self.fc3(a) 24 | 25 | 26 | if __name__ == "__main__": 27 | config = generate_algorithm_config("DQN") 28 | config = generate_env_config("openai_gym", config) 29 | config = generate_training_config( 30 | root_dir="trial", episode_per_epoch=10, max_episodes=10000, config=config 31 | ) 32 | config["frame_config"]["models"] = ["SomeQNet", "SomeQNet"] 33 | config["frame_config"]["model_kwargs"] = [{"state_dim": 4, "action_num": 2}] * 2 34 | launch(config) 35 | -------------------------------------------------------------------------------- /examples/tutorials/parallel_distributed/mth_event.py: -------------------------------------------------------------------------------- 1 | from machin.parallel.event import * 2 | from machin.parallel.thread import Thread 3 | import time 4 | 5 | event1 = Event() 6 | event2 = Event() 7 | event3 = Event() 8 | 9 | # wait() will block until its value might have changed (due to a sub event) 10 | # wait() returns a bool value 11 | 12 | 13 | def test1(): 14 | global event1, event2, event3 15 | event = OrEvent(event1, event2, event3) 16 | while not event.wait(): 17 | continue 18 | # will print if any one of these events are set 19 | print("hello1") 20 | 21 | 22 | def test2(): 23 | global event1, event2, event3 24 | event = AndEvent(AndEvent(event1, event3), event2) 25 | while not event.wait(): 26 | continue 27 | # will print if event1, event2 and event3 are all set 28 | print("hello2") 29 | 30 | 31 | if __name__ == "__main__": 32 | t1 = Thread(target=test1) 33 | t2 = Thread(target=test2) 34 | t1.start() 35 | t2.start() 36 | print("set event1") 37 | event1.set() 38 | 39 | time.sleep(1) 40 | print("set event2") 41 | event2.set() 42 | print("set event3") 43 | event3.set() 44 | -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /machin/parallel/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | try: 4 | from ._world import ( 5 | World, 6 | CollectiveGroup, 7 | RpcGroup, 8 | get_world, 9 | get_cur_rank, 10 | get_cur_name, 11 | is_world_initialized, 12 | debug_with_process, 13 | ) 14 | 15 | from . import _world as world 16 | except Exception as e: 17 | warnings.warn( 18 | f""" 19 | 20 | Importing world failed. 21 | Exception: {str(e)} 22 | This might be because you are using platforms other than linux. 23 | All exported symbols will be set to `None`, please don't use 24 | any distributed framework. 25 | """ 26 | ) 27 | World = None 28 | CollectiveGroup = None 29 | RpcGroup = None 30 | get_world = None 31 | get_cur_rank = None 32 | get_cur_name = None 33 | is_world_initialized = None 34 | debug_with_process = None 35 | 36 | __all__ = [ 37 | "World", 38 | "CollectiveGroup", 39 | "RpcGroup", 40 | "get_world", 41 | "get_cur_rank", 42 | "get_cur_name", 43 | "is_world_initialized", 44 | "debug_with_process", 45 | ] 46 | -------------------------------------------------------------------------------- /test/auto/test_dataset.py: -------------------------------------------------------------------------------- 1 | from machin.auto.dataset import determine_precision, DatasetResult 2 | import pytest 3 | import torch as t 4 | import torch.nn as nn 5 | 6 | 7 | class QNet(nn.Module): 8 | def __init__(self, state_dim, action_num): 9 | super().__init__() 10 | 11 | self.fc1 = nn.Linear(state_dim, 16) 12 | self.fc2 = nn.Linear(16, 16) 13 | self.fc3 = nn.Linear(16, action_num) 14 | 15 | def forward(self, state): 16 | a = t.relu(self.fc1(state)) 17 | a = t.relu(self.fc2(a)) 18 | return self.fc3(a) 19 | 20 | 21 | def test_determine_precision(): 22 | assert determine_precision([QNet(10, 2)]) == t.float32 23 | mixed_qnet = QNet(10, 2) 24 | mixed_qnet.fc2 = mixed_qnet.fc2.type(t.float64) 25 | with pytest.raises(RuntimeError, match="Multiple data types of parameters"): 26 | determine_precision([mixed_qnet]) 27 | 28 | 29 | class TestDatasetResult: 30 | def test_add_observation(self): 31 | dr = DatasetResult() 32 | dr.add_observation({}) 33 | assert len(dr.observations) == 1 34 | assert len(dr) == 1 35 | 36 | def test_add_log(self): 37 | dr = DatasetResult() 38 | dr.add_log({}) 39 | assert len(dr.logs) == 1 40 | -------------------------------------------------------------------------------- /machin/utils/logging.py: -------------------------------------------------------------------------------- 1 | """ 2 | Attributes: 3 | default_logger: The default global logger. 4 | 5 | TODO: maybe add logging utilities for distributed scenario? 6 | """ 7 | import colorlog 8 | from logging import INFO 9 | 10 | 11 | class FakeLogger: 12 | def setLevel(self, level): 13 | pass 14 | 15 | def debug(self, msg, *args, **kwargs): 16 | pass 17 | 18 | def info(self, msg, *args, **kwargs): 19 | pass 20 | 21 | def warning(self, msg, *args, **kwargs): 22 | pass 23 | 24 | def warn(self, msg, *args, **kwargs): 25 | pass 26 | 27 | def error(self, msg, *args, **kwargs): 28 | pass 29 | 30 | def exception(self, msg, *args, exc_info=True, **kwargs): 31 | pass 32 | 33 | def critical(self, msg, *args, **kwargs): 34 | pass 35 | 36 | def log(self, level, msg, *args, **kwargs): 37 | pass 38 | 39 | 40 | _default_handler = colorlog.StreamHandler() 41 | _default_handler.setFormatter( 42 | colorlog.ColoredFormatter( 43 | "%(log_color)s[%(asctime)s] <%(levelname)s>:%(name)s:%(message)s" 44 | ) 45 | ) 46 | 47 | default_logger = colorlog.getLogger("default_logger") 48 | default_logger.addHandler(_default_handler) 49 | default_logger.setLevel(INFO) 50 | fake_logger = FakeLogger() 51 | -------------------------------------------------------------------------------- /machin/parallel/exception.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | 3 | 4 | class RemoteTraceback(Exception): # pragma: no cover 5 | """ 6 | Remote traceback, rebuilt by ``_rebuild_exc`` from pickled original 7 | traceback ExceptionWithTraceback, should be thrown on the master 8 | side. 9 | """ 10 | 11 | def __init__(self, tb): 12 | self.tb = tb 13 | 14 | def __str__(self): 15 | return self.tb 16 | 17 | 18 | def _rebuild_exc(exc, tb): 19 | exc.__cause__ = RemoteTraceback(tb) 20 | return exc 21 | 22 | 23 | class ExceptionWithTraceback: # pragma: no cover 24 | def __init__(self, exc: Exception, tb: str = None): 25 | """ 26 | This exception class is used by slave processes to capture 27 | exceptions thrown during execution and send back throw queues 28 | to their master. 29 | 30 | Args: 31 | exc: Your exception. 32 | tb: An optional traceback, by default is is set to 33 | ``exc.__traceback__`` 34 | """ 35 | if tb is None: 36 | tb = exc.__traceback__ 37 | tb = traceback.format_exception(type(exc), exc, tb) 38 | tb = "".join(tb) 39 | self.exc = exc 40 | self.tb = f'\n"""\n{tb}"""' 41 | 42 | def __reduce__(self): 43 | # Used by pickler 44 | return _rebuild_exc, (self.exc, self.tb) 45 | -------------------------------------------------------------------------------- /test/data/archive.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import datetime 4 | import torch as t 5 | 6 | 7 | class Archive: 8 | def __init__(self, path=None, find_in=None, match=None): 9 | if path is not None: 10 | self.path = path 11 | elif find_in is not None and match is not None: 12 | for f in os.listdir(find_in): 13 | if ( 14 | os.path.isfile(os.path.join(find_in, f)) 15 | and re.match(match, f) is not None 16 | ): 17 | self.path = os.path.join(find_in, f) 18 | break 19 | else: 20 | raise ValueError(f"Could not find a file in {find_in} matching {match}") 21 | else: 22 | raise ValueError( 23 | "You can either specify a path, or a find path and match pattern." 24 | ) 25 | self.data = {} 26 | 27 | def add_item(self, key, obj): 28 | self.data[key] = obj 29 | return self 30 | 31 | def load(self): 32 | self.data = t.load(self.path) 33 | return self 34 | 35 | def save(self): 36 | t.save(self.data, self.path, pickle_protocol=3) 37 | return self 38 | 39 | def item(self, key): 40 | return self.data[key] 41 | 42 | 43 | def get_time_string(): 44 | return datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 45 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Machin documentation master file, created by 2 | sphinx-quickstart on Mon Jun 1 14:47:38 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | ================== 7 | Welcome 8 | ================== 9 | Welcome to the main documentation of **Machin** library. 10 | 11 | ---- 12 | 13 | About 14 | ++++++++++++++++++ 15 | .. toctree:: 16 | :maxdepth: 1 17 | 18 | about.rst 19 | 20 | Installation 21 | ++++++++++++++++++ 22 | Machin is hosted on `PyPI `_, currently it 23 | requires: 24 | 25 | 1. python >= 3.6 26 | 2. torch >= 1.6.0 27 | 28 | If you are using PIP to manage your python packages, you may directly type:: 29 | 30 | pip install machin 31 | 32 | If you are using conda to manage your python packages, you are suggested to create a 33 | virtual environment first, to prevent PIP changes your packages without letting 34 | conda know:: 35 | 36 | conda create -n some_env pip 37 | conda activate some_env 38 | pip install machin 39 | 40 | Tutorials and examples 41 | ++++++++++++++++++++++ 42 | .. toctree:: 43 | :maxdepth: 1 44 | 45 | tutorials/index.rst 46 | advance/index.rst 47 | 48 | 49 | API 50 | ++++++++++++++++++ 51 | .. toctree:: 52 | :maxdepth: 2 53 | 54 | api/index.rst 55 | 56 | 57 | Indices and tables 58 | ++++++++++++++++++ 59 | 60 | * :ref:`genindex` 61 | * :ref:`modindex` 62 | * :ref:`search` 63 | -------------------------------------------------------------------------------- /docs/source/about.rst: -------------------------------------------------------------------------------- 1 | About 2 | ================================== 3 | .. toctree:: 4 | :maxdepth: 1 5 | :caption: Contents: 6 | 7 | Machin is a reinforcement library purely based on pytorch, 8 | it is designed with three things in mind: 9 | 10 | 1. **Easy to understand.** 11 | 2. **Easy to extend.** 12 | 3. **Easy to reuse.** 13 | 14 | The first goal is achieved through clear structure design, robust document, 15 | and concise description of use cases. The second goal is achieved through 16 | adding an extra layer upon basic apis provided in the distributed module of 17 | pytorch, this layer offers additional fault tolerance mechanism and 18 | eliminates hassles occurring in distributed programming. The last goal is 19 | the result of modular designs, careful api arrangements, and experiences 20 | gathered from other similar projects. 21 | 22 | Compared to other versatile and powerful reinforcement learning frameworks, 23 | Machin tries to offer a pleasant programming experience, smoothing out 24 | as many obstacles involved in reinforcement learning and distributed 25 | programming as possible. Some essential functions such as automated tuning and 26 | neural architecture search are not offered in this package, we strongly 27 | recommend you take a look at these amazing projects and take a piggyback ride: 28 | 29 | * `ray tune `_ 30 | * `tpot `_ 31 | * `nni `_ -------------------------------------------------------------------------------- /test/frame/noise/test_generator.py: -------------------------------------------------------------------------------- 1 | from machin.frame.noise.generator import ( 2 | NormalNoiseGen, 3 | UniformNoiseGen, 4 | ClippedNormalNoiseGen, 5 | OrnsteinUhlenbeckNoiseGen, 6 | ) 7 | 8 | import torch as t 9 | 10 | 11 | class TestAllNoiseGen: 12 | ######################################################################## 13 | # Test for all noise generators 14 | ######################################################################## 15 | def test_normal_noise_gen(self, pytestconfig): 16 | noise_gen = NormalNoiseGen([1, 2]) 17 | noise_gen() 18 | noise_gen(pytestconfig.getoption("gpu_device")) 19 | str(noise_gen) 20 | 21 | def test_clipped_normal_noise_gen(self, pytestconfig): 22 | noise_gen = ClippedNormalNoiseGen([1, 2]) 23 | noise_gen() 24 | noise_gen(pytestconfig.getoption("gpu_device")) 25 | str(noise_gen) 26 | 27 | def test_uniform_noise_gen(self, pytestconfig): 28 | noise_gen = UniformNoiseGen([1, 2]) 29 | noise_gen() 30 | noise_gen(pytestconfig.getoption("gpu_device")) 31 | str(noise_gen) 32 | 33 | def test_ou_noise_gen(self, pytestconfig): 34 | noise_gen = OrnsteinUhlenbeckNoiseGen([1, 2]) 35 | noise_gen2 = OrnsteinUhlenbeckNoiseGen([1, 2], x0=t.ones([1, 2])) 36 | noise_gen() 37 | noise_gen.reset() 38 | noise_gen2.reset() 39 | noise_gen(pytestconfig.getoption("gpu_device")) 40 | str(noise_gen) 41 | -------------------------------------------------------------------------------- /machin/utils/learning_rate.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module is the place for all learning rate functions, currently, only 3 | manual learning rate changing according to global steps is implemented,. 4 | """ 5 | from typing import List, Tuple 6 | from logging import Logger 7 | 8 | 9 | def gen_learning_rate_func(lr_map: List[Tuple[int, float]], logger: Logger = None): 10 | """ 11 | Example:: 12 | 13 | from torch.optim.lr_scheduler import LambdaLR 14 | 15 | # 0 <= step < 200000, lr=1e-3, 200000 <= step, lr=3e-4 16 | lr_func = gen_learning_rate_func([(0, 1e-3), (200000, 3e-4)],) 17 | lr_sch = LambdaLR(optimizer, lr_func) 18 | 19 | Args: 20 | lr_map: A 2d learning rate map, the first element of each row is step. 21 | the second is learning rate. 22 | logger: A logger to log current learning rate 23 | 24 | Returns: 25 | A learning rate generation function with signature `lr_gen(step)->lr`, 26 | accepts int and returns float. use it in your pytorch lr scheduler. 27 | """ 28 | 29 | def learning_rate_func(step): 30 | for i in range(len(lr_map) - 1): 31 | if lr_map[i][0] <= step < lr_map[i + 1][0]: 32 | if logger is not None: 33 | logger.info(f"Current learning rate:{lr_map[i][1]}") 34 | return lr_map[i][1] 35 | if logger is not None: 36 | logger.info(f"Current learning rate:{lr_map[-1][1]}") 37 | return lr_map[-1][1] 38 | 39 | return learning_rate_func 40 | -------------------------------------------------------------------------------- /docs/source/api/machin.utils.rst: -------------------------------------------------------------------------------- 1 | machin.utils 2 | ============= 3 | 4 | checker 5 | +++++++++++++++++ 6 | .. automodule:: machin.utils.checker 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | conf 12 | +++++++++++++++++ 13 | .. automodule:: machin.utils.conf 14 | :members: 15 | :undoc-members: 16 | :show-inheritance: 17 | 18 | helper_classes 19 | +++++++++++++++++ 20 | .. automodule:: machin.utils.helper_classes 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | learning_rate 26 | +++++++++++++++++ 27 | .. automodule:: machin.utils.learning_rate 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | 32 | logging 33 | +++++++++++++++++ 34 | .. automodule:: machin.utils.logging 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | media 40 | +++++++++++++++++ 41 | .. automodule:: machin.utils.media 42 | :members: 43 | :undoc-members: 44 | :show-inheritance: 45 | 46 | prepare 47 | +++++++++++++++++ 48 | .. automodule:: machin.utils.prepare 49 | :members: 50 | :undoc-members: 51 | :show-inheritance: 52 | 53 | save_env 54 | +++++++++++++++++ 55 | .. automodule:: machin.utils.save_env 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | 60 | tensor_board 61 | +++++++++++++++++ 62 | .. automodule:: machin.utils.tensor_board 63 | :members: 64 | :undoc-members: 65 | :show-inheritance: 66 | 67 | visualize 68 | +++++++++++++++++ 69 | .. automodule:: machin.utils.visualize 70 | :members: 71 | :undoc-members: 72 | :show-inheritance: 73 | 74 | -------------------------------------------------------------------------------- /examples/tutorials/parallel_distributed/assign.py: -------------------------------------------------------------------------------- 1 | from machin.parallel.assigner import ModelAssigner 2 | from machin.model.nets.resnet import ResNet 3 | 4 | if __name__ == "__main__": 5 | models = [ 6 | ResNet(in_planes=9, depth=152, out_planes=1, out_pool_size=[20, 20]) 7 | for _ in range(4) 8 | ] 9 | 10 | # create 4 ResNet as example models 11 | assigner = ModelAssigner( 12 | models=models, 13 | # `model_connection` (A, B): 1 means the amount of data trasmitted 14 | # from A to B is 1 15 | model_connection={(0, 1): 1, (2, 3): 1}, 16 | # available devices 17 | devices=["cuda:0", "cpu"], 18 | model_size_multiplier=1, 19 | max_mem_ratio=0.7, 20 | # computing power compared to GPU, 0 means none, 1 means same 21 | cpu_weight=1, 22 | # larger connection weight will force assigner to place 23 | # models with larger quantities of data transmision on the 24 | # same device 25 | connection_weight=1e3, 26 | # try this and see what will happen 27 | # connection_weight=1e1, 28 | size_match_weight=1e-2, 29 | complexity_match_weight=10, 30 | entropy_weight=1, 31 | iterations=500, 32 | update_rate=0.01, 33 | gpu_gpu_distance=1, 34 | cpu_gpu_distance=10, 35 | move_models=False, 36 | ) 37 | real_assignment = [str(dev) for dev in assigner.assignment] 38 | # should be "cuda:0", "cuda:0", "cpu", "cpu" 39 | # or "cpu", "cpu", "cuda:0", "cuda:0" 40 | print(f"Assignment: {real_assignment}") 41 | -------------------------------------------------------------------------------- /test/util_fixtures.py: -------------------------------------------------------------------------------- 1 | from test.data.all import generate_all, get_all 2 | import random 3 | import numpy as np 4 | import torch as t 5 | import pytest 6 | 7 | 8 | @pytest.fixture() 9 | def gpu(pytestconfig): 10 | dev = pytestconfig.getoption("gpu_device") 11 | if dev is not None and dev.startswith("cuda"): 12 | return dev 13 | pytest.skip(f"Requiring GPU but provided `gpu_device` is {dev}") 14 | 15 | 16 | @pytest.fixture(params=["cpu", "gpu"]) 17 | def device(pytestconfig, request): 18 | if request.param == "cpu": 19 | return "cpu" 20 | else: 21 | dev = pytestconfig.getoption("gpu_device") 22 | if dev is not None and dev.startswith("cuda"): 23 | return dev 24 | pytest.skip(f"Requiring GPU but provided `gpu_device` is {dev}") 25 | 26 | 27 | @pytest.fixture(params=["float32", "float64"]) 28 | def dtype(request): 29 | if request.param == "float32": 30 | return t.float32 31 | return t.float64 32 | 33 | 34 | @pytest.fixture() 35 | def mp_tmpdir(tmpdir): 36 | """ 37 | For multiprocessing, sharing the same tmpdir across all processes 38 | """ 39 | return tmpdir.make_numbered_dir() 40 | 41 | 42 | @pytest.fixture(scope="session") 43 | def archives(): 44 | # prepare all test data archives 45 | generate_all() 46 | return get_all() 47 | 48 | 49 | @pytest.fixture(scope="session", autouse=True) 50 | def fix_random(): 51 | t.manual_seed(0) 52 | np.random.seed(0) 53 | random.seed(0) 54 | return None 55 | 56 | 57 | __all__ = ["gpu", "device", "dtype", "mp_tmpdir", "archives", "fix_random"] 58 | -------------------------------------------------------------------------------- /examples/tutorials/as_fast_as_lightning/automatic/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "early_stopping_patience": 3, 3 | "env": "openai_gym", 4 | "episode_per_epoch": 10, 5 | "frame": "DQN", 6 | "frame_config": { 7 | "batch_size": 100, 8 | "criterion": "MSELoss", 9 | "criterion_args": [], 10 | "criterion_kwargs": {}, 11 | "discount": 0.99, 12 | "epsilon_decay": 0.9999, 13 | "gradient_max": Infinity, 14 | "learning_rate": 0.001, 15 | "lr_scheduler": null, 16 | "lr_scheduler_args": null, 17 | "lr_scheduler_kwargs": null, 18 | "mode": "double", 19 | "model_args": [ 20 | [], 21 | [] 22 | ], 23 | "model_kwargs": [ 24 | {"state_dim": 4, "action_num": 2}, 25 | {"state_dim": 4, "action_num": 2} 26 | ], 27 | "models": [ 28 | "qnet.QNet", 29 | "qnet.QNet" 30 | ], 31 | "optimizer": "Adam", 32 | "replay_buffer": null, 33 | "replay_device": "cpu", 34 | "replay_size": 500000, 35 | "update_rate": 0.005, 36 | "update_steps": null, 37 | "visualize": false, 38 | "visualize_dir": "" 39 | }, 40 | "gpus": [ 41 | 0 42 | ], 43 | "max_episodes": 10000, 44 | "root_dir": "trial", 45 | "test_env_config": { 46 | "act_kwargs": {}, 47 | "env_name": "CartPole-v1", 48 | "render_every_episode": 100 49 | }, 50 | "train_env_config": { 51 | "act_kwargs": {}, 52 | "env_name": "CartPole-v1", 53 | "render_every_episode": 100 54 | } 55 | } -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/bin/interactive.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os,sys 3 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 4 | import argparse 5 | 6 | from multiagent.environment import MultiAgentEnv 7 | from multiagent.policy import InteractivePolicy 8 | import multiagent.scenarios as scenarios 9 | 10 | if __name__ == '__main__': 11 | # parse arguments 12 | parser = argparse.ArgumentParser(description=None) 13 | parser.add_argument('-s', '--scenario', default='simple.py', help='Path of the scenario Python script.') 14 | args = parser.parse_args() 15 | 16 | # load scenario from script 17 | scenario = scenarios.load(args.scenario).Scenario() 18 | # create world 19 | world = scenario.make_world() 20 | # create multiagent environment 21 | env = MultiAgentEnv(world, scenario.reset_world, scenario.reward, scenario.observation, info_callback=None, shared_viewer = False) 22 | # render call to create viewer window (necessary only for interactive policies) 23 | env.render() 24 | # create interactive policies for each agent 25 | policies = [InteractivePolicy(env,i) for i in range(env.n)] 26 | # execution loop 27 | obs_n = env.reset() 28 | while True: 29 | # query for action from each agent's policy 30 | act_n = [] 31 | for i, policy in enumerate(policies): 32 | act_n.append(policy.action(obs_n[i])) 33 | # step environment 34 | obs_n, reward_n, done_n, _ = env.step(act_n) 35 | # render all agent views 36 | env.render() 37 | # display rewards 38 | #for agent in env.world.agents: 39 | # print(agent.name + " reward: %0.3f" % env._get_reward(agent)) 40 | -------------------------------------------------------------------------------- /.circleci-archive/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | orbs: 4 | win: circleci/windows@2.2.0 5 | 6 | jobs: 7 | win-test: 8 | executor: win/default 9 | steps: 10 | - checkout 11 | - restore_cache: 12 | key: deps-{{ .Branch }} 13 | - run: 14 | name: install dependencies 15 | command: | 16 | python --version 17 | pip install virtualenv 18 | virtualenv venv 19 | .\venv\Scripts\activate 20 | pip install . 21 | pip install .\test_lib\multiagent-particle-envs\ 22 | pip install gym 23 | pip install mock pytest==6.0.0 pytest-cov==2.10.0 allure-pytest==2.8.16 pytest-html==1.22.1 pytest-repeat==0.8.0 24 | - save_cache: 25 | key: deps-{{ .Branch }} 26 | paths: 27 | - "venv" 28 | - run: 29 | name: make test directory 30 | command: mkdir test_results -Force 31 | - run: 32 | name: test API 33 | command: 34 | # each step runs in its own shell and venv activate does 35 | # does not persist. 36 | >- 37 | .\venv\Scripts\activate && 38 | python -m pytest 39 | -s --assert=plain 40 | --cov-report term-missing 41 | --cov=machin 42 | -k "not full_train" 43 | -o junit_family=xunit1 44 | --junitxml test_results\test_api.xml 45 | --cov-report xml:test_results\cov_report.xml 46 | --html=test_results\test_api.html 47 | --self-contained-html 48 | .\test\ 49 | - store_artifacts: 50 | path: test_results\test_api.xml 51 | - store_test_results: 52 | path: test_results\ 53 | 54 | workflows: 55 | main: 56 | jobs: 57 | - win-test 58 | -------------------------------------------------------------------------------- /examples/tutorials/as_fast_as_lightning/programmatic/nni/nni_main.py: -------------------------------------------------------------------------------- 1 | from machin.auto.config import ( 2 | generate_algorithm_config, 3 | generate_env_config, 4 | generate_training_config, 5 | launch, 6 | ) 7 | from pytorch_lightning.callbacks import Callback 8 | 9 | import nni 10 | import torch as t 11 | import torch.nn as nn 12 | 13 | 14 | class SomeQNet(nn.Module): 15 | def __init__(self, state_dim, action_num): 16 | super().__init__() 17 | 18 | self.fc1 = nn.Linear(state_dim, 16) 19 | self.fc2 = nn.Linear(16, 16) 20 | self.fc3 = nn.Linear(16, action_num) 21 | 22 | def forward(self, state): 23 | a = t.relu(self.fc1(state)) 24 | a = t.relu(self.fc2(a)) 25 | return self.fc3(a) 26 | 27 | 28 | class InspectCallback(Callback): 29 | def __init__(self): 30 | self.total_reward = 0 31 | 32 | def on_train_batch_end( 33 | self, trainer, pl_module, outputs, batch, _batch_idx, _dataloader_idx 34 | ) -> None: 35 | for l in batch[0].logs: 36 | if "total_reward" in l: 37 | self.total_reward = l["total_reward"] 38 | 39 | 40 | if __name__ == "__main__": 41 | param = nni.get_next_parameter() 42 | cb = InspectCallback() 43 | while param: 44 | config = generate_algorithm_config("DQN") 45 | config = generate_env_config("openai_gym", config) 46 | config = generate_training_config( 47 | root_dir="trial", episode_per_epoch=10, max_episodes=10000, config=config 48 | ) 49 | config["frame_config"]["models"] = ["SomeQNet", "SomeQNet"] 50 | config["frame_config"]["model_kwargs"] = [{"state_dim": 4, "action_num": 2}] * 2 51 | config["frame_config"]["learning_rate"] = param["lr"] 52 | config["frame_config"]["update_rate"] = param["upd"] 53 | launch(config, pl_callbacks=[cb]) 54 | # we use total reward as "accuracy" 55 | nni.report_final_result(cb.total_reward) 56 | param = nni.get_next_parameter() 57 | -------------------------------------------------------------------------------- /test/parallel/test_pickle.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pipe, get_context 2 | from machin.parallel.pickle import dumps, loads 3 | from machin.parallel.process import Process 4 | from test.util_platforms import linux_only 5 | 6 | import torch as t 7 | 8 | 9 | def subproc_test_dumps_copy_tensor(pipe): 10 | pipe.send(dumps(t.zeros([10]), copy_tensor=True)) 11 | 12 | 13 | def test_dumps_copy_tensor(): 14 | pipe_0, pipe_1 = Pipe(duplex=True) 15 | ctx = get_context("spawn") 16 | process_0 = Process(target=subproc_test_dumps_copy_tensor, args=(pipe_0,), ctx=ctx) 17 | process_0.start() 18 | while process_0.is_alive(): 19 | process_0.watch() 20 | assert t.all(loads(pipe_1.recv()) == t.zeros([10])) 21 | process_0.join() 22 | 23 | 24 | def subproc_test_dumps_not_copy_tensor(pipe): 25 | tensor = t.zeros([10]) 26 | tensor.share_memory_() 27 | pipe.send(dumps(tensor, copy_tensor=False)) 28 | 29 | 30 | @linux_only 31 | def test_dumps_not_copy_tensor(): 32 | pipe_0, pipe_1 = Pipe(duplex=True) 33 | ctx = get_context("fork") 34 | process_0 = Process( 35 | target=subproc_test_dumps_not_copy_tensor, args=(pipe_0,), ctx=ctx 36 | ) 37 | process_0.start() 38 | while process_0.is_alive(): 39 | process_0.watch() 40 | assert t.all(loads(pipe_1.recv()) == t.zeros([10])) 41 | process_0.join() 42 | 43 | 44 | def subproc_test_dumps_local_func(pipe): 45 | tensor = t.zeros([10]) 46 | tensor.share_memory_() 47 | 48 | def local_func(): 49 | nonlocal tensor 50 | return tensor 51 | 52 | pipe.send(dumps(local_func, copy_tensor=False)) 53 | 54 | 55 | @linux_only 56 | def test_dumps_local_func(): 57 | pipe_0, pipe_1 = Pipe(duplex=True) 58 | ctx = get_context("fork") 59 | process_0 = Process(target=subproc_test_dumps_local_func, args=(pipe_0,), ctx=ctx) 60 | process_0.start() 61 | while process_0.is_alive(): 62 | process_0.watch() 63 | assert t.all(loads(pipe_1.recv())() == t.zeros([10])) 64 | process_0.join() 65 | -------------------------------------------------------------------------------- /machin/parallel/thread.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | import threading 3 | 4 | 5 | class ThreadException(Exception): 6 | pass 7 | 8 | 9 | class Thread(threading.Thread): 10 | """ 11 | Enhanced thread with exception tracing. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | group=None, 17 | target=None, 18 | name=None, 19 | args=(), 20 | kwargs={}, 21 | cleaner=None, 22 | *, 23 | daemon=None, 24 | ): 25 | threading.Thread.__init__( 26 | self, 27 | group=group, 28 | target=target, 29 | name=name, 30 | args=args, 31 | kwargs=kwargs, 32 | daemon=daemon, 33 | ) 34 | self._cleaner = cleaner 35 | self._exception_str = "" 36 | self._has_exception = False 37 | 38 | @property 39 | def exception(self): 40 | if not self._has_exception: 41 | return None 42 | exc = ThreadException(self._exception_str) 43 | return exc 44 | 45 | def watch(self): 46 | if self._has_exception: 47 | raise self.exception 48 | 49 | @staticmethod 50 | def format_exceptions(exceptions): 51 | all_tb = "" 52 | for exc, i in zip(exceptions, range(len(exceptions))): 53 | tb = exc.__traceback__ 54 | tb = traceback.format_exception(type(exc), exc, tb) 55 | tb = "".join(tb) 56 | all_tb = all_tb + f"\nException {i}:\n{tb}" 57 | return all_tb 58 | 59 | def run(self): 60 | exc = [] 61 | try: 62 | super().run() 63 | except BaseException as e: 64 | exc.append(e) 65 | finally: 66 | if self._cleaner is not None: 67 | try: 68 | self._cleaner() 69 | except BaseException as e: 70 | exc.append(e) 71 | if exc: 72 | self._exception_str = self.format_exceptions(exc) 73 | self._has_exception = True 74 | -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/make_env.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for creating a multiagent environment with one of the scenarios listed 3 | in ./scenarios/. 4 | Can be called by using, for example: 5 | env = make_env('simple_speaker_listener') 6 | After producing the env object, can be used similarly to an OpenAI gym 7 | environment. 8 | 9 | A policy using this environment must output actions in the form of a list 10 | for all agents. Each element of the list should be a numpy array, 11 | of size (env.world.dim_p + env.world.dim_c, 1). Physical actions precede 12 | communication actions in this array. See environment.py for more details. 13 | """ 14 | 15 | def make_env(scenario_name, benchmark=False): 16 | ''' 17 | Creates a MultiAgentEnv object as env. This can be used similar to a gym 18 | environment by calling env.reset() and env.step(). 19 | Use env.render() to view the environment on the screen. 20 | 21 | Input: 22 | scenario_name : name of the scenario from ./scenarios/ to be Returns 23 | (without the .py extension) 24 | benchmark : whether you want to produce benchmarking data 25 | (usually only done during evaluation) 26 | 27 | Some useful env properties (see environment.py): 28 | .observation_space : Returns the observation space for each agent 29 | .action_space : Returns the action space for each agent 30 | .n : Returns the number of Agents 31 | ''' 32 | from multiagent.environment import MultiAgentEnv 33 | import multiagent.scenarios as scenarios 34 | 35 | # load scenario from script 36 | scenario = scenarios.load(scenario_name + ".py").Scenario() 37 | # create world 38 | world = scenario.make_world() 39 | # create multiagent environment 40 | if benchmark: 41 | env = MultiAgentEnv(world, scenario.reset_world, scenario.reward, scenario.observation, scenario.benchmark_data) 42 | else: 43 | env = MultiAgentEnv(world, scenario.reset_world, scenario.reward, scenario.observation) 44 | return env 45 | -------------------------------------------------------------------------------- /test/auto/_pl_plugin_runner.py: -------------------------------------------------------------------------------- 1 | from machin.parallel.distributed import get_world, get_cur_rank 2 | from machin.utils.helper_classes import Object 3 | from torch.utils.data import DataLoader, TensorDataset 4 | import os 5 | import sys 6 | import pickle 7 | import torch as t 8 | import torch.nn as nn 9 | import pytorch_lightning as pl 10 | 11 | # necessary to patch PL DDP plugins 12 | import machin.auto 13 | 14 | 15 | class NNModule(nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | self.fc1 = nn.Linear(10, 10) 19 | 20 | def forward(self, x): 21 | return t.sum(x) 22 | 23 | 24 | class ParallelModule(pl.LightningModule): 25 | def __init__(self): 26 | super().__init__() 27 | self.nn_model = NNModule() 28 | self.frame = Object({"optimizers": None, "lr_schedulers": None}) 29 | 30 | def train_dataloader(self): 31 | return DataLoader( 32 | dataset=TensorDataset(t.ones([5, 10])), collate_fn=lambda x: x 33 | ) 34 | 35 | def training_step(self, batch, _batch_idx): 36 | world_inited = get_world() is not None 37 | model_inited = isinstance(self.nn_model, NNModule) 38 | 39 | if world_inited and get_cur_rank() == 0: 40 | with open(os.environ["TEST_SAVE_PATH"], "wb") as f: 41 | pickle.dump([model_inited], f) 42 | if not world_inited: 43 | raise RuntimeError("World not initialized.") 44 | return None 45 | 46 | def init_frame(self): 47 | pass 48 | 49 | def configure_optimizers(self): 50 | return None 51 | 52 | 53 | if __name__ == "__main__": 54 | os.environ["WORLD_SIZE"] = "3" 55 | os.environ["NODE_RANK"] = "0" 56 | os.environ["LOCAL_RANK"] = sys.argv[2] 57 | print(os.environ["TEST_SAVE_PATH"]) 58 | print(sys.argv[1]) 59 | trainer = pl.Trainer( 60 | num_nodes=1, 61 | num_processes=3, 62 | limit_train_batches=1, 63 | max_steps=1, 64 | accelerator="ddp" if sys.argv[1] == "ddp" else "ddp_spawn", 65 | ) 66 | assert trainer.distributed_backend == sys.argv[1] 67 | model = ParallelModule() 68 | trainer.fit(model) 69 | -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/multiagent/policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pyglet.window import key 3 | 4 | # individual agent policy 5 | class Policy(object): 6 | def __init__(self): 7 | pass 8 | def action(self, obs): 9 | raise NotImplementedError() 10 | 11 | # interactive policy based on keyboard input 12 | # hard-coded to deal only with movement, not communication 13 | class InteractivePolicy(Policy): 14 | def __init__(self, env, agent_index): 15 | super(InteractivePolicy, self).__init__() 16 | self.env = env 17 | # hard-coded keyboard events 18 | self.move = [False for i in range(4)] 19 | self.comm = [False for i in range(env.world.dim_c)] 20 | # register keyboard events with this environment's window 21 | env.viewers[agent_index].window.on_key_press = self.key_press 22 | env.viewers[agent_index].window.on_key_release = self.key_release 23 | 24 | def action(self, obs): 25 | # ignore observation and just act based on keyboard events 26 | if self.env.discrete_action_input: 27 | u = 0 28 | if self.move[0]: u = 1 29 | if self.move[1]: u = 2 30 | if self.move[2]: u = 4 31 | if self.move[3]: u = 3 32 | else: 33 | u = np.zeros(5) # 5-d because of no-move action 34 | if self.move[0]: u[1] += 1.0 35 | if self.move[1]: u[2] += 1.0 36 | if self.move[3]: u[3] += 1.0 37 | if self.move[2]: u[4] += 1.0 38 | if True not in self.move: 39 | u[0] += 1.0 40 | return np.concatenate([u, np.zeros(self.env.world.dim_c)]) 41 | 42 | # keyboard event callbacks 43 | def key_press(self, k, mod): 44 | if k==key.LEFT: self.move[0] = True 45 | if k==key.RIGHT: self.move[1] = True 46 | if k==key.UP: self.move[2] = True 47 | if k==key.DOWN: self.move[3] = True 48 | def key_release(self, k, mod): 49 | if k==key.LEFT: self.move[0] = False 50 | if k==key.RIGHT: self.move[1] = False 51 | if k==key.UP: self.move[2] = False 52 | if k==key.DOWN: self.move[3] = False 53 | -------------------------------------------------------------------------------- /test/utils/test_checker.py: -------------------------------------------------------------------------------- 1 | from machin.utils.checker import ( 2 | CheckError, 3 | check_shape, 4 | check_nan, 5 | check_model, 6 | mark_as_atom_module, 7 | mark_module_output, 8 | p_chk_nan, 9 | p_chk_range, 10 | ) 11 | from machin.utils.tensor_board import TensorBoard 12 | 13 | import pytest 14 | import torch as t 15 | import torch.nn as nn 16 | 17 | 18 | def test_check_shape(): 19 | with pytest.raises(CheckError, match="has invalid shape"): 20 | tensor = t.zeros([10, 10]) 21 | check_shape(tensor, [5, 5]) 22 | 23 | 24 | def test_check_nan(): 25 | with pytest.raises(CheckError, match="contains nan"): 26 | tensor = t.full([10, 10], float("NaN")) 27 | check_nan(tensor) 28 | 29 | 30 | class SubModule1(nn.Module): 31 | def __init__(self): 32 | super().__init__() 33 | self.fc1 = nn.Linear(5, 10) 34 | self.fc2 = nn.Linear(10, 20) 35 | mark_as_atom_module(self) 36 | mark_module_output(self, ["output1_sub1"]) 37 | 38 | def forward(self, x): 39 | return self.fc2(self.fc1(x)), None 40 | 41 | 42 | class SubModule2(nn.Module): 43 | def __init__(self): 44 | super().__init__() 45 | self.fc1 = nn.Linear(20, 20) 46 | 47 | def forward(self, x): 48 | return self.fc1(x) 49 | 50 | 51 | class CheckedModel(nn.Module): 52 | def __init__(self): 53 | super().__init__() 54 | self.sub1 = SubModule1() 55 | self.sub2 = SubModule2() 56 | 57 | def forward(self, x): 58 | return self.sub2(self.sub1(x)[0]) 59 | 60 | 61 | param_checked = False 62 | 63 | 64 | def param_check_hook(*_): 65 | global param_checked 66 | param_checked = True 67 | 68 | 69 | def test_check_model(): 70 | global param_checked 71 | board = TensorBoard() 72 | board.init() 73 | model = CheckedModel() 74 | cancel = check_model( 75 | board.writer, 76 | model, 77 | param_check_interval=1, 78 | param_check_hooks=(param_check_hook, p_chk_nan, p_chk_range), 79 | name="checked_model", 80 | ) 81 | output = model(t.ones([1, 5])) 82 | output.sum().backward() 83 | cancel() 84 | assert param_checked 85 | -------------------------------------------------------------------------------- /test/utils/test_conf.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from machin.utils.conf import ( 3 | Config, 4 | load_config_cmd, 5 | load_config_file, 6 | save_config, 7 | merge_config, 8 | ) 9 | from machin.utils.helper_classes import Object 10 | 11 | import os 12 | import json 13 | from unittest import mock 14 | 15 | 16 | def get_config(): 17 | c = Config() 18 | c.conf1 = 1 19 | c.conf2 = 2 20 | return c 21 | 22 | 23 | @mock.patch( 24 | "machin.utils.conf.argparse.ArgumentParser.parse_args", 25 | return_value=Object(data={"conf": ["conf1=2", "conf3=3"]}), 26 | ) 27 | def test_load_config_cmd(*_mock_classes): 28 | conf = load_config_cmd() 29 | assert conf["conf1"] == 2 30 | assert conf["conf2"] is None 31 | assert conf["conf3"] == 3 32 | 33 | conf = load_config_cmd(get_config()) 34 | # configs from commandline precedes configs from the config file 35 | assert conf["conf1"] == 2 36 | assert conf["conf2"] == 2 37 | assert conf["conf3"] == 3 38 | 39 | 40 | def test_load_config_file(tmpdir): 41 | tmp_dir = str(tmpdir.make_numbered_dir()) 42 | with open(join(tmp_dir, "conf.json"), "w") as config_file: 43 | json.dump({"conf1": 2, "conf3": 3}, config_file, sort_keys=True, indent=4) 44 | 45 | conf = load_config_file(join(tmp_dir, "conf.json")) 46 | assert conf["conf1"] == 2 47 | assert conf["conf2"] is None 48 | assert conf["conf3"] == 3 49 | 50 | conf = load_config_file(join(tmp_dir, "conf.json"), get_config()) 51 | assert conf["conf1"] == 2 52 | assert conf["conf2"] == 2 53 | assert conf["conf3"] == 3 54 | 55 | 56 | def test_save_config(tmpdir): 57 | conf = get_config() 58 | tmp_dir = str(tmpdir.make_numbered_dir()) 59 | save_config(conf, join(tmp_dir, "conf.json")) 60 | assert os.path.exists(join(tmp_dir, "conf.json")) 61 | 62 | 63 | def test_merge_config(): 64 | conf = get_config() 65 | conf = merge_config(conf, {"conf1": 2, "conf3": 3}) 66 | assert conf.conf1 == 2 67 | assert conf.conf2 == 2 68 | assert conf.conf3 == 3 69 | 70 | conf = get_config() 71 | conf2 = Config(conf1=2, conf3=3) 72 | conf = merge_config(conf, conf2) 73 | assert conf.conf1 == 2 74 | assert conf.conf2 == 2 75 | assert conf.conf3 == 3 76 | -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/multiagent/scenarios/simple.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from multiagent.core import World, Agent, Landmark 3 | from multiagent.scenario import BaseScenario 4 | 5 | class Scenario(BaseScenario): 6 | def make_world(self): 7 | world = World() 8 | # add agents 9 | world.agents = [Agent() for i in range(1)] 10 | for i, agent in enumerate(world.agents): 11 | agent.name = 'agent %d' % i 12 | agent.collide = False 13 | agent.silent = True 14 | # add landmarks 15 | world.landmarks = [Landmark() for i in range(1)] 16 | for i, landmark in enumerate(world.landmarks): 17 | landmark.name = 'landmark %d' % i 18 | landmark.collide = False 19 | landmark.movable = False 20 | # make initial conditions 21 | self.reset_world(world) 22 | return world 23 | 24 | def reset_world(self, world): 25 | # random properties for agents 26 | for i, agent in enumerate(world.agents): 27 | agent.color = np.array([0.25,0.25,0.25]) 28 | # random properties for landmarks 29 | for i, landmark in enumerate(world.landmarks): 30 | landmark.color = np.array([0.75,0.75,0.75]) 31 | world.landmarks[0].color = np.array([0.75,0.25,0.25]) 32 | # set random initial states 33 | for agent in world.agents: 34 | agent.state.p_pos = np.random.uniform(-1,+1, world.dim_p) 35 | agent.state.p_vel = np.zeros(world.dim_p) 36 | agent.state.c = np.zeros(world.dim_c) 37 | for i, landmark in enumerate(world.landmarks): 38 | landmark.state.p_pos = np.random.uniform(-1,+1, world.dim_p) 39 | landmark.state.p_vel = np.zeros(world.dim_p) 40 | 41 | def reward(self, agent, world): 42 | dist2 = np.sum(np.square(agent.state.p_pos - world.landmarks[0].state.p_pos)) 43 | return -dist2 44 | 45 | def observation(self, agent, world): 46 | # get positions of all entities in this agent's reference frame 47 | entity_pos = [] 48 | for entity in world.landmarks: 49 | entity_pos.append(entity.state.p_pos - agent.state.p_pos) 50 | return np.concatenate([agent.state.p_vel] + entity_pos) 51 | -------------------------------------------------------------------------------- /test/auto/test_pl_logger.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from pytorch_lightning.loggers.base import DummyExperiment 3 | from machin.auto.pl_logger import LocalMediaLogger 4 | import os 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | class TestLocalMediaLogger: 9 | def test_all(self, tmpdir): 10 | tmp_dir = str(tmpdir.make_numbered_dir()) 11 | lm_logger = LocalMediaLogger(tmp_dir, tmp_dir) 12 | 13 | assert lm_logger.name == "LocalMediaLogger" 14 | assert isinstance(lm_logger.experiment, DummyExperiment) 15 | assert lm_logger.version == "0.1" 16 | 17 | # nothing happens 18 | lm_logger.log_hyperparams({"a": 0.1}) 19 | lm_logger.log_metrics({"b": 0.1}, step=1) 20 | lm_logger.save() 21 | lm_logger.finalize("") 22 | 23 | # test logging artifact 24 | artifact_path = str(os.path.join(tmp_dir, "test.txt")) 25 | new_artifact_path = str(os.path.join(tmp_dir, "test1.txt")) 26 | 27 | with open(artifact_path, "w") as f: 28 | f.write("1" * 1000) 29 | lm_logger.log_artifact(artifact_path) 30 | assert os.path.exists(artifact_path) 31 | os.remove(artifact_path) 32 | 33 | with open(artifact_path, "w") as f: 34 | f.write("1" * 1000) 35 | lm_logger.log_artifact(artifact_path, "test1.txt") 36 | assert os.path.exists(new_artifact_path) 37 | os.remove(new_artifact_path) 38 | 39 | # test logging image 40 | image_path = str(os.path.join(tmp_dir, "test_1.png")) 41 | 42 | # PIL Image test 43 | im = Image.new("RGB", (100, 100)) 44 | lm_logger.log_image("test", im, step=1) 45 | assert os.path.exists(image_path) 46 | os.remove(image_path) 47 | 48 | # Matplotlib figure test 49 | fig = plt.figure() 50 | lm_logger.log_image("test", fig, step=1) 51 | assert os.path.exists(image_path) 52 | os.remove(image_path) 53 | 54 | # Image file test 55 | source_path = str(os.path.join(tmp_dir, "x.jpg")) 56 | target_path = str(os.path.join(tmp_dir, "test_1.jpg")) 57 | im.save(source_path) 58 | lm_logger.log_image("test", source_path, step=1) 59 | assert os.path.exists(target_path) 60 | os.remove(target_path) 61 | -------------------------------------------------------------------------------- /test/auto/test_pl_plugin.py: -------------------------------------------------------------------------------- 1 | from test.util_platforms import linux_only_forall 2 | import os 3 | import sys 4 | import pytest 5 | import pickle 6 | import os.path as p 7 | import subprocess as sp 8 | 9 | 10 | linux_only_forall() 11 | 12 | 13 | class TestDDPPlugin: 14 | def test_all(self, tmpdir): 15 | test_save_path = str(p.join(tmpdir.make_numbered_dir(), "test.save")) 16 | env = os.environ.copy() 17 | env["TEST_SAVE_PATH"] = test_save_path 18 | processes = [ 19 | sp.Popen( 20 | [ 21 | sys.executable, 22 | p.join(p.dirname(p.abspath(__file__)), "_pl_plugin_runner.py"), 23 | "ddp", 24 | str(i), 25 | ], 26 | env=env, 27 | ) 28 | for i in range(3) 29 | ] 30 | try: 31 | for process in processes: 32 | process.wait(timeout=20) 33 | except sp.TimeoutExpired: 34 | pytest.fail("Timeout on waiting for the DDPPlugin script to end.") 35 | 36 | with open(test_save_path, "rb") as f: 37 | flags = pickle.load(f) 38 | assert flags == [True], f"Not properly_inited, flags are: {flags}" 39 | 40 | 41 | class TestDDPSpawnPlugin: 42 | def test_all(self, tmpdir): 43 | test_save_path = str(p.join(tmpdir.make_numbered_dir(), "test.save")) 44 | env = os.environ.copy() 45 | env["TEST_SAVE_PATH"] = test_save_path 46 | processes = [ 47 | sp.Popen( 48 | [ 49 | sys.executable, 50 | p.join(p.dirname(p.abspath(__file__)), "_pl_plugin_runner.py"), 51 | "ddp_spawn", 52 | str(i), 53 | ], 54 | env=env, 55 | ) 56 | for i in range(3) 57 | ] 58 | try: 59 | for process in processes: 60 | process.wait(timeout=20) 61 | except sp.TimeoutExpired: 62 | pytest.fail("Timeout on waiting for the DDPSpawnPlugin script to end.") 63 | 64 | with open(test_save_path, "rb") as f: 65 | flags = pickle.load(f) 66 | assert flags == [True], f"Not properly_inited, flags are: {flags}" 67 | -------------------------------------------------------------------------------- /examples/tutorials/parallel_distributed/dist_rpc.py: -------------------------------------------------------------------------------- 1 | from machin.parallel.distributed import World 2 | from torch.multiprocessing import spawn 3 | from time import sleep 4 | 5 | 6 | # an example service class 7 | class WorkerService: 8 | counter = 0 9 | 10 | def count(self): 11 | self.counter += 1 12 | return self.counter 13 | 14 | def get_count(self): 15 | return self.counter 16 | 17 | 18 | def main(rank): 19 | world = World(world_size=3, rank=rank, name=str(rank), rpc_timeout=20) 20 | service = WorkerService() 21 | if rank == 0: 22 | # only group members needs to enter this function 23 | group = world.create_rpc_group("group", ["0", "1"]) 24 | 25 | service.counter = 20 26 | 27 | # register two services and share a value by pairing 28 | group.pair("value", service.counter) 29 | group.register("count", service.count) 30 | group.register("get_count", service.get_count) 31 | 32 | # cannot register an already used key 33 | # KeyError will be raised 34 | # group.register("count", service.count) 35 | 36 | # wait for process 1 to finish 37 | sleep(4) 38 | assert service.get_count() == 23 39 | 40 | # deregister service and unpair value 41 | group.unpair("value") 42 | group.deregister("count") 43 | group.deregister("get_count") 44 | sleep(4) 45 | group.destroy() 46 | elif rank == 1: 47 | group = world.create_rpc_group("group", ["0", "1"]) 48 | sleep(0.5) 49 | assert group.is_registered("count") and group.is_registered("get_count") 50 | print("Process 1: service 'count' and 'get_count' correctly " "registered.") 51 | 52 | assert group.registered_sync("count") == 21 53 | assert group.registered_async("count").wait() == 22 54 | assert group.registered_remote("count").to_here() == 23 55 | print("Process 1: service 'count' and 'get_count' correctly " "called") 56 | sleep(4) 57 | assert not group.is_registered("count") and not group.is_registered("get_count") 58 | print("Process 1: service 'count' and 'get_count' correctly " "unregistered.") 59 | group.destroy() 60 | 61 | 62 | if __name__ == "__main__": 63 | # spawn 3 sub processes 64 | spawn(main, nprocs=3) 65 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import re 3 | import setuptools 4 | 5 | root = path.abspath(path.dirname(__file__)) 6 | 7 | with open(path.join(root, "machin", "__init__.py")) as f: 8 | version = re.search(r"__version__ = \"(.*?)\"", f.read()).group(1) 9 | 10 | with open("README.md", mode="r", encoding="utf8") as desc: 11 | long_description = desc.read() 12 | 13 | setuptools.setup( 14 | name="machin", 15 | version=version, 16 | description="Reinforcement learning library", 17 | long_description=long_description, 18 | long_description_content_type="text/markdown", 19 | url="https://github.com/iffiX/machin", 20 | author="Iffi", 21 | author_email="iffi@mail.beyond-infinity.com", 22 | license="MIT", 23 | python_requires=">=3.5", 24 | packages=setuptools.find_packages( 25 | exclude=[ 26 | "test", 27 | "test.*", 28 | "test_lib", 29 | "test_lib.*", 30 | "examples", 31 | "examples.*", 32 | "docs", 33 | "docs.*", 34 | ] 35 | ), 36 | classifiers=[ 37 | # How mature is this project? Common values are 38 | # 3 - Alpha 39 | # 4 - Beta 40 | # 5 - Production/Stable 41 | "Development Status :: 4 - Beta", 42 | # Indicate who your project is intended for 43 | "Intended Audience :: Developers", 44 | "Intended Audience :: Science/Research", 45 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 46 | "Topic :: Software Development :: Libraries :: Python Modules", 47 | # Pick your license as you wish (should match "license" above) 48 | "License :: OSI Approved :: MIT License", 49 | # Specify the Python versions you support here. In particular, ensure 50 | # that you indicate whether you support Python 2, Python 3 or both. 51 | "Programming Language :: Python :: 3.6", 52 | "Programming Language :: Python :: 3.7", 53 | "Programming Language :: Python :: 3.8", 54 | ], 55 | install_requires=[ 56 | "gym", 57 | "psutil", 58 | "numpy", 59 | "torch>=1.6.0", 60 | "pytorch-lightning>=1.2.0", 61 | "torchviz", 62 | "moviepy", 63 | "matplotlib", 64 | "colorlog", 65 | "dill", 66 | "GPUtil", 67 | "Pillow", 68 | "tensorboardX", 69 | ], 70 | ) 71 | -------------------------------------------------------------------------------- /test/data/all.py: -------------------------------------------------------------------------------- 1 | from . import generators, ROOT 2 | from .archive import Archive 3 | import os 4 | import re 5 | 6 | 7 | def first(iterable, condition=lambda x: True): 8 | """ 9 | Returns the first item in the `iterable` that 10 | satisfies the `condition`. 11 | 12 | If the condition is not given, returns the first item of 13 | the iterable. 14 | 15 | Raises `StopIteration` if no item satisfying the condition is found. 16 | """ 17 | return next(x for x in iterable if condition(x)) 18 | 19 | 20 | def generate_all(): 21 | print("Generating all needed data...") 22 | os.makedirs(os.path.join(ROOT, "generated"), exist_ok=True) 23 | for gen in dir(generators): 24 | method = getattr(getattr(generators, gen), "generate", None) 25 | name = getattr(getattr(generators, gen), "generated_name", None) 26 | if method is not None and name is not None: 27 | match = [ 28 | f if re.match(name, f) is not None else None 29 | for f in os.listdir(os.path.join(ROOT, "generated")) 30 | ] 31 | try: 32 | file = first(match, lambda m: m is not None) 33 | print(f"Skipping {gen} because file {file} already exists.") 34 | except StopIteration: 35 | print(f"Generating {gen}...") 36 | method() 37 | else: 38 | print( 39 | f"Skipping {gen} because its method({method}) " 40 | f"or generated name({name}) is None" 41 | ) 42 | 43 | 44 | def get_all(): 45 | archives = {} 46 | os.makedirs(os.path.join(ROOT, "generated"), exist_ok=True) 47 | for gen in dir(generators): 48 | method = getattr(getattr(generators, gen), "generate", None) 49 | name = getattr(getattr(generators, gen), "generated_name", None) 50 | if method is not None and name is not None: 51 | match = [ 52 | f if re.match(name, f) is not None else None 53 | for f in os.listdir(os.path.join(ROOT, "generated")) 54 | ] 55 | try: 56 | file = first(match, lambda m: m is not None) 57 | archives[name] = Archive(path=os.path.join(ROOT, "generated", file)) 58 | except StopIteration: 59 | raise ValueError( 60 | f"Missing generated file for {gen}, please re-run generate_all" 61 | ) 62 | return archives 63 | -------------------------------------------------------------------------------- /examples/tutorials/parallel_distributed/dist_oserver.py: -------------------------------------------------------------------------------- 1 | from machin.parallel.distributed import World 2 | from machin.parallel.server import OrderedServerSimpleImpl 3 | from torch.multiprocessing import spawn 4 | from time import sleep 5 | 6 | 7 | def main(rank): 8 | world = World(world_size=3, rank=rank, name=str(rank), rpc_timeout=20) 9 | # Usually, distributed services in Machin are seperated 10 | # into: 11 | # An accessor: OrderedServerSimple 12 | # An implementation: OrderedServerSimpleImpl 13 | # 14 | # Except DistributedBuffer and DistributedPrioritizedBuffer 15 | # 16 | # Accessor is a handle, which records the name of internal 17 | # service handles. Usually paired as an accessible resource 18 | # to the group, so any group members can get this accessor 19 | # and use the internal resources & services. 20 | # 21 | # Implementation is the thing that actually starts on the 22 | # provider process, some implementations may contain backend 23 | # threads, etc. 24 | 25 | group = world.create_rpc_group("group", ["0", "1", "2"]) 26 | if rank == 0: 27 | _server = OrderedServerSimpleImpl("server", group) 28 | sleep(5) 29 | else: 30 | sleep(2) 31 | server = group.get_paired("server").to_here() 32 | 33 | # change key "a", set new value to "value" 34 | # change version from `None` to `1` 35 | if server.push("a", "value", 1, None): 36 | print(rank, "push 1 success") 37 | else: 38 | print(rank, "push 1 failed") 39 | 40 | # change key "a", set new value to "value2" 41 | # change version from `1` to `2` 42 | if server.push("a", "value2", 2, 1): 43 | print(rank, "push 2 success") 44 | else: 45 | print(rank, "push 2 failed") 46 | 47 | # change key "a", set new value to "value3" 48 | # change version from `2` to `3` 49 | if server.push("a", "value3", 3, 2): 50 | print(rank, "push 3 success") 51 | else: 52 | print(rank, "push 3 failed") 53 | 54 | assert server.pull("a", None) == ("value3", 3) 55 | assert server.pull("a", 2) == ("value2", 2) 56 | assert server.pull("a", 1) is None 57 | assert server.pull("b", None) is None 58 | print("Ordered server check passed") 59 | group.destroy() 60 | 61 | 62 | if __name__ == "__main__": 63 | # spawn 3 sub processes 64 | spawn(main, nprocs=3) 65 | -------------------------------------------------------------------------------- /test/utils/test_media.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from machin.utils.media import ( 3 | show_image, 4 | create_video, 5 | create_video_subproc, 6 | create_image, 7 | create_image_subproc, 8 | ) 9 | 10 | import os 11 | import pytest 12 | import numpy as np 13 | 14 | 15 | @pytest.fixture(scope="function") 16 | def images(): 17 | images = [ 18 | np.random.randint(0, 255, size=[128, 128], dtype=np.uint8) for _ in range(120) 19 | ] 20 | return images 21 | 22 | 23 | @pytest.fixture(scope="function") 24 | def images_f(): 25 | images = [np.random.rand(128, 128) for _ in range(120)] 26 | return images 27 | 28 | 29 | def test_show_image(images): 30 | show_image(images[0], show_normalized=True) 31 | show_image(images[0], show_normalized=False) 32 | 33 | 34 | def test_create_video(images, tmpdir): 35 | tmp_dir = str(tmpdir.make_numbered_dir()) 36 | create_video(images, tmp_dir, "vid", extension=".gif") 37 | assert os.path.exists(join(tmp_dir, "vid.gif")) 38 | create_video(images, tmp_dir, "vid", extension=".mp4") 39 | assert os.path.exists(join(tmp_dir, "vid.mp4")) 40 | 41 | 42 | def test_create_video_float(images_f, tmpdir): 43 | tmp_dir = str(tmpdir.make_numbered_dir()) 44 | create_video(images_f, tmp_dir, "vid", extension=".gif") 45 | assert os.path.exists(join(tmp_dir, "vid.gif")) 46 | create_video(images_f, tmp_dir, "vid", extension=".mp4") 47 | assert os.path.exists(join(tmp_dir, "vid.mp4")) 48 | 49 | 50 | def test_create_video_subproc(images, tmpdir): 51 | tmp_dir = str(tmpdir.make_numbered_dir()) 52 | create_video_subproc([], tmp_dir, "empty", extension=".gif")() 53 | create_video_subproc(images, tmp_dir, "vid", extension=".gif")() 54 | assert os.path.exists(join(tmp_dir, "vid.gif")) 55 | 56 | 57 | def test_create_image(images, tmpdir): 58 | tmp_dir = str(tmpdir.make_numbered_dir()) 59 | create_image(images[0], tmp_dir, "img", extension=".png") 60 | assert os.path.exists(join(tmp_dir, "img.png")) 61 | 62 | 63 | def test_create_image_float(images_f, tmpdir): 64 | tmp_dir = str(tmpdir.make_numbered_dir()) 65 | create_image(images_f[0], tmp_dir, "img", extension=".png") 66 | assert os.path.exists(join(tmp_dir, "img.png")) 67 | 68 | 69 | def test_create_image_subproc(images, tmpdir): 70 | tmp_dir = str(tmpdir.make_numbered_dir()) 71 | create_image_subproc(images[0], tmp_dir, "img", extension=".png")() 72 | assert os.path.exists(join(tmp_dir, "img.png")) 73 | -------------------------------------------------------------------------------- /machin/parallel/process.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pipe 2 | from multiprocessing.process import BaseProcess 3 | from multiprocessing.context import _default_context 4 | import traceback 5 | 6 | 7 | class ProcessException(Exception): 8 | pass 9 | 10 | 11 | class Process(BaseProcess): 12 | def __init__( 13 | self, 14 | group=None, 15 | target=None, 16 | name=None, 17 | args=(), 18 | kwargs={}, 19 | cleaner=None, 20 | ctx=_default_context, 21 | *, 22 | daemon=None, 23 | ): 24 | self._exc_pipe = Pipe() 25 | super().__init__( 26 | group=group, 27 | target=target, 28 | name=name, 29 | args=args, 30 | kwargs=kwargs, 31 | daemon=daemon, 32 | ) 33 | self._cleaner = cleaner 34 | self._ctx = ctx 35 | self._start_method = ctx.Process._start_method 36 | self._exception = None 37 | 38 | @staticmethod 39 | def _Popen(process_obj): 40 | assert isinstance(process_obj, Process) 41 | return process_obj._ctx.Process._Popen(process_obj) 42 | 43 | @property 44 | def exception(self): 45 | if self._exception is not None: 46 | return self._exception 47 | elif not self._exc_pipe[0].poll(timeout=1e-4): 48 | return None 49 | self._exception = ProcessException(self._exc_pipe[0].recv()) 50 | return self._exception 51 | 52 | def watch(self): 53 | if self._exc_pipe[0].poll(timeout=1e-4): 54 | self.join() 55 | raise self.exception 56 | 57 | @staticmethod 58 | def format_exceptions(exceptions): 59 | all_tb = "" 60 | for exc, i in zip(exceptions, range(len(exceptions))): 61 | tb = exc.__traceback__ 62 | tb = traceback.format_exception(type(exc), exc, tb) 63 | tb = "".join(tb) 64 | all_tb = all_tb + f"\nException {i}:\n{tb}" 65 | return all_tb 66 | 67 | def run(self): 68 | exc = [] 69 | try: 70 | super().run() 71 | except BaseException as e: 72 | exc.append(e) 73 | finally: 74 | if self._cleaner is not None: 75 | try: 76 | self._cleaner() 77 | except BaseException as e: 78 | exc.append(e) 79 | if exc: 80 | self._exc_pipe[1].send(self.format_exceptions(exc)) 81 | -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/multiagent/multi_discrete.py: -------------------------------------------------------------------------------- 1 | # An old version of OpenAI Gym's multi_discrete.py. (Was getting affected by Gym updates) 2 | # (https://github.com/openai/gym/blob/1fb81d4e3fb780ccf77fec731287ba07da35eb84/gym/spaces/multi_discrete.py) 3 | 4 | import numpy as np 5 | 6 | import gym 7 | import numpy as np 8 | 9 | class MultiDiscrete(gym.Space): 10 | """ 11 | - The multi-discrete action space consists of a series of discrete action spaces with different parameters 12 | - It can be adapted to both a Discrete action space or a continuous (Box) action space 13 | - It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space 14 | - It is parametrized by passing an array of arrays containing [min, max] for each discrete action space 15 | where the discrete action space can take any integers from `min` to `max` (both inclusive) 16 | Note: A value of 0 always need to represent the NOOP action. 17 | e.g. Nintendo Game Controller 18 | - Can be conceptualized as 3 discrete action spaces: 19 | 1) Arrow Keys: Discrete 5 - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4] - params: min: 0, max: 4 20 | 2) Button A: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1 21 | 3) Button B: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1 22 | - Can be initialized as 23 | MultiDiscrete([ [0,4], [0,1], [0,1] ]) 24 | """ 25 | def __init__(self, array_of_param_array): 26 | self.low = np.array([x[0] for x in array_of_param_array]) 27 | self.high = np.array([x[1] for x in array_of_param_array]) 28 | self.num_discrete_space = self.low.shape[0] 29 | 30 | def sample(self): 31 | """ Returns a array with one sample from each discrete action space """ 32 | # For each row: round(random .* (max - min) + min, 0) 33 | np_random = np.random.RandomState() 34 | random_array = np_random.rand(self.num_discrete_space) 35 | return [int(x) for x in np.floor(np.multiply((self.high - self.low + 1.), random_array) + self.low)] 36 | def contains(self, x): 37 | return len(x) == self.num_discrete_space and (np.array(x) >= self.low).all() and (np.array(x) <= self.high).all() 38 | 39 | @property 40 | def shape(self): 41 | return self.num_discrete_space 42 | def __repr__(self): 43 | return "MultiDiscrete" + str(self.num_discrete_space) 44 | def __eq__(self, other): 45 | return np.array_equal(self.low, other.low) and np.array_equal(self.high, other.high) -------------------------------------------------------------------------------- /machin/parallel/event.py: -------------------------------------------------------------------------------- 1 | from threading import Event, Lock 2 | import time 3 | 4 | _lock = Lock() 5 | 6 | 7 | def init_or_register(event, parent_event): 8 | if hasattr(event, "_magic_parent_events"): 9 | event._magic_parent_events.append(parent_event) 10 | else: 11 | org_set = event.set 12 | event._magic_parent_events = [parent_event] 13 | 14 | def new_set(): 15 | nonlocal event, org_set 16 | for pe in event._magic_parent_events: 17 | with pe._cond: 18 | pe._cond.notify_all() 19 | org_set() 20 | 21 | event.set = new_set 22 | 23 | 24 | class MultiEvent(Event): 25 | # can only be used with Event from threading and not from multiprocessing 26 | def __init__(self, *events): 27 | global _lock 28 | super().__init__() 29 | 30 | with _lock: 31 | self.is_leaf = all([type(e) == Event for e in events]) 32 | for event in events: 33 | if type(event) == Event: 34 | init_or_register(event, self) 35 | elif isinstance(event, MultiEvent): 36 | for le in event.get_leaf_events(): 37 | init_or_register(le, self) 38 | else: 39 | raise ValueError( 40 | f"Type {type(event)} is not a valid event type, " 41 | "requires threading.Event or an instance of MultiEvent." 42 | ) 43 | self._events = events 44 | 45 | def set(self): 46 | pass 47 | 48 | def clear(self): 49 | pass 50 | 51 | def is_set(self): 52 | return False 53 | 54 | def get_leaf_events(self): 55 | for event in self._events: 56 | if type(event) == Event: 57 | yield event 58 | else: 59 | for event in event.get_leaf_events(): 60 | yield event 61 | 62 | def wait(self, timeout=None): 63 | begin = time.time() 64 | while True: 65 | with self._cond: 66 | self._cond.wait(timeout) 67 | if timeout is None or ( 68 | timeout is not None and time.time() - begin >= timeout 69 | ): 70 | break 71 | return self.is_set() 72 | 73 | 74 | class OrEvent(MultiEvent): 75 | def is_set(self): 76 | return any([event.is_set() for event in self._events]) 77 | 78 | 79 | class AndEvent(MultiEvent): 80 | def is_set(self): 81 | return all([event.is_set() for event in self._events]) 82 | -------------------------------------------------------------------------------- /test/parallel/server/test_ordered_server.py: -------------------------------------------------------------------------------- 1 | from machin.parallel.server import OrderedServerSimpleImpl 2 | from test.util_run_multi import * 3 | from test.util_platforms import linux_only_forall 4 | 5 | linux_only_forall() 6 | 7 | 8 | def _log(rank, msg): 9 | default_logger.info(f"Client {rank}: {msg}") 10 | 11 | 12 | class Object: 13 | pass 14 | 15 | 16 | class TestOrderedServerSimple: 17 | def test__push_pull_service(self): 18 | fake_group = Object() 19 | fake_group.pair = lambda *_: None 20 | fake_group.register = lambda *_: None 21 | fake_group.destroy = lambda: None 22 | fake_group.is_member = lambda *_: True 23 | server = OrderedServerSimpleImpl("fake_server", fake_group) 24 | assert server._push_service("a", "value", 1, None) 25 | assert not server._push_service("a", "value1", 2, 0) 26 | assert server._push_service("a", "value2", 2, 1) 27 | assert server._push_service("a", "value3", -1, 2) 28 | 29 | assert server._pull_service("a", None) == ("value3", -1) 30 | assert server._pull_service("a", 2) == ("value2", 2) 31 | assert server._pull_service("a", 1) is None 32 | assert server._pull_service("b", None) is None 33 | 34 | @staticmethod 35 | @run_multi(expected_results=[True, True, True]) 36 | @setup_world 37 | def test_push_pull(rank): 38 | world = get_world() 39 | if rank == 0: 40 | group = world.create_rpc_group("group", ["0", "1"]) 41 | _server = OrderedServerSimpleImpl("server", group) 42 | group.barrier() 43 | group.barrier() 44 | elif rank == 1: 45 | group = world.create_rpc_group("group", ["0", "1"]) 46 | group.barrier() 47 | server = group.get_paired("server").to_here() 48 | 49 | if server.push("a", "value", 1, None): 50 | _log(rank, "push 1 success") 51 | else: 52 | _log(rank, "push 1 failed") 53 | if server.push("a", "value2", 2, 1): 54 | _log(rank, "push 2 success") 55 | else: 56 | _log(rank, "push 2 failed") 57 | if server.push("a", "value3", 3, 2): 58 | _log(rank, "push 3 success") 59 | else: 60 | _log(rank, "push 3 failed") 61 | 62 | assert server.pull("a", None) == ("value3", 3) 63 | assert server.pull("a", 2) == ("value2", 2) 64 | assert server.pull("a", 1) is None 65 | assert server.pull("b", None) is None 66 | group.barrier() 67 | return True 68 | -------------------------------------------------------------------------------- /test/utils/test_helper_classes.py: -------------------------------------------------------------------------------- 1 | from machin.utils.helper_classes import Counter, Trigger, Timer, Switch, Object 2 | import pytest 3 | 4 | 5 | class TestCounter: 6 | def test_counter(self): 7 | c = Counter(start=0, step=1) 8 | c.count() 9 | assert c.get() == 1 10 | c.reset() 11 | assert c.get() == 0 12 | assert c < 1 13 | assert c <= 1 14 | assert c == 0 15 | assert c > -1 16 | assert c >= -1 17 | str(c) 18 | 19 | 20 | class TestSwitch: 21 | def test_switch(self): 22 | s = Switch() 23 | s.on() 24 | assert s.get() 25 | s.off() 26 | assert not s.get() 27 | s.flip() 28 | assert s.get() 29 | 30 | 31 | class TestTrigger: 32 | def test_trigger(self): 33 | t = Trigger() 34 | t.on() 35 | assert t.get() 36 | assert not t.get() 37 | 38 | 39 | class TestTimer: 40 | def test_timer(self): 41 | t = Timer() 42 | t.begin() 43 | t.end() 44 | 45 | 46 | class TestObject: 47 | def test_init(self): 48 | obj = Object() 49 | assert obj.data == {} 50 | obj = Object({"a": 1}) 51 | assert obj.data == {"a": 1} 52 | 53 | def test_call(self): 54 | obj = Object() 55 | obj("original_call") 56 | obj.call = lambda _: "pong" 57 | assert obj("ping") == "pong" 58 | 59 | def test_get_attr(self): 60 | obj = Object({"a": 1}) 61 | with pytest.raises(AttributeError, match="Failed to find"): 62 | _ = obj.__some_invalid_special_attr__ 63 | assert obj.a == 1 64 | 65 | def test_get_item(self): 66 | obj = Object({"a": 1}) 67 | assert obj["a"] == 1 68 | 69 | def test_set_attr(self): 70 | # set data keys 71 | obj = Object({"a": 1, "const": 0}, const_attrs={"const"}) 72 | obj.a = 1 73 | assert obj.a == 1 74 | obj.b = 1 75 | assert obj.b == 1 76 | 77 | # set const keys 78 | with pytest.raises(RuntimeError, match="is const"): 79 | obj.const = 1 80 | 81 | # set .call attribute 82 | obj.call = lambda _: "pong" 83 | assert obj("ping") == "pong" 84 | 85 | # set .data attribute 86 | obj.data = {} 87 | assert obj.a is None 88 | with pytest.raises(ValueError, match="must be a dictionary"): 89 | obj.data = None 90 | 91 | # set other attributes 92 | with pytest.raises(RuntimeError, match="should not set"): 93 | obj.__dict__ = {} 94 | 95 | def test_set_item(self): 96 | obj = Object({"a": 1}) 97 | obj["a"] = 2 98 | assert obj.a == 2 99 | -------------------------------------------------------------------------------- /test/utils/test_save_env.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from machin.utils.save_env import SaveEnv 3 | 4 | import os 5 | import time 6 | 7 | 8 | class TestSaveEnv: 9 | def test_all(self, tmpdir): 10 | tmp_dir = str(tmpdir.make_numbered_dir()) 11 | save_env = SaveEnv(env_root=tmp_dir) 12 | save_env = SaveEnv( 13 | env_root=tmp_dir, 14 | restart_from_trial=os.path.basename(save_env.get_trial_root()), 15 | ) 16 | 17 | # check directories 18 | t_root = save_env.get_trial_root() 19 | assert ( 20 | os.path.exists(join(t_root, "model")) 21 | and os.path.isdir(join(t_root, "model")) 22 | and os.path.exists(join(t_root, "config")) 23 | and os.path.isdir(join(t_root, "config")) 24 | and os.path.exists(join(t_root, "log", "images")) 25 | and os.path.isdir(join(t_root, "log", "images")) 26 | and os.path.exists(join(t_root, "log", "train_log")) 27 | and os.path.isdir(join(t_root, "log", "train_log")) 28 | ) 29 | 30 | save_env.create_dirs(["some_custom_dir"]) 31 | assert os.path.exists(join(t_root, "some_custom_dir")) and os.path.isdir( 32 | join(t_root, "some_custom_dir") 33 | ) 34 | 35 | save_env.get_trial_time() 36 | assert save_env.get_trial_config_dir() == join(t_root, "config") 37 | assert save_env.get_trial_model_dir() == join(t_root, "model") 38 | assert save_env.get_trial_image_dir() == join(t_root, "log", "images") 39 | assert save_env.get_trial_train_log_dir() == join(t_root, "log", "train_log") 40 | 41 | with open(join(t_root, "config", "conf.json"), "w") as _: 42 | pass 43 | with open(join(t_root, "model", "model.pt"), "w") as _: 44 | pass 45 | with open(join(t_root, "log", "images", "image.png"), "w") as _: 46 | pass 47 | with open(join(t_root, "log", "train_log", "log.txt"), "w") as _: 48 | pass 49 | save_env.clear_trial_config_dir() 50 | assert not os.path.exists(join(t_root, "config", "conf.json")) 51 | save_env.clear_trial_model_dir() 52 | assert not os.path.exists(join(t_root, "model", "model.pt")) 53 | save_env.clear_trial_image_dir() 54 | assert not os.path.exists(join(t_root, "log", "images", "image.png")) 55 | save_env.clear_trial_train_log_dir() 56 | assert not os.path.exists(join(t_root, "log", "train_log", "log.txt")) 57 | 58 | time.sleep(2) 59 | os.mkdir(join(tmp_dir, "some_dir_not_trial")) 60 | save_env2 = SaveEnv(env_root=tmp_dir) 61 | save_env2.remove_trials_older_than(0, 0, 0, 1) 62 | assert not os.path.exists(t_root) 63 | -------------------------------------------------------------------------------- /docs/misc/contribute.md: -------------------------------------------------------------------------------- 1 | ### Contribute to Machin 2 | 3 | #### Prepare your editing environment 4 | --- 5 | `virtualenv` package facilitates local editing, you should install it first: 6 | ``` 7 | python3 -m pip install virtualenv 8 | ``` 9 | Then you should clone the repository and start a new virtual environment named 10 | in the root directory using: 11 | ``` 12 | git clone https://github.com/iffiX/machin.git 13 | cd machin 14 | virtualenv --no-site-packages venv 15 | ``` 16 | Finally you can switch to this new virtual environment and install the Machin 17 | library in local edit mode, in this way, all edits on Machin files will be 18 | effective immediately. 19 | ``` 20 | source venv/bin/activate 21 | pip3 install -e . 22 | ``` 23 | 24 | #### Polish your code 25 | All code in Machin must be **readable**, therefore we require you write your 26 | code in the [google python style](http://google.github.io/styleguide/pyguide.html). 27 | 28 | You also must document your code, we require you to give detailed signatures 29 | using the `typing` builtin library for all of your arguments as well as 30 | keyword arguments. A great example is the 31 | [google style docstring](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) 32 | from napoleon, an extension of sphinx the doc builder. 33 | 34 | Finally, after so much hard work, do not forget to use pylint to check your code, 35 | `PEP8` style must be conformed. 36 | 37 | #### Test your code 38 | The great test at last! you can run the following command to run all existing 39 | tests coming along with the Machin library. This test command also includes 40 | training networks. 41 | ``` 42 | pytest test --conv machin --capture=no --durations=0 -v 43 | ``` 44 | If you just want to test(train) a specific algorithm, use `-k` option, name is 45 | ` and full_train`: 46 | ``` 47 | pytest test -k "DDPG and full_train" 48 | ``` 49 | Or exclude all training by using: 50 | ``` 51 | pytest test -k "not full_train" 52 | ``` 53 | 54 | You should group your tests in a single class, like the example given by pytest, 55 | or insert your test cases into a existing test class: 56 | ``` 57 | # content of test_class.py 58 | class TestClass: 59 | def test_one(self): 60 | x = "this" 61 | assert "h" in x 62 | 63 | def test_two(self): 64 | x = "hello" 65 | assert hasattr(x, "check") 66 | ``` 67 | #### Submit a pull request 68 | Submit a pull request in the end! We truly appreciate your help, 69 | Travis will automatically test your code and we will review your code as soon as possible. 70 | 71 | #### Build the documents 72 | In order to build the documents, you must install `requirements.txt` in `/docs` 73 | in your venv, and execute: 74 | ``` 75 | cd docs 76 | make html 77 | ``` 78 | This command will build your documents in `docs/build`. -------------------------------------------------------------------------------- /examples/framework_examples/dqn.py: -------------------------------------------------------------------------------- 1 | from machin.frame.algorithms import DQN 2 | from machin.utils.logging import default_logger as logger 3 | import torch as t 4 | import torch.nn as nn 5 | import gym 6 | 7 | # configurations 8 | env = gym.make("CartPole-v0") 9 | observe_dim = 4 10 | action_num = 2 11 | max_episodes = 1000 12 | max_steps = 200 13 | solved_reward = 190 14 | solved_repeat = 5 15 | 16 | 17 | # model definition 18 | class QNet(nn.Module): 19 | def __init__(self, state_dim, action_num): 20 | super().__init__() 21 | 22 | self.fc1 = nn.Linear(state_dim, 16) 23 | self.fc2 = nn.Linear(16, 16) 24 | self.fc3 = nn.Linear(16, action_num) 25 | 26 | def forward(self, state): 27 | a = t.relu(self.fc1(state)) 28 | a = t.relu(self.fc2(a)) 29 | return self.fc3(a) 30 | 31 | 32 | if __name__ == "__main__": 33 | q_net = QNet(observe_dim, action_num) 34 | q_net_t = QNet(observe_dim, action_num) 35 | 36 | dqn = DQN(q_net, q_net_t, t.optim.Adam, nn.MSELoss(reduction="sum")) 37 | 38 | episode, step, reward_fulfilled = 0, 0, 0 39 | smoothed_total_reward = 0 40 | 41 | while episode < max_episodes: 42 | episode += 1 43 | total_reward = 0 44 | terminal = False 45 | step = 0 46 | state = t.tensor(env.reset(), dtype=t.float32).view(1, observe_dim) 47 | tmp_observations = [] 48 | 49 | while not terminal and step <= max_steps: 50 | step += 1 51 | with t.no_grad(): 52 | old_state = state 53 | # agent model inference 54 | action = dqn.act_discrete_with_noise({"state": old_state}) 55 | state, reward, terminal, _ = env.step(action.item()) 56 | state = t.tensor(state, dtype=t.float32).view(1, observe_dim) 57 | total_reward += reward 58 | 59 | tmp_observations.append( 60 | { 61 | "state": {"state": old_state}, 62 | "action": {"action": action}, 63 | "next_state": {"state": state}, 64 | "reward": reward, 65 | "terminal": terminal or step == max_steps, 66 | } 67 | ) 68 | 69 | dqn.store_episode(tmp_observations) 70 | # update, update more if episode is longer, else less 71 | if episode > 100: 72 | for _ in range(step): 73 | dqn.update() 74 | 75 | # show reward 76 | smoothed_total_reward = smoothed_total_reward * 0.9 + total_reward * 0.1 77 | logger.info(f"Episode {episode} total reward={smoothed_total_reward:.2f}") 78 | 79 | if smoothed_total_reward > solved_reward: 80 | reward_fulfilled += 1 81 | if reward_fulfilled >= solved_repeat: 82 | logger.info("Environment solved!") 83 | exit(0) 84 | else: 85 | reward_fulfilled = 0 86 | -------------------------------------------------------------------------------- /examples/framework_examples/dqn_per.py: -------------------------------------------------------------------------------- 1 | from machin.frame.algorithms import DQNPer 2 | from machin.utils.logging import default_logger as logger 3 | import torch as t 4 | import torch.nn as nn 5 | import gym 6 | 7 | # configurations 8 | env = gym.make("CartPole-v0") 9 | observe_dim = 4 10 | action_num = 2 11 | max_episodes = 1000 12 | max_steps = 200 13 | solved_reward = 190 14 | solved_repeat = 5 15 | 16 | 17 | # model definition 18 | class QNet(nn.Module): 19 | def __init__(self, state_dim, action_num): 20 | super().__init__() 21 | 22 | self.fc1 = nn.Linear(state_dim, 16) 23 | self.fc2 = nn.Linear(16, 16) 24 | self.fc3 = nn.Linear(16, action_num) 25 | 26 | def forward(self, state): 27 | a = t.relu(self.fc1(state)) 28 | a = t.relu(self.fc2(a)) 29 | return self.fc3(a) 30 | 31 | 32 | if __name__ == "__main__": 33 | q_net = QNet(observe_dim, action_num) 34 | q_net_t = QNet(observe_dim, action_num) 35 | 36 | dqn_per = DQNPer(q_net, q_net_t, t.optim.Adam, nn.MSELoss(reduction="sum")) 37 | 38 | episode, step, reward_fulfilled = 0, 0, 0 39 | smoothed_total_reward = 0 40 | 41 | while episode < max_episodes: 42 | episode += 1 43 | total_reward = 0 44 | terminal = False 45 | step = 0 46 | state = t.tensor(env.reset(), dtype=t.float32).view(1, observe_dim) 47 | tmp_observations = [] 48 | 49 | while not terminal and step <= max_steps: 50 | step += 1 51 | with t.no_grad(): 52 | old_state = state 53 | # agent model inference 54 | action = dqn_per.act_discrete_with_noise({"state": old_state}) 55 | state, reward, terminal, _ = env.step(action.item()) 56 | state = t.tensor(state, dtype=t.float32).view(1, observe_dim) 57 | total_reward += reward 58 | 59 | tmp_observations.append( 60 | { 61 | "state": {"state": old_state}, 62 | "action": {"action": action}, 63 | "next_state": {"state": state}, 64 | "reward": reward, 65 | "terminal": terminal or step == max_steps, 66 | } 67 | ) 68 | 69 | dqn_per.store_episode(tmp_observations) 70 | # update, update more if episode is longer, else less 71 | if episode > 100: 72 | for _ in range(step): 73 | dqn_per.update() 74 | 75 | # show reward 76 | smoothed_total_reward = smoothed_total_reward * 0.9 + total_reward * 0.1 77 | logger.info(f"Episode {episode} total reward={smoothed_total_reward:.2f}") 78 | 79 | if smoothed_total_reward > solved_reward: 80 | reward_fulfilled += 1 81 | if reward_fulfilled >= solved_repeat: 82 | logger.info("Environment solved!") 83 | exit(0) 84 | else: 85 | reward_fulfilled = 0 86 | -------------------------------------------------------------------------------- /examples/tutorials/recurrent_networks/dqn.py: -------------------------------------------------------------------------------- 1 | from machin.env.utils.openai_gym import disable_view_window 2 | from machin.frame.algorithms import DQNPer 3 | from machin.utils.logging import default_logger as logger 4 | 5 | import gym 6 | import torch as t 7 | import torch.nn as nn 8 | 9 | from util import convert 10 | from history import History 11 | 12 | # configurations 13 | env = gym.make("Frostbite-ram-v0") 14 | action_num = env.action_space.n 15 | max_episodes = 20000 16 | history_depth = 4 17 | 18 | # disable view window in rendering 19 | disable_view_window() 20 | 21 | 22 | # Q network model definition 23 | # for atari games 24 | class QNet(nn.Module): 25 | def __init__(self, history_depth, action_num): 26 | super().__init__() 27 | self.fc1 = nn.Linear(128 * history_depth, 256) 28 | self.fc2 = nn.Linear(256, 256) 29 | self.fc3 = nn.Linear(256, action_num) 30 | 31 | def forward(self, mem): 32 | return self.fc3(t.relu(self.fc2(t.relu(self.fc1(mem.flatten(start_dim=1)))))) 33 | 34 | 35 | if __name__ == "__main__": 36 | q_net = QNet(history_depth, action_num).to("cuda:0") 37 | q_net_t = QNet(history_depth, action_num).to("cuda:0") 38 | 39 | dqn = DQNPer( 40 | q_net, q_net_t, t.optim.Adam, nn.MSELoss(reduction="sum"), learning_rate=5e-4 41 | ) 42 | 43 | episode, step, reward_fulfilled = 0, 0, 0 44 | smoothed_total_reward = 0 45 | 46 | while episode < max_episodes: 47 | episode += 1 48 | total_reward = 0 49 | terminal = False 50 | step = 0 51 | state = convert(env.reset()) 52 | history = History(history_depth, (1, 128)) 53 | 54 | while not terminal: 55 | step += 1 56 | with t.no_grad(): 57 | history.append(state) 58 | # agent model inference 59 | action = dqn.act_discrete_with_noise({"mem": history.get()}) 60 | 61 | # info is {"ale.lives": self.ale.lives()}, not used here 62 | state, reward, terminal, _ = env.step(action.item()) 63 | state = convert(state) 64 | total_reward += reward 65 | old_history = history.get() 66 | new_history = history.append(state).get() 67 | dqn.store_transition( 68 | { 69 | "state": {"mem": old_history}, 70 | "action": {"action": action}, 71 | "next_state": {"mem": new_history}, 72 | "reward": reward, 73 | "terminal": terminal, 74 | } 75 | ) 76 | 77 | # update, update more if episode is longer, else less 78 | if episode > 20: 79 | for _ in range(step // 10): 80 | dqn.update() 81 | 82 | # show reward 83 | smoothed_total_reward = smoothed_total_reward * 0.9 + total_reward * 0.1 84 | 85 | logger.info(f"Episode {episode} total reward={smoothed_total_reward:.2f}") 86 | -------------------------------------------------------------------------------- /test/utils/test_prepare.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from machin.utils.prepare import ( 3 | prep_clear_dirs, 4 | prep_create_dirs, 5 | prep_load_state_dict, 6 | prep_load_model, 7 | ) 8 | 9 | import os 10 | import pytest 11 | import torch as t 12 | 13 | 14 | def create_file(file_path): 15 | base_dir = os.path.dirname(file_path) 16 | if not os.path.exists(base_dir): 17 | os.makedirs(base_dir) 18 | open(file_path, "a").close() 19 | 20 | 21 | def test_prep_clear_dirs(tmpdir): 22 | tmp_dir = str(tmpdir.make_numbered_dir()) 23 | create_file(join(tmp_dir, "some_dir", "some_file")) 24 | create_file(join(tmp_dir, "some_file2")) 25 | os.symlink(join(tmp_dir, "some_dir", "some_file"), join(tmp_dir, "some_file3")) 26 | prep_clear_dirs([tmp_dir]) 27 | assert not os.path.exists(join(tmp_dir, "some_dir", "some_file")) 28 | assert not os.path.exists(join(tmp_dir, "some_file2")) 29 | assert not os.path.exists(join(tmp_dir, "some_file3")) 30 | 31 | 32 | def test_prep_create_dirs(tmpdir): 33 | tmp_dir = str(tmpdir.make_numbered_dir()) 34 | prep_create_dirs([join(tmp_dir, "some_dir")]) 35 | assert os.path.exists(join(tmp_dir, "some_dir")) and os.path.isdir( 36 | join(tmp_dir, "some_dir") 37 | ) 38 | 39 | 40 | def test_prep_load_state_dict(pytestconfig): 41 | model = t.nn.Linear(100, 100) 42 | model2 = t.nn.Linear(100, 100).to(pytestconfig.getoption("gpu_device")) 43 | state_dict = model2.state_dict() 44 | prep_load_state_dict(model, state_dict) 45 | assert t.all(model.weight == model2.weight.to("cpu")) 46 | assert t.all(model.bias == model2.bias.to("cpu")) 47 | 48 | 49 | def test_prep_load_model(tmpdir): 50 | tmp_dir = str(tmpdir.make_numbered_dir()) 51 | tmp_dir2 = str(tmpdir.make_numbered_dir()) 52 | 53 | # create example model directory 54 | with t.no_grad(): 55 | model = t.nn.Linear(100, 100, bias=False) 56 | model.weight.fill_(0) 57 | t.save(model, join(tmp_dir, "model_0.pt")) 58 | model.weight.fill_(1) 59 | t.save(model, join(tmp_dir, "model_100.pt")) 60 | 61 | with pytest.raises(RuntimeError, match="Model directory doesn't exist"): 62 | prep_load_model(join(tmp_dir, "not_exist_dir"), {"model": model}) 63 | 64 | # load a specific version 65 | prep_load_model(tmp_dir, {"model": model}, version=0) 66 | assert t.all(model.weight == 0) 67 | 68 | # load a non-exist version in a directory with valid models 69 | # will load version 100 70 | prep_load_model(tmp_dir, {"model": model}, version=50) 71 | assert t.all(model.weight == 1) 72 | 73 | # load the newest version 74 | prep_load_model(tmp_dir, {"model": model}) 75 | assert t.all(model.weight == 1) 76 | 77 | # load a non-exist version in a directory with invalid models 78 | # eg: cannot find the same version for all models in the model map 79 | with pytest.raises(RuntimeError, match="Cannot find a valid version"): 80 | prep_load_model(tmp_dir2, {"model": model}) 81 | prep_load_model(tmp_dir2, {"model": model}, quiet=True) 82 | -------------------------------------------------------------------------------- /machin/env/wrappers/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Union, List, Any 3 | 4 | 5 | class ParallelWrapperBase(ABC): 6 | def __init__(self, *_, **__): 7 | """ 8 | Note: 9 | Parallel wrapper is designed to wrap the same kind of environments, 10 | they may have different parameters, but must have the same action 11 | and observation space. 12 | """ 13 | pass 14 | 15 | @abstractmethod 16 | def reset(self, idx: Union[int, List[int], None] = None) -> Any: 17 | """ 18 | Reset all environments if id is ``None``, otherwise reset the specific 19 | environment(s) with given index(es). 20 | 21 | Args: 22 | idx: Environment index(es) to be reset. 23 | 24 | Returns: 25 | Initial observation of all environments. Format is unspecified. 26 | """ 27 | pass 28 | 29 | @abstractmethod 30 | def step(self, action, idx: Union[int, List[int], None] = None) -> Any: 31 | """ 32 | Let specified environment(s) run one time step. specified environments 33 | must be active and have not reached terminal states before. 34 | 35 | Args: 36 | action: actions to take. 37 | idx: Environment index(es) to be run. 38 | 39 | Returns: 40 | New states of environments. 41 | """ 42 | pass 43 | 44 | @abstractmethod 45 | def seed(self, seed: Union[int, List[int], None] = None) -> List[int]: 46 | """ 47 | Set seed(s) for all environment(s). 48 | 49 | Args: 50 | seed: A single integer seed for all environments, 51 | or a list of integers for each environment, 52 | or None for default seed. 53 | 54 | Returns: 55 | New seed of each environment. 56 | """ 57 | pass 58 | 59 | @abstractmethod 60 | def render(self, *args, **kwargs) -> Any: 61 | """ 62 | Render all environments. 63 | """ 64 | pass 65 | 66 | @abstractmethod 67 | def close(self) -> Any: 68 | """ 69 | Close all environments. 70 | """ 71 | pass 72 | 73 | @abstractmethod 74 | def active(self) -> List[int]: 75 | """ 76 | Returns: 77 | Indexes of active environments. 78 | """ 79 | pass 80 | 81 | @abstractmethod 82 | def size(self) -> int: 83 | """ 84 | Returns: 85 | Number of environments. 86 | """ 87 | pass 88 | 89 | @property 90 | @abstractmethod 91 | def action_space(self) -> Any: 92 | """ 93 | Returns: 94 | Action space descriptor. 95 | """ 96 | pass 97 | 98 | @property 99 | @abstractmethod 100 | def observation_space(self) -> Any: 101 | """ 102 | Returns: 103 | Observation space descriptor. 104 | """ 105 | pass 106 | -------------------------------------------------------------------------------- /examples/tutorials/parallel_distributed/mpr_pickle.py: -------------------------------------------------------------------------------- 1 | from machin.parallel.pickle import dumps, loads 2 | from machin.parallel.process import Process 3 | from machin.parallel import get_context 4 | import torch as t 5 | 6 | 7 | def print_tensor_sub_proc(tens): 8 | print(loads(tens)) 9 | 10 | 11 | def exec_sub_proc(func): 12 | loads(func)() 13 | 14 | 15 | if __name__ == "__main__": 16 | spawn_ctx = get_context("spawn") 17 | fork_ctx = get_context("fork") 18 | # cpu tensor, not in shared memory 19 | # If you would like to pass this tensor to a sub process 20 | # set copy_tensor to `True`, otherwise only a pointer to 21 | # memory will be passed to the subprocess. 22 | # However, if you do this in the same process, no SEGFAULT 23 | # will happen, because memory map is the same. 24 | tensor = t.ones([10]) 25 | p = Process( 26 | target=print_tensor_sub_proc, 27 | args=(dumps(tensor, copy_tensor=True),), 28 | ctx=fork_ctx, 29 | ) 30 | p.start() 31 | p.join() 32 | # cpu tensor, in shared memory 33 | 34 | # If you would like to pass this tensor to a sub process 35 | # set copy_tensor to `False` is more efficient, because 36 | # only a pointer to the shared memory will be passed, and 37 | # not all data in the tensor. 38 | tensor.share_memory_() 39 | p = Process( 40 | target=print_tensor_sub_proc, 41 | args=(dumps(tensor, copy_tensor=False),), 42 | ctx=fork_ctx, 43 | ) 44 | p.start() 45 | p.join() 46 | print( 47 | "Dumped length of shm tensor if copy: {}".format( 48 | len(dumps(tensor, copy_tensor=True)) 49 | ) 50 | ) 51 | print( 52 | "Dumped length of shm tensor if not copy: {}".format( 53 | len(dumps(tensor, copy_tensor=False)) 54 | ) 55 | ) 56 | 57 | # gpu tensor 58 | # If you would like to pass this tensor to a sub process 59 | # set copy_tensor to `False` is more efficient, because 60 | # only a pointer to the CUDA memory will be passed, and 61 | # not all data in the tensor. 62 | # You should use "spawn" context instead of "fork" as well. 63 | tensor = tensor.to("cuda:0") 64 | p = Process( 65 | target=print_tensor_sub_proc, 66 | args=(dumps(tensor, copy_tensor=False),), 67 | ctx=spawn_ctx, 68 | ) 69 | p.start() 70 | p.join() 71 | print( 72 | "Dumped length of gpu tensor if copy: {}".format( 73 | len(dumps(tensor, copy_tensor=True)) 74 | ) 75 | ) 76 | print( 77 | "Dumped length of gpu tensor if not copy: {}".format( 78 | len(dumps(tensor, copy_tensor=False)) 79 | ) 80 | ) 81 | 82 | # in order to pass a local function / lambda function 83 | # to the subprocess, set recursive to `true` 84 | # then refered nonlocal&global variable will also be serialized. 85 | def local_func(): 86 | global tensor 87 | tensor.fill_(3) 88 | 89 | print(f"Before:{tensor}") 90 | p = Process( 91 | target=exec_sub_proc, 92 | args=(dumps(local_func, recurse=True, copy_tensor=False),), 93 | ctx=spawn_ctx, 94 | ) 95 | p.start() 96 | p.join() 97 | print(f"After:{tensor}") 98 | -------------------------------------------------------------------------------- /machin/auto/__main__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from pprint import pprint 4 | from machin.auto.config import ( 5 | get_available_algorithms, 6 | get_available_environments, 7 | generate_algorithm_config, 8 | generate_env_config, 9 | generate_training_config, 10 | launch, 11 | ) 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | subparsers = parser.add_subparsers(dest="command") 16 | 17 | p_list = subparsers.add_parser( 18 | "list", help="List available algorithms or environments." 19 | ) 20 | 21 | p_list.add_argument( 22 | "--algo", action="store_true", help="List available algorithms.", 23 | ) 24 | 25 | p_list.add_argument( 26 | "--env", action="store_true", help="List available environments." 27 | ) 28 | 29 | p_generate = subparsers.add_parser("generate", help="Generate configuration.") 30 | 31 | p_generate.add_argument( 32 | "--algo", type=str, required=True, help="Algorithm name to use." 33 | ) 34 | p_generate.add_argument( 35 | "--env", type=str, required=True, help="Environment name to use." 36 | ) 37 | p_generate.add_argument( 38 | "--print", action="store_true", help="Direct config output to screen." 39 | ) 40 | p_generate.add_argument( 41 | "--output", 42 | type=str, 43 | default="config.json", 44 | help="JSON config file output path.", 45 | ) 46 | 47 | p_launch = subparsers.add_parser( 48 | "launch", help="Launch training with pytorch-lightning." 49 | ) 50 | 51 | p_launch.add_argument( 52 | "--config", type=str, default="config.json", help="JSON config file path.", 53 | ) 54 | 55 | args = parser.parse_args() 56 | if args.command == "list": 57 | if args.env: 58 | print("Available environments are:") 59 | for env in get_available_environments(): 60 | print(env) 61 | elif args.algo: 62 | print("Available algorithms are:") 63 | for algo in get_available_algorithms(): 64 | print(algo) 65 | else: 66 | print("You can list --algo or --env.") 67 | 68 | elif args.command == "generate": 69 | if args.algo not in get_available_algorithms(): 70 | print( 71 | f"{args.algo} is not a valid algorithm, use list " 72 | "--algo to get a list of available algorithms." 73 | ) 74 | exit(0) 75 | if args.env not in get_available_environments(): 76 | print( 77 | f"{args.env} is not a valid environment, use list " 78 | "--env to get a list of available environments." 79 | ) 80 | exit(0) 81 | config = {} 82 | config = generate_env_config(args.env, config=config) 83 | config = generate_algorithm_config(args.algo, config=config) 84 | config = generate_training_config(config=config) 85 | 86 | if args.print: 87 | pprint(config) 88 | 89 | with open(args.output, "w") as f: 90 | json.dump(config, f, indent=4, sort_keys=True) 91 | print(f"Config saved to {args.output}") 92 | 93 | elif args.command == "launch": 94 | with open(args.config, "r") as f: 95 | conf = json.load(f) 96 | launch(conf) 97 | -------------------------------------------------------------------------------- /machin/auto/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import tempfile 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | from typing import Iterator, List, Dict, Union, Any, Tuple, Callable 7 | from torch.utils.data import IterableDataset 8 | from pytorch_lightning.loggers.base import LoggerCollection 9 | from machin.utils.media import create_video, numpy_array_to_pil_image 10 | 11 | 12 | Scalar = Any 13 | 14 | 15 | def determine_precision(models): 16 | dtype = set() 17 | for model in models: 18 | for k, v in model.named_parameters(): 19 | dtype.add(v.dtype) 20 | dtype = list(dtype) 21 | if len(dtype) > 1: 22 | raise RuntimeError( 23 | "Multiple data types of parameters detected " 24 | f"in models: {dtype}, this is currently not supported " 25 | "since we need to determine the data type of your " 26 | "model input from your model parameter data type." 27 | ) 28 | return dtype[0] 29 | 30 | 31 | def get_loggers_as_list(module: pl.LightningModule): 32 | if isinstance(module.logger, LoggerCollection): 33 | return module.logger._logger_iterable 34 | else: 35 | return [module.logger] 36 | 37 | 38 | def log_image(module, name, image: np.ndarray): 39 | for logger in get_loggers_as_list(module): 40 | if hasattr(logger, "log_image") and callable(logger.log_image): 41 | logger.log_image(name, numpy_array_to_pil_image(image)) 42 | 43 | 44 | def log_video(module, name, video_frames: List[np.ndarray]): 45 | # create video temp file 46 | fd, path = tempfile.mkstemp(suffix=".gif") 47 | os.close(fd) 48 | try: 49 | create_video( 50 | video_frames, 51 | os.path.dirname(path), 52 | os.path.basename(os.path.splitext(path)[0]), 53 | extension=".gif", 54 | ) 55 | except Exception as e: 56 | print(e) 57 | os.remove(path) 58 | return 59 | 60 | size = os.path.getsize(path) 61 | while True: 62 | time.sleep(1) 63 | new_size = os.path.getsize(path) 64 | if size != 0 and new_size == size: 65 | break 66 | size = new_size 67 | 68 | for logger in get_loggers_as_list(module): 69 | if hasattr(logger, "log_artifact") and callable(logger.log_artifact): 70 | logger.log_artifact(path, name + ".gif") 71 | if os.path.exists(path): 72 | os.remove(path) 73 | 74 | 75 | class DatasetResult: 76 | def __init__( 77 | self, 78 | observations: List[Dict[str, Any]] = None, 79 | logs: List[Dict[str, Union[Scalar, Tuple[Scalar, str]]]] = None, 80 | ): 81 | self.observations = observations or [] 82 | self.logs = logs or [] 83 | 84 | def add_observation(self, obs: Dict[str, Any]): 85 | self.observations.append(obs) 86 | 87 | def add_log(self, log: Dict[str, Union[Scalar, Tuple[Any, Callable]]]): 88 | self.logs.append(log) 89 | 90 | def __len__(self): 91 | return len(self.observations) 92 | 93 | 94 | class RLDataset(IterableDataset): 95 | """ 96 | Base class for all RL Datasets. 97 | """ 98 | 99 | early_stopping_monitor = "" 100 | 101 | def __init__(self, **_kwargs): 102 | super().__init__() 103 | 104 | def __iter__(self) -> Iterator: 105 | return self 106 | 107 | def __next__(self) -> DatasetResult: 108 | raise StopIteration() 109 | -------------------------------------------------------------------------------- /examples/framework_examples/ars.py: -------------------------------------------------------------------------------- 1 | from machin.model.nets.base import dynamic_module_wrapper as dmw 2 | from machin.frame.helpers.servers import model_server_helper 3 | from machin.frame.algorithms import ARS 4 | from machin.parallel.distributed import World 5 | from machin.utils.logging import default_logger as logger 6 | from torch.multiprocessing import spawn 7 | import gym 8 | import torch as t 9 | import torch.nn as nn 10 | 11 | 12 | class ActorDiscrete(nn.Module): 13 | def __init__(self, state_dim, action_dim): 14 | super().__init__() 15 | self.fc = nn.Linear(state_dim, action_dim, bias=False) 16 | 17 | def forward(self, state): 18 | a = t.argmax(self.fc(state), dim=1).item() 19 | return a 20 | 21 | 22 | def main(rank): 23 | env = gym.make("CartPole-v0") 24 | observe_dim = 4 25 | action_num = 2 26 | max_episodes = 2000 27 | max_steps = 200 28 | solved_reward = 190 29 | solved_repeat = 5 30 | 31 | # initlize distributed world first 32 | world = World(world_size=3, rank=rank, name=str(rank), rpc_timeout=20) 33 | 34 | actor = dmw(ActorDiscrete(observe_dim, action_num)) 35 | servers = model_server_helper(model_num=1) 36 | ars_group = world.create_rpc_group("ars", ["0", "1", "2"]) 37 | ars = ARS( 38 | actor, 39 | t.optim.SGD, 40 | ars_group, 41 | servers, 42 | noise_std_dev=0.1, 43 | learning_rate=0.1, 44 | noise_size=1000000, 45 | rollout_num=6, 46 | used_rollout_num=6, 47 | normalize_state=True, 48 | ) 49 | 50 | # begin training 51 | episode, step, reward_fulfilled = 0, 0, 0 52 | smoothed_total_reward = 0 53 | 54 | while episode < max_episodes: 55 | episode += 1 56 | all_reward = 0 57 | for at in ars.get_actor_types(): 58 | total_reward = 0 59 | terminal = False 60 | step = 0 61 | 62 | # batch size = 1 63 | state = t.tensor(env.reset(), dtype=t.float32).view(1, observe_dim) 64 | while not terminal and step <= max_steps: 65 | step += 1 66 | with t.no_grad(): 67 | # agent model inference 68 | action = ars.act({"state": state}, at) 69 | state, reward, terminal, __ = env.step(action) 70 | state = t.tensor(state, dtype=t.float32).view(1, observe_dim) 71 | total_reward += reward 72 | 73 | ars.store_reward(total_reward, at) 74 | all_reward += total_reward 75 | 76 | # update 77 | ars.update() 78 | 79 | # show reward 80 | smoothed_total_reward = ( 81 | smoothed_total_reward * 0.9 + all_reward / len(ars.get_actor_types()) * 0.1 82 | ) 83 | logger.info( 84 | f"Process {rank} Episode {episode} total reward={smoothed_total_reward:.2f}" 85 | ) 86 | 87 | if smoothed_total_reward > solved_reward: 88 | reward_fulfilled += 1 89 | if reward_fulfilled >= solved_repeat: 90 | logger.info("Environment solved!") 91 | # will cause torch RPC to complain 92 | # since other processes may have not finished yet. 93 | # just for demonstration. 94 | exit(0) 95 | else: 96 | reward_fulfilled = 0 97 | 98 | 99 | if __name__ == "__main__": 100 | # spawn 3 sub processes 101 | spawn(main, nprocs=3) 102 | -------------------------------------------------------------------------------- /examples/framework_examples/a2c.py: -------------------------------------------------------------------------------- 1 | from machin.frame.algorithms import A2C 2 | from machin.utils.logging import default_logger as logger 3 | from torch.distributions import Categorical 4 | import torch as t 5 | import torch.nn as nn 6 | import gym 7 | 8 | # configurations 9 | env = gym.make("CartPole-v0") 10 | observe_dim = 4 11 | action_num = 2 12 | max_episodes = 1000 13 | max_steps = 200 14 | solved_reward = 190 15 | solved_repeat = 5 16 | 17 | 18 | # model definition 19 | class Actor(nn.Module): 20 | def __init__(self, state_dim, action_num): 21 | super().__init__() 22 | 23 | self.fc1 = nn.Linear(state_dim, 16) 24 | self.fc2 = nn.Linear(16, 16) 25 | self.fc3 = nn.Linear(16, action_num) 26 | 27 | def forward(self, state, action=None): 28 | a = t.relu(self.fc1(state)) 29 | a = t.relu(self.fc2(a)) 30 | probs = t.softmax(self.fc3(a), dim=1) 31 | dist = Categorical(probs=probs) 32 | act = action if action is not None else dist.sample() 33 | act_entropy = dist.entropy() 34 | act_log_prob = dist.log_prob(act.flatten()) 35 | return act, act_log_prob, act_entropy 36 | 37 | 38 | class Critic(nn.Module): 39 | def __init__(self, state_dim): 40 | super().__init__() 41 | 42 | self.fc1 = nn.Linear(state_dim, 16) 43 | self.fc2 = nn.Linear(16, 16) 44 | self.fc3 = nn.Linear(16, 1) 45 | 46 | def forward(self, state): 47 | v = t.relu(self.fc1(state)) 48 | v = t.relu(self.fc2(v)) 49 | v = self.fc3(v) 50 | return v 51 | 52 | 53 | if __name__ == "__main__": 54 | actor = Actor(observe_dim, action_num) 55 | critic = Critic(observe_dim) 56 | 57 | a2c = A2C(actor, critic, t.optim.Adam, nn.MSELoss(reduction="sum")) 58 | 59 | episode, step, reward_fulfilled = 0, 0, 0 60 | smoothed_total_reward = 0 61 | 62 | while episode < max_episodes: 63 | episode += 1 64 | total_reward = 0 65 | terminal = False 66 | step = 0 67 | state = t.tensor(env.reset(), dtype=t.float32).view(1, observe_dim) 68 | 69 | tmp_observations = [] 70 | while not terminal and step <= max_steps: 71 | step += 1 72 | with t.no_grad(): 73 | old_state = state 74 | # agent model inference 75 | action = a2c.act({"state": old_state})[0] 76 | state, reward, terminal, _ = env.step(action.item()) 77 | state = t.tensor(state, dtype=t.float32).view(1, observe_dim) 78 | total_reward += reward 79 | 80 | tmp_observations.append( 81 | { 82 | "state": {"state": old_state}, 83 | "action": {"action": action}, 84 | "next_state": {"state": state}, 85 | "reward": reward, 86 | "terminal": terminal or step == max_steps, 87 | } 88 | ) 89 | 90 | # update 91 | a2c.store_episode(tmp_observations) 92 | a2c.update() 93 | 94 | # show reward 95 | smoothed_total_reward = smoothed_total_reward * 0.9 + total_reward * 0.1 96 | logger.info(f"Episode {episode} total reward={smoothed_total_reward:.2f}") 97 | 98 | if smoothed_total_reward > solved_reward: 99 | reward_fulfilled += 1 100 | if reward_fulfilled >= solved_repeat: 101 | logger.info("Environment solved!") 102 | exit(0) 103 | else: 104 | reward_fulfilled = 0 105 | -------------------------------------------------------------------------------- /examples/framework_examples/ppo.py: -------------------------------------------------------------------------------- 1 | from machin.frame.algorithms import PPO 2 | from machin.utils.logging import default_logger as logger 3 | from torch.distributions import Categorical 4 | import torch as t 5 | import torch.nn as nn 6 | import gym 7 | 8 | # configurations 9 | env = gym.make("CartPole-v0") 10 | observe_dim = 4 11 | action_num = 2 12 | max_episodes = 1000 13 | max_steps = 200 14 | solved_reward = 190 15 | solved_repeat = 5 16 | 17 | 18 | # model definition 19 | class Actor(nn.Module): 20 | def __init__(self, state_dim, action_num): 21 | super().__init__() 22 | 23 | self.fc1 = nn.Linear(state_dim, 16) 24 | self.fc2 = nn.Linear(16, 16) 25 | self.fc3 = nn.Linear(16, action_num) 26 | 27 | def forward(self, state, action=None): 28 | a = t.relu(self.fc1(state)) 29 | a = t.relu(self.fc2(a)) 30 | probs = t.softmax(self.fc3(a), dim=1) 31 | dist = Categorical(probs=probs) 32 | act = action if action is not None else dist.sample() 33 | act_entropy = dist.entropy() 34 | act_log_prob = dist.log_prob(act.flatten()) 35 | return act, act_log_prob, act_entropy 36 | 37 | 38 | class Critic(nn.Module): 39 | def __init__(self, state_dim): 40 | super().__init__() 41 | 42 | self.fc1 = nn.Linear(state_dim, 16) 43 | self.fc2 = nn.Linear(16, 16) 44 | self.fc3 = nn.Linear(16, 1) 45 | 46 | def forward(self, state): 47 | v = t.relu(self.fc1(state)) 48 | v = t.relu(self.fc2(v)) 49 | v = self.fc3(v) 50 | return v 51 | 52 | 53 | if __name__ == "__main__": 54 | actor = Actor(observe_dim, action_num) 55 | critic = Critic(observe_dim) 56 | 57 | ppo = PPO(actor, critic, t.optim.Adam, nn.MSELoss(reduction="sum")) 58 | 59 | episode, step, reward_fulfilled = 0, 0, 0 60 | smoothed_total_reward = 0 61 | 62 | while episode < max_episodes: 63 | episode += 1 64 | total_reward = 0 65 | terminal = False 66 | step = 0 67 | state = t.tensor(env.reset(), dtype=t.float32).view(1, observe_dim) 68 | 69 | tmp_observations = [] 70 | while not terminal and step <= max_steps: 71 | step += 1 72 | with t.no_grad(): 73 | old_state = state 74 | # agent model inference 75 | action = ppo.act({"state": old_state})[0] 76 | state, reward, terminal, _ = env.step(action.item()) 77 | state = t.tensor(state, dtype=t.float32).view(1, observe_dim) 78 | total_reward += reward 79 | 80 | tmp_observations.append( 81 | { 82 | "state": {"state": old_state}, 83 | "action": {"action": action}, 84 | "next_state": {"state": state}, 85 | "reward": reward, 86 | "terminal": terminal or step == max_steps, 87 | } 88 | ) 89 | 90 | # update 91 | ppo.store_episode(tmp_observations) 92 | ppo.update() 93 | 94 | # show reward 95 | smoothed_total_reward = smoothed_total_reward * 0.9 + total_reward * 0.1 96 | logger.info(f"Episode {episode} total reward={smoothed_total_reward:.2f}") 97 | 98 | if smoothed_total_reward > solved_reward: 99 | reward_fulfilled += 1 100 | if reward_fulfilled >= solved_repeat: 101 | logger.info("Environment solved!") 102 | exit(0) 103 | else: 104 | reward_fulfilled = 0 105 | -------------------------------------------------------------------------------- /examples/framework_examples/rainbow.py: -------------------------------------------------------------------------------- 1 | from machin.frame.algorithms import RAINBOW 2 | from machin.utils.logging import default_logger as logger 3 | import torch as t 4 | import torch.nn as nn 5 | import gym 6 | 7 | # configurations 8 | env = gym.make("CartPole-v0") 9 | observe_dim = 4 10 | action_num = 2 11 | # maximum and minimum of reward value 12 | # since reward is 1 for every step, maximum q value should be 13 | # below 20(reward_future_steps) * (1 + discount ** n_steps) < 40 14 | value_max = 40 15 | value_min = 0 16 | reward_future_steps = 20 17 | max_episodes = 1000 18 | max_steps = 200 19 | solved_reward = 190 20 | solved_repeat = 5 21 | 22 | 23 | # model definition 24 | class QNet(nn.Module): 25 | # this test setup lacks the noisy linear layer and dueling structure. 26 | def __init__(self, state_dim, action_num, atom_num=10): 27 | super().__init__() 28 | 29 | self.fc1 = nn.Linear(state_dim, 16) 30 | self.fc2 = nn.Linear(16, 16) 31 | self.fc3 = nn.Linear(16, action_num * atom_num) 32 | self.action_num = action_num 33 | self.atom_num = atom_num 34 | 35 | def forward(self, state): 36 | a = t.relu(self.fc1(state)) 37 | a = t.relu(self.fc2(a)) 38 | return t.softmax(self.fc3(a).view(-1, self.action_num, self.atom_num), dim=-1) 39 | 40 | 41 | if __name__ == "__main__": 42 | q_net = QNet(observe_dim, action_num) 43 | q_net_t = QNet(observe_dim, action_num) 44 | 45 | rainbow = RAINBOW( 46 | q_net, 47 | q_net_t, 48 | t.optim.Adam, 49 | value_min, 50 | value_max, 51 | reward_future_steps=reward_future_steps, 52 | ) 53 | 54 | episode, step, reward_fulfilled = 0, 0, 0 55 | smoothed_total_reward = 0 56 | 57 | while episode < max_episodes: 58 | episode += 1 59 | total_reward = 0 60 | terminal = False 61 | step = 0 62 | state = t.tensor(env.reset(), dtype=t.float32).view(1, observe_dim) 63 | 64 | tmp_observations = [] 65 | while not terminal and step <= max_steps: 66 | step += 1 67 | with t.no_grad(): 68 | old_state = state 69 | # agent model inference 70 | action = rainbow.act_discrete_with_noise({"state": old_state}) 71 | state, reward, terminal, _ = env.step(action.item()) 72 | state = t.tensor(state, dtype=t.float32).view(1, observe_dim) 73 | total_reward += reward 74 | 75 | tmp_observations.append( 76 | { 77 | "state": {"state": old_state}, 78 | "action": {"action": action}, 79 | "next_state": {"state": state}, 80 | "reward": reward, 81 | "terminal": terminal or step == max_steps, 82 | } 83 | ) 84 | 85 | rainbow.store_episode(tmp_observations) 86 | 87 | # update, update more if episode is longer, else less 88 | if episode > 100: 89 | for _ in range(step): 90 | rainbow.update() 91 | 92 | # show reward 93 | smoothed_total_reward = smoothed_total_reward * 0.9 + total_reward * 0.1 94 | logger.info(f"Episode {episode} total reward={smoothed_total_reward:.2f}") 95 | 96 | if smoothed_total_reward > solved_reward: 97 | reward_fulfilled += 1 98 | if reward_fulfilled >= solved_repeat: 99 | logger.info("Environment solved!") 100 | exit(0) 101 | else: 102 | reward_fulfilled = 0 103 | -------------------------------------------------------------------------------- /examples/tutorials/your_first_program/main.py: -------------------------------------------------------------------------------- 1 | from machin.frame.algorithms import DQN 2 | from machin.utils.logging import default_logger as logger 3 | from machin.model.nets import static_module_wrapper, dynamic_module_wrapper 4 | import torch as t 5 | import torch.nn as nn 6 | import gym 7 | 8 | # configurations 9 | env = gym.make("CartPole-v0") 10 | observe_dim = 4 11 | action_num = 2 12 | max_episodes = 1000 13 | max_steps = 200 14 | solved_reward = 190 15 | solved_repeat = 5 16 | 17 | 18 | # model definition 19 | class QNet(nn.Module): 20 | def __init__(self, state_dim, action_num): 21 | super().__init__() 22 | 23 | self.fc1 = nn.Linear(state_dim, 16) 24 | self.fc2 = nn.Linear(16, 16) 25 | self.fc3 = nn.Linear(16, action_num) 26 | 27 | def forward(self, some_state): 28 | a = t.relu(self.fc1(some_state)) 29 | a = t.relu(self.fc2(a)) 30 | return self.fc3(a) 31 | 32 | 33 | if __name__ == "__main__": 34 | # let framework determine input/output device based on parameter location 35 | # a warning will be thrown. 36 | q_net = QNet(observe_dim, action_num) 37 | q_net_t = QNet(observe_dim, action_num) 38 | 39 | # to mark the input/output device Manually 40 | # will not work if you move your model to other devices 41 | # after wrapping 42 | 43 | # q_net = static_module_wrapper(q_net, "cpu", "cpu") 44 | # q_net_t = static_module_wrapper(q_net_t, "cpu", "cpu") 45 | 46 | # to mark the input/output device Automatically 47 | # will not work if you model locates on multiple devices 48 | 49 | # q_net = dynamic_module_wrapper(q_net) 50 | # q_net_t = dynamic_module_wrapper(q_net_t) 51 | 52 | dqn = DQN(q_net, q_net_t, t.optim.Adam, nn.MSELoss(reduction="sum")) 53 | 54 | episode, step, reward_fulfilled = 0, 0, 0 55 | smoothed_total_reward = 0 56 | 57 | while episode < max_episodes: 58 | episode += 1 59 | total_reward = 0 60 | terminal = False 61 | step = 0 62 | state = t.tensor(env.reset(), dtype=t.float32).view(1, observe_dim) 63 | 64 | while not terminal and step <= max_steps: 65 | step += 1 66 | with t.no_grad(): 67 | old_state = state 68 | # agent model inference 69 | action = dqn.act_discrete_with_noise({"some_state": old_state}) 70 | state, reward, terminal, _ = env.step(action.item()) 71 | state = t.tensor(state, dtype=t.float32).view(1, observe_dim) 72 | total_reward += reward 73 | 74 | dqn.store_transition( 75 | { 76 | "state": {"some_state": old_state}, 77 | "action": {"action": action}, 78 | "next_state": {"some_state": state}, 79 | "reward": reward, 80 | "terminal": terminal or step == max_steps, 81 | } 82 | ) 83 | 84 | # update, update more if episode is longer, else less 85 | if episode > 100: 86 | for _ in range(step): 87 | dqn.update() 88 | 89 | # show reward 90 | smoothed_total_reward = smoothed_total_reward * 0.9 + total_reward * 0.1 91 | logger.info(f"Episode {episode} total reward={smoothed_total_reward:.2f}") 92 | 93 | if smoothed_total_reward > solved_reward: 94 | reward_fulfilled += 1 95 | if reward_fulfilled >= solved_repeat: 96 | logger.info("Environment solved!") 97 | exit(0) 98 | else: 99 | reward_fulfilled = 0 100 | -------------------------------------------------------------------------------- /examples/tutorials/recurrent_networks/rppo.py: -------------------------------------------------------------------------------- 1 | from machin.env.utils.openai_gym import disable_view_window 2 | from machin.frame.algorithms import PPO 3 | from machin.utils.logging import default_logger as logger 4 | from torch.distributions import Categorical 5 | 6 | import gym 7 | import torch as t 8 | import torch.nn as nn 9 | 10 | from util import convert 11 | 12 | # configurations 13 | env = gym.make("Frostbite-ram-v0") 14 | action_num = env.action_space.n 15 | max_episodes = 20000 16 | 17 | # disable view window in rendering 18 | disable_view_window() 19 | 20 | 21 | class RecurrentActor(nn.Module): 22 | def __init__(self, action_num): 23 | super().__init__() 24 | self.gru = nn.GRU(128, 256, batch_first=True) 25 | self.fc1 = nn.Linear(256, 256) 26 | self.fc2 = nn.Linear(256, action_num) 27 | 28 | def forward(self, mem, hidden, action=None): 29 | hidden = hidden.transpose(0, 1) 30 | a, hidden = self.gru(mem.unsqueeze(1), hidden) 31 | a = self.fc2(t.relu(self.fc1(t.relu(a.flatten(start_dim=1))))) 32 | probs = t.softmax(a, dim=1) 33 | dist = Categorical(probs=probs) 34 | act = action if action is not None else dist.sample() 35 | act_entropy = dist.entropy() 36 | act_log_prob = dist.log_prob(act.flatten()) 37 | return act, act_log_prob, act_entropy, hidden 38 | 39 | 40 | class Critic(nn.Module): 41 | def __init__(self): 42 | super().__init__() 43 | self.fc1 = nn.Linear(128, 256) 44 | self.fc2 = nn.Linear(256, 256) 45 | self.fc3 = nn.Linear(256, 1) 46 | 47 | def forward(self, mem): 48 | v = t.relu(self.fc1(mem)) 49 | v = t.relu(self.fc2(v)) 50 | v = self.fc3(v) 51 | return v 52 | 53 | 54 | if __name__ == "__main__": 55 | actor = RecurrentActor(action_num).to("cuda:0") 56 | critic = Critic().to("cuda:0") 57 | 58 | rppo = PPO( 59 | actor, 60 | critic, 61 | t.optim.Adam, 62 | nn.MSELoss(reduction="sum"), 63 | actor_learning_rate=1e-5, 64 | critic_learning_rate=1e-4, 65 | ) 66 | 67 | episode, step, reward_fulfilled = 0, 0, 0 68 | smoothed_total_reward = 0 69 | 70 | while episode < max_episodes: 71 | episode += 1 72 | total_reward = 0 73 | terminal = False 74 | step = 0 75 | hidden = t.zeros([1, 1, 256]) 76 | state = convert(env.reset()) 77 | 78 | tmp_observations = [] 79 | while not terminal: 80 | step += 1 81 | with t.no_grad(): 82 | old_state = state 83 | # agent model inference 84 | old_hidden = hidden 85 | action, _, _, hidden = rppo.act({"mem": state, "hidden": hidden}) 86 | state, reward, terminal, _ = env.step(action.item()) 87 | state = convert(state) 88 | total_reward += reward 89 | 90 | tmp_observations.append( 91 | { 92 | "state": {"mem": old_state, "hidden": old_hidden}, 93 | "action": {"action": action}, 94 | "next_state": {"mem": state, "hidden": hidden}, 95 | "reward": reward, 96 | "terminal": terminal, 97 | } 98 | ) 99 | 100 | # update 101 | rppo.store_episode(tmp_observations) 102 | rppo.update() 103 | 104 | # show reward 105 | smoothed_total_reward = smoothed_total_reward * 0.9 + total_reward * 0.1 106 | 107 | logger.info(f"Episode {episode} total reward={smoothed_total_reward:.2f}") 108 | -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/multiagent/scenarios/simple_reference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from multiagent.core import World, Agent, Landmark 3 | from multiagent.scenario import BaseScenario 4 | 5 | class Scenario(BaseScenario): 6 | def make_world(self): 7 | world = World() 8 | # set any world properties first 9 | world.dim_c = 10 10 | world.collaborative = True # whether agents share rewards 11 | # add agents 12 | world.agents = [Agent() for i in range(2)] 13 | for i, agent in enumerate(world.agents): 14 | agent.name = 'agent %d' % i 15 | agent.collide = False 16 | # add landmarks 17 | world.landmarks = [Landmark() for i in range(3)] 18 | for i, landmark in enumerate(world.landmarks): 19 | landmark.name = 'landmark %d' % i 20 | landmark.collide = False 21 | landmark.movable = False 22 | # make initial conditions 23 | self.reset_world(world) 24 | return world 25 | 26 | def reset_world(self, world): 27 | # assign goals to agents 28 | for agent in world.agents: 29 | agent.goal_a = None 30 | agent.goal_b = None 31 | # want other agent to go to the goal landmark 32 | world.agents[0].goal_a = world.agents[1] 33 | world.agents[0].goal_b = np.random.choice(world.landmarks) 34 | world.agents[1].goal_a = world.agents[0] 35 | world.agents[1].goal_b = np.random.choice(world.landmarks) 36 | # random properties for agents 37 | for i, agent in enumerate(world.agents): 38 | agent.color = np.array([0.25,0.25,0.25]) 39 | # random properties for landmarks 40 | world.landmarks[0].color = np.array([0.75,0.25,0.25]) 41 | world.landmarks[1].color = np.array([0.25,0.75,0.25]) 42 | world.landmarks[2].color = np.array([0.25,0.25,0.75]) 43 | # special colors for goals 44 | world.agents[0].goal_a.color = world.agents[0].goal_b.color 45 | world.agents[1].goal_a.color = world.agents[1].goal_b.color 46 | # set random initial states 47 | for agent in world.agents: 48 | agent.state.p_pos = np.random.uniform(-1,+1, world.dim_p) 49 | agent.state.p_vel = np.zeros(world.dim_p) 50 | agent.state.c = np.zeros(world.dim_c) 51 | for i, landmark in enumerate(world.landmarks): 52 | landmark.state.p_pos = np.random.uniform(-1,+1, world.dim_p) 53 | landmark.state.p_vel = np.zeros(world.dim_p) 54 | 55 | def reward(self, agent, world): 56 | if agent.goal_a is None or agent.goal_b is None: 57 | return 0.0 58 | dist2 = np.sum(np.square(agent.goal_a.state.p_pos - agent.goal_b.state.p_pos)) 59 | return -dist2 60 | 61 | def observation(self, agent, world): 62 | # goal color 63 | goal_color = [np.zeros(world.dim_color), np.zeros(world.dim_color)] 64 | if agent.goal_b is not None: 65 | goal_color[1] = agent.goal_b.color 66 | 67 | # get positions of all entities in this agent's reference frame 68 | entity_pos = [] 69 | for entity in world.landmarks: 70 | entity_pos.append(entity.state.p_pos - agent.state.p_pos) 71 | # entity colors 72 | entity_color = [] 73 | for entity in world.landmarks: 74 | entity_color.append(entity.color) 75 | # communication of all other agents 76 | comm = [] 77 | for other in world.agents: 78 | if other is agent: continue 79 | comm.append(other.state.c) 80 | return np.concatenate([agent.state.p_vel] + entity_pos + [goal_color[1]] + comm) 81 | -------------------------------------------------------------------------------- /examples/tutorials/recurrent_networks/ppo.py: -------------------------------------------------------------------------------- 1 | from machin.env.utils.openai_gym import disable_view_window 2 | from machin.frame.algorithms import PPO 3 | from machin.utils.logging import default_logger as logger 4 | from torch.distributions import Categorical 5 | 6 | import gym 7 | import torch as t 8 | import torch.nn as nn 9 | 10 | from util import convert 11 | from history import History 12 | 13 | # configurations 14 | env = gym.make("Frostbite-ram-v0") 15 | action_num = env.action_space.n 16 | max_episodes = 20000 17 | history_depth = 4 18 | 19 | # disable view window in rendering 20 | disable_view_window() 21 | 22 | 23 | class Actor(nn.Module): 24 | def __init__(self, history_depth, action_num): 25 | super().__init__() 26 | self.fc1 = nn.Linear(128 * history_depth, 256) 27 | self.fc2 = nn.Linear(256, 256) 28 | self.fc3 = nn.Linear(256, action_num) 29 | 30 | def forward(self, mem, action=None): 31 | a = t.relu(self.fc1(mem.flatten(start_dim=1))) 32 | a = t.relu(self.fc2(a)) 33 | probs = t.softmax(self.fc3(a), dim=1) 34 | dist = Categorical(probs=probs) 35 | act = action if action is not None else dist.sample() 36 | act_entropy = dist.entropy() 37 | act_log_prob = dist.log_prob(act.flatten()) 38 | return act, act_log_prob, act_entropy 39 | 40 | 41 | class Critic(nn.Module): 42 | def __init__(self, history_depth): 43 | super().__init__() 44 | 45 | self.fc1 = nn.Linear(128 * history_depth, 256) 46 | self.fc2 = nn.Linear(256, 256) 47 | self.fc3 = nn.Linear(256, 1) 48 | 49 | def forward(self, mem): 50 | v = t.relu(self.fc1(mem.flatten(start_dim=1))) 51 | v = t.relu(self.fc2(v)) 52 | v = self.fc3(v) 53 | return v 54 | 55 | 56 | if __name__ == "__main__": 57 | actor = Actor(history_depth, action_num).to("cuda:0") 58 | critic = Critic(history_depth).to("cuda:0") 59 | 60 | ppo = PPO( 61 | actor, 62 | critic, 63 | t.optim.Adam, 64 | nn.MSELoss(reduction="sum"), 65 | actor_learning_rate=1e-5, 66 | critic_learning_rate=1e-4, 67 | ) 68 | 69 | episode, step, reward_fulfilled = 0, 0, 0 70 | smoothed_total_reward = 0 71 | 72 | while episode < max_episodes: 73 | episode += 1 74 | total_reward = 0 75 | terminal = False 76 | step = 0 77 | state = convert(env.reset()) 78 | history = History(history_depth, (1, 128)) 79 | 80 | tmp_observations = [] 81 | while not terminal: 82 | step += 1 83 | with t.no_grad(): 84 | history.append(state) 85 | # agent model inference 86 | action = ppo.act({"mem": history.get()})[0] 87 | state, reward, terminal, _ = env.step(action.item()) 88 | state = convert(state) 89 | total_reward += reward 90 | 91 | old_history = history.get() 92 | new_history = history.append(state).get() 93 | tmp_observations.append( 94 | { 95 | "state": {"mem": old_history}, 96 | "action": {"action": action}, 97 | "next_state": {"mem": new_history}, 98 | "reward": reward, 99 | "terminal": terminal, 100 | } 101 | ) 102 | 103 | # update 104 | ppo.store_episode(tmp_observations) 105 | ppo.update() 106 | 107 | # show reward 108 | smoothed_total_reward = smoothed_total_reward * 0.9 + total_reward * 0.1 109 | 110 | logger.info(f"Episode {episode} total reward={smoothed_total_reward:.2f}") 111 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # http://www.sphinx-doc.org/en/master/config 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath(os.path.join('..', '..'))) 16 | 17 | import machin 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = 'Machin' 22 | copyright = '2020, Iffi' 23 | author = 'Iffi' 24 | 25 | # The full version, including alpha/beta/rc tags 26 | release = machin.__version__ 27 | 28 | 29 | # -- General configuration --------------------------------------------------- 30 | master_doc = 'index' 31 | 32 | # Add any Sphinx extension module names here, as strings. They can be 33 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 34 | # ones. 35 | extensions = [ 36 | 'sphinx.ext.todo', 37 | 'sphinx.ext.autodoc', 38 | 'sphinx.ext.doctest', 39 | 'sphinx.ext.intersphinx', 40 | 'sphinx.ext.coverage', 41 | 'sphinx.ext.mathjax', 42 | 'sphinx.ext.ifconfig', 43 | 'sphinx.ext.viewcode', 44 | 'sphinx.ext.githubpages', 45 | 'sphinx.ext.napoleon', 46 | 'sphinx.ext.autosectionlabel', 47 | ] 48 | 49 | # Add any paths that contain templates here, relative to this directory. 50 | # templates_path = ['_templates'] 51 | 52 | # List of patterns, relative to source directory, that match files and 53 | # directories to ignore when looking for source files. 54 | # This pattern also affects html_static_path and html_extra_path. 55 | exclude_patterns = [] 56 | 57 | # Both the class’ and the __init__ method’s docstring are concatenated 58 | # and inserted. 59 | autoclass_content = 'both' 60 | autodoc_default_options = { 61 | #'special-members': '__call__, __getitem__, __len__' 62 | } 63 | autodoc_member_order = 'groupwise' # 'bysource', 'alphabetical' 64 | autodoc_typehints = "description" 65 | autodoc_mock_imports = [""] 66 | # autodoc_dumb_docstring = True 67 | 68 | # same as autoclass_content = 'both', 69 | # but __init__ signature is also documented, not beautiful. 70 | # napoleon_include_init_with_doc = True 71 | # napoleon_use_admonition_for_examples = True 72 | 73 | # -- Options for HTML output ------------------------------------------------- 74 | 75 | # The theme to use for HTML and HTML Help pages. See the documentation for 76 | # a list of builtin themes. 77 | # 78 | html_theme = 'theme' 79 | html_logo = 'static/icon_title.png' 80 | html_favicon = 'static/favicon.png' 81 | html_theme_path = ['../'] 82 | 83 | html_theme_options = { 84 | 'canonical_url': '', 85 | 'logo_only': True, 86 | 'display_version': True, 87 | 'prev_next_buttons_location': 'bottom', 88 | 'style_external_links': False, 89 | # Toc options 90 | 'collapse_navigation': True, 91 | 'sticky_navigation': True, 92 | 'navigation_depth': 4, 93 | 'includehidden': True, 94 | 'titles_only': False 95 | } 96 | 97 | # Add any paths that contain custom static files (such as style sheets) here, 98 | # relative to this directory. They are copied after the builtin static files, 99 | # so a file named "default.css" will overwrite the builtin "default.css". 100 | html_static_path = ['static'] 101 | 102 | # automatic section reference label 103 | autosectionlabel_prefix_document = True 104 | autosectionlabel_maxdepth = 3 105 | numfig = True 106 | -------------------------------------------------------------------------------- /test_lib/multiagent-particle-envs/multiagent/scenarios/simple_speaker_listener.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from multiagent.core import World, Agent, Landmark 3 | from multiagent.scenario import BaseScenario 4 | 5 | class Scenario(BaseScenario): 6 | def make_world(self): 7 | world = World() 8 | # set any world properties first 9 | world.dim_c = 3 10 | num_landmarks = 3 11 | world.collaborative = True 12 | # add agents 13 | world.agents = [Agent() for i in range(2)] 14 | for i, agent in enumerate(world.agents): 15 | agent.name = 'agent %d' % i 16 | agent.collide = False 17 | agent.size = 0.075 18 | # speaker 19 | world.agents[0].movable = False 20 | # listener 21 | world.agents[1].silent = True 22 | # add landmarks 23 | world.landmarks = [Landmark() for i in range(num_landmarks)] 24 | for i, landmark in enumerate(world.landmarks): 25 | landmark.name = 'landmark %d' % i 26 | landmark.collide = False 27 | landmark.movable = False 28 | landmark.size = 0.04 29 | # make initial conditions 30 | self.reset_world(world) 31 | return world 32 | 33 | def reset_world(self, world): 34 | # assign goals to agents 35 | for agent in world.agents: 36 | agent.goal_a = None 37 | agent.goal_b = None 38 | # want listener to go to the goal landmark 39 | world.agents[0].goal_a = world.agents[1] 40 | world.agents[0].goal_b = np.random.choice(world.landmarks) 41 | # random properties for agents 42 | for i, agent in enumerate(world.agents): 43 | agent.color = np.array([0.25,0.25,0.25]) 44 | # random properties for landmarks 45 | world.landmarks[0].color = np.array([0.65,0.15,0.15]) 46 | world.landmarks[1].color = np.array([0.15,0.65,0.15]) 47 | world.landmarks[2].color = np.array([0.15,0.15,0.65]) 48 | # special colors for goals 49 | world.agents[0].goal_a.color = world.agents[0].goal_b.color + np.array([0.45, 0.45, 0.45]) 50 | # set random initial states 51 | for agent in world.agents: 52 | agent.state.p_pos = np.random.uniform(-1,+1, world.dim_p) 53 | agent.state.p_vel = np.zeros(world.dim_p) 54 | agent.state.c = np.zeros(world.dim_c) 55 | for i, landmark in enumerate(world.landmarks): 56 | landmark.state.p_pos = np.random.uniform(-1,+1, world.dim_p) 57 | landmark.state.p_vel = np.zeros(world.dim_p) 58 | 59 | def benchmark_data(self, agent, world): 60 | # returns data for benchmarking purposes 61 | return self.reward(agent, reward) 62 | 63 | def reward(self, agent, world): 64 | # squared distance from listener to landmark 65 | a = world.agents[0] 66 | dist2 = np.sum(np.square(a.goal_a.state.p_pos - a.goal_b.state.p_pos)) 67 | return -dist2 68 | 69 | def observation(self, agent, world): 70 | # goal color 71 | goal_color = np.zeros(world.dim_color) 72 | if agent.goal_b is not None: 73 | goal_color = agent.goal_b.color 74 | 75 | # get positions of all entities in this agent's reference frame 76 | entity_pos = [] 77 | for entity in world.landmarks: 78 | entity_pos.append(entity.state.p_pos - agent.state.p_pos) 79 | 80 | # communication of all other agents 81 | comm = [] 82 | for other in world.agents: 83 | if other is agent or (other.state.c is None): continue 84 | comm.append(other.state.c) 85 | 86 | # speaker 87 | if not agent.movable: 88 | return np.concatenate([goal_color]) 89 | # listener 90 | if agent.silent: 91 | return np.concatenate([agent.state.p_vel] + entity_pos + comm) 92 | 93 | -------------------------------------------------------------------------------- /examples/framework_examples/ddpg.py: -------------------------------------------------------------------------------- 1 | from machin.frame.algorithms import DDPG 2 | from machin.utils.logging import default_logger as logger 3 | import torch as t 4 | import torch.nn as nn 5 | import gym 6 | 7 | # configurations 8 | env = gym.make("Pendulum-v0") 9 | observe_dim = 3 10 | action_dim = 1 11 | action_range = 2 12 | max_episodes = 1000 13 | max_steps = 200 14 | noise_param = (0, 0.2) 15 | noise_mode = "normal" 16 | solved_reward = -150 17 | solved_repeat = 5 18 | 19 | 20 | # model definition 21 | class Actor(nn.Module): 22 | def __init__(self, state_dim, action_dim, action_range): 23 | super().__init__() 24 | 25 | self.fc1 = nn.Linear(state_dim, 16) 26 | self.fc2 = nn.Linear(16, 16) 27 | self.fc3 = nn.Linear(16, action_dim) 28 | self.action_range = action_range 29 | 30 | def forward(self, state): 31 | a = t.relu(self.fc1(state)) 32 | a = t.relu(self.fc2(a)) 33 | a = t.tanh(self.fc3(a)) * self.action_range 34 | return a 35 | 36 | 37 | class Critic(nn.Module): 38 | def __init__(self, state_dim, action_dim): 39 | super().__init__() 40 | 41 | self.fc1 = nn.Linear(state_dim + action_dim, 16) 42 | self.fc2 = nn.Linear(16, 16) 43 | self.fc3 = nn.Linear(16, 1) 44 | 45 | def forward(self, state, action): 46 | state_action = t.cat([state, action], 1) 47 | q = t.relu(self.fc1(state_action)) 48 | q = t.relu(self.fc2(q)) 49 | q = self.fc3(q) 50 | return q 51 | 52 | 53 | if __name__ == "__main__": 54 | actor = Actor(observe_dim, action_dim, action_range) 55 | actor_t = Actor(observe_dim, action_dim, action_range) 56 | critic = Critic(observe_dim, action_dim) 57 | critic_t = Critic(observe_dim, action_dim) 58 | 59 | ddpg = DDPG( 60 | actor, actor_t, critic, critic_t, t.optim.Adam, nn.MSELoss(reduction="sum") 61 | ) 62 | 63 | episode, step, reward_fulfilled = 0, 0, 0 64 | smoothed_total_reward = 0 65 | 66 | while episode < max_episodes: 67 | episode += 1 68 | total_reward = 0 69 | terminal = False 70 | step = 0 71 | state = t.tensor(env.reset(), dtype=t.float32).view(1, observe_dim) 72 | tmp_observations = [] 73 | 74 | while not terminal and step <= max_steps: 75 | step += 1 76 | with t.no_grad(): 77 | old_state = state 78 | # agent model inference 79 | action = ddpg.act_with_noise( 80 | {"state": old_state}, noise_param=noise_param, mode=noise_mode 81 | ) 82 | state, reward, terminal, _ = env.step(action.numpy()) 83 | state = t.tensor(state, dtype=t.float32).view(1, observe_dim) 84 | total_reward += reward[0] 85 | 86 | tmp_observations.append( 87 | { 88 | "state": {"state": old_state}, 89 | "action": {"action": action}, 90 | "next_state": {"state": state}, 91 | "reward": reward[0], 92 | "terminal": terminal or step == max_steps, 93 | } 94 | ) 95 | 96 | ddpg.store_episode(tmp_observations) 97 | # update, update more if episode is longer, else less 98 | if episode > 100: 99 | for _ in range(step): 100 | ddpg.update() 101 | 102 | # show reward 103 | smoothed_total_reward = smoothed_total_reward * 0.9 + total_reward * 0.1 104 | logger.info(f"Episode {episode} total reward={smoothed_total_reward:.2f}") 105 | 106 | if smoothed_total_reward > solved_reward: 107 | reward_fulfilled += 1 108 | if reward_fulfilled >= solved_repeat: 109 | logger.info("Environment solved!") 110 | exit(0) 111 | else: 112 | reward_fulfilled = 0 113 | -------------------------------------------------------------------------------- /examples/framework_examples/hddpg.py: -------------------------------------------------------------------------------- 1 | from machin.frame.algorithms import HDDPG 2 | from machin.utils.logging import default_logger as logger 3 | import torch as t 4 | import torch.nn as nn 5 | import gym 6 | 7 | # configurations 8 | env = gym.make("Pendulum-v0") 9 | observe_dim = 3 10 | action_dim = 1 11 | action_range = 2 12 | max_episodes = 1000 13 | max_steps = 200 14 | noise_param = (0, 0.2) 15 | noise_mode = "normal" 16 | solved_reward = -150 17 | solved_repeat = 5 18 | 19 | 20 | # model definition 21 | class Actor(nn.Module): 22 | def __init__(self, state_dim, action_dim, action_range): 23 | super().__init__() 24 | 25 | self.fc1 = nn.Linear(state_dim, 16) 26 | self.fc2 = nn.Linear(16, 16) 27 | self.fc3 = nn.Linear(16, action_dim) 28 | self.action_range = action_range 29 | 30 | def forward(self, state): 31 | a = t.relu(self.fc1(state)) 32 | a = t.relu(self.fc2(a)) 33 | a = t.tanh(self.fc3(a)) * self.action_range 34 | return a 35 | 36 | 37 | class Critic(nn.Module): 38 | def __init__(self, state_dim, action_dim): 39 | super().__init__() 40 | 41 | self.fc1 = nn.Linear(state_dim + action_dim, 16) 42 | self.fc2 = nn.Linear(16, 16) 43 | self.fc3 = nn.Linear(16, 1) 44 | 45 | def forward(self, state, action): 46 | state_action = t.cat([state, action], 1) 47 | q = t.relu(self.fc1(state_action)) 48 | q = t.relu(self.fc2(q)) 49 | q = self.fc3(q) 50 | return q 51 | 52 | 53 | if __name__ == "__main__": 54 | actor = Actor(observe_dim, action_dim, action_range) 55 | actor_t = Actor(observe_dim, action_dim, action_range) 56 | critic = Critic(observe_dim, action_dim) 57 | critic_t = Critic(observe_dim, action_dim) 58 | 59 | hddpg = HDDPG( 60 | actor, actor_t, critic, critic_t, t.optim.Adam, nn.MSELoss(reduction="sum") 61 | ) 62 | 63 | episode, step, reward_fulfilled = 0, 0, 0 64 | smoothed_total_reward = 0 65 | 66 | while episode < max_episodes: 67 | episode += 1 68 | total_reward = 0 69 | terminal = False 70 | step = 0 71 | state = t.tensor(env.reset(), dtype=t.float32).view(1, observe_dim) 72 | tmp_observations = [] 73 | 74 | while not terminal and step <= max_steps: 75 | step += 1 76 | with t.no_grad(): 77 | old_state = state 78 | # agent model inference 79 | action = hddpg.act_with_noise( 80 | {"state": old_state}, noise_param=noise_param, mode=noise_mode 81 | ) 82 | state, reward, terminal, _ = env.step(action.numpy()) 83 | state = t.tensor(state, dtype=t.float32).view(1, observe_dim) 84 | total_reward += reward[0] 85 | 86 | tmp_observations.append( 87 | { 88 | "state": {"state": old_state}, 89 | "action": {"action": action}, 90 | "next_state": {"state": state}, 91 | "reward": reward[0], 92 | "terminal": terminal or step == max_steps, 93 | } 94 | ) 95 | 96 | hddpg.store_episode(tmp_observations) 97 | # update, update more if episode is longer, else less 98 | if episode > 100: 99 | for _ in range(step): 100 | hddpg.update() 101 | 102 | # show reward 103 | smoothed_total_reward = smoothed_total_reward * 0.9 + total_reward * 0.1 104 | logger.info(f"Episode {episode} total reward={smoothed_total_reward:.2f}") 105 | 106 | if smoothed_total_reward > solved_reward: 107 | reward_fulfilled += 1 108 | if reward_fulfilled >= solved_repeat: 109 | logger.info("Environment solved!") 110 | exit(0) 111 | else: 112 | reward_fulfilled = 0 113 | -------------------------------------------------------------------------------- /examples/framework_examples/ddpg_per.py: -------------------------------------------------------------------------------- 1 | from machin.frame.algorithms import DDPGPer 2 | from machin.utils.logging import default_logger as logger 3 | import torch as t 4 | import torch.nn as nn 5 | import gym 6 | 7 | # configurations 8 | env = gym.make("Pendulum-v0") 9 | observe_dim = 3 10 | action_dim = 1 11 | action_range = 2 12 | max_episodes = 1000 13 | max_steps = 200 14 | noise_param = (0, 0.2) 15 | noise_mode = "normal" 16 | solved_reward = -150 17 | solved_repeat = 5 18 | 19 | 20 | # model definition 21 | class Actor(nn.Module): 22 | def __init__(self, state_dim, action_dim, action_range): 23 | super().__init__() 24 | 25 | self.fc1 = nn.Linear(state_dim, 16) 26 | self.fc2 = nn.Linear(16, 16) 27 | self.fc3 = nn.Linear(16, action_dim) 28 | self.action_range = action_range 29 | 30 | def forward(self, state): 31 | a = t.relu(self.fc1(state)) 32 | a = t.relu(self.fc2(a)) 33 | a = t.tanh(self.fc3(a)) * self.action_range 34 | return a 35 | 36 | 37 | class Critic(nn.Module): 38 | def __init__(self, state_dim, action_dim): 39 | super().__init__() 40 | 41 | self.fc1 = nn.Linear(state_dim + action_dim, 16) 42 | self.fc2 = nn.Linear(16, 16) 43 | self.fc3 = nn.Linear(16, 1) 44 | 45 | def forward(self, state, action): 46 | state_action = t.cat([state, action], 1) 47 | q = t.relu(self.fc1(state_action)) 48 | q = t.relu(self.fc2(q)) 49 | q = self.fc3(q) 50 | return q 51 | 52 | 53 | if __name__ == "__main__": 54 | actor = Actor(observe_dim, action_dim, action_range) 55 | actor_t = Actor(observe_dim, action_dim, action_range) 56 | critic = Critic(observe_dim, action_dim) 57 | critic_t = Critic(observe_dim, action_dim) 58 | 59 | ddpg_per = DDPGPer( 60 | actor, actor_t, critic, critic_t, t.optim.Adam, nn.MSELoss(reduction="sum") 61 | ) 62 | 63 | episode, step, reward_fulfilled = 0, 0, 0 64 | smoothed_total_reward = 0 65 | 66 | while episode < max_episodes: 67 | episode += 1 68 | total_reward = 0 69 | terminal = False 70 | step = 0 71 | state = t.tensor(env.reset(), dtype=t.float32).view(1, observe_dim) 72 | tmp_observations = [] 73 | 74 | while not terminal and step <= max_steps: 75 | step += 1 76 | with t.no_grad(): 77 | old_state = state 78 | # agent model inference 79 | action = ddpg_per.act_with_noise( 80 | {"state": old_state}, noise_param=noise_param, mode=noise_mode 81 | ) 82 | state, reward, terminal, _ = env.step(action.numpy()) 83 | state = t.tensor(state, dtype=t.float32).view(1, observe_dim) 84 | total_reward += reward[0] 85 | 86 | tmp_observations.append( 87 | { 88 | "state": {"state": old_state}, 89 | "action": {"action": action}, 90 | "next_state": {"state": state}, 91 | "reward": reward[0], 92 | "terminal": terminal or step == max_steps, 93 | } 94 | ) 95 | 96 | ddpg_per.store_episode(tmp_observations) 97 | # update, update more if episode is longer, else less 98 | if episode > 100: 99 | for _ in range(step): 100 | ddpg_per.update() 101 | 102 | # show reward 103 | smoothed_total_reward = smoothed_total_reward * 0.9 + total_reward * 0.1 104 | logger.info(f"Episode {episode} total reward={smoothed_total_reward:.2f}") 105 | 106 | if smoothed_total_reward > solved_reward: 107 | reward_fulfilled += 1 108 | if reward_fulfilled >= solved_repeat: 109 | logger.info("Environment solved!") 110 | exit(0) 111 | else: 112 | reward_fulfilled = 0 113 | -------------------------------------------------------------------------------- /examples/tutorials/recurrent_networks/drqn.py: -------------------------------------------------------------------------------- 1 | from machin.env.utils.openai_gym import disable_view_window 2 | from machin.frame.algorithms import DQNPer 3 | from machin.utils.logging import default_logger as logger 4 | 5 | import gym 6 | import torch as t 7 | import torch.nn as nn 8 | 9 | from util import convert 10 | from history import History 11 | 12 | # configurations 13 | env = gym.make("Frostbite-ram-v0") 14 | action_num = env.action_space.n 15 | max_episodes = 20000 16 | history_depth = 4 17 | 18 | # disable view window in rendering 19 | disable_view_window() 20 | 21 | 22 | # Q network model definition 23 | # for atari games 24 | class RecurrentQNet(nn.Module): 25 | def __init__(self, action_num): 26 | super().__init__() 27 | self.gru = nn.GRU(128, 256, batch_first=True) 28 | self.fc1 = nn.Linear(256, 256) 29 | self.fc2 = nn.Linear(256, action_num) 30 | 31 | def forward(self, mem=None, hidden=None, history_mem=None): 32 | if mem is not None: 33 | # in sampling 34 | a, h = self.gru(mem.unsqueeze(1), hidden) 35 | return self.fc2(t.relu(self.fc1(t.relu(a.flatten(start_dim=1))))), h 36 | else: 37 | # in updating 38 | batch_size = history_mem.shape[0] 39 | seq_length = history_mem.shape[1] 40 | hidden = t.zeros([1, batch_size, 256], device=history_mem.device) 41 | for i in range(seq_length): 42 | _, hidden = self.gru(history_mem[:, i].unsqueeze(1), hidden) 43 | # a[:, -1] = h 44 | return self.fc2( 45 | t.relu(self.fc1(t.relu(hidden.transpose(0, 1).flatten(start_dim=1)))) 46 | ) 47 | 48 | 49 | if __name__ == "__main__": 50 | r_q_net = RecurrentQNet(action_num).to("cuda:0") 51 | r_q_net_t = RecurrentQNet(action_num).to("cuda:0") 52 | 53 | drqn = DQNPer( 54 | r_q_net, 55 | r_q_net_t, 56 | t.optim.Adam, 57 | nn.MSELoss(reduction="sum"), 58 | learning_rate=5e-4, 59 | ) 60 | 61 | episode, step, reward_fulfilled = 0, 0, 0 62 | smoothed_total_reward = 0 63 | 64 | while episode < max_episodes: 65 | episode += 1 66 | total_reward = 0 67 | terminal = False 68 | step = 0 69 | hidden = t.zeros([1, 1, 256]) 70 | state = convert(env.reset()) 71 | history = History(history_depth, (1, 128)) 72 | 73 | while not terminal: 74 | step += 1 75 | with t.no_grad(): 76 | old_state = state 77 | history.append(state) 78 | # agent model inference 79 | action, hidden = drqn.act_discrete_with_noise( 80 | {"mem": old_state, "hidden": hidden} 81 | ) 82 | 83 | # info is {"ale.lives": self.ale.lives()}, not used here 84 | state, reward, terminal, _ = env.step(action.item()) 85 | state = convert(state) 86 | total_reward += reward 87 | 88 | # history mem includes current state 89 | old_history = history.get() 90 | new_history = history.append(state).get() 91 | drqn.store_transition( 92 | { 93 | "state": {"history_mem": old_history}, 94 | "action": {"action": action}, 95 | "next_state": {"history_mem": new_history}, 96 | "reward": reward, 97 | "terminal": terminal, 98 | } 99 | ) 100 | 101 | # update, update more if episode is longer, else less 102 | if episode > 20: 103 | for _ in range(step // 10): 104 | drqn.update() 105 | 106 | # show reward 107 | smoothed_total_reward = smoothed_total_reward * 0.9 + total_reward * 0.1 108 | 109 | logger.info(f"Episode {episode} total reward={smoothed_total_reward:.2f}") 110 | -------------------------------------------------------------------------------- /machin/parallel/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import weakref 3 | import itertools 4 | from threading import Lock 5 | from machin.utils.logging import default_logger 6 | 7 | _finalizer_lock = Lock() 8 | _finalizer_registry = {} 9 | _finalizer_counter = itertools.count() 10 | 11 | 12 | class Finalize(object): 13 | """ 14 | Class which supports object finalization using weakrefs. 15 | Adapted from python 3.7.3 multiprocessing.util. 16 | """ 17 | 18 | def __init__(self, obj, callback, args=(), kwargs=None, exitpriority=None): 19 | if (exitpriority is not None) and not isinstance(exitpriority, int): 20 | raise TypeError( 21 | "Exitpriority ({0!r}) must be None or int, not {1!s}".format( 22 | exitpriority, type(exitpriority) 23 | ) 24 | ) 25 | 26 | if obj is not None: 27 | # weakref is just used to track the object in __repr__ 28 | self._weakref = weakref.ref(obj, self) 29 | elif exitpriority is None: 30 | raise ValueError("Without object, exitpriority cannot be None") 31 | 32 | self._callback = callback 33 | self._args = args 34 | self._kwargs = kwargs or {} 35 | with _finalizer_lock: 36 | self._key = (exitpriority, next(_finalizer_counter)) 37 | self._pid = os.getpid() 38 | 39 | _finalizer_registry[self._key] = self 40 | 41 | def __call__( 42 | self, 43 | # Need to bind these locally because the globals could have 44 | # been cleared at shutdown 45 | finalizer_registry=None, 46 | debug_logger=None, 47 | getpid=None, 48 | ): 49 | """ 50 | Run the callback unless it has already been called or cancelled 51 | """ 52 | finalizer_registry = finalizer_registry or _finalizer_registry 53 | debug_logger = debug_logger or default_logger 54 | getpid = getpid or os.getpid 55 | 56 | try: 57 | del finalizer_registry[self._key] 58 | except KeyError: 59 | debug_logger.debug("finalizer no longer registered") 60 | else: 61 | if self._pid != getpid(): 62 | debug_logger.debug("finalizer ignored because different process") 63 | res = None 64 | else: 65 | debug_logger.debug( 66 | f"finalizer calling {self._callback} " 67 | f"with args {self._args} " 68 | f"and kwargs {self._kwargs}" 69 | ) 70 | res = self._callback(*self._args, **self._kwargs) 71 | self._weakref = ( 72 | self._callback 73 | ) = self._args = self._kwargs = self._key = None 74 | return res 75 | 76 | def cancel(self): 77 | """ 78 | Cancel finalization of the object 79 | """ 80 | try: 81 | del _finalizer_registry[self._key] 82 | except KeyError: 83 | pass 84 | else: 85 | self._weakref = ( 86 | self._callback 87 | ) = self._args = self._kwargs = self._key = None 88 | 89 | def still_active(self): 90 | """ 91 | Return whether this finalizer is still waiting to invoke callback 92 | """ 93 | return self._key in _finalizer_registry 94 | 95 | def __repr__(self): 96 | try: 97 | obj = self._weakref() 98 | except (AttributeError, TypeError): 99 | obj = None 100 | 101 | if obj is None: 102 | return f"<{self.__class__.__name__} object, dead>" 103 | 104 | x = ( 105 | f"<{self.__class__.__name__} object, " 106 | f"callback={getattr(self._callback, '__name__', self._callback)}" 107 | ) 108 | if self._args: 109 | x += ", args=" + str(self._args) 110 | if self._kwargs: 111 | x += ", kwargs=" + str(self._kwargs) 112 | if self._key[0] is not None: 113 | x += ", exitprority=" + str(self._key[0]) 114 | return x + ">" 115 | -------------------------------------------------------------------------------- /docs/theme/static/pygments.css: -------------------------------------------------------------------------------- 1 | .highlight { 2 | background-color: #eeeeec; 3 | padding: 0.5em 2px; 4 | } 5 | .highlight .hll { background-color: #ffffcc} 6 | .highlight .c { color: #999988; font-style: italic } /* Comment */ 7 | .highlight .err { color: #9c2c21; background-color: #e3d2d2 } /* Error */ 8 | .highlight .o { color: #aaa; font-weight: bold } /* Operator */ 9 | .highlight .cm { color: #ff7777; font-style: italic } /* Comment.Multiline */ 10 | .highlight .cp { color: #999999; font-weight: bold; font-style: italic } /* Comment.Preproc */ 11 | .highlight .c1 { color: #ff7777; font-style: italic } /* Comment.Single */ 12 | .highlight .cs { color: #999999; font-weight: bold; font-style: italic } /* Comment.Special */ 13 | .highlight .gd { color: #aaa; background-color: #ffdddd } /* Generic.Deleted */ 14 | .highlight .ge { color: #aaa; font-style: italic } /* Generic.Emph */ 15 | .highlight .gr { color: #9c2c21 } /* Generic.Error */ 16 | .highlight .gh { color: #999999 } /* Generic.Heading */ 17 | .highlight .gi { color: #aaa; background-color: #ddffdd } /* Generic.Inserted */ 18 | .highlight .go { color: #888888 } /* Generic.Output */ 19 | .highlight .gp { color: #555555 } /* Generic.Prompt */ 20 | .highlight .gs { font-weight: bold } /* Generic.Strong */ 21 | .highlight .gu { color: #aaaaaa } /* Generic.Subheading */ 22 | .highlight .gt { color: #9c2c21 } /* Generic.Traceback */ 23 | .highlight .k { color: #f8ac3f; font-weight: bold } /* Keyword */ 24 | .highlight .kc { color: #f8ac3f; font-weight: bold } /* Keyword.Constant */ 25 | .highlight .kd { color: #f8ac3f; font-weight: bold } /* Keyword.Declaration */ 26 | .highlight .kn { color: #f8ac3f; font-weight: bold } /* Keyword.Namespace */ 27 | .highlight .kp { color: #f8ac3f; font-weight: bold } /* Keyword.Pseudo */ 28 | .highlight .kr { color: #f8ac3f; font-weight: bold } /* Keyword.Reserved */ 29 | .highlight .kt { color: #f8ac3f; font-weight: bold } /* Keyword.Type */ 30 | .highlight .m { color: #077077 } /* Literal.Number */ 31 | .highlight .s { color: #077077 } /* Literal.String */ 32 | .highlight .na { color: #6cca51 } /* Name.Attribute */ 33 | .highlight .nb { color: #6cca51 } /* Name.Builtin */ 34 | .highlight .nc { color: #ee6622; font-weight: bold } /* Name.Class */ 35 | .highlight .no { color: #ee6622 } /* Name.Constant */ 36 | .highlight .nd { color: #ee6622; font-weight: bold } /* Name.Decorator */ 37 | .highlight .ni { color: #ee6622 } /* Name.Entity */ 38 | .highlight .ne { color: #ee6622; font-weight: bold } /* Name.Exception */ 39 | .highlight .nf { color: #ee6622; font-weight: bold } /* Name.Function */ 40 | .highlight .nl { color: #ee6622; font-weight: bold } /* Name.Label */ 41 | .highlight .nn { color: #777 } /* Name.Namespace */ 42 | .highlight .nt { color: #f8ac3f } /* Name.Tag */ 43 | .highlight .nv { color: #008080 } /* Name.Variable */ 44 | .highlight .ow { color: #aaa; font-weight: bold } /* Operator.Word */ 45 | .highlight .w { color: #aaa } /* Text.Whitespace */ 46 | .highlight .mf { color: #009999 } /* Literal.Number.Float */ 47 | .highlight .mh { color: #009999 } /* Literal.Number.Hex */ 48 | .highlight .mi { color: #009999 } /* Literal.Number.Integer */ 49 | .highlight .mo { color: #009999 } /* Literal.Number.Oct */ 50 | .highlight .sb { color: #9c2c21 } /* Literal.String.Backtick */ 51 | .highlight .sc { color: #9c2c21 } /* Literal.String.Char */ 52 | .highlight .sd { color: #9c2c21 } /* Literal.String.Doc */ 53 | .highlight .s2 { color: #9c2c21 } /* Literal.String.Double */ 54 | .highlight .se { color: #9c2c21 } /* Literal.String.Escape */ 55 | .highlight .sh { color: #9c2c21 } /* Literal.String.Heredoc */ 56 | .highlight .si { color: #9c2c21 } /* Literal.String.Interpol */ 57 | .highlight .sx { color: #9c2c21 } /* Literal.String.Other */ 58 | .highlight .sr { color: #6cca51 } /* Literal.String.Regex */ 59 | .highlight .s1 { color: #9c2c21 } /* Literal.String.Single */ 60 | .highlight .ss { color: #ffffff } /* Literal.String.Symbol */ 61 | .highlight .bp { color: #aaa } /* Name.Builtin.Pseudo */ 62 | .highlight .vc { color: #077077 } /* Name.Variable.Class */ 63 | .highlight .vg { color: #077077 } /* Name.Variable.Global */ 64 | .highlight .vi { color: #077077 } /* Name.Variable.Instance */ 65 | .highlight .il { color: #077077 } /* Literal.Number.Integer.Long */ -------------------------------------------------------------------------------- /docs/source/static/icon_title.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 19 | 21 | 43 | 45 | 46 | 48 | image/svg+xml 49 | 51 | 52 | 53 | 54 | 55 | 60 | M A C H I N 76 | 79 | 86 | 89 | 97 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /machin/utils/conf.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import copy 3 | import json 4 | import argparse 5 | 6 | from .helper_classes import Object 7 | 8 | 9 | class Config(Object): 10 | """ 11 | A simple replacement for python dict. 12 | """ 13 | 14 | def __init__(self, **configs): 15 | super().__init__(configs) 16 | 17 | def get(self, key, default=None): 18 | if key in self: 19 | return self[key] 20 | return default 21 | 22 | def __iter__(self): 23 | for key in self.__dict__: 24 | if not key.startswith("__"): 25 | yield key 26 | 27 | def __contains__(self, key): 28 | assert not key.startswith("__") 29 | return hasattr(self, key) 30 | 31 | def __getitem__(self, key): 32 | assert not key.startswith("__") 33 | return getattr(self, key) 34 | 35 | def __setitem__(self, key, value): 36 | assert not key.startswith("__") 37 | setattr(self, key, value) 38 | 39 | 40 | def load_config_cmd(merge_conf: Config = None) -> Config: 41 | """ 42 | Get configs from the commandline by using "--conf". 43 | 44 | ``--conf a=b`` will set ``.a = b`` 45 | 46 | Example:: 47 | 48 | python3 test.py --conf device=\"cuda:1\" 49 | --conf some_dict={\"some_key\":1} 50 | 51 | Example:: 52 | 53 | from machin.utils.conf import Config 54 | from machin.utils.save_env import SaveEnv 55 | 56 | # set some config attributes 57 | c = Config( 58 | model_save_int = 100, 59 | root_dir = "some_directory", 60 | restart_from_trial = "2020_05_09_15_00_31" 61 | ) 62 | 63 | load_config_cmd(c) 64 | 65 | # restart_from_trial specifies the trial name in your root 66 | # directory. 67 | # If it is set, then SaveEnv constructor will 68 | # load arguments from that trial record, will overwrite. 69 | # If not, then SaveEnv constructor will save configurations 70 | # as: ``//config/config.json`` 71 | 72 | save_env = SaveEnv(c) 73 | 74 | Args: 75 | merge_conf: Config to merge. 76 | """ 77 | parser = argparse.ArgumentParser(description=__doc__) 78 | parser.add_argument("--conf", action="append") 79 | args = parser.parse_args() 80 | 81 | config_dict = {} 82 | if args.conf is not None: 83 | for env_str in args.conf: 84 | name, value = env_str.split("=") 85 | value = eval(value) 86 | config_dict[name] = value 87 | 88 | return merge_config((Config() if merge_conf is None else merge_conf), config_dict) 89 | 90 | 91 | def load_config_file(json_file: str, merge_conf: Config = None) -> Config: 92 | """ 93 | Get configs from a json file. 94 | 95 | Args: 96 | json_file: Path to the json config file. 97 | merge_conf: Config to merge. 98 | 99 | Return: 100 | configuration 101 | """ 102 | # parse the configurations from the config json file provided 103 | with open(json_file) as config_file: 104 | config_dict = json.load(config_file) 105 | 106 | return merge_config((Config() if merge_conf is None else merge_conf), config_dict) 107 | 108 | 109 | def save_config(conf: Config, json_file: str): 110 | """ 111 | Dump config object to a json file. 112 | """ 113 | with open(json_file, "w") as config_file: 114 | json.dump(conf.data, config_file, sort_keys=True, indent=4) 115 | 116 | 117 | def merge_config(conf: Config, merge: Union[dict, Config]) -> Config: 118 | """ 119 | Merge config object with a dictionary, or a Config object, 120 | same keys in the ``conf`` will be overwritten by keys 121 | in ``merge``. 122 | """ 123 | new_conf = copy.deepcopy(conf) 124 | if isinstance(merge, dict): 125 | for k, v in merge.items(): 126 | new_conf[k] = v 127 | else: 128 | for k, v in merge.data.items(): 129 | if k not in new_conf.const_attrs: 130 | new_conf[k] = v 131 | return new_conf 132 | --------------------------------------------------------------------------------