├── Agent_test.py ├── README.md ├── rl.png ├── sim.cc ├── tcp-rl-env.cc ├── tcp-rl-env.h ├── tcp-rl.cc ├── tcp-rl.h └── tcp_base.py /Agent_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import gym 4 | import tensorflow as tf 5 | import tensorflow.contrib.slim as slim 6 | import numpy as np 7 | import matplotlib as mpl 8 | import matplotlib.pyplot as plt 9 | from tensorflow import keras 10 | from ns3gym import ns3env 11 | 12 | 13 | import argparse 14 | from ns3gym import ns3env 15 | from tcp_base import TcpTimeBased 16 | from tcp_newreno import TcpNewReno 17 | 18 | __author__ = "Kenan and Sharif, Modified the code by Piotr Gawlowicz" 19 | __copyright__ = "Copyright (c) 2020, Technische Universität Berlin" 20 | __version__ = "0.1.0" 21 | __email__ = "gawlowicz@tkn.tu-berlin.de" 22 | 23 | parser = argparse.ArgumentParser(description='Start simulation script on/off') 24 | parser.add_argument('--start', 25 | type=int, 26 | default=1, 27 | help='Start ns-3 simulation script 0/1, Default: 1') 28 | parser.add_argument('--iterations', 29 | type=int, 30 | default=1, 31 | help='Number of iterations, Default: 1') 32 | args = parser.parse_args() 33 | startSim = bool(args.start) 34 | iterationNum = int(args.iterations) 35 | 36 | port = 5555 37 | simTime = 10 # seconds 38 | stepTime = 0.5 # seconds 39 | seed = 12 40 | simArgs = {"--duration": simTime,} 41 | debug = False 42 | 43 | env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug) 44 | # simpler: 45 | #env = ns3env.Ns3Env() 46 | env.reset() 47 | 48 | ob_space = env.observation_space 49 | ac_space = env.action_space 50 | 51 | print("Observation space: ", ob_space, ob_space.dtype) 52 | print("Action space: ", ac_space, ac_space.dtype) 53 | 54 | 55 | def get_agent(obs): 56 | socketUuid = obs[0] 57 | tcpEnvType = obs[1] 58 | tcpAgent = get_agent.tcpAgents.get(socketUuid, None) 59 | if tcpAgent is None: 60 | if tcpEnvType == 0: 61 | # event-based = 0 62 | tcpAgent = TcpNewReno() 63 | else: 64 | # time-based = 1 65 | tcpAgent = TcpTimeBased() 66 | tcpAgent.set_spaces(get_agent.ob_space, get_agent.ac_space) 67 | get_agent.tcpAgents[socketUuid] = tcpAgent 68 | 69 | return tcpAgent 70 | 71 | # initialize variable 72 | get_agent.tcpAgents = {} 73 | get_agent.ob_space = ob_space 74 | get_agent.ac_space = ac_space 75 | 76 | s_size = ob_space.shape[0] 77 | print("State size: ",ob_space.shape[0]) 78 | 79 | a_size = 3 80 | print("Action size: ", a_size) 81 | 82 | model = keras.Sequential() 83 | model.add(keras.layers.Dense(s_size, input_shape=(s_size,), activation='relu')) 84 | model.add(keras.layers.Dense(s_size, input_shape=(s_size,), activation='relu')) 85 | model.add(keras.layers.Dense(a_size, activation='softmax')) 86 | model.compile(optimizer=tf.train.AdamOptimizer(0.001), 87 | loss='categorical_crossentropy', 88 | metrics=['accuracy']) 89 | 90 | total_episodes = 6 91 | max_env_steps = 100 92 | env._max_episode_steps = max_env_steps 93 | 94 | epsilon = 1.0 # exploration rate 95 | epsilon_min = 0.01 96 | epsilon_decay = 0.999 97 | 98 | time_history = [] 99 | rew_history = [] 100 | cWnd_history=[] 101 | cWnd_history2=[] 102 | Rtt=[] 103 | segAkc=[] 104 | No_step = 0 105 | t2 =10 106 | t =[] 107 | reward = 0 108 | done = False 109 | info = None 110 | action_mapping = {} 111 | action_mapping[0] = 0 112 | action_mapping[1] = 600 113 | action_mapping[2] = -60 114 | U_new =0 115 | U =0 116 | U_old=0 117 | reward=0 118 | for e in range(total_episodes): 119 | obs = env.reset() 120 | cWnd = obs[5] 121 | obs = np.reshape(obs, [1, s_size]) 122 | rewardsum = 0 123 | for time in range(max_env_steps): 124 | # Choose action 125 | if np.random.rand(1) < epsilon: 126 | action_index = np.random.randint(3) 127 | print (action_index) 128 | print("Value Initialization ...") 129 | else: 130 | action_index = np.argmax(model.predict(obs)[0]) 131 | print(action_index) 132 | new_cWnd = cWnd + action_mapping[action_index] 133 | new_ssThresh = np.int(cWnd/2) 134 | actions = [new_ssThresh, new_cWnd] 135 | U_new=0.7*(np.log(obs[0,2]))-0.7*(np.log(obs[0,9] )) 136 | U=U_new-U_old 137 | if U <-0.05: 138 | reward=-5 139 | elif U >0.05: 140 | reward=1 141 | else: 142 | reward=0 143 | # Step 144 | next_state, reward, done, info = env.step(actions) 145 | cWnd = next_state[5] 146 | print("cWnd:",cWnd) 147 | if done: 148 | print("episode: {}/{}, time: {}, rew: {}, eps: {:.2}" 149 | .format(e, total_episodes, time, rewardsum, epsilon)) 150 | break 151 | U_old=0.7*(np.log(obs[0,2]))-0.7*(np.log(obs[0,9] )) 152 | next_state = np.reshape(next_state, [1, s_size]) 153 | # Train 154 | target = reward 155 | if not done: 156 | target = (reward + 0.95 * np.amax(model.predict(next_state)[0])) 157 | target_f = model.predict(obs) 158 | print("target :", target_f) 159 | target_f[0][action_index] = target 160 | model.fit(obs, target_f, epochs=1, verbose=0) 161 | 162 | obs = next_state 163 | seg=obs[0,5] 164 | rtt=obs[0,9] 165 | rewardsum += reward 166 | if epsilon > epsilon_min: epsilon *= epsilon_decay 167 | No_step += 1 168 | 169 | print("number of steps :", No_step) 170 | print("espsilon :",epsilon) 171 | 172 | 173 | print("reward sum", rewardsum) 174 | segAkc.append(seg) 175 | Rtt.append(rtt) 176 | 177 | cWnd_history.append(cWnd) 178 | time_history.append(time) 179 | rew_history.append(rewardsum) 180 | 181 | 182 | 183 | print("Plot Learning Performance") 184 | mpl.rcdefaults() 185 | mpl.rcParams.update({'font.size': 16}) 186 | fig, ax = plt.subplots(figsize=(10,4)) 187 | plt.grid(True, linestyle='--') 188 | plt.title('Learning Performance') 189 | #plt.plot(range(len(rew_history)), rew_history, label='Reward', marker="^", linestyle=":")#, color='red') 190 | plt.plot(range(len(rew_history)), rew_history, label='Reward', marker="", linestyle="-")#, color='k') 191 | 192 | #plt.plot(range(len(segAkc)), segAkc, label='segAkc', marker="", linestyle="-"),# color='b') 193 | #plt.plot(range(len(Rtt)),Rtt, label='Rtt', marker="", linestyle="-")#, color='y') 194 | plt.xlabel('Episode') 195 | plt.ylabel('Steps') 196 | plt.legend(prop={'size': 12}) 197 | plt.savefig('learning.pdf', bbox_inches='tight') 198 | plt.show() 199 | 200 | 201 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RL-TCP 2 | For years, TCP has been explored and discussed for end-to-end congestion control mechanisms. Although, various efforts and optimization techniques have taken place to find an optimal approach in dealing with all kinds of network architecture and scenarios. But the TCP version only follows the rules which they are defined by. The rule-based approach takes no consideration of the previous information of the links every time a flow begins,whereas a better performance can be accomplished if TCP can adapt its attitude based on prior learned information when the same path was before experienced. The rising of machine learning and deep learning paves the way for many researchers in academia. These approaches rather learn from history to find the best way and methods to get the job done.
3 | This project discusses the implementation of Reinforcement Learning (RL) based TCP congestion control using python Agent which uses RL based techniques to somehow overcome the rule-based shortcomings of TCP. 4 | 5 | ## RL Tools 6 | The following RL tools are used for the implementation of the project. 7 | 8 | ### NS-3 9 | Ns-3 is a discrete-event network simulator for networking systems, that primarily used for research and education. It became a standard in networking research in recent years as the results obtained are accepted by the science community. In this project, we used NS-3 to generate traffic and also simulate our environment. It can be found here. https://www.nsnam.org 10 | 11 | ### Ns-3 gym 12 | ns3-gym is a middleware which interconnects ns-3 network simulator and OpenAI Gym framework.ns-3 gym works as a proxy between both ns-3 environments that is a C++ environment and the python Agent of Reinforcement learning which is in python. It can be installed from here. https://github.com/tkn-tub/ns3-gym 13 | 14 | ## Agent 15 | Our agent interacts with the network environments of ns-3 and keeps exploring the optimal policy by taking various actions such as increasing or decreasing cWnd size. The environment setup is a combination of ns-3 network simulator, ns3-gym proxy, and a python based agent. Our mapping approach is as follows: 16 | 17 | * State: The state space is the profile of our networking simulation that is provided by the ns-3 simulator. 18 | * Action: The actions are considered to increase, decrease or leave no-change the cWnd size with +600, -200, 0 respectively. 19 | * Reward: The reward is a utility function that considers the RTT and segments acknowledged parameters. +2 incase of increasing cWnd and not overwhelming the receiver side, otherwise -30. 20 | 21 | ![](rl.png) 22 | As the result indicates, our agent adaptively adjusts the cWnd in a constant range of values according to an objective the agent realized. 23 | 24 | In order to run the agent, first, install ns-3, ns3-gym and build the project and then execute: 25 | ``` 26 | ./Agent_test.py 27 | ```` 28 | Or in two Terminals: 29 | ``` 30 | # Terminal 1: 31 | ./waf --run "rl-tcp --transport_prot=TcpRlTimeBased" 32 | 33 | # Terminal 2: 34 | ./Agent_test.py --start=0 35 | 36 | ``` 37 | -------------------------------------------------------------------------------- /rl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharif-Rezaie/RL-TCP/9dc1e3893f11e3e67d1fae7f59d975548cce9126/rl.png -------------------------------------------------------------------------------- /sim.cc: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Piotr Gawlowicz 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | * Based on script: ./examples/tcp/tcp-variants-comparison.cc 20 | * 21 | * Topology: 22 | * 23 | * Right Leafs (Clients) Left Leafs (Sinks) 24 | * | \ / | 25 | * | \ bottleneck / | 26 | * | R0--------------R1 | 27 | * | / \ | 28 | * | access / \ access | 29 | * N ----------- --------N 30 | */ 31 | 32 | #include 33 | #include 34 | #include 35 | 36 | #include "ns3/core-module.h" 37 | #include "ns3/network-module.h" 38 | #include "ns3/internet-module.h" 39 | #include "ns3/point-to-point-module.h" 40 | #include "ns3/point-to-point-layout-module.h" 41 | #include "ns3/applications-module.h" 42 | #include "ns3/error-model.h" 43 | #include "ns3/tcp-header.h" 44 | #include "ns3/enum.h" 45 | #include "ns3/event-id.h" 46 | #include "ns3/flow-monitor-helper.h" 47 | #include "ns3/ipv4-global-routing-helper.h" 48 | #include "ns3/traffic-control-module.h" 49 | 50 | #include "ns3/opengym-module.h" 51 | #include "tcp-rl.h" 52 | 53 | using namespace ns3; 54 | 55 | NS_LOG_COMPONENT_DEFINE ("TcpVariantsComparison"); 56 | 57 | static std::vector rxPkts; 58 | 59 | static void 60 | CountRxPkts(uint32_t sinkId, Ptr packet, const Address & srcAddr) 61 | { 62 | rxPkts[sinkId]++; 63 | } 64 | 65 | static void 66 | PrintRxCount() 67 | { 68 | uint32_t size = rxPkts.size(); 69 | NS_LOG_UNCOND("RxPkts:"); 70 | for (uint32_t i=0; i openGymInterface; 144 | if (transport_prot.compare ("ns3::TcpRl") == 0) 145 | { 146 | openGymInterface = OpenGymInterface::Get(openGymPort); 147 | Config::SetDefault ("ns3::TcpRl::Reward", DoubleValue (2.0)); // Reward when increasing congestion window 148 | Config::SetDefault ("ns3::TcpRl::Penalty", DoubleValue (-30.0)); // Penalty when decreasing congestion window 149 | } 150 | 151 | if (transport_prot.compare ("ns3::TcpRlTimeBased") == 0) 152 | { 153 | openGymInterface = OpenGymInterface::Get(openGymPort); 154 | Config::SetDefault ("ns3::TcpRlTimeBased::StepTime", TimeValue (Seconds(tcpEnvTimeStep))); // Time step of TCP env 155 | } 156 | 157 | // Calculate the ADU size 158 | Header* temp_header = new Ipv4Header (); 159 | uint32_t ip_header = temp_header->GetSerializedSize (); 160 | NS_LOG_LOGIC ("IP Header size is: " << ip_header); 161 | delete temp_header; 162 | temp_header = new TcpHeader (); 163 | uint32_t tcp_header = temp_header->GetSerializedSize (); 164 | NS_LOG_LOGIC ("TCP Header size is: " << tcp_header); 165 | delete temp_header; 166 | uint32_t tcp_adu_size = mtu_bytes - 20 - (ip_header + tcp_header); 167 | NS_LOG_LOGIC ("TCP ADU size is: " << tcp_adu_size); 168 | 169 | // Set the simulation start and stop time 170 | double start_time = 0.1; 171 | double stop_time = start_time + duration; 172 | 173 | // 4 MB of TCP buffer 174 | Config::SetDefault ("ns3::TcpSocket::RcvBufSize", UintegerValue (1 << 21)); 175 | Config::SetDefault ("ns3::TcpSocket::SndBufSize", UintegerValue (1 << 21)); 176 | Config::SetDefault ("ns3::TcpSocketBase::Sack", BooleanValue (sack)); 177 | Config::SetDefault ("ns3::TcpSocket::DelAckCount", UintegerValue (2)); 178 | 179 | 180 | Config::SetDefault ("ns3::TcpL4Protocol::RecoveryType", 181 | TypeIdValue (TypeId::LookupByName (recovery))); 182 | // Select TCP variant 183 | if (transport_prot.compare ("ns3::TcpWestwoodPlus") == 0) 184 | { 185 | // TcpWestwoodPlus is not an actual TypeId name; we need TcpWestwood here 186 | Config::SetDefault ("ns3::TcpL4Protocol::SocketType", TypeIdValue (TcpWestwood::GetTypeId ())); 187 | // the default protocol type in ns3::TcpWestwood is WESTWOOD 188 | Config::SetDefault ("ns3::TcpWestwood::ProtocolType", EnumValue (TcpWestwood::WESTWOODPLUS)); 189 | } 190 | else 191 | { 192 | TypeId tcpTid; 193 | NS_ABORT_MSG_UNLESS (TypeId::LookupByNameFailSafe (transport_prot, &tcpTid), "TypeId " << transport_prot << " not found"); 194 | Config::SetDefault ("ns3::TcpL4Protocol::SocketType", TypeIdValue (TypeId::LookupByName (transport_prot))); 195 | } 196 | 197 | // Configure the error model 198 | // Here we use RateErrorModel with packet error rate 199 | Ptr uv = CreateObject (); 200 | uv->SetStream (50); 201 | RateErrorModel error_model; 202 | error_model.SetRandomVariable (uv); 203 | error_model.SetUnit (RateErrorModel::ERROR_UNIT_PACKET); 204 | error_model.SetRate (error_p); 205 | 206 | // Create the point-to-point link helpers 207 | PointToPointHelper bottleNeckLink; 208 | bottleNeckLink.SetDeviceAttribute ("DataRate", StringValue (bottleneck_bandwidth)); 209 | bottleNeckLink.SetChannelAttribute ("Delay", StringValue (bottleneck_delay)); 210 | //bottleNeckLink.SetDeviceAttribute ("ReceiveErrorModel", PointerValue (&error_model)); 211 | 212 | PointToPointHelper pointToPointLeaf; 213 | pointToPointLeaf.SetDeviceAttribute ("DataRate", StringValue (access_bandwidth)); 214 | pointToPointLeaf.SetChannelAttribute ("Delay", StringValue (access_delay)); 215 | 216 | PointToPointDumbbellHelper d (nLeaf, pointToPointLeaf, 217 | nLeaf, pointToPointLeaf, 218 | bottleNeckLink); 219 | 220 | // Install IP stack 221 | InternetStackHelper stack; 222 | stack.InstallAll (); 223 | 224 | // Traffic Control 225 | TrafficControlHelper tchPfifo; 226 | tchPfifo.SetRootQueueDisc ("ns3::PfifoFastQueueDisc"); 227 | 228 | TrafficControlHelper tchCoDel; 229 | tchCoDel.SetRootQueueDisc ("ns3::CoDelQueueDisc"); 230 | 231 | DataRate access_b (access_bandwidth); 232 | DataRate bottle_b (bottleneck_bandwidth); 233 | Time access_d (access_delay); 234 | Time bottle_d (bottleneck_delay); 235 | 236 | uint32_t size = static_cast((std::min (access_b, bottle_b).GetBitRate () / 8) * 237 | ((access_d + bottle_d + access_d) * 2).GetSeconds ()); 238 | 239 | Config::SetDefault ("ns3::PfifoFastQueueDisc::MaxSize", 240 | QueueSizeValue (QueueSize (QueueSizeUnit::PACKETS, size / mtu_bytes))); 241 | Config::SetDefault ("ns3::CoDelQueueDisc::MaxSize", 242 | QueueSizeValue (QueueSize (QueueSizeUnit::BYTES, size))); 243 | 244 | if (queue_disc_type.compare ("ns3::PfifoFastQueueDisc") == 0) 245 | { 246 | tchPfifo.Install (d.GetLeft()->GetDevice(1)); 247 | tchPfifo.Install (d.GetRight()->GetDevice(1)); 248 | } 249 | else if (queue_disc_type.compare ("ns3::CoDelQueueDisc") == 0) 250 | { 251 | tchCoDel.Install (d.GetLeft()->GetDevice(1)); 252 | tchCoDel.Install (d.GetRight()->GetDevice(1)); 253 | } 254 | else 255 | { 256 | NS_FATAL_ERROR ("Queue not recognized. Allowed values are ns3::CoDelQueueDisc or ns3::PfifoFastQueueDisc"); 257 | } 258 | 259 | // Assign IP Addresses 260 | d.AssignIpv4Addresses (Ipv4AddressHelper ("10.1.1.0", "255.255.255.0"), 261 | Ipv4AddressHelper ("10.2.1.0", "255.255.255.0"), 262 | Ipv4AddressHelper ("10.3.1.0", "255.255.255.0")); 263 | 264 | 265 | NS_LOG_INFO ("Initialize Global Routing."); 266 | Ipv4GlobalRoutingHelper::PopulateRoutingTables (); 267 | 268 | // Install apps in left and right nodes 269 | uint16_t port = 50000; 270 | Address sinkLocalAddress (InetSocketAddress (Ipv4Address::GetAny (), port)); 271 | PacketSinkHelper sinkHelper ("ns3::TcpSocketFactory", sinkLocalAddress); 272 | ApplicationContainer sinkApps; 273 | for (uint32_t i = 0; i < d.RightCount (); ++i) 274 | { 275 | sinkHelper.SetAttribute ("Protocol", TypeIdValue (TcpSocketFactory::GetTypeId ())); 276 | sinkApps.Add (sinkHelper.Install (d.GetRight (i))); 277 | } 278 | sinkApps.Start (Seconds (0.0)); 279 | sinkApps.Stop (Seconds (stop_time)); 280 | 281 | for (uint32_t i = 0; i < d.LeftCount (); ++i) 282 | { 283 | // Create an on/off app sending packets to the left side 284 | AddressValue remoteAddress (InetSocketAddress (d.GetRightIpv4Address (i), port)); 285 | Config::SetDefault ("ns3::TcpSocket::SegmentSize", UintegerValue (tcp_adu_size)); 286 | BulkSendHelper ftp ("ns3::TcpSocketFactory", Address ()); 287 | ftp.SetAttribute ("Remote", remoteAddress); 288 | ftp.SetAttribute ("SendSize", UintegerValue (tcp_adu_size)); 289 | ftp.SetAttribute ("MaxBytes", UintegerValue (data_mbytes * 1000000)); 290 | 291 | ApplicationContainer clientApp = ftp.Install (d.GetLeft (i)); 292 | clientApp.Start (Seconds (start_time * i)); // Start after sink 293 | clientApp.Stop (Seconds (stop_time - 3)); // Stop before the sink 294 | } 295 | 296 | // Flow monitor 297 | FlowMonitorHelper flowHelper; 298 | if (flow_monitor) 299 | { 300 | flowHelper.InstallAll (); 301 | } 302 | 303 | // Count RX packets 304 | for (uint32_t i = 0; i < d.RightCount (); ++i) 305 | { 306 | rxPkts.push_back(0); 307 | Ptr pktSink = DynamicCast(sinkApps.Get(i)); 308 | pktSink->TraceConnectWithoutContext ("Rx", MakeBoundCallback (&CountRxPkts, i)); 309 | } 310 | 311 | Simulator::Stop (Seconds (stop_time)); 312 | Simulator::Run (); 313 | 314 | if (flow_monitor) 315 | { 316 | flowHelper.SerializeToXmlFile (prefix_file_name + ".flowmonitor", true, true); 317 | } 318 | 319 | if (transport_prot.compare ("ns3::TcpRl") == 0 or transport_prot.compare ("ns3::TcpRlTimeBased") == 0) 320 | { 321 | openGymInterface->NotifySimulationEnd(); 322 | } 323 | 324 | PrintRxCount(); 325 | Simulator::Destroy (); 326 | return 0; 327 | } 328 | -------------------------------------------------------------------------------- /tcp-rl-env.cc: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Technische Universität Berlin 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | */ 20 | 21 | #include "tcp-rl-env.h" 22 | #include "ns3/tcp-header.h" 23 | #include "ns3/object.h" 24 | #include "ns3/core-module.h" 25 | #include "ns3/log.h" 26 | #include "ns3/simulator.h" 27 | #include "ns3/tcp-socket-base.h" 28 | #include 29 | #include 30 | 31 | 32 | namespace ns3 { 33 | 34 | NS_LOG_COMPONENT_DEFINE ("ns3::TcpGymEnv"); 35 | NS_OBJECT_ENSURE_REGISTERED (TcpGymEnv); 36 | 37 | TcpGymEnv::TcpGymEnv () 38 | { 39 | NS_LOG_FUNCTION (this); 40 | SetOpenGymInterface(OpenGymInterface::Get()); 41 | } 42 | 43 | TcpGymEnv::~TcpGymEnv () 44 | { 45 | NS_LOG_FUNCTION (this); 46 | } 47 | 48 | TypeId 49 | TcpGymEnv::GetTypeId (void) 50 | { 51 | static TypeId tid = TypeId ("ns3::TcpGymEnv") 52 | .SetParent () 53 | .SetGroupName ("OpenGym") 54 | ; 55 | 56 | return tid; 57 | } 58 | 59 | void 60 | TcpGymEnv::DoDispose () 61 | { 62 | NS_LOG_FUNCTION (this); 63 | } 64 | 65 | void 66 | TcpGymEnv::SetNodeId(uint32_t id) 67 | { 68 | NS_LOG_FUNCTION (this); 69 | m_nodeId = id; 70 | } 71 | 72 | void 73 | TcpGymEnv::SetSocketUuid(uint32_t id) 74 | { 75 | NS_LOG_FUNCTION (this); 76 | m_socketUuid = id; 77 | } 78 | 79 | std::string 80 | TcpGymEnv::GetTcpCongStateName(const TcpSocketState::TcpCongState_t state) 81 | { 82 | std::string stateName = "UNKNOWN"; 83 | switch(state) { 84 | case TcpSocketState::CA_OPEN: 85 | stateName = "CA_OPEN"; 86 | break; 87 | case TcpSocketState::CA_DISORDER: 88 | stateName = "CA_DISORDER"; 89 | break; 90 | case TcpSocketState::CA_CWR: 91 | stateName = "CA_CWR"; 92 | break; 93 | case TcpSocketState::CA_RECOVERY: 94 | stateName = "CA_RECOVERY"; 95 | break; 96 | case TcpSocketState::CA_LOSS: 97 | stateName = "CA_LOSS"; 98 | break; 99 | case TcpSocketState::CA_LAST_STATE: 100 | stateName = "CA_LAST_STATE"; 101 | break; 102 | default: 103 | stateName = "UNKNOWN"; 104 | break; 105 | } 106 | return stateName; 107 | } 108 | 109 | std::string 110 | TcpGymEnv::GetTcpCAEventName(const TcpSocketState::TcpCAEvent_t event) 111 | { 112 | std::string eventName = "UNKNOWN"; 113 | switch(event) { 114 | case TcpSocketState::CA_EVENT_TX_START: 115 | eventName = "CA_EVENT_TX_START"; 116 | break; 117 | case TcpSocketState::CA_EVENT_CWND_RESTART: 118 | eventName = "CA_EVENT_CWND_RESTART"; 119 | break; 120 | case TcpSocketState::CA_EVENT_COMPLETE_CWR: 121 | eventName = "CA_EVENT_COMPLETE_CWR"; 122 | break; 123 | case TcpSocketState::CA_EVENT_LOSS: 124 | eventName = "CA_EVENT_LOSS"; 125 | break; 126 | case TcpSocketState::CA_EVENT_ECN_NO_CE: 127 | eventName = "CA_EVENT_ECN_NO_CE"; 128 | break; 129 | case TcpSocketState::CA_EVENT_ECN_IS_CE: 130 | eventName = "CA_EVENT_ECN_IS_CE"; 131 | break; 132 | case TcpSocketState::CA_EVENT_DELAYED_ACK: 133 | eventName = "CA_EVENT_DELAYED_ACK"; 134 | break; 135 | case TcpSocketState::CA_EVENT_NON_DELAYED_ACK: 136 | eventName = "CA_EVENT_NON_DELAYED_ACK"; 137 | break; 138 | default: 139 | eventName = "UNKNOWN"; 140 | break; 141 | } 142 | return eventName; 143 | } 144 | 145 | /* 146 | Define action space 147 | */ 148 | Ptr 149 | TcpGymEnv::GetActionSpace() 150 | { 151 | // new_ssThresh 152 | // new_cWnd 153 | uint32_t parameterNum = 2; 154 | float low = 0.0; 155 | float high = 65535; 156 | std::vector shape = {parameterNum,}; 157 | std::string dtype = TypeNameGet (); 158 | 159 | Ptr box = CreateObject (low, high, shape, dtype); 160 | NS_LOG_INFO ("MyGetActionSpace: " << box); 161 | return box; 162 | } 163 | 164 | /* 165 | Define game over condition 166 | */ 167 | bool 168 | TcpGymEnv::GetGameOver() 169 | { 170 | m_isGameOver = false; 171 | bool test = false; 172 | static float stepCounter = 0.0; 173 | stepCounter += 1; 174 | if (stepCounter == 10 && test) { 175 | m_isGameOver = true; 176 | } 177 | NS_LOG_INFO ("MyGetGameOver: " << m_isGameOver); 178 | return m_isGameOver; 179 | } 180 | 181 | /* 182 | Define reward function 183 | */ 184 | float 185 | TcpGymEnv::GetReward() 186 | { 187 | NS_LOG_INFO("MyGetReward: " << m_envReward); 188 | return m_envReward; 189 | } 190 | 191 | /* 192 | Define extra info. Optional 193 | */ 194 | std::string 195 | TcpGymEnv::GetExtraInfo() 196 | { 197 | NS_LOG_INFO("MyGetExtraInfo: " << m_info); 198 | return m_info; 199 | } 200 | 201 | /* 202 | Execute received actions 203 | */ 204 | bool 205 | TcpGymEnv::ExecuteActions(Ptr action) 206 | { 207 | Ptr > box = DynamicCast >(action); 208 | m_new_ssThresh = box->GetValue(0); 209 | m_new_cWnd = box->GetValue(1); 210 | 211 | NS_LOG_INFO ("MyExecuteActions: " << action); 212 | return true; 213 | } 214 | 215 | 216 | NS_OBJECT_ENSURE_REGISTERED (TcpEventGymEnv); 217 | 218 | TcpEventGymEnv::TcpEventGymEnv () : TcpGymEnv() 219 | { 220 | NS_LOG_FUNCTION (this); 221 | } 222 | 223 | TcpEventGymEnv::~TcpEventGymEnv () 224 | { 225 | NS_LOG_FUNCTION (this); 226 | } 227 | 228 | TypeId 229 | TcpEventGymEnv::GetTypeId (void) 230 | { 231 | static TypeId tid = TypeId ("ns3::TcpEventGymEnv") 232 | .SetParent () 233 | .SetGroupName ("OpenGym") 234 | .AddConstructor () 235 | ; 236 | 237 | return tid; 238 | } 239 | 240 | void 241 | TcpEventGymEnv::DoDispose () 242 | { 243 | NS_LOG_FUNCTION (this); 244 | } 245 | 246 | void 247 | TcpEventGymEnv::SetReward(float value) 248 | { 249 | NS_LOG_FUNCTION (this); 250 | m_reward = value; 251 | } 252 | 253 | void 254 | TcpEventGymEnv::SetPenalty(float value) 255 | { 256 | NS_LOG_FUNCTION (this); 257 | m_penalty = value; 258 | } 259 | 260 | /* 261 | Define observation space 262 | */ 263 | Ptr 264 | TcpEventGymEnv::GetObservationSpace() 265 | { 266 | // socket unique ID 267 | // tcp env type: event-based = 0 / time-based = 1 268 | // sim time in us 269 | // node ID 270 | // ssThresh 271 | // cWnd 272 | // segmentSize 273 | // segmentsAcked 274 | // bytesInFlight 275 | // rtt in us 276 | // min rtt in us 277 | // called func 278 | // congetsion algorithm (CA) state 279 | // CA event 280 | // ECN state 281 | uint32_t parameterNum = 10; 282 | float low = 0.0; 283 | float high = 1000000000.0; 284 | std::vector shape = {parameterNum,}; 285 | std::string dtype = TypeNameGet (); 286 | 287 | Ptr box = CreateObject (low, high, shape, dtype); 288 | NS_LOG_INFO ("MyGetObservationSpace: " << box); 289 | return box; 290 | } 291 | 292 | /* 293 | Collect observations 294 | */ 295 | Ptr 296 | TcpEventGymEnv::GetObservation() 297 | { 298 | uint32_t parameterNum = 10; 299 | std::vector shape = {parameterNum,}; 300 | 301 | Ptr > box = CreateObject >(shape); 302 | 303 | box->AddValue(m_socketUuid); 304 | box->AddValue(0); 305 | box->AddValue(Simulator::Now().GetMicroSeconds ()); 306 | box->AddValue(m_nodeId); 307 | box->AddValue(m_tcb->m_ssThresh); 308 | box->AddValue(m_tcb->m_cWnd); 309 | box->AddValue(m_tcb->m_segmentSize); 310 | box->AddValue(m_segmentsAcked); 311 | box->AddValue(m_bytesInFlight); 312 | box->AddValue(m_rtt.GetMicroSeconds ()); 313 | //box->AddValue(m_tcb->m_minRtt.GetMicroSeconds ()); 314 | //box->AddValue(m_calledFunc); 315 | //box->AddValue(m_tcb->m_congState); 316 | //box->AddValue(m_event); 317 | //box->AddValue(m_tcb->m_ecnState); 318 | 319 | // Print data 320 | NS_LOG_INFO ("MyGetObservation: " << box); 321 | return box; 322 | } 323 | 324 | void 325 | TcpEventGymEnv::TxPktTrace(Ptr, const TcpHeader&, Ptr) 326 | { 327 | NS_LOG_FUNCTION (this); 328 | } 329 | 330 | void 331 | TcpEventGymEnv::RxPktTrace(Ptr, const TcpHeader&, Ptr) 332 | { 333 | NS_LOG_FUNCTION (this); 334 | } 335 | 336 | uint32_t 337 | TcpEventGymEnv::GetSsThresh (Ptr tcb, uint32_t bytesInFlight) 338 | { 339 | NS_LOG_FUNCTION (this); 340 | // pkt was lost, so penalty 341 | m_envReward = m_penalty; 342 | 343 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " GetSsThresh, BytesInFlight: " << bytesInFlight); 344 | m_calledFunc = CalledFunc_t::GET_SS_THRESH; 345 | m_info = "GetSsThresh"; 346 | m_tcb = tcb; 347 | m_bytesInFlight = bytesInFlight; 348 | Notify(); 349 | return m_new_ssThresh; 350 | } 351 | 352 | void 353 | TcpEventGymEnv::IncreaseWindow (Ptr tcb, uint32_t segmentsAcked) 354 | { 355 | NS_LOG_FUNCTION (this); 356 | // pkt was acked, so reward 357 | m_envReward = m_reward; 358 | 359 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " IncreaseWindow, SegmentsAcked: " << segmentsAcked); 360 | m_calledFunc = CalledFunc_t::INCREASE_WINDOW; 361 | m_info = "IncreaseWindow"; 362 | m_tcb = tcb; 363 | m_segmentsAcked = segmentsAcked; 364 | Notify(); 365 | tcb->m_cWnd = m_new_cWnd; 366 | } 367 | 368 | void 369 | TcpEventGymEnv::PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt) 370 | { 371 | NS_LOG_FUNCTION (this); 372 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " PktsAcked, SegmentsAcked: " << segmentsAcked << " Rtt: " << rtt); 373 | m_calledFunc = CalledFunc_t::PKTS_ACKED; 374 | m_info = "PktsAcked"; 375 | m_tcb = tcb; 376 | m_segmentsAcked = segmentsAcked; 377 | m_rtt = rtt; 378 | } 379 | 380 | void 381 | TcpEventGymEnv::CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState) 382 | { 383 | NS_LOG_FUNCTION (this); 384 | std::string stateName = GetTcpCongStateName(newState); 385 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CongestionStateSet: " << newState << " " << stateName); 386 | 387 | m_calledFunc = CalledFunc_t::CONGESTION_STATE_SET; 388 | m_info = "CongestionStateSet"; 389 | m_tcb = tcb; 390 | m_newState = newState; 391 | } 392 | 393 | void 394 | TcpEventGymEnv::CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event) 395 | { 396 | NS_LOG_FUNCTION (this); 397 | std::string eventName = GetTcpCAEventName(event); 398 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CwndEvent: " << event << " " << eventName); 399 | 400 | m_calledFunc = CalledFunc_t::CWND_EVENT; 401 | m_info = "CwndEvent"; 402 | m_tcb = tcb; 403 | m_event = event; 404 | } 405 | 406 | 407 | NS_OBJECT_ENSURE_REGISTERED (TcpTimeStepGymEnv); 408 | 409 | TcpTimeStepGymEnv::TcpTimeStepGymEnv () : TcpGymEnv() 410 | { 411 | NS_LOG_FUNCTION (this); 412 | m_envReward = 0.0; 413 | } 414 | 415 | TcpTimeStepGymEnv::TcpTimeStepGymEnv (Time timeStep) : TcpGymEnv() 416 | { 417 | NS_LOG_FUNCTION (this); 418 | m_timeStep = timeStep; 419 | m_envReward = 0.0; 420 | } 421 | 422 | void 423 | TcpTimeStepGymEnv::ScheduleNextStateRead () 424 | { 425 | NS_LOG_FUNCTION (this); 426 | Simulator::Schedule (m_timeStep, &TcpTimeStepGymEnv::ScheduleNextStateRead, this); 427 | Notify(); 428 | } 429 | 430 | TcpTimeStepGymEnv::~TcpTimeStepGymEnv () 431 | { 432 | NS_LOG_FUNCTION (this); 433 | } 434 | 435 | TypeId 436 | TcpTimeStepGymEnv::GetTypeId (void) 437 | { 438 | static TypeId tid = TypeId ("ns3::TcpTimeStepGymEnv") 439 | .SetParent () 440 | .SetGroupName ("OpenGym") 441 | .AddConstructor () 442 | ; 443 | 444 | return tid; 445 | } 446 | 447 | void 448 | TcpTimeStepGymEnv::DoDispose () 449 | { 450 | NS_LOG_FUNCTION (this); 451 | } 452 | 453 | /* 454 | Define observation space 455 | */ 456 | Ptr 457 | TcpTimeStepGymEnv::GetObservationSpace() 458 | { 459 | // socket unique ID 460 | // tcp env type: event-based = 0 / time-based = 1 461 | // sim time in us 462 | // node ID 463 | // ssThresh 464 | // cWnd 465 | // segmentSize 466 | // bytesInFlightSum 467 | // bytesInFlightAvg 468 | // segmentsAckedSum 469 | // segmentsAckedAvg 470 | // avgRtt 471 | // minRtt 472 | // avgInterTx 473 | // avgInterRx 474 | // throughput 475 | uint32_t parameterNum = 16; 476 | float low = 0.0; 477 | float high = 1000000000.0; 478 | std::vector shape = {parameterNum,}; 479 | std::string dtype = TypeNameGet (); 480 | 481 | Ptr box = CreateObject (low, high, shape, dtype); 482 | NS_LOG_INFO ("MyGetObservationSpace: " << box); 483 | return box; 484 | } 485 | 486 | /* 487 | Collect observations 488 | */ 489 | Ptr 490 | TcpTimeStepGymEnv::GetObservation() 491 | { 492 | uint32_t parameterNum = 16; 493 | std::vector shape = {parameterNum,}; 494 | 495 | Ptr > box = CreateObject >(shape); 496 | 497 | box->AddValue(m_socketUuid); 498 | box->AddValue(1); 499 | box->AddValue(Simulator::Now().GetMicroSeconds ()); 500 | box->AddValue(m_nodeId); 501 | box->AddValue(m_tcb->m_ssThresh); 502 | box->AddValue(m_tcb->m_cWnd); 503 | box->AddValue(m_tcb->m_segmentSize); 504 | 505 | //bytesInFlightSum 506 | uint64_t bytesInFlightSum = std::accumulate(m_bytesInFlight.begin(), m_bytesInFlight.end(), 0); 507 | box->AddValue(bytesInFlightSum); 508 | 509 | //bytesInFlightAvg 510 | uint64_t bytesInFlightAvg = 0; 511 | if (m_bytesInFlight.size()) { 512 | bytesInFlightAvg = bytesInFlightSum / m_bytesInFlight.size(); 513 | } 514 | box->AddValue(bytesInFlightAvg); 515 | 516 | //segmentsAckedSum 517 | uint64_t segmentsAckedSum = std::accumulate(m_segmentsAcked.begin(), m_segmentsAcked.end(), 0); 518 | box->AddValue(segmentsAckedSum); 519 | 520 | //segmentsAckedAvg 521 | uint64_t segmentsAckedAvg = 0; 522 | if (m_segmentsAcked.size()) { 523 | segmentsAckedAvg = segmentsAckedSum / m_segmentsAcked.size(); 524 | } 525 | box->AddValue(segmentsAckedAvg); 526 | 527 | //avgRtt 528 | Time avgRtt = Seconds(0.0); 529 | if(m_rttSampleNum) { 530 | avgRtt = m_rttSum / m_rttSampleNum; 531 | } 532 | box->AddValue(avgRtt.GetMicroSeconds ()); 533 | 534 | //m_minRtt 535 | box->AddValue(m_tcb->m_minRtt.GetMicroSeconds ()); 536 | 537 | //avgInterTx 538 | Time avgInterTx = Seconds(0.0); 539 | if (m_interTxTimeNum) { 540 | avgInterTx = m_interTxTimeSum / m_interTxTimeNum; 541 | } 542 | box->AddValue(avgInterTx.GetMicroSeconds ()); 543 | 544 | //avgInterRx 545 | Time avgInterRx = Seconds(0.0); 546 | if (m_interRxTimeNum) { 547 | avgInterRx = m_interRxTimeSum / m_interRxTimeNum; 548 | } 549 | box->AddValue(avgInterRx.GetMicroSeconds ()); 550 | 551 | //throughput bytes/s 552 | float throughput = (segmentsAckedSum * m_tcb->m_segmentSize) / m_timeStep.GetSeconds(); 553 | box->AddValue(throughput); 554 | 555 | // Print data 556 | NS_LOG_INFO ("MyGetObservation: " << box); 557 | 558 | m_bytesInFlight.clear(); 559 | m_segmentsAcked.clear(); 560 | 561 | m_rttSampleNum = 0; 562 | m_rttSum = MicroSeconds (0.0); 563 | 564 | m_interTxTimeNum = 0; 565 | m_interTxTimeSum = MicroSeconds (0.0); 566 | 567 | m_interRxTimeNum = 0; 568 | m_interRxTimeSum = MicroSeconds (0.0); 569 | 570 | return box; 571 | } 572 | 573 | void 574 | TcpTimeStepGymEnv::TxPktTrace(Ptr, const TcpHeader&, Ptr) 575 | { 576 | NS_LOG_FUNCTION (this); 577 | if ( m_lastPktTxTime > MicroSeconds(0.0) ) { 578 | Time interTxTime = Simulator::Now() - m_lastPktTxTime; 579 | m_interTxTimeSum += interTxTime; 580 | m_interTxTimeNum++; 581 | } 582 | 583 | m_lastPktTxTime = Simulator::Now(); 584 | } 585 | 586 | void 587 | TcpTimeStepGymEnv::RxPktTrace(Ptr, const TcpHeader&, Ptr) 588 | { 589 | NS_LOG_FUNCTION (this); 590 | if ( m_lastPktRxTime > MicroSeconds(0.0) ) { 591 | Time interRxTime = Simulator::Now() - m_lastPktRxTime; 592 | m_interRxTimeSum += interRxTime; 593 | m_interRxTimeNum++; 594 | } 595 | 596 | m_lastPktRxTime = Simulator::Now(); 597 | } 598 | 599 | uint32_t 600 | TcpTimeStepGymEnv::GetSsThresh (Ptr tcb, uint32_t bytesInFlight) 601 | { 602 | NS_LOG_FUNCTION (this); 603 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " GetSsThresh, BytesInFlight: " << bytesInFlight); 604 | m_tcb = tcb; 605 | m_bytesInFlight.push_back(bytesInFlight); 606 | 607 | if (!m_started) { 608 | m_started = true; 609 | Notify(); 610 | ScheduleNextStateRead(); 611 | } 612 | 613 | // action 614 | return m_new_ssThresh; 615 | } 616 | 617 | void 618 | TcpTimeStepGymEnv::IncreaseWindow (Ptr tcb, uint32_t segmentsAcked) 619 | { 620 | NS_LOG_FUNCTION (this); 621 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " IncreaseWindow, SegmentsAcked: " << segmentsAcked); 622 | m_tcb = tcb; 623 | m_segmentsAcked.push_back(segmentsAcked); 624 | m_bytesInFlight.push_back(tcb->m_bytesInFlight); 625 | 626 | if (!m_started) { 627 | m_started = true; 628 | Notify(); 629 | ScheduleNextStateRead(); 630 | } 631 | // action 632 | tcb->m_cWnd = m_new_cWnd; 633 | } 634 | 635 | void 636 | TcpTimeStepGymEnv::PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt) 637 | { 638 | NS_LOG_FUNCTION (this); 639 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " PktsAcked, SegmentsAcked: " << segmentsAcked << " Rtt: " << rtt); 640 | m_tcb = tcb; 641 | m_rttSum += rtt; 642 | m_rttSampleNum++; 643 | } 644 | 645 | void 646 | TcpTimeStepGymEnv::CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState) 647 | { 648 | NS_LOG_FUNCTION (this); 649 | std::string stateName = GetTcpCongStateName(newState); 650 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CongestionStateSet: " << newState << " " << stateName); 651 | m_tcb = tcb; 652 | } 653 | 654 | void 655 | TcpTimeStepGymEnv::CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event) 656 | { 657 | NS_LOG_FUNCTION (this); 658 | std::string eventName = GetTcpCAEventName(event); 659 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CwndEvent: " << event << " " << eventName); 660 | m_tcb = tcb; 661 | } 662 | 663 | } // namespace ns3 664 | -------------------------------------------------------------------------------- /tcp-rl-env.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Technische Universität Berlin 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | */ 20 | 21 | #ifndef TCP_RL_ENV_H 22 | #define TCP_RL_ENV_H 23 | 24 | #include "ns3/opengym-module.h" 25 | #include "ns3/tcp-socket-base.h" 26 | #include 27 | 28 | namespace ns3 { 29 | 30 | class Packet; 31 | class TcpHeader; 32 | class TcpSocketBase; 33 | class Time; 34 | 35 | 36 | class TcpGymEnv : public OpenGymEnv 37 | { 38 | public: 39 | TcpGymEnv (); 40 | virtual ~TcpGymEnv (); 41 | static TypeId GetTypeId (void); 42 | virtual void DoDispose (); 43 | 44 | void SetNodeId(uint32_t id); 45 | void SetSocketUuid(uint32_t id); 46 | 47 | std::string GetTcpCongStateName(const TcpSocketState::TcpCongState_t state); 48 | std::string GetTcpCAEventName(const TcpSocketState::TcpCAEvent_t event); 49 | 50 | // OpenGym interface 51 | virtual Ptr GetActionSpace(); 52 | virtual bool GetGameOver(); 53 | virtual float GetReward(); 54 | virtual std::string GetExtraInfo(); 55 | virtual bool ExecuteActions(Ptr action); 56 | 57 | virtual Ptr GetObservationSpace() = 0; 58 | virtual Ptr GetObservation() = 0; 59 | 60 | // trace packets, e.g. for calculating inter tx/rx time 61 | virtual void TxPktTrace(Ptr, const TcpHeader&, Ptr) = 0; 62 | virtual void RxPktTrace(Ptr, const TcpHeader&, Ptr) = 0; 63 | 64 | // TCP congestion control interface 65 | virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight) = 0; 66 | virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked) = 0; 67 | // optional functions used to collect obs 68 | virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt) = 0; 69 | virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState) = 0; 70 | virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event) = 0; 71 | 72 | typedef enum 73 | { 74 | GET_SS_THRESH = 0, 75 | INCREASE_WINDOW, 76 | PKTS_ACKED, 77 | CONGESTION_STATE_SET, 78 | CWND_EVENT, 79 | } CalledFunc_t; 80 | 81 | protected: 82 | uint32_t m_nodeId; 83 | uint32_t m_socketUuid; 84 | 85 | // state 86 | // obs has to be implemented in child class 87 | 88 | // game over 89 | bool m_isGameOver; 90 | 91 | // reward 92 | float m_envReward; 93 | 94 | // extra info 95 | std::string m_info; 96 | 97 | // actions 98 | uint32_t m_new_ssThresh; 99 | uint32_t m_new_cWnd; 100 | }; 101 | 102 | 103 | class TcpEventGymEnv : public TcpGymEnv 104 | { 105 | public: 106 | TcpEventGymEnv (); 107 | virtual ~TcpEventGymEnv (); 108 | static TypeId GetTypeId (void); 109 | virtual void DoDispose (); 110 | 111 | void SetReward(float value); 112 | void SetPenalty(float value); 113 | 114 | // OpenGym interface 115 | virtual Ptr GetObservationSpace(); 116 | Ptr GetObservation(); 117 | 118 | // trace packets, e.g. for calculating inter tx/rx time 119 | virtual void TxPktTrace(Ptr, const TcpHeader&, Ptr); 120 | virtual void RxPktTrace(Ptr, const TcpHeader&, Ptr); 121 | 122 | // TCP congestion control interface 123 | virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight); 124 | virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked); 125 | // optional functions used to collect obs 126 | virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt); 127 | virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState); 128 | virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event); 129 | 130 | private: 131 | // state 132 | CalledFunc_t m_calledFunc; 133 | Ptr m_tcb; 134 | uint32_t m_bytesInFlight; 135 | uint32_t m_segmentsAcked; 136 | Time m_rtt; 137 | TcpSocketState::TcpCongState_t m_newState; 138 | TcpSocketState::TcpCAEvent_t m_event; 139 | 140 | // reward 141 | float m_reward; 142 | float m_penalty; 143 | }; 144 | 145 | 146 | class TcpTimeStepGymEnv : public TcpGymEnv 147 | { 148 | public: 149 | TcpTimeStepGymEnv (); 150 | TcpTimeStepGymEnv (Time timeStep); 151 | virtual ~TcpTimeStepGymEnv (); 152 | static TypeId GetTypeId (void); 153 | virtual void DoDispose (); 154 | 155 | // OpenGym interface 156 | virtual Ptr GetObservationSpace(); 157 | Ptr GetObservation(); 158 | 159 | // trace packets, e.g. for calculating inter tx/rx time 160 | virtual void TxPktTrace(Ptr, const TcpHeader&, Ptr); 161 | virtual void RxPktTrace(Ptr, const TcpHeader&, Ptr); 162 | 163 | // TCP congestion control interface 164 | virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight); 165 | virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked); 166 | // optional functions used to collect obs 167 | virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt); 168 | virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState); 169 | virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event); 170 | 171 | private: 172 | void ScheduleNextStateRead(); 173 | bool m_started {false}; 174 | Time m_timeStep; 175 | // state 176 | Ptr m_tcb; 177 | std::vector m_bytesInFlight; 178 | std::vector m_segmentsAcked; 179 | 180 | uint64_t m_rttSampleNum {0}; 181 | Time m_rttSum {MicroSeconds (0.0)}; 182 | 183 | Time m_lastPktTxTime {MicroSeconds(0.0)}; 184 | Time m_lastPktRxTime {MicroSeconds(0.0)}; 185 | uint64_t m_interTxTimeNum {0}; 186 | Time m_interTxTimeSum {MicroSeconds (0.0)}; 187 | uint64_t m_interRxTimeNum {0}; 188 | Time m_interRxTimeSum {MicroSeconds (0.0)}; 189 | 190 | // reward 191 | }; 192 | 193 | 194 | 195 | } // namespace ns3 196 | 197 | #endif /* TCP_RL_ENV_H */ 198 | -------------------------------------------------------------------------------- /tcp-rl.cc: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Technische Universität Berlin 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | */ 20 | 21 | #include "tcp-rl.h" 22 | #include "tcp-rl-env.h" 23 | #include "ns3/tcp-header.h" 24 | #include "ns3/object.h" 25 | #include "ns3/node-list.h" 26 | #include "ns3/core-module.h" 27 | #include "ns3/log.h" 28 | #include "ns3/simulator.h" 29 | #include "ns3/tcp-socket-base.h" 30 | #include "ns3/tcp-l4-protocol.h" 31 | 32 | 33 | namespace ns3 { 34 | 35 | 36 | NS_OBJECT_ENSURE_REGISTERED (TcpSocketDerived); 37 | 38 | TypeId 39 | TcpSocketDerived::GetTypeId (void) 40 | { 41 | static TypeId tid = TypeId ("ns3::TcpSocketDerived") 42 | .SetParent () 43 | .SetGroupName ("Internet") 44 | .AddConstructor () 45 | ; 46 | return tid; 47 | } 48 | 49 | TypeId 50 | TcpSocketDerived::GetInstanceTypeId () const 51 | { 52 | return TcpSocketDerived::GetTypeId (); 53 | } 54 | 55 | TcpSocketDerived::TcpSocketDerived (void) 56 | { 57 | } 58 | 59 | Ptr 60 | TcpSocketDerived::GetCongestionControlAlgorithm () 61 | { 62 | return m_congestionControl; 63 | } 64 | 65 | TcpSocketDerived::~TcpSocketDerived (void) 66 | { 67 | } 68 | 69 | 70 | NS_LOG_COMPONENT_DEFINE ("ns3::TcpRlBase"); 71 | NS_OBJECT_ENSURE_REGISTERED (TcpRlBase); 72 | 73 | TypeId 74 | TcpRlBase::GetTypeId (void) 75 | { 76 | static TypeId tid = TypeId ("ns3::TcpRlBase") 77 | .SetParent () 78 | .SetGroupName ("Internet") 79 | .AddConstructor () 80 | ; 81 | return tid; 82 | } 83 | 84 | TcpRlBase::TcpRlBase (void) 85 | : TcpCongestionOps () 86 | { 87 | NS_LOG_FUNCTION (this); 88 | m_tcpSocket = 0; 89 | m_tcpGymEnv = 0; 90 | } 91 | 92 | TcpRlBase::TcpRlBase (const TcpRlBase& sock) 93 | : TcpCongestionOps (sock) 94 | { 95 | NS_LOG_FUNCTION (this); 96 | m_tcpSocket = 0; 97 | m_tcpGymEnv = 0; 98 | } 99 | 100 | TcpRlBase::~TcpRlBase (void) 101 | { 102 | m_tcpSocket = 0; 103 | m_tcpGymEnv = 0; 104 | } 105 | 106 | uint64_t 107 | TcpRlBase::GenerateUuid () 108 | { 109 | static uint64_t uuid = 0; 110 | uuid++; 111 | return uuid; 112 | } 113 | 114 | void 115 | TcpRlBase::CreateGymEnv() 116 | { 117 | NS_LOG_FUNCTION (this); 118 | // should never be called, only child classes: TcpRl and TcpRlTimeBased 119 | } 120 | 121 | void 122 | TcpRlBase::ConnectSocketCallbacks() 123 | { 124 | NS_LOG_FUNCTION (this); 125 | 126 | bool foundSocket = false; 127 | for (NodeList::Iterator i = NodeList::Begin (); i != NodeList::End (); ++i) { 128 | Ptr node = *i; 129 | Ptr tcp = node->GetObject (); 130 | 131 | ObjectVectorValue socketVec; 132 | tcp->GetAttribute ("SocketList", socketVec); 133 | NS_LOG_DEBUG("Node: " << node->GetId() << " TCP socket num: " << socketVec.GetN()); 134 | 135 | uint32_t sockNum = socketVec.GetN(); 136 | for (uint32_t j=0; j sockObj = socketVec.Get(j); 138 | Ptr tcpSocket = DynamicCast (sockObj); 139 | NS_LOG_DEBUG("Node: " << node->GetId() << " TCP Socket: " << tcpSocket); 140 | if(!tcpSocket) { continue; } 141 | 142 | Ptr dtcpSocket = StaticCast(tcpSocket); 143 | Ptr ca = dtcpSocket->GetCongestionControlAlgorithm(); 144 | NS_LOG_DEBUG("CA name: " << ca->GetName()); 145 | Ptr rlCa = DynamicCast(ca); 146 | if (rlCa == this) { 147 | NS_LOG_DEBUG("Found TcpRl CA!"); 148 | foundSocket = true; 149 | m_tcpSocket = tcpSocket; 150 | break; 151 | } 152 | } 153 | 154 | if (foundSocket) { 155 | break; 156 | } 157 | } 158 | 159 | NS_ASSERT_MSG(m_tcpSocket, "TCP socket was not found."); 160 | 161 | if(m_tcpSocket) { 162 | NS_LOG_DEBUG("Found TCP Socket: " << m_tcpSocket); 163 | m_tcpSocket->TraceConnectWithoutContext ("Tx", MakeCallback (&TcpGymEnv::TxPktTrace, m_tcpGymEnv)); 164 | m_tcpSocket->TraceConnectWithoutContext ("Rx", MakeCallback (&TcpGymEnv::RxPktTrace, m_tcpGymEnv)); 165 | NS_LOG_DEBUG("Connect socket callbacks " << m_tcpSocket->GetNode()->GetId()); 166 | m_tcpGymEnv->SetNodeId(m_tcpSocket->GetNode()->GetId()); 167 | } 168 | } 169 | 170 | std::string 171 | TcpRlBase::GetName () const 172 | { 173 | return "TcpRlBase"; 174 | } 175 | 176 | uint32_t 177 | TcpRlBase::GetSsThresh (Ptr state, 178 | uint32_t bytesInFlight) 179 | { 180 | NS_LOG_FUNCTION (this << state << bytesInFlight); 181 | 182 | if (!m_tcpGymEnv) { 183 | CreateGymEnv(); 184 | } 185 | 186 | uint32_t newSsThresh = 0; 187 | if (m_tcpGymEnv) { 188 | newSsThresh = m_tcpGymEnv->GetSsThresh(state, bytesInFlight); 189 | } 190 | 191 | return newSsThresh; 192 | } 193 | 194 | void 195 | TcpRlBase::IncreaseWindow (Ptr tcb, uint32_t segmentsAcked) 196 | { 197 | NS_LOG_FUNCTION (this << tcb << segmentsAcked); 198 | 199 | if (!m_tcpGymEnv) { 200 | CreateGymEnv(); 201 | } 202 | 203 | if (m_tcpGymEnv) { 204 | m_tcpGymEnv->IncreaseWindow(tcb, segmentsAcked); 205 | } 206 | } 207 | 208 | void 209 | TcpRlBase::PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt) 210 | { 211 | NS_LOG_FUNCTION (this); 212 | 213 | if (!m_tcpGymEnv) { 214 | CreateGymEnv(); 215 | } 216 | 217 | if (m_tcpGymEnv) { 218 | m_tcpGymEnv->PktsAcked(tcb, segmentsAcked, rtt); 219 | } 220 | } 221 | 222 | void 223 | TcpRlBase::CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState) 224 | { 225 | NS_LOG_FUNCTION (this); 226 | 227 | if (!m_tcpGymEnv) { 228 | CreateGymEnv(); 229 | } 230 | 231 | if (m_tcpGymEnv) { 232 | m_tcpGymEnv->CongestionStateSet(tcb, newState); 233 | } 234 | } 235 | 236 | void 237 | TcpRlBase::CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event) 238 | { 239 | NS_LOG_FUNCTION (this); 240 | 241 | if (!m_tcpGymEnv) { 242 | CreateGymEnv(); 243 | } 244 | 245 | if (m_tcpGymEnv) { 246 | m_tcpGymEnv->CwndEvent(tcb, event); 247 | } 248 | } 249 | 250 | Ptr 251 | TcpRlBase::Fork () 252 | { 253 | return CopyObject (this); 254 | } 255 | 256 | 257 | NS_OBJECT_ENSURE_REGISTERED (TcpRl); 258 | 259 | TypeId 260 | TcpRl::GetTypeId (void) 261 | { 262 | static TypeId tid = TypeId ("ns3::TcpRl") 263 | .SetParent () 264 | .SetGroupName ("Internet") 265 | .AddConstructor () 266 | .AddAttribute ("Reward", "Reward when increasing congestion window.", 267 | DoubleValue (4.0), 268 | MakeDoubleAccessor (&TcpRl::m_reward), 269 | MakeDoubleChecker ()) 270 | .AddAttribute ("Penalty", "Reward when increasing congestion window.", 271 | DoubleValue (-10.0), 272 | MakeDoubleAccessor (&TcpRl::m_penalty), 273 | MakeDoubleChecker ()) 274 | ; 275 | return tid; 276 | } 277 | 278 | TcpRl::TcpRl (void) 279 | : TcpRlBase () 280 | { 281 | NS_LOG_FUNCTION (this); 282 | } 283 | 284 | TcpRl::TcpRl (const TcpRl& sock) 285 | : TcpRlBase (sock) 286 | { 287 | NS_LOG_FUNCTION (this); 288 | } 289 | 290 | TcpRl::~TcpRl (void) 291 | { 292 | } 293 | 294 | std::string 295 | TcpRl::GetName () const 296 | { 297 | return "TcpRl"; 298 | } 299 | 300 | void 301 | TcpRl::CreateGymEnv() 302 | { 303 | NS_LOG_FUNCTION (this); 304 | Ptr env = CreateObject(); 305 | env->SetSocketUuid(TcpRlBase::GenerateUuid()); 306 | env->SetReward(m_reward); 307 | env->SetPenalty(m_penalty); 308 | m_tcpGymEnv = env; 309 | 310 | ConnectSocketCallbacks(); 311 | } 312 | 313 | 314 | NS_OBJECT_ENSURE_REGISTERED (TcpRlTimeBased); 315 | 316 | TypeId 317 | TcpRlTimeBased::GetTypeId (void) 318 | { 319 | static TypeId tid = TypeId ("ns3::TcpRlTimeBased") 320 | .SetParent () 321 | .SetGroupName ("Internet") 322 | .AddConstructor () 323 | .AddAttribute ("StepTime", 324 | "Step interval used in TCP env. Default: 100ms", 325 | TimeValue (MilliSeconds (100)), 326 | MakeTimeAccessor (&TcpRlTimeBased::m_timeStep), 327 | MakeTimeChecker ()) 328 | ; 329 | return tid; 330 | } 331 | 332 | TcpRlTimeBased::TcpRlTimeBased (void) : TcpRlBase () 333 | { 334 | NS_LOG_FUNCTION (this); 335 | } 336 | 337 | TcpRlTimeBased::TcpRlTimeBased (const TcpRlTimeBased& sock) 338 | : TcpRlBase (sock) 339 | { 340 | NS_LOG_FUNCTION (this); 341 | } 342 | 343 | TcpRlTimeBased::~TcpRlTimeBased (void) 344 | { 345 | } 346 | 347 | std::string 348 | TcpRlTimeBased::GetName () const 349 | { 350 | return "TcpRlTimeBased"; 351 | } 352 | 353 | void 354 | TcpRlTimeBased::CreateGymEnv() 355 | { 356 | NS_LOG_FUNCTION (this); 357 | Ptr env = CreateObject (m_timeStep); 358 | env->SetSocketUuid(TcpRlBase::GenerateUuid()); 359 | m_tcpGymEnv = env; 360 | 361 | ConnectSocketCallbacks(); 362 | } 363 | 364 | } // namespace ns3 365 | -------------------------------------------------------------------------------- /tcp-rl.h: -------------------------------------------------------------------------------- 1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */ 2 | /* 3 | * Copyright (c) 2018 Technische Universität Berlin 4 | * 5 | * This program is free software; you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License version 2 as 7 | * published by the Free Software Foundation; 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, write to the Free Software 16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 17 | * 18 | * Author: Piotr Gawlowicz 19 | */ 20 | 21 | #ifndef TCP_RL_H 22 | #define TCP_RL_H 23 | 24 | #include "ns3/tcp-congestion-ops.h" 25 | #include "ns3/opengym-module.h" 26 | #include "ns3/tcp-socket-base.h" 27 | 28 | namespace ns3 { 29 | 30 | class TcpSocketBase; 31 | class Time; 32 | class TcpGymEnv; 33 | 34 | 35 | // used to get pointer to Congestion Algorithm 36 | class TcpSocketDerived : public TcpSocketBase 37 | { 38 | public: 39 | static TypeId GetTypeId (void); 40 | virtual TypeId GetInstanceTypeId () const; 41 | 42 | TcpSocketDerived (void); 43 | virtual ~TcpSocketDerived (void); 44 | 45 | Ptr GetCongestionControlAlgorithm (); 46 | }; 47 | 48 | 49 | class TcpRlBase : public TcpCongestionOps 50 | { 51 | public: 52 | /** 53 | * \brief Get the type ID. 54 | * \return the object TypeId 55 | */ 56 | static TypeId GetTypeId (void); 57 | 58 | TcpRlBase (); 59 | 60 | /** 61 | * \brief Copy constructor. 62 | * \param sock object to copy. 63 | */ 64 | TcpRlBase (const TcpRlBase& sock); 65 | 66 | ~TcpRlBase (); 67 | 68 | virtual std::string GetName () const; 69 | virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight); 70 | virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked); 71 | virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt); 72 | virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState); 73 | virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event); 74 | virtual Ptr Fork (); 75 | 76 | protected: 77 | static uint64_t GenerateUuid (); 78 | virtual void CreateGymEnv(); 79 | void ConnectSocketCallbacks(); 80 | 81 | // OpenGymEnv interface 82 | Ptr m_tcpSocket; 83 | Ptr m_tcpGymEnv; 84 | }; 85 | 86 | 87 | class TcpRl : public TcpRlBase 88 | { 89 | public: 90 | static TypeId GetTypeId (void); 91 | 92 | TcpRl (); 93 | TcpRl (const TcpRl& sock); 94 | ~TcpRl (); 95 | 96 | virtual std::string GetName () const; 97 | private: 98 | virtual void CreateGymEnv(); 99 | // OpenGymEnv env 100 | float m_reward {1.0}; 101 | float m_penalty {-100.0}; 102 | }; 103 | 104 | 105 | class TcpRlTimeBased : public TcpRlBase 106 | { 107 | public: 108 | static TypeId GetTypeId (void); 109 | 110 | TcpRlTimeBased (); 111 | TcpRlTimeBased (const TcpRlTimeBased& sock); 112 | ~TcpRlTimeBased (); 113 | 114 | virtual std::string GetName () const; 115 | 116 | private: 117 | virtual void CreateGymEnv(); 118 | // OpenGymEnv env 119 | Time m_timeStep {MilliSeconds (100)}; 120 | }; 121 | 122 | } // namespace ns3 123 | 124 | #endif /* TCP_RL_H */ -------------------------------------------------------------------------------- /tcp_base.py: -------------------------------------------------------------------------------- 1 | __author__ = "Piotr Gawlowicz" 2 | __copyright__ = "Copyright (c) 2018, Technische Universität Berlin" 3 | __version__ = "0.1.0" 4 | __email__ = "gawlowicz@tkn.tu-berlin.de" 5 | 6 | 7 | class Tcp(object): 8 | """docstring for Tcp""" 9 | def __init__(self): 10 | super(Tcp, self).__init__() 11 | 12 | def set_spaces(self, obs, act): 13 | self.obsSpace = obs 14 | self.actSpace = act 15 | 16 | def get_action(self, obs, reward, done, info): 17 | pass 18 | 19 | 20 | class TcpEventBased(Tcp): 21 | """docstring for TcpEventBased""" 22 | def __init__(self): 23 | super(TcpEventBased, self).__init__() 24 | 25 | def get_action(self, obs, reward, done, info): 26 | # unique socket ID 27 | socketUuid = obs[0] 28 | # TCP env type: event-based = 0 / time-based = 1 29 | envType = obs[1] 30 | # sim time in us 31 | simTime_us = obs[2] 32 | # unique node ID 33 | nodeId = obs[3] 34 | # current ssThreshold 35 | ssThresh = obs[4] 36 | # current contention window size 37 | cWnd = obs[5] 38 | # segment size 39 | segmentSize = obs[6] 40 | # number of acked segments 41 | segmentsAcked = obs[7] 42 | # estimated bytes in flight 43 | bytesInFlight = obs[8] 44 | # last estimation of RTT 45 | lastRtt_us = obs[9] 46 | # min value of RTT 47 | minRtt_us = obs[10] 48 | # function from Congestion Algorithm (CA) interface: 49 | # GET_SS_THRESH = 0 (packet loss), 50 | # INCREASE_WINDOW (packet acked), 51 | # PKTS_ACKED (unused), 52 | # CONGESTION_STATE_SET (unused), 53 | # CWND_EVENT (unused), 54 | calledFunc = obs[11] 55 | # Congetsion Algorithm (CA) state: 56 | # CA_OPEN = 0, 57 | # CA_DISORDER, 58 | # CA_CWR, 59 | # CA_RECOVERY, 60 | # CA_LOSS, 61 | # CA_LAST_STATE 62 | caState = obs[12] 63 | # Congetsion Algorithm (CA) event: 64 | # CA_EVENT_TX_START = 0, 65 | # CA_EVENT_CWND_RESTART, 66 | # CA_EVENT_COMPLETE_CWR, 67 | # CA_EVENT_LOSS, 68 | # CA_EVENT_ECN_NO_CE, 69 | # CA_EVENT_ECN_IS_CE, 70 | # CA_EVENT_DELAYED_ACK, 71 | # CA_EVENT_NON_DELAYED_ACK, 72 | caEvent = obs[13] 73 | # ECN state: 74 | # ECN_DISABLED = 0, 75 | # ECN_IDLE, 76 | # ECN_CE_RCVD, 77 | # ECN_SENDING_ECE, 78 | # ECN_ECE_RCVD, 79 | # ECN_CWR_SENT 80 | ecnState = obs[14] 81 | 82 | # compute new values 83 | new_cWnd = 10 * segmentSize 84 | new_ssThresh = 5 * segmentSize 85 | 86 | # return actions 87 | # actions = [new_ssThresh, new_cWnd] 88 | 89 | return actions 90 | 91 | 92 | class TcpTimeBased(Tcp): 93 | """docstring for TcpTimeBased""" 94 | def __init__(self): 95 | super(TcpTimeBased, self).__init__() 96 | 97 | def get_action(self, obs, reward, done, info): 98 | # unique socket ID 99 | socketUuid = obs[0] 100 | # TCP env type: event-based = 0 / time-based = 1 101 | envType = obs[1] 102 | # sim time in us 103 | simTime_us = obs[2] 104 | # unique node ID 105 | nodeId = obs[3] 106 | # current ssThreshold 107 | ssThresh = obs[4] 108 | # current contention window size 109 | cWnd = obs[5] 110 | # segment size 111 | segmentSize = obs[6] 112 | # bytesInFlightSum 113 | bytesInFlightSum = obs[7] 114 | # bytesInFlightAvg 115 | bytesInFlightAvg = obs[8] 116 | # segmentsAckedSum 117 | segmentsAckedSum = obs[9] 118 | # segmentsAckedAvg 119 | segmentsAckedAvg = obs[10] 120 | # avgRtt 121 | avgRtt = obs[11] 122 | # minRtt 123 | minRtt = obs[12] 124 | # avgInterTx 125 | avgInterTx = obs[13] 126 | # avgInterRx 127 | avgInterRx = obs[14] 128 | # throughput 129 | throughput = obs[15] 130 | 131 | # compute new values 132 | new_cWnd = 10 * segmentSize 133 | # new_cWnd2 = cWnd - 10 134 | new_ssThresh = 5 * segmentSize 135 | 136 | # return actions 137 | actions = [new_cWnd2, new_ssThresh,] 138 | 139 | return actions 140 | --------------------------------------------------------------------------------