├── LICENSE ├── README.md ├── Schaferct-reproducibility.pdf ├── code ├── detail_evaluate_on_24_sessions.py ├── evaluate_all.py ├── make_training_dataset_pickle.py ├── reorganize_by_policy_id.py └── v14_iql.py ├── onnx_model └── Schaferct_model.onnx ├── onnx_model_for_evaluation └── baseline.onnx ├── requirements.txt └── training_dataset_pickle └── README.md /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Institute of Computing Technology, Chinese Academy of Sciences 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Schaferct 2 | This is the offical repo of team **Schaferct** in [2nd Bandwidth Prediction of MMSys'24](https://www.microsoft.com/en-us/research/academic-program/bandwidth-estimation-challenge/overview/). 3 | 4 | You can find the full paper [here](https://dl.acm.org/doi/10.1145/3625468.3652183) 5 | 6 | ## Datasets 7 | 8 | Use the following scripts to download training dataset and evaluation dataset 9 | 10 | training dataset: [https://github.com/microsoft/RL4BandwidthEstimationChallenge/blob/main/download-testbed-dataset.sh](https://github.com/microsoft/RL4BandwidthEstimationChallenge/blob/main/download-testbed-dataset.sh) 11 | evaluation dataset: [https://github.com/microsoft/RL4BandwidthEstimationChallenge/blob/main/download-emulated-dataset.sh](https://github.com/microsoft/RL4BandwidthEstimationChallenge/blob/main/download-emulated-dataset.sh) 12 | 13 | ## Hardware info 14 | 15 | Ubuntu 20.04.6 LTS with a 12GB GPU (we use a NVIDIA GeForce RTX 3080 Ti) 16 | 17 | ## Experimentation 18 | 19 | 1. Clone repo and install packets 20 | 21 | ```bash 22 | git clone https://github.com/n13eho/Schaferct.git 23 | cd Schaferct 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | 2. Download datasets (links mentioned before) 28 | 3. Remake training datasets 29 | 30 | You can use the pickle we have made during our training at `./traning_dataset_pickle/v8.pickle` (you need to download it first, see [here](https://github.com/n13eho/Schaferct/blob/main/training_dataset_pickle/README.md) for detail), or remake a new training dataset by: 31 | 32 | 1. Rearrange the datasets by different type of behavior policy. Before run this script (`reorganize_by_policy_id.py`), you need to modify the two path of your downloaded datasets. 33 | 34 | ```bash 35 | mkdir ALLdatasets 36 | mkdir ALLdatasets/train 37 | mkdir ALLdatasets/evaluate 38 | cd ./code 39 | python reorganize_by_policy_id.py 40 | ``` 41 | 42 | 2. Make new pickle dataset. You can modify the `K` in `make_training_dataset_pickle.py` to put more or less sessions into training dataset 43 | 44 | ```bash 45 | python make_training_dataset_pickle.py 46 | ``` 47 | 48 | The dataset-making process takes about 1.5 hours. 49 | 50 | 4. Train model 51 | 52 | *The code we use is modified from [CORL](https://github.com/tinkoff-ai/CORL/blob/main/algorithms/offline/iql.py) 53 | 54 | You can modify the variables below: 55 | 56 | 1. `pickle_path`: path to training dataset, you can use your own dataset 57 | 2. `ENUM`: how many sessions in each policy type of evaluation dataset to evaluate model every `eval_freq` steps 58 | 3. `USE_WANDB`: trun on the [wandb](https://wandb.ai/site) or not. You can turn it on to monitor the training process by mse, error_rate, q_score, loss, etc. 59 | 60 | Then you can run: 61 | 62 | ```bash 63 | python v14_iql.py 64 | ``` 65 | 66 | The training process takes about 4 hours. 67 | 68 | 5. Evaluate models 69 | 70 | You can use the offical [evaluation scripts](https://github.com/microsoft/RL4BandwidthEstimationChallenge/blob/main/run_baseline_model.py) to evaluate your model, and here we offer two other srcipts to help evaluation, please modify path and names of variables before runing. 71 | 72 | The detail instructions are in the comments of the scripts. 73 | 74 | 1. To run a small evaluation on a [small dataset](https://github.com/microsoft/RL4BandwidthEstimationChallenge/tree/main/data): (download the 24 sessions and modify their path first) 75 | 76 | ```bash 77 | python detail_evaluate_on_24_sessions.py 78 | ``` 79 | 80 | 2. To evaluate the metrics (mse, errorate) over all evaluation dataset: 81 | 82 | ```bash 83 | python evaluate_all.py 84 | ``` 85 | 86 | The whole evaluate process takes about 2 hours. 87 | 88 | 89 | !! Once again, remember to modify/adjust/rename the path/name of variables mentioned above or in the code’s comments. 90 | -------------------------------------------------------------------------------- /Schaferct-reproducibility.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/n13eho/Schaferct/da49600ba1fb915181081cd8183c7cb13f278bc9/Schaferct-reproducibility.pdf -------------------------------------------------------------------------------- /code/detail_evaluate_on_24_sessions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : n13eho 3 | # @Time : 2024.03.25 4 | 5 | """ 6 | Evaluate models on evaluation datasets in detail. 7 | """ 8 | 9 | import glob 10 | import json 11 | import os 12 | import numpy as np 13 | from tqdm import tqdm 14 | import onnxruntime as ort 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | current_dir = os.path.split(os.path.abspath(__file__))[0] 19 | project_root_path = current_dir.rsplit('/', 1)[0] 20 | 21 | plt.rcParams.clear() 22 | plt.rcParams['font.size'] = 10 23 | plt.rcParams['font.family'] = 'Times New Roman' 24 | plt.rcParams['xtick.labelsize'] = 17 25 | plt.rcParams['ytick.labelsize'] = 17 26 | plt.rcParams['axes.labelsize'] = 17 27 | plt.rcParams['legend.fontsize'] = 12 28 | plt.rcParams['pdf.fonttype'] = 42 29 | plt.rcParams['ps.fonttype'] = 42 30 | 31 | 32 | if __name__ == "__main__": 33 | 34 | data_dir = "./data" # < modify the path to your data 35 | onnx_models = ['baseline', 'iql_v14_520k'] # < modify your onnx model names 36 | onnx_models_dir = os.path.join(project_root_path, 'onnx_model_for_evaluation') 37 | figs_dir = os.path.join(project_root_path, 'onnx_model_for_evaluation', ('_'.join(onnx_models[1:]))) 38 | if not os.path.exists(figs_dir): 39 | os.mkdir(figs_dir) 40 | data_files = glob.glob(os.path.join(data_dir, f'*.json'), recursive=True) 41 | ort_sessions = [] 42 | for m in onnx_models: 43 | m_path = os.path.join(onnx_models_dir, m + '.onnx') 44 | ort_sessions.append(ort.InferenceSession(m_path)) 45 | 46 | for filename in tqdm(data_files, desc="Processing"): 47 | with open(filename, "r") as file: 48 | call_data = json.load(file) 49 | 50 | observations = np.asarray(call_data['observations'], dtype=np.float32) 51 | bandwidth_predictions = np.asarray(call_data['bandwidth_predictions'], dtype=np.float32) 52 | true_capacity = np.asarray(call_data['true_capacity'], dtype=np.float32) 53 | 54 | baseline_model_predictions = {} 55 | for m in onnx_models: 56 | baseline_model_predictions[m] = [] 57 | hidden_state, cell_state = np.zeros((1, 1), dtype=np.float32), np.zeros((1, 1), dtype=np.float32) 58 | for t in range(observations.shape[0]): 59 | obss = observations[t:t+1,:].reshape(1,1,-1) 60 | feed_dict = {'obs': obss, 61 | 'hidden_states': hidden_state, 62 | 'cell_states': cell_state 63 | } 64 | for idx, orts in enumerate(ort_sessions): 65 | bw_prediction, hidden_state, cell_state = orts.run(None, feed_dict) 66 | baseline_model_predictions[onnx_models[idx]].append(bw_prediction[0,0,0]) 67 | 68 | 69 | for m in onnx_models: 70 | baseline_model_predictions[m] = np.asarray(baseline_model_predictions[m], dtype=np.float32) 71 | 72 | fig = plt.figure(figsize=(6, 3)) 73 | time_s = np.arange(0, observations.shape[0]*60,60)/1000 74 | for idx, m in enumerate(onnx_models): 75 | plt.plot(time_s, baseline_model_predictions[m] / 1000, linestyle='-', label=['Baseline', 'Our model'][idx], color='C' + str(idx)) 76 | plt.plot(time_s, bandwidth_predictions/1000, linestyle='--', label='Estimator ' + call_data['policy_id'], color='C' + str(len(onnx_models))) 77 | plt.plot(time_s, true_capacity/1000, label='True Capacity', color='black') 78 | plt.xlim(0, 125) 79 | plt.ylim(0) 80 | plt.ylabel("Bandwidth (Kbps)") 81 | plt.xlabel("Duration (second)") 82 | plt.grid(True) 83 | 84 | plt.legend(bbox_to_anchor=(0.5, 1.05), ncol=4, handletextpad=0.1, columnspacing=0.5, 85 | loc='center', frameon=False) 86 | 87 | plt.tight_layout() 88 | plt.savefig(os.path.join(figs_dir, os.path.basename(filename).replace(".json",".pdf")), dpi=300) 89 | plt.close() -------------------------------------------------------------------------------- /code/evaluate_all.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : n13eho 3 | # @Time : 2024.03.25 4 | 5 | """ 6 | Evaluate all models (baseline, new model, ...) on whole evaluation dataset 7 | Metircs: error rate, mse, over-estimated rate 8 | """ 9 | 10 | import json 11 | import os 12 | import numpy as np 13 | from tqdm import tqdm 14 | import onnxruntime as ort 15 | import matplotlib.pyplot as plt 16 | 17 | humm = 1e6 18 | 19 | current_dir = os.path.split(os.path.abspath(__file__))[0] 20 | project_root_path = current_dir.rsplit('/', 1)[0] 21 | emulate_dataset_dir_path = os.path.join(project_root_path, 'ALLdatasets', 'evaluate') 22 | onnx_models_dir_path = os.path.join(project_root_path, 'onnx_model_for_evaluation') 23 | onnx_model_name = ['iql_v14_520k'] # < modify your onnx model names 24 | 25 | 26 | def get_over_estimated_rate(x, y): 27 | l = [max((xx - yy) / yy, 0) for xx, yy in zip(x, y)] 28 | l = np.asarray(l, dtype=np.float32) 29 | return np.nanmean(l) 30 | 31 | def get_mse(x, y): 32 | l = [(xx - yy) ** 2 for xx, yy in zip(x, y)] 33 | l = np.asarray(l, dtype=np.float32) 34 | return np.nanmean(l) 35 | 36 | def get_error_rate(x, y): 37 | # error rate = min(1, |x-y| / y) 38 | l = [min(1, abs(xx - yy) / yy) for xx, yy in zip(x, y)] 39 | l = np.asarray(l, dtype=np.float32) 40 | return np.nanmean(l) 41 | 42 | def evaluate_every_f(e_f_path): 43 | # [behavior policy, baseline, m1, m2, m3, ...] 44 | er_perf = [] 45 | mse_perf = [] 46 | oer_perf = [] 47 | with open(e_f_path, "r") as file: 48 | call_data = json.load(file) 49 | 50 | observations = np.asarray(call_data['observations'], dtype=np.float32) 51 | 52 | behavior_policy = np.asarray(call_data['bandwidth_predictions'], dtype=np.float32) / humm 53 | true_capacity = np.asarray(call_data['true_capacity'], dtype=np.float32) / humm 54 | 55 | # first go with behavior policy 56 | er_perf.append(get_error_rate(behavior_policy, true_capacity)) 57 | mse_perf.append(get_mse(behavior_policy, true_capacity)) 58 | oer_perf.append(get_over_estimated_rate(behavior_policy, true_capacity)) 59 | 60 | # then go with these models 61 | for onnx_name in onnx_model_name: 62 | onnx_m_path = os.path.join(onnx_models_dir_path, onnx_name + '.onnx') 63 | ort_session = ort.InferenceSession(onnx_m_path) 64 | predictions = [] 65 | hc = np.zeros((1, 1), dtype=np.float32) 66 | 67 | for t in range(observations.shape[0]): 68 | feed_dict = {'obs': observations[t:t+1,:].reshape(1,1,-1), 69 | 'hidden_states': hc, 70 | 'cell_states': hc 71 | } 72 | bw_prediction, _, _ = ort_session.run(None, feed_dict) 73 | predictions.append(bw_prediction[0,0,0]) 74 | predictions = np.asarray(predictions, dtype=np.float32) / humm 75 | 76 | er_perf.append(get_error_rate(predictions, true_capacity)) 77 | mse_perf.append(get_mse(predictions, true_capacity)) 78 | oer_perf.append(get_over_estimated_rate(predictions, true_capacity)) 79 | return er_perf, mse_perf, oer_perf 80 | 81 | 82 | if __name__ == "__main__": 83 | 84 | # put all evaluation json in a list 85 | all_e_file_path = [] 86 | for sub_dir_name in os.listdir(emulate_dataset_dir_path): 87 | sub_dir_path = os.path.join(emulate_dataset_dir_path, sub_dir_name) 88 | for e_file in os.listdir(sub_dir_path): 89 | e_file_path = os.path.join(sub_dir_path, e_file) 90 | all_e_file_path.append(e_file_path) 91 | 92 | # prepare data 93 | # [behavior policy, baseline, m1, m2, m3, ...] 94 | error_rate = [] 95 | mse = [] 96 | over_estimated_rate = [] 97 | for e_file in tqdm(all_e_file_path, desc='Evaluating'): 98 | er_perf, mse_perf, oer_perf = evaluate_every_f(e_file) 99 | error_rate.append(er_perf) 100 | mse.append(mse_perf) 101 | over_estimated_rate.append(oer_perf) 102 | # get avg 103 | error_rate = np.asarray(error_rate) 104 | mse = np.asarray(mse) 105 | over_estimated_rate = np.asarray(over_estimated_rate) 106 | error_rate = np.average(error_rate, axis=0) 107 | mse = np.average(mse, axis=0) 108 | over_estimated_rate = np.average(over_estimated_rate, axis=0) 109 | 110 | # plot 111 | models_names = ['behavior policy'] 112 | models_names.extend(onnx_model_name) 113 | type_num = len(onnx_model_name) + 1 114 | x = np.arange(type_num) 115 | bar_width = 0.3 116 | 117 | fig, bar_rate = plt.subplots(figsize=(9, 5)) 118 | bar_mse = bar_rate.twinx() 119 | rects1 = bar_rate.bar(x, error_rate, bar_width, color='c') 120 | rects2 = bar_rate.bar(x + bar_width, over_estimated_rate, bar_width, color='orange') 121 | rects3 = bar_mse.bar(x + bar_width * 2, mse, bar_width, color='green') 122 | 123 | bar_rate.bar_label(rects1, padding=1, fmt='%.2f') 124 | bar_rate.bar_label(rects2, padding=1, fmt='%.2f') 125 | bar_mse.bar_label(rects3, padding=1, fmt='%.2f') 126 | 127 | bar_rate.set_ylabel('Error Rate / Over-estimated Rate (%)') 128 | bar_mse.set_ylabel('MSE (Mbps^2)') 129 | 130 | bar_rate.set_xlabel("preditive models") 131 | plt.xticks(x + bar_width, models_names) 132 | plt.legend([rects1, rects2, rects3], ['error rate', 'over-estimated rate', 'mse'], 133 | bbox_to_anchor=(0.5, 1.03), ncol=3, loc='center', frameon=False) 134 | 135 | plt.tight_layout() 136 | plt.savefig(os.path.join(current_dir, '+'.join(models_names[1:]) + ".png")) 137 | plt.close() 138 | -------------------------------------------------------------------------------- /code/make_training_dataset_pickle.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : n13eho 3 | # @Time : 2024.03.25 4 | 5 | """ 6 | Make training dataset out of K sessions from each policy type, you can modify K to get more or less data 7 | 8 | output: new_training_dataset_pickle.pickle 9 | """ 10 | 11 | 12 | import json 13 | import os 14 | current_dir = os.path.split(os.path.abspath(__file__))[0] 15 | project_root_path = current_dir.rsplit('/', 1)[0] 16 | import numpy as np 17 | import math 18 | import pickle 19 | from tqdm import tqdm 20 | import random 21 | 22 | 23 | pickle_name = 'new_training_dataset_pickle.pickle' 24 | K = 300 # every 6 chunk, 6*K sessions 25 | 26 | TESTBED_POLICY_TYPE_NUM = 6 27 | testbed_dataset_dir_path = os.path.join(project_root_path, 'ALLdatasets', 'train') 28 | 29 | def load_bwec_dataset(): 30 | obs_ = [] 31 | action_ = [] 32 | next_obs_ = [] 33 | reward_ = [] 34 | done_ = [] 35 | 36 | for testbed_chunk_idx in range(TESTBED_POLICY_TYPE_NUM): 37 | sub_dir_path = os.path.join(testbed_dataset_dir_path, 'v' + str(testbed_chunk_idx)) 38 | for session in tqdm(random.sample(os.listdir(sub_dir_path), K), desc='Processing policy_type ' + str(testbed_chunk_idx)): 39 | session_path = os.path.join(sub_dir_path, session) 40 | with open(session_path, 'r', encoding='utf-8') as jf: 41 | single_session_trajectory = json.load(jf) 42 | observations = single_session_trajectory['observations'] 43 | actions = single_session_trajectory['bandwidth_predictions'] 44 | quality_videos = single_session_trajectory['video_quality'] 45 | quality_audios = single_session_trajectory['audio_quality'] 46 | 47 | avg_q_v = np.nanmean(np.asarray(quality_videos, dtype=np.float32)) 48 | avg_q_a = np.nanmean(np.asarray(quality_audios, dtype=np.float32)) 49 | 50 | obs = [] 51 | next_obs = [] 52 | action = [] 53 | reward = [] 54 | for idx in range(len(observations)): 55 | 56 | r_v = quality_videos[idx] 57 | r_a = quality_audios[idx] 58 | if math.isnan(quality_videos[idx]): 59 | r_v = avg_q_v 60 | if math.isnan(quality_audios[idx]): 61 | r_a = avg_q_a 62 | reward.append(r_v * 1.8 + r_a * 0.2) 63 | 64 | obs.append(observations[idx]) 65 | if idx + 1 >= len(observations): 66 | next_obs.append([-1] * len(observations[0])) # s_terminal 67 | else: 68 | next_obs.append(observations[idx + 1]) 69 | action.append([actions[idx]]) 70 | 71 | done_bool = [False] * (len(obs) - 1) + [True] 72 | 73 | # check dim 74 | assert len(obs) == len(next_obs) == len(action) == len(reward) == len(done_bool), 'DIM not match' 75 | 76 | # expaned into x_ 77 | obs_.extend(obs) 78 | action_.extend(action) 79 | next_obs_.extend(next_obs) 80 | reward_.extend(reward) 81 | done_.extend(done_bool) 82 | # break 83 | 84 | return { 85 | 'observations': np.array(obs_), 86 | 'actions': np.array(action_), 87 | 'next_observations': np.array(next_obs_), 88 | 'rewards': np.array(reward_), 89 | 'terminals': np.array(done_), 90 | } 91 | 92 | if __name__ == '__main__': 93 | dataset = load_bwec_dataset() 94 | 95 | print('dumping...') 96 | dataset_file_path = os.path.join(project_root_path, 'training_dataset_pickle', pickle_name) 97 | dataset_file = open(dataset_file_path, 'wb') 98 | pickle.dump(dataset, dataset_file) 99 | dataset_file.close() -------------------------------------------------------------------------------- /code/reorganize_by_policy_id.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : n13eho 3 | # @Time : 2024.03.25 4 | 5 | """ 6 | Rearrange the datasets by different type of behavior policy 7 | 8 | Line 18 and line 20 (`src_dataset_dir_path`) need to be modified according to the train/evaluation datasets' path 9 | """ 10 | 11 | import os 12 | from tqdm import tqdm 13 | import json 14 | import shutil 15 | 16 | current_dir = os.path.split(os.path.abspath(__file__))[0] 17 | project_root_path = current_dir.rsplit('/', 1)[0] 18 | # src_dataset_dir_path = os.path.join(project_root_path, 'testbed_dataset') # < modeify me 19 | # target_dataset_dir_path = os.path.join(project_root_path, 'ALLdatasets', 'train') 20 | src_dataset_dir_path = os.path.join(project_root_path, 'emulate_dataset') # < modeify me 21 | target_dataset_dir_path = os.path.join(project_root_path, 'ALLdatasets', 'evaluate') 22 | 23 | 24 | if __name__ == "__main__": 25 | 26 | # put all json in a list 27 | all_file_path = [] 28 | all_file_name = [] 29 | for sub_dir_name in os.listdir(src_dataset_dir_path): 30 | sub_dir_path = os.path.join(src_dataset_dir_path, sub_dir_name) 31 | for call_file in os.listdir(sub_dir_path): 32 | all_file_name.append(call_file) 33 | e_file_path = os.path.join(sub_dir_path, call_file) 34 | all_file_path.append(e_file_path) 35 | 36 | print(src_dataset_dir_path) 37 | print(target_dataset_dir_path) 38 | print(len(all_file_path)) 39 | for i in tqdm(range(len(all_file_name)), desc='Moving'): 40 | call_file_name = all_file_name[i] 41 | call_file_path = all_file_path[i] 42 | with open(call_file_path, "r") as file: 43 | call_data = json.load(file) 44 | policy_id = call_data['policy_id'] 45 | target_sub_dir_path = os.path.join(target_dataset_dir_path, policy_id) 46 | target_file_path = os.path.join(target_sub_dir_path, call_file_name) 47 | # copy 48 | shutil.copyfile(call_file_path, target_file_path) 49 | 50 | -------------------------------------------------------------------------------- /code/v14_iql.py: -------------------------------------------------------------------------------- 1 | # source: https://github.com/gwthomas/IQL-PyTorch 2 | # https://arxiv.org/pdf/2110.06169.pdf 3 | import copy 4 | import os 5 | import random 6 | import uuid 7 | from dataclasses import asdict, dataclass 8 | from pathlib import Path 9 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 10 | from tqdm import tqdm 11 | 12 | import numpy as np 13 | import pyrallis 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import wandb 18 | from torch.distributions import Normal 19 | from torch.optim.lr_scheduler import CosineAnnealingLR 20 | import pickle 21 | import json 22 | import onnxruntime as ort 23 | 24 | TensorBatch = List[torch.Tensor] 25 | pickle_path = '../training_dataset_pickle/v8.pickle' 26 | evaluation_dataset_path = '../ALLdatasets/evaluate' 27 | ENUM = 20 # every 5 evaluation set 28 | small_evaluation_datasets = [] 29 | policy_dir_names = os.listdir(evaluation_dataset_path) 30 | for p_t in policy_dir_names: 31 | policy_type_dir = os.path.join(evaluation_dataset_path, p_t) 32 | for e_f_name in os.listdir(policy_type_dir)[:ENUM]: 33 | e_f_path = os.path.join(policy_type_dir, e_f_name) 34 | small_evaluation_datasets.append(e_f_path) 35 | 36 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 37 | USE_WANDB = 1 38 | b_in_Mb = 1e6 39 | 40 | MAX_ACTION = 20 # Mbps 41 | STATE_DIM = 150 42 | ACTION_DIM = 1 43 | 44 | EXP_ADV_MAX = 100.0 45 | LOG_STD_MIN = -20.0 46 | LOG_STD_MAX = 2.0 47 | 48 | @dataclass 49 | class TrainConfig: 50 | # Experiment 51 | device: str = "cuda" 52 | env: str = "v14" 53 | seed: int = 0 # Sets Gym, PyTorch and Numpy seeds 54 | eval_freq: int = int(5e3) # How often (time steps) we evaluate 55 | max_timesteps: int = int(1e6) # Max time steps to run environment 56 | checkpoints_path: Optional[str] = './checkpoints_iql' # Save path 57 | load_model: str = "" # Model load file name, "" doesn't load 58 | # IQL 59 | buffer_size: int = 6_538_000 # Replay buffer size 60 | batch_size: int = 512 # Batch size for all networks 61 | discount: float = 0.99 # Discount factor 62 | tau: float = 0.005 # Target network update rate 63 | beta: float = 3.0 # Inverse temperature. Small beta -> BC, big beta -> maximizing Q 64 | iql_tau: float = 0.7 # Coefficient for asymmetric loss 65 | iql_deterministic: bool = False # Use deterministic actor 66 | vf_lr: float = 3e-4 # V function learning rate 67 | qf_lr: float = 3e-4 # Critic learning rate 68 | actor_lr: float = 3e-4 # Actor learning rate 69 | actor_dropout: Optional[float] = None # Adroit uses dropout for policy network 70 | # Wandb logging 71 | project: str = "BWEC-Schaferct" 72 | group: str = "IQL" 73 | name: str = "IQL" 74 | 75 | def __post_init__(self): 76 | self.name = f"{self.name}-{self.env}-{str(uuid.uuid4())[:8]}" 77 | if self.checkpoints_path is not None: 78 | self.checkpoints_path = os.path.join(self.checkpoints_path, self.name) 79 | 80 | def soft_update(target: nn.Module, source: nn.Module, tau: float): 81 | for target_param, source_param in zip(target.parameters(), source.parameters()): 82 | target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data) 83 | 84 | class ReplayBuffer: 85 | def __init__( 86 | self, 87 | state_dim: int, 88 | action_dim: int, 89 | buffer_size: int, 90 | device: str = "cpu", 91 | ): 92 | self._buffer_size = buffer_size 93 | self._pointer = 0 94 | self._size = 0 95 | 96 | self._states = torch.zeros( 97 | (buffer_size, state_dim), dtype=torch.float32, device=device 98 | ) 99 | self._actions = torch.zeros( 100 | (buffer_size, action_dim), dtype=torch.float32, device=device 101 | ) 102 | self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device) 103 | self._next_states = torch.zeros( 104 | (buffer_size, state_dim), dtype=torch.float32, device=device 105 | ) 106 | self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device) 107 | self._device = device 108 | 109 | def _to_tensor(self, data: np.ndarray) -> torch.Tensor: 110 | return torch.tensor(data, dtype=torch.float32, device=self._device) 111 | 112 | # Loads data in d4rl format, i.e. from Dict[str, np.array]. 113 | def load_dataset(self, data: Dict[str, np.ndarray]): 114 | if self._size != 0: 115 | raise ValueError("Trying to load data into non-empty replay buffer") 116 | n_transitions = data["observations"].shape[0] 117 | print(f"Dataset size: {n_transitions}") 118 | if n_transitions > self._buffer_size: 119 | raise ValueError( 120 | "Replay buffer is smaller than the dataset you are trying to load!" 121 | ) 122 | self._states[:n_transitions] = self._to_tensor(data["observations"]) 123 | self._actions[:n_transitions] = self._to_tensor(data["actions"] / b_in_Mb) 124 | self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None]) 125 | self._next_states[:n_transitions] = self._to_tensor(data["next_observations"]) 126 | self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None]) 127 | self._size += n_transitions 128 | self._pointer = min(self._size, n_transitions) 129 | 130 | def sample(self, batch_size: int) -> TensorBatch: 131 | indices = np.random.randint(0, min(self._size, self._pointer), size=batch_size) 132 | states = self._states[indices] 133 | actions = self._actions[indices] 134 | rewards = self._rewards[indices] 135 | next_states = self._next_states[indices] 136 | dones = self._dones[indices] 137 | # states = torch.unsqueeze(states, 0) 138 | return [states, actions, rewards, next_states, dones] 139 | 140 | def add_transition(self): 141 | # Use this method to add new data into the replay buffer during fine-tuning. 142 | # I left it unimplemented since now we do not do fine-tuning. 143 | raise NotImplementedError 144 | 145 | def set_seed( 146 | seed: int, deterministic_torch: bool = False 147 | ): 148 | os.environ["PYTHONHASHSEED"] = str(seed) 149 | np.random.seed(seed) 150 | random.seed(seed) 151 | torch.manual_seed(seed) 152 | torch.use_deterministic_algorithms(deterministic_torch) 153 | 154 | def wandb_init(config: dict) -> None: 155 | wandb.init( 156 | config=config, 157 | project=config["project"], 158 | group=config["group"], 159 | name=config["name"], 160 | id=str(uuid.uuid4()), 161 | ) 162 | wandb.run.save() 163 | 164 | def asymmetric_l2_loss(u: torch.Tensor, tau: float) -> torch.Tensor: 165 | return torch.mean(torch.abs(tau - (u < 0).float()) * u**2) 166 | 167 | class Squeeze(nn.Module): 168 | def __init__(self, dim=-1): 169 | super().__init__() 170 | self.dim = dim 171 | 172 | def forward(self, x: torch.Tensor) -> torch.Tensor: 173 | return x.squeeze(dim=self.dim) 174 | 175 | class MLP(nn.Module): 176 | def __init__( 177 | self, 178 | dims, 179 | activation_fn: Callable[[], nn.Module] = nn.ReLU, 180 | output_activation_fn: Callable[[], nn.Module] = None, 181 | squeeze_output: bool = False, 182 | dropout: Optional[float] = None, 183 | ): 184 | super().__init__() 185 | n_dims = len(dims) 186 | if n_dims < 2: 187 | raise ValueError("MLP requires at least two dims (input and output)") 188 | 189 | layers = [] 190 | for i in range(n_dims - 2): 191 | layers.append(nn.Linear(dims[i], dims[i + 1])) 192 | layers.append(activation_fn()) 193 | 194 | if dropout is not None: 195 | layers.append(nn.Dropout(dropout)) 196 | 197 | layers.append(nn.Linear(dims[-2], dims[-1])) 198 | if output_activation_fn is not None: 199 | layers.append(output_activation_fn()) 200 | if squeeze_output: 201 | if dims[-1] != 1: 202 | raise ValueError("Last dim must be 1 when squeezing") 203 | layers.append(Squeeze(-1)) 204 | self.net = nn.Sequential(*layers) 205 | 206 | def forward(self, x: torch.Tensor) -> torch.Tensor: 207 | return self.net(x) 208 | 209 | class GaussianPolicy(nn.Module): 210 | def __init__( 211 | self, 212 | state_dim: int, 213 | act_dim: int, 214 | max_action: float, 215 | hidden_dim: int = 256, 216 | n_hidden: int = 2, 217 | dropout: Optional[float] = None, 218 | ): 219 | super().__init__() 220 | self.log_std = nn.Parameter(torch.zeros(act_dim, dtype=torch.float32)) 221 | self.max_action = max_action 222 | self.encoder0 = nn.Parameter(torch.tensor([1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 223 | 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 224 | 1e-4, 1e-4, 1e-4, 1e-4, 1e-4, 1e-5, 1e-5, 1e-5, 1e-5, 1e-5, 225 | 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 226 | 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 227 | 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 228 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 229 | 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 230 | 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 231 | 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 232 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 233 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 234 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 235 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 236 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 237 | ], dtype=torch.float32)) 238 | self.encoder0.requires_grad_(False) 239 | 240 | # encoder 1 241 | self.encoder1 = nn.Sequential( 242 | # encoder 1 243 | nn.Linear(150, 256), 244 | # nn.LayerNorm(256), 245 | nn.ReLU() 246 | ) 247 | # GRU 248 | self.gru = nn.GRU(256, 256, 2) 249 | # FC 250 | self.fc_mid = nn.Sequential( 251 | nn.Linear(256, 256), 252 | nn.ReLU() 253 | ) 254 | # Recisual Block 1(rb1) 255 | self.rb1 = nn.Sequential( 256 | nn.Linear(256, 256), 257 | nn.LeakyReLU(), 258 | nn.Linear(256, 256), 259 | nn.LeakyReLU() 260 | ) 261 | # Recisual Block 2(rb2) 262 | self.rb2 = nn.Sequential( 263 | nn.Linear(256, 256), 264 | nn.LeakyReLU(), 265 | nn.Linear(256, 256), 266 | nn.LeakyReLU() 267 | ) 268 | # final 'gmm' 269 | self.final = nn.Sequential( 270 | nn.Linear(256, 1), 271 | nn.Tanh() 272 | ) 273 | 274 | def forward(self, obs: torch.Tensor, h, c): 275 | obs_ = torch.squeeze(obs, 0) 276 | obs_ = obs_ * self.encoder0 277 | ### 278 | mean = self.encoder1(obs_) 279 | mean, _ = self.gru(mean) 280 | mean = self.fc_mid(mean) 281 | mem1 = mean 282 | mean = self.rb1(mean) + mem1 283 | mem2 = mean 284 | mean = self.rb2(mean) + mem2 285 | mean = self.final(mean) 286 | ### 287 | mean = mean * self.max_action * 1e6 # Mbps -> bps 288 | mean = mean.clamp(min = 10) # larger than 10bps 289 | std = torch.exp(self.log_std.clamp(LOG_STD_MIN, LOG_STD_MAX)) 290 | std = std.expand(mean.shape[0], 1) 291 | ret = torch.cat((mean, std), 1) 292 | ret = torch.unsqueeze(ret, 0) # (1, bs, 2) 293 | return ret, h, c 294 | 295 | class DeterministicPolicy(nn.Module): 296 | def __init__( 297 | self, 298 | state_dim: int, 299 | act_dim: int, 300 | max_action: float, 301 | hidden_dim: int = 256, 302 | n_hidden: int = 2, 303 | dropout: Optional[float] = None, 304 | ): 305 | super().__init__() 306 | self.net = MLP( 307 | [state_dim, *([hidden_dim] * n_hidden), act_dim], 308 | output_activation_fn=nn.Tanh, 309 | dropout=dropout, 310 | ) 311 | self.max_action = max_action 312 | 313 | def forward(self, obs: torch.Tensor) -> torch.Tensor: 314 | return self.net(obs) 315 | 316 | @torch.no_grad() 317 | def act(self, state: np.ndarray, device: str = "cpu"): 318 | state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32) 319 | return ( 320 | torch.clamp(self(state) * self.max_action, -self.max_action, self.max_action) 321 | .cpu() 322 | .data.numpy() 323 | .flatten() 324 | ) 325 | 326 | class TwinQ(nn.Module): 327 | def __init__( 328 | self, state_dim: int, action_dim: int, hidden_dim: int = 256, n_hidden: int = 2 329 | ): 330 | super().__init__() 331 | dims = [state_dim + action_dim, *([hidden_dim] * n_hidden), 1] 332 | self.q1 = MLP(dims, squeeze_output=True) 333 | self.q2 = MLP(dims, squeeze_output=True) 334 | self.encoder = nn.Parameter(torch.tensor([1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 335 | 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 336 | 1e-4, 1e-4, 1e-4, 1e-4, 1e-4, 1e-5, 1e-5, 1e-5, 1e-5, 1e-5, 337 | 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 338 | 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 339 | 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 340 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 341 | 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 342 | 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 343 | 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 344 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 345 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 346 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 347 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 348 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 349 | ], dtype=torch.float32)) 350 | self.encoder.requires_grad_(False) 351 | 352 | def both(self, state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 353 | state_ = state * self.encoder 354 | sa = torch.cat([state_, action], 1) 355 | return self.q1(sa), self.q2(sa) 356 | 357 | def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: 358 | return torch.min(*self.both(state, action)) 359 | 360 | class ValueFunction(nn.Module): 361 | def __init__(self, state_dim: int, hidden_dim: int = 256, n_hidden: int = 2): 362 | super().__init__() 363 | dims = [state_dim, *([hidden_dim] * n_hidden), 1] 364 | self.v = MLP(dims, squeeze_output=True) 365 | self.encoder = nn.Parameter(torch.tensor([1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 1e-6, 366 | 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 367 | 1e-4, 1e-4, 1e-4, 1e-4, 1e-4, 1e-5, 1e-5, 1e-5, 1e-5, 1e-5, 368 | 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 369 | 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 370 | 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 1e-2, 371 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 372 | 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 373 | 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 374 | 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 1e-1, 375 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 376 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 377 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 378 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 379 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 380 | ], dtype=torch.float32)) 381 | self.encoder.requires_grad_(False) 382 | 383 | def forward(self, state: torch.Tensor) -> torch.Tensor: 384 | state_ = state * self.encoder 385 | return self.v(state_) 386 | 387 | 388 | class ImplicitQLearning: 389 | def __init__( 390 | self, 391 | max_action: float, 392 | actor: nn.Module, 393 | actor_optimizer: torch.optim.Optimizer, 394 | q_network: nn.Module, 395 | q_optimizer: torch.optim.Optimizer, 396 | v_network: nn.Module, 397 | v_optimizer: torch.optim.Optimizer, 398 | iql_tau: float = 0.7, 399 | beta: float = 3.0, 400 | max_steps: int = 1000000, 401 | discount: float = 0.99, 402 | tau: float = 0.005, 403 | device: str = "cpu", 404 | ): 405 | self.max_action = max_action 406 | self.qf = q_network 407 | self.q_target = copy.deepcopy(self.qf).requires_grad_(False).to(device) 408 | self.vf = v_network 409 | self.actor = actor 410 | self.v_optimizer = v_optimizer 411 | self.q_optimizer = q_optimizer 412 | self.actor_optimizer = actor_optimizer 413 | self.actor_lr_schedule = CosineAnnealingLR(self.actor_optimizer, max_steps) 414 | self.iql_tau = iql_tau 415 | self.beta = beta 416 | self.discount = discount 417 | self.tau = tau 418 | 419 | self.total_it = 0 420 | self.device = device 421 | 422 | def _update_v(self, observations, actions, log_dict) -> torch.Tensor: 423 | # Update value function 424 | with torch.no_grad(): 425 | target_q = self.q_target(observations, actions) 426 | 427 | v = self.vf(observations) 428 | adv = target_q - v 429 | v_loss = asymmetric_l2_loss(adv, self.iql_tau) 430 | log_dict["value_loss"] = v_loss.item() 431 | self.v_optimizer.zero_grad() 432 | v_loss.backward() 433 | self.v_optimizer.step() 434 | return adv 435 | 436 | def _update_q( 437 | self, 438 | next_v: torch.Tensor, 439 | observations: torch.Tensor, 440 | actions: torch.Tensor, 441 | rewards: torch.Tensor, 442 | terminals: torch.Tensor, 443 | log_dict: Dict, 444 | ): 445 | targets = rewards + (1.0 - terminals.float()) * self.discount * next_v.detach() 446 | qs = self.qf.both(observations, actions) 447 | q_loss = sum(F.mse_loss(q, targets) for q in qs) / len(qs) 448 | # log_dict["q_rewards"] = torch.mean(rewards).item() 449 | # log_dict["q_targets"] = torch.mean(targets).item() 450 | log_dict["q_score"] = (torch.mean(qs[0]).item() + torch.mean(qs[1]).item()) / 2 451 | log_dict["q_loss"] = q_loss.item() 452 | self.q_optimizer.zero_grad() 453 | q_loss.backward() 454 | self.q_optimizer.step() 455 | 456 | # Update target Q network 457 | soft_update(self.q_target, self.qf, self.tau) 458 | 459 | def _update_policy( 460 | self, 461 | adv: torch.Tensor, 462 | observations: torch.Tensor, 463 | actions: torch.Tensor, 464 | log_dict: Dict, 465 | ): 466 | exp_adv = torch.exp(self.beta * adv.detach()).clamp(max=EXP_ADV_MAX) 467 | out_, _, _ = self.actor(observations, torch.zeros((1, 1)), torch.zeros((1, 1))) 468 | out_ = torch.squeeze(out_, 0) 469 | mean = out_[:, :1] 470 | mean = mean / 1e6 471 | std = out_[0, 1:] 472 | policy_out = Normal(mean, std) 473 | if isinstance(policy_out, torch.distributions.Distribution): 474 | bc_losses = -policy_out.log_prob(actions).sum(-1, keepdim=False) 475 | elif torch.is_tensor(policy_out): 476 | if policy_out.shape != actions.shape: 477 | raise RuntimeError("Actions shape missmatch") 478 | bc_losses = torch.sum((policy_out - actions) ** 2, dim=1) 479 | else: 480 | raise NotImplementedError 481 | policy_loss = torch.mean(exp_adv * bc_losses) 482 | log_dict["actor_all_loss"] = policy_loss.item() 483 | # log_dict["actor_bc_loss"] = torch.mean(bc_losses).item() 484 | # log_dict["actor_expadv_loss"] = torch.mean(exp_adv).item() 485 | self.actor_optimizer.zero_grad() 486 | policy_loss.backward() 487 | self.actor_optimizer.step() 488 | self.actor_lr_schedule.step() 489 | 490 | def train(self, batch: TensorBatch) -> Dict[str, float]: 491 | self.total_it += 1 492 | ( 493 | observations, 494 | actions, 495 | rewards, 496 | next_observations, 497 | dones, 498 | ) = batch 499 | log_dict = {} 500 | 501 | # next state's score 502 | with torch.no_grad(): 503 | next_v = self.vf(next_observations) 504 | # Update value function 505 | adv = self._update_v(observations, actions, log_dict) 506 | rewards = rewards.squeeze(dim=-1) 507 | dones = dones.squeeze(dim=-1) 508 | # Update Q function 509 | self._update_q(next_v, observations, actions, rewards, dones, log_dict) 510 | # Update actor 511 | self._update_policy(adv, observations, actions, log_dict) 512 | 513 | return log_dict 514 | 515 | def state_dict(self) -> Dict[str, Any]: 516 | return { 517 | "qf": self.qf.state_dict(), 518 | "q_optimizer": self.q_optimizer.state_dict(), 519 | "vf": self.vf.state_dict(), 520 | "v_optimizer": self.v_optimizer.state_dict(), 521 | "actor": self.actor.state_dict(), 522 | "actor_optimizer": self.actor_optimizer.state_dict(), 523 | "actor_lr_schedule": self.actor_lr_schedule.state_dict(), 524 | "total_it": self.total_it, 525 | } 526 | 527 | def load_state_dict(self, state_dict: Dict[str, Any]): 528 | self.qf.load_state_dict(state_dict["qf"]) 529 | self.q_optimizer.load_state_dict(state_dict["q_optimizer"]) 530 | self.q_target = copy.deepcopy(self.qf) 531 | 532 | self.vf.load_state_dict(state_dict["vf"]) 533 | self.v_optimizer.load_state_dict(state_dict["v_optimizer"]) 534 | 535 | self.actor.load_state_dict(state_dict["actor"]) 536 | self.actor_optimizer.load_state_dict(state_dict["actor_optimizer"]) 537 | self.actor_lr_schedule.load_state_dict(state_dict["actor_lr_schedule"]) 538 | 539 | self.total_it = state_dict["total_it"] 540 | 541 | 542 | def get_input_from_file(): 543 | # dummy -> real input 544 | evaluation_file = '../evaluation/data/02560.json' 545 | with open(evaluation_file, "r") as file: 546 | call_data = json.load(file) 547 | observations = np.asarray(call_data['observations'], dtype=np.float32) 548 | observations = observations.reshape(1, -1, STATE_DIM) 549 | return observations 550 | 551 | def export2onnx(pt_path, onnx_path): 552 | """ 553 | trans pt to onnx 554 | """ 555 | BS = 1 # batch size 556 | hidden_size = 1 # number of hidden units in the LSTM 557 | 558 | # instantiate the ML BW estimator 559 | torchBwModel = GaussianPolicy(STATE_DIM, ACTION_DIM, MAX_ACTION) 560 | torchBwModel.load_state_dict(torch.load(pt_path)) 561 | # create inputs: 1 episode x T timesteps x obs_dim features 562 | dummy_inputs = get_input_from_file() 563 | torch_dummy_inputs = torch.as_tensor(dummy_inputs) 564 | torch_initial_hidden_state = torch.zeros((BS, hidden_size)) 565 | torch_initial_cell_state = torch.zeros((BS, hidden_size)) 566 | # predict dummy outputs: 1 episode x T timesteps x 2 (mean and std) 567 | dummy_outputs, final_hidden_state, final_cell_state = torchBwModel(torch_dummy_inputs, torch_initial_hidden_state, torch_initial_cell_state) 568 | # save onnx model 569 | os.makedirs(os.path.dirname(onnx_path), exist_ok=True) 570 | torchBwModel.to("cpu") 571 | torchBwModel.eval() 572 | torch.onnx.export( 573 | torchBwModel, 574 | (torch_dummy_inputs[0:1, 0:1, :], torch_initial_hidden_state, torch_initial_cell_state), 575 | onnx_path, 576 | opset_version=11, 577 | input_names=['obs', 'hidden_states', 'cell_states'], # the model's input names 578 | output_names=['output', 'state_out', 'cell_out'], # the model's output names 579 | ) 580 | 581 | # verify torch and onnx models outputs 582 | ort_session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider']) 583 | onnx_hidden_state, onnx_cell_state = (np.zeros((1, hidden_size), dtype=np.float32), np.zeros((1, hidden_size), dtype=np.float32)) 584 | torch_hidden_state, torch_cell_state = (torch.as_tensor(onnx_hidden_state), torch.as_tensor(onnx_cell_state)) 585 | # online interaction: step through the environment 1 time step at a time 586 | with torch.no_grad(): 587 | for i in tqdm(range(dummy_inputs.shape[1]), desc="Verifing "): 588 | torch_estimate, torch_hidden_state, torch_cell_state = torchBwModel(torch_dummy_inputs[0:1, i:i+1, :], torch_hidden_state, torch_cell_state) 589 | feed_dict= {'obs': dummy_inputs[0:1, i:i+1, :], 'hidden_states': onnx_hidden_state, 'cell_states': onnx_cell_state} 590 | onnx_estimate, onnx_hidden_state, onnx_cell_state = ort_session.run(None, feed_dict) 591 | assert np.allclose(torch_estimate.numpy(), onnx_estimate, atol=10), 'Failed to match model outputs!, {}, {}'.format(torch_estimate.numpy(), onnx_estimate) 592 | assert np.allclose(torch_hidden_state, onnx_hidden_state, atol=1e-7), 'Failed to match hidden state1' 593 | assert np.allclose(torch_cell_state, onnx_cell_state, atol=1e-7), 'Failed to match cell state!' 594 | 595 | assert np.allclose(torch_hidden_state, final_hidden_state, atol=1e-7), 'Failed to match final hidden state!' 596 | assert np.allclose(torch_cell_state, final_cell_state, atol=1e-7), 'Failed to match final cell state!' 597 | # print("Torch and Onnx models outputs have been verified successfully!") 598 | 599 | def evaluate(onnx_path): 600 | ort_session = ort.InferenceSession(onnx_path) 601 | 602 | every_call_mse = [] 603 | every_call_accuracy = [] 604 | for f_path in tqdm(small_evaluation_datasets, desc="Evaluating"): 605 | with open(f_path, 'r') as file: 606 | call_data = json.load(file) 607 | 608 | observations = np.asarray(call_data['observations'], dtype=np.float32) 609 | true_capacity = np.asarray(call_data['true_capacity'], dtype=np.float32) 610 | 611 | model_predictions = [] 612 | hidden_state, cell_state = np.zeros((1, 1), dtype=np.float32), np.zeros((1, 1), dtype=np.float32) 613 | for t in range(observations.shape[0]): 614 | obss = observations[t:t+1,:].reshape(1,1,-1) 615 | feed_dict = {'obs': obss, 616 | 'hidden_states': hidden_state, 617 | 'cell_states': cell_state 618 | } 619 | bw_prediction, hidden_state, cell_state = ort_session.run(None, feed_dict) 620 | model_predictions.append(bw_prediction[0,0,0]) 621 | # mse and accuracy of this call 622 | model_predictions = np.asarray(model_predictions, dtype=np.float32) 623 | true_capacity = true_capacity / 1e6 624 | model_predictions = model_predictions / 1e6 625 | call_mse = [] 626 | call_accuracy = [] 627 | for true_bw, pre_bw in zip(true_capacity, model_predictions): 628 | if np.isnan(true_bw) or np.isnan(pre_bw): 629 | continue 630 | else: 631 | mse_ = (true_bw - pre_bw) ** 2 632 | call_mse.append(mse_) 633 | accuracy_ = max(0, 1 - abs(pre_bw - true_bw) / true_bw) 634 | call_accuracy.append(accuracy_) 635 | call_mse = np.asarray(call_mse, dtype=np.float32) 636 | every_call_mse.append(np.mean(call_mse)) 637 | call_accuracy = np.asarray(call_accuracy, dtype=np.float32) 638 | every_call_accuracy.append(np.mean(call_accuracy)) 639 | every_call_mse = np.asarray(every_call_mse, dtype=np.float32) 640 | every_call_accuracy = np.asarray(every_call_accuracy, dtype=np.float32) 641 | return np.mean(every_call_mse), np.mean(every_call_accuracy) 642 | 643 | 644 | @pyrallis.wrap() 645 | def train(config: TrainConfig): 646 | state_dim = STATE_DIM 647 | action_dim = ACTION_DIM 648 | 649 | testdataset_file = open(pickle_path, 'rb') 650 | dataset = pickle.load(testdataset_file) 651 | print('dataset loaded') 652 | 653 | replay_buffer = ReplayBuffer( 654 | state_dim, 655 | action_dim, 656 | config.buffer_size, 657 | config.device, 658 | ) 659 | replay_buffer.load_dataset(dataset) 660 | 661 | max_action = MAX_ACTION 662 | 663 | if config.checkpoints_path is not None: 664 | print(f"Checkpoints path: {config.checkpoints_path}") 665 | os.makedirs(config.checkpoints_path, exist_ok=True) 666 | with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f: 667 | pyrallis.dump(config, f) 668 | 669 | # Set seeds 670 | seed = config.seed 671 | set_seed(seed) 672 | 673 | q_network = TwinQ(state_dim, action_dim).to(config.device) 674 | v_network = ValueFunction(state_dim).to(config.device) 675 | actor = ( 676 | DeterministicPolicy( 677 | state_dim, action_dim, max_action, dropout=config.actor_dropout 678 | ) 679 | if config.iql_deterministic 680 | else GaussianPolicy( 681 | state_dim, action_dim, max_action, dropout=config.actor_dropout 682 | ) 683 | ).to(config.device) 684 | v_optimizer = torch.optim.Adam(v_network.parameters(), lr=config.vf_lr) 685 | q_optimizer = torch.optim.Adam(q_network.parameters(), lr=config.qf_lr) 686 | actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.actor_lr) 687 | 688 | kwargs = { 689 | "max_action": max_action, 690 | "actor": actor, 691 | "actor_optimizer": actor_optimizer, 692 | "q_network": q_network, 693 | "q_optimizer": q_optimizer, 694 | "v_network": v_network, 695 | "v_optimizer": v_optimizer, 696 | "discount": config.discount, 697 | "tau": config.tau, 698 | "device": config.device, 699 | # IQL 700 | "beta": config.beta, 701 | "iql_tau": config.iql_tau, 702 | "max_steps": config.max_timesteps, 703 | } 704 | 705 | print("---------------------------------------") 706 | print(f"Training IQL, Env: {config.env}, Seed: {seed}") 707 | print("---------------------------------------") 708 | 709 | # Initialize actor 710 | trainer = ImplicitQLearning(**kwargs) 711 | 712 | if config.load_model != "": 713 | policy_file = Path(config.load_model) 714 | trainer.load_state_dict(torch.load(policy_file)) 715 | actor = trainer.actor 716 | 717 | if USE_WANDB: 718 | wandb_init(asdict(config)) 719 | 720 | for t in range(int(config.max_timesteps)): 721 | batch = replay_buffer.sample(config.batch_size) 722 | batch = [b.to(config.device) for b in batch] 723 | log_dict = trainer.train(batch) 724 | if USE_WANDB: 725 | wandb.log(log_dict, step=trainer.total_it) 726 | # Evaluate episode 727 | if (t + 1) % config.eval_freq == 0: 728 | print(f"Time steps: {t + 1}") 729 | 730 | pt_path = os.path.join(config.checkpoints_path, f"checkpoint_{t + 1}.pt") 731 | onnx_path = os.path.join(config.checkpoints_path, f"checkpoint_{t + 1}.onnx") 732 | # save pt 733 | if config.checkpoints_path is not None: 734 | torch.save(trainer.state_dict()["actor"], pt_path) 735 | # save onnx 736 | export2onnx(pt_path, onnx_path) 737 | # evaluate 738 | mse_, accuracy_ = evaluate(onnx_path) 739 | if USE_WANDB and trainer.total_it > 1000: 740 | wandb.log({"mse": mse_, "error_rate": 1 - accuracy_}, step=trainer.total_it) 741 | 742 | 743 | if __name__ == "__main__": 744 | train() 745 | -------------------------------------------------------------------------------- /onnx_model/Schaferct_model.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/n13eho/Schaferct/da49600ba1fb915181081cd8183c7cb13f278bc9/onnx_model/Schaferct_model.onnx -------------------------------------------------------------------------------- /onnx_model_for_evaluation/baseline.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/n13eho/Schaferct/da49600ba1fb915181081cd8183c7cb13f278bc9/onnx_model_for_evaluation/baseline.onnx -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.8.2 2 | numpy==1.26.4 3 | onnxruntime==1.16.3 4 | pandas==2.2.1 5 | pyrallis==0.3.1 6 | scikit_learn==1.4.1.post1 7 | torch==2.1.2 8 | tqdm==4.66.1 9 | -------------------------------------------------------------------------------- /training_dataset_pickle/README.md: -------------------------------------------------------------------------------- 1 | Download the `v8.pickle` from [link](https://drive.google.com/file/d/1I1XvvM5lYX21pbqnuahOW9BFdC7KXUQ7/view?usp=sharing) and put it in this directory. --------------------------------------------------------------------------------