├── .DS_Store ├── BandwidthEstimator.py ├── BandwidthEstimator_gcc.py ├── README.md ├── azure-pipelines.yml ├── deep_rl ├── .DS_Store ├── __pycache__ │ ├── actor_critic_cnn.cpython-38.pyc │ ├── actor_critic_cnn_plus.cpython-38.pyc │ ├── ppo_agent.cpython-38.pyc │ └── ppo_agent_plus.cpython-38.pyc ├── actor_critic.py └── ppo_agent.py ├── onnx-model.onnx ├── packet_info.py ├── packet_record.py └── ppo_2021_07_25_04_57_11_with500trace.pth /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thegreatwb/HRCC/2f7844da456e42b4a0b0eabe5f3a9be86d39356e/.DS_Store -------------------------------------------------------------------------------- /BandwidthEstimator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | from deep_rl.ppo_agent import PPO 4 | import torch 5 | from packet_info import PacketInfo 6 | from packet_record import PacketRecord 7 | from BandwidthEstimator_gcc import GCCEstimator 8 | 9 | 10 | class Estimator(object): 11 | def __init__(self, model_path="./model/pretrained_model.pth", step_time=200): 12 | ''' 13 | Import existing models 14 | ''' 15 | # 1. Define model-related parameters 16 | exploration_param = 0.1 # the std var of action distribution 17 | K_epochs = 37 # update policy for K_epochs 18 | ppo_clip = 0.1 # clip parameter of PPO 19 | gamma = 0.99 # discount factor 20 | lr = 3e-5 # Adam parameters 21 | betas = (0.9, 0.999) 22 | self.state_dim = 6 23 | self.state_length = 10 24 | action_dim = 1 25 | # 2. Load model 26 | self.device = torch.device("cpu") 27 | self.ppo = PPO(self.state_dim, self.state_length, action_dim, exploration_param, lr, betas, gamma, K_epochs, ppo_clip) 28 | self.ppo.policy.load_state_dict(torch.load('./ppo_2021_07_25_04_57_11_with500trace.pth')) 29 | self.packet_record = PacketRecord() 30 | self.packet_record.reset() 31 | self.step_time = step_time 32 | # 3. Initialization 33 | self.state = torch.zeros((1, self.state_dim, self.state_length)) 34 | self.time_to_guide = False 35 | self.counter = 0 36 | self.bandwidth_prediction = 300000 37 | self.gcc_estimator = GCCEstimator() 38 | self.receiving_rate_list = [] 39 | self.delay_list = [] 40 | self.loss_ratio_list = [] 41 | self.bandwidth_prediction_list = [] 42 | self.overuse_flag = 'NORMAL' 43 | self.overuse_distance = 5 44 | self.last_overuse_cap = 1000000 45 | 46 | def report_states(self, stats: dict): 47 | ''' 48 | stats is a dict with the following items 49 | { 50 | "send_time_ms": uint, 51 | "arrival_time_ms": uint, 52 | "payload_type": int, 53 | "sequence_number": uint, 54 | "ssrc": int, 55 | "padding_length": uint, 56 | "header_length": uint, 57 | "payload_size": uint 58 | } 59 | ''' 60 | packet_info = PacketInfo() 61 | packet_info.payload_type = stats["payload_type"] 62 | packet_info.ssrc = stats["ssrc"] 63 | packet_info.sequence_number = stats["sequence_number"] 64 | packet_info.send_timestamp = stats["send_time_ms"] 65 | packet_info.receive_timestamp = stats["arrival_time_ms"] 66 | packet_info.padding_length = stats["padding_length"] 67 | packet_info.header_length = stats["header_length"] 68 | packet_info.payload_size = stats["payload_size"] 69 | packet_info.bandwidth_prediction = self.bandwidth_prediction 70 | 71 | self.packet_record.on_receive(packet_info) 72 | self.gcc_estimator.report_states(stats) 73 | 74 | def get_estimated_bandwidth(self)->int: 75 | ''' 76 | Calculate estimated bandwidth 77 | ''' 78 | # 1. Calculate state 79 | self.receiving_rate = self.packet_record.calculate_receiving_rate(interval=self.step_time) 80 | self.receiving_rate_list.append(self.receiving_rate) 81 | self.delay = self.packet_record.calculate_average_delay(interval=self.step_time) 82 | self.delay_list.append(self.delay) 83 | 84 | self.loss_ratio = self.packet_record.calculate_loss_ratio(interval=self.step_time) 85 | self.loss_ratio_list.append(self.loss_ratio) 86 | 87 | self.gcc_decision, self.overuse_flag = self.gcc_estimator.get_estimated_bandwidth() 88 | if self.overuse_flag == 'OVERUSE': 89 | self.overuse_distance = 0 90 | self.last_overuse_cap = self.receiving_rate 91 | else: 92 | self.overuse_distance += 1 93 | self.state = self.state.clone().detach() 94 | self.state = torch.roll(self.state, -1, dims=-1) 95 | 96 | self.state[0, 0, -1] = self.receiving_rate / 6000000.0 97 | self.state[0, 1, -1] = self.delay / 1000.0 98 | self.state[0, 2, -1] = self.loss_ratio 99 | self.state[0, 3, -1] = self.bandwidth_prediction / 6000000.0 100 | self.state[0, 4, -1] = self.overuse_distance / 100.0 101 | self.state[0, 5, -1] = self.last_overuse_cap / 6000000.0 102 | 103 | if len(self.receiving_rate_list) == self.state_length: 104 | self.receiving_rate_list.pop(0) 105 | self.delay_list.pop(0) 106 | self.loss_ratio_list.pop(0) 107 | 108 | self.counter += 1 109 | 110 | if self.counter % 4 == 0: 111 | self.time_to_guide = True 112 | self.counter = 0 113 | 114 | # 2. RL-Agent tunes the bandwidth estimated by the heuristic scheme 115 | if self.time_to_guide == True: 116 | action, _, _, _ = self.ppo.policy.forward(self.state) 117 | self.bandwidth_prediction = self.gcc_decision * pow(2, (2 * action - 1)) 118 | self.gcc_estimator.change_bandwidth_estimation(self.bandwidth_prediction) 119 | self.time_to_guide = False 120 | else: 121 | self.bandwidth_prediction = self.gcc_decision 122 | 123 | return self.bandwidth_prediction 124 | -------------------------------------------------------------------------------- /BandwidthEstimator_gcc.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | kMinNumDeltas = 60 4 | threshold_gain_ = 4 5 | kBurstIntervalMs = 5 6 | kTrendlineWindowSize = 20 7 | kTrendlineSmoothingCoeff = 0.9 8 | kOverUsingTimeThreshold = 10 9 | kMaxAdaptOffsetMs = 15.0 10 | eta = 1.08 11 | alpha = 0.85 12 | k_up_ = 0.0087 13 | k_down_ = 0.039 14 | Time_Interval = 200 15 | 16 | 17 | class GCCEstimator(object): 18 | def __init__(self): 19 | self.packets_list = [] 20 | self.packet_group = [] 21 | self.first_group_complete_time = -1 22 | 23 | self.acc_delay = 0 24 | self.smoothed_delay = 0 25 | self.acc_delay_list = collections.deque([]) 26 | self.smoothed_delay_list = collections.deque([]) 27 | 28 | self.state = 'Hold' 29 | self.last_bandwidth_estimation = 300 * 1000 30 | self.avg_max_bitrate_kbps_ = -1 31 | self.var_max_bitrate_kbps_ = -1 32 | self.rate_control_region_ = "kRcMaxUnknown" 33 | self.time_last_bitrate_change_ = -1 34 | 35 | self.gamma1 = 12.5 36 | self.num_of_deltas_ = 0 37 | self.time_over_using = -1 38 | self.prev_trend = 0.0 39 | self.overuse_counter = 0 40 | self.overuse_flag = 'NORMAL' 41 | self.last_update_ms = -1 42 | self.last_update_threshold_ms = -1 43 | self.now_ms = -1 44 | 45 | # reset estimator according to rtc_env_gcc 46 | def reset(self): 47 | self.packets_list = [] 48 | self.packet_group = [] 49 | self.first_group_complete_time = -1 50 | 51 | self.acc_delay = 0 52 | self.smoothed_delay = 0 53 | self.acc_delay_list = collections.deque([]) 54 | self.smoothed_delay_list = collections.deque([]) 55 | 56 | self.state = 'Hold' 57 | self.last_bandwidth_estimation = 300 * 1000 58 | self.avg_max_bitrate_kbps_ = -1 59 | self.var_max_bitrate_kbps_ = -1 60 | self.rate_control_region_ = "kRcMaxUnknown" 61 | self.time_last_bitrate_change_ = -1 62 | 63 | self.gamma1 = 12.5 64 | self.num_of_deltas_ = 0 65 | self.time_over_using = -1 66 | self.prev_trend = 0.0 67 | self.overuse_counter = 0 68 | self.overuse_flag = 'NORMAL' 69 | self.last_update_ms = -1 70 | self.last_update_threshold_ms = -1 71 | self.now_ms = -1 72 | 73 | def report_states(self, stats: dict): 74 | ''' 75 | Store all packet header information for packets received within 200ms in packets_list 76 | ''' 77 | pkt = stats 78 | packet_info = PacketInfo() 79 | packet_info.payload_type = pkt["payload_type"] 80 | packet_info.ssrc = pkt["ssrc"] 81 | packet_info.sequence_number = pkt["sequence_number"] 82 | packet_info.send_timestamp = pkt["send_time_ms"] 83 | packet_info.receive_timestamp = pkt["arrival_time_ms"] 84 | packet_info.padding_length = pkt["padding_length"] 85 | packet_info.header_length = pkt["header_length"] 86 | packet_info.payload_size = pkt["payload_size"] 87 | packet_info.size = pkt["header_length"] + pkt["payload_size"] + pkt["padding_length"] 88 | packet_info.bandwidth_prediction = self.last_bandwidth_estimation 89 | self.now_ms = packet_info.receive_timestamp # use the arrival time of the last packet as the system time 90 | 91 | self.packets_list.append(packet_info) 92 | 93 | def get_estimated_bandwidth(self) -> int: 94 | ''' 95 | Calculate estimated bandwidth 96 | ''' 97 | BWE_by_delay, flag = self.get_estimated_bandwidth_by_delay() 98 | BWE_by_loss = self.get_estimated_bandwidth_by_loss() 99 | bandwidth_estimation = min(BWE_by_delay, BWE_by_loss) 100 | if flag == True: 101 | self.packets_list = [] 102 | self.last_bandwidth_estimation = bandwidth_estimation 103 | return bandwidth_estimation,self.overuse_flag 104 | 105 | def get_inner_estimation(self): 106 | BWE_by_delay, flag = self.get_estimated_bandwidth_by_delay() 107 | BWE_by_loss = self.get_estimated_bandwidth_by_loss() 108 | bandwidth_estimation = min(BWE_by_delay, BWE_by_loss) 109 | if flag == True: 110 | self.packets_list = [] 111 | return BWE_by_delay,BWE_by_loss 112 | 113 | def change_bandwidth_estimation(self,bandwidth_prediction): 114 | self.last_bandwidth_estimation = bandwidth_prediction 115 | 116 | def get_estimated_bandwidth_by_delay(self): 117 | ''' 118 | Bandwidth estimation based on delay 119 | ''' 120 | if len(self.packets_list) == 0: # no packet is received within this time interval 121 | return self.last_bandwidth_estimation, False 122 | 123 | # 1. Divide packets 124 | pkt_group_list = self.divide_packet_group() 125 | if len(pkt_group_list) < 2: 126 | return self.last_bandwidth_estimation, False 127 | 128 | # 2. Calculate the packet gradient 129 | send_time_delta_list, _, _, delay_gradient_list = self.compute_deltas_for_pkt_group(pkt_group_list) 130 | 131 | # 3. Calculate the trendline 132 | trendline = self.trendline_filter(delay_gradient_list, pkt_group_list) 133 | if trendline == None: 134 | return self.last_bandwidth_estimation, False 135 | 136 | # 4. Determine the current network status 137 | self.overuse_detector(trendline, sum(send_time_delta_list)) 138 | 139 | # 5. Determine direction to bandwidth adjustment 140 | state = self.ChangeState() 141 | 142 | # 6. Adjust estimated bandwidth 143 | bandwidth_estimation = self.rate_adaptation_by_delay(state) 144 | 145 | return bandwidth_estimation, True 146 | 147 | def get_estimated_bandwidth_by_loss(self) -> int: 148 | ''' 149 | Bandwidth estimation based on packet loss 150 | ''' 151 | loss_rate = self.caculate_loss_rate() 152 | if loss_rate == -1: 153 | return self.last_bandwidth_estimation 154 | 155 | bandwidth_estimation = self.rate_adaptation_by_loss(loss_rate) 156 | return bandwidth_estimation 157 | 158 | def caculate_loss_rate(self): 159 | ''' 160 | Calculate the packet loss rate in this time interval 161 | ''' 162 | flag = False 163 | valid_packets_num = 0 164 | min_sequence_number, max_sequence_number = 0, 0 165 | if len(self.packets_list) == 0: # no packet is received within this time interval 166 | return -1 167 | for i in range(len(self.packets_list)): 168 | if self.packets_list[i].payload_type == 126: 169 | if not flag: 170 | min_sequence_number = self.packets_list[i].sequence_number 171 | max_sequence_number = self.packets_list[i].sequence_number 172 | flag = True 173 | valid_packets_num += 1 174 | min_sequence_number = min(min_sequence_number, self.packets_list[i].sequence_number) 175 | max_sequence_number = max(max_sequence_number, self.packets_list[i].sequence_number) 176 | if (max_sequence_number - min_sequence_number) == 0: 177 | return -1 178 | receive_rate = valid_packets_num / (max_sequence_number - min_sequence_number) 179 | loss_rate = 1 - receive_rate 180 | return loss_rate 181 | 182 | def rate_adaptation_by_loss(self, loss_rate) -> int: 183 | ''' 184 | Caculate Bandwidth estimation based on packet loss rate 185 | ''' 186 | bandwidth_estimation = self.last_bandwidth_estimation 187 | if loss_rate > 0.1: 188 | bandwidth_estimation = self.last_bandwidth_estimation * (1 - 0.5 * loss_rate) 189 | elif loss_rate < 0.02: 190 | bandwidth_estimation = 1.05 * self.last_bandwidth_estimation 191 | return bandwidth_estimation 192 | 193 | def divide_packet_group(self): 194 | ''' 195 | Divide packets 196 | ''' 197 | pkt_group_list = [] 198 | first_send_time_in_group = self.packets_list[0].send_timestamp 199 | 200 | pkt_group = [self.packets_list[0]] 201 | for pkt in self.packets_list[1:]: 202 | if pkt.send_timestamp - first_send_time_in_group <= kBurstIntervalMs: 203 | pkt_group.append(pkt) 204 | else: 205 | pkt_group_list.append(PacketGroup(pkt_group)) 206 | if self.first_group_complete_time == -1: 207 | self.first_group_complete_time = pkt_group[-1].receive_timestamp 208 | first_send_time_in_group = pkt.send_timestamp 209 | pkt_group = [pkt] 210 | return pkt_group_list 211 | 212 | def compute_deltas_for_pkt_group(self, pkt_group_list): 213 | ''' 214 | Calculate the packet group time difference 215 | ''' 216 | send_time_delta_list, arrival_time_delta_list, group_size_delta_list, delay_gradient_list = [], [], [], [] 217 | for idx in range(1, len(pkt_group_list)): 218 | send_time_delta = pkt_group_list[idx].send_time_list[-1] - pkt_group_list[idx - 1].send_time_list[-1] 219 | arrival_time_delta = pkt_group_list[idx].arrival_time_list[-1] - pkt_group_list[idx - 1].arrival_time_list[ 220 | -1] 221 | group_size_delta = pkt_group_list[idx].pkt_group_size - pkt_group_list[idx - 1].pkt_group_size 222 | delay = arrival_time_delta - send_time_delta 223 | self.num_of_deltas_ += 1 224 | send_time_delta_list.append(send_time_delta) 225 | arrival_time_delta_list.append(arrival_time_delta) 226 | group_size_delta_list.append(group_size_delta) 227 | delay_gradient_list.append(delay) 228 | 229 | return send_time_delta_list, arrival_time_delta_list, group_size_delta_list, delay_gradient_list 230 | 231 | def trendline_filter(self, delay_gradient_list, pkt_group_list): 232 | ''' 233 | Calculate the trendline from the delay gradient of the packet 234 | ''' 235 | for i, delay_gradient in enumerate(delay_gradient_list): 236 | accumulated_delay = self.acc_delay + delay_gradient 237 | smoothed_delay = kTrendlineSmoothingCoeff * self.smoothed_delay + ( 238 | 1 - kTrendlineSmoothingCoeff) * accumulated_delay 239 | 240 | self.acc_delay = accumulated_delay 241 | self.smoothed_delay = smoothed_delay 242 | 243 | arrival_time_ms = pkt_group_list[i + 1].complete_time 244 | self.acc_delay_list.append(arrival_time_ms - self.first_group_complete_time) 245 | 246 | self.smoothed_delay_list.append(smoothed_delay) 247 | if len(self.acc_delay_list) > kTrendlineWindowSize: 248 | self.acc_delay_list.popleft() 249 | self.smoothed_delay_list.popleft() 250 | if len(self.acc_delay_list) == kTrendlineWindowSize: 251 | avg_acc_delay = sum(self.acc_delay_list) / len(self.acc_delay_list) 252 | avg_smoothed_delay = sum(self.smoothed_delay_list) / len(self.smoothed_delay_list) 253 | 254 | numerator = 0 255 | denominator = 0 256 | for i in range(kTrendlineWindowSize): 257 | numerator += (self.acc_delay_list[i] - avg_acc_delay) * ( 258 | self.smoothed_delay_list[i] - avg_smoothed_delay) 259 | denominator += (self.acc_delay_list[i] - avg_acc_delay) * (self.acc_delay_list[i] - avg_acc_delay) 260 | 261 | trendline = numerator / (denominator + 1e-05) 262 | else: 263 | trendline = None 264 | self.acc_delay_list.clear() 265 | self.smoothed_delay_list.clear() 266 | self.acc_delay = 0 267 | self.smoothed_delay = 0 268 | return trendline 269 | 270 | def overuse_detector(self, trendline, ts_delta): 271 | """ 272 | Determine the current network status 273 | """ 274 | # self.overuse_flag = 'NORMAL' 275 | now_ms = self.now_ms 276 | if self.num_of_deltas_ < 2: 277 | return 278 | 279 | modified_trend = trendline * min(self.num_of_deltas_, kMinNumDeltas) * threshold_gain_ 280 | 281 | if modified_trend > self.gamma1: 282 | if self.time_over_using == -1: 283 | self.time_over_using = ts_delta / 2 284 | else: 285 | self.time_over_using += ts_delta 286 | self.overuse_counter += 1 287 | if self.time_over_using > kOverUsingTimeThreshold and self.overuse_counter > 1: 288 | if trendline > self.prev_trend: 289 | self.time_over_using = 0 290 | self.overuse_counter = 0 291 | self.overuse_flag = 'OVERUSE' 292 | elif modified_trend < -self.gamma1: 293 | self.time_over_using = -1 294 | self.overuse_counter = 0 295 | self.overuse_flag = 'UNDERUSE' 296 | else: 297 | self.time_over_using = -1 298 | self.overuse_counter = 0 299 | self.overuse_flag = 'NORMAL' 300 | 301 | self.prev_trend = trendline 302 | self.update_threthold(modified_trend, now_ms) 303 | 304 | def update_threthold(self, modified_trend, now_ms): 305 | ''' 306 | Update the threshold for determining overload 307 | ''' 308 | if self.last_update_threshold_ms == -1: 309 | self.last_update_threshold_ms = now_ms 310 | if abs(modified_trend) > self.gamma1 + kMaxAdaptOffsetMs: 311 | self.last_update_threshold_ms = now_ms 312 | return 313 | if abs(modified_trend) < self.gamma1: 314 | k = k_down_ 315 | else: 316 | k = k_up_ 317 | kMaxTimeDeltaMs = 100 318 | time_delta_ms = min(now_ms - self.last_update_threshold_ms, kMaxTimeDeltaMs) 319 | self.gamma1 += k * (abs(modified_trend) - self.gamma1) * time_delta_ms 320 | if (self.gamma1 < 6): 321 | self.gamma1 = 6 322 | elif (self.gamma1 > 600): 323 | self.gamma1 = 600 324 | self.last_update_threshold_ms = now_ms 325 | 326 | def state_transfer(self): 327 | ''' 328 | Update the direction of estimated bandwidth adjustment 329 | ''' 330 | newstate = None 331 | overuse_flag = self.overuse_flag 332 | if self.state == 'Decrease' and overuse_flag == 'OVERUSE': 333 | newstate = 'Decrease' 334 | elif self.state == 'Decrease' and (overuse_flag == 'NORMAL' or overuse_flag == 'UNDERUSE'): 335 | newstate = 'Hold' 336 | elif self.state == 'Hold' and overuse_flag == 'OVERUSE': 337 | newstate = 'Decrease' 338 | elif self.state == 'Hold' and overuse_flag == 'NORMAL': 339 | newstate = 'Increase' 340 | elif self.state == 'Hold' and overuse_flag == 'UNDERUSE': 341 | newstate = 'Hold' 342 | elif self.state == 'Increase' and overuse_flag == 'OVERUSE': 343 | newstate = 'Decrease' 344 | elif self.state == 'Increase' and overuse_flag == 'NORMAL': 345 | newstate = 'Increase' 346 | elif self.state == 'Increase' and overuse_flag == 'UNDERUSE': 347 | newstate = 'Hold' 348 | else: 349 | print('Wrong state!') 350 | self.state = newstate 351 | return newstate 352 | 353 | def ChangeState(self): 354 | overuse_flag = self.overuse_flag 355 | if overuse_flag == 'NORMAL': 356 | if self.state == 'Hold': 357 | self.state = 'Increase' 358 | elif overuse_flag == 'OVERUSE': 359 | if self.state != 'Decrease': 360 | self.state = 'Decrease' 361 | elif overuse_flag == 'UNDERUSE': 362 | self.state = 'Hold' 363 | return self.state 364 | 365 | def rate_adaptation_by_delay(self, state): 366 | ''' 367 | Determine the final bandwidth estimation 368 | ''' 369 | estimated_throughput = 0 370 | for pkt in self.packets_list: 371 | estimated_throughput += pkt.size 372 | if len(self.packets_list) == 0: 373 | estimated_throughput_bps = 0 374 | else: 375 | time_delta = self.now_ms - self.packets_list[0].receive_timestamp 376 | time_delta = max(time_delta , Time_Interval) 377 | estimated_throughput_bps = 1000 * 8 * estimated_throughput / time_delta 378 | estimated_throughput_kbps = estimated_throughput_bps / 1000 379 | 380 | troughput_based_limit = 3 * estimated_throughput_bps + 10 381 | ''' 382 | Calculate the standard deviation of the maximum throughput 383 | ''' 384 | self.UpdateMaxThroughputEstimate(estimated_throughput_kbps) 385 | std_max_bit_rate = pow(self.var_max_bitrate_kbps_ * self.avg_max_bitrate_kbps_, 0.5) 386 | 387 | if state == 'Increase': 388 | if self.avg_max_bitrate_kbps_ >= 0 and \ 389 | estimated_throughput_kbps > self.avg_max_bitrate_kbps_ + 3 * std_max_bit_rate: 390 | self.avg_max_bitrate_kbps_ = -1.0 391 | self.rate_control_region_ = "kRcMaxUnknown" 392 | 393 | if self.rate_control_region_ == "kRcNearMax": 394 | # Already close to maximum, additivity increase 395 | additive_increase_bps = self.AdditiveRateIncrease(self.now_ms, self.time_last_bitrate_change_) 396 | bandwidth_estimation = self.last_bandwidth_estimation + additive_increase_bps 397 | elif self.rate_control_region_ == "kRcMaxUnknown": 398 | # Maximum value unknown, multiplicative increase 399 | multiplicative_increase_bps = self.MultiplicativeRateIncrease(self.now_ms, 400 | self.time_last_bitrate_change_) 401 | bandwidth_estimation = self.last_bandwidth_estimation + multiplicative_increase_bps 402 | else: 403 | print("error!") 404 | bandwidth_estimation = min(bandwidth_estimation,troughput_based_limit) 405 | self.time_last_bitrate_change_ = self.now_ms 406 | elif state == 'Decrease': 407 | beta = 0.85 408 | bandwidth_estimation = beta * estimated_throughput_bps + 0.5 409 | if bandwidth_estimation > self.last_bandwidth_estimation: 410 | if self.rate_control_region_ != "kRcMaxUnknown": 411 | bandwidth_estimation = (beta * self.avg_max_bitrate_kbps_ * 1000 + 0.5) 412 | bandwidth_estimation = min(bandwidth_estimation, self.last_bandwidth_estimation) 413 | self.rate_control_region_ = "kRcNearMax" 414 | 415 | if estimated_throughput_kbps < self.avg_max_bitrate_kbps_-3*std_max_bit_rate: 416 | self.avg_max_bitrate_kbps_ = -1 417 | self.UpdateMaxThroughputEstimate(estimated_throughput_kbps) 418 | 419 | self.state='Hold' 420 | self.time_last_bitrate_change_ = self.now_ms 421 | elif state == 'Hold': 422 | bandwidth_estimation = self.last_bandwidth_estimation 423 | else: 424 | print('Wrong State!') 425 | return bandwidth_estimation 426 | 427 | def AdditiveRateIncrease(self, now_ms, last_ms): 428 | """ 429 | Implementation of additive rate growth algorithm 430 | """ 431 | sum_packet_size, avg_packet_size = 0, 0 432 | for pkt in self.packets_list: 433 | sum_packet_size += pkt.size 434 | avg_packet_size = 8 * sum_packet_size / len(self.packets_list) 435 | 436 | beta = 0.0 437 | RTT = 2 * (self.packets_list[-1].receive_timestamp - self.packets_list[-1].send_timestamp) 438 | response_time = 200 439 | 440 | if last_ms > 0: 441 | beta = min(((now_ms - last_ms) / response_time), 1.0) 442 | additive_increase_bps = max(800, beta * avg_packet_size) 443 | return additive_increase_bps 444 | 445 | def MultiplicativeRateIncrease(self, now_ms, last_ms): 446 | """ 447 | Implementation of Multiplicative rate growth algorithm 448 | """ 449 | alpha = 1.08 450 | if last_ms > -1: 451 | time_since_last_update_ms = min(now_ms - last_ms, 1000) 452 | alpha = pow(alpha, time_since_last_update_ms / 1000) 453 | multiplicative_increase_bps = max(self.last_bandwidth_estimation * (alpha - 1.0), 1000.0) 454 | return multiplicative_increase_bps 455 | 456 | def UpdateMaxThroughputEstimate(self, estimated_throughput_kbps): 457 | """ 458 | Update estimates of the maximum throughput 459 | """ 460 | alpha = 0.05 461 | if self.avg_max_bitrate_kbps_ == -1: 462 | self.avg_max_bitrate_kbps_ = estimated_throughput_kbps 463 | else: 464 | self.avg_max_bitrate_kbps_ = (1 - alpha) * self.avg_max_bitrate_kbps_ + alpha * estimated_throughput_kbps 465 | norm = max(self.avg_max_bitrate_kbps_, 1.0) 466 | var_value = pow((self.avg_max_bitrate_kbps_ - estimated_throughput_kbps), 2) / norm 467 | self.var_max_bitrate_kbps_ = (1 - alpha) * self.var_max_bitrate_kbps_ + alpha * var_value 468 | if self.var_max_bitrate_kbps_ < 0.4: 469 | self.var_max_bitrate_kbps_ = 0.4 470 | if self.var_max_bitrate_kbps_ > 2.5: 471 | self.var_max_bitrate_kbps_ = 2.5 472 | 473 | 474 | class PacketInfo: 475 | def __init__(self): 476 | self.payload_type = None 477 | self.sequence_number = None # int 478 | self.send_timestamp = None # int, ms 479 | self.ssrc = None # int 480 | self.padding_length = None # int, B 481 | self.header_length = None # int, B 482 | self.receive_timestamp = None # int, ms 483 | self.payload_size = None # int, B 484 | self.bandwidth_prediction = None # int, bps 485 | 486 | class PacketGroup: 487 | def __init__(self, pkt_group): 488 | self.pkts = pkt_group 489 | self.arrival_time_list = [pkt.receive_timestamp for pkt in pkt_group] 490 | self.send_time_list = [pkt.send_timestamp for pkt in pkt_group] 491 | self.pkt_group_size = sum([pkt.size for pkt in pkt_group]) 492 | self.pkt_num_in_group = len(pkt_group) 493 | self.complete_time = self.arrival_time_list[-1] 494 | self.transfer_duration = self.arrival_time_list[-1] - self.arrival_time_list[0] 495 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HRCC -------------------------------------------------------------------------------- /azure-pipelines.yml: -------------------------------------------------------------------------------- 1 | # Starter pipeline 2 | # Start with a minimal pipeline that you can customize to build and deploy your code. 3 | # Add steps that build, run tests, deploy, and more: 4 | # https://aka.ms/yaml 5 | 6 | trigger: 7 | branches: 8 | include: 9 | - "*" 10 | exclude: 11 | - upstream/* 12 | 13 | 14 | pool: 15 | vmImage: 'ubuntu-latest' 16 | 17 | steps: 18 | 19 | - checkout: self 20 | 21 | - script: | 22 | sudo apt install -y python3 python3-pip 23 | python3 -m pip install torch pytest 24 | displayName: "Install runtime dependencies" 25 | 26 | - script: | 27 | python3 -m pytest tests 28 | displayName: "Run test" 29 | 30 | - script: docker pull opennetlab.azurecr.io/challenge-env 31 | displayName: 'Get challenge environment' 32 | 33 | - script: | 34 | wget https://raw.githubusercontent.com/OpenNetLab/AlphaRTC/main/examples/peerconnection/serverless/corpus/receiver_pyinfer.json -O receiver_pyinfer.json 35 | wget https://raw.githubusercontent.com/OpenNetLab/AlphaRTC/main/examples/peerconnection/serverless/corpus/sender_pyinfer.json -O sender_pyinfer.json 36 | mkdir testmedia 37 | wget https://github.com/OpenNetLab/AlphaRTC/raw/main/examples/peerconnection/serverless/corpus/testmedia/test.wav -O testmedia/test.wav 38 | wget https://raw.githubusercontent.com/OpenNetLab/AlphaRTC/main/examples/peerconnection/serverless/corpus/testmedia/test.yuv -O testmedia/test.yuv 39 | displayName: 'Fetch corpus' 40 | 41 | - script: docker run -d --rm -v `pwd`:/app -w /app --name alphartc_pyinfer opennetlab.azurecr.io/challenge-env peerconnection_serverless receiver_pyinfer.json 42 | && docker exec alphartc_pyinfer peerconnection_serverless sender_pyinfer.json 43 | displayName: 'Run this example' 44 | -------------------------------------------------------------------------------- /deep_rl/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thegreatwb/HRCC/2f7844da456e42b4a0b0eabe5f3a9be86d39356e/deep_rl/.DS_Store -------------------------------------------------------------------------------- /deep_rl/__pycache__/actor_critic_cnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thegreatwb/HRCC/2f7844da456e42b4a0b0eabe5f3a9be86d39356e/deep_rl/__pycache__/actor_critic_cnn.cpython-38.pyc -------------------------------------------------------------------------------- /deep_rl/__pycache__/actor_critic_cnn_plus.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thegreatwb/HRCC/2f7844da456e42b4a0b0eabe5f3a9be86d39356e/deep_rl/__pycache__/actor_critic_cnn_plus.cpython-38.pyc -------------------------------------------------------------------------------- /deep_rl/__pycache__/ppo_agent.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thegreatwb/HRCC/2f7844da456e42b4a0b0eabe5f3a9be86d39356e/deep_rl/__pycache__/ppo_agent.cpython-38.pyc -------------------------------------------------------------------------------- /deep_rl/__pycache__/ppo_agent_plus.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thegreatwb/HRCC/2f7844da456e42b4a0b0eabe5f3a9be86d39356e/deep_rl/__pycache__/ppo_agent_plus.cpython-38.pyc -------------------------------------------------------------------------------- /deep_rl/actor_critic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | from torch import nn 6 | from torch.distributions import MultivariateNormal 7 | import torch.nn.functional as F 8 | 9 | import numpy as np 10 | 11 | 12 | class ActorCritic(nn.Module): 13 | def __init__(self, state_dim, state_length, action_dim, exploration_param=0.05, device="cpu"): 14 | super(ActorCritic, self).__init__() 15 | 16 | self.layer1_shape = 128 17 | self.layer2_shape = 128 18 | self.numFcInput = 6144 19 | 20 | self.rConv1d = nn.Conv1d(1, self.layer1_shape, 3) 21 | self.dConv1d = nn.Conv1d(1, self.layer1_shape, 3) 22 | self.lConv1d = nn.Conv1d(1, self.layer1_shape, 3) 23 | self.pConv1d = nn.Conv1d(1, self.layer1_shape, 3) 24 | self.oConv1d = nn.Conv1d(1, self.layer1_shape, 3) 25 | self.cConv1d = nn.Conv1d(1, self.layer1_shape, 3) 26 | 27 | 28 | self.rConv1d_critic = nn.Conv1d(1, self.layer1_shape, 3) 29 | self.dConv1d_critic = nn.Conv1d(1, self.layer1_shape, 3) 30 | self.lConv1d_critic = nn.Conv1d(1, self.layer1_shape, 3) 31 | self.pConv1d_critic = nn.Conv1d(1, self.layer1_shape, 3) 32 | self.oConv1d_critic = nn.Conv1d(1, self.layer1_shape, 3) 33 | self.cConv1d_critic = nn.Conv1d(1, self.layer1_shape, 3) 34 | 35 | self.fc = nn.Linear(self.numFcInput, self.layer2_shape) 36 | self.actor_output = nn.Linear(self.layer2_shape, action_dim) 37 | self.critic_output = nn.Linear(self.layer2_shape, 1) 38 | self.device = device 39 | self.action_var = torch.full((action_dim,), exploration_param**2).to(self.device) 40 | self.random_action = False 41 | 42 | 43 | def forward(self, inputs): 44 | # actor 45 | receivingConv = F.relu(self.rConv1d(inputs[:, 0:1, :]), inplace=True) 46 | delayConv = F.relu(self.dConv1d(inputs[:, 1:2, :]), inplace=True) 47 | lossConv = F.relu(self.lConv1d(inputs[:, 2:3, :]), inplace=True) 48 | predicationConv = F.relu(self.pConv1d(inputs[:, 3:4, :]), inplace=True) 49 | overusedisConv = F.relu(self.oConv1d(inputs[:, 4:5, :]), inplace=True) 50 | oversuecapConv = F.relu(self.cConv1d(inputs[:, 5:6, :]), inplace=True) 51 | receiving_flatten = receivingConv.view(receivingConv.shape[0], -1) 52 | delay_flatten = delayConv.view(delayConv.shape[0], -1) 53 | loss_flatten = lossConv.view(lossConv.shape[0], -1) 54 | predication_flatten = predicationConv.view(predicationConv.shape[0], -1) 55 | overusedis_flatten = overusedisConv.view( overusedisConv.shape[0], -1) 56 | oversuecap_flatten = oversuecapConv.view(oversuecapConv.shape[0], -1) 57 | 58 | merge = torch.cat([receiving_flatten, delay_flatten, loss_flatten, predication_flatten,overusedis_flatten,oversuecap_flatten], 1) 59 | fcOut = F.relu(self.fc(merge), inplace=True) 60 | action_mean = torch.sigmoid(self.actor_output(fcOut)) 61 | cov_mat = torch.diag(self.action_var).to(self.device) 62 | dist = MultivariateNormal(action_mean, cov_mat) 63 | if not self.random_action: 64 | action = action_mean 65 | else: 66 | action = dist.sample() 67 | action_logprobs = dist.log_prob(action) 68 | #critic 69 | receivingConv_critic = F.relu(self.rConv1d_critic(inputs[:, 0:1, :]), inplace=True) 70 | delayConv_critic = F.relu(self.dConv1d_critic(inputs[:, 1:2, :]), inplace=True) 71 | lossConv_critic = F.relu(self.lConv1d_critic(inputs[:, 2:3, :]), inplace=True) 72 | predicationConv_critic = F.relu(self.pConv1d_critic(inputs[:, 3:4, :]), inplace=True) 73 | overusedisConv_critic = F.relu(self.oConv1d_critic(inputs[:, 4:5, :]), inplace=True) 74 | oversuecapConv_critic = F.relu(self.cConv1d_critic(inputs[:, 5:6, :]), inplace=True) 75 | receiving_flatten_critic = receivingConv_critic.view(receivingConv_critic.shape[0], -1) 76 | delay_flatten_critic = delayConv_critic.view(delayConv_critic.shape[0], -1) 77 | loss_flatten_critic = lossConv_critic.view(lossConv_critic.shape[0], -1) 78 | predication_flatten_critic = predicationConv_critic.view(predicationConv_critic.shape[0], -1) 79 | overusedis_flatten_critic = overusedisConv_critic.view(overusedisConv_critic.shape[0], -1) 80 | oversuecap_flatten_critic = oversuecapConv_critic.view(oversuecapConv_critic.shape[0], -1) 81 | merge_critic = torch.cat([receiving_flatten_critic, delay_flatten_critic, loss_flatten_critic,predication_flatten_critic,overusedis_flatten_critic,oversuecap_flatten_critic ], 1) 82 | 83 | fcOut_critic = F.relu(self.fc(merge_critic), inplace=True) 84 | value = self.critic_output(fcOut_critic) 85 | 86 | return action.detach(), action_logprobs, value, action_mean 87 | 88 | def evaluate(self, state, action): 89 | _, _, value, action_mean = self.forward(state) 90 | cov_mat = torch.diag(self.action_var).to(self.device) 91 | dist = MultivariateNormal(action_mean, cov_mat) 92 | 93 | action_logprobs = dist.log_prob(action) 94 | dist_entropy = dist.entropy() 95 | 96 | return action_logprobs, torch.squeeze(value), dist_entropy 97 | 98 | -------------------------------------------------------------------------------- /deep_rl/ppo_agent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | import numpy as np 6 | import time 7 | from torch import nn 8 | from .actor_critic import ActorCritic 9 | 10 | 11 | class PPO: 12 | def __init__(self, state_dim, state_length, action_dim, exploration_param, lr, betas, gamma, ppo_epoch, ppo_clip, use_gae=False): 13 | self.lr = lr 14 | self.betas = betas 15 | self.gamma = gamma 16 | self.ppo_clip = ppo_clip 17 | self.ppo_epoch = ppo_epoch 18 | self.use_gae = use_gae 19 | 20 | self.device = torch.device("cpu") 21 | self.policy = ActorCritic(state_dim, state_length, action_dim, exploration_param, self.device).to(self.device) 22 | self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas) 23 | 24 | self.policy_old = ActorCritic(state_dim, state_length, action_dim, exploration_param, self.device).to(self.device) 25 | self.policy_old.load_state_dict(self.policy.state_dict()) 26 | 27 | def select_action(self, state, storage): 28 | state = torch.FloatTensor(state).to(self.device) 29 | action, action_logprobs, value, action_mean= self.policy_old.forward(state) 30 | 31 | storage.logprobs.append(action_logprobs) 32 | storage.values.append(value) 33 | storage.states.append(state) 34 | storage.actions.append(action) 35 | return action 36 | 37 | def get_value(self, state): 38 | action, action_logprobs, value, action_mean = self.policy_old.forward(state) 39 | return value 40 | 41 | def update(self, storage, state): 42 | episode_policy_loss = 0 43 | episode_value_loss = 0 44 | if self.use_gae: 45 | raise NotImplementedError 46 | advantages = (torch.tensor(storage.returns) - torch.tensor(storage.values)).detach() 47 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5) 48 | 49 | old_states = torch.squeeze(torch.stack(storage.states).to(self.device), 1).detach() 50 | old_actions = torch.squeeze(torch.stack(storage.actions).to(self.device), 1).detach() 51 | old_action_logprobs = torch.squeeze(torch.stack(storage.logprobs), 1).to(self.device).detach() 52 | old_returns = torch.squeeze(torch.stack(storage.returns), 1).to(self.device).detach() 53 | 54 | for t in range(self.ppo_epoch): 55 | logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions) 56 | ratios = torch.exp(logprobs - old_action_logprobs) 57 | 58 | surr1 = ratios * advantages 59 | surr2 = torch.clamp(ratios, 1-self.ppo_clip, 1+self.ppo_clip) * advantages 60 | policy_loss = -torch.min(surr1, surr2).mean() 61 | value_loss = 0.5 * (state_values - old_returns).pow(2).mean() 62 | loss = policy_loss + value_loss 63 | 64 | self.optimizer.zero_grad() 65 | loss.backward() 66 | self.optimizer.step() 67 | episode_policy_loss += policy_loss.detach() 68 | episode_value_loss += value_loss.detach() 69 | 70 | self.policy_old.load_state_dict(self.policy.state_dict()) 71 | return episode_policy_loss / self.ppo_epoch, episode_value_loss / self.ppo_epoch 72 | 73 | def save_model(self, data_path): 74 | torch.save(self.policy.state_dict(), '{}ppo_{}.pth'.format(data_path, time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()))) 75 | -------------------------------------------------------------------------------- /onnx-model.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thegreatwb/HRCC/2f7844da456e42b4a0b0eabe5f3a9be86d39356e/onnx-model.onnx -------------------------------------------------------------------------------- /packet_info.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | class PacketInfo: 5 | def __init__(self): 6 | self.payload_type = None 7 | self.sequence_number = None # int 8 | self.send_timestamp = None # int, ms 9 | self.ssrc = None # int 10 | self.padding_length = None # int, B 11 | self.header_length = None # int, B 12 | self.receive_timestamp = None # int, ms 13 | self.payload_size = None # int, B 14 | self.bandwidth_prediction = None # int, bps 15 | 16 | def __str__(self): 17 | return ( 18 | f"receive_timestamp: { self.receive_timestamp }ms" 19 | f", send_timestamp: { self.send_timestamp }ms" 20 | f", payload_size: { self.payload_size }B" 21 | ) 22 | -------------------------------------------------------------------------------- /packet_record.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | 6 | 7 | class PacketRecord: 8 | # feature_interval can be modified 9 | def __init__(self, base_delay_ms=0): 10 | self.base_delay_ms = base_delay_ms 11 | self.reset() 12 | 13 | def reset(self): 14 | self.packet_num = 0 15 | self.packet_list = [] # ms 16 | self.last_seqNo = {} 17 | self.timer_delta = None # ms 18 | self.min_seen_delay = self.base_delay_ms # ms 19 | # ms, record the rtime of the last packet in last interval, 20 | self.last_interval_rtime = None 21 | 22 | def clear(self): 23 | self.packet_num = 0 24 | if self.packet_list: 25 | self.last_interval_rtime = self.packet_list[-1]['timestamp'] 26 | self.packet_list = [] 27 | 28 | def on_receive(self, packet_info): 29 | assert (len(self.packet_list) == 0 30 | or packet_info.receive_timestamp 31 | >= self.packet_list[-1]['timestamp']), \ 32 | "The incoming packets receive_timestamp disordered" 33 | 34 | # Calculate the loss count 35 | loss_count = 0 36 | if packet_info.ssrc in self.last_seqNo: 37 | loss_count = max(0, 38 | packet_info.sequence_number - self.last_seqNo[packet_info.ssrc] - 1) 39 | self.last_seqNo[packet_info.ssrc] = packet_info.sequence_number 40 | 41 | # Calculate packet delay 42 | if self.timer_delta is None: 43 | # shift delay of the first packet to base delay 44 | self.timer_delta = self.base_delay_ms - (packet_info.receive_timestamp - packet_info.send_timestamp) 45 | delay = self.timer_delta + packet_info.receive_timestamp - packet_info.send_timestamp 46 | 47 | self.min_seen_delay = min(delay, self.min_seen_delay) 48 | 49 | # Check the last interval rtime 50 | if self.last_interval_rtime is None: 51 | self.last_interval_rtime = packet_info.receive_timestamp 52 | 53 | # Record result in current packet 54 | packet_result = { 55 | 'timestamp': packet_info.receive_timestamp, # ms 56 | 'delay': delay, # ms 57 | 'payload_byte': packet_info.payload_size, # B 58 | 'loss_count': loss_count, # p 59 | 'bandwidth_prediction': packet_info.bandwidth_prediction # bps 60 | } 61 | self.packet_list.append(packet_result) 62 | self.packet_num += 1 63 | 64 | def _get_result_list(self, interval, key): 65 | if self.packet_num == 0: 66 | return [] 67 | 68 | result_list = [] 69 | if interval == 0: 70 | interval = self.packet_list[-1]['timestamp'] -\ 71 | self.last_interval_rtime 72 | start_time = self.packet_list[-1]['timestamp'] - interval 73 | index = self.packet_num - 1 74 | while index >= 0 and self.packet_list[index]['timestamp'] > start_time: 75 | result_list.append(self.packet_list[index][key]) 76 | index -= 1 77 | return result_list 78 | 79 | def calculate_average_delay(self, interval=0): 80 | ''' 81 | Calulate the average delay in the last interval time, 82 | interval=0 means based on the whole packets 83 | The unit of return value: ms 84 | ''' 85 | delay_list = self._get_result_list(interval=interval, key='delay') 86 | if delay_list: 87 | return np.mean(delay_list) - self.base_delay_ms 88 | # return np.mean(delay_list) 89 | else: 90 | return 0 91 | 92 | def calculate_loss_ratio(self, interval=0): 93 | ''' 94 | Calulate the loss ratio in the last interval time, 95 | interval=0 means based on the whole packets 96 | The unit of return value: packet/packet 97 | ''' 98 | loss_list = self._get_result_list(interval=interval, key='loss_count') 99 | if loss_list: 100 | loss_num = np.sum(loss_list) 101 | received_num = len(loss_list) 102 | return loss_num / (loss_num + received_num) 103 | else: 104 | return 0 105 | 106 | def calculate_receiving_rate(self, interval=0): 107 | ''' 108 | Calulate the receiving rate in the last interval time, 109 | interval=0 means based on the whole packets 110 | The unit of return value: bps 111 | ''' 112 | received_size_list = self._get_result_list(interval=interval, key='payload_byte') 113 | if received_size_list: 114 | received_nbytes = np.sum(received_size_list) 115 | if interval == 0: 116 | interval = self.packet_list[-1]['timestamp'] -\ 117 | self.last_interval_rtime 118 | return received_nbytes * 8 / interval * 1000 119 | else: 120 | return 0 121 | 122 | def calculate_latest_prediction(self): 123 | ''' 124 | The unit of return value: bps 125 | ''' 126 | if self.packet_num > 0: 127 | return self.packet_list[-1]['bandwidth_prediction'] 128 | else: 129 | return 0 130 | -------------------------------------------------------------------------------- /ppo_2021_07_25_04_57_11_with500trace.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thegreatwb/HRCC/2f7844da456e42b4a0b0eabe5f3a9be86d39356e/ppo_2021_07_25_04_57_11_with500trace.pth --------------------------------------------------------------------------------