├── Agent_test.py
├── README.md
├── rl.png
├── sim.cc
├── tcp-rl-env.cc
├── tcp-rl-env.h
├── tcp-rl.cc
├── tcp-rl.h
└── tcp_base.py
/Agent_test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | import gym
4 | import tensorflow as tf
5 | import tensorflow.contrib.slim as slim
6 | import numpy as np
7 | import matplotlib as mpl
8 | import matplotlib.pyplot as plt
9 | from tensorflow import keras
10 | from ns3gym import ns3env
11 |
12 |
13 | import argparse
14 | from ns3gym import ns3env
15 | from tcp_base import TcpTimeBased
16 | from tcp_newreno import TcpNewReno
17 |
18 | __author__ = "Kenan and Sharif, Modified the code by Piotr Gawlowicz"
19 | __copyright__ = "Copyright (c) 2020, Technische Universität Berlin"
20 | __version__ = "0.1.0"
21 | __email__ = "gawlowicz@tkn.tu-berlin.de"
22 |
23 | parser = argparse.ArgumentParser(description='Start simulation script on/off')
24 | parser.add_argument('--start',
25 | type=int,
26 | default=1,
27 | help='Start ns-3 simulation script 0/1, Default: 1')
28 | parser.add_argument('--iterations',
29 | type=int,
30 | default=1,
31 | help='Number of iterations, Default: 1')
32 | args = parser.parse_args()
33 | startSim = bool(args.start)
34 | iterationNum = int(args.iterations)
35 |
36 | port = 5555
37 | simTime = 10 # seconds
38 | stepTime = 0.5 # seconds
39 | seed = 12
40 | simArgs = {"--duration": simTime,}
41 | debug = False
42 |
43 | env = ns3env.Ns3Env(port=port, stepTime=stepTime, startSim=startSim, simSeed=seed, simArgs=simArgs, debug=debug)
44 | # simpler:
45 | #env = ns3env.Ns3Env()
46 | env.reset()
47 |
48 | ob_space = env.observation_space
49 | ac_space = env.action_space
50 |
51 | print("Observation space: ", ob_space, ob_space.dtype)
52 | print("Action space: ", ac_space, ac_space.dtype)
53 |
54 |
55 | def get_agent(obs):
56 | socketUuid = obs[0]
57 | tcpEnvType = obs[1]
58 | tcpAgent = get_agent.tcpAgents.get(socketUuid, None)
59 | if tcpAgent is None:
60 | if tcpEnvType == 0:
61 | # event-based = 0
62 | tcpAgent = TcpNewReno()
63 | else:
64 | # time-based = 1
65 | tcpAgent = TcpTimeBased()
66 | tcpAgent.set_spaces(get_agent.ob_space, get_agent.ac_space)
67 | get_agent.tcpAgents[socketUuid] = tcpAgent
68 |
69 | return tcpAgent
70 |
71 | # initialize variable
72 | get_agent.tcpAgents = {}
73 | get_agent.ob_space = ob_space
74 | get_agent.ac_space = ac_space
75 |
76 | s_size = ob_space.shape[0]
77 | print("State size: ",ob_space.shape[0])
78 |
79 | a_size = 3
80 | print("Action size: ", a_size)
81 |
82 | model = keras.Sequential()
83 | model.add(keras.layers.Dense(s_size, input_shape=(s_size,), activation='relu'))
84 | model.add(keras.layers.Dense(s_size, input_shape=(s_size,), activation='relu'))
85 | model.add(keras.layers.Dense(a_size, activation='softmax'))
86 | model.compile(optimizer=tf.train.AdamOptimizer(0.001),
87 | loss='categorical_crossentropy',
88 | metrics=['accuracy'])
89 |
90 | total_episodes = 6
91 | max_env_steps = 100
92 | env._max_episode_steps = max_env_steps
93 |
94 | epsilon = 1.0 # exploration rate
95 | epsilon_min = 0.01
96 | epsilon_decay = 0.999
97 |
98 | time_history = []
99 | rew_history = []
100 | cWnd_history=[]
101 | cWnd_history2=[]
102 | Rtt=[]
103 | segAkc=[]
104 | No_step = 0
105 | t2 =10
106 | t =[]
107 | reward = 0
108 | done = False
109 | info = None
110 | action_mapping = {}
111 | action_mapping[0] = 0
112 | action_mapping[1] = 600
113 | action_mapping[2] = -60
114 | U_new =0
115 | U =0
116 | U_old=0
117 | reward=0
118 | for e in range(total_episodes):
119 | obs = env.reset()
120 | cWnd = obs[5]
121 | obs = np.reshape(obs, [1, s_size])
122 | rewardsum = 0
123 | for time in range(max_env_steps):
124 | # Choose action
125 | if np.random.rand(1) < epsilon:
126 | action_index = np.random.randint(3)
127 | print (action_index)
128 | print("Value Initialization ...")
129 | else:
130 | action_index = np.argmax(model.predict(obs)[0])
131 | print(action_index)
132 | new_cWnd = cWnd + action_mapping[action_index]
133 | new_ssThresh = np.int(cWnd/2)
134 | actions = [new_ssThresh, new_cWnd]
135 | U_new=0.7*(np.log(obs[0,2]))-0.7*(np.log(obs[0,9] ))
136 | U=U_new-U_old
137 | if U <-0.05:
138 | reward=-5
139 | elif U >0.05:
140 | reward=1
141 | else:
142 | reward=0
143 | # Step
144 | next_state, reward, done, info = env.step(actions)
145 | cWnd = next_state[5]
146 | print("cWnd:",cWnd)
147 | if done:
148 | print("episode: {}/{}, time: {}, rew: {}, eps: {:.2}"
149 | .format(e, total_episodes, time, rewardsum, epsilon))
150 | break
151 | U_old=0.7*(np.log(obs[0,2]))-0.7*(np.log(obs[0,9] ))
152 | next_state = np.reshape(next_state, [1, s_size])
153 | # Train
154 | target = reward
155 | if not done:
156 | target = (reward + 0.95 * np.amax(model.predict(next_state)[0]))
157 | target_f = model.predict(obs)
158 | print("target :", target_f)
159 | target_f[0][action_index] = target
160 | model.fit(obs, target_f, epochs=1, verbose=0)
161 |
162 | obs = next_state
163 | seg=obs[0,5]
164 | rtt=obs[0,9]
165 | rewardsum += reward
166 | if epsilon > epsilon_min: epsilon *= epsilon_decay
167 | No_step += 1
168 |
169 | print("number of steps :", No_step)
170 | print("espsilon :",epsilon)
171 |
172 |
173 | print("reward sum", rewardsum)
174 | segAkc.append(seg)
175 | Rtt.append(rtt)
176 |
177 | cWnd_history.append(cWnd)
178 | time_history.append(time)
179 | rew_history.append(rewardsum)
180 |
181 |
182 |
183 | print("Plot Learning Performance")
184 | mpl.rcdefaults()
185 | mpl.rcParams.update({'font.size': 16})
186 | fig, ax = plt.subplots(figsize=(10,4))
187 | plt.grid(True, linestyle='--')
188 | plt.title('Learning Performance')
189 | #plt.plot(range(len(rew_history)), rew_history, label='Reward', marker="^", linestyle=":")#, color='red')
190 | plt.plot(range(len(rew_history)), rew_history, label='Reward', marker="", linestyle="-")#, color='k')
191 |
192 | #plt.plot(range(len(segAkc)), segAkc, label='segAkc', marker="", linestyle="-"),# color='b')
193 | #plt.plot(range(len(Rtt)),Rtt, label='Rtt', marker="", linestyle="-")#, color='y')
194 | plt.xlabel('Episode')
195 | plt.ylabel('Steps')
196 | plt.legend(prop={'size': 12})
197 | plt.savefig('learning.pdf', bbox_inches='tight')
198 | plt.show()
199 |
200 |
201 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RL-TCP
2 | For years, TCP has been explored and discussed for end-to-end congestion control mechanisms. Although, various efforts and optimization techniques have taken place to find an optimal approach in dealing with all kinds of network architecture and scenarios. But the TCP version only follows the rules which they are defined by. The rule-based approach takes no consideration of the previous information of the links every time a flow begins,whereas a better performance can be accomplished if TCP can adapt its attitude based on prior learned information when the same path was before experienced. The rising of machine learning and deep learning paves the way for many researchers in academia. These approaches rather learn from history to find the best way and methods to get the job done.
3 | This project discusses the implementation of Reinforcement Learning (RL) based TCP congestion control using python Agent which uses RL based techniques to somehow overcome the rule-based shortcomings of TCP.
4 |
5 | ## RL Tools
6 | The following RL tools are used for the implementation of the project.
7 |
8 | ### NS-3
9 | Ns-3 is a discrete-event network simulator for networking systems, that primarily used for research and education. It became a standard in networking research in recent years as the results obtained are accepted by the science community. In this project, we used NS-3 to generate traffic and also simulate our environment. It can be found here. https://www.nsnam.org
10 |
11 | ### Ns-3 gym
12 | ns3-gym is a middleware which interconnects ns-3 network simulator and OpenAI Gym framework.ns-3 gym works as a proxy between both ns-3 environments that is a C++ environment and the python Agent of Reinforcement learning which is in python. It can be installed from here. https://github.com/tkn-tub/ns3-gym
13 |
14 | ## Agent
15 | Our agent interacts with the network environments of ns-3 and keeps exploring the optimal policy by taking various actions such as increasing or decreasing cWnd size. The environment setup is a combination of ns-3 network simulator, ns3-gym proxy, and a python based agent. Our mapping approach is as follows:
16 |
17 | * State: The state space is the profile of our networking simulation that is provided by the ns-3 simulator.
18 | * Action: The actions are considered to increase, decrease or leave no-change the cWnd size with +600, -200, 0 respectively.
19 | * Reward: The reward is a utility function that considers the RTT and segments acknowledged parameters. +2 incase of increasing cWnd and not overwhelming the receiver side, otherwise -30.
20 |
21 | 
22 | As the result indicates, our agent adaptively adjusts the cWnd in a constant range of values according to an objective the agent realized.
23 |
24 | In order to run the agent, first, install ns-3, ns3-gym and build the project and then execute:
25 | ```
26 | ./Agent_test.py
27 | ````
28 | Or in two Terminals:
29 | ```
30 | # Terminal 1:
31 | ./waf --run "rl-tcp --transport_prot=TcpRlTimeBased"
32 |
33 | # Terminal 2:
34 | ./Agent_test.py --start=0
35 |
36 | ```
37 |
--------------------------------------------------------------------------------
/rl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharif-Rezaie/RL-TCP/9dc1e3893f11e3e67d1fae7f59d975548cce9126/rl.png
--------------------------------------------------------------------------------
/sim.cc:
--------------------------------------------------------------------------------
1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
2 | /*
3 | * Copyright (c) 2018 Piotr Gawlowicz
4 | *
5 | * This program is free software; you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License version 2 as
7 | * published by the Free Software Foundation;
8 | *
9 | * This program is distributed in the hope that it will be useful,
10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | * GNU General Public License for more details.
13 | *
14 | * You should have received a copy of the GNU General Public License
15 | * along with this program; if not, write to the Free Software
16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17 | *
18 | * Author: Piotr Gawlowicz
19 | * Based on script: ./examples/tcp/tcp-variants-comparison.cc
20 | *
21 | * Topology:
22 | *
23 | * Right Leafs (Clients) Left Leafs (Sinks)
24 | * | \ / |
25 | * | \ bottleneck / |
26 | * | R0--------------R1 |
27 | * | / \ |
28 | * | access / \ access |
29 | * N ----------- --------N
30 | */
31 |
32 | #include
33 | #include
34 | #include
35 |
36 | #include "ns3/core-module.h"
37 | #include "ns3/network-module.h"
38 | #include "ns3/internet-module.h"
39 | #include "ns3/point-to-point-module.h"
40 | #include "ns3/point-to-point-layout-module.h"
41 | #include "ns3/applications-module.h"
42 | #include "ns3/error-model.h"
43 | #include "ns3/tcp-header.h"
44 | #include "ns3/enum.h"
45 | #include "ns3/event-id.h"
46 | #include "ns3/flow-monitor-helper.h"
47 | #include "ns3/ipv4-global-routing-helper.h"
48 | #include "ns3/traffic-control-module.h"
49 |
50 | #include "ns3/opengym-module.h"
51 | #include "tcp-rl.h"
52 |
53 | using namespace ns3;
54 |
55 | NS_LOG_COMPONENT_DEFINE ("TcpVariantsComparison");
56 |
57 | static std::vector rxPkts;
58 |
59 | static void
60 | CountRxPkts(uint32_t sinkId, Ptr packet, const Address & srcAddr)
61 | {
62 | rxPkts[sinkId]++;
63 | }
64 |
65 | static void
66 | PrintRxCount()
67 | {
68 | uint32_t size = rxPkts.size();
69 | NS_LOG_UNCOND("RxPkts:");
70 | for (uint32_t i=0; i openGymInterface;
144 | if (transport_prot.compare ("ns3::TcpRl") == 0)
145 | {
146 | openGymInterface = OpenGymInterface::Get(openGymPort);
147 | Config::SetDefault ("ns3::TcpRl::Reward", DoubleValue (2.0)); // Reward when increasing congestion window
148 | Config::SetDefault ("ns3::TcpRl::Penalty", DoubleValue (-30.0)); // Penalty when decreasing congestion window
149 | }
150 |
151 | if (transport_prot.compare ("ns3::TcpRlTimeBased") == 0)
152 | {
153 | openGymInterface = OpenGymInterface::Get(openGymPort);
154 | Config::SetDefault ("ns3::TcpRlTimeBased::StepTime", TimeValue (Seconds(tcpEnvTimeStep))); // Time step of TCP env
155 | }
156 |
157 | // Calculate the ADU size
158 | Header* temp_header = new Ipv4Header ();
159 | uint32_t ip_header = temp_header->GetSerializedSize ();
160 | NS_LOG_LOGIC ("IP Header size is: " << ip_header);
161 | delete temp_header;
162 | temp_header = new TcpHeader ();
163 | uint32_t tcp_header = temp_header->GetSerializedSize ();
164 | NS_LOG_LOGIC ("TCP Header size is: " << tcp_header);
165 | delete temp_header;
166 | uint32_t tcp_adu_size = mtu_bytes - 20 - (ip_header + tcp_header);
167 | NS_LOG_LOGIC ("TCP ADU size is: " << tcp_adu_size);
168 |
169 | // Set the simulation start and stop time
170 | double start_time = 0.1;
171 | double stop_time = start_time + duration;
172 |
173 | // 4 MB of TCP buffer
174 | Config::SetDefault ("ns3::TcpSocket::RcvBufSize", UintegerValue (1 << 21));
175 | Config::SetDefault ("ns3::TcpSocket::SndBufSize", UintegerValue (1 << 21));
176 | Config::SetDefault ("ns3::TcpSocketBase::Sack", BooleanValue (sack));
177 | Config::SetDefault ("ns3::TcpSocket::DelAckCount", UintegerValue (2));
178 |
179 |
180 | Config::SetDefault ("ns3::TcpL4Protocol::RecoveryType",
181 | TypeIdValue (TypeId::LookupByName (recovery)));
182 | // Select TCP variant
183 | if (transport_prot.compare ("ns3::TcpWestwoodPlus") == 0)
184 | {
185 | // TcpWestwoodPlus is not an actual TypeId name; we need TcpWestwood here
186 | Config::SetDefault ("ns3::TcpL4Protocol::SocketType", TypeIdValue (TcpWestwood::GetTypeId ()));
187 | // the default protocol type in ns3::TcpWestwood is WESTWOOD
188 | Config::SetDefault ("ns3::TcpWestwood::ProtocolType", EnumValue (TcpWestwood::WESTWOODPLUS));
189 | }
190 | else
191 | {
192 | TypeId tcpTid;
193 | NS_ABORT_MSG_UNLESS (TypeId::LookupByNameFailSafe (transport_prot, &tcpTid), "TypeId " << transport_prot << " not found");
194 | Config::SetDefault ("ns3::TcpL4Protocol::SocketType", TypeIdValue (TypeId::LookupByName (transport_prot)));
195 | }
196 |
197 | // Configure the error model
198 | // Here we use RateErrorModel with packet error rate
199 | Ptr uv = CreateObject ();
200 | uv->SetStream (50);
201 | RateErrorModel error_model;
202 | error_model.SetRandomVariable (uv);
203 | error_model.SetUnit (RateErrorModel::ERROR_UNIT_PACKET);
204 | error_model.SetRate (error_p);
205 |
206 | // Create the point-to-point link helpers
207 | PointToPointHelper bottleNeckLink;
208 | bottleNeckLink.SetDeviceAttribute ("DataRate", StringValue (bottleneck_bandwidth));
209 | bottleNeckLink.SetChannelAttribute ("Delay", StringValue (bottleneck_delay));
210 | //bottleNeckLink.SetDeviceAttribute ("ReceiveErrorModel", PointerValue (&error_model));
211 |
212 | PointToPointHelper pointToPointLeaf;
213 | pointToPointLeaf.SetDeviceAttribute ("DataRate", StringValue (access_bandwidth));
214 | pointToPointLeaf.SetChannelAttribute ("Delay", StringValue (access_delay));
215 |
216 | PointToPointDumbbellHelper d (nLeaf, pointToPointLeaf,
217 | nLeaf, pointToPointLeaf,
218 | bottleNeckLink);
219 |
220 | // Install IP stack
221 | InternetStackHelper stack;
222 | stack.InstallAll ();
223 |
224 | // Traffic Control
225 | TrafficControlHelper tchPfifo;
226 | tchPfifo.SetRootQueueDisc ("ns3::PfifoFastQueueDisc");
227 |
228 | TrafficControlHelper tchCoDel;
229 | tchCoDel.SetRootQueueDisc ("ns3::CoDelQueueDisc");
230 |
231 | DataRate access_b (access_bandwidth);
232 | DataRate bottle_b (bottleneck_bandwidth);
233 | Time access_d (access_delay);
234 | Time bottle_d (bottleneck_delay);
235 |
236 | uint32_t size = static_cast((std::min (access_b, bottle_b).GetBitRate () / 8) *
237 | ((access_d + bottle_d + access_d) * 2).GetSeconds ());
238 |
239 | Config::SetDefault ("ns3::PfifoFastQueueDisc::MaxSize",
240 | QueueSizeValue (QueueSize (QueueSizeUnit::PACKETS, size / mtu_bytes)));
241 | Config::SetDefault ("ns3::CoDelQueueDisc::MaxSize",
242 | QueueSizeValue (QueueSize (QueueSizeUnit::BYTES, size)));
243 |
244 | if (queue_disc_type.compare ("ns3::PfifoFastQueueDisc") == 0)
245 | {
246 | tchPfifo.Install (d.GetLeft()->GetDevice(1));
247 | tchPfifo.Install (d.GetRight()->GetDevice(1));
248 | }
249 | else if (queue_disc_type.compare ("ns3::CoDelQueueDisc") == 0)
250 | {
251 | tchCoDel.Install (d.GetLeft()->GetDevice(1));
252 | tchCoDel.Install (d.GetRight()->GetDevice(1));
253 | }
254 | else
255 | {
256 | NS_FATAL_ERROR ("Queue not recognized. Allowed values are ns3::CoDelQueueDisc or ns3::PfifoFastQueueDisc");
257 | }
258 |
259 | // Assign IP Addresses
260 | d.AssignIpv4Addresses (Ipv4AddressHelper ("10.1.1.0", "255.255.255.0"),
261 | Ipv4AddressHelper ("10.2.1.0", "255.255.255.0"),
262 | Ipv4AddressHelper ("10.3.1.0", "255.255.255.0"));
263 |
264 |
265 | NS_LOG_INFO ("Initialize Global Routing.");
266 | Ipv4GlobalRoutingHelper::PopulateRoutingTables ();
267 |
268 | // Install apps in left and right nodes
269 | uint16_t port = 50000;
270 | Address sinkLocalAddress (InetSocketAddress (Ipv4Address::GetAny (), port));
271 | PacketSinkHelper sinkHelper ("ns3::TcpSocketFactory", sinkLocalAddress);
272 | ApplicationContainer sinkApps;
273 | for (uint32_t i = 0; i < d.RightCount (); ++i)
274 | {
275 | sinkHelper.SetAttribute ("Protocol", TypeIdValue (TcpSocketFactory::GetTypeId ()));
276 | sinkApps.Add (sinkHelper.Install (d.GetRight (i)));
277 | }
278 | sinkApps.Start (Seconds (0.0));
279 | sinkApps.Stop (Seconds (stop_time));
280 |
281 | for (uint32_t i = 0; i < d.LeftCount (); ++i)
282 | {
283 | // Create an on/off app sending packets to the left side
284 | AddressValue remoteAddress (InetSocketAddress (d.GetRightIpv4Address (i), port));
285 | Config::SetDefault ("ns3::TcpSocket::SegmentSize", UintegerValue (tcp_adu_size));
286 | BulkSendHelper ftp ("ns3::TcpSocketFactory", Address ());
287 | ftp.SetAttribute ("Remote", remoteAddress);
288 | ftp.SetAttribute ("SendSize", UintegerValue (tcp_adu_size));
289 | ftp.SetAttribute ("MaxBytes", UintegerValue (data_mbytes * 1000000));
290 |
291 | ApplicationContainer clientApp = ftp.Install (d.GetLeft (i));
292 | clientApp.Start (Seconds (start_time * i)); // Start after sink
293 | clientApp.Stop (Seconds (stop_time - 3)); // Stop before the sink
294 | }
295 |
296 | // Flow monitor
297 | FlowMonitorHelper flowHelper;
298 | if (flow_monitor)
299 | {
300 | flowHelper.InstallAll ();
301 | }
302 |
303 | // Count RX packets
304 | for (uint32_t i = 0; i < d.RightCount (); ++i)
305 | {
306 | rxPkts.push_back(0);
307 | Ptr pktSink = DynamicCast(sinkApps.Get(i));
308 | pktSink->TraceConnectWithoutContext ("Rx", MakeBoundCallback (&CountRxPkts, i));
309 | }
310 |
311 | Simulator::Stop (Seconds (stop_time));
312 | Simulator::Run ();
313 |
314 | if (flow_monitor)
315 | {
316 | flowHelper.SerializeToXmlFile (prefix_file_name + ".flowmonitor", true, true);
317 | }
318 |
319 | if (transport_prot.compare ("ns3::TcpRl") == 0 or transport_prot.compare ("ns3::TcpRlTimeBased") == 0)
320 | {
321 | openGymInterface->NotifySimulationEnd();
322 | }
323 |
324 | PrintRxCount();
325 | Simulator::Destroy ();
326 | return 0;
327 | }
328 |
--------------------------------------------------------------------------------
/tcp-rl-env.cc:
--------------------------------------------------------------------------------
1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
2 | /*
3 | * Copyright (c) 2018 Technische Universität Berlin
4 | *
5 | * This program is free software; you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License version 2 as
7 | * published by the Free Software Foundation;
8 | *
9 | * This program is distributed in the hope that it will be useful,
10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | * GNU General Public License for more details.
13 | *
14 | * You should have received a copy of the GNU General Public License
15 | * along with this program; if not, write to the Free Software
16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17 | *
18 | * Author: Piotr Gawlowicz
19 | */
20 |
21 | #include "tcp-rl-env.h"
22 | #include "ns3/tcp-header.h"
23 | #include "ns3/object.h"
24 | #include "ns3/core-module.h"
25 | #include "ns3/log.h"
26 | #include "ns3/simulator.h"
27 | #include "ns3/tcp-socket-base.h"
28 | #include
29 | #include
30 |
31 |
32 | namespace ns3 {
33 |
34 | NS_LOG_COMPONENT_DEFINE ("ns3::TcpGymEnv");
35 | NS_OBJECT_ENSURE_REGISTERED (TcpGymEnv);
36 |
37 | TcpGymEnv::TcpGymEnv ()
38 | {
39 | NS_LOG_FUNCTION (this);
40 | SetOpenGymInterface(OpenGymInterface::Get());
41 | }
42 |
43 | TcpGymEnv::~TcpGymEnv ()
44 | {
45 | NS_LOG_FUNCTION (this);
46 | }
47 |
48 | TypeId
49 | TcpGymEnv::GetTypeId (void)
50 | {
51 | static TypeId tid = TypeId ("ns3::TcpGymEnv")
52 | .SetParent ()
53 | .SetGroupName ("OpenGym")
54 | ;
55 |
56 | return tid;
57 | }
58 |
59 | void
60 | TcpGymEnv::DoDispose ()
61 | {
62 | NS_LOG_FUNCTION (this);
63 | }
64 |
65 | void
66 | TcpGymEnv::SetNodeId(uint32_t id)
67 | {
68 | NS_LOG_FUNCTION (this);
69 | m_nodeId = id;
70 | }
71 |
72 | void
73 | TcpGymEnv::SetSocketUuid(uint32_t id)
74 | {
75 | NS_LOG_FUNCTION (this);
76 | m_socketUuid = id;
77 | }
78 |
79 | std::string
80 | TcpGymEnv::GetTcpCongStateName(const TcpSocketState::TcpCongState_t state)
81 | {
82 | std::string stateName = "UNKNOWN";
83 | switch(state) {
84 | case TcpSocketState::CA_OPEN:
85 | stateName = "CA_OPEN";
86 | break;
87 | case TcpSocketState::CA_DISORDER:
88 | stateName = "CA_DISORDER";
89 | break;
90 | case TcpSocketState::CA_CWR:
91 | stateName = "CA_CWR";
92 | break;
93 | case TcpSocketState::CA_RECOVERY:
94 | stateName = "CA_RECOVERY";
95 | break;
96 | case TcpSocketState::CA_LOSS:
97 | stateName = "CA_LOSS";
98 | break;
99 | case TcpSocketState::CA_LAST_STATE:
100 | stateName = "CA_LAST_STATE";
101 | break;
102 | default:
103 | stateName = "UNKNOWN";
104 | break;
105 | }
106 | return stateName;
107 | }
108 |
109 | std::string
110 | TcpGymEnv::GetTcpCAEventName(const TcpSocketState::TcpCAEvent_t event)
111 | {
112 | std::string eventName = "UNKNOWN";
113 | switch(event) {
114 | case TcpSocketState::CA_EVENT_TX_START:
115 | eventName = "CA_EVENT_TX_START";
116 | break;
117 | case TcpSocketState::CA_EVENT_CWND_RESTART:
118 | eventName = "CA_EVENT_CWND_RESTART";
119 | break;
120 | case TcpSocketState::CA_EVENT_COMPLETE_CWR:
121 | eventName = "CA_EVENT_COMPLETE_CWR";
122 | break;
123 | case TcpSocketState::CA_EVENT_LOSS:
124 | eventName = "CA_EVENT_LOSS";
125 | break;
126 | case TcpSocketState::CA_EVENT_ECN_NO_CE:
127 | eventName = "CA_EVENT_ECN_NO_CE";
128 | break;
129 | case TcpSocketState::CA_EVENT_ECN_IS_CE:
130 | eventName = "CA_EVENT_ECN_IS_CE";
131 | break;
132 | case TcpSocketState::CA_EVENT_DELAYED_ACK:
133 | eventName = "CA_EVENT_DELAYED_ACK";
134 | break;
135 | case TcpSocketState::CA_EVENT_NON_DELAYED_ACK:
136 | eventName = "CA_EVENT_NON_DELAYED_ACK";
137 | break;
138 | default:
139 | eventName = "UNKNOWN";
140 | break;
141 | }
142 | return eventName;
143 | }
144 |
145 | /*
146 | Define action space
147 | */
148 | Ptr
149 | TcpGymEnv::GetActionSpace()
150 | {
151 | // new_ssThresh
152 | // new_cWnd
153 | uint32_t parameterNum = 2;
154 | float low = 0.0;
155 | float high = 65535;
156 | std::vector shape = {parameterNum,};
157 | std::string dtype = TypeNameGet ();
158 |
159 | Ptr box = CreateObject (low, high, shape, dtype);
160 | NS_LOG_INFO ("MyGetActionSpace: " << box);
161 | return box;
162 | }
163 |
164 | /*
165 | Define game over condition
166 | */
167 | bool
168 | TcpGymEnv::GetGameOver()
169 | {
170 | m_isGameOver = false;
171 | bool test = false;
172 | static float stepCounter = 0.0;
173 | stepCounter += 1;
174 | if (stepCounter == 10 && test) {
175 | m_isGameOver = true;
176 | }
177 | NS_LOG_INFO ("MyGetGameOver: " << m_isGameOver);
178 | return m_isGameOver;
179 | }
180 |
181 | /*
182 | Define reward function
183 | */
184 | float
185 | TcpGymEnv::GetReward()
186 | {
187 | NS_LOG_INFO("MyGetReward: " << m_envReward);
188 | return m_envReward;
189 | }
190 |
191 | /*
192 | Define extra info. Optional
193 | */
194 | std::string
195 | TcpGymEnv::GetExtraInfo()
196 | {
197 | NS_LOG_INFO("MyGetExtraInfo: " << m_info);
198 | return m_info;
199 | }
200 |
201 | /*
202 | Execute received actions
203 | */
204 | bool
205 | TcpGymEnv::ExecuteActions(Ptr action)
206 | {
207 | Ptr > box = DynamicCast >(action);
208 | m_new_ssThresh = box->GetValue(0);
209 | m_new_cWnd = box->GetValue(1);
210 |
211 | NS_LOG_INFO ("MyExecuteActions: " << action);
212 | return true;
213 | }
214 |
215 |
216 | NS_OBJECT_ENSURE_REGISTERED (TcpEventGymEnv);
217 |
218 | TcpEventGymEnv::TcpEventGymEnv () : TcpGymEnv()
219 | {
220 | NS_LOG_FUNCTION (this);
221 | }
222 |
223 | TcpEventGymEnv::~TcpEventGymEnv ()
224 | {
225 | NS_LOG_FUNCTION (this);
226 | }
227 |
228 | TypeId
229 | TcpEventGymEnv::GetTypeId (void)
230 | {
231 | static TypeId tid = TypeId ("ns3::TcpEventGymEnv")
232 | .SetParent ()
233 | .SetGroupName ("OpenGym")
234 | .AddConstructor ()
235 | ;
236 |
237 | return tid;
238 | }
239 |
240 | void
241 | TcpEventGymEnv::DoDispose ()
242 | {
243 | NS_LOG_FUNCTION (this);
244 | }
245 |
246 | void
247 | TcpEventGymEnv::SetReward(float value)
248 | {
249 | NS_LOG_FUNCTION (this);
250 | m_reward = value;
251 | }
252 |
253 | void
254 | TcpEventGymEnv::SetPenalty(float value)
255 | {
256 | NS_LOG_FUNCTION (this);
257 | m_penalty = value;
258 | }
259 |
260 | /*
261 | Define observation space
262 | */
263 | Ptr
264 | TcpEventGymEnv::GetObservationSpace()
265 | {
266 | // socket unique ID
267 | // tcp env type: event-based = 0 / time-based = 1
268 | // sim time in us
269 | // node ID
270 | // ssThresh
271 | // cWnd
272 | // segmentSize
273 | // segmentsAcked
274 | // bytesInFlight
275 | // rtt in us
276 | // min rtt in us
277 | // called func
278 | // congetsion algorithm (CA) state
279 | // CA event
280 | // ECN state
281 | uint32_t parameterNum = 10;
282 | float low = 0.0;
283 | float high = 1000000000.0;
284 | std::vector shape = {parameterNum,};
285 | std::string dtype = TypeNameGet ();
286 |
287 | Ptr box = CreateObject (low, high, shape, dtype);
288 | NS_LOG_INFO ("MyGetObservationSpace: " << box);
289 | return box;
290 | }
291 |
292 | /*
293 | Collect observations
294 | */
295 | Ptr
296 | TcpEventGymEnv::GetObservation()
297 | {
298 | uint32_t parameterNum = 10;
299 | std::vector shape = {parameterNum,};
300 |
301 | Ptr > box = CreateObject >(shape);
302 |
303 | box->AddValue(m_socketUuid);
304 | box->AddValue(0);
305 | box->AddValue(Simulator::Now().GetMicroSeconds ());
306 | box->AddValue(m_nodeId);
307 | box->AddValue(m_tcb->m_ssThresh);
308 | box->AddValue(m_tcb->m_cWnd);
309 | box->AddValue(m_tcb->m_segmentSize);
310 | box->AddValue(m_segmentsAcked);
311 | box->AddValue(m_bytesInFlight);
312 | box->AddValue(m_rtt.GetMicroSeconds ());
313 | //box->AddValue(m_tcb->m_minRtt.GetMicroSeconds ());
314 | //box->AddValue(m_calledFunc);
315 | //box->AddValue(m_tcb->m_congState);
316 | //box->AddValue(m_event);
317 | //box->AddValue(m_tcb->m_ecnState);
318 |
319 | // Print data
320 | NS_LOG_INFO ("MyGetObservation: " << box);
321 | return box;
322 | }
323 |
324 | void
325 | TcpEventGymEnv::TxPktTrace(Ptr, const TcpHeader&, Ptr)
326 | {
327 | NS_LOG_FUNCTION (this);
328 | }
329 |
330 | void
331 | TcpEventGymEnv::RxPktTrace(Ptr, const TcpHeader&, Ptr)
332 | {
333 | NS_LOG_FUNCTION (this);
334 | }
335 |
336 | uint32_t
337 | TcpEventGymEnv::GetSsThresh (Ptr tcb, uint32_t bytesInFlight)
338 | {
339 | NS_LOG_FUNCTION (this);
340 | // pkt was lost, so penalty
341 | m_envReward = m_penalty;
342 |
343 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " GetSsThresh, BytesInFlight: " << bytesInFlight);
344 | m_calledFunc = CalledFunc_t::GET_SS_THRESH;
345 | m_info = "GetSsThresh";
346 | m_tcb = tcb;
347 | m_bytesInFlight = bytesInFlight;
348 | Notify();
349 | return m_new_ssThresh;
350 | }
351 |
352 | void
353 | TcpEventGymEnv::IncreaseWindow (Ptr tcb, uint32_t segmentsAcked)
354 | {
355 | NS_LOG_FUNCTION (this);
356 | // pkt was acked, so reward
357 | m_envReward = m_reward;
358 |
359 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " IncreaseWindow, SegmentsAcked: " << segmentsAcked);
360 | m_calledFunc = CalledFunc_t::INCREASE_WINDOW;
361 | m_info = "IncreaseWindow";
362 | m_tcb = tcb;
363 | m_segmentsAcked = segmentsAcked;
364 | Notify();
365 | tcb->m_cWnd = m_new_cWnd;
366 | }
367 |
368 | void
369 | TcpEventGymEnv::PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt)
370 | {
371 | NS_LOG_FUNCTION (this);
372 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " PktsAcked, SegmentsAcked: " << segmentsAcked << " Rtt: " << rtt);
373 | m_calledFunc = CalledFunc_t::PKTS_ACKED;
374 | m_info = "PktsAcked";
375 | m_tcb = tcb;
376 | m_segmentsAcked = segmentsAcked;
377 | m_rtt = rtt;
378 | }
379 |
380 | void
381 | TcpEventGymEnv::CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState)
382 | {
383 | NS_LOG_FUNCTION (this);
384 | std::string stateName = GetTcpCongStateName(newState);
385 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CongestionStateSet: " << newState << " " << stateName);
386 |
387 | m_calledFunc = CalledFunc_t::CONGESTION_STATE_SET;
388 | m_info = "CongestionStateSet";
389 | m_tcb = tcb;
390 | m_newState = newState;
391 | }
392 |
393 | void
394 | TcpEventGymEnv::CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event)
395 | {
396 | NS_LOG_FUNCTION (this);
397 | std::string eventName = GetTcpCAEventName(event);
398 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CwndEvent: " << event << " " << eventName);
399 |
400 | m_calledFunc = CalledFunc_t::CWND_EVENT;
401 | m_info = "CwndEvent";
402 | m_tcb = tcb;
403 | m_event = event;
404 | }
405 |
406 |
407 | NS_OBJECT_ENSURE_REGISTERED (TcpTimeStepGymEnv);
408 |
409 | TcpTimeStepGymEnv::TcpTimeStepGymEnv () : TcpGymEnv()
410 | {
411 | NS_LOG_FUNCTION (this);
412 | m_envReward = 0.0;
413 | }
414 |
415 | TcpTimeStepGymEnv::TcpTimeStepGymEnv (Time timeStep) : TcpGymEnv()
416 | {
417 | NS_LOG_FUNCTION (this);
418 | m_timeStep = timeStep;
419 | m_envReward = 0.0;
420 | }
421 |
422 | void
423 | TcpTimeStepGymEnv::ScheduleNextStateRead ()
424 | {
425 | NS_LOG_FUNCTION (this);
426 | Simulator::Schedule (m_timeStep, &TcpTimeStepGymEnv::ScheduleNextStateRead, this);
427 | Notify();
428 | }
429 |
430 | TcpTimeStepGymEnv::~TcpTimeStepGymEnv ()
431 | {
432 | NS_LOG_FUNCTION (this);
433 | }
434 |
435 | TypeId
436 | TcpTimeStepGymEnv::GetTypeId (void)
437 | {
438 | static TypeId tid = TypeId ("ns3::TcpTimeStepGymEnv")
439 | .SetParent ()
440 | .SetGroupName ("OpenGym")
441 | .AddConstructor ()
442 | ;
443 |
444 | return tid;
445 | }
446 |
447 | void
448 | TcpTimeStepGymEnv::DoDispose ()
449 | {
450 | NS_LOG_FUNCTION (this);
451 | }
452 |
453 | /*
454 | Define observation space
455 | */
456 | Ptr
457 | TcpTimeStepGymEnv::GetObservationSpace()
458 | {
459 | // socket unique ID
460 | // tcp env type: event-based = 0 / time-based = 1
461 | // sim time in us
462 | // node ID
463 | // ssThresh
464 | // cWnd
465 | // segmentSize
466 | // bytesInFlightSum
467 | // bytesInFlightAvg
468 | // segmentsAckedSum
469 | // segmentsAckedAvg
470 | // avgRtt
471 | // minRtt
472 | // avgInterTx
473 | // avgInterRx
474 | // throughput
475 | uint32_t parameterNum = 16;
476 | float low = 0.0;
477 | float high = 1000000000.0;
478 | std::vector shape = {parameterNum,};
479 | std::string dtype = TypeNameGet ();
480 |
481 | Ptr box = CreateObject (low, high, shape, dtype);
482 | NS_LOG_INFO ("MyGetObservationSpace: " << box);
483 | return box;
484 | }
485 |
486 | /*
487 | Collect observations
488 | */
489 | Ptr
490 | TcpTimeStepGymEnv::GetObservation()
491 | {
492 | uint32_t parameterNum = 16;
493 | std::vector shape = {parameterNum,};
494 |
495 | Ptr > box = CreateObject >(shape);
496 |
497 | box->AddValue(m_socketUuid);
498 | box->AddValue(1);
499 | box->AddValue(Simulator::Now().GetMicroSeconds ());
500 | box->AddValue(m_nodeId);
501 | box->AddValue(m_tcb->m_ssThresh);
502 | box->AddValue(m_tcb->m_cWnd);
503 | box->AddValue(m_tcb->m_segmentSize);
504 |
505 | //bytesInFlightSum
506 | uint64_t bytesInFlightSum = std::accumulate(m_bytesInFlight.begin(), m_bytesInFlight.end(), 0);
507 | box->AddValue(bytesInFlightSum);
508 |
509 | //bytesInFlightAvg
510 | uint64_t bytesInFlightAvg = 0;
511 | if (m_bytesInFlight.size()) {
512 | bytesInFlightAvg = bytesInFlightSum / m_bytesInFlight.size();
513 | }
514 | box->AddValue(bytesInFlightAvg);
515 |
516 | //segmentsAckedSum
517 | uint64_t segmentsAckedSum = std::accumulate(m_segmentsAcked.begin(), m_segmentsAcked.end(), 0);
518 | box->AddValue(segmentsAckedSum);
519 |
520 | //segmentsAckedAvg
521 | uint64_t segmentsAckedAvg = 0;
522 | if (m_segmentsAcked.size()) {
523 | segmentsAckedAvg = segmentsAckedSum / m_segmentsAcked.size();
524 | }
525 | box->AddValue(segmentsAckedAvg);
526 |
527 | //avgRtt
528 | Time avgRtt = Seconds(0.0);
529 | if(m_rttSampleNum) {
530 | avgRtt = m_rttSum / m_rttSampleNum;
531 | }
532 | box->AddValue(avgRtt.GetMicroSeconds ());
533 |
534 | //m_minRtt
535 | box->AddValue(m_tcb->m_minRtt.GetMicroSeconds ());
536 |
537 | //avgInterTx
538 | Time avgInterTx = Seconds(0.0);
539 | if (m_interTxTimeNum) {
540 | avgInterTx = m_interTxTimeSum / m_interTxTimeNum;
541 | }
542 | box->AddValue(avgInterTx.GetMicroSeconds ());
543 |
544 | //avgInterRx
545 | Time avgInterRx = Seconds(0.0);
546 | if (m_interRxTimeNum) {
547 | avgInterRx = m_interRxTimeSum / m_interRxTimeNum;
548 | }
549 | box->AddValue(avgInterRx.GetMicroSeconds ());
550 |
551 | //throughput bytes/s
552 | float throughput = (segmentsAckedSum * m_tcb->m_segmentSize) / m_timeStep.GetSeconds();
553 | box->AddValue(throughput);
554 |
555 | // Print data
556 | NS_LOG_INFO ("MyGetObservation: " << box);
557 |
558 | m_bytesInFlight.clear();
559 | m_segmentsAcked.clear();
560 |
561 | m_rttSampleNum = 0;
562 | m_rttSum = MicroSeconds (0.0);
563 |
564 | m_interTxTimeNum = 0;
565 | m_interTxTimeSum = MicroSeconds (0.0);
566 |
567 | m_interRxTimeNum = 0;
568 | m_interRxTimeSum = MicroSeconds (0.0);
569 |
570 | return box;
571 | }
572 |
573 | void
574 | TcpTimeStepGymEnv::TxPktTrace(Ptr, const TcpHeader&, Ptr)
575 | {
576 | NS_LOG_FUNCTION (this);
577 | if ( m_lastPktTxTime > MicroSeconds(0.0) ) {
578 | Time interTxTime = Simulator::Now() - m_lastPktTxTime;
579 | m_interTxTimeSum += interTxTime;
580 | m_interTxTimeNum++;
581 | }
582 |
583 | m_lastPktTxTime = Simulator::Now();
584 | }
585 |
586 | void
587 | TcpTimeStepGymEnv::RxPktTrace(Ptr, const TcpHeader&, Ptr)
588 | {
589 | NS_LOG_FUNCTION (this);
590 | if ( m_lastPktRxTime > MicroSeconds(0.0) ) {
591 | Time interRxTime = Simulator::Now() - m_lastPktRxTime;
592 | m_interRxTimeSum += interRxTime;
593 | m_interRxTimeNum++;
594 | }
595 |
596 | m_lastPktRxTime = Simulator::Now();
597 | }
598 |
599 | uint32_t
600 | TcpTimeStepGymEnv::GetSsThresh (Ptr tcb, uint32_t bytesInFlight)
601 | {
602 | NS_LOG_FUNCTION (this);
603 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " GetSsThresh, BytesInFlight: " << bytesInFlight);
604 | m_tcb = tcb;
605 | m_bytesInFlight.push_back(bytesInFlight);
606 |
607 | if (!m_started) {
608 | m_started = true;
609 | Notify();
610 | ScheduleNextStateRead();
611 | }
612 |
613 | // action
614 | return m_new_ssThresh;
615 | }
616 |
617 | void
618 | TcpTimeStepGymEnv::IncreaseWindow (Ptr tcb, uint32_t segmentsAcked)
619 | {
620 | NS_LOG_FUNCTION (this);
621 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " IncreaseWindow, SegmentsAcked: " << segmentsAcked);
622 | m_tcb = tcb;
623 | m_segmentsAcked.push_back(segmentsAcked);
624 | m_bytesInFlight.push_back(tcb->m_bytesInFlight);
625 |
626 | if (!m_started) {
627 | m_started = true;
628 | Notify();
629 | ScheduleNextStateRead();
630 | }
631 | // action
632 | tcb->m_cWnd = m_new_cWnd;
633 | }
634 |
635 | void
636 | TcpTimeStepGymEnv::PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt)
637 | {
638 | NS_LOG_FUNCTION (this);
639 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " PktsAcked, SegmentsAcked: " << segmentsAcked << " Rtt: " << rtt);
640 | m_tcb = tcb;
641 | m_rttSum += rtt;
642 | m_rttSampleNum++;
643 | }
644 |
645 | void
646 | TcpTimeStepGymEnv::CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState)
647 | {
648 | NS_LOG_FUNCTION (this);
649 | std::string stateName = GetTcpCongStateName(newState);
650 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CongestionStateSet: " << newState << " " << stateName);
651 | m_tcb = tcb;
652 | }
653 |
654 | void
655 | TcpTimeStepGymEnv::CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event)
656 | {
657 | NS_LOG_FUNCTION (this);
658 | std::string eventName = GetTcpCAEventName(event);
659 | NS_LOG_INFO(Simulator::Now() << " Node: " << m_nodeId << " CwndEvent: " << event << " " << eventName);
660 | m_tcb = tcb;
661 | }
662 |
663 | } // namespace ns3
664 |
--------------------------------------------------------------------------------
/tcp-rl-env.h:
--------------------------------------------------------------------------------
1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
2 | /*
3 | * Copyright (c) 2018 Technische Universität Berlin
4 | *
5 | * This program is free software; you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License version 2 as
7 | * published by the Free Software Foundation;
8 | *
9 | * This program is distributed in the hope that it will be useful,
10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | * GNU General Public License for more details.
13 | *
14 | * You should have received a copy of the GNU General Public License
15 | * along with this program; if not, write to the Free Software
16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17 | *
18 | * Author: Piotr Gawlowicz
19 | */
20 |
21 | #ifndef TCP_RL_ENV_H
22 | #define TCP_RL_ENV_H
23 |
24 | #include "ns3/opengym-module.h"
25 | #include "ns3/tcp-socket-base.h"
26 | #include
27 |
28 | namespace ns3 {
29 |
30 | class Packet;
31 | class TcpHeader;
32 | class TcpSocketBase;
33 | class Time;
34 |
35 |
36 | class TcpGymEnv : public OpenGymEnv
37 | {
38 | public:
39 | TcpGymEnv ();
40 | virtual ~TcpGymEnv ();
41 | static TypeId GetTypeId (void);
42 | virtual void DoDispose ();
43 |
44 | void SetNodeId(uint32_t id);
45 | void SetSocketUuid(uint32_t id);
46 |
47 | std::string GetTcpCongStateName(const TcpSocketState::TcpCongState_t state);
48 | std::string GetTcpCAEventName(const TcpSocketState::TcpCAEvent_t event);
49 |
50 | // OpenGym interface
51 | virtual Ptr GetActionSpace();
52 | virtual bool GetGameOver();
53 | virtual float GetReward();
54 | virtual std::string GetExtraInfo();
55 | virtual bool ExecuteActions(Ptr action);
56 |
57 | virtual Ptr GetObservationSpace() = 0;
58 | virtual Ptr GetObservation() = 0;
59 |
60 | // trace packets, e.g. for calculating inter tx/rx time
61 | virtual void TxPktTrace(Ptr, const TcpHeader&, Ptr) = 0;
62 | virtual void RxPktTrace(Ptr, const TcpHeader&, Ptr) = 0;
63 |
64 | // TCP congestion control interface
65 | virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight) = 0;
66 | virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked) = 0;
67 | // optional functions used to collect obs
68 | virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt) = 0;
69 | virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState) = 0;
70 | virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event) = 0;
71 |
72 | typedef enum
73 | {
74 | GET_SS_THRESH = 0,
75 | INCREASE_WINDOW,
76 | PKTS_ACKED,
77 | CONGESTION_STATE_SET,
78 | CWND_EVENT,
79 | } CalledFunc_t;
80 |
81 | protected:
82 | uint32_t m_nodeId;
83 | uint32_t m_socketUuid;
84 |
85 | // state
86 | // obs has to be implemented in child class
87 |
88 | // game over
89 | bool m_isGameOver;
90 |
91 | // reward
92 | float m_envReward;
93 |
94 | // extra info
95 | std::string m_info;
96 |
97 | // actions
98 | uint32_t m_new_ssThresh;
99 | uint32_t m_new_cWnd;
100 | };
101 |
102 |
103 | class TcpEventGymEnv : public TcpGymEnv
104 | {
105 | public:
106 | TcpEventGymEnv ();
107 | virtual ~TcpEventGymEnv ();
108 | static TypeId GetTypeId (void);
109 | virtual void DoDispose ();
110 |
111 | void SetReward(float value);
112 | void SetPenalty(float value);
113 |
114 | // OpenGym interface
115 | virtual Ptr GetObservationSpace();
116 | Ptr GetObservation();
117 |
118 | // trace packets, e.g. for calculating inter tx/rx time
119 | virtual void TxPktTrace(Ptr, const TcpHeader&, Ptr);
120 | virtual void RxPktTrace(Ptr, const TcpHeader&, Ptr);
121 |
122 | // TCP congestion control interface
123 | virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight);
124 | virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked);
125 | // optional functions used to collect obs
126 | virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt);
127 | virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState);
128 | virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event);
129 |
130 | private:
131 | // state
132 | CalledFunc_t m_calledFunc;
133 | Ptr m_tcb;
134 | uint32_t m_bytesInFlight;
135 | uint32_t m_segmentsAcked;
136 | Time m_rtt;
137 | TcpSocketState::TcpCongState_t m_newState;
138 | TcpSocketState::TcpCAEvent_t m_event;
139 |
140 | // reward
141 | float m_reward;
142 | float m_penalty;
143 | };
144 |
145 |
146 | class TcpTimeStepGymEnv : public TcpGymEnv
147 | {
148 | public:
149 | TcpTimeStepGymEnv ();
150 | TcpTimeStepGymEnv (Time timeStep);
151 | virtual ~TcpTimeStepGymEnv ();
152 | static TypeId GetTypeId (void);
153 | virtual void DoDispose ();
154 |
155 | // OpenGym interface
156 | virtual Ptr GetObservationSpace();
157 | Ptr GetObservation();
158 |
159 | // trace packets, e.g. for calculating inter tx/rx time
160 | virtual void TxPktTrace(Ptr, const TcpHeader&, Ptr);
161 | virtual void RxPktTrace(Ptr, const TcpHeader&, Ptr);
162 |
163 | // TCP congestion control interface
164 | virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight);
165 | virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked);
166 | // optional functions used to collect obs
167 | virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt);
168 | virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState);
169 | virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event);
170 |
171 | private:
172 | void ScheduleNextStateRead();
173 | bool m_started {false};
174 | Time m_timeStep;
175 | // state
176 | Ptr m_tcb;
177 | std::vector m_bytesInFlight;
178 | std::vector m_segmentsAcked;
179 |
180 | uint64_t m_rttSampleNum {0};
181 | Time m_rttSum {MicroSeconds (0.0)};
182 |
183 | Time m_lastPktTxTime {MicroSeconds(0.0)};
184 | Time m_lastPktRxTime {MicroSeconds(0.0)};
185 | uint64_t m_interTxTimeNum {0};
186 | Time m_interTxTimeSum {MicroSeconds (0.0)};
187 | uint64_t m_interRxTimeNum {0};
188 | Time m_interRxTimeSum {MicroSeconds (0.0)};
189 |
190 | // reward
191 | };
192 |
193 |
194 |
195 | } // namespace ns3
196 |
197 | #endif /* TCP_RL_ENV_H */
198 |
--------------------------------------------------------------------------------
/tcp-rl.cc:
--------------------------------------------------------------------------------
1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
2 | /*
3 | * Copyright (c) 2018 Technische Universität Berlin
4 | *
5 | * This program is free software; you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License version 2 as
7 | * published by the Free Software Foundation;
8 | *
9 | * This program is distributed in the hope that it will be useful,
10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | * GNU General Public License for more details.
13 | *
14 | * You should have received a copy of the GNU General Public License
15 | * along with this program; if not, write to the Free Software
16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17 | *
18 | * Author: Piotr Gawlowicz
19 | */
20 |
21 | #include "tcp-rl.h"
22 | #include "tcp-rl-env.h"
23 | #include "ns3/tcp-header.h"
24 | #include "ns3/object.h"
25 | #include "ns3/node-list.h"
26 | #include "ns3/core-module.h"
27 | #include "ns3/log.h"
28 | #include "ns3/simulator.h"
29 | #include "ns3/tcp-socket-base.h"
30 | #include "ns3/tcp-l4-protocol.h"
31 |
32 |
33 | namespace ns3 {
34 |
35 |
36 | NS_OBJECT_ENSURE_REGISTERED (TcpSocketDerived);
37 |
38 | TypeId
39 | TcpSocketDerived::GetTypeId (void)
40 | {
41 | static TypeId tid = TypeId ("ns3::TcpSocketDerived")
42 | .SetParent ()
43 | .SetGroupName ("Internet")
44 | .AddConstructor ()
45 | ;
46 | return tid;
47 | }
48 |
49 | TypeId
50 | TcpSocketDerived::GetInstanceTypeId () const
51 | {
52 | return TcpSocketDerived::GetTypeId ();
53 | }
54 |
55 | TcpSocketDerived::TcpSocketDerived (void)
56 | {
57 | }
58 |
59 | Ptr
60 | TcpSocketDerived::GetCongestionControlAlgorithm ()
61 | {
62 | return m_congestionControl;
63 | }
64 |
65 | TcpSocketDerived::~TcpSocketDerived (void)
66 | {
67 | }
68 |
69 |
70 | NS_LOG_COMPONENT_DEFINE ("ns3::TcpRlBase");
71 | NS_OBJECT_ENSURE_REGISTERED (TcpRlBase);
72 |
73 | TypeId
74 | TcpRlBase::GetTypeId (void)
75 | {
76 | static TypeId tid = TypeId ("ns3::TcpRlBase")
77 | .SetParent ()
78 | .SetGroupName ("Internet")
79 | .AddConstructor ()
80 | ;
81 | return tid;
82 | }
83 |
84 | TcpRlBase::TcpRlBase (void)
85 | : TcpCongestionOps ()
86 | {
87 | NS_LOG_FUNCTION (this);
88 | m_tcpSocket = 0;
89 | m_tcpGymEnv = 0;
90 | }
91 |
92 | TcpRlBase::TcpRlBase (const TcpRlBase& sock)
93 | : TcpCongestionOps (sock)
94 | {
95 | NS_LOG_FUNCTION (this);
96 | m_tcpSocket = 0;
97 | m_tcpGymEnv = 0;
98 | }
99 |
100 | TcpRlBase::~TcpRlBase (void)
101 | {
102 | m_tcpSocket = 0;
103 | m_tcpGymEnv = 0;
104 | }
105 |
106 | uint64_t
107 | TcpRlBase::GenerateUuid ()
108 | {
109 | static uint64_t uuid = 0;
110 | uuid++;
111 | return uuid;
112 | }
113 |
114 | void
115 | TcpRlBase::CreateGymEnv()
116 | {
117 | NS_LOG_FUNCTION (this);
118 | // should never be called, only child classes: TcpRl and TcpRlTimeBased
119 | }
120 |
121 | void
122 | TcpRlBase::ConnectSocketCallbacks()
123 | {
124 | NS_LOG_FUNCTION (this);
125 |
126 | bool foundSocket = false;
127 | for (NodeList::Iterator i = NodeList::Begin (); i != NodeList::End (); ++i) {
128 | Ptr node = *i;
129 | Ptr tcp = node->GetObject ();
130 |
131 | ObjectVectorValue socketVec;
132 | tcp->GetAttribute ("SocketList", socketVec);
133 | NS_LOG_DEBUG("Node: " << node->GetId() << " TCP socket num: " << socketVec.GetN());
134 |
135 | uint32_t sockNum = socketVec.GetN();
136 | for (uint32_t j=0; j sockObj = socketVec.Get(j);
138 | Ptr tcpSocket = DynamicCast (sockObj);
139 | NS_LOG_DEBUG("Node: " << node->GetId() << " TCP Socket: " << tcpSocket);
140 | if(!tcpSocket) { continue; }
141 |
142 | Ptr dtcpSocket = StaticCast(tcpSocket);
143 | Ptr ca = dtcpSocket->GetCongestionControlAlgorithm();
144 | NS_LOG_DEBUG("CA name: " << ca->GetName());
145 | Ptr rlCa = DynamicCast(ca);
146 | if (rlCa == this) {
147 | NS_LOG_DEBUG("Found TcpRl CA!");
148 | foundSocket = true;
149 | m_tcpSocket = tcpSocket;
150 | break;
151 | }
152 | }
153 |
154 | if (foundSocket) {
155 | break;
156 | }
157 | }
158 |
159 | NS_ASSERT_MSG(m_tcpSocket, "TCP socket was not found.");
160 |
161 | if(m_tcpSocket) {
162 | NS_LOG_DEBUG("Found TCP Socket: " << m_tcpSocket);
163 | m_tcpSocket->TraceConnectWithoutContext ("Tx", MakeCallback (&TcpGymEnv::TxPktTrace, m_tcpGymEnv));
164 | m_tcpSocket->TraceConnectWithoutContext ("Rx", MakeCallback (&TcpGymEnv::RxPktTrace, m_tcpGymEnv));
165 | NS_LOG_DEBUG("Connect socket callbacks " << m_tcpSocket->GetNode()->GetId());
166 | m_tcpGymEnv->SetNodeId(m_tcpSocket->GetNode()->GetId());
167 | }
168 | }
169 |
170 | std::string
171 | TcpRlBase::GetName () const
172 | {
173 | return "TcpRlBase";
174 | }
175 |
176 | uint32_t
177 | TcpRlBase::GetSsThresh (Ptr state,
178 | uint32_t bytesInFlight)
179 | {
180 | NS_LOG_FUNCTION (this << state << bytesInFlight);
181 |
182 | if (!m_tcpGymEnv) {
183 | CreateGymEnv();
184 | }
185 |
186 | uint32_t newSsThresh = 0;
187 | if (m_tcpGymEnv) {
188 | newSsThresh = m_tcpGymEnv->GetSsThresh(state, bytesInFlight);
189 | }
190 |
191 | return newSsThresh;
192 | }
193 |
194 | void
195 | TcpRlBase::IncreaseWindow (Ptr tcb, uint32_t segmentsAcked)
196 | {
197 | NS_LOG_FUNCTION (this << tcb << segmentsAcked);
198 |
199 | if (!m_tcpGymEnv) {
200 | CreateGymEnv();
201 | }
202 |
203 | if (m_tcpGymEnv) {
204 | m_tcpGymEnv->IncreaseWindow(tcb, segmentsAcked);
205 | }
206 | }
207 |
208 | void
209 | TcpRlBase::PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt)
210 | {
211 | NS_LOG_FUNCTION (this);
212 |
213 | if (!m_tcpGymEnv) {
214 | CreateGymEnv();
215 | }
216 |
217 | if (m_tcpGymEnv) {
218 | m_tcpGymEnv->PktsAcked(tcb, segmentsAcked, rtt);
219 | }
220 | }
221 |
222 | void
223 | TcpRlBase::CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState)
224 | {
225 | NS_LOG_FUNCTION (this);
226 |
227 | if (!m_tcpGymEnv) {
228 | CreateGymEnv();
229 | }
230 |
231 | if (m_tcpGymEnv) {
232 | m_tcpGymEnv->CongestionStateSet(tcb, newState);
233 | }
234 | }
235 |
236 | void
237 | TcpRlBase::CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event)
238 | {
239 | NS_LOG_FUNCTION (this);
240 |
241 | if (!m_tcpGymEnv) {
242 | CreateGymEnv();
243 | }
244 |
245 | if (m_tcpGymEnv) {
246 | m_tcpGymEnv->CwndEvent(tcb, event);
247 | }
248 | }
249 |
250 | Ptr
251 | TcpRlBase::Fork ()
252 | {
253 | return CopyObject (this);
254 | }
255 |
256 |
257 | NS_OBJECT_ENSURE_REGISTERED (TcpRl);
258 |
259 | TypeId
260 | TcpRl::GetTypeId (void)
261 | {
262 | static TypeId tid = TypeId ("ns3::TcpRl")
263 | .SetParent ()
264 | .SetGroupName ("Internet")
265 | .AddConstructor ()
266 | .AddAttribute ("Reward", "Reward when increasing congestion window.",
267 | DoubleValue (4.0),
268 | MakeDoubleAccessor (&TcpRl::m_reward),
269 | MakeDoubleChecker ())
270 | .AddAttribute ("Penalty", "Reward when increasing congestion window.",
271 | DoubleValue (-10.0),
272 | MakeDoubleAccessor (&TcpRl::m_penalty),
273 | MakeDoubleChecker ())
274 | ;
275 | return tid;
276 | }
277 |
278 | TcpRl::TcpRl (void)
279 | : TcpRlBase ()
280 | {
281 | NS_LOG_FUNCTION (this);
282 | }
283 |
284 | TcpRl::TcpRl (const TcpRl& sock)
285 | : TcpRlBase (sock)
286 | {
287 | NS_LOG_FUNCTION (this);
288 | }
289 |
290 | TcpRl::~TcpRl (void)
291 | {
292 | }
293 |
294 | std::string
295 | TcpRl::GetName () const
296 | {
297 | return "TcpRl";
298 | }
299 |
300 | void
301 | TcpRl::CreateGymEnv()
302 | {
303 | NS_LOG_FUNCTION (this);
304 | Ptr env = CreateObject();
305 | env->SetSocketUuid(TcpRlBase::GenerateUuid());
306 | env->SetReward(m_reward);
307 | env->SetPenalty(m_penalty);
308 | m_tcpGymEnv = env;
309 |
310 | ConnectSocketCallbacks();
311 | }
312 |
313 |
314 | NS_OBJECT_ENSURE_REGISTERED (TcpRlTimeBased);
315 |
316 | TypeId
317 | TcpRlTimeBased::GetTypeId (void)
318 | {
319 | static TypeId tid = TypeId ("ns3::TcpRlTimeBased")
320 | .SetParent ()
321 | .SetGroupName ("Internet")
322 | .AddConstructor ()
323 | .AddAttribute ("StepTime",
324 | "Step interval used in TCP env. Default: 100ms",
325 | TimeValue (MilliSeconds (100)),
326 | MakeTimeAccessor (&TcpRlTimeBased::m_timeStep),
327 | MakeTimeChecker ())
328 | ;
329 | return tid;
330 | }
331 |
332 | TcpRlTimeBased::TcpRlTimeBased (void) : TcpRlBase ()
333 | {
334 | NS_LOG_FUNCTION (this);
335 | }
336 |
337 | TcpRlTimeBased::TcpRlTimeBased (const TcpRlTimeBased& sock)
338 | : TcpRlBase (sock)
339 | {
340 | NS_LOG_FUNCTION (this);
341 | }
342 |
343 | TcpRlTimeBased::~TcpRlTimeBased (void)
344 | {
345 | }
346 |
347 | std::string
348 | TcpRlTimeBased::GetName () const
349 | {
350 | return "TcpRlTimeBased";
351 | }
352 |
353 | void
354 | TcpRlTimeBased::CreateGymEnv()
355 | {
356 | NS_LOG_FUNCTION (this);
357 | Ptr env = CreateObject (m_timeStep);
358 | env->SetSocketUuid(TcpRlBase::GenerateUuid());
359 | m_tcpGymEnv = env;
360 |
361 | ConnectSocketCallbacks();
362 | }
363 |
364 | } // namespace ns3
365 |
--------------------------------------------------------------------------------
/tcp-rl.h:
--------------------------------------------------------------------------------
1 | /* -*- Mode: C++; c-file-style: "gnu"; indent-tabs-mode:nil; -*- */
2 | /*
3 | * Copyright (c) 2018 Technische Universität Berlin
4 | *
5 | * This program is free software; you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License version 2 as
7 | * published by the Free Software Foundation;
8 | *
9 | * This program is distributed in the hope that it will be useful,
10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | * GNU General Public License for more details.
13 | *
14 | * You should have received a copy of the GNU General Public License
15 | * along with this program; if not, write to the Free Software
16 | * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17 | *
18 | * Author: Piotr Gawlowicz
19 | */
20 |
21 | #ifndef TCP_RL_H
22 | #define TCP_RL_H
23 |
24 | #include "ns3/tcp-congestion-ops.h"
25 | #include "ns3/opengym-module.h"
26 | #include "ns3/tcp-socket-base.h"
27 |
28 | namespace ns3 {
29 |
30 | class TcpSocketBase;
31 | class Time;
32 | class TcpGymEnv;
33 |
34 |
35 | // used to get pointer to Congestion Algorithm
36 | class TcpSocketDerived : public TcpSocketBase
37 | {
38 | public:
39 | static TypeId GetTypeId (void);
40 | virtual TypeId GetInstanceTypeId () const;
41 |
42 | TcpSocketDerived (void);
43 | virtual ~TcpSocketDerived (void);
44 |
45 | Ptr GetCongestionControlAlgorithm ();
46 | };
47 |
48 |
49 | class TcpRlBase : public TcpCongestionOps
50 | {
51 | public:
52 | /**
53 | * \brief Get the type ID.
54 | * \return the object TypeId
55 | */
56 | static TypeId GetTypeId (void);
57 |
58 | TcpRlBase ();
59 |
60 | /**
61 | * \brief Copy constructor.
62 | * \param sock object to copy.
63 | */
64 | TcpRlBase (const TcpRlBase& sock);
65 |
66 | ~TcpRlBase ();
67 |
68 | virtual std::string GetName () const;
69 | virtual uint32_t GetSsThresh (Ptr tcb, uint32_t bytesInFlight);
70 | virtual void IncreaseWindow (Ptr tcb, uint32_t segmentsAcked);
71 | virtual void PktsAcked (Ptr tcb, uint32_t segmentsAcked, const Time& rtt);
72 | virtual void CongestionStateSet (Ptr tcb, const TcpSocketState::TcpCongState_t newState);
73 | virtual void CwndEvent (Ptr tcb, const TcpSocketState::TcpCAEvent_t event);
74 | virtual Ptr Fork ();
75 |
76 | protected:
77 | static uint64_t GenerateUuid ();
78 | virtual void CreateGymEnv();
79 | void ConnectSocketCallbacks();
80 |
81 | // OpenGymEnv interface
82 | Ptr m_tcpSocket;
83 | Ptr m_tcpGymEnv;
84 | };
85 |
86 |
87 | class TcpRl : public TcpRlBase
88 | {
89 | public:
90 | static TypeId GetTypeId (void);
91 |
92 | TcpRl ();
93 | TcpRl (const TcpRl& sock);
94 | ~TcpRl ();
95 |
96 | virtual std::string GetName () const;
97 | private:
98 | virtual void CreateGymEnv();
99 | // OpenGymEnv env
100 | float m_reward {1.0};
101 | float m_penalty {-100.0};
102 | };
103 |
104 |
105 | class TcpRlTimeBased : public TcpRlBase
106 | {
107 | public:
108 | static TypeId GetTypeId (void);
109 |
110 | TcpRlTimeBased ();
111 | TcpRlTimeBased (const TcpRlTimeBased& sock);
112 | ~TcpRlTimeBased ();
113 |
114 | virtual std::string GetName () const;
115 |
116 | private:
117 | virtual void CreateGymEnv();
118 | // OpenGymEnv env
119 | Time m_timeStep {MilliSeconds (100)};
120 | };
121 |
122 | } // namespace ns3
123 |
124 | #endif /* TCP_RL_H */
--------------------------------------------------------------------------------
/tcp_base.py:
--------------------------------------------------------------------------------
1 | __author__ = "Piotr Gawlowicz"
2 | __copyright__ = "Copyright (c) 2018, Technische Universität Berlin"
3 | __version__ = "0.1.0"
4 | __email__ = "gawlowicz@tkn.tu-berlin.de"
5 |
6 |
7 | class Tcp(object):
8 | """docstring for Tcp"""
9 | def __init__(self):
10 | super(Tcp, self).__init__()
11 |
12 | def set_spaces(self, obs, act):
13 | self.obsSpace = obs
14 | self.actSpace = act
15 |
16 | def get_action(self, obs, reward, done, info):
17 | pass
18 |
19 |
20 | class TcpEventBased(Tcp):
21 | """docstring for TcpEventBased"""
22 | def __init__(self):
23 | super(TcpEventBased, self).__init__()
24 |
25 | def get_action(self, obs, reward, done, info):
26 | # unique socket ID
27 | socketUuid = obs[0]
28 | # TCP env type: event-based = 0 / time-based = 1
29 | envType = obs[1]
30 | # sim time in us
31 | simTime_us = obs[2]
32 | # unique node ID
33 | nodeId = obs[3]
34 | # current ssThreshold
35 | ssThresh = obs[4]
36 | # current contention window size
37 | cWnd = obs[5]
38 | # segment size
39 | segmentSize = obs[6]
40 | # number of acked segments
41 | segmentsAcked = obs[7]
42 | # estimated bytes in flight
43 | bytesInFlight = obs[8]
44 | # last estimation of RTT
45 | lastRtt_us = obs[9]
46 | # min value of RTT
47 | minRtt_us = obs[10]
48 | # function from Congestion Algorithm (CA) interface:
49 | # GET_SS_THRESH = 0 (packet loss),
50 | # INCREASE_WINDOW (packet acked),
51 | # PKTS_ACKED (unused),
52 | # CONGESTION_STATE_SET (unused),
53 | # CWND_EVENT (unused),
54 | calledFunc = obs[11]
55 | # Congetsion Algorithm (CA) state:
56 | # CA_OPEN = 0,
57 | # CA_DISORDER,
58 | # CA_CWR,
59 | # CA_RECOVERY,
60 | # CA_LOSS,
61 | # CA_LAST_STATE
62 | caState = obs[12]
63 | # Congetsion Algorithm (CA) event:
64 | # CA_EVENT_TX_START = 0,
65 | # CA_EVENT_CWND_RESTART,
66 | # CA_EVENT_COMPLETE_CWR,
67 | # CA_EVENT_LOSS,
68 | # CA_EVENT_ECN_NO_CE,
69 | # CA_EVENT_ECN_IS_CE,
70 | # CA_EVENT_DELAYED_ACK,
71 | # CA_EVENT_NON_DELAYED_ACK,
72 | caEvent = obs[13]
73 | # ECN state:
74 | # ECN_DISABLED = 0,
75 | # ECN_IDLE,
76 | # ECN_CE_RCVD,
77 | # ECN_SENDING_ECE,
78 | # ECN_ECE_RCVD,
79 | # ECN_CWR_SENT
80 | ecnState = obs[14]
81 |
82 | # compute new values
83 | new_cWnd = 10 * segmentSize
84 | new_ssThresh = 5 * segmentSize
85 |
86 | # return actions
87 | # actions = [new_ssThresh, new_cWnd]
88 |
89 | return actions
90 |
91 |
92 | class TcpTimeBased(Tcp):
93 | """docstring for TcpTimeBased"""
94 | def __init__(self):
95 | super(TcpTimeBased, self).__init__()
96 |
97 | def get_action(self, obs, reward, done, info):
98 | # unique socket ID
99 | socketUuid = obs[0]
100 | # TCP env type: event-based = 0 / time-based = 1
101 | envType = obs[1]
102 | # sim time in us
103 | simTime_us = obs[2]
104 | # unique node ID
105 | nodeId = obs[3]
106 | # current ssThreshold
107 | ssThresh = obs[4]
108 | # current contention window size
109 | cWnd = obs[5]
110 | # segment size
111 | segmentSize = obs[6]
112 | # bytesInFlightSum
113 | bytesInFlightSum = obs[7]
114 | # bytesInFlightAvg
115 | bytesInFlightAvg = obs[8]
116 | # segmentsAckedSum
117 | segmentsAckedSum = obs[9]
118 | # segmentsAckedAvg
119 | segmentsAckedAvg = obs[10]
120 | # avgRtt
121 | avgRtt = obs[11]
122 | # minRtt
123 | minRtt = obs[12]
124 | # avgInterTx
125 | avgInterTx = obs[13]
126 | # avgInterRx
127 | avgInterRx = obs[14]
128 | # throughput
129 | throughput = obs[15]
130 |
131 | # compute new values
132 | new_cWnd = 10 * segmentSize
133 | # new_cWnd2 = cWnd - 10
134 | new_ssThresh = 5 * segmentSize
135 |
136 | # return actions
137 | actions = [new_cWnd2, new_ssThresh,]
138 |
139 | return actions
140 |
--------------------------------------------------------------------------------