├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── clean_data.py ├── conda_env.yaml ├── conda_env_cpu.yaml ├── config_file.py ├── ddpg └── ddpg.py ├── mlagents └── envs │ ├── __init__.py │ ├── base_unity_environment.py │ ├── brain.py │ ├── communicator.py │ ├── communicator_objects │ ├── __init__.py │ ├── agent_action_proto_pb2.py │ ├── agent_info_proto_pb2.py │ ├── brain_parameters_proto_pb2.py │ ├── command_proto_pb2.py │ ├── custom_action_pb2.py │ ├── custom_observation_pb2.py │ ├── custom_reset_parameters_pb2.py │ ├── demonstration_meta_proto_pb2.py │ ├── engine_configuration_proto_pb2.py │ ├── environment_parameters_proto_pb2.py │ ├── header_pb2.py │ ├── resolution_proto_pb2.py │ ├── space_type_proto_pb2.py │ ├── unity_input_pb2.py │ ├── unity_message_pb2.py │ ├── unity_output_pb2.py │ ├── unity_rl_initialization_input_pb2.py │ ├── unity_rl_initialization_output_pb2.py │ ├── unity_rl_input_pb2.py │ ├── unity_rl_output_pb2.py │ ├── unity_to_external_pb2.py │ └── unity_to_external_pb2_grpc.py │ ├── environment.py │ ├── exception.py │ ├── mock_communicator.py │ ├── rpc_communicator.py │ ├── socket_communicator.py │ ├── subprocess_environment.py │ └── tests │ ├── __init__.py │ ├── test_envs.py │ ├── test_rpc_communicator.py │ └── test_subprocess_unity_environment.py ├── ppo ├── ppo.py ├── ppo_base.py └── ppopac.py ├── run.py ├── sac ├── sac.py └── sac_no_v.py ├── simple_run.py ├── td3 └── td3.py ├── test.py └── utils ├── recorder.py ├── replay_buffer.py └── sth.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | pg/ 107 | trial/ 108 | run.py 109 | config_file.py -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "C:\\A\\Anaconda\\envs\\mfpy\\python.exe" 3 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Keavnn 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RLwithUnity 2 | 3 | This project include some state-of-art or classic RL(reinforcement learning) algorithms used for training agents by interactive with Unity through [ml-agents](https://github.com/Unity-Technologies/ml-agents) v0.8.1. 4 | 5 | The Algorithms in this repository are writed totally separated, 'cause I want each algorithm being different with others, what I mean is that I just wanna each algorithm has its own `.py` file and don't have to switch to another file to find out the implementation which one may confused. And those algorithms will never be encapsulated into a base algorithm model. 6 | 7 | This framework implements training mechanism conversion between On-Policy and Off-Policy for Actor-Critic architecture algorithms. Just need to set the value of varibable `use_replay_buffer` in `config_file.py`(True for off-policy and False for on-policy). 8 | 9 | You can just run each algorithm in this repository by `python simple_run.py`. I don't put any record function in it(like excel, mongo, logger, checkpoint, summary...). 10 | 11 | I am very appreciate to my best friend - [BlueFisher](https://github.com/BlueFisher) - who always look down on my coding style and competence(Although he **is** right, at least for now, but will be **was**). 12 | 13 | Any questions about this project or errors about my bad grammer, plz let me know in [this](https://github.com/StepNeverStop/RLwithUnity/issues). -------------------------------------------------------------------------------- /clean_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | from config_file import config 5 | 6 | 7 | def clean(config): 8 | length = len(config['clean_list']) 9 | count = 0 10 | for i in range(length): 11 | cp_dir = config['record config']['checkpoint_basic_dir'] + \ 12 | config['clean_list'][i] 13 | log_dir = config['record config']['log_basic_dir'] + \ 14 | config['clean_list'][i] 15 | excel_dir = config['record config']['excel_basic_dir'] + \ 16 | config['clean_list'][i] 17 | config_dir = config['record config']['config_basic_dir'] + \ 18 | config['clean_list'][i] 19 | if os.path.exists(log_dir): 20 | print('-' * 10 + str(i) + '-' * 10) 21 | print('C.L.E.A.N : {0}'.format(config['clean_list'][i])) 22 | try: 23 | shutil.rmtree(log_dir) 24 | print('remove LOG success.') 25 | shutil.rmtree(cp_dir) 26 | print('remove CHECKPOINT success.') 27 | shutil.rmtree(config_dir) 28 | print('remove CONFIG success.') 29 | shutil.rmtree(excel_dir) 30 | print('remove EXCEL success.') 31 | count += 1 32 | except Exception as e: 33 | print(e) 34 | sys.exit() 35 | else: 36 | print('{0} is not exist, please check and run again...'.format( 37 | config['clean_list'][i])) 38 | print(f'total: {length}, clean: {count}') 39 | 40 | 41 | clean(config) 42 | -------------------------------------------------------------------------------- /conda_env.yaml: -------------------------------------------------------------------------------- 1 | name: mlgpu 2 | channels: 3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/bioconda/ 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2/ 9 | dependencies: 10 | - _tflow_select=2.1.0=gpu 11 | - absl-py=0.7.0=py36_0 12 | - astor=0.7.1=py36_0 13 | - blas=1.0=mkl 14 | - ca-certificates=2019.1.23=0 15 | - certifi=2019.3.9=py36_0 16 | - cudatoolkit=9.0=1 17 | - cudnn=7.3.1=cuda9.0_0 18 | - et_xmlfile=1.0.1=py36h3d2d736_0 19 | - freetype=2.9.1=ha9979f8_1 20 | - gast=0.2.2=py36_0 21 | - grpcio=1.16.1=py36h351948d_1 22 | - h5py=2.9.0=py36h5e291fa_0 23 | - hdf5=1.10.4=h7ebc959_0 24 | - icc_rt=2019.0.0=h0cc432a_1 25 | - intel-openmp=2019.1=144 26 | - jdcal=1.4=py36_0 27 | - jpeg=9b=hb83a4c4_2 28 | - keras-applications=1.0.6=py36_0 29 | - keras-preprocessing=1.0.5=py36_0 30 | - libpng=1.6.36=h2a8f88b_0 31 | - libprotobuf=3.6.1=h7bd577a_0 32 | - libtiff=4.0.10=hb898794_2 33 | - markdown=3.0.1=py36_0 34 | - mkl=2019.1=144 35 | - mkl_fft=1.0.10=py36h14836fe_0 36 | - mkl_random=1.0.2=py36h343c172_0 37 | - numpy=1.16.2=py36h19fb1c0_0 38 | - numpy-base=1.16.2=py36hc3f5095_0 39 | - olefile=0.46=py36_0 40 | - openpyxl=2.6.1=py_0 41 | - openssl=1.1.1b=he774522_1 42 | - pandas=0.24.2=py36ha925a31_0 43 | - pillow=5.4.1=py36hdc69c19_0 44 | - pip=19.0.3=py36_0 45 | - protobuf=3.6.1=py36h33f27b4_0 46 | - pymongo=3.7.2=py36ha925a31_0 47 | - pyreadline=2.1=py36_1 48 | - python=3.6.8=h9f7ef89_7 49 | - python-dateutil=2.8.0=py36_0 50 | - pytz=2018.9=py36_0 51 | - pywin32=223=py36hfa6e2cd_1 52 | - pyyaml=5.1=py36he774522_0 53 | - scipy=1.2.1=py36h29ff71c_0 54 | - setuptools=40.8.0=py36_0 55 | - six=1.12.0=py36_0 56 | - sqlite=3.26.0=he774522_0 57 | - tensorboard=1.12.2=py36h33f27b4_0 58 | - tensorflow=1.12.0=gpu_py36ha5f9131_0 59 | - tensorflow-base=1.12.0=gpu_py36h6e53903_0 60 | - tensorflow-gpu=1.12.0=h0d30ee6_0 61 | - termcolor=1.1.0=py36_1 62 | - tk=8.6.8=hfa6e2cd_0 63 | - vc=14.1=h0510ff6_4 64 | - vs2015_runtime=14.15.26706=h3a45250_0 65 | - werkzeug=0.14.1=py36_0 66 | - wheel=0.33.1=py36_0 67 | - wincertstore=0.2=py36h7fe50ca_0 68 | - xz=5.2.4=h2fa13f4_4 69 | - yaml=0.1.7=hc54c509_2 70 | - zlib=1.2.11=h62dcd97_3 71 | - zstd=1.3.7=h508b16e_0 72 | prefix: C:\A\Anaconda\envs\mlgpu 73 | 74 | -------------------------------------------------------------------------------- /conda_env_cpu.yaml: -------------------------------------------------------------------------------- 1 | name: mlcpu 2 | channels: 3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/bioconda/ 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2/ 9 | dependencies: 10 | - _tflow_select=2.3.0=mkl 11 | - absl-py=0.7.0=py36_0 12 | - astor=0.7.1=py36_0 13 | - blas=1.0=mkl 14 | - ca-certificates=2019.1.23=0 15 | - certifi=2019.3.9=py36_0 16 | - et_xmlfile=1.0.1=py36h3d2d736_0 17 | - freetype=2.9.1=ha9979f8_1 18 | - gast=0.2.2=py36_0 19 | - grpcio=1.16.1=py36h351948d_1 20 | - h5py=2.9.0=py36h5e291fa_0 21 | - hdf5=1.10.4=h7ebc959_0 22 | - icc_rt=2019.0.0=h0cc432a_1 23 | - intel-openmp=2019.1=144 24 | - jdcal=1.4=py36_0 25 | - jpeg=9b=hb83a4c4_2 26 | - keras-applications=1.0.6=py36_0 27 | - keras-preprocessing=1.0.5=py36_0 28 | - libmklml=2019.0.3=0 29 | - libpng=1.6.36=h2a8f88b_0 30 | - libprotobuf=3.6.1=h7bd577a_0 31 | - libtiff=4.0.10=hb898794_2 32 | - markdown=3.0.1=py36_0 33 | - mkl=2019.1=144 34 | - mkl_fft=1.0.10=py36h14836fe_0 35 | - mkl_random=1.0.2=py36h343c172_0 36 | - numpy=1.16.2=py36h19fb1c0_0 37 | - numpy-base=1.16.2=py36hc3f5095_0 38 | - olefile=0.46=py36_0 39 | - openpyxl=2.6.1=py36_1 40 | - openssl=1.1.1b=he774522_1 41 | - pandas=0.24.2=py36ha925a31_0 42 | - pillow=5.4.1=py36hdc69c19_0 43 | - pip=19.0.3=py36_0 44 | - protobuf=3.6.1=py36h33f27b4_0 45 | - pymongo=3.7.2=py36ha925a31_0 46 | - pyreadline=2.1=py36_1 47 | - python=3.6.8=h9f7ef89_7 48 | - python-dateutil=2.8.0=py36_0 49 | - pytz=2018.9=py36_0 50 | - pywin32=223=py36hfa6e2cd_1 51 | - pyyaml=5.1=py36he774522_0 52 | - scipy=1.2.1=py36h29ff71c_0 53 | - setuptools=40.8.0=py36_0 54 | - six=1.12.0=py36_0 55 | - sqlite=3.26.0=he774522_0 56 | - tensorboard=1.12.2=py36h33f27b4_0 57 | - tensorflow=1.12.0=mkl_py36h4f00353_0 58 | - tensorflow-base=1.12.0=mkl_py36h81393da_0 59 | - termcolor=1.1.0=py36_1 60 | - tk=8.6.8=hfa6e2cd_0 61 | - vc=14.1=h0510ff6_4 62 | - vs2015_runtime=14.15.26706=h3a45250_0 63 | - werkzeug=0.14.1=py36_0 64 | - wheel=0.33.1=py36_0 65 | - wincertstore=0.2=py36h7fe50ca_0 66 | - xz=5.2.4=h2fa13f4_4 67 | - yaml=0.1.7=hc54c509_2 68 | - zlib=1.2.11=h62dcd97_3 69 | - zstd=1.3.7=h508b16e_0 70 | prefix: C:\A\Anaconda\envs\mlcpu 71 | 72 | -------------------------------------------------------------------------------- /config_file.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | from enum import Enum 4 | 5 | # judge the current operating system 6 | base = r'C:' if platform.system() == "Windows" else r'/data/wjs' 7 | 8 | env_list = [ 9 | 'RollerBall', 10 | '3DBall', 11 | 'Boat' 12 | ] 13 | 14 | reset_config = [None, { 15 | 'copy': 10 16 | }] 17 | 18 | class algorithms(Enum): 19 | ppo_sep_ac = 1 # AC, stochastic 20 | ppo_com = 2 # AC, stochastic 21 | # boundary, about the way of calculate `discounted reward` 22 | sac = 3 # AC+Q, stochastic, off-policy 23 | sac_no_v = 4 24 | ddpg = 5 # AC+Q, deterministic, off-policy 25 | td3 = 6 # AC+Q, deterministic, off-policy 26 | 27 | 28 | unity_file = [ 29 | r'C:/UnityBuild/RollerBall/OneFloor/RollerBall-custom.exe',#0 30 | r'3dball',#1 31 | r'C:/UnityBuild/Boat/first/BoatTrain.exe',#2 32 | r'C:/UnityBuild/Boat/second/BoatTrain.exe',#3 33 | r'C:/UnityBuild/Boat/interval1/BoatTrain.exe',#4 34 | r'C:/UnityBuild/Boat/no_border/BoatTrain.exe',#5 35 | r'C:/UnityBuild/Boat/no_border2/BoatTrain.exe'#6 36 | ] 37 | 38 | max_episode = 50000 # max episode num or step num, depend on whether episode-update or step-update 39 | 40 | config = { 41 | 'hyper parameters': { 42 | # set the temperature of SAC, auto adjust or not 43 | 'alpha': 0.2, 44 | 'auto_adaption': True, 45 | 46 | 'ployak': 0.995, # range from 0. to 1. 47 | 'epsilon': 0.2, # control the learning stepsize of clip-ppo 48 | 'beta': 1.0e-3, # coefficient of entropy regularizatione 49 | 'lr': 5.0e-4, 50 | 'actor_lr': 0.0001, 51 | 'critic_lr': 0.0002, 52 | 'tp_lr': 0.001, 53 | 'reward_lr': 0.001, 54 | 'gamma': 0.99, 55 | 'lambda': 0.95, 56 | 'action_bound': 1, 57 | 'decay_rate': 0.7, 58 | 'decay_steps': 100, 59 | 'stair': False, 60 | 'max_episode': max_episode, 61 | 'base_sigma': 0.1, # only work on stochastic policy 62 | 'assign_interval': 4 # not use yet 63 | }, 64 | 65 | 'train config': { 66 | # choose algorithm 67 | 'algorithm': algorithms.sac, 68 | 'init_max_step': 300, 69 | 'max_step': 1000, # use for both on-policy and off-policy, control the max step within one episode. 70 | 'max_episode': max_episode, 71 | 'max_sample_time': 20, 72 | 'till_all_done': True, # use for on-policy leanring 73 | 'start_continuous_done': False, 74 | # train mode, .exe or unity-client && train or inference 75 | 'train': True, 76 | 'unity_mode': False, 77 | 'unity_file': unity_file[0].replace('C:',f'{base}'), 78 | 'port': 5006, 79 | # trick 80 | 'use_trick': True, 81 | # excel 82 | 'excel_record': False, 83 | 'excel_record_frequency': 10, 84 | # mongodb 85 | 'mongo_record': False, 86 | 'mongo_record_frequency': 10, 87 | 'mongo_record_all': False, 88 | # shuffle batch or not 89 | 'random_batch': True, 90 | 'batchsize': 100, 91 | 'epoch': 1, 92 | # checkpoint 93 | 'save_frequency': 20, 94 | # set the agents' number and control mode 95 | 'dynamic_allocation': True, 96 | 'reset_config': reset_config[1], 97 | # some sets about using replay_buffer 98 | 'use_replay_buffer': True, # on-policy or off-policy 99 | 'use_priority' : False, 100 | 'buffer_size' : 10000, 101 | 'buffer_batch_size': 100, 102 | 'max_learn_time' : 20 103 | }, 104 | 105 | 'record config': { 106 | 'basic_dir': r'C:/RLData/'.replace('C:',f'{base}'), 107 | 'log_basic_dir': r'C:/RLData/logs/'.replace('C:',f'{base}'), 108 | 'excel_basic_dir': r'C:/RLData/excels/'.replace('C:',f'{base}'), 109 | 'checkpoint_basic_dir': r'C:/RLData/models/'.replace('C:',f'{base}'), 110 | 'config_basic_dir': r'C:/RLData/config/'.replace('C:',f'{base}'), 111 | 'project_name': env_list[0], 112 | 'remark': r'sac_onefloor', 113 | 'run_id': r'3', 114 | 'logger2file' : False 115 | }, 116 | 117 | 'config_file': r"", 118 | 'ps': r"time penalty=-0.01", 119 | 120 | 'clean_list': [ 121 | r'Boat\sac_off_no_border0', 122 | r'Boat\sac_off_no_border1', 123 | r'Boat\sac_off_no_border2', 124 | r'RollerBall\sac_onefloor0', 125 | r'RollerBall\sac_onefloor1', 126 | r'RollerBall\sac_onefloor2', 127 | ] 128 | } 129 | -------------------------------------------------------------------------------- /ddpg/ddpg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow.contrib.layers as c_layers 4 | 5 | initKernelAndBias = { 6 | 'kernel_initializer': tf.random_normal_initializer(0., .1), 7 | 'bias_initializer': tf.constant_initializer(0.1, dtype=tf.float32) 8 | } 9 | # initKernelAndBias={ 10 | # 'kernel_initializer' : c_layers.variance_scaling_initializer(1.0) 11 | # } 12 | 13 | 14 | class DDPG(object): 15 | def __init__(self, sess, s_dim, a_counts, hyper_config): 16 | self.sess = sess 17 | self.s_dim = s_dim 18 | self.a_counts = a_counts 19 | self.activation_fn = tf.nn.tanh 20 | self.action_bound = hyper_config['action_bound'] 21 | self.assign_interval = hyper_config['assign_interval'] 22 | 23 | self.episode = tf.Variable(tf.constant(0)) 24 | self.lr = tf.train.polynomial_decay( 25 | hyper_config['lr'], self.episode, hyper_config['max_episode'], 1e-10, power=1.0) 26 | self.s = tf.placeholder(tf.float32, [None, self.s_dim], 'state') 27 | self.a = tf.placeholder(tf.float32, [None, self.a_counts], 'action') 28 | self.r = tf.placeholder(tf.float32, [None, 1], 'reward') 29 | self.s_ = tf.placeholder(tf.float32, [None, self.s_dim], 'next_state') 30 | 31 | self.mu, self.action, self.actor_var = self._build_actor_net( 32 | 'actor', self.s, True) 33 | self.target_mu, self.action_target, self.actor_target_var = self._build_actor_net( 34 | 'actor_target', self.s_, False) 35 | 36 | self.s_a = tf.concat((self.s, self.a), axis=1) 37 | self.s_mu = tf.concat((self.s, self.mu), axis=1) 38 | self.s_a_target = tf.concat((self.s_, self.target_mu), axis=1) 39 | 40 | self.q, self.q_var = self._build_q_net( 41 | 'q', self.s_a, True, reuse=False) 42 | self.q_actor, _ = self._build_q_net('q', self.s_mu, True, reuse=True) 43 | self.q_target, self.q_target_var = self._build_q_net( 44 | 'q_target', self.s_a_target, False, reuse=False) 45 | self.dc_r = tf.stop_gradient( 46 | self.r + hyper_config['gamma'] * self.q_target) 47 | 48 | self.q_loss = 0.5 * tf.reduce_mean( 49 | tf.squared_difference(self.q, self.dc_r)) 50 | self.actor_loss = -tf.reduce_mean(self.q_actor) 51 | 52 | q_var = tf.get_collection( 53 | tf.GraphKeys.TRAINABLE_VARIABLES, scope='q') 54 | actor_vars = tf.get_collection( 55 | tf.GraphKeys.TRAINABLE_VARIABLES, scope='actor') 56 | 57 | optimizer = tf.train.AdamOptimizer(self.lr) 58 | self.train_q = optimizer.minimize(self.q_loss, var_list=q_var) 59 | with tf.control_dependencies([self.train_q]): 60 | self.train_actor = optimizer.minimize( 61 | self.actor_loss, var_list=actor_vars) 62 | with tf.control_dependencies([self.train_actor]): 63 | self.assign_q_target = tf.group([tf.assign( 64 | r, hyper_config['ployak'] * v + (1 - hyper_config['ployak']) * r) for r, v in zip(self.q_target_var, self.q_var)]) 65 | self.assign_actor_target = tf.group([tf.assign( 66 | r, hyper_config['ployak'] * v + (1 - hyper_config['ployak']) * r) for r, v in zip(self.actor_target_var, self.actor_var)]) 67 | # self.assign_q_target = [ 68 | # tf.assign(r, 1/(self.episode+1) * v + (1-1/(self.episode+1)) * r) for r, v in zip(self.q_target_var, self.q_var)] 69 | # self.assign_q_target = [ 70 | # tf.assign(r, 1/(self.episode+1) * v + (1-1/(self.episode+1)) * r) for r, v in zip(self.actor_target_var, self.actor_var)] 71 | 72 | def _build_actor_net(self, name, input_vector, trainable): 73 | with tf.variable_scope(name): 74 | actor1 = tf.layers.dense( 75 | inputs=input_vector, 76 | units=128, 77 | activation=self.activation_fn, 78 | name='actor1', 79 | trainable=trainable, 80 | **initKernelAndBias 81 | ) 82 | actor2 = tf.layers.dense( 83 | inputs=actor1, 84 | units=64, 85 | activation=self.activation_fn, 86 | name='actor2', 87 | trainable=trainable, 88 | **initKernelAndBias 89 | ) 90 | mu = tf.layers.dense( 91 | inputs=actor2, 92 | units=self.a_counts, 93 | activation=tf.nn.tanh, 94 | name='mu', 95 | trainable=trainable, 96 | **initKernelAndBias 97 | ) 98 | e = tf.random_normal(tf.shape(mu)) 99 | action = tf.clip_by_value( 100 | mu + e, -self.action_bound, self.action_bound) / self.action_bound 101 | var = tf.get_variable_scope().global_variables() 102 | return mu, action, var 103 | 104 | def _build_q_net(self, name, input_vector, trainable, reuse=False): 105 | with tf.variable_scope(name): 106 | layer1 = tf.layers.dense( 107 | inputs=input_vector, 108 | units=256, 109 | activation=self.activation_fn, 110 | name='layer1', 111 | trainable=trainable, 112 | reuse=reuse, 113 | **initKernelAndBias 114 | ) 115 | layer2 = tf.layers.dense( 116 | inputs=layer1, 117 | units=256, 118 | activation=self.activation_fn, 119 | name='layer2', 120 | trainable=trainable, 121 | reuse=reuse, 122 | **initKernelAndBias 123 | ) 124 | q = tf.layers.dense( 125 | inputs=layer2, 126 | units=1, 127 | activation=None, 128 | name='q_value', 129 | trainable=trainable, 130 | reuse=reuse, 131 | **initKernelAndBias 132 | ) 133 | var = tf.get_variable_scope().global_variables() 134 | return q, var 135 | 136 | def decay_lr(self, episode, **kargs): 137 | return self.sess.run(self.lr, feed_dict={ 138 | self.episode: episode 139 | }) 140 | 141 | def choose_action(self, s, **kargs): 142 | return np.ones((s.shape[0], self.a_counts)), self.sess.run(self.action, feed_dict={ 143 | self.s: s 144 | }) 145 | 146 | def choose_inference_action(self, s, **kargs): 147 | return np.ones((s.shape[0], self.a_counts)), self.sess.run(self.mu, feed_dict={ 148 | self.s: s 149 | }) 150 | 151 | def get_state_value(self, s, **kargs): 152 | return np.squeeze(np.zeros(np.array(s).shape[0])) 153 | 154 | def learn(self, s, a, r, s_, episode, **kargs): 155 | self.sess.run([self.train_q, self.train_actor, self.assign_q_target, self.assign_actor_target], feed_dict={ 156 | self.s: s, 157 | self.a: a, 158 | self.r: r, 159 | self.s_: s_, 160 | self.episode: episode 161 | }) 162 | 163 | def get_actor_loss(self, s, **kargs): 164 | return self.sess.run(self.actor_loss, feed_dict={ 165 | self.s: s 166 | }) 167 | 168 | def get_critic_loss(self, s, a, r, s_, **kargs): 169 | return self.sess.run(self.q_loss, feed_dict={ 170 | self.s: s, 171 | self.a: a, 172 | self.r: r, 173 | self.s_: s_, 174 | }) 175 | 176 | def get_entropy(self, s, **kargs): 177 | return np.zeros((np.array(s).shape[0], self.a_counts)) 178 | 179 | def get_sigma(self, s, **kargs): 180 | return np.zeros((np.array(s).shape[0], self.a_counts)) 181 | -------------------------------------------------------------------------------- /mlagents/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .brain import * 2 | from .environment import * 3 | from .exception import * 4 | -------------------------------------------------------------------------------- /mlagents/envs/base_unity_environment.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict 3 | 4 | from mlagents.envs import AllBrainInfo, BrainParameters 5 | 6 | 7 | class BaseUnityEnvironment(ABC): 8 | @abstractmethod 9 | def step(self, vector_action=None, memory=None, text_action=None, value=None) -> AllBrainInfo: 10 | pass 11 | 12 | @abstractmethod 13 | def reset(self, config=None, train_mode=True) -> AllBrainInfo: 14 | pass 15 | 16 | @property 17 | @abstractmethod 18 | def global_done(self): 19 | pass 20 | 21 | @property 22 | @abstractmethod 23 | def external_brains(self) -> Dict[str, BrainParameters]: 24 | pass 25 | 26 | @property 27 | @abstractmethod 28 | def reset_parameters(self) -> Dict[str, str]: 29 | pass 30 | 31 | @abstractmethod 32 | def close(self): 33 | pass 34 | -------------------------------------------------------------------------------- /mlagents/envs/brain.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import io 4 | 5 | from typing import Dict, List, Optional 6 | from PIL import Image 7 | 8 | logger = logging.getLogger("mlagents.envs") 9 | 10 | 11 | class BrainInfo: 12 | def __init__(self, visual_observation, vector_observation, text_observations, memory=None, 13 | reward=None, agents=None, local_done=None, 14 | vector_action=None, text_action=None, max_reached=None, action_mask=None, 15 | custom_observations=None): 16 | """ 17 | Describes experience at current step of all agents linked to a brain. 18 | """ 19 | self.visual_observations = visual_observation 20 | self.vector_observations = vector_observation 21 | self.text_observations = text_observations 22 | self.memories = memory 23 | self.rewards = reward 24 | self.local_done = local_done 25 | self.max_reached = max_reached 26 | self.agents = agents 27 | self.previous_vector_actions = vector_action 28 | self.previous_text_actions = text_action 29 | self.action_masks = action_mask 30 | self.custom_observations = custom_observations 31 | 32 | def merge(self, other): 33 | for i in range(len(self.visual_observations)): 34 | self.visual_observations[i].extend(other.visual_observations[i]) 35 | self.vector_observations = np.append(self.vector_observations, other.vector_observations, axis=0) 36 | self.text_observations.extend(other.text_observations) 37 | self.memories = self.merge_memories(self.memories, other.memories, self.agents, other.agents) 38 | self.rewards = safe_concat_lists(self.rewards, other.rewards) 39 | self.local_done = safe_concat_lists(self.local_done, other.local_done) 40 | self.max_reached = safe_concat_lists(self.max_reached, other.max_reached) 41 | self.agents = safe_concat_lists(self.agents, other.agents) 42 | self.previous_vector_actions = safe_concat_np_ndarray( 43 | self.previous_vector_actions, other.previous_vector_actions 44 | ) 45 | self.previous_text_actions = safe_concat_lists( 46 | self.previous_text_actions, other.previous_text_actions 47 | ) 48 | self.action_masks = safe_concat_np_ndarray(self.action_masks, other.action_masks) 49 | self.custom_observations = safe_concat_lists(self.custom_observations, other.custom_observations) 50 | 51 | @staticmethod 52 | def merge_memories(m1, m2, agents1, agents2): 53 | if len(m1) == 0 and len(m2) != 0: 54 | m1 = np.zeros((len(agents1), m2.shape[1])) 55 | elif len(m2) == 0 and len(m1) != 0: 56 | m2 = np.zeros((len(agents2), m1.shape[1])) 57 | elif m2.shape[1] > m1.shape[1]: 58 | new_m1 = np.zeros((m1.shape[0], m2.shape[1])) 59 | new_m1[0:m1.shape[0], 0:m1.shape[1]] = m1 60 | return np.append(new_m1, m2, axis=0) 61 | elif m1.shape[1] > m2.shape[1]: 62 | new_m2 = np.zeros((m2.shape[0], m1.shape[1])) 63 | new_m2[0:m2.shape[0], 0:m2.shape[1]] = m2 64 | return np.append(m1, new_m2, axis=0) 65 | return np.append(m1, m2, axis=0) 66 | 67 | @staticmethod 68 | def process_pixels(image_bytes, gray_scale): 69 | """ 70 | Converts byte array observation image into numpy array, re-sizes it, 71 | and optionally converts it to grey scale 72 | :param gray_scale: Whether to convert the image to grayscale. 73 | :param image_bytes: input byte array corresponding to image 74 | :return: processed numpy array of observation from environment 75 | """ 76 | s = bytearray(image_bytes) 77 | image = Image.open(io.BytesIO(s)) 78 | s = np.array(image) / 255.0 79 | if gray_scale: 80 | s = np.mean(s, axis=2) 81 | s = np.reshape(s, [s.shape[0], s.shape[1], 1]) 82 | return s 83 | 84 | @staticmethod 85 | def from_agent_proto(agent_info_list, brain_params): 86 | """ 87 | Converts list of agent infos to BrainInfo. 88 | """ 89 | vis_obs = [] 90 | for i in range(brain_params.number_visual_observations): 91 | obs = [BrainInfo.process_pixels(x.visual_observations[i], 92 | brain_params.camera_resolutions[i]['blackAndWhite']) 93 | for x in agent_info_list] 94 | vis_obs += [obs] 95 | if len(agent_info_list) == 0: 96 | memory_size = 0 97 | else: 98 | memory_size = max([len(x.memories) for x in agent_info_list]) 99 | if memory_size == 0: 100 | memory = np.zeros((0, 0)) 101 | else: 102 | [x.memories.extend([0] * (memory_size - len(x.memories))) for x in agent_info_list] 103 | memory = np.array([list(x.memories) for x in agent_info_list]) 104 | total_num_actions = sum(brain_params.vector_action_space_size) 105 | mask_actions = np.ones((len(agent_info_list), total_num_actions)) 106 | for agent_index, agent_info in enumerate(agent_info_list): 107 | if agent_info.action_mask is not None: 108 | if len(agent_info.action_mask) == total_num_actions: 109 | mask_actions[agent_index, :] = [ 110 | 0 if agent_info.action_mask[k] else 1 for k in range(total_num_actions)] 111 | if any([np.isnan(x.reward) for x in agent_info_list]): 112 | logger.warning("An agent had a NaN reward for brain " + brain_params.brain_name) 113 | if any([np.isnan(x.stacked_vector_observation).any() for x in agent_info_list]): 114 | logger.warning("An agent had a NaN observation for brain " + brain_params.brain_name) 115 | 116 | if len(agent_info_list) == 0: 117 | vector_obs = np.zeros( 118 | (0, brain_params.vector_observation_space_size * brain_params.num_stacked_vector_observations) 119 | ) 120 | else: 121 | vector_obs = np.nan_to_num( 122 | np.array([x.stacked_vector_observation for x in agent_info_list]) 123 | ) 124 | brain_info = BrainInfo( 125 | visual_observation=vis_obs, 126 | vector_observation=vector_obs, 127 | text_observations=[x.text_observation for x in agent_info_list], 128 | memory=memory, 129 | reward=[x.reward if not np.isnan(x.reward) else 0 for x in agent_info_list], 130 | agents=[x.id for x in agent_info_list], 131 | local_done=[x.done for x in agent_info_list], 132 | vector_action=np.array([x.stored_vector_actions for x in agent_info_list]), 133 | text_action=[list(x.stored_text_actions) for x in agent_info_list], 134 | max_reached=[x.max_step_reached for x in agent_info_list], 135 | custom_observations=[x.custom_observation for x in agent_info_list], 136 | action_mask=mask_actions 137 | ) 138 | return brain_info 139 | 140 | 141 | def safe_concat_lists(l1: Optional[List], l2: Optional[List]): 142 | if l1 is None and l2 is None: 143 | return None 144 | if l1 is None and l2 is not None: 145 | return l2.copy() 146 | if l1 is not None and l2 is None: 147 | return l1.copy() 148 | else: 149 | copy = l1.copy() 150 | copy.extend(l2) 151 | return copy 152 | 153 | 154 | def safe_concat_np_ndarray(a1: Optional[np.ndarray], a2: Optional[np.ndarray]): 155 | if a1 is not None and a1.size != 0: 156 | if a2 is not None and a2.size != 0: 157 | return np.append(a1, a2, axis=0) 158 | else: 159 | return a1.copy() 160 | elif a2 is not None and a2.size != 0: 161 | return a2.copy() 162 | return None 163 | 164 | 165 | # Renaming of dictionary of brain name to BrainInfo for clarity 166 | AllBrainInfo = Dict[str, BrainInfo] 167 | 168 | 169 | class BrainParameters: 170 | def __init__(self, 171 | brain_name: str, 172 | vector_observation_space_size: int, 173 | num_stacked_vector_observations: int, 174 | camera_resolutions: List[Dict], 175 | vector_action_space_size: List[int], 176 | vector_action_descriptions: List[str], 177 | vector_action_space_type: int): 178 | """ 179 | Contains all brain-specific parameters. 180 | """ 181 | self.brain_name = brain_name 182 | self.vector_observation_space_size = vector_observation_space_size 183 | self.num_stacked_vector_observations = num_stacked_vector_observations 184 | self.number_visual_observations = len(camera_resolutions) 185 | self.camera_resolutions = camera_resolutions 186 | self.vector_action_space_size = vector_action_space_size 187 | self.vector_action_descriptions = vector_action_descriptions 188 | self.vector_action_space_type = ["discrete", "continuous"][vector_action_space_type] 189 | 190 | def __str__(self): 191 | return '''Unity brain name: {} 192 | Number of Visual Observations (per agent): {} 193 | Vector Observation space size (per agent): {} 194 | Number of stacked Vector Observation: {} 195 | Vector Action space type: {} 196 | Vector Action space size (per agent): {} 197 | Vector Action descriptions: {}'''.format(self.brain_name, 198 | str(self.number_visual_observations), 199 | str(self.vector_observation_space_size), 200 | str(self.num_stacked_vector_observations), 201 | self.vector_action_space_type, 202 | str(self.vector_action_space_size), 203 | ', '.join(self.vector_action_descriptions)) 204 | 205 | @staticmethod 206 | def from_proto(brain_param_proto): 207 | """ 208 | Converts brain parameter proto to BrainParameter object. 209 | :param brain_param_proto: protobuf object. 210 | :return: BrainParameter object. 211 | """ 212 | resolution = [{ 213 | "height": x.height, 214 | "width": x.width, 215 | "blackAndWhite": x.gray_scale 216 | } for x in brain_param_proto.camera_resolutions] 217 | brain_params = BrainParameters(brain_param_proto.brain_name, 218 | brain_param_proto.vector_observation_size, 219 | brain_param_proto.num_stacked_vector_observations, 220 | resolution, 221 | list(brain_param_proto.vector_action_size), 222 | list(brain_param_proto.vector_action_descriptions), 223 | brain_param_proto.vector_action_space_type) 224 | return brain_params -------------------------------------------------------------------------------- /mlagents/envs/communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .communicator_objects import UnityOutput, UnityInput 4 | 5 | logger = logging.getLogger("mlagents.envs") 6 | 7 | 8 | class Communicator(object): 9 | def __init__(self, worker_id=0, base_port=5005): 10 | """ 11 | Python side of the communication. Must be used in pair with the right Unity Communicator equivalent. 12 | 13 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 14 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 15 | """ 16 | 17 | def initialize(self, inputs: UnityInput) -> UnityOutput: 18 | """ 19 | Used to exchange initialization parameters between Python and the Environment 20 | :param inputs: The initialization input that will be sent to the environment. 21 | :return: UnityOutput: The initialization output sent by Unity 22 | """ 23 | 24 | def exchange(self, inputs: UnityInput) -> UnityOutput: 25 | """ 26 | Used to send an input and receive an output from the Environment 27 | :param inputs: The UnityInput that needs to be sent the Environment 28 | :return: The UnityOutputs generated by the Environment 29 | """ 30 | 31 | def close(self): 32 | """ 33 | Sends a shutdown signal to the unity environment, and closes the connection. 34 | """ 35 | 36 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent_action_proto_pb2 import * 2 | from .agent_info_proto_pb2 import * 3 | from .brain_parameters_proto_pb2 import * 4 | from .command_proto_pb2 import * 5 | from .custom_action_pb2 import * 6 | from .custom_observation_pb2 import * 7 | from .custom_reset_parameters_pb2 import * 8 | from .demonstration_meta_proto_pb2 import * 9 | from .engine_configuration_proto_pb2 import * 10 | from .environment_parameters_proto_pb2 import * 11 | from .header_pb2 import * 12 | from .resolution_proto_pb2 import * 13 | from .space_type_proto_pb2 import * 14 | from .unity_input_pb2 import * 15 | from .unity_message_pb2 import * 16 | from .unity_output_pb2 import * 17 | from .unity_rl_initialization_input_pb2 import * 18 | from .unity_rl_initialization_output_pb2 import * 19 | from .unity_rl_input_pb2 import * 20 | from .unity_rl_output_pb2 import * 21 | from .unity_to_external_pb2 import * 22 | from .unity_to_external_pb2_grpc import * 23 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/agent_action_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/agent_action_proto.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from mlagents.envs.communicator_objects import custom_action_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_custom__action__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='mlagents/envs/communicator_objects/agent_action_proto.proto', 21 | package='communicator_objects', 22 | syntax='proto3', 23 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 24 | serialized_pb=_b('\n;mlagents/envs/communicator_objects/agent_action_proto.proto\x12\x14\x63ommunicator_objects\x1a\x36mlagents/envs/communicator_objects/custom_action.proto\"\x9c\x01\n\x10\x41gentActionProto\x12\x16\n\x0evector_actions\x18\x01 \x03(\x02\x12\x14\n\x0ctext_actions\x18\x02 \x01(\t\x12\x10\n\x08memories\x18\x03 \x03(\x02\x12\r\n\x05value\x18\x04 \x01(\x02\x12\x39\n\rcustom_action\x18\x05 \x01(\x0b\x32\".communicator_objects.CustomActionB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_custom__action__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _AGENTACTIONPROTO = _descriptor.Descriptor( 32 | name='AgentActionProto', 33 | full_name='communicator_objects.AgentActionProto', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='vector_actions', full_name='communicator_objects.AgentActionProto.vector_actions', index=0, 40 | number=1, type=2, cpp_type=6, label=3, 41 | has_default_value=False, default_value=[], 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | serialized_options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='text_actions', full_name='communicator_objects.AgentActionProto.text_actions', index=1, 47 | number=2, type=9, cpp_type=9, label=1, 48 | has_default_value=False, default_value=_b("").decode('utf-8'), 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | serialized_options=None, file=DESCRIPTOR), 52 | _descriptor.FieldDescriptor( 53 | name='memories', full_name='communicator_objects.AgentActionProto.memories', index=2, 54 | number=3, type=2, cpp_type=6, label=3, 55 | has_default_value=False, default_value=[], 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | serialized_options=None, file=DESCRIPTOR), 59 | _descriptor.FieldDescriptor( 60 | name='value', full_name='communicator_objects.AgentActionProto.value', index=3, 61 | number=4, type=2, cpp_type=6, label=1, 62 | has_default_value=False, default_value=float(0), 63 | message_type=None, enum_type=None, containing_type=None, 64 | is_extension=False, extension_scope=None, 65 | serialized_options=None, file=DESCRIPTOR), 66 | _descriptor.FieldDescriptor( 67 | name='custom_action', full_name='communicator_objects.AgentActionProto.custom_action', index=4, 68 | number=5, type=11, cpp_type=10, label=1, 69 | has_default_value=False, default_value=None, 70 | message_type=None, enum_type=None, containing_type=None, 71 | is_extension=False, extension_scope=None, 72 | serialized_options=None, file=DESCRIPTOR), 73 | ], 74 | extensions=[ 75 | ], 76 | nested_types=[], 77 | enum_types=[ 78 | ], 79 | serialized_options=None, 80 | is_extendable=False, 81 | syntax='proto3', 82 | extension_ranges=[], 83 | oneofs=[ 84 | ], 85 | serialized_start=142, 86 | serialized_end=298, 87 | ) 88 | 89 | _AGENTACTIONPROTO.fields_by_name['custom_action'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_custom__action__pb2._CUSTOMACTION 90 | DESCRIPTOR.message_types_by_name['AgentActionProto'] = _AGENTACTIONPROTO 91 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 92 | 93 | AgentActionProto = _reflection.GeneratedProtocolMessageType('AgentActionProto', (_message.Message,), dict( 94 | DESCRIPTOR = _AGENTACTIONPROTO, 95 | __module__ = 'mlagents.envs.communicator_objects.agent_action_proto_pb2' 96 | # @@protoc_insertion_point(class_scope:communicator_objects.AgentActionProto) 97 | )) 98 | _sym_db.RegisterMessage(AgentActionProto) 99 | 100 | 101 | DESCRIPTOR._options = None 102 | # @@protoc_insertion_point(module_scope) 103 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/agent_info_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/agent_info_proto.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from mlagents.envs.communicator_objects import custom_observation_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_custom__observation__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='mlagents/envs/communicator_objects/agent_info_proto.proto', 21 | package='communicator_objects', 22 | syntax='proto3', 23 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 24 | serialized_pb=_b('\n9mlagents/envs/communicator_objects/agent_info_proto.proto\x12\x14\x63ommunicator_objects\x1a;mlagents/envs/communicator_objects/custom_observation.proto\"\xd7\x02\n\x0e\x41gentInfoProto\x12\"\n\x1astacked_vector_observation\x18\x01 \x03(\x02\x12\x1b\n\x13visual_observations\x18\x02 \x03(\x0c\x12\x18\n\x10text_observation\x18\x03 \x01(\t\x12\x1d\n\x15stored_vector_actions\x18\x04 \x03(\x02\x12\x1b\n\x13stored_text_actions\x18\x05 \x01(\t\x12\x10\n\x08memories\x18\x06 \x03(\x02\x12\x0e\n\x06reward\x18\x07 \x01(\x02\x12\x0c\n\x04\x64one\x18\x08 \x01(\x08\x12\x18\n\x10max_step_reached\x18\t \x01(\x08\x12\n\n\x02id\x18\n \x01(\x05\x12\x13\n\x0b\x61\x63tion_mask\x18\x0b \x03(\x08\x12\x43\n\x12\x63ustom_observation\x18\x0c \x01(\x0b\x32\'.communicator_objects.CustomObservationB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_custom__observation__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _AGENTINFOPROTO = _descriptor.Descriptor( 32 | name='AgentInfoProto', 33 | full_name='communicator_objects.AgentInfoProto', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='stacked_vector_observation', full_name='communicator_objects.AgentInfoProto.stacked_vector_observation', index=0, 40 | number=1, type=2, cpp_type=6, label=3, 41 | has_default_value=False, default_value=[], 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | serialized_options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='visual_observations', full_name='communicator_objects.AgentInfoProto.visual_observations', index=1, 47 | number=2, type=12, cpp_type=9, label=3, 48 | has_default_value=False, default_value=[], 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | serialized_options=None, file=DESCRIPTOR), 52 | _descriptor.FieldDescriptor( 53 | name='text_observation', full_name='communicator_objects.AgentInfoProto.text_observation', index=2, 54 | number=3, type=9, cpp_type=9, label=1, 55 | has_default_value=False, default_value=_b("").decode('utf-8'), 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | serialized_options=None, file=DESCRIPTOR), 59 | _descriptor.FieldDescriptor( 60 | name='stored_vector_actions', full_name='communicator_objects.AgentInfoProto.stored_vector_actions', index=3, 61 | number=4, type=2, cpp_type=6, label=3, 62 | has_default_value=False, default_value=[], 63 | message_type=None, enum_type=None, containing_type=None, 64 | is_extension=False, extension_scope=None, 65 | serialized_options=None, file=DESCRIPTOR), 66 | _descriptor.FieldDescriptor( 67 | name='stored_text_actions', full_name='communicator_objects.AgentInfoProto.stored_text_actions', index=4, 68 | number=5, type=9, cpp_type=9, label=1, 69 | has_default_value=False, default_value=_b("").decode('utf-8'), 70 | message_type=None, enum_type=None, containing_type=None, 71 | is_extension=False, extension_scope=None, 72 | serialized_options=None, file=DESCRIPTOR), 73 | _descriptor.FieldDescriptor( 74 | name='memories', full_name='communicator_objects.AgentInfoProto.memories', index=5, 75 | number=6, type=2, cpp_type=6, label=3, 76 | has_default_value=False, default_value=[], 77 | message_type=None, enum_type=None, containing_type=None, 78 | is_extension=False, extension_scope=None, 79 | serialized_options=None, file=DESCRIPTOR), 80 | _descriptor.FieldDescriptor( 81 | name='reward', full_name='communicator_objects.AgentInfoProto.reward', index=6, 82 | number=7, type=2, cpp_type=6, label=1, 83 | has_default_value=False, default_value=float(0), 84 | message_type=None, enum_type=None, containing_type=None, 85 | is_extension=False, extension_scope=None, 86 | serialized_options=None, file=DESCRIPTOR), 87 | _descriptor.FieldDescriptor( 88 | name='done', full_name='communicator_objects.AgentInfoProto.done', index=7, 89 | number=8, type=8, cpp_type=7, label=1, 90 | has_default_value=False, default_value=False, 91 | message_type=None, enum_type=None, containing_type=None, 92 | is_extension=False, extension_scope=None, 93 | serialized_options=None, file=DESCRIPTOR), 94 | _descriptor.FieldDescriptor( 95 | name='max_step_reached', full_name='communicator_objects.AgentInfoProto.max_step_reached', index=8, 96 | number=9, type=8, cpp_type=7, label=1, 97 | has_default_value=False, default_value=False, 98 | message_type=None, enum_type=None, containing_type=None, 99 | is_extension=False, extension_scope=None, 100 | serialized_options=None, file=DESCRIPTOR), 101 | _descriptor.FieldDescriptor( 102 | name='id', full_name='communicator_objects.AgentInfoProto.id', index=9, 103 | number=10, type=5, cpp_type=1, label=1, 104 | has_default_value=False, default_value=0, 105 | message_type=None, enum_type=None, containing_type=None, 106 | is_extension=False, extension_scope=None, 107 | serialized_options=None, file=DESCRIPTOR), 108 | _descriptor.FieldDescriptor( 109 | name='action_mask', full_name='communicator_objects.AgentInfoProto.action_mask', index=10, 110 | number=11, type=8, cpp_type=7, label=3, 111 | has_default_value=False, default_value=[], 112 | message_type=None, enum_type=None, containing_type=None, 113 | is_extension=False, extension_scope=None, 114 | serialized_options=None, file=DESCRIPTOR), 115 | _descriptor.FieldDescriptor( 116 | name='custom_observation', full_name='communicator_objects.AgentInfoProto.custom_observation', index=11, 117 | number=12, type=11, cpp_type=10, label=1, 118 | has_default_value=False, default_value=None, 119 | message_type=None, enum_type=None, containing_type=None, 120 | is_extension=False, extension_scope=None, 121 | serialized_options=None, file=DESCRIPTOR), 122 | ], 123 | extensions=[ 124 | ], 125 | nested_types=[], 126 | enum_types=[ 127 | ], 128 | serialized_options=None, 129 | is_extendable=False, 130 | syntax='proto3', 131 | extension_ranges=[], 132 | oneofs=[ 133 | ], 134 | serialized_start=145, 135 | serialized_end=488, 136 | ) 137 | 138 | _AGENTINFOPROTO.fields_by_name['custom_observation'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_custom__observation__pb2._CUSTOMOBSERVATION 139 | DESCRIPTOR.message_types_by_name['AgentInfoProto'] = _AGENTINFOPROTO 140 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 141 | 142 | AgentInfoProto = _reflection.GeneratedProtocolMessageType('AgentInfoProto', (_message.Message,), dict( 143 | DESCRIPTOR = _AGENTINFOPROTO, 144 | __module__ = 'mlagents.envs.communicator_objects.agent_info_proto_pb2' 145 | # @@protoc_insertion_point(class_scope:communicator_objects.AgentInfoProto) 146 | )) 147 | _sym_db.RegisterMessage(AgentInfoProto) 148 | 149 | 150 | DESCRIPTOR._options = None 151 | # @@protoc_insertion_point(module_scope) 152 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/brain_parameters_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/brain_parameters_proto.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from mlagents.envs.communicator_objects import resolution_proto_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_resolution__proto__pb2 17 | from mlagents.envs.communicator_objects import space_type_proto_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_space__type__proto__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='mlagents/envs/communicator_objects/brain_parameters_proto.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 25 | serialized_pb=_b('\n?mlagents/envs/communicator_objects/brain_parameters_proto.proto\x12\x14\x63ommunicator_objects\x1a\x39mlagents/envs/communicator_objects/resolution_proto.proto\x1a\x39mlagents/envs/communicator_objects/space_type_proto.proto\"\xd4\x02\n\x14\x42rainParametersProto\x12\x1f\n\x17vector_observation_size\x18\x01 \x01(\x05\x12\'\n\x1fnum_stacked_vector_observations\x18\x02 \x01(\x05\x12\x1a\n\x12vector_action_size\x18\x03 \x03(\x05\x12\x41\n\x12\x63\x61mera_resolutions\x18\x04 \x03(\x0b\x32%.communicator_objects.ResolutionProto\x12\"\n\x1avector_action_descriptions\x18\x05 \x03(\t\x12\x46\n\x18vector_action_space_type\x18\x06 \x01(\x0e\x32$.communicator_objects.SpaceTypeProto\x12\x12\n\nbrain_name\x18\x07 \x01(\t\x12\x13\n\x0bis_training\x18\x08 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 26 | , 27 | dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_space__type__proto__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _BRAINPARAMETERSPROTO = _descriptor.Descriptor( 33 | name='BrainParametersProto', 34 | full_name='communicator_objects.BrainParametersProto', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='vector_observation_size', full_name='communicator_objects.BrainParametersProto.vector_observation_size', index=0, 41 | number=1, type=5, cpp_type=1, label=1, 42 | has_default_value=False, default_value=0, 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | serialized_options=None, file=DESCRIPTOR), 46 | _descriptor.FieldDescriptor( 47 | name='num_stacked_vector_observations', full_name='communicator_objects.BrainParametersProto.num_stacked_vector_observations', index=1, 48 | number=2, type=5, cpp_type=1, label=1, 49 | has_default_value=False, default_value=0, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | serialized_options=None, file=DESCRIPTOR), 53 | _descriptor.FieldDescriptor( 54 | name='vector_action_size', full_name='communicator_objects.BrainParametersProto.vector_action_size', index=2, 55 | number=3, type=5, cpp_type=1, label=3, 56 | has_default_value=False, default_value=[], 57 | message_type=None, enum_type=None, containing_type=None, 58 | is_extension=False, extension_scope=None, 59 | serialized_options=None, file=DESCRIPTOR), 60 | _descriptor.FieldDescriptor( 61 | name='camera_resolutions', full_name='communicator_objects.BrainParametersProto.camera_resolutions', index=3, 62 | number=4, type=11, cpp_type=10, label=3, 63 | has_default_value=False, default_value=[], 64 | message_type=None, enum_type=None, containing_type=None, 65 | is_extension=False, extension_scope=None, 66 | serialized_options=None, file=DESCRIPTOR), 67 | _descriptor.FieldDescriptor( 68 | name='vector_action_descriptions', full_name='communicator_objects.BrainParametersProto.vector_action_descriptions', index=4, 69 | number=5, type=9, cpp_type=9, label=3, 70 | has_default_value=False, default_value=[], 71 | message_type=None, enum_type=None, containing_type=None, 72 | is_extension=False, extension_scope=None, 73 | serialized_options=None, file=DESCRIPTOR), 74 | _descriptor.FieldDescriptor( 75 | name='vector_action_space_type', full_name='communicator_objects.BrainParametersProto.vector_action_space_type', index=5, 76 | number=6, type=14, cpp_type=8, label=1, 77 | has_default_value=False, default_value=0, 78 | message_type=None, enum_type=None, containing_type=None, 79 | is_extension=False, extension_scope=None, 80 | serialized_options=None, file=DESCRIPTOR), 81 | _descriptor.FieldDescriptor( 82 | name='brain_name', full_name='communicator_objects.BrainParametersProto.brain_name', index=6, 83 | number=7, type=9, cpp_type=9, label=1, 84 | has_default_value=False, default_value=_b("").decode('utf-8'), 85 | message_type=None, enum_type=None, containing_type=None, 86 | is_extension=False, extension_scope=None, 87 | serialized_options=None, file=DESCRIPTOR), 88 | _descriptor.FieldDescriptor( 89 | name='is_training', full_name='communicator_objects.BrainParametersProto.is_training', index=7, 90 | number=8, type=8, cpp_type=7, label=1, 91 | has_default_value=False, default_value=False, 92 | message_type=None, enum_type=None, containing_type=None, 93 | is_extension=False, extension_scope=None, 94 | serialized_options=None, file=DESCRIPTOR), 95 | ], 96 | extensions=[ 97 | ], 98 | nested_types=[], 99 | enum_types=[ 100 | ], 101 | serialized_options=None, 102 | is_extendable=False, 103 | syntax='proto3', 104 | extension_ranges=[], 105 | oneofs=[ 106 | ], 107 | serialized_start=208, 108 | serialized_end=548, 109 | ) 110 | 111 | _BRAINPARAMETERSPROTO.fields_by_name['camera_resolutions'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_resolution__proto__pb2._RESOLUTIONPROTO 112 | _BRAINPARAMETERSPROTO.fields_by_name['vector_action_space_type'].enum_type = mlagents_dot_envs_dot_communicator__objects_dot_space__type__proto__pb2._SPACETYPEPROTO 113 | DESCRIPTOR.message_types_by_name['BrainParametersProto'] = _BRAINPARAMETERSPROTO 114 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 115 | 116 | BrainParametersProto = _reflection.GeneratedProtocolMessageType('BrainParametersProto', (_message.Message,), dict( 117 | DESCRIPTOR = _BRAINPARAMETERSPROTO, 118 | __module__ = 'mlagents.envs.communicator_objects.brain_parameters_proto_pb2' 119 | # @@protoc_insertion_point(class_scope:communicator_objects.BrainParametersProto) 120 | )) 121 | _sym_db.RegisterMessage(BrainParametersProto) 122 | 123 | 124 | DESCRIPTOR._options = None 125 | # @@protoc_insertion_point(module_scope) 126 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/command_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/command_proto.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf.internal import enum_type_wrapper 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='mlagents/envs/communicator_objects/command_proto.proto', 21 | package='communicator_objects', 22 | syntax='proto3', 23 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 24 | serialized_pb=_b('\n6mlagents/envs/communicator_objects/command_proto.proto\x12\x14\x63ommunicator_objects*-\n\x0c\x43ommandProto\x12\x08\n\x04STEP\x10\x00\x12\t\n\x05RESET\x10\x01\x12\x08\n\x04QUIT\x10\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | ) 26 | 27 | _COMMANDPROTO = _descriptor.EnumDescriptor( 28 | name='CommandProto', 29 | full_name='communicator_objects.CommandProto', 30 | filename=None, 31 | file=DESCRIPTOR, 32 | values=[ 33 | _descriptor.EnumValueDescriptor( 34 | name='STEP', index=0, number=0, 35 | serialized_options=None, 36 | type=None), 37 | _descriptor.EnumValueDescriptor( 38 | name='RESET', index=1, number=1, 39 | serialized_options=None, 40 | type=None), 41 | _descriptor.EnumValueDescriptor( 42 | name='QUIT', index=2, number=2, 43 | serialized_options=None, 44 | type=None), 45 | ], 46 | containing_type=None, 47 | serialized_options=None, 48 | serialized_start=80, 49 | serialized_end=125, 50 | ) 51 | _sym_db.RegisterEnumDescriptor(_COMMANDPROTO) 52 | 53 | CommandProto = enum_type_wrapper.EnumTypeWrapper(_COMMANDPROTO) 54 | STEP = 0 55 | RESET = 1 56 | QUIT = 2 57 | 58 | 59 | DESCRIPTOR.enum_types_by_name['CommandProto'] = _COMMANDPROTO 60 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 61 | 62 | 63 | DESCRIPTOR._options = None 64 | # @@protoc_insertion_point(module_scope) 65 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/custom_action_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/custom_action.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='mlagents/envs/communicator_objects/custom_action.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 23 | serialized_pb=_b('\n6mlagents/envs/communicator_objects/custom_action.proto\x12\x14\x63ommunicator_objects\"\x0e\n\x0c\x43ustomActionB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | ) 25 | 26 | 27 | 28 | 29 | _CUSTOMACTION = _descriptor.Descriptor( 30 | name='CustomAction', 31 | full_name='communicator_objects.CustomAction', 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[ 36 | ], 37 | extensions=[ 38 | ], 39 | nested_types=[], 40 | enum_types=[ 41 | ], 42 | serialized_options=None, 43 | is_extendable=False, 44 | syntax='proto3', 45 | extension_ranges=[], 46 | oneofs=[ 47 | ], 48 | serialized_start=80, 49 | serialized_end=94, 50 | ) 51 | 52 | DESCRIPTOR.message_types_by_name['CustomAction'] = _CUSTOMACTION 53 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 54 | 55 | CustomAction = _reflection.GeneratedProtocolMessageType('CustomAction', (_message.Message,), dict( 56 | DESCRIPTOR = _CUSTOMACTION, 57 | __module__ = 'mlagents.envs.communicator_objects.custom_action_pb2' 58 | # @@protoc_insertion_point(class_scope:communicator_objects.CustomAction) 59 | )) 60 | _sym_db.RegisterMessage(CustomAction) 61 | 62 | 63 | DESCRIPTOR._options = None 64 | # @@protoc_insertion_point(module_scope) 65 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/custom_observation_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/custom_observation.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='mlagents/envs/communicator_objects/custom_observation.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 23 | serialized_pb=_b('\n;mlagents/envs/communicator_objects/custom_observation.proto\x12\x14\x63ommunicator_objects\"\x13\n\x11\x43ustomObservationB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | ) 25 | 26 | 27 | 28 | 29 | _CUSTOMOBSERVATION = _descriptor.Descriptor( 30 | name='CustomObservation', 31 | full_name='communicator_objects.CustomObservation', 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[ 36 | ], 37 | extensions=[ 38 | ], 39 | nested_types=[], 40 | enum_types=[ 41 | ], 42 | serialized_options=None, 43 | is_extendable=False, 44 | syntax='proto3', 45 | extension_ranges=[], 46 | oneofs=[ 47 | ], 48 | serialized_start=85, 49 | serialized_end=104, 50 | ) 51 | 52 | DESCRIPTOR.message_types_by_name['CustomObservation'] = _CUSTOMOBSERVATION 53 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 54 | 55 | CustomObservation = _reflection.GeneratedProtocolMessageType('CustomObservation', (_message.Message,), dict( 56 | DESCRIPTOR = _CUSTOMOBSERVATION, 57 | __module__ = 'mlagents.envs.communicator_objects.custom_observation_pb2' 58 | # @@protoc_insertion_point(class_scope:communicator_objects.CustomObservation) 59 | )) 60 | _sym_db.RegisterMessage(CustomObservation) 61 | 62 | 63 | DESCRIPTOR._options = None 64 | # @@protoc_insertion_point(module_scope) 65 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/custom_reset_parameters_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/custom_reset_parameters.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='mlagents/envs/communicator_objects/custom_reset_parameters.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 23 | serialized_pb=_b('\n@mlagents/envs/communicator_objects/custom_reset_parameters.proto\x12\x14\x63ommunicator_objects\"\x17\n\x15\x43ustomResetParametersB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | ) 25 | 26 | 27 | 28 | 29 | _CUSTOMRESETPARAMETERS = _descriptor.Descriptor( 30 | name='CustomResetParameters', 31 | full_name='communicator_objects.CustomResetParameters', 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[ 36 | ], 37 | extensions=[ 38 | ], 39 | nested_types=[], 40 | enum_types=[ 41 | ], 42 | serialized_options=None, 43 | is_extendable=False, 44 | syntax='proto3', 45 | extension_ranges=[], 46 | oneofs=[ 47 | ], 48 | serialized_start=90, 49 | serialized_end=113, 50 | ) 51 | 52 | DESCRIPTOR.message_types_by_name['CustomResetParameters'] = _CUSTOMRESETPARAMETERS 53 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 54 | 55 | CustomResetParameters = _reflection.GeneratedProtocolMessageType('CustomResetParameters', (_message.Message,), dict( 56 | DESCRIPTOR = _CUSTOMRESETPARAMETERS, 57 | __module__ = 'mlagents.envs.communicator_objects.custom_reset_parameters_pb2' 58 | # @@protoc_insertion_point(class_scope:communicator_objects.CustomResetParameters) 59 | )) 60 | _sym_db.RegisterMessage(CustomResetParameters) 61 | 62 | 63 | DESCRIPTOR._options = None 64 | # @@protoc_insertion_point(module_scope) 65 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/demonstration_meta_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/demonstration_meta_proto.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='mlagents/envs/communicator_objects/demonstration_meta_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 23 | serialized_pb=_b('\nAmlagents/envs/communicator_objects/demonstration_meta_proto.proto\x12\x14\x63ommunicator_objects\"\x8d\x01\n\x16\x44\x65monstrationMetaProto\x12\x13\n\x0b\x61pi_version\x18\x01 \x01(\x05\x12\x1a\n\x12\x64\x65monstration_name\x18\x02 \x01(\t\x12\x14\n\x0cnumber_steps\x18\x03 \x01(\x05\x12\x17\n\x0fnumber_episodes\x18\x04 \x01(\x05\x12\x13\n\x0bmean_reward\x18\x05 \x01(\x02\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | ) 25 | 26 | 27 | 28 | 29 | _DEMONSTRATIONMETAPROTO = _descriptor.Descriptor( 30 | name='DemonstrationMetaProto', 31 | full_name='communicator_objects.DemonstrationMetaProto', 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[ 36 | _descriptor.FieldDescriptor( 37 | name='api_version', full_name='communicator_objects.DemonstrationMetaProto.api_version', index=0, 38 | number=1, type=5, cpp_type=1, label=1, 39 | has_default_value=False, default_value=0, 40 | message_type=None, enum_type=None, containing_type=None, 41 | is_extension=False, extension_scope=None, 42 | serialized_options=None, file=DESCRIPTOR), 43 | _descriptor.FieldDescriptor( 44 | name='demonstration_name', full_name='communicator_objects.DemonstrationMetaProto.demonstration_name', index=1, 45 | number=2, type=9, cpp_type=9, label=1, 46 | has_default_value=False, default_value=_b("").decode('utf-8'), 47 | message_type=None, enum_type=None, containing_type=None, 48 | is_extension=False, extension_scope=None, 49 | serialized_options=None, file=DESCRIPTOR), 50 | _descriptor.FieldDescriptor( 51 | name='number_steps', full_name='communicator_objects.DemonstrationMetaProto.number_steps', index=2, 52 | number=3, type=5, cpp_type=1, label=1, 53 | has_default_value=False, default_value=0, 54 | message_type=None, enum_type=None, containing_type=None, 55 | is_extension=False, extension_scope=None, 56 | serialized_options=None, file=DESCRIPTOR), 57 | _descriptor.FieldDescriptor( 58 | name='number_episodes', full_name='communicator_objects.DemonstrationMetaProto.number_episodes', index=3, 59 | number=4, type=5, cpp_type=1, label=1, 60 | has_default_value=False, default_value=0, 61 | message_type=None, enum_type=None, containing_type=None, 62 | is_extension=False, extension_scope=None, 63 | serialized_options=None, file=DESCRIPTOR), 64 | _descriptor.FieldDescriptor( 65 | name='mean_reward', full_name='communicator_objects.DemonstrationMetaProto.mean_reward', index=4, 66 | number=5, type=2, cpp_type=6, label=1, 67 | has_default_value=False, default_value=float(0), 68 | message_type=None, enum_type=None, containing_type=None, 69 | is_extension=False, extension_scope=None, 70 | serialized_options=None, file=DESCRIPTOR), 71 | ], 72 | extensions=[ 73 | ], 74 | nested_types=[], 75 | enum_types=[ 76 | ], 77 | serialized_options=None, 78 | is_extendable=False, 79 | syntax='proto3', 80 | extension_ranges=[], 81 | oneofs=[ 82 | ], 83 | serialized_start=92, 84 | serialized_end=233, 85 | ) 86 | 87 | DESCRIPTOR.message_types_by_name['DemonstrationMetaProto'] = _DEMONSTRATIONMETAPROTO 88 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 89 | 90 | DemonstrationMetaProto = _reflection.GeneratedProtocolMessageType('DemonstrationMetaProto', (_message.Message,), dict( 91 | DESCRIPTOR = _DEMONSTRATIONMETAPROTO, 92 | __module__ = 'mlagents.envs.communicator_objects.demonstration_meta_proto_pb2' 93 | # @@protoc_insertion_point(class_scope:communicator_objects.DemonstrationMetaProto) 94 | )) 95 | _sym_db.RegisterMessage(DemonstrationMetaProto) 96 | 97 | 98 | DESCRIPTOR._options = None 99 | # @@protoc_insertion_point(module_scope) 100 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/engine_configuration_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/engine_configuration_proto.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='mlagents/envs/communicator_objects/engine_configuration_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 23 | serialized_pb=_b('\nCmlagents/envs/communicator_objects/engine_configuration_proto.proto\x12\x14\x63ommunicator_objects\"\x95\x01\n\x18\x45ngineConfigurationProto\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x15\n\rquality_level\x18\x03 \x01(\x05\x12\x12\n\ntime_scale\x18\x04 \x01(\x02\x12\x19\n\x11target_frame_rate\x18\x05 \x01(\x05\x12\x14\n\x0cshow_monitor\x18\x06 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | ) 25 | 26 | 27 | 28 | 29 | _ENGINECONFIGURATIONPROTO = _descriptor.Descriptor( 30 | name='EngineConfigurationProto', 31 | full_name='communicator_objects.EngineConfigurationProto', 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[ 36 | _descriptor.FieldDescriptor( 37 | name='width', full_name='communicator_objects.EngineConfigurationProto.width', index=0, 38 | number=1, type=5, cpp_type=1, label=1, 39 | has_default_value=False, default_value=0, 40 | message_type=None, enum_type=None, containing_type=None, 41 | is_extension=False, extension_scope=None, 42 | serialized_options=None, file=DESCRIPTOR), 43 | _descriptor.FieldDescriptor( 44 | name='height', full_name='communicator_objects.EngineConfigurationProto.height', index=1, 45 | number=2, type=5, cpp_type=1, label=1, 46 | has_default_value=False, default_value=0, 47 | message_type=None, enum_type=None, containing_type=None, 48 | is_extension=False, extension_scope=None, 49 | serialized_options=None, file=DESCRIPTOR), 50 | _descriptor.FieldDescriptor( 51 | name='quality_level', full_name='communicator_objects.EngineConfigurationProto.quality_level', index=2, 52 | number=3, type=5, cpp_type=1, label=1, 53 | has_default_value=False, default_value=0, 54 | message_type=None, enum_type=None, containing_type=None, 55 | is_extension=False, extension_scope=None, 56 | serialized_options=None, file=DESCRIPTOR), 57 | _descriptor.FieldDescriptor( 58 | name='time_scale', full_name='communicator_objects.EngineConfigurationProto.time_scale', index=3, 59 | number=4, type=2, cpp_type=6, label=1, 60 | has_default_value=False, default_value=float(0), 61 | message_type=None, enum_type=None, containing_type=None, 62 | is_extension=False, extension_scope=None, 63 | serialized_options=None, file=DESCRIPTOR), 64 | _descriptor.FieldDescriptor( 65 | name='target_frame_rate', full_name='communicator_objects.EngineConfigurationProto.target_frame_rate', index=4, 66 | number=5, type=5, cpp_type=1, label=1, 67 | has_default_value=False, default_value=0, 68 | message_type=None, enum_type=None, containing_type=None, 69 | is_extension=False, extension_scope=None, 70 | serialized_options=None, file=DESCRIPTOR), 71 | _descriptor.FieldDescriptor( 72 | name='show_monitor', full_name='communicator_objects.EngineConfigurationProto.show_monitor', index=5, 73 | number=6, type=8, cpp_type=7, label=1, 74 | has_default_value=False, default_value=False, 75 | message_type=None, enum_type=None, containing_type=None, 76 | is_extension=False, extension_scope=None, 77 | serialized_options=None, file=DESCRIPTOR), 78 | ], 79 | extensions=[ 80 | ], 81 | nested_types=[], 82 | enum_types=[ 83 | ], 84 | serialized_options=None, 85 | is_extendable=False, 86 | syntax='proto3', 87 | extension_ranges=[], 88 | oneofs=[ 89 | ], 90 | serialized_start=94, 91 | serialized_end=243, 92 | ) 93 | 94 | DESCRIPTOR.message_types_by_name['EngineConfigurationProto'] = _ENGINECONFIGURATIONPROTO 95 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 96 | 97 | EngineConfigurationProto = _reflection.GeneratedProtocolMessageType('EngineConfigurationProto', (_message.Message,), dict( 98 | DESCRIPTOR = _ENGINECONFIGURATIONPROTO, 99 | __module__ = 'mlagents.envs.communicator_objects.engine_configuration_proto_pb2' 100 | # @@protoc_insertion_point(class_scope:communicator_objects.EngineConfigurationProto) 101 | )) 102 | _sym_db.RegisterMessage(EngineConfigurationProto) 103 | 104 | 105 | DESCRIPTOR._options = None 106 | # @@protoc_insertion_point(module_scope) 107 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/environment_parameters_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/environment_parameters_proto.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from mlagents.envs.communicator_objects import custom_reset_parameters_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_custom__reset__parameters__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='mlagents/envs/communicator_objects/environment_parameters_proto.proto', 21 | package='communicator_objects', 22 | syntax='proto3', 23 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 24 | serialized_pb=_b('\nEmlagents/envs/communicator_objects/environment_parameters_proto.proto\x12\x14\x63ommunicator_objects\x1a@mlagents/envs/communicator_objects/custom_reset_parameters.proto\"\x83\x02\n\x1a\x45nvironmentParametersProto\x12_\n\x10\x66loat_parameters\x18\x01 \x03(\x0b\x32\x45.communicator_objects.EnvironmentParametersProto.FloatParametersEntry\x12L\n\x17\x63ustom_reset_parameters\x18\x02 \x01(\x0b\x32+.communicator_objects.CustomResetParameters\x1a\x36\n\x14\x46loatParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_custom__reset__parameters__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY = _descriptor.Descriptor( 32 | name='FloatParametersEntry', 33 | full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='key', full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry.key', index=0, 40 | number=1, type=9, cpp_type=9, label=1, 41 | has_default_value=False, default_value=_b("").decode('utf-8'), 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | serialized_options=None, file=DESCRIPTOR), 45 | _descriptor.FieldDescriptor( 46 | name='value', full_name='communicator_objects.EnvironmentParametersProto.FloatParametersEntry.value', index=1, 47 | number=2, type=2, cpp_type=6, label=1, 48 | has_default_value=False, default_value=float(0), 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | serialized_options=None, file=DESCRIPTOR), 52 | ], 53 | extensions=[ 54 | ], 55 | nested_types=[], 56 | enum_types=[ 57 | ], 58 | serialized_options=_b('8\001'), 59 | is_extendable=False, 60 | syntax='proto3', 61 | extension_ranges=[], 62 | oneofs=[ 63 | ], 64 | serialized_start=367, 65 | serialized_end=421, 66 | ) 67 | 68 | _ENVIRONMENTPARAMETERSPROTO = _descriptor.Descriptor( 69 | name='EnvironmentParametersProto', 70 | full_name='communicator_objects.EnvironmentParametersProto', 71 | filename=None, 72 | file=DESCRIPTOR, 73 | containing_type=None, 74 | fields=[ 75 | _descriptor.FieldDescriptor( 76 | name='float_parameters', full_name='communicator_objects.EnvironmentParametersProto.float_parameters', index=0, 77 | number=1, type=11, cpp_type=10, label=3, 78 | has_default_value=False, default_value=[], 79 | message_type=None, enum_type=None, containing_type=None, 80 | is_extension=False, extension_scope=None, 81 | serialized_options=None, file=DESCRIPTOR), 82 | _descriptor.FieldDescriptor( 83 | name='custom_reset_parameters', full_name='communicator_objects.EnvironmentParametersProto.custom_reset_parameters', index=1, 84 | number=2, type=11, cpp_type=10, label=1, 85 | has_default_value=False, default_value=None, 86 | message_type=None, enum_type=None, containing_type=None, 87 | is_extension=False, extension_scope=None, 88 | serialized_options=None, file=DESCRIPTOR), 89 | ], 90 | extensions=[ 91 | ], 92 | nested_types=[_ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY, ], 93 | enum_types=[ 94 | ], 95 | serialized_options=None, 96 | is_extendable=False, 97 | syntax='proto3', 98 | extension_ranges=[], 99 | oneofs=[ 100 | ], 101 | serialized_start=162, 102 | serialized_end=421, 103 | ) 104 | 105 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY.containing_type = _ENVIRONMENTPARAMETERSPROTO 106 | _ENVIRONMENTPARAMETERSPROTO.fields_by_name['float_parameters'].message_type = _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY 107 | _ENVIRONMENTPARAMETERSPROTO.fields_by_name['custom_reset_parameters'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_custom__reset__parameters__pb2._CUSTOMRESETPARAMETERS 108 | DESCRIPTOR.message_types_by_name['EnvironmentParametersProto'] = _ENVIRONMENTPARAMETERSPROTO 109 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 110 | 111 | EnvironmentParametersProto = _reflection.GeneratedProtocolMessageType('EnvironmentParametersProto', (_message.Message,), dict( 112 | 113 | FloatParametersEntry = _reflection.GeneratedProtocolMessageType('FloatParametersEntry', (_message.Message,), dict( 114 | DESCRIPTOR = _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY, 115 | __module__ = 'mlagents.envs.communicator_objects.environment_parameters_proto_pb2' 116 | # @@protoc_insertion_point(class_scope:communicator_objects.EnvironmentParametersProto.FloatParametersEntry) 117 | )) 118 | , 119 | DESCRIPTOR = _ENVIRONMENTPARAMETERSPROTO, 120 | __module__ = 'mlagents.envs.communicator_objects.environment_parameters_proto_pb2' 121 | # @@protoc_insertion_point(class_scope:communicator_objects.EnvironmentParametersProto) 122 | )) 123 | _sym_db.RegisterMessage(EnvironmentParametersProto) 124 | _sym_db.RegisterMessage(EnvironmentParametersProto.FloatParametersEntry) 125 | 126 | 127 | DESCRIPTOR._options = None 128 | _ENVIRONMENTPARAMETERSPROTO_FLOATPARAMETERSENTRY._options = None 129 | # @@protoc_insertion_point(module_scope) 130 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/header_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/header.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='mlagents/envs/communicator_objects/header.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 23 | serialized_pb=_b('\n/mlagents/envs/communicator_objects/header.proto\x12\x14\x63ommunicator_objects\")\n\x06Header\x12\x0e\n\x06status\x18\x01 \x01(\x05\x12\x0f\n\x07message\x18\x02 \x01(\tB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | ) 25 | 26 | 27 | 28 | 29 | _HEADER = _descriptor.Descriptor( 30 | name='Header', 31 | full_name='communicator_objects.Header', 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[ 36 | _descriptor.FieldDescriptor( 37 | name='status', full_name='communicator_objects.Header.status', index=0, 38 | number=1, type=5, cpp_type=1, label=1, 39 | has_default_value=False, default_value=0, 40 | message_type=None, enum_type=None, containing_type=None, 41 | is_extension=False, extension_scope=None, 42 | serialized_options=None, file=DESCRIPTOR), 43 | _descriptor.FieldDescriptor( 44 | name='message', full_name='communicator_objects.Header.message', index=1, 45 | number=2, type=9, cpp_type=9, label=1, 46 | has_default_value=False, default_value=_b("").decode('utf-8'), 47 | message_type=None, enum_type=None, containing_type=None, 48 | is_extension=False, extension_scope=None, 49 | serialized_options=None, file=DESCRIPTOR), 50 | ], 51 | extensions=[ 52 | ], 53 | nested_types=[], 54 | enum_types=[ 55 | ], 56 | serialized_options=None, 57 | is_extendable=False, 58 | syntax='proto3', 59 | extension_ranges=[], 60 | oneofs=[ 61 | ], 62 | serialized_start=73, 63 | serialized_end=114, 64 | ) 65 | 66 | DESCRIPTOR.message_types_by_name['Header'] = _HEADER 67 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 68 | 69 | Header = _reflection.GeneratedProtocolMessageType('Header', (_message.Message,), dict( 70 | DESCRIPTOR = _HEADER, 71 | __module__ = 'mlagents.envs.communicator_objects.header_pb2' 72 | # @@protoc_insertion_point(class_scope:communicator_objects.Header) 73 | )) 74 | _sym_db.RegisterMessage(Header) 75 | 76 | 77 | DESCRIPTOR._options = None 78 | # @@protoc_insertion_point(module_scope) 79 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/resolution_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/resolution_proto.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='mlagents/envs/communicator_objects/resolution_proto.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 23 | serialized_pb=_b('\n9mlagents/envs/communicator_objects/resolution_proto.proto\x12\x14\x63ommunicator_objects\"D\n\x0fResolutionProto\x12\r\n\x05width\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\x12\n\ngray_scale\x18\x03 \x01(\x08\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | ) 25 | 26 | 27 | 28 | 29 | _RESOLUTIONPROTO = _descriptor.Descriptor( 30 | name='ResolutionProto', 31 | full_name='communicator_objects.ResolutionProto', 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[ 36 | _descriptor.FieldDescriptor( 37 | name='width', full_name='communicator_objects.ResolutionProto.width', index=0, 38 | number=1, type=5, cpp_type=1, label=1, 39 | has_default_value=False, default_value=0, 40 | message_type=None, enum_type=None, containing_type=None, 41 | is_extension=False, extension_scope=None, 42 | serialized_options=None, file=DESCRIPTOR), 43 | _descriptor.FieldDescriptor( 44 | name='height', full_name='communicator_objects.ResolutionProto.height', index=1, 45 | number=2, type=5, cpp_type=1, label=1, 46 | has_default_value=False, default_value=0, 47 | message_type=None, enum_type=None, containing_type=None, 48 | is_extension=False, extension_scope=None, 49 | serialized_options=None, file=DESCRIPTOR), 50 | _descriptor.FieldDescriptor( 51 | name='gray_scale', full_name='communicator_objects.ResolutionProto.gray_scale', index=2, 52 | number=3, type=8, cpp_type=7, label=1, 53 | has_default_value=False, default_value=False, 54 | message_type=None, enum_type=None, containing_type=None, 55 | is_extension=False, extension_scope=None, 56 | serialized_options=None, file=DESCRIPTOR), 57 | ], 58 | extensions=[ 59 | ], 60 | nested_types=[], 61 | enum_types=[ 62 | ], 63 | serialized_options=None, 64 | is_extendable=False, 65 | syntax='proto3', 66 | extension_ranges=[], 67 | oneofs=[ 68 | ], 69 | serialized_start=83, 70 | serialized_end=151, 71 | ) 72 | 73 | DESCRIPTOR.message_types_by_name['ResolutionProto'] = _RESOLUTIONPROTO 74 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 75 | 76 | ResolutionProto = _reflection.GeneratedProtocolMessageType('ResolutionProto', (_message.Message,), dict( 77 | DESCRIPTOR = _RESOLUTIONPROTO, 78 | __module__ = 'mlagents.envs.communicator_objects.resolution_proto_pb2' 79 | # @@protoc_insertion_point(class_scope:communicator_objects.ResolutionProto) 80 | )) 81 | _sym_db.RegisterMessage(ResolutionProto) 82 | 83 | 84 | DESCRIPTOR._options = None 85 | # @@protoc_insertion_point(module_scope) 86 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/space_type_proto_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/space_type_proto.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf.internal import enum_type_wrapper 8 | from google.protobuf import descriptor as _descriptor 9 | from google.protobuf import message as _message 10 | from google.protobuf import reflection as _reflection 11 | from google.protobuf import symbol_database as _symbol_database 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | 17 | from mlagents.envs.communicator_objects import resolution_proto_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_resolution__proto__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='mlagents/envs/communicator_objects/space_type_proto.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 25 | serialized_pb=_b('\n9mlagents/envs/communicator_objects/space_type_proto.proto\x12\x14\x63ommunicator_objects\x1a\x39mlagents/envs/communicator_objects/resolution_proto.proto*.\n\x0eSpaceTypeProto\x12\x0c\n\x08\x64iscrete\x10\x00\x12\x0e\n\ncontinuous\x10\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 26 | , 27 | dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_resolution__proto__pb2.DESCRIPTOR,]) 28 | 29 | _SPACETYPEPROTO = _descriptor.EnumDescriptor( 30 | name='SpaceTypeProto', 31 | full_name='communicator_objects.SpaceTypeProto', 32 | filename=None, 33 | file=DESCRIPTOR, 34 | values=[ 35 | _descriptor.EnumValueDescriptor( 36 | name='discrete', index=0, number=0, 37 | serialized_options=None, 38 | type=None), 39 | _descriptor.EnumValueDescriptor( 40 | name='continuous', index=1, number=1, 41 | serialized_options=None, 42 | type=None), 43 | ], 44 | containing_type=None, 45 | serialized_options=None, 46 | serialized_start=142, 47 | serialized_end=188, 48 | ) 49 | _sym_db.RegisterEnumDescriptor(_SPACETYPEPROTO) 50 | 51 | SpaceTypeProto = enum_type_wrapper.EnumTypeWrapper(_SPACETYPEPROTO) 52 | discrete = 0 53 | continuous = 1 54 | 55 | 56 | DESCRIPTOR.enum_types_by_name['SpaceTypeProto'] = _SPACETYPEPROTO 57 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 58 | 59 | 60 | DESCRIPTOR._options = None 61 | # @@protoc_insertion_point(module_scope) 62 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/unity_input_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/unity_input.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from mlagents.envs.communicator_objects import unity_rl_input_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__input__pb2 17 | from mlagents.envs.communicator_objects import unity_rl_initialization_input_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__initialization__input__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='mlagents/envs/communicator_objects/unity_input.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 25 | serialized_pb=_b('\n4mlagents/envs/communicator_objects/unity_input.proto\x12\x14\x63ommunicator_objects\x1a\x37mlagents/envs/communicator_objects/unity_rl_input.proto\x1a\x46mlagents/envs/communicator_objects/unity_rl_initialization_input.proto\"\x95\x01\n\nUnityInput\x12\x34\n\x08rl_input\x18\x01 \x01(\x0b\x32\".communicator_objects.UnityRLInput\x12Q\n\x17rl_initialization_input\x18\x02 \x01(\x0b\x32\x30.communicator_objects.UnityRLInitializationInputB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 26 | , 27 | dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__input__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__initialization__input__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _UNITYINPUT = _descriptor.Descriptor( 33 | name='UnityInput', 34 | full_name='communicator_objects.UnityInput', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='rl_input', full_name='communicator_objects.UnityInput.rl_input', index=0, 41 | number=1, type=11, cpp_type=10, label=1, 42 | has_default_value=False, default_value=None, 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | serialized_options=None, file=DESCRIPTOR), 46 | _descriptor.FieldDescriptor( 47 | name='rl_initialization_input', full_name='communicator_objects.UnityInput.rl_initialization_input', index=1, 48 | number=2, type=11, cpp_type=10, label=1, 49 | has_default_value=False, default_value=None, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | serialized_options=None, file=DESCRIPTOR), 53 | ], 54 | extensions=[ 55 | ], 56 | nested_types=[], 57 | enum_types=[ 58 | ], 59 | serialized_options=None, 60 | is_extendable=False, 61 | syntax='proto3', 62 | extension_ranges=[], 63 | oneofs=[ 64 | ], 65 | serialized_start=208, 66 | serialized_end=357, 67 | ) 68 | 69 | _UNITYINPUT.fields_by_name['rl_input'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__input__pb2._UNITYRLINPUT 70 | _UNITYINPUT.fields_by_name['rl_initialization_input'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__initialization__input__pb2._UNITYRLINITIALIZATIONINPUT 71 | DESCRIPTOR.message_types_by_name['UnityInput'] = _UNITYINPUT 72 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 73 | 74 | UnityInput = _reflection.GeneratedProtocolMessageType('UnityInput', (_message.Message,), dict( 75 | DESCRIPTOR = _UNITYINPUT, 76 | __module__ = 'mlagents.envs.communicator_objects.unity_input_pb2' 77 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityInput) 78 | )) 79 | _sym_db.RegisterMessage(UnityInput) 80 | 81 | 82 | DESCRIPTOR._options = None 83 | # @@protoc_insertion_point(module_scope) 84 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/unity_message_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/unity_message.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from mlagents.envs.communicator_objects import unity_output_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__output__pb2 17 | from mlagents.envs.communicator_objects import unity_input_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__input__pb2 18 | from mlagents.envs.communicator_objects import header_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_header__pb2 19 | 20 | 21 | DESCRIPTOR = _descriptor.FileDescriptor( 22 | name='mlagents/envs/communicator_objects/unity_message.proto', 23 | package='communicator_objects', 24 | syntax='proto3', 25 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 26 | serialized_pb=_b('\n6mlagents/envs/communicator_objects/unity_message.proto\x12\x14\x63ommunicator_objects\x1a\x35mlagents/envs/communicator_objects/unity_output.proto\x1a\x34mlagents/envs/communicator_objects/unity_input.proto\x1a/mlagents/envs/communicator_objects/header.proto\"\xac\x01\n\x0cUnityMessage\x12,\n\x06header\x18\x01 \x01(\x0b\x32\x1c.communicator_objects.Header\x12\x37\n\x0cunity_output\x18\x02 \x01(\x0b\x32!.communicator_objects.UnityOutput\x12\x35\n\x0bunity_input\x18\x03 \x01(\x0b\x32 .communicator_objects.UnityInputB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 27 | , 28 | dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_unity__output__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_unity__input__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_header__pb2.DESCRIPTOR,]) 29 | 30 | 31 | 32 | 33 | _UNITYMESSAGE = _descriptor.Descriptor( 34 | name='UnityMessage', 35 | full_name='communicator_objects.UnityMessage', 36 | filename=None, 37 | file=DESCRIPTOR, 38 | containing_type=None, 39 | fields=[ 40 | _descriptor.FieldDescriptor( 41 | name='header', full_name='communicator_objects.UnityMessage.header', index=0, 42 | number=1, type=11, cpp_type=10, label=1, 43 | has_default_value=False, default_value=None, 44 | message_type=None, enum_type=None, containing_type=None, 45 | is_extension=False, extension_scope=None, 46 | serialized_options=None, file=DESCRIPTOR), 47 | _descriptor.FieldDescriptor( 48 | name='unity_output', full_name='communicator_objects.UnityMessage.unity_output', index=1, 49 | number=2, type=11, cpp_type=10, label=1, 50 | has_default_value=False, default_value=None, 51 | message_type=None, enum_type=None, containing_type=None, 52 | is_extension=False, extension_scope=None, 53 | serialized_options=None, file=DESCRIPTOR), 54 | _descriptor.FieldDescriptor( 55 | name='unity_input', full_name='communicator_objects.UnityMessage.unity_input', index=2, 56 | number=3, type=11, cpp_type=10, label=1, 57 | has_default_value=False, default_value=None, 58 | message_type=None, enum_type=None, containing_type=None, 59 | is_extension=False, extension_scope=None, 60 | serialized_options=None, file=DESCRIPTOR), 61 | ], 62 | extensions=[ 63 | ], 64 | nested_types=[], 65 | enum_types=[ 66 | ], 67 | serialized_options=None, 68 | is_extendable=False, 69 | syntax='proto3', 70 | extension_ranges=[], 71 | oneofs=[ 72 | ], 73 | serialized_start=239, 74 | serialized_end=411, 75 | ) 76 | 77 | _UNITYMESSAGE.fields_by_name['header'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_header__pb2._HEADER 78 | _UNITYMESSAGE.fields_by_name['unity_output'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_unity__output__pb2._UNITYOUTPUT 79 | _UNITYMESSAGE.fields_by_name['unity_input'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_unity__input__pb2._UNITYINPUT 80 | DESCRIPTOR.message_types_by_name['UnityMessage'] = _UNITYMESSAGE 81 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 82 | 83 | UnityMessage = _reflection.GeneratedProtocolMessageType('UnityMessage', (_message.Message,), dict( 84 | DESCRIPTOR = _UNITYMESSAGE, 85 | __module__ = 'mlagents.envs.communicator_objects.unity_message_pb2' 86 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityMessage) 87 | )) 88 | _sym_db.RegisterMessage(UnityMessage) 89 | 90 | 91 | DESCRIPTOR._options = None 92 | # @@protoc_insertion_point(module_scope) 93 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/unity_output_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/unity_output.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from mlagents.envs.communicator_objects import unity_rl_output_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__output__pb2 17 | from mlagents.envs.communicator_objects import unity_rl_initialization_output_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__initialization__output__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='mlagents/envs/communicator_objects/unity_output.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 25 | serialized_pb=_b('\n5mlagents/envs/communicator_objects/unity_output.proto\x12\x14\x63ommunicator_objects\x1a\x38mlagents/envs/communicator_objects/unity_rl_output.proto\x1aGmlagents/envs/communicator_objects/unity_rl_initialization_output.proto\"\x9a\x01\n\x0bUnityOutput\x12\x36\n\trl_output\x18\x01 \x01(\x0b\x32#.communicator_objects.UnityRLOutput\x12S\n\x18rl_initialization_output\x18\x02 \x01(\x0b\x32\x31.communicator_objects.UnityRLInitializationOutputB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 26 | , 27 | dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__output__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__initialization__output__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _UNITYOUTPUT = _descriptor.Descriptor( 33 | name='UnityOutput', 34 | full_name='communicator_objects.UnityOutput', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='rl_output', full_name='communicator_objects.UnityOutput.rl_output', index=0, 41 | number=1, type=11, cpp_type=10, label=1, 42 | has_default_value=False, default_value=None, 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | serialized_options=None, file=DESCRIPTOR), 46 | _descriptor.FieldDescriptor( 47 | name='rl_initialization_output', full_name='communicator_objects.UnityOutput.rl_initialization_output', index=1, 48 | number=2, type=11, cpp_type=10, label=1, 49 | has_default_value=False, default_value=None, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | serialized_options=None, file=DESCRIPTOR), 53 | ], 54 | extensions=[ 55 | ], 56 | nested_types=[], 57 | enum_types=[ 58 | ], 59 | serialized_options=None, 60 | is_extendable=False, 61 | syntax='proto3', 62 | extension_ranges=[], 63 | oneofs=[ 64 | ], 65 | serialized_start=211, 66 | serialized_end=365, 67 | ) 68 | 69 | _UNITYOUTPUT.fields_by_name['rl_output'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__output__pb2._UNITYRLOUTPUT 70 | _UNITYOUTPUT.fields_by_name['rl_initialization_output'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_unity__rl__initialization__output__pb2._UNITYRLINITIALIZATIONOUTPUT 71 | DESCRIPTOR.message_types_by_name['UnityOutput'] = _UNITYOUTPUT 72 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 73 | 74 | UnityOutput = _reflection.GeneratedProtocolMessageType('UnityOutput', (_message.Message,), dict( 75 | DESCRIPTOR = _UNITYOUTPUT, 76 | __module__ = 'mlagents.envs.communicator_objects.unity_output_pb2' 77 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityOutput) 78 | )) 79 | _sym_db.RegisterMessage(UnityOutput) 80 | 81 | 82 | DESCRIPTOR._options = None 83 | # @@protoc_insertion_point(module_scope) 84 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/unity_rl_initialization_input_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/unity_rl_initialization_input.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='mlagents/envs/communicator_objects/unity_rl_initialization_input.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 23 | serialized_pb=_b('\nFmlagents/envs/communicator_objects/unity_rl_initialization_input.proto\x12\x14\x63ommunicator_objects\"*\n\x1aUnityRLInitializationInput\x12\x0c\n\x04seed\x18\x01 \x01(\x05\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | ) 25 | 26 | 27 | 28 | 29 | _UNITYRLINITIALIZATIONINPUT = _descriptor.Descriptor( 30 | name='UnityRLInitializationInput', 31 | full_name='communicator_objects.UnityRLInitializationInput', 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[ 36 | _descriptor.FieldDescriptor( 37 | name='seed', full_name='communicator_objects.UnityRLInitializationInput.seed', index=0, 38 | number=1, type=5, cpp_type=1, label=1, 39 | has_default_value=False, default_value=0, 40 | message_type=None, enum_type=None, containing_type=None, 41 | is_extension=False, extension_scope=None, 42 | serialized_options=None, file=DESCRIPTOR), 43 | ], 44 | extensions=[ 45 | ], 46 | nested_types=[], 47 | enum_types=[ 48 | ], 49 | serialized_options=None, 50 | is_extendable=False, 51 | syntax='proto3', 52 | extension_ranges=[], 53 | oneofs=[ 54 | ], 55 | serialized_start=96, 56 | serialized_end=138, 57 | ) 58 | 59 | DESCRIPTOR.message_types_by_name['UnityRLInitializationInput'] = _UNITYRLINITIALIZATIONINPUT 60 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 61 | 62 | UnityRLInitializationInput = _reflection.GeneratedProtocolMessageType('UnityRLInitializationInput', (_message.Message,), dict( 63 | DESCRIPTOR = _UNITYRLINITIALIZATIONINPUT, 64 | __module__ = 'mlagents.envs.communicator_objects.unity_rl_initialization_input_pb2' 65 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInitializationInput) 66 | )) 67 | _sym_db.RegisterMessage(UnityRLInitializationInput) 68 | 69 | 70 | DESCRIPTOR._options = None 71 | # @@protoc_insertion_point(module_scope) 72 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/unity_rl_initialization_output_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/unity_rl_initialization_output.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from mlagents.envs.communicator_objects import brain_parameters_proto_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_brain__parameters__proto__pb2 17 | from mlagents.envs.communicator_objects import environment_parameters_proto_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_environment__parameters__proto__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='mlagents/envs/communicator_objects/unity_rl_initialization_output.proto', 22 | package='communicator_objects', 23 | syntax='proto3', 24 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 25 | serialized_pb=_b('\nGmlagents/envs/communicator_objects/unity_rl_initialization_output.proto\x12\x14\x63ommunicator_objects\x1a?mlagents/envs/communicator_objects/brain_parameters_proto.proto\x1a\x45mlagents/envs/communicator_objects/environment_parameters_proto.proto\"\xe6\x01\n\x1bUnityRLInitializationOutput\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x10\n\x08log_path\x18\x03 \x01(\t\x12\x44\n\x10\x62rain_parameters\x18\x05 \x03(\x0b\x32*.communicator_objects.BrainParametersProto\x12P\n\x16\x65nvironment_parameters\x18\x06 \x01(\x0b\x32\x30.communicator_objects.EnvironmentParametersProtoB\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 26 | , 27 | dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_brain__parameters__proto__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_environment__parameters__proto__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _UNITYRLINITIALIZATIONOUTPUT = _descriptor.Descriptor( 33 | name='UnityRLInitializationOutput', 34 | full_name='communicator_objects.UnityRLInitializationOutput', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='name', full_name='communicator_objects.UnityRLInitializationOutput.name', index=0, 41 | number=1, type=9, cpp_type=9, label=1, 42 | has_default_value=False, default_value=_b("").decode('utf-8'), 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | serialized_options=None, file=DESCRIPTOR), 46 | _descriptor.FieldDescriptor( 47 | name='version', full_name='communicator_objects.UnityRLInitializationOutput.version', index=1, 48 | number=2, type=9, cpp_type=9, label=1, 49 | has_default_value=False, default_value=_b("").decode('utf-8'), 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | serialized_options=None, file=DESCRIPTOR), 53 | _descriptor.FieldDescriptor( 54 | name='log_path', full_name='communicator_objects.UnityRLInitializationOutput.log_path', index=2, 55 | number=3, type=9, cpp_type=9, label=1, 56 | has_default_value=False, default_value=_b("").decode('utf-8'), 57 | message_type=None, enum_type=None, containing_type=None, 58 | is_extension=False, extension_scope=None, 59 | serialized_options=None, file=DESCRIPTOR), 60 | _descriptor.FieldDescriptor( 61 | name='brain_parameters', full_name='communicator_objects.UnityRLInitializationOutput.brain_parameters', index=3, 62 | number=5, type=11, cpp_type=10, label=3, 63 | has_default_value=False, default_value=[], 64 | message_type=None, enum_type=None, containing_type=None, 65 | is_extension=False, extension_scope=None, 66 | serialized_options=None, file=DESCRIPTOR), 67 | _descriptor.FieldDescriptor( 68 | name='environment_parameters', full_name='communicator_objects.UnityRLInitializationOutput.environment_parameters', index=4, 69 | number=6, type=11, cpp_type=10, label=1, 70 | has_default_value=False, default_value=None, 71 | message_type=None, enum_type=None, containing_type=None, 72 | is_extension=False, extension_scope=None, 73 | serialized_options=None, file=DESCRIPTOR), 74 | ], 75 | extensions=[ 76 | ], 77 | nested_types=[], 78 | enum_types=[ 79 | ], 80 | serialized_options=None, 81 | is_extendable=False, 82 | syntax='proto3', 83 | extension_ranges=[], 84 | oneofs=[ 85 | ], 86 | serialized_start=234, 87 | serialized_end=464, 88 | ) 89 | 90 | _UNITYRLINITIALIZATIONOUTPUT.fields_by_name['brain_parameters'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_brain__parameters__proto__pb2._BRAINPARAMETERSPROTO 91 | _UNITYRLINITIALIZATIONOUTPUT.fields_by_name['environment_parameters'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_environment__parameters__proto__pb2._ENVIRONMENTPARAMETERSPROTO 92 | DESCRIPTOR.message_types_by_name['UnityRLInitializationOutput'] = _UNITYRLINITIALIZATIONOUTPUT 93 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 94 | 95 | UnityRLInitializationOutput = _reflection.GeneratedProtocolMessageType('UnityRLInitializationOutput', (_message.Message,), dict( 96 | DESCRIPTOR = _UNITYRLINITIALIZATIONOUTPUT, 97 | __module__ = 'mlagents.envs.communicator_objects.unity_rl_initialization_output_pb2' 98 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInitializationOutput) 99 | )) 100 | _sym_db.RegisterMessage(UnityRLInitializationOutput) 101 | 102 | 103 | DESCRIPTOR._options = None 104 | # @@protoc_insertion_point(module_scope) 105 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/unity_rl_input_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/unity_rl_input.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from mlagents.envs.communicator_objects import agent_action_proto_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_agent__action__proto__pb2 17 | from mlagents.envs.communicator_objects import environment_parameters_proto_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_environment__parameters__proto__pb2 18 | from mlagents.envs.communicator_objects import command_proto_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_command__proto__pb2 19 | 20 | 21 | DESCRIPTOR = _descriptor.FileDescriptor( 22 | name='mlagents/envs/communicator_objects/unity_rl_input.proto', 23 | package='communicator_objects', 24 | syntax='proto3', 25 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 26 | serialized_pb=_b('\n7mlagents/envs/communicator_objects/unity_rl_input.proto\x12\x14\x63ommunicator_objects\x1a;mlagents/envs/communicator_objects/agent_action_proto.proto\x1a\x45mlagents/envs/communicator_objects/environment_parameters_proto.proto\x1a\x36mlagents/envs/communicator_objects/command_proto.proto\"\xb4\x03\n\x0cUnityRLInput\x12K\n\ragent_actions\x18\x01 \x03(\x0b\x32\x34.communicator_objects.UnityRLInput.AgentActionsEntry\x12P\n\x16\x65nvironment_parameters\x18\x02 \x01(\x0b\x32\x30.communicator_objects.EnvironmentParametersProto\x12\x13\n\x0bis_training\x18\x03 \x01(\x08\x12\x33\n\x07\x63ommand\x18\x04 \x01(\x0e\x32\".communicator_objects.CommandProto\x1aM\n\x14ListAgentActionProto\x12\x35\n\x05value\x18\x01 \x03(\x0b\x32&.communicator_objects.AgentActionProto\x1al\n\x11\x41gentActionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x46\n\x05value\x18\x02 \x01(\x0b\x32\x37.communicator_objects.UnityRLInput.ListAgentActionProto:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 27 | , 28 | dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_agent__action__proto__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_environment__parameters__proto__pb2.DESCRIPTOR,mlagents_dot_envs_dot_communicator__objects_dot_command__proto__pb2.DESCRIPTOR,]) 29 | 30 | 31 | 32 | 33 | _UNITYRLINPUT_LISTAGENTACTIONPROTO = _descriptor.Descriptor( 34 | name='ListAgentActionProto', 35 | full_name='communicator_objects.UnityRLInput.ListAgentActionProto', 36 | filename=None, 37 | file=DESCRIPTOR, 38 | containing_type=None, 39 | fields=[ 40 | _descriptor.FieldDescriptor( 41 | name='value', full_name='communicator_objects.UnityRLInput.ListAgentActionProto.value', index=0, 42 | number=1, type=11, cpp_type=10, label=3, 43 | has_default_value=False, default_value=[], 44 | message_type=None, enum_type=None, containing_type=None, 45 | is_extension=False, extension_scope=None, 46 | serialized_options=None, file=DESCRIPTOR), 47 | ], 48 | extensions=[ 49 | ], 50 | nested_types=[], 51 | enum_types=[ 52 | ], 53 | serialized_options=None, 54 | is_extendable=False, 55 | syntax='proto3', 56 | extension_ranges=[], 57 | oneofs=[ 58 | ], 59 | serialized_start=519, 60 | serialized_end=596, 61 | ) 62 | 63 | _UNITYRLINPUT_AGENTACTIONSENTRY = _descriptor.Descriptor( 64 | name='AgentActionsEntry', 65 | full_name='communicator_objects.UnityRLInput.AgentActionsEntry', 66 | filename=None, 67 | file=DESCRIPTOR, 68 | containing_type=None, 69 | fields=[ 70 | _descriptor.FieldDescriptor( 71 | name='key', full_name='communicator_objects.UnityRLInput.AgentActionsEntry.key', index=0, 72 | number=1, type=9, cpp_type=9, label=1, 73 | has_default_value=False, default_value=_b("").decode('utf-8'), 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | serialized_options=None, file=DESCRIPTOR), 77 | _descriptor.FieldDescriptor( 78 | name='value', full_name='communicator_objects.UnityRLInput.AgentActionsEntry.value', index=1, 79 | number=2, type=11, cpp_type=10, label=1, 80 | has_default_value=False, default_value=None, 81 | message_type=None, enum_type=None, containing_type=None, 82 | is_extension=False, extension_scope=None, 83 | serialized_options=None, file=DESCRIPTOR), 84 | ], 85 | extensions=[ 86 | ], 87 | nested_types=[], 88 | enum_types=[ 89 | ], 90 | serialized_options=_b('8\001'), 91 | is_extendable=False, 92 | syntax='proto3', 93 | extension_ranges=[], 94 | oneofs=[ 95 | ], 96 | serialized_start=598, 97 | serialized_end=706, 98 | ) 99 | 100 | _UNITYRLINPUT = _descriptor.Descriptor( 101 | name='UnityRLInput', 102 | full_name='communicator_objects.UnityRLInput', 103 | filename=None, 104 | file=DESCRIPTOR, 105 | containing_type=None, 106 | fields=[ 107 | _descriptor.FieldDescriptor( 108 | name='agent_actions', full_name='communicator_objects.UnityRLInput.agent_actions', index=0, 109 | number=1, type=11, cpp_type=10, label=3, 110 | has_default_value=False, default_value=[], 111 | message_type=None, enum_type=None, containing_type=None, 112 | is_extension=False, extension_scope=None, 113 | serialized_options=None, file=DESCRIPTOR), 114 | _descriptor.FieldDescriptor( 115 | name='environment_parameters', full_name='communicator_objects.UnityRLInput.environment_parameters', index=1, 116 | number=2, type=11, cpp_type=10, label=1, 117 | has_default_value=False, default_value=None, 118 | message_type=None, enum_type=None, containing_type=None, 119 | is_extension=False, extension_scope=None, 120 | serialized_options=None, file=DESCRIPTOR), 121 | _descriptor.FieldDescriptor( 122 | name='is_training', full_name='communicator_objects.UnityRLInput.is_training', index=2, 123 | number=3, type=8, cpp_type=7, label=1, 124 | has_default_value=False, default_value=False, 125 | message_type=None, enum_type=None, containing_type=None, 126 | is_extension=False, extension_scope=None, 127 | serialized_options=None, file=DESCRIPTOR), 128 | _descriptor.FieldDescriptor( 129 | name='command', full_name='communicator_objects.UnityRLInput.command', index=3, 130 | number=4, type=14, cpp_type=8, label=1, 131 | has_default_value=False, default_value=0, 132 | message_type=None, enum_type=None, containing_type=None, 133 | is_extension=False, extension_scope=None, 134 | serialized_options=None, file=DESCRIPTOR), 135 | ], 136 | extensions=[ 137 | ], 138 | nested_types=[_UNITYRLINPUT_LISTAGENTACTIONPROTO, _UNITYRLINPUT_AGENTACTIONSENTRY, ], 139 | enum_types=[ 140 | ], 141 | serialized_options=None, 142 | is_extendable=False, 143 | syntax='proto3', 144 | extension_ranges=[], 145 | oneofs=[ 146 | ], 147 | serialized_start=270, 148 | serialized_end=706, 149 | ) 150 | 151 | _UNITYRLINPUT_LISTAGENTACTIONPROTO.fields_by_name['value'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_agent__action__proto__pb2._AGENTACTIONPROTO 152 | _UNITYRLINPUT_LISTAGENTACTIONPROTO.containing_type = _UNITYRLINPUT 153 | _UNITYRLINPUT_AGENTACTIONSENTRY.fields_by_name['value'].message_type = _UNITYRLINPUT_LISTAGENTACTIONPROTO 154 | _UNITYRLINPUT_AGENTACTIONSENTRY.containing_type = _UNITYRLINPUT 155 | _UNITYRLINPUT.fields_by_name['agent_actions'].message_type = _UNITYRLINPUT_AGENTACTIONSENTRY 156 | _UNITYRLINPUT.fields_by_name['environment_parameters'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_environment__parameters__proto__pb2._ENVIRONMENTPARAMETERSPROTO 157 | _UNITYRLINPUT.fields_by_name['command'].enum_type = mlagents_dot_envs_dot_communicator__objects_dot_command__proto__pb2._COMMANDPROTO 158 | DESCRIPTOR.message_types_by_name['UnityRLInput'] = _UNITYRLINPUT 159 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 160 | 161 | UnityRLInput = _reflection.GeneratedProtocolMessageType('UnityRLInput', (_message.Message,), dict( 162 | 163 | ListAgentActionProto = _reflection.GeneratedProtocolMessageType('ListAgentActionProto', (_message.Message,), dict( 164 | DESCRIPTOR = _UNITYRLINPUT_LISTAGENTACTIONPROTO, 165 | __module__ = 'mlagents.envs.communicator_objects.unity_rl_input_pb2' 166 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInput.ListAgentActionProto) 167 | )) 168 | , 169 | 170 | AgentActionsEntry = _reflection.GeneratedProtocolMessageType('AgentActionsEntry', (_message.Message,), dict( 171 | DESCRIPTOR = _UNITYRLINPUT_AGENTACTIONSENTRY, 172 | __module__ = 'mlagents.envs.communicator_objects.unity_rl_input_pb2' 173 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInput.AgentActionsEntry) 174 | )) 175 | , 176 | DESCRIPTOR = _UNITYRLINPUT, 177 | __module__ = 'mlagents.envs.communicator_objects.unity_rl_input_pb2' 178 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLInput) 179 | )) 180 | _sym_db.RegisterMessage(UnityRLInput) 181 | _sym_db.RegisterMessage(UnityRLInput.ListAgentActionProto) 182 | _sym_db.RegisterMessage(UnityRLInput.AgentActionsEntry) 183 | 184 | 185 | DESCRIPTOR._options = None 186 | _UNITYRLINPUT_AGENTACTIONSENTRY._options = None 187 | # @@protoc_insertion_point(module_scope) 188 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/unity_rl_output_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: mlagents/envs/communicator_objects/unity_rl_output.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from mlagents.envs.communicator_objects import agent_info_proto_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_agent__info__proto__pb2 17 | 18 | 19 | DESCRIPTOR = _descriptor.FileDescriptor( 20 | name='mlagents/envs/communicator_objects/unity_rl_output.proto', 21 | package='communicator_objects', 22 | syntax='proto3', 23 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 24 | serialized_pb=_b('\n8mlagents/envs/communicator_objects/unity_rl_output.proto\x12\x14\x63ommunicator_objects\x1a\x39mlagents/envs/communicator_objects/agent_info_proto.proto\"\xa3\x02\n\rUnityRLOutput\x12\x13\n\x0bglobal_done\x18\x01 \x01(\x08\x12G\n\nagentInfos\x18\x02 \x03(\x0b\x32\x33.communicator_objects.UnityRLOutput.AgentInfosEntry\x1aI\n\x12ListAgentInfoProto\x12\x33\n\x05value\x18\x01 \x03(\x0b\x32$.communicator_objects.AgentInfoProto\x1ai\n\x0f\x41gentInfosEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x45\n\x05value\x18\x02 \x01(\x0b\x32\x36.communicator_objects.UnityRLOutput.ListAgentInfoProto:\x02\x38\x01\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 25 | , 26 | dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_agent__info__proto__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO = _descriptor.Descriptor( 32 | name='ListAgentInfoProto', 33 | full_name='communicator_objects.UnityRLOutput.ListAgentInfoProto', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='value', full_name='communicator_objects.UnityRLOutput.ListAgentInfoProto.value', index=0, 40 | number=1, type=11, cpp_type=10, label=3, 41 | has_default_value=False, default_value=[], 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | serialized_options=None, file=DESCRIPTOR), 45 | ], 46 | extensions=[ 47 | ], 48 | nested_types=[], 49 | enum_types=[ 50 | ], 51 | serialized_options=None, 52 | is_extendable=False, 53 | syntax='proto3', 54 | extension_ranges=[], 55 | oneofs=[ 56 | ], 57 | serialized_start=253, 58 | serialized_end=326, 59 | ) 60 | 61 | _UNITYRLOUTPUT_AGENTINFOSENTRY = _descriptor.Descriptor( 62 | name='AgentInfosEntry', 63 | full_name='communicator_objects.UnityRLOutput.AgentInfosEntry', 64 | filename=None, 65 | file=DESCRIPTOR, 66 | containing_type=None, 67 | fields=[ 68 | _descriptor.FieldDescriptor( 69 | name='key', full_name='communicator_objects.UnityRLOutput.AgentInfosEntry.key', index=0, 70 | number=1, type=9, cpp_type=9, label=1, 71 | has_default_value=False, default_value=_b("").decode('utf-8'), 72 | message_type=None, enum_type=None, containing_type=None, 73 | is_extension=False, extension_scope=None, 74 | serialized_options=None, file=DESCRIPTOR), 75 | _descriptor.FieldDescriptor( 76 | name='value', full_name='communicator_objects.UnityRLOutput.AgentInfosEntry.value', index=1, 77 | number=2, type=11, cpp_type=10, label=1, 78 | has_default_value=False, default_value=None, 79 | message_type=None, enum_type=None, containing_type=None, 80 | is_extension=False, extension_scope=None, 81 | serialized_options=None, file=DESCRIPTOR), 82 | ], 83 | extensions=[ 84 | ], 85 | nested_types=[], 86 | enum_types=[ 87 | ], 88 | serialized_options=_b('8\001'), 89 | is_extendable=False, 90 | syntax='proto3', 91 | extension_ranges=[], 92 | oneofs=[ 93 | ], 94 | serialized_start=328, 95 | serialized_end=433, 96 | ) 97 | 98 | _UNITYRLOUTPUT = _descriptor.Descriptor( 99 | name='UnityRLOutput', 100 | full_name='communicator_objects.UnityRLOutput', 101 | filename=None, 102 | file=DESCRIPTOR, 103 | containing_type=None, 104 | fields=[ 105 | _descriptor.FieldDescriptor( 106 | name='global_done', full_name='communicator_objects.UnityRLOutput.global_done', index=0, 107 | number=1, type=8, cpp_type=7, label=1, 108 | has_default_value=False, default_value=False, 109 | message_type=None, enum_type=None, containing_type=None, 110 | is_extension=False, extension_scope=None, 111 | serialized_options=None, file=DESCRIPTOR), 112 | _descriptor.FieldDescriptor( 113 | name='agentInfos', full_name='communicator_objects.UnityRLOutput.agentInfos', index=1, 114 | number=2, type=11, cpp_type=10, label=3, 115 | has_default_value=False, default_value=[], 116 | message_type=None, enum_type=None, containing_type=None, 117 | is_extension=False, extension_scope=None, 118 | serialized_options=None, file=DESCRIPTOR), 119 | ], 120 | extensions=[ 121 | ], 122 | nested_types=[_UNITYRLOUTPUT_LISTAGENTINFOPROTO, _UNITYRLOUTPUT_AGENTINFOSENTRY, ], 123 | enum_types=[ 124 | ], 125 | serialized_options=None, 126 | is_extendable=False, 127 | syntax='proto3', 128 | extension_ranges=[], 129 | oneofs=[ 130 | ], 131 | serialized_start=142, 132 | serialized_end=433, 133 | ) 134 | 135 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO.fields_by_name['value'].message_type = mlagents_dot_envs_dot_communicator__objects_dot_agent__info__proto__pb2._AGENTINFOPROTO 136 | _UNITYRLOUTPUT_LISTAGENTINFOPROTO.containing_type = _UNITYRLOUTPUT 137 | _UNITYRLOUTPUT_AGENTINFOSENTRY.fields_by_name['value'].message_type = _UNITYRLOUTPUT_LISTAGENTINFOPROTO 138 | _UNITYRLOUTPUT_AGENTINFOSENTRY.containing_type = _UNITYRLOUTPUT 139 | _UNITYRLOUTPUT.fields_by_name['agentInfos'].message_type = _UNITYRLOUTPUT_AGENTINFOSENTRY 140 | DESCRIPTOR.message_types_by_name['UnityRLOutput'] = _UNITYRLOUTPUT 141 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 142 | 143 | UnityRLOutput = _reflection.GeneratedProtocolMessageType('UnityRLOutput', (_message.Message,), dict( 144 | 145 | ListAgentInfoProto = _reflection.GeneratedProtocolMessageType('ListAgentInfoProto', (_message.Message,), dict( 146 | DESCRIPTOR = _UNITYRLOUTPUT_LISTAGENTINFOPROTO, 147 | __module__ = 'mlagents.envs.communicator_objects.unity_rl_output_pb2' 148 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput.ListAgentInfoProto) 149 | )) 150 | , 151 | 152 | AgentInfosEntry = _reflection.GeneratedProtocolMessageType('AgentInfosEntry', (_message.Message,), dict( 153 | DESCRIPTOR = _UNITYRLOUTPUT_AGENTINFOSENTRY, 154 | __module__ = 'mlagents.envs.communicator_objects.unity_rl_output_pb2' 155 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput.AgentInfosEntry) 156 | )) 157 | , 158 | DESCRIPTOR = _UNITYRLOUTPUT, 159 | __module__ = 'mlagents.envs.communicator_objects.unity_rl_output_pb2' 160 | # @@protoc_insertion_point(class_scope:communicator_objects.UnityRLOutput) 161 | )) 162 | _sym_db.RegisterMessage(UnityRLOutput) 163 | _sym_db.RegisterMessage(UnityRLOutput.ListAgentInfoProto) 164 | _sym_db.RegisterMessage(UnityRLOutput.AgentInfosEntry) 165 | 166 | 167 | DESCRIPTOR._options = None 168 | _UNITYRLOUTPUT_AGENTINFOSENTRY._options = None 169 | # @@protoc_insertion_point(module_scope) 170 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/unity_to_external_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: mlagents/envs/communicator_objects/unity_to_external.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | from mlagents.envs.communicator_objects import unity_message_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='mlagents/envs/communicator_objects/unity_to_external.proto', 20 | package='communicator_objects', 21 | syntax='proto3', 22 | serialized_options=_b('\252\002\034MLAgents.CommunicatorObjects'), 23 | serialized_pb=_b('\n:mlagents/envs/communicator_objects/unity_to_external.proto\x12\x14\x63ommunicator_objects\x1a\x36mlagents/envs/communicator_objects/unity_message.proto2g\n\x0fUnityToExternal\x12T\n\x08\x45xchange\x12\".communicator_objects.UnityMessage\x1a\".communicator_objects.UnityMessage\"\x00\x42\x1f\xaa\x02\x1cMLAgents.CommunicatorObjectsb\x06proto3') 24 | , 25 | dependencies=[mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2.DESCRIPTOR,]) 26 | 27 | 28 | 29 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 30 | 31 | 32 | DESCRIPTOR._options = None 33 | 34 | _UNITYTOEXTERNAL = _descriptor.ServiceDescriptor( 35 | name='UnityToExternal', 36 | full_name='communicator_objects.UnityToExternal', 37 | file=DESCRIPTOR, 38 | index=0, 39 | serialized_options=None, 40 | serialized_start=140, 41 | serialized_end=243, 42 | methods=[ 43 | _descriptor.MethodDescriptor( 44 | name='Exchange', 45 | full_name='communicator_objects.UnityToExternal.Exchange', 46 | index=0, 47 | containing_service=None, 48 | input_type=mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2._UNITYMESSAGE, 49 | output_type=mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2._UNITYMESSAGE, 50 | serialized_options=None, 51 | ), 52 | ]) 53 | _sym_db.RegisterServiceDescriptor(_UNITYTOEXTERNAL) 54 | 55 | DESCRIPTOR.services_by_name['UnityToExternal'] = _UNITYTOEXTERNAL 56 | 57 | # @@protoc_insertion_point(module_scope) 58 | -------------------------------------------------------------------------------- /mlagents/envs/communicator_objects/unity_to_external_pb2_grpc.py: -------------------------------------------------------------------------------- 1 | # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! 2 | import grpc 3 | 4 | from mlagents.envs.communicator_objects import unity_message_pb2 as mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2 5 | 6 | 7 | class UnityToExternalStub(object): 8 | # missing associated documentation comment in .proto file 9 | pass 10 | 11 | def __init__(self, channel): 12 | """Constructor. 13 | 14 | Args: 15 | channel: A grpc.Channel. 16 | """ 17 | self.Exchange = channel.unary_unary( 18 | '/communicator_objects.UnityToExternal/Exchange', 19 | request_serializer=mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessage.SerializeToString, 20 | response_deserializer=mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessage.FromString, 21 | ) 22 | 23 | 24 | class UnityToExternalServicer(object): 25 | # missing associated documentation comment in .proto file 26 | pass 27 | 28 | def Exchange(self, request, context): 29 | """Sends the academy parameters 30 | """ 31 | context.set_code(grpc.StatusCode.UNIMPLEMENTED) 32 | context.set_details('Method not implemented!') 33 | raise NotImplementedError('Method not implemented!') 34 | 35 | 36 | def add_UnityToExternalServicer_to_server(servicer, server): 37 | rpc_method_handlers = { 38 | 'Exchange': grpc.unary_unary_rpc_method_handler( 39 | servicer.Exchange, 40 | request_deserializer=mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessage.FromString, 41 | response_serializer=mlagents_dot_envs_dot_communicator__objects_dot_unity__message__pb2.UnityMessage.SerializeToString, 42 | ), 43 | } 44 | generic_handler = grpc.method_handlers_generic_handler( 45 | 'communicator_objects.UnityToExternal', rpc_method_handlers) 46 | server.add_generic_rpc_handlers((generic_handler,)) 47 | -------------------------------------------------------------------------------- /mlagents/envs/exception.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger("mlagents.envs") 3 | 4 | class UnityException(Exception): 5 | """ 6 | Any error related to ml-agents environment. 7 | """ 8 | pass 9 | 10 | class UnityEnvironmentException(UnityException): 11 | """ 12 | Related to errors starting and closing environment. 13 | """ 14 | pass 15 | 16 | 17 | class UnityActionException(UnityException): 18 | """ 19 | Related to errors with sending actions. 20 | """ 21 | pass 22 | 23 | class UnityTimeOutException(UnityException): 24 | """ 25 | Related to errors with communication timeouts. 26 | """ 27 | def __init__(self, message, log_file_path = None): 28 | if log_file_path is not None: 29 | try: 30 | with open(log_file_path, "r") as f: 31 | printing = False 32 | unity_error = '\n' 33 | for l in f: 34 | l=l.strip() 35 | if (l == 'Exception') or (l=='Error'): 36 | printing = True 37 | unity_error += '----------------------\n' 38 | if (l == ''): 39 | printing = False 40 | if printing: 41 | unity_error += l + '\n' 42 | logger.info(unity_error) 43 | logger.error("An error might have occured in the environment. " 44 | "You can check the logfile for more information at {}".format(log_file_path)) 45 | except: 46 | logger.error("An error might have occured in the environment. " 47 | "No UnitySDK.log file could be found.") 48 | super(UnityTimeOutException, self).__init__(message) 49 | 50 | 51 | class UnityWorkerInUseException(UnityException): 52 | """ 53 | This error occurs when the port for a certain worker ID is already reserved. 54 | """ 55 | 56 | MESSAGE_TEMPLATE = ( 57 | "Couldn't start socket communication because worker number {} is still in use. " 58 | "You may need to manually close a previously opened environment " 59 | "or use a different worker number.") 60 | 61 | def __init__(self, worker_id): 62 | message = self.MESSAGE_TEMPLATE.format(str(worker_id)) 63 | super(UnityWorkerInUseException, self).__init__(message) 64 | -------------------------------------------------------------------------------- /mlagents/envs/mock_communicator.py: -------------------------------------------------------------------------------- 1 | from .communicator import Communicator 2 | from .communicator_objects import UnityOutput, UnityInput, \ 3 | ResolutionProto, BrainParametersProto, UnityRLInitializationOutput, \ 4 | AgentInfoProto, UnityRLOutput 5 | 6 | 7 | class MockCommunicator(Communicator): 8 | def __init__(self, discrete_action=False, visual_inputs=0, stack=True, num_agents=3, 9 | brain_name="RealFakeBrain", vec_obs_size=3): 10 | """ 11 | Python side of the grpc communication. Python is the client and Unity the server 12 | 13 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 14 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 15 | """ 16 | self.is_discrete = discrete_action 17 | self.steps = 0 18 | self.visual_inputs = visual_inputs 19 | self.has_been_closed = False 20 | self.num_agents = num_agents 21 | self.brain_name = brain_name 22 | self.vec_obs_size = vec_obs_size 23 | if stack: 24 | self.num_stacks = 2 25 | else: 26 | self.num_stacks = 1 27 | 28 | def initialize(self, inputs: UnityInput) -> UnityOutput: 29 | resolutions = [ResolutionProto( 30 | width=30, 31 | height=40, 32 | gray_scale=False) for i in range(self.visual_inputs)] 33 | bp = BrainParametersProto( 34 | vector_observation_size=self.vec_obs_size, 35 | num_stacked_vector_observations=self.num_stacks, 36 | vector_action_size=[2], 37 | camera_resolutions=resolutions, 38 | vector_action_descriptions=["", ""], 39 | vector_action_space_type=int(not self.is_discrete), 40 | brain_name=self.brain_name, 41 | is_training=True 42 | ) 43 | rl_init = UnityRLInitializationOutput( 44 | name="RealFakeAcademy", 45 | version="API-8", 46 | log_path="", 47 | brain_parameters=[bp] 48 | ) 49 | return UnityOutput( 50 | rl_initialization_output=rl_init 51 | ) 52 | 53 | def exchange(self, inputs: UnityInput) -> UnityOutput: 54 | dict_agent_info = {} 55 | if self.is_discrete: 56 | vector_action = [1] 57 | else: 58 | vector_action = [1, 2] 59 | list_agent_info = [] 60 | if self.num_stacks == 1: 61 | observation = [1, 2, 3] 62 | else: 63 | observation = [1, 2, 3, 1, 2, 3] 64 | 65 | for i in range(self.num_agents): 66 | list_agent_info.append( 67 | AgentInfoProto( 68 | stacked_vector_observation=observation, 69 | reward=1, 70 | stored_vector_actions=vector_action, 71 | stored_text_actions="", 72 | text_observation="", 73 | memories=[], 74 | done=(i == 2), 75 | max_step_reached=False, 76 | id=i 77 | )) 78 | dict_agent_info["RealFakeBrain"] = \ 79 | UnityRLOutput.ListAgentInfoProto(value=list_agent_info) 80 | global_done = False 81 | try: 82 | fake_brain = inputs.rl_input.agent_actions["RealFakeBrain"] 83 | global_done = (fake_brain.value[0].vector_actions[0] == -1) 84 | except: 85 | pass 86 | result = UnityRLOutput( 87 | global_done=global_done, 88 | agentInfos=dict_agent_info 89 | ) 90 | return UnityOutput( 91 | rl_output=result 92 | ) 93 | 94 | def close(self): 95 | """ 96 | Sends a shutdown signal to the unity environment, and closes the grpc connection. 97 | """ 98 | self.has_been_closed = True 99 | -------------------------------------------------------------------------------- /mlagents/envs/rpc_communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import grpc 3 | 4 | import socket 5 | from multiprocessing import Pipe 6 | from concurrent.futures import ThreadPoolExecutor 7 | 8 | from .communicator import Communicator 9 | from .communicator_objects import UnityToExternalServicer, add_UnityToExternalServicer_to_server 10 | from .communicator_objects import UnityMessage, UnityInput, UnityOutput 11 | from .exception import UnityTimeOutException, UnityWorkerInUseException 12 | 13 | logger = logging.getLogger("mlagents.envs") 14 | 15 | 16 | class UnityToExternalServicerImplementation(UnityToExternalServicer): 17 | def __init__(self): 18 | self.parent_conn, self.child_conn = Pipe() 19 | 20 | def Initialize(self, request, context): 21 | self.child_conn.send(request) 22 | return self.child_conn.recv() 23 | 24 | def Exchange(self, request, context): 25 | self.child_conn.send(request) 26 | return self.child_conn.recv() 27 | 28 | 29 | class RpcCommunicator(Communicator): 30 | def __init__(self, worker_id=0, base_port=5005, timeout_wait=30): 31 | """ 32 | Python side of the grpc communication. Python is the server and Unity the client 33 | 34 | 35 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 36 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 37 | """ 38 | self.port = base_port + worker_id 39 | self.worker_id = worker_id 40 | self.timeout_wait = timeout_wait 41 | self.server = None 42 | self.unity_to_external = None 43 | self.is_open = False 44 | self.create_server() 45 | 46 | def create_server(self): 47 | """ 48 | Creates the GRPC server. 49 | """ 50 | self.check_port(self.port) 51 | 52 | try: 53 | # Establish communication grpc 54 | self.server = grpc.server(ThreadPoolExecutor(max_workers=10)) 55 | self.unity_to_external = UnityToExternalServicerImplementation() 56 | add_UnityToExternalServicer_to_server(self.unity_to_external, self.server) 57 | # Using unspecified address, which means that grpc is communicating on all IPs 58 | # This is so that the docker container can connect. 59 | self.server.add_insecure_port('[::]:' + str(self.port)) 60 | self.server.start() 61 | self.is_open = True 62 | except: 63 | raise UnityWorkerInUseException(self.worker_id) 64 | 65 | def check_port(self, port): 66 | """ 67 | Attempts to bind to the requested communicator port, checking if it is already in use. 68 | """ 69 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 70 | try: 71 | s.bind(("localhost", port)) 72 | except socket.error: 73 | raise UnityWorkerInUseException(self.worker_id) 74 | finally: 75 | s.close() 76 | 77 | def initialize(self, inputs: UnityInput) -> UnityOutput: 78 | if not self.unity_to_external.parent_conn.poll(self.timeout_wait): 79 | raise UnityTimeOutException( 80 | "The Unity environment took too long to respond. Make sure that :\n" 81 | "\t The environment does not need user interaction to launch\n" 82 | "\t The Academy's Broadcast Hub is configured correctly\n" 83 | "\t The Agents are linked to the appropriate Brains\n" 84 | "\t The environment and the Python interface have compatible versions.") 85 | aca_param = self.unity_to_external.parent_conn.recv().unity_output 86 | message = UnityMessage() 87 | message.header.status = 200 88 | message.unity_input.CopyFrom(inputs) 89 | self.unity_to_external.parent_conn.send(message) 90 | self.unity_to_external.parent_conn.recv() 91 | return aca_param 92 | 93 | def exchange(self, inputs: UnityInput) -> UnityOutput: 94 | message = UnityMessage() 95 | message.header.status = 200 96 | message.unity_input.CopyFrom(inputs) 97 | self.unity_to_external.parent_conn.send(message) 98 | output = self.unity_to_external.parent_conn.recv() 99 | if output.header.status != 200: 100 | return None 101 | return output.unity_output 102 | 103 | def close(self): 104 | """ 105 | Sends a shutdown signal to the unity environment, and closes the grpc connection. 106 | """ 107 | if self.is_open: 108 | message_input = UnityMessage() 109 | message_input.header.status = 400 110 | self.unity_to_external.parent_conn.send(message_input) 111 | self.unity_to_external.parent_conn.close() 112 | self.server.stop(False) 113 | self.is_open = False 114 | -------------------------------------------------------------------------------- /mlagents/envs/socket_communicator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import socket 3 | import struct 4 | 5 | from .communicator import Communicator 6 | from .communicator_objects import UnityMessage, UnityOutput, UnityInput 7 | from .exception import UnityTimeOutException 8 | 9 | 10 | logger = logging.getLogger("mlagents.envs") 11 | 12 | 13 | class SocketCommunicator(Communicator): 14 | def __init__(self, worker_id=0, 15 | base_port=5005): 16 | """ 17 | Python side of the socket communication 18 | 19 | :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. 20 | :int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios. 21 | """ 22 | 23 | self.port = base_port + worker_id 24 | self._buffer_size = 12000 25 | self.worker_id = worker_id 26 | self._socket = None 27 | self._conn = None 28 | 29 | def initialize(self, inputs: UnityInput) -> UnityOutput: 30 | try: 31 | # Establish communication socket 32 | self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 33 | self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 34 | self._socket.bind(("localhost", self.port)) 35 | except: 36 | raise UnityTimeOutException("Couldn't start socket communication because worker number {} is still in use. " 37 | "You may need to manually close a previously opened environment " 38 | "or use a different worker number.".format(str(self.worker_id))) 39 | try: 40 | self._socket.settimeout(30) 41 | self._socket.listen(1) 42 | self._conn, _ = self._socket.accept() 43 | self._conn.settimeout(30) 44 | except : 45 | raise UnityTimeOutException( 46 | "The Unity environment took too long to respond. Make sure that :\n" 47 | "\t The environment does not need user interaction to launch\n" 48 | "\t The Academy's Broadcast Hub is configured correctly\n" 49 | "\t The Agents are linked to the appropriate Brains\n" 50 | "\t The environment and the Python interface have compatible versions.") 51 | message = UnityMessage() 52 | message.header.status = 200 53 | message.unity_input.CopyFrom(inputs) 54 | self._communicator_send(message.SerializeToString()) 55 | initialization_output = UnityMessage() 56 | initialization_output.ParseFromString(self._communicator_receive()) 57 | return initialization_output.unity_output 58 | 59 | def _communicator_receive(self): 60 | try: 61 | s = self._conn.recv(self._buffer_size) 62 | message_length = struct.unpack("I", bytearray(s[:4]))[0] 63 | s = s[4:] 64 | while len(s) != message_length: 65 | s += self._conn.recv(self._buffer_size) 66 | except socket.timeout as e: 67 | raise UnityTimeOutException("The environment took too long to respond.") 68 | return s 69 | 70 | def _communicator_send(self, message): 71 | self._conn.send(struct.pack("I", len(message)) + message) 72 | 73 | def exchange(self, inputs: UnityInput) -> UnityOutput: 74 | message = UnityMessage() 75 | message.header.status = 200 76 | message.unity_input.CopyFrom(inputs) 77 | self._communicator_send(message.SerializeToString()) 78 | outputs = UnityMessage() 79 | outputs.ParseFromString(self._communicator_receive()) 80 | if outputs.header.status != 200: 81 | return None 82 | return outputs.unity_output 83 | 84 | def close(self): 85 | """ 86 | Sends a shutdown signal to the unity environment, and closes the socket connection. 87 | """ 88 | if self._socket is not None and self._conn is not None: 89 | message_input = UnityMessage() 90 | message_input.header.status = 400 91 | self._communicator_send(message_input.SerializeToString()) 92 | if self._socket is not None: 93 | self._socket.close() 94 | self._socket = None 95 | if self._socket is not None: 96 | self._conn.close() 97 | self._conn = None 98 | 99 | -------------------------------------------------------------------------------- /mlagents/envs/subprocess_environment.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import copy 3 | import numpy as np 4 | import cloudpickle 5 | 6 | from mlagents.envs import UnityEnvironment 7 | from multiprocessing import Process, Pipe 8 | from multiprocessing.connection import Connection 9 | from mlagents.envs.base_unity_environment import BaseUnityEnvironment 10 | from mlagents.envs import AllBrainInfo, UnityEnvironmentException 11 | 12 | 13 | class EnvironmentCommand(NamedTuple): 14 | name: str 15 | payload: Any = None 16 | 17 | 18 | class EnvironmentResponse(NamedTuple): 19 | name: str 20 | worker_id: int 21 | payload: Any 22 | 23 | 24 | class UnityEnvWorker(NamedTuple): 25 | process: Process 26 | worker_id: int 27 | conn: Connection 28 | 29 | def send(self, name: str, payload=None): 30 | try: 31 | cmd = EnvironmentCommand(name, payload) 32 | self.conn.send(cmd) 33 | except (BrokenPipeError, EOFError): 34 | raise KeyboardInterrupt 35 | 36 | def recv(self) -> EnvironmentResponse: 37 | try: 38 | response: EnvironmentResponse = self.conn.recv() 39 | return response 40 | except (BrokenPipeError, EOFError): 41 | raise KeyboardInterrupt 42 | 43 | def close(self): 44 | try: 45 | self.conn.send(EnvironmentCommand('close')) 46 | except (BrokenPipeError, EOFError): 47 | pass 48 | self.process.join() 49 | 50 | 51 | def worker(parent_conn: Connection, pickled_env_factory: str, worker_id: int): 52 | env_factory: Callable[[int], UnityEnvironment] = cloudpickle.loads(pickled_env_factory) 53 | env = env_factory(worker_id) 54 | 55 | def _send_response(cmd_name, payload): 56 | parent_conn.send( 57 | EnvironmentResponse(cmd_name, worker_id, payload) 58 | ) 59 | try: 60 | while True: 61 | cmd: EnvironmentCommand = parent_conn.recv() 62 | if cmd.name == 'step': 63 | vector_action, memory, text_action, value = cmd.payload 64 | all_brain_info = env.step(vector_action, memory, text_action, value) 65 | _send_response('step', all_brain_info) 66 | elif cmd.name == 'external_brains': 67 | _send_response('external_brains', env.external_brains) 68 | elif cmd.name == 'reset_parameters': 69 | _send_response('reset_parameters', env.reset_parameters) 70 | elif cmd.name == 'reset': 71 | all_brain_info = env.reset(cmd.payload[0], cmd.payload[1]) 72 | _send_response('reset', all_brain_info) 73 | elif cmd.name == 'global_done': 74 | _send_response('global_done', env.global_done) 75 | elif cmd.name == 'close': 76 | break 77 | except KeyboardInterrupt: 78 | print('UnityEnvironment worker: keyboard interrupt') 79 | finally: 80 | env.close() 81 | 82 | 83 | class SubprocessUnityEnvironment(BaseUnityEnvironment): 84 | def __init__(self, 85 | env_factory: Callable[[int], BaseUnityEnvironment], 86 | n_env: int = 1): 87 | self.envs = [] 88 | self.env_agent_counts = {} 89 | self.waiting = False 90 | for worker_id in range(n_env): 91 | self.envs.append(self.create_worker(worker_id, env_factory)) 92 | 93 | @staticmethod 94 | def create_worker( 95 | worker_id: int, 96 | env_factory: Callable[[int], BaseUnityEnvironment] 97 | ) -> UnityEnvWorker: 98 | parent_conn, child_conn = Pipe() 99 | 100 | # Need to use cloudpickle for the env factory function since function objects aren't picklable 101 | # on Windows as of Python 3.6. 102 | pickled_env_factory = cloudpickle.dumps(env_factory) 103 | child_process = Process(target=worker, args=(child_conn, pickled_env_factory, worker_id)) 104 | child_process.start() 105 | return UnityEnvWorker(child_process, worker_id, parent_conn) 106 | 107 | def step_async(self, vector_action, memory=None, text_action=None, value=None) -> None: 108 | if self.waiting: 109 | raise UnityEnvironmentException( 110 | 'Tried to take an environment step bore previous step has completed.' 111 | ) 112 | 113 | agent_counts_cum = {} 114 | for brain_name in self.env_agent_counts.keys(): 115 | agent_counts_cum[brain_name] = np.cumsum(self.env_agent_counts[brain_name]) 116 | 117 | # Split the actions provided by the previous set of agent counts, and send the step 118 | # commands to the workers. 119 | for worker_id, env in enumerate(self.envs): 120 | env_actions = {} 121 | env_memory = {} 122 | env_text_action = {} 123 | env_value = {} 124 | for brain_name in self.env_agent_counts.keys(): 125 | start_ind = 0 126 | if worker_id > 0: 127 | start_ind = agent_counts_cum[brain_name][worker_id - 1] 128 | end_ind = agent_counts_cum[brain_name][worker_id] 129 | if vector_action.get(brain_name) is not None: 130 | env_actions[brain_name] = vector_action[brain_name][start_ind:end_ind] 131 | if memory and memory.get(brain_name) is not None: 132 | env_memory[brain_name] = memory[brain_name][start_ind:end_ind] 133 | if text_action and text_action.get(brain_name) is not None: 134 | env_text_action[brain_name] = text_action[brain_name][start_ind:end_ind] 135 | if value and value.get(brain_name) is not None: 136 | env_value[brain_name] = value[brain_name][start_ind:end_ind] 137 | 138 | env.send('step', (env_actions, env_memory, env_text_action, env_value)) 139 | self.waiting = True 140 | 141 | def step_await(self) -> AllBrainInfo: 142 | if not self.waiting: 143 | raise UnityEnvironmentException('Tried to await an environment step, but no async step was taken.') 144 | 145 | steps = [self.envs[i].recv() for i in range(len(self.envs))] 146 | self._get_agent_counts(map(lambda s: s.payload, steps)) 147 | combined_brain_info = self._merge_step_info(steps) 148 | self.waiting = False 149 | return combined_brain_info 150 | 151 | def step(self, vector_action=None, memory=None, text_action=None, value=None) -> AllBrainInfo: 152 | self.step_async(vector_action, memory, text_action, value) 153 | return self.step_await() 154 | 155 | def reset(self, config=None, train_mode=True) -> AllBrainInfo: 156 | self._broadcast_message('reset', (config, train_mode)) 157 | reset_results = [self.envs[i].recv() for i in range(len(self.envs))] 158 | self._get_agent_counts(map(lambda r: r.payload, reset_results)) 159 | 160 | return self._merge_step_info(reset_results) 161 | 162 | @property 163 | def global_done(self): 164 | self._broadcast_message('global_done') 165 | dones: List[EnvironmentResponse] = [ 166 | self.envs[i].recv().payload for i in range(len(self.envs)) 167 | ] 168 | return all(dones) 169 | 170 | @property 171 | def external_brains(self): 172 | self.envs[0].send('external_brains') 173 | return self.envs[0].recv().payload 174 | 175 | @property 176 | def reset_parameters(self): 177 | self.envs[0].send('reset_parameters') 178 | return self.envs[0].recv().payload 179 | 180 | def close(self): 181 | for env in self.envs: 182 | env.close() 183 | 184 | def _get_agent_counts(self, step_list: Iterable[AllBrainInfo]): 185 | for i, step in enumerate(step_list): 186 | for brain_name, brain_info in step.items(): 187 | if brain_name not in self.env_agent_counts.keys(): 188 | self.env_agent_counts[brain_name] = [0] * len(self.envs) 189 | self.env_agent_counts[brain_name][i] = len(brain_info.agents) 190 | 191 | @staticmethod 192 | def _merge_step_info(env_steps: List[EnvironmentResponse]) -> AllBrainInfo: 193 | accumulated_brain_info: AllBrainInfo = None 194 | for env_step in env_steps: 195 | all_brain_info: AllBrainInfo = env_step.payload 196 | for brain_name, brain_info in all_brain_info.items(): 197 | for i in range(len(brain_info.agents)): 198 | brain_info.agents[i] = str(env_step.worker_id) + '-' + str(brain_info.agents[i]) 199 | if accumulated_brain_info: 200 | accumulated_brain_info[brain_name].merge(brain_info) 201 | if not accumulated_brain_info: 202 | accumulated_brain_info = copy.deepcopy(all_brain_info) 203 | return accumulated_brain_info 204 | 205 | def _broadcast_message(self, name: str, payload = None): 206 | for env in self.envs: 207 | env.send(name, payload) -------------------------------------------------------------------------------- /mlagents/envs/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StepNeverStop/RLwithUnity/f82be902e327b25c9884d8afe8b892f013197c22/mlagents/envs/tests/__init__.py -------------------------------------------------------------------------------- /mlagents/envs/tests/test_envs.py: -------------------------------------------------------------------------------- 1 | import unittest.mock as mock 2 | import pytest 3 | 4 | import numpy as np 5 | 6 | from mlagents.envs import UnityEnvironment, UnityEnvironmentException, UnityActionException, \ 7 | BrainInfo 8 | from mlagents.envs.mock_communicator import MockCommunicator 9 | 10 | 11 | @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') 12 | def test_handles_bad_filename(get_communicator): 13 | with pytest.raises(UnityEnvironmentException): 14 | UnityEnvironment(' ') 15 | 16 | 17 | @mock.patch('mlagents.envs.UnityEnvironment.executable_launcher') 18 | @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') 19 | def test_initialization(mock_communicator, mock_launcher): 20 | mock_communicator.return_value = MockCommunicator( 21 | discrete_action=False, visual_inputs=0) 22 | env = UnityEnvironment(' ') 23 | with pytest.raises(UnityActionException): 24 | env.step([0]) 25 | assert env.brain_names[0] == 'RealFakeBrain' 26 | env.close() 27 | 28 | 29 | @mock.patch('mlagents.envs.UnityEnvironment.executable_launcher') 30 | @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') 31 | def test_reset(mock_communicator, mock_launcher): 32 | mock_communicator.return_value = MockCommunicator( 33 | discrete_action=False, visual_inputs=0) 34 | env = UnityEnvironment(' ') 35 | brain = env.brains['RealFakeBrain'] 36 | brain_info = env.reset() 37 | env.close() 38 | assert not env.global_done 39 | assert isinstance(brain_info, dict) 40 | assert isinstance(brain_info['RealFakeBrain'], BrainInfo) 41 | assert isinstance(brain_info['RealFakeBrain'].visual_observations, list) 42 | assert isinstance(brain_info['RealFakeBrain'].vector_observations, np.ndarray) 43 | assert len(brain_info['RealFakeBrain'].visual_observations) == brain.number_visual_observations 44 | assert len(brain_info['RealFakeBrain'].vector_observations) == \ 45 | len(brain_info['RealFakeBrain'].agents) 46 | assert len(brain_info['RealFakeBrain'].vector_observations[0]) == \ 47 | brain.vector_observation_space_size * brain.num_stacked_vector_observations 48 | 49 | 50 | @mock.patch('mlagents.envs.UnityEnvironment.executable_launcher') 51 | @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') 52 | def test_step(mock_communicator, mock_launcher): 53 | mock_communicator.return_value = MockCommunicator( 54 | discrete_action=False, visual_inputs=0) 55 | env = UnityEnvironment(' ') 56 | brain = env.brains['RealFakeBrain'] 57 | brain_info = env.reset() 58 | brain_info = env.step([0] * brain.vector_action_space_size[0] * len(brain_info['RealFakeBrain'].agents)) 59 | with pytest.raises(UnityActionException): 60 | env.step([0]) 61 | brain_info = env.step([-1] * brain.vector_action_space_size[0] * len(brain_info['RealFakeBrain'].agents)) 62 | with pytest.raises(UnityActionException): 63 | env.step([0] * brain.vector_action_space_size[0] * len(brain_info['RealFakeBrain'].agents)) 64 | env.close() 65 | assert env.global_done 66 | assert isinstance(brain_info, dict) 67 | assert isinstance(brain_info['RealFakeBrain'], BrainInfo) 68 | assert isinstance(brain_info['RealFakeBrain'].visual_observations, list) 69 | assert isinstance(brain_info['RealFakeBrain'].vector_observations, np.ndarray) 70 | assert len(brain_info['RealFakeBrain'].visual_observations) == brain.number_visual_observations 71 | assert len(brain_info['RealFakeBrain'].vector_observations) == \ 72 | len(brain_info['RealFakeBrain'].agents) 73 | assert len(brain_info['RealFakeBrain'].vector_observations[0]) == \ 74 | brain.vector_observation_space_size * brain.num_stacked_vector_observations 75 | 76 | print("\n\n\n\n\n\n\n" + str(brain_info['RealFakeBrain'].local_done)) 77 | assert not brain_info['RealFakeBrain'].local_done[0] 78 | assert brain_info['RealFakeBrain'].local_done[2] 79 | 80 | 81 | @mock.patch('mlagents.envs.UnityEnvironment.executable_launcher') 82 | @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') 83 | def test_close(mock_communicator, mock_launcher): 84 | comm = MockCommunicator( 85 | discrete_action=False, visual_inputs=0) 86 | mock_communicator.return_value = comm 87 | env = UnityEnvironment(' ') 88 | assert env._loaded 89 | env.close() 90 | assert not env._loaded 91 | assert comm.has_been_closed 92 | 93 | 94 | if __name__ == '__main__': 95 | pytest.main() 96 | -------------------------------------------------------------------------------- /mlagents/envs/tests/test_rpc_communicator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from mlagents.envs import RpcCommunicator 4 | from mlagents.envs import UnityWorkerInUseException 5 | 6 | 7 | def test_rpc_communicator_checks_port_on_create(): 8 | first_comm = RpcCommunicator() 9 | with pytest.raises(UnityWorkerInUseException): 10 | second_comm = RpcCommunicator() 11 | second_comm.close() 12 | first_comm.close() 13 | 14 | 15 | def test_rpc_communicator_close(): 16 | # Ensures it is possible to open a new RPC Communicators 17 | # after closing one on the same worker_id 18 | first_comm = RpcCommunicator() 19 | first_comm.close() 20 | second_comm = RpcCommunicator() 21 | second_comm.close() 22 | 23 | 24 | def test_rpc_communicator_create_multiple_workers(): 25 | # Ensures multiple RPC communicators can be created with 26 | # different worker_ids without causing an error. 27 | first_comm = RpcCommunicator() 28 | second_comm = RpcCommunicator(worker_id=1) 29 | first_comm.close() 30 | second_comm.close() 31 | 32 | -------------------------------------------------------------------------------- /mlagents/envs/tests/test_subprocess_unity_environment.py: -------------------------------------------------------------------------------- 1 | import unittest.mock as mock 2 | from unittest.mock import MagicMock 3 | import unittest 4 | 5 | from mlagents.envs.subprocess_environment import * 6 | from mlagents.envs import UnityEnvironmentException, BrainInfo 7 | 8 | 9 | def mock_env_factory(worker_id: int): 10 | return mock.create_autospec(spec=BaseUnityEnvironment) 11 | 12 | 13 | class MockEnvWorker: 14 | def __init__(self, worker_id): 15 | self.worker_id = worker_id 16 | self.process = None 17 | self.conn = None 18 | self.send = MagicMock() 19 | self.recv = MagicMock() 20 | 21 | 22 | class SubprocessEnvironmentTest(unittest.TestCase): 23 | def test_environments_are_created(self): 24 | SubprocessUnityEnvironment.create_worker = MagicMock() 25 | env = SubprocessUnityEnvironment(mock_env_factory, 2) 26 | # Creates two processes 27 | self.assertEqual(env.create_worker.call_args_list, [ 28 | mock.call(0, mock_env_factory), 29 | mock.call(1, mock_env_factory) 30 | ]) 31 | self.assertEqual(len(env.envs), 2) 32 | 33 | def test_step_async_fails_when_waiting(self): 34 | env = SubprocessUnityEnvironment(mock_env_factory, 0) 35 | env.waiting = True 36 | with self.assertRaises(UnityEnvironmentException): 37 | env.step_async(vector_action=[]) 38 | 39 | @staticmethod 40 | def test_step_async_splits_input_by_agent_count(): 41 | env = SubprocessUnityEnvironment(mock_env_factory, 0) 42 | env.env_agent_counts = { 43 | 'MockBrain': [1, 3, 5] 44 | } 45 | env.envs = [ 46 | MockEnvWorker(0), 47 | MockEnvWorker(1), 48 | MockEnvWorker(2), 49 | ] 50 | env_0_actions = [[1.0, 2.0]] 51 | env_1_actions = ([[3.0, 4.0]] * 3) 52 | env_2_actions = ([[5.0, 6.0]] * 5) 53 | vector_action = { 54 | 'MockBrain': env_0_actions + env_1_actions + env_2_actions 55 | } 56 | env.step_async(vector_action=vector_action) 57 | env.envs[0].send.assert_called_with('step', ({'MockBrain': env_0_actions}, {}, {}, {})) 58 | env.envs[1].send.assert_called_with('step', ({'MockBrain': env_1_actions}, {}, {}, {})) 59 | env.envs[2].send.assert_called_with('step', ({'MockBrain': env_2_actions}, {}, {}, {})) 60 | 61 | def test_step_async_sets_waiting(self): 62 | env = SubprocessUnityEnvironment(mock_env_factory, 0) 63 | env.step_async(vector_action=[]) 64 | self.assertTrue(env.waiting) 65 | 66 | def test_step_await_fails_if_not_waiting(self): 67 | env = SubprocessUnityEnvironment(mock_env_factory, 0) 68 | with self.assertRaises(UnityEnvironmentException): 69 | env.step_await() 70 | 71 | def test_step_await_combines_brain_info(self): 72 | all_brain_info_env0 = { 73 | 'MockBrain': BrainInfo([], [[1.0, 2.0], [1.0, 2.0]], [], agents=[1, 2], memory=np.zeros((0,0))) 74 | } 75 | all_brain_info_env1 = { 76 | 'MockBrain': BrainInfo([], [[3.0, 4.0]], [], agents=[3], memory=np.zeros((0,0))) 77 | } 78 | env_worker_0 = MockEnvWorker(0) 79 | env_worker_0.recv.return_value = EnvironmentResponse('step', 0, all_brain_info_env0) 80 | env_worker_1 = MockEnvWorker(1) 81 | env_worker_1.recv.return_value = EnvironmentResponse('step', 1, all_brain_info_env1) 82 | env = SubprocessUnityEnvironment(mock_env_factory, 0) 83 | env.envs = [env_worker_0, env_worker_1] 84 | env.waiting = True 85 | combined_braininfo = env.step_await()['MockBrain'] 86 | self.assertEqual( 87 | combined_braininfo.vector_observations.tolist(), 88 | [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0]] 89 | ) 90 | self.assertEqual(combined_braininfo.agents, ['0-1', '0-2', '1-3']) 91 | -------------------------------------------------------------------------------- /ppo/ppo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from mlagents.envs import UnityEnvironment 4 | 5 | initKernelAndBias={ 6 | 'kernel_initializer' : tf.random_normal_initializer(0., .1), 7 | 'bias_initializer' : tf.constant_initializer(0.1, dtype=tf.float32) 8 | } 9 | 10 | EPSILON = 0.2 11 | ACTOR_LEARNING_RATE = 0.001 12 | CRITIC_LEARNING_RATE = 0.001 13 | GAMMA = 0.99 14 | LAMBDA = 1.0 15 | MAX_EP = 5000 16 | MAX_STEP = 200 17 | BATCHSIZE = 4 18 | LEARN_COUNTS = 10 19 | 20 | class Actor(object): 21 | def __init__(self, sess, s_dim, a_counts): 22 | self.sess = sess 23 | 24 | self.s_dim = s_dim 25 | self.a_counts = a_counts 26 | self.s = tf.placeholder(tf.float32, [None, s_dim], "state") 27 | self.a = tf.placeholder(tf.float32, [None, a_counts], "action") 28 | self.advantage = tf.placeholder(tf.float32, [None, 1], "advantage") 29 | 30 | self.new_norm_dist, self.new_actor_vars = self._build_actor_net('ActorNew', trainable=True) 31 | self.old_norm_dist, self.old_actor_vars = self._build_actor_net('ActorOld', trainable=False) 32 | 33 | self.sample_op = self.new_norm_dist.sample() 34 | # self.sample_op = self.old_norm_dist.sample() 35 | 36 | self.old_action_prob = self.old_norm_dist.prob(self.a) 37 | self.new_action_prob = self.new_norm_dist.prob(self.a) 38 | 39 | ratio = self.new_action_prob / self.old_action_prob 40 | surrogate = ratio * self.advantage 41 | self.actor_loss = -tf.reduce_mean(tf.minimum( 42 | surrogate, 43 | tf.clip_by_value(ratio, 1 - EPSILON, 1 + EPSILON) * self.advantage 44 | )) 45 | 46 | self.train_op = tf.train.AdamOptimizer(ACTOR_LEARNING_RATE).minimize(self.actor_loss) 47 | 48 | self.assign_new_to_old = [p.assign(oldp) for p, oldp in zip(self.new_actor_vars, self.old_actor_vars)] 49 | # self.assign_new_to_old = [tf.assign(oldp,p) for p, oldp in zip(self.new_actor_vars, self.old_actor_vars)] 50 | 51 | def _build_actor_net(self, name, trainable): 52 | with tf.variable_scope(name): 53 | layer1 = tf.layers.dense( 54 | inputs = self.s, 55 | units = 20, 56 | activation=tf.nn.relu, 57 | name='layer1', 58 | **initKernelAndBias, 59 | trainable=trainable 60 | ) 61 | layer2 = tf.layers.dense( 62 | inputs=layer1, 63 | units=20, 64 | activation=tf.nn.relu, 65 | name='layer2', 66 | **initKernelAndBias, 67 | trainable=trainable 68 | ) 69 | mu = tf.layers.dense( 70 | inputs=layer2, 71 | units=self.a_counts, 72 | activation=tf.nn.tanh, 73 | name='mu', 74 | **initKernelAndBias, 75 | trainable=trainable 76 | ) 77 | sigma = tf.layers.dense( 78 | inputs=layer2, 79 | units=self.a_counts, 80 | activation=tf.nn.softplus, 81 | name='delta', 82 | **initKernelAndBias, 83 | trainable=trainable 84 | ) 85 | norm_dist = tf.distributions.Normal(loc=mu, scale=sigma) 86 | var = tf.get_variable_scope().global_variables() 87 | return norm_dist, var 88 | 89 | def choose_action(self, s): 90 | return self.sess.run(self.sample_op, feed_dict={ 91 | self.s : s 92 | }) 93 | 94 | def learn(self, s, a, advantage): 95 | advantage = advantage[:, np.newaxis] 96 | self.sess.run(self.train_op, feed_dict={ 97 | self.s : s, 98 | self.a : a, 99 | self.advantage : advantage 100 | }) 101 | 102 | def assign_params(self): 103 | self.sess.run(self.assign_new_to_old) 104 | 105 | class Critic(object): 106 | def __init__(self, sess, s_dim): 107 | self.sess = sess 108 | 109 | self.s = tf.placeholder(tf.float32, [None, s_dim], "state") 110 | self.dc_r = tf.placeholder(tf.float32, [None, 1], 'discounted_r') 111 | 112 | self._build_critic_net('Critic') 113 | self.advantage = self.dc_r - self.v 114 | self.loss = tf.reduce_mean(tf.square(self.advantage)) 115 | self.train_op = tf.train.AdamOptimizer(CRITIC_LEARNING_RATE).minimize(self.loss) 116 | 117 | def _build_critic_net(self, name): 118 | with tf.variable_scope(name): 119 | layer1 = tf.layers.dense( 120 | inputs=self.s, 121 | units=30, 122 | activation=tf.nn.relu, 123 | name='layer1', 124 | **initKernelAndBias 125 | ) 126 | layer2 = tf.layers.dense( 127 | inputs=layer1, 128 | units=10, 129 | activation=tf.nn.relu, 130 | name='layer2', 131 | **initKernelAndBias 132 | ) 133 | self.v = tf.layers.dense( 134 | inputs=layer2, 135 | units=1, 136 | activation=None, 137 | name='values', 138 | **initKernelAndBias 139 | ) 140 | 141 | def get_state_value(self, s): 142 | return self.sess.run(self.v, feed_dict={ 143 | self.s : s 144 | }) 145 | 146 | #---------------deprecated 147 | def get_advantage(self, s): 148 | self.values = get_state_value(s) 149 | sub_advantage = tf.zeros_like(self.r) 150 | advantage = tf.zeros_like(self.r) 151 | for index in reversed(range(tf.shape(sub_advantage)[0])): 152 | sub_advantage[index]=self.r[index] + GAMMA * self.values[index+1] - self.values[index] 153 | tmp = 0 154 | for index in reversed(range(tf.shape(sub_advantage)[0])): 155 | tmp = tmp * LAMBDA * GAMMA + sub_advantage[index] 156 | advantage[index] = tmp 157 | return advantage 158 | 159 | def learn(self, s, dc_r): 160 | self.sess.run(self.train_op,feed_dict={ 161 | self.s : s, 162 | self.dc_r : dc_r 163 | }) 164 | 165 | def main(): 166 | env = UnityEnvironment() 167 | brain_name = env.brain_names[0] 168 | brain = env.brains[brain_name] 169 | print(brain.vector_observation_space_size) 170 | print(brain.vector_action_space_size) 171 | 172 | sess = tf.Session() 173 | 174 | actor = Actor( 175 | sess=sess, 176 | s_dim=brain.vector_observation_space_size, 177 | a_counts=brain.vector_action_space_size[0] 178 | ) 179 | critic = Critic( 180 | sess=sess, 181 | s_dim=brain.vector_observation_space_size 182 | ) 183 | 184 | sess.run(tf.global_variables_initializer()) 185 | 186 | for episode in range(MAX_EP): 187 | step = 0 188 | total_reward = 0. 189 | discounted_reward = 0 190 | s, a, r, dc_r= [], [], [], [] 191 | obs = env.reset(train_mode=True)[brain_name] 192 | state = obs.vector_observations 193 | s.append(state[0]) 194 | while True: 195 | action = actor.choose_action(state) 196 | a.append(action[0]) 197 | obs = env.step(action)[brain_name] 198 | step += 1 199 | reward = obs.rewards 200 | r.append(reward[0]) 201 | state = obs.vector_observations 202 | done = obs.local_done[0] 203 | if done or step >= MAX_STEP: 204 | if len(s) < BATCHSIZE: 205 | break 206 | else: 207 | length = len(s) 208 | for index in reversed(range(length)): 209 | discounted_reward = discounted_reward * GAMMA + r[index] 210 | dc_r.append(discounted_reward) 211 | total_reward = dc_r[-1] 212 | s_ = list(reversed([s[index:index+BATCHSIZE] for index in range(length-BATCHSIZE+1)])) 213 | a_ = list(reversed([a[index:index+BATCHSIZE] for index in range(length-BATCHSIZE+1)])) 214 | r_ = list(reversed([r[index:index+BATCHSIZE] for index in range(length-BATCHSIZE+1)])) 215 | dc_r_ = list(reversed([dc_r[index:index+BATCHSIZE] for index in range(length-BATCHSIZE+1)])) 216 | for index in range(len(s_)): 217 | actor.assign_params() 218 | ss = np.array(s_[index]) 219 | aa = np.array(a_[index]) 220 | rr = np.array(r_[index]) 221 | dc_rr =np.array(dc_r_[index])[:, np.newaxis] 222 | values = critic.get_state_value(ss) 223 | value_ = critic.get_state_value(state) 224 | sub_advantage=np.zeros_like(rr) 225 | for index in reversed(range(np.shape(rr)[0])): 226 | sub_advantage[index] = rr[index] + GAMMA * value_ - values[index] 227 | value_ = values[index] 228 | tmp = 0 229 | advantage=np.zeros_like(sub_advantage) 230 | for index in reversed(range(np.shape(sub_advantage)[0])): 231 | tmp = tmp * LAMBDA * GAMMA + sub_advantage[index] 232 | advantage[index] = tmp 233 | 234 | [actor.learn(ss, aa, advantage) for _ in range(LEARN_COUNTS)] 235 | [critic.learn(ss, dc_rr) for _ in range(LEARN_COUNTS)] 236 | if done: 237 | break 238 | 239 | s, a, r, dc_r= [], [], [], [] 240 | s.append(state[0]) 241 | else: 242 | s.append(state[0]) 243 | print('episede: {0} steps: {1} reward: {2}'.format(episode, step, total_reward)) 244 | if __name__ == '__main__': 245 | main() -------------------------------------------------------------------------------- /ppo/ppopac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from mlagents.envs import UnityEnvironment 4 | 5 | initKernelAndBias={ 6 | 'kernel_initializer' : tf.random_normal_initializer(0., .1), 7 | 'bias_initializer' : tf.constant_initializer(0.1) 8 | } 9 | 10 | class Actor(object): 11 | def __init__(self, sess, observationDim, actionDim, learning_rate=0.001, update_frequency=10): 12 | self.sess=sess 13 | self.s=tf.placeholder(tf.float32, [1,observationDim],"state") 14 | self.a=tf.placeholder(tf.float32, [1,actionDim],"action") 15 | self.advantage=tf.placeholder(tf.float32,[1,1],"advantage") 16 | self.update_frequency=update_frequency 17 | 18 | with tf.variable_scope("ActorMain"): 19 | layer1 = tf.layers.dense( 20 | inputs=self.s, 21 | units=20, 22 | activation=tf.nn.relu, 23 | name='layer1', 24 | **initKernelAndBias 25 | ) 26 | 27 | layer2=tf.layers.dense( 28 | inputs=layer1, 29 | units=20, 30 | activation=tf.nn.relu, 31 | name='layer2', 32 | **initKernelAndBias 33 | ) 34 | 35 | self.mu = tf.layers.dense( 36 | inputs=layer2, 37 | units=actionDim, 38 | activation=None, 39 | name='mu', 40 | **initKernelAndBias 41 | ) 42 | self.norm_dist = tf.distributions.Normal(loc=self.mu,scale=[1.]*actionDim) 43 | self.var1=tf.get_variable_scope().global_variables() 44 | 45 | with tf.variable_scope("Actor2"): 46 | layer1 = tf.layers.dense( 47 | inputs=self.s, 48 | units=20, 49 | activation=tf.nn.relu, 50 | name='layer1', 51 | **initKernelAndBias, 52 | trainable=False 53 | ) 54 | 55 | layer2=tf.layers.dense( 56 | inputs=layer1, 57 | units=20, 58 | activation=tf.nn.relu, 59 | name='layer2', 60 | **initKernelAndBias, 61 | trainable=False 62 | ) 63 | self.mu = tf.layers.dense( 64 | inputs=layer2, 65 | units=actionDim, 66 | activation=None, 67 | name='mu', 68 | **initKernelAndBias, 69 | trainable=False 70 | ) 71 | self.norm_dist_behavior = tf.distributions.Normal(loc=self.mu,scale=[1.]*actionDim) 72 | self.sample_op = self.norm_dist_behavior.sample() 73 | self.var2=tf.get_variable_scope().global_variables() 74 | 75 | with tf.variable_scope('exp_v'): 76 | self.log_prob = self.norm_dist.log_prob(self.a) 77 | self.exp_v = tf.reduce_mean(self.log_prob*self.advantage) 78 | with tf.variable_scope('train'): 79 | self.train_op = tf.train.AdamOptimizer(learning_rate).minimize(-self.exp_v) 80 | with tf.variable_scope('assign'): 81 | self.assign_target_to_behavior=[tf.assign(r, v) for r, v in zip(self.var2, self.var1)] 82 | 83 | def choose_action(self, s): 84 | return self.sess.run(self.sample_op,feed_dict={ 85 | self.s:s 86 | }) 87 | def learn(self, s, a, advantage, step): 88 | if step % self.update_frequency == 0: 89 | self.sess.run([self.train_op, self.assign_target_to_behavior],feed_dict={ 90 | self.s:s, 91 | self.a:a, 92 | self.advantage:advantage 93 | }) 94 | else: 95 | self.sess.run(self.train_op,feed_dict={ 96 | self.s:s, 97 | self.a:a, 98 | self.advantage:advantage 99 | }) 100 | 101 | class Critic(object): 102 | def __init__(self, sess, observationDim, learning_rate=0.01, gamma=0.95): 103 | self.sess= sess 104 | 105 | self.s = tf.placeholder(tf.float32, [1,observationDim],"state") 106 | self.r = tf.placeholder(tf.float32, [1,1],"reward") 107 | self.v_ = tf.placeholder(tf.float32, [1,1], "value_of_next") 108 | 109 | with tf.variable_scope('Critic'): 110 | layer1 = tf.layers.dense( 111 | inputs=self.s, 112 | units=30, 113 | activation=tf.nn.relu, 114 | name='layer1', 115 | **initKernelAndBias 116 | ) 117 | layer2 = tf.layers.dense( 118 | inputs=layer1, 119 | units=10, 120 | activation=tf.nn.relu, 121 | name='layer2', 122 | **initKernelAndBias 123 | ) 124 | self.v = tf.layers.dense( 125 | inputs=layer2, 126 | units=1, 127 | activation=None, 128 | name='Value', 129 | **initKernelAndBias 130 | ) 131 | with tf.variable_scope('square_advantage'): 132 | self.advantage = tf.reduce_mean(self.r + gamma*self.v_-self.v) 133 | self.loss = tf.square(self.advantage) 134 | 135 | with tf.variable_scope('train'): 136 | self.train_op = tf.train.AdamOptimizer(learning_rate).minimize(self.loss) 137 | 138 | def learn(self, s, r, s_): 139 | v_ = self.sess.run(self.v, feed_dict={ 140 | self.s: s_ 141 | }) 142 | advantage, _ = self.sess.run([self.advantage, self.train_op], feed_dict={ 143 | self.s: s, 144 | self.v_: v_, 145 | self.r: r 146 | }) 147 | return advantage 148 | 149 | env = UnityEnvironment() 150 | brain_name = env.brain_names[0] 151 | brain = env.brains[brain_name] 152 | print(brain.vector_observation_space_size) 153 | print(brain.vector_action_space_size) 154 | 155 | sess = tf.Session() 156 | 157 | actor = Actor( 158 | sess=sess, 159 | observationDim=brain.vector_observation_space_size, 160 | actionDim=brain.vector_action_space_size[0], 161 | learning_rate=0.02, 162 | update_frequency=10 163 | ) 164 | critic = Critic( 165 | sess=sess, 166 | observationDim=brain.vector_observation_space_size, 167 | learning_rate=0.01, 168 | gamma=0.95 169 | ) 170 | 171 | sess.run(tf.global_variables_initializer()) 172 | 173 | time=0 174 | gamma=0.9 175 | for i_episode in range(5000): 176 | step=0 177 | discounted_reward=0 178 | observation = env.reset(train_mode=True)[brain_name] 179 | s=observation.vector_observations 180 | while True: 181 | time+=1 182 | step+=1 183 | action = np.squeeze(actor.choose_action(s), axis=0) 184 | # print(action) 185 | observation=env.step(action)[brain_name] 186 | 187 | reward=np.array(observation.rewards) 188 | discounted_reward*=gamma #有错 189 | discounted_reward+=reward[0] 190 | advantage = critic.learn(s,reward[np.newaxis,:],observation.vector_observations) 191 | advantage=[[advantage]] 192 | # print(advantage) 193 | actor.learn(s,action[np.newaxis,:],advantage, time) 194 | 195 | s=observation.vector_observations 196 | 197 | if observation.local_done[0]: 198 | print("episode:", i_episode," steps:", step," rewards:", discounted_reward) 199 | break -------------------------------------------------------------------------------- /sac/sac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow.contrib.layers as c_layers 4 | 5 | initKernelAndBias = { 6 | 'kernel_initializer': tf.random_normal_initializer(0., .1), 7 | 'bias_initializer': tf.constant_initializer(0.1, dtype=tf.float32) 8 | } 9 | # initKernelAndBias={ 10 | # 'kernel_initializer' : c_layers.variance_scaling_initializer(1.0) 11 | # } 12 | 13 | 14 | class SAC(object): 15 | def __init__(self, sess, s_dim, a_counts, hyper_config): 16 | self.sess = sess 17 | self.s_dim = s_dim 18 | self.a_counts = a_counts 19 | self.activation_fn = tf.nn.tanh 20 | self.log_alpha = tf.get_variable( 21 | 'log_alpha', dtype=tf.float32, initializer=0.0) 22 | self.alpha = hyper_config['alpha'] if not hyper_config['auto_adaption'] else tf.exp( 23 | self.log_alpha) 24 | self.action_bound = hyper_config['action_bound'] 25 | self.assign_interval = hyper_config['assign_interval'] 26 | 27 | self.episode = tf.Variable(tf.constant(0)) 28 | self.lr = tf.train.polynomial_decay( 29 | hyper_config['lr'], self.episode, hyper_config['max_episode'], 1e-10, power=1.0) 30 | self.s = tf.placeholder(tf.float32, [None, self.s_dim], 'state') 31 | self.a = tf.placeholder(tf.float32, [None, self.a_counts], 'action') 32 | self.r = tf.placeholder(tf.float32, [None, 1], 'reward') 33 | self.s_ = tf.placeholder(tf.float32, [None, self.s_dim], 'next_state') 34 | self.sigma_offset = tf.placeholder( 35 | tf.float32, [self.a_counts, ], 'sigma_offset') 36 | 37 | self.norm_dist, self.a_new, self.log_prob = self._build_actor_net( 38 | 'actor_net') 39 | self.prob = self.norm_dist.prob(self.a_new) 40 | self.entropy = self.norm_dist.entropy() 41 | self.s_a = tf.concat((self.s, self.a), axis=1) 42 | self.s_a_new = tf.concat((self.s, self.a_new), axis=1) 43 | self.q1 = self._build_q_net('q1', self.s_a, False) 44 | self.q2 = self._build_q_net('q2', self.s_a, False) 45 | self.q1_anew = self._build_q_net('q1', self.s_a_new, True) 46 | self.q2_anew = self._build_q_net('q2', self.s_a_new, True) 47 | self.v_from_q = tf.minimum( 48 | self.q1_anew, self.q2_anew) - self.alpha * self.log_prob 49 | self.v_from_q_stop = tf.stop_gradient(self.v_from_q) 50 | self.v, self.v_var = self._build_v_net( 51 | 'v', input_vector=self.s, trainable=True) 52 | self.v_target, self.v_target_var = self._build_v_net( 53 | 'v_target', input_vector=self.s_, trainable=False) 54 | self.dc_r = tf.stop_gradient( 55 | self.r + hyper_config['gamma'] * self.v_target) 56 | 57 | self.q1_loss = tf.reduce_mean( 58 | tf.squared_difference(self.q1, self.dc_r)) 59 | self.q2_loss = tf.reduce_mean( 60 | tf.squared_difference(self.q2, self.dc_r)) 61 | self.v_loss = tf.reduce_mean( 62 | tf.squared_difference(self.v, self.v_from_q)) 63 | self.v_loss_stop = tf.reduce_mean( 64 | tf.squared_difference(self.v, self.v_from_q_stop)) 65 | self.critic_loss = 0.5 * self.q1_loss + 0.5 * \ 66 | self.q2_loss + 0.5 * self.v_loss_stop 67 | self.actor_loss = -tf.reduce_mean( 68 | self.q1_anew - self.alpha * self.log_prob) 69 | self.alpha_loss = -tf.reduce_mean( 70 | self.log_alpha * tf.stop_gradient(self.log_prob - self.a_counts)) 71 | 72 | q1_vars = tf.get_collection( 73 | tf.GraphKeys.TRAINABLE_VARIABLES, scope='q1') 74 | q2_vars = tf.get_collection( 75 | tf.GraphKeys.TRAINABLE_VARIABLES, scope='q2') 76 | value_vars = tf.get_collection( 77 | tf.GraphKeys.TRAINABLE_VARIABLES, scope='v') 78 | actor_vars = tf.get_collection( 79 | tf.GraphKeys.TRAINABLE_VARIABLES, scope='actor_net') 80 | 81 | optimizer = tf.train.AdamOptimizer(self.lr) 82 | self.train_q1 = optimizer.minimize(self.q1_loss, var_list=q1_vars) 83 | self.train_q2 = optimizer.minimize(self.q2_loss, var_list=q2_vars) 84 | self.train_v = optimizer.minimize(self.v_loss, var_list=value_vars) 85 | 86 | self.assign_v_target = tf.group([tf.assign( 87 | r, hyper_config['ployak'] * v + (1 - hyper_config['ployak']) * r) for r, v in zip(self.v_target_var, self.v_var)]) 88 | # self.assign_v_target = [ 89 | # tf.assign(r, 1/(self.episode+1) * v + (1-1/(self.episode+1)) * r) for r, v in zip(self.v_target_var, self.v_var)] 90 | with tf.control_dependencies([self.assign_v_target]): 91 | self.train_critic = optimizer.minimize(self.critic_loss) 92 | with tf.control_dependencies([self.train_critic]): 93 | self.train_actor = optimizer.minimize( 94 | self.actor_loss, var_list=actor_vars) 95 | with tf.control_dependencies([self.train_actor]): 96 | self.train_alpha = optimizer.minimize( 97 | self.alpha_loss, var_list=[self.log_alpha]) 98 | 99 | 100 | def _build_actor_net(self, name): 101 | with tf.variable_scope(name): 102 | actor1 = tf.layers.dense( 103 | inputs=self.s, 104 | units=128, 105 | activation=self.activation_fn, 106 | name='actor1', 107 | **initKernelAndBias 108 | ) 109 | actor2 = tf.layers.dense( 110 | inputs=actor1, 111 | units=64, 112 | activation=self.activation_fn, 113 | name='actor2', 114 | **initKernelAndBias 115 | ) 116 | self.mu = tf.layers.dense( 117 | inputs=actor2, 118 | units=self.a_counts, 119 | activation=tf.nn.tanh, 120 | name='mu', 121 | **initKernelAndBias 122 | ) 123 | sigma1 = tf.layers.dense( 124 | inputs=actor1, 125 | units=64, 126 | activation=self.activation_fn, 127 | name='simga1', 128 | **initKernelAndBias 129 | ) 130 | self.sigma = tf.layers.dense( 131 | inputs=sigma1, 132 | units=self.a_counts, 133 | activation=tf.nn.sigmoid, 134 | name='sigma', 135 | **initKernelAndBias 136 | ) 137 | norm_dist = tf.distributions.Normal( 138 | loc=self.mu, scale=self.sigma + self.sigma_offset) 139 | # action = tf.tanh(norm_dist.sample()) 140 | action = tf.clip_by_value( 141 | norm_dist.sample(), -self.action_bound, self.action_bound) / self.action_bound 142 | log_prob = norm_dist.log_prob(action) 143 | return norm_dist, action, log_prob 144 | 145 | def _build_q_net(self, name, input_vector, reuse): 146 | with tf.variable_scope(name): 147 | layer1 = tf.layers.dense( 148 | inputs=input_vector, 149 | units=256, 150 | activation=self.activation_fn, 151 | name='layer1', 152 | reuse=reuse, 153 | **initKernelAndBias 154 | ) 155 | layer2 = tf.layers.dense( 156 | inputs=layer1, 157 | units=256, 158 | activation=self.activation_fn, 159 | name='layer2', 160 | reuse=reuse, 161 | **initKernelAndBias 162 | ) 163 | q = tf.layers.dense( 164 | inputs=layer2, 165 | units=1, 166 | activation=None, 167 | name='q_value', 168 | reuse=reuse, 169 | **initKernelAndBias 170 | ) 171 | return q 172 | 173 | def _build_v_net(self, name, input_vector, trainable): 174 | with tf.variable_scope(name): 175 | layer1 = tf.layers.dense( 176 | inputs=input_vector, 177 | units=256, 178 | activation=self.activation_fn, 179 | name='layer1', 180 | trainable=trainable, 181 | **initKernelAndBias 182 | ) 183 | layer2 = tf.layers.dense( 184 | inputs=layer1, 185 | units=256, 186 | activation=self.activation_fn, 187 | name='layer2', 188 | trainable=trainable, 189 | **initKernelAndBias 190 | ) 191 | v = tf.layers.dense( 192 | inputs=layer2, 193 | units=1, 194 | activation=None, 195 | name='value', 196 | trainable=trainable, 197 | **initKernelAndBias 198 | ) 199 | var = tf.get_variable_scope().global_variables() 200 | return v, var 201 | 202 | def decay_lr(self, episode, **kargs): 203 | return self.sess.run(self.lr, feed_dict={ 204 | self.episode: episode 205 | }) 206 | 207 | def choose_action(self, s, sigma_offset, **kargs): 208 | return self.sess.run([self.prob, self.a_new], feed_dict={ 209 | self.s: s, 210 | self.sigma_offset: sigma_offset 211 | }) 212 | 213 | def choose_inference_action(self, s, sigma_offset, **kargs): 214 | return self.sess.run([self.prob, self.mu], feed_dict={ 215 | self.s: s, 216 | self.sigma_offset: sigma_offset 217 | }) 218 | 219 | def get_state_value(self, s, sigma_offset, **kargs): 220 | return np.squeeze(np.zeros(np.array(s).shape[0])) 221 | 222 | def learn(self, s, a, r, s_, episode, sigma_offset, **kargs): 223 | # self.sess.run([self.assign_v_target, self.train_q1, self.train_q2, self.train_v, self.train_actor], feed_dict={ 224 | # self.s: s, 225 | # self.a: a, 226 | # self.r: r, 227 | # self.s_: s_, 228 | # self.episode: episode, 229 | # self.sigma_offset: sigma_offset 230 | # }) 231 | self.sess.run([self.assign_v_target, self.train_critic, self.train_actor, self.train_alpha], feed_dict={ 232 | self.s: s, 233 | self.a: a, 234 | self.r: r, 235 | self.s_: s_, 236 | self.episode: episode, 237 | self.sigma_offset: sigma_offset 238 | }) 239 | 240 | def get_entropy(self, s, sigma_offset, **kargs): 241 | return self.sess.run(self.entropy, feed_dict={ 242 | self.s: s, 243 | self.sigma_offset: sigma_offset 244 | }) 245 | 246 | def get_actor_loss(self, s, sigma_offset, **kargs): 247 | return self.sess.run(self.actor_loss, feed_dict={ 248 | self.s: s, 249 | self.sigma_offset: sigma_offset 250 | }) 251 | 252 | def get_critic_loss(self, s, a, r, s_, sigma_offset, **kargs): 253 | return self.sess.run(self.critic_loss, feed_dict={ 254 | self.s: s, 255 | self.a: a, 256 | self.r: r, 257 | self.s_: s_, 258 | self.sigma_offset: sigma_offset 259 | }) 260 | 261 | def get_sigma(self, s, sigma_offset, **kargs): 262 | return self.sess.run(self.sigma, feed_dict={ 263 | self.s: s, 264 | self.sigma_offset: sigma_offset 265 | }) 266 | -------------------------------------------------------------------------------- /sac/sac_no_v.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow.contrib.layers as c_layers 4 | 5 | initKernelAndBias = { 6 | 'kernel_initializer': tf.random_normal_initializer(0., .1), 7 | 'bias_initializer': tf.constant_initializer(0.1, dtype=tf.float32) 8 | } 9 | # initKernelAndBias={ 10 | # 'kernel_initializer' : c_layers.variance_scaling_initializer(1.0) 11 | # } 12 | 13 | 14 | class SAC_NO_V(object): 15 | def __init__(self, sess, s_dim, a_counts, hyper_config): 16 | self.sess = sess 17 | self.s_dim = s_dim 18 | self.a_counts = a_counts 19 | self.activation_fn = tf.nn.tanh 20 | self.log_alpha = tf.get_variable( 21 | 'log_alpha', dtype=tf.float32, initializer=0.0) 22 | self.alpha = hyper_config['alpha'] if not hyper_config['auto_adaption'] else tf.exp( 23 | self.log_alpha) 24 | self.action_bound = hyper_config['action_bound'] 25 | self.assign_interval = hyper_config['assign_interval'] 26 | 27 | self.episode = tf.Variable(tf.constant(0)) 28 | self.lr = tf.train.polynomial_decay( 29 | hyper_config['lr'], self.episode, hyper_config['max_episode'], 1e-10, power=1.0) 30 | self.s = tf.placeholder(tf.float32, [None, self.s_dim], 'state') 31 | self.a = tf.placeholder(tf.float32, [None, self.a_counts], 'action') 32 | self.r = tf.placeholder(tf.float32, [None, 1], 'reward') 33 | self.s_ = tf.placeholder(tf.float32, [None, self.s_dim], 'next_state') 34 | self.sigma_offset = tf.placeholder( 35 | tf.float32, [self.a_counts, ], 'sigma_offset') 36 | 37 | self.norm_dist, self.mu, self.sigma, self.a_s, self.a_s_log_prob = self._build_actor_net( 38 | 'actor_net', self.s, reuse=False) 39 | _, _, _, self.a_s_, self.a_s_log_prob_ = self._build_actor_net( 40 | 'actor_net', self.s_, reuse=True) 41 | self.prob = self.norm_dist.prob(self.a_s) 42 | self.new_log_prob = self.norm_dist.log_prob(self.a) 43 | self.entropy = self.norm_dist.entropy() 44 | self.s_a = tf.concat((self.s, self.a), axis=1) 45 | self.s_a_ = tf.concat((self.s_, self.a_s_), axis=1) 46 | self.s_a_s = tf.concat((self.s, self.a_s), axis=1) 47 | self.q1, self.q1_vars = self._build_q_net( 48 | 'q1', self.s_a, trainable=True, reuse=False) 49 | self.q1_target, self.q1_target_vars = self._build_q_net( 50 | 'q1_target', self.s_a_, trainable=False, reuse=False) 51 | self.q2, self.q2_vars = self._build_q_net( 52 | 'q2', self.s_a, trainable=True, reuse=False) 53 | self.q2_target, self.q2_target_vars = self._build_q_net( 54 | 'q2_target', self.s_a_, trainable=False, reuse=False) 55 | self.q1_s_a, _ = self._build_q_net( 56 | 'q1', self.s_a_s, trainable=True, reuse=True) 57 | self.q2_s_a, _ = self._build_q_net( 58 | 'q2', self.s_a_s, trainable=True, reuse=True) 59 | 60 | self.dc_r_q1 = tf.stop_gradient( 61 | self.r + hyper_config['gamma'] * (self.q1_target - self.alpha * tf.reduce_mean(self.a_s_log_prob_))) 62 | self.dc_r_q2 = tf.stop_gradient( 63 | self.r + hyper_config['gamma'] * (self.q2_target - self.alpha * tf.reduce_mean(self.a_s_log_prob_))) 64 | self.q1_loss = tf.reduce_mean( 65 | tf.squared_difference(self.q1, self.dc_r_q1)) 66 | self.q2_loss = tf.reduce_mean( 67 | tf.squared_difference(self.q2, self.dc_r_q2)) 68 | self.critic_loss = 0.5 * self.q1_loss + 0.5 * self.q2_loss 69 | # self.actor_loss = -tf.reduce_mean( 70 | # tf.minimum(self.q1_s_a, self.q2_s_a) - self.alpha * (self.a_s_log_prob + self.new_log_prob)) 71 | self.actor_loss = -tf.reduce_mean( 72 | tf.minimum(self.q1_s_a, self.q2_s_a) - self.alpha * self.a_s_log_prob) 73 | 74 | self.alpha_loss = -tf.reduce_mean( 75 | self.log_alpha * tf.stop_gradient(self.a_s_log_prob - self.a_counts)) 76 | 77 | q1_vars = tf.get_collection( 78 | tf.GraphKeys.TRAINABLE_VARIABLES, scope='q1') 79 | q2_vars = tf.get_collection( 80 | tf.GraphKeys.TRAINABLE_VARIABLES, scope='q2') 81 | actor_vars = tf.get_collection( 82 | tf.GraphKeys.TRAINABLE_VARIABLES, scope='actor_net') 83 | 84 | optimizer = tf.train.AdamOptimizer(self.lr) 85 | self.train_q1 = optimizer.minimize(self.q1_loss, var_list=q1_vars) 86 | self.train_q2 = optimizer.minimize(self.q2_loss, var_list=q2_vars) 87 | 88 | self.assign_q1_target = tf.group([tf.assign( 89 | r, hyper_config['ployak'] * v + (1 - hyper_config['ployak']) * r) for r, v in zip(self.q1_target_vars, self.q1_vars)]) 90 | self.assign_q2_target = tf.group([tf.assign( 91 | r, hyper_config['ployak'] * v + (1 - hyper_config['ployak']) * r) for r, v in zip(self.q2_target_vars, self.q2_vars)]) 92 | with tf.control_dependencies([self.assign_q1_target, self.assign_q2_target]): 93 | self.train_critic = optimizer.minimize(self.critic_loss) 94 | with tf.control_dependencies([self.train_critic]): 95 | self.train_actor = optimizer.minimize( 96 | self.actor_loss, var_list=actor_vars) 97 | with tf.control_dependencies([self.train_actor]): 98 | self.train_alpha = optimizer.minimize( 99 | self.alpha_loss, var_list=[self.log_alpha]) 100 | # self.assign_q1_target = [ 101 | # tf.assign(r, 1/(self.episode+1) * v + (1-1/(self.episode+1)) * r) for r, v in zip(self.v_target_var, self.v_var)] 102 | # self.assign_q2_target = [ 103 | # tf.assign(r, 1/(self.episode+1) * v + (1-1/(self.episode+1)) * r) for r, v in zip(self.v_target_var, self.v_var)] 104 | 105 | def _build_actor_net(self, name, input_vector, reuse=False): 106 | with tf.variable_scope(name): 107 | actor1 = tf.layers.dense( 108 | inputs=input_vector, 109 | units=128, 110 | activation=self.activation_fn, 111 | name='actor1', 112 | reuse=reuse, 113 | **initKernelAndBias 114 | ) 115 | actor2 = tf.layers.dense( 116 | inputs=actor1, 117 | units=64, 118 | activation=self.activation_fn, 119 | name='actor2', 120 | reuse=reuse, 121 | **initKernelAndBias 122 | ) 123 | mu = tf.layers.dense( 124 | inputs=actor2, 125 | units=self.a_counts, 126 | activation=tf.nn.tanh, 127 | name='mu', 128 | reuse=reuse, 129 | **initKernelAndBias 130 | ) 131 | sigma1 = tf.layers.dense( 132 | inputs=actor1, 133 | units=64, 134 | activation=self.activation_fn, 135 | name='simga1', 136 | reuse=reuse, 137 | **initKernelAndBias 138 | ) 139 | sigma = tf.layers.dense( 140 | inputs=sigma1, 141 | units=self.a_counts, 142 | activation=tf.nn.sigmoid, 143 | name='sigma', 144 | reuse=reuse, 145 | **initKernelAndBias 146 | ) 147 | norm_dist = tf.distributions.Normal( 148 | loc=mu, scale=sigma + self.sigma_offset) 149 | # action = tf.tanh(norm_dist.sample()) 150 | action = tf.clip_by_value( 151 | norm_dist.sample(), -self.action_bound, self.action_bound) / self.action_bound 152 | log_prob = norm_dist.log_prob(action) 153 | return norm_dist, mu, sigma, action, log_prob 154 | 155 | def _build_q_net(self, name, input_vector, trainable, reuse): 156 | with tf.variable_scope(name): 157 | layer1 = tf.layers.dense( 158 | inputs=input_vector, 159 | units=256, 160 | activation=self.activation_fn, 161 | name='layer1', 162 | trainable=trainable, 163 | reuse=reuse, 164 | **initKernelAndBias 165 | ) 166 | layer2 = tf.layers.dense( 167 | inputs=layer1, 168 | units=256, 169 | activation=self.activation_fn, 170 | name='layer2', 171 | trainable=trainable, 172 | reuse=reuse, 173 | **initKernelAndBias 174 | ) 175 | q = tf.layers.dense( 176 | inputs=layer2, 177 | units=1, 178 | activation=None, 179 | name='q_value', 180 | trainable=trainable, 181 | reuse=reuse, 182 | **initKernelAndBias 183 | ) 184 | var = tf.get_variable_scope().global_variables() 185 | return q, var 186 | 187 | def decay_lr(self, episode): 188 | return self.sess.run(self.lr, feed_dict={ 189 | self.episode: episode 190 | }) 191 | 192 | def choose_action(self, s, sigma_offset, **kargs): 193 | return self.sess.run([self.prob, self.a_s], feed_dict={ 194 | self.s: s, 195 | self.sigma_offset: sigma_offset 196 | }) 197 | 198 | def choose_inference_action(self, s, sigma_offset, **kargs): 199 | return self.sess.run([self.prob, self.mu], feed_dict={ 200 | self.s: s, 201 | self.sigma_offset: sigma_offset 202 | }) 203 | 204 | def get_state_value(self, s, sigma_offset, **kargs): 205 | return np.squeeze(np.zeros(np.array(s).shape[0])) 206 | 207 | def learn(self, s, a, r, s_, episode, sigma_offset, **kargs): 208 | # self.sess.run([self.train_q1, self.train_q2, self.train_actor, self.train_alpha, self.assign_q1_target, self.assign_q2_target], feed_dict={ 209 | # self.s: s, 210 | # self.a: a, 211 | # self.r: r, 212 | # self.s_: s_, 213 | # self.episode: episode, 214 | # self.sigma_offset: sigma_offset 215 | # }) 216 | self.sess.run([self.assign_q1_target, self.assign_q2_target, self.train_critic, self.train_actor, self.train_alpha], feed_dict={ 217 | self.s: s, 218 | self.a: a, 219 | self.r: r, 220 | self.s_: s_, 221 | self.episode: episode, 222 | self.sigma_offset: sigma_offset 223 | }) 224 | 225 | def get_entropy(self, s, sigma_offset, **kargs): 226 | return self.sess.run(self.entropy, feed_dict={ 227 | self.s: s, 228 | self.sigma_offset: sigma_offset 229 | }) 230 | 231 | def get_actor_loss(self, s, a, sigma_offset, **kargs): 232 | return self.sess.run(self.actor_loss, feed_dict={ 233 | self.s: s, 234 | self.a: a, 235 | self.sigma_offset: sigma_offset 236 | }) 237 | 238 | def get_critic_loss(self, s, a, r, s_, sigma_offset, **kargs): 239 | return self.sess.run(self.critic_loss, feed_dict={ 240 | self.s: s, 241 | self.a: a, 242 | self.r: r, 243 | self.s_: s_, 244 | self.sigma_offset: sigma_offset 245 | }) 246 | 247 | def get_sigma(self, s, sigma_offset, **kargs): 248 | return self.sess.run(self.sigma, feed_dict={ 249 | self.s: s, 250 | self.sigma_offset: sigma_offset 251 | }) 252 | -------------------------------------------------------------------------------- /td3/td3.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow.contrib.layers as c_layers 4 | 5 | initKernelAndBias = { 6 | 'kernel_initializer': tf.random_normal_initializer(0., .1), 7 | 'bias_initializer': tf.constant_initializer(0.1, dtype=tf.float32) 8 | } 9 | # initKernelAndBias={ 10 | # 'kernel_initializer' : c_layers.variance_scaling_initializer(1.0) 11 | # } 12 | 13 | 14 | class TD3(object): 15 | def __init__(self, sess, s_dim, a_counts, hyper_config): 16 | self.sess = sess 17 | self.s_dim = s_dim 18 | self.a_counts = a_counts 19 | self.activation_fn = tf.nn.tanh 20 | self.action_bound = hyper_config['action_bound'] 21 | self.assign_interval = hyper_config['assign_interval'] 22 | 23 | self.episode = tf.Variable(tf.constant(0)) 24 | self.lr = tf.train.polynomial_decay( 25 | hyper_config['lr'], self.episode, hyper_config['max_episode'], 1e-10, power=1.0) 26 | self.s = tf.placeholder(tf.float32, [None, self.s_dim], 'state') 27 | self.a = tf.placeholder(tf.float32, [None, self.a_counts], 'action') 28 | self.r = tf.placeholder(tf.float32, [None, 1], 'reward') 29 | self.s_ = tf.placeholder(tf.float32, [None, self.s_dim], 'next_state') 30 | 31 | self.mu, self.action, self.actor_var = self._build_actor_net( 32 | 'actor', self.s, trainable=True) 33 | self.target_mu, self.action_target, self.actor_target_var = self._build_actor_net( 34 | 'actor_target', self.s_, trainable=False) 35 | 36 | self.s_a = tf.concat((self.s, self.a), axis=1) 37 | self.s_mu = tf.concat((self.s, self.mu), axis=1) 38 | self.s_a_target = tf.concat((self.s_, self.action_target), axis=1) 39 | 40 | self.q1, self.q1_var = self._build_q_net( 41 | 'q1', self.s_a, True, reuse=False) 42 | self.q1_actor, _ = self._build_q_net('q1', self.s_mu, True, reuse=True) 43 | self.q1_target, self.q1_target_var = self._build_q_net( 44 | 'q1_target', self.s_a_target, False, reuse=False) 45 | 46 | self.q2, self.q2_var = self._build_q_net( 47 | 'q2', self.s_a, True, reuse=False) 48 | self.q2_target, self.q2_target_var = self._build_q_net( 49 | 'q2_target', self.s_a_target, False, reuse=False) 50 | 51 | self.q_target = tf.minimum(self.q1_target, self.q2_target) 52 | self.dc_r = tf.stop_gradient( 53 | self.r + hyper_config['gamma'] * self.q_target) 54 | 55 | self.q1_loss = tf.reduce_mean( 56 | tf.squared_difference(self.q1, self.dc_r)) 57 | self.q2_loss = tf.reduce_mean( 58 | tf.squared_difference(self.q2, self.dc_r)) 59 | self.critic_loss = 0.5 * self.q1_loss + 0.5 * \ 60 | self.q2_loss 61 | self.actor_loss = -tf.reduce_mean(self.q1_actor) 62 | 63 | q1_var = tf.get_collection( 64 | tf.GraphKeys.TRAINABLE_VARIABLES, scope='q1') 65 | q2_var = tf.get_collection( 66 | tf.GraphKeys.TRAINABLE_VARIABLES, scope='q2') 67 | actor_vars = tf.get_collection( 68 | tf.GraphKeys.TRAINABLE_VARIABLES, scope='actor') 69 | 70 | optimizer = tf.train.AdamOptimizer(self.lr) 71 | self.train_q1 = optimizer.minimize(self.q1_loss, var_list=q1_var) 72 | self.train_q2 = optimizer.minimize(self.q2_loss, var_list=q2_var) 73 | self.train_value = optimizer.minimize(self.critic_loss) 74 | with tf.control_dependencies([self.train_value]): 75 | self.train_actor = optimizer.minimize( 76 | self.actor_loss, var_list=actor_vars) 77 | with tf.control_dependencies([self.train_actor]): 78 | self.assign_q1_target = tf.group([tf.assign( 79 | r, hyper_config['ployak'] * v + (1 - hyper_config['ployak']) * r) for r, v in zip(self.q1_target_var, self.q1_var)]) 80 | self.assign_q2_target = tf.group([tf.assign( 81 | r, hyper_config['ployak'] * v + (1 - hyper_config['ployak']) * r) for r, v in zip(self.q2_target_var, self.q2_var)]) 82 | self.assign_actor_target = tf.group([tf.assign( 83 | r, hyper_config['ployak'] * v + (1 - hyper_config['ployak']) * r) for r, v in zip(self.actor_target_var, self.actor_var)]) 84 | # self.assign_q1_target = [ 85 | # tf.assign(r, 1/(self.episode+1) * v + (1-1/(self.episode+1)) * r) for r, v in zip(self.q1_target_var, self.q1_var)] 86 | # self.assign_q2_target = [ 87 | # tf.assign(r, 1/(self.episode+1) * v + (1-1/(self.episode+1)) * r) for r, v in zip(self.q2_target_var, self.q2_var)] 88 | # self.assign_actor_target = [ 89 | # tf.assign(r, 1/(self.episode+1) * v + (1-1/(self.episode+1)) * r) for r, v in zip(self.actor_target_var, self.actor_var)] 90 | 91 | def _build_actor_net(self, name, input_vector, trainable): 92 | with tf.variable_scope(name): 93 | actor1 = tf.layers.dense( 94 | inputs=input_vector, 95 | units=128, 96 | activation=self.activation_fn, 97 | name='actor1', 98 | trainable=trainable, 99 | **initKernelAndBias 100 | ) 101 | actor2 = tf.layers.dense( 102 | inputs=actor1, 103 | units=64, 104 | activation=self.activation_fn, 105 | name='actor2', 106 | trainable=trainable, 107 | **initKernelAndBias 108 | ) 109 | mu = tf.layers.dense( 110 | inputs=actor2, 111 | units=self.a_counts, 112 | activation=tf.nn.tanh, 113 | name='mu', 114 | trainable=trainable, 115 | **initKernelAndBias 116 | ) 117 | e = tf.random_normal(tf.shape(mu)) 118 | action = tf.clip_by_value( 119 | mu + e, -self.action_bound, self.action_bound) / self.action_bound 120 | var = tf.get_variable_scope().global_variables() 121 | return mu, action, var 122 | 123 | def _build_q_net(self, name, input_vector, trainable, reuse=False): 124 | with tf.variable_scope(name): 125 | layer1 = tf.layers.dense( 126 | inputs=input_vector, 127 | units=256, 128 | activation=self.activation_fn, 129 | name='layer1', 130 | trainable=trainable, 131 | reuse=reuse, 132 | **initKernelAndBias 133 | ) 134 | layer2 = tf.layers.dense( 135 | inputs=layer1, 136 | units=256, 137 | activation=self.activation_fn, 138 | name='layer2', 139 | trainable=trainable, 140 | reuse=reuse, 141 | **initKernelAndBias 142 | ) 143 | q1 = tf.layers.dense( 144 | inputs=layer2, 145 | units=1, 146 | activation=None, 147 | name='q_value', 148 | trainable=trainable, 149 | reuse=reuse, 150 | **initKernelAndBias 151 | ) 152 | var = tf.get_variable_scope().global_variables() 153 | return q1, var 154 | 155 | def decay_lr(self, episode, **kargs): 156 | return self.sess.run(self.lr, feed_dict={ 157 | self.episode: episode 158 | }) 159 | 160 | def choose_action(self, s, **kargs): 161 | return np.ones((s.shape[0], self.a_counts)), self.sess.run(self.action, feed_dict={ 162 | self.s: s 163 | }) 164 | 165 | def choose_inference_action(self, s, **kargs): 166 | return np.ones((s.shape[0], self.a_counts)), self.sess.run(self.mu, feed_dict={ 167 | self.s: s 168 | }) 169 | 170 | def get_state_value(self, s, **kargs): 171 | # return np.squeeze( 172 | # self.sess.run(self.q_target, feed_dict={ 173 | # self.s: s 174 | # })) 175 | return np.squeeze(np.zeros(np.array(s).shape[0])) 176 | 177 | def learn(self, s, a, r, s_, episode, **kargs): 178 | self.sess.run(self.train_value, feed_dict={ 179 | self.s: s, 180 | self.a: a, 181 | self.r: r, 182 | self.s_: s_, 183 | self.episode: episode 184 | }) 185 | self.sess.run([self.train_value, self.train_actor, self.assign_q1_target, self.assign_q2_target, self.assign_actor_target], feed_dict={ 186 | self.s: s, 187 | self.a: a, 188 | self.r: r, 189 | self.s_: s_, 190 | self.episode: episode 191 | }) 192 | 193 | def get_actor_loss(self, s, **kargs): 194 | return self.sess.run(self.actor_loss, feed_dict={ 195 | self.s: s 196 | }) 197 | 198 | def get_critic_loss(self, s, a, r, s_, **kargs): 199 | return self.sess.run(self.critic_loss, feed_dict={ 200 | self.s: s, 201 | self.a: a, 202 | self.r: r, 203 | self.s_: s_, 204 | }) 205 | 206 | def get_entropy(self, s, **kargs): 207 | return np.zeros((np.array(s).shape[0], self.a_counts)) 208 | 209 | def get_sigma(self, s, **kargs): 210 | return np.zeros((np.array(s).shape[0], self.a_counts)) 211 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import platform 2 | print(platform.platform()) 3 | print(platform.system()) 4 | print(platform.python_version()) -------------------------------------------------------------------------------- /utils/recorder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import tensorflow as tf 3 | import pandas as pd 4 | from pymongo import MongoClient 5 | 6 | 7 | class Recorder(object): 8 | def __init__(self, log_dir, excel_dir, record_config, logger, max_to_keep=5, pad_step_number=True, graph=None): 9 | self.saver = tf.train.Saver( 10 | max_to_keep=max_to_keep, pad_step_number=pad_step_number) 11 | self.writer = tf.summary.FileWriter(log_dir, graph=graph) 12 | self.excel_writer = pd.ExcelWriter(excel_dir + '/data.xlsx') 13 | self.mongocon = MongoClient('127.0.0.1', 27017) 14 | self.mongodb = self.mongocon[ 15 | record_config['project_name'] + r'_' 16 | + record_config['remark'] 17 | + record_config['run_id'] 18 | ] 19 | self.logger = logger 20 | 21 | def close(self): 22 | self.mongocon.close() 23 | -------------------------------------------------------------------------------- /utils/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class ReplayBuffer(object): 4 | def __init__(self, s_dim, a_counts, buffer_size, use_priority=False): 5 | self.use_priority = use_priority 6 | self.state = np.zeros([buffer_size, s_dim], dtype=np.float32) 7 | self.next_state = np.zeros([buffer_size, s_dim], dtype=np.float32) 8 | self.action = np.zeros([buffer_size, a_counts], dtype=np.float32) 9 | self.prob = np.zeros([buffer_size, a_counts], dtype=np.float32) 10 | self.reward = np.zeros(buffer_size, dtype=np.float32) 11 | self.discounted_reward = np.zeros(buffer_size, dtype=np.float32) 12 | self.sum_tree_n = np.int32(np.ceil(np.log2(buffer_size))) + 1 13 | if self.use_priority: 14 | self.td_error = [np.zeros(np.int32( 15 | np.ceil(buffer_size / np.power(2, i)))) for i in range(self.sum_tree_n)] 16 | else: 17 | self.td_error = np.zeros(buffer_size, dtype=np.float32) 18 | self.advantage = np.zeros(buffer_size, dtype=np.float32) 19 | self.done = np.zeros(buffer_size, dtype=np.float32) 20 | self.now, self.buffer_size, self.max_size = 0, 0, buffer_size 21 | 22 | def store(self, state, action, prob, reward, discounted_reward, td_error, advantage, next_state, done): 23 | self.state[self.now] = state 24 | self.next_state[self.now] = next_state 25 | self.action[self.now] = action 26 | self.prob[self.now] = prob 27 | self.reward[self.now] = reward 28 | self.discounted_reward[self.now] = discounted_reward 29 | if self.use_priority: 30 | diff = (np.abs(td_error) if np.abs(td_error) > np.max( 31 | self.td_error[0]) else np.max(self.td_error[0]) + 1) - self.td_error[0][self.now] 32 | for i in range(self.sum_tree_n): 33 | self.td_error[i][self.now // np.power(2, i)] += diff 34 | else: 35 | self.td_error[self.now] = td_error 36 | self.advantage[self.now] = advantage 37 | self.done[self.now] = done 38 | self.now = (self.now + 1) % self.max_size 39 | self.buffer_size = min(self.buffer_size + 1, self.max_size) 40 | 41 | def sample_batch(self, batch_size=32): 42 | if self.use_priority: 43 | temp_indexs = np.random.random_sample( 44 | batch_size) * self.td_error[-1][0] 45 | indexs = np.zeros(batch_size, dtype=np.int32) 46 | for index, i in enumerate(temp_indexs): 47 | k = 0 48 | for j in reversed(range(self.sum_tree_n - 1)): 49 | k *= 2 50 | if self.td_error[j][k] < i: 51 | i -= self.td_error[j][k] 52 | k += 1 53 | indexs[index] = k 54 | else: 55 | indexs = np.random.randint(0, self.buffer_size, size=batch_size) 56 | return indexs, dict( 57 | state=self.state[indexs], 58 | next_state=self.next_state[indexs], 59 | action=self.action[indexs], 60 | old_prob=self.prob[indexs], 61 | reward=self.reward[indexs], 62 | discounted_reward=self.discounted_reward[indexs], 63 | td_error=self.td_error[0][indexs] if self.use_priority else self.td_error[indexs], 64 | weights=np.power(self.td_error[-1][0] / (self.buffer_size * self.td_error[0][indexs]), 0.04) / np.max( 65 | self.td_error[0]) if self.use_priority else np.zeros(batch_size), 66 | advantage=self.advantage[indexs], 67 | done=self.done[indexs] 68 | ) 69 | 70 | def update(self, indexs, td_error): 71 | for i, j in zip(indexs, td_error): 72 | if self.use_priority: 73 | diff = np.abs(j) - self.td_error[0][i] 74 | for k in range(self.sum_tree_n): 75 | self.td_error[k][i // np.power(2, k)] += diff 76 | else: 77 | self.td_error[i] = td_error 78 | -------------------------------------------------------------------------------- /utils/sth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | class sth(object): 7 | @staticmethod 8 | def discounted_sum(x, gamma, init_value, done_index, length=0, z=None): 9 | """ 10 | x: list of data 11 | gamma: discounted factor 12 | init_value: the initiate value of the last data 13 | return: a list of discounted numbers 14 | 15 | Examples: 16 | x: [1, 2, 3] 17 | gamma: 0.5 18 | init_value: 0.0 19 | return: [2.75, 3.5, 3] 20 | """ 21 | y = [] 22 | if length == 0: 23 | length = len(x) 24 | if len(done_index) > 0: 25 | for i in reversed(range(length)): 26 | for j in reversed(done_index): 27 | if i == j: 28 | init_value = 0 29 | init_value = init_value * gamma + x[i] 30 | y.append(init_value) 31 | if z is not None: 32 | init_value = z[i] 33 | else: 34 | for i in reversed(range(length)): 35 | init_value = init_value * gamma + x[i] 36 | y.append(init_value) 37 | if z is not None: 38 | init_value = z[i] 39 | y.reverse() 40 | return y 41 | 42 | @staticmethod 43 | def discounted_sum_minus(x, gamma, init_value, done_index, z, length=0): 44 | """ 45 | x: list of data 46 | gamma: discounted factor 47 | init_value: the initiate value of the last data 48 | return: a list of discounted numbers 49 | : 50 | Examples: 51 | x: [1, 2, 3] 52 | gamma: 0.5 53 | init_value: 0.0 54 | z: [1, 2, 3] 55 | return: [1, 1.5, 0] 56 | """ 57 | y = [] 58 | if length == 0: 59 | length = len(x) 60 | if len(done_index) > 0: 61 | for i in reversed(range(length)): 62 | for j in reversed(done_index): 63 | if i == j: 64 | init_value = 0 65 | y.append(init_value * gamma + x[i] - z[i]) 66 | init_value = z[i] 67 | else: 68 | for i in reversed(range(length)): 69 | y.append(init_value * gamma + x[i] - z[i]) 70 | init_value = z[i] 71 | y.reverse() 72 | return y 73 | 74 | @staticmethod 75 | def get_discounted_sum(x, gamma, init_value, length=0): 76 | """ 77 | x: list of data 78 | gamma: discounted factor 79 | init_value: the initiate value of the last data 80 | return: the value of discounted sum, type 1-D 81 | : 82 | Examples: 83 | x: [1, 2, 3] 84 | gamma: 1 85 | init_value: 0.0 86 | return: 6 87 | """ 88 | if length == 0: 89 | length = len(x) 90 | for i in reversed(range(length)): 91 | init_value = init_value * gamma + x[i] 92 | return init_value 93 | 94 | @staticmethod 95 | def split_batchs(x, batchsize, length=0, cross=False, reverse=True): 96 | """ 97 | x: list of date 98 | batchsize: size of each block that be splited 99 | length: the last index of data 100 | cross: if TRUE, the index will minus 1, otherwise it will minus batchsize 101 | : 102 | Examples: 103 | x: [1, 2, 3, 4] 104 | batchsize: 2 105 | reverse: False 106 | cross: F return: [[1,2],[3,4]] 107 | cross: T return: [[1,2],[2,3],[3,4]] 108 | """ 109 | if length == 0: 110 | length = len(x) 111 | if length < batchsize: 112 | return [x] 113 | if reverse: 114 | if cross: 115 | return list(reversed([x[i:i+batchsize] for i in range(length-batchsize+1)])) 116 | else: 117 | return list(reversed([x[i:i+batchsize] for i in range(0, length, batchsize)])) 118 | else: 119 | if cross: 120 | return list([x[i:i+batchsize] for i in range(length-batchsize+1)]) 121 | else: 122 | return list([x[i:i+batchsize] for i in range(0, length, batchsize)]) 123 | 124 | @staticmethod 125 | def check_or_create(dicpath, name=''): 126 | if not os.path.exists(dicpath): 127 | os.makedirs(dicpath) 128 | print(f'create {name} directionary :', dicpath) 129 | @staticmethod 130 | def save_config(filename, config): 131 | fw = open(os.path.join(filename, 'config.yaml'), 'w', encoding='utf-8') 132 | yaml.dump(config, fw) 133 | fw.close() 134 | print(f'save config to {filename}') 135 | @staticmethod 136 | def load_config(filename): 137 | f = open(os.path.join(filename, 'config.yaml'), 'r', encoding='utf-8') 138 | x = yaml.safe_load(f.read()) 139 | f.close() 140 | print(f'load config from {filename}') 141 | return x --------------------------------------------------------------------------------