├── ml-agents ├── tests │ ├── __init__.py │ ├── envs │ │ ├── __init__.py │ │ ├── test_rpc_communicator.py │ │ └── test_envs.py │ ├── trainers │ │ ├── __init__.py │ │ ├── test.demo │ │ ├── test_demo_loader.py │ │ ├── test_buffer.py │ │ ├── test_curriculum.py │ │ ├── test_meta_curriculum.py │ │ └── test_bc.py │ └── mock_communicator.py ├── mlagents │ ├── __init__.py │ ├── envs │ │ ├── utilities.py │ │ ├── __init__.py │ │ ├── communicator_objects │ │ │ ├── __init__.py │ │ │ ├── unity_to_external_pb2_grpc.py │ │ │ ├── command_proto_pb2.py │ │ │ ├── unity_to_external_pb2.py │ │ │ ├── space_type_proto_pb2.py │ │ │ ├── unity_rl_initialization_input_pb2.py │ │ │ ├── header_pb2.py │ │ │ ├── resolution_proto_pb2.py │ │ │ ├── agent_action_proto_pb2.py │ │ │ ├── unity_input_pb2.py │ │ │ ├── unity_output_pb2.py │ │ │ ├── demonstration_meta_proto_pb2.py │ │ │ ├── unity_message_pb2.py │ │ │ ├── engine_configuration_proto_pb2.py │ │ │ ├── environment_parameters_proto_pb2.py │ │ │ └── unity_rl_initialization_output_pb2.py │ │ ├── communicator.py │ │ ├── exception.py │ │ ├── socket_communicator.py │ │ └── rpc_communicator.py │ └── trainers │ │ ├── ppo │ │ └── __init__.py │ │ ├── bc │ │ ├── __init__.py │ │ ├── offline_trainer.py │ │ ├── models.py │ │ ├── policy.py │ │ └── online_trainer.py │ │ ├── __init__.py │ │ ├── exception.py │ │ ├── benchmark.py │ │ ├── demo_loader.py │ │ ├── curriculum.py │ │ ├── learn.py │ │ └── meta_curriculum.py ├── README.md └── setup.py ├── gym-unity ├── gym_unity │ ├── __init__.py │ └── envs │ │ └── __init__.py ├── setup.py ├── tests │ └── test_gym.py └── README.md ├── docs ├── images │ ├── 1.PNG │ ├── 2.PNG │ ├── 3.PNG │ ├── 4.PNG │ ├── 5.PNG │ ├── 6.PNG │ ├── LOD.jpg │ ├── time.jpg │ ├── width.jpg │ ├── width15.PNG │ ├── width3.PNG │ ├── birdview.jpg │ ├── car-camera.png │ ├── overview.PNG │ ├── thumbnail.jpg │ ├── weather1.jpg │ ├── weather2.jpg │ ├── weather3.jpg │ ├── weather4.jpg │ ├── architecture.png │ ├── tensorboard1.PNG │ ├── tensorboard2.PNG │ ├── image-overview.jpg │ └── forward-backward.jpg ├── FAQ.md ├── Requirements.md ├── Installation-and-Setup.md ├── Readme.md ├── Benchmark.md ├── ML-ImageSynthesis.md ├── Training-Process.md ├── AgentInfos-Obs-Action-Reward.md ├── Environment-Details.md ├── Training-and-Environment-Configuration.md ├── Pretrain-Model-Details.md └── Setup-Configuration-Files.md ├── unity-volume └── .gitignore ├── .gitattributes ├── protobuf-definitions ├── proto │ └── mlagents │ │ └── envs │ │ └── communicator_objects │ │ ├── header.proto │ │ ├── unity_rl_initialization_input.proto │ │ ├── command_proto.proto │ │ ├── environment_parameters_proto.proto │ │ ├── resolution_proto.proto │ │ ├── space_type_proto.proto │ │ ├── agent_action_proto.proto │ │ ├── unity_to_external.proto │ │ ├── demonstration_meta_proto.proto │ │ ├── engine_configuration_proto.proto │ │ ├── unity_rl_output.proto │ │ ├── unity_input.proto │ │ ├── unity_message.proto │ │ ├── unity_output.proto │ │ ├── agent_info_proto.proto │ │ ├── unity_rl_initialization_output.proto │ │ ├── unity_rl_input.proto │ │ └── brain_parameters_proto.proto ├── README.md ├── make.bat └── make_for_win.bat ├── config ├── curricula │ ├── wall-jump │ │ ├── SmallWallJumpLearning.json │ │ └── BigWallJumpLearning.json │ ├── test │ │ └── TestBrain.json │ └── autobench │ │ └── AutoBenchBrain.json ├── offline_bc_config.yaml ├── online_bc_config.yaml └── trainer_config.yaml ├── .gitignore ├── learn_gym.py ├── README.md ├── learn_rl.py └── learn_ml.py /ml-agents/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ml-agents/mlagents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ml-agents/tests/envs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/utilities.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ml-agents/tests/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gym-unity/gym_unity/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | -------------------------------------------------------------------------------- /docs/images/1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/1.PNG -------------------------------------------------------------------------------- /docs/images/2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/2.PNG -------------------------------------------------------------------------------- /docs/images/3.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/3.PNG -------------------------------------------------------------------------------- /docs/images/4.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/4.PNG -------------------------------------------------------------------------------- /docs/images/5.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/5.PNG -------------------------------------------------------------------------------- /docs/images/6.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/6.PNG -------------------------------------------------------------------------------- /docs/images/LOD.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/LOD.jpg -------------------------------------------------------------------------------- /docs/images/time.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/time.jpg -------------------------------------------------------------------------------- /docs/images/width.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/width.jpg -------------------------------------------------------------------------------- /docs/images/width15.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/width15.PNG -------------------------------------------------------------------------------- /docs/images/width3.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/width3.PNG -------------------------------------------------------------------------------- /docs/images/birdview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/birdview.jpg -------------------------------------------------------------------------------- /docs/images/car-camera.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/car-camera.png -------------------------------------------------------------------------------- /docs/images/overview.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/overview.PNG -------------------------------------------------------------------------------- /docs/images/thumbnail.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/thumbnail.jpg -------------------------------------------------------------------------------- /docs/images/weather1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/weather1.jpg -------------------------------------------------------------------------------- /docs/images/weather2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/weather2.jpg -------------------------------------------------------------------------------- /docs/images/weather3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/weather3.jpg -------------------------------------------------------------------------------- /docs/images/weather4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/weather4.jpg -------------------------------------------------------------------------------- /gym-unity/gym_unity/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from gym_unity.envs.unity_env import UnityEnv, UnityGymException 2 | -------------------------------------------------------------------------------- /docs/images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/architecture.png -------------------------------------------------------------------------------- /docs/images/tensorboard1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/tensorboard1.PNG -------------------------------------------------------------------------------- /docs/images/tensorboard2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/tensorboard2.PNG -------------------------------------------------------------------------------- /docs/images/image-overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/image-overview.jpg -------------------------------------------------------------------------------- /docs/images/forward-backward.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/docs/images/forward-backward.jpg -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import * 2 | from .brain import * 3 | from .exception import * 4 | -------------------------------------------------------------------------------- /ml-agents/mlagents/trainers/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | from .trainer import * 3 | from .policy import * 4 | -------------------------------------------------------------------------------- /ml-agents/tests/trainers/test.demo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stevenh-tw/AutoBench/HEAD/ml-agents/tests/trainers/test.demo -------------------------------------------------------------------------------- /unity-volume/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | # Ignore everything in this directory except for .gitignore. This directory is for illustrative purposes 3 | !.gitignore 4 | -------------------------------------------------------------------------------- /ml-agents/mlagents/trainers/bc/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | from .online_trainer import * 3 | from .offline_trainer import * 4 | from .policy import * 5 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.asset binary 2 | *.unity binary 3 | *.prefab binary 4 | *.meta binary 5 | */CommunicatorObjects/* binary 6 | */communicator_objects/* binary 7 | *.md text 8 | -------------------------------------------------------------------------------- /docs/FAQ.md: -------------------------------------------------------------------------------- 1 | # FAQ 2 | 3 | ### Why cannot provide Unity environment source code? 4 | 5 | Because this project involves lots of asset that came from the Unity Asset Store, according to the EULA, we can only provide the built instance of the environment. 6 | -------------------------------------------------------------------------------- /protobuf-definitions/proto/mlagents/envs/communicator_objects/header.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option csharp_namespace = "MLAgents.CommunicatorObjects"; 4 | package communicator_objects; 5 | 6 | message Header { 7 | int32 status = 1; 8 | string message = 2; 9 | } 10 | -------------------------------------------------------------------------------- /protobuf-definitions/proto/mlagents/envs/communicator_objects/unity_rl_initialization_input.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option csharp_namespace = "MLAgents.CommunicatorObjects"; 4 | package communicator_objects; 5 | 6 | 7 | message UnityRLInitializationInput { 8 | int32 seed = 1; 9 | } 10 | -------------------------------------------------------------------------------- /config/curricula/wall-jump/SmallWallJumpLearning.json: -------------------------------------------------------------------------------- 1 | { 2 | "measure" : "progress", 3 | "thresholds" : [0.1, 0.3, 0.5], 4 | "min_lesson_length": 100, 5 | "signal_smoothing" : true, 6 | "parameters" : 7 | { 8 | "small_wall_height" : [1.5, 2.0, 2.5, 4.0] 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /protobuf-definitions/proto/mlagents/envs/communicator_objects/command_proto.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option csharp_namespace = "MLAgents.CommunicatorObjects"; 4 | package communicator_objects; 5 | 6 | enum CommandProto { 7 | STEP = 0; 8 | RESET = 1; 9 | QUIT = 2; 10 | } 11 | -------------------------------------------------------------------------------- /protobuf-definitions/proto/mlagents/envs/communicator_objects/environment_parameters_proto.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option csharp_namespace = "MLAgents.CommunicatorObjects"; 4 | package communicator_objects; 5 | 6 | message EnvironmentParametersProto { 7 | map float_parameters = 1; 8 | } 9 | -------------------------------------------------------------------------------- /protobuf-definitions/proto/mlagents/envs/communicator_objects/resolution_proto.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option csharp_namespace = "MLAgents.CommunicatorObjects"; 4 | package communicator_objects; 5 | 6 | message ResolutionProto { 7 | int32 width = 1; 8 | int32 height = 2; 9 | bool gray_scale = 3; 10 | } 11 | 12 | -------------------------------------------------------------------------------- /config/curricula/test/TestBrain.json: -------------------------------------------------------------------------------- 1 | { 2 | "measure" : "reward", 3 | "thresholds" : [10, 20, 50], 4 | "min_lesson_length" : 100, 5 | "signal_smoothing" : true, 6 | "parameters" : 7 | { 8 | "param1" : [0.7, 0.5, 0.3, 0.1], 9 | "param2" : [100, 50, 20, 15], 10 | "param3" : [0.2, 0.3, 0.7, 0.9] 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /config/curricula/wall-jump/BigWallJumpLearning.json: -------------------------------------------------------------------------------- 1 | { 2 | "measure" : "progress", 3 | "thresholds" : [0.1, 0.3, 0.5], 4 | "min_lesson_length": 100, 5 | "signal_smoothing" : true, 6 | "parameters" : 7 | { 8 | "big_wall_min_height" : [0.0, 4.0, 6.0, 8.0], 9 | "big_wall_max_height" : [4.0, 7.0, 8.0, 8.0] 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /docs/Requirements.md: -------------------------------------------------------------------------------- 1 | # Requirements 2 | The Unity executable instance only runs on Windows. 3 | 4 | ### Python Requirements 5 | ``` 6 | python>=3.6,<3.7 7 | 8 | tensorflow>=1.7,<1.8 9 | Pillow>=4.2.1 10 | matplotlib 11 | numpy>=1.13.3,<=1.14.5 12 | jupyter 13 | pytest>=3.2.2,<4.0.0 14 | docopt 15 | pyyaml 16 | protobuf>=3.6,<3.7 17 | grpcio>=1.11.0,<1.12.0 18 | ``` 19 | -------------------------------------------------------------------------------- /protobuf-definitions/proto/mlagents/envs/communicator_objects/space_type_proto.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "mlagents/envs/communicator_objects/resolution_proto.proto"; 4 | 5 | option csharp_namespace = "MLAgents.CommunicatorObjects"; 6 | package communicator_objects; 7 | 8 | enum SpaceTypeProto { 9 | discrete = 0; 10 | continuous = 1; 11 | } 12 | 13 | -------------------------------------------------------------------------------- /protobuf-definitions/proto/mlagents/envs/communicator_objects/agent_action_proto.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option csharp_namespace = "MLAgents.CommunicatorObjects"; 4 | package communicator_objects; 5 | 6 | message AgentActionProto { 7 | repeated float vector_actions = 1; 8 | string text_actions = 2; 9 | repeated float memories = 3; 10 | float value = 4; 11 | } 12 | -------------------------------------------------------------------------------- /docs/Installation-and-Setup.md: -------------------------------------------------------------------------------- 1 | # Installation and Setup 2 | * Clone this repository 3 | * Download [Unity Environment Executable](https://goo.gl/sQ2avA) 4 | * Unzip the file 5 | * Copy and paste to ```AutoBench``` folder 6 | ``` 7 | #In AutoBench folder 8 | cd ml-agents 9 | python setup.py install 10 | cd .. 11 | python learn_rl.py 12 | ``` 13 | * Notice: this environment is only compatible in Windows. 14 | -------------------------------------------------------------------------------- /protobuf-definitions/proto/mlagents/envs/communicator_objects/unity_to_external.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "mlagents/envs/communicator_objects/unity_message.proto"; 4 | 5 | option csharp_namespace = "MLAgents.CommunicatorObjects"; 6 | package communicator_objects; 7 | 8 | service UnityToExternal { 9 | // Sends the academy parameters 10 | rpc Exchange(UnityMessage) returns (UnityMessage) {} 11 | } 12 | 13 | -------------------------------------------------------------------------------- /protobuf-definitions/proto/mlagents/envs/communicator_objects/demonstration_meta_proto.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option csharp_namespace = "MLAgents.CommunicatorObjects"; 4 | package communicator_objects; 5 | 6 | message DemonstrationMetaProto { 7 | int32 api_version = 1; 8 | string demonstration_name = 2; 9 | int32 number_steps = 3; 10 | int32 number_episodes = 4; 11 | float mean_reward = 5; 12 | } 13 | -------------------------------------------------------------------------------- /protobuf-definitions/proto/mlagents/envs/communicator_objects/engine_configuration_proto.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option csharp_namespace = "MLAgents.CommunicatorObjects"; 4 | package communicator_objects; 5 | 6 | message EngineConfigurationProto { 7 | int32 width = 1; 8 | int32 height = 2; 9 | int32 quality_level = 3; 10 | float time_scale = 4; 11 | int32 target_frame_rate = 5; 12 | bool show_monitor = 6; 13 | } 14 | -------------------------------------------------------------------------------- /ml-agents/mlagents/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .buffer import * 2 | from .curriculum import * 3 | from .meta_curriculum import * 4 | from .models import * 5 | from .trainer_controller import * 6 | from .bc.models import * 7 | from .bc.offline_trainer import * 8 | from .bc.online_trainer import * 9 | from .bc.policy import * 10 | from .ppo.models import * 11 | from .ppo.trainer import * 12 | from .ppo.policy import * 13 | from .exception import * 14 | from .policy import * 15 | from .demo_loader import * 16 | -------------------------------------------------------------------------------- /gym-unity/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup(name='gym_unity', 6 | version='0.2.0', 7 | description='Unity Machine Learning Agents Gym Interface', 8 | license='Apache License 2.0', 9 | author='Unity Technologies', 10 | author_email='ML-Agents@unity3d.com', 11 | url='https://github.com/Unity-Technologies/ml-agents', 12 | packages=find_packages(), 13 | install_requires=['gym', 'mlagents'] 14 | ) 15 | -------------------------------------------------------------------------------- /protobuf-definitions/proto/mlagents/envs/communicator_objects/unity_rl_output.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "mlagents/envs/communicator_objects/agent_info_proto.proto"; 4 | 5 | option csharp_namespace = "MLAgents.CommunicatorObjects"; 6 | package communicator_objects; 7 | 8 | message UnityRLOutput { 9 | message ListAgentInfoProto { 10 | repeated AgentInfoProto value = 1; 11 | } 12 | bool global_done = 1; 13 | map agentInfos = 2; 14 | } 15 | 16 | -------------------------------------------------------------------------------- /ml-agents/mlagents/trainers/exception.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains exceptions for the trainers package. 3 | """ 4 | 5 | class TrainerError(Exception): 6 | """ 7 | Any error related to the trainers in the ML-Agents Toolkit. 8 | """ 9 | pass 10 | 11 | class CurriculumError(TrainerError): 12 | """ 13 | Any error related to training with a curriculum. 14 | """ 15 | pass 16 | 17 | class MetaCurriculumError(TrainerError): 18 | """ 19 | Any error related to the configuration of a metacurriculum. 20 | """ 21 | -------------------------------------------------------------------------------- /protobuf-definitions/proto/mlagents/envs/communicator_objects/unity_input.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "mlagents/envs/communicator_objects/unity_rl_input.proto"; 4 | import "mlagents/envs/communicator_objects/unity_rl_initialization_input.proto"; 5 | 6 | option csharp_namespace = "MLAgents.CommunicatorObjects"; 7 | package communicator_objects; 8 | 9 | message UnityInput { 10 | UnityRLInput rl_input = 1; 11 | UnityRLInitializationInput rl_initialization_input = 2; 12 | 13 | //More messages can be added here 14 | } 15 | -------------------------------------------------------------------------------- /protobuf-definitions/proto/mlagents/envs/communicator_objects/unity_message.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "mlagents/envs/communicator_objects/unity_output.proto"; 4 | import "mlagents/envs/communicator_objects/unity_input.proto"; 5 | import "mlagents/envs/communicator_objects/header.proto"; 6 | 7 | option csharp_namespace = "MLAgents.CommunicatorObjects"; 8 | package communicator_objects; 9 | 10 | message UnityMessage { 11 | Header header = 1; 12 | UnityOutput unity_output = 2; 13 | UnityInput unity_input = 3; 14 | } 15 | 16 | -------------------------------------------------------------------------------- /protobuf-definitions/proto/mlagents/envs/communicator_objects/unity_output.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "mlagents/envs/communicator_objects/unity_rl_output.proto"; 4 | import "mlagents/envs/communicator_objects/unity_rl_initialization_output.proto"; 5 | 6 | option csharp_namespace = "MLAgents.CommunicatorObjects"; 7 | package communicator_objects; 8 | 9 | message UnityOutput { 10 | UnityRLOutput rl_output = 1; 11 | UnityRLInitializationOutput rl_initialization_output = 2; 12 | 13 | //More messages can be added here 14 | } 15 | 16 | -------------------------------------------------------------------------------- /protobuf-definitions/proto/mlagents/envs/communicator_objects/agent_info_proto.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | option csharp_namespace = "MLAgents.CommunicatorObjects"; 4 | package communicator_objects; 5 | 6 | message AgentInfoProto { 7 | repeated float stacked_vector_observation = 1; 8 | repeated bytes visual_observations = 2; 9 | string text_observation = 3; 10 | repeated float stored_vector_actions = 4; 11 | string stored_text_actions = 5; 12 | repeated float memories = 6; 13 | float reward = 7; 14 | bool done = 8; 15 | bool max_step_reached = 9; 16 | int32 id = 10; 17 | repeated bool action_mask = 11; 18 | } 19 | -------------------------------------------------------------------------------- /protobuf-definitions/proto/mlagents/envs/communicator_objects/unity_rl_initialization_output.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "mlagents/envs/communicator_objects/brain_parameters_proto.proto"; 4 | import "mlagents/envs/communicator_objects/environment_parameters_proto.proto"; 5 | 6 | option csharp_namespace = "MLAgents.CommunicatorObjects"; 7 | package communicator_objects; 8 | 9 | // The request message containing the academy's parameters. 10 | message UnityRLInitializationOutput { 11 | string name = 1; 12 | string version = 2; 13 | string log_path = 3; 14 | repeated BrainParametersProto brain_parameters = 5; 15 | EnvironmentParametersProto environment_parameters = 6; 16 | } 17 | -------------------------------------------------------------------------------- /docs/Readme.md: -------------------------------------------------------------------------------- 1 | # AutoBench Documatation 2 | 3 | - ### [Installation and Setup](Installation-and-Setup.md) 4 | 5 | - ### [Environment Details](Environment-Details.md) 6 | 7 | - ### [AgentsInfo - Observation, Action, Reward](AgentInfos-Obs-Action-Reward.md) 8 | 9 | - ### [Pretrain Model Details](Pretrain-Model-Details.md) 10 | 11 | - ### [Benchmark](Benchmark.md) 12 | 13 | - ### [Training and Environment Configuration](Training-and-Environment-Configuration.md) 14 | 15 | - ### [Setup Configuration File](Setup-Configuration-Files.md) 16 | 17 | - ### [ML-ImageSynthesis](ML-ImageSynthesis.md) 18 | 19 | - ### [Training Process](Training-Process.md) 20 | 21 | - ### [Requirements](Requirements.md) 22 | 23 | - ### [FAQ](FAQ.md) 24 | -------------------------------------------------------------------------------- /ml-agents/tests/trainers/test_demo_loader.py: -------------------------------------------------------------------------------- 1 | import unittest.mock as mock 2 | import pytest 3 | import os 4 | 5 | from mlagents.trainers.demo_loader import load_demonstration, make_demo_buffer 6 | 7 | 8 | def test_load_demo(): 9 | path_prefix = os.path.dirname(os.path.abspath(__file__)) 10 | brain_parameters, brain_infos, total_expected = load_demonstration(path_prefix+'/test.demo') 11 | assert (brain_parameters.brain_name == "Ball3DBrain") 12 | assert (brain_parameters.vector_observation_space_size == 8) 13 | assert (len(brain_infos) == total_expected) 14 | 15 | demo_buffer = make_demo_buffer(brain_infos, brain_parameters, 1) 16 | assert (len(demo_buffer.update_buffer['actions']) == total_expected - 1) 17 | -------------------------------------------------------------------------------- /protobuf-definitions/proto/mlagents/envs/communicator_objects/unity_rl_input.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "mlagents/envs/communicator_objects/agent_action_proto.proto"; 4 | import "mlagents/envs/communicator_objects/environment_parameters_proto.proto"; 5 | import "mlagents/envs/communicator_objects/command_proto.proto"; 6 | 7 | option csharp_namespace = "MLAgents.CommunicatorObjects"; 8 | package communicator_objects; 9 | 10 | message UnityRLInput { 11 | message ListAgentActionProto { 12 | repeated AgentActionProto value = 1; 13 | } 14 | map agent_actions = 1; 15 | EnvironmentParametersProto environment_parameters = 2; 16 | bool is_training = 3; 17 | CommandProto command = 4; 18 | } 19 | -------------------------------------------------------------------------------- /protobuf-definitions/proto/mlagents/envs/communicator_objects/brain_parameters_proto.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "mlagents/envs/communicator_objects/resolution_proto.proto"; 4 | import "mlagents/envs/communicator_objects/space_type_proto.proto"; 5 | 6 | option csharp_namespace = "MLAgents.CommunicatorObjects"; 7 | package communicator_objects; 8 | 9 | message BrainParametersProto { 10 | int32 vector_observation_size = 1; 11 | int32 num_stacked_vector_observations = 2; 12 | repeated int32 vector_action_size = 3; 13 | repeated ResolutionProto camera_resolutions = 4; 14 | repeated string vector_action_descriptions = 5; 15 | SpaceTypeProto vector_action_space_type = 6; 16 | string brain_name = 7; 17 | bool is_training = 8; 18 | } 19 | -------------------------------------------------------------------------------- /config/offline_bc_config.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | trainer: offline_bc 3 | batch_size: 64 4 | summary_freq: 1000 5 | max_steps: 5.0e4 6 | batches_per_epoch: 10 7 | use_recurrent: false 8 | hidden_units: 128 9 | learning_rate: 3.0e-4 10 | num_layers: 2 11 | sequence_length: 32 12 | memory_size: 256 13 | demo_path: ./UnitySDK/Assets/Demonstrations/.demo 14 | 15 | HallwayLearning: 16 | trainer: offline_bc 17 | max_steps: 5.0e5 18 | num_epoch: 5 19 | batch_size: 64 20 | batches_per_epoch: 5 21 | num_layers: 2 22 | hidden_units: 128 23 | sequence_length: 16 24 | use_recurrent: true 25 | memory_size: 256 26 | sequence_length: 32 27 | demo_path: ./UnitySDK/Assets/Demonstrations/Hallway.demo 28 | -------------------------------------------------------------------------------- /protobuf-definitions/README.md: -------------------------------------------------------------------------------- 1 | # Unity ML-Agents Protobuf Definitions 2 | 3 | Contains relevant definitions needed to generate probobuf files used in [ML-Agents Toolkit](https://github.com/Unity-Technologies/ml-agents). 4 | 5 | ## Requirements 6 | 7 | * grpc 1.10.1 8 | * protobuf 3.6.0 9 | 10 | ## Set-up & Installation 11 | 12 | `pip install protobuf==3.6.0 --force` 13 | 14 | `pip install grpcio-tools` 15 | 16 | `nuget install Grpc.Tools` into known directory. 17 | 18 | ### Installing Protobuf Compiler 19 | 20 | On Mac: `brew install protobuf` 21 | 22 | On Windows & Linux: [See here](https://github.com/google/protobuf/blob/master/src/README.md). 23 | 24 | ## Running 25 | 26 | 1. Install pre-requisites. 27 | 2. Un-comment line 4 in `make.bat`, and set to correct Grpc.Tools sub-directory. 28 | 3. Run `make.bat` 29 | -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/communicator_objects/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent_action_proto_pb2 import * 2 | from .agent_info_proto_pb2 import * 3 | from .brain_parameters_proto_pb2 import * 4 | from .command_proto_pb2 import * 5 | from .demonstration_meta_proto_pb2 import * 6 | from .engine_configuration_proto_pb2 import * 7 | from .environment_parameters_proto_pb2 import * 8 | from .header_pb2 import * 9 | from .resolution_proto_pb2 import * 10 | from .space_type_proto_pb2 import * 11 | from .unity_input_pb2 import * 12 | from .unity_message_pb2 import * 13 | from .unity_output_pb2 import * 14 | from .unity_rl_initialization_input_pb2 import * 15 | from .unity_rl_initialization_output_pb2 import * 16 | from .unity_rl_input_pb2 import * 17 | from .unity_rl_output_pb2 import * 18 | from .unity_to_external_pb2 import * 19 | from .unity_to_external_pb2_grpc import * 20 | -------------------------------------------------------------------------------- /docs/Benchmark.md: -------------------------------------------------------------------------------- 1 | # Benchmark 2 | BenchmarkManager provides a set of APIs
3 | ### Sample Code: 4 | ``` 5 | BenchmarkManager(agent_amount, benchmark_episode, success_threshold, verbose) # Initialize 6 | 7 | while True: 8 | 9 | action = decide(curr_info) 10 | new_info = env.step(action) 11 | 12 | if use_benchmark: 13 | BenchmarkManager.add_result(new_info) # Add every brain info for analysis 14 | 15 | if BenchmarkManager.is_complete(): 16 | BenchmarkManager.analyze() # Analyze and Print the result 17 | break 18 | 19 | curr_info = new_info 20 | ``` 21 | #### Agent_amount 22 | Agent amount of the environment 23 | 24 | #### Benchmark_episode 25 | Number of episode needed for benchmarking 26 | 27 | #### Success_threshold 28 | Minimum reward threshold being considered success in current time step 29 | 30 | #### Benchmark_verbose 31 | Whether or not print out episode information if episode ends 32 | 33 | 34 | ### Limitation: 35 | * Only support 1 benchmark concurrently 36 | -------------------------------------------------------------------------------- /ml-agents/tests/envs/test_rpc_communicator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from mlagents.envs import RpcCommunicator 4 | from mlagents.envs import UnityWorkerInUseException 5 | 6 | 7 | def test_rpc_communicator_checks_port_on_create(): 8 | first_comm = RpcCommunicator() 9 | with pytest.raises(UnityWorkerInUseException): 10 | second_comm = RpcCommunicator() 11 | second_comm.close() 12 | first_comm.close() 13 | 14 | 15 | def test_rpc_communicator_close(): 16 | # Ensures it is possible to open a new RPC Communicators 17 | # after closing one on the same worker_id 18 | first_comm = RpcCommunicator() 19 | first_comm.close() 20 | second_comm = RpcCommunicator() 21 | second_comm.close() 22 | 23 | 24 | def test_rpc_communicator_create_multiple_workers(): 25 | # Ensures multiple RPC communicators can be created with 26 | # different worker_ids without causing an error. 27 | first_comm = RpcCommunicator() 28 | second_comm = RpcCommunicator(worker_id=1) 29 | first_comm.close() 30 | second_comm.close() 31 | 32 | -------------------------------------------------------------------------------- /docs/ML-ImageSynthesis.md: -------------------------------------------------------------------------------- 1 | # ML-ImageSynthesis 2 | 3 | ### Raw Image 4 | 5 | 6 | ### Object ID Segmentation 7 | 8 | 9 | ### Object Type Segmentation 10 | 11 | 12 | ### Depth Map 13 | 14 | 15 | ### Normals 16 | 17 | 18 | ### Optical Flow 19 | 20 | 21 | 22 | Users can choose between raw image or segmentation for different difficulty levels and depth or optical flow for further research. 23 | 24 | * Source repo: [ML-ImageSynthesis](https://bitbucket.org/Unity-Technologies/ml-imagesynthesis)
25 | * Support Image Type : Raw Image, Segmentation, Depth, Opitcal Flow and Normals. 26 | * Type code: 0=Disable, 1=Raw, 2=ObjectID, 3=ObjectType, 4=Depth, 5=Normals, 6=Optical Flow 27 | * More info about the [How to configurable cameras](Setup-Configuration-Files.md#environment-config) 28 | 29 | ### Limitation 30 | ObjectID Mode doesn't guarantee label consistency, meaning same object may have different labels in different episode, but different objects in the same episode don't have same label, that's because GameObject inside the scene may be destroyed and re-instantiated. 31 | -------------------------------------------------------------------------------- /ml-agents/README.md: -------------------------------------------------------------------------------- 1 | # Unity ML-Agents Python Interface and Trainers 2 | 3 | The `mlagents` Python package is part of the 4 | [ML-Agents Toolkit](https://github.com/Unity-Technologies/ml-agents). 5 | `mlagents` provides a Python API that allows direct interaction with the Unity 6 | game engine as well as a collection of trainers and algorithms to train agents 7 | in Unity environments. 8 | 9 | The `mlagents` Python package contains two sub packages: 10 | 11 | * `mlagents.envs`: A low level API which allows you to interact directly with a 12 | Unity Environment. See 13 | [here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Python-API.md) 14 | for more information on using this package. 15 | 16 | * `mlagents.trainers`: A set of Reinforcement Learning algorithms designed to be 17 | used with Unity environments. Access them using the: `mlagents-learn` access 18 | point. See 19 | [here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-ML-Agents.md) 20 | for more information on using this package. 21 | 22 | ## Installation 23 | 24 | Install the `mlagents` package with: 25 | 26 | ```sh 27 | pip install mlagents 28 | ``` 29 | 30 | ## Usage & More Information 31 | 32 | For more detailed documentation, check out the 33 | [ML-Agents Toolkit documentation.](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Readme.md) 34 | -------------------------------------------------------------------------------- /config/curricula/autobench/AutoBenchBrain.json: -------------------------------------------------------------------------------- 1 | { 2 | "measure": "progress", 3 | "thresholds": [], 4 | "min_lesson_length": 0, 5 | "signal_smoothing": false, 6 | "parameters": { 7 | "camera1_type": [ 8 | 0 9 | ], 10 | "camera2_type": [ 11 | 3 12 | ], 13 | "camera3_type": [ 14 | 0 15 | ], 16 | "camera1_res_x": [ 17 | 0 18 | ], 19 | "camera2_res_x": [ 20 | 50 21 | ], 22 | "camera3_res_x": [ 23 | 0 24 | ], 25 | "camera1_res_y": [ 26 | 0 27 | ], 28 | "camera2_res_y": [ 29 | 50 30 | ], 31 | "camera3_res_y": [ 32 | 0 33 | ], 34 | "weather_id": [ 35 | 1 36 | ], 37 | "time_id": [ 38 | 9 39 | ], 40 | "road_width": [ 41 | 8 42 | ], 43 | "forward": [ 44 | true 45 | ], 46 | "detail": [ 47 | false 48 | ], 49 | "goal_reward": [ 50 | 500 51 | ], 52 | "time_penalty": [ 53 | -1 54 | ], 55 | "collision_penalty": [ 56 | -300 57 | ], 58 | "position_reward": [ 59 | 300 60 | ], 61 | "velocity_reward": [ 62 | 1 63 | ] 64 | } 65 | } -------------------------------------------------------------------------------- /docs/Training-Process.md: -------------------------------------------------------------------------------- 1 | # Training Process 2 | ### Reinforcement Learning
3 | Using Proximal Policy Optimzation (PPO) 4 | * Locates ```learn_rl.py``` 5 | * Modify the parameters [(more info)](Setup-Configuration-Files.md#python-script) 6 | * Run ```learn_rl.py``` 7 | 8 | ### General Machine Learning 9 | * Locates ```learn_ml.py``` 10 | * Modify the parameters [(more info)](Setup-Configuration-Files.md#python-script) 11 | * Implement your own decision algorithm in ```def decide(brain_info: BrainInfo) function``` 12 | * Run ```learn_ml.py``` 13 | 14 | ### OpenAI Gym Compitable 15 | * Sample Code: ```learn_gym.py``` 16 | ``` 17 | from gym_unity.envs import UnityEnv 18 | 19 | env = UnityEnv(environment_filename, worker_id=0, use_visual, multiagent, env_config,camera_res_overwrite) 20 | ``` 21 | * Limitation: 22 | By default the first visual observation is provided as the observation, if present. Otherwise vector observations are provided.
23 | All BrainInfo output from the environment can still be accessed from the info provided by ```env.step(action)``` 24 | 25 | 26 | ### Inference 27 | * Set ```load_model = True``` Load the pre-train model 28 | * Set ```train_model = False``` Don't run any learning algorithm 29 | * Set ```fast_simulation = False``` Enable inference mode, allow you to use WASD-controled Observe Camera 30 | * Run ```learn_rl.py``` 31 | 32 | ### Runing the Pre-train model 33 | * [Details about Pre-train model](Pretrain-Model-Details.md) 34 | -------------------------------------------------------------------------------- /docs/AgentInfos-Obs-Action-Reward.md: -------------------------------------------------------------------------------- 1 | # AgentsInfo - Observation, Action, Reward 2 | 3 | ## Vector Observations: 4 | * Vehicle velocity (Vector3) 5 | * Vehicle relative position to success area (Vector3) 6 | * Vehicle Y-axis1 rotation angle (float) 7 | 8 | ## Visual Observations: 9 | * Up to 3 cameras image output 10 | * More info about [Image output type](ML-ImageSynthesis.md) and [How to configure camera](Setup-Configuration-Files.md#environment-config) 11 | 12 | ## Actions: 13 | * +30, 0, -30 degree of streeing angle 14 | * +1, 0, -0.3 scale of throttle power. 15 | Total combination of 9 discrete actions 16 | 17 | ## Rewards: 18 | * Time penalty: -1 / timestep 19 | * Drive on the grass: -300 20 | * Collide with barriers: -300 21 | * Distance2 between agent and success area:
22 | ```max((last_timestep_distance - current_timestep_distance), 0) * 300``` 23 | * Velocity reward: +1/ m/s 24 | * Succeed: +500 25 | * More info about [Customize reward values](Setup-Configuration-Files.md#environment-config) 26 | 27 | ## Episode End Condiition: 28 | * Time step > 1500 (30s environment time) 29 | * Collide with barriers 30 | * Drive on the grass 31 | * Succeed 32 | 33 | ## Definition of Succeed: 34 | * All 4 wheels in the success area, and 35 | ```Vector3.Dot(VehicleForwardUnitVector, SuccessAreaForwardUnitVector) < 0.5``` 36 | 37 | 38 | 39 | 1 Y-axis is the vertical axis in Unity
40 | 2 Unity uses SI units
41 | -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .communicator_objects import UnityOutput, UnityInput 4 | 5 | logger = logging.getLogger("mlagents.envs") 6 | 7 | 8 | class Communicator(object): 9 | def __init__(self, worker_id=0, base_port=5005): 10 | """ 11 | Python side of the communication. Must be used in pair with the right Unity Communicator equivalent. 12 | 13 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 14 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 15 | """ 16 | 17 | def initialize(self, inputs: UnityInput) -> UnityOutput: 18 | """ 19 | Used to exchange initialization parameters between Python and the Environment 20 | :param inputs: The initialization input that will be sent to the environment. 21 | :return: UnityOutput: The initialization output sent by Unity 22 | """ 23 | 24 | def exchange(self, inputs: UnityInput) -> UnityOutput: 25 | """ 26 | Used to send an input and receive an output from the Environment 27 | :param inputs: The UnityInput that needs to be sent the Environment 28 | :return: The UnityOutputs generated by the Environment 29 | """ 30 | 31 | def close(self): 32 | """ 33 | Sends a shutdown signal to the unity environment, and closes the connection. 34 | """ 35 | 36 | -------------------------------------------------------------------------------- /docs/Environment-Details.md: -------------------------------------------------------------------------------- 1 | # Environment Details 2 | 3 | 4 | This environment consists of a 50x50m ground, surrounding by 3m walls, car model dimension is approximately 4.3x1.8x1.4m. 5 | 6 | **The goal is to navigate through the S curve path then enter the finish area1 without colliding with barriers.** 7 | 8 | Road geometry is based on part of the actual driving license exam in Taiwan which has both forward and backward phase. AutoBench provides a forward / backward option as part of the configuration suite compatible with 3 cameras setup for each agent's visual observations.
9 | 10 | 11 | ### Vehicle Cameras 12 | 13 | 14 | One front-facing camera mounted on the top center of the vehicle, two rear-facing cameras simulating each side mirrors of the vehicle. Users can specify each camera with any type of visual observation or simply disable the camera visual output. 15 | 16 | To improve training efficiency and speed, AutoBench has 10 agents training in parallel, all agents cannot collide with or see each others. 17 | 18 | ### Environment Configurations include: 19 | * Road Width 20 | * Forward and Backward 21 | * Level of Details 22 | * Camera Image Types and Resolutions 23 | * Weather Conditions 24 | * Time of Day 25 | * Reward Value(Reward Shaping) 26 | 27 | Learn more about [Training and Environment Configuration](Training-and-Environment-Configuration.md) 28 | 29 | 30 | 31 | 1 The blue box is just for visualize finish area, agent cannot see it. 32 | -------------------------------------------------------------------------------- /protobuf-definitions/make.bat: -------------------------------------------------------------------------------- 1 | # variables 2 | 3 | # GRPC-TOOLS required. Install with `nuget install Grpc.Tools`. 4 | # Then un-comment and replace [DIRECTORY] with location of files. 5 | # For example, on macOS, you might have something like: 6 | # COMPILER=Grpc.Tools.1.14.1/tools/macosx_x64 7 | # COMPILER=[DIRECTORY] 8 | 9 | SRC_DIR=proto/mlagents/envs/communicator_objects 10 | DST_DIR_C=../UnitySDK/Assets/ML-Agents/Scripts/CommunicatorObjects 11 | DST_DIR_P=../ml-agents 12 | PROTO_PATH=proto 13 | 14 | PYTHON_PACKAGE=mlagents/envs/communicator_objects 15 | 16 | # clean 17 | rm -rf $DST_DIR_C 18 | rm -rf $DST_DIR_P/$PYTHON_PACKAGE 19 | mkdir -p $DST_DIR_C 20 | mkdir -p $DST_DIR_P/$PYTHON_PACKAGE 21 | 22 | # generate proto objects in python and C# 23 | 24 | protoc --proto_path=proto --csharp_out=$DST_DIR_C $SRC_DIR/*.proto 25 | protoc --proto_path=proto --python_out=$DST_DIR_P $SRC_DIR/*.proto 26 | 27 | # grpc 28 | 29 | GRPC=unity_to_external.proto 30 | 31 | $COMPILER/protoc --proto_path=proto --csharp_out $DST_DIR_C --grpc_out $DST_DIR_C $SRC_DIR/$GRPC --plugin=protoc-gen-grpc=$COMPILER/grpc_csharp_plugin 32 | python3 -m grpc_tools.protoc --proto_path=proto --python_out=$DST_DIR_P --grpc_python_out=$DST_DIR_P $SRC_DIR/$GRPC 33 | 34 | 35 | # Generate the init file for the python module 36 | # rm -f $DST_DIR_P/$PYTHON_PACKAGE/__init__.py 37 | for FILE in $DST_DIR_P/$PYTHON_PACKAGE/*.py 38 | do 39 | FILE=${FILE##*/} 40 | # echo from .$(basename $FILE) import \* >> $DST_DIR_P/$PYTHON_PACKAGE/__init__.py 41 | echo from .${FILE%.py} import \* >> $DST_DIR_P/$PYTHON_PACKAGE/__init__.py 42 | done 43 | 44 | -------------------------------------------------------------------------------- /ml-agents/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from os import path 3 | from io import open 4 | 5 | here = path.abspath(path.dirname(__file__)) 6 | 7 | # Get the long description from the README file 8 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 9 | long_description = f.read() 10 | 11 | setup( 12 | name='mlagents', 13 | version='0.6.0', 14 | description='Unity Machine Learning Agents', 15 | long_description=long_description, 16 | long_description_content_type='text/markdown', 17 | url='https://github.com/Unity-Technologies/ml-agents', 18 | author='Unity Technologies', 19 | author_email='ML-Agents@unity3d.com', 20 | 21 | classifiers=[ 22 | 'Intended Audience :: Developers', 23 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 24 | 'License :: OSI Approved :: Apache Software License', 25 | 'Programming Language :: Python :: 3.6' 26 | ], 27 | 28 | packages=find_packages(exclude=['tests', 'tests.*', '*.tests', '*.tests.*']), # Required 29 | 30 | install_requires=[ 31 | 'tensorflow>=1.7,<1.8', 32 | 'Pillow>=4.2.1', 33 | 'matplotlib', 34 | 'numpy>=1.13.3,<=1.14.5', 35 | 'jupyter', 36 | 'pytest>=3.2.2,<4.0.0', 37 | 'docopt', 38 | 'pyyaml', 39 | 'protobuf>=3.6,<3.7', 40 | 'grpcio>=1.11.0,<1.12.0'], 41 | 42 | python_requires=">=3.6,<3.7", 43 | 44 | entry_points={ 45 | 'console_scripts': [ 46 | 'mlagents-learn=mlagents.trainers.learn:main', 47 | ], 48 | }, 49 | ) 50 | -------------------------------------------------------------------------------- /protobuf-definitions/make_for_win.bat: -------------------------------------------------------------------------------- 1 | rem variables 2 | 3 | rem GRPC-TOOLS required. Install with `nuget install Grpc.Tools`. 4 | rem Then un-comment and replace [DIRECTORY] with location of files. 5 | rem For example, on Windows, you might have something like: 6 | rem set COMPILER=Grpc.Tools.1.14.1/tools/windows_x64 7 | 8 | set SRC_DIR=proto\mlagents\envs\communicator_objects 9 | set DST_DIR_C=..\UnitySDK\Assets\ML-Agents\Scripts\CommunicatorObjects 10 | set DST_DIR_P=..\ml-agents 11 | set PROTO_PATH=proto 12 | 13 | set PYTHON_PACKAGE=mlagents\envs\communicator_objects 14 | 15 | rem clean 16 | rd /s /q %DST_DIR_C% 17 | rd /s /q %DST_DIR_P%\%PYTHON_PACKAGE% 18 | mkdir %DST_DIR_C% 19 | mkdir %DST_DIR_P%\%PYTHON_PACKAGE% 20 | 21 | rem generate proto objects in python and C# 22 | 23 | for %%i in (%SRC_DIR%\*.proto) do ( 24 | protoc --proto_path=proto --csharp_out=%DST_DIR_C% %%i 25 | protoc --proto_path=proto --python_out=%DST_DIR_P% %%i 26 | ) 27 | 28 | rem grpc 29 | 30 | set GRPC=unity_to_external.proto 31 | 32 | %COMPILER%\protoc --proto_path=proto --csharp_out %DST_DIR_C% --grpc_out %DST_DIR_C% %SRC_DIR%\%GRPC% --plugin=protoc-gen-grpc=%COMPILER%\grpc_csharp_plugin.exe 33 | python3 -m grpc_tools.protoc --proto_path=proto --python_out=%DST_DIR_P% --grpc_python_out=%DST_DIR_P% %SRC_DIR%\%GRPC% 34 | 35 | rem Generate the init file for the python module 36 | rem rm -f $DST_DIR_P/$PYTHON_PACKAGE/__init__.py 37 | setlocal enabledelayedexpansion 38 | for %%i in (%DST_DIR_P%\%PYTHON_PACKAGE%\*.py) do ( 39 | set FILE=%%~ni 40 | rem echo from .$(basename $FILE) import * >> $DST_DIR_P/$PYTHON_PACKAGE/__init__.py 41 | echo from .!FILE! import * >> %DST_DIR_P%\%PYTHON_PACKAGE%\__init__.py 42 | ) 43 | 44 | -------------------------------------------------------------------------------- /docs/Training-and-Environment-Configuration.md: -------------------------------------------------------------------------------- 1 | # Training and Environment Configuration 2 | 3 | ## Road Width 4 | 5 | 6 | Road width can be set between 0.1 to 18, controlling the difficulty of environment 7 | 8 | ## Forward and Backward 9 | 10 | Road geometry is based on part of the actual driving license exam in Taiwan which has both forward and backward phase. 11 | 12 | ## Level of Details 13 | 14 | 15 | ## Camera Image Types and Resolutions 16 | 17 | 18 | Cameras can be chosen from various image types, and any perfered resolution. 19 | 20 | More info about [Camera Output Type](ML-ImageSynthesis.md) 21 | 22 | ## Weather and Time 23 | Weather conditions includes: 24 | 25 | 26 | 27 | 28 | 29 | 30 | With weather code: 31 | * 0=ClearSky, 1=Cloudy1, 2=Cloudy2, 3=Cloudy3, 4=Cloudy4, 5=Foggy 32 | * 6=LightRain, 7=HeavyRain, 8=LightSnow, 9=HeavySnow, 10=Storm 33 | 34 | Time can be set to any float value between 0 and 24
35 | 36 | 37 | 38 | For instance, 9.5 = 9:30am, 18.25 = 6:15pm 39 | 40 | ### Noice: weather can also effect the environment lighting 41 | 42 | ## Reward 43 | AutoBench comes with default reward value trained for pretrain baseline, users can modified the value for further reward shaping. 44 | More info about [Reward Calculation](AgentInfos-Obs-Action-Reward.md#rewards) 45 | 46 | ## Learn more about [Setup Configuration Files](Setup-Configuration-Files.md) 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea/ 2 | 3 | /AutoBenchExecutable 4 | 5 | /UnitySDK/[Ll]ibrary/ 6 | /UnitySDK/[Tt]emp/ 7 | /UnitySDK/[Oo]bj/ 8 | /UnitySDK/[Bb]uild/ 9 | /UnitySDK/[Bb]uilds/ 10 | /UnitySDK/[Pp]ackages/ 11 | /UnitySDK/[Uu]nity[Pp]ackage[Mm]anager/ 12 | /UnitySDK/Assets/AssetStoreTools* 13 | /UnitySDK/Assets/Plugins* 14 | /UnitySDK/Assets/Gizmos* 15 | /UnitySDK/Assets/Demonstrations* 16 | 17 | # Training environments 18 | /envs 19 | 20 | # Environemnt logfile 21 | *UnitySDK.log 22 | 23 | # Visual Studio 2015 cache directory 24 | /UnitySDK/.vs/ 25 | 26 | # Autogenerated VS/MD/Consulo solution and project files 27 | /UnitySDKExportedObj/ 28 | /UnitySDK.consulo/ 29 | *.csproj 30 | *.unityproj 31 | *.sln 32 | *.suo 33 | *.tmp 34 | *.user 35 | *.userprefs 36 | *.pidb 37 | *.booproj 38 | *.svd 39 | *.pdb 40 | 41 | # Unity3D generated meta files 42 | *.pidb.meta 43 | 44 | # Unity3D Generated File On Crash Reports 45 | /UnitySDK/sysinfo.txt 46 | 47 | # Builds 48 | *.apk 49 | *.unitypackage 50 | *.app 51 | *.exe 52 | *.x86_64 53 | *.x86 54 | 55 | # Tensorflow Sharp Files 56 | /UnitySDK/Assets/ML-Agents/Plugins/Android* 57 | /UnitySDK/Assets/ML-Agents/Plugins/iOS* 58 | /UnitySDK/Assets/ML-Agents/Plugins/Computer* 59 | /UnitySDK/Assets/ML-Agents/Plugins/System.Numerics* 60 | /UnitySDK/Assets/ML-Agents/Plugins/System.ValueTuple* 61 | 62 | # Generated doc folders 63 | /docs/html 64 | 65 | # Mac hidden files 66 | *.DS_Store 67 | */.ipynb_checkpoints 68 | */.idea 69 | *.pyc 70 | *.idea/misc.xml 71 | *.idea/modules.xml 72 | *.iml 73 | *.cache 74 | */build/ 75 | */dist/ 76 | *.egg-info* 77 | *.eggs* 78 | *.gitignore.swp 79 | 80 | # VSCode hidden files 81 | *.vscode/ 82 | 83 | .DS_Store 84 | .ipynb_checkpoints 85 | 86 | # pytest cache 87 | *.pytest_cache/ 88 | 89 | # Ignore compiled protobuf files. 90 | ml-agents-protobuf/cs 91 | ml-agents-protobuf/python 92 | ml-agents-protobuf/Grpc* 93 | 94 | # Ignore PyPi build files. 95 | dist/ 96 | build/ 97 | -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/communicator_objects/unity_to_external_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | import grpc 3 | 4 | from mlagents.envs.communicator_objects import unity_message_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2 5 | 6 | 7 | class UnityToExternalStub(object): 8 | # missing associated documentation comment in .proto file 9 | pass 10 | 11 | def __init__(self, channel): 12 | """Constructor. 13 | 14 | Args: 15 | channel: A grpc.Channel. 16 | """ 17 | self.Exchange = channel.unary_unary( 18 | '/communicator_objects.UnityToExternal/Exchange', 19 | request_serializer=mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessage.SerializeToString, 20 | response_deserializer=mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessage.FromString, 21 | ) 22 | 23 | 24 | class UnityToExternalServicer(object): 25 | # missing associated documentation comment in .proto file 26 | pass 27 | 28 | def Exchange(self, request, context): 29 | """Sends the academy parameters 30 | """ 31 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 32 | context.set_details('Method not implemented!') 33 | raise NotImplementedError('Method not implemented!') 34 | 35 | 36 | def add_UnityToExternalServicer_to_server(servicer, server): 37 | rpc_method_handlers = { 38 | 'Exchange': grpc.unary_unary_rpc_method_handler( 39 | servicer.Exchange, 40 | request_deserializer=mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessage.FromString, 41 | response_serializer=mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessage.SerializeToString, 42 | ), 43 | } 44 | generic_handler = grpc.method_handlers_generic_handler( 45 | 'communicator_objects.UnityToExternal', rpc_method_handlers) 46 | server.add_generic_rpc_handlers((generic_handler,)) 47 | -------------------------------------------------------------------------------- /gym-unity/tests/test_gym.py: -------------------------------------------------------------------------------- 1 | import unittest.mock as mock 2 | import pytest 3 | import numpy as np 4 | 5 | from gym_unity.envs import UnityEnv, UnityGymException 6 | from tests.mock_communicator import MockCommunicator 7 | 8 | @mock.patch('mlagents.envs.UnityEnvironment.executable_launcher') 9 | @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') 10 | def test_gym_wrapper(mock_communicator, mock_launcher): 11 | mock_communicator.return_value = MockCommunicator( 12 | discrete_action=False, visual_inputs=0, stack=False, num_agents=1) 13 | 14 | # Test for incorrect number of agents. 15 | with pytest.raises(UnityGymException): 16 | UnityEnv(' ', use_visual=False, multiagent=True) 17 | 18 | env = UnityEnv(' ', use_visual=False) 19 | assert isinstance(env, UnityEnv) 20 | assert isinstance(env.reset(), np.ndarray) 21 | actions = env.action_space.sample() 22 | assert actions.shape[0] == 2 23 | obs, rew, done, info = env.step(actions) 24 | assert isinstance(obs, np.ndarray) 25 | assert isinstance(rew, float) 26 | assert isinstance(done, bool) 27 | assert isinstance(info, dict) 28 | 29 | @mock.patch('mlagents.envs.UnityEnvironment.executable_launcher') 30 | @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') 31 | def test_multi_agent(mock_communicator, mock_launcher): 32 | mock_communicator.return_value = MockCommunicator( 33 | discrete_action=False, visual_inputs=0, stack=False, num_agents=2) 34 | 35 | # Test for incorrect number of agents. 36 | with pytest.raises(UnityGymException): 37 | UnityEnv(' ', multiagent=False) 38 | 39 | env = UnityEnv(' ', use_visual=False, multiagent=True) 40 | assert isinstance(env.reset(), list) 41 | actions = [env.action_space.sample() for i in range(env.number_agents)] 42 | obs, rew, done, info = env.step(actions) 43 | assert isinstance(obs, list) 44 | assert isinstance(rew, list) 45 | assert isinstance(done, list) 46 | assert isinstance(info, dict) 47 | -------------------------------------------------------------------------------- /ml-agents/tests/trainers/test_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mlagents.trainers.buffer import Buffer 3 | 4 | 5 | def assert_array(a, b): 6 | assert a.shape == b.shape 7 | la = list(a.flatten()) 8 | lb = list(b.flatten()) 9 | for i in range(len(la)): 10 | assert la[i] == lb[i] 11 | 12 | 13 | def test_buffer(): 14 | b = Buffer() 15 | for fake_agent_id in range(4): 16 | for step in range(9): 17 | b[fake_agent_id]['vector_observation'].append( 18 | [100 * fake_agent_id + 10 * step + 1, 19 | 100 * fake_agent_id + 10 * step + 2, 20 | 100 * fake_agent_id + 10 * step + 3] 21 | ) 22 | b[fake_agent_id]['action'].append([100 * fake_agent_id + 10 * step + 4, 23 | 100 * fake_agent_id + 10 * step + 5]) 24 | a = b[1]['vector_observation'].get_batch(batch_size=2, training_length=1, sequential=True) 25 | assert_array(a, np.array([[171, 172, 173], [181, 182, 183]])) 26 | a = b[2]['vector_observation'].get_batch(batch_size=2, training_length=3, sequential=True) 27 | assert_array(a, np.array([ 28 | [[231, 232, 233], [241, 242, 243], [251, 252, 253]], 29 | [[261, 262, 263], [271, 272, 273], [281, 282, 283]] 30 | ])) 31 | a = b[2]['vector_observation'].get_batch(batch_size=2, training_length=3, sequential=False) 32 | assert_array(a, np.array([ 33 | [[251, 252, 253], [261, 262, 263], [271, 272, 273]], 34 | [[261, 262, 263], [271, 272, 273], [281, 282, 283]] 35 | ])) 36 | b[4].reset_agent() 37 | assert len(b[4]) == 0 38 | b.append_update_buffer(3, batch_size=None, training_length=2) 39 | b.append_update_buffer(2, batch_size=None, training_length=2) 40 | assert len(b.update_buffer['action']) == 10 41 | assert np.array(b.update_buffer['action']).shape == (10, 2, 2) 42 | 43 | c = b.update_buffer.make_mini_batch(start=0, end=1) 44 | assert c.keys() == b.update_buffer.keys() 45 | assert c['action'].shape == (1, 2, 2) 46 | -------------------------------------------------------------------------------- /docs/Pretrain-Model-Details.md: -------------------------------------------------------------------------------- 1 | # Pretrain Model Details 2 | ### Setup 3 | * Download the [Pretrain Model](https://goo.gl/HunBMB) 4 | * Unzip the file 5 | * Copy and Paste ```Pretrain``` folder under the ```AutoBench/models``` folder (if not exist, create one) 6 | * Set these in ```learn_rl.py``` 7 | ``` 8 | run_id = 'Pretrain' 9 | load_model = True 10 | ``` 11 | 12 | ### Architecture 13 | 14 | 15 | The architecture follows the standard [PPO](https://arxiv.org/abs/1707.06347) implementation with the following modification. 16 | * Adding Convolutional Neural Networks (CNN) as a visual encoder 17 | * Adding Long Short Term Memory (LSTM) layer after the concatenation of visual and vector latent features layer to implement recurrent functionality in [Deep Recurrent Q-Network](). 18 | 19 | ### Tesorboard 20 | 21 | 22 | 23 | ### Difficulty and Agents Settings 24 | The following settings are the training configuration of this pre-train baseline 25 | 26 | ``` 27 | AutoBenchBrain: 28 | batch_size: 512 29 | beta: 1.0e-1 30 | buffer_size: 4096 31 | epsilon: 0.2 32 | gamma: 0.99 33 | hidden_units: 128 34 | lambd: 0.95 35 | learning_rate: 1.0e-4 36 | max_steps: 1.0e7 37 | memory_size: 256 38 | normalize: true 39 | num_epoch: 5 40 | num_layers: 2 41 | time_horizon: 512 42 | sequence_length: 64 43 | summary_freq: 1000 44 | use_recurrent: true 45 | 46 | "parameters": { 47 | "camera1_type": [0], 48 | "camera2_type": [3], 49 | "camera3_type": [0], 50 | "camera1_res_x": [0], 51 | "camera2_res_x": [50], 52 | "camera3_res_x": [0], 53 | "camera1_res_y": [0], 54 | "camera2_res_y": [50], 55 | "camera3_res_y": [0], 56 | "weather_id": [1], 57 | "time_id": [9], 58 | "road_width": [8], 59 | "forward": [true], 60 | "detail": [false], 61 | "goal_reward": [500], 62 | "time_penalty": [-1], 63 | "collision_penalty": [-300], 64 | "position_reward": [300], 65 | "velocity_reward": [1] 66 | } 67 | 68 | ``` 69 | -------------------------------------------------------------------------------- /ml-agents/mlagents/trainers/benchmark.py: -------------------------------------------------------------------------------- 1 | from mlagents.envs.brain import BrainInfo 2 | import numpy as np 3 | 4 | class BenchmarkManager(object): 5 | 6 | agent_status = [[]] 7 | agent_amount = None 8 | agent_benchmark_result = [] #[episode_len, cumulative_reward, success_goal] 9 | success_threshold = None 10 | # agent_benchmark_result = [[5,100,True],[6,120,False], [10,255,True]] 11 | benchmark_episode = None 12 | verbose = False 13 | 14 | @staticmethod 15 | def is_complete(): 16 | return len(BenchmarkManager.agent_benchmark_result) >= BenchmarkManager.benchmark_episode 17 | 18 | @staticmethod 19 | def add_result(info: BrainInfo): 20 | 21 | for agent_index in range(BenchmarkManager.agent_amount): 22 | 23 | BenchmarkManager.agent_status[agent_index][0] += 1 24 | BenchmarkManager.agent_status[agent_index][1] += info.rewards[agent_index] 25 | 26 | if info.local_done[agent_index]: 27 | if BenchmarkManager.verbose: 28 | print('Episode Length', BenchmarkManager.agent_status[agent_index][0]) 29 | print('Cumulative Reward', BenchmarkManager.agent_status[agent_index][1]) 30 | 31 | u = BenchmarkManager.agent_status[agent_index][:] 32 | 33 | if info.rewards[agent_index] >= BenchmarkManager.success_threshold: 34 | u.append(True) 35 | else: 36 | u.append(False) 37 | 38 | BenchmarkManager.agent_benchmark_result.append(u[:]) 39 | BenchmarkManager.agent_status[agent_index][0] = 0 40 | BenchmarkManager.agent_status[agent_index][1] = 0 41 | 42 | @staticmethod 43 | def analyze(): 44 | 45 | result = np.array(BenchmarkManager.agent_benchmark_result) 46 | 47 | print('Episode Length: Avg = %.2f, Std = %.2f' % (np.average(result[:,0]), np.std(result[:,0]))) 48 | print('Reward: Avg = %.2f, Std = %.2f' % (np.average(result[:,1]), np.std(result[:,1]))) 49 | print( 50 | 'Success Rate: ', 51 | '{:.0%}'.format(np.sum(result[:,2]) / len(BenchmarkManager.agent_benchmark_result)) 52 | ) 53 | 54 | 55 | def __init__(self, agent_amount, benchmark_episode, success_threshold, verbose): 56 | 57 | BenchmarkManager.agent_status = [[0 for x in range(2)] for y in range(agent_amount)] 58 | BenchmarkManager.agent_amount = agent_amount 59 | BenchmarkManager.benchmark_episode = benchmark_episode 60 | BenchmarkManager.success_threshold = success_threshold 61 | BenchmarkManager.verbose = verbose 62 | 63 | -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/communicator_objects/command_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: mlagents/envs/communicator_objects/command_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf.internal import enum_type_wrapper 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | from google.protobuf import descriptor_pb2 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='mlagents/envs/communicator_objects/command_proto.proto', 21 | package='communicator_objects', 22 | syntax='proto3', 23 | serialized_pb=_b('\n6mlagents/envs/communicator_objects/command_proto.proto\x12\x14\x63ommunicator_objects*-\n\x0c\x43ommandProto\x12\x08\n\x04STEP\x10\x00\x12\t\n\x05RESET\x10\x01\x12\x08\n\x04QUIT\x10\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | ) 25 | 26 | _COMMANDPROTO = _descriptor.EnumDescriptor( 27 | name='CommandProto', 28 | full_name='communicator_objects.CommandProto', 29 | filename=None, 30 | file=DESCRIPTOR, 31 | values=[ 32 | _descriptor.EnumValueDescriptor( 33 | name='STEP', index=0, number=0, 34 | options=None, 35 | type=None), 36 | _descriptor.EnumValueDescriptor( 37 | name='RESET', index=1, number=1, 38 | options=None, 39 | type=None), 40 | _descriptor.EnumValueDescriptor( 41 | name='QUIT', index=2, number=2, 42 | options=None, 43 | type=None), 44 | ], 45 | containing_type=None, 46 | options=None, 47 | serialized_start=80, 48 | serialized_end=125, 49 | ) 50 | _sym_db.RegisterEnumDescriptor(_COMMANDPROTO) 51 | 52 | CommandProto = enum_type_wrapper.EnumTypeWrapper(_COMMANDPROTO) 53 | STEP = 0 54 | RESET = 1 55 | QUIT = 2 56 | 57 | 58 | DESCRIPTOR.enum_types_by_name['CommandProto'] = _COMMANDPROTO 59 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 60 | 61 | 62 | DESCRIPTOR.has_options = True 63 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 64 | # @@protoc_insertion_point(module_scope) 65 | -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/exception.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger("mlagents.envs") 3 | 4 | class UnityException(Exception): 5 | """ 6 | Any error related to ml-agents environment. 7 | """ 8 | pass 9 | 10 | class UnityEnvironmentException(UnityException): 11 | """ 12 | Related to errors starting and closing environment. 13 | """ 14 | pass 15 | 16 | 17 | class UnityActionException(UnityException): 18 | """ 19 | Related to errors with sending actions. 20 | """ 21 | pass 22 | 23 | class UnityTimeOutException(UnityException): 24 | """ 25 | Related to errors with communication timeouts. 26 | """ 27 | def __init__(self, message, log_file_path = None): 28 | if log_file_path is not None: 29 | try: 30 | with open(log_file_path, "r") as f: 31 | printing = False 32 | unity_error = '\n' 33 | for l in f: 34 | l=l.strip() 35 | if (l == 'Exception') or (l=='Error'): 36 | printing = True 37 | unity_error += '----------------------\n' 38 | if (l == ''): 39 | printing = False 40 | if printing: 41 | unity_error += l + '\n' 42 | logger.info(unity_error) 43 | logger.error("An error might have occured in the environment. " 44 | "You can check the logfile for more information at {}".format(log_file_path)) 45 | except: 46 | logger.error("An error might have occured in the environment. " 47 | "No UnitySDK.log file could be found.") 48 | super(UnityTimeOutException, self).__init__(message) 49 | 50 | 51 | class UnityWorkerInUseException(UnityException): 52 | """ 53 | This error occurs when the port for a certain worker ID is already reserved. 54 | """ 55 | 56 | MESSAGE_TEMPLATE = ( 57 | "Couldn't start socket communication because worker number {} is still in use. " 58 | "You may need to manually close a previously opened environment " 59 | "or use a different worker number.") 60 | 61 | def __init__(self, worker_id): 62 | message = self.MESSAGE_TEMPLATE.format(str(worker_id)) 63 | super(UnityWorkerInUseException, self).__init__(message) 64 | -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/communicator_objects/unity_to_external_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: mlagents/envs/communicator_objects/unity_to_external.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | from mlagents.envs.communicator_objects import unity_message_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='mlagents/envs/communicator_objects/unity_to_external.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 23 | serialized_pb=_b('\n:mlagents/envs/communicator_objects/unity_to_external.proto\x12\x14\x63ommunicator_objects\x1a\x36mlagents/envs/communicator_objects/unity_message.proto2g\n\x0fUnityToExternal\x12T\n\x08\x45xchange\x12\".communicator_objects.UnityMessage\x1a\".communicator_objects.UnityMessage\"\x00\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | , 25 | dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2.DESCRIPTOR,]) 26 | 27 | 28 | 29 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 30 | 31 | 32 | DESCRIPTOR._options = None 33 | 34 | _UNITYTOEXTERNAL = _descriptor.ServiceDescriptor( 35 | name='UnityToExternal', 36 | full_name='communicator_objects.UnityToExternal', 37 | file=DESCRIPTOR, 38 | index=0, 39 | serialized_options=None, 40 | serialized_start=140, 41 | serialized_end=243, 42 | methods=[ 43 | _descriptor.MethodDescriptor( 44 | name='Exchange', 45 | full_name='communicator_objects.UnityToExternal.Exchange', 46 | index=0, 47 | containing_service=None, 48 | input_type=mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2._UNITYMESSAGE, 49 | output_type=mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2._UNITYMESSAGE, 50 | serialized_options=None, 51 | ), 52 | ]) 53 | _sym_db.RegisterServiceDescriptor(_UNITYTOEXTERNAL) 54 | 55 | DESCRIPTOR.services_by_name['UnityToExternal'] = _UNITYTOEXTERNAL 56 | 57 | # @@protoc_insertion_point(module_scope) 58 | -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/communicator_objects/space_type_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: mlagents/envs/communicator_objects/space_type_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf.internal import enum_type_wrapper 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | from google.protobuf import descriptor_pb2 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | from mlagents.envs.communicator_objects import resolution_proto_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_resolution__proto__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='mlagents/envs/communicator_objects/space_type_proto.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n9mlagents/envs/communicator_objects/space_type_proto.proto\x12\x14\x63ommunicator_objects\x1a\x39mlagents/envs/communicator_objects/resolution_proto.proto*.\n\x0eSpaceTypeProto\x12\x0c\n\x08\x64iscrete\x10\x00\x12\x0e\n\ncontinuous\x10\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,]) 27 | 28 | _SPACETYPEPROTO = _descriptor.EnumDescriptor( 29 | name='SpaceTypeProto', 30 | full_name='communicator_objects.SpaceTypeProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | values=[ 34 | _descriptor.EnumValueDescriptor( 35 | name='discrete', index=0, number=0, 36 | options=None, 37 | type=None), 38 | _descriptor.EnumValueDescriptor( 39 | name='continuous', index=1, number=1, 40 | options=None, 41 | type=None), 42 | ], 43 | containing_type=None, 44 | options=None, 45 | serialized_start=142, 46 | serialized_end=188, 47 | ) 48 | _sym_db.RegisterEnumDescriptor(_SPACETYPEPROTO) 49 | 50 | SpaceTypeProto = enum_type_wrapper.EnumTypeWrapper(_SPACETYPEPROTO) 51 | discrete = 0 52 | continuous = 1 53 | 54 | 55 | DESCRIPTOR.enum_types_by_name['SpaceTypeProto'] = _SPACETYPEPROTO 56 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 57 | 58 | 59 | DESCRIPTOR.has_options = True 60 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 61 | # @@protoc_insertion_point(module_scope) 62 | -------------------------------------------------------------------------------- /ml-agents/mlagents/trainers/bc/offline_trainer.py: -------------------------------------------------------------------------------- 1 | # # Unity ML-Agents Toolkit 2 | # ## ML-Agent Learning (Behavioral Cloning) 3 | # Contains an implementation of Behavioral Cloning Algorithm 4 | 5 | import logging 6 | import copy 7 | 8 | from mlagents.trainers.bc.trainer import BCTrainer 9 | from mlagents.trainers.demo_loader import demo_to_buffer 10 | from mlagents.trainers.trainer import UnityTrainerException 11 | 12 | logger = logging.getLogger("mlagents.trainers") 13 | 14 | 15 | class OfflineBCTrainer(BCTrainer): 16 | """The OfflineBCTrainer is an implementation of Offline Behavioral Cloning.""" 17 | 18 | def __init__(self, brain, trainer_parameters, training, load, seed, run_id): 19 | """ 20 | Responsible for collecting experiences and training PPO model. 21 | :param trainer_parameters: The parameters for the trainer (dictionary). 22 | :param training: Whether the trainer is set for training. 23 | :param load: Whether the model should be loaded. 24 | :param seed: The seed the model will be initialized with 25 | :param run_id: The The identifier of the current run 26 | """ 27 | super(OfflineBCTrainer, self).__init__( 28 | brain, trainer_parameters, training, load, seed, run_id) 29 | 30 | self.param_keys = ['batch_size', 'summary_freq', 'max_steps', 31 | 'batches_per_epoch', 'use_recurrent', 32 | 'hidden_units', 'learning_rate', 'num_layers', 33 | 'sequence_length', 'memory_size', 'model_path', 34 | 'demo_path'] 35 | 36 | self.check_param_keys() 37 | self.batches_per_epoch = trainer_parameters['batches_per_epoch'] 38 | self.n_sequences = max(int(trainer_parameters['batch_size'] / self.policy.sequence_length), 39 | 1) 40 | 41 | brain_params, self.demonstration_buffer = demo_to_buffer( 42 | trainer_parameters['demo_path'], 43 | self.policy.sequence_length) 44 | 45 | policy_brain = copy.deepcopy(brain.__dict__) 46 | expert_brain = copy.deepcopy(brain_params.__dict__) 47 | policy_brain.pop('brain_name') 48 | expert_brain.pop('brain_name') 49 | if expert_brain != policy_brain: 50 | raise UnityTrainerException("The provided demonstration is not compatible with the " 51 | "brain being used for performance evaluation.") 52 | 53 | def __str__(self): 54 | return '''Hyperparameters for the Imitation Trainer of brain {0}: \n{1}'''.format( 55 | self.brain_name, '\n'.join( 56 | ['\t{0}:\t{1}'.format(x, self.trainer_parameters[x]) for x in self.param_keys])) 57 | -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/communicator_objects/unity_rl_initialization_input_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: mlagents/envs/communicator_objects/unity_rl_initialization_input.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='mlagents/envs/communicator_objects/unity_rl_initialization_input.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\nFmlagents/envs/communicator_objects/unity_rl_initialization_input.proto\x12\x14\x63ommunicator_objects\"*\n\x1aUnityRLInitializationInput\x12\x0c\n\x04seed\x18\x01 \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _UNITYRLINITIALIZATIONINPUT = _descriptor.Descriptor( 29 | name='UnityRLInitializationInput', 30 | full_name='communicator_objects.UnityRLInitializationInput', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='seed', full_name='communicator_objects.UnityRLInitializationInput.seed', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | ], 43 | extensions=[ 44 | ], 45 | nested_types=[], 46 | enum_types=[ 47 | ], 48 | options=None, 49 | is_extendable=False, 50 | syntax='proto3', 51 | extension_ranges=[], 52 | oneofs=[ 53 | ], 54 | serialized_start=96, 55 | serialized_end=138, 56 | ) 57 | 58 | DESCRIPTOR.message_types_by_name['UnityRLInitializationInput'] = _UNITYRLINITIALIZATIONINPUT 59 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 60 | 61 | UnityRLInitializationInput = _reflection.GeneratedProtocolMessageType('UnityRLInitializationInput', (_message.Message,), dict( 62 | DESCRIPTOR = _UNITYRLINITIALIZATIONINPUT, 63 | __module__ = 'mlagents.envs.communicator_objects.unity_rl_initialization_input_pb2' 64 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInitializationInput) 65 | )) 66 | _sym_db.RegisterMessage(UnityRLInitializationInput) 67 | 68 | 69 | DESCRIPTOR.has_options = True 70 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 71 | # @@protoc_insertion_point(module_scope) 72 | -------------------------------------------------------------------------------- /config/online_bc_config.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | trainer: online_bc 3 | brain_to_imitate: 4 | batch_size: 64 5 | time_horizon: 64 6 | summary_freq: 1000 7 | max_steps: 5.0e4 8 | batches_per_epoch: 10 9 | use_recurrent: false 10 | hidden_units: 128 11 | learning_rate: 3.0e-4 12 | num_layers: 2 13 | sequence_length: 32 14 | memory_size: 256 15 | 16 | BananaLearning: 17 | trainer: online_bc 18 | max_steps: 10000 19 | summary_freq: 1000 20 | brain_to_imitate: BananaPlayer 21 | batch_size: 16 22 | batches_per_epoch: 5 23 | num_layers: 4 24 | hidden_units: 64 25 | use_recurrent: false 26 | sequence_length: 16 27 | 28 | BouncerLearning: 29 | trainer: online_bc 30 | max_steps: 10000 31 | summary_freq: 10 32 | brain_to_imitate: BouncerPlayer 33 | batch_size: 16 34 | batches_per_epoch: 1 35 | num_layers: 1 36 | hidden_units: 64 37 | use_recurrent: false 38 | sequence_length: 16 39 | 40 | HallwayLearning: 41 | trainer: online_bc 42 | max_steps: 10000 43 | summary_freq: 1000 44 | brain_to_imitate: HallwayPlayer 45 | batch_size: 16 46 | batches_per_epoch: 5 47 | num_layers: 4 48 | hidden_units: 64 49 | use_recurrent: false 50 | sequence_length: 16 51 | 52 | PushBlockLearning: 53 | trainer: online_bc 54 | max_steps: 10000 55 | summary_freq: 1000 56 | brain_to_imitate: PushBlockPlayer 57 | batch_size: 16 58 | batches_per_epoch: 5 59 | num_layers: 4 60 | hidden_units: 64 61 | use_recurrent: false 62 | sequence_length: 16 63 | 64 | PyramidsLearning: 65 | trainer: online_bc 66 | max_steps: 10000 67 | summary_freq: 1000 68 | brain_to_imitate: PyramidsPlayer 69 | batch_size: 16 70 | batches_per_epoch: 5 71 | num_layers: 4 72 | hidden_units: 64 73 | use_recurrent: false 74 | sequence_length: 16 75 | 76 | TennisLearning: 77 | trainer: online_bc 78 | max_steps: 10000 79 | summary_freq: 1000 80 | brain_to_imitate: TennisPlayer 81 | batch_size: 16 82 | batches_per_epoch: 5 83 | num_layers: 4 84 | hidden_units: 64 85 | use_recurrent: false 86 | sequence_length: 16 87 | 88 | StudentBrain: 89 | trainer: online_bc 90 | max_steps: 10000 91 | summary_freq: 1000 92 | brain_to_imitate: TeacherBrain 93 | batch_size: 16 94 | batches_per_epoch: 5 95 | num_layers: 4 96 | hidden_units: 64 97 | use_recurrent: false 98 | sequence_length: 16 99 | 100 | StudentRecurrentBrain: 101 | trainer: online_bc 102 | max_steps: 10000 103 | summary_freq: 1000 104 | brain_to_imitate: TeacherBrain 105 | batch_size: 16 106 | batches_per_epoch: 5 107 | num_layers: 4 108 | hidden_units: 64 109 | use_recurrent: true 110 | sequence_length: 32 -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/communicator_objects/header_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: mlagents/envs/communicator_objects/header.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='mlagents/envs/communicator_objects/header.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n/mlagents/envs/communicator_objects/header.proto\x12\x14\x63ommunicator_objects\")\n\x06Header\x12\x0e\n\x06status\x18\x01 \x01(\x05\x12\x0f\n\x07message\x18\x02 \x01(\tB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _HEADER = _descriptor.Descriptor( 29 | name='Header', 30 | full_name='communicator_objects.Header', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='status', full_name='communicator_objects.Header.status', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='message', full_name='communicator_objects.Header.message', index=1, 44 | number=2, type=9, cpp_type=9, label=1, 45 | has_default_value=False, default_value=_b("").decode('utf-8'), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | ], 50 | extensions=[ 51 | ], 52 | nested_types=[], 53 | enum_types=[ 54 | ], 55 | options=None, 56 | is_extendable=False, 57 | syntax='proto3', 58 | extension_ranges=[], 59 | oneofs=[ 60 | ], 61 | serialized_start=73, 62 | serialized_end=114, 63 | ) 64 | 65 | DESCRIPTOR.message_types_by_name['Header'] = _HEADER 66 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 67 | 68 | Header = _reflection.GeneratedProtocolMessageType('Header', (_message.Message,), dict( 69 | DESCRIPTOR = _HEADER, 70 | __module__ = 'mlagents.envs.communicator_objects.header_pb2' 71 | # @@protoc_insertion_point(class_scope:communicator_objects.Header) 72 | )) 73 | _sym_db.RegisterMessage(Header) 74 | 75 | 76 | DESCRIPTOR.has_options = True 77 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 78 | # @@protoc_insertion_point(module_scope) 79 | -------------------------------------------------------------------------------- /learn_gym.py: -------------------------------------------------------------------------------- 1 | import json 2 | from gym_unity.envs import UnityEnv 3 | 4 | def extract_camera_config_gym(config_file): 5 | 6 | config = [] 7 | use_visual = True 8 | with open(config_file, 'r') as data_file: 9 | data = json.load(data_file) 10 | params = data['parameters'] 11 | if params['camera1_type'][0] != 0: 12 | config += [{ 13 | "height": params['camera1_res_y'][0], 14 | "width": params['camera1_res_x'][0], 15 | "blackAndWhite": False 16 | }] 17 | if params['camera2_type'][0] != 0: 18 | config += [{ 19 | "height": params['camera2_res_y'][0], 20 | "width": params['camera2_res_x'][0], 21 | "blackAndWhite": False 22 | }] 23 | if params['camera3_type'][0] != 0: 24 | config += [{ 25 | "height": params['camera3_res_y'][0], 26 | "width": params['camera3_res_x'][0], 27 | "blackAndWhite": False 28 | }] 29 | 30 | if len(config) == 0: 31 | use_visual = False 32 | 33 | if len(config) > 1: 34 | config = config[0:1] 35 | 36 | return config, use_visual 37 | 38 | def check_config_validity_gym(config): 39 | 40 | if config['camera1_type'] < 0 or config['camera1_type'] > 6: 41 | raise ValueError('camera1_type') 42 | if config['camera2_type'] < 0 or config['camera2_type'] > 6: 43 | raise ValueError('camera2_type') 44 | if config['camera3_type'] < 0 or config['camera3_type'] > 6: 45 | raise ValueError('camera3_type') 46 | 47 | if config['camera1_type'] != 0 and (config['camera1_res_x'] == 0 or config['camera1_res_y'] == 0): 48 | raise ValueError('camera1_res') 49 | if config['camera2_type'] != 0 and (config['camera2_res_x'] == 0 or config['camera2_res_y'] == 0): 50 | raise ValueError('camera2_res') 51 | if config['camera3_type'] != 0 and (config['camera3_res_x'] == 0 or config['camera3_res_y'] == 0): 52 | raise ValueError('camera3_res') 53 | 54 | if config['weather_id'] < 0 or config['weather_id'] > 10: 55 | raise ValueError('weather_id') 56 | if config['time_id'] < 0 or config['time_id'] >= 24: 57 | raise ValueError('time_id') 58 | if config['road_width'] <= 0: 59 | raise ValueError('road_width') 60 | 61 | if (config['camera1_type'] != 0) + (config['camera2_type'] != 0) + (config['camera3_type'] != 0) > 1: 62 | raise ValueError('Gym only support 1 visual observation') 63 | 64 | def get_env_config(curriculum_folder): 65 | 66 | try: 67 | with open(curriculum_folder) as data_file: 68 | data = json.load(data_file) 69 | except IOError: 70 | raise IOError() 71 | 72 | config = {} 73 | parameters = data['parameters'] 74 | for key in parameters: 75 | config[key] = parameters[key][0] 76 | 77 | check_config_validity_gym(config) 78 | 79 | return config 80 | 81 | def main(): 82 | 83 | env_path = 'AutoBenchExecutable/AutoBenchExecutable' 84 | curriculum_file = 'config/curricula/autobench/AutoBenchBrain.json' 85 | camera_res_overwrite, use_visual = extract_camera_config_gym(curriculum_file) 86 | # Setup the Unity Environment 87 | env = UnityEnv(environment_filename=env_path, worker_id=0, use_visual=use_visual, 88 | multiagent=True, env_config=get_env_config(curriculum_file), 89 | camera_res_overwrite=camera_res_overwrite) 90 | 91 | if __name__ == '__main__': 92 | main() -------------------------------------------------------------------------------- /ml-agents/tests/trainers/test_curriculum.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import json 3 | from unittest.mock import patch, mock_open 4 | 5 | from mlagents.trainers.exception import CurriculumError 6 | from mlagents.trainers import Curriculum 7 | 8 | 9 | dummy_curriculum_json_str = ''' 10 | { 11 | "measure" : "reward", 12 | "thresholds" : [10, 20, 50], 13 | "min_lesson_length" : 3, 14 | "signal_smoothing" : true, 15 | "parameters" : 16 | { 17 | "param1" : [0.7, 0.5, 0.3, 0.1], 18 | "param2" : [100, 50, 20, 15], 19 | "param3" : [0.2, 0.3, 0.7, 0.9] 20 | } 21 | } 22 | ''' 23 | 24 | 25 | bad_curriculum_json_str = ''' 26 | { 27 | "measure" : "reward", 28 | "thresholds" : [10, 20, 50], 29 | "min_lesson_length" : 3, 30 | "signal_smoothing" : false, 31 | "parameters" : 32 | { 33 | "param1" : [0.7, 0.5, 0.3, 0.1], 34 | "param2" : [100, 50, 20], 35 | "param3" : [0.2, 0.3, 0.7, 0.9] 36 | } 37 | } 38 | ''' 39 | 40 | @pytest.fixture 41 | def location(): 42 | return 'TestBrain.json' 43 | 44 | 45 | @pytest.fixture 46 | def default_reset_parameters(): 47 | return {"param1": 1, "param2": 1, "param3": 1} 48 | 49 | 50 | @patch('builtins.open', new_callable=mock_open, read_data=dummy_curriculum_json_str) 51 | def test_init_curriculum_happy_path(mock_file, location, default_reset_parameters): 52 | curriculum = Curriculum(location, default_reset_parameters) 53 | 54 | assert curriculum._brain_name == 'TestBrain' 55 | assert curriculum.lesson_num == 0 56 | assert curriculum.measure == 'reward' 57 | 58 | 59 | @patch('builtins.open', new_callable=mock_open, read_data=bad_curriculum_json_str) 60 | def test_init_curriculum_bad_curriculum_raises_error(mock_file, location, default_reset_parameters): 61 | with pytest.raises(CurriculumError): 62 | Curriculum(location, default_reset_parameters) 63 | 64 | 65 | @patch('builtins.open', new_callable=mock_open, read_data=dummy_curriculum_json_str) 66 | def test_increment_lesson(mock_file, location, default_reset_parameters): 67 | curriculum = Curriculum(location, default_reset_parameters) 68 | assert curriculum.lesson_num == 0 69 | 70 | curriculum.lesson_num = 1 71 | assert curriculum.lesson_num == 1 72 | 73 | assert not curriculum.increment_lesson(10) 74 | assert curriculum.lesson_num == 1 75 | 76 | assert curriculum.increment_lesson(30) 77 | assert curriculum.lesson_num == 2 78 | 79 | assert not curriculum.increment_lesson(30) 80 | assert curriculum.lesson_num == 2 81 | 82 | assert curriculum.increment_lesson(10000) 83 | assert curriculum.lesson_num == 3 84 | 85 | 86 | @patch('builtins.open', new_callable=mock_open, read_data=dummy_curriculum_json_str) 87 | def test_get_config(mock_file): 88 | curriculum = Curriculum('TestBrain.json', {"param1": 1, "param2": 1, "param3": 1}) 89 | assert curriculum.get_config() == {"param1": 0.7, "param2": 100, "param3": 0.2} 90 | 91 | curriculum.lesson_num = 2 92 | assert curriculum.get_config() == {'param1': 0.3, 'param2': 20, 'param3': 0.7} 93 | assert curriculum.get_config(0) == {"param1": 0.7, "param2": 100, "param3": 0.2} 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoBench - Autonomous vehicle training and benchmark environment with configurable difficulty 2 | 3 | 4 | [![AutoBench Video](docs/images/thumbnail.jpg)](https://www.youtube.com/watch?v=Ptg1hnLxy9U) 5 | 6 | **AutoBench** is an open-source project base on [Unity ML-Agents Toolkit](https://github.com/Unity-Technologies/ml-agents) featuring high configurability including difficulty, rewards, weather conditions, and visual observation types. Using **REAL** driving license exam in Taiwan as an example to showcase the applicability of autonomous vehicle in reinforcement learning approach with configurable difficulty technique. 7 | 8 | Environment supports: 9 | * General Machine Learning - letting the user implement their own decide algorithm within the out-of-the-box standard training loop. 10 | * Reinforcement Learning - Containing the entire RL training loop, based on part of ML-Agents trainer with built-in [Proximal Policy Algorithm](https://arxiv.org/abs/1707.06347) algorithm, and recurrent support using [DRQN](https://arxiv.org/abs/1507.06527) 11 | 12 | * OpenAI Gym Compatible - Based on Gym-Unity wrapper, that supports OpenAI Gym interface for further integration meaning users are able to test any Gym-compatible algorithms within AutoBench. [Limitation](docs/Training-Process.md#openai-gym-compitable) 13 | 14 | * Benchmark Mode - Complement for above every training mode, ```BenchmarkManager``` is a static class handling benchmark tracking and analysis. Sample code can be found in ```learn_ml.py``` and ```learn_rl.py``` 15 | 16 | 17 | ## Features 18 | * Road Geometry based on real license exam. 19 | * Support Reinforcement learning, General machine learning, and OpenAI Gym interface. [(learn more)](docs/Training-Process.md) 20 | * Configurable Difficulty through Road Width, Visual Details and Image types. [(learn more)](docs/Training-and-Environment-Configuration.md) 21 | * Various Visual Observation Types support. [(learn more)](docs/ML-ImageSynthesis.md) 22 | * WASD-control Observation Camera in inference mode 23 | * Configurable Weather Conditions, Time of the Day. [(learn more)](docs/Training-and-Environment-Configuration.md#weather-and-time) 24 | * Configurable Rewards for reward shaping [(learn more)](docs/AgentInfos-Obs-Action-Reward.md#rewards) 25 | * Out-of-the-box Benchmark system [(learn more)](docs/Benchmark.md) 26 | * A Pre-train model [(learn more)](docs/Pretrain-Model-Details.md) 27 | 28 | (future work) 29 | * More configurable road spec (curvature, bumpiness) 30 | * More tasks (parallel parking, reverse parking) 31 | * Configurable agent amount 32 | 33 | ## -> Tutorial and Documentation <- 34 | More info about the [Tutorial and Documentation](docs/) 35 | 36 | ## References 37 | - [Unity ML-Agents Repository](https://github.com/Unity-Technologies/ml-agents) 38 | - [ML-ImageSynthesis Repository](https://bitbucket.org/Unity-Technologies/ml-imagesynthesis) 39 | - [Proximal Policy Algorithm](https://arxiv.org/abs/1707.06347) 40 | - [Deep Recurrent Q-Learning for Partially Observable MDPs](https://arxiv.org/abs/1507.06527) 41 | 42 | ## Citation 43 | Paper Citation (TBA) 44 | 45 | ## License 46 | [Apache License 2.0](LICENSE) 47 | 48 | ## Buy me a coffee 49 | Bitcoin: 1ErZXAEoQVzFSkarXKpxfTRzYrp9SALVp2
50 | Ethereum: 0x312ADcc92c3ff549001ea4437A767c512C9546E3 51 | 52 | ## Feedback 53 | Feel free to give us feedback using the "Issue" section 54 | -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/communicator_objects/resolution_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: mlagents/envs/communicator_objects/resolution_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='mlagents/envs/communicator_objects/resolution_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n9mlagents/envs/communicator_objects/resolution_proto.proto\x12\x14\x63ommunicator_objects\"D\n\x0fResolutionProto\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x12\n\ngray_scale\x18\x03 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _RESOLUTIONPROTO = _descriptor.Descriptor( 29 | name='ResolutionProto', 30 | full_name='communicator_objects.ResolutionProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='width', full_name='communicator_objects.ResolutionProto.width', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='height', full_name='communicator_objects.ResolutionProto.height', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='gray_scale', full_name='communicator_objects.ResolutionProto.gray_scale', index=2, 51 | number=3, type=8, cpp_type=7, label=1, 52 | has_default_value=False, default_value=False, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | syntax='proto3', 65 | extension_ranges=[], 66 | oneofs=[ 67 | ], 68 | serialized_start=83, 69 | serialized_end=151, 70 | ) 71 | 72 | DESCRIPTOR.message_types_by_name['ResolutionProto'] = _RESOLUTIONPROTO 73 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 74 | 75 | ResolutionProto = _reflection.GeneratedProtocolMessageType('ResolutionProto', (_message.Message,), dict( 76 | DESCRIPTOR = _RESOLUTIONPROTO, 77 | __module__ = 'mlagents.envs.communicator_objects.resolution_proto_pb2' 78 | # @@protoc_insertion_point(class_scope:communicator_objects.ResolutionProto) 79 | )) 80 | _sym_db.RegisterMessage(ResolutionProto) 81 | 82 | 83 | DESCRIPTOR.has_options = True 84 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 85 | # @@protoc_insertion_point(module_scope) 86 | -------------------------------------------------------------------------------- /docs/Setup-Configuration-Files.md: -------------------------------------------------------------------------------- 1 | # Setup Configuration Files 2 | 3 | ### Training Config 4 | Located in config/trainer_config.yaml
5 | This is the configuration about Reinforcement Learning PPO trainer 6 | ``` 7 | AutoBenchBrain: 8 | batch_size: 1024 9 | beta: 1.0e-1 10 | buffer_size: 1024 11 | epsilon: 0.2 12 | gamma: 0.99 13 | hidden_units: 128 14 | lambd: 0.95 15 | learning_rate: 3.0e-4 16 | max_steps: 1.0e7 17 | memory_size: 256 18 | normalize: true 19 | num_epoch: 5 20 | num_layers: 2 21 | time_horizon: 512 22 | sequence_length: 64 23 | summary_freq: 3000 24 | use_recurrent: true 25 | ``` 26 | ### Environment Config 27 | Located in config/curricula/autobench/AutoBenchBrain.json 28 | This is about the configuration of Unity environment 29 | ``` 30 | { 31 | "measure": "progress", #Ignore 32 | "thresholds": [], #Ignore 33 | "min_lesson_length": 0, #Ignore 34 | "signal_smoothing": false, #Ignore 35 | "parameters": { 36 | "camera1_type": [0], 37 | "camera2_type": [3], 38 | "camera3_type": [0], 39 | "camera1_res_x": [0], 40 | "camera2_res_x": [50], 41 | "camera3_res_x": [0], 42 | "camera1_res_y": [0], 43 | "camera2_res_y": [50], 44 | "camera3_res_y": [0], 45 | "weather_id": [1], 46 | "time_id": [9], 47 | "road_width": [7], 48 | "forward": [true], 49 | "detail": [false], 50 | "goal_reward": [500], 51 | "time_penalty": [-1], 52 | "collision_penalty": [-300], 53 | "position_reward": [300], 54 | "velocity_reward": [1] 55 | } 56 | } 57 | ``` 58 | Only need to focus on ```parameters``` section 59 | 60 | ### Python Script 61 | Located in learn_rl.py, learn_ml.py, learn_gym.py
62 | The following uses learn_rl.py as an example 63 | ``` 64 | env_path = 'AutoBenchExecutable/AutoBenchExecutable' #Default executable path 65 | run_id = '1' 66 | load_model = False 67 | train_model = True 68 | save_freq = 10000 69 | keep_checkpoints = 1000 70 | worker_id = 0 71 | run_seed = 0 72 | curriculum_folder = 'config/curricula/autobench/' 73 | curriculum_file = 'config/curricula/autobench/AutoBenchBrain.json' 74 | lesson = 0 75 | fast_simulation = True 76 | no_graphics = False 77 | trainer_config_path = 'config/trainer_config.yaml' 78 | benchmark = False 79 | benchmark_episode = 100 80 | benchmark_verbose = True 81 | ``` 82 | #### Env_path 83 | Path of the Unity executable 84 | #### Run_id 85 | Identifier for each run, suitable for fine-tunning parameters 86 | #### Load_model 87 | Whether load the tensorflow model 88 | #### Train_model 89 | Whether train the tensorflow model 90 | #### Save_freq 91 | Frequency of the tensorflow model saved 92 | #### Keep_checkpoints 93 | Maximum checkpoint allow for saving 94 | #### Worker_id 95 | Ignore and set to 0 96 | #### Run_seed 97 | Random seed of the Unity executable 98 | #### Curriculum_folder 99 | Folder of environment config file 100 | #### Curriculum_file 101 | Location of environment config file 102 | #### Lesson 103 | Ignore and set to 0 104 | #### Fast_simulation 105 | If set to True, small window, 100X time scale, 10 agents
106 | If set to False, large window, 1X time scale, 1 agent and WASD-controled Observe Camera 107 | #### No_graphic 108 | Whether not showing the windows of Unity environment 109 | #### Trainer_config_path 110 | Location of trainer config file 111 | #### Benchmark 112 | Whether benchmark the current model 113 | #### Benchmark_episode 114 | Number of episode needed for benchmarking 115 | #### Benchmark_verbose 116 | Whether or not print out episode information if episode ends 117 | -------------------------------------------------------------------------------- /ml-agents/mlagents/trainers/bc/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.layers as c_layers 3 | from mlagents.trainers.models import LearningModel 4 | 5 | 6 | class BehavioralCloningModel(LearningModel): 7 | def __init__(self, brain, h_size=128, lr=1e-4, n_layers=2, m_size=128, 8 | normalize=False, use_recurrent=False, seed=0): 9 | LearningModel.__init__(self, m_size, normalize, use_recurrent, brain, seed) 10 | num_streams = 1 11 | hidden_streams = self.create_observation_streams(num_streams, h_size, n_layers) 12 | hidden = hidden_streams[0] 13 | self.dropout_rate = tf.placeholder(dtype=tf.float32, shape=[], name="dropout_rate") 14 | hidden_reg = tf.layers.dropout(hidden, self.dropout_rate) 15 | if self.use_recurrent: 16 | tf.Variable(self.m_size, name="memory_size", trainable=False, dtype=tf.int32) 17 | self.memory_in = tf.placeholder(shape=[None, self.m_size], dtype=tf.float32, name='recurrent_in') 18 | hidden_reg, self.memory_out = self.create_recurrent_encoder(hidden_reg, self.memory_in, 19 | self.sequence_length) 20 | self.memory_out = tf.identity(self.memory_out, name='recurrent_out') 21 | 22 | if brain.vector_action_space_type == "discrete": 23 | policy_branches = [] 24 | for size in self.act_size: 25 | policy_branches.append( 26 | tf.layers.dense( 27 | hidden, 28 | size, 29 | activation=None, 30 | use_bias=False, 31 | kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01))) 32 | self.action_probs = tf.concat( 33 | [tf.nn.softmax(branch) for branch in policy_branches], axis=1, name="action_probs") 34 | self.action_masks = tf.placeholder(shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks") 35 | self.sample_action_float, normalized_logits = self.create_discrete_action_masking_layer( 36 | tf.concat(policy_branches, axis=1), self.action_masks, self.act_size) 37 | tf.identity(normalized_logits, name='action') 38 | self.sample_action = tf.cast(self.sample_action_float, tf.int32) 39 | self.true_action = tf.placeholder(shape=[None, len(policy_branches)], dtype=tf.int32, name="teacher_action") 40 | self.action_oh = tf.concat([ 41 | tf.one_hot(self.true_action[:, i], self.act_size[i]) for i in range(len(self.act_size))], axis=1) 42 | self.loss = tf.reduce_sum(-tf.log(self.action_probs + 1e-10) * self.action_oh) 43 | self.action_percent = tf.reduce_mean(tf.cast( 44 | tf.equal(tf.cast(tf.argmax(self.action_probs, axis=1), tf.int32), self.sample_action), tf.float32)) 45 | else: 46 | self.policy = tf.layers.dense(hidden_reg, self.act_size[0], activation=None, use_bias=False, name='pre_action', 47 | kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01)) 48 | self.clipped_sample_action = tf.clip_by_value(self.policy, -1, 1) 49 | self.sample_action = tf.identity(self.clipped_sample_action, name="action") 50 | self.true_action = tf.placeholder(shape=[None, self.act_size[0]], dtype=tf.float32, name="teacher_action") 51 | self.clipped_true_action = tf.clip_by_value(self.true_action, -1, 1) 52 | self.loss = tf.reduce_sum(tf.squared_difference(self.clipped_true_action, self.sample_action)) 53 | 54 | optimizer = tf.train.AdamOptimizer(learning_rate=lr) 55 | self.update = optimizer.minimize(self.loss) 56 | -------------------------------------------------------------------------------- /ml-agents/tests/mock_communicator.py: -------------------------------------------------------------------------------- 1 | from mlagents.envs.communicator import Communicator 2 | from mlagents.envs.communicator_objects import UnityMessage, UnityOutput, UnityInput, \ 3 | ResolutionProto, BrainParametersProto, UnityRLInitializationOutput, \ 4 | AgentInfoProto, UnityRLOutput 5 | 6 | 7 | class MockCommunicator(Communicator): 8 | def __init__(self, discrete_action=False, visual_inputs=0, stack=True, num_agents=3, 9 | brain_name="RealFakeBrain", vec_obs_size=3): 10 | """ 11 | Python side of the grpc communication. Python is the client and Unity the server 12 | 13 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 14 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 15 | """ 16 | self.is_discrete = discrete_action 17 | self.steps = 0 18 | self.visual_inputs = visual_inputs 19 | self.has_been_closed = False 20 | self.num_agents = num_agents 21 | self.brain_name = brain_name 22 | self.vec_obs_size = vec_obs_size 23 | if stack: 24 | self.num_stacks = 2 25 | else: 26 | self.num_stacks = 1 27 | 28 | def initialize(self, inputs: UnityInput) -> UnityOutput: 29 | resolutions = [ResolutionProto( 30 | width=30, 31 | height=40, 32 | gray_scale=False) for i in range(self.visual_inputs)] 33 | bp = BrainParametersProto( 34 | vector_observation_size=self.vec_obs_size, 35 | num_stacked_vector_observations=self.num_stacks, 36 | vector_action_size=[2], 37 | camera_resolutions=resolutions, 38 | vector_action_descriptions=["", ""], 39 | vector_action_space_type=int(not self.is_discrete), 40 | brain_name=self.brain_name, 41 | is_training=True 42 | ) 43 | rl_init = UnityRLInitializationOutput( 44 | name="RealFakeAcademy", 45 | version="API-6", 46 | log_path="", 47 | brain_parameters=[bp] 48 | ) 49 | return UnityOutput( 50 | rl_initialization_output=rl_init 51 | ) 52 | 53 | def exchange(self, inputs: UnityInput) -> UnityOutput: 54 | dict_agent_info = {} 55 | if self.is_discrete: 56 | vector_action = [1] 57 | else: 58 | vector_action = [1, 2] 59 | list_agent_info = [] 60 | if self.num_stacks == 1: 61 | observation = [1, 2, 3] 62 | else: 63 | observation = [1, 2, 3, 1, 2, 3] 64 | 65 | for i in range(self.num_agents): 66 | list_agent_info.append( 67 | AgentInfoProto( 68 | stacked_vector_observation=observation, 69 | reward=1, 70 | stored_vector_actions=vector_action, 71 | stored_text_actions="", 72 | text_observation="", 73 | memories=[], 74 | done=(i == 2), 75 | max_step_reached=False, 76 | id=i 77 | )) 78 | dict_agent_info["RealFakeBrain"] = \ 79 | UnityRLOutput.ListAgentInfoProto(value=list_agent_info) 80 | global_done = False 81 | try: 82 | fake_brain = inputs.rl_input.agent_actions["RealFakeBrain"] 83 | global_done = (fake_brain.value[0].vector_actions[0] == -1) 84 | except: 85 | pass 86 | result = UnityRLOutput( 87 | global_done=global_done, 88 | agentInfos=dict_agent_info 89 | ) 90 | return UnityOutput( 91 | rl_output=result 92 | ) 93 | 94 | def close(self): 95 | """ 96 | Sends a shutdown signal to the unity environment, and closes the grpc connection. 97 | """ 98 | self.has_been_closed = True 99 | -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/communicator_objects/agent_action_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: mlagents/envs/communicator_objects/agent_action_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='mlagents/envs/communicator_objects/agent_action_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n;mlagents/envs/communicator_objects/agent_action_proto.proto\x12\x14\x63ommunicator_objects\"a\n\x10\x41gentActionProto\x12\x16\n\x0evector_actions\x18\x01 \x03(\x02\x12\x14\n\x0ctext_actions\x18\x02 \x01(\t\x12\x10\n\x08memories\x18\x03 \x03(\x02\x12\r\n\x05value\x18\x04 \x01(\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _AGENTACTIONPROTO = _descriptor.Descriptor( 29 | name='AgentActionProto', 30 | full_name='communicator_objects.AgentActionProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='vector_actions', full_name='communicator_objects.AgentActionProto.vector_actions', index=0, 37 | number=1, type=2, cpp_type=6, label=3, 38 | has_default_value=False, default_value=[], 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='text_actions', full_name='communicator_objects.AgentActionProto.text_actions', index=1, 44 | number=2, type=9, cpp_type=9, label=1, 45 | has_default_value=False, default_value=_b("").decode('utf-8'), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='memories', full_name='communicator_objects.AgentActionProto.memories', index=2, 51 | number=3, type=2, cpp_type=6, label=3, 52 | has_default_value=False, default_value=[], 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | _descriptor.FieldDescriptor( 57 | name='value', full_name='communicator_objects.AgentActionProto.value', index=3, 58 | number=4, type=2, cpp_type=6, label=1, 59 | has_default_value=False, default_value=float(0), 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None, file=DESCRIPTOR), 63 | ], 64 | extensions=[ 65 | ], 66 | nested_types=[], 67 | enum_types=[ 68 | ], 69 | options=None, 70 | is_extendable=False, 71 | syntax='proto3', 72 | extension_ranges=[], 73 | oneofs=[ 74 | ], 75 | serialized_start=85, 76 | serialized_end=182, 77 | ) 78 | 79 | DESCRIPTOR.message_types_by_name['AgentActionProto'] = _AGENTACTIONPROTO 80 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 81 | 82 | AgentActionProto = _reflection.GeneratedProtocolMessageType('AgentActionProto', (_message.Message,), dict( 83 | DESCRIPTOR = _AGENTACTIONPROTO, 84 | __module__ = 'mlagents.envs.communicator_objects.agent_action_proto_pb2' 85 | # @@protoc_insertion_point(class_scope:communicator_objects.AgentActionProto) 86 | )) 87 | _sym_db.RegisterMessage(AgentActionProto) 88 | 89 | 90 | DESCRIPTOR.has_options = True 91 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 92 | # @@protoc_insertion_point(module_scope) 93 | -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/communicator_objects/unity_input_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: mlagents/envs/communicator_objects/unity_input.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from mlagents.envs.communicator_objects import unity_rl_input_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__input__pb2 17 | from mlagents.envs.communicator_objects import unity_rl_initialization_input_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__initialization__input__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='mlagents/envs/communicator_objects/unity_input.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n4mlagents/envs/communicator_objects/unity_input.proto\x12\x14\x63ommunicator_objects\x1a\x37mlagents/envs/communicator_objects/unity_rl_input.proto\x1a\x46mlagents/envs/communicator_objects/unity_rl_initialization_input.proto\"\x95\x01\n\nUnityInput\x12\x34\n\x08rl_input\x18\x01 \x01(\x0b\x32\".communicator_objects.UnityRLInput\x12Q\n\x17rl_initialization_input\x18\x02 \x01(\x0b\x32\x30.communicator_objects.UnityRLInitializationInputB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__input__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__initialization__input__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _UNITYINPUT = _descriptor.Descriptor( 32 | name='UnityInput', 33 | full_name='communicator_objects.UnityInput', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='rl_input', full_name='communicator_objects.UnityInput.rl_input', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='rl_initialization_input', full_name='communicator_objects.UnityInput.rl_initialization_input', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None, file=DESCRIPTOR), 52 | ], 53 | extensions=[ 54 | ], 55 | nested_types=[], 56 | enum_types=[ 57 | ], 58 | options=None, 59 | is_extendable=False, 60 | syntax='proto3', 61 | extension_ranges=[], 62 | oneofs=[ 63 | ], 64 | serialized_start=208, 65 | serialized_end=357, 66 | ) 67 | 68 | _UNITYINPUT.fields_by_name['rl_input'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__input__pb2._UNITYRLINPUT 69 | _UNITYINPUT.fields_by_name['rl_initialization_input'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__initialization__input__pb2._UNITYRLINITIALIZATIONINPUT 70 | DESCRIPTOR.message_types_by_name['UnityInput'] = _UNITYINPUT 71 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 72 | 73 | UnityInput = _reflection.GeneratedProtocolMessageType('UnityInput', (_message.Message,), dict( 74 | DESCRIPTOR = _UNITYINPUT, 75 | __module__ = 'mlagents.envs.communicator_objects.unity_input_pb2' 76 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityInput) 77 | )) 78 | _sym_db.RegisterMessage(UnityInput) 79 | 80 | 81 | DESCRIPTOR.has_options = True 82 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 83 | # @@protoc_insertion_point(module_scope) 84 | -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/communicator_objects/unity_output_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: mlagents/envs/communicator_objects/unity_output.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from mlagents.envs.communicator_objects import unity_rl_output_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__output__pb2 17 | from mlagents.envs.communicator_objects import unity_rl_initialization_output_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__initialization__output__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='mlagents/envs/communicator_objects/unity_output.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n5mlagents/envs/communicator_objects/unity_output.proto\x12\x14\x63ommunicator_objects\x1a\x38mlagents/envs/communicator_objects/unity_rl_output.proto\x1aGmlagents/envs/communicator_objects/unity_rl_initialization_output.proto\"\x9a\x01\n\x0bUnityOutput\x12\x36\n\trl_output\x18\x01 \x01(\x0b\x32#.communicator_objects.UnityRLOutput\x12S\n\x18rl_initialization_output\x18\x02 \x01(\x0b\x32\x31.communicator_objects.UnityRLInitializationOutputB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__output__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__initialization__output__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _UNITYOUTPUT = _descriptor.Descriptor( 32 | name='UnityOutput', 33 | full_name='communicator_objects.UnityOutput', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='rl_output', full_name='communicator_objects.UnityOutput.rl_output', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='rl_initialization_output', full_name='communicator_objects.UnityOutput.rl_initialization_output', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None, file=DESCRIPTOR), 52 | ], 53 | extensions=[ 54 | ], 55 | nested_types=[], 56 | enum_types=[ 57 | ], 58 | options=None, 59 | is_extendable=False, 60 | syntax='proto3', 61 | extension_ranges=[], 62 | oneofs=[ 63 | ], 64 | serialized_start=211, 65 | serialized_end=365, 66 | ) 67 | 68 | _UNITYOUTPUT.fields_by_name['rl_output'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__output__pb2._UNITYRLOUTPUT 69 | _UNITYOUTPUT.fields_by_name['rl_initialization_output'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__initialization__output__pb2._UNITYRLINITIALIZATIONOUTPUT 70 | DESCRIPTOR.message_types_by_name['UnityOutput'] = _UNITYOUTPUT 71 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 72 | 73 | UnityOutput = _reflection.GeneratedProtocolMessageType('UnityOutput', (_message.Message,), dict( 74 | DESCRIPTOR = _UNITYOUTPUT, 75 | __module__ = 'mlagents.envs.communicator_objects.unity_output_pb2' 76 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityOutput) 77 | )) 78 | _sym_db.RegisterMessage(UnityOutput) 79 | 80 | 81 | DESCRIPTOR.has_options = True 82 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 83 | # @@protoc_insertion_point(module_scope) 84 | -------------------------------------------------------------------------------- /ml-agents/mlagents/trainers/bc/policy.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | from mlagents.trainers.bc.models import BehavioralCloningModel 5 | from mlagents.trainers.policy import Policy 6 | 7 | logger = logging.getLogger("mlagents.trainers") 8 | 9 | 10 | class BCPolicy(Policy): 11 | def __init__(self, seed, brain, trainer_parameters, load): 12 | """ 13 | :param seed: Random seed. 14 | :param brain: Assigned Brain object. 15 | :param trainer_parameters: Defined training parameters. 16 | :param load: Whether a pre-trained model will be loaded or a new one created. 17 | """ 18 | super(BCPolicy, self).__init__(seed, brain, trainer_parameters) 19 | 20 | with self.graph.as_default(): 21 | with self.graph.as_default(): 22 | self.model = BehavioralCloningModel( 23 | h_size=int(trainer_parameters['hidden_units']), 24 | lr=float(trainer_parameters['learning_rate']), 25 | n_layers=int(trainer_parameters['num_layers']), 26 | m_size=self.m_size, 27 | normalize=False, 28 | use_recurrent=trainer_parameters['use_recurrent'], 29 | brain=brain, 30 | seed=seed) 31 | 32 | if load: 33 | self._load_graph() 34 | else: 35 | self._initialize_graph() 36 | 37 | self.inference_dict = {'action': self.model.sample_action} 38 | self.update_dict = {'policy_loss': self.model.loss, 39 | 'update_batch': self.model.update} 40 | if self.use_recurrent: 41 | self.inference_dict['memory_out'] = self.model.memory_out 42 | 43 | self.evaluate_rate = 1.0 44 | self.update_rate = 0.5 45 | 46 | def evaluate(self, brain_info): 47 | """ 48 | Evaluates policy for the agent experiences provided. 49 | :param brain_info: BrainInfo input to network. 50 | :return: Results of evaluation. 51 | """ 52 | feed_dict = {self.model.dropout_rate: self.evaluate_rate, 53 | self.model.sequence_length: 1} 54 | 55 | feed_dict = self._fill_eval_dict(feed_dict, brain_info) 56 | if self.use_recurrent: 57 | if brain_info.memories.shape[1] == 0: 58 | brain_info.memories = self.make_empty_memory(len(brain_info.agents)) 59 | feed_dict[self.model.memory_in] = brain_info.memories 60 | run_out = self._execute_model(feed_dict, self.inference_dict) 61 | return run_out 62 | 63 | def update(self, mini_batch, num_sequences): 64 | """ 65 | Performs update on model. 66 | :param mini_batch: Batch of experiences. 67 | :param num_sequences: Number of sequences to process. 68 | :return: Results of update. 69 | """ 70 | 71 | feed_dict = {self.model.dropout_rate: self.update_rate, 72 | self.model.batch_size: num_sequences, 73 | self.model.sequence_length: self.sequence_length} 74 | if self.use_continuous_act: 75 | feed_dict[self.model.true_action] = mini_batch['actions']. \ 76 | reshape([-1, self.brain.vector_action_space_size[0]]) 77 | else: 78 | feed_dict[self.model.true_action] = mini_batch['actions'].reshape( 79 | [-1, len(self.brain.vector_action_space_size)]) 80 | feed_dict[self.model.action_masks] = np.ones( 81 | (num_sequences, sum(self.brain.vector_action_space_size))) 82 | if self.use_vec_obs: 83 | apparent_obs_size = self.brain.vector_observation_space_size * \ 84 | self.brain.num_stacked_vector_observations 85 | feed_dict[self.model.vector_in] = mini_batch['vector_obs'] \ 86 | .reshape([-1,apparent_obs_size]) 87 | for i, _ in enumerate(self.model.visual_in): 88 | visual_obs = mini_batch['visual_obs%d' % i] 89 | feed_dict[self.model.visual_in[i]] = visual_obs 90 | if self.use_recurrent: 91 | feed_dict[self.model.memory_in] = np.zeros([num_sequences, self.m_size]) 92 | run_out = self._execute_model(feed_dict, self.update_dict) 93 | return run_out 94 | -------------------------------------------------------------------------------- /ml-agents/mlagents/trainers/demo_loader.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import logging 3 | import os 4 | from mlagents.trainers.buffer import Buffer 5 | from mlagents.envs.brain import BrainParameters, BrainInfo 6 | from mlagents.envs.communicator_objects import * 7 | from google.protobuf.internal.decoder import _DecodeVarint32 8 | 9 | logger = logging.getLogger("mlagents.trainers") 10 | 11 | 12 | def make_demo_buffer(brain_infos, brain_params, sequence_length): 13 | # Create and populate buffer using experiences 14 | demo_buffer = Buffer() 15 | for idx, experience in enumerate(brain_infos): 16 | if idx > len(brain_infos) - 2: 17 | break 18 | current_brain_info = brain_infos[idx] 19 | next_brain_info = brain_infos[idx + 1] 20 | demo_buffer[0].last_brain_info = current_brain_info 21 | demo_buffer[0]['done'].append(next_brain_info.local_done[0]) 22 | demo_buffer[0]['rewards'].append(next_brain_info.rewards[0]) 23 | for i in range(brain_params.number_visual_observations): 24 | demo_buffer[0]['visual_obs%d' % i] \ 25 | .append(current_brain_info.visual_observations[i][0]) 26 | if brain_params.vector_observation_space_size > 0: 27 | demo_buffer[0]['vector_obs'] \ 28 | .append(current_brain_info.vector_observations[0]) 29 | demo_buffer[0]['actions'].append(next_brain_info.previous_vector_actions[0]) 30 | if next_brain_info.local_done[0]: 31 | demo_buffer.append_update_buffer(0, batch_size=None, 32 | training_length=sequence_length) 33 | demo_buffer.reset_local_buffers() 34 | demo_buffer.append_update_buffer(0, batch_size=None, 35 | training_length=sequence_length) 36 | return demo_buffer 37 | 38 | 39 | def demo_to_buffer(file_path, sequence_length): 40 | """ 41 | Loads demonstration file and uses it to fill training buffer. 42 | :param file_path: Location of demonstration file (.demo). 43 | :param sequence_length: Length of trajectories to fill buffer. 44 | :return: 45 | """ 46 | brain_params, brain_infos, _ = load_demonstration(file_path) 47 | demo_buffer = make_demo_buffer(brain_infos, brain_params, sequence_length) 48 | return brain_params, demo_buffer 49 | 50 | 51 | def load_demonstration(file_path): 52 | """ 53 | Loads and parses a demonstration file. 54 | :param file_path: Location of demonstration file (.demo). 55 | :return: BrainParameter and list of BrainInfos containing demonstration data. 56 | """ 57 | 58 | # First 32 bytes of file dedicated to meta-data. 59 | INITIAL_POS = 33 60 | 61 | if not os.path.isfile(file_path): 62 | raise FileNotFoundError("The demonstration file {} does not exist.".format(file_path)) 63 | file_extension = pathlib.Path(file_path).suffix 64 | if file_extension != '.demo': 65 | raise ValueError("The file is not a '.demo' file. Please provide a file with the " 66 | "correct extension.") 67 | 68 | brain_params = None 69 | brain_infos = [] 70 | data = open(file_path, "rb").read() 71 | next_pos, pos, obs_decoded = 0, 0, 0 72 | total_expected = 0 73 | while pos < len(data): 74 | next_pos, pos = _DecodeVarint32(data, pos) 75 | if obs_decoded == 0: 76 | meta_data_proto = DemonstrationMetaProto() 77 | meta_data_proto.ParseFromString(data[pos:pos + next_pos]) 78 | total_expected = meta_data_proto.number_steps 79 | pos = INITIAL_POS 80 | if obs_decoded == 1: 81 | brain_param_proto = BrainParametersProto() 82 | brain_param_proto.ParseFromString(data[pos:pos + next_pos]) 83 | brain_params = BrainParameters.from_proto(brain_param_proto) 84 | pos += next_pos 85 | if obs_decoded > 1: 86 | agent_info = AgentInfoProto() 87 | agent_info.ParseFromString(data[pos:pos + next_pos]) 88 | brain_info = BrainInfo.from_agent_proto([agent_info], brain_params) 89 | brain_infos.append(brain_info) 90 | if len(brain_infos) == total_expected: 91 | break 92 | pos += next_pos 93 | obs_decoded += 1 94 | return brain_params, brain_infos, total_expected 95 | -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/socket_communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import socket 3 | import struct 4 | 5 | from .communicator import Communicator 6 | from .communicator_objects import UnityMessage, UnityOutput, UnityInput 7 | from .exception import UnityTimeOutException 8 | 9 | 10 | logger = logging.getLogger("mlagents.envs") 11 | 12 | 13 | class SocketCommunicator(Communicator): 14 | def __init__(self, worker_id=0, 15 | base_port=5005): 16 | """ 17 | Python side of the socket communication 18 | 19 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 20 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 21 | """ 22 | 23 | self.port = base_port + worker_id 24 | self._buffer_size = 12000 25 | self.worker_id = worker_id 26 | self._socket = None 27 | self._conn = None 28 | 29 | def initialize(self, inputs: UnityInput) -> UnityOutput: 30 | try: 31 | # Establish communication socket 32 | self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 33 | self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 34 | self._socket.bind(("localhost", self.port)) 35 | except: 36 | raise UnityTimeOutException("Couldn't start socket communication because worker number {} is still in use. " 37 | "You may need to manually close a previously opened environment " 38 | "or use a different worker number.".format(str(self.worker_id))) 39 | try: 40 | self._socket.settimeout(30) 41 | self._socket.listen(1) 42 | self._conn, _ = self._socket.accept() 43 | self._conn.settimeout(30) 44 | except : 45 | raise UnityTimeOutException( 46 | "The Unity environment took too long to respond. Make sure that :\n" 47 | "\t The environment does not need user interaction to launch\n" 48 | "\t The Academy and the External Brain(s) are attached to objects in the Scene\n" 49 | "\t The environment and the Python interface have compatible versions.") 50 | message = UnityMessage() 51 | message.header.status = 200 52 | message.unity_input.CopyFrom(inputs) 53 | self._communicator_send(message.SerializeToString()) 54 | initialization_output = UnityMessage() 55 | initialization_output.ParseFromString(self._communicator_receive()) 56 | return initialization_output.unity_output 57 | 58 | def _communicator_receive(self): 59 | try: 60 | s = self._conn.recv(self._buffer_size) 61 | message_length = struct.unpack("I", bytearray(s[:4]))[0] 62 | s = s[4:] 63 | while len(s) != message_length: 64 | s += self._conn.recv(self._buffer_size) 65 | except socket.timeout as e: 66 | raise UnityTimeOutException("The environment took too long to respond.") 67 | return s 68 | 69 | def _communicator_send(self, message): 70 | self._conn.send(struct.pack("I", len(message)) + message) 71 | 72 | def exchange(self, inputs: UnityInput) -> UnityOutput: 73 | message = UnityMessage() 74 | message.header.status = 200 75 | message.unity_input.CopyFrom(inputs) 76 | self._communicator_send(message.SerializeToString()) 77 | outputs = UnityMessage() 78 | outputs.ParseFromString(self._communicator_receive()) 79 | if outputs.header.status != 200: 80 | return None 81 | return outputs.unity_output 82 | 83 | def close(self): 84 | """ 85 | Sends a shutdown signal to the unity environment, and closes the socket connection. 86 | """ 87 | if self._socket is not None and self._conn is not None: 88 | message_input = UnityMessage() 89 | message_input.header.status = 400 90 | self._communicator_send(message_input.SerializeToString()) 91 | if self._socket is not None: 92 | self._socket.close() 93 | self._socket = None 94 | if self._socket is not None: 95 | self._conn.close() 96 | self._conn = None 97 | 98 | -------------------------------------------------------------------------------- /ml-agents/tests/envs/test_envs.py: -------------------------------------------------------------------------------- 1 | import unittest.mock as mock 2 | import pytest 3 | import struct 4 | 5 | import numpy as np 6 | 7 | from mlagents.envs import UnityEnvironment, UnityEnvironmentException, UnityActionException, \ 8 | BrainInfo 9 | from tests.mock_communicator import MockCommunicator 10 | 11 | @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') 12 | def test_handles_bad_filename(get_communicator): 13 | with pytest.raises(UnityEnvironmentException): 14 | UnityEnvironment(' ') 15 | 16 | 17 | @mock.patch('mlagents.envs.UnityEnvironment.executable_launcher') 18 | @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') 19 | def test_initialization(mock_communicator, mock_launcher): 20 | mock_communicator.return_value = MockCommunicator( 21 | discrete_action=False, visual_inputs=0) 22 | env = UnityEnvironment(' ') 23 | with pytest.raises(UnityActionException): 24 | env.step([0]) 25 | assert env.brain_names[0] == 'RealFakeBrain' 26 | env.close() 27 | 28 | 29 | @mock.patch('mlagents.envs.UnityEnvironment.executable_launcher') 30 | @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') 31 | def test_reset(mock_communicator, mock_launcher): 32 | mock_communicator.return_value = MockCommunicator( 33 | discrete_action=False, visual_inputs=0) 34 | env = UnityEnvironment(' ') 35 | brain = env.brains['RealFakeBrain'] 36 | brain_info = env.reset() 37 | env.close() 38 | assert not env.global_done 39 | assert isinstance(brain_info, dict) 40 | assert isinstance(brain_info['RealFakeBrain'], BrainInfo) 41 | assert isinstance(brain_info['RealFakeBrain'].visual_observations, list) 42 | assert isinstance(brain_info['RealFakeBrain'].vector_observations, np.ndarray) 43 | assert len(brain_info['RealFakeBrain'].visual_observations) == brain.number_visual_observations 44 | assert brain_info['RealFakeBrain'].vector_observations.shape[0] == \ 45 | len(brain_info['RealFakeBrain'].agents) 46 | assert brain_info['RealFakeBrain'].vector_observations.shape[1] == \ 47 | brain.vector_observation_space_size * brain.num_stacked_vector_observations 48 | 49 | 50 | @mock.patch('mlagents.envs.UnityEnvironment.executable_launcher') 51 | @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') 52 | def test_step(mock_communicator, mock_launcher): 53 | mock_communicator.return_value = MockCommunicator( 54 | discrete_action=False, visual_inputs=0) 55 | env = UnityEnvironment(' ') 56 | brain = env.brains['RealFakeBrain'] 57 | brain_info = env.reset() 58 | brain_info = env.step([0] * brain.vector_action_space_size[0] * len(brain_info['RealFakeBrain'].agents)) 59 | with pytest.raises(UnityActionException): 60 | env.step([0]) 61 | brain_info = env.step([-1] * brain.vector_action_space_size[0] * len(brain_info['RealFakeBrain'].agents)) 62 | with pytest.raises(UnityActionException): 63 | env.step([0] * brain.vector_action_space_size[0] * len(brain_info['RealFakeBrain'].agents)) 64 | env.close() 65 | assert env.global_done 66 | assert isinstance(brain_info, dict) 67 | assert isinstance(brain_info['RealFakeBrain'], BrainInfo) 68 | assert isinstance(brain_info['RealFakeBrain'].visual_observations, list) 69 | assert isinstance(brain_info['RealFakeBrain'].vector_observations, np.ndarray) 70 | assert len(brain_info['RealFakeBrain'].visual_observations) == brain.number_visual_observations 71 | assert brain_info['RealFakeBrain'].vector_observations.shape[0] == \ 72 | len(brain_info['RealFakeBrain'].agents) 73 | assert brain_info['RealFakeBrain'].vector_observations.shape[1] == \ 74 | brain.vector_observation_space_size * brain.num_stacked_vector_observations 75 | 76 | print("\n\n\n\n\n\n\n" + str(brain_info['RealFakeBrain'].local_done)) 77 | assert not brain_info['RealFakeBrain'].local_done[0] 78 | assert brain_info['RealFakeBrain'].local_done[2] 79 | 80 | 81 | @mock.patch('mlagents.envs.UnityEnvironment.executable_launcher') 82 | @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') 83 | def test_close(mock_communicator, mock_launcher): 84 | comm = MockCommunicator( 85 | discrete_action=False, visual_inputs=0) 86 | mock_communicator.return_value = comm 87 | env = UnityEnvironment(' ') 88 | assert env._loaded 89 | env.close() 90 | assert not env._loaded 91 | assert comm.has_been_closed 92 | 93 | 94 | if __name__ == '__main__': 95 | pytest.main() 96 | -------------------------------------------------------------------------------- /learn_rl.py: -------------------------------------------------------------------------------- 1 | import json 2 | from mlagents.trainers.trainer_controller import TrainerController 3 | 4 | 5 | def extract_camera_config(config_file): 6 | 7 | config = [] 8 | with open(config_file, 'r') as data_file: 9 | data = json.load(data_file) 10 | params = data['parameters'] 11 | if params['camera1_type'][0] != 0: 12 | config += [{ 13 | "height": params['camera1_res_y'][0], 14 | "width": params['camera1_res_x'][0], 15 | "blackAndWhite": False 16 | }] 17 | if params['camera2_type'][0] != 0: 18 | config += [{ 19 | "height": params['camera2_res_y'][0], 20 | "width": params['camera2_res_x'][0], 21 | "blackAndWhite": False 22 | }] 23 | if params['camera3_type'][0] != 0: 24 | config += [{ 25 | "height": params['camera3_res_y'][0], 26 | "width": params['camera3_res_x'][0], 27 | "blackAndWhite": False 28 | }] 29 | 30 | return config 31 | 32 | def get_env_config(curriculum_file): 33 | 34 | try: 35 | with open(curriculum_file) as data_file: 36 | data = json.load(data_file) 37 | except IOError: 38 | raise IOError() 39 | 40 | config = {} 41 | parameters = data['parameters'] 42 | for key in parameters: 43 | config[key] = parameters[key][0] 44 | 45 | check_config_validity(config) 46 | 47 | return config 48 | 49 | def check_config_validity(config): 50 | 51 | if config['camera1_type'] < 0 or config['camera1_type'] > 6: 52 | raise ValueError('camera1_type') 53 | if config['camera2_type'] < 0 or config['camera2_type'] > 6: 54 | raise ValueError('camera2_type') 55 | if config['camera3_type'] < 0 or config['camera3_type'] > 6: 56 | raise ValueError('camera3_type') 57 | 58 | if config['camera1_type'] != 0 and (config['camera1_res_x'] == 0 or config['camera1_res_y'] == 0): 59 | raise ValueError('camera1_res') 60 | if config['camera2_type'] != 0 and (config['camera2_res_x'] == 0 or config['camera2_res_y'] == 0): 61 | raise ValueError('camera2_res') 62 | if config['camera3_type'] != 0 and (config['camera3_res_x'] == 0 or config['camera3_res_y'] == 0): 63 | raise ValueError('camera3_res') 64 | 65 | if config['weather_id'] < 0 or config['weather_id'] > 10: 66 | raise ValueError('weather_id') 67 | if config['time_id'] < 0 or config['time_id'] >= 24: 68 | raise ValueError('time_id') 69 | if config['road_width'] <= 0: 70 | raise ValueError('road_width') 71 | 72 | def main(): 73 | try: 74 | print(''' 75 | 76 | ▄▄▄▓▓▓▓ 77 | ╓▓▓▓▓▓▓█▓▓▓▓▓ 78 | ,▄▄▄m▀▀▀' ,▓▓▓▀▓▓▄ ▓▓▓ ▓▓▌ 79 | ▄▓▓▓▀' ▄▓▓▀ ▓▓▓ ▄▄ ▄▄ ,▄▄ ▄▄▄▄ ,▄▄ ▄▓▓▌▄ ▄▄▄ ,▄▄ 80 | ▄▓▓▓▀ ▄▓▓▀ ▐▓▓▌ ▓▓▌ ▐▓▓ ▐▓▓▓▀▀▀▓▓▌ ▓▓▓ ▀▓▓▌▀ ^▓▓▌ ╒▓▓▌ 81 | ▄▓▓▓▓▓▄▄▄▄▄▄▄▄▓▓▓ ▓▀ ▓▓▌ ▐▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▌ ▐▓▓▄ ▓▓▌ 82 | ▀▓▓▓▓▀▀▀▀▀▀▀▀▀▀▓▓▄ ▓▓ ▓▓▌ ▐▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▌ ▐▓▓▐▓▓ 83 | ^█▓▓▓ ▀▓▓▄ ▐▓▓▌ ▓▓▓▓▄▓▓▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▓▄ ▓▓▓▓` 84 | '▀▓▓▓▄ ^▓▓▓ ▓▓▓ └▀▀▀▀ ▀▀ ^▀▀ `▀▀ `▀▀ '▀▀ ▐▓▓▌ 85 | ▀▀▀▀▓▄▄▄ ▓▓▓▓▓▓, ▓▓▓▓▀ 86 | `▀█▓▓▓▓▓▓▓▓▓▌ 87 | ¬`▀▀▀█▓ 88 | 89 | ''') 90 | except: 91 | print('\n\n\tUnity Technologies\n') 92 | 93 | # Docker Parameters 94 | docker_target_name = None 95 | 96 | # General parameters 97 | env_path = 'AutoBenchExecutable/AutoBenchExecutable' 98 | #env_path = None 99 | run_id = '1' 100 | load_model = True 101 | train_model = True 102 | save_freq = 10000 103 | keep_checkpoints = 10000 104 | worker_id = 0 105 | run_seed = 0 106 | curriculum_folder = 'config/curricula/autobench/' 107 | curriculum_file = 'config/curricula/autobench/AutoBenchBrain.json' 108 | lesson = 0 109 | fast_simulation = True 110 | no_graphics = False 111 | trainer_config_path = 'config/trainer_config.yaml' 112 | camera_res_overwrite = extract_camera_config(curriculum_file) 113 | benchmark = False 114 | benchmark_episode = 100 115 | benchmark_verbose = True 116 | env_config = get_env_config(curriculum_file) 117 | 118 | # Create controller and launch environment. 119 | tc = TrainerController(env_path, run_id, 120 | save_freq, curriculum_folder, fast_simulation, 121 | load_model, train_model, worker_id, 122 | keep_checkpoints, lesson, run_seed, 123 | docker_target_name, trainer_config_path, no_graphics, 124 | camera_res_overwrite, benchmark, benchmark_episode, 125 | env_config['goal_reward'] + env_config['time_penalty'], benchmark_verbose) 126 | 127 | # Begin training 128 | tc.start_learning() 129 | 130 | if __name__ == '__main__': 131 | main() 132 | -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/communicator_objects/demonstration_meta_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: mlagents/envs/communicator_objects/demonstration_meta_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor.FileDescriptor( 18 | name='mlagents/envs/communicator_objects/demonstration_meta_proto.proto', 19 | package='communicator_objects', 20 | syntax='proto3', 21 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 22 | serialized_pb=_b('\nAmlagents/envs/communicator_objects/demonstration_meta_proto.proto\x12\x14\x63ommunicator_objects\"\x8d\x01\n\x16\x44\x65monstrationMetaProto\x12\x13\n\x0b\x61pi_version\x18\x01 \x01(\x05\x12\x1a\n\x12\x64\x65monstration_name\x18\x02 \x01(\t\x12\x14\n\x0cnumber_steps\x18\x03 \x01(\x05\x12\x17\n\x0fnumber_episodes\x18\x04 \x01(\x05\x12\x13\n\x0bmean_reward\x18\x05 \x01(\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _DEMONSTRATIONMETAPROTO = _descriptor.Descriptor( 29 | name='DemonstrationMetaProto', 30 | full_name='communicator_objects.DemonstrationMetaProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='api_version', full_name='communicator_objects.DemonstrationMetaProto.api_version', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | serialized_options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='demonstration_name', full_name='communicator_objects.DemonstrationMetaProto.demonstration_name', index=1, 44 | number=2, type=9, cpp_type=9, label=1, 45 | has_default_value=False, default_value=_b("").decode('utf-8'), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | serialized_options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='number_steps', full_name='communicator_objects.DemonstrationMetaProto.number_steps', index=2, 51 | number=3, type=5, cpp_type=1, label=1, 52 | has_default_value=False, default_value=0, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | serialized_options=None, file=DESCRIPTOR), 56 | _descriptor.FieldDescriptor( 57 | name='number_episodes', full_name='communicator_objects.DemonstrationMetaProto.number_episodes', index=3, 58 | number=4, type=5, cpp_type=1, label=1, 59 | has_default_value=False, default_value=0, 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | serialized_options=None, file=DESCRIPTOR), 63 | _descriptor.FieldDescriptor( 64 | name='mean_reward', full_name='communicator_objects.DemonstrationMetaProto.mean_reward', index=4, 65 | number=5, type=2, cpp_type=6, label=1, 66 | has_default_value=False, default_value=float(0), 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | serialized_options=None, file=DESCRIPTOR), 70 | ], 71 | extensions=[ 72 | ], 73 | nested_types=[], 74 | enum_types=[ 75 | ], 76 | serialized_options=None, 77 | is_extendable=False, 78 | syntax='proto3', 79 | extension_ranges=[], 80 | oneofs=[ 81 | ], 82 | serialized_start=92, 83 | serialized_end=233, 84 | ) 85 | 86 | DESCRIPTOR.message_types_by_name['DemonstrationMetaProto'] = _DEMONSTRATIONMETAPROTO 87 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 88 | 89 | DemonstrationMetaProto = _reflection.GeneratedProtocolMessageType('DemonstrationMetaProto', (_message.Message,), dict( 90 | DESCRIPTOR = _DEMONSTRATIONMETAPROTO, 91 | __module__ = 'mlagents.envs.communicator_objects.demonstration_meta_proto_pb2' 92 | # @@protoc_insertion_point(class_scope:communicator_objects.DemonstrationMetaProto) 93 | )) 94 | _sym_db.RegisterMessage(DemonstrationMetaProto) 95 | 96 | 97 | DESCRIPTOR._options = None 98 | # @@protoc_insertion_point(module_scope) 99 | -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/communicator_objects/unity_message_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: mlagents/envs/communicator_objects/unity_message.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from mlagents.envs.communicator_objects import unity_output_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__output__pb2 17 | from mlagents.envs.communicator_objects import unity_input_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__input__pb2 18 | from mlagents.envs.communicator_objects import header_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_header__pb2 19 | 20 | 21 | DESCRIPTOR = _descriptor.FileDescriptor( 22 | name='mlagents/envs/communicator_objects/unity_message.proto', 23 | package='communicator_objects', 24 | syntax='proto3', 25 | serialized_pb=_b('\n6mlagents/envs/communicator_objects/unity_message.proto\x12\x14\x63ommunicator_objects\x1a\x35mlagents/envs/communicator_objects/unity_output.proto\x1a\x34mlagents/envs/communicator_objects/unity_input.proto\x1a/mlagents/envs/communicator_objects/header.proto\"\xac\x01\n\x0cUnityMessage\x12,\n\x06header\x18\x01 \x01(\x0b\x32\x1c.communicator_objects.Header\x12\x37\n\x0cunity_output\x18\x02 \x01(\x0b\x32!.communicator_objects.UnityOutput\x12\x35\n\x0bunity_input\x18\x03 \x01(\x0b\x32 .communicator_objects.UnityInputB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 26 | , 27 | dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_unity__output__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_unity__input__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_header__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _UNITYMESSAGE = _descriptor.Descriptor( 33 | name='UnityMessage', 34 | full_name='communicator_objects.UnityMessage', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='header', full_name='communicator_objects.UnityMessage.header', index=0, 41 | number=1, type=11, cpp_type=10, label=1, 42 | has_default_value=False, default_value=None, 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None, file=DESCRIPTOR), 46 | _descriptor.FieldDescriptor( 47 | name='unity_output', full_name='communicator_objects.UnityMessage.unity_output', index=1, 48 | number=2, type=11, cpp_type=10, label=1, 49 | has_default_value=False, default_value=None, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | options=None, file=DESCRIPTOR), 53 | _descriptor.FieldDescriptor( 54 | name='unity_input', full_name='communicator_objects.UnityMessage.unity_input', index=2, 55 | number=3, type=11, cpp_type=10, label=1, 56 | has_default_value=False, default_value=None, 57 | message_type=None, enum_type=None, containing_type=None, 58 | is_extension=False, extension_scope=None, 59 | options=None, file=DESCRIPTOR), 60 | ], 61 | extensions=[ 62 | ], 63 | nested_types=[], 64 | enum_types=[ 65 | ], 66 | options=None, 67 | is_extendable=False, 68 | syntax='proto3', 69 | extension_ranges=[], 70 | oneofs=[ 71 | ], 72 | serialized_start=239, 73 | serialized_end=411, 74 | ) 75 | 76 | _UNITYMESSAGE.fields_by_name['header'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_header__pb2._HEADER 77 | _UNITYMESSAGE.fields_by_name['unity_output'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_unity__output__pb2._UNITYOUTPUT 78 | _UNITYMESSAGE.fields_by_name['unity_input'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_unity__input__pb2._UNITYINPUT 79 | DESCRIPTOR.message_types_by_name['UnityMessage'] = _UNITYMESSAGE 80 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 81 | 82 | UnityMessage = _reflection.GeneratedProtocolMessageType('UnityMessage', (_message.Message,), dict( 83 | DESCRIPTOR = _UNITYMESSAGE, 84 | __module__ = 'mlagents.envs.communicator_objects.unity_message_pb2' 85 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityMessage) 86 | )) 87 | _sym_db.RegisterMessage(UnityMessage) 88 | 89 | 90 | DESCRIPTOR.has_options = True 91 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 92 | # @@protoc_insertion_point(module_scope) 93 | -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/rpc_communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import grpc 3 | 4 | import socket 5 | from multiprocessing import Pipe 6 | from concurrent.futures import ThreadPoolExecutor 7 | 8 | from .communicator import Communicator 9 | from .communicator_objects import UnityToExternalServicer, add_UnityToExternalServicer_to_server 10 | from .communicator_objects import UnityMessage, UnityInput, UnityOutput 11 | from .exception import UnityTimeOutException, UnityWorkerInUseException 12 | 13 | logger = logging.getLogger("mlagents.envs") 14 | 15 | 16 | class UnityToExternalServicerImplementation(UnityToExternalServicer): 17 | def __init__(self): 18 | self.parent_conn, self.child_conn = Pipe() 19 | 20 | def Initialize(self, request, context): 21 | self.child_conn.send(request) 22 | return self.child_conn.recv() 23 | 24 | def Exchange(self, request, context): 25 | self.child_conn.send(request) 26 | return self.child_conn.recv() 27 | 28 | 29 | class RpcCommunicator(Communicator): 30 | def __init__(self, worker_id=0, base_port=5005): 31 | """ 32 | Python side of the grpc communication. Python is the server and Unity the client 33 | 34 | 35 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 36 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 37 | """ 38 | self.port = base_port + worker_id 39 | self.worker_id = worker_id 40 | self.server = None 41 | self.unity_to_external = None 42 | self.is_open = False 43 | self.create_server() 44 | 45 | def create_server(self): 46 | """ 47 | Creates the GRPC server. 48 | """ 49 | self.check_port(self.port) 50 | 51 | try: 52 | # Establish communication grpc 53 | self.server = grpc.server(ThreadPoolExecutor(max_workers=10)) 54 | self.unity_to_external = UnityToExternalServicerImplementation() 55 | add_UnityToExternalServicer_to_server(self.unity_to_external, self.server) 56 | # Using unspecified address, which means that grpc is communicating on all IPs 57 | # This is so that the docker container can connect. 58 | self.server.add_insecure_port('[::]:' + str(self.port)) 59 | self.server.start() 60 | self.is_open = True 61 | except: 62 | raise UnityWorkerInUseException(self.worker_id) 63 | 64 | def check_port(self, port): 65 | """ 66 | Attempts to bind to the requested communicator port, checking if it is already in use. 67 | """ 68 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 69 | try: 70 | s.bind(("localhost", port)) 71 | except socket.error: 72 | raise UnityWorkerInUseException(self.worker_id) 73 | finally: 74 | s.close() 75 | 76 | def initialize(self, inputs: UnityInput) -> UnityOutput: 77 | if not self.unity_to_external.parent_conn.poll(30): 78 | raise UnityTimeOutException( 79 | "The Unity environment took too long to respond. Make sure that :\n" 80 | "\t The environment does not need user interaction to launch\n" 81 | "\t The Academy and the External Brain(s) are attached to objects in the Scene\n" 82 | "\t The environment and the Python interface have compatible versions.") 83 | aca_param = self.unity_to_external.parent_conn.recv().unity_output 84 | message = UnityMessage() 85 | message.header.status = 200 86 | message.unity_input.CopyFrom(inputs) 87 | self.unity_to_external.parent_conn.send(message) 88 | self.unity_to_external.parent_conn.recv() 89 | return aca_param 90 | 91 | def exchange(self, inputs: UnityInput) -> UnityOutput: 92 | message = UnityMessage() 93 | message.header.status = 200 94 | message.unity_input.CopyFrom(inputs) 95 | self.unity_to_external.parent_conn.send(message) 96 | output = self.unity_to_external.parent_conn.recv() 97 | if output.header.status != 200: 98 | return None 99 | return output.unity_output 100 | 101 | def close(self): 102 | """ 103 | Sends a shutdown signal to the unity environment, and closes the grpc connection. 104 | """ 105 | if self.is_open: 106 | message_input = UnityMessage() 107 | message_input.header.status = 400 108 | self.unity_to_external.parent_conn.send(message_input) 109 | self.unity_to_external.parent_conn.close() 110 | self.server.stop(False) 111 | self.is_open = False 112 | -------------------------------------------------------------------------------- /ml-agents/mlagents/trainers/curriculum.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | 5 | from .exception import CurriculumError 6 | 7 | import logging 8 | 9 | logger = logging.getLogger('mlagents.trainers') 10 | 11 | 12 | class Curriculum(object): 13 | def __init__(self, location, default_reset_parameters): 14 | """ 15 | Initializes a Curriculum object. 16 | :param location: Path to JSON defining curriculum. 17 | :param default_reset_parameters: Set of reset parameters for 18 | environment. 19 | """ 20 | self.max_lesson_num = 0 21 | self.measure = None 22 | self._lesson_num = 0 23 | # The name of the brain should be the basename of the file without the 24 | # extension. 25 | self._brain_name = os.path.basename(location).split('.')[0] 26 | 27 | try: 28 | with open(location) as data_file: 29 | self.data = json.load(data_file) 30 | except IOError: 31 | raise CurriculumError( 32 | 'The file {0} could not be found.'.format(location)) 33 | except UnicodeDecodeError: 34 | raise CurriculumError('There was an error decoding {}' 35 | .format(location)) 36 | self.smoothing_value = 0 37 | for key in ['parameters', 'measure', 'thresholds', 38 | 'min_lesson_length', 'signal_smoothing']: 39 | if key not in self.data: 40 | raise CurriculumError("{0} does not contain a " 41 | "{1} field." 42 | .format(location, key)) 43 | self.smoothing_value = 0 44 | self.measure = self.data['measure'] 45 | self.min_lesson_length = self.data['min_lesson_length'] 46 | self.max_lesson_num = len(self.data['thresholds']) 47 | 48 | parameters = self.data['parameters'] 49 | for key in parameters: 50 | if key not in default_reset_parameters: 51 | raise CurriculumError( 52 | 'The parameter {0} in Curriculum {1} is not present in ' 53 | 'the Environment'.format(key, location)) 54 | if len(parameters[key]) != self.max_lesson_num + 1: 55 | raise CurriculumError( 56 | 'The parameter {0} in Curriculum {1} must have {2} values ' 57 | 'but {3} were found'.format(key, location, 58 | self.max_lesson_num + 1, 59 | len(parameters[key]))) 60 | 61 | @property 62 | def lesson_num(self): 63 | return self._lesson_num 64 | 65 | @lesson_num.setter 66 | def lesson_num(self, lesson_num): 67 | self._lesson_num = max(0, min(lesson_num, self.max_lesson_num)) 68 | 69 | def increment_lesson(self, measure_val): 70 | """ 71 | Increments the lesson number depending on the progress given. 72 | :param measure_val: Measure of progress (either reward or percentage 73 | steps completed). 74 | :return Whether the lesson was incremented. 75 | """ 76 | if not self.data or not measure_val or math.isnan(measure_val): 77 | return False 78 | if self.data['signal_smoothing']: 79 | measure_val = self.smoothing_value * 0.25 + 0.75 * measure_val 80 | self.smoothing_value = measure_val 81 | if self.lesson_num < self.max_lesson_num: 82 | if measure_val > self.data['thresholds'][self.lesson_num]: 83 | self.lesson_num += 1 84 | config = {} 85 | parameters = self.data['parameters'] 86 | for key in parameters: 87 | config[key] = parameters[key][self.lesson_num] 88 | logger.info('{0} lesson changed. Now in lesson {1}: {2}' 89 | .format(self._brain_name, 90 | self.lesson_num, 91 | ', '.join([str(x) + ' -> ' + str(config[x]) 92 | for x in config]))) 93 | return True 94 | return False 95 | 96 | def get_config(self, lesson=None): 97 | """ 98 | Returns reset parameters which correspond to the lesson. 99 | :param lesson: The lesson you want to get the config of. If None, the 100 | current lesson is returned. 101 | :return: The configuration of the reset parameters. 102 | """ 103 | if not self.data: 104 | return {} 105 | if lesson is None: 106 | lesson = self.lesson_num 107 | lesson = max(0, min(lesson, self.max_lesson_num)) 108 | config = {} 109 | parameters = self.data['parameters'] 110 | for key in parameters: 111 | config[key] = parameters[key][lesson] 112 | return config 113 | -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/communicator_objects/engine_configuration_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: mlagents/envs/communicator_objects/engine_configuration_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='mlagents/envs/communicator_objects/engine_configuration_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\nCmlagents/envs/communicator_objects/engine_configuration_proto.proto\x12\x14\x63ommunicator_objects\"\x95\x01\n\x18\x45ngineConfigurationProto\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x15\n\rquality_level\x18\x03 \x01(\x05\x12\x12\n\ntime_scale\x18\x04 \x01(\x02\x12\x19\n\x11target_frame_rate\x18\x05 \x01(\x05\x12\x14\n\x0cshow_monitor\x18\x06 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _ENGINECONFIGURATIONPROTO = _descriptor.Descriptor( 29 | name='EngineConfigurationProto', 30 | full_name='communicator_objects.EngineConfigurationProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='width', full_name='communicator_objects.EngineConfigurationProto.width', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='height', full_name='communicator_objects.EngineConfigurationProto.height', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='quality_level', full_name='communicator_objects.EngineConfigurationProto.quality_level', index=2, 51 | number=3, type=5, cpp_type=1, label=1, 52 | has_default_value=False, default_value=0, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | _descriptor.FieldDescriptor( 57 | name='time_scale', full_name='communicator_objects.EngineConfigurationProto.time_scale', index=3, 58 | number=4, type=2, cpp_type=6, label=1, 59 | has_default_value=False, default_value=float(0), 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None, file=DESCRIPTOR), 63 | _descriptor.FieldDescriptor( 64 | name='target_frame_rate', full_name='communicator_objects.EngineConfigurationProto.target_frame_rate', index=4, 65 | number=5, type=5, cpp_type=1, label=1, 66 | has_default_value=False, default_value=0, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None, file=DESCRIPTOR), 70 | _descriptor.FieldDescriptor( 71 | name='show_monitor', full_name='communicator_objects.EngineConfigurationProto.show_monitor', index=5, 72 | number=6, type=8, cpp_type=7, label=1, 73 | has_default_value=False, default_value=False, 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None, file=DESCRIPTOR), 77 | ], 78 | extensions=[ 79 | ], 80 | nested_types=[], 81 | enum_types=[ 82 | ], 83 | options=None, 84 | is_extendable=False, 85 | syntax='proto3', 86 | extension_ranges=[], 87 | oneofs=[ 88 | ], 89 | serialized_start=94, 90 | serialized_end=243, 91 | ) 92 | 93 | DESCRIPTOR.message_types_by_name['EngineConfigurationProto'] = _ENGINECONFIGURATIONPROTO 94 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 95 | 96 | EngineConfigurationProto = _reflection.GeneratedProtocolMessageType('EngineConfigurationProto', (_message.Message,), dict( 97 | DESCRIPTOR = _ENGINECONFIGURATIONPROTO, 98 | __module__ = 'mlagents.envs.communicator_objects.engine_configuration_proto_pb2' 99 | # @@protoc_insertion_point(class_scope:communicator_objects.EngineConfigurationProto) 100 | )) 101 | _sym_db.RegisterMessage(EngineConfigurationProto) 102 | 103 | 104 | DESCRIPTOR.has_options = True 105 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 106 | # @@protoc_insertion_point(module_scope) 107 | -------------------------------------------------------------------------------- /ml-agents/tests/trainers/test_meta_curriculum.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest.mock import patch, call, Mock 3 | 4 | from mlagents.trainers.meta_curriculum import MetaCurriculum 5 | from mlagents.trainers.exception import MetaCurriculumError 6 | 7 | 8 | class MetaCurriculumTest(MetaCurriculum): 9 | """This class allows us to test MetaCurriculum objects without calling 10 | MetaCurriculum's __init__ function. 11 | """ 12 | def __init__(self, brains_to_curriculums): 13 | self._brains_to_curriculums = brains_to_curriculums 14 | 15 | 16 | @pytest.fixture 17 | def default_reset_parameters(): 18 | return {'param1' : 1, 'param2' : 2, 'param3' : 3} 19 | 20 | 21 | @pytest.fixture 22 | def more_reset_parameters(): 23 | return {'param4' : 4, 'param5' : 5, 'param6' : 6} 24 | 25 | 26 | @pytest.fixture 27 | def measure_vals(): 28 | return {'Brain1' : 0.2, 'Brain2' : 0.3} 29 | 30 | 31 | @pytest.fixture 32 | def reward_buff_sizes(): 33 | return {'Brain1' : 7, 'Brain2' : 8} 34 | 35 | 36 | @patch('mlagents.trainers.Curriculum.get_config', return_value={}) 37 | @patch('mlagents.trainers.Curriculum.__init__', return_value=None) 38 | @patch('os.listdir', return_value=['Brain1.json', 'Brain2.json']) 39 | def test_init_meta_curriculum_happy_path(listdir, mock_curriculum_init, 40 | mock_curriculum_get_config, 41 | default_reset_parameters): 42 | meta_curriculum = MetaCurriculum('test/', default_reset_parameters) 43 | 44 | assert len(meta_curriculum.brains_to_curriculums) == 2 45 | 46 | assert 'Brain1' in meta_curriculum.brains_to_curriculums 47 | assert 'Brain2' in meta_curriculum.brains_to_curriculums 48 | 49 | calls = [call('test/Brain1.json', default_reset_parameters), 50 | call('test/Brain2.json', default_reset_parameters)] 51 | 52 | mock_curriculum_init.assert_has_calls(calls) 53 | 54 | 55 | @patch('os.listdir', side_effect=NotADirectoryError()) 56 | def test_init_meta_curriculum_bad_curriculum_folder_raises_error(listdir): 57 | with pytest.raises(MetaCurriculumError): 58 | MetaCurriculum('test/', default_reset_parameters) 59 | 60 | 61 | @patch('mlagents.trainers.Curriculum') 62 | @patch('mlagents.trainers.Curriculum') 63 | def test_set_lesson_nums(curriculum_a, curriculum_b): 64 | meta_curriculum = MetaCurriculumTest({'Brain1' : curriculum_a, 65 | 'Brain2' : curriculum_b}) 66 | 67 | meta_curriculum.lesson_nums = {'Brain1' : 1, 'Brain2' : 3} 68 | 69 | assert curriculum_a.lesson_num == 1 70 | assert curriculum_b.lesson_num == 3 71 | 72 | 73 | @patch('mlagents.trainers.Curriculum') 74 | @patch('mlagents.trainers.Curriculum') 75 | def test_increment_lessons(curriculum_a, curriculum_b, measure_vals): 76 | meta_curriculum = MetaCurriculumTest({'Brain1' : curriculum_a, 77 | 'Brain2' : curriculum_b}) 78 | 79 | meta_curriculum.increment_lessons(measure_vals) 80 | 81 | curriculum_a.increment_lesson.assert_called_with(0.2) 82 | curriculum_b.increment_lesson.assert_called_with(0.3) 83 | 84 | 85 | @patch('mlagents.trainers.Curriculum') 86 | @patch('mlagents.trainers.Curriculum') 87 | def test_increment_lessons_with_reward_buff_sizes(curriculum_a, curriculum_b, 88 | measure_vals, 89 | reward_buff_sizes): 90 | curriculum_a.min_lesson_length = 5 91 | curriculum_b.min_lesson_length = 10 92 | meta_curriculum = MetaCurriculumTest({'Brain1' : curriculum_a, 93 | 'Brain2' : curriculum_b}) 94 | 95 | meta_curriculum.increment_lessons(measure_vals, 96 | reward_buff_sizes=reward_buff_sizes) 97 | 98 | curriculum_a.increment_lesson.assert_called_with(0.2) 99 | curriculum_b.increment_lesson.assert_not_called() 100 | 101 | 102 | @patch('mlagents.trainers.Curriculum') 103 | @patch('mlagents.trainers.Curriculum') 104 | def test_set_all_curriculums_to_lesson_num(curriculum_a, curriculum_b): 105 | meta_curriculum = MetaCurriculumTest({'Brain1' : curriculum_a, 106 | 'Brain2' : curriculum_b}) 107 | 108 | meta_curriculum.set_all_curriculums_to_lesson_num(2) 109 | 110 | assert curriculum_a.lesson_num == 2 111 | assert curriculum_b.lesson_num == 2 112 | 113 | 114 | @patch('mlagents.trainers.Curriculum') 115 | @patch('mlagents.trainers.Curriculum') 116 | def test_get_config(curriculum_a, curriculum_b, default_reset_parameters, 117 | more_reset_parameters): 118 | curriculum_a.get_config.return_value = default_reset_parameters 119 | curriculum_b.get_config.return_value = default_reset_parameters 120 | meta_curriculum = MetaCurriculumTest({'Brain1' : curriculum_a, 121 | 'Brain2' : curriculum_b}) 122 | 123 | assert meta_curriculum.get_config() == default_reset_parameters 124 | 125 | curriculum_b.get_config.return_value = more_reset_parameters 126 | 127 | new_reset_parameters = dict(default_reset_parameters) 128 | new_reset_parameters.update(more_reset_parameters) 129 | 130 | assert meta_curriculum.get_config() == new_reset_parameters 131 | -------------------------------------------------------------------------------- /ml-agents/mlagents/trainers/learn.py: -------------------------------------------------------------------------------- 1 | # # Unity ML-Agents Toolkit 2 | 3 | import logging 4 | 5 | from multiprocessing import Process, Queue 6 | import numpy as np 7 | from docopt import docopt 8 | 9 | from mlagents.trainers.trainer_controller import TrainerController 10 | from mlagents.trainers.exception import TrainerError 11 | 12 | 13 | def run_training(sub_id, run_seed, run_options, process_queue): 14 | """ 15 | Launches training session. 16 | :param process_queue: Queue used to send signal back to main. 17 | :param sub_id: Unique id for training session. 18 | :param run_seed: Random seed used for training. 19 | :param run_options: Command line arguments for training. 20 | """ 21 | # Docker Parameters 22 | docker_target_name = (run_options['--docker-target-name'] 23 | if run_options['--docker-target-name'] != 'None' else None) 24 | 25 | # General parameters 26 | env_path = None 27 | run_id = run_options['--run-id'] 28 | load_model = run_options['--load'] 29 | train_model = run_options['--train'] 30 | save_freq = int(run_options['--save-freq']) 31 | keep_checkpoints = int(run_options['--keep-checkpoints']) 32 | worker_id = 0 33 | curriculum_file = None 34 | lesson = 0 35 | fast_simulation = False 36 | no_graphics = False 37 | trainer_config_path = run_options[''] 38 | 39 | # Create controller and launch environment. 40 | tc = TrainerController(env_path, run_id + '-' + str(sub_id), 41 | save_freq, curriculum_file, fast_simulation, 42 | load_model, train_model, worker_id + sub_id, 43 | keep_checkpoints, lesson, run_seed, 44 | docker_target_name, trainer_config_path, no_graphics) 45 | 46 | # Signal that environment has been launched. 47 | # process_queue.put(True) 48 | 49 | # Begin training 50 | tc.start_learning() 51 | 52 | 53 | def main(): 54 | try: 55 | print(''' 56 | 57 | ▄▄▄▓▓▓▓ 58 | ╓▓▓▓▓▓▓█▓▓▓▓▓ 59 | ,▄▄▄m▀▀▀' ,▓▓▓▀▓▓▄ ▓▓▓ ▓▓▌ 60 | ▄▓▓▓▀' ▄▓▓▀ ▓▓▓ ▄▄ ▄▄ ,▄▄ ▄▄▄▄ ,▄▄ ▄▓▓▌▄ ▄▄▄ ,▄▄ 61 | ▄▓▓▓▀ ▄▓▓▀ ▐▓▓▌ ▓▓▌ ▐▓▓ ▐▓▓▓▀▀▀▓▓▌ ▓▓▓ ▀▓▓▌▀ ^▓▓▌ ╒▓▓▌ 62 | ▄▓▓▓▓▓▄▄▄▄▄▄▄▄▓▓▓ ▓▀ ▓▓▌ ▐▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▌ ▐▓▓▄ ▓▓▌ 63 | ▀▓▓▓▓▀▀▀▀▀▀▀▀▀▀▓▓▄ ▓▓ ▓▓▌ ▐▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▌ ▐▓▓▐▓▓ 64 | ^█▓▓▓ ▀▓▓▄ ▐▓▓▌ ▓▓▓▓▄▓▓▓▓ ▐▓▓ ▓▓▓ ▓▓▓ ▓▓▓▄ ▓▓▓▓` 65 | '▀▓▓▓▄ ^▓▓▓ ▓▓▓ └▀▀▀▀ ▀▀ ^▀▀ `▀▀ `▀▀ '▀▀ ▐▓▓▌ 66 | ▀▀▀▀▓▄▄▄ ▓▓▓▓▓▓, ▓▓▓▓▀ 67 | `▀█▓▓▓▓▓▓▓▓▓▌ 68 | ¬`▀▀▀█▓ 69 | ''') 70 | except: 71 | print('\n\n\tUnity Technologies\n') 72 | 73 | logger = logging.getLogger('mlagents.trainers') 74 | _USAGE = ''' 75 | Usage: 76 | mlagents-learn [options] 77 | mlagents-learn --help 78 | 79 | Options: 80 | --env= Name of the Unity executable [default: None]. 81 | --curriculum= Curriculum json directory for environment [default: None]. 82 | --keep-checkpoints= How many model checkpoints to keep [default: 5]. 83 | --lesson= Start learning from this lesson [default: 0]. 84 | --load Whether to load the model or randomly initialize [default: False]. 85 | --run-id= The directory name for model and summary statistics [default: ppo]. 86 | --num-runs= Number of concurrent training sessions [default: 1]. 87 | --save-freq= Frequency at which to save model [default: 50000]. 88 | --seed= Random seed used for training [default: -1]. 89 | --slow Whether to run the game at training speed [default: False]. 90 | --train Whether to train model, or only run inference [default: False]. 91 | --worker-id= Number to add to communication port (5005) [default: 0]. 92 | --docker-target-name=
Docker volume to store training-specific files [default: None]. 93 | --no-graphics Whether to run the environment in no-graphics mode [default: False]. 94 | ''' 95 | 96 | options = docopt(_USAGE) 97 | logger.info(options) 98 | num_runs = int(options['--num-runs']) 99 | seed = int(options['--seed']) 100 | 101 | if options['--env'] == 'None' and num_runs > 1: 102 | raise TrainerError('It is not possible to launch more than one concurrent training session ' 103 | 'when training from the editor.') 104 | 105 | jobs = [] 106 | run_seed = seed 107 | 108 | # options['--env'] = None 109 | # options[] 110 | 111 | run_training(1,run_seed,options,None) 112 | 113 | # for i in range(num_runs): 114 | # if seed == -1: 115 | # run_seed = np.random.randint(0, 10000) 116 | # process_queue = Queue() 117 | # p = Process(target=run_training, args=(i, run_seed, options, process_queue)) 118 | # jobs.append(p) 119 | # p.start() 120 | # # Wait for signal that environment has successfully launched 121 | # while process_queue.get() is not True: 122 | # continue 123 | 124 | if __name__ == '__main__': 125 | main() -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/communicator_objects/environment_parameters_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: mlagents/envs/communicator_objects/environment_parameters_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='mlagents/envs/communicator_objects/environment_parameters_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\nEmlagents/envs/communicator_objects/environment_parameters_proto.proto\x12\x14\x63ommunicator_objects\"\xb5\x01\n\x1a\x45nvironmentParametersProto\x12_\n\x10\x66loat_parameters\x18\x01 \x03(\x0b\x32\x45.communicator_objects.EnvironmentParametersProto.FloatParametersEntry\x1a\x36\n\x14\x46loatParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY = _descriptor.Descriptor( 29 | name='FloatParametersEntry', 30 | full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='key', full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry.key', index=0, 37 | number=1, type=9, cpp_type=9, label=1, 38 | has_default_value=False, default_value=_b("").decode('utf-8'), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='value', full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry.value', index=1, 44 | number=2, type=2, cpp_type=6, label=1, 45 | has_default_value=False, default_value=float(0), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | ], 50 | extensions=[ 51 | ], 52 | nested_types=[], 53 | enum_types=[ 54 | ], 55 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), 56 | is_extendable=False, 57 | syntax='proto3', 58 | extension_ranges=[], 59 | oneofs=[ 60 | ], 61 | serialized_start=223, 62 | serialized_end=277, 63 | ) 64 | 65 | _ENVIRONMENTPARAMETERSPROTO = _descriptor.Descriptor( 66 | name='EnvironmentParametersProto', 67 | full_name='communicator_objects.EnvironmentParametersProto', 68 | filename=None, 69 | file=DESCRIPTOR, 70 | containing_type=None, 71 | fields=[ 72 | _descriptor.FieldDescriptor( 73 | name='float_parameters', full_name='communicator_objects.EnvironmentParametersProto.float_parameters', index=0, 74 | number=1, type=11, cpp_type=10, label=3, 75 | has_default_value=False, default_value=[], 76 | message_type=None, enum_type=None, containing_type=None, 77 | is_extension=False, extension_scope=None, 78 | options=None, file=DESCRIPTOR), 79 | ], 80 | extensions=[ 81 | ], 82 | nested_types=[_ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY, ], 83 | enum_types=[ 84 | ], 85 | options=None, 86 | is_extendable=False, 87 | syntax='proto3', 88 | extension_ranges=[], 89 | oneofs=[ 90 | ], 91 | serialized_start=96, 92 | serialized_end=277, 93 | ) 94 | 95 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY.containing_type = _ENVIRONMENTPARAMETERSPROTO 96 | _ENVIRONMENTPARAMETERSPROTO.fields_by_name['float_parameters'].message_type = _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY 97 | DESCRIPTOR.message_types_by_name['EnvironmentParametersProto'] = _ENVIRONMENTPARAMETERSPROTO 98 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 99 | 100 | EnvironmentParametersProto = _reflection.GeneratedProtocolMessageType('EnvironmentParametersProto', (_message.Message,), dict( 101 | 102 | FloatParametersEntry = _reflection.GeneratedProtocolMessageType('FloatParametersEntry', (_message.Message,), dict( 103 | DESCRIPTOR = _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY, 104 | __module__ = 'mlagents.envs.communicator_objects.environment_parameters_proto_pb2' 105 | # @@protoc_insertion_point(class_scope:communicator_objects.EnvironmentParametersProto.FloatParametersEntry) 106 | )) 107 | , 108 | DESCRIPTOR = _ENVIRONMENTPARAMETERSPROTO, 109 | __module__ = 'mlagents.envs.communicator_objects.environment_parameters_proto_pb2' 110 | # @@protoc_insertion_point(class_scope:communicator_objects.EnvironmentParametersProto) 111 | )) 112 | _sym_db.RegisterMessage(EnvironmentParametersProto) 113 | _sym_db.RegisterMessage(EnvironmentParametersProto.FloatParametersEntry) 114 | 115 | 116 | DESCRIPTOR.has_options = True 117 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 118 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY.has_options = True 119 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) 120 | # @@protoc_insertion_point(module_scope) 121 | -------------------------------------------------------------------------------- /learn_ml.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import numpy as np 4 | from mlagents.envs.environment import UnityEnvironment 5 | from mlagents.envs.brain import BrainInfo 6 | from mlagents.trainers.benchmark import BenchmarkManager 7 | 8 | 9 | def extract_camera_config(config_file): 10 | 11 | config = [] 12 | with open(config_file, 'r') as data_file: 13 | data = json.load(data_file) 14 | params = data['parameters'] 15 | if params['camera1_type'][0] != 0: 16 | config += [{ 17 | "height": params['camera1_res_y'][0], 18 | "width": params['camera1_res_x'][0], 19 | "blackAndWhite": False 20 | }] 21 | if params['camera2_type'][0] != 0: 22 | config += [{ 23 | "height": params['camera2_res_y'][0], 24 | "width": params['camera2_res_x'][0], 25 | "blackAndWhite": False 26 | }] 27 | if params['camera3_type'][0] != 0: 28 | config += [{ 29 | "height": params['camera3_res_y'][0], 30 | "width": params['camera3_res_x'][0], 31 | "blackAndWhite": False 32 | }] 33 | 34 | return config 35 | 36 | def get_env_config(curriculum_folder): 37 | 38 | try: 39 | with open(curriculum_folder) as data_file: 40 | data = json.load(data_file) 41 | except IOError: 42 | raise IOError() 43 | 44 | config = {} 45 | parameters = data['parameters'] 46 | for key in parameters: 47 | config[key] = parameters[key][0] 48 | 49 | check_config_validity(config) 50 | 51 | return config 52 | 53 | def check_config_validity(config): 54 | 55 | if config['camera1_type'] < 0 or config['camera1_type'] > 6: 56 | raise ValueError('camera1_type') 57 | if config['camera2_type'] < 0 or config['camera2_type'] > 6: 58 | raise ValueError('camera2_type') 59 | if config['camera3_type'] < 0 or config['camera3_type'] > 6: 60 | raise ValueError('camera3_type') 61 | 62 | if config['camera1_type'] != 0 and (config['camera1_res_x'] == 0 or config['camera1_res_y'] == 0): 63 | raise ValueError('camera1_res') 64 | if config['camera2_type'] != 0 and (config['camera2_res_x'] == 0 or config['camera2_res_y'] == 0): 65 | raise ValueError('camera2_res') 66 | if config['camera3_type'] != 0 and (config['camera3_res_x'] == 0 or config['camera3_res_y'] == 0): 67 | raise ValueError('camera3_res') 68 | 69 | if config['weather_id'] < 0 or config['weather_id'] > 10: 70 | raise ValueError('weather_id') 71 | if config['time_id'] < 0 or config['time_id'] >= 24: 72 | raise ValueError('time_id') 73 | if config['road_width'] <= 0: 74 | raise ValueError('road_width') 75 | 76 | # Reference to brain.py, BrainInfo, for the commented data structure details 77 | def decide(brain_info: BrainInfo): 78 | 79 | # Given the brain_info, decide which action for each agent, 80 | 81 | ################################## 82 | # Action Index 83 | # 84 | # +1 0 -0.3 throttle power 85 | # -30 0 1 2 86 | # 0 3 4 5 87 | # +30 6 7 8 88 | # steering 89 | # angle 90 | # 91 | ################################## 92 | 93 | # Sample placeholder, type = ndarray, shape = (agent_amount, 1) 94 | sample_action = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 0]).transpose() 95 | 96 | # sample_action = np.array([5]).transpose() 97 | # TODO Design your own decide algorithm here 98 | 99 | return sample_action 100 | 101 | def main(): 102 | 103 | env_path = 'AutoBenchExecutable/AutoBenchExecutable' 104 | # env_path = None 105 | curriculum_file = 'config/curricula/autobench/AutoBenchBrain.json' 106 | no_graphics = False # True if you want Unity Environment to train in background 107 | max_step = 1e10 # total training step 108 | 109 | # Set True: 100x time scale, small window, 10 agents 110 | # Set False: 1x time scale, big window, 1 agents with observation camera 111 | fast_simulation = False 112 | benchmark = False 113 | benchmark_episode = 1000 114 | 115 | 116 | # Setup the Unity Environment 117 | env_config = get_env_config(curriculum_file) 118 | env = UnityEnvironment(file_name=env_path, 119 | no_graphics=no_graphics, 120 | camera_res_overwrite=extract_camera_config(curriculum_file)) 121 | brain_name = env.brain_names[0] # Get brain_name, assume only have 1 brain 122 | 123 | curr_info = env.reset(config=env_config,train_mode=fast_simulation)[brain_name] 124 | agent_size = len(curr_info.agents) 125 | 126 | BenchmarkManager(agent_amount=agent_size, benchmark_episode=benchmark_episode, 127 | success_threshold=env_config['goal_reward']+env_config['time_penalty'], 128 | verbose=False) 129 | 130 | last_update_time = time.clock() 131 | 132 | ### Standard RL training loop 133 | for global_step in range(int(max_step)): 134 | 135 | # Implement your own decide algorithm 136 | action = decide(curr_info) 137 | 138 | # Send Action into Unity Environment and return new_info, type = BrainInfo 139 | # You can refer to brain.py, I've commented the detail data structure 140 | new_info = env.step(vector_action={brain_name: action}, 141 | memory={brain_name: None}, 142 | text_action={brain_name: None})[brain_name] 143 | if benchmark: 144 | BenchmarkManager.add_result(new_info) 145 | if BenchmarkManager.is_complete(): 146 | BenchmarkManager.analyze() 147 | break 148 | 149 | # Calculate and Print training speed 150 | if global_step % 100 == 0: 151 | print("Steps:{:,}".format(global_step), " || Speed:", 152 | format(100 / (time.clock() - last_update_time), ".2f")) 153 | last_update_time = time.clock() 154 | 155 | # Assign new_info to curr_info for next timestep training 156 | curr_info = new_info 157 | 158 | env.close() 159 | 160 | 161 | if __name__ == '__main__': 162 | main() -------------------------------------------------------------------------------- /ml-agents/mlagents/envs/communicator_objects/unity_rl_initialization_output_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: mlagents/envs/communicator_objects/unity_rl_initialization_output.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from mlagents.envs.communicator_objects import brain_parameters_proto_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_brain__parameters__proto__pb2 17 | from mlagents.envs.communicator_objects import environment_parameters_proto_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_environment__parameters__proto__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='mlagents/envs/communicator_objects/unity_rl_initialization_output.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\nGmlagents/envs/communicator_objects/unity_rl_initialization_output.proto\x12\x14\x63ommunicator_objects\x1a?mlagents/envs/communicator_objects/brain_parameters_proto.proto\x1a\x45mlagents/envs/communicator_objects/environment_parameters_proto.proto\"\xe6\x01\n\x1bUnityRLInitializationOutput\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x10\n\x08log_path\x18\x03 \x01(\t\x12\x44\n\x10\x62rain_parameters\x18\x05 \x03(\x0b\x32*.communicator_objects.BrainParametersProto\x12P\n\x16\x65nvironment_parameters\x18\x06 \x01(\x0b\x32\x30.communicator_objects.EnvironmentParametersProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_brain__parameters__proto__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_environment__parameters__proto__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _UNITYRLINITIALIZATIONOUTPUT = _descriptor.Descriptor( 32 | name='UnityRLInitializationOutput', 33 | full_name='communicator_objects.UnityRLInitializationOutput', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='name', full_name='communicator_objects.UnityRLInitializationOutput.name', index=0, 40 | number=1, type=9, cpp_type=9, label=1, 41 | has_default_value=False, default_value=_b("").decode('utf-8'), 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='version', full_name='communicator_objects.UnityRLInitializationOutput.version', index=1, 47 | number=2, type=9, cpp_type=9, label=1, 48 | has_default_value=False, default_value=_b("").decode('utf-8'), 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None, file=DESCRIPTOR), 52 | _descriptor.FieldDescriptor( 53 | name='log_path', full_name='communicator_objects.UnityRLInitializationOutput.log_path', index=2, 54 | number=3, type=9, cpp_type=9, label=1, 55 | has_default_value=False, default_value=_b("").decode('utf-8'), 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | options=None, file=DESCRIPTOR), 59 | _descriptor.FieldDescriptor( 60 | name='brain_parameters', full_name='communicator_objects.UnityRLInitializationOutput.brain_parameters', index=3, 61 | number=5, type=11, cpp_type=10, label=3, 62 | has_default_value=False, default_value=[], 63 | message_type=None, enum_type=None, containing_type=None, 64 | is_extension=False, extension_scope=None, 65 | options=None, file=DESCRIPTOR), 66 | _descriptor.FieldDescriptor( 67 | name='environment_parameters', full_name='communicator_objects.UnityRLInitializationOutput.environment_parameters', index=4, 68 | number=6, type=11, cpp_type=10, label=1, 69 | has_default_value=False, default_value=None, 70 | message_type=None, enum_type=None, containing_type=None, 71 | is_extension=False, extension_scope=None, 72 | options=None, file=DESCRIPTOR), 73 | ], 74 | extensions=[ 75 | ], 76 | nested_types=[], 77 | enum_types=[ 78 | ], 79 | options=None, 80 | is_extendable=False, 81 | syntax='proto3', 82 | extension_ranges=[], 83 | oneofs=[ 84 | ], 85 | serialized_start=234, 86 | serialized_end=464, 87 | ) 88 | 89 | _UNITYRLINITIALIZATIONOUTPUT.fields_by_name['brain_parameters'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_brain__parameters__proto__pb2._BRAINPARAMETERSPROTO 90 | _UNITYRLINITIALIZATIONOUTPUT.fields_by_name['environment_parameters'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_environment__parameters__proto__pb2._ENVIRONMENTPARAMETERSPROTO 91 | DESCRIPTOR.message_types_by_name['UnityRLInitializationOutput'] = _UNITYRLINITIALIZATIONOUTPUT 92 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 93 | 94 | UnityRLInitializationOutput = _reflection.GeneratedProtocolMessageType('UnityRLInitializationOutput', (_message.Message,), dict( 95 | DESCRIPTOR = _UNITYRLINITIALIZATIONOUTPUT, 96 | __module__ = 'mlagents.envs.communicator_objects.unity_rl_initialization_output_pb2' 97 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInitializationOutput) 98 | )) 99 | _sym_db.RegisterMessage(UnityRLInitializationOutput) 100 | 101 | 102 | DESCRIPTOR.has_options = True 103 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 104 | # @@protoc_insertion_point(module_scope) 105 | -------------------------------------------------------------------------------- /ml-agents/tests/trainers/test_bc.py: -------------------------------------------------------------------------------- 1 | import unittest.mock as mock 2 | import pytest 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | import yaml 7 | 8 | from mlagents.trainers.bc.models import BehavioralCloningModel 9 | from mlagents.trainers.bc.policy import BCPolicy 10 | from mlagents.envs import UnityEnvironment 11 | from tests.mock_communicator import MockCommunicator 12 | 13 | 14 | @pytest.fixture 15 | def dummy_config(): 16 | return yaml.load( 17 | ''' 18 | hidden_units: 128 19 | learning_rate: 3.0e-4 20 | num_layers: 2 21 | use_recurrent: false 22 | sequence_length: 32 23 | memory_size: 32 24 | ''') 25 | 26 | 27 | @mock.patch('mlagents.envs.UnityEnvironment.executable_launcher') 28 | @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') 29 | def test_bc_policy_evaluate(mock_communicator, mock_launcher): 30 | tf.reset_default_graph() 31 | mock_communicator.return_value = MockCommunicator( 32 | discrete_action=False, visual_inputs=0) 33 | env = UnityEnvironment(' ') 34 | brain_infos = env.reset() 35 | brain_info = brain_infos[env.brain_names[0]] 36 | 37 | trainer_parameters = dummy_config() 38 | model_path = env.brain_names[0] 39 | trainer_parameters['model_path'] = model_path 40 | trainer_parameters['keep_checkpoints'] = 3 41 | policy = BCPolicy(0, env.brains[env.brain_names[0]], trainer_parameters, False) 42 | run_out = policy.evaluate(brain_info) 43 | assert run_out['action'].shape == (3, 2) 44 | 45 | env.close() 46 | 47 | 48 | @mock.patch('mlagents.envs.UnityEnvironment.executable_launcher') 49 | @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') 50 | def test_cc_bc_model(mock_communicator, mock_launcher): 51 | tf.reset_default_graph() 52 | with tf.Session() as sess: 53 | with tf.variable_scope("FakeGraphScope"): 54 | mock_communicator.return_value = MockCommunicator( 55 | discrete_action=False, visual_inputs=0) 56 | env = UnityEnvironment(' ') 57 | model = BehavioralCloningModel(env.brains["RealFakeBrain"]) 58 | init = tf.global_variables_initializer() 59 | sess.run(init) 60 | 61 | run_list = [model.sample_action, model.policy] 62 | feed_dict = {model.batch_size: 2, 63 | model.sequence_length: 1, 64 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3], 65 | [3, 4, 5, 3, 4, 5]])} 66 | sess.run(run_list, feed_dict=feed_dict) 67 | env.close() 68 | 69 | 70 | @mock.patch('mlagents.envs.UnityEnvironment.executable_launcher') 71 | @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') 72 | def test_dc_bc_model(mock_communicator, mock_launcher): 73 | tf.reset_default_graph() 74 | with tf.Session() as sess: 75 | with tf.variable_scope("FakeGraphScope"): 76 | mock_communicator.return_value = MockCommunicator( 77 | discrete_action=True, visual_inputs=0) 78 | env = UnityEnvironment(' ') 79 | model = BehavioralCloningModel(env.brains["RealFakeBrain"]) 80 | init = tf.global_variables_initializer() 81 | sess.run(init) 82 | 83 | run_list = [model.sample_action, model.action_probs] 84 | feed_dict = {model.batch_size: 2, 85 | model.dropout_rate: 1.0, 86 | model.sequence_length: 1, 87 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3], 88 | [3, 4, 5, 3, 4, 5]]), 89 | model.action_masks: np.ones([2, 2])} 90 | sess.run(run_list, feed_dict=feed_dict) 91 | env.close() 92 | 93 | 94 | @mock.patch('mlagents.envs.UnityEnvironment.executable_launcher') 95 | @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') 96 | def test_visual_dc_bc_model(mock_communicator, mock_launcher): 97 | tf.reset_default_graph() 98 | with tf.Session() as sess: 99 | with tf.variable_scope("FakeGraphScope"): 100 | mock_communicator.return_value = MockCommunicator( 101 | discrete_action=True, visual_inputs=2) 102 | env = UnityEnvironment(' ') 103 | model = BehavioralCloningModel(env.brains["RealFakeBrain"]) 104 | init = tf.global_variables_initializer() 105 | sess.run(init) 106 | 107 | run_list = [model.sample_action, model.action_probs] 108 | feed_dict = {model.batch_size: 2, 109 | model.dropout_rate: 1.0, 110 | model.sequence_length: 1, 111 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3], 112 | [3, 4, 5, 3, 4, 5]]), 113 | model.visual_in[0]: np.ones([2, 40, 30, 3]), 114 | model.visual_in[1]: np.ones([2, 40, 30, 3]), 115 | model.action_masks: np.ones([2, 2])} 116 | sess.run(run_list, feed_dict=feed_dict) 117 | env.close() 118 | 119 | 120 | @mock.patch('mlagents.envs.UnityEnvironment.executable_launcher') 121 | @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') 122 | def test_visual_cc_bc_model(mock_communicator, mock_launcher): 123 | tf.reset_default_graph() 124 | with tf.Session() as sess: 125 | with tf.variable_scope("FakeGraphScope"): 126 | mock_communicator.return_value = MockCommunicator( 127 | discrete_action=False, visual_inputs=2) 128 | env = UnityEnvironment(' ') 129 | model = BehavioralCloningModel(env.brains["RealFakeBrain"]) 130 | init = tf.global_variables_initializer() 131 | sess.run(init) 132 | 133 | run_list = [model.sample_action, model.policy] 134 | feed_dict = {model.batch_size: 2, 135 | model.sequence_length: 1, 136 | model.vector_in: np.array([[1, 2, 3, 1, 2, 3], 137 | [3, 4, 5, 3, 4, 5]]), 138 | model.visual_in[0]: np.ones([2, 40, 30, 3]), 139 | model.visual_in[1]: np.ones([2, 40, 30, 3])} 140 | sess.run(run_list, feed_dict=feed_dict) 141 | env.close() 142 | 143 | 144 | if __name__ == '__main__': 145 | pytest.main() 146 | -------------------------------------------------------------------------------- /ml-agents/mlagents/trainers/meta_curriculum.py: -------------------------------------------------------------------------------- 1 | """Contains the MetaCurriculum class.""" 2 | 3 | import os 4 | from mlagents.trainers.curriculum import Curriculum 5 | from mlagents.trainers.exception import MetaCurriculumError 6 | 7 | import logging 8 | 9 | logger = logging.getLogger('mlagents.trainers') 10 | 11 | 12 | class MetaCurriculum(object): 13 | """A MetaCurriculum holds curriculums. Each curriculum is associated to a 14 | particular brain in the environment. 15 | """ 16 | 17 | def __init__(self, curriculum_folder, default_reset_parameters): 18 | """Initializes a MetaCurriculum object. 19 | 20 | Args: 21 | curriculum_folder (str): The relative or absolute path of the 22 | folder which holds the curriculums for this environment. 23 | The folder should contain JSON files whose names are the 24 | brains that the curriculums belong to. 25 | default_reset_parameters (dict): The default reset parameters 26 | of the environment. 27 | """ 28 | used_reset_parameters = set() 29 | self._brains_to_curriculums = {} 30 | 31 | try: 32 | for curriculum_filename in os.listdir(curriculum_folder): 33 | brain_name = curriculum_filename.split('.')[0] 34 | curriculum_filepath = \ 35 | os.path.join(curriculum_folder, curriculum_filename) 36 | curriculum = Curriculum(curriculum_filepath, 37 | default_reset_parameters) 38 | 39 | # Check if any two curriculums use the same reset params. 40 | if any([(parameter in curriculum.get_config().keys()) 41 | for parameter in used_reset_parameters]): 42 | logger.warning('Two or more curriculums will ' 43 | 'attempt to change the same reset ' 44 | 'parameter. The result will be ' 45 | 'non-deterministic.') 46 | 47 | used_reset_parameters.update(curriculum.get_config().keys()) 48 | self._brains_to_curriculums[brain_name] = curriculum 49 | except NotADirectoryError: 50 | raise MetaCurriculumError(curriculum_folder + ' is not a ' 51 | 'directory. Refer to the ML-Agents ' 52 | 'curriculum learning docs.') 53 | 54 | 55 | @property 56 | def brains_to_curriculums(self): 57 | """A dict from brain_name to the brain's curriculum.""" 58 | return self._brains_to_curriculums 59 | 60 | @property 61 | def lesson_nums(self): 62 | """A dict from brain name to the brain's curriculum's lesson number.""" 63 | lesson_nums = {} 64 | for brain_name, curriculum in self.brains_to_curriculums.items(): 65 | lesson_nums[brain_name] = curriculum.lesson_num 66 | 67 | return lesson_nums 68 | 69 | @lesson_nums.setter 70 | def lesson_nums(self, lesson_nums): 71 | for brain_name, lesson in lesson_nums.items(): 72 | self.brains_to_curriculums[brain_name].lesson_num = lesson 73 | 74 | def _lesson_ready_to_increment(self, brain_name, reward_buff_size): 75 | """Determines whether the curriculum of a specified brain is ready 76 | to attempt an increment. 77 | 78 | Args: 79 | brain_name (str): The name of the brain whose curriculum will be 80 | checked for readiness. 81 | reward_buff_size (int): The size of the reward buffer of the trainer 82 | that corresponds to the specified brain. 83 | 84 | Returns: 85 | Whether the curriculum of the specified brain should attempt to 86 | increment its lesson. 87 | """ 88 | return reward_buff_size >= (self.brains_to_curriculums[brain_name] 89 | .min_lesson_length) 90 | 91 | def increment_lessons(self, measure_vals, reward_buff_sizes=None): 92 | """Attempts to increments all the lessons of all the curriculums in this 93 | MetaCurriculum. Note that calling this method does not guarantee the 94 | lesson of a curriculum will increment. The lesson of a curriculum will 95 | only increment if the specified measure threshold defined in the 96 | curriculum has been reached and the minimum number of episodes in the 97 | lesson have been completed. 98 | 99 | Args: 100 | measure_vals (dict): A dict of brain name to measure value. 101 | reward_buff_sizes (dict): A dict of brain names to the size of their 102 | corresponding reward buffers. 103 | 104 | Returns: 105 | A dict from brain name to whether that brain's lesson number was 106 | incremented. 107 | """ 108 | ret = {} 109 | if reward_buff_sizes: 110 | for brain_name, buff_size in reward_buff_sizes.items(): 111 | if self._lesson_ready_to_increment(brain_name, buff_size): 112 | measure_val = measure_vals[brain_name] 113 | ret[brain_name] = (self.brains_to_curriculums[brain_name] 114 | .increment_lesson(measure_val)) 115 | else: 116 | for brain_name, measure_val in measure_vals.items(): 117 | ret[brain_name] = (self.brains_to_curriculums[brain_name] 118 | .increment_lesson(measure_val)) 119 | return ret 120 | 121 | 122 | def set_all_curriculums_to_lesson_num(self, lesson_num): 123 | """Sets all the curriculums in this meta curriculum to a specified 124 | lesson number. 125 | 126 | Args: 127 | lesson_num (int): The lesson number which all the curriculums will 128 | be set to. 129 | """ 130 | for _, curriculum in self.brains_to_curriculums.items(): 131 | curriculum.lesson_num = lesson_num 132 | 133 | 134 | def get_config(self): 135 | """Get the combined configuration of all curriculums in this 136 | MetaCurriculum. 137 | 138 | Returns: 139 | A dict from parameter to value. 140 | """ 141 | config = {} 142 | 143 | for _, curriculum in self.brains_to_curriculums.items(): 144 | curr_config = curriculum.get_config() 145 | config.update(curr_config) 146 | 147 | return config 148 | -------------------------------------------------------------------------------- /gym-unity/README.md: -------------------------------------------------------------------------------- 1 | # Unity ML-Agents Gym Wrapper 2 | 3 | A common way in which machine learning researchers interact with simulation 4 | environments is via a wrapper provided by OpenAI called `gym`. For more 5 | information on the gym interface, see [here](https://github.com/openai/gym). 6 | 7 | We provide a a gym wrapper, and instructions for using it with existing machine 8 | learning algorithms which utilize gyms. Both wrappers provide interfaces on top 9 | of our `UnityEnvironment` class, which is the default way of interfacing with a 10 | Unity environment via Python. 11 | 12 | ## Installation 13 | 14 | The gym wrapper can be installed using: 15 | 16 | ```sh 17 | pip install gym_unity 18 | ``` 19 | 20 | or by running the following from the `/gym-unity` directory of the repository: 21 | 22 | ```sh 23 | pip install . 24 | ``` 25 | 26 | ## Using the Gym Wrapper 27 | 28 | The gym interface is available from `gym_unity.envs`. To launch an environmnent 29 | from the root of the project repository use: 30 | 31 | ```python 32 | from gym_unity.envs import UnityEnv 33 | 34 | env = UnityEnv(environment_filename, worker_id, default_visual, multiagent) 35 | ``` 36 | 37 | * `environment_filename` refers to the path to the Unity environment. 38 | * `worker_id` refers to the port to use for communication with the environment. 39 | Defaults to `0`. 40 | * `use_visual` refers to whether to use visual observations (True) or vector 41 | observations (False) as the default observation provided by the `reset` and 42 | `step` functions. Defaults to `False`. 43 | * `multiagent` refers to whether you intent to launch an environment which 44 | contains more than one agent. Defaults to `False`. 45 | 46 | The returned environment `env` will function as a gym. 47 | 48 | For more on using the gym interface, see our 49 | [Jupyter Notebook tutorial](../notebooks/getting-started-gym.ipynb). 50 | 51 | ## Limitation 52 | 53 | * It is only possible to use an environment with a single Brain. 54 | * By default the first visual observation is provided as the `observation`, if 55 | present. Otherwise vector observations are provided. 56 | * All `BrainInfo` output from the environment can still be accessed from the 57 | `info` provided by `env.step(action)`. 58 | * Stacked vector observations are not supported. 59 | * Environment registration for use with `gym.make()` is currently not supported. 60 | 61 | ## Running OpenAI Baselines Algorithms 62 | 63 | OpenAI provides a set of open-source maintained and tested Reinforcement 64 | Learning algorithms called the [Baselines](https://github.com/openai/baselines). 65 | 66 | Using the provided Gym wrapper, it is possible to train ML-Agents environments 67 | using these algorithms. This requires the creation of custom training scripts to 68 | launch each algorithm. In most cases these scripts can be created by making 69 | slightly modifications to the ones provided for Atari and Mujoco environments. 70 | 71 | ### Example - DQN Baseline 72 | 73 | In order to train an agent to play the `GridWorld` environment using the 74 | Baselines DQN algorithm, you first need to install the baselines package using 75 | pip: 76 | 77 | ``` 78 | pip install git+git://github.com/openai/baselines 79 | ``` 80 | 81 | Next, create a file called `train_unity.py`. Then create an `/envs/` directory 82 | and build the GridWorld environment to that directory. For more information on 83 | building Unity environments, see 84 | [here](../docs/Learning-Environment-Executable.md). Add the following code to 85 | the `train_unity.py` file: 86 | 87 | ```python 88 | import gym 89 | 90 | from baselines import deepq 91 | from gym_unity.envs import UnityEnv 92 | 93 | def main(): 94 | env = UnityEnv("./envs/GridWorld", 0, use_visual=True) 95 | act = deepq.learn( 96 | env, 97 | "mlp", 98 | lr=1e-3, 99 | total_timesteps=100000, 100 | buffer_size=50000, 101 | exploration_fraction=0.1, 102 | exploration_final_eps=0.02, 103 | print_freq=10 104 | ) 105 | print("Saving model to unity_model.pkl") 106 | act.save("unity_model.pkl") 107 | 108 | 109 | if __name__ == '__main__': 110 | main() 111 | ``` 112 | 113 | To start the training process, run the following from the root of the baselines 114 | repository: 115 | 116 | ```sh 117 | python -m train_unity 118 | ``` 119 | 120 | ### Other Algorithms 121 | 122 | Other algorithms in the Baselines repository can be run using scripts similar to 123 | the examples from the baselines package. In most cases, the primary changes needed 124 | to use a Unity environment are to import `UnityEnv`, and to replace the environment 125 | creation code, typically `gym.make()`, with a call to `UnityEnv(env_path)` 126 | passing the environment binary path. 127 | 128 | A typical rule of thumb is that for vision-based environments, modification 129 | should be done to Atari training scripts, and for vector observation 130 | environments, modification should be done to Mujoco scripts. 131 | 132 | Some algorithms will make use of `make_env()` or `make_mujoco_env()` 133 | functions. You can define a similar function for Unity environments. An example of 134 | such a method using the PPO2 baseline: 135 | 136 | ```python 137 | from gym_unity.envs import UnityEnv 138 | from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv 139 | from baselines.bench import Monitor 140 | from baselines import logger 141 | import baselines.ppo2.ppo2 as ppo2 142 | 143 | import os 144 | 145 | try: 146 | from mpi4py import MPI 147 | except ImportError: 148 | MPI = None 149 | 150 | def make_unity_env(env_directory, num_env, visual, start_index=0): 151 | """ 152 | Create a wrapped, monitored Unity environment. 153 | """ 154 | def make_env(rank, use_visual=True): # pylint: disable=C0111 155 | def _thunk(): 156 | env = UnityEnv(env_directory, rank, use_visual=use_visual) 157 | env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank))) 158 | return env 159 | return _thunk 160 | if visual: 161 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) 162 | else: 163 | rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 164 | return make_env(rank, use_visual=False) 165 | 166 | def main(): 167 | env = make_unity_env('./envs/GridWorld', 4, True) 168 | ppo2.learn( 169 | network="mlp", 170 | env=env, 171 | total_timesteps=100000, 172 | lr=1e-3, 173 | ) 174 | 175 | if __name__ == '__main__': 176 | main() 177 | ``` 178 | -------------------------------------------------------------------------------- /config/trainer_config.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | trainer: ppo 3 | batch_size: 1024 4 | beta: 5.0e-3 5 | buffer_size: 10240 6 | epsilon: 0.2 7 | gamma: 0.99 8 | hidden_units: 128 9 | lambd: 0.95 10 | learning_rate: 3.0e-4 11 | max_steps: 5.0e4 12 | memory_size: 256 13 | normalize: false 14 | num_epoch: 3 15 | num_layers: 2 16 | time_horizon: 64 17 | sequence_length: 64 18 | summary_freq: 1000 19 | use_recurrent: false 20 | use_curiosity: false 21 | curiosity_strength: 0.01 22 | curiosity_enc_size: 128 23 | 24 | AutoBenchBrain: 25 | trainer: ppo 26 | batch_size: 512 27 | beta: 1.0e-1 28 | buffer_size: 4096 29 | epsilon: 0.2 30 | gamma: 0.99 31 | hidden_units: 128 32 | lambd: 0.95 33 | learning_rate: 1.0e-4 34 | max_steps: 1.0e7 35 | memory_size: 256 36 | normalize: true 37 | num_epoch: 5 38 | num_layers: 2 39 | time_horizon: 512 40 | sequence_length: 64 41 | summary_freq: 1000 42 | use_recurrent: true 43 | use_curiosity: false 44 | curiosity_strength: 0.01 45 | curiosity_enc_size: 128 46 | 47 | BananaLearning: 48 | normalize: false 49 | batch_size: 1024 50 | beta: 5.0e-3 51 | buffer_size: 10240 52 | max_steps: 1.0e5 53 | 54 | BouncerLearning: 55 | normalize: true 56 | max_steps: 5.0e5 57 | num_layers: 2 58 | hidden_units: 64 59 | 60 | PushBlockLearning: 61 | max_steps: 5.0e4 62 | batch_size: 128 63 | buffer_size: 2048 64 | beta: 1.0e-2 65 | hidden_units: 256 66 | summary_freq: 2000 67 | time_horizon: 64 68 | num_layers: 2 69 | 70 | SmallWallJumpLearning: 71 | max_steps: 1.0e6 72 | batch_size: 128 73 | buffer_size: 2048 74 | beta: 5.0e-3 75 | hidden_units: 256 76 | summary_freq: 2000 77 | time_horizon: 128 78 | num_layers: 2 79 | normalize: false 80 | 81 | BigWallJumpLearning: 82 | max_steps: 1.0e6 83 | batch_size: 128 84 | buffer_size: 2048 85 | beta: 5.0e-3 86 | hidden_units: 256 87 | summary_freq: 2000 88 | time_horizon: 128 89 | num_layers: 2 90 | normalize: false 91 | 92 | StrikerLearning: 93 | max_steps: 5.0e5 94 | learning_rate: 1e-3 95 | batch_size: 128 96 | num_epoch: 3 97 | buffer_size: 2000 98 | beta: 1.0e-2 99 | hidden_units: 256 100 | summary_freq: 2000 101 | time_horizon: 128 102 | num_layers: 2 103 | normalize: false 104 | 105 | GoalieLearning: 106 | max_steps: 5.0e5 107 | learning_rate: 1e-3 108 | batch_size: 320 109 | num_epoch: 3 110 | buffer_size: 2000 111 | beta: 1.0e-2 112 | hidden_units: 256 113 | summary_freq: 2000 114 | time_horizon: 128 115 | num_layers: 2 116 | normalize: false 117 | 118 | PyramidsLearning: 119 | use_curiosity: true 120 | summary_freq: 2000 121 | curiosity_strength: 0.01 122 | curiosity_enc_size: 256 123 | time_horizon: 128 124 | batch_size: 128 125 | buffer_size: 2048 126 | hidden_units: 512 127 | num_layers: 2 128 | beta: 1.0e-2 129 | max_steps: 5.0e5 130 | num_epoch: 3 131 | 132 | VisualPyramidsLearning: 133 | use_curiosity: true 134 | curiosity_strength: 0.01 135 | curiosity_enc_size: 256 136 | time_horizon: 128 137 | batch_size: 64 138 | buffer_size: 2024 139 | hidden_units: 256 140 | num_layers: 1 141 | beta: 1.0e-2 142 | max_steps: 5.0e5 143 | num_epoch: 3 144 | 145 | 3DBallLearning: 146 | normalize: true 147 | batch_size: 64 148 | buffer_size: 12000 149 | summary_freq: 1000 150 | time_horizon: 1000 151 | lambd: 0.99 152 | gamma: 0.995 153 | beta: 0.001 154 | use_curiosity: true 155 | 156 | 3DBallHardLearning: 157 | normalize: true 158 | batch_size: 1200 159 | buffer_size: 12000 160 | summary_freq: 1000 161 | time_horizon: 1000 162 | max_steps: 5.0e5 163 | gamma: 0.995 164 | beta: 0.001 165 | 166 | TennisLearning: 167 | normalize: true 168 | max_steps: 2e5 169 | 170 | CrawlerStaticLearning: 171 | normalize: true 172 | num_epoch: 3 173 | time_horizon: 1000 174 | batch_size: 2024 175 | buffer_size: 20240 176 | gamma: 0.995 177 | max_steps: 1e6 178 | summary_freq: 3000 179 | num_layers: 3 180 | hidden_units: 512 181 | 182 | CrawlerDynamicLearning: 183 | normalize: true 184 | num_epoch: 3 185 | time_horizon: 1000 186 | batch_size: 2024 187 | buffer_size: 20240 188 | gamma: 0.995 189 | max_steps: 1e6 190 | summary_freq: 3000 191 | num_layers: 3 192 | hidden_units: 512 193 | 194 | WalkerLearning: 195 | normalize: true 196 | num_epoch: 3 197 | time_horizon: 1000 198 | batch_size: 2048 199 | buffer_size: 20480 200 | gamma: 0.995 201 | max_steps: 2e6 202 | summary_freq: 3000 203 | num_layers: 3 204 | hidden_units: 512 205 | 206 | ReacherLearning: 207 | normalize: true 208 | num_epoch: 3 209 | time_horizon: 1000 210 | batch_size: 2024 211 | buffer_size: 20240 212 | gamma: 0.995 213 | max_steps: 1e6 214 | summary_freq: 3000 215 | 216 | HallwayLearning: 217 | use_recurrent: true 218 | sequence_length: 64 219 | num_layers: 2 220 | hidden_units: 128 221 | memory_size: 256 222 | beta: 1.0e-2 223 | gamma: 0.99 224 | num_epoch: 3 225 | buffer_size: 1024 226 | batch_size: 128 227 | max_steps: 5.0e5 228 | summary_freq: 100 229 | time_horizon: 64 230 | 231 | VisualHallwayLearning: 232 | use_recurrent: true 233 | sequence_length: 64 234 | num_layers: 1 235 | hidden_units: 128 236 | memory_size: 256 237 | beta: 1.0e-2 238 | gamma: 0.99 239 | num_epoch: 3 240 | buffer_size: 1024 241 | batch_size: 64 242 | max_steps: 5.0e5 243 | summary_freq: 1000 244 | time_horizon: 64 245 | 246 | VisualPushBlockLearning: 247 | use_recurrent: true 248 | sequence_length: 32 249 | num_layers: 1 250 | hidden_units: 128 251 | memory_size: 256 252 | beta: 1.0e-2 253 | gamma: 0.99 254 | num_epoch: 3 255 | buffer_size: 1024 256 | batch_size: 64 257 | max_steps: 5.0e5 258 | summary_freq: 1000 259 | time_horizon: 64 260 | 261 | GridWorldLearning: 262 | batch_size: 32 263 | normalize: false 264 | num_layers: 1 265 | hidden_units: 256 266 | beta: 5.0e-3 267 | gamma: 0.9 268 | buffer_size: 256 269 | max_steps: 5.0e5 270 | summary_freq: 2000 271 | time_horizon: 5 272 | 273 | BasicLearning: 274 | batch_size: 32 275 | normalize: false 276 | num_layers: 1 277 | hidden_units: 20 278 | beta: 5.0e-3 279 | gamma: 0.9 280 | buffer_size: 256 281 | max_steps: 5.0e5 282 | summary_freq: 2000 283 | time_horizon: 3 284 | -------------------------------------------------------------------------------- /ml-agents/mlagents/trainers/bc/online_trainer.py: -------------------------------------------------------------------------------- 1 | # # Unity ML-Agents Toolkit 2 | # ## ML-Agent Learning (Behavioral Cloning) 3 | # Contains an implementation of Behavioral Cloning Algorithm 4 | 5 | import logging 6 | import numpy as np 7 | 8 | from mlagents.envs import AllBrainInfo 9 | from mlagents.trainers.bc.trainer import BCTrainer 10 | 11 | logger = logging.getLogger("mlagents.trainers") 12 | 13 | 14 | class OnlineBCTrainer(BCTrainer): 15 | """The OnlineBCTrainer is an implementation of Online Behavioral Cloning.""" 16 | 17 | def __init__(self, brain, trainer_parameters, training, load, seed, run_id): 18 | """ 19 | Responsible for collecting experiences and training PPO model. 20 | :param trainer_parameters: The parameters for the trainer (dictionary). 21 | :param training: Whether the trainer is set for training. 22 | :param load: Whether the model should be loaded. 23 | :param seed: The seed the model will be initialized with 24 | :param run_id: The The identifier of the current run 25 | """ 26 | super(OnlineBCTrainer, self).__init__(brain, trainer_parameters, training, load, seed, 27 | run_id) 28 | 29 | self.param_keys = ['brain_to_imitate', 'batch_size', 'time_horizon', 30 | 'summary_freq', 'max_steps', 31 | 'batches_per_epoch', 'use_recurrent', 32 | 'hidden_units', 'learning_rate', 'num_layers', 33 | 'sequence_length', 'memory_size', 'model_path'] 34 | 35 | self.check_param_keys() 36 | self.brain_to_imitate = trainer_parameters['brain_to_imitate'] 37 | self.batches_per_epoch = trainer_parameters['batches_per_epoch'] 38 | self.n_sequences = max(int(trainer_parameters['batch_size'] / self.policy.sequence_length), 39 | 1) 40 | 41 | def __str__(self): 42 | return '''Hyperparameters for the Imitation Trainer of brain {0}: \n{1}'''.format( 43 | self.brain_name, '\n'.join( 44 | ['\t{0}:\t{1}'.format(x, self.trainer_parameters[x]) for x in self.param_keys])) 45 | 46 | def add_experiences(self, curr_info: AllBrainInfo, next_info: AllBrainInfo, 47 | take_action_outputs): 48 | """ 49 | Adds experiences to each agent's experience history. 50 | :param curr_info: Current AllBrainInfo (Dictionary of all current brains and corresponding BrainInfo). 51 | :param next_info: Next AllBrainInfo (Dictionary of all current brains and corresponding BrainInfo). 52 | :param take_action_outputs: The outputs of the take action method. 53 | """ 54 | 55 | # Used to collect teacher experience into training buffer 56 | info_teacher = curr_info[self.brain_to_imitate] 57 | next_info_teacher = next_info[self.brain_to_imitate] 58 | for agent_id in info_teacher.agents: 59 | self.demonstration_buffer[agent_id].last_brain_info = info_teacher 60 | 61 | for agent_id in next_info_teacher.agents: 62 | stored_info_teacher = self.demonstration_buffer[agent_id].last_brain_info 63 | if stored_info_teacher is None: 64 | continue 65 | else: 66 | idx = stored_info_teacher.agents.index(agent_id) 67 | next_idx = next_info_teacher.agents.index(agent_id) 68 | if stored_info_teacher.text_observations[idx] != "": 69 | info_teacher_record, info_teacher_reset = \ 70 | stored_info_teacher.text_observations[idx].lower().split(",") 71 | next_info_teacher_record, next_info_teacher_reset = \ 72 | next_info_teacher.text_observations[idx]. \ 73 | lower().split(",") 74 | if next_info_teacher_reset == "true": 75 | self.demonstration_buffer.reset_update_buffer() 76 | else: 77 | info_teacher_record, next_info_teacher_record = "true", "true" 78 | if info_teacher_record == "true" and next_info_teacher_record == "true": 79 | if not stored_info_teacher.local_done[idx]: 80 | for i in range(self.policy.vis_obs_size): 81 | self.demonstration_buffer[agent_id]['visual_obs%d' % i] \ 82 | .append(stored_info_teacher.visual_observations[i][idx]) 83 | if self.policy.use_vec_obs: 84 | self.demonstration_buffer[agent_id]['vector_obs'] \ 85 | .append(stored_info_teacher.vector_observations[idx]) 86 | if self.policy.use_recurrent: 87 | if stored_info_teacher.memories.shape[1] == 0: 88 | stored_info_teacher.memories = np.zeros( 89 | (len(stored_info_teacher.agents), 90 | self.policy.m_size)) 91 | self.demonstration_buffer[agent_id]['memory'].append( 92 | stored_info_teacher.memories[idx]) 93 | self.demonstration_buffer[agent_id]['actions'].append( 94 | next_info_teacher.previous_vector_actions[next_idx]) 95 | 96 | super(OnlineBCTrainer, self).add_experiences(curr_info, next_info, take_action_outputs) 97 | 98 | def process_experiences(self, current_info: AllBrainInfo, next_info: AllBrainInfo): 99 | """ 100 | Checks agent histories for processing condition, and processes them as necessary. 101 | Processing involves calculating value and advantage targets for model updating step. 102 | :param current_info: Current AllBrainInfo 103 | :param next_info: Next AllBrainInfo 104 | """ 105 | info_teacher = next_info[self.brain_to_imitate] 106 | for l in range(len(info_teacher.agents)): 107 | teacher_action_list = len(self.demonstration_buffer[info_teacher.agents[l]]['actions']) 108 | horizon_reached = teacher_action_list > self.trainer_parameters['time_horizon'] 109 | teacher_filled = len(self.demonstration_buffer[info_teacher.agents[l]]['actions']) > 0 110 | if (info_teacher.local_done[l] or horizon_reached) and teacher_filled: 111 | agent_id = info_teacher.agents[l] 112 | self.demonstration_buffer.append_update_buffer( 113 | agent_id, batch_size=None, training_length=self.policy.sequence_length) 114 | self.demonstration_buffer[agent_id].reset_agent() 115 | 116 | super(OnlineBCTrainer, self).process_experiences(current_info, next_info) 117 | --------------------------------------------------------------------------------