├── test_score_plot.png ├── best_model.checkpoint ├── training_score_plot.png ├── training_100avgscore_plot.png ├── install_requirements ├── unityagents │ ├── __init__.py │ ├── communicator.py │ ├── exception.py │ ├── brain.py │ ├── rpc_communicator.py │ ├── socket_communicator.py │ ├── curriculum.py │ └── environment.py ├── requirements.txt ├── communicator_objects │ ├── __init__.py │ ├── unity_to_external_pb2_grpc.py │ ├── command_proto_pb2.py │ ├── unity_to_external_pb2.py │ ├── space_type_proto_pb2.py │ ├── brain_type_proto_pb2.py │ ├── unity_rl_initialization_input_pb2.py │ ├── header_pb2.py │ ├── resolution_proto_pb2.py │ ├── agent_action_proto_pb2.py │ ├── unity_input_pb2.py │ ├── unity_output_pb2.py │ ├── unity_message_pb2.py │ ├── engine_configuration_proto_pb2.py │ ├── environment_parameters_proto_pb2.py │ ├── unity_rl_initialization_output_pb2.py │ ├── agent_info_proto_pb2.py │ ├── unity_rl_output_pb2.py │ ├── brain_parameters_proto_pb2.py │ └── unity_rl_input_pb2.py └── setup.py ├── utils.py ├── LICENSE ├── .gitignore ├── test.py ├── models.py ├── README.md ├── train.py ├── Report.md └── agent.py /test_score_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaochen/DDPG_MultiAgent/HEAD/test_score_plot.png -------------------------------------------------------------------------------- /best_model.checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaochen/DDPG_MultiAgent/HEAD/best_model.checkpoint -------------------------------------------------------------------------------- /training_score_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaochen/DDPG_MultiAgent/HEAD/training_score_plot.png -------------------------------------------------------------------------------- /training_100avgscore_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaochen/DDPG_MultiAgent/HEAD/training_100avgscore_plot.png -------------------------------------------------------------------------------- /install_requirements/unityagents/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import * 2 | from .brain import * 3 | from .exception import * 4 | from .curriculum import * 5 | -------------------------------------------------------------------------------- /install_requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.7.1 2 | Pillow>=4.2.1 3 | matplotlib 4 | numpy>=1.11.0 5 | jupyter 6 | pytest>=3.2.2 7 | docopt 8 | pyyaml 9 | protobuf==3.5.2 10 | grpcio==1.11.0 11 | torch==0.4.0 12 | pandas 13 | scipy 14 | ipykernel 15 | tqdm 16 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | def draw(scores, path="fig.png", title="Performance", xlabel="Episode #", ylabel="Score"): 6 | fig = plt.figure() 7 | ax = fig.add_subplot(111) 8 | plt.title(title) 9 | plt.plot(np.arange(len(scores)), scores) 10 | plt.ylabel(ylabel) 11 | plt.xlabel(xlabel) 12 | plt.savefig(path) 13 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent_action_proto_pb2 import * 2 | from .agent_info_proto_pb2 import * 3 | from .brain_parameters_proto_pb2 import * 4 | from .brain_type_proto_pb2 import * 5 | from .command_proto_pb2 import * 6 | from .engine_configuration_proto_pb2 import * 7 | from .environment_parameters_proto_pb2 import * 8 | from .header_pb2 import * 9 | from .resolution_proto_pb2 import * 10 | from .space_type_proto_pb2 import * 11 | from .unity_input_pb2 import * 12 | from .unity_message_pb2 import * 13 | from .unity_output_pb2 import * 14 | from .unity_rl_initialization_input_pb2 import * 15 | from .unity_rl_initialization_output_pb2 import * 16 | from .unity_rl_input_pb2 import * 17 | from .unity_rl_output_pb2 import * 18 | from .unity_to_external_pb2 import * 19 | from .unity_to_external_pb2_grpc import * 20 | -------------------------------------------------------------------------------- /install_requirements/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, Command, find_packages 4 | 5 | 6 | with open('requirements.txt') as f: 7 | required = f.read().splitlines() 8 | 9 | setup(name='unityagents', 10 | version='0.4.0', 11 | description='Unity Machine Learning Agents', 12 | license='Apache License 2.0', 13 | author='Unity Technologies', 14 | author_email='ML-Agents@unity3d.com', 15 | url='https://github.com/Unity-Technologies/ml-agents', 16 | packages=find_packages(), 17 | install_requires = required, 18 | long_description= ("Unity Machine Learning Agents allows researchers and developers " 19 | "to transform games and simulations created using the Unity Editor into environments " 20 | "where intelligent agents can be trained using reinforcement learning, evolutionary " 21 | "strategies, or other machine learning methods through a simple to use Python API.") 22 | ) 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /install_requirements/unityagents/communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from communicator_objects import UnityOutput, UnityInput 4 | 5 | logging.basicConfig(level=logging.INFO) 6 | logger = logging.getLogger("unityagents") 7 | 8 | 9 | class Communicator(object): 10 | def __init__(self, worker_id=0, 11 | base_port=5005): 12 | """ 13 | Python side of the communication. Must be used in pair with the right Unity Communicator equivalent. 14 | 15 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 16 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 17 | """ 18 | 19 | def initialize(self, inputs: UnityInput) -> UnityOutput: 20 | """ 21 | Used to exchange initialization parameters between Python and the Environment 22 | :param inputs: The initialization input that will be sent to the environment. 23 | :return: UnityOutput: The initialization output sent by Unity 24 | """ 25 | 26 | def exchange(self, inputs: UnityInput) -> UnityOutput: 27 | """ 28 | Used to send an input and receive an output from the Environment 29 | :param inputs: The UnityInput that needs to be sent the Environment 30 | :return: The UnityOutputs generated by the Environment 31 | """ 32 | 33 | def close(self): 34 | """ 35 | Sends a shutdown signal to the unity environment, and closes the connection. 36 | """ 37 | 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /install_requirements/unityagents/exception.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger("unityagents") 3 | 4 | class UnityException(Exception): 5 | """ 6 | Any error related to ml-agents environment. 7 | """ 8 | pass 9 | 10 | class UnityEnvironmentException(UnityException): 11 | """ 12 | Related to errors starting and closing environment. 13 | """ 14 | pass 15 | 16 | 17 | class UnityActionException(UnityException): 18 | """ 19 | Related to errors with sending actions. 20 | """ 21 | pass 22 | 23 | class UnityTimeOutException(UnityException): 24 | """ 25 | Related to errors with communication timeouts. 26 | """ 27 | def __init__(self, message, log_file_path = None): 28 | if log_file_path is not None: 29 | try: 30 | with open(log_file_path, "r") as f: 31 | printing = False 32 | unity_error = '\n' 33 | for l in f: 34 | l=l.strip() 35 | if (l == 'Exception') or (l=='Error'): 36 | printing = True 37 | unity_error += '----------------------\n' 38 | if (l == ''): 39 | printing = False 40 | if printing: 41 | unity_error += l + '\n' 42 | logger.info(unity_error) 43 | logger.error("An error might have occured in the environment. " 44 | "You can check the logfile for more information at {}".format(log_file_path)) 45 | except: 46 | logger.error("An error might have occured in the environment. " 47 | "No unity-environment.log file could be found.") 48 | super(UnityTimeOutException, self).__init__(message) 49 | 50 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/unity_to_external_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | import grpc 3 | 4 | from communicator_objects import unity_message_pb2 as communicator__objects_dot_unity__message__pb2 5 | 6 | 7 | class UnityToExternalStub(object): 8 | # missing associated documentation comment in .proto file 9 | pass 10 | 11 | def __init__(self, channel): 12 | """Constructor. 13 | 14 | Args: 15 | channel: A grpc.Channel. 16 | """ 17 | self.Exchange = channel.unary_unary( 18 | '/communicator_objects.UnityToExternal/Exchange', 19 | request_serializer=communicator__objects_dot_unity__message__pb2.UnityMessage.SerializeToString, 20 | response_deserializer=communicator__objects_dot_unity__message__pb2.UnityMessage.FromString, 21 | ) 22 | 23 | 24 | class UnityToExternalServicer(object): 25 | # missing associated documentation comment in .proto file 26 | pass 27 | 28 | def Exchange(self, request, context): 29 | """Sends the academy parameters 30 | """ 31 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 32 | context.set_details('Method not implemented!') 33 | raise NotImplementedError('Method not implemented!') 34 | 35 | 36 | def add_UnityToExternalServicer_to_server(servicer, server): 37 | rpc_method_handlers = { 38 | 'Exchange': grpc.unary_unary_rpc_method_handler( 39 | servicer.Exchange, 40 | request_deserializer=communicator__objects_dot_unity__message__pb2.UnityMessage.FromString, 41 | response_serializer=communicator__objects_dot_unity__message__pb2.UnityMessage.SerializeToString, 42 | ), 43 | } 44 | generic_handler = grpc.method_handlers_generic_handler( 45 | 'communicator_objects.UnityToExternal', rpc_method_handlers) 46 | server.add_generic_rpc_handlers((generic_handler,)) 47 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/command_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/command_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf.internal import enum_type_wrapper 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | from google.protobuf import descriptor_pb2 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='communicator_objects/command_proto.proto', 21 | package='communicator_objects', 22 | syntax='proto3', 23 | serialized_pb=_b('\n(communicator_objects/command_proto.proto\x12\x14\x63ommunicator_objects*-\n\x0c\x43ommandProto\x12\x08\n\x04STEP\x10\x00\x12\t\n\x05RESET\x10\x01\x12\x08\n\x04QUIT\x10\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | ) 25 | 26 | _COMMANDPROTO = _descriptor.EnumDescriptor( 27 | name='CommandProto', 28 | full_name='communicator_objects.CommandProto', 29 | filename=None, 30 | file=DESCRIPTOR, 31 | values=[ 32 | _descriptor.EnumValueDescriptor( 33 | name='STEP', index=0, number=0, 34 | options=None, 35 | type=None), 36 | _descriptor.EnumValueDescriptor( 37 | name='RESET', index=1, number=1, 38 | options=None, 39 | type=None), 40 | _descriptor.EnumValueDescriptor( 41 | name='QUIT', index=2, number=2, 42 | options=None, 43 | type=None), 44 | ], 45 | containing_type=None, 46 | options=None, 47 | serialized_start=66, 48 | serialized_end=111, 49 | ) 50 | _sym_db.RegisterEnumDescriptor(_COMMANDPROTO) 51 | 52 | CommandProto = enum_type_wrapper.EnumTypeWrapper(_COMMANDPROTO) 53 | STEP = 0 54 | RESET = 1 55 | QUIT = 2 56 | 57 | 58 | DESCRIPTOR.enum_types_by_name['CommandProto'] = _COMMANDPROTO 59 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 60 | 61 | 62 | DESCRIPTOR.has_options = True 63 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 64 | # @@protoc_insertion_point(module_scope) 65 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from unityagents import UnityEnvironment 2 | import numpy as np 3 | from tqdm import tqdm 4 | from agent import MADDPG 5 | from utils import draw 6 | 7 | unity_environment_path = "./Tennis_Linux/Tennis.x86_64" 8 | best_model_path = "./best_model.checkpoint" 9 | 10 | if __name__ == "__main__": 11 | # prepare environment 12 | env = UnityEnvironment(file_name=unity_environment_path) 13 | brain_name = env.brain_names[0] 14 | brain = env.brains[brain_name] 15 | env_info = env.reset(train_mode=True)[brain_name] 16 | 17 | num_agents = len(env_info.agents) 18 | print('Number of agents:', num_agents) 19 | 20 | # dim of each action 21 | action_size = brain.vector_action_space_size 22 | print('Size of each action:', action_size) 23 | 24 | # dim of the state space 25 | states = env_info.vector_observations 26 | state_size = states.shape[1] 27 | 28 | agent = MADDPG(state_size, action_size) 29 | 30 | agent.load(best_model_path) 31 | 32 | test_scores = [] 33 | for i_episode in tqdm(range(1, 101)): 34 | scores = np.zeros(num_agents) # initialize the scores 35 | env_info = env.reset(train_mode=True)[brain_name] # reset the environment 36 | states = env_info.vector_observations # get the current states 37 | dones = [False]*num_agents 38 | while not np.any(dones): 39 | actions = agent.act(states) # select actions 40 | env_info = env.step(actions)[brain_name] # send the actions to the environment 41 | next_states = env_info.vector_observations # get the next states 42 | rewards = env_info.rewards # get the rewards 43 | dones = env_info.local_done # see if episode has finished 44 | scores += rewards # update the scores 45 | states = next_states # roll over the states to next time step 46 | 47 | test_scores.append(np.max(scores)) 48 | 49 | avg_score = sum(test_scores)/len(test_scores) 50 | print("Test Score: {}".format(avg_score)) 51 | draw(test_scores, "./test_score_plot.png", "Test Scores of 100 Episodes (Avg. score {})".format(avg_score)) 52 | env.close() 53 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/unity_to_external_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_to_external.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import unity_message_pb2 as communicator__objects_dot_unity__message__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='communicator_objects/unity_to_external.proto', 21 | package='communicator_objects', 22 | syntax='proto3', 23 | serialized_pb=_b('\n,communicator_objects/unity_to_external.proto\x12\x14\x63ommunicator_objects\x1a(communicator_objects/unity_message.proto2g\n\x0fUnityToExternal\x12T\n\x08\x45xchange\x12\".communicator_objects.UnityMessage\x1a\".communicator_objects.UnityMessage\"\x00\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | , 25 | dependencies=[communicator__objects_dot_unity__message__pb2.DESCRIPTOR,]) 26 | 27 | 28 | 29 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 30 | 31 | 32 | DESCRIPTOR.has_options = True 33 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 34 | 35 | _UNITYTOEXTERNAL = _descriptor.ServiceDescriptor( 36 | name='UnityToExternal', 37 | full_name='communicator_objects.UnityToExternal', 38 | file=DESCRIPTOR, 39 | index=0, 40 | options=None, 41 | serialized_start=112, 42 | serialized_end=215, 43 | methods=[ 44 | _descriptor.MethodDescriptor( 45 | name='Exchange', 46 | full_name='communicator_objects.UnityToExternal.Exchange', 47 | index=0, 48 | containing_service=None, 49 | input_type=communicator__objects_dot_unity__message__pb2._UNITYMESSAGE, 50 | output_type=communicator__objects_dot_unity__message__pb2._UNITYMESSAGE, 51 | options=None, 52 | ), 53 | ]) 54 | _sym_db.RegisterServiceDescriptor(_UNITYTOEXTERNAL) 55 | 56 | DESCRIPTOR.services_by_name['UnityToExternal'] = _UNITYTOEXTERNAL 57 | 58 | # @@protoc_insertion_point(module_scope) 59 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/space_type_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/space_type_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf.internal import enum_type_wrapper 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | from google.protobuf import descriptor_pb2 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | from communicator_objects import resolution_proto_pb2 as communicator__objects_dot_resolution__proto__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/space_type_proto.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n+communicator_objects/space_type_proto.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/resolution_proto.proto*.\n\x0eSpaceTypeProto\x12\x0c\n\x08\x64iscrete\x10\x00\x12\x0e\n\ncontinuous\x10\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,]) 27 | 28 | _SPACETYPEPROTO = _descriptor.EnumDescriptor( 29 | name='SpaceTypeProto', 30 | full_name='communicator_objects.SpaceTypeProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | values=[ 34 | _descriptor.EnumValueDescriptor( 35 | name='discrete', index=0, number=0, 36 | options=None, 37 | type=None), 38 | _descriptor.EnumValueDescriptor( 39 | name='continuous', index=1, number=1, 40 | options=None, 41 | type=None), 42 | ], 43 | containing_type=None, 44 | options=None, 45 | serialized_start=114, 46 | serialized_end=160, 47 | ) 48 | _sym_db.RegisterEnumDescriptor(_SPACETYPEPROTO) 49 | 50 | SpaceTypeProto = enum_type_wrapper.EnumTypeWrapper(_SPACETYPEPROTO) 51 | discrete = 0 52 | continuous = 1 53 | 54 | 55 | DESCRIPTOR.enum_types_by_name['SpaceTypeProto'] = _SPACETYPEPROTO 56 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 57 | 58 | 59 | DESCRIPTOR.has_options = True 60 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 61 | # @@protoc_insertion_point(module_scope) 62 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/brain_type_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/brain_type_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf.internal import enum_type_wrapper 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | from google.protobuf import descriptor_pb2 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | from communicator_objects import resolution_proto_pb2 as communicator__objects_dot_resolution__proto__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/brain_type_proto.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n+communicator_objects/brain_type_proto.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/resolution_proto.proto*G\n\x0e\x42rainTypeProto\x12\n\n\x06Player\x10\x00\x12\r\n\tHeuristic\x10\x01\x12\x0c\n\x08\x45xternal\x10\x02\x12\x0c\n\x08Internal\x10\x03\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,]) 27 | 28 | _BRAINTYPEPROTO = _descriptor.EnumDescriptor( 29 | name='BrainTypeProto', 30 | full_name='communicator_objects.BrainTypeProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | values=[ 34 | _descriptor.EnumValueDescriptor( 35 | name='Player', index=0, number=0, 36 | options=None, 37 | type=None), 38 | _descriptor.EnumValueDescriptor( 39 | name='Heuristic', index=1, number=1, 40 | options=None, 41 | type=None), 42 | _descriptor.EnumValueDescriptor( 43 | name='External', index=2, number=2, 44 | options=None, 45 | type=None), 46 | _descriptor.EnumValueDescriptor( 47 | name='Internal', index=3, number=3, 48 | options=None, 49 | type=None), 50 | ], 51 | containing_type=None, 52 | options=None, 53 | serialized_start=114, 54 | serialized_end=185, 55 | ) 56 | _sym_db.RegisterEnumDescriptor(_BRAINTYPEPROTO) 57 | 58 | BrainTypeProto = enum_type_wrapper.EnumTypeWrapper(_BRAINTYPEPROTO) 59 | Player = 0 60 | Heuristic = 1 61 | External = 2 62 | Internal = 3 63 | 64 | 65 | DESCRIPTOR.enum_types_by_name['BrainTypeProto'] = _BRAINTYPEPROTO 66 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 67 | 68 | 69 | DESCRIPTOR.has_options = True 70 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 71 | # @@protoc_insertion_point(module_scope) 72 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/unity_rl_initialization_input_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_rl_initialization_input.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/unity_rl_initialization_input.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n8communicator_objects/unity_rl_initialization_input.proto\x12\x14\x63ommunicator_objects\"*\n\x1aUnityRLInitializationInput\x12\x0c\n\x04seed\x18\x01 \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _UNITYRLINITIALIZATIONINPUT = _descriptor.Descriptor( 29 | name='UnityRLInitializationInput', 30 | full_name='communicator_objects.UnityRLInitializationInput', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='seed', full_name='communicator_objects.UnityRLInitializationInput.seed', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | ], 43 | extensions=[ 44 | ], 45 | nested_types=[], 46 | enum_types=[ 47 | ], 48 | options=None, 49 | is_extendable=False, 50 | syntax='proto3', 51 | extension_ranges=[], 52 | oneofs=[ 53 | ], 54 | serialized_start=82, 55 | serialized_end=124, 56 | ) 57 | 58 | DESCRIPTOR.message_types_by_name['UnityRLInitializationInput'] = _UNITYRLINITIALIZATIONINPUT 59 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 60 | 61 | UnityRLInitializationInput = _reflection.GeneratedProtocolMessageType('UnityRLInitializationInput', (_message.Message,), dict( 62 | DESCRIPTOR = _UNITYRLINITIALIZATIONINPUT, 63 | __module__ = 'communicator_objects.unity_rl_initialization_input_pb2' 64 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInitializationInput) 65 | )) 66 | _sym_db.RegisterMessage(UnityRLInitializationInput) 67 | 68 | 69 | DESCRIPTOR.has_options = True 70 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 71 | # @@protoc_insertion_point(module_scope) 72 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/header_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/header.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/header.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n!communicator_objects/header.proto\x12\x14\x63ommunicator_objects\")\n\x06Header\x12\x0e\n\x06status\x18\x01 \x01(\x05\x12\x0f\n\x07message\x18\x02 \x01(\tB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _HEADER = _descriptor.Descriptor( 29 | name='Header', 30 | full_name='communicator_objects.Header', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='status', full_name='communicator_objects.Header.status', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='message', full_name='communicator_objects.Header.message', index=1, 44 | number=2, type=9, cpp_type=9, label=1, 45 | has_default_value=False, default_value=_b("").decode('utf-8'), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | ], 50 | extensions=[ 51 | ], 52 | nested_types=[], 53 | enum_types=[ 54 | ], 55 | options=None, 56 | is_extendable=False, 57 | syntax='proto3', 58 | extension_ranges=[], 59 | oneofs=[ 60 | ], 61 | serialized_start=59, 62 | serialized_end=100, 63 | ) 64 | 65 | DESCRIPTOR.message_types_by_name['Header'] = _HEADER 66 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 67 | 68 | Header = _reflection.GeneratedProtocolMessageType('Header', (_message.Message,), dict( 69 | DESCRIPTOR = _HEADER, 70 | __module__ = 'communicator_objects.header_pb2' 71 | # @@protoc_insertion_point(class_scope:communicator_objects.Header) 72 | )) 73 | _sym_db.RegisterMessage(Header) 74 | 75 | 76 | DESCRIPTOR.has_options = True 77 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 78 | # @@protoc_insertion_point(module_scope) 79 | -------------------------------------------------------------------------------- /install_requirements/unityagents/brain.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | 4 | class BrainInfo: 5 | def __init__(self, visual_observation, vector_observation, text_observations, memory=None, 6 | reward=None, agents=None, local_done=None, 7 | vector_action=None, text_action=None, max_reached=None): 8 | """ 9 | Describes experience at current step of all agents linked to a brain. 10 | """ 11 | self.visual_observations = visual_observation 12 | self.vector_observations = vector_observation 13 | self.text_observations = text_observations 14 | self.memories = memory 15 | self.rewards = reward 16 | self.local_done = local_done 17 | self.max_reached = max_reached 18 | self.agents = agents 19 | self.previous_vector_actions = vector_action 20 | self.previous_text_actions = text_action 21 | 22 | 23 | AllBrainInfo = Dict[str, BrainInfo] 24 | 25 | 26 | class BrainParameters: 27 | def __init__(self, brain_name, brain_param): 28 | """ 29 | Contains all brain-specific parameters. 30 | :param brain_name: Name of brain. 31 | :param brain_param: Dictionary of brain parameters. 32 | """ 33 | self.brain_name = brain_name 34 | self.vector_observation_space_size = brain_param["vectorObservationSize"] 35 | self.num_stacked_vector_observations = brain_param["numStackedVectorObservations"] 36 | self.number_visual_observations = len(brain_param["cameraResolutions"]) 37 | self.camera_resolutions = brain_param["cameraResolutions"] 38 | self.vector_action_space_size = brain_param["vectorActionSize"] 39 | self.vector_action_descriptions = brain_param["vectorActionDescriptions"] 40 | self.vector_action_space_type = ["discrete", "continuous"][brain_param["vectorActionSpaceType"]] 41 | self.vector_observation_space_type = ["discrete", "continuous"][brain_param["vectorObservationSpaceType"]] 42 | 43 | def __str__(self): 44 | return '''Unity brain name: {0} 45 | Number of Visual Observations (per agent): {1} 46 | Vector Observation space type: {2} 47 | Vector Observation space size (per agent): {3} 48 | Number of stacked Vector Observation: {4} 49 | Vector Action space type: {5} 50 | Vector Action space size (per agent): {6} 51 | Vector Action descriptions: {7}'''.format(self.brain_name, 52 | str(self.number_visual_observations), 53 | self.vector_observation_space_type, 54 | str(self.vector_observation_space_size), 55 | str(self.num_stacked_vector_observations), 56 | self.vector_action_space_type, 57 | str(self.vector_action_space_size), 58 | ', '.join(self.vector_action_descriptions)) 59 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torchsummary import summary 6 | 7 | torch.manual_seed(999) 8 | 9 | def hidden_init(layer): 10 | fan_in = layer.weight.data.size()[0] 11 | lim = 1. / np.sqrt(fan_in) 12 | return (-lim, lim) 13 | 14 | class ActorNetwork(nn.Module): 15 | """ 16 | Actor (Policy) Network. 17 | """ 18 | 19 | def __init__(self, state_dim, action_dim): 20 | """Initialize parameters and build model. 21 | :state_dim (int): Dimension of each state 22 | :action_dim (int): Dimension of each action 23 | """ 24 | super(ActorNetwork, self).__init__() 25 | self.fc1 = nn.Linear(state_dim, 64) 26 | self.fc2 = nn.Linear(64, 128) 27 | self.fc3 = nn.Linear(128, action_dim) 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | """ 32 | Initialize parameters 33 | """ 34 | self.fc1.weight.data.uniform_(*hidden_init(self.fc1)) 35 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2)) 36 | self.fc3.weight.data.uniform_(-3e-3, 3e-3) 37 | 38 | def forward(self, x): 39 | """ 40 | Maps a state to actions 41 | """ 42 | x = F.relu(self.fc1(x)) 43 | x = F.relu(self.fc2(x)) 44 | return F.tanh(self.fc3(x)) 45 | 46 | 47 | class CriticNetwork(nn.Module): 48 | """ 49 | Critic (State-Value) Network. 50 | """ 51 | 52 | def __init__(self, state_dim, action_dim): 53 | """ 54 | Initialize parameters and build model 55 | :state_dim (int): Dimension of each state 56 | :action_dim (int): Dimension of each action 57 | """ 58 | super(CriticNetwork, self).__init__() 59 | self.state_fc = nn.Linear(state_dim, 64) 60 | self.fc1 = nn.Linear(action_dim+64, 128) 61 | self.fc2 = nn.Linear(128, 64) 62 | self.fc3 = nn.Linear(64, 1) 63 | self.reset_parameters() 64 | 65 | def reset_parameters(self): 66 | """ 67 | Initialize parameters 68 | """ 69 | self.state_fc.weight.data.uniform_(*hidden_init(self.state_fc)) 70 | self.fc1.weight.data.uniform_(*hidden_init(self.fc1)) 71 | self.fc2.weight.data.uniform_(*hidden_init(self.fc2)) 72 | self.fc3.weight.data.uniform_(-3e-3, 3e-3) 73 | 74 | def forward(self, state, action): 75 | """ 76 | Maps a state-action pair to Q-values 77 | """ 78 | state, action = state.squeeze(), action.squeeze() 79 | x = F.relu(self.state_fc(state)) 80 | x = torch.cat((x, action), dim=1) 81 | x = F.relu(self.fc1(x)) 82 | x = F.relu(self.fc2(x)) 83 | return self.fc3(x) 84 | 85 | if __name__ == "__main__": 86 | # summarize network structures using torchsummary 87 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 88 | state_dim, action_dim = 24, 2 89 | actor = ActorNetwork(state_dim, action_dim).to(device) 90 | critic = CriticNetwork(state_dim, action_dim).to(device) 91 | sum_res = summary(actor, (1, state_dim)) 92 | sum_res = summary(critic, [(1, state_dim), (1, action_dim)]) 93 | 94 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/resolution_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/resolution_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/resolution_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n+communicator_objects/resolution_proto.proto\x12\x14\x63ommunicator_objects\"D\n\x0fResolutionProto\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x12\n\ngray_scale\x18\x03 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _RESOLUTIONPROTO = _descriptor.Descriptor( 29 | name='ResolutionProto', 30 | full_name='communicator_objects.ResolutionProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='width', full_name='communicator_objects.ResolutionProto.width', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='height', full_name='communicator_objects.ResolutionProto.height', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='gray_scale', full_name='communicator_objects.ResolutionProto.gray_scale', index=2, 51 | number=3, type=8, cpp_type=7, label=1, 52 | has_default_value=False, default_value=False, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | syntax='proto3', 65 | extension_ranges=[], 66 | oneofs=[ 67 | ], 68 | serialized_start=69, 69 | serialized_end=137, 70 | ) 71 | 72 | DESCRIPTOR.message_types_by_name['ResolutionProto'] = _RESOLUTIONPROTO 73 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 74 | 75 | ResolutionProto = _reflection.GeneratedProtocolMessageType('ResolutionProto', (_message.Message,), dict( 76 | DESCRIPTOR = _RESOLUTIONPROTO, 77 | __module__ = 'communicator_objects.resolution_proto_pb2' 78 | # @@protoc_insertion_point(class_scope:communicator_objects.ResolutionProto) 79 | )) 80 | _sym_db.RegisterMessage(ResolutionProto) 81 | 82 | 83 | DESCRIPTOR.has_options = True 84 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 85 | # @@protoc_insertion_point(module_scope) 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Project Details 2 | --- 3 | In this project, two Actor-Critic agents were trained using Deep Deterministic Policy Gradients (DDPG) to play the tennis game: 4 | 5 | ![](https://user-images.githubusercontent.com/10624937/42135623-e770e354-7d12-11e8-998d-29fc74429ca2.gif) 6 | 7 | ## The Environment 8 | The environment for this project involves two agents controlling rackets to bounce a ball over a net. 9 | ### State Space 10 | State is continuous, it has **8** dimensions corresponding to the position and velocity of the ball and racket. Each agent receives its own, local observation. 11 | ### Action Space 12 | Each action is continuous, in the form of a vector with **2** dimensions, corresponding to movement toward (or away from) the net, and jumping. 13 | ### Reward 14 | - A reward of **+0.1** is obtained, if an agent hits the ball over the net. 15 | - A reward of **-0.01** is obtained, if an agent lets a ball hit the ground or hits the ball out of bounds. 16 | ### Goal 17 | The agents must bounce ball between one another while not dropping or sending ball out of bounds. The longer the bounce turns last, the better the performance is achieved. 18 | ### Solving the Environment 19 | The task is episodic. An average score of **+0.5** (over **100** consecutive episodes, after taking the maximum of the two agents) is required to solve this task. 20 | 21 | # Getting Started 22 | ## Step 1: Clone the Project and Install Dependencies 23 | \*Please prepare a python3 virtual environment if necessary. 24 | ``` 25 | git clone https://github.com/qiaochen/DDPG_MultiAgent 26 | cd install_requirements 27 | pip install . 28 | ``` 29 | ## Step 2: Download the Unity Environment 30 | For this project, I use the environment form **Udacity**. The links to modules at different system environments are copied here for convenience: 31 | * Linux: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Tennis/Tennis_Linux.zip) 32 | * Mac OSX: [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Tennis/Tennis.app.zip) 33 | * Windows (32-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Tennis/Tennis_Windows_x86.zip) 34 | * Windows (64-bit): [click here](https://s3-us-west-1.amazonaws.com/udacity-drlnd/P3/Tennis/Tennis_Windows_x86_64.zip) 35 | 36 | I conducted my experiments in Ubuntu 16.04, so I picked the 1st option. Then, extract and place the Tennis_Linux folder within the project root. The project folder structure now looks like this (Project program generated .png and model files are excluded): 37 | ``` 38 | Project Root 39 | |-install_requirements (Folder) 40 | |-README.md 41 | |-Report.md 42 | |-agent.py 43 | |-models.py 44 | |-train.py 45 | |-test.py 46 | |-utils.py 47 | |-Tennis_Linux (Folder) 48 | |-Tennis.x86_64 49 | |-Tennis.x86 50 | |-Tennis_Data (Folder) 51 | ``` 52 | ## Instructions to the Program 53 | --- 54 | ### Step 1: Training 55 | ``` 56 | python thain.py 57 | ``` 58 | After training, the following files will be generated and placed in the project root folder: 59 | 60 | - best_model.checkpoint (the trained model) 61 | - training_100avgscore_plot.png (a plot of avg. scores during training) 62 | - training_score_plot.png (a plot of per-episode scores during training) 63 | - unity-environment.log (log file created by Unity) 64 | ### Step 2: Testing 65 | ``` 66 | python test.py 67 | ``` 68 | The testing performance will be summarized in the generated plot within project root: 69 | 70 | - test_score_plot.png 71 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/agent_action_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/agent_action_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/agent_action_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n-communicator_objects/agent_action_proto.proto\x12\x14\x63ommunicator_objects\"R\n\x10\x41gentActionProto\x12\x16\n\x0evector_actions\x18\x01 \x03(\x02\x12\x14\n\x0ctext_actions\x18\x02 \x01(\t\x12\x10\n\x08memories\x18\x03 \x03(\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _AGENTACTIONPROTO = _descriptor.Descriptor( 29 | name='AgentActionProto', 30 | full_name='communicator_objects.AgentActionProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='vector_actions', full_name='communicator_objects.AgentActionProto.vector_actions', index=0, 37 | number=1, type=2, cpp_type=6, label=3, 38 | has_default_value=False, default_value=[], 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='text_actions', full_name='communicator_objects.AgentActionProto.text_actions', index=1, 44 | number=2, type=9, cpp_type=9, label=1, 45 | has_default_value=False, default_value=_b("").decode('utf-8'), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='memories', full_name='communicator_objects.AgentActionProto.memories', index=2, 51 | number=3, type=2, cpp_type=6, label=3, 52 | has_default_value=False, default_value=[], 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | syntax='proto3', 65 | extension_ranges=[], 66 | oneofs=[ 67 | ], 68 | serialized_start=71, 69 | serialized_end=153, 70 | ) 71 | 72 | DESCRIPTOR.message_types_by_name['AgentActionProto'] = _AGENTACTIONPROTO 73 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 74 | 75 | AgentActionProto = _reflection.GeneratedProtocolMessageType('AgentActionProto', (_message.Message,), dict( 76 | DESCRIPTOR = _AGENTACTIONPROTO, 77 | __module__ = 'communicator_objects.agent_action_proto_pb2' 78 | # @@protoc_insertion_point(class_scope:communicator_objects.AgentActionProto) 79 | )) 80 | _sym_db.RegisterMessage(AgentActionProto) 81 | 82 | 83 | DESCRIPTOR.has_options = True 84 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 85 | # @@protoc_insertion_point(module_scope) 86 | -------------------------------------------------------------------------------- /install_requirements/unityagents/rpc_communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import grpc 3 | 4 | from multiprocessing import Pipe 5 | from concurrent.futures import ThreadPoolExecutor 6 | 7 | from .communicator import Communicator 8 | from communicator_objects import UnityToExternalServicer, add_UnityToExternalServicer_to_server 9 | from communicator_objects import UnityMessage, UnityInput, UnityOutput 10 | from .exception import UnityTimeOutException 11 | 12 | 13 | logging.basicConfig(level=logging.INFO) 14 | logger = logging.getLogger("unityagents") 15 | 16 | 17 | class UnityToExternalServicerImplementation(UnityToExternalServicer): 18 | parent_conn, child_conn = Pipe() 19 | 20 | def Initialize(self, request, context): 21 | self.child_conn.send(request) 22 | return self.child_conn.recv() 23 | 24 | def Exchange(self, request, context): 25 | self.child_conn.send(request) 26 | return self.child_conn.recv() 27 | 28 | 29 | class RpcCommunicator(Communicator): 30 | def __init__(self, worker_id=0, 31 | base_port=5005): 32 | """ 33 | Python side of the grpc communication. Python is the server and Unity the client 34 | 35 | 36 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 37 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 38 | """ 39 | self.port = base_port + worker_id 40 | self.worker_id = worker_id 41 | self.server = None 42 | self.unity_to_external = None 43 | self.is_open = False 44 | 45 | def initialize(self, inputs: UnityInput) -> UnityOutput: 46 | try: 47 | # Establish communication grpc 48 | self.server = grpc.server(ThreadPoolExecutor(max_workers=10)) 49 | self.unity_to_external = UnityToExternalServicerImplementation() 50 | add_UnityToExternalServicer_to_server(self.unity_to_external, self.server) 51 | self.server.add_insecure_port('[::]:'+str(self.port)) 52 | self.server.start() 53 | except : 54 | raise UnityTimeOutException( 55 | "Couldn't start socket communication because worker number {} is still in use. " 56 | "You may need to manually close a previously opened environment " 57 | "or use a different worker number.".format(str(self.worker_id))) 58 | if not self.unity_to_external.parent_conn.poll(30): 59 | raise UnityTimeOutException( 60 | "The Unity environment took too long to respond. Make sure that :\n" 61 | "\t The environment does not need user interaction to launch\n" 62 | "\t The Academy and the External Brain(s) are attached to objects in the Scene\n" 63 | "\t The environment and the Python interface have compatible versions.") 64 | aca_param = self.unity_to_external.parent_conn.recv().unity_output 65 | self.is_open = True 66 | message = UnityMessage() 67 | message.header.status = 200 68 | message.unity_input.CopyFrom(inputs) 69 | self.unity_to_external.parent_conn.send(message) 70 | self.unity_to_external.parent_conn.recv() 71 | return aca_param 72 | 73 | def exchange(self, inputs: UnityInput) -> UnityOutput: 74 | message = UnityMessage() 75 | message.header.status = 200 76 | message.unity_input.CopyFrom(inputs) 77 | self.unity_to_external.parent_conn.send(message) 78 | output = self.unity_to_external.parent_conn.recv() 79 | if output.header.status != 200: 80 | return None 81 | return output.unity_output 82 | 83 | def close(self): 84 | """ 85 | Sends a shutdown signal to the unity environment, and closes the grpc connection. 86 | """ 87 | if self.is_open: 88 | message_input = UnityMessage() 89 | message_input.header.status = 400 90 | self.unity_to_external.parent_conn.send(message_input) 91 | self.unity_to_external.parent_conn.close() 92 | self.server.stop(False) 93 | self.is_open = False 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/unity_input_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_input.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import unity_rl_input_pb2 as communicator__objects_dot_unity__rl__input__pb2 17 | from communicator_objects import unity_rl_initialization_input_pb2 as communicator__objects_dot_unity__rl__initialization__input__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/unity_input.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n&communicator_objects/unity_input.proto\x12\x14\x63ommunicator_objects\x1a)communicator_objects/unity_rl_input.proto\x1a\x38\x63ommunicator_objects/unity_rl_initialization_input.proto\"\xb0\x01\n\nUnityInput\x12\x34\n\x08rl_input\x18\x01 \x01(\x0b\x32\".communicator_objects.UnityRLInput\x12Q\n\x17rl_initialization_input\x18\x02 \x01(\x0b\x32\x30.communicator_objects.UnityRLInitializationInput\x12\x19\n\x11\x63ustom_data_input\x18\x03 \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_unity__rl__input__pb2.DESCRIPTOR,communicator__objects_dot_unity__rl__initialization__input__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _UNITYINPUT = _descriptor.Descriptor( 32 | name='UnityInput', 33 | full_name='communicator_objects.UnityInput', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='rl_input', full_name='communicator_objects.UnityInput.rl_input', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='rl_initialization_input', full_name='communicator_objects.UnityInput.rl_initialization_input', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None, file=DESCRIPTOR), 52 | _descriptor.FieldDescriptor( 53 | name='custom_data_input', full_name='communicator_objects.UnityInput.custom_data_input', index=2, 54 | number=3, type=5, cpp_type=1, label=1, 55 | has_default_value=False, default_value=0, 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | options=None, file=DESCRIPTOR), 59 | ], 60 | extensions=[ 61 | ], 62 | nested_types=[], 63 | enum_types=[ 64 | ], 65 | options=None, 66 | is_extendable=False, 67 | syntax='proto3', 68 | extension_ranges=[], 69 | oneofs=[ 70 | ], 71 | serialized_start=166, 72 | serialized_end=342, 73 | ) 74 | 75 | _UNITYINPUT.fields_by_name['rl_input'].message_type = communicator__objects_dot_unity__rl__input__pb2._UNITYRLINPUT 76 | _UNITYINPUT.fields_by_name['rl_initialization_input'].message_type = communicator__objects_dot_unity__rl__initialization__input__pb2._UNITYRLINITIALIZATIONINPUT 77 | DESCRIPTOR.message_types_by_name['UnityInput'] = _UNITYINPUT 78 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 79 | 80 | UnityInput = _reflection.GeneratedProtocolMessageType('UnityInput', (_message.Message,), dict( 81 | DESCRIPTOR = _UNITYINPUT, 82 | __module__ = 'communicator_objects.unity_input_pb2' 83 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityInput) 84 | )) 85 | _sym_db.RegisterMessage(UnityInput) 86 | 87 | 88 | DESCRIPTOR.has_options = True 89 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 90 | # @@protoc_insertion_point(module_scope) 91 | -------------------------------------------------------------------------------- /install_requirements/unityagents/socket_communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import socket 3 | import struct 4 | 5 | from .communicator import Communicator 6 | from communicator_objects import UnityMessage, UnityOutput, UnityInput 7 | from .exception import UnityTimeOutException 8 | 9 | 10 | logging.basicConfig(level=logging.INFO) 11 | logger = logging.getLogger("unityagents") 12 | 13 | 14 | class SocketCommunicator(Communicator): 15 | def __init__(self, worker_id=0, 16 | base_port=5005): 17 | """ 18 | Python side of the socket communication 19 | 20 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 21 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 22 | """ 23 | 24 | self.port = base_port + worker_id 25 | self._buffer_size = 12000 26 | self.worker_id = worker_id 27 | self._socket = None 28 | self._conn = None 29 | 30 | def initialize(self, inputs: UnityInput) -> UnityOutput: 31 | try: 32 | # Establish communication socket 33 | self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 34 | self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 35 | self._socket.bind(("localhost", self.port)) 36 | except: 37 | raise UnityTimeOutException("Couldn't start socket communication because worker number {} is still in use. " 38 | "You may need to manually close a previously opened environment " 39 | "or use a different worker number.".format(str(self.worker_id))) 40 | try: 41 | self._socket.settimeout(30) 42 | self._socket.listen(1) 43 | self._conn, _ = self._socket.accept() 44 | self._conn.settimeout(30) 45 | except : 46 | raise UnityTimeOutException( 47 | "The Unity environment took too long to respond. Make sure that :\n" 48 | "\t The environment does not need user interaction to launch\n" 49 | "\t The Academy and the External Brain(s) are attached to objects in the Scene\n" 50 | "\t The environment and the Python interface have compatible versions.") 51 | message = UnityMessage() 52 | message.header.status = 200 53 | message.unity_input.CopyFrom(inputs) 54 | self._communicator_send(message.SerializeToString()) 55 | initialization_output = UnityMessage() 56 | initialization_output.ParseFromString(self._communicator_receive()) 57 | return initialization_output.unity_output 58 | 59 | def _communicator_receive(self): 60 | try: 61 | s = self._conn.recv(self._buffer_size) 62 | message_length = struct.unpack("I", bytearray(s[:4]))[0] 63 | s = s[4:] 64 | while len(s) != message_length: 65 | s += self._conn.recv(self._buffer_size) 66 | except socket.timeout as e: 67 | raise UnityTimeOutException("The environment took too long to respond.") 68 | return s 69 | 70 | def _communicator_send(self, message): 71 | self._conn.send(struct.pack("I", len(message)) + message) 72 | 73 | def exchange(self, inputs: UnityInput) -> UnityOutput: 74 | message = UnityMessage() 75 | message.header.status = 200 76 | message.unity_input.CopyFrom(inputs) 77 | self._communicator_send(message.SerializeToString()) 78 | outputs = UnityMessage() 79 | outputs.ParseFromString(self._communicator_receive()) 80 | if outputs.header.status != 200: 81 | return None 82 | return outputs.unity_output 83 | 84 | def close(self): 85 | """ 86 | Sends a shutdown signal to the unity environment, and closes the socket connection. 87 | """ 88 | if self._socket is not None and self._conn is not None: 89 | message_input = UnityMessage() 90 | message_input.header.status = 400 91 | self._communicator_send(message_input.SerializeToString()) 92 | if self._socket is not None: 93 | self._socket.close() 94 | self._socket = None 95 | if self._socket is not None: 96 | self._conn.close() 97 | self._conn = None 98 | 99 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/unity_output_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_output.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import unity_rl_output_pb2 as communicator__objects_dot_unity__rl__output__pb2 17 | from communicator_objects import unity_rl_initialization_output_pb2 as communicator__objects_dot_unity__rl__initialization__output__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/unity_output.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n\'communicator_objects/unity_output.proto\x12\x14\x63ommunicator_objects\x1a*communicator_objects/unity_rl_output.proto\x1a\x39\x63ommunicator_objects/unity_rl_initialization_output.proto\"\xb6\x01\n\x0bUnityOutput\x12\x36\n\trl_output\x18\x01 \x01(\x0b\x32#.communicator_objects.UnityRLOutput\x12S\n\x18rl_initialization_output\x18\x02 \x01(\x0b\x32\x31.communicator_objects.UnityRLInitializationOutput\x12\x1a\n\x12\x63ustom_data_output\x18\x03 \x01(\tB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_unity__rl__output__pb2.DESCRIPTOR,communicator__objects_dot_unity__rl__initialization__output__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _UNITYOUTPUT = _descriptor.Descriptor( 32 | name='UnityOutput', 33 | full_name='communicator_objects.UnityOutput', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='rl_output', full_name='communicator_objects.UnityOutput.rl_output', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='rl_initialization_output', full_name='communicator_objects.UnityOutput.rl_initialization_output', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None, file=DESCRIPTOR), 52 | _descriptor.FieldDescriptor( 53 | name='custom_data_output', full_name='communicator_objects.UnityOutput.custom_data_output', index=2, 54 | number=3, type=9, cpp_type=9, label=1, 55 | has_default_value=False, default_value=_b("").decode('utf-8'), 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | options=None, file=DESCRIPTOR), 59 | ], 60 | extensions=[ 61 | ], 62 | nested_types=[], 63 | enum_types=[ 64 | ], 65 | options=None, 66 | is_extendable=False, 67 | syntax='proto3', 68 | extension_ranges=[], 69 | oneofs=[ 70 | ], 71 | serialized_start=169, 72 | serialized_end=351, 73 | ) 74 | 75 | _UNITYOUTPUT.fields_by_name['rl_output'].message_type = communicator__objects_dot_unity__rl__output__pb2._UNITYRLOUTPUT 76 | _UNITYOUTPUT.fields_by_name['rl_initialization_output'].message_type = communicator__objects_dot_unity__rl__initialization__output__pb2._UNITYRLINITIALIZATIONOUTPUT 77 | DESCRIPTOR.message_types_by_name['UnityOutput'] = _UNITYOUTPUT 78 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 79 | 80 | UnityOutput = _reflection.GeneratedProtocolMessageType('UnityOutput', (_message.Message,), dict( 81 | DESCRIPTOR = _UNITYOUTPUT, 82 | __module__ = 'communicator_objects.unity_output_pb2' 83 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityOutput) 84 | )) 85 | _sym_db.RegisterMessage(UnityOutput) 86 | 87 | 88 | DESCRIPTOR.has_options = True 89 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 90 | # @@protoc_insertion_point(module_scope) 91 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/unity_message_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_message.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import unity_output_pb2 as communicator__objects_dot_unity__output__pb2 17 | from communicator_objects import unity_input_pb2 as communicator__objects_dot_unity__input__pb2 18 | from communicator_objects import header_pb2 as communicator__objects_dot_header__pb2 19 | 20 | 21 | DESCRIPTOR = _descriptor.FileDescriptor( 22 | name='communicator_objects/unity_message.proto', 23 | package='communicator_objects', 24 | syntax='proto3', 25 | serialized_pb=_b('\n(communicator_objects/unity_message.proto\x12\x14\x63ommunicator_objects\x1a\'communicator_objects/unity_output.proto\x1a&communicator_objects/unity_input.proto\x1a!communicator_objects/header.proto\"\xac\x01\n\x0cUnityMessage\x12,\n\x06header\x18\x01 \x01(\x0b\x32\x1c.communicator_objects.Header\x12\x37\n\x0cunity_output\x18\x02 \x01(\x0b\x32!.communicator_objects.UnityOutput\x12\x35\n\x0bunity_input\x18\x03 \x01(\x0b\x32 .communicator_objects.UnityInputB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 26 | , 27 | dependencies=[communicator__objects_dot_unity__output__pb2.DESCRIPTOR,communicator__objects_dot_unity__input__pb2.DESCRIPTOR,communicator__objects_dot_header__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _UNITYMESSAGE = _descriptor.Descriptor( 33 | name='UnityMessage', 34 | full_name='communicator_objects.UnityMessage', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='header', full_name='communicator_objects.UnityMessage.header', index=0, 41 | number=1, type=11, cpp_type=10, label=1, 42 | has_default_value=False, default_value=None, 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None, file=DESCRIPTOR), 46 | _descriptor.FieldDescriptor( 47 | name='unity_output', full_name='communicator_objects.UnityMessage.unity_output', index=1, 48 | number=2, type=11, cpp_type=10, label=1, 49 | has_default_value=False, default_value=None, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | options=None, file=DESCRIPTOR), 53 | _descriptor.FieldDescriptor( 54 | name='unity_input', full_name='communicator_objects.UnityMessage.unity_input', index=2, 55 | number=3, type=11, cpp_type=10, label=1, 56 | has_default_value=False, default_value=None, 57 | message_type=None, enum_type=None, containing_type=None, 58 | is_extension=False, extension_scope=None, 59 | options=None, file=DESCRIPTOR), 60 | ], 61 | extensions=[ 62 | ], 63 | nested_types=[], 64 | enum_types=[ 65 | ], 66 | options=None, 67 | is_extendable=False, 68 | syntax='proto3', 69 | extension_ranges=[], 70 | oneofs=[ 71 | ], 72 | serialized_start=183, 73 | serialized_end=355, 74 | ) 75 | 76 | _UNITYMESSAGE.fields_by_name['header'].message_type = communicator__objects_dot_header__pb2._HEADER 77 | _UNITYMESSAGE.fields_by_name['unity_output'].message_type = communicator__objects_dot_unity__output__pb2._UNITYOUTPUT 78 | _UNITYMESSAGE.fields_by_name['unity_input'].message_type = communicator__objects_dot_unity__input__pb2._UNITYINPUT 79 | DESCRIPTOR.message_types_by_name['UnityMessage'] = _UNITYMESSAGE 80 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 81 | 82 | UnityMessage = _reflection.GeneratedProtocolMessageType('UnityMessage', (_message.Message,), dict( 83 | DESCRIPTOR = _UNITYMESSAGE, 84 | __module__ = 'communicator_objects.unity_message_pb2' 85 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityMessage) 86 | )) 87 | _sym_db.RegisterMessage(UnityMessage) 88 | 89 | 90 | DESCRIPTOR.has_options = True 91 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 92 | # @@protoc_insertion_point(module_scope) 93 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from unityagents import UnityEnvironment 2 | import numpy as np 3 | from agent import MADDPG 4 | from utils import draw 5 | 6 | unity_environment_path = "./Tennis_Linux/Tennis.x86_64" 7 | best_model_path = "./best_model.checkpoint" 8 | rollout_length = 3 9 | 10 | if __name__ == "__main__": 11 | # prepare environment 12 | env = UnityEnvironment(file_name=unity_environment_path) 13 | brain_name = env.brain_names[0] 14 | brain = env.brains[brain_name] 15 | env_info = env.reset(train_mode=True)[brain_name] 16 | 17 | num_agents = len(env_info.agents) 18 | print('Number of agents:', num_agents) 19 | 20 | # size of each action 21 | action_size = brain.vector_action_space_size 22 | print('Size of each action:', action_size) 23 | 24 | # examine the state space 25 | states = env_info.vector_observations 26 | state_size = states.shape[1] 27 | print('There are {} agents. Each observes a state with length: {}'.format(states.shape[0], state_size)) 28 | print('The state for the first agent looks like:', states[0]) 29 | 30 | num_episodes = 2500 31 | agent = MADDPG(state_size, 32 | action_size, 33 | lr_actor = 1e-5, 34 | lr_critic = 1e-4, 35 | lr_decay = .995, 36 | replay_buff_size = int(1e6), 37 | gamma = .95, 38 | batch_size = 64, 39 | random_seed = 999, 40 | soft_update_tau = 1e-3 41 | ) 42 | 43 | total_rewards = [] 44 | avg_scores = [] 45 | max_avg_score = -1 46 | max_score = -1 47 | threshold_init = 20 48 | noise_t = 1.0 49 | noise_decay = .995 50 | worsen_tolerance = threshold_init # for early-stopping training if consistently worsen for # episodes 51 | for i_episode in range(1, num_episodes+1): 52 | env_inst = env.reset(train_mode=True)[brain_name] # reset the environment 53 | states = env_inst.vector_observations # get the current state 54 | scores = np.zeros(num_agents) # initialize score array 55 | dones = [False]*num_agents 56 | while not np.any(dones): 57 | actions = agent.act(states,noise_t) # select an action 58 | env_inst = env.step(actions)[brain_name] # send the action to the environment 59 | next_states = env_inst.vector_observations # get the next state 60 | rewards = env_inst.rewards # get the reward 61 | dones = env_inst.local_done # see if episode has finished 62 | agent.update(states, actions, rewards, next_states, dones) 63 | 64 | noise_t *= noise_decay 65 | scores += rewards # update scores 66 | states = next_states 67 | 68 | episode_score = np.max(scores) 69 | total_rewards.append(episode_score) 70 | print("Episodic {} Score: {:.4f}".format(i_episode, episode_score)) 71 | 72 | if max_score <= episode_score: 73 | max_score = episode_score 74 | agent.save(best_model_path) # save best model so far 75 | 76 | if len(total_rewards) >= 100: # record avg score for the latest 100 steps 77 | latest_avg_score = sum(total_rewards[(len(total_rewards)-100):]) / 100 78 | print("100 Episodic Everage Score: {:.4f}".format(latest_avg_score)) 79 | avg_scores.append(latest_avg_score) 80 | 81 | if max_avg_score <= latest_avg_score: # record better results 82 | worsen_tolerance = threshold_init # re-count tolerance 83 | max_avg_score = latest_avg_score 84 | else: 85 | if max_avg_score > 0.5: 86 | worsen_tolerance -= 1 # count worsening counts 87 | print("Loaded from last best model.") 88 | agent.load(best_model_path) # continue from last best-model 89 | if worsen_tolerance <= 0: # earliy stop training 90 | print("Early Stop Training.") 91 | break 92 | 93 | draw(total_rewards,"./training_score_plot.png", "Training Scores (Per Episode)") 94 | draw(avg_scores,"./training_100avgscore_plot.png", "Training Scores (Average of Latest 100 Episodes)", ylabel="Avg. Score") 95 | env.close() 96 | 97 | -------------------------------------------------------------------------------- /Report.md: -------------------------------------------------------------------------------- 1 | # Project Report 2 | 3 | In this project, a multi-agent for tennis player control is trained using Deep Deterministic Policy Gradient (DDPG). Through learning by self-playing, the agent obtained good performance in the evaluation test. 4 | 5 | 6 | ## Learning Algorithm 7 | 8 | - Network Architecture 9 | The architecture of the Actor and Critic networks are summarized using the project [pytorch-summary](https://github.com/sksq96/pytorch-summary/) as follows: 10 | - Actor Network 11 | ``` 12 | ---------------------------------------------------------------- 13 | Layer (type) Output Shape Param # 14 | ================================================================ 15 | Linear-1 [-1, 1, 64] 1,600 16 | Linear-2 [-1, 1, 128] 8,320 17 | Linear-3 [-1, 1, 2] 258 18 | ================================================================ 19 | Total params: 10,178 20 | Trainable params: 10,178 21 | Non-trainable params: 0 22 | ---------------------------------------------------------------- 23 | Input size (MB): 0.00 24 | Forward/backward pass size (MB): 0.00 25 | Params size (MB): 0.04 26 | Estimated Total Size (MB): 0.04 27 | ---------------------------------------------------------------- 28 | ``` 29 | - Critic Network 30 | ``` 31 | ---------------------------------------------------------------- 32 | Layer (type) Output Shape Param # 33 | ================================================================ 34 | Linear-1 [-1, 64] 1,600 35 | Linear-2 [-1, 128] 8,576 36 | Linear-3 [-1, 64] 8,256 37 | Linear-4 [-1, 1] 65 38 | ================================================================ 39 | Total params: 18,497 40 | Trainable params: 18,497 41 | Non-trainable params: 0 42 | ---------------------------------------------------------------- 43 | Input size (MB): 0.00 44 | Forward/backward pass size (MB): 0.00 45 | Params size (MB): 0.07 46 | Estimated Total Size (MB): 0.07 47 | ---------------------------------------------------------------- 48 | ``` 49 | - Hyper-parameters 50 | - learning rate for actor network: 1e-5 51 | - learning rate for the critic network: 1e-4 52 | - learning rate decay rate: 0.995 53 | - replay buffer size: 1e6 54 | - long term reward discount rate: 0.95 55 | - soft update tau: 0.001 56 | - Training Strategy 57 | - Adam is used as the optimizer 58 | - An `early-stop` scheme is applied to stop training if the 100-episode-average score continues decreasing over `20` consecutive episodes. 59 | - Each time the model gets worse regarding avg scores, the model recovers from the last best model and the learning rate of Adam is decreased: `new learning rate = old learning rate * learning rate decay rate` 60 | 61 | ## Performance Evaluation 62 | ### Training 63 | During training, the performance jumped to the best level and stabalized there after about **1700** episodes. Before that, the first time the performance surpassed 0.5 occurred at around episode 800. The episodic and average (over 100 latest episodes) scores are plotted as following: 64 | - Reward per-episode during training 65 | 66 | ![img](https://raw.githubusercontent.com/qiaochen/DDPG_MultiAgent/master/training_score_plot.png) 67 | 68 | - Average reward over latest 100 episodes during training 69 | 70 | ![img](https://raw.githubusercontent.com/qiaochen/DDPG_MultiAgent/master/training_100avgscore_plot.png) 71 | 72 | As can be seen from the plot, the average score gradually passed **0.5** and reached **2.0** during training, before the early-stopping scheme terminates the training process. 73 | 74 | ### Testing 75 | The scores of 100 testing episodes are visualized as follows: 76 | 77 | ![img](https://raw.githubusercontent.com/qiaochen/DDPG_MultiAgent/master/test_score_plot.png) 78 | 79 | The model obtained an average score of **+1.95** during testing, which is over **+0.5**. 80 | 81 | ## Conclusion 82 | The trained model has successfully solved the tennis play task. The performance: 83 | 1. an average score of `+1.95` over `100` consecutive episodes 84 | 2. the best model was trained using around `1700` episodes 85 | 86 | has fulfilled the passing threshold of solving the problem: obtain an average score of higher than `+0.5` over `100` consecutive episodes. 87 | 88 | ## Ideas for Future Work 89 | - Use prioritized replay buffer, or Rainbow to improve the Critic network 90 | - Use methods like GAE or PPO in the calculation of policy loss, to improve the training performance of the Actor network. 91 | - See if A2C and other algorithms could perform better. 92 | -------------------------------------------------------------------------------- /install_requirements/unityagents/curriculum.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from .exception import UnityEnvironmentException 4 | 5 | import logging 6 | 7 | logger = logging.getLogger("unityagents") 8 | 9 | 10 | class Curriculum(object): 11 | def __init__(self, location, default_reset_parameters): 12 | """ 13 | Initializes a Curriculum object. 14 | :param location: Path to JSON defining curriculum. 15 | :param default_reset_parameters: Set of reset parameters for environment. 16 | """ 17 | self.lesson_length = 0 18 | self.max_lesson_number = 0 19 | self.measure_type = None 20 | if location is None: 21 | self.data = None 22 | else: 23 | try: 24 | with open(location) as data_file: 25 | self.data = json.load(data_file) 26 | except IOError: 27 | raise UnityEnvironmentException( 28 | "The file {0} could not be found.".format(location)) 29 | except UnicodeDecodeError: 30 | raise UnityEnvironmentException("There was an error decoding {}".format(location)) 31 | self.smoothing_value = 0 32 | for key in ['parameters', 'measure', 'thresholds', 33 | 'min_lesson_length', 'signal_smoothing']: 34 | if key not in self.data: 35 | raise UnityEnvironmentException("{0} does not contain a " 36 | "{1} field.".format(location, key)) 37 | parameters = self.data['parameters'] 38 | self.measure_type = self.data['measure'] 39 | self.max_lesson_number = len(self.data['thresholds']) 40 | for key in parameters: 41 | if key not in default_reset_parameters: 42 | raise UnityEnvironmentException( 43 | "The parameter {0} in Curriculum {1} is not present in " 44 | "the Environment".format(key, location)) 45 | for key in parameters: 46 | if len(parameters[key]) != self.max_lesson_number + 1: 47 | raise UnityEnvironmentException( 48 | "The parameter {0} in Curriculum {1} must have {2} values " 49 | "but {3} were found".format(key, location, 50 | self.max_lesson_number + 1, len(parameters[key]))) 51 | self.set_lesson_number(0) 52 | 53 | @property 54 | def measure(self): 55 | return self.measure_type 56 | 57 | @property 58 | def get_lesson_number(self): 59 | return self.lesson_number 60 | 61 | def set_lesson_number(self, value): 62 | self.lesson_length = 0 63 | self.lesson_number = max(0, min(value, self.max_lesson_number)) 64 | 65 | def increment_lesson(self, progress): 66 | """ 67 | Increments the lesson number depending on the progree given. 68 | :param progress: Measure of progress (either reward or percentage steps completed). 69 | """ 70 | if self.data is None or progress is None: 71 | return 72 | if self.data["signal_smoothing"]: 73 | progress = self.smoothing_value * 0.25 + 0.75 * progress 74 | self.smoothing_value = progress 75 | self.lesson_length += 1 76 | if self.lesson_number < self.max_lesson_number: 77 | if ((progress > self.data['thresholds'][self.lesson_number]) and 78 | (self.lesson_length > self.data['min_lesson_length'])): 79 | self.lesson_length = 0 80 | self.lesson_number += 1 81 | config = {} 82 | parameters = self.data["parameters"] 83 | for key in parameters: 84 | config[key] = parameters[key][self.lesson_number] 85 | logger.info("\nLesson changed. Now in Lesson {0} : \t{1}" 86 | .format(self.lesson_number, 87 | ', '.join([str(x) + ' -> ' + str(config[x]) for x in config]))) 88 | 89 | def get_config(self, lesson=None): 90 | """ 91 | Returns reset parameters which correspond to the lesson. 92 | :param lesson: The lesson you want to get the config of. If None, the current lesson is returned. 93 | :return: The configuration of the reset parameters. 94 | """ 95 | if self.data is None: 96 | return {} 97 | if lesson is None: 98 | lesson = self.lesson_number 99 | lesson = max(0, min(lesson, self.max_lesson_number)) 100 | config = {} 101 | parameters = self.data["parameters"] 102 | for key in parameters: 103 | config[key] = parameters[key][lesson] 104 | return config 105 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/engine_configuration_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/engine_configuration_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/engine_configuration_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n5communicator_objects/engine_configuration_proto.proto\x12\x14\x63ommunicator_objects\"\x95\x01\n\x18\x45ngineConfigurationProto\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x15\n\rquality_level\x18\x03 \x01(\x05\x12\x12\n\ntime_scale\x18\x04 \x01(\x02\x12\x19\n\x11target_frame_rate\x18\x05 \x01(\x05\x12\x14\n\x0cshow_monitor\x18\x06 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _ENGINECONFIGURATIONPROTO = _descriptor.Descriptor( 29 | name='EngineConfigurationProto', 30 | full_name='communicator_objects.EngineConfigurationProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='width', full_name='communicator_objects.EngineConfigurationProto.width', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='height', full_name='communicator_objects.EngineConfigurationProto.height', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='quality_level', full_name='communicator_objects.EngineConfigurationProto.quality_level', index=2, 51 | number=3, type=5, cpp_type=1, label=1, 52 | has_default_value=False, default_value=0, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | _descriptor.FieldDescriptor( 57 | name='time_scale', full_name='communicator_objects.EngineConfigurationProto.time_scale', index=3, 58 | number=4, type=2, cpp_type=6, label=1, 59 | has_default_value=False, default_value=float(0), 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None, file=DESCRIPTOR), 63 | _descriptor.FieldDescriptor( 64 | name='target_frame_rate', full_name='communicator_objects.EngineConfigurationProto.target_frame_rate', index=4, 65 | number=5, type=5, cpp_type=1, label=1, 66 | has_default_value=False, default_value=0, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None, file=DESCRIPTOR), 70 | _descriptor.FieldDescriptor( 71 | name='show_monitor', full_name='communicator_objects.EngineConfigurationProto.show_monitor', index=5, 72 | number=6, type=8, cpp_type=7, label=1, 73 | has_default_value=False, default_value=False, 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None, file=DESCRIPTOR), 77 | ], 78 | extensions=[ 79 | ], 80 | nested_types=[], 81 | enum_types=[ 82 | ], 83 | options=None, 84 | is_extendable=False, 85 | syntax='proto3', 86 | extension_ranges=[], 87 | oneofs=[ 88 | ], 89 | serialized_start=80, 90 | serialized_end=229, 91 | ) 92 | 93 | DESCRIPTOR.message_types_by_name['EngineConfigurationProto'] = _ENGINECONFIGURATIONPROTO 94 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 95 | 96 | EngineConfigurationProto = _reflection.GeneratedProtocolMessageType('EngineConfigurationProto', (_message.Message,), dict( 97 | DESCRIPTOR = _ENGINECONFIGURATIONPROTO, 98 | __module__ = 'communicator_objects.engine_configuration_proto_pb2' 99 | # @@protoc_insertion_point(class_scope:communicator_objects.EngineConfigurationProto) 100 | )) 101 | _sym_db.RegisterMessage(EngineConfigurationProto) 102 | 103 | 104 | DESCRIPTOR.has_options = True 105 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 106 | # @@protoc_insertion_point(module_scope) 107 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/environment_parameters_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/environment_parameters_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/environment_parameters_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n7communicator_objects/environment_parameters_proto.proto\x12\x14\x63ommunicator_objects\"\xb5\x01\n\x1a\x45nvironmentParametersProto\x12_\n\x10\x66loat_parameters\x18\x01 \x03(\x0b\x32\x45.communicator_objects.EnvironmentParametersProto.FloatParametersEntry\x1a\x36\n\x14\x46loatParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY = _descriptor.Descriptor( 29 | name='FloatParametersEntry', 30 | full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='key', full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry.key', index=0, 37 | number=1, type=9, cpp_type=9, label=1, 38 | has_default_value=False, default_value=_b("").decode('utf-8'), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='value', full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry.value', index=1, 44 | number=2, type=2, cpp_type=6, label=1, 45 | has_default_value=False, default_value=float(0), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | ], 50 | extensions=[ 51 | ], 52 | nested_types=[], 53 | enum_types=[ 54 | ], 55 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), 56 | is_extendable=False, 57 | syntax='proto3', 58 | extension_ranges=[], 59 | oneofs=[ 60 | ], 61 | serialized_start=209, 62 | serialized_end=263, 63 | ) 64 | 65 | _ENVIRONMENTPARAMETERSPROTO = _descriptor.Descriptor( 66 | name='EnvironmentParametersProto', 67 | full_name='communicator_objects.EnvironmentParametersProto', 68 | filename=None, 69 | file=DESCRIPTOR, 70 | containing_type=None, 71 | fields=[ 72 | _descriptor.FieldDescriptor( 73 | name='float_parameters', full_name='communicator_objects.EnvironmentParametersProto.float_parameters', index=0, 74 | number=1, type=11, cpp_type=10, label=3, 75 | has_default_value=False, default_value=[], 76 | message_type=None, enum_type=None, containing_type=None, 77 | is_extension=False, extension_scope=None, 78 | options=None, file=DESCRIPTOR), 79 | ], 80 | extensions=[ 81 | ], 82 | nested_types=[_ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY, ], 83 | enum_types=[ 84 | ], 85 | options=None, 86 | is_extendable=False, 87 | syntax='proto3', 88 | extension_ranges=[], 89 | oneofs=[ 90 | ], 91 | serialized_start=82, 92 | serialized_end=263, 93 | ) 94 | 95 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY.containing_type = _ENVIRONMENTPARAMETERSPROTO 96 | _ENVIRONMENTPARAMETERSPROTO.fields_by_name['float_parameters'].message_type = _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY 97 | DESCRIPTOR.message_types_by_name['EnvironmentParametersProto'] = _ENVIRONMENTPARAMETERSPROTO 98 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 99 | 100 | EnvironmentParametersProto = _reflection.GeneratedProtocolMessageType('EnvironmentParametersProto', (_message.Message,), dict( 101 | 102 | FloatParametersEntry = _reflection.GeneratedProtocolMessageType('FloatParametersEntry', (_message.Message,), dict( 103 | DESCRIPTOR = _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY, 104 | __module__ = 'communicator_objects.environment_parameters_proto_pb2' 105 | # @@protoc_insertion_point(class_scope:communicator_objects.EnvironmentParametersProto.FloatParametersEntry) 106 | )) 107 | , 108 | DESCRIPTOR = _ENVIRONMENTPARAMETERSPROTO, 109 | __module__ = 'communicator_objects.environment_parameters_proto_pb2' 110 | # @@protoc_insertion_point(class_scope:communicator_objects.EnvironmentParametersProto) 111 | )) 112 | _sym_db.RegisterMessage(EnvironmentParametersProto) 113 | _sym_db.RegisterMessage(EnvironmentParametersProto.FloatParametersEntry) 114 | 115 | 116 | DESCRIPTOR.has_options = True 117 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 118 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY.has_options = True 119 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) 120 | # @@protoc_insertion_point(module_scope) 121 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/unity_rl_initialization_output_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_rl_initialization_output.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import brain_parameters_proto_pb2 as communicator__objects_dot_brain__parameters__proto__pb2 17 | from communicator_objects import environment_parameters_proto_pb2 as communicator__objects_dot_environment__parameters__proto__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='communicator_objects/unity_rl_initialization_output.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_pb=_b('\n9communicator_objects/unity_rl_initialization_output.proto\x12\x14\x63ommunicator_objects\x1a\x31\x63ommunicator_objects/brain_parameters_proto.proto\x1a\x37\x63ommunicator_objects/environment_parameters_proto.proto\"\xe6\x01\n\x1bUnityRLInitializationOutput\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x10\n\x08log_path\x18\x03 \x01(\t\x12\x44\n\x10\x62rain_parameters\x18\x05 \x03(\x0b\x32*.communicator_objects.BrainParametersProto\x12P\n\x16\x65nvironment_parameters\x18\x06 \x01(\x0b\x32\x30.communicator_objects.EnvironmentParametersProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[communicator__objects_dot_brain__parameters__proto__pb2.DESCRIPTOR,communicator__objects_dot_environment__parameters__proto__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _UNITYRLINITIALIZATIONOUTPUT = _descriptor.Descriptor( 32 | name='UnityRLInitializationOutput', 33 | full_name='communicator_objects.UnityRLInitializationOutput', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='name', full_name='communicator_objects.UnityRLInitializationOutput.name', index=0, 40 | number=1, type=9, cpp_type=9, label=1, 41 | has_default_value=False, default_value=_b("").decode('utf-8'), 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='version', full_name='communicator_objects.UnityRLInitializationOutput.version', index=1, 47 | number=2, type=9, cpp_type=9, label=1, 48 | has_default_value=False, default_value=_b("").decode('utf-8'), 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None, file=DESCRIPTOR), 52 | _descriptor.FieldDescriptor( 53 | name='log_path', full_name='communicator_objects.UnityRLInitializationOutput.log_path', index=2, 54 | number=3, type=9, cpp_type=9, label=1, 55 | has_default_value=False, default_value=_b("").decode('utf-8'), 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | options=None, file=DESCRIPTOR), 59 | _descriptor.FieldDescriptor( 60 | name='brain_parameters', full_name='communicator_objects.UnityRLInitializationOutput.brain_parameters', index=3, 61 | number=5, type=11, cpp_type=10, label=3, 62 | has_default_value=False, default_value=[], 63 | message_type=None, enum_type=None, containing_type=None, 64 | is_extension=False, extension_scope=None, 65 | options=None, file=DESCRIPTOR), 66 | _descriptor.FieldDescriptor( 67 | name='environment_parameters', full_name='communicator_objects.UnityRLInitializationOutput.environment_parameters', index=4, 68 | number=6, type=11, cpp_type=10, label=1, 69 | has_default_value=False, default_value=None, 70 | message_type=None, enum_type=None, containing_type=None, 71 | is_extension=False, extension_scope=None, 72 | options=None, file=DESCRIPTOR), 73 | ], 74 | extensions=[ 75 | ], 76 | nested_types=[], 77 | enum_types=[ 78 | ], 79 | options=None, 80 | is_extendable=False, 81 | syntax='proto3', 82 | extension_ranges=[], 83 | oneofs=[ 84 | ], 85 | serialized_start=192, 86 | serialized_end=422, 87 | ) 88 | 89 | _UNITYRLINITIALIZATIONOUTPUT.fields_by_name['brain_parameters'].message_type = communicator__objects_dot_brain__parameters__proto__pb2._BRAINPARAMETERSPROTO 90 | _UNITYRLINITIALIZATIONOUTPUT.fields_by_name['environment_parameters'].message_type = communicator__objects_dot_environment__parameters__proto__pb2._ENVIRONMENTPARAMETERSPROTO 91 | DESCRIPTOR.message_types_by_name['UnityRLInitializationOutput'] = _UNITYRLINITIALIZATIONOUTPUT 92 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 93 | 94 | UnityRLInitializationOutput = _reflection.GeneratedProtocolMessageType('UnityRLInitializationOutput', (_message.Message,), dict( 95 | DESCRIPTOR = _UNITYRLINITIALIZATIONOUTPUT, 96 | __module__ = 'communicator_objects.unity_rl_initialization_output_pb2' 97 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInitializationOutput) 98 | )) 99 | _sym_db.RegisterMessage(UnityRLInitializationOutput) 100 | 101 | 102 | DESCRIPTOR.has_options = True 103 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 104 | # @@protoc_insertion_point(module_scope) 105 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/agent_info_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/agent_info_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='communicator_objects/agent_info_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_pb=_b('\n+communicator_objects/agent_info_proto.proto\x12\x14\x63ommunicator_objects\"\xfd\x01\n\x0e\x41gentInfoProto\x12\"\n\x1astacked_vector_observation\x18\x01 \x03(\x02\x12\x1b\n\x13visual_observations\x18\x02 \x03(\x0c\x12\x18\n\x10text_observation\x18\x03 \x01(\t\x12\x1d\n\x15stored_vector_actions\x18\x04 \x03(\x02\x12\x1b\n\x13stored_text_actions\x18\x05 \x01(\t\x12\x10\n\x08memories\x18\x06 \x03(\x02\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 23 | ) 24 | 25 | 26 | 27 | 28 | _AGENTINFOPROTO = _descriptor.Descriptor( 29 | name='AgentInfoProto', 30 | full_name='communicator_objects.AgentInfoProto', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='stacked_vector_observation', full_name='communicator_objects.AgentInfoProto.stacked_vector_observation', index=0, 37 | number=1, type=2, cpp_type=6, label=3, 38 | has_default_value=False, default_value=[], 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None, file=DESCRIPTOR), 42 | _descriptor.FieldDescriptor( 43 | name='visual_observations', full_name='communicator_objects.AgentInfoProto.visual_observations', index=1, 44 | number=2, type=12, cpp_type=9, label=3, 45 | has_default_value=False, default_value=[], 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None, file=DESCRIPTOR), 49 | _descriptor.FieldDescriptor( 50 | name='text_observation', full_name='communicator_objects.AgentInfoProto.text_observation', index=2, 51 | number=3, type=9, cpp_type=9, label=1, 52 | has_default_value=False, default_value=_b("").decode('utf-8'), 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None, file=DESCRIPTOR), 56 | _descriptor.FieldDescriptor( 57 | name='stored_vector_actions', full_name='communicator_objects.AgentInfoProto.stored_vector_actions', index=3, 58 | number=4, type=2, cpp_type=6, label=3, 59 | has_default_value=False, default_value=[], 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None, file=DESCRIPTOR), 63 | _descriptor.FieldDescriptor( 64 | name='stored_text_actions', full_name='communicator_objects.AgentInfoProto.stored_text_actions', index=4, 65 | number=5, type=9, cpp_type=9, label=1, 66 | has_default_value=False, default_value=_b("").decode('utf-8'), 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None, file=DESCRIPTOR), 70 | _descriptor.FieldDescriptor( 71 | name='memories', full_name='communicator_objects.AgentInfoProto.memories', index=5, 72 | number=6, type=2, cpp_type=6, label=3, 73 | has_default_value=False, default_value=[], 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None, file=DESCRIPTOR), 77 | _descriptor.FieldDescriptor( 78 | name='reward', full_name='communicator_objects.AgentInfoProto.reward', index=6, 79 | number=7, type=2, cpp_type=6, label=1, 80 | has_default_value=False, default_value=float(0), 81 | message_type=None, enum_type=None, containing_type=None, 82 | is_extension=False, extension_scope=None, 83 | options=None, file=DESCRIPTOR), 84 | _descriptor.FieldDescriptor( 85 | name='done', full_name='communicator_objects.AgentInfoProto.done', index=7, 86 | number=8, type=8, cpp_type=7, label=1, 87 | has_default_value=False, default_value=False, 88 | message_type=None, enum_type=None, containing_type=None, 89 | is_extension=False, extension_scope=None, 90 | options=None, file=DESCRIPTOR), 91 | _descriptor.FieldDescriptor( 92 | name='max_step_reached', full_name='communicator_objects.AgentInfoProto.max_step_reached', index=8, 93 | number=9, type=8, cpp_type=7, label=1, 94 | has_default_value=False, default_value=False, 95 | message_type=None, enum_type=None, containing_type=None, 96 | is_extension=False, extension_scope=None, 97 | options=None, file=DESCRIPTOR), 98 | _descriptor.FieldDescriptor( 99 | name='id', full_name='communicator_objects.AgentInfoProto.id', index=9, 100 | number=10, type=5, cpp_type=1, label=1, 101 | has_default_value=False, default_value=0, 102 | message_type=None, enum_type=None, containing_type=None, 103 | is_extension=False, extension_scope=None, 104 | options=None, file=DESCRIPTOR), 105 | ], 106 | extensions=[ 107 | ], 108 | nested_types=[], 109 | enum_types=[ 110 | ], 111 | options=None, 112 | is_extendable=False, 113 | syntax='proto3', 114 | extension_ranges=[], 115 | oneofs=[ 116 | ], 117 | serialized_start=70, 118 | serialized_end=323, 119 | ) 120 | 121 | DESCRIPTOR.message_types_by_name['AgentInfoProto'] = _AGENTINFOPROTO 122 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 123 | 124 | AgentInfoProto = _reflection.GeneratedProtocolMessageType('AgentInfoProto', (_message.Message,), dict( 125 | DESCRIPTOR = _AGENTINFOPROTO, 126 | __module__ = 'communicator_objects.agent_info_proto_pb2' 127 | # @@protoc_insertion_point(class_scope:communicator_objects.AgentInfoProto) 128 | )) 129 | _sym_db.RegisterMessage(AgentInfoProto) 130 | 131 | 132 | DESCRIPTOR.has_options = True 133 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 134 | # @@protoc_insertion_point(module_scope) 135 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/unity_rl_output_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_rl_output.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import agent_info_proto_pb2 as communicator__objects_dot_agent__info__proto__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='communicator_objects/unity_rl_output.proto', 21 | package='communicator_objects', 22 | syntax='proto3', 23 | serialized_pb=_b('\n*communicator_objects/unity_rl_output.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/agent_info_proto.proto\"\xa3\x02\n\rUnityRLOutput\x12\x13\n\x0bglobal_done\x18\x01 \x01(\x08\x12G\n\nagentInfos\x18\x02 \x03(\x0b\x32\x33.communicator_objects.UnityRLOutput.AgentInfosEntry\x1aI\n\x12ListAgentInfoProto\x12\x33\n\x05value\x18\x01 \x03(\x0b\x32$.communicator_objects.AgentInfoProto\x1ai\n\x0f\x41gentInfosEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x45\n\x05value\x18\x02 \x01(\x0b\x32\x36.communicator_objects.UnityRLOutput.ListAgentInfoProto:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | , 25 | dependencies=[communicator__objects_dot_agent__info__proto__pb2.DESCRIPTOR,]) 26 | 27 | 28 | 29 | 30 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO = _descriptor.Descriptor( 31 | name='ListAgentInfoProto', 32 | full_name='communicator_objects.UnityRLOutput.ListAgentInfoProto', 33 | filename=None, 34 | file=DESCRIPTOR, 35 | containing_type=None, 36 | fields=[ 37 | _descriptor.FieldDescriptor( 38 | name='value', full_name='communicator_objects.UnityRLOutput.ListAgentInfoProto.value', index=0, 39 | number=1, type=11, cpp_type=10, label=3, 40 | has_default_value=False, default_value=[], 41 | message_type=None, enum_type=None, containing_type=None, 42 | is_extension=False, extension_scope=None, 43 | options=None, file=DESCRIPTOR), 44 | ], 45 | extensions=[ 46 | ], 47 | nested_types=[], 48 | enum_types=[ 49 | ], 50 | options=None, 51 | is_extendable=False, 52 | syntax='proto3', 53 | extension_ranges=[], 54 | oneofs=[ 55 | ], 56 | serialized_start=225, 57 | serialized_end=298, 58 | ) 59 | 60 | _UNITYRLOUTPUT_AGENTINFOSENTRY = _descriptor.Descriptor( 61 | name='AgentInfosEntry', 62 | full_name='communicator_objects.UnityRLOutput.AgentInfosEntry', 63 | filename=None, 64 | file=DESCRIPTOR, 65 | containing_type=None, 66 | fields=[ 67 | _descriptor.FieldDescriptor( 68 | name='key', full_name='communicator_objects.UnityRLOutput.AgentInfosEntry.key', index=0, 69 | number=1, type=9, cpp_type=9, label=1, 70 | has_default_value=False, default_value=_b("").decode('utf-8'), 71 | message_type=None, enum_type=None, containing_type=None, 72 | is_extension=False, extension_scope=None, 73 | options=None, file=DESCRIPTOR), 74 | _descriptor.FieldDescriptor( 75 | name='value', full_name='communicator_objects.UnityRLOutput.AgentInfosEntry.value', index=1, 76 | number=2, type=11, cpp_type=10, label=1, 77 | has_default_value=False, default_value=None, 78 | message_type=None, enum_type=None, containing_type=None, 79 | is_extension=False, extension_scope=None, 80 | options=None, file=DESCRIPTOR), 81 | ], 82 | extensions=[ 83 | ], 84 | nested_types=[], 85 | enum_types=[ 86 | ], 87 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), 88 | is_extendable=False, 89 | syntax='proto3', 90 | extension_ranges=[], 91 | oneofs=[ 92 | ], 93 | serialized_start=300, 94 | serialized_end=405, 95 | ) 96 | 97 | _UNITYRLOUTPUT = _descriptor.Descriptor( 98 | name='UnityRLOutput', 99 | full_name='communicator_objects.UnityRLOutput', 100 | filename=None, 101 | file=DESCRIPTOR, 102 | containing_type=None, 103 | fields=[ 104 | _descriptor.FieldDescriptor( 105 | name='global_done', full_name='communicator_objects.UnityRLOutput.global_done', index=0, 106 | number=1, type=8, cpp_type=7, label=1, 107 | has_default_value=False, default_value=False, 108 | message_type=None, enum_type=None, containing_type=None, 109 | is_extension=False, extension_scope=None, 110 | options=None, file=DESCRIPTOR), 111 | _descriptor.FieldDescriptor( 112 | name='agentInfos', full_name='communicator_objects.UnityRLOutput.agentInfos', index=1, 113 | number=2, type=11, cpp_type=10, label=3, 114 | has_default_value=False, default_value=[], 115 | message_type=None, enum_type=None, containing_type=None, 116 | is_extension=False, extension_scope=None, 117 | options=None, file=DESCRIPTOR), 118 | ], 119 | extensions=[ 120 | ], 121 | nested_types=[_UNITYRLOUTPUT_LISTAGENTINFOPROTO, _UNITYRLOUTPUT_AGENTINFOSENTRY, ], 122 | enum_types=[ 123 | ], 124 | options=None, 125 | is_extendable=False, 126 | syntax='proto3', 127 | extension_ranges=[], 128 | oneofs=[ 129 | ], 130 | serialized_start=114, 131 | serialized_end=405, 132 | ) 133 | 134 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO.fields_by_name['value'].message_type = communicator__objects_dot_agent__info__proto__pb2._AGENTINFOPROTO 135 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO.containing_type = _UNITYRLOUTPUT 136 | _UNITYRLOUTPUT_AGENTINFOSENTRY.fields_by_name['value'].message_type = _UNITYRLOUTPUT_LISTAGENTINFOPROTO 137 | _UNITYRLOUTPUT_AGENTINFOSENTRY.containing_type = _UNITYRLOUTPUT 138 | _UNITYRLOUTPUT.fields_by_name['agentInfos'].message_type = _UNITYRLOUTPUT_AGENTINFOSENTRY 139 | DESCRIPTOR.message_types_by_name['UnityRLOutput'] = _UNITYRLOUTPUT 140 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 141 | 142 | UnityRLOutput = _reflection.GeneratedProtocolMessageType('UnityRLOutput', (_message.Message,), dict( 143 | 144 | ListAgentInfoProto = _reflection.GeneratedProtocolMessageType('ListAgentInfoProto', (_message.Message,), dict( 145 | DESCRIPTOR = _UNITYRLOUTPUT_LISTAGENTINFOPROTO, 146 | __module__ = 'communicator_objects.unity_rl_output_pb2' 147 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput.ListAgentInfoProto) 148 | )) 149 | , 150 | 151 | AgentInfosEntry = _reflection.GeneratedProtocolMessageType('AgentInfosEntry', (_message.Message,), dict( 152 | DESCRIPTOR = _UNITYRLOUTPUT_AGENTINFOSENTRY, 153 | __module__ = 'communicator_objects.unity_rl_output_pb2' 154 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput.AgentInfosEntry) 155 | )) 156 | , 157 | DESCRIPTOR = _UNITYRLOUTPUT, 158 | __module__ = 'communicator_objects.unity_rl_output_pb2' 159 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput) 160 | )) 161 | _sym_db.RegisterMessage(UnityRLOutput) 162 | _sym_db.RegisterMessage(UnityRLOutput.ListAgentInfoProto) 163 | _sym_db.RegisterMessage(UnityRLOutput.AgentInfosEntry) 164 | 165 | 166 | DESCRIPTOR.has_options = True 167 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 168 | _UNITYRLOUTPUT_AGENTINFOSENTRY.has_options = True 169 | _UNITYRLOUTPUT_AGENTINFOSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) 170 | # @@protoc_insertion_point(module_scope) 171 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/brain_parameters_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/brain_parameters_proto.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import resolution_proto_pb2 as communicator__objects_dot_resolution__proto__pb2 17 | from communicator_objects import brain_type_proto_pb2 as communicator__objects_dot_brain__type__proto__pb2 18 | from communicator_objects import space_type_proto_pb2 as communicator__objects_dot_space__type__proto__pb2 19 | 20 | 21 | DESCRIPTOR = _descriptor.FileDescriptor( 22 | name='communicator_objects/brain_parameters_proto.proto', 23 | package='communicator_objects', 24 | syntax='proto3', 25 | serialized_pb=_b('\n1communicator_objects/brain_parameters_proto.proto\x12\x14\x63ommunicator_objects\x1a+communicator_objects/resolution_proto.proto\x1a+communicator_objects/brain_type_proto.proto\x1a+communicator_objects/space_type_proto.proto\"\xc6\x03\n\x14\x42rainParametersProto\x12\x1f\n\x17vector_observation_size\x18\x01 \x01(\x05\x12\'\n\x1fnum_stacked_vector_observations\x18\x02 \x01(\x05\x12\x1a\n\x12vector_action_size\x18\x03 \x01(\x05\x12\x41\n\x12\x63\x61mera_resolutions\x18\x04 \x03(\x0b\x32%.communicator_objects.ResolutionProto\x12\"\n\x1avector_action_descriptions\x18\x05 \x03(\t\x12\x46\n\x18vector_action_space_type\x18\x06 \x01(\x0e\x32$.communicator_objects.SpaceTypeProto\x12K\n\x1dvector_observation_space_type\x18\x07 \x01(\x0e\x32$.communicator_objects.SpaceTypeProto\x12\x12\n\nbrain_name\x18\x08 \x01(\t\x12\x38\n\nbrain_type\x18\t \x01(\x0e\x32$.communicator_objects.BrainTypeProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 26 | , 27 | dependencies=[communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,communicator__objects_dot_brain__type__proto__pb2.DESCRIPTOR,communicator__objects_dot_space__type__proto__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _BRAINPARAMETERSPROTO = _descriptor.Descriptor( 33 | name='BrainParametersProto', 34 | full_name='communicator_objects.BrainParametersProto', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='vector_observation_size', full_name='communicator_objects.BrainParametersProto.vector_observation_size', index=0, 41 | number=1, type=5, cpp_type=1, label=1, 42 | has_default_value=False, default_value=0, 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None, file=DESCRIPTOR), 46 | _descriptor.FieldDescriptor( 47 | name='num_stacked_vector_observations', full_name='communicator_objects.BrainParametersProto.num_stacked_vector_observations', index=1, 48 | number=2, type=5, cpp_type=1, label=1, 49 | has_default_value=False, default_value=0, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | options=None, file=DESCRIPTOR), 53 | _descriptor.FieldDescriptor( 54 | name='vector_action_size', full_name='communicator_objects.BrainParametersProto.vector_action_size', index=2, 55 | number=3, type=5, cpp_type=1, label=1, 56 | has_default_value=False, default_value=0, 57 | message_type=None, enum_type=None, containing_type=None, 58 | is_extension=False, extension_scope=None, 59 | options=None, file=DESCRIPTOR), 60 | _descriptor.FieldDescriptor( 61 | name='camera_resolutions', full_name='communicator_objects.BrainParametersProto.camera_resolutions', index=3, 62 | number=4, type=11, cpp_type=10, label=3, 63 | has_default_value=False, default_value=[], 64 | message_type=None, enum_type=None, containing_type=None, 65 | is_extension=False, extension_scope=None, 66 | options=None, file=DESCRIPTOR), 67 | _descriptor.FieldDescriptor( 68 | name='vector_action_descriptions', full_name='communicator_objects.BrainParametersProto.vector_action_descriptions', index=4, 69 | number=5, type=9, cpp_type=9, label=3, 70 | has_default_value=False, default_value=[], 71 | message_type=None, enum_type=None, containing_type=None, 72 | is_extension=False, extension_scope=None, 73 | options=None, file=DESCRIPTOR), 74 | _descriptor.FieldDescriptor( 75 | name='vector_action_space_type', full_name='communicator_objects.BrainParametersProto.vector_action_space_type', index=5, 76 | number=6, type=14, cpp_type=8, label=1, 77 | has_default_value=False, default_value=0, 78 | message_type=None, enum_type=None, containing_type=None, 79 | is_extension=False, extension_scope=None, 80 | options=None, file=DESCRIPTOR), 81 | _descriptor.FieldDescriptor( 82 | name='vector_observation_space_type', full_name='communicator_objects.BrainParametersProto.vector_observation_space_type', index=6, 83 | number=7, type=14, cpp_type=8, label=1, 84 | has_default_value=False, default_value=0, 85 | message_type=None, enum_type=None, containing_type=None, 86 | is_extension=False, extension_scope=None, 87 | options=None, file=DESCRIPTOR), 88 | _descriptor.FieldDescriptor( 89 | name='brain_name', full_name='communicator_objects.BrainParametersProto.brain_name', index=7, 90 | number=8, type=9, cpp_type=9, label=1, 91 | has_default_value=False, default_value=_b("").decode('utf-8'), 92 | message_type=None, enum_type=None, containing_type=None, 93 | is_extension=False, extension_scope=None, 94 | options=None, file=DESCRIPTOR), 95 | _descriptor.FieldDescriptor( 96 | name='brain_type', full_name='communicator_objects.BrainParametersProto.brain_type', index=8, 97 | number=9, type=14, cpp_type=8, label=1, 98 | has_default_value=False, default_value=0, 99 | message_type=None, enum_type=None, containing_type=None, 100 | is_extension=False, extension_scope=None, 101 | options=None, file=DESCRIPTOR), 102 | ], 103 | extensions=[ 104 | ], 105 | nested_types=[], 106 | enum_types=[ 107 | ], 108 | options=None, 109 | is_extendable=False, 110 | syntax='proto3', 111 | extension_ranges=[], 112 | oneofs=[ 113 | ], 114 | serialized_start=211, 115 | serialized_end=665, 116 | ) 117 | 118 | _BRAINPARAMETERSPROTO.fields_by_name['camera_resolutions'].message_type = communicator__objects_dot_resolution__proto__pb2._RESOLUTIONPROTO 119 | _BRAINPARAMETERSPROTO.fields_by_name['vector_action_space_type'].enum_type = communicator__objects_dot_space__type__proto__pb2._SPACETYPEPROTO 120 | _BRAINPARAMETERSPROTO.fields_by_name['vector_observation_space_type'].enum_type = communicator__objects_dot_space__type__proto__pb2._SPACETYPEPROTO 121 | _BRAINPARAMETERSPROTO.fields_by_name['brain_type'].enum_type = communicator__objects_dot_brain__type__proto__pb2._BRAINTYPEPROTO 122 | DESCRIPTOR.message_types_by_name['BrainParametersProto'] = _BRAINPARAMETERSPROTO 123 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 124 | 125 | BrainParametersProto = _reflection.GeneratedProtocolMessageType('BrainParametersProto', (_message.Message,), dict( 126 | DESCRIPTOR = _BRAINPARAMETERSPROTO, 127 | __module__ = 'communicator_objects.brain_parameters_proto_pb2' 128 | # @@protoc_insertion_point(class_scope:communicator_objects.BrainParametersProto) 129 | )) 130 | _sym_db.RegisterMessage(BrainParametersProto) 131 | 132 | 133 | DESCRIPTOR.has_options = True 134 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 135 | # @@protoc_insertion_point(module_scope) 136 | -------------------------------------------------------------------------------- /install_requirements/communicator_objects/unity_rl_input_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: communicator_objects/unity_rl_input.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from communicator_objects import agent_action_proto_pb2 as communicator__objects_dot_agent__action__proto__pb2 17 | from communicator_objects import environment_parameters_proto_pb2 as communicator__objects_dot_environment__parameters__proto__pb2 18 | from communicator_objects import command_proto_pb2 as communicator__objects_dot_command__proto__pb2 19 | 20 | 21 | DESCRIPTOR = _descriptor.FileDescriptor( 22 | name='communicator_objects/unity_rl_input.proto', 23 | package='communicator_objects', 24 | syntax='proto3', 25 | serialized_pb=_b('\n)communicator_objects/unity_rl_input.proto\x12\x14\x63ommunicator_objects\x1a-communicator_objects/agent_action_proto.proto\x1a\x37\x63ommunicator_objects/environment_parameters_proto.proto\x1a(communicator_objects/command_proto.proto\"\xb4\x03\n\x0cUnityRLInput\x12K\n\ragent_actions\x18\x01 \x03(\x0b\x32\x34.communicator_objects.UnityRLInput.AgentActionsEntry\x12P\n\x16\x65nvironment_parameters\x18\x02 \x01(\x0b\x32\x30.communicator_objects.EnvironmentParametersProto\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x33\n\x07\x63ommand\x18\x04 \x01(\x0e\x32\".communicator_objects.CommandProto\x1aM\n\x14ListAgentActionProto\x12\x35\n\x05value\x18\x01 \x03(\x0b\x32&.communicator_objects.AgentActionProto\x1al\n\x11\x41gentActionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x46\n\x05value\x18\x02 \x01(\x0b\x32\x37.communicator_objects.UnityRLInput.ListAgentActionProto:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 26 | , 27 | dependencies=[communicator__objects_dot_agent__action__proto__pb2.DESCRIPTOR,communicator__objects_dot_environment__parameters__proto__pb2.DESCRIPTOR,communicator__objects_dot_command__proto__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _UNITYRLINPUT_LISTAGENTACTIONPROTO = _descriptor.Descriptor( 33 | name='ListAgentActionProto', 34 | full_name='communicator_objects.UnityRLInput.ListAgentActionProto', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='value', full_name='communicator_objects.UnityRLInput.ListAgentActionProto.value', index=0, 41 | number=1, type=11, cpp_type=10, label=3, 42 | has_default_value=False, default_value=[], 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None, file=DESCRIPTOR), 46 | ], 47 | extensions=[ 48 | ], 49 | nested_types=[], 50 | enum_types=[ 51 | ], 52 | options=None, 53 | is_extendable=False, 54 | syntax='proto3', 55 | extension_ranges=[], 56 | oneofs=[ 57 | ], 58 | serialized_start=463, 59 | serialized_end=540, 60 | ) 61 | 62 | _UNITYRLINPUT_AGENTACTIONSENTRY = _descriptor.Descriptor( 63 | name='AgentActionsEntry', 64 | full_name='communicator_objects.UnityRLInput.AgentActionsEntry', 65 | filename=None, 66 | file=DESCRIPTOR, 67 | containing_type=None, 68 | fields=[ 69 | _descriptor.FieldDescriptor( 70 | name='key', full_name='communicator_objects.UnityRLInput.AgentActionsEntry.key', index=0, 71 | number=1, type=9, cpp_type=9, label=1, 72 | has_default_value=False, default_value=_b("").decode('utf-8'), 73 | message_type=None, enum_type=None, containing_type=None, 74 | is_extension=False, extension_scope=None, 75 | options=None, file=DESCRIPTOR), 76 | _descriptor.FieldDescriptor( 77 | name='value', full_name='communicator_objects.UnityRLInput.AgentActionsEntry.value', index=1, 78 | number=2, type=11, cpp_type=10, label=1, 79 | has_default_value=False, default_value=None, 80 | message_type=None, enum_type=None, containing_type=None, 81 | is_extension=False, extension_scope=None, 82 | options=None, file=DESCRIPTOR), 83 | ], 84 | extensions=[ 85 | ], 86 | nested_types=[], 87 | enum_types=[ 88 | ], 89 | options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')), 90 | is_extendable=False, 91 | syntax='proto3', 92 | extension_ranges=[], 93 | oneofs=[ 94 | ], 95 | serialized_start=542, 96 | serialized_end=650, 97 | ) 98 | 99 | _UNITYRLINPUT = _descriptor.Descriptor( 100 | name='UnityRLInput', 101 | full_name='communicator_objects.UnityRLInput', 102 | filename=None, 103 | file=DESCRIPTOR, 104 | containing_type=None, 105 | fields=[ 106 | _descriptor.FieldDescriptor( 107 | name='agent_actions', full_name='communicator_objects.UnityRLInput.agent_actions', index=0, 108 | number=1, type=11, cpp_type=10, label=3, 109 | has_default_value=False, default_value=[], 110 | message_type=None, enum_type=None, containing_type=None, 111 | is_extension=False, extension_scope=None, 112 | options=None, file=DESCRIPTOR), 113 | _descriptor.FieldDescriptor( 114 | name='environment_parameters', full_name='communicator_objects.UnityRLInput.environment_parameters', index=1, 115 | number=2, type=11, cpp_type=10, label=1, 116 | has_default_value=False, default_value=None, 117 | message_type=None, enum_type=None, containing_type=None, 118 | is_extension=False, extension_scope=None, 119 | options=None, file=DESCRIPTOR), 120 | _descriptor.FieldDescriptor( 121 | name='is_training', full_name='communicator_objects.UnityRLInput.is_training', index=2, 122 | number=3, type=8, cpp_type=7, label=1, 123 | has_default_value=False, default_value=False, 124 | message_type=None, enum_type=None, containing_type=None, 125 | is_extension=False, extension_scope=None, 126 | options=None, file=DESCRIPTOR), 127 | _descriptor.FieldDescriptor( 128 | name='command', full_name='communicator_objects.UnityRLInput.command', index=3, 129 | number=4, type=14, cpp_type=8, label=1, 130 | has_default_value=False, default_value=0, 131 | message_type=None, enum_type=None, containing_type=None, 132 | is_extension=False, extension_scope=None, 133 | options=None, file=DESCRIPTOR), 134 | ], 135 | extensions=[ 136 | ], 137 | nested_types=[_UNITYRLINPUT_LISTAGENTACTIONPROTO, _UNITYRLINPUT_AGENTACTIONSENTRY, ], 138 | enum_types=[ 139 | ], 140 | options=None, 141 | is_extendable=False, 142 | syntax='proto3', 143 | extension_ranges=[], 144 | oneofs=[ 145 | ], 146 | serialized_start=214, 147 | serialized_end=650, 148 | ) 149 | 150 | _UNITYRLINPUT_LISTAGENTACTIONPROTO.fields_by_name['value'].message_type = communicator__objects_dot_agent__action__proto__pb2._AGENTACTIONPROTO 151 | _UNITYRLINPUT_LISTAGENTACTIONPROTO.containing_type = _UNITYRLINPUT 152 | _UNITYRLINPUT_AGENTACTIONSENTRY.fields_by_name['value'].message_type = _UNITYRLINPUT_LISTAGENTACTIONPROTO 153 | _UNITYRLINPUT_AGENTACTIONSENTRY.containing_type = _UNITYRLINPUT 154 | _UNITYRLINPUT.fields_by_name['agent_actions'].message_type = _UNITYRLINPUT_AGENTACTIONSENTRY 155 | _UNITYRLINPUT.fields_by_name['environment_parameters'].message_type = communicator__objects_dot_environment__parameters__proto__pb2._ENVIRONMENTPARAMETERSPROTO 156 | _UNITYRLINPUT.fields_by_name['command'].enum_type = communicator__objects_dot_command__proto__pb2._COMMANDPROTO 157 | DESCRIPTOR.message_types_by_name['UnityRLInput'] = _UNITYRLINPUT 158 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 159 | 160 | UnityRLInput = _reflection.GeneratedProtocolMessageType('UnityRLInput', (_message.Message,), dict( 161 | 162 | ListAgentActionProto = _reflection.GeneratedProtocolMessageType('ListAgentActionProto', (_message.Message,), dict( 163 | DESCRIPTOR = _UNITYRLINPUT_LISTAGENTACTIONPROTO, 164 | __module__ = 'communicator_objects.unity_rl_input_pb2' 165 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInput.ListAgentActionProto) 166 | )) 167 | , 168 | 169 | AgentActionsEntry = _reflection.GeneratedProtocolMessageType('AgentActionsEntry', (_message.Message,), dict( 170 | DESCRIPTOR = _UNITYRLINPUT_AGENTACTIONSENTRY, 171 | __module__ = 'communicator_objects.unity_rl_input_pb2' 172 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInput.AgentActionsEntry) 173 | )) 174 | , 175 | DESCRIPTOR = _UNITYRLINPUT, 176 | __module__ = 'communicator_objects.unity_rl_input_pb2' 177 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInput) 178 | )) 179 | _sym_db.RegisterMessage(UnityRLInput) 180 | _sym_db.RegisterMessage(UnityRLInput.ListAgentActionProto) 181 | _sym_db.RegisterMessage(UnityRLInput.AgentActionsEntry) 182 | 183 | 184 | DESCRIPTOR.has_options = True 185 | DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\252\002\034MLAgents.CommunicatorObjects')) 186 | _UNITYRLINPUT_AGENTACTIONSENTRY.has_options = True 187 | _UNITYRLINPUT_AGENTACTIONSENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b('8\001')) 188 | # @@protoc_insertion_point(module_scope) 189 | -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from models import ActorNetwork, CriticNetwork 6 | from collections import deque, namedtuple 7 | import random 8 | import copy 9 | 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | eps = np.finfo(np.float32).eps.item() 12 | 13 | 14 | class MADDPG: 15 | """ 16 | The Multi-Agent consisting of two DDPG Agents 17 | """ 18 | def __init__(self, 19 | *args, 20 | **kargs 21 | ): 22 | """ 23 | Initialize constituent agents 24 | :args - tuple of parameters for DDPG Agent 25 | (state_dim, 26 | action_dim, 27 | lr_actor, 28 | lr_critic, 29 | lr_decay, 30 | replay_buff_size, 31 | gamma, 32 | batch_size, 33 | random_seed, 34 | soft_update_tau) 35 | """ 36 | super(MADDPG, self).__init__() 37 | 38 | agent = DDPGAgent(*args, **kargs) 39 | self.adversarial_agents = [agent, agent] # the agent self-plays with itself 40 | 41 | def get_actors(self): 42 | """ 43 | get actors of all the agents in the MADDPG object 44 | """ 45 | actors = [ddpg_agent.actor_local for ddpg_agent in self.adversarial_agents] 46 | return actors 47 | 48 | def get_target_actors(self): 49 | """ 50 | get target_actors of all the agents in the MADDPG object 51 | """ 52 | target_actors = [ddpg_agent.actor_target for ddpg_agent in self.adversarial_agents] 53 | return target_actors 54 | 55 | def act(self, states_all_agents, add_noise=False): 56 | """ 57 | get actions from all agents in the MADDPG object 58 | """ 59 | actions = [agent.act(state, add_noise) for agent, state in zip(self.adversarial_agents, states_all_agents)] 60 | return np.stack(actions, axis=0) 61 | 62 | def update(self, *experiences): 63 | """ 64 | update the critics and actors of all the agents 65 | """ 66 | states, actions, rewards, next_states, dones = experiences 67 | for agent_idx, agent in enumerate(self.adversarial_agents): 68 | state = states[agent_idx,:] 69 | action = actions[agent_idx,:] 70 | reward = rewards[agent_idx] 71 | next_state = next_states[agent_idx,:] 72 | done = dones[agent_idx] 73 | agent.update_model(state, action, reward, next_state, done) 74 | 75 | def save(self, path): 76 | """ 77 | Save the model 78 | """ 79 | agent = self.adversarial_agents[0] 80 | torch.save((agent.actor_local.state_dict(), agent.critic_local.state_dict()), path) 81 | 82 | def load(self, path): 83 | """ 84 | Load model and decay learning rate 85 | """ 86 | actor_state_dict, critic_state_dict = torch.load(path) 87 | agent = self.adversarial_agents[0] 88 | agent.actor_local.load_state_dict(actor_state_dict) 89 | agent.actor_target.load_state_dict(actor_state_dict) 90 | agent.critic_local.load_state_dict(critic_state_dict) 91 | agent.critic_target.load_state_dict(critic_state_dict) 92 | agent.lr_actor *= agent.lr_decay 93 | agent.lr_critic *= agent.lr_decay 94 | for group in agent.actor_optimizer.param_groups: 95 | group['lr'] = agent.lr_actor 96 | for group in agent.critic_optimizer.param_groups: 97 | group['lr'] = agent.lr_critic 98 | 99 | for i in range(1, len(self.adversarial_agents)): 100 | self.adversarial_agents[i] = agent 101 | 102 | print("Loaded models!") 103 | 104 | 105 | class DDPGAgent: 106 | """ 107 | A DDPG Agent 108 | """ 109 | def __init__(self, 110 | state_dim, 111 | action_dim, 112 | lr_actor = 1e-4, 113 | lr_critic = 1e-4, 114 | lr_decay = .95, 115 | replay_buff_size = 10000, 116 | gamma = .99, 117 | batch_size = 128, 118 | random_seed = 42, 119 | soft_update_tau = 1e-3 120 | ): 121 | """ 122 | Initialize model 123 | """ 124 | self.lr_actor = lr_actor 125 | self.gamma = gamma 126 | self.lr_critic = lr_critic 127 | self.lr_decay = lr_decay 128 | self.tau = soft_update_tau 129 | 130 | self.actor_local = ActorNetwork(state_dim, action_dim).to(device=device) 131 | self.actor_target = ActorNetwork(state_dim, action_dim).to(device=device) 132 | self.actor_optimizer = optim.Adam(self.actor_local.parameters(), lr=self.lr_actor) 133 | 134 | self.critic_local = CriticNetwork(state_dim, action_dim).to(device=device) 135 | self.critic_target = CriticNetwork(state_dim, action_dim).to(device=device) 136 | self.critic_optimizer = optim.Adam(self.critic_local.parameters(), lr=self.lr_critic) 137 | 138 | self.noise = OUNoise(action_dim, random_seed) 139 | 140 | self.memory = ReplayBuffer(action_dim, replay_buff_size, batch_size, random_seed) 141 | 142 | 143 | def update_model(self, state, action, reward, next_state, done): 144 | """ 145 | Update policy and value parameters using given batch of experience tuples. 146 | Q_targets = r + γ * critic_target(next_state, actor_target(next_state)) 147 | where: 148 | actor_target(state) -> action 149 | critic_target(state, action) -> Q-value 150 | 151 | :experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 152 | :gamma (float): discount factor 153 | """ 154 | self.memory.add(state, action, reward, next_state, done) 155 | if not self.memory.is_ready(): 156 | return 157 | 158 | experiences = self.memory.sample() 159 | states, actions, rewards, next_states, dones = experiences 160 | 161 | # ---------------------------- update critic ---------------------------- # 162 | # Get predicted next-state actions and Q values from target models 163 | actions_next = self.actor_target(next_states) 164 | Q_targets_next = self.critic_target(next_states, actions_next) 165 | # Compute Q targets for current states (y_i) 166 | Q_targets = rewards + (self.gamma * Q_targets_next * (1 - dones)).detach() 167 | # Compute critic loss 168 | Q_expected = self.critic_local(states, actions) 169 | critic_loss = F.smooth_l1_loss(Q_expected, Q_targets) 170 | # Minimize the loss 171 | self.critic_optimizer.zero_grad() 172 | critic_loss.backward() 173 | self.critic_optimizer.step() 174 | 175 | # ---------------------------- update actor ---------------------------- # 176 | # Compute actor loss 177 | actions_pred = self.actor_local(states) 178 | actor_loss = -self.critic_local(states, actions_pred).mean() 179 | # Minimize the loss 180 | self.actor_optimizer.zero_grad() 181 | actor_loss.backward() 182 | self.actor_optimizer.step() 183 | 184 | # ----------------------- update target networks ----------------------- # 185 | self.soft_update(self.critic_local, self.critic_target, self.tau) 186 | self.soft_update(self.actor_local, self.actor_target, self.tau) 187 | 188 | def act(self, state, noise_t=0.0): 189 | """ 190 | Returns actions for given state as per current policy. 191 | """ 192 | if len(np.shape(state)) == 1: 193 | state = state.reshape(1,-1) 194 | state = torch.from_numpy(state).float().to(device=device) 195 | self.actor_local.eval() 196 | with torch.no_grad(): 197 | action = self.actor_local(state).cpu().data.numpy() 198 | self.actor_local.train() 199 | action += self.noise.sample() * noise_t 200 | return np.clip(action, -1, 1).squeeze() 201 | 202 | def reset(self): 203 | self.noise.reset() 204 | 205 | def soft_update(self, local_model, target_model, tau): 206 | """ 207 | Soft update model parameters. 208 | θ_target = τ*θ_local + (1 - τ)*θ_target 209 | 210 | :local_model: PyTorch model (weights will be copied from) 211 | :target_model: PyTorch model (weights will be copied to) 212 | :tau (float): interpolation parameter 213 | """ 214 | for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): 215 | target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data) 216 | 217 | 218 | class OUNoise: 219 | """ 220 | Ornstein-Uhlenbeck process. 221 | """ 222 | 223 | def __init__(self, size, seed, mu=0., theta=0.15, sigma=0.2): 224 | """ 225 | Initialize parameters and noise process.""" 226 | self.mu = mu * np.ones(size) 227 | self.theta = theta 228 | self.sigma = sigma 229 | self.seed = random.seed(seed) 230 | self.reset() 231 | 232 | def reset(self): 233 | """ 234 | Reset the internal state (= noise) to mean (mu). 235 | """ 236 | self.state = copy.copy(self.mu) 237 | 238 | def sample(self): 239 | """ 240 | Update internal state and return it as a noise sample. 241 | """ 242 | x = self.state 243 | dx = self.theta * (self.mu - x) + self.sigma * np.array([random.random() for i in range(len(x))]) 244 | self.state = x + dx 245 | return self.state 246 | 247 | 248 | class ReplayBuffer: 249 | """ 250 | Fixed-size buffer to store experience tuples. 251 | """ 252 | 253 | def __init__(self, action_size, buffer_size, batch_size, seed): 254 | """ 255 | Initialize a ReplayBuffer object. 256 | 257 | :buffer_size (int): maximum size of buffer 258 | :batch_size (int): size of each training batch 259 | """ 260 | self.action_size = action_size 261 | self.memory = deque(maxlen=buffer_size) # internal memory (deque) 262 | self.batch_size = batch_size 263 | self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) 264 | self.seed = random.seed(seed) 265 | 266 | def add(self, state, action, reward, next_state, done): 267 | """ 268 | Add a new experience to memory. 269 | """ 270 | e = self.experience(state, action, reward, next_state, done) 271 | self.memory.append(e) 272 | 273 | def sample(self): 274 | """ 275 | Randomly sample a batch of experiences from memory. 276 | """ 277 | experiences = random.sample(self.memory, k=self.batch_size) 278 | 279 | states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device) 280 | actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).float().to(device) 281 | rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device) 282 | next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device) 283 | dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device) 284 | 285 | return (states, actions, rewards, next_states, dones) 286 | 287 | def is_ready(self): 288 | return len(self.memory) > self.batch_size 289 | 290 | def __len__(self): 291 | """ 292 | Return the current size of internal memory. 293 | """ 294 | return len(self.memory) 295 | 296 | 297 | -------------------------------------------------------------------------------- /install_requirements/unityagents/environment.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import glob 3 | import io 4 | import logging 5 | import numpy as np 6 | import os 7 | import subprocess 8 | 9 | from .brain import BrainInfo, BrainParameters, AllBrainInfo 10 | from .exception import UnityEnvironmentException, UnityActionException, UnityTimeOutException 11 | from .curriculum import Curriculum 12 | 13 | from communicator_objects import UnityRLInput, UnityRLOutput, AgentActionProto,\ 14 | EnvironmentParametersProto, UnityRLInitializationInput, UnityRLInitializationOutput,\ 15 | UnityInput, UnityOutput 16 | 17 | from .rpc_communicator import RpcCommunicator 18 | from .socket_communicator import SocketCommunicator 19 | 20 | 21 | from sys import platform 22 | from PIL import Image 23 | 24 | logging.basicConfig(level=logging.INFO) 25 | logger = logging.getLogger("unityagents") 26 | 27 | 28 | class UnityEnvironment(object): 29 | def __init__(self, file_name=None, worker_id=0, 30 | base_port=5005, curriculum=None, 31 | seed=0, docker_training=False, no_graphics=False): 32 | """ 33 | Starts a new unity environment and establishes a connection with the environment. 34 | Notice: Currently communication between Unity and Python takes place over an open socket without authentication. 35 | Ensure that the network where training takes place is secure. 36 | 37 | :string file_name: Name of Unity environment binary. 38 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 39 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 40 | :param docker_training: Informs this class whether the process is being run within a container. 41 | :param no_graphics: Whether to run the Unity simulator in no-graphics mode 42 | """ 43 | 44 | atexit.register(self._close) 45 | self.port = base_port + worker_id 46 | self._buffer_size = 12000 47 | self._version_ = "API-4" 48 | self._loaded = False # If true, this means the environment was successfully loaded 49 | self.proc1 = None # The process that is started. If None, no process was started 50 | self.communicator = self.get_communicator(worker_id, base_port) 51 | 52 | # If the environment name is 'editor', a new environment will not be launched 53 | # and the communicator will directly try to connect to an existing unity environment. 54 | if file_name is not None: 55 | self.executable_launcher(file_name, docker_training, no_graphics) 56 | else: 57 | logger.info("Start training by pressing the Play button in the Unity Editor.") 58 | self._loaded = True 59 | 60 | rl_init_parameters_in = UnityRLInitializationInput( 61 | seed=seed 62 | ) 63 | try: 64 | aca_params = self.send_academy_parameters(rl_init_parameters_in) 65 | except UnityTimeOutException: 66 | self._close() 67 | raise 68 | # TODO : think of a better way to expose the academyParameters 69 | self._unity_version = aca_params.version 70 | if self._unity_version != self._version_: 71 | raise UnityEnvironmentException( 72 | "The API number is not compatible between Unity and python. Python API : {0}, Unity API : " 73 | "{1}.\nPlease go to https://github.com/Unity-Technologies/ml-agents to download the latest version " 74 | "of ML-Agents.".format(self._version_, self._unity_version)) 75 | self._n_agents = {} 76 | self._global_done = None 77 | self._academy_name = aca_params.name 78 | self._log_path = aca_params.log_path 79 | self._brains = {} 80 | self._brain_names = [] 81 | self._external_brain_names = [] 82 | for brain_param in aca_params.brain_parameters: 83 | self._brain_names += [brain_param.brain_name] 84 | resolution = [{ 85 | "height": x.height, 86 | "width": x.width, 87 | "blackAndWhite": x.gray_scale 88 | } for x in brain_param.camera_resolutions] 89 | self._brains[brain_param.brain_name] = \ 90 | BrainParameters(brain_param.brain_name, { 91 | "vectorObservationSize": brain_param.vector_observation_size, 92 | "numStackedVectorObservations": brain_param.num_stacked_vector_observations, 93 | "cameraResolutions": resolution, 94 | "vectorActionSize": brain_param.vector_action_size, 95 | "vectorActionDescriptions": brain_param.vector_action_descriptions, 96 | "vectorActionSpaceType": brain_param.vector_action_space_type, 97 | "vectorObservationSpaceType": brain_param.vector_observation_space_type 98 | }) 99 | if brain_param.brain_type == 2: 100 | self._external_brain_names += [brain_param.brain_name] 101 | self._num_brains = len(self._brain_names) 102 | self._num_external_brains = len(self._external_brain_names) 103 | self._resetParameters = dict(aca_params.environment_parameters.float_parameters) # TODO 104 | self._curriculum = Curriculum(curriculum, self._resetParameters) 105 | logger.info("\n'{0}' started successfully!\n{1}".format(self._academy_name, str(self))) 106 | if self._num_external_brains == 0: 107 | logger.warning(" No External Brains found in the Unity Environment. " 108 | "You will not be able to pass actions to your agent(s).") 109 | 110 | @property 111 | def curriculum(self): 112 | return self._curriculum 113 | 114 | @property 115 | def logfile_path(self): 116 | return self._log_path 117 | 118 | @property 119 | def brains(self): 120 | return self._brains 121 | 122 | @property 123 | def global_done(self): 124 | return self._global_done 125 | 126 | @property 127 | def academy_name(self): 128 | return self._academy_name 129 | 130 | @property 131 | def number_brains(self): 132 | return self._num_brains 133 | 134 | @property 135 | def number_external_brains(self): 136 | return self._num_external_brains 137 | 138 | @property 139 | def brain_names(self): 140 | return self._brain_names 141 | 142 | @property 143 | def external_brain_names(self): 144 | return self._external_brain_names 145 | 146 | def executable_launcher(self, file_name, docker_training, no_graphics): 147 | cwd = os.getcwd() 148 | file_name = (file_name.strip() 149 | .replace('.app', '').replace('.exe', '').replace('.x86_64', '').replace('.x86', '')) 150 | true_filename = os.path.basename(os.path.normpath(file_name)) 151 | logger.debug('The true file name is {}'.format(true_filename)) 152 | launch_string = None 153 | if platform == "linux" or platform == "linux2": 154 | candidates = glob.glob(os.path.join(cwd, file_name) + '.x86_64') 155 | if len(candidates) == 0: 156 | candidates = glob.glob(os.path.join(cwd, file_name) + '.x86') 157 | if len(candidates) == 0: 158 | candidates = glob.glob(file_name + '.x86_64') 159 | if len(candidates) == 0: 160 | candidates = glob.glob(file_name + '.x86') 161 | if len(candidates) > 0: 162 | launch_string = candidates[0] 163 | 164 | elif platform == 'darwin': 165 | candidates = glob.glob(os.path.join(cwd, file_name + '.app', 'Contents', 'MacOS', true_filename)) 166 | if len(candidates) == 0: 167 | candidates = glob.glob(os.path.join(file_name + '.app', 'Contents', 'MacOS', true_filename)) 168 | if len(candidates) == 0: 169 | candidates = glob.glob(os.path.join(cwd, file_name + '.app', 'Contents', 'MacOS', '*')) 170 | if len(candidates) == 0: 171 | candidates = glob.glob(os.path.join(file_name + '.app', 'Contents', 'MacOS', '*')) 172 | if len(candidates) > 0: 173 | launch_string = candidates[0] 174 | elif platform == 'win32': 175 | candidates = glob.glob(os.path.join(cwd, file_name + '.exe')) 176 | if len(candidates) == 0: 177 | candidates = glob.glob(file_name + '.exe') 178 | if len(candidates) > 0: 179 | launch_string = candidates[0] 180 | if launch_string is None: 181 | self._close() 182 | raise UnityEnvironmentException("Couldn't launch the {0} environment. " 183 | "Provided filename does not match any environments." 184 | .format(true_filename)) 185 | else: 186 | logger.debug("This is the launch string {}".format(launch_string)) 187 | # Launch Unity environment 188 | if not docker_training: 189 | if no_graphics: 190 | self.proc1 = subprocess.Popen( 191 | [launch_string,'-nographics', '-batchmode', 192 | '--port', str(self.port)]) 193 | else: 194 | self.proc1 = subprocess.Popen( 195 | [launch_string, '--port', str(self.port)]) 196 | else: 197 | """ 198 | Comments for future maintenance: 199 | xvfb-run is a wrapper around Xvfb, a virtual xserver where all 200 | rendering is done to virtual memory. It automatically creates a 201 | new virtual server automatically picking a server number `auto-servernum`. 202 | The server is passed the arguments using `server-args`, we are telling 203 | Xvfb to create Screen number 0 with width 640, height 480 and depth 24 bits. 204 | Note that 640 X 480 are the default width and height. The main reason for 205 | us to add this is because we'd like to change the depth from the default 206 | of 8 bits to 24. 207 | Unfortunately, this means that we will need to pass the arguments through 208 | a shell which is why we set `shell=True`. Now, this adds its own 209 | complications. E.g SIGINT can bounce off the shell and not get propagated 210 | to the child processes. This is why we add `exec`, so that the shell gets 211 | launched, the arguments are passed to `xvfb-run`. `exec` replaces the shell 212 | we created with `xvfb`. 213 | """ 214 | docker_ls = ("exec xvfb-run --auto-servernum" 215 | " --server-args='-screen 0 640x480x24'" 216 | " {0} --port {1}").format(launch_string, str(self.port)) 217 | self.proc1 = subprocess.Popen(docker_ls, 218 | stdout=subprocess.PIPE, 219 | stderr=subprocess.PIPE, 220 | shell=True) 221 | 222 | def get_communicator(self, worker_id, base_port): 223 | return RpcCommunicator(worker_id, base_port) 224 | # return SocketCommunicator(worker_id, base_port) 225 | 226 | def __str__(self): 227 | _new_reset_param = self._curriculum.get_config() 228 | for k in _new_reset_param: 229 | self._resetParameters[k] = _new_reset_param[k] 230 | return '''Unity Academy name: {0} 231 | Number of Brains: {1} 232 | Number of External Brains : {2} 233 | Lesson number : {3} 234 | Reset Parameters :\n\t\t{4}'''.format(self._academy_name, str(self._num_brains), 235 | str(self._num_external_brains), self._curriculum.get_lesson_number, 236 | "\n\t\t".join([str(k) + " -> " + str(self._resetParameters[k]) 237 | for k in self._resetParameters])) + '\n' + \ 238 | '\n'.join([str(self._brains[b]) for b in self._brains]) 239 | 240 | def reset(self, train_mode=True, config=None, lesson=None) -> AllBrainInfo: 241 | """ 242 | Sends a signal to reset the unity environment. 243 | :return: AllBrainInfo : A Data structure corresponding to the initial reset state of the environment. 244 | """ 245 | if config is None: 246 | config = self._curriculum.get_config(lesson) 247 | elif config != {}: 248 | logger.info("\nAcademy Reset with parameters : \t{0}" 249 | .format(', '.join([str(x) + ' -> ' + str(config[x]) for x in config]))) 250 | for k in config: 251 | if (k in self._resetParameters) and (isinstance(config[k], (int, float))): 252 | self._resetParameters[k] = config[k] 253 | elif not isinstance(config[k], (int, float)): 254 | raise UnityEnvironmentException( 255 | "The value for parameter '{0}'' must be an Integer or a Float.".format(k)) 256 | else: 257 | raise UnityEnvironmentException("The parameter '{0}' is not a valid parameter.".format(k)) 258 | 259 | if self._loaded: 260 | outputs = self.communicator.exchange( 261 | self._generate_reset_input(train_mode, config) 262 | ) 263 | if outputs is None: 264 | raise KeyboardInterrupt 265 | rl_output = outputs.rl_output 266 | s = self._get_state(rl_output) 267 | self._global_done = s[1] 268 | for _b in self._external_brain_names: 269 | self._n_agents[_b] = len(s[0][_b].agents) 270 | return s[0] 271 | else: 272 | raise UnityEnvironmentException("No Unity environment is loaded.") 273 | 274 | def step(self, vector_action=None, memory=None, text_action=None) -> AllBrainInfo: 275 | """ 276 | Provides the environment with an action, moves the environment dynamics forward accordingly, and returns 277 | observation, state, and reward information to the agent. 278 | :param vector_action: Agent's vector action to send to environment. Can be a scalar or vector of int/floats. 279 | :param memory: Vector corresponding to memory used for RNNs, frame-stacking, or other auto-regressive process. 280 | :param text_action: Text action to send to environment for. 281 | :return: AllBrainInfo : A Data structure corresponding to the new state of the environment. 282 | """ 283 | vector_action = {} if vector_action is None else vector_action 284 | memory = {} if memory is None else memory 285 | text_action = {} if text_action is None else text_action 286 | if self._loaded and not self._global_done and self._global_done is not None: 287 | if isinstance(vector_action, (int, np.int_, float, np.float_, list, np.ndarray)): 288 | if self._num_external_brains == 1: 289 | vector_action = {self._external_brain_names[0]: vector_action} 290 | elif self._num_external_brains > 1: 291 | raise UnityActionException( 292 | "You have {0} brains, you need to feed a dictionary of brain names a keys, " 293 | "and vector_actions as values".format(self._num_brains)) 294 | else: 295 | raise UnityActionException( 296 | "There are no external brains in the environment, " 297 | "step cannot take a vector_action input") 298 | 299 | if isinstance(memory, (int, np.int_, float, np.float_, list, np.ndarray)): 300 | if self._num_external_brains == 1: 301 | memory = {self._external_brain_names[0]: memory} 302 | elif self._num_external_brains > 1: 303 | raise UnityActionException( 304 | "You have {0} brains, you need to feed a dictionary of brain names as keys " 305 | "and memories as values".format(self._num_brains)) 306 | else: 307 | raise UnityActionException( 308 | "There are no external brains in the environment, " 309 | "step cannot take a memory input") 310 | if isinstance(text_action, (str, list, np.ndarray)): 311 | if self._num_external_brains == 1: 312 | text_action = {self._external_brain_names[0]: text_action} 313 | elif self._num_external_brains > 1: 314 | raise UnityActionException( 315 | "You have {0} brains, you need to feed a dictionary of brain names as keys " 316 | "and text_actions as values".format(self._num_brains)) 317 | else: 318 | raise UnityActionException( 319 | "There are no external brains in the environment, " 320 | "step cannot take a value input") 321 | 322 | for brain_name in list(vector_action.keys()) + list(memory.keys()) + list(text_action.keys()): 323 | if brain_name not in self._external_brain_names: 324 | raise UnityActionException( 325 | "The name {0} does not correspond to an external brain " 326 | "in the environment".format(brain_name)) 327 | 328 | for b in self._external_brain_names: 329 | n_agent = self._n_agents[b] 330 | if b not in vector_action: 331 | # raise UnityActionException("You need to input an action for the brain {0}".format(b)) 332 | if self._brains[b].vector_action_space_type == "discrete": 333 | vector_action[b] = [0.0] * n_agent 334 | else: 335 | vector_action[b] = [0.0] * n_agent * self._brains[b].vector_action_space_size 336 | else: 337 | vector_action[b] = self._flatten(vector_action[b]) 338 | if b not in memory: 339 | memory[b] = [] 340 | else: 341 | if memory[b] is None: 342 | memory[b] = [] 343 | else: 344 | memory[b] = self._flatten(memory[b]) 345 | if b not in text_action: 346 | text_action[b] = [""] * n_agent 347 | else: 348 | if text_action[b] is None: 349 | text_action[b] = [""] * n_agent 350 | if isinstance(text_action[b], str): 351 | text_action[b] = [text_action[b]] * n_agent 352 | if not ((len(text_action[b]) == n_agent) or len(text_action[b]) == 0): 353 | raise UnityActionException( 354 | "There was a mismatch between the provided text_action and environment's expectation: " 355 | "The brain {0} expected {1} text_action but was given {2}".format( 356 | b, n_agent, len(text_action[b]))) 357 | if not ((self._brains[b].vector_action_space_type == "discrete" and len(vector_action[b]) == n_agent) or 358 | (self._brains[b].vector_action_space_type == "continuous" and len( 359 | vector_action[b]) == self._brains[b].vector_action_space_size * n_agent)): 360 | raise UnityActionException( 361 | "There was a mismatch between the provided action and environment's expectation: " 362 | "The brain {0} expected {1} {2} action(s), but was provided: {3}" 363 | .format(b, n_agent if self._brains[b].vector_action_space_type == "discrete" else 364 | str(self._brains[b].vector_action_space_size * n_agent), 365 | self._brains[b].vector_action_space_type, 366 | str(vector_action[b]))) 367 | 368 | outputs = self.communicator.exchange( 369 | self._generate_step_input(vector_action, memory, text_action) 370 | ) 371 | if outputs is None: 372 | raise KeyboardInterrupt 373 | rl_output = outputs.rl_output 374 | s = self._get_state(rl_output) 375 | self._global_done = s[1] 376 | for _b in self._external_brain_names: 377 | self._n_agents[_b] = len(s[0][_b].agents) 378 | return s[0] 379 | elif not self._loaded: 380 | raise UnityEnvironmentException("No Unity environment is loaded.") 381 | elif self._global_done: 382 | raise UnityActionException("The episode is completed. Reset the environment with 'reset()'") 383 | elif self.global_done is None: 384 | raise UnityActionException( 385 | "You cannot conduct step without first calling reset. Reset the environment with 'reset()'") 386 | 387 | def close(self): 388 | """ 389 | Sends a shutdown signal to the unity environment, and closes the socket connection. 390 | """ 391 | if self._loaded: 392 | self._close() 393 | else: 394 | raise UnityEnvironmentException("No Unity environment is loaded.") 395 | 396 | def _close(self): 397 | self._loaded = False 398 | self.communicator.close() 399 | if self.proc1 is not None: 400 | self.proc1.kill() 401 | 402 | @staticmethod 403 | def _flatten(arr): 404 | """ 405 | Converts arrays to list. 406 | :param arr: numpy vector. 407 | :return: flattened list. 408 | """ 409 | if isinstance(arr, (int, np.int_, float, np.float_)): 410 | arr = [float(arr)] 411 | if isinstance(arr, np.ndarray): 412 | arr = arr.tolist() 413 | if len(arr) == 0: 414 | return arr 415 | if isinstance(arr[0], np.ndarray): 416 | arr = [item for sublist in arr for item in sublist.tolist()] 417 | if isinstance(arr[0], list): 418 | arr = [item for sublist in arr for item in sublist] 419 | arr = [float(x) for x in arr] 420 | return arr 421 | 422 | @staticmethod 423 | def _process_pixels(image_bytes, gray_scale): 424 | """ 425 | Converts byte array observation image into numpy array, re-sizes it, and optionally converts it to grey scale 426 | :param image_bytes: input byte array corresponding to image 427 | :return: processed numpy array of observation from environment 428 | """ 429 | s = bytearray(image_bytes) 430 | image = Image.open(io.BytesIO(s)) 431 | s = np.array(image) / 255.0 432 | if gray_scale: 433 | s = np.mean(s, axis=2) 434 | s = np.reshape(s, [s.shape[0], s.shape[1], 1]) 435 | return s 436 | 437 | def _get_state(self, output: UnityRLOutput) -> (AllBrainInfo, bool): 438 | """ 439 | Collects experience information from all external brains in environment at current step. 440 | :return: a dictionary of BrainInfo objects. 441 | """ 442 | _data = {} 443 | global_done = output.global_done 444 | for b in output.agentInfos: 445 | agent_info_list = output.agentInfos[b].value 446 | vis_obs = [] 447 | for i in range(self.brains[b].number_visual_observations): 448 | obs = [self._process_pixels(x.visual_observations[i], 449 | self.brains[b].camera_resolutions[i]['blackAndWhite']) 450 | for x in agent_info_list] 451 | vis_obs += [np.array(obs)] 452 | if len(agent_info_list) == 0: 453 | memory_size = 0 454 | else: 455 | memory_size = max([len(x.memories) for x in agent_info_list]) 456 | if memory_size == 0: 457 | memory = np.zeros((0, 0)) 458 | else: 459 | [x.memories.extend([0] * (memory_size - len(x.memories))) for x in agent_info_list] 460 | memory = np.array([x.memories for x in agent_info_list]) 461 | _data[b] = BrainInfo( 462 | visual_observation=vis_obs, 463 | vector_observation=np.array([x.stacked_vector_observation for x in agent_info_list]), 464 | text_observations=[x.text_observation for x in agent_info_list], 465 | memory=memory, 466 | reward=[x.reward for x in agent_info_list], 467 | agents=[x.id for x in agent_info_list], 468 | local_done=[x.done for x in agent_info_list], 469 | vector_action=np.array([x.stored_vector_actions for x in agent_info_list]), 470 | text_action=[x.stored_text_actions for x in agent_info_list], 471 | max_reached=[x.max_step_reached for x in agent_info_list] 472 | ) 473 | return _data, global_done 474 | 475 | def _generate_step_input(self, vector_action, memory, text_action) -> UnityRLInput: 476 | rl_in = UnityRLInput() 477 | for b in vector_action: 478 | n_agents = self._n_agents[b] 479 | if n_agents == 0: 480 | continue 481 | _a_s = len(vector_action[b]) // n_agents 482 | _m_s = len(memory[b]) // n_agents 483 | for i in range(n_agents): 484 | action = AgentActionProto( 485 | vector_actions=vector_action[b][i*_a_s: (i+1)*_a_s], 486 | memories=memory[b][i*_m_s: (i+1)*_m_s], 487 | text_actions=text_action[b][i] 488 | ) 489 | rl_in.agent_actions[b].value.extend([action]) 490 | rl_in.command = 0 491 | return self.wrap_unity_input(rl_in) 492 | 493 | def _generate_reset_input(self, training, config) -> UnityRLInput: 494 | rl_in = UnityRLInput() 495 | rl_in.is_training = training 496 | rl_in.environment_parameters.CopyFrom(EnvironmentParametersProto()) 497 | for key in config: 498 | rl_in.environment_parameters.float_parameters[key] = config[key] 499 | rl_in.command = 1 500 | return self.wrap_unity_input(rl_in) 501 | 502 | def send_academy_parameters(self, init_parameters: UnityRLInitializationInput) -> UnityRLInitializationOutput: 503 | inputs = UnityInput() 504 | inputs.rl_initialization_input.CopyFrom(init_parameters) 505 | return self.communicator.initialize(inputs).rl_initialization_output 506 | 507 | def wrap_unity_input(self, rl_input: UnityRLInput) -> UnityOutput: 508 | result = UnityInput() 509 | result.rl_input.CopyFrom(rl_input) 510 | return result 511 | --------------------------------------------------------------------------------