├── .gitignore ├── LICENSE ├── NS3-VERSION ├── README.md ├── VERSION ├── doc ├── figures │ ├── cognitive-radio-learning.png │ └── interferer-pattern.png └── opengym.rst ├── examples ├── interference-pattern │ ├── cognitive-agent-v1.py │ ├── learning.pdf │ ├── mygym.cc │ ├── mygym.h │ ├── sim.cc │ └── simple_test.py ├── linear-mesh-2 │ ├── mygym.cc │ ├── mygym.h │ ├── sim.cc │ ├── simple_test.py │ └── test.py ├── linear-mesh │ ├── backpressure_v1.py │ ├── dqn-agent-v1.py │ ├── dqn-agent-v2.py │ ├── my_random.py │ ├── my_random2.py │ ├── no_op.py │ ├── no_op2.py │ ├── qfull.py │ ├── qfull_working.py │ ├── qlearn.py │ ├── qlearn_full.py │ └── sim.cc ├── multi-agent │ ├── agent1.py │ ├── agent2.py │ ├── mygym.cc │ ├── mygym.h │ ├── readme.md │ └── sim.cc ├── multigym │ ├── .idea │ │ ├── .gitignore │ │ ├── inspectionProfiles │ │ │ └── profiles_settings.xml │ │ ├── misc.xml │ │ ├── modules.xml │ │ ├── multigym.iml │ │ └── vcs.xml │ ├── mygym.cc │ ├── mygym.h │ ├── sim.cc │ └── test.py ├── opengym-2 │ ├── mygym.cc │ ├── mygym.h │ ├── sim.cc │ ├── simple_test.py │ └── test.py ├── opengym │ ├── sim.cc │ ├── simple_test.py │ └── test.py ├── rl-tcp │ ├── sim.cc │ ├── tcp-rl-env.cc │ ├── tcp-rl-env.h │ ├── tcp-rl.cc │ ├── tcp-rl.h │ ├── tcp_base.py │ ├── tcp_newreno.py │ ├── test.py │ └── test_tcp.py └── wscript ├── helper ├── opengym-helper.cc └── opengym-helper.h ├── model-single-agent ├── container.cc ├── container.h ├── messages.proto ├── ns3gym │ ├── LICENSE │ ├── MANIFEST.in │ ├── README.md │ ├── ns3gym │ │ ├── __init__.py │ │ ├── ns3env.py │ │ └── start_sim.py │ ├── requirements.txt │ └── setup.py ├── opengym_env.cc ├── opengym_env.h ├── opengym_interface.cc ├── opengym_interface.h ├── spaces.cc └── spaces.h ├── model ├── __init__.py ├── container.cc ├── container.h ├── messages.proto ├── ns3gym │ ├── .idea │ │ ├── .gitignore │ │ ├── inspectionProfiles │ │ │ └── profiles_settings.xml │ │ ├── misc.xml │ │ ├── modules.xml │ │ ├── ns3gym.iml │ │ └── vcs.xml │ ├── LICENSE │ ├── MANIFEST.in │ ├── README.md │ ├── __init__.py │ ├── ns3gym │ │ ├── __init__.py │ │ ├── ns3_multiagent_env.py │ │ ├── ns3env.py │ │ └── start_sim.py │ ├── requirements.txt │ └── setup.py ├── opengym_env.cc ├── opengym_env.h ├── opengym_interface.cc ├── opengym_interface.h ├── opengym_multi_env.cc ├── opengym_multi_env.h ├── opengym_multi_interface.cc ├── opengym_multi_interface.h ├── spaces.cc └── spaces.h ├── test └── opengym-test-suite.cc └── wscript /.gitignore: -------------------------------------------------------------------------------- 1 | *.diff 2 | *.orig 3 | *.patch 4 | *.rej 5 | 6 | *.cwnd 7 | *.dat 8 | *.log 9 | *.mob 10 | *.pcap 11 | *.plt 12 | *.routes 13 | *.tr 14 | [D|U]l[A-Z][a-z]*Stats.txt 15 | seventh-packet-byte-count.png 16 | 17 | \#*# 18 | ~* 19 | 20 | testpy-output 21 | 22 | bindings/python/pybindgen/ 23 | 24 | ms_print.* 25 | massif.* 26 | coverity 27 | TAGS 28 | 29 | .lock-waf_*_build 30 | .waf* 31 | 32 | build-dir/ 33 | build/ 34 | /.cproject 35 | /.project 36 | 37 | # Protobuf files 38 | *.pb.cc 39 | *.pb.h 40 | *_pb2.py 41 | 42 | # Created by https://www.gitignore.io/api/python 43 | # Edit at https://www.gitignore.io/?templates=python 44 | 45 | ### Python ### 46 | # Byte-compiled / optimized / DLL files 47 | __pycache__/ 48 | *.py[cod] 49 | *$py.class 50 | 51 | # C extensions 52 | *.so 53 | 54 | # Distribution / packaging 55 | .Python 56 | build/ 57 | develop-eggs/ 58 | dist/ 59 | downloads/ 60 | eggs/ 61 | .eggs/ 62 | lib/ 63 | lib64/ 64 | parts/ 65 | sdist/ 66 | var/ 67 | wheels/ 68 | *.egg-info/ 69 | .installed.cfg 70 | *.egg 71 | MANIFEST 72 | 73 | # PyInstaller 74 | # Usually these files are written by a python script from a template 75 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 76 | *.manifest 77 | *.spec 78 | 79 | # Installer logs 80 | pip-log.txt 81 | pip-delete-this-directory.txt 82 | 83 | # Unit test / coverage reports 84 | htmlcov/ 85 | .tox/ 86 | .nox/ 87 | .coverage 88 | .coverage.* 89 | .cache 90 | nosetests.xml 91 | coverage.xml 92 | *.cover 93 | .hypothesis/ 94 | .pytest_cache/ 95 | 96 | # Translations 97 | *.mo 98 | *.pot 99 | 100 | # Django stuff: 101 | *.log 102 | local_settings.py 103 | db.sqlite3 104 | 105 | # Flask stuff: 106 | instance/ 107 | .webassets-cache 108 | 109 | # Scrapy stuff: 110 | .scrapy 111 | 112 | # Sphinx documentation 113 | docs/_build/ 114 | 115 | # PyBuilder 116 | target/ 117 | 118 | # Jupyter Notebook 119 | .ipynb_checkpoints 120 | 121 | # IPython 122 | profile_default/ 123 | ipython_config.py 124 | 125 | # pyenv 126 | .python-version 127 | 128 | # celery beat schedule file 129 | celerybeat-schedule 130 | 131 | # SageMath parsed files 132 | *.sage.py 133 | 134 | # Environments 135 | .env 136 | .venv 137 | env/ 138 | venv/ 139 | ENV/ 140 | env.bak/ 141 | venv.bak/ 142 | 143 | # Spyder project settings 144 | .spyderproject 145 | .spyproject 146 | 147 | # Rope project settings 148 | .ropeproject 149 | 150 | # mkdocs documentation 151 | /site 152 | 153 | # mypy 154 | .mypy_cache/ 155 | .dmypy.json 156 | dmypy.json 157 | 158 | # Pyre type checker 159 | .pyre/ 160 | 161 | ### Python Patch ### 162 | .venv/ 163 | 164 | ### Python.VirtualEnv Stack ### 165 | # Virtualenv 166 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ 167 | [Bb]in 168 | [Ii]nclude 169 | [Ll]ib 170 | [Ll]ib64 171 | [Ll]ocal 172 | [Ss]cripts 173 | pyvenv.cfg 174 | pip-selfcheck.json 175 | 176 | # End of https://www.gitignore.io/api/python 177 | -------------------------------------------------------------------------------- /NS3-VERSION: -------------------------------------------------------------------------------- 1 | release ns-3.29 or later 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ns3-gym for multi-agent 2 | 3 | MultiEnv is an extension of [ns3-gym](https://github.com/tkn-tub/ns3-gym), so that the nodes in the network can be completely regarded as independent agents, which have their own states, observations, and rewards. 4 | 5 | NOTE: We formalize the network problem as a multi-agent extension Markov decision processes (MDPs) called Partially Observable Markov Games (POMGs). 6 | 7 | **Why we use a multi-agent environment?** 8 | 9 | **Fully-distributed learning:** Algorithms with centralized learning process are not applicable in the real computer network. The centralized learning controller is usually unable to gather collected environment transitions from widely distributed routers once an action is executed somewhere and to update the parameters of each neural network simultaneously caused by the limited bandwidth. [FROM: You, Xinyu, et al. "Toward Packet Routing with Fully-distributed Multi-agent Deep Reinforcement Learning." arXiv preprint arXiv:1905.03494 (2019).] 10 | 11 | Code Reference: Lowe R, Wu Y, Tamar A, et al. Multi-agent actor-critic for mixed cooperative-competitive environments[C]//Advances in neural information processing systems. 2017: 6379-6390. https://github.com/openai/multiagent-particle-envs 12 | 13 | ## How (Python) 14 | ```python 15 | from ns3gym import ns3_multiagent_env as ns3env 16 | ``` 17 | 18 | ## How (ns-3) 19 | ```C++ 20 | ///\{ Each agent OpenGym Env 21 | virtual Ptr GetActionSpace(uint32_t agent_id) = 0; 22 | virtual Ptr GetObservationSpace(uint32_t agent_id) = 0; 23 | virtual Ptr GetObservation(uint32_t agent_id) = 0; 24 | virtual float GetReward(uint32_t agent_id) = 0; 25 | virtual bool GetDone(uint32_t agent_id) = 0; 26 | virtual std::string GetInfo(uint32_t agent_id) = 0; 27 | virtual bool ExecuteActions(uint32_t agent_id, Ptr action) = 0; 28 | ///\} 29 | ``` 30 | ## Example 31 | ### multi-agent example 1 32 | This example shows how to create an ns3-gym environment with multiple agents in one Python processes. Similar to the 33 | [multiagent-particle-envs](https://github.com/openai/multiagent-particle-envs) 34 | 35 | [multi-agent example 1](https://github.com/zhangmwg/ns3-gym-multiagent/tree/master/examples/multigym) 36 | 37 | ### multi-agent example 2 38 | This example shows how to create an ns3-gym environment with multiple agents and connects them to multiple independent Python processes. 39 | 40 | [multi-agent example 2](https://github.com/zhangmwg/ns3-gym-multiagent/tree/master/examples/multi-agent) 41 | 42 | ns3-gym 43 | ============ 44 | 45 | [OpenAI Gym](https://gym.openai.com/) is a toolkit for reinforcement learning (RL) widely used in research. The network simulator [ns-3](https://www.nsnam.org/) is the de-facto standard for academic and industry studies in the areas of networking protocols and communication technologies. [ns3-gym](https://github.com/tkn-tub/ns3-gym) is a framework that integrates both OpenAI Gym and ns-3 in order to encourage usage of RL in networking research. 46 | 47 | Installation 48 | ============ 49 | 50 | https://github.com/tkn-tub/ns3-gym 51 | 52 | How to reference ns3-gym? 53 | ============ 54 | 55 | Please use the following bibtex : 56 | 57 | ``` 58 | @inproceedings{ns3gym, 59 | Title = {{ns-3 meets OpenAI Gym: The Playground for Machine Learning in Networking Research}}, 60 | Author = {Gaw{\l}owicz, Piotr and Zubow, Anatolij}, 61 | Booktitle = {{ACM International Conference on Modeling, Analysis and Simulation of Wireless and Mobile Systems (MSWiM)}}, 62 | Year = {2019}, 63 | Location = {Miami Beach, USA}, 64 | Month = {November}, 65 | Url = {http://www.tkn.tu-berlin.de/fileadmin/fg112/Papers/2019/gawlowicz19_mswim.pdf} 66 | } 67 | ``` 68 | 69 | ``` 70 | @article{ns3gym, 71 | author = {Gawlowicz, Piotr and Zubow, Anatolij}, 72 | title = {{ns3-gym: Extending OpenAI Gym for Networking Research}}, 73 | journal = {CoRR}, 74 | year = {2018}, 75 | url = {https://arxiv.org/abs/1810.03943}, 76 | archivePrefix = {arXiv}, 77 | } 78 | ``` 79 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | release 1.1.0 2 | -------------------------------------------------------------------------------- /doc/figures/cognitive-radio-learning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangmwg/ns3-gym-multiagent/e7927cf70daae2001f996a76498251235c9338de/doc/figures/cognitive-radio-learning.png -------------------------------------------------------------------------------- /doc/figures/interferer-pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangmwg/ns3-gym-multiagent/e7927cf70daae2001f996a76498251235c9338de/doc/figures/interferer-pattern.png -------------------------------------------------------------------------------- /doc/opengym.rst: -------------------------------------------------------------------------------- 1 | Example Module Documentation 2 | ---------------------------- 3 | 4 | .. include:: replace.txt 5 | .. highlight:: cpp 6 | 7 | .. heading hierarchy: 8 | ------------- Chapter 9 | ************* Section (#.#) 10 | ============= Subsection (#.#.#) 11 | ############# Paragraph (no number) 12 | 13 | This is a suggested outline for adding new module documentation to |ns3|. 14 | See ``src/click/doc/click.rst`` for an example. 15 | 16 | The introductory paragraph is for describing what this code is trying to 17 | model. 18 | 19 | For consistency (italicized formatting), please use |ns3| to refer to 20 | ns-3 in the documentation (and likewise, |ns2| for ns-2). These macros 21 | are defined in the file ``replace.txt``. 22 | 23 | Model Description 24 | ***************** 25 | 26 | The source code for the new module lives in the directory ``src/opengym``. 27 | 28 | Add here a basic description of what is being modeled. 29 | 30 | Design 31 | ====== 32 | 33 | Briefly describe the software design of the model and how it fits into 34 | the existing ns-3 architecture. 35 | 36 | Scope and Limitations 37 | ===================== 38 | 39 | What can the model do? What can it not do? Please use this section to 40 | describe the scope and limitations of the model. 41 | 42 | References 43 | ========== 44 | 45 | Add academic citations here, such as if you published a paper on this 46 | model, or if readers should read a particular specification or other work. 47 | 48 | Usage 49 | ***** 50 | 51 | This section is principally concerned with the usage of your model, using 52 | the public API. Focus first on most common usage patterns, then go 53 | into more advanced topics. 54 | 55 | Building New Module 56 | =================== 57 | 58 | Include this subsection only if there are special build instructions or 59 | platform limitations. 60 | 61 | Helpers 62 | ======= 63 | 64 | What helper API will users typically use? Describe it here. 65 | 66 | Attributes 67 | ========== 68 | 69 | What classes hold attributes, and what are the key ones worth mentioning? 70 | 71 | Output 72 | ====== 73 | 74 | What kind of data does the model generate? What are the key trace 75 | sources? What kind of logging output can be enabled? 76 | 77 | Advanced Usage 78 | ============== 79 | 80 | Go into further details (such as using the API outside of the helpers) 81 | in additional sections, as needed. 82 | 83 | Examples 84 | ======== 85 | 86 | What examples using this new code are available? Describe them here. 87 | 88 | Troubleshooting 89 | =============== 90 | 91 | Add any tips for avoiding pitfalls, etc. 92 | 93 | Validation 94 | ********** 95 | 96 | Describe how the model has been tested/validated. What tests run in the 97 | test suite? How much API and code is covered by the tests? Again, 98 | references to outside published work may help here. 99 | -------------------------------------------------------------------------------- /examples/interference-pattern/cognitive-agent-v1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import gym 5 | import tensorflow as tf 6 | import tensorflow.contrib.slim as slim 7 | import numpy as np 8 | import matplotlib as mpl 9 | import matplotlib.pyplot as plt 10 | from tensorflow import keras 11 | from ns3gym import ns3env 12 | 13 | env = gym.make('ns3-v0') 14 | ob_space = env.observation_space 15 | ac_space = env.action_space 16 | print("Observation space: ", ob_space, ob_space.dtype) 17 | print("Action space: ", ac_space, ac_space.n) 18 | 19 | s_size = ob_space.shape[0] 20 | a_size = ac_space.n 21 | model = keras.Sequential() 22 | model.add(keras.layers.Dense(s_size, input_shape=(s_size,), activation='relu')) 23 | model.add(keras.layers.Dense(a_size, activation='softmax')) 24 | model.compile(optimizer=tf.train.AdamOptimizer(0.001), 25 | loss='categorical_crossentropy', 26 | metrics=['accuracy']) 27 | 28 | total_episodes = 200 29 | max_env_steps = 100 30 | env._max_episode_steps = max_env_steps 31 | 32 | epsilon = 1.0 # exploration rate 33 | epsilon_min = 0.01 34 | epsilon_decay = 0.999 35 | 36 | time_history = [] 37 | rew_history = [] 38 | 39 | for e in range(total_episodes): 40 | 41 | state = env.reset() 42 | state = np.reshape(state, [1, s_size]) 43 | rewardsum = 0 44 | for time in range(max_env_steps): 45 | 46 | # Choose action 47 | if np.random.rand(1) < epsilon: 48 | action = np.random.randint(a_size) 49 | else: 50 | action = np.argmax(model.predict(state)[0]) 51 | 52 | # Step 53 | next_state, reward, done, _ = env.step(action) 54 | 55 | if done: 56 | print("episode: {}/{}, time: {}, rew: {}, eps: {:.2}" 57 | .format(e, total_episodes, time, rewardsum, epsilon)) 58 | break 59 | 60 | next_state = np.reshape(next_state, [1, s_size]) 61 | 62 | # Train 63 | target = reward 64 | if not done: 65 | target = (reward + 0.95 * np.amax(model.predict(next_state)[0])) 66 | 67 | target_f = model.predict(state) 68 | target_f[0][action] = target 69 | model.fit(state, target_f, epochs=1, verbose=0) 70 | 71 | state = next_state 72 | rewardsum += reward 73 | if epsilon > epsilon_min: epsilon *= epsilon_decay 74 | 75 | time_history.append(time) 76 | rew_history.append(rewardsum) 77 | 78 | #for n in range(2 ** s_size): 79 | # state = [n >> i & 1 for i in range(0, 2)] 80 | # state = np.reshape(state, [1, s_size]) 81 | # print("state " + str(state) 82 | # + " -> prediction " + str(model.predict(state)[0]) 83 | # ) 84 | 85 | #print(model.get_config()) 86 | #print(model.to_json()) 87 | #print(model.get_weights()) 88 | 89 | print("Plot Learning Performance") 90 | mpl.rcdefaults() 91 | mpl.rcParams.update({'font.size': 16}) 92 | 93 | fig, ax = plt.subplots(figsize=(10,4)) 94 | plt.grid(True, linestyle='--') 95 | plt.title('Learning Performance') 96 | plt.plot(range(len(time_history)), time_history, label='Steps', marker="^", linestyle=":")#, color='red') 97 | plt.plot(range(len(rew_history)), rew_history, label='Reward', marker="", linestyle="-")#, color='k') 98 | plt.xlabel('Episode') 99 | plt.ylabel('Time') 100 | plt.legend(prop={'size': 12}) 101 | 102 | plt.savefig('learning.pdf', bbox_inches='tight') 103 | plt.show() 104 | -------------------------------------------------------------------------------- /examples/interference-pattern/learning.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangmwg/ns3-gym-multiagent/e7927cf70daae2001f996a76498251235c9338de/examples/interference-pattern/learning.pdf -------------------------------------------------------------------------------- /examples/interference-pattern/mygym.cc: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Technische Universität Berlin 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | */ 20 | 21 | #include "mygym.h" 22 | #include "ns3/object.h" 23 | #include "ns3/core-module.h" 24 | #include "ns3/wifi-module.h" 25 | #include "ns3/node-list.h" 26 | #include "ns3/log.h" 27 | #include 28 | #include 29 | 30 | namespace ns3 { 31 | 32 | NS_LOG_COMPONENT_DEFINE ("MyGymEnv"); 33 | 34 | NS_OBJECT_ENSURE_REGISTERED (MyGymEnv); 35 | 36 | MyGymEnv::MyGymEnv () 37 | { 38 | NS_LOG_FUNCTION (this); 39 | m_currentNode = 0; 40 | m_currentChannel = 0; 41 | m_collisionTh = 3; 42 | m_channelNum = 1; 43 | m_channelOccupation.clear(); 44 | } 45 | 46 | MyGymEnv::MyGymEnv (uint32_t channelNum) 47 | { 48 | NS_LOG_FUNCTION (this); 49 | m_currentNode = 0; 50 | m_currentChannel = 0; 51 | m_collisionTh = 3; 52 | m_channelNum = channelNum; 53 | m_channelOccupation.clear(); 54 | } 55 | 56 | MyGymEnv::~MyGymEnv () 57 | { 58 | NS_LOG_FUNCTION (this); 59 | } 60 | 61 | TypeId 62 | MyGymEnv::GetTypeId (void) 63 | { 64 | static TypeId tid = TypeId ("MyGymEnv") 65 | .SetParent () 66 | .SetGroupName ("OpenGym") 67 | .AddConstructor () 68 | ; 69 | return tid; 70 | } 71 | 72 | void 73 | MyGymEnv::DoDispose () 74 | { 75 | NS_LOG_FUNCTION (this); 76 | } 77 | 78 | Ptr 79 | MyGymEnv::GetActionSpace() 80 | { 81 | NS_LOG_FUNCTION (this); 82 | Ptr space = CreateObject (m_channelNum); 83 | NS_LOG_UNCOND ("GetActionSpace: " << space); 84 | return space; 85 | } 86 | 87 | Ptr 88 | MyGymEnv::GetObservationSpace() 89 | { 90 | NS_LOG_FUNCTION (this); 91 | float low = 0.0; 92 | float high = 1.0; 93 | std::vector shape = {m_channelNum,}; 94 | std::string dtype = TypeNameGet (); 95 | Ptr space = CreateObject (low, high, shape, dtype); 96 | NS_LOG_UNCOND ("GetObservationSpace: " << space); 97 | return space; 98 | } 99 | 100 | bool 101 | MyGymEnv::GetGameOver() 102 | { 103 | NS_LOG_FUNCTION (this); 104 | bool isGameOver = false; 105 | 106 | uint32_t collisionNum = 0; 107 | for (auto& v : m_collisions) 108 | collisionNum += v; 109 | 110 | if (collisionNum >= m_collisionTh){ 111 | isGameOver = true; 112 | } 113 | NS_LOG_UNCOND ("MyGetGameOver: " << isGameOver); 114 | return isGameOver; 115 | } 116 | 117 | Ptr 118 | MyGymEnv::GetObservation() 119 | { 120 | NS_LOG_FUNCTION (this); 121 | std::vector shape = {m_channelNum,}; 122 | Ptr > box = CreateObject >(shape); 123 | 124 | for (uint32_t i = 0; i < m_channelOccupation.size(); ++i) { 125 | uint32_t value = m_channelOccupation.at(i); 126 | box->AddValue(value); 127 | } 128 | 129 | NS_LOG_UNCOND ("MyGetObservation: " << box); 130 | return box; 131 | } 132 | 133 | float 134 | MyGymEnv::GetReward() 135 | { 136 | NS_LOG_FUNCTION (this); 137 | float reward = 1.0; 138 | if (m_channelOccupation.size() == 0){ 139 | return 0.0; 140 | } 141 | uint32_t occupied = m_channelOccupation.at(m_currentChannel); 142 | if (occupied == 1) { 143 | reward = -1.0; 144 | m_collisions.erase(m_collisions.begin()); 145 | m_collisions.push_back(1); 146 | } else { 147 | m_collisions.erase(m_collisions.begin()); 148 | m_collisions.push_back(0); 149 | } 150 | NS_LOG_UNCOND ("MyGetReward: " << reward); 151 | return reward; 152 | } 153 | 154 | std::string 155 | MyGymEnv::GetExtraInfo() 156 | { 157 | NS_LOG_FUNCTION (this); 158 | std::string myInfo = "info"; 159 | NS_LOG_UNCOND("MyGetExtraInfo: " << myInfo); 160 | return myInfo; 161 | } 162 | 163 | bool 164 | MyGymEnv::ExecuteActions(Ptr action) 165 | { 166 | NS_LOG_FUNCTION (this); 167 | Ptr discrete = DynamicCast(action); 168 | uint32_t nextChannel = discrete->GetValue(); 169 | m_currentChannel = nextChannel; 170 | 171 | NS_LOG_UNCOND ("Current Channel: " << m_currentChannel); 172 | return true; 173 | } 174 | 175 | void 176 | MyGymEnv::CollectChannelOccupation(uint32_t chanId, uint32_t occupied) 177 | { 178 | NS_LOG_FUNCTION (this); 179 | m_channelOccupation.push_back(occupied); 180 | } 181 | 182 | bool 183 | MyGymEnv::CheckIfReady() 184 | { 185 | NS_LOG_FUNCTION (this); 186 | return m_channelOccupation.size() == m_channelNum; 187 | } 188 | 189 | void 190 | MyGymEnv::ClearObs() 191 | { 192 | NS_LOG_FUNCTION (this); 193 | m_channelOccupation.clear(); 194 | } 195 | 196 | void 197 | MyGymEnv::PerformCca (Ptr entity, uint32_t channelId, Ptr avgPowerSpectralDensity) 198 | { 199 | double power = Integral (*(avgPowerSpectralDensity)); 200 | double powerDbW = 10 * std::log10(power); 201 | double threshold = -60; 202 | uint32_t busy = powerDbW > threshold; 203 | NS_LOG_UNCOND("Channel: " << channelId << " CCA: " << busy << " RxPower: " << powerDbW); 204 | 205 | entity->CollectChannelOccupation(channelId, busy); 206 | 207 | if (entity->CheckIfReady()){ 208 | entity->Notify(); 209 | entity->ClearObs(); 210 | } 211 | } 212 | 213 | } // ns3 namespace -------------------------------------------------------------------------------- /examples/interference-pattern/mygym.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Technische Universität Berlin 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | */ 20 | 21 | 22 | #ifndef MY_GYM_ENTITY_H 23 | #define MY_GYM_ENTITY_H 24 | 25 | #include "ns3/stats-module.h" 26 | #include "ns3/opengym-module.h" 27 | #include "ns3/spectrum-module.h" 28 | 29 | namespace ns3 { 30 | 31 | class Node; 32 | class WifiMacQueue; 33 | class Packet; 34 | 35 | class MyGymEnv : public OpenGymEnv 36 | { 37 | public: 38 | MyGymEnv (); 39 | MyGymEnv (uint32_t channelNum); 40 | virtual ~MyGymEnv (); 41 | static TypeId GetTypeId (void); 42 | virtual void DoDispose (); 43 | 44 | Ptr GetActionSpace(); 45 | Ptr GetObservationSpace(); 46 | bool GetGameOver(); 47 | Ptr GetObservation(); 48 | float GetReward(); 49 | std::string GetExtraInfo(); 50 | bool ExecuteActions(Ptr action); 51 | 52 | // the function has to be static to work with MakeBoundCallback 53 | // that is why we pass pointer to MyGymEnv instance to be able to store the context (node, etc) 54 | static void PerformCca(Ptr entity, uint32_t channelId, Ptr avgPowerSpectralDensity); 55 | void CollectChannelOccupation(uint32_t chanId, uint32_t occupied); 56 | bool CheckIfReady(); 57 | void ClearObs(); 58 | 59 | private: 60 | void ScheduleNextStateRead(); 61 | Ptr GetQueue(Ptr node); 62 | bool SetCw(Ptr node, uint32_t cwMinValue=0, uint32_t cwMaxValue=0); 63 | 64 | Time m_interval = Seconds(0.1); 65 | Ptr m_currentNode; 66 | uint64_t m_rxPktNum; 67 | uint32_t m_channelNum; 68 | std::vector m_channelOccupation; 69 | uint32_t m_currentChannel; 70 | 71 | uint32_t m_collisionTh; 72 | std::vector m_collisions = {0,0,0,0,0,0,0,0,0,0,}; 73 | }; 74 | 75 | } 76 | 77 | 78 | #endif // MY_GYM_ENTITY_H 79 | -------------------------------------------------------------------------------- /examples/interference-pattern/simple_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import gym 5 | import argparse 6 | from ns3gym import ns3env 7 | 8 | __author__ = "Piotr Gawlowicz" 9 | __copyright__ = "Copyright (c) 2018, Technische Universität Berlin" 10 | __version__ = "0.1.0" 11 | __email__ = "gawlowicz@tkn.tu-berlin.de" 12 | 13 | 14 | env = gym.make('ns3-v0') 15 | env.reset() 16 | 17 | ob_space = env.observation_space 18 | ac_space = env.action_space 19 | print("Observation space: ", ob_space, ob_space.dtype) 20 | print("Action space: ", ac_space, ac_space.dtype) 21 | 22 | stepIdx = 0 23 | 24 | try: 25 | obs = env.reset() 26 | print("Step: ", stepIdx) 27 | print("---obs: ", obs) 28 | 29 | while True: 30 | stepIdx += 1 31 | 32 | action = env.action_space.sample() 33 | print("---action: ", action) 34 | obs, reward, done, info = env.step(action) 35 | 36 | print("Step: ", stepIdx) 37 | print("---obs, reward, done, info: ", obs, reward, done, info) 38 | 39 | if done: 40 | break 41 | 42 | except KeyboardInterrupt: 43 | print("Ctrl-C -> Exit") 44 | finally: 45 | env.close() 46 | print("Done") -------------------------------------------------------------------------------- /examples/linear-mesh-2/mygym.cc: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Technische Universität Berlin 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | */ 20 | 21 | #include "mygym.h" 22 | #include "ns3/object.h" 23 | #include "ns3/core-module.h" 24 | #include "ns3/wifi-module.h" 25 | #include "ns3/node-list.h" 26 | #include "ns3/log.h" 27 | #include 28 | #include 29 | 30 | namespace ns3 { 31 | 32 | NS_LOG_COMPONENT_DEFINE ("MyGymEnv"); 33 | 34 | NS_OBJECT_ENSURE_REGISTERED (MyGymEnv); 35 | 36 | MyGymEnv::MyGymEnv () 37 | { 38 | NS_LOG_FUNCTION (this); 39 | m_currentNode = 0; 40 | m_rxPktNum = 0; 41 | } 42 | 43 | MyGymEnv::MyGymEnv (Time stepTime) 44 | { 45 | NS_LOG_FUNCTION (this); 46 | m_currentNode = 0; 47 | m_rxPktNum = 0; 48 | m_interval = stepTime; 49 | 50 | Simulator::Schedule (Seconds(0.0), &MyGymEnv::ScheduleNextStateRead, this); 51 | } 52 | 53 | void 54 | MyGymEnv::ScheduleNextStateRead () 55 | { 56 | NS_LOG_FUNCTION (this); 57 | Simulator::Schedule (m_interval, &MyGymEnv::ScheduleNextStateRead, this); 58 | Notify(); 59 | } 60 | 61 | MyGymEnv::~MyGymEnv () 62 | { 63 | NS_LOG_FUNCTION (this); 64 | } 65 | 66 | TypeId 67 | MyGymEnv::GetTypeId (void) 68 | { 69 | static TypeId tid = TypeId ("MyGymEnv") 70 | .SetParent () 71 | .SetGroupName ("OpenGym") 72 | .AddConstructor () 73 | ; 74 | return tid; 75 | } 76 | 77 | void 78 | MyGymEnv::DoDispose () 79 | { 80 | NS_LOG_FUNCTION (this); 81 | } 82 | 83 | Ptr 84 | MyGymEnv::GetActionSpace() 85 | { 86 | NS_LOG_FUNCTION (this); 87 | uint32_t nodeNum = NodeList::GetNNodes (); 88 | float low = 0.0; 89 | float high = 100.0; 90 | std::vector shape = {nodeNum,}; 91 | std::string dtype = TypeNameGet (); 92 | Ptr space = CreateObject (low, high, shape, dtype); 93 | NS_LOG_UNCOND ("GetActionSpace: " << space); 94 | return space; 95 | } 96 | 97 | Ptr 98 | MyGymEnv::GetObservationSpace() 99 | { 100 | NS_LOG_FUNCTION (this); 101 | uint32_t nodeNum = NodeList::GetNNodes (); 102 | float low = 0.0; 103 | float high = 100.0; 104 | std::vector shape = {nodeNum,}; 105 | std::string dtype = TypeNameGet (); 106 | Ptr space = CreateObject (low, high, shape, dtype); 107 | NS_LOG_UNCOND ("GetObservationSpace: " << space); 108 | return space; 109 | } 110 | 111 | bool 112 | MyGymEnv::GetGameOver() 113 | { 114 | NS_LOG_FUNCTION (this); 115 | bool isGameOver = false; 116 | NS_LOG_UNCOND ("MyGetGameOver: " << isGameOver); 117 | return isGameOver; 118 | } 119 | 120 | Ptr 121 | MyGymEnv::GetQueue(Ptr node) 122 | { 123 | Ptr dev = node->GetDevice (0); 124 | Ptr wifi_dev = DynamicCast (dev); 125 | Ptr wifi_mac = wifi_dev->GetMac (); 126 | Ptr rmac = DynamicCast (wifi_mac); 127 | PointerValue ptr; 128 | rmac->GetAttribute ("Txop", ptr); 129 | Ptr txop = ptr.Get (); 130 | Ptr queue = txop->GetWifiMacQueue (); 131 | return queue; 132 | } 133 | 134 | Ptr 135 | MyGymEnv::GetObservation() 136 | { 137 | NS_LOG_FUNCTION (this); 138 | uint32_t nodeNum = NodeList::GetNNodes (); 139 | std::vector shape = {nodeNum,}; 140 | Ptr > box = CreateObject >(shape); 141 | 142 | for (NodeList::Iterator i = NodeList::Begin (); i != NodeList::End (); ++i) { 143 | Ptr node = *i; 144 | Ptr queue = GetQueue (node); 145 | uint32_t value = queue->GetNPackets(); 146 | box->AddValue(value); 147 | } 148 | 149 | NS_LOG_UNCOND ("MyGetObservation: " << box); 150 | return box; 151 | } 152 | 153 | float 154 | MyGymEnv::GetReward() 155 | { 156 | NS_LOG_FUNCTION (this); 157 | static float lastValue = 0.0; 158 | float reward = m_rxPktNum - lastValue; 159 | lastValue = m_rxPktNum; 160 | NS_LOG_UNCOND ("MyGetReward: " << reward); 161 | return reward; 162 | } 163 | 164 | std::string 165 | MyGymEnv::GetExtraInfo() 166 | { 167 | NS_LOG_FUNCTION (this); 168 | std::string myInfo = "currentNodeId"; 169 | myInfo += "="; 170 | if (m_currentNode) { 171 | myInfo += std::to_string(m_currentNode->GetId()); 172 | } 173 | NS_LOG_UNCOND("MyGetExtraInfo: " << myInfo); 174 | return myInfo; 175 | } 176 | 177 | bool 178 | MyGymEnv::SetCw(Ptr node, uint32_t cwMinValue, uint32_t cwMaxValue) 179 | { 180 | Ptr dev = node->GetDevice (0); 181 | Ptr wifi_dev = DynamicCast (dev); 182 | Ptr wifi_mac = wifi_dev->GetMac (); 183 | Ptr rmac = DynamicCast (wifi_mac); 184 | PointerValue ptr; 185 | rmac->GetAttribute ("Txop", ptr); 186 | Ptr txop = ptr.Get (); 187 | 188 | // if both set to the same value then we have uniform backoff? 189 | if (cwMinValue != 0) { 190 | NS_LOG_DEBUG ("Set CW min: " << cwMinValue); 191 | txop->SetMinCw(cwMinValue); 192 | } 193 | 194 | if (cwMaxValue != 0) { 195 | NS_LOG_DEBUG ("Set CW max: " << cwMaxValue); 196 | txop->SetMaxCw(cwMaxValue); 197 | } 198 | return true; 199 | } 200 | 201 | bool 202 | MyGymEnv::ExecuteActions(Ptr action) 203 | { 204 | NS_LOG_FUNCTION (this); 205 | NS_LOG_UNCOND ("MyExecuteActions: " << action); 206 | Ptr > box = DynamicCast >(action); 207 | std::vector actionVector = box->GetData(); 208 | 209 | uint32_t nodeNum = NodeList::GetNNodes (); 210 | for (uint32_t i=0; i node = NodeList::GetNode(i); 213 | uint32_t cwSize = actionVector.at(i); 214 | SetCw(node, cwSize, cwSize); 215 | } 216 | 217 | return true; 218 | } 219 | 220 | void 221 | MyGymEnv::NotifyPktRxEvent(Ptr entity, Ptr node, Ptr packet) 222 | { 223 | NS_LOG_DEBUG ("Client received a packet of " << packet->GetSize () << " bytes"); 224 | entity->m_currentNode = node; 225 | entity->m_rxPktNum++; 226 | 227 | NS_LOG_UNCOND ("Node with ID " << entity->m_currentNode->GetId() << " received " << entity->m_rxPktNum << " packets"); 228 | 229 | entity->Notify(); 230 | } 231 | 232 | void 233 | MyGymEnv::CountRxPkts(Ptr entity, Ptr node, Ptr packet) 234 | { 235 | NS_LOG_DEBUG ("Client received a packet of " << packet->GetSize () << " bytes"); 236 | entity->m_currentNode = node; 237 | entity->m_rxPktNum++; 238 | } 239 | 240 | 241 | } // ns3 namespace -------------------------------------------------------------------------------- /examples/linear-mesh-2/mygym.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Technische Universität Berlin 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | */ 20 | 21 | 22 | #ifndef MY_GYM_ENTITY_H 23 | #define MY_GYM_ENTITY_H 24 | 25 | #include "ns3/stats-module.h" 26 | #include "ns3/opengym-module.h" 27 | 28 | namespace ns3 { 29 | 30 | class Node; 31 | class WifiMacQueue; 32 | class Packet; 33 | 34 | class MyGymEnv : public OpenGymEnv 35 | { 36 | public: 37 | MyGymEnv (); 38 | MyGymEnv (Time stepTime); 39 | virtual ~MyGymEnv (); 40 | static TypeId GetTypeId (void); 41 | virtual void DoDispose (); 42 | 43 | Ptr GetActionSpace(); 44 | Ptr GetObservationSpace(); 45 | bool GetGameOver(); 46 | Ptr GetObservation(); 47 | float GetReward(); 48 | std::string GetExtraInfo(); 49 | bool ExecuteActions(Ptr action); 50 | 51 | // the function has to be static to work with MakeBoundCallback 52 | // that is why we pass pointer to MyGymEnv instance to be able to store the context (node, etc) 53 | static void NotifyPktRxEvent(Ptr entity, Ptr node, Ptr packet); 54 | static void CountRxPkts(Ptr entity, Ptr node, Ptr packet); 55 | 56 | private: 57 | void ScheduleNextStateRead(); 58 | Ptr GetQueue(Ptr node); 59 | bool SetCw(Ptr node, uint32_t cwMinValue=0, uint32_t cwMaxValue=0); 60 | 61 | Time m_interval = Seconds(0.1); 62 | Ptr m_currentNode; 63 | uint64_t m_rxPktNum; 64 | 65 | }; 66 | 67 | } 68 | 69 | 70 | #endif // MY_GYM_ENTITY_H 71 | -------------------------------------------------------------------------------- /examples/linear-mesh-2/simple_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import gym 5 | import argparse 6 | from ns3gym import ns3env 7 | 8 | __author__ = "Piotr Gawlowicz" 9 | __copyright__ = "Copyright (c) 2018, Technische Universität Berlin" 10 | __version__ = "0.1.0" 11 | __email__ = "gawlowicz@tkn.tu-berlin.de" 12 | 13 | 14 | env = gym.make('ns3-v0') 15 | env.reset() 16 | 17 | ob_space = env.observation_space 18 | ac_space = env.action_space 19 | print("Observation space: ", ob_space, ob_space.dtype) 20 | print("Action space: ", ac_space, ac_space.dtype) 21 | 22 | stepIdx = 0 23 | 24 | try: 25 | obs = env.reset() 26 | print("Step: ", stepIdx) 27 | print("---obs: ", obs) 28 | 29 | while True: 30 | stepIdx += 1 31 | 32 | action = env.action_space.sample() 33 | print("---action: ", action) 34 | obs, reward, done, info = env.step(action) 35 | 36 | print("Step: ", stepIdx) 37 | print("---obs, reward, done, info: ", obs, reward, done, info) 38 | 39 | if done: 40 | break 41 | 42 | except KeyboardInterrupt: 43 | print("Ctrl-C -> Exit") 44 | finally: 45 | env.close() 46 | print("Done") -------------------------------------------------------------------------------- /examples/linear-mesh-2/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | from ns3gym import ns3env 6 | 7 | __author__ = "Piotr Gawlowicz" 8 | __copyright__ = "Copyright (c) 2018, Technische Universität Berlin" 9 | __version__ = "0.1.0" 10 | __email__ = "gawlowicz@tkn.tu-berlin.de" 11 | 12 | 13 | parser = argparse.ArgumentParser(description='Start simulation script on/off') 14 | parser.add_argument('--start', 15 | type=int, 16 | default=1, 17 | help='Start ns-3 simulation script 0/1, Default: 1') 18 | parser.add_argument('--iterations', 19 | type=int, 20 | default=1, 21 | help='Number of iterations, Default: 1') 22 | args = parser.parse_args() 23 | startSim = bool(args.start) 24 | iterationNum = int(args.iterations) 25 | 26 | port = 5555 27 | simTime = 20 # seconds 28 | stepTime = 0.5 # seconds 29 | seed = 0 30 | simArgs = {"--simTime": simTime, 31 | "--testArg": 123} 32 | debug = False 33 | 34 | env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug) 35 | # simpler: 36 | #env = ns3env.Ns3Env() 37 | env.reset() 38 | 39 | ob_space = env.observation_space 40 | ac_space = env.action_space 41 | print("Observation space: ", ob_space, ob_space.dtype) 42 | print("Action space: ", ac_space, ac_space.dtype) 43 | 44 | stepIdx = 0 45 | currIt = 0 46 | 47 | try: 48 | while True: 49 | print("Start iteration: ", currIt) 50 | obs = env.reset() 51 | print("Step: ", stepIdx) 52 | print("---obs: ", obs) 53 | 54 | while True: 55 | stepIdx += 1 56 | action = env.action_space.sample() 57 | print("---action: ", action) 58 | 59 | print("Step: ", stepIdx) 60 | obs, reward, done, info = env.step(action) 61 | print("---obs, reward, done, info: ", obs, reward, done, info) 62 | 63 | if done: 64 | stepIdx = 0 65 | if currIt + 1 < iterationNum: 66 | env.reset() 67 | break 68 | 69 | currIt += 1 70 | if currIt == iterationNum: 71 | break 72 | 73 | except KeyboardInterrupt: 74 | print("Ctrl-C -> Exit") 75 | finally: 76 | env.close() 77 | print("Done") -------------------------------------------------------------------------------- /examples/linear-mesh/backpressure_v1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import time 6 | import numpy as np 7 | from ns3gym import ns3env 8 | 9 | __author__ = "Piotr Gawlowicz" 10 | __copyright__ = "Copyright (c) 2018, Technische Universität Berlin" 11 | __version__ = "0.1.0" 12 | __email__ = "gawlowicz@tkn.tu-berlin.de" 13 | 14 | 15 | parser = argparse.ArgumentParser(description='Start simulation script on/off') 16 | parser.add_argument('--start', 17 | type=int, 18 | default=0, 19 | help='Start simulation script 0/1, Default: 0') 20 | parser.add_argument('--iterations', 21 | type=int, 22 | default=1, 23 | help='Number of iterations, Default: 1') 24 | args = parser.parse_args() 25 | startSim = bool(args.start) 26 | iterationNum = int(args.iterations) 27 | 28 | port = 5555 29 | simTime = 10 # seconds 30 | stepTime = 0.01 # seconds 31 | seed = 0 32 | simArgs = {"--simTime": simTime, 33 | "--testArg": 123, 34 | "--distance": 500} 35 | debug = False 36 | 37 | env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug) 38 | env.reset() 39 | 40 | ob_space = env.observation_space 41 | ac_space = env.action_space 42 | print("Observation space: ", ob_space, ob_space.dtype) 43 | print("Action space: ", ac_space, ac_space.dtype) 44 | 45 | stepIdx = 0 46 | currIt = 0 47 | allRxPkts = 0 48 | 49 | def calculate_cw_window(obs): 50 | diff = -np.diff(obs) 51 | maxDiff = np.argmax(diff) 52 | 53 | maxCw = 50 54 | action = np.ones(shape=len(obs), dtype=np.uint32) * maxCw 55 | action[maxDiff] = 5 56 | return action 57 | 58 | try: 59 | while True: 60 | obs = env.reset() 61 | reward = 0 62 | print("Start iteration: ", currIt) 63 | print("Step: ", stepIdx) 64 | print("---obs: ", obs) 65 | 66 | while True: 67 | stepIdx += 1 68 | 69 | allRxPkts += reward 70 | action = calculate_cw_window(obs) 71 | print("---action: ", action) 72 | 73 | obs, reward, done, info = env.step(action) 74 | print("Step: ", stepIdx) 75 | print("---obs, reward, done, info: ", obs, reward, done, info) 76 | 77 | if done: 78 | stepIdx = 0 79 | print("All rx pkts num: ", allRxPkts) 80 | allRxPkts = 0 81 | 82 | if currIt + 1 < iterationNum: 83 | env.reset() 84 | break 85 | 86 | currIt += 1 87 | if currIt == iterationNum: 88 | break 89 | 90 | except KeyboardInterrupt: 91 | print("Ctrl-C -> Exit") 92 | finally: 93 | env.close() 94 | print("Done") -------------------------------------------------------------------------------- /examples/linear-mesh/dqn-agent-v1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import scipy.io as io 5 | import gym 6 | import tensorflow as tf 7 | import tensorflow.contrib.slim as slim 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | from tensorflow import keras 11 | from ns3gym import ns3env 12 | 13 | 14 | class DqnAgent(object): 15 | """docstring for DqnAgent""" 16 | def __init__(self, inNum, outNum): 17 | super(DqnAgent, self).__init__() 18 | self.model = keras.Sequential() 19 | self.model.add(keras.layers.Dense(inNum, input_shape=(inNum,), activation='relu')) 20 | self.model.add(keras.layers.Dense(outNum, activation='softmax')) 21 | self.model.compile(optimizer=tf.train.AdamOptimizer(0.001), 22 | loss='categorical_crossentropy', 23 | metrics=['accuracy']) 24 | 25 | def get_action(self, state): 26 | return np.argmax(self.model.predict(state)[0]) 27 | 28 | def predict(self, next_state): 29 | return self.model.predict(next_state)[0] 30 | 31 | def fit(self, state, target, action): 32 | target_f = self.model.predict(state) 33 | target_f[0][action] = target 34 | self.model.fit(state, target_f, epochs=1, verbose=0) 35 | 36 | 37 | # Environment initialization 38 | port = 5561 39 | simTime = 10 # seconds 40 | startSim = True 41 | stepTime = 0.05 # seconds 42 | seed = 132 43 | simArgs = {"--simTime": simTime, 44 | "--testArg": 123, 45 | "--nodeNum": 5, 46 | "--distance": 500} 47 | debug = False 48 | 49 | env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug) 50 | #env = gym.make('ns3-v0') 51 | 52 | ob_space = env.observation_space 53 | ac_space = env.action_space 54 | print("Observation space: ", ob_space, ob_space.dtype) 55 | print("Action space: ", ac_space, ac_space.dtype) 56 | s_size = ob_space.shape[0] 57 | a_size = ac_space.shape[0] 58 | 59 | inputQueues = 2 60 | cwSize = 100 61 | 62 | agent0 = DqnAgent(inputQueues, cwSize) 63 | agent1 = DqnAgent(inputQueues, cwSize) 64 | agent2 = DqnAgent(inputQueues, cwSize) 65 | agent3 = DqnAgent(inputQueues, cwSize) 66 | 67 | total_episodes = 50 68 | max_env_steps = 100 69 | env._max_episode_steps = max_env_steps 70 | 71 | epsilon = 1.0 # exploration rate 72 | epsilon_min = 0.01 73 | epsilon_decay = 0.999 74 | 75 | time_history = [] 76 | rew_history = [] 77 | 78 | for e in range(total_episodes): 79 | 80 | state = env.reset() 81 | state = np.reshape(state, [1, s_size]) 82 | rewardsum = 0 83 | for time in range(max_env_steps): 84 | 85 | # Choose action 86 | if np.random.rand(1) < epsilon: 87 | action0 = np.random.randint(cwSize) 88 | action1 = np.random.randint(cwSize) 89 | action2 = np.random.randint(cwSize) 90 | action3 = np.random.randint(cwSize) 91 | else: 92 | action0 = agent0.get_action(state[:,0:2]) 93 | action1 = agent1.get_action(state[:,1:3]) 94 | action2 = agent2.get_action(state[:,2:4]) 95 | action3 = agent3.get_action(state[:,3:5]) 96 | 97 | # Step 98 | actionVec = [action0, action1, action2, action3, 100] 99 | next_state, reward, done, _ = env.step(actionVec) 100 | 101 | if done: 102 | print("episode: {}/{}, time: {}, rew: {}, eps: {:.2}" 103 | .format(e, total_episodes, time, rewardsum, epsilon)) 104 | break 105 | 106 | next_state = np.reshape(next_state, [1, s_size]) 107 | 108 | # Train 109 | target0 = reward 110 | target1 = reward 111 | target2 = reward 112 | target3 = reward 113 | 114 | if not done: 115 | target0 = reward + 0.95 * np.amax(agent0.predict(next_state[:,0:2])) 116 | target1 = reward + 0.95 * np.amax(agent1.predict(next_state[:,1:3])) 117 | target2 = reward + 0.95 * np.amax(agent2.predict(next_state[:,2:4])) 118 | target3 = reward + 0.95 * np.amax(agent3.predict(next_state[:,3:5])) 119 | 120 | agent0.fit(state[:,0:2], target0, action0) 121 | agent1.fit(state[:,1:3], target1, action1) 122 | agent2.fit(state[:,2:4], target2, action2) 123 | agent3.fit(state[:,3:5], target3, action3) 124 | 125 | state = next_state 126 | rewardsum += reward 127 | if epsilon > epsilon_min: epsilon *= epsilon_decay 128 | 129 | time_history.append(time) 130 | rew_history.append(rewardsum) 131 | 132 | #for n in range(2 ** s_size): 133 | # state = [n >> i & 1 for i in range(0, 2)] 134 | # state = np.reshape(state, [1, s_size]) 135 | # print("state " + str(state) 136 | # + " -> prediction " + str(model.predict(state)[0]) 137 | # ) 138 | 139 | #print(model.get_config()) 140 | #print(model.to_json()) 141 | #print(model.get_weights()) 142 | 143 | plt.plot(range(len(time_history)), time_history) 144 | plt.plot(range(len(rew_history)), rew_history) 145 | plt.xlabel('Episode') 146 | plt.ylabel('Time') 147 | plt.show() 148 | 149 | 150 | curve0 = np.zeros(shape=(101,101)) 151 | curve1 = np.zeros(shape=(101,101)) 152 | curve2 = np.zeros(shape=(101,101)) 153 | curve3 = np.zeros(shape=(101,101)) 154 | 155 | for i in range(101): 156 | for j in range(101): 157 | state = np.array([i,j]) 158 | state = np.reshape(state, [1, 2]) 159 | 160 | curve0[i,j] = agent0.get_action(state) 161 | curve1[i,j] = agent1.get_action(state) 162 | curve2[i,j] = agent2.get_action(state) 163 | curve3[i,j] = agent3.get_action(state) 164 | 165 | print("Save curves to MATLAB file") 166 | io.savemat("curves_2d.mat", { 167 | '0':curve0, 168 | '1':curve1, 169 | '2':curve2, 170 | '3':curve3, 171 | } 172 | ) -------------------------------------------------------------------------------- /examples/linear-mesh/dqn-agent-v2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import scipy.io as io 5 | import gym 6 | import tensorflow as tf 7 | import tensorflow.contrib.slim as slim 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | from tensorflow import keras 11 | from ns3gym import ns3env 12 | 13 | 14 | class DqnAgent(object): 15 | """docstring for DqnAgent""" 16 | def __init__(self, inNum, outNum): 17 | super(DqnAgent, self).__init__() 18 | self.model = keras.Sequential() 19 | self.model.add(keras.layers.Dense(inNum, input_shape=(inNum,), activation='relu')) 20 | self.model.add(keras.layers.Dense(outNum, activation='softmax')) 21 | self.model.compile(optimizer=tf.train.AdamOptimizer(0.001), 22 | loss='categorical_crossentropy', 23 | metrics=['accuracy']) 24 | 25 | def get_action(self, state): 26 | return np.argmax(self.model.predict(state)[0]) 27 | 28 | def predict(self, next_state): 29 | return self.model.predict(next_state)[0] 30 | 31 | def fit(self, state, target, action): 32 | target_f = self.model.predict(state) 33 | target_f[0][action] = target 34 | self.model.fit(state, target_f, epochs=1, verbose=0) 35 | 36 | 37 | # Environment initialization 38 | port = 5562 39 | simTime = 10 # seconds 40 | startSim = True 41 | stepTime = 0.05 # seconds 42 | seed = 132 43 | simArgs = {"--simTime": simTime, 44 | "--testArg": 123, 45 | "--nodeNum": 5, 46 | "--distance": 500} 47 | debug = False 48 | 49 | env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug) 50 | #env = gym.make('ns3-v0') 51 | 52 | 53 | ob_space = env.observation_space 54 | ac_space = env.action_space 55 | print("Observation space: ", ob_space, ob_space.dtype) 56 | print("Action space: ", ac_space, ac_space.dtype) 57 | s_size = ob_space.shape[0] 58 | a_size = ac_space.shape[0] 59 | 60 | inputQueues = 1 61 | cwSize = 100 62 | 63 | agent0 = DqnAgent(inputQueues, cwSize) 64 | agent1 = DqnAgent(inputQueues, cwSize) 65 | agent2 = DqnAgent(inputQueues, cwSize) 66 | agent3 = DqnAgent(inputQueues, cwSize) 67 | 68 | total_episodes = 50 69 | max_env_steps = 100 70 | env._max_episode_steps = max_env_steps 71 | 72 | epsilon = 1.0 # exploration rate 73 | epsilon_min = 0.01 74 | epsilon_decay = 0.999 75 | 76 | time_history = [] 77 | rew_history = [] 78 | 79 | for e in range(total_episodes): 80 | 81 | state = env.reset() 82 | state = np.reshape(state, [1, s_size]) 83 | rewardsum = 0 84 | for time in range(max_env_steps): 85 | 86 | # Choose action 87 | if np.random.rand(1) < epsilon: 88 | action0 = np.random.randint(cwSize) 89 | action1 = np.random.randint(cwSize) 90 | action2 = np.random.randint(cwSize) 91 | action3 = np.random.randint(cwSize) 92 | else: 93 | action0 = agent0.get_action(state[:,0]-state[:,1]) 94 | action1 = agent1.get_action(state[:,1]-state[:,2]) 95 | action2 = agent2.get_action(state[:,2]-state[:,3]) 96 | action3 = agent3.get_action(state[:,3]-state[:,4]) 97 | 98 | # Step 99 | actionVec = [action0, action1, action2, action3, 100] 100 | next_state, reward, done, _ = env.step(actionVec) 101 | 102 | if done: 103 | print("episode: {}/{}, time: {}, rew: {}, eps: {:.2}" 104 | .format(e, total_episodes, time, rewardsum, epsilon)) 105 | break 106 | 107 | next_state = np.reshape(next_state, [1, s_size]) 108 | 109 | # Train 110 | target0 = reward 111 | target1 = reward 112 | target2 = reward 113 | target3 = reward 114 | 115 | if not done: 116 | target0 = reward + 0.95 * np.amax(agent0.predict(next_state[:,0]-next_state[:,1])) 117 | target1 = reward + 0.95 * np.amax(agent1.predict(next_state[:,1]-next_state[:,2])) 118 | target2 = reward + 0.95 * np.amax(agent2.predict(next_state[:,2]-next_state[:,3])) 119 | target3 = reward + 0.95 * np.amax(agent3.predict(next_state[:,3]-next_state[:,4])) 120 | 121 | agent0.fit(state[:,0]-state[:,1], target0, action0) 122 | agent1.fit(state[:,1]-state[:,2], target1, action1) 123 | agent2.fit(state[:,2]-state[:,3], target2, action2) 124 | agent3.fit(state[:,3]-state[:,4], target3, action3) 125 | 126 | state = next_state 127 | rewardsum += reward 128 | if epsilon > epsilon_min: epsilon *= epsilon_decay 129 | 130 | time_history.append(time) 131 | rew_history.append(rewardsum) 132 | 133 | #for n in range(2 ** s_size): 134 | # state = [n >> i & 1 for i in range(0, 2)] 135 | # state = np.reshape(state, [1, s_size]) 136 | # print("state " + str(state) 137 | # + " -> prediction " + str(model.predict(state)[0]) 138 | # ) 139 | 140 | #print(model.get_config()) 141 | #print(model.to_json()) 142 | #print(model.get_weights()) 143 | 144 | plt.plot(range(len(time_history)), time_history) 145 | plt.plot(range(len(rew_history)), rew_history) 146 | plt.xlabel('Episode') 147 | plt.ylabel('Time') 148 | plt.show() 149 | 150 | 151 | curve0 = np.zeros(shape=(101)) 152 | curve1 = np.zeros(shape=(101)) 153 | curve2 = np.zeros(shape=(101)) 154 | curve3 = np.zeros(shape=(101)) 155 | 156 | for i in range(101): 157 | state = np.array([i]) 158 | state = np.reshape(state, [1, 1]) 159 | 160 | curve0[i] = agent0.get_action(state) 161 | curve1[i] = agent1.get_action(state) 162 | curve2[i] = agent2.get_action(state) 163 | curve3[i] = agent3.get_action(state) 164 | 165 | print("Save curves to MATLAB file") 166 | io.savemat("curves_1d.mat", { 167 | '0':curve0, 168 | '1':curve1, 169 | '2':curve2, 170 | '3':curve3, 171 | } 172 | ) -------------------------------------------------------------------------------- /examples/linear-mesh/my_random.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | from ns3gym import ns3env 6 | 7 | __author__ = "Piotr Gawlowicz" 8 | __copyright__ = "Copyright (c) 2018, Technische Universität Berlin" 9 | __version__ = "0.1.0" 10 | __email__ = "gawlowicz@tkn.tu-berlin.de" 11 | 12 | 13 | parser = argparse.ArgumentParser(description='Start simulation script on/off') 14 | parser.add_argument('--start', 15 | type=int, 16 | default=0, 17 | help='Start simulation script 0/1, Default: 0') 18 | parser.add_argument('--iterations', 19 | type=int, 20 | default=1, 21 | help='Number of iterations, Default: 1') 22 | args = parser.parse_args() 23 | startSim = bool(args.start) 24 | iterationNum = int(args.iterations) 25 | 26 | port = 5550 27 | simTime = 10 # seconds 28 | stepTime = 0.01 # seconds 29 | seed = 0 30 | simArgs = {"--simTime": simTime, 31 | "--testArg": 123, 32 | "--nodeNum": 5, 33 | "--distance": 500} 34 | debug = False 35 | 36 | env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug) 37 | env.reset() 38 | 39 | ob_space = env.observation_space 40 | ac_space = env.action_space 41 | print("Observation space: ", ob_space, ob_space.dtype) 42 | print("Action space: ", ac_space, ac_space.dtype) 43 | 44 | stepIdx = 0 45 | currIt = 0 46 | allRxPkts = 0 47 | 48 | try: 49 | while True: 50 | print("Start iteration: ", currIt) 51 | obs = env.reset() 52 | reward = 0 53 | print("Step: ", stepIdx) 54 | print("---obs: ", obs) 55 | 56 | while True: 57 | stepIdx += 1 58 | 59 | allRxPkts += reward 60 | action = env.action_space.sample() 61 | action = action * 1 + 1 62 | print("---action: ", action) 63 | 64 | obs, reward, done, info = env.step(action) 65 | print("Step: ", stepIdx) 66 | print("---obs, reward, done, info: ", obs, reward, done, info) 67 | 68 | if done: 69 | stepIdx = 0 70 | print("All rx pkts num: ", allRxPkts) 71 | allRxPkts = 0 72 | 73 | if currIt + 1 < iterationNum: 74 | env.reset() 75 | break 76 | 77 | currIt += 1 78 | if currIt == iterationNum: 79 | break 80 | 81 | except KeyboardInterrupt: 82 | print("Ctrl-C -> Exit") 83 | finally: 84 | env.close() 85 | print("Done") -------------------------------------------------------------------------------- /examples/linear-mesh/my_random2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import time 5 | import gym 6 | import os 7 | import numpy as np 8 | from ns3gym import ns3env 9 | import matplotlib.pyplot as plt 10 | 11 | # Environment initialization 12 | port = 5550 13 | simTime = 20 # seconds 14 | startSim = True 15 | stepTime = 0.1 # seconds 16 | seed = 0 17 | simArgs = {"--simTime": simTime, 18 | "--testArg": 123, 19 | "--nodeNum": 5, 20 | "--distance": 500} 21 | debug = False 22 | 23 | env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug) 24 | 25 | rewards = [] 26 | iterations = [] 27 | episodes = 10 28 | 29 | # Episodes 30 | for episode in range(episodes): 31 | # Refresh state 32 | state = env.reset() 33 | done = False 34 | t_reward = 0 35 | 36 | i = 0 37 | # Run episode 38 | while True: 39 | if done: 40 | break 41 | 42 | i += 1 43 | action = env.action_space.sample() * 5 + 1 44 | state, reward, done, info = env.step(action) 45 | t_reward += reward 46 | 47 | print(episode, " Total reward:", t_reward) 48 | if t_reward: 49 | rewards.append(t_reward) 50 | iterations.append(i) 51 | 52 | # Close environment 53 | env.close() 54 | 55 | # Plot results 56 | def chunks_func(l, n): 57 | n = max(1, n) 58 | return (l[i:i+n] for i in xrange(0, len(l), n)) 59 | 60 | size = episodes 61 | #chunks = list(chunk_list(rewards, size)) 62 | rewards = np.array(rewards) 63 | print("mean value: ", np.mean(rewards), " std:", np.std(rewards),) 64 | chunks = np.array_split(rewards, size) 65 | #chunks = chunks_func(rewards, size) 66 | averages = [sum(chunk) / len(chunk) for chunk in chunks] 67 | 68 | 69 | plt.plot(averages) 70 | plt.xlabel('Episode') 71 | plt.ylabel('Average Reward') 72 | plt.show() -------------------------------------------------------------------------------- /examples/linear-mesh/no_op.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import time 6 | import numpy as np 7 | from ns3gym import ns3env 8 | 9 | __author__ = "Piotr Gawlowicz" 10 | __copyright__ = "Copyright (c) 2018, Technische Universität Berlin" 11 | __version__ = "0.1.0" 12 | __email__ = "gawlowicz@tkn.tu-berlin.de" 13 | 14 | 15 | parser = argparse.ArgumentParser(description='Start simulation script on/off') 16 | parser.add_argument('--start', 17 | type=int, 18 | default=0, 19 | help='Start simulation script 0/1, Default: 0') 20 | parser.add_argument('--iterations', 21 | type=int, 22 | default=1, 23 | help='Number of iterations, Default: 1') 24 | args = parser.parse_args() 25 | startSim = bool(args.start) 26 | iterationNum = int(args.iterations) 27 | 28 | port = 5551 29 | simTime = 10 # seconds 30 | stepTime = 0.1 # seconds 31 | seed = 0 32 | simArgs = {"--simTime": simTime, 33 | "--testArg": 123, 34 | "--nodeNum": 5, 35 | "--distance": 500} 36 | debug = False 37 | 38 | env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug) 39 | env.reset() 40 | 41 | ob_space = env.observation_space 42 | ac_space = env.action_space 43 | print("Observation space: ", ob_space, ob_space.dtype) 44 | print("Action space: ", ac_space, ac_space.dtype) 45 | 46 | stepIdx = 0 47 | currIt = 0 48 | allRxPkts = 0 49 | 50 | def get_no_op_action(obs): 51 | # cwValue 0 is not applied, so no_op 52 | action = np.zeros(shape=len(obs), dtype=np.uint32) 53 | return action 54 | 55 | try: 56 | while True: 57 | print("Start iteration: ", currIt) 58 | obs = env.reset() 59 | reward = 0 60 | print("Step: ", stepIdx) 61 | print("---obs: ", obs) 62 | 63 | while True: 64 | stepIdx += 1 65 | 66 | allRxPkts += reward 67 | action = get_no_op_action(obs) 68 | print("---action: ", action) 69 | 70 | obs, reward, done, info = env.step(action) 71 | print("Step: ", stepIdx) 72 | print("---obs, reward, done, info: ", obs, reward, done, info) 73 | 74 | if done: 75 | stepIdx = 0 76 | print("All rx pkts num: ", allRxPkts) 77 | allRxPkts = 0 78 | 79 | if currIt + 1 < iterationNum: 80 | env.reset() 81 | break 82 | 83 | currIt += 1 84 | if currIt == iterationNum: 85 | break 86 | 87 | except KeyboardInterrupt: 88 | print("Ctrl-C -> Exit") 89 | finally: 90 | env.close() 91 | print("Done") -------------------------------------------------------------------------------- /examples/linear-mesh/no_op2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import time 5 | import gym 6 | import os 7 | import numpy as np 8 | from ns3gym import ns3env 9 | import matplotlib.pyplot as plt 10 | 11 | # Environment initialization 12 | port = 5551 13 | simTime = 10 # seconds 14 | startSim = True 15 | stepTime = 0.1 # seconds 16 | seed = 0 17 | simArgs = {"--simTime": simTime, 18 | "--testArg": 123, 19 | "--nodeNum": 5, 20 | "--distance": 500} 21 | debug = False 22 | 23 | env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug) 24 | 25 | rewards = [] 26 | iterations = [] 27 | episodes = 10 28 | 29 | def get_no_op_action(obs): 30 | # cwValue 0 is not applied, so no_op 31 | action = np.zeros(shape=len(obs), dtype=np.uint32) 32 | return action 33 | 34 | # Episodes 35 | for episode in range(episodes): 36 | # Refresh state 37 | state = env.reset() 38 | done = False 39 | t_reward = 0 40 | 41 | i = 0 42 | # Run episode 43 | while True: 44 | if done: 45 | break 46 | 47 | i += 1 48 | action = get_no_op_action(state) 49 | state, reward, done, info = env.step(action) 50 | t_reward += reward 51 | 52 | print(episode, " Total reward:", t_reward) 53 | if t_reward: 54 | rewards.append(t_reward) 55 | iterations.append(i) 56 | 57 | # Close environment 58 | env.close() 59 | 60 | # Plot results 61 | def chunks_func(l, n): 62 | n = max(1, n) 63 | return (l[i:i+n] for i in xrange(0, len(l), n)) 64 | 65 | size = episodes 66 | #chunks = list(chunk_list(rewards, size)) 67 | rewards = np.array(rewards) 68 | print("mean value: ", np.mean(rewards), " std:", np.std(rewards),) 69 | #chunks = np.array_split(rewards, size) 70 | #chunks = chunks_func(rewards, size) 71 | #averages = [sum(chunk) / len(chunk) for chunk in chunks] 72 | 73 | 74 | plt.plot(rewards) 75 | plt.xlabel('Episode') 76 | plt.ylabel('Average Reward') 77 | plt.show() -------------------------------------------------------------------------------- /examples/linear-mesh/qfull.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import time 5 | import gym 6 | import os 7 | import numpy as np 8 | from ns3gym import ns3env 9 | import matplotlib.pyplot as plt 10 | 11 | # Environment initialization 12 | port = 5553 13 | simTime = 20 # seconds 14 | startSim = True 15 | stepTime = 0.1 # seconds 16 | seed = 122 17 | simArgs = {"--simTime": simTime, 18 | "--testArg": 123, 19 | "--nodeNum": 5, 20 | "--distance": 500} 21 | debug = False 22 | 23 | env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug) 24 | 25 | # Q and rewards 26 | stateNum = 11 27 | actionNum = 10 28 | factor = np.uint(100/(stateNum-1)) 29 | print(stateNum, actionNum, factor) 30 | Q0 = np.zeros(shape=(stateNum, stateNum, stateNum, actionNum), dtype=np.float) 31 | Q1 = np.zeros(shape=(stateNum, stateNum, stateNum, actionNum), dtype=np.float) 32 | Q2 = np.zeros(shape=(stateNum, stateNum, stateNum, actionNum), dtype=np.float) 33 | Q3 = np.zeros(shape=(stateNum, stateNum, stateNum, actionNum), dtype=np.float) 34 | 35 | sigma = 0.1 36 | Q0 = np.random.randn(stateNum, stateNum, stateNum, actionNum) * sigma 37 | Q1 = np.random.randn(stateNum, stateNum, stateNum, actionNum) * sigma 38 | Q2 = np.random.randn(stateNum, stateNum, stateNum, actionNum) * sigma 39 | Q3 = np.random.randn(stateNum, stateNum, stateNum, actionNum) * sigma 40 | 41 | rewards = [] 42 | iterations = [] 43 | 44 | # Parameters 45 | alpha = 0.2 46 | discount = 0.6 47 | episodes = 25 48 | 49 | learning = True 50 | 51 | def calculate_q_diff(obs): 52 | obs = np.array(obs) 53 | diff = -np.diff(obs) 54 | diff[diff<0] = 0 55 | return diff 56 | 57 | # Episodes 58 | for episode in range(episodes): 59 | # Refresh state 60 | if episode == 15: 61 | env.simSeed = 0 # random seed 62 | learning = False 63 | 64 | state = env.reset() 65 | state = np.uint(np.array(state, dtype=np.uint32) / factor) 66 | state = state[1:-1] 67 | done = False 68 | t_reward = 0 69 | 70 | i = 0 71 | # Run episode 72 | while True: 73 | if done: 74 | break 75 | 76 | i += 1 77 | current = state 78 | divider = episode 79 | [a, b, c] = current 80 | if learning: 81 | tmpQ = Q0[a,b,c, :] + np.random.randn(actionNum) * (1 / float(divider + 1)) 82 | action0 = np.argmax(tmpQ) 83 | tmpQ = Q1[a,b,c, :] + np.random.randn(actionNum) * (1 / float(divider + 1)) 84 | action1 = np.argmax(tmpQ) 85 | tmpQ = Q2[a,b,c, :] + np.random.randn(actionNum) * (1 / float(divider + 1)) 86 | action2 = np.argmax(tmpQ) 87 | tmpQ = Q3[a,b,c, :] + np.random.randn(actionNum) * (1 / float(divider + 1)) 88 | action3 = np.argmax(tmpQ) 89 | else: 90 | action0 = np.argmax(Q0[a,b,c, :]) 91 | action1 = np.argmax(Q1[a,b,c, :]) 92 | action2 = np.argmax(Q2[a,b,c, :]) 93 | action3 = np.argmax(Q3[a,b,c, :]) 94 | 95 | action0 = np.unravel_index(action0, Q0.shape)[-1] 96 | action1 = np.unravel_index(action1, Q1.shape)[-1] 97 | action2 = np.unravel_index(action2, Q2.shape)[-1] 98 | action3 = np.unravel_index(action3, Q3.shape)[-1] 99 | 100 | action = np.array([action0, action1, action2, action3, 0], dtype=np.uint) 101 | action = action * 5*factor + 1 102 | print(action) 103 | 104 | state, reward, done, info = env.step(action) 105 | state = np.uint(np.array(state, dtype=np.uint32) / factor) 106 | state = state[1:-1] 107 | 108 | t_reward += reward 109 | [x,y,z] = state 110 | if learning: 111 | Q0[a,b,c, action0] += alpha * (reward + discount * np.max(Q0[x,y,z, :]) - Q0[a,b,c, action0]) 112 | Q1[a,b,c, action1] += alpha * (reward + discount * np.max(Q1[x,y,z, :]) - Q1[a,b,c, action1]) 113 | Q2[a,b,c, action2] += alpha * (reward + discount * np.max(Q2[x,y,z, :]) - Q2[a,b,c, action2]) 114 | Q3[a,b,c, action3] += alpha * (reward + discount * np.max(Q3[x,y,z, :]) - Q3[a,b,c, action3]) 115 | print(np.max(Q1), reward) 116 | 117 | 118 | print(episode, " Total reward:", t_reward) 119 | rewards.append(t_reward) 120 | iterations.append(i) 121 | 122 | # Close environment 123 | env.close() 124 | 125 | # Plot results 126 | def chunks_func(l, n): 127 | n = max(1, n) 128 | return (l[i:i+n] for i in xrange(0, len(l), n)) 129 | 130 | size = episodes 131 | #chunks = list(chunk_list(rewards, size)) 132 | rewards = np.array(rewards) 133 | print("mean value: ", np.mean(rewards[10:]), " std:", np.std(rewards[10:]),) 134 | chunks = np.array_split(rewards, size) 135 | #chunks = chunks_func(rewards, size) 136 | averages = [sum(chunk) / len(chunk) for chunk in chunks] 137 | 138 | 139 | plt.plot(averages) 140 | plt.xlabel('Episode') 141 | plt.ylabel('Average Reward') 142 | plt.show() 143 | 144 | #print("Q0",Q0) 145 | #print("Q1",Q1) 146 | #print("Q2",Q2) 147 | #print("Q3",Q3) -------------------------------------------------------------------------------- /examples/linear-mesh/qfull_working.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import time 5 | import gym 6 | import os 7 | import numpy as np 8 | from ns3gym import ns3env 9 | import matplotlib.pyplot as plt 10 | 11 | # Environment initialization 12 | port = 5553 13 | simTime = 10 # seconds 14 | startSim = True 15 | stepTime = 0.1 # seconds 16 | seed = 122 17 | simArgs = {"--simTime": simTime, 18 | "--testArg": 123, 19 | "--nodeNum": 5, 20 | "--distance": 500} 21 | debug = False 22 | 23 | env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug) 24 | 25 | # Q and rewards 26 | Q0 = np.zeros(shape=(5, 5, 5, 4), dtype=np.float) 27 | Q1 = np.zeros(shape=(5, 5, 5, 4), dtype=np.float) 28 | Q2 = np.zeros(shape=(5, 5, 5, 4), dtype=np.float) 29 | Q3 = np.zeros(shape=(5, 5, 5, 4), dtype=np.float) 30 | 31 | rewards = [] 32 | iterations = [] 33 | 34 | # Parameters 35 | alpha = 0.75 36 | discount = 0.95 37 | episodes = 10 38 | 39 | def calculate_q_diff(obs): 40 | obs = np.array(obs) 41 | diff = -np.diff(obs) 42 | diff[diff<0] = 0 43 | return diff 44 | 45 | # Episodes 46 | for episode in range(episodes): 47 | # Refresh state 48 | state = env.reset() 49 | state = np.uint(np.array(state, dtype=np.uint32) / 25) 50 | state = state[1:-1] 51 | done = False 52 | t_reward = 0 53 | 54 | i = 0 55 | # Run episode 56 | while True: 57 | if done: 58 | break 59 | 60 | i += 1 61 | current = state 62 | divider = episode 63 | [a, b, c] = current 64 | action0 = np.argmax(Q0[a,b,c, :] + np.random.randn(1, 4) * (1 / float(divider + 1))) 65 | action1 = np.argmax(Q1[a,b,c, :] + np.random.randn(1, 4) * (1 / float(divider + 1))) 66 | action2 = np.argmax(Q2[a,b,c, :] + np.random.randn(1, 4) * (1 / float(divider + 1))) 67 | action3 = np.argmax(Q3[a,b,c, :] + np.random.randn(1, 4) * (1 / float(divider + 1))) 68 | 69 | action0 = np.unravel_index(action0, Q0.shape)[-1] 70 | action1 = np.unravel_index(action1, Q1.shape)[-1] 71 | action2 = np.unravel_index(action2, Q2.shape)[-1] 72 | action3 = np.unravel_index(action3, Q3.shape)[-1] 73 | 74 | action = np.array([action0, action1, action2, action3, 0], dtype=np.uint) 75 | action = action * 25 + 5 76 | #print(action) 77 | 78 | state, reward, done, info = env.step(action) 79 | state = np.uint(np.array(state, dtype=np.uint32) / 25) 80 | #state = calculate_q_diff(state) 81 | state = state[1:-1] 82 | 83 | t_reward += reward 84 | [x,y,z] = state 85 | Q0[a,b,c, action0] += alpha * (reward + discount * np.max(Q0[x,y,z, :]) - Q0[a,b,c, action0]) 86 | Q1[a,b,c, action1] += alpha * (reward + discount * np.max(Q1[x,y,z, :]) - Q1[a,b,c, action1]) 87 | Q2[a,b,c, action2] += alpha * (reward + discount * np.max(Q2[x,y,z, :]) - Q2[a,b,c, action2]) 88 | Q3[a,b,c, action3] += alpha * (reward + discount * np.max(Q3[x,y,z, :]) - Q3[a,b,c, action3]) 89 | 90 | 91 | print("Total reward:", t_reward) 92 | rewards.append(t_reward) 93 | iterations.append(i) 94 | 95 | # Close environment 96 | env.close() 97 | 98 | # Plot results 99 | def chunks_func(l, n): 100 | n = max(1, n) 101 | return (l[i:i+n] for i in xrange(0, len(l), n)) 102 | 103 | size = episodes 104 | #chunks = list(chunk_list(rewards, size)) 105 | rewards = np.array(rewards) 106 | chunks = np.array_split(rewards, size) 107 | #chunks = chunks_func(rewards, size) 108 | averages = [sum(chunk) / len(chunk) for chunk in chunks] 109 | 110 | plt.plot(averages) 111 | plt.xlabel('Episode') 112 | plt.ylabel('Average Reward') 113 | plt.show() 114 | 115 | print("Q0",Q0) 116 | print("Q1",Q1) 117 | print("Q2",Q2) 118 | print("Q3",Q3) -------------------------------------------------------------------------------- /examples/linear-mesh/qlearn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import time 5 | import gym 6 | import os 7 | import numpy as np 8 | from ns3gym import ns3env 9 | import matplotlib.pyplot as plt 10 | 11 | # Environment initialization 12 | port = 5552 13 | simTime = 10 # seconds 14 | startSim = True 15 | stepTime = 0.1 # seconds 16 | seed = 0 17 | simArgs = {"--simTime": simTime, 18 | "--testArg": 123, 19 | "--nodeNum": 5, 20 | "--distance": 500} 21 | debug = False 22 | 23 | env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug) 24 | 25 | # Q and rewards 26 | Q = np.zeros(shape=(5, 11, 10), dtype=np.float) 27 | action = np.zeros(shape=(5), dtype=np.uint) 28 | 29 | rewards = [] 30 | iterations = [] 31 | 32 | # Parameters 33 | alpha = 0.75 34 | discount = 0.95 35 | episodes = 10 36 | 37 | # Episodes 38 | for episode in range(episodes): 39 | # Refresh state 40 | state = env.reset() 41 | #print('state:', state) 42 | state = np.uint(np.array(state, dtype=np.uint32) / 10) 43 | done = False 44 | t_reward = 0 45 | 46 | i = 0 47 | # Run episode 48 | while True: 49 | if done: 50 | break 51 | 52 | i += 1 53 | current = state 54 | for n in range(5): 55 | action[n] = np.argmax(Q[n, current[n], :] + np.random.randn(1, 10) * (1 / float(episode + 1))) 56 | 57 | saction = np.uint(action * 100) + 1 58 | #print("action", saction) 59 | state, reward, done, info = env.step(saction) 60 | #print('state:', state, reward, done) 61 | state = np.uint(np.array(state) / 10 ) 62 | 63 | t_reward += reward 64 | for n in range(5): 65 | Q[n, current[n], action[n]] += alpha * (reward + discount * np.max(Q[n, state[n], :]) - Q[n, current[n], action[n]]) 66 | 67 | print("Total reward:", t_reward) 68 | rewards.append(t_reward) 69 | iterations.append(i) 70 | 71 | # Close environment 72 | env.close() 73 | 74 | # Plot results 75 | def chunks_func(l, n): 76 | n = max(1, n) 77 | return (l[i:i+n] for i in xrange(0, len(l), n)) 78 | 79 | size = episodes 80 | #chunks = list(chunk_list(rewards, size)) 81 | rewards = np.array(rewards) 82 | chunks = np.array_split(rewards, size) 83 | #chunks = chunks_func(rewards, size) 84 | averages = [sum(chunk) / len(chunk) for chunk in chunks] 85 | 86 | plt.plot(averages) 87 | plt.xlabel('Episode') 88 | plt.ylabel('Average Reward') 89 | plt.show() -------------------------------------------------------------------------------- /examples/linear-mesh/qlearn_full.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import time 5 | import gym 6 | import os 7 | import numpy as np 8 | from ns3gym import ns3env 9 | import matplotlib.pyplot as plt 10 | 11 | # Environment initialization 12 | port = 5556 13 | simTime = 10 # seconds 14 | startSim = True 15 | stepTime = 0.005 # seconds 16 | seed = 0 17 | simArgs = {"--simTime": simTime, 18 | "--testArg": 123, 19 | "--distance": 20} 20 | debug = False 21 | 22 | env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug) 23 | 24 | # Q and rewards 25 | Q = np.zeros(shape=(5, 101, 100), dtype=np.float) 26 | action = np.zeros(shape=(5), dtype=np.uint) 27 | 28 | rewards = [] 29 | iterations = [] 30 | 31 | # Parameters 32 | alpha = 0.75 33 | discount = 0.95 34 | episodes = 1000 35 | 36 | def calculate_q_diff(obs): 37 | obs = np.array(obs) 38 | diff = -np.diff(obs) 39 | obs[:-1] = diff 40 | obs[-1] = 0 41 | obs[obs<0] = 0 42 | return obs 43 | 44 | # Episodes 45 | for episode in range(episodes): 46 | # Refresh state 47 | state = env.reset() 48 | 49 | #state = calculate_q_diff(state) 50 | done = False 51 | t_reward = 0 52 | 53 | i = 0 54 | # Run episode 55 | while True: 56 | if done: 57 | break 58 | 59 | i += 1 60 | current = state 61 | for n in range(5): 62 | action[n] = np.argmax(Q[n, current[n], :] + np.random.randn(1, 100) * (1 / float(episode + 1))) 63 | 64 | state, reward, done, info = env.step(action) 65 | #state = calculate_q_diff(state) 66 | #print('state:', state, reward, done) 67 | #print("action:", action) 68 | 69 | t_reward += reward 70 | for n in range(5): 71 | Q[n, current[n], action[n]] += alpha * (reward + discount * np.max(Q[n, state[n], :]) - Q[n, current[n], action[n]]) 72 | 73 | print("Total reward:", t_reward) 74 | rewards.append(t_reward) 75 | iterations.append(i) 76 | 77 | # Close environment 78 | env.close() 79 | 80 | # Plot results 81 | def chunks_func(l, n): 82 | n = max(1, n) 83 | return (l[i:i+n] for i in xrange(0, len(l), n)) 84 | 85 | size = 5 86 | #chunks = list(chunk_list(rewards, size)) 87 | rewards = np.array(rewards) 88 | chunks = np.array_split(rewards, size) 89 | #chunks = chunks_func(rewards, size) 90 | averages = [sum(chunk) / len(chunk) for chunk in chunks] 91 | 92 | plt.plot(averages) 93 | plt.xlabel('Episode') 94 | plt.ylabel('Average Reward') 95 | plt.show() 96 | 97 | print(Q) -------------------------------------------------------------------------------- /examples/multi-agent/agent1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | from ns3gym import ns3env 6 | 7 | __author__ = "Piotr Gawlowicz" 8 | __copyright__ = "Copyright (c) 2020, Technische Universität Berlin" 9 | __version__ = "0.1.0" 10 | __email__ = "gawlowicz@tkn.tu-berlin.de" 11 | 12 | 13 | port = 5555 14 | env = ns3env.Ns3Env(port=port, startSim=False) 15 | env.reset() 16 | 17 | ob_space = env.observation_space 18 | ac_space = env.action_space 19 | print("Observation space: ", ob_space, ob_space.dtype) 20 | print("Action space: ", ac_space, ac_space.dtype) 21 | 22 | 23 | stepIdx = 0 24 | currIt = 0 25 | iterationNum = 3 26 | 27 | try: 28 | while True: 29 | obs = env.reset() 30 | print("Step: ", stepIdx) 31 | print("---obs: ", obs) 32 | 33 | while True: 34 | stepIdx += 1 35 | action = env.action_space.sample() 36 | print("---action: ", action) 37 | 38 | print("Step: ", stepIdx) 39 | obs, reward, done, info = env.step(action) 40 | print("---obs, reward, done, info: ", obs, reward, done, info) 41 | 42 | input("press enter....") 43 | 44 | if done: 45 | break 46 | 47 | currIt += 1 48 | if currIt == iterationNum: 49 | break 50 | 51 | 52 | except KeyboardInterrupt: 53 | print("Ctrl-C -> Exit") 54 | finally: 55 | env.close() 56 | print("Done") -------------------------------------------------------------------------------- /examples/multi-agent/agent2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | from ns3gym import ns3env 6 | 7 | __author__ = "Piotr Gawlowicz" 8 | __copyright__ = "Copyright (c) 2020, Technische Universität Berlin" 9 | __version__ = "0.1.0" 10 | __email__ = "gawlowicz@tkn.tu-berlin.de" 11 | 12 | 13 | port = 5556 14 | env = ns3env.Ns3Env(port=port, startSim=False) 15 | env.reset() 16 | 17 | ob_space = env.observation_space 18 | ac_space = env.action_space 19 | print("Observation space: ", ob_space, ob_space.dtype) 20 | print("Action space: ", ac_space, ac_space.dtype) 21 | 22 | 23 | stepIdx = 0 24 | currIt = 0 25 | iterationNum = 3 26 | 27 | try: 28 | while True: 29 | obs = env.reset() 30 | print("Step: ", stepIdx) 31 | print("---obs: ", obs) 32 | 33 | while True: 34 | stepIdx += 1 35 | action = env.action_space.sample() 36 | print("---action: ", action) 37 | 38 | print("Step: ", stepIdx) 39 | obs, reward, done, info = env.step(action) 40 | print("---obs, reward, done, info: ", obs, reward, done, info) 41 | 42 | input("press enter....") 43 | 44 | if done: 45 | break 46 | 47 | currIt += 1 48 | if currIt == iterationNum: 49 | break 50 | 51 | 52 | except KeyboardInterrupt: 53 | print("Ctrl-C -> Exit") 54 | finally: 55 | env.close() 56 | print("Done") -------------------------------------------------------------------------------- /examples/multi-agent/mygym.cc: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Technische Universität Berlin 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | */ 20 | 21 | #include "mygym.h" 22 | #include "ns3/object.h" 23 | #include "ns3/core-module.h" 24 | #include "ns3/wifi-module.h" 25 | #include "ns3/node-list.h" 26 | #include "ns3/log.h" 27 | #include 28 | #include 29 | 30 | namespace ns3 { 31 | 32 | NS_LOG_COMPONENT_DEFINE ("MyGymEnv"); 33 | 34 | NS_OBJECT_ENSURE_REGISTERED (MyGymEnv); 35 | 36 | MyGymEnv::MyGymEnv () 37 | { 38 | NS_LOG_FUNCTION (this); 39 | m_interval = Seconds(0.1); 40 | 41 | Simulator::Schedule (Seconds(0.0), &MyGymEnv::ScheduleNextStateRead, this); 42 | } 43 | 44 | MyGymEnv::MyGymEnv (uint32_t agentId, Time stepTime) 45 | { 46 | NS_LOG_FUNCTION (this); 47 | m_agentId = agentId; 48 | m_interval = stepTime; 49 | 50 | Simulator::Schedule (Seconds(0.0), &MyGymEnv::ScheduleNextStateRead, this); 51 | } 52 | 53 | void 54 | MyGymEnv::ScheduleNextStateRead () 55 | { 56 | NS_LOG_FUNCTION (this); 57 | Simulator::Schedule (m_interval, &MyGymEnv::ScheduleNextStateRead, this); 58 | Notify(); 59 | } 60 | 61 | MyGymEnv::~MyGymEnv () 62 | { 63 | NS_LOG_FUNCTION (this); 64 | } 65 | 66 | TypeId 67 | MyGymEnv::GetTypeId (void) 68 | { 69 | static TypeId tid = TypeId ("MyGymEnv") 70 | .SetParent () 71 | .SetGroupName ("OpenGym") 72 | .AddConstructor () 73 | ; 74 | return tid; 75 | } 76 | 77 | void 78 | MyGymEnv::DoDispose () 79 | { 80 | NS_LOG_FUNCTION (this); 81 | } 82 | 83 | /* 84 | Define observation space 85 | */ 86 | Ptr 87 | MyGymEnv::GetObservationSpace() 88 | { 89 | uint32_t nodeNum = 5; 90 | float low = 0.0; 91 | float high = 10.0; 92 | std::vector shape = {nodeNum,}; 93 | std::string dtype = TypeNameGet (); 94 | 95 | Ptr discrete = CreateObject (nodeNum); 96 | Ptr box = CreateObject (low, high, shape, dtype); 97 | 98 | Ptr space = CreateObject (); 99 | space->Add("box", box); 100 | space->Add("discrete", discrete); 101 | 102 | NS_LOG_UNCOND ("AgentID: " << m_agentId << " MyGetObservationSpace: " << space); 103 | return space; 104 | } 105 | 106 | /* 107 | Define action space 108 | */ 109 | Ptr 110 | MyGymEnv::GetActionSpace() 111 | { 112 | uint32_t nodeNum = 5; 113 | float low = 0.0; 114 | float high = 10.0; 115 | std::vector shape = {nodeNum,}; 116 | std::string dtype = TypeNameGet (); 117 | 118 | Ptr discrete = CreateObject (nodeNum); 119 | Ptr box = CreateObject (low, high, shape, dtype); 120 | 121 | Ptr space = CreateObject (); 122 | space->Add("box", box); 123 | space->Add("discrete", discrete); 124 | 125 | NS_LOG_UNCOND ("AgentID: " << m_agentId << " MyGetActionSpace: " << space); 126 | return space; 127 | } 128 | 129 | /* 130 | Define game over condition 131 | */ 132 | bool 133 | MyGymEnv::GetGameOver() 134 | { 135 | bool isGameOver = false; 136 | bool test = false; 137 | static float stepCounter = 0.0; 138 | stepCounter += 1; 139 | if (stepCounter == 10 && test) { 140 | isGameOver = true; 141 | } 142 | NS_LOG_UNCOND ("AgentID: " << m_agentId << " MyGetGameOver: " << isGameOver); 143 | return isGameOver; 144 | } 145 | 146 | /* 147 | Collect observations 148 | */ 149 | Ptr 150 | MyGymEnv::GetObservation() 151 | { 152 | uint32_t nodeNum = 5; 153 | uint32_t low = 0.0; 154 | uint32_t high = 10.0; 155 | Ptr rngInt = CreateObject (); 156 | 157 | std::vector shape = {nodeNum,}; 158 | Ptr > box = CreateObject >(shape); 159 | 160 | // generate random data 161 | for (uint32_t i = 0; iGetInteger(low, high); 163 | box->AddValue(value); 164 | } 165 | 166 | Ptr discrete = CreateObject(nodeNum); 167 | uint32_t value = rngInt->GetInteger(low, high); 168 | discrete->SetValue(value); 169 | 170 | Ptr data = CreateObject (); 171 | data->Add(box); 172 | data->Add(discrete); 173 | 174 | // Print data from tuple 175 | Ptr > mbox = DynamicCast >(data->Get(0)); 176 | Ptr mdiscrete = DynamicCast(data->Get(1)); 177 | NS_LOG_UNCOND ("AgentID: " << m_agentId << " MyGetObservation: " << data); 178 | NS_LOG_UNCOND ("---" << mbox); 179 | NS_LOG_UNCOND ("---" << mdiscrete); 180 | 181 | return data; 182 | } 183 | 184 | /* 185 | Define reward function 186 | */ 187 | float 188 | MyGymEnv::GetReward() 189 | { 190 | static float reward = 0.0; 191 | reward += 1; 192 | return reward; 193 | } 194 | 195 | /* 196 | Define extra info. Optional 197 | */ 198 | std::string 199 | MyGymEnv::GetExtraInfo() 200 | { 201 | std::string myInfo = "testInfo"; 202 | myInfo += "|123"; 203 | NS_LOG_UNCOND("AgentID: " << m_agentId << " MyGetExtraInfo: " << myInfo); 204 | return myInfo; 205 | } 206 | 207 | /* 208 | Execute received actions 209 | */ 210 | bool 211 | MyGymEnv::ExecuteActions(Ptr action) 212 | { 213 | Ptr dict = DynamicCast(action); 214 | Ptr > box = DynamicCast >(dict->Get("box")); 215 | Ptr discrete = DynamicCast(dict->Get("discrete")); 216 | 217 | NS_LOG_UNCOND ("AgentID: " << m_agentId << " MyExecuteActions: " << action); 218 | NS_LOG_UNCOND ("---" << box); 219 | NS_LOG_UNCOND ("---" << discrete); 220 | return true; 221 | } 222 | 223 | } // ns3 namespace -------------------------------------------------------------------------------- /examples/multi-agent/mygym.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Technische Universität Berlin 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | */ 20 | 21 | 22 | #ifndef MY_GYM_ENTITY_H 23 | #define MY_GYM_ENTITY_H 24 | 25 | #include "ns3/opengym-module.h" 26 | #include "ns3/nstime.h" 27 | 28 | namespace ns3 { 29 | 30 | class MyGymEnv : public OpenGymEnv 31 | { 32 | public: 33 | MyGymEnv (); 34 | MyGymEnv (uint32_t agentId, Time stepTime); 35 | virtual ~MyGymEnv (); 36 | static TypeId GetTypeId (void); 37 | virtual void DoDispose (); 38 | 39 | Ptr GetActionSpace(); 40 | Ptr GetObservationSpace(); 41 | bool GetGameOver(); 42 | Ptr GetObservation(); 43 | float GetReward(); 44 | std::string GetExtraInfo(); 45 | bool ExecuteActions(Ptr action); 46 | 47 | private: 48 | void ScheduleNextStateRead(); 49 | 50 | uint32_t m_agentId; 51 | Time m_interval; 52 | }; 53 | 54 | } 55 | 56 | 57 | #endif // MY_GYM_ENTITY_H 58 | -------------------------------------------------------------------------------- /examples/multi-agent/readme.md: -------------------------------------------------------------------------------- 1 | multi-agent example 2 | =================== 3 | 4 | This example shows how to create an ns3-gym environment with multiple agents and connects them to multiple independent Python processes. 5 | Note that for each agent an independent ns3-gym gateway is created. 6 | Each gateway binds its socket on different port number. 7 | Here, agent 1 communicates over port number 5555, while agent 2 uses port number 5556. 8 | 9 | In order to run the example: 10 | 11 | ``` 12 | # Terminal 1 13 | ./waf --run "multi-agent" 14 | 15 | # Terminal 2 16 | cd ./scratch/multi-agent 17 | ./agent1.py 18 | 19 | # Terminal 3 20 | cd ./scratch/multi-agent 21 | ./agent2.py 22 | ``` -------------------------------------------------------------------------------- /examples/multi-agent/sim.cc: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Piotr Gawlowicz 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | * 20 | */ 21 | 22 | #include "ns3/core-module.h" 23 | #include "ns3/opengym-module.h" 24 | #include "mygym.h" 25 | 26 | using namespace ns3; 27 | 28 | NS_LOG_COMPONENT_DEFINE ("OpenGym"); 29 | 30 | int 31 | main (int argc, char *argv[]) 32 | { 33 | // Parameters of the scenario 34 | uint32_t simSeed = 1; 35 | double simulationTime = 1; //seconds 36 | double envStepTime = 0.1; //seconds, ns3gym env step time interval 37 | uint32_t openGymPort = 5555; 38 | uint32_t testArg = 0; 39 | 40 | CommandLine cmd; 41 | // required parameters for OpenGym interface 42 | cmd.AddValue ("openGymPort", "Port number for OpenGym env. Default: 5555", openGymPort); 43 | cmd.AddValue ("simSeed", "Seed for random generator. Default: 1", simSeed); 44 | // optional parameters 45 | cmd.AddValue ("simTime", "Simulation time in seconds. Default: 10s", simulationTime); 46 | cmd.AddValue ("stepTime", "Gym Env step time in seconds. Default: 0.1s", envStepTime); 47 | cmd.AddValue ("testArg", "Extra simulation argument. Default: 0", testArg); 48 | cmd.Parse (argc, argv); 49 | 50 | NS_LOG_UNCOND("Ns3Env parameters:"); 51 | NS_LOG_UNCOND("--simulationTime: " << simulationTime); 52 | NS_LOG_UNCOND("--openGymPort: " << openGymPort); 53 | NS_LOG_UNCOND("--envStepTime: " << envStepTime); 54 | NS_LOG_UNCOND("--seed: " << simSeed); 55 | NS_LOG_UNCOND("--testArg: " << testArg); 56 | 57 | RngSeedManager::SetSeed (1); 58 | RngSeedManager::SetRun (simSeed); 59 | 60 | // OpenGym Env for agent 1 61 | uint32_t agentId = 1; 62 | openGymPort = 5555; 63 | Ptr openGymInterface1 = CreateObject (openGymPort); 64 | Ptr myGymEnv1 = CreateObject (agentId, Seconds(envStepTime)); 65 | myGymEnv1->SetOpenGymInterface(openGymInterface1); 66 | 67 | // OpenGym Env for agent 2 68 | agentId = 2; 69 | openGymPort = 5556; 70 | Ptr openGymInterface2 = CreateObject (openGymPort); 71 | Ptr myGymEnv2 = CreateObject (agentId, Seconds(envStepTime)); 72 | myGymEnv2->SetOpenGymInterface(openGymInterface2); 73 | 74 | NS_LOG_UNCOND ("Simulation start"); 75 | Simulator::Stop (Seconds (simulationTime)); 76 | Simulator::Run (); 77 | NS_LOG_UNCOND ("Simulation stop"); 78 | 79 | openGymInterface1->NotifySimulationEnd(); 80 | openGymInterface2->NotifySimulationEnd(); 81 | Simulator::Destroy (); 82 | 83 | } 84 | -------------------------------------------------------------------------------- /examples/multigym/.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml -------------------------------------------------------------------------------- /examples/multigym/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /examples/multigym/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /examples/multigym/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /examples/multigym/.idea/multigym.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /examples/multigym/.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /examples/multigym/mygym.cc: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: ZhangminWang 19 | */ 20 | 21 | #include "mygym.h" 22 | #include "ns3/object.h" 23 | #include "ns3/core-module.h" 24 | #include "ns3/wifi-module.h" 25 | #include "ns3/node-list.h" 26 | #include "ns3/log.h" 27 | #include 28 | #include 29 | 30 | namespace ns3 { 31 | 32 | NS_LOG_COMPONENT_DEFINE ("MyGymEnv"); 33 | 34 | NS_OBJECT_ENSURE_REGISTERED (MyGymEnv); 35 | 36 | MyGymEnv::MyGymEnv () 37 | { 38 | NS_LOG_FUNCTION (this); 39 | m_interval = Seconds (0.1); 40 | 41 | Simulator::Schedule (Seconds (0.0), &MyGymEnv::ScheduleNextStateRead, this); 42 | } 43 | 44 | MyGymEnv::MyGymEnv (Time stepTime) 45 | { 46 | NS_LOG_FUNCTION (this); 47 | m_interval = stepTime; 48 | 49 | Simulator::Schedule (Seconds (0.0), &MyGymEnv::ScheduleNextStateRead, this); 50 | } 51 | 52 | void 53 | MyGymEnv::ScheduleNextStateRead () 54 | { 55 | NS_LOG_FUNCTION (this); 56 | Simulator::Schedule (m_interval, &MyGymEnv::ScheduleNextStateRead, this); 57 | // Notify(); 58 | Step (); 59 | } 60 | 61 | MyGymEnv::~MyGymEnv () 62 | { 63 | NS_LOG_FUNCTION (this); 64 | } 65 | 66 | TypeId 67 | MyGymEnv::GetTypeId (void) 68 | { 69 | static TypeId tid = TypeId ("MyGymEnv") 70 | .SetParent () 71 | .SetGroupName ("OpenGym") 72 | .AddConstructor (); 73 | return tid; 74 | } 75 | 76 | void 77 | MyGymEnv::DoDispose () 78 | { 79 | NS_LOG_FUNCTION (this); 80 | } 81 | 82 | /* 83 | Define observation space 84 | */ 85 | Ptr 86 | MyGymEnv::GetObservationSpace (uint32_t id) 87 | { 88 | // uint32_t nodeNum = 5; 89 | float low = 0.0; 90 | float high = 10.0; 91 | std::vector shape = {1}; 92 | std::string dtype = TypeNameGet (); 93 | 94 | Ptr box = CreateObject (low, high, shape, dtype); 95 | 96 | NS_LOG_UNCOND ("ID " << id << " MyGetObservationSpace: " << box); 97 | return box; 98 | } 99 | 100 | /* 101 | Define action space 102 | */ 103 | Ptr 104 | MyGymEnv::GetActionSpace (uint32_t id) 105 | { 106 | Ptr discrete = CreateObject (5); 107 | 108 | NS_LOG_UNCOND ("ID " << id << " MyGetActionSpace: " << discrete); 109 | return discrete; 110 | } 111 | 112 | /* 113 | Define game over condition 114 | */ 115 | bool 116 | MyGymEnv::GetDone (uint32_t id) 117 | { 118 | bool isGameOver = false; 119 | bool test = false; 120 | static float stepCounter = 0.0; 121 | stepCounter += 1; 122 | if (stepCounter == 10 && test) 123 | { 124 | isGameOver = true; 125 | } 126 | NS_LOG_UNCOND ("ID " << id << " MyGetGameOver: " << isGameOver); 127 | return isGameOver; 128 | } 129 | 130 | /* 131 | Collect observations 132 | */ 133 | Ptr 134 | MyGymEnv::GetObservation (uint32_t id) 135 | { 136 | uint32_t low = 0.0; 137 | uint32_t high = 10.0; 138 | Ptr rngInt = CreateObject (); 139 | 140 | std::vector shape = {1}; 141 | Ptr> box = CreateObject> (shape); 142 | 143 | // generate random data 144 | uint32_t value = rngInt->GetInteger (low, high); 145 | box->AddValue (value); 146 | 147 | Ptr data = CreateObject (); 148 | data->Add (box); 149 | 150 | // Print data from tuple 151 | Ptr> mbox = 152 | DynamicCast> (data->Get (0)); 153 | NS_LOG_UNCOND ("ID " << id << " MyGetObservation: " << data); 154 | NS_LOG_UNCOND ("---" << mbox); 155 | 156 | return data; 157 | } 158 | 159 | /* 160 | Define reward function 161 | */ 162 | float 163 | MyGymEnv::GetReward (uint32_t id) 164 | { 165 | static float reward = 0.0; 166 | reward += 1; 167 | return reward; 168 | } 169 | 170 | /* 171 | Define extra info. Optional 172 | */ 173 | std::string 174 | MyGymEnv::GetInfo (uint32_t id) 175 | { 176 | std::string myInfo = "testInfo"; 177 | NS_LOG_UNCOND ("ID " << id << " MyGetExtraInfo: " << myInfo); 178 | return myInfo; 179 | } 180 | 181 | /* 182 | Execute received actions 183 | */ 184 | bool 185 | MyGymEnv::ExecuteActions (uint32_t id, Ptr action) 186 | { 187 | Ptr discrete = DynamicCast (action); 188 | 189 | NS_LOG_UNCOND ("ID " << id << " MyExecuteActions: " << discrete); 190 | return true; 191 | } 192 | 193 | } // namespace ns3 -------------------------------------------------------------------------------- /examples/multigym/mygym.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2019 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: ZhangminWang 19 | */ 20 | 21 | 22 | #ifndef MY_GYM_ENTITY_H 23 | #define MY_GYM_ENTITY_H 24 | 25 | #include "ns3/opengym-module.h" 26 | #include "ns3/nstime.h" 27 | 28 | namespace ns3 { 29 | 30 | /** 31 | * \note This is only a test, not use algorithm 32 | */ 33 | class MyGymEnv : public OpenGymMultiEnv 34 | { 35 | public: 36 | MyGymEnv (); 37 | MyGymEnv (Time stepTime); 38 | virtual ~MyGymEnv (); 39 | static TypeId GetTypeId (void); 40 | virtual void DoDispose (); 41 | 42 | Ptr GetActionSpace(uint32_t id); 43 | Ptr GetObservationSpace(uint32_t id); 44 | bool GetDone(uint32_t id); 45 | Ptr GetObservation(uint32_t id); 46 | float GetReward(uint32_t id); 47 | std::string GetInfo(uint32_t id); 48 | bool ExecuteActions(uint32_t id, Ptr action); 49 | 50 | private: 51 | void ScheduleNextStateRead(); 52 | 53 | Time m_interval; 54 | }; 55 | 56 | } 57 | 58 | 59 | #endif // MY_GYM_ENTITY_H 60 | -------------------------------------------------------------------------------- /examples/multigym/sim.cc: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Piotr Gawlowicz 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | * 20 | */ 21 | 22 | #include "ns3/core-module.h" 23 | #include "ns3/opengym-module.h" 24 | #include "mygym.h" 25 | 26 | using namespace ns3; 27 | 28 | NS_LOG_COMPONENT_DEFINE ("OpenGym"); 29 | 30 | int 31 | main (int argc, char *argv[]) 32 | { 33 | // Parameters of the scenario 34 | uint32_t simSeed = 1; 35 | double simulationTime = 1; //seconds 36 | double envStepTime = 0.1; //seconds, ns3gym env step time interval 37 | uint32_t testArg = 0; 38 | uint32_t openGymPort; 39 | 40 | CommandLine cmd; 41 | // required parameters for OpenGym interface 42 | cmd.AddValue ("openGymPort", "Port number for OpenGym env. Default: 5555", openGymPort); 43 | cmd.AddValue ("simSeed", "Seed for random generator. Default: 1", simSeed); 44 | // optional parameters 45 | cmd.AddValue ("simTime", "Simulation time in seconds. Default: 10s", simulationTime); 46 | cmd.AddValue ("stepTime", "Gym Env step time in seconds. Default: 0.1s", envStepTime); 47 | cmd.AddValue ("testArg", "Extra simulation argument. Default: 0", testArg); 48 | cmd.Parse (argc, argv); 49 | 50 | NS_LOG_UNCOND ("Ns3Env parameters:"); 51 | NS_LOG_UNCOND ("--simulationTime: " << simulationTime); 52 | NS_LOG_UNCOND ("--envStepTime: " << envStepTime); 53 | NS_LOG_UNCOND ("--seed: " << simSeed); 54 | NS_LOG_UNCOND ("--testArg: " << testArg); 55 | 56 | RngSeedManager::SetSeed (1); 57 | RngSeedManager::SetRun (simSeed); 58 | 59 | // OpenGym MultiEnv 60 | Ptr myGymEnv = CreateObject (Seconds (envStepTime)); 61 | for (uint32_t id = 1; id < 11; id++) 62 | { 63 | myGymEnv->AddAgentId (id); 64 | } 65 | 66 | NS_LOG_UNCOND ("Simulation start"); 67 | Simulator::Stop (Seconds (simulationTime)); 68 | Simulator::Run (); 69 | NS_LOG_UNCOND ("Simulation stop"); 70 | 71 | myGymEnv->NotifySimulationEnd (); 72 | Simulator::Destroy (); 73 | } 74 | -------------------------------------------------------------------------------- /examples/multigym/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | from ns3gym import ns3_multiagent_env as ns3env 6 | 7 | __author__ = "ZhangminWang" 8 | __copyright__ = "Copyright (c) 2019" 9 | __version__ = "0.1.1" 10 | 11 | # NOTE: This is only a test, not use algorithm 12 | 13 | parser = argparse.ArgumentParser(description='Start simulation script on/off') 14 | parser.add_argument('--start', 15 | type=int, 16 | default=1, 17 | help='Start ns-3 simulation script 0/1, Default: 1') 18 | parser.add_argument('--iterations', 19 | type=int, 20 | default=10, 21 | help='Number of iterations, Default: 10') 22 | args = parser.parse_args() 23 | startSim = True 24 | iterationNum = int(args.iterations) 25 | 26 | port = 5555 27 | simTime = 5 # seconds 28 | stepTime = 0.5 # seconds 29 | seed = 0 30 | simArgs = {"--simTime": simTime, 31 | "--stepTime": stepTime, 32 | "--testArg": 123} 33 | debug = False 34 | 35 | env = ns3env.MultiEnv(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug) 36 | # simpler: 37 | #env = ns3env.Ns3Env() 38 | env.reset() 39 | ob_spaces = env.observation_space 40 | ac_spaces = env.action_space 41 | 42 | print("Observation space: ", ob_spaces, " size: ", len(ob_spaces)) 43 | print("Action space: ", ac_spaces, " size: ", len(ac_spaces)) 44 | 45 | stepIdx = 0 46 | currIt = 0 47 | agent_num = len(ob_spaces) 48 | 49 | try: 50 | while True: 51 | print("Start iteration: ", currIt) 52 | obs = env.reset() 53 | print("---obs: ", obs) 54 | 55 | while True: 56 | stepIdx += 1 57 | print("Step: ", stepIdx) 58 | actions = [] 59 | for ag in range(agent_num): 60 | action = env.action_space[ag].sample() 61 | print("---action: ", action) 62 | actions.append(action) 63 | 64 | obs, reward, _, _ = env.step(actions) 65 | print("---obs, reward: ", obs, reward) 66 | 67 | if stepIdx == 10: 68 | stepIdx = 0 69 | break 70 | 71 | currIt += 1 72 | if currIt == iterationNum: 73 | break 74 | 75 | except KeyboardInterrupt: 76 | print("Ctrl-C -> Exit") 77 | finally: 78 | env.close() 79 | print("Done") -------------------------------------------------------------------------------- /examples/opengym-2/mygym.cc: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Technische Universität Berlin 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | */ 20 | 21 | #include "mygym.h" 22 | #include "ns3/object.h" 23 | #include "ns3/core-module.h" 24 | #include "ns3/wifi-module.h" 25 | #include "ns3/node-list.h" 26 | #include "ns3/log.h" 27 | #include 28 | #include 29 | 30 | namespace ns3 { 31 | 32 | NS_LOG_COMPONENT_DEFINE ("MyGymEnv"); 33 | 34 | NS_OBJECT_ENSURE_REGISTERED (MyGymEnv); 35 | 36 | MyGymEnv::MyGymEnv () 37 | { 38 | NS_LOG_FUNCTION (this); 39 | m_interval = Seconds(0.1); 40 | 41 | Simulator::Schedule (Seconds(0.0), &MyGymEnv::ScheduleNextStateRead, this); 42 | } 43 | 44 | MyGymEnv::MyGymEnv (Time stepTime) 45 | { 46 | NS_LOG_FUNCTION (this); 47 | m_interval = stepTime; 48 | 49 | Simulator::Schedule (Seconds(0.0), &MyGymEnv::ScheduleNextStateRead, this); 50 | } 51 | 52 | void 53 | MyGymEnv::ScheduleNextStateRead () 54 | { 55 | NS_LOG_FUNCTION (this); 56 | Simulator::Schedule (m_interval, &MyGymEnv::ScheduleNextStateRead, this); 57 | Notify(); 58 | } 59 | 60 | MyGymEnv::~MyGymEnv () 61 | { 62 | NS_LOG_FUNCTION (this); 63 | } 64 | 65 | TypeId 66 | MyGymEnv::GetTypeId (void) 67 | { 68 | static TypeId tid = TypeId ("MyGymEnv") 69 | .SetParent () 70 | .SetGroupName ("OpenGym") 71 | .AddConstructor () 72 | ; 73 | return tid; 74 | } 75 | 76 | void 77 | MyGymEnv::DoDispose () 78 | { 79 | NS_LOG_FUNCTION (this); 80 | } 81 | 82 | /* 83 | Define observation space 84 | */ 85 | Ptr 86 | MyGymEnv::GetObservationSpace() 87 | { 88 | uint32_t nodeNum = 5; 89 | float low = 0.0; 90 | float high = 10.0; 91 | std::vector shape = {nodeNum,}; 92 | std::string dtype = TypeNameGet (); 93 | 94 | Ptr discrete = CreateObject (nodeNum); 95 | Ptr box = CreateObject (low, high, shape, dtype); 96 | 97 | Ptr space = CreateObject (); 98 | space->Add("box", box); 99 | space->Add("discrete", discrete); 100 | 101 | NS_LOG_UNCOND ("MyGetObservationSpace: " << space); 102 | return space; 103 | } 104 | 105 | /* 106 | Define action space 107 | */ 108 | Ptr 109 | MyGymEnv::GetActionSpace() 110 | { 111 | uint32_t nodeNum = 5; 112 | float low = 0.0; 113 | float high = 10.0; 114 | std::vector shape = {nodeNum,}; 115 | std::string dtype = TypeNameGet (); 116 | 117 | Ptr discrete = CreateObject (nodeNum); 118 | Ptr box = CreateObject (low, high, shape, dtype); 119 | 120 | Ptr space = CreateObject (); 121 | space->Add("box", box); 122 | space->Add("discrete", discrete); 123 | 124 | NS_LOG_UNCOND ("MyGetActionSpace: " << space); 125 | return space; 126 | } 127 | 128 | /* 129 | Define game over condition 130 | */ 131 | bool 132 | MyGymEnv::GetGameOver() 133 | { 134 | bool isGameOver = false; 135 | bool test = false; 136 | static float stepCounter = 0.0; 137 | stepCounter += 1; 138 | if (stepCounter == 10 && test) { 139 | isGameOver = true; 140 | } 141 | NS_LOG_UNCOND ("MyGetGameOver: " << isGameOver); 142 | return isGameOver; 143 | } 144 | 145 | /* 146 | Collect observations 147 | */ 148 | Ptr 149 | MyGymEnv::GetObservation() 150 | { 151 | uint32_t nodeNum = 5; 152 | uint32_t low = 0.0; 153 | uint32_t high = 10.0; 154 | Ptr rngInt = CreateObject (); 155 | 156 | std::vector shape = {nodeNum,}; 157 | Ptr > box = CreateObject >(shape); 158 | 159 | // generate random data 160 | for (uint32_t i = 0; iGetInteger(low, high); 162 | box->AddValue(value); 163 | } 164 | 165 | Ptr discrete = CreateObject(nodeNum); 166 | uint32_t value = rngInt->GetInteger(low, high); 167 | discrete->SetValue(value); 168 | 169 | Ptr data = CreateObject (); 170 | data->Add(box); 171 | data->Add(discrete); 172 | 173 | // Print data from tuple 174 | Ptr > mbox = DynamicCast >(data->Get(0)); 175 | Ptr mdiscrete = DynamicCast(data->Get(1)); 176 | NS_LOG_UNCOND ("MyGetObservation: " << data); 177 | NS_LOG_UNCOND ("---" << mbox); 178 | NS_LOG_UNCOND ("---" << mdiscrete); 179 | 180 | return data; 181 | } 182 | 183 | /* 184 | Define reward function 185 | */ 186 | float 187 | MyGymEnv::GetReward() 188 | { 189 | static float reward = 0.0; 190 | reward += 1; 191 | return reward; 192 | } 193 | 194 | /* 195 | Define extra info. Optional 196 | */ 197 | std::string 198 | MyGymEnv::GetExtraInfo() 199 | { 200 | std::string myInfo = "testInfo"; 201 | myInfo += "|123"; 202 | NS_LOG_UNCOND("MyGetExtraInfo: " << myInfo); 203 | return myInfo; 204 | } 205 | 206 | /* 207 | Execute received actions 208 | */ 209 | bool 210 | MyGymEnv::ExecuteActions(Ptr action) 211 | { 212 | Ptr dict = DynamicCast(action); 213 | Ptr > box = DynamicCast >(dict->Get("box")); 214 | Ptr discrete = DynamicCast(dict->Get("discrete")); 215 | 216 | NS_LOG_UNCOND ("MyExecuteActions: " << action); 217 | NS_LOG_UNCOND ("---" << box); 218 | NS_LOG_UNCOND ("---" << discrete); 219 | return true; 220 | } 221 | 222 | } // ns3 namespace -------------------------------------------------------------------------------- /examples/opengym-2/mygym.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Technische Universität Berlin 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | */ 20 | 21 | 22 | #ifndef MY_GYM_ENTITY_H 23 | #define MY_GYM_ENTITY_H 24 | 25 | #include "ns3/opengym-module.h" 26 | #include "ns3/nstime.h" 27 | 28 | namespace ns3 { 29 | 30 | class MyGymEnv : public OpenGymEnv 31 | { 32 | public: 33 | MyGymEnv (); 34 | MyGymEnv (Time stepTime); 35 | virtual ~MyGymEnv (); 36 | static TypeId GetTypeId (void); 37 | virtual void DoDispose (); 38 | 39 | Ptr GetActionSpace(); 40 | Ptr GetObservationSpace(); 41 | bool GetGameOver(); 42 | Ptr GetObservation(); 43 | float GetReward(); 44 | std::string GetExtraInfo(); 45 | bool ExecuteActions(Ptr action); 46 | 47 | private: 48 | void ScheduleNextStateRead(); 49 | 50 | Time m_interval; 51 | }; 52 | 53 | } 54 | 55 | 56 | #endif // MY_GYM_ENTITY_H 57 | -------------------------------------------------------------------------------- /examples/opengym-2/sim.cc: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Piotr Gawlowicz 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | * 20 | */ 21 | 22 | #include "ns3/core-module.h" 23 | #include "ns3/opengym-module.h" 24 | #include "mygym.h" 25 | 26 | using namespace ns3; 27 | 28 | NS_LOG_COMPONENT_DEFINE ("OpenGym"); 29 | 30 | int 31 | main (int argc, char *argv[]) 32 | { 33 | // Parameters of the scenario 34 | uint32_t simSeed = 1; 35 | double simulationTime = 1; //seconds 36 | double envStepTime = 0.1; //seconds, ns3gym env step time interval 37 | uint32_t openGymPort = 5555; 38 | uint32_t testArg = 0; 39 | 40 | CommandLine cmd; 41 | // required parameters for OpenGym interface 42 | cmd.AddValue ("openGymPort", "Port number for OpenGym env. Default: 5555", openGymPort); 43 | cmd.AddValue ("simSeed", "Seed for random generator. Default: 1", simSeed); 44 | // optional parameters 45 | cmd.AddValue ("simTime", "Simulation time in seconds. Default: 10s", simulationTime); 46 | cmd.AddValue ("stepTime", "Gym Env step time in seconds. Default: 0.1s", envStepTime); 47 | cmd.AddValue ("testArg", "Extra simulation argument. Default: 0", testArg); 48 | cmd.Parse (argc, argv); 49 | 50 | NS_LOG_UNCOND("Ns3Env parameters:"); 51 | NS_LOG_UNCOND("--simulationTime: " << simulationTime); 52 | NS_LOG_UNCOND("--openGymPort: " << openGymPort); 53 | NS_LOG_UNCOND("--envStepTime: " << envStepTime); 54 | NS_LOG_UNCOND("--seed: " << simSeed); 55 | NS_LOG_UNCOND("--testArg: " << testArg); 56 | 57 | RngSeedManager::SetSeed (1); 58 | RngSeedManager::SetRun (simSeed); 59 | 60 | // OpenGym Env 61 | Ptr openGymInterface = CreateObject (openGymPort); 62 | Ptr myGymEnv = CreateObject (Seconds(envStepTime)); 63 | myGymEnv->SetOpenGymInterface(openGymInterface); 64 | 65 | NS_LOG_UNCOND ("Simulation start"); 66 | Simulator::Stop (Seconds (simulationTime)); 67 | Simulator::Run (); 68 | NS_LOG_UNCOND ("Simulation stop"); 69 | 70 | openGymInterface->NotifySimulationEnd(); 71 | Simulator::Destroy (); 72 | 73 | } 74 | -------------------------------------------------------------------------------- /examples/opengym-2/simple_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import gym 5 | import argparse 6 | import ns3gym 7 | 8 | __author__ = "Piotr Gawlowicz" 9 | __copyright__ = "Copyright (c) 2018, Technische Universität Berlin" 10 | __version__ = "0.1.0" 11 | __email__ = "gawlowicz@tkn.tu-berlin.de" 12 | 13 | 14 | env = gym.make('ns3-v0') 15 | env.reset() 16 | 17 | ob_space = env.observation_space 18 | ac_space = env.action_space 19 | print("Observation space: ", ob_space, ob_space.dtype) 20 | print("Action space: ", ac_space, ac_space.dtype) 21 | 22 | stepIdx = 0 23 | 24 | try: 25 | obs = env.reset() 26 | print("Step: ", stepIdx) 27 | print("---obs: ", obs) 28 | 29 | while True: 30 | stepIdx += 1 31 | 32 | action = env.action_space.sample() 33 | print("---action: ", action) 34 | obs, reward, done, info = env.step(action) 35 | 36 | print("Step: ", stepIdx) 37 | print("---obs, reward, done, info: ", obs, reward, done, info) 38 | 39 | if done: 40 | break 41 | 42 | except KeyboardInterrupt: 43 | print("Ctrl-C -> Exit") 44 | finally: 45 | env.close() 46 | print("Done") -------------------------------------------------------------------------------- /examples/opengym-2/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | from ns3gym import ns3env 6 | 7 | __author__ = "Piotr Gawlowicz" 8 | __copyright__ = "Copyright (c) 2018, Technische Universität Berlin" 9 | __version__ = "0.1.0" 10 | __email__ = "gawlowicz@tkn.tu-berlin.de" 11 | 12 | 13 | parser = argparse.ArgumentParser(description='Start simulation script on/off') 14 | parser.add_argument('--start', 15 | type=int, 16 | default=1, 17 | help='Start ns-3 simulation script 0/1, Default: 1') 18 | parser.add_argument('--iterations', 19 | type=int, 20 | default=1, 21 | help='Number of iterations, Default: 1') 22 | args = parser.parse_args() 23 | startSim = bool(args.start) 24 | iterationNum = int(args.iterations) 25 | 26 | port = 5555 27 | simTime = 5 # seconds 28 | stepTime = 0.5 # seconds 29 | seed = 0 30 | simArgs = {"--simTime": simTime, 31 | "--stepTime": stepTime, 32 | "--testArg": 123} 33 | debug = False 34 | 35 | env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug) 36 | # simpler: 37 | #env = ns3env.Ns3Env() 38 | env.reset() 39 | 40 | ob_space = env.observation_space 41 | ac_space = env.action_space 42 | print("Observation space: ", ob_space, ob_space.dtype) 43 | print("Action space: ", ac_space, ac_space.dtype) 44 | 45 | # stepIdx = 0 46 | # currIt = 0 47 | 48 | # try: 49 | # while True: 50 | # print("Start iteration: ", currIt) 51 | # obs = env.reset() 52 | # print("Step: ", stepIdx) 53 | # print("---obs: ", obs) 54 | 55 | # while True: 56 | # stepIdx += 1 57 | # action = env.action_space.sample() 58 | # print("---action: ", action) 59 | 60 | # print("Step: ", stepIdx) 61 | # obs, reward, done, info = env.step(action) 62 | # print("---obs, reward, done, info: ", obs, reward, done, info) 63 | 64 | # if done: 65 | # stepIdx = 0 66 | # if currIt + 1 < iterationNum: 67 | # env.reset() 68 | # break 69 | 70 | # currIt += 1 71 | # if currIt == iterationNum: 72 | # break 73 | 74 | # except KeyboardInterrupt: 75 | # print("Ctrl-C -> Exit") 76 | # finally: 77 | # env.close() 78 | # print("Done") -------------------------------------------------------------------------------- /examples/opengym/sim.cc: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Piotr Gawlowicz 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | * 20 | */ 21 | 22 | #include "ns3/core-module.h" 23 | #include "ns3/opengym-module.h" 24 | 25 | using namespace ns3; 26 | 27 | NS_LOG_COMPONENT_DEFINE ("OpenGym"); 28 | 29 | /* 30 | Define observation space 31 | */ 32 | Ptr MyGetObservationSpace(void) 33 | { 34 | uint32_t nodeNum = 5; 35 | float low = 0.0; 36 | float high = 10.0; 37 | std::vector shape = {nodeNum,}; 38 | std::string dtype = TypeNameGet (); 39 | Ptr space = CreateObject (low, high, shape, dtype); 40 | NS_LOG_UNCOND ("MyGetObservationSpace: " << space); 41 | return space; 42 | } 43 | 44 | /* 45 | Define action space 46 | */ 47 | Ptr MyGetActionSpace(void) 48 | { 49 | uint32_t nodeNum = 5; 50 | 51 | Ptr space = CreateObject (nodeNum); 52 | NS_LOG_UNCOND ("MyGetActionSpace: " << space); 53 | return space; 54 | } 55 | 56 | /* 57 | Define game over condition 58 | */ 59 | bool MyGetGameOver(void) 60 | { 61 | 62 | bool isGameOver = false; 63 | bool test = false; 64 | static float stepCounter = 0.0; 65 | stepCounter += 1; 66 | if (stepCounter == 10 && test) { 67 | isGameOver = true; 68 | } 69 | NS_LOG_UNCOND ("MyGetGameOver: " << isGameOver); 70 | return isGameOver; 71 | } 72 | 73 | /* 74 | Collect observations 75 | */ 76 | Ptr MyGetObservation(void) 77 | { 78 | uint32_t nodeNum = 5; 79 | uint32_t low = 0.0; 80 | uint32_t high = 10.0; 81 | Ptr rngInt = CreateObject (); 82 | 83 | std::vector shape = {nodeNum,}; 84 | Ptr > box = CreateObject >(shape); 85 | 86 | // generate random data 87 | for (uint32_t i = 0; iGetInteger(low, high); 89 | box->AddValue(value); 90 | } 91 | 92 | NS_LOG_UNCOND ("MyGetObservation: " << box); 93 | return box; 94 | } 95 | 96 | /* 97 | Define reward function 98 | */ 99 | float MyGetReward(void) 100 | { 101 | static float reward = 0.0; 102 | reward += 1; 103 | return reward; 104 | } 105 | 106 | /* 107 | Define extra info. Optional 108 | */ 109 | std::string MyGetExtraInfo(void) 110 | { 111 | std::string myInfo = "testInfo"; 112 | myInfo += "|123"; 113 | NS_LOG_UNCOND("MyGetExtraInfo: " << myInfo); 114 | return myInfo; 115 | } 116 | 117 | 118 | /* 119 | Execute received actions 120 | */ 121 | bool MyExecuteActions(Ptr action) 122 | { 123 | Ptr discrete = DynamicCast(action); 124 | NS_LOG_UNCOND ("MyExecuteActions: " << action); 125 | return true; 126 | } 127 | 128 | void ScheduleNextStateRead(double envStepTime, Ptr openGym) 129 | { 130 | Simulator::Schedule (Seconds(envStepTime), &ScheduleNextStateRead, envStepTime, openGym); 131 | openGym->NotifyCurrentState(); 132 | } 133 | 134 | int 135 | main (int argc, char *argv[]) 136 | { 137 | // Parameters of the scenario 138 | uint32_t simSeed = 1; 139 | double simulationTime = 1; //seconds 140 | double envStepTime = 0.1; //seconds, ns3gym env step time interval 141 | uint32_t openGymPort = 5555; 142 | uint32_t testArg = 0; 143 | 144 | CommandLine cmd; 145 | // required parameters for OpenGym interface 146 | cmd.AddValue ("openGymPort", "Port number for OpenGym env. Default: 5555", openGymPort); 147 | cmd.AddValue ("simSeed", "Seed for random generator. Default: 1", simSeed); 148 | // optional parameters 149 | cmd.AddValue ("simTime", "Simulation time in seconds. Default: 10s", simulationTime); 150 | cmd.AddValue ("testArg", "Extra simulation argument. Default: 0", testArg); 151 | cmd.Parse (argc, argv); 152 | 153 | NS_LOG_UNCOND("Ns3Env parameters:"); 154 | NS_LOG_UNCOND("--simulationTime: " << simulationTime); 155 | NS_LOG_UNCOND("--openGymPort: " << openGymPort); 156 | NS_LOG_UNCOND("--envStepTime: " << envStepTime); 157 | NS_LOG_UNCOND("--seed: " << simSeed); 158 | NS_LOG_UNCOND("--testArg: " << testArg); 159 | 160 | RngSeedManager::SetSeed (1); 161 | RngSeedManager::SetRun (simSeed); 162 | 163 | // OpenGym Env 164 | Ptr openGym = CreateObject (openGymPort); 165 | openGym->SetGetActionSpaceCb( MakeCallback (&MyGetActionSpace) ); 166 | openGym->SetGetObservationSpaceCb( MakeCallback (&MyGetObservationSpace) ); 167 | openGym->SetGetGameOverCb( MakeCallback (&MyGetGameOver) ); 168 | openGym->SetGetObservationCb( MakeCallback (&MyGetObservation) ); 169 | openGym->SetGetRewardCb( MakeCallback (&MyGetReward) ); 170 | openGym->SetGetExtraInfoCb( MakeCallback (&MyGetExtraInfo) ); 171 | openGym->SetExecuteActionsCb( MakeCallback (&MyExecuteActions) ); 172 | Simulator::Schedule (Seconds(0.0), &ScheduleNextStateRead, envStepTime, openGym); 173 | 174 | NS_LOG_UNCOND ("Simulation start"); 175 | Simulator::Stop (Seconds (simulationTime)); 176 | Simulator::Run (); 177 | NS_LOG_UNCOND ("Simulation stop"); 178 | 179 | openGym->NotifySimulationEnd(); 180 | Simulator::Destroy (); 181 | 182 | } 183 | -------------------------------------------------------------------------------- /examples/opengym/simple_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import gym 5 | import argparse 6 | import ns3gym 7 | 8 | __author__ = "Piotr Gawlowicz" 9 | __copyright__ = "Copyright (c) 2018, Technische Universität Berlin" 10 | __version__ = "0.1.0" 11 | __email__ = "gawlowicz@tkn.tu-berlin.de" 12 | 13 | 14 | env = gym.make('ns3-v0') 15 | env.reset() 16 | 17 | ob_space = env.observation_space 18 | ac_space = env.action_space 19 | print("Observation space: ", ob_space, ob_space.dtype) 20 | print("Action space: ", ac_space, ac_space.dtype) 21 | 22 | stepIdx = 0 23 | 24 | try: 25 | obs = env.reset() 26 | print("Step: ", stepIdx) 27 | print("---obs: ", obs) 28 | 29 | while True: 30 | stepIdx += 1 31 | 32 | action = env.action_space.sample() 33 | print("---action: ", action) 34 | obs, reward, done, info = env.step(action) 35 | 36 | print("Step: ", stepIdx) 37 | print("---obs, reward, done, info: ", obs, reward, done, info) 38 | 39 | if done: 40 | break 41 | 42 | except KeyboardInterrupt: 43 | print("Ctrl-C -> Exit") 44 | finally: 45 | env.close() 46 | print("Done") -------------------------------------------------------------------------------- /examples/opengym/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | from ns3gym import ns3env 6 | 7 | __author__ = "Piotr Gawlowicz" 8 | __copyright__ = "Copyright (c) 2018, Technische Universität Berlin" 9 | __version__ = "0.1.0" 10 | __email__ = "gawlowicz@tkn.tu-berlin.de" 11 | 12 | 13 | parser = argparse.ArgumentParser(description='Start simulation script on/off') 14 | parser.add_argument('--start', 15 | type=int, 16 | default=1, 17 | help='Start ns-3 simulation script 0/1, Default: 1') 18 | parser.add_argument('--iterations', 19 | type=int, 20 | default=1, 21 | help='Number of iterations, Default: 1') 22 | args = parser.parse_args() 23 | startSim = bool(args.start) 24 | iterationNum = int(args.iterations) 25 | 26 | port = 5555 27 | simTime = 20 # seconds 28 | stepTime = 0.5 # seconds 29 | seed = 0 30 | simArgs = {"--simTime": simTime, 31 | "--testArg": 123} 32 | debug = False 33 | 34 | env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug) 35 | # simpler: 36 | #env = ns3env.Ns3Env() 37 | env.reset() 38 | 39 | ob_space = env.observation_space 40 | ac_space = env.action_space 41 | print("Observation space: ", ob_space, ob_space.dtype) 42 | print("Action space: ", ac_space, ac_space.dtype) 43 | 44 | stepIdx = 0 45 | currIt = 0 46 | 47 | try: 48 | while True: 49 | print("Start iteration: ", currIt) 50 | obs = env.reset() 51 | print("Step: ", stepIdx) 52 | print("---obs:", obs) 53 | 54 | while True: 55 | stepIdx += 1 56 | action = env.action_space.sample() 57 | print("---action: ", action) 58 | 59 | print("Step: ", stepIdx) 60 | obs, reward, done, info = env.step(action) 61 | print("---obs, reward, done, info: ", obs, reward, done, info) 62 | 63 | if done: 64 | stepIdx = 0 65 | if currIt + 1 < iterationNum: 66 | env.reset() 67 | break 68 | 69 | currIt += 1 70 | if currIt == iterationNum: 71 | break 72 | 73 | except KeyboardInterrupt: 74 | print("Ctrl-C -> Exit") 75 | finally: 76 | env.close() 77 | print("Done") -------------------------------------------------------------------------------- /examples/rl-tcp/tcp-rl-env.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Technische Universität Berlin 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | */ 20 | 21 | #ifndef TCP_RL_ENV_H 22 | #define TCP_RL_ENV_H 23 | 24 | #include "ns3/opengym-module.h" 25 | #include "ns3/tcp-socket-base.h" 26 | #include 27 | 28 | namespace ns3 { 29 | 30 | class Packet; 31 | class TcpHeader; 32 | class TcpSocketBase; 33 | class Time; 34 | 35 | 36 | class TcpGymEnv : public OpenGymEnv 37 | { 38 | public: 39 | TcpGymEnv (); 40 | virtual ~TcpGymEnv (); 41 | static TypeId GetTypeId (void); 42 | virtual void DoDispose (); 43 | 44 | void SetNodeId(uint32_t id); 45 | void SetSocketUuid(uint32_t id); 46 | 47 | std::string GetTcpCongStateName(const TcpSocketState::TcpCongState_t state); 48 | std::string GetTcpCAEventName(const TcpSocketState::TcpCAEvent_t event); 49 | 50 | // OpenGym interface 51 | virtual Ptr GetActionSpace(); 52 | virtual bool GetGameOver(); 53 | virtual float GetReward(); 54 | virtual std::string GetExtraInfo(); 55 | virtual bool ExecuteActions(Ptr action); 56 | 57 | virtual Ptr GetObservationSpace() = 0; 58 | virtual Ptr GetObservation() = 0; 59 | 60 | // trace packets, e.g. for calculating inter tx/rx time 61 | virtual void TxPktTrace(Ptr, const TcpHeader&, Ptr) = 0; 62 | virtual void RxPktTrace(Ptr, const TcpHeader&, Ptr) = 0; 63 | 64 | // TCP congestion control interface 65 | virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight) = 0; 66 | virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked) = 0; 67 | // optional functions used to collect obs 68 | virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt) = 0; 69 | virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState) = 0; 70 | virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event) = 0; 71 | 72 | typedef enum 73 | { 74 | GET_SS_THRESH = 0, 75 | INCREASE_WINDOW, 76 | PKTS_ACKED, 77 | CONGESTION_STATE_SET, 78 | CWND_EVENT, 79 | } CalledFunc_t; 80 | 81 | protected: 82 | uint32_t m_nodeId; 83 | uint32_t m_socketUuid; 84 | 85 | // state 86 | // obs has to be implemented in child class 87 | 88 | // game over 89 | bool m_isGameOver; 90 | 91 | // reward 92 | float m_envReward; 93 | 94 | // extra info 95 | std::string m_info; 96 | 97 | // actions 98 | uint32_t m_new_ssThresh; 99 | uint32_t m_new_cWnd; 100 | }; 101 | 102 | 103 | class TcpEventGymEnv : public TcpGymEnv 104 | { 105 | public: 106 | TcpEventGymEnv (); 107 | virtual ~TcpEventGymEnv (); 108 | static TypeId GetTypeId (void); 109 | virtual void DoDispose (); 110 | 111 | void SetReward(float value); 112 | void SetPenalty(float value); 113 | 114 | // OpenGym interface 115 | virtual Ptr GetObservationSpace(); 116 | Ptr GetObservation(); 117 | 118 | // trace packets, e.g. for calculating inter tx/rx time 119 | virtual void TxPktTrace(Ptr, const TcpHeader&, Ptr); 120 | virtual void RxPktTrace(Ptr, const TcpHeader&, Ptr); 121 | 122 | // TCP congestion control interface 123 | virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight); 124 | virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked); 125 | // optional functions used to collect obs 126 | virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt); 127 | virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState); 128 | virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event); 129 | 130 | private: 131 | // state 132 | CalledFunc_t m_calledFunc; 133 | Ptr m_tcb; 134 | uint32_t m_bytesInFlight; 135 | uint32_t m_segmentsAcked; 136 | Time m_rtt; 137 | TcpSocketState::TcpCongState_t m_newState; 138 | TcpSocketState::TcpCAEvent_t m_event; 139 | 140 | // reward 141 | float m_reward; 142 | float m_penalty; 143 | }; 144 | 145 | 146 | class TcpTimeStepGymEnv : public TcpGymEnv 147 | { 148 | public: 149 | TcpTimeStepGymEnv (); 150 | TcpTimeStepGymEnv (Time timeStep); 151 | virtual ~TcpTimeStepGymEnv (); 152 | static TypeId GetTypeId (void); 153 | virtual void DoDispose (); 154 | 155 | // OpenGym interface 156 | virtual Ptr GetObservationSpace(); 157 | Ptr GetObservation(); 158 | 159 | // trace packets, e.g. for calculating inter tx/rx time 160 | virtual void TxPktTrace(Ptr, const TcpHeader&, Ptr); 161 | virtual void RxPktTrace(Ptr, const TcpHeader&, Ptr); 162 | 163 | // TCP congestion control interface 164 | virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight); 165 | virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked); 166 | // optional functions used to collect obs 167 | virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt); 168 | virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState); 169 | virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event); 170 | 171 | private: 172 | void ScheduleNextStateRead(); 173 | bool m_started {false}; 174 | Time m_timeStep; 175 | // state 176 | Ptr m_tcb; 177 | std::vector m_bytesInFlight; 178 | std::vector m_segmentsAcked; 179 | 180 | uint64_t m_rttSampleNum {0}; 181 | Time m_rttSum {MicroSeconds (0.0)}; 182 | 183 | Time m_lastPktTxTime {MicroSeconds(0.0)}; 184 | Time m_lastPktRxTime {MicroSeconds(0.0)}; 185 | uint64_t m_interTxTimeNum {0}; 186 | Time m_interTxTimeSum {MicroSeconds (0.0)}; 187 | uint64_t m_interRxTimeNum {0}; 188 | Time m_interRxTimeSum {MicroSeconds (0.0)}; 189 | 190 | // reward 191 | }; 192 | 193 | 194 | 195 | } // namespace ns3 196 | 197 | #endif /* TCP_RL_ENV_H */ -------------------------------------------------------------------------------- /examples/rl-tcp/tcp-rl.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Technische Universität Berlin 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | */ 20 | 21 | #ifndef TCP_RL_H 22 | #define TCP_RL_H 23 | 24 | #include "ns3/tcp-congestion-ops.h" 25 | #include "ns3/opengym-module.h" 26 | #include "ns3/tcp-socket-base.h" 27 | 28 | namespace ns3 { 29 | 30 | class TcpSocketBase; 31 | class Time; 32 | class TcpGymEnv; 33 | 34 | 35 | // used to get pointer to Congestion Algorithm 36 | class TcpSocketDerived : public TcpSocketBase 37 | { 38 | public: 39 | static TypeId GetTypeId (void); 40 | virtual TypeId GetInstanceTypeId () const; 41 | 42 | TcpSocketDerived (void); 43 | virtual ~TcpSocketDerived (void); 44 | 45 | Ptr GetCongestionControlAlgorithm (); 46 | }; 47 | 48 | 49 | class TcpRlBase : public TcpCongestionOps 50 | { 51 | public: 52 | /** 53 | * \brief Get the type ID. 54 | * \return the object TypeId 55 | */ 56 | static TypeId GetTypeId (void); 57 | 58 | TcpRlBase (); 59 | 60 | /** 61 | * \brief Copy constructor. 62 | * \param sock object to copy. 63 | */ 64 | TcpRlBase (const TcpRlBase& sock); 65 | 66 | ~TcpRlBase (); 67 | 68 | virtual std::string GetName () const; 69 | virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight); 70 | virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked); 71 | virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt); 72 | virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState); 73 | virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event); 74 | virtual Ptr Fork (); 75 | 76 | protected: 77 | static uint64_t GenerateUuid (); 78 | virtual void CreateGymEnv(); 79 | void ConnectSocketCallbacks(); 80 | 81 | // OpenGymEnv interface 82 | Ptr m_tcpSocket; 83 | Ptr m_tcpGymEnv; 84 | }; 85 | 86 | 87 | class TcpRl : public TcpRlBase 88 | { 89 | public: 90 | static TypeId GetTypeId (void); 91 | 92 | TcpRl (); 93 | TcpRl (const TcpRl& sock); 94 | ~TcpRl (); 95 | 96 | virtual std::string GetName () const; 97 | private: 98 | virtual void CreateGymEnv(); 99 | // OpenGymEnv env 100 | float m_reward {1.0}; 101 | float m_penalty {-100.0}; 102 | }; 103 | 104 | 105 | class TcpRlTimeBased : public TcpRlBase 106 | { 107 | public: 108 | static TypeId GetTypeId (void); 109 | 110 | TcpRlTimeBased (); 111 | TcpRlTimeBased (const TcpRlTimeBased& sock); 112 | ~TcpRlTimeBased (); 113 | 114 | virtual std::string GetName () const; 115 | 116 | private: 117 | virtual void CreateGymEnv(); 118 | // OpenGymEnv env 119 | Time m_timeStep {MilliSeconds (100)}; 120 | }; 121 | 122 | } // namespace ns3 123 | 124 | #endif /* TCP_RL_H */ -------------------------------------------------------------------------------- /examples/rl-tcp/tcp_base.py: -------------------------------------------------------------------------------- 1 | __author__ = "Piotr Gawlowicz" 2 | __copyright__ = "Copyright (c) 2018, Technische Universität Berlin" 3 | __version__ = "0.1.0" 4 | __email__ = "gawlowicz@tkn.tu-berlin.de" 5 | 6 | 7 | class Tcp(object): 8 | """docstring for Tcp""" 9 | def __init__(self): 10 | super(Tcp, self).__init__() 11 | 12 | def set_spaces(self, obs, act): 13 | self.obsSpace = obs 14 | self.actSpace = act 15 | 16 | def get_action(self, obs, reward, done, info): 17 | pass 18 | 19 | 20 | class TcpEventBased(Tcp): 21 | """docstring for TcpEventBased""" 22 | def __init__(self): 23 | super(TcpEventBased, self).__init__() 24 | 25 | def get_action(self, obs, reward, done, info): 26 | # unique socket ID 27 | socketUuid = obs[0] 28 | # TCP env type: event-based = 0 / time-based = 1 29 | envType = obs[1] 30 | # sim time in us 31 | simTime_us = obs[2] 32 | # unique node ID 33 | nodeId = obs[3] 34 | # current ssThreshold 35 | ssThresh = obs[4] 36 | # current contention window size 37 | cWnd = obs[5] 38 | # segment size 39 | segmentSize = obs[6] 40 | # number of acked segments 41 | segmentsAcked = obs[7] 42 | # estimated bytes in flight 43 | bytesInFlight = obs[8] 44 | # last estimation of RTT 45 | lastRtt_us = obs[9] 46 | # min value of RTT 47 | minRtt_us = obs[10] 48 | # function from Congestion Algorithm (CA) interface: 49 | # GET_SS_THRESH = 0 (packet loss), 50 | # INCREASE_WINDOW (packet acked), 51 | # PKTS_ACKED (unused), 52 | # CONGESTION_STATE_SET (unused), 53 | # CWND_EVENT (unused), 54 | calledFunc = obs[11] 55 | # Congetsion Algorithm (CA) state: 56 | # CA_OPEN = 0, 57 | # CA_DISORDER, 58 | # CA_CWR, 59 | # CA_RECOVERY, 60 | # CA_LOSS, 61 | # CA_LAST_STATE 62 | caState = obs[12] 63 | # Congetsion Algorithm (CA) event: 64 | # CA_EVENT_TX_START = 0, 65 | # CA_EVENT_CWND_RESTART, 66 | # CA_EVENT_COMPLETE_CWR, 67 | # CA_EVENT_LOSS, 68 | # CA_EVENT_ECN_NO_CE, 69 | # CA_EVENT_ECN_IS_CE, 70 | # CA_EVENT_DELAYED_ACK, 71 | # CA_EVENT_NON_DELAYED_ACK, 72 | caEvent = obs[13] 73 | # ECN state: 74 | # ECN_DISABLED = 0, 75 | # ECN_IDLE, 76 | # ECN_CE_RCVD, 77 | # ECN_SENDING_ECE, 78 | # ECN_ECE_RCVD, 79 | # ECN_CWR_SENT 80 | ecnState = obs[14] 81 | 82 | # compute new values 83 | new_cWnd = 10 * segmentSize 84 | new_ssThresh = 5 * segmentSize 85 | 86 | # return actions 87 | actions = [new_ssThresh, new_cWnd] 88 | 89 | return actions 90 | 91 | 92 | class TcpTimeBased(Tcp): 93 | """docstring for TcpTimeBased""" 94 | def __init__(self): 95 | super(TcpTimeBased, self).__init__() 96 | 97 | def get_action(self, obs, reward, done, info): 98 | # unique socket ID 99 | socketUuid = obs[0] 100 | # TCP env type: event-based = 0 / time-based = 1 101 | envType = obs[1] 102 | # sim time in us 103 | simTime_us = obs[2] 104 | # unique node ID 105 | nodeId = obs[3] 106 | # current ssThreshold 107 | ssThresh = obs[4] 108 | # current contention window size 109 | cWnd = obs[5] 110 | # segment size 111 | segmentSize = obs[6] 112 | # bytesInFlightSum 113 | bytesInFlightSum = obs[7] 114 | # bytesInFlightAvg 115 | bytesInFlightAvg = obs[8] 116 | # segmentsAckedSum 117 | segmentsAckedSum = obs[9] 118 | # segmentsAckedAvg 119 | segmentsAckedAvg = obs[10] 120 | # avgRtt 121 | avgRtt = obs[11] 122 | # minRtt 123 | minRtt = obs[12] 124 | # avgInterTx 125 | avgInterTx = obs[13] 126 | # avgInterRx 127 | avgInterRx = obs[14] 128 | # throughput 129 | throughput = obs[15] 130 | 131 | # compute new values 132 | new_cWnd = 10 * segmentSize 133 | new_ssThresh = 5 * segmentSize 134 | 135 | # return actions 136 | actions = [new_ssThresh, new_cWnd] 137 | 138 | return actions -------------------------------------------------------------------------------- /examples/rl-tcp/tcp_newreno.py: -------------------------------------------------------------------------------- 1 | from tcp_base import TcpEventBased 2 | 3 | __author__ = "Piotr Gawlowicz" 4 | __copyright__ = "Copyright (c) 2018, Technische Universität Berlin" 5 | __version__ = "0.1.0" 6 | __email__ = "gawlowicz@tkn.tu-berlin.de" 7 | 8 | 9 | class TcpNewReno(TcpEventBased): 10 | """docstring for TcpNewReno""" 11 | def __init__(self): 12 | super(TcpNewReno, self).__init__() 13 | 14 | def get_action(self, obs, reward, done, info): 15 | # unique socket ID 16 | socketUuid = obs[0] 17 | # TCP env type: event-based = 0 / time-based = 1 18 | envType = obs[1] 19 | # sim time in us 20 | simTime_us = obs[2] 21 | # unique node ID 22 | nodeId = obs[3] 23 | # current ssThreshold 24 | ssThresh = obs[4] 25 | # current contention window size 26 | cWnd = obs[5] 27 | # segment size 28 | segmentSize = obs[6] 29 | # number of acked segments 30 | segmentsAcked = obs[7] 31 | # estimated bytes in flight 32 | bytesInFlight = obs[8] 33 | 34 | new_cWnd = 1 35 | new_ssThresh = 1 36 | 37 | # IncreaseWindow 38 | if (cWnd < ssThresh): 39 | # slow start 40 | if (segmentsAcked >= 1): 41 | new_cWnd = cWnd + segmentSize 42 | 43 | if (cWnd >= ssThresh): 44 | # congestion avoidance 45 | if (segmentsAcked > 0): 46 | adder = 1.0 * (segmentSize * segmentSize) / cWnd; 47 | adder = int(max (1.0, adder)) 48 | new_cWnd = cWnd + adder 49 | 50 | # GetSsThresh 51 | new_ssThresh = int(max (2 * segmentSize, bytesInFlight / 2)) 52 | 53 | # return actions 54 | actions = [new_ssThresh, new_cWnd] 55 | 56 | return actions 57 | -------------------------------------------------------------------------------- /examples/rl-tcp/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | from ns3gym import ns3env 6 | 7 | __author__ = "Piotr Gawlowicz" 8 | __copyright__ = "Copyright (c) 2018, Technische Universität Berlin" 9 | __version__ = "0.1.0" 10 | __email__ = "gawlowicz@tkn.tu-berlin.de" 11 | 12 | 13 | parser = argparse.ArgumentParser(description='Start simulation script on/off') 14 | parser.add_argument('--start', 15 | type=int, 16 | default=1, 17 | help='Start ns-3 simulation script 0/1, Default: 1') 18 | parser.add_argument('--iterations', 19 | type=int, 20 | default=1, 21 | help='Number of iterations, Default: 1') 22 | args = parser.parse_args() 23 | startSim = bool(args.start) 24 | iterationNum = int(args.iterations) 25 | 26 | port = 5555 27 | simTime = 10 # seconds 28 | stepTime = 0.5 # seconds 29 | seed = 0 30 | simArgs = {"--duration": simTime} 31 | debug = False 32 | 33 | env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug) 34 | # simpler: 35 | #env = ns3env.Ns3Env() 36 | env.reset() 37 | 38 | ob_space = env.observation_space 39 | ac_space = env.action_space 40 | print("Observation space: ", ob_space, ob_space.dtype) 41 | print("Action space: ", ac_space, ac_space.dtype) 42 | 43 | stepIdx = 0 44 | currIt = 0 45 | 46 | try: 47 | while True: 48 | print("Start iteration: ", currIt) 49 | obs = env.reset() 50 | print("Step: ", stepIdx) 51 | print("---obs: ", obs) 52 | 53 | while True: 54 | stepIdx += 1 55 | action = env.action_space.sample() 56 | print("---action: ", action) 57 | 58 | print("Step: ", stepIdx) 59 | obs, reward, done, info = env.step(action) 60 | print("---obs, reward, done, info: ", obs, reward, done, info) 61 | 62 | if done: 63 | stepIdx = 0 64 | if currIt + 1 < iterationNum: 65 | env.reset() 66 | break 67 | 68 | currIt += 1 69 | if currIt == iterationNum: 70 | break 71 | 72 | except KeyboardInterrupt: 73 | print("Ctrl-C -> Exit") 74 | finally: 75 | env.close() 76 | print("Done") -------------------------------------------------------------------------------- /examples/rl-tcp/test_tcp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | from ns3gym import ns3env 6 | from tcp_base import TcpTimeBased 7 | from tcp_newreno import TcpNewReno 8 | 9 | __author__ = "Piotr Gawlowicz" 10 | __copyright__ = "Copyright (c) 2018, Technische Universität Berlin" 11 | __version__ = "0.1.0" 12 | __email__ = "gawlowicz@tkn.tu-berlin.de" 13 | 14 | 15 | parser = argparse.ArgumentParser(description='Start simulation script on/off') 16 | parser.add_argument('--start', 17 | type=int, 18 | default=1, 19 | help='Start ns-3 simulation script 0/1, Default: 1') 20 | parser.add_argument('--iterations', 21 | type=int, 22 | default=1, 23 | help='Number of iterations, Default: 1') 24 | args = parser.parse_args() 25 | startSim = bool(args.start) 26 | iterationNum = int(args.iterations) 27 | 28 | port = 5555 29 | simTime = 10 # seconds 30 | stepTime = 0.5 # seconds 31 | seed = 12 32 | simArgs = {"--duration": simTime,} 33 | debug = False 34 | 35 | env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug) 36 | # simpler: 37 | #env = ns3env.Ns3Env() 38 | env.reset() 39 | 40 | ob_space = env.observation_space 41 | ac_space = env.action_space 42 | print("Observation space: ", ob_space, ob_space.dtype) 43 | print("Action space: ", ac_space, ac_space.dtype) 44 | 45 | stepIdx = 0 46 | currIt = 0 47 | 48 | def get_agent(obs): 49 | socketUuid = obs[0] 50 | tcpEnvType = obs[1] 51 | tcpAgent = get_agent.tcpAgents.get(socketUuid, None) 52 | if tcpAgent is None: 53 | if tcpEnvType == 0: 54 | # event-based = 0 55 | tcpAgent = TcpNewReno() 56 | else: 57 | # time-based = 1 58 | tcpAgent = TcpTimeBased() 59 | tcpAgent.set_spaces(get_agent.ob_space, get_agent.ac_space) 60 | get_agent.tcpAgents[socketUuid] = tcpAgent 61 | 62 | return tcpAgent 63 | 64 | # initialize variable 65 | get_agent.tcpAgents = {} 66 | get_agent.ob_space = ob_space 67 | get_agent.ac_space = ac_space 68 | 69 | try: 70 | while True: 71 | print("Start iteration: ", currIt) 72 | obs = env.reset() 73 | reward = 0 74 | done = False 75 | info = None 76 | print("Step: ", stepIdx) 77 | print("---obs: ", obs) 78 | 79 | # get existing agent of create new TCP agent if needed 80 | tcpAgent = get_agent(obs) 81 | 82 | while True: 83 | stepIdx += 1 84 | action = tcpAgent.get_action(obs, reward, done, info) 85 | print("---action: ", action) 86 | 87 | print("Step: ", stepIdx) 88 | obs, reward, done, info = env.step(action) 89 | print("---obs, reward, done, info: ", obs, reward, done, info) 90 | 91 | # get existing agent of create new TCP agent if needed 92 | tcpAgent = get_agent(obs) 93 | 94 | if done: 95 | stepIdx = 0 96 | if currIt + 1 < iterationNum: 97 | env.reset() 98 | break 99 | 100 | currIt += 1 101 | if currIt == iterationNum: 102 | break 103 | 104 | except KeyboardInterrupt: 105 | print("Ctrl-C -> Exit") 106 | finally: 107 | env.close() 108 | print("Done") -------------------------------------------------------------------------------- /examples/wscript: -------------------------------------------------------------------------------- 1 | # -*- Mode: python; py-indent-offset: 4; indent-tabs-mode: nil; coding: utf-8; -*- 2 | 3 | 4 | def build(bld): 5 | obj = bld.create_ns3_program("opengym", ["core", "opengym"]) 6 | obj.source = "opengym/sim.cc" 7 | 8 | obj = bld.create_ns3_program("opengym-2", ["core", "opengym"]) 9 | obj.source = ["opengym-2/sim.cc", "opengym-2/mygym.cc"] 10 | 11 | obj = bld.create_ns3_program( 12 | "linear-mesh", ["core", "internet", "application", "wifi", "opengym"] 13 | ) 14 | obj.source = ["linear-mesh/sim.cc"] 15 | 16 | obj = bld.create_ns3_program( 17 | "linear-mesh-2", ["core", "internet", "application", "wifi", "opengym"] 18 | ) 19 | obj.source = ["linear-mesh-2/sim.cc", "linear-mesh-2/mygym.cc"] 20 | 21 | obj = bld.create_ns3_program( 22 | "interference-pattern", ["core", "internet", "wifi", "opengym"] 23 | ) 24 | obj.source = ["interference-pattern/sim.cc", "interference-pattern/mygym.cc"] 25 | 26 | obj = bld.create_ns3_program("rl-tcp", ["core", "internet", "point-to-point", "point-to-point-layout", 27 | "applications", "flow-monitor", "opengym"]) 28 | obj.source = ["rl-tcp/sim.cc", "rl-tcp/tcp-rl-env.cc", "rl-tcp/tcp-rl.cc"] 29 | 30 | obj = bld.create_ns3_program("multigym", ["core", "opengym"]) 31 | obj.source = ["multigym/sim.cc", "multigym/mygym.cc"] 32 | -------------------------------------------------------------------------------- /helper/opengym-helper.cc: -------------------------------------------------------------------------------- 1 | /* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */ 2 | 3 | #include "opengym-helper.h" 4 | 5 | namespace ns3 { 6 | 7 | /* ... */ 8 | 9 | 10 | } 11 | 12 | -------------------------------------------------------------------------------- /helper/opengym-helper.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */ 2 | #ifndef OPENGYM_HELPER_H 3 | #define OPENGYM_HELPER_H 4 | 5 | namespace ns3 { 6 | 7 | /* ... */ 8 | 9 | } 10 | 11 | #endif /* OPENGYM_HELPER_H */ 12 | 13 | -------------------------------------------------------------------------------- /model-single-agent/messages.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | import "google/protobuf/any.proto"; 3 | 4 | package ns3opengym; 5 | 6 | //---------Types----------// 7 | enum MsgType { 8 | Unknown = 0; 9 | Init = 1; 10 | ActionSpace = 2; 11 | ObservationSpace = 3; 12 | IsGameOver = 4; 13 | Observation = 5; 14 | Reward = 6; 15 | ExtraInfo = 7; 16 | Action = 8; 17 | StopEnv = 9; 18 | } 19 | 20 | enum SpaceType { 21 | NoSpaceType = 0; 22 | Discrete = 1; 23 | Box = 2; 24 | Tuple = 3; 25 | Dict = 4; 26 | } 27 | 28 | enum Dtype { 29 | NoDType = 0; 30 | INT = 1; 31 | UINT = 2; 32 | FLOAT = 3; 33 | DOUBLE = 4; 34 | } 35 | //------------------------// 36 | 37 | //---Space Descriptions---// 38 | message SpaceDescription { 39 | SpaceType type = 1; 40 | google.protobuf.Any space = 2; 41 | string name = 3; //optional 42 | } 43 | 44 | message DiscreteSpace { 45 | int32 n = 1; 46 | } 47 | 48 | message BoxSpace { 49 | float low = 1; 50 | float high = 2; 51 | Dtype dtype = 3; 52 | repeated uint32 shape = 4; 53 | } 54 | 55 | message TupleSpace { 56 | repeated SpaceDescription element = 1; 57 | } 58 | 59 | message DictSpace { 60 | repeated SpaceDescription element = 1; 61 | } 62 | //------------------------// 63 | 64 | //----Data Containers-----// 65 | message DataContainer { 66 | SpaceType type = 1; 67 | google.protobuf.Any data = 2; 68 | string name = 3; //optional 69 | } 70 | 71 | message DiscreteDataContainer { 72 | int32 data = 1; 73 | } 74 | 75 | message BoxDataContainer { 76 | Dtype dtype = 1; 77 | repeated uint32 shape = 2; 78 | 79 | repeated int32 intData = 3; 80 | repeated uint32 uintData = 4; 81 | repeated float floatData = 5; 82 | repeated double doubleData = 6; 83 | } 84 | 85 | message TupleDataContainer { 86 | repeated DataContainer element = 1; 87 | } 88 | 89 | message DictDataContainer { 90 | repeated DataContainer element = 1; 91 | } 92 | //------------------------// 93 | 94 | //--------Messages--------// 95 | message SimInitMsg { 96 | uint64 simProcessId = 1; 97 | uint64 wafShellProcessId = 2; 98 | SpaceDescription obsSpace = 3; 99 | SpaceDescription actSpace = 4; 100 | } 101 | 102 | message SimInitAck { 103 | bool done = 1; 104 | bool stopSimReq = 2; 105 | } 106 | 107 | message EnvStateMsg { 108 | DataContainer obsData = 1; 109 | float reward = 2; 110 | bool isGameOver = 3; 111 | 112 | enum Reason { 113 | SimulationEnd = 0; 114 | GameOver = 1; 115 | } 116 | Reason reason = 4; 117 | string info = 5; 118 | } 119 | 120 | message EnvActMsg { 121 | DataContainer actData = 1; 122 | bool stopSimReq = 2; 123 | } 124 | //------------------------// -------------------------------------------------------------------------------- /model-single-agent/ns3gym/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2018 Piotr Gawlowicz (gawlowicz.p@gmail.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /model-single-agent/ns3gym/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | -------------------------------------------------------------------------------- /model-single-agent/ns3gym/README.md: -------------------------------------------------------------------------------- 1 | ns3gym 2 | ====== 3 | 4 | OpenAI Gym meets ns-3 -------------------------------------------------------------------------------- /model-single-agent/ns3gym/ns3gym/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id='ns3-v0', 5 | entry_point='ns3gym.ns3env:Ns3Env', 6 | ) -------------------------------------------------------------------------------- /model-single-agent/ns3gym/ns3gym/start_sim.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import sys 3 | import os 4 | import time 5 | import subprocess 6 | 7 | 8 | def find_waf_path(cwd): 9 | wafPath = cwd 10 | 11 | found = False 12 | myDir = cwd 13 | while (not found): 14 | for fname in os.listdir(myDir): 15 | if fname == "waf": 16 | found = True 17 | wafPath = os.path.join(myDir, fname) 18 | break 19 | 20 | myDir = os.path.dirname(myDir) 21 | 22 | return wafPath 23 | 24 | 25 | def build_ns3_project(debug=True): 26 | cwd = os.getcwd() 27 | simScriptName = os.path.basename(cwd) 28 | wafPath = find_waf_path(cwd) 29 | baseNs3Dir = os.path.dirname(wafPath) 30 | 31 | os.chdir(baseNs3Dir) 32 | 33 | wafString = wafPath + ' build' 34 | 35 | output = subprocess.DEVNULL 36 | if debug: 37 | output = None 38 | 39 | buildRequired = False 40 | ns3Proc = subprocess.Popen(wafString, shell=True, stdout=subprocess.PIPE, stderr=None, universal_newlines=True) 41 | 42 | lineHistory = [] 43 | for line in ns3Proc.stdout: 44 | if (True or "Compiling" in line or "Linking" in line) and not buildRequired: 45 | buildRequired = True 46 | print("Build ns-3 project if required") 47 | for l in lineHistory: 48 | sys.stdout.write(l) 49 | lineHistory = [] 50 | 51 | if buildRequired: 52 | sys.stdout.write(line) 53 | else: 54 | lineHistory.append(line) 55 | 56 | p_status = ns3Proc.wait() 57 | if buildRequired: 58 | print("(Re-)Build of ns-3 finished with status: ", p_status) 59 | os.chdir(cwd) 60 | 61 | 62 | def start_sim_script(port=5555, simSeed=0, simArgs={}, debug=False): 63 | cwd = os.getcwd() 64 | simScriptName = os.path.basename(cwd) 65 | wafPath = find_waf_path(cwd) 66 | baseNs3Dir = os.path.dirname(wafPath) 67 | 68 | os.chdir(baseNs3Dir) 69 | 70 | wafString = wafPath + ' --run "' + simScriptName 71 | 72 | if port: 73 | wafString += ' --openGymPort=' + str(port) 74 | 75 | if simSeed: 76 | wafString += ' --simSeed=' + str(simSeed) 77 | 78 | for k,v in simArgs.items(): 79 | wafString += " " 80 | wafString += str(k) 81 | wafString += "=" 82 | wafString += str(v) 83 | 84 | wafString += '"' 85 | 86 | ns3Proc = None 87 | if debug: 88 | ns3Proc = subprocess.Popen(wafString, shell=True, stdout=None, stderr=None) 89 | else: 90 | ''' 91 | users were complaining that when they start example they have to wait 10 min for initialization. 92 | simply ns3 is being built during this time, so now the output of the build will be put to stdout 93 | but sometimes build is not required and I would like to avoid unnecessary output on the screen 94 | it is not easy to get tell before start ./waf whether the build is required or not 95 | here, I use simple trick, i.e. if output of build contains {"Compiling","Linking"} 96 | then the build is required and, hence, i put the output to the stdout 97 | ''' 98 | errorOutput = subprocess.DEVNULL 99 | ns3Proc = subprocess.Popen(wafString, shell=True, stdout=subprocess.PIPE, stderr=errorOutput, universal_newlines=True) 100 | 101 | buildRequired = False 102 | lineHistory = [] 103 | for line in ns3Proc.stdout: 104 | if ("Compiling" in line or "Linking" in line) and not buildRequired: 105 | buildRequired = True 106 | print("Build ns-3 project if required") 107 | for l in lineHistory: 108 | sys.stdout.write(l) 109 | lineHistory = [] 110 | 111 | if buildRequired: 112 | sys.stdout.write(line) 113 | else: 114 | lineHistory.append(line) 115 | 116 | if ("Waf: Leaving directory" in line): 117 | break 118 | 119 | if debug: 120 | print("Start command: ",wafString) 121 | print("Started ns3 simulation script, Process Id: ", ns3Proc.pid) 122 | 123 | # go back to my dir 124 | os.chdir(cwd) 125 | return ns3Proc -------------------------------------------------------------------------------- /model-single-agent/ns3gym/requirements.txt: -------------------------------------------------------------------------------- 1 | pyzmq 2 | numpy 3 | protobuf 4 | gym -------------------------------------------------------------------------------- /model-single-agent/ns3gym/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import sys 3 | import os.path 4 | 5 | cwd = os.getcwd() 6 | protobufFile = cwd + '/ns3gym/messages_pb2.py' 7 | 8 | if not os.path.isfile(protobufFile): 9 | print("File: ", "ns3-gym/src/opengym/model/ns3gym/ns3gym/messages_pb2.py", " was not found.") 10 | sys.exit('Protocol Buffer messages are missing. Please run ./waf configure to generate the file') 11 | 12 | 13 | def readme(): 14 | with open('README.md') as f: 15 | return f.read() 16 | 17 | 18 | setup( 19 | name='ns3gym', 20 | version='0.1.0', 21 | packages=find_packages(), 22 | scripts=[], 23 | url='', 24 | license='MIT', 25 | author='Piotr Gawlowicz', 26 | author_email='gawlowicz.p@gmail.com', 27 | description='OpenAI Gym meets ns-3', 28 | long_description='OpenAI Gym meets ns-3', 29 | keywords='openAI gym, ML, RL, ns-3', 30 | install_requires=['pyzmq', 'numpy', 'protobuf', 'gym'], 31 | extras_require={}, 32 | ) 33 | -------------------------------------------------------------------------------- /model-single-agent/opengym_env.cc: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Piotr Gawlowicz 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | * 20 | */ 21 | 22 | #include "ns3/log.h" 23 | #include "ns3/object.h" 24 | #include "opengym_env.h" 25 | #include "container.h" 26 | #include "spaces.h" 27 | #include "opengym_interface.h" 28 | 29 | namespace ns3 { 30 | 31 | NS_OBJECT_ENSURE_REGISTERED (OpenGymEnv); 32 | 33 | NS_LOG_COMPONENT_DEFINE ("OpenGymEnv"); 34 | 35 | TypeId 36 | OpenGymEnv::GetTypeId (void) 37 | { 38 | static TypeId tid = TypeId ("ns3::OpenGymEnv") 39 | .SetParent () 40 | .SetGroupName ("OpenGym") 41 | ; 42 | return tid; 43 | } 44 | 45 | OpenGymEnv::OpenGymEnv() 46 | { 47 | NS_LOG_FUNCTION (this); 48 | } 49 | 50 | OpenGymEnv::~OpenGymEnv () 51 | { 52 | NS_LOG_FUNCTION (this); 53 | } 54 | 55 | void 56 | OpenGymEnv::DoDispose (void) 57 | { 58 | NS_LOG_FUNCTION (this); 59 | } 60 | 61 | void 62 | OpenGymEnv::DoInitialize (void) 63 | { 64 | NS_LOG_FUNCTION (this); 65 | } 66 | 67 | void 68 | OpenGymEnv::SetOpenGymInterface(Ptr openGymInterface) 69 | { 70 | NS_LOG_FUNCTION (this); 71 | m_openGymInterface = openGymInterface; 72 | openGymInterface->SetGetActionSpaceCb( MakeCallback (&OpenGymEnv::GetActionSpace, this) ); 73 | openGymInterface->SetGetObservationSpaceCb( MakeCallback (&OpenGymEnv::GetObservationSpace, this) ); 74 | openGymInterface->SetGetGameOverCb( MakeCallback (&OpenGymEnv::GetGameOver, this) ); 75 | openGymInterface->SetGetObservationCb( MakeCallback (&OpenGymEnv::GetObservation, this) ); 76 | openGymInterface->SetGetRewardCb( MakeCallback (&OpenGymEnv::GetReward, this) ); 77 | openGymInterface->SetGetExtraInfoCb( MakeCallback (&OpenGymEnv::GetExtraInfo, this) ); 78 | openGymInterface->SetExecuteActionsCb( MakeCallback (&OpenGymEnv::ExecuteActions, this) ); 79 | } 80 | 81 | /** 82 | * \brief Notify Current State 83 | * 1. Set Callback (SetGetGameOverCb,SetGetObservationCb, SetGetRewardCb, 84 | * SetGetExtraInfoCb, SetExecuteActionsCb) 85 | * 2. Collect current env state 86 | * 3. Execute Actions 87 | */ 88 | void 89 | OpenGymEnv::Notify() 90 | { 91 | NS_LOG_FUNCTION (this); 92 | if (m_openGymInterface) 93 | { 94 | m_openGymInterface->Notify(this); 95 | } 96 | } 97 | 98 | void 99 | OpenGymEnv::NotifySimulationEnd() 100 | { 101 | NS_LOG_FUNCTION (this); 102 | if (m_openGymInterface) 103 | { 104 | m_openGymInterface->NotifySimulationEnd(); 105 | } 106 | } 107 | 108 | } -------------------------------------------------------------------------------- /model-single-agent/opengym_env.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Piotr Gawlowicz 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | * 20 | */ 21 | 22 | #ifndef OPENGYM_ENV_H 23 | #define OPENGYM_ENV_H 24 | 25 | #include "ns3/object.h" 26 | 27 | namespace ns3 { 28 | 29 | class OpenGymSpace; 30 | class OpenGymDataContainer; 31 | class OpenGymInterface; 32 | 33 | class OpenGymEnv : public Object 34 | { 35 | public: 36 | OpenGymEnv (); 37 | virtual ~OpenGymEnv (); 38 | 39 | static TypeId GetTypeId (); 40 | 41 | virtual Ptr GetActionSpace() = 0; 42 | virtual Ptr GetObservationSpace() = 0; 43 | // TODO: get all in one function like below, do we need it? 44 | //virtual void GetEnvState(Ptr &obs, float &reward, bool &done, std::string &info) = 0; 45 | virtual bool GetGameOver() = 0; 46 | virtual Ptr GetObservation() = 0; 47 | virtual float GetReward() = 0; 48 | virtual std::string GetExtraInfo() = 0; 49 | virtual bool ExecuteActions(Ptr action) = 0; 50 | 51 | void SetOpenGymInterface(Ptr openGymInterface); 52 | void Notify(); 53 | void NotifySimulationEnd(); 54 | 55 | 56 | protected: 57 | // Inherited 58 | virtual void DoInitialize (void); 59 | virtual void DoDispose (void); 60 | 61 | Ptr m_openGymInterface; 62 | private: 63 | 64 | }; 65 | 66 | } // end of namespace ns3 67 | 68 | #endif /* OPENGYM_ENV_H */ -------------------------------------------------------------------------------- /model-single-agent/opengym_interface.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Piotr Gawlowicz 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | * 20 | */ 21 | 22 | #ifndef OPENGYM_INTERFACE_H 23 | #define OPENGYM_INTERFACE_H 24 | 25 | #include "ns3/object.h" 26 | #include 27 | 28 | namespace ns3 { 29 | 30 | class OpenGymSpace; 31 | class OpenGymDataContainer; 32 | class OpenGymEnv; 33 | 34 | class OpenGymInterface : public Object 35 | { 36 | public: 37 | static Ptr Get (uint32_t port=5555); 38 | 39 | OpenGymInterface (uint32_t port=5555); 40 | virtual ~OpenGymInterface (); 41 | 42 | static TypeId GetTypeId (); 43 | 44 | void Init(); 45 | void NotifyCurrentState(); 46 | void WaitForStop(); 47 | 48 | void NotifySimulationEnd(); 49 | 50 | Ptr GetActionSpace(); 51 | Ptr GetObservationSpace(); 52 | Ptr GetObservation(); 53 | float GetReward(); 54 | bool IsGameOver(); 55 | std::string GetExtraInfo(); 56 | bool ExecuteActions(Ptr action); 57 | 58 | void SetGetActionSpaceCb(Callback< Ptr > cb); 59 | void SetGetObservationSpaceCb(Callback< Ptr > cb); 60 | void SetGetObservationCb(Callback< Ptr > cb); 61 | void SetGetRewardCb(Callback cb); 62 | void SetGetGameOverCb(Callback< bool > cb); 63 | void SetGetExtraInfoCb(Callback cb); 64 | void SetExecuteActionsCb(Callback > cb); 65 | 66 | void Notify(Ptr entity); 67 | 68 | protected: 69 | // Inherited 70 | virtual void DoInitialize (void); 71 | virtual void DoDispose (void); 72 | 73 | private: 74 | static Ptr *DoGet (uint32_t port=5555); 75 | static void Delete (void); 76 | 77 | uint32_t m_port; 78 | zmq::context_t m_zmq_context; 79 | zmq::socket_t m_zmq_socket; 80 | 81 | bool m_simEnd; 82 | bool m_stopEnvRequested; 83 | bool m_initSimMsgSent; 84 | 85 | Callback< Ptr > m_actionSpaceCb; 86 | Callback< Ptr > m_observationSpaceCb; 87 | Callback< bool > m_gameOverCb; 88 | Callback< Ptr > m_obsCb; 89 | Callback m_rewardCb; 90 | Callback m_extraInfoCb; 91 | Callback > m_actionCb; 92 | }; 93 | 94 | } // end of namespace ns3 95 | 96 | #endif /* OPENGYM_INTERFACE_H */ 97 | 98 | -------------------------------------------------------------------------------- /model-single-agent/spaces.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Piotr Gawlowicz 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | * 20 | */ 21 | 22 | #ifndef OPENGYM_SPACES_H 23 | #define OPENGYM_SPACES_H 24 | 25 | #include "ns3/object.h" 26 | #include "messages.pb.h" 27 | 28 | namespace ns3 { 29 | 30 | class OpenGymSpace : public Object 31 | { 32 | public: 33 | OpenGymSpace (); 34 | virtual ~OpenGymSpace (); 35 | 36 | static TypeId GetTypeId (); 37 | 38 | virtual ns3opengym::SpaceDescription GetSpaceDescription() = 0; 39 | virtual void Print(std::ostream& where) const = 0; 40 | protected: 41 | // Inherited 42 | virtual void DoInitialize (void); 43 | virtual void DoDispose (void); 44 | }; 45 | 46 | 47 | class OpenGymDiscreteSpace : public OpenGymSpace 48 | { 49 | public: 50 | OpenGymDiscreteSpace (); 51 | OpenGymDiscreteSpace (int n); 52 | virtual ~OpenGymDiscreteSpace (); 53 | 54 | static TypeId GetTypeId (); 55 | 56 | virtual ns3opengym::SpaceDescription GetSpaceDescription(); 57 | 58 | int GetN(void); 59 | virtual void Print(std::ostream& where) const; 60 | friend std::ostream& operator<< (std::ostream& os, const Ptr space) 61 | { 62 | space->Print(os); 63 | return os; 64 | } 65 | 66 | protected: 67 | // Inherited 68 | virtual void DoInitialize (void); 69 | virtual void DoDispose (void); 70 | 71 | private: 72 | int m_n; 73 | }; 74 | 75 | class OpenGymBoxSpace : public OpenGymSpace 76 | { 77 | public: 78 | OpenGymBoxSpace (); 79 | OpenGymBoxSpace (float low, float high, std::vector shape, std::string dtype); 80 | OpenGymBoxSpace (std::vector low, std::vector high, std::vector shape, std::string dtype); 81 | virtual ~OpenGymBoxSpace (); 82 | 83 | static TypeId GetTypeId (); 84 | 85 | virtual ns3opengym::SpaceDescription GetSpaceDescription(); 86 | 87 | float GetLow(); 88 | float GetHigh(); 89 | std::vector GetShape(); 90 | 91 | virtual void Print(std::ostream& where) const; 92 | friend std::ostream& operator<< (std::ostream& os, const Ptr space) 93 | { 94 | space->Print(os); 95 | return os; 96 | } 97 | 98 | protected: 99 | // Inherited 100 | virtual void DoInitialize (void); 101 | virtual void DoDispose (void); 102 | 103 | private: 104 | void SetDtype (); 105 | 106 | float m_low; 107 | float m_high; 108 | std::vector m_shape; 109 | std::string m_dtypeName; 110 | std::vector m_lowVec; 111 | std::vector m_highVec; 112 | 113 | ns3opengym::Dtype m_dtype; 114 | }; 115 | 116 | 117 | class OpenGymTupleSpace : public OpenGymSpace 118 | { 119 | public: 120 | OpenGymTupleSpace (); 121 | virtual ~OpenGymTupleSpace (); 122 | 123 | static TypeId GetTypeId (); 124 | 125 | virtual ns3opengym::SpaceDescription GetSpaceDescription(); 126 | 127 | bool Add(Ptr space); 128 | Ptr Get(uint32_t idx); 129 | 130 | virtual void Print(std::ostream& where) const; 131 | friend std::ostream& operator<< (std::ostream& os, const Ptr space) 132 | { 133 | space->Print(os); 134 | return os; 135 | } 136 | 137 | protected: 138 | // Inherited 139 | virtual void DoInitialize (void); 140 | virtual void DoDispose (void); 141 | 142 | private: 143 | std::vector< Ptr > m_tuple; 144 | }; 145 | 146 | 147 | class OpenGymDictSpace : public OpenGymSpace 148 | { 149 | public: 150 | OpenGymDictSpace (); 151 | virtual ~OpenGymDictSpace (); 152 | 153 | static TypeId GetTypeId (); 154 | 155 | virtual ns3opengym::SpaceDescription GetSpaceDescription(); 156 | 157 | bool Add(std::string key, Ptr value); 158 | Ptr Get(std::string key); 159 | 160 | virtual void Print(std::ostream& where) const; 161 | friend std::ostream& operator<< (std::ostream& os, const Ptr space) 162 | { 163 | space->Print(os); 164 | return os; 165 | } 166 | 167 | protected: 168 | // Inherited 169 | virtual void DoInitialize (void); 170 | virtual void DoDispose (void); 171 | 172 | private: 173 | std::map< std::string, Ptr > m_dict; 174 | }; 175 | 176 | } // end of namespace ns3 177 | 178 | #endif /* OPENGYM_SPACES_H */ 179 | 180 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangmwg/ns3-gym-multiagent/e7927cf70daae2001f996a76498251235c9338de/model/__init__.py -------------------------------------------------------------------------------- /model/messages.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | import "google/protobuf/any.proto"; 3 | 4 | package ns3opengym; 5 | 6 | //---------Types----------// 7 | enum MsgType { 8 | Unknown = 0; 9 | Init = 1; 10 | ActionSpace = 2; 11 | ObservationSpace = 3; 12 | IsGameOver = 4; 13 | Observation = 5; 14 | Reward = 6; 15 | ExtraInfo = 7; 16 | Action = 8; 17 | StopEnv = 9; 18 | } 19 | 20 | enum SpaceType { 21 | NoSpaceType = 0; 22 | Discrete = 1; 23 | Box = 2; 24 | Tuple = 3; 25 | Dict = 4; 26 | } 27 | 28 | enum Dtype { 29 | NoDType = 0; 30 | INT = 1; 31 | UINT = 2; 32 | FLOAT = 3; 33 | DOUBLE = 4; 34 | } 35 | //------------------------// 36 | 37 | //---Space Descriptions---// 38 | message SpaceDescription { 39 | SpaceType type = 1; 40 | google.protobuf.Any space = 2; 41 | string name = 3; //optional 42 | } 43 | 44 | message DiscreteSpace { 45 | int32 n = 1; 46 | } 47 | 48 | message BoxSpace { 49 | float low = 1; 50 | float high = 2; 51 | Dtype dtype = 3; 52 | repeated uint32 shape = 4; 53 | } 54 | 55 | message TupleSpace { 56 | repeated SpaceDescription element = 1; 57 | } 58 | 59 | message DictSpace { 60 | repeated SpaceDescription element = 1; 61 | } 62 | //------------------------// 63 | 64 | //----Data Containers-----// 65 | message DataContainer { 66 | SpaceType type = 1; 67 | google.protobuf.Any data = 2; 68 | string name = 3; //optional 69 | } 70 | 71 | message DiscreteDataContainer { 72 | int32 data = 1; 73 | } 74 | 75 | message BoxDataContainer { 76 | Dtype dtype = 1; 77 | repeated uint32 shape = 2; 78 | 79 | repeated int32 intData = 3; 80 | repeated uint32 uintData = 4; 81 | repeated float floatData = 5; 82 | repeated double doubleData = 6; 83 | } 84 | 85 | message TupleDataContainer { 86 | repeated DataContainer element = 1; 87 | } 88 | 89 | message DictDataContainer { 90 | repeated DataContainer element = 1; 91 | } 92 | //------------------------// 93 | 94 | //--------Messages--------// 95 | message SimInitMsg { 96 | uint64 simProcessId = 1; 97 | uint64 wafShellProcessId = 2; 98 | SpaceDescription obsSpace = 3; 99 | SpaceDescription actSpace = 4; 100 | } 101 | 102 | message SimInitAck { 103 | bool done = 1; 104 | bool stopSimReq = 2; 105 | } 106 | 107 | message EnvStateMsg { 108 | DataContainer obsData = 1; 109 | float reward = 2; 110 | bool isGameOver = 3; 111 | 112 | enum Reason { 113 | SimulationEnd = 0; 114 | GameOver = 1; 115 | } 116 | Reason reason = 4; 117 | string info = 5; 118 | } 119 | 120 | message EnvActMsg { 121 | DataContainer actData = 1; 122 | bool stopSimReq = 2; 123 | } 124 | //------------------------// 125 | 126 | //---------Multi Agent Messages-----------// 127 | 128 | // NOTE: 129 | // use SimInitAck 130 | // Multi agent SimInitAck is the same as SimInitAck above 131 | 132 | message AgentInitMsg { 133 | uint32 agentId = 1; 134 | SpaceDescription obsSpace = 2; 135 | SpaceDescription actSpace = 3; 136 | } 137 | 138 | message MultiAgentInitMsg { 139 | uint64 simProcessId = 1; 140 | uint64 wafShellProcessId = 2; 141 | repeated AgentInitMsg agentInitMsg = 3; 142 | } 143 | 144 | message AgentStateMsg { 145 | // NOT use Reason 146 | uint32 agentId = 1; 147 | DataContainer obsData = 2; 148 | float reward = 3; 149 | bool done = 4; 150 | string info = 5; 151 | } 152 | 153 | message MultiAgentStateMsg { 154 | repeated AgentStateMsg agentStateMsg = 1; 155 | bool ns3SimulationEnd = 2; 156 | } 157 | 158 | message AgentActMsg { 159 | uint32 agentId = 1; 160 | DataContainer actData = 2; 161 | } 162 | 163 | message MultiAgentActMsg { 164 | repeated AgentActMsg agentActMsg = 1; 165 | bool stopSimReq = 2; 166 | } -------------------------------------------------------------------------------- /model/ns3gym/.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml 3 | -------------------------------------------------------------------------------- /model/ns3gym/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /model/ns3gym/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /model/ns3gym/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /model/ns3gym/.idea/ns3gym.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /model/ns3gym/.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /model/ns3gym/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2018 Piotr Gawlowicz (gawlowicz.p@gmail.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /model/ns3gym/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | -------------------------------------------------------------------------------- /model/ns3gym/README.md: -------------------------------------------------------------------------------- 1 | ns3gym 2 | ====== 3 | 4 | OpenAI Gym meets ns-3 5 | 6 | ## Multi-agent env (add) 7 | add multiagent_env.py 8 | 9 | @author Zhangmin Wang 10 | -------------------------------------------------------------------------------- /model/ns3gym/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangmwg/ns3-gym-multiagent/e7927cf70daae2001f996a76498251235c9338de/model/ns3gym/__init__.py -------------------------------------------------------------------------------- /model/ns3gym/ns3gym/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id='ns3-v0', 5 | entry_point='ns3gym.ns3env:Ns3Env', 6 | ) 7 | register( 8 | id='ns3-gym-v1', 9 | entry_point = 'ns3gym.ns3_multiagent_env:MultiEnv', 10 | ) -------------------------------------------------------------------------------- /model/ns3gym/ns3gym/start_sim.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import subprocess 5 | 6 | 7 | def find_waf_path(cwd): 8 | wafPath = cwd 9 | 10 | found = False 11 | myDir = cwd 12 | while (not found): 13 | for fname in os.listdir(myDir): 14 | if fname == "waf": 15 | found = True 16 | wafPath = os.path.join(myDir, fname) 17 | break 18 | 19 | myDir = os.path.dirname(myDir) 20 | 21 | return wafPath 22 | 23 | 24 | def build_ns3_project(debug=True): 25 | cwd = os.getcwd() 26 | simScriptName = os.path.basename(cwd) 27 | wafPath = find_waf_path(cwd) 28 | baseNs3Dir = os.path.dirname(wafPath) 29 | 30 | os.chdir(baseNs3Dir) 31 | 32 | wafString = wafPath + ' build' 33 | 34 | output = subprocess.DEVNULL 35 | if debug: 36 | output = None 37 | 38 | buildRequired = False 39 | ns3Proc = subprocess.Popen(wafString, shell=True, stdout=subprocess.PIPE, stderr=None, universal_newlines=True) 40 | 41 | lineHistory = [] 42 | for line in ns3Proc.stdout: 43 | if (True or "Compiling" in line or "Linking" in line) and not buildRequired: 44 | buildRequired = True 45 | print("Build ns-3 project if required") 46 | for l in lineHistory: 47 | sys.stdout.write(l) 48 | lineHistory = [] 49 | 50 | if buildRequired: 51 | sys.stdout.write(line) 52 | else: 53 | lineHistory.append(line) 54 | 55 | p_status = ns3Proc.wait() 56 | if buildRequired: 57 | print("(Re-)Build of ns-3 finished with status: ", p_status) 58 | os.chdir(cwd) 59 | 60 | 61 | def start_sim_script(port=5555, simSeed=0, simArgs={}, debug=False): 62 | cwd = os.getcwd() 63 | simScriptName = os.path.basename(cwd) 64 | wafPath = find_waf_path(cwd) 65 | baseNs3Dir = os.path.dirname(wafPath) 66 | 67 | os.chdir(baseNs3Dir) 68 | 69 | wafString = wafPath 70 | 71 | if debug: 72 | wafString += ' --run ' + simScriptName + ' --command-template="gdb %s"' 73 | else: 74 | wafString += ' --run "' + simScriptName 75 | 76 | if port: 77 | wafString += ' --openGymPort=' + str(port) 78 | 79 | if simSeed: 80 | wafString += ' --simSeed=' + str(simSeed) 81 | 82 | for k,v in simArgs.items(): 83 | wafString += " " 84 | wafString += str(k) 85 | wafString += "=" 86 | wafString += str(v) 87 | 88 | wafString += '"' 89 | 90 | ns3Proc = None 91 | if debug: 92 | ns3Proc = subprocess.Popen(wafString, shell=True, stdout=None, stderr=None) 93 | print("Start command: ", wafString) 94 | print("Started ns3 simulation script, Process Id: ", ns3Proc.pid) 95 | print("DEBUG ...\n") 96 | else: 97 | ''' 98 | users were complaining that when they start example they have to wait 10 min for initialization. 99 | simply ns3 is being built during this time, so now the output of the build will be put to stdout 100 | but sometimes build is not required and I would like to avoid unnecessary output on the screen 101 | it is not easy to get tell before start ./waf whether the build is required or not 102 | here, I use simple trick, i.e. if output of build contains {"Compiling","Linking"} 103 | then the build is required and, hence, i put the output to the stdout 104 | ''' 105 | print("Start command: ", wafString, "\n") 106 | errorOutput = subprocess.DEVNULL 107 | ns3Proc = subprocess.Popen(wafString, shell=True, stdout=subprocess.PIPE, stderr=errorOutput, universal_newlines=True) 108 | 109 | buildRequired = False 110 | lineHistory = [] 111 | for line in ns3Proc.stdout: 112 | if ("Compiling" in line or "Linking" in line) and not buildRequired: 113 | buildRequired = True 114 | print("Build ns-3 project if required") 115 | for l in lineHistory: 116 | sys.stdout.write(l) 117 | lineHistory = [] 118 | 119 | if buildRequired: 120 | sys.stdout.write(line) 121 | else: 122 | lineHistory.append(line) 123 | 124 | if ("Waf: Leaving directory" in line): 125 | break 126 | 127 | # if debug: 128 | # print("Start command: ",wafString) 129 | # print("Started ns3 simulation script, Process Id: ", ns3Proc.pid) 130 | # print("DEBUG ...") 131 | # else: 132 | # print("Start command: ",wafString) 133 | 134 | # go back to my dir 135 | os.chdir(cwd) 136 | return ns3Proc -------------------------------------------------------------------------------- /model/ns3gym/requirements.txt: -------------------------------------------------------------------------------- 1 | pyzmq 2 | numpy 3 | protobuf 4 | gym -------------------------------------------------------------------------------- /model/ns3gym/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import sys 3 | import os.path 4 | 5 | cwd = os.getcwd() 6 | protobufFile = cwd + '/ns3gym/messages_pb2.py' 7 | 8 | if not os.path.isfile(protobufFile): 9 | print("File: ", "src/opengym/model/ns3gym/ns3gym/messages_pb2.py", " was not found.") 10 | sys.exit('Protocol Buffer messages are missing. Please run ./waf configure to generate the file') 11 | 12 | 13 | def readme(): 14 | with open('README.md') as f: 15 | return f.read() 16 | 17 | 18 | setup( 19 | name='ns3gym', 20 | # version='0.1.0', 21 | version='0.1.1', 22 | packages=find_packages(), 23 | scripts=[], 24 | url='', 25 | license='MIT', 26 | author='Piotr Gawlowicz, Zhangmin Wang', 27 | author_email='gawlowicz.p@gmail.com, zhangmwg@gmail.com', 28 | description='OpenAI Gym meets ns-3', 29 | long_description='OpenAI Gym meets ns-3', 30 | keywords='openAI gym, ML, RL, ns-3', 31 | install_requires=['pyzmq', 'numpy', 'protobuf', 'gym'], 32 | extras_require={}, 33 | ) 34 | -------------------------------------------------------------------------------- /model/opengym_env.cc: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Piotr Gawlowicz 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | * 20 | */ 21 | 22 | #include "ns3/log.h" 23 | #include "ns3/object.h" 24 | #include "opengym_env.h" 25 | #include "container.h" 26 | #include "spaces.h" 27 | #include "opengym_interface.h" 28 | 29 | namespace ns3 { 30 | 31 | NS_OBJECT_ENSURE_REGISTERED (OpenGymEnv); 32 | 33 | NS_LOG_COMPONENT_DEFINE ("OpenGymEnv"); 34 | 35 | TypeId 36 | OpenGymEnv::GetTypeId (void) 37 | { 38 | static TypeId tid = TypeId ("ns3::OpenGymEnv") 39 | .SetParent () 40 | .SetGroupName ("OpenGym") 41 | ; 42 | return tid; 43 | } 44 | 45 | OpenGymEnv::OpenGymEnv() 46 | { 47 | NS_LOG_FUNCTION (this); 48 | } 49 | 50 | OpenGymEnv::~OpenGymEnv () 51 | { 52 | NS_LOG_FUNCTION (this); 53 | } 54 | 55 | void 56 | OpenGymEnv::DoDispose (void) 57 | { 58 | NS_LOG_FUNCTION (this); 59 | } 60 | 61 | void 62 | OpenGymEnv::DoInitialize (void) 63 | { 64 | NS_LOG_FUNCTION (this); 65 | } 66 | 67 | void 68 | OpenGymEnv::SetOpenGymInterface(Ptr openGymInterface) 69 | { 70 | NS_LOG_FUNCTION (this); 71 | // NS_LOG_UNCOND("Set OpenGym Interface"); 72 | m_openGymInterface = openGymInterface; 73 | openGymInterface->SetGetActionSpaceCb( MakeCallback (&OpenGymEnv::GetActionSpace, this) ); 74 | openGymInterface->SetGetObservationSpaceCb( MakeCallback (&OpenGymEnv::GetObservationSpace, this) ); 75 | openGymInterface->SetGetGameOverCb( MakeCallback (&OpenGymEnv::GetGameOver, this) ); 76 | openGymInterface->SetGetObservationCb( MakeCallback (&OpenGymEnv::GetObservation, this) ); 77 | openGymInterface->SetGetRewardCb( MakeCallback (&OpenGymEnv::GetReward, this) ); 78 | openGymInterface->SetGetExtraInfoCb( MakeCallback (&OpenGymEnv::GetExtraInfo, this) ); 79 | openGymInterface->SetExecuteActionsCb( MakeCallback (&OpenGymEnv::ExecuteActions, this) ); 80 | } 81 | 82 | /** 83 | * \brief Notify Current State 84 | * 1. Set Callback (SetGetGameOverCb,SetGetObservationCb, SetGetRewardCb, 85 | * SetGetExtraInfoCb, SetExecuteActionsCb) 86 | * 2. Collect current env state 87 | * 3. Execute Actions 88 | */ 89 | void 90 | OpenGymEnv::Notify() 91 | { 92 | NS_LOG_FUNCTION (this); 93 | if (m_openGymInterface) 94 | { 95 | m_openGymInterface->Notify(this); 96 | } 97 | } 98 | 99 | void 100 | OpenGymEnv::NotifySimulationEnd() 101 | { 102 | NS_LOG_FUNCTION (this); 103 | NS_LOG_UNCOND ("OpenGymEnv NotifySimulationEnd."); 104 | if (m_openGymInterface) 105 | { 106 | m_openGymInterface->NotifySimulationEnd(); 107 | } 108 | } 109 | 110 | } -------------------------------------------------------------------------------- /model/opengym_env.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Piotr Gawlowicz 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | * 20 | */ 21 | 22 | #ifndef OPENGYM_ENV_H 23 | #define OPENGYM_ENV_H 24 | 25 | #include "ns3/object.h" 26 | 27 | namespace ns3 { 28 | 29 | class OpenGymSpace; 30 | class OpenGymDataContainer; 31 | class OpenGymInterface; 32 | 33 | class OpenGymEnv : public Object 34 | { 35 | public: 36 | OpenGymEnv (); 37 | virtual ~OpenGymEnv (); 38 | 39 | static TypeId GetTypeId (); 40 | 41 | virtual Ptr GetActionSpace () = 0; 42 | virtual Ptr GetObservationSpace () = 0; 43 | // TODO: get all in one function like below, do we need it? 44 | //virtual void GetEnvState(Ptr &obs, float &reward, bool &done, std::string &info) = 0; 45 | virtual bool GetGameOver () = 0; 46 | virtual Ptr GetObservation () = 0; 47 | virtual float GetReward () = 0; 48 | virtual std::string GetExtraInfo () = 0; 49 | virtual bool ExecuteActions (Ptr action) = 0; 50 | 51 | void SetOpenGymInterface (Ptr openGymInterface); 52 | /** 53 | * \brief Notify Current State 54 | * 1. Set Callback (SetGetGameOverCb,SetGetObservationCb, SetGetRewardCb, 55 | * SetGetExtraInfoCb, SetExecuteActionsCb) 56 | * 2. Collect current env state 57 | * 3. Execute Actions 58 | */ 59 | void Notify (); 60 | void NotifySimulationEnd (); 61 | 62 | protected: 63 | // Inherited 64 | virtual void DoInitialize (void); 65 | virtual void DoDispose (void); 66 | 67 | Ptr m_openGymInterface; 68 | 69 | private: 70 | }; 71 | 72 | } // end of namespace ns3 73 | 74 | #endif /* OPENGYM_ENV_H */ -------------------------------------------------------------------------------- /model/opengym_interface.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Piotr Gawlowicz 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | * 20 | */ 21 | 22 | #ifndef OPENGYM_INTERFACE_H 23 | #define OPENGYM_INTERFACE_H 24 | 25 | #include "ns3/object.h" 26 | #include 27 | 28 | namespace ns3 { 29 | 30 | class OpenGymSpace; 31 | class OpenGymDataContainer; 32 | class OpenGymEnv; 33 | 34 | class OpenGymInterface : public Object 35 | { 36 | public: 37 | static Ptr Get (uint32_t port=5555); 38 | 39 | OpenGymInterface (uint32_t port=5555); 40 | virtual ~OpenGymInterface (); 41 | 42 | static TypeId GetTypeId (); 43 | 44 | void Init(); 45 | /** 46 | * 1. Collect current env state 47 | * 2. Execute Actions 48 | */ 49 | void NotifyCurrentState(); 50 | void WaitForStop(); 51 | 52 | void NotifySimulationEnd(); 53 | 54 | Ptr GetActionSpace(); 55 | Ptr GetObservationSpace(); 56 | Ptr GetObservation(); 57 | float GetReward(); 58 | bool IsGameOver(); 59 | std::string GetExtraInfo(); 60 | bool ExecuteActions(Ptr action); 61 | 62 | void SetGetActionSpaceCb(Callback< Ptr > cb); 63 | void SetGetObservationSpaceCb(Callback< Ptr > cb); 64 | void SetGetObservationCb(Callback< Ptr > cb); 65 | void SetGetRewardCb(Callback cb); 66 | void SetGetGameOverCb(Callback< bool > cb); 67 | void SetGetExtraInfoCb(Callback cb); 68 | void SetExecuteActionsCb(Callback > cb); 69 | 70 | /** 71 | * \brief Notify current state 72 | * 1. Collect current env state 73 | * 2. Execute Actions 74 | */ 75 | void Notify(Ptr entity); 76 | 77 | protected: 78 | // Inherited 79 | virtual void DoInitialize (void); 80 | virtual void DoDispose (void); 81 | 82 | private: 83 | static Ptr *DoGet (uint32_t port=5555); 84 | static void Delete (void); 85 | 86 | uint32_t m_port; 87 | zmq::context_t m_zmq_context; 88 | zmq::socket_t m_zmq_socket; 89 | 90 | bool m_simEnd; 91 | bool m_stopEnvRequested; 92 | bool m_initSimMsgSent; 93 | 94 | Callback< Ptr > m_actionSpaceCb; 95 | Callback< Ptr > m_observationSpaceCb; 96 | Callback< bool > m_gameOverCb; 97 | Callback< Ptr > m_obsCb; 98 | Callback m_rewardCb; 99 | Callback m_extraInfoCb; 100 | Callback > m_actionCb; 101 | }; 102 | 103 | } // end of namespace ns3 104 | 105 | #endif /* OPENGYM_INTERFACE_H */ 106 | 107 | -------------------------------------------------------------------------------- /model/opengym_multi_env.cc: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2019 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Multi-agent Environment 19 | * 20 | * Author: Zhangmin Wang 21 | */ 22 | 23 | #include "ns3/log.h" 24 | #include "ns3/object.h" 25 | #include "ns3/uinteger.h" 26 | #include "opengym_multi_env.h" 27 | #include "container.h" 28 | #include "spaces.h" 29 | #include "opengym_multi_interface.h" 30 | 31 | namespace ns3 { 32 | 33 | NS_OBJECT_ENSURE_REGISTERED (OpenGymMultiEnv); 34 | 35 | NS_LOG_COMPONENT_DEFINE ("OpenGymMultiEnv"); 36 | 37 | TypeId 38 | OpenGymMultiEnv::GetTypeId (void) 39 | { 40 | static TypeId tid = 41 | TypeId ("ns3::OpenGymMultiEnv") 42 | .SetParent () 43 | .SetGroupName ("OpenGym") 44 | .AddAttribute ("OpenGymPort", "OpenGymPort, default 5555", UintegerValue (5555), 45 | MakeUintegerAccessor (&OpenGymMultiEnv::m_openGymPort), 46 | MakeUintegerChecker ()); 47 | return tid; 48 | } 49 | 50 | OpenGymMultiEnv::OpenGymMultiEnv () 51 | { 52 | NS_LOG_FUNCTION (this); 53 | // Env automatically associate Interface 54 | Ptr openGymMultiInterface = 55 | CreateObject (m_openGymPort); 56 | SetOpenGymMultiInterface (openGymMultiInterface); 57 | } 58 | 59 | OpenGymMultiEnv::~OpenGymMultiEnv () 60 | { 61 | NS_LOG_FUNCTION (this); 62 | } 63 | 64 | void 65 | OpenGymMultiEnv::DoDispose (void) 66 | { 67 | NS_LOG_FUNCTION (this); 68 | } 69 | 70 | void 71 | OpenGymMultiEnv::DoInitialize (void) 72 | { 73 | NS_LOG_FUNCTION (this); 74 | } 75 | 76 | void 77 | OpenGymMultiEnv::AddAgentId (uint32_t agent_id) 78 | { 79 | NS_LOG_FUNCTION (this); 80 | m_openGymMultiInterface->AddAgent (agent_id); 81 | } 82 | 83 | void 84 | OpenGymMultiEnv::SetOpenGymMultiInterface (Ptr multiInterface) 85 | { 86 | NS_LOG_FUNCTION (this); 87 | m_openGymMultiInterface = multiInterface; 88 | multiInterface->SetGetActionSpaceCb (MakeCallback (&OpenGymMultiEnv::GetActionSpace, this)); 89 | multiInterface->SetGetObservationSpaceCb ( 90 | MakeCallback (&OpenGymMultiEnv::GetObservationSpace, this)); 91 | multiInterface->SetGetObservationCb (MakeCallback (&OpenGymMultiEnv::GetObservation, this)); 92 | multiInterface->SetGetRewardCb (MakeCallback (&OpenGymMultiEnv::GetReward, this)); 93 | multiInterface->SetGetDoneCb (MakeCallback (&OpenGymMultiEnv::GetDone, this)); 94 | multiInterface->SetGetInfoCb (MakeCallback (&OpenGymMultiEnv::GetInfo, this)); 95 | multiInterface->SetExecuteActionsCb (MakeCallback (&OpenGymMultiEnv::ExecuteActions, this)); 96 | } 97 | 98 | /** 99 | * \brief Notify Current State, similar gym step. 100 | * 1. Set Callback (SetGetDoneCb,SetGetObservationCb, SetGetRewardCb, 101 | * SetGetExtraInfoCb, SetExecuteActionsCb) 102 | * 2. Collect current env state 103 | * 3. Execute Actions 104 | * 105 | * Notify network environment current state, including each agent. 106 | * Set Get Callback for one agent by agent_id: [observation, reward, done, info, execute_action] 107 | * Notify current state of agents 108 | */ 109 | void 110 | OpenGymMultiEnv::Step () 111 | { 112 | NS_LOG_FUNCTION (this); 113 | if (m_openGymMultiInterface) 114 | { 115 | m_openGymMultiInterface->Notify (this); 116 | } 117 | } 118 | 119 | void 120 | OpenGymMultiEnv::NotifySimulationEnd () 121 | { 122 | NS_LOG_FUNCTION (this); 123 | if (m_openGymMultiInterface) 124 | { 125 | m_openGymMultiInterface->NotifySimulationEnd (); 126 | } 127 | } 128 | 129 | } // namespace ns3 -------------------------------------------------------------------------------- /model/opengym_multi_env.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2019 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Multi-agent Environment 19 | * Base on: 20 | * opengym_env 21 | * 22 | * Author: Zhangmin Wang 23 | */ 24 | 25 | /************************************************************************************* 26 | * NOTE: We formalize the network problem as a multi-agent extension Markov decision 27 | * processes (MDPs) called Partially Observable Markov Games (POMGs). 28 | * 29 | * Why we use a multi-agent environment? 30 | * 31 | * Fully-distributed learning: 32 | * Algorithms with centralized learning process are not applicable 33 | * in the real computer network. The centralized learning controller 34 | * is usually unable to gather collected environment transitions 35 | * from widely distributed routers once an action is executed somewhere 36 | * and to update the parameters of each neural network simultaneously 37 | * caused by the limited bandwidth. 38 | * \ref You, Xinyu, et al. "Toward Packet Routing with Fully-distributed 39 | * Multi-agent Deep Reinforcement Learning." arXiv preprint 40 | * arXiv:1905.03494 (2019). 41 | *************************************************************************************/ 42 | 43 | #ifndef OPENGYM_MULTI_ENV_H 44 | #define OPENGYM_MULTI_ENV_H 45 | 46 | #include "ns3/object.h" 47 | 48 | namespace ns3{ 49 | 50 | class OpenGymSpace; 51 | class OpenGymDataContainer; 52 | class OpenGymMultiInterface; 53 | 54 | class OpenGymMultiEnv : public Object 55 | { 56 | public: 57 | OpenGymMultiEnv(); 58 | virtual ~OpenGymMultiEnv(); 59 | 60 | static TypeId GetTypeId(); 61 | 62 | // Add agent ID 63 | void AddAgentId(uint32_t agent_id); 64 | 65 | ///\{ Each agent OpenGym Env 66 | virtual Ptr GetActionSpace(uint32_t agent_id) = 0; 67 | virtual Ptr GetObservationSpace(uint32_t agent_id) = 0; 68 | virtual Ptr GetObservation(uint32_t agent_id) = 0; 69 | virtual float GetReward(uint32_t agent_id) = 0; 70 | virtual bool GetDone(uint32_t agent_id) = 0; 71 | virtual std::string GetInfo(uint32_t agent_id) = 0; 72 | virtual bool ExecuteActions(uint32_t agent_id, Ptr action) = 0; 73 | ///\} 74 | 75 | /** 76 | * \brief Notify Current State, similar gym step. 77 | * 1. Set Callback (SetGetDoneCb,SetGetObservationCb, SetGetRewardCb, 78 | * SetGetExtraInfoCb, SetExecuteActionsCb) 79 | * 2. Collect current env state 80 | * 3. Execute Actions 81 | * 82 | * Notify network environment current state, including each agent. 83 | * Set Get Callback for one agent by agent_id: [observation, reward, done, info, execute_action] 84 | * Notify current state of agents 85 | * \note Same as in OpenGymEnv::Notify() 86 | * \note We rename step, because the function is the same as openAI Gym step() 87 | */ 88 | void Step(); 89 | void NotifySimulationEnd(); 90 | 91 | protected: 92 | // Inherited 93 | virtual void DoInitialize(void); 94 | virtual void DoDispose(void); 95 | Ptr m_openGymMultiInterface; 96 | uint32_t m_openGymPort = 5555; 97 | private: 98 | void SetOpenGymMultiInterface(Ptr multiInterface); 99 | }; 100 | 101 | } // namespace ns3 102 | 103 | #endif /* OPENGYM_MULTI_ENV */ -------------------------------------------------------------------------------- /model/opengym_multi_interface.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2019 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * 19 | * ******************************************************************************** 20 | * 21 | * Multi-agent interface, different from single-agent interface (opengym_interface). 22 | * We formalize the network problem as a multi-agent extension Markov decision 23 | * processes (MDPs) called Partially Observable Markov Games (POMGs). 24 | * 25 | * Base on: 26 | * opengym_interface 27 | * 28 | * Author: Zhangmin Wang 29 | */ 30 | 31 | #ifndef OPENGYM_MULTI_INTERFACE_H 32 | #define OPENGYM_MULTI_INTERFACE_H 33 | 34 | #include "ns3/object.h" 35 | #include 36 | 37 | namespace ns3 { 38 | 39 | class OpenGymSpace; 40 | class OpenGymDataContainer; 41 | class OpenGymMultiEnv; 42 | 43 | /** 44 | * \note This class should only be called by OpenGymMultiEnv. 45 | */ 46 | class OpenGymMultiInterface : public Object 47 | { 48 | public: 49 | static Ptr Get (uint32_t port = 5555); 50 | 51 | OpenGymMultiInterface (uint32_t port = 5555); 52 | virtual ~OpenGymMultiInterface (); 53 | 54 | static TypeId GetTypeId (); 55 | 56 | void Init (); 57 | 58 | /** 59 | * \note Notify, similar openAI Gym step 60 | * Notify network environment current state, including each agent. 61 | * Set Get Callback for one agent by agent_id: [observation, reward, done, info, execute_action] 62 | * Notify current state of agents 63 | */ 64 | void Notify (Ptr entity); 65 | /** 66 | * Notify current state of agents 67 | * Send env state msg to python 68 | * Receive multi-agent actions msg from python 69 | * 70 | * NOTE: first step after reset is called without actions, 71 | * just to get current state 72 | * 73 | * 1. Collect current env state 74 | * 2. Execute Actions 75 | * \note Similar gym step, this function should only be called by Notify. 76 | */ 77 | void NotifyCurrentState (); 78 | void WaitForStop (); 79 | void NotifySimulationEnd (); 80 | 81 | void AddAgent(uint32_t agent_id); 82 | 83 | // Each agent 84 | Ptr GetActionSpace (uint32_t agent_id); 85 | Ptr GetObservationSpace (uint32_t agent_id); 86 | Ptr GetObservation (uint32_t agent_id); 87 | float GetReward (uint32_t agent_id); 88 | bool GetDone (uint32_t agent_id); 89 | std::string GetInfo (uint32_t agent_id); 90 | bool ExecuteActions (uint32_t agent_id, Ptr action); 91 | 92 | // Each agent OpenGym Interface Callback 93 | void SetGetActionSpaceCb (Callback, uint32_t> cb); 94 | void SetGetObservationSpaceCb (Callback, uint32_t> cb); 95 | void SetGetObservationCb (Callback, uint32_t> cb); 96 | void SetGetRewardCb (Callback cb); 97 | void SetGetDoneCb (Callback cb); 98 | void SetGetInfoCb (Callback cb); 99 | void SetExecuteActionsCb (Callback> cb); 100 | 101 | protected: 102 | // Inherited 103 | virtual void DoInitialize (void); 104 | virtual void DoDispose (void); 105 | 106 | private: 107 | static Ptr *DoGet (uint32_t port = 5555); 108 | static void Delete (void); 109 | 110 | uint32_t m_port; 111 | zmq::context_t m_zmq_context; 112 | zmq::socket_t m_zmq_socket; 113 | 114 | bool m_simEnd; 115 | bool m_stopEnvRequested; 116 | bool m_initSimMsgSent; 117 | 118 | // agent ID vector 119 | std::vector m_agentIdVec; 120 | 121 | Callback, uint32_t> m_actionSpaceCb; 122 | Callback, uint32_t> m_observationSpaceCb; 123 | Callback, uint32_t> m_obsCb; 124 | Callback m_rewardCb; 125 | Callback m_doneCb; 126 | Callback m_infoCb; 127 | Callback> m_actionCb; 128 | }; 129 | 130 | } // namespace ns3 131 | 132 | #endif /* OPENGYM_MULTI_INTERFACE_H */ -------------------------------------------------------------------------------- /model/spaces.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Piotr Gawlowicz 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | * 20 | */ 21 | 22 | #ifndef OPENGYM_SPACES_H 23 | #define OPENGYM_SPACES_H 24 | 25 | #include "ns3/object.h" 26 | #include "messages.pb.h" 27 | 28 | namespace ns3 { 29 | 30 | class OpenGymSpace : public Object 31 | { 32 | public: 33 | OpenGymSpace (); 34 | virtual ~OpenGymSpace (); 35 | 36 | static TypeId GetTypeId (); 37 | 38 | virtual ns3opengym::SpaceDescription GetSpaceDescription() = 0; 39 | virtual void Print(std::ostream& where) const = 0; 40 | protected: 41 | // Inherited 42 | virtual void DoInitialize (void); 43 | virtual void DoDispose (void); 44 | }; 45 | 46 | 47 | class OpenGymDiscreteSpace : public OpenGymSpace 48 | { 49 | public: 50 | OpenGymDiscreteSpace (); 51 | OpenGymDiscreteSpace (int n); 52 | virtual ~OpenGymDiscreteSpace (); 53 | 54 | static TypeId GetTypeId (); 55 | 56 | virtual ns3opengym::SpaceDescription GetSpaceDescription(); 57 | 58 | int GetN(void); 59 | virtual void Print(std::ostream& where) const; 60 | friend std::ostream& operator<< (std::ostream& os, const Ptr space) 61 | { 62 | space->Print(os); 63 | return os; 64 | } 65 | 66 | protected: 67 | // Inherited 68 | virtual void DoInitialize (void); 69 | virtual void DoDispose (void); 70 | 71 | private: 72 | int m_n; 73 | }; 74 | 75 | class OpenGymBoxSpace : public OpenGymSpace 76 | { 77 | public: 78 | OpenGymBoxSpace (); 79 | OpenGymBoxSpace (float low, float high, std::vector shape, std::string dtype); 80 | OpenGymBoxSpace (std::vector low, std::vector high, std::vector shape, std::string dtype); 81 | virtual ~OpenGymBoxSpace (); 82 | 83 | static TypeId GetTypeId (); 84 | 85 | virtual ns3opengym::SpaceDescription GetSpaceDescription(); 86 | 87 | float GetLow(); 88 | float GetHigh(); 89 | std::vector GetShape(); 90 | 91 | virtual void Print(std::ostream& where) const; 92 | friend std::ostream& operator<< (std::ostream& os, const Ptr space) 93 | { 94 | space->Print(os); 95 | return os; 96 | } 97 | 98 | protected: 99 | // Inherited 100 | virtual void DoInitialize (void); 101 | virtual void DoDispose (void); 102 | 103 | private: 104 | void SetDtype (); 105 | 106 | float m_low; 107 | float m_high; 108 | std::vector m_shape; 109 | std::string m_dtypeName; 110 | std::vector m_lowVec; 111 | std::vector m_highVec; 112 | 113 | ns3opengym::Dtype m_dtype; 114 | }; 115 | 116 | 117 | class OpenGymTupleSpace : public OpenGymSpace 118 | { 119 | public: 120 | OpenGymTupleSpace (); 121 | virtual ~OpenGymTupleSpace (); 122 | 123 | static TypeId GetTypeId (); 124 | 125 | virtual ns3opengym::SpaceDescription GetSpaceDescription(); 126 | 127 | bool Add(Ptr space); 128 | Ptr Get(uint32_t idx); 129 | 130 | virtual void Print(std::ostream& where) const; 131 | friend std::ostream& operator<< (std::ostream& os, const Ptr space) 132 | { 133 | space->Print(os); 134 | return os; 135 | } 136 | 137 | protected: 138 | // Inherited 139 | virtual void DoInitialize (void); 140 | virtual void DoDispose (void); 141 | 142 | private: 143 | std::vector< Ptr > m_tuple; 144 | }; 145 | 146 | 147 | class OpenGymDictSpace : public OpenGymSpace 148 | { 149 | public: 150 | OpenGymDictSpace (); 151 | virtual ~OpenGymDictSpace (); 152 | 153 | static TypeId GetTypeId (); 154 | 155 | virtual ns3opengym::SpaceDescription GetSpaceDescription(); 156 | 157 | bool Add(std::string key, Ptr value); 158 | Ptr Get(std::string key); 159 | 160 | virtual void Print(std::ostream& where) const; 161 | friend std::ostream& operator<< (std::ostream& os, const Ptr space) 162 | { 163 | space->Print(os); 164 | return os; 165 | } 166 | 167 | protected: 168 | // Inherited 169 | virtual void DoInitialize (void); 170 | virtual void DoDispose (void); 171 | 172 | private: 173 | std::map< std::string, Ptr > m_dict; 174 | }; 175 | 176 | } // end of namespace ns3 177 | 178 | #endif /* OPENGYM_SPACES_H */ 179 | 180 | -------------------------------------------------------------------------------- /test/opengym-test-suite.cc: -------------------------------------------------------------------------------- 1 | /* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */ 2 | 3 | // Include a header file from your module to test. 4 | #include "ns3/opengym-module.h" 5 | 6 | // An essential include is test.h 7 | #include "ns3/test.h" 8 | 9 | // Do not put your test classes in namespace ns3. You may find it useful 10 | // to use the using directive to access the ns3 namespace directly 11 | using namespace ns3; 12 | 13 | // This is an example TestCase. 14 | class OpengymTestCase1 : public TestCase 15 | { 16 | public: 17 | OpengymTestCase1 (); 18 | virtual ~OpengymTestCase1 (); 19 | 20 | private: 21 | virtual void DoRun (void); 22 | }; 23 | 24 | // Add some help text to this case to describe what it is intended to test 25 | OpengymTestCase1::OpengymTestCase1 () 26 | : TestCase ("Opengym test case (does nothing)") 27 | { 28 | } 29 | 30 | // This destructor does nothing but we include it as a reminder that 31 | // the test case should clean up after itself 32 | OpengymTestCase1::~OpengymTestCase1 () 33 | { 34 | } 35 | 36 | // 37 | // This method is the pure virtual method from class TestCase that every 38 | // TestCase must implement 39 | // 40 | void 41 | OpengymTestCase1::DoRun (void) 42 | { 43 | // A wide variety of test macros are available in src/core/test.h 44 | NS_TEST_ASSERT_MSG_EQ (true, true, "true doesn't equal true for some reason"); 45 | // Use this one for floating point comparisons 46 | NS_TEST_ASSERT_MSG_EQ_TOL (0.01, 0.01, 0.001, "Numbers are not equal within tolerance"); 47 | } 48 | 49 | // The TestSuite class names the TestSuite, identifies what type of TestSuite, 50 | // and enables the TestCases to be run. Typically, only the constructor for 51 | // this class must be defined 52 | // 53 | class OpengymTestSuite : public TestSuite 54 | { 55 | public: 56 | OpengymTestSuite (); 57 | }; 58 | 59 | OpengymTestSuite::OpengymTestSuite () 60 | : TestSuite ("opengym", UNIT) 61 | { 62 | // TestDuration for TestCase can be QUICK, EXTENSIVE or TAKES_FOREVER 63 | AddTestCase (new OpengymTestCase1, TestCase::QUICK); 64 | } 65 | 66 | // Do not forget to allocate an instance of this TestSuite 67 | static OpengymTestSuite opengymTestSuite; 68 | 69 | -------------------------------------------------------------------------------- /wscript: -------------------------------------------------------------------------------- 1 | # -*- Mode: python; py-indent-offset: 4; indent-tabs-mode: nil; coding: utf-8; -*- 2 | 3 | import os 4 | import sys 5 | import subprocess 6 | 7 | from waflib import Options 8 | from waflib.Errors import WafError 9 | 10 | #def options(opt): 11 | # pass 12 | 13 | def configure(conf): 14 | conf.env['ENABLE_ZMQ'] = conf.check(mandatory=False, lib='zmq', define_name='HAVE_ZMQ', uselib='ZMQ') 15 | conf.env['ENABLE_PROTOBUF'] = conf.check(mandatory=False, lib='protobuf', define_name='HAVE_PROTOBUF', uselib='PROTOBUF') 16 | 17 | # check if protoc is installed 18 | conf.env['ENABLE_PROTOC'] = False 19 | 20 | protoc_min_version = (3,0,0) 21 | protoc = conf.find_program('protoc', var='PROTOC') 22 | 23 | try: 24 | cmd = [protoc[0], '--version'] 25 | output = subprocess.check_output(cmd).decode("utf-8") 26 | output = output.split(" ")[1].rstrip() 27 | protoc_version = tuple(output.split('.')) 28 | protoc_version = tuple(map(int, protoc_version)) 29 | conf.msg('Checking for protoc version', output) 30 | conf.env['ENABLE_PROTOC'] = True 31 | except Exception: 32 | conf.fatal('Could not determine the protoc version %r'%protoc) 33 | 34 | if not conf.env['ENABLE_ZMQ'] or not conf.env['ENABLE_PROTOBUF'] or not conf.env['ENABLE_PROTOC']: 35 | conf.env['MODULES_NOT_BUILT'].append('opengym') 36 | return 37 | 38 | # if protoc was found, check if version >= 3.0.0 39 | if protoc_version < protoc_min_version: 40 | conf.fatal('protoc version %s older than minimum supported version %s' % 41 | ('.'.join(map(str, protoc_version)), '.'.join(map(str, protoc_min_version)) )) 42 | 43 | conf.env.append_value("LINKFLAGS", ["-lzmq", "-lprotobuf"]) 44 | conf.env.append_value("LIB", ["zmq", "protobuf"]) 45 | 46 | # build protobuff messages 47 | try: 48 | pbSrcDir = str(conf.path) + "/model/" 49 | protoc = protoc[0] 50 | rc = subprocess.call(protoc+" -I="+pbSrcDir+" --cpp_out="+pbSrcDir+" "+pbSrcDir+"messages.proto", shell=True) 51 | 52 | if rc == 0: 53 | conf.msg("Build ns3gym Protobuf C++ messages", "Done") 54 | else: 55 | conf.fatal('Build ns3gym Protobuf C++ messages failed') 56 | 57 | rc = subprocess.call(protoc+" -I="+pbSrcDir+" --python_out="+pbSrcDir+"ns3gym/ns3gym/ "+pbSrcDir+"messages.proto", shell=True) 58 | 59 | if rc == 0: 60 | conf.msg("Build ns3gym Protobuf Python messages", "Done") 61 | else: 62 | conf.fatal('Build ns3gym Protobuf Python messages failed') 63 | 64 | except Exception as e: 65 | conf.env['MODULES_NOT_BUILT'].append('opengym') 66 | conf.fatal('Build of ns3gym Protobuf messages failed') 67 | 68 | def build(bld): 69 | # Don't do anything for this module if click should not be built. 70 | if 'opengym' in bld.env['MODULES_NOT_BUILT']: 71 | return 72 | 73 | module = bld.create_ns3_module('opengym', ['core']) 74 | module.source = [ 75 | 'model/opengym_interface.cc', 76 | 'model/messages.pb.cc', 77 | 'model/container.cc', 78 | 'model/spaces.cc', 79 | 'model/opengym_env.cc', 80 | 'model/opengym_multi_interface.cc', 81 | 'model/opengym_multi_env.cc', 82 | 'helper/opengym-helper.cc', 83 | ] 84 | 85 | module_test = bld.create_ns3_module_test_library('opengym') 86 | module_test.source = [ 87 | 'test/opengym-test-suite.cc', 88 | ] 89 | 90 | headers = bld(features='ns3header') 91 | headers.module = 'opengym' 92 | headers.source = [ 93 | 'model/opengym_interface.h', 94 | 'model/messages.pb.h', 95 | 'model/container.h', 96 | 'model/spaces.h', 97 | 'model/opengym_env.h', 98 | 'model/opengym_multi_interface.h', 99 | 'model/opengym_multi_env.h', 100 | 'helper/opengym-helper.h', 101 | ] 102 | 103 | if bld.env['ENABLE_ZMQ']: 104 | module.use.extend(['lzmq']) 105 | module.use.extend(['lprotobuf']) 106 | 107 | if bld.env['ENABLE_EXAMPLES']: 108 | bld.recurse('examples') 109 | 110 | # bld.ns3_python_bindings() 111 | 112 | --------------------------------------------------------------------------------